决策树分类基本原理

对于分类任务,使用树形结构进行判断是一件很自然的事情

例如我们要判断一个目标是苹果、梨还是香蕉,我们会依次判断其形状、颜色、……
每一次判断都会把目标划分到原类别集合的一个子集内,如果我们把这个过程画出来,就会形成一棵树,这就是决策树分类的原理

decision_tree_sample

形式化的说,决策树包含一个根节点、若干个内部节点和若干个叶节点
每个内部节点表示一次划分决策,每个叶节点表示一个决策(分类)结果,从根出发到达某个叶子的路径就代表完成了一个完整的分类过程

假设样本属性均为离散型,如下图所示为构建决策树的伪码
其中如何选择划分标准使得决策树时空消耗更小、泛化能力更强,就是不同决策树算法所探讨的问题

decision_tree_algorithm

ID3

ID3算法使用信息增益(information gain)作为划分标准

假设当前样本集合$D$中属于第$k$类样本的集合为$D_k$,则$D$的信息熵,也即不确定度为

若使用某属性$A$作为划分标准且$A$有$V$个可能取值,则该次划分将产生$V$个分支

记第$i$个分支包含的样本子集为$D^{(i)}$,则有条件熵

信息增益表示随机事件A发生后对样本集D不确定性的减少程度,即

显然不确定性减小越多的划分越好,因此信息增益最大的划分就是最优划分

C4.5

实际上,使用信息增益作为划分标准会使得决策树倾向于选择具有大量不同值的属性

为了减少这种影响,C4.5算法使用增益率(gain ratio)作为划分标准

其中

需要注意的是,若直接选择增益率最大的属性进行划分,则决策树可能倾向于选择不同值较少的属性,且其中一个子树的大小比其他子树小得多

实际上C4.5还使用了一种启发式方法来避免这些问题,即先找到使得信息增益高于平均的属性,再从中找到使得增益率最大的

CART

CART算法使用gini指数来衡量样本集D的纯度

从直观上看,gini指数表示从D中随机选择两个样本,其类别不一致的概率
因此gini指数越小,则样本集D的纯度越高,划分时选择使得基尼指数减小最多的属性即可

二叉划分

显然上述ID3与C4.5算法都只能处理离散属性,CART提出了二叉划分来处理连续属性

对于连续属性A,设其有n个不同取值,从小到大记为$a1,a_2,…,a_n$
则对属性A有候选划分点集合${\frac{a_i+a
{i+1}}{2}|1\leq i\leq n-1}$
对于每个划分点$a$,可将样本集$D$划分为$D^{-},D^{+}$,分别包含在属性A上取值小于等于和大于$a$的样本

采用这种二叉划分策略,可将离散属性也视为连续属性

此外二叉划分策略下同一属性可多次作为划分标准

剪枝处理

剪枝是决策树应对过拟合的主要方法,分为预剪枝后剪枝

预剪枝是在构建决策树的过程中限制树的大小,例如

  • 限制树的最大高度
  • 设置叶节点所需最少样本数量、划分所需的最小样本数量
  • 若当前节点的所有划分都不能带来显著的性能提升(如信息增益都很小),则不继续划分

后剪枝是生成完整的决策树后,将过拟合的子树替换为叶子节点

多变量决策树

容易发现,上述决策树算法的每一次划分的分类边界都是多维属性空间中一个和坐标轴平面平行的超平面

然而实际的分类边界往往非常复杂,决策树需要很多段划分才能得到较好的近似

因此一个自然的想法就是使用“斜的”划分边界,即决策树的每次划分考虑多个属性的线性组合

这类决策树有OC1算法和Brodley提出的Multivariate decision trees算法等等,此处不展开说明

代码实现

最后给出numpy实现的简单决策树的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import numpy as np

class BaseDecisionTree:
def entropy(self, y):
_, cnt = np.unique(y, return_counts=True)
probs = cnt / y.shape[0]
return -np.sum(probs * np.log(probs) / np.log(2))

def gini(self, y):
_, cnt = np.unique(y, return_counts=True)
probs = cnt / y.shape[0]
return 1 - np.sum(np.square(probs))

def calcInfoGainRatio(self, y, y_sub):
n_samples = y.shape[0]

entropy = self.entropy(y)
new_entropy = np.sum([ny.size / n_samples * self.entropy(ny) for ny in y_sub])
gain = entropy - new_entropy

iv = - np.sum([ny.size / n_samples * np.log(ny.size / n_samples) for ny in y_sub])
gain_ratio = gain / iv

return gain_ratio

def calcInfoGain(self, y, y_sub):
n_samples = y.shape[0]

entropy = self.entropy(y)
new_entropy = np.sum([ny.size / n_samples * self.entropy(ny) for ny in y_sub])

return entropy - new_entropy

def calcGiniGain(self, y, y_sub):
n_samples = y.shape[0]

gini = self.gini(y)
new_gini = np.sum([ny.size / n_samples * self.gini(ny) for ny in y_sub])

return gini - new_gini

def storeTree(self):
import pickle
with open('DecisionTree.txt', 'wb') as fw:
pickle.dump(self.decisionTree, fw)

def loadTree(self):
import pickle
with open('DecisionTree.txt', 'rb') as fr:
self.decisionTree = pickle.load(fr)

class BinaryDecisionTree(BaseDecisionTree):
def __init__(self, loss='info_gain'):
self.loss = loss
self.decisionTree = {}

def fit(self, x, y):
self.decisionTree = self.createTree(x, y)

def createTree(self, x, y):
# all labels are same
if np.unique(y).size == 1:
return y[0][0]

index, threshold = self.split(x, y)

if index == -1:
return np.argmax(np.bincount(y.flatten()))

decisionTree = {'split_index': index, 'threshold': threshold}

pos = x[:, index] < threshold
decisionTree['left'] = self.createTree(x[pos], y[pos]) # samples with features[index] < threshold
decisionTree['right'] = self.createTree(x[~pos], y[~pos])

return decisionTree

def split(self, x, y):
n_samples, n_feat = x.shape[:]
opt_loss, split_index, threshold = 0, -1, None

# iterate over all features to find optimal splitting index
for i in range(n_feat):
unique = np.unique(np.sort(x[:, i]))
split_val = (unique[:-1] + unique[1:]) / 2
for val in split_val:
pos = x[:, i] < val
y_sub = [y[pos], y[~pos]]

if self.loss == 'info_gain':
new_loss = self.calcInfoGain(y, y_sub)
elif self.loss == 'info_gain_ratio':
new_loss = self.calcInfoGainRatio(y, y_sub)
else:
new_loss = self.calcGiniGain(y, y_sub)

if new_loss > opt_loss:
opt_loss, split_index, threshold = new_loss, i, val

return split_index, threshold

def classify(self, inputs, node):
index = node['split_index']
nxt = node['left'] if inputs[index] < node['threshold'] else node['right']

if type(nxt).__name__ == 'dict':
return self.classify(inputs, nxt)
else:
return nxt

def predict(self, inputs):
res = [self.classify(sample, self.decisionTree) for sample in inputs]
return np.array(res)

class MultiwayDecisionTree(BaseDecisionTree):
"""
:param loss {"info_gain", "info_gain_ratio", "gini"}, default="info_gain"
"""
def __init__(self, loss='info_gain'):
self.loss = loss
self.decisionTree = {}

def fit(self, x, y):
used = np.full((x.shape[1], ), False)
self.decisionTree = self.createTree(x, y, used)

def createTree(self, x, y, used):
# all labels are same
if np.unique(y).size == 1:
return y[0][0]

# most frequent label
most_freq = np.argmax(np.bincount(y.flatten()))

# no usable feature
if np.all(used):
return most_freq

index = self.split(x, y, used)

# all samples are same
if index == -1:
return most_freq

used[index] = True
decisionTree = {'split_index': index, 'others': most_freq}

unique = np.unique(x[:, index])
for val in unique:
pos = np.where(x[:, index] == val)
decisionTree[val] = self.createTree(x[pos], y[pos], used)

return decisionTree

def split(self, x, y, used):
n_samples, n_feat = x.shape[:]
opt_loss, split_index = 0., -1

# iterate over all features to find optimal splitting index
for i in range(n_feat):
if used[i]:
continue

unique = np.unique(x[:, i])
y_sub = [y[np.where(x[:, i] == val)] for val in unique]

if self.loss == 'info_gain':
new_loss = self.calcInfoGain(y, y_sub)
elif self.loss == 'info_gain_ratio':
new_loss = self.calcInfoGainRatio(y, y_sub)
else:
new_loss = self.calcGiniGain(y, y_sub)

if new_loss > opt_loss:
max_gain, split_index = new_loss, i

return split_index

def classify(self, inputs, node):
index = node['split_index']

if inputs[index] not in node:
return node['others']

nxt = node[inputs[index]]

if type(nxt).__name__ == 'dict':
return self.classify(inputs, nxt)
else:
return nxt

def predict(self, inputs):
res = [self.classify(sample, self.decisionTree) for sample in inputs]
return np.array(res)