xgboost + shap可加性解释(R版本):优秀的机器学习解决方案

文摘   科学   2024-07-15 08:42   湖南  

大家好,我是北游。如果你在做一份预测模型相关研究,大概率不会对xgboost + shap感到陌生。如果还不太了解的话,不妨看看我们去年发布的一些推文。链接如下:

        临床预测模型建模神器:XGBoost!

        临床预测模型进阶:XGBoost的自动化调参调优方法

        python机器学习:超越随机森林(XGBoost篇)

        预测模型| 黑箱模型解释困难?一种流行的模型解释方法

总之就是比较零散,包括现有的各类机器学习资料或网络资源,我认为也有一定的问题,因为很多资料只是分享了建模、调参或验证等步骤,相对较为零散,但我们真正要做研究,是需要从头到尾进行数据处理的。各种辛酸,只有亲身经历过才懂。

本文分享了XGBoost的建模、评价、解释相关代码,虽然也不是非常完整或正确,但起码连起来了。姑且读之。




一、数据预处理

笔者从竞赛平台kaggle获取了一份关于泰坦尼克号乘客生存情况的数据,样本量不大,几百行,比较适合用于演示各类机器学习算法,有需要的朋友可以回复关键词自动获取。但是大多数机器学习算法对数据格式有些要求,所以我们可以先探索下这份数据并进行一些预处理,使之可用性变得更好。
这些预处理包括选取关注的变量、因子化、分层随机拆分、填补缺失值,具体步骤如下(只是简单处理,未必专业,仅供参考)。
library(tidyverse)library(tidymodels)# 导入数据,后台回复“泰坦尼克”即可获取data = read_csv('train.csv')

#### 数据清洗 ##### 选取几个变量,并对分类变量进行因子化varname = colnames(data)data = data %>% select(Survived,Pclass,Sex,Age,SibSp,Embarked) %>% mutate( across(c('Survived','Pclass','Sex','Embarked'),factor) )str(data)
# 查看类别比例table(data$Survived)

# 分割数据集(分层随机)set.seed(123)data_split = initial_split(data,strata = Survived)data_train = training(data_split)data_test = testing(data_split)
table(data_train$Survived)table(data_test$Survived)

#查看缺失值is.na(data_train) %>% colSums()is.na(data_test) %>% colSums()

# 简单插补mean_value = data_train %>% summarise(mean_value = mean(Age, na.rm = T)) %>% pull(mean_value)
# 使用mutate和if_else函数进行均值插补newtrain = data_train %>% mutate(Age = if_else(is.na(Age), mean_value, Age))
is.na(newtrain) %>% colSums()# 对分类变量进行随机插补newtrain = newtrain %>% mutate(Embarked = if_else(is.na(Embarked), sample(na.omit(Embarked), size = 1, replace = F), Embarked))is.na(newtrain) %>% colSums()
# 对测试集插补
m_value = data_test %>% summarise(m_value = mean(Age, na.rm = T)) %>% pull(m_value)newtest = data_test %>% mutate(Age = if_else(is.na(Age), m_value, Age))is.na(newtest) %>% colSums()
# 保存备用,需要提前创建datasets文件夹save(newtrain, file = "datasets/train.rdata")save(newtest, file = "datasets/test.rdata")

二、xgboost建模

xgboost是GBDT (Gradient Boosting Decision Tree) 的一种高效实现,曾经一度是机器学习(集成学习)领域最强方法,强于随机森林等一众经典模型。放在现在来看,如果不与深度学习众多方法比较的话,性能也是很强悍的,尤其适合我们这种样本量通常不大的临床预测模型。当然,现在也有了一些可以与之媲美甚至超越的模型,主流的有lightgbm、catboost等,这些会在后面介绍。

在介绍如何用shap对这类黑箱模型进行解释之前我们回顾下xgboost建模方法。

#### 建模 ####library(xgboost)library(Matrix)
# 设计稀疏矩阵,分类变量设置哑变量,标签转数值head(newtrain)str(newtrain)sparse_matrix = sparse.model.matrix(Survived ~ ., data = newtrain)[,-1]head(sparse_matrix)output_vector = as.numeric(newtrain$Survived)-1
# 建模(不调参)xgb = xgboost(data = sparse_matrix, label = output_vector, max_depth = 4, eta = 1, nthread = 2, nrounds = 10,objective = "binary:logistic")


上面的代码没有调参,如果想要调整超参数,可以用caret、mlr3、tidymodels等框架来做。

查看下哪些自变量比较重要:

# 特征重要性排序imp = xgb.importance(colnames(sparse_matrix),xgb)head(imp)xgb.plot.importance(imp)

xgboost包自带的特征重要性排序函数比较简陋,也可以提取重要性数值自己用ggplot2画,会美观很多。

下面做个简单的模型评价(混淆矩阵及其他)。

#### 模型评价 ##### 转换测试集test_sparse_matrix = sparse.model.matrix(Survived ~ ., data = newtest)[,-1]head(test_sparse_matrix)test_output_vector = as.numeric(newtest$Survived)-1
# 获取预测值pred = predict(xgb, test_sparse_matrix)
# 输出混淆矩阵library(caret)confusionMatrix(as.factor(ifelse(pred>0.5,1,0)), newtest$Survived)

基本的指标有了。如果想要获得更多的指标,包括精确度、召回率、F1分数等,可以使用tidymodels生态的yardstick包实现。至于临床预测模型中很常用的一些可视化技术,比如ROC曲线、校准曲线、DCA、HL检验、brier评分等,方法与常规预测模型的做法类似,大家可以自行摸索,如确实需要私人辅导,也可以联系我们。

、基于SHAP的模型解释

这里我们演示下SHAPforxgboost包的用法,第一步依然是按照并导入改包(没演示了)。由于前面的建模是通过matrix函数作的分类转数值(one-hot编码),笔者尝试了,不适用于SHAPforxgboost包,干脆重新建模(实际别这么干)。

# xgb+shap解释  重新建模library(SHAPforxgboost)X = data.matrix(newtrain[, -1])dtrain = xgb.DMatrix(X, label = newtrain$Survived)fit = xgb.train(  params = list(    objective = "reg:squarederror",    learning_rate = 0.1  ),   data = dtrain,  nrounds = 50)

计算shap值:

shap = shap.prep(fit,X_train=X)

可视化:

shap.plot.summary(shap)

解释:初次接触的同学可能不懂如何看此图,其实很简单的,蓝色表示高(对预测有正面影响,如果预测的是疾病或并发症,可以认为是危险因素),黄色反之。有些工具输出的图是用红色表示高的,没关系,一样解读。每个散点其实就是每个样本(患者),所以主要看这些散点集中的位置,距离0越远,说明这个特征(自变量)的影响越大。

再画个条形图:

shap.plot.summary(shap, kind = "bar")

其他图也可以自己摸索下,这个包还是比较好用的,推荐。

另外,我们也可能会在新开的同名视频号中,以视频的形式分享更多的预测模型知识,大家不妨关注下。这期内容不少,暂时就写到这里了,我们下期再见!

觉得此文有帮助的话,可以帮忙点个赞吗?谢谢


主要参考来源:

相关R包(xgboost、SHAPforxgboost)的帮助文档;

免责声明:仅供科研学习和分享,不做商业用途,如有侵权,请联系我们删除,谢谢。

补充:护理统计随笔平台的内容现在已经非常丰富,很多方面都有涉及到,如果你觉得你没有看到往期的相关文章,不妨打开公众号的菜单页,在各级目录中查找你想要的内容。如果你在科研学习中遇到了疑问,恰好也想跟网友们交流,可以加入我们建立的“护理科研交流群”。这是一个完全自由、开放、没有套路的纯交流群。入群方式:私信发送关键词“加群”。


护理统计随笔
专注护理科研设计和统计分析。别人不会告诉你的干货,可以来这里找!
 最新文章