决策树回归
决策树模型不仅可以用来进行分类,同样也可以用来进行回归。特征划分的标准即是计算按照每个候选点进行划分后数据的最小二乘误差。决策树的输出即为叶子节点上数据的均值。
回归树的特征划分方法
- 对每一个特征中相邻的数据取均值,作为候选切分点。假设特征有a个取值,则有a - 1 个候选切分点(回归树采用与分类树中连续变量切分类似的方法)。
- 然后针对每个切分点,将该特征的数据分成两部分,r1和r2。
- 计算两部分中数据的均值c1和c2。
- 对两部分做最小二乘。损失函数为为(y - 均值)^2,再求和。
- 将两部分最小二乘的结果相加,得到每个候选切分点的损失函数。
- 找出损失函数最小的切分点,作为该特征的切分点。
- 以此类推,进行特征切分。直到结束。
回归树的输出值
每个叶子节点的取值为该叶子节点上所有数据的均值。损失函数通过最小二乘法得到。
实例
为了便于理解,下面举一个简单实例。训练数据见下表,目标是得到一棵最小二乘回归树。
x | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
y | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 | 8.9 | 8.7 | 9 | 9.05 |
- 选择最优切分变量j与最优切分点s。
在本数据集中,只有一个变量,因此最优切分变量自然是x。
接下来我们考虑9个切分点[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5],切分点的选择即为每两个相邻x的均值。
损失函数定义为平方损失函数:
\[Loss(y,f(x))=\sum(f(x)−y)^2\],将上述9个切分点依次 代入下面的公式,其中 $c_m$为每个部分所有数据的均值。
例如,取 s=1.5。此时 R1={1},R2={2,3,4,5,6,7,8,9,10},这两个区域的输出值分别为: $c_1$=5.56,$c_2$=(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)/9=7.50。得到下表:
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
$c_1$ | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 | 6.24 | 6.62 | 6.88 | 7.11 |
$c_2$ | 7.5 | 7.73 | 7.99 | 8.25 | 8.54 | 8.91 | 8.92 | 9.03 | 9.05 |
把$c_1$,$c_2$的值代入到上式,如:m(1.5)=0+15.72=15.72。同理,可获得下表:
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | .5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
m(s) | 15.72 | 12.07 | 8.36 | 5.78 | 3.91 | 1.93 | 8.01 | 11.73 | 15.74 |
显然取 s=6.5时,m(s)最小。因此,第一个划分变量j=x,s=6.5 用选定的(j,s)划分区域,并决定输出值。
两个区域分别是:R1={1,2,3,4,5,6},R2={7,8,9,10}输出值$c_1$=6.24,$c_2$=8.91(平均值)
对R1继续进行划分:
x | 1 | 2 | 3 | 4 | 5 | 6 |
---|---|---|---|---|---|---|
y | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 |
取切分点[1.5,2.5,3.5,4.5,5.5],则各区域的输出值c如下表
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 |
c2 | 6.37 | 6.54 | 6.75 | 6.93 | 7.05 |
计算m(s):
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
m(s) | 1.3087 | 0.754 | 0.2771 | 0.4368 | 1.0644 |
s=3.5时m(s)最小。
之后的过程不再赘述。
假设在生成3个区域之后停止划分,那么最终生成的回归树形式如下:
\[T=\left\{ \begin{aligned} x & = & 5.72 (x<3.5)\\ y & = & 6.75(3.5<=x<6.5) \\ z & = & 8.91(x>6.5) \end{aligned} \right.\]参考
- https://blog.csdn.net/weixin_40604987/article/details/79296427