决策树

决策树(Decision Tree) 是广泛用于 分类(classification) 和 回归(regression) 任务的模型。本质上,它从一层层的 if/else 问题中进行学习,并得出结论。

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
In [2]:
mglearn.plots.plot_animal_tree()
在这张图中,树的每个结点代表一个问题或一个包含答案的终结点(也叫叶结点)。树的 边将问题的答案与将问的下一个问题连接起来。 用机器学习的语言来说就是,了区分四类动物(鹰、企鹅、海豚和熊),我们利用三个 特征(“有没有羽毛”“会不会飞”和“有没有鳍”)来构建一个模型。我们可以利用监督 学习从数据中学习模型,而无需人为构建模型。

1.决策树的概念

根结点(Root Node):如图所示只有子节点,没有父节点的节点,叫做根节点。

image.png

内部节点(Internal Nodes)(节点):如图所示既有父节点,也有子节点的节点。它也被称为节点。

image.png

叶子结点(Leaf Nodes):如图所示仅仅只有父节点的节点叫做叶子结点。

image.png

节点的 深度 可以理解为节点与决策树根结点的距离,如根节点的子节点的深度为1,因为这些节点与跟节点的距离为1,子节点的深度要比父节点的深度大1。决策树的深度是所有叶子节点的最大深度,当深度到达指定的上限大小时,停止分裂。那么决策树是什么呢? 简单来说决策树就是一棵树,其中跟节点是输入特征的判定条件,叶子结点就是最终结果。

2.构造决策树

学习决策树,就是学习一系列 if/else 问题,使我们能够以最快的速度得到正确答案。在机器 学习中,这些问题叫作 测试 (不要与测试集弄混,测试集是用来测试模型泛化性能的数据)。数据通常并不是像动物的例子那样具有二元特征(是 / 否)的形式,而是表示为连续特征, 比如图所示的二维数据集 图中 圆形:类别0 三角形:类别1

image.png

为了构造决策树,算法搜遍所有可能的测试,找出对目标变量来说信息量最大的那一个。 图 2-24 展示了选出的第一个测试。将数据集在 x[1]=0.0596 处垂直划分可以得到最多信 息,它在最大程度上将 类别0 中的点与 类别1 中的点进行区分。

image.png

尽管第一次划分已经对两个类别做了很 好的区分,但底部区域仍包含属于 类别0 的点,顶部区域也仍包含属于 类别1 的点。我们 可以在两个区域中重复寻找最佳测试的过程,从而构建出更准确的模型。下图 展示了信 息量最大的下一次划分,这次划分是基于 x[0] 做出的,分为左右两个区域。 这一递归过程生成一棵二元决策树,其中每个结点都包含一个测试。或者你可以将每个测 试看成沿着一条轴对当前数据进行划分。这是一种将算法看作分层划分的观点。由于每个 测试仅关注一个特征,所以划分后的区域边界始终与坐标轴平行。

image.png

对数据反复进行递归划分,直到划分后的每个区域(决策树的每个叶结点)只包含单一目 标值(单一类别或单一回归值)。如果树中某个叶结点所包含数据点的目标值都相同,那 么这个叶结点就是纯的(pure)。这个数据集的最终划分结果见下图。

image.png

通常来说,构造决策树直到所有叶结点都是纯的叶结点,这会导致模型非常复杂,并且对 训练数据高度过拟合。纯叶结点的存在说明这棵树在训练集上的精度是 100%。训练集中 的每个数据点都位于分类正确的叶结点中。在上图的左图中可以看出过拟合。你可以看 到,在所有属于类别 0 的点中间有一块属于类别 1 的区域。另一方面,有一小条属于类别 0 的区域,包围着最右侧属于类别 0 的那个点。这并不是人们想象中决策边界的样子,这 个决策边界过于关注远离同类别其他点的单个异常点。防止过拟合有两种常见的策略:一种是及早停止树的生长,也叫 预剪枝(pre-pruning); 另一种是先构造树,但随后删除或折叠信息量很少的结点,也叫 后剪枝(post-pruning)或 剪枝(pruning)。预剪枝的限制条件可能包括限制树的最大深度、限制叶结点的最大数目, 或者规定一个结点中数据点的最小数目来防止继续划分。scikit-learn 的决策树在 DecisionTreeRegressor 类和 DecisionTreeClassifier 类中实现。 scikit-learn 只实现了预剪枝,没有实现后剪枝。

3.控制决策树的复杂度--乳腺癌实例

我们在乳腺癌数据集上更详细地看一下预剪枝的效果。和前面一样,我们导入数据集并将 其分为训练集和测试集。然后利用默认设置来构建模型,默认将树完全展开(树不断分 支,直到所有叶结点都是纯的)。我们固定树的 random_state ,用于在内部解决平局问题:
In [3]:
#导入乳腺癌数据集
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
scikit-learn 中的 train_test_split 函数可以打乱数据集并进行拆分。这个函数将 75% 的 行数据及对应标签作为训练集,剩下 25% 的数据及其标签作为测试集。训练集与测试集的 分配比例可以是随意的,但使用 25% 的数据作为测试集是很好的经验法则。
In [4]:
#分出 训练数据(training data) 和 测试数据(test data)
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
cancer.data, cancer.target, stratify=cancer.target, random_state=42)
#cancer.data 为 train_data,即所要划分的样本特征集
#cancer.target 为 train_target,即所要划分的样本结果,目标变量
#用了stratify参数,training集和testing集的类的比例是 A:B= 4:1
# random_state:是随机数的种子。
# 随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。
In [5]:
#使用 训练数据 构建模型,并测试 训练精度
from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(random_state=0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))
Accuracy on training set: 1.000
Accuracy on test set: 0.937
不出所料,训练集上的精度是 100%,这是因为叶结点都是纯的,树的深度很大,足以完 美地记住训练数据的所有标签。graphviz如果我们不限制决策树的深度,它的深度和复杂度都可以变得特别大。因此,未剪枝的树 容易过拟合,对新数据的泛化性能不佳。现在我们将预剪枝应用在决策树上,这可以在完 美拟合训练数据之前阻止树的展开。一种选择是在到达一定深度后停止树的展开。这里我 们设置 max_depth=4。限制树的深度可以减少过拟合。这会降低训练集的精度,但可以 提高测试集的精度。
In [6]:
#限制树的深度为4,并重新构建模型,测试 训练精度
tree = DecisionTreeClassifier(max_depth=4, random_state=0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))
Accuracy on training set: 0.988
Accuracy on test set: 0.951

4.分析决策树

我们可以利用 tree 模块 的 export_graphviz 函数来将树可视化。 这个函数会生成一 个 .dot 格式的文件,这是一种用于保存图形的文本文件格式。我们设置为结点添加颜色 的选项,颜色表示每个结点中的多数类别,同时传入类别名称和特征名称,这样可以对 树正确标记
In [7]:
#将树可视化,生成本地文件
from sklearn.tree import export_graphviz
export_graphviz(tree, out_file="tree.dot", class_names=["malignant","benign"],
feature_names=cancer.feature_names, impurity=False, filled=True)
我们可以利用 graphviz 模块读取这个文件并将其可视化(你也可以使用任何能够读取 .dot 文件的程序)
In [8]:
import graphviz
with open("tree.dot") as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)
Out[8]:
Tree 0 worst radius <= 16.795 samples = 426 value = [159, 267] class = benign 1 worst concave points <= 0.136 samples = 284 value = [25, 259] class = benign 0->1 True 14 texture error <= 0.473 samples = 142 value = [134, 8] class = malignant 0->14 False 2 radius error <= 1.048 samples = 252 value = [4, 248] class = benign 1->2 7 worst texture <= 25.62 samples = 32 value = [21, 11] class = malignant 1->7 3 smoothness error <= 0.003 samples = 251 value = [3, 248] class = benign 2->3 6 samples = 1 value = [1, 0] class = malignant 2->6 4 samples = 4 value = [1, 3] class = benign 3->4 5 samples = 247 value = [2, 245] class = benign 3->5 8 worst smoothness <= 0.179 samples = 12 value = [3, 9] class = benign 7->8 11 worst symmetry <= 0.268 samples = 20 value = [18, 2] class = malignant 7->11 9 samples = 10 value = [1, 9] class = benign 8->9 10 samples = 2 value = [2, 0] class = malignant 8->10 12 samples = 3 value = [1, 2] class = benign 11->12 13 samples = 17 value = [17, 0] class = malignant 11->13 15 samples = 5 value = [0, 5] class = benign 14->15 16 worst concavity <= 0.191 samples = 137 value = [134, 3] class = malignant 14->16 17 worst texture <= 30.975 samples = 5 value = [2, 3] class = benign 16->17 20 samples = 132 value = [132, 0] class = malignant 16->20 18 samples = 3 value = [0, 3] class = benign 17->18 19 samples = 2 value = [2, 0] class = malignant 17->19

5.树的特征重要性

查看整个树可能非常费劲,除此之外,我还可以利用一些有用的属性来总结树的工作原 理。其中最常用的是 特征重要性(feature importance),它为每个特征对树的决策的重要性 进行排序。对于每个特征来说,它都是一个介于 0 和 1 之间的数字,其中 0 表示“根本没 用到”,1 表示“完美预测目标值”。特征重要性的求和始终为 1。
In [9]:
print("Feature importances:\n{}".format(tree.feature_importances_))
Feature importances:
[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.01019737 0.04839825
 0.         0.         0.0024156  0.         0.         0.
 0.         0.         0.72682851 0.0458159  0.         0.
 0.0141577  0.         0.018188   0.1221132  0.01188548 0.        ]
我们可以将特征重要性可视化
In [10]:
def plot_feature_importances_cancer(model):
    n_features = cancer.data.shape[1]
    plt.barh(range(n_features), model.feature_importances_, align='center')
    plt.yticks(np.arange(n_features), cancer.feature_names)
    plt.xlabel("Feature importance")
    plt.ylabel("Feature")
    
plot_feature_importances_cancer(tree)