跳转到内容

DecisionTree

Updated: at 04:44

DecisionTree,决策树,是一种 non-parametric supervised learning 方法,应该说是 Machine Learning 中相当流行的方法了。

当我们人在做决策的时候,我们是如何思考的?

很合理。实际上,对于任意的决策,都有一系列的思考路径 A->B->C->…,将每一个思考并选择的事件视作是一种分叉节点,得到最终思考结果的位置作为叶子,我们就得到了一棵决策树。决策树本身结构很简单,但表达力是极强的,理论上可以拟合任意的决策过程。

当然,只是理论上。实际上决策树的训练采用的是启发式贪心算法,所以也没那么厉害。

决策树有许多的优点。它非常灵活,可以拟合非线性关系,可以可视化出每一个节点,还算可以理解。对数据没什么要求,拿起来就能用。能处理 Classification 和 Regression,能处理 numerical variable 和 categorical variable. 可以处理 multi output…

sounds great,那么,代价是什么呢?

一般而言,决策树适用于:

决策树相对而言还是没那么好用,但其衍生的随机森林之类的 Ensemble Methods 可以说是超级流行的 ML 算法了。

image

image

Algorithms

Model

决策树的基本思想是:每个节点是一个 condition,比如说 x1>0.5x_1 > 0.5,又比如 x2<0.3x_2 < 0.3. 组成的一棵树,最终的叶子节点就是 output.

决策树递归地划分 feature space.

设被归类至节点 m 处的数据为 QmQ_m,包含 nmn_m 个样本,对于每一个可能的 split θ=(j,tm)\theta = (j, t_m),代表选择第 jj 个 feature,设置阈值为 tmt_m,即在这个节点使用条件 xjtmx_j \le t_m 划分样本。

如果 split 选定了,就可以划分得到:

Qmleft(θ)={(x,y)xjtm}Qmright(θ)={(x,y)xj>tm}\begin{aligned} Q^{\text{left}}_m(\theta) = \{(\mathbf{x},y) \mid x_j \le t_m \} \\ Q^{\text{right}}_m(\theta) = \{(\mathbf{x},y) \mid x_j > t_m \} \end{aligned}

可以使用 impurity function 或者 loss function H()H() 来评估划分的质量。这里的评估需要根据具体问题来选择, classification 和 regression 是不一样的。可以表述为:

G(Qm,θ)=nmleftnmH(Qmleft(θ))+nmrightnmH(Qmright(θ))G(Q_m, \theta) = \dfrac{n^{\text{left}}_m}{n_m} H\left( Q^{\text{left}}_m(\theta) \right) + \dfrac{n^{\text{right}}_m}{n_m} H\left( Q^{\text{right}}_m(\theta) \right)

选择最优的参数使 impurity 最小:

θ=arg minθG(Qm,θ)\theta^* = \argmin _\theta G(Q_m, \theta)

递归处理子集直到达到最大深度。


如果是 Classification 问题的话,分类结果可能为 0,1,,K10,1,\cdots,K-1,对于节点 mm,记:

pmk=1nmyQmI(y=k)p_{mk} = \dfrac{1}{n_m}\sum \limits _{y\in Q_m} I(y=k)

为 class kk 在节点 mm 中的数量。若 mm 是一个终止节点,则相应的 class probability 设置为 pmkp_{mk}.

常用的 impurity 度量:

Gini:

H(Qm)=kpmk(1pmk)H(Q_m) = \sum \limits _k p_{mk}(1-p_{mk})

Log Loss 或 Entropy:

H(Qm)=kpmklog(pmk)H(Q_m) = - \sum \limits _k p_{mk} \log (p_{mk})

使用信息熵等价于最小化 Log Loss.


如果是 Regression 的话,常用的指标包括 MSE、Mean Poisson deviance 和 Mean Absolute Error.

Training

一般而言,大家使用的是迭代式的贪心算法。常见的实现包括 CART、C5.0、ID3.

不管了,我自己写一个吧。

在每一个节点,对于每一个 feature xix_i,都有一系列可能的取值。排序去重后,得到一个 vi\bm{v}_i.

对于 numerical variable,vijRv_{ij} \in \mathbb{R},对于 categorical,vijv_{ij} 则是一个 class. 这两种需要分开看待。

对于 numerical feature,我们需要遍历可能的阈值 tt,筛选 (x,y) satisfies xivit(\bm{x},y) \text{ satisfies } x_i \le v_{it},并且选择最优的 tt 使得 loss 最小。

对于 categorical feature,我们需要遍历所有可能的 class tvit \in \bm{v}_i,筛选 xi=tx_i = txitx_i \neq t 两个分类,然后选择最优的 tt 使得 loss 最小。

每一步都如此操作,就需要遍历 O(features×nm)O(|\text{features}| \times n_m) 次,过程中 loss 是可以动态更新的。总的复杂度为 O(features×n×depth)O(|\text{features}| \times n \times \text{depth}).

实现的时候,每个 feature 逐个设置 threshold,更新 left tree 的 class counts. 排序这里可以预处理离散化,然后用桶排序做到线性,如果直接排序的话总复杂度会退化到 O(features×nlogn×depth)O(|\text{features}| \times n\log n \times \text{depth}).

为了省事,还是这么做了。实现得很粗糙。

一般大家还会设置一些常用的限制,比如 max depth、min samples split、min samples leaf、min purity decrease 等。

std::unique_ptr<TreeNode> DecisionTree::_fit(size_t n, Sample *samples,
                                             int depth) {
  #ifdef DEBUG
  puts("");
  puts("");
  printf("fitting with %lu samples\n", n);
  #endif
  std::vector<size_t> counts;
  std::vector<size_t> now_counts;
  counts.resize(this->labels);

  // count all labels
  for (size_t idx = 0; idx < n; ++idx) {
    counts[samples[idx].label]++;
  }

  // check if all labels are the same
  size_t most_freq = 0;
  for (size_t idx = 0; idx < this->labels; ++idx) {
    if (counts[idx] == n) {
      return std::make_unique<TreeNode>(leaf_node(idx));
    }
    if (counts[idx] > counts[most_freq]) {
      most_freq = idx;
    }
  }

  if (depth >= this->options.max_depth ||
      n <= this->options.min_samples_split) {
    return std::make_unique<TreeNode>(leaf_node(most_freq));
  }
  int split_type = 0; // 1 for num, 2 for cat
  size_t feat_idx = 0;
  double threshold = 0;
  int cate = 0;

  double minloss = 1e20;
  double init_loss = 0;
  for (size_t k = 0; k < this->labels; ++k) {
    init_loss -= (counts[k] * 1.0 / n) * log(counts[k] * 1.0 / n);
  }

  if (this->feat_nums > 0) {
    now_counts.resize(this->labels);
    for (size_t idx = 0; idx < this->feat_nums; ++idx) {
      // sort by this feature
      std::sort(samples, samples + n, [&](Sample const &a, Sample const &b) {
        return a.nums[idx] < b.nums[idx];
      });
      // now calculate loss dynamically
      double loss = 0;
      // reset to zero
      std::memset(now_counts.data(), 0, sizeof(size_t) * this->labels);
      double last = samples[0].nums[idx];
      for (size_t i = 0; i < n; ++i) {
        // split between [0, i] and (i, n)
        // only ends on unique values
        now_counts[samples[i].label]++;
        while (i + 1 < n && samples[i + 1].nums[idx] == samples[i].nums[idx]) {
          ++i;
          now_counts[samples[i].label]++;
        }
        // all nodes in left tree, break
        if (i >= n - 1)
          break;
        // left tree size i + 1, right tree size n - i - -1
        loss = 0;
        #ifdef DEBUG
        std::print("NUM split at feat {} num {}\n", idx, samples[i].nums[idx]);
        #endif 
        for (size_t k = 0; k < this->labels; ++k) {
          // left
        #ifdef DEBUG
          std::print("class {}, now counts {}\n", k, now_counts[k]);
        #endif 
          if (now_counts[k] > 0) {
            auto pmf_left = 1.0 * now_counts[k] / (i + 1);
            loss -= 1.0*(i+1)/n * pmf_left * std::log(pmf_left);
        #ifdef DEBUG
            std::print("left total {}, class {}, loss contri {}\n", i + 1,
                       now_counts[k], -pmf_left * std::log(pmf_left));
        #endif 
          }
          // right
          if (counts[k] - now_counts[k] > 0) {
            auto pmf_right = 1.0 * (counts[k] - now_counts[k]) / (n - i - 1);
            loss -= 1.0*(n-i-1)/n * pmf_right * std::log(pmf_right);
        #ifdef DEBUG
            std::print("right total {}, class {}, loss contri {}\n", n - i - 1,
                       n - now_counts[k], -pmf_right * std::log(pmf_right));
        #endif 
          }
        }
        #ifdef DEBUG
        std::print("init loss {} split at feat {} num {}, loss {}\n", init_loss,
                   idx, samples[i].nums[idx], loss);
        puts("");
        #endif 
        if (loss < minloss) {
          minloss = loss;
          split_type = 1;
          threshold = samples[i].nums[idx];
          feat_idx = idx;
        }
      }
    }
  }
  if (this->feat_cats > 0) {
    now_counts.resize(this->labels);
    for (size_t idx = 0; idx < this->feat_cats; ++idx) {
      // sort by this feature
      std::sort(samples, samples + n, [&](Sample const &a, Sample const &b) {
        return a.cats[idx] < b.cats[idx];
      });
      // now calculate loss dynamically
      double loss = 0;
      // reset to zero
      for (size_t i = 0; i < n; ++i) {
        #ifdef DEBUG
        std::print("CAT split at feat {} cate {}\n", idx, samples[i].cats[idx]);
        #endif
        std::memset(now_counts.data(), 0, sizeof(size_t) * this->labels);
        // calculate counts in this cateogry
        now_counts[samples[i].label]++;
        size_t to = i;
        while (to + 1 < n &&
               samples[to + 1].cats[idx] == samples[to].cats[idx]) {
          ++to;
          now_counts[samples[to].label]++;
        }
        auto now_size = to - i + 1;
        auto remain_size = n - now_size;
        i = to;
        loss = 0;
        for (size_t k = 0; k < this->labels; ++k) {
        #ifdef DEBUG
          std::print("nowcount[{}] = {}, counts[{}] = {}\n", k, now_counts[k], k, counts[k]);
          #endif
          if (now_counts[k] > 0) {
            auto pmf_now = 1.0 * now_counts[k] / now_size;
            loss -= 1.0 * now_size / n * pmf_now * std::log(pmf_now);
        #ifdef DEBUG
            std::print("left total {}, class {}, loss contri {}\n", now_size,
                       now_counts[k], -pmf_now * std::log(pmf_now));
            #endif
          }
          if (counts[k] - now_counts[k] > 0) {
            auto pmf_remain = 1.0 * (counts[k] - now_counts[k]) / remain_size;
            loss -= 1.0 * remain_size / n * pmf_remain * std::log(pmf_remain);
        #ifdef DEBUG
            std::print("right total {}, class {}, loss contri {}\n", remain_size,
                       counts[k] - now_counts[k], -pmf_remain * std::log(pmf_remain));
            #endif
          }
        }
        if (loss < minloss) {
          minloss = loss;
          split_type = 2;
          cate = samples[i].cats[idx];
          feat_idx = idx;
        }
      }
    }
  }
  if ((1.0*(init_loss - minloss)/init_loss <= this->options.min_purity_decrease)) {
    #ifdef DEBUG
    std::print("init {}, min {} purged by purity", init_loss, minloss);
    #endif
    return std::make_unique<TreeNode>(leaf_node(most_freq));
  }
  if (split_type == 1) {
    // partition samples by threshold
    size_t p = 0, q = 1;
    while (samples[p].nums[feat_idx] <= threshold)
      ++p;
    q = p + 1;
    while (q < n) {
      if (samples[q].nums[feat_idx] <= threshold) {
        std::swap(samples[p], samples[q]);
        ++p;
      }
      ++q;
    }
    #ifdef DEBUG
    std::print("Best NUM, threshold {} feat {} minloss {}\n", threshold, feat_idx, minloss);
    #endif
    if (p < this->options.min_samples_leaf ||
        n - p < this->options.min_samples_leaf) {
      return std::make_unique<TreeNode>(leaf_node(most_freq));
    }
    // now [0, p) <= threshold, [p, n) > threshold
    auto cur = std::make_unique<TreeNode>(num_node(feat_idx, threshold));
    cur->ls = this->_fit(p, samples, depth + 1);
    cur->rs = this->_fit(n - p, samples + p, depth + 1);
    return cur;
  }
  // partition sample by category
  size_t p = 0, q = 1;
  while (q < n) {
    if (samples[q].cats[feat_idx] == cate) {
      std::swap(samples[p], samples[q]);
      ++p;
    }
    ++q;
  }
  #ifdef DEBUG
  std::print("Best CAT, cateogry {} feat {} minloss {}\n", cate, feat_idx, minloss);
  #endif
  if (p < this->options.min_samples_leaf ||
      n - p < this->options.min_samples_leaf) {
    return std::make_unique<TreeNode>(leaf_node(most_freq));
  }
  auto cur = std::make_unique<TreeNode>(cat_node(feat_idx, cate));
  cur->ls = this->_fit(p, samples, depth + 1);
  cur->rs = this->_fit(n - p, samples + p, depth + 1);
  return cur;
}


上一篇
Useful Tool Notes
下一篇
RSA