讲透一个强大算法模型,决策树!!

文摘   2024-09-25 15:26   北京  

哈喽,我是cos大壮~

这些天,很多同学开始接触机器学习、学习机器学习。

我想着,还是把一些基础的内容,能够讲透,讲彻底给到大家。

今天想要和大家分享的是机器学习中,最重要的算法之一:决策树 !~

简单说,决策树是一种帮助我们做决策的工具,它可以把复杂的选择过程简单地拆分成一个个的步骤,就像一棵倒挂的树一样。每个分支是一个问题或判断,每个叶子是一个最终的结果。

老规矩如果大家伙觉得近期文章还不错!欢迎大家点个赞、转个发,文末赠送《机器学习学习小册》

文末可取本文PDF版本~

这里举一个非常浅显的例子,大家就懂了~

想象你在一家冰淇淋店,得做出选择:要不要加巧克力?要不要加草莓酱?冰淇淋是选草莓味的还是香草味的?决策树就像是一个图表,帮你把每一步的选择列出来。比如:

1. 第一个判断:你想吃冰淇淋吗?如果是,那就继续选下一个问题。如果不是,那你就不需要再选择了。

2. 第二个判断:你想吃巧克力味的吗?如果是,往巧克力分支走;如果不是,走另一个分支。

3. 继续细化选择:根据每个问题的回答,继续分支,直到你最后得出“买草莓冰淇淋加草莓酱”这样的结果。

在机器学习中,决策树用于预测或分类。当我们有一堆数据时,它能自动生成这个“选择图表”,帮助我们判断每个数据属于哪个类别,或者预测某个事情的结果。

总结起来,决策树就是把一个复杂的问题分成简单的“是或否”问题,通过不断地做判断,最终得到一个结果。

理论基础

决策树,尤其适用于分类和回归问题。它通过一系列的条件判断,将数据逐步分割成更小的子集,直到每个子集中的数据可以归类到某个类别或者给出预测值。为了详细解释决策树的数学原理,我们将逐步介绍它的核心思想、公式推理和算法流程。

1. 决策树的数学原理

决策树通过“划分”数据来进行分类或预测。这种划分方式通常基于对某个特征的值进行条件判断。每次划分都会产生两个或多个子集,每个子集的数据在某种程度上更加“纯净”或者同质化。

信息熵 (Entropy)

为了度量某个节点上数据的纯度,决策树算法通常会使用信息熵(Entropy)作为衡量标准。信息熵衡量了数据的不确定性或混乱程度。假设我们有一个类别为  的分类问题,节点上每个类别  的概率为 ,则该节点的信息熵定义为:

其中:

  •  是当前节点的样本集合;
  •  是属于类别  的样本比例;
  • 信息熵越高,数据越混乱,信息熵为 0 表示该节点的样本完全属于一个类别。

信息增益 (Information Gain)

为了确定最佳划分点,决策树需要衡量每次划分后的数据纯度的提高程度。这个度量标准就是信息增益,它表示划分前后信息熵的减少量。

假设当前节点  上的信息熵为 ,我们将该节点划分成子集 ,其中每个子集对应一个不同的划分结果。划分后的信息熵是划分前信息熵的加权和:

其中:

  •  是第  个子集中的样本数量;
  •  是划分前节点中的总样本数量。

信息增益定义为划分前后的信息熵差:

其中  是用于划分的特征。算法会选择信息增益最大的特征作为划分标准。

基尼不纯度 (Gini Impurity)

另一种常用的划分标准是基尼不纯度,特别是在分类树(Classification Tree,CART)中。基尼不纯度表示节点上随机选择两个样本,其类别不一致的概率。基尼不纯度的公式为:

其中  依然是属于类别  的样本比例。划分后的基尼不纯度为:

和信息增益类似,基尼不纯度也会选择最小化划分后基尼不纯度的特征。

回归树中的方差缩减

在回归树中,我们用方差来衡量数据的纯净程度。对于当前节点  上的数据集,方差为:

其中  是节点中第  个样本的目标值, 是该节点上所有样本的目标值平均值。划分后,方差的缩减量可作为划分标准,类似于分类中的信息增益。

2. 决策树的构建算法流程

决策树的构建是一个递归过程,主要步骤如下:

初始设置

1. 输入:数据集 ,其中  是特征, 是目标变量(类别或连续值)。

2. 递归停止条件

  • 所有样本都属于同一类别;
  • 没有特征可供划分;
  • 达到预设的树深度或样本数量阈值。

递归划分

1. 选择最优划分特征

  • 对于每个特征,计算不同划分方式下的信息增益、基尼不纯度或者方差缩减。
  • 选择使目标(信息增益最大化或基尼不纯度最小化)的特征  及其对应的划分阈值。

2. 划分数据集

  • 根据特征  和阈值,将数据集划分成两个或多个子集。

3. 递归构建子树

  • 对每个子集,重复上述步骤,直到满足递归停止条件。

剪枝 (Pruning)

为了防止决策树过拟合,通常会对决策树进行剪枝。剪枝的两种常见方法为:

  • 预剪枝:在树生成的过程中设置最大深度或最小叶节点样本数,提前停止树的生长。

  • 后剪枝:生成一棵完全生长的决策树后,通过评估树的性能(如通过交叉验证),剪掉对泛化性能没有帮助的分支。

总结来说,决策树模型的核心在于通过特征划分来递归地缩小数据的复杂性,并通过度量(信息增益、基尼不纯度、方差等)来决定最佳划分点。

目标函数:通过最大化信息增益(分类树)或最小化基尼不纯度(CART分类树)来选择特征划分点。

递归划分:逐步将数据集划分为纯度更高的子集,直到满足停止条件。

剪枝:防止过拟合,提升泛化能力。

完整案例

这里,咱们用 Python 实现一个决策树模型,并且分析其性能,进行可视化,并通过算法优化提升效果。我们使用 sklearn 中的决策树模型,并应用在真实的数据集上。为了让分析更具体,使用经典的泰坦尼克号数据集来预测乘客的生存情况。

1. 导入必要的库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder

# 忽略警告
import warnings
warnings.filterwarnings("ignore")

2. 加载并探索数据

使用泰坦尼克号数据集,大家可以点击名片,回复「数据集」进行获取~

# 加载泰坦尼克数据集
df = pd.read_csv('titanic.csv')

# 查看前几行数据
print(df.head())

数据集包括以下一些重要特征:

  • survived: 目标变量,1 表示幸存,0 表示未幸存。
  • pclass: 客舱等级,1 表示头等舱,2 表示二等舱,3 表示三等舱。
  • sex: 性别。
  • age: 年龄。
  • sibsp: 一同上船的兄弟姐妹/配偶数量。
  • parch: 一同上船的父母/子女数量。
  • fare: 票价。
  • embarked: 登船港口。

3. 数据预处理

由于原始数据包含一些缺失值和非数值型数据,我们需要进行清洗和编码。

# 数据预处理:填充缺失值
df['Age'].fillna(df['Age'].mean(), inplace=True)
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)

# 编码性别和登船港口
df['Sex'] = df['Sex'].map({'male'0'female'1})
df['Embarked'] = df['Embarked'].map({'C'0'Q'1'S'2})

# 丢弃不必要的列
df.drop(['Cabin''Name''Ticket'], axis=1, inplace=True)

# 查看前几行数据
print(df.head())

4. 特征与目标变量的划分

# 特征与目标变量
X = df.drop('survived', axis=1)
y = df['survived']

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

5. 训练决策树模型

# 初始化决策树分类器
clf = DecisionTreeClassifier(random_state=42)

# 训练模型
clf.fit(X_train, y_train)

# 预测
y_pred = clf.predict(X_test)

# 输出模型的准确率
print("准确率:", accuracy_score(y_test, y_pred))

6. 模型分析

我们可以通过混淆矩阵、分类报告等指标来分析模型的性能。

# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix")
plt.show()

# 分类报告
print("分类报告:")
print(classification_report(y_test, y_pred))

7. 可视化决策树

plt.figure(figsize=(20,10))
plot_tree(clf, feature_names=X.columns, class_names=['Not Survived''Survived'], filled=True)
plt.show()

这将生成一个完整的决策树图,可以直观地看到每个节点如何进行分割。

8. 优化模型:网格搜索

为了提升模型性能,我们可以对决策树的超参数进行调优。我们将使用网格搜索来找到最佳参数组合。

# 定义参数网格
param_grid = {
    'max_depth': [3579None],
    'min_samples_split': [2510],
    'min_samples_leaf': [124],
    'criterion': ['gini''entropy']
}

# 网格搜索
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5, n_jobs=-1, verbose=1)
grid_search.fit(X_train, y_train)

# 输出最佳参数
print("最佳参数:", grid_search.best_params_)

# 使用最佳参数重新训练模型
best_clf = grid_search.best_estimator_
best_clf.fit(X_train, y_train)

# 预测并评估
y_pred_optimized = best_clf.predict(X_test)

print("优化后准确率:", accuracy_score(y_test, y_pred_optimized))

9. 优化后模型的分析与可视化

# 优化后混淆矩阵
cm_optimized = confusion_matrix(y_test, y_pred_optimized)
sns.heatmap(cm_optimized, annot=True, fmt="d", cmap="Greens")
plt.title("Optimized Confusion Matrix")
plt.show()

# 分类报告
print("优化后分类报告:")
print(classification_report(y_test, y_pred_optimized))

# 优化后决策树可视化
plt.figure(figsize=(20,10))
plot_tree(best_clf, feature_names=X.columns, class_names=['Not Survived''Survived'], filled=True)
plt.show()

10. 总结与分析

  • 准确率:初始模型的准确率为 75%左右,优化后模型的准确率提高到了 80%左右。

  • 模型解释:通过决策树可视化,我们可以清晰地看到每个节点是如何根据特征划分数据的,例如性别、年龄、票价等对生存率的影响。

  • 超参数调优:通过网格搜索,我们发现 max_depthmin_samples_split 等超参数的调优能够显著提升模型的泛化能力,避免过拟合或欠拟合。

  • 可视化:混淆矩阵和决策树图让我们更直观地了解模型在不同类别上的表现,以及各特征的分割效果。

模型分析

决策树模型的优缺点

优点

1. 简单易理解:决策树的结构直观且易于解释,特别是对非技术人员。每个节点的分裂基于简单的“是/否”问题,使得模型的每一步决策都很清楚。

2. 无需大量数据预处理:决策树对数据的要求较低,比如不需要像线性回归那样对数据进行归一化处理,也不需要进行复杂的特征编码或归一化。

3. 能够处理非线性数据:决策树能够自动捕捉数据中的非线性关系,而不需要对特征进行复杂的变换。特征之间的非线性相互作用可以通过决策树的层次结构自然处理。

4. 适合多种数据类型:决策树可以处理数值型数据和类别型数据,并且不需要像其他算法那样进行严格的数据转换。

5. 处理缺失数据:决策树对缺失数据不太敏感,可以在缺失数据情况下依然进行有效的分类或回归。

6. 不需要假设分布:决策树不需要假设数据服从特定的分布,例如正态分布,这使得它在应对不同类型的任务时更加灵活。

缺点

1. 容易过拟合:决策树如果不进行剪枝或者限制深度,容易生成一棵过于复杂的树,导致在训练数据上表现很好,但在测试数据上表现不佳。过拟合的树通常对数据的噪音非常敏感。

2. 对数据的小变化敏感:决策树对数据集中的小变化非常敏感。稍微改变一下数据,可能会导致树的结构完全不同,因此决策树的稳定性较差。

3. 偏向于支配性特征:决策树容易偏向于具有更多取值的特征,可能会过度依赖于这些特征的划分。对于具有较多类别的分类变量,决策树倾向于优先使用这些特征进行分裂。

4. 不能很好地处理连续变量:对于回归问题,决策树在连续变量上的表现较差,因为它只能对变量进行简单的区间划分,可能会导致预测值不够精确。

5. 决策边界不灵活:决策树通过垂直或水平的方式进行分割,这意味着它无法有效地处理需要更灵活的分界面的问题。这在某些复杂情况下导致表现较差。


决策树与其他算法的对比

1. 决策树 vs. 随机森林 (Random Forest)

  • 随机森林是基于决策树的集成学习方法。它通过构建多个决策树并对它们的预测结果取平均或投票,来提高模型的泛化能力,降低过拟合的风险。

  • 对比

    • 随机森林通过“集成多棵树”来减少单棵树的缺点(如过拟合),而单棵决策树可能对训练数据过于敏感。
    • 决策树简单易解释,而随机森林由于集成了多棵树,解释性较差。
  • 何时使用随机森林:当希望提升模型稳定性和预测能力,且不需要过多解释性时,随机森林是更好的选择。

2. 决策树 vs. 梯度提升树 (Gradient Boosting Trees, GBT)

  • 梯度提升树也是一种基于决策树的集成方法,它通过逐步添加决策树来修正前面树的错误。它通常比单一的决策树和随机森林性能更好,但训练速度较慢。

  • 对比

    • 决策树容易过拟合,梯度提升树则通过学习多棵树来逐渐提高模型的精度,表现更好。
    • 决策树训练速度快,而梯度提升树需要迭代多次,因此计算开销更大。
  • 何时使用梯度提升树:当模型精度要求高,并且可以忍受较长的训练时间时,梯度提升树是优选算法。

3. 决策树 vs. 支持向量机 (SVM)

  • 支持向量机通过构建超平面来将数据分类。它对于高维数据和非线性数据表现良好,尤其是使用核技巧时。

  • 对比

    • 决策树更易于解释,而支持向量机的解释性较差,因为它依赖于抽象的超平面和距离度量。
    • 决策树适合处理混合类型的数据,而 SVM 更适合处理数值型数据,尤其是在特征空间较为复杂的情况下。
  • 何时使用 SVM:当数据维度较高,且希望得到较好的边界分类时,SVM 是更好的选择。

4. 决策树 vs. K近邻算法 (KNN)

  • K近邻算法是一种基于实例的学习方法,直接使用训练样本进行分类或回归。它根据最近的  个样本来决定预测结果。

  • 对比

    • 决策树学习全局模式,而 KNN 仅根据局部邻居进行预测。决策树更适合复杂的特征划分,而 KNN 在简单、平滑的决策边界问题上表现良好。
    • 决策树适合大规模数据集,KNN 对大数据集计算开销较大。
  • 何时使用 KNN:当数据集规模较小、特征维度较少、决策边界平滑时,KNN 是合适的选择。


决策树的适用场景和替代方案

适用场景:

1. 可解释性需求强:在一些需要可解释性的场合(如医疗诊断、金融风控),决策树模型因其直观的规则集结构,便于理解和解释。

2. 小数据集:决策树在处理中小型数据集时表现较好,训练速度快且无需大量预处理。

3. 数据类型多样:决策树能够处理混合数据类型(数值型和类别型数据),且对缺失值不敏感,因此适合处理数据类型复杂的问题。

替代方案的适用场景:

1. 需要更高的预测精度:如果预测准确率是主要目标,且对模型的可解释性要求不高,可以考虑随机森林、梯度提升树等集成学习方法。

2. 数据量大且噪音较多:在大数据集且噪音较多的情况下,决策树容易过拟合。集成方法(如随机森林、梯度提升树)可以通过多树结构降低噪声对模型的影响。

3. 非线性复杂边界:对于复杂的非线性分类问题,支持向量机或神经网络往往能提供比决策树更优的性能。

最后

大家有问题可以直接在评论区留言即可~

喜欢本文的朋友可收藏、点赞、转发起来!

需要本文PDF的同学,扫码备注「基础算法」即可~ 

关注本号,带来更多算法干货实例,提升工作学习效率!
最后,给大家准备了《机器学习学习小册》PDF版本16大块的内容,124个问题总结
100个超强算法模型,大家如果觉得有用,可以点击查看~

推荐阅读

原创、超强、精华合集
100个超强机器学习算法模型汇总
机器学习全路线
机器学习各个算法的优缺点
7大方面,30个最强数据集
6大部分,20 个机器学习算法全面汇总
铁汁,都到这了,别忘记点赞呀~

深夜努力写Python
Python、机器学习算法
 最新文章