DecisionTree,决策树,是一种 non-parametric supervised learning 方法,应该说是 Machine Learning 中相当流行的方法了。
当我们人在做决策的时候,我们是如何思考的?
- 如果 X,则…
很合理。实际上,对于任意的决策,都有一系列的思考路径 A->B->C->…,将每一个思考并选择的事件视作是一种分叉节点,得到最终思考结果的位置作为叶子,我们就得到了一棵决策树。决策树本身结构很简单,但表达力是极强的,理论上可以拟合任意的决策过程。
当然,只是理论上。实际上决策树的训练采用的是启发式贪心算法,所以也没那么厉害。
决策树有许多的优点。它非常灵活,可以拟合非线性关系,可以可视化出每一个节点,还算可以理解。对数据没什么要求,拿起来就能用。能处理 Classification 和 Regression,能处理 numerical variable 和 categorical variable. 可以处理 multi output…
sounds great,那么,代价是什么呢?
- 由于决策树过于灵活,很容易出现 overfit. 一般 DecisionTree 会作为 Base Method,采用 Ensemble 方法去做 RandomForest、GradientBoosting 之类的。
- 决策树的预测输出是不连续的,这意味着它在 Extrapolation 上表现是比较差的。
- 求解最佳决策树是 NPC 问题,一般都采用启发式的训练方法。
- 有一些关系是树模型比较难以学习的,比如 XOR、parity、multiplexer 之类的。
- 如果训练集的 class 不均匀,可能会产生 biased trees.
- 在表达变量之间的 interaction 时有局限性。或者说 Decision Boundaries 必须垂直于坐标轴,表达斜线之类的比较困难,但也可以一定程度上表达出来。
一般而言,决策树适用于:
- 样本量相对较大。
- 对可解释性有一定要求,但又不是那么有要求。
- 对泛化能力要求相对不那么高,或者说样本足够覆盖我们在乎的常见情况。
决策树相对而言还是没那么好用,但其衍生的随机森林之类的 Ensemble Methods 可以说是超级流行的 ML 算法了。
Algorithms
Model
决策树的基本思想是:每个节点是一个 condition,比如说 ,又比如 . 组成的一棵树,最终的叶子节点就是 output.
决策树递归地划分 feature space.
设被归类至节点 m 处的数据为 ,包含 个样本,对于每一个可能的 split ,代表选择第 个 feature,设置阈值为 ,即在这个节点使用条件 划分样本。
如果 split 选定了,就可以划分得到:
可以使用 impurity function 或者 loss function 来评估划分的质量。这里的评估需要根据具体问题来选择, classification 和 regression 是不一样的。可以表述为:
选择最优的参数使 impurity 最小:
递归处理子集直到达到最大深度。
如果是 Classification 问题的话,分类结果可能为 ,对于节点 ,记:
为 class 在节点 中的数量。若 是一个终止节点,则相应的 class probability 设置为 .
常用的 impurity 度量:
Gini:
Log Loss 或 Entropy:
使用信息熵等价于最小化 Log Loss.
如果是 Regression 的话,常用的指标包括 MSE、Mean Poisson deviance 和 Mean Absolute Error.
Training
一般而言,大家使用的是迭代式的贪心算法。常见的实现包括 CART、C5.0、ID3.
不管了,我自己写一个吧。
在每一个节点,对于每一个 feature ,都有一系列可能的取值。排序去重后,得到一个 .
对于 numerical variable,,对于 categorical, 则是一个 class. 这两种需要分开看待。
对于 numerical feature,我们需要遍历可能的阈值 ,筛选 ,并且选择最优的 使得 loss 最小。
对于 categorical feature,我们需要遍历所有可能的 class ,筛选 和 两个分类,然后选择最优的 使得 loss 最小。
每一步都如此操作,就需要遍历 次,过程中 loss 是可以动态更新的。总的复杂度为 .
实现的时候,每个 feature 逐个设置 threshold,更新 left tree 的 class counts. 排序这里可以预处理离散化,然后用桶排序做到线性,如果直接排序的话总复杂度会退化到 .
为了省事,还是这么做了。实现得很粗糙。
一般大家还会设置一些常用的限制,比如 max depth、min samples split、min samples leaf、min purity decrease 等。
std::unique_ptr<TreeNode> DecisionTree::_fit(size_t n, Sample *samples,
int depth) {
#ifdef DEBUG
puts("");
puts("");
printf("fitting with %lu samples\n", n);
#endif
std::vector<size_t> counts;
std::vector<size_t> now_counts;
counts.resize(this->labels);
// count all labels
for (size_t idx = 0; idx < n; ++idx) {
counts[samples[idx].label]++;
}
// check if all labels are the same
size_t most_freq = 0;
for (size_t idx = 0; idx < this->labels; ++idx) {
if (counts[idx] == n) {
return std::make_unique<TreeNode>(leaf_node(idx));
}
if (counts[idx] > counts[most_freq]) {
most_freq = idx;
}
}
if (depth >= this->options.max_depth ||
n <= this->options.min_samples_split) {
return std::make_unique<TreeNode>(leaf_node(most_freq));
}
int split_type = 0; // 1 for num, 2 for cat
size_t feat_idx = 0;
double threshold = 0;
int cate = 0;
double minloss = 1e20;
double init_loss = 0;
for (size_t k = 0; k < this->labels; ++k) {
init_loss -= (counts[k] * 1.0 / n) * log(counts[k] * 1.0 / n);
}
if (this->feat_nums > 0) {
now_counts.resize(this->labels);
for (size_t idx = 0; idx < this->feat_nums; ++idx) {
// sort by this feature
std::sort(samples, samples + n, [&](Sample const &a, Sample const &b) {
return a.nums[idx] < b.nums[idx];
});
// now calculate loss dynamically
double loss = 0;
// reset to zero
std::memset(now_counts.data(), 0, sizeof(size_t) * this->labels);
double last = samples[0].nums[idx];
for (size_t i = 0; i < n; ++i) {
// split between [0, i] and (i, n)
// only ends on unique values
now_counts[samples[i].label]++;
while (i + 1 < n && samples[i + 1].nums[idx] == samples[i].nums[idx]) {
++i;
now_counts[samples[i].label]++;
}
// all nodes in left tree, break
if (i >= n - 1)
break;
// left tree size i + 1, right tree size n - i - -1
loss = 0;
#ifdef DEBUG
std::print("NUM split at feat {} num {}\n", idx, samples[i].nums[idx]);
#endif
for (size_t k = 0; k < this->labels; ++k) {
// left
#ifdef DEBUG
std::print("class {}, now counts {}\n", k, now_counts[k]);
#endif
if (now_counts[k] > 0) {
auto pmf_left = 1.0 * now_counts[k] / (i + 1);
loss -= 1.0*(i+1)/n * pmf_left * std::log(pmf_left);
#ifdef DEBUG
std::print("left total {}, class {}, loss contri {}\n", i + 1,
now_counts[k], -pmf_left * std::log(pmf_left));
#endif
}
// right
if (counts[k] - now_counts[k] > 0) {
auto pmf_right = 1.0 * (counts[k] - now_counts[k]) / (n - i - 1);
loss -= 1.0*(n-i-1)/n * pmf_right * std::log(pmf_right);
#ifdef DEBUG
std::print("right total {}, class {}, loss contri {}\n", n - i - 1,
n - now_counts[k], -pmf_right * std::log(pmf_right));
#endif
}
}
#ifdef DEBUG
std::print("init loss {} split at feat {} num {}, loss {}\n", init_loss,
idx, samples[i].nums[idx], loss);
puts("");
#endif
if (loss < minloss) {
minloss = loss;
split_type = 1;
threshold = samples[i].nums[idx];
feat_idx = idx;
}
}
}
}
if (this->feat_cats > 0) {
now_counts.resize(this->labels);
for (size_t idx = 0; idx < this->feat_cats; ++idx) {
// sort by this feature
std::sort(samples, samples + n, [&](Sample const &a, Sample const &b) {
return a.cats[idx] < b.cats[idx];
});
// now calculate loss dynamically
double loss = 0;
// reset to zero
for (size_t i = 0; i < n; ++i) {
#ifdef DEBUG
std::print("CAT split at feat {} cate {}\n", idx, samples[i].cats[idx]);
#endif
std::memset(now_counts.data(), 0, sizeof(size_t) * this->labels);
// calculate counts in this cateogry
now_counts[samples[i].label]++;
size_t to = i;
while (to + 1 < n &&
samples[to + 1].cats[idx] == samples[to].cats[idx]) {
++to;
now_counts[samples[to].label]++;
}
auto now_size = to - i + 1;
auto remain_size = n - now_size;
i = to;
loss = 0;
for (size_t k = 0; k < this->labels; ++k) {
#ifdef DEBUG
std::print("nowcount[{}] = {}, counts[{}] = {}\n", k, now_counts[k], k, counts[k]);
#endif
if (now_counts[k] > 0) {
auto pmf_now = 1.0 * now_counts[k] / now_size;
loss -= 1.0 * now_size / n * pmf_now * std::log(pmf_now);
#ifdef DEBUG
std::print("left total {}, class {}, loss contri {}\n", now_size,
now_counts[k], -pmf_now * std::log(pmf_now));
#endif
}
if (counts[k] - now_counts[k] > 0) {
auto pmf_remain = 1.0 * (counts[k] - now_counts[k]) / remain_size;
loss -= 1.0 * remain_size / n * pmf_remain * std::log(pmf_remain);
#ifdef DEBUG
std::print("right total {}, class {}, loss contri {}\n", remain_size,
counts[k] - now_counts[k], -pmf_remain * std::log(pmf_remain));
#endif
}
}
if (loss < minloss) {
minloss = loss;
split_type = 2;
cate = samples[i].cats[idx];
feat_idx = idx;
}
}
}
}
if ((1.0*(init_loss - minloss)/init_loss <= this->options.min_purity_decrease)) {
#ifdef DEBUG
std::print("init {}, min {} purged by purity", init_loss, minloss);
#endif
return std::make_unique<TreeNode>(leaf_node(most_freq));
}
if (split_type == 1) {
// partition samples by threshold
size_t p = 0, q = 1;
while (samples[p].nums[feat_idx] <= threshold)
++p;
q = p + 1;
while (q < n) {
if (samples[q].nums[feat_idx] <= threshold) {
std::swap(samples[p], samples[q]);
++p;
}
++q;
}
#ifdef DEBUG
std::print("Best NUM, threshold {} feat {} minloss {}\n", threshold, feat_idx, minloss);
#endif
if (p < this->options.min_samples_leaf ||
n - p < this->options.min_samples_leaf) {
return std::make_unique<TreeNode>(leaf_node(most_freq));
}
// now [0, p) <= threshold, [p, n) > threshold
auto cur = std::make_unique<TreeNode>(num_node(feat_idx, threshold));
cur->ls = this->_fit(p, samples, depth + 1);
cur->rs = this->_fit(n - p, samples + p, depth + 1);
return cur;
}
// partition sample by category
size_t p = 0, q = 1;
while (q < n) {
if (samples[q].cats[feat_idx] == cate) {
std::swap(samples[p], samples[q]);
++p;
}
++q;
}
#ifdef DEBUG
std::print("Best CAT, cateogry {} feat {} minloss {}\n", cate, feat_idx, minloss);
#endif
if (p < this->options.min_samples_leaf ||
n - p < this->options.min_samples_leaf) {
return std::make_unique<TreeNode>(leaf_node(most_freq));
}
auto cur = std::make_unique<TreeNode>(cat_node(feat_idx, cate));
cur->ls = this->_fit(p, samples, depth + 1);
cur->rs = this->_fit(n - p, samples + p, depth + 1);
return cur;
}