比随机森林强! 利用catboost做临床预测模型

文摘   科学   2024-07-21 20:35   湖南  

大家好。新的机器学习方法层出不穷,那么,这些方法对于我们护理或医学研究领域有什么启发呢?对于这个问题,笔者翻阅了不少机器学习相关的SCI文献,总结了一点小心得:拥抱新技术,积极探索交叉领域,精准干预离不开精准且实用的预测模型。


一、CatBoost的亮点

CatBoost是什么?CatBoost是一种基于梯度提升决策树的机器学习算法,由俄罗斯的Yandex公司开发。Yandex公司官方的话来说,CatBoost是一个高性能开源库,是基于决策树的梯度提升算法实现。
机器学习算法那么多,CatBoost有什么优点吗?当然是有的,而且这个优点非常有意思,笔者很感兴趣!不然也不会去分享这种技术了。下面,重头戏来了,根据官网给的信息,CatBoost有如下主要特性

无需参数调整即可实现卓越品质

分类功能支持

快速且可扩展的 GPU 版本

提高准确性,减少过拟合

快速预测

下图是官方给出的基准测试,展示的是与其他几个流行框架之间的比较结果。表中的数字表示分类模式的对数损失值(越低越好)。百分比是根据调整后的 CatBoost 结果测量的指标差异。

二、CatBoost的R语言版本
CatBoost作为开源框架,不仅支持python语言,也支持R语言,所以提供了可以让我们轻松开发和测试模型的包(CatBoost包),所以本期笔者会演示R语言的实现方法,大家不妨与笔者一起体验下这款看起来很亮眼的模型,到底是否名副其实。

安装CatBoost包的路子并不是很轻松,因为需要从世界上最大的开源网站gthub下载,由于大家都懂的原因,访问这个网站时经常抽风,所以下面这种方法需要点技巧,有网络要求。
install.packages('remotes')remotes::install_url('https://github.com/catboost/catboost/releases/download/v1.2.5/catboost-R-windows-x86_64-1.2.5.tgz',                     INSTALL_opts = c("--no-multiarch", "--no-test-load"))

如果你已经成功安装这个R包,那么恭喜你,已经成功了一半,因为它用起来是so easy,正如官方给出的“宣言”:不需要调参,默认的超参数下就有很好的效果。总所周知(个人觉得),机器学习(预测模型)最难的几个地方,有一个就是超参数的调优,不像是logistic或cox这类传统回归模型,大部分机器学习模型是需要探索最优超参数的(期刊也对此有所要求)。
三、CatBoost上手体验
关于数据集:笔者从竞赛平台kaggle获取了一份泰坦尼克号乘客生存情况数据,样本量不大,几百行,比较适合用于演示各类机器学习算法,有需要的朋友可以回复关键词泰坦尼克自动获取文件(注意,这是未做预处理的)
但是大多数机器学习算法对数据格式有些要求,所以笔者对此数据做了点小小的预处理,包括因子化、分层随机拆分、填补缺失值等,当然只是简单处理,未必专业,但也不重要,因为大家如果想做预测模型研究的话,肯定是用自己的数据集来做分析的,本次演示用的数据集仅供参考。数据集的预处理步骤见前几天发布的xgboost+shap推文:xgboost + shap可加性解释(R版本):优秀的机器学习解决方案
处理后的数据集长这样:

除了Age之外,其余全部是因子型(没做编码),其中survived是因变量(用机器学习术语,survived是标签)。
下面是全文最精彩的部分,笔者参考了官网的教程(网站在文末)和R包文档,用演示数据跑了一遍,感觉确实很方便。

#### 导入R包和演示数据 ####library(tidyverse)library(catboost)
# 直接导入R文件,这是上面的泰坦尼克数据预处理后的文档# 为了方便,我把它保存起来了,这里就直接读入,如果你没有# 可以直接读入泰坦尼克数据load('dataset/train.rdata')load('dataset/test.rdata')

vars = newtrain[,-1]y = as.numeric(newtrain$Survived) - 1 #标签需要是0或1

# 需要对数据进行特殊封装preprocessdata = catboost.load_pool(vars,y)
#建模,内部有默认的随机种子0
cat_mod = catboost.train(preprocessdata,NULL,                         param= list(loss_function = 'Logloss',  iterations = 100, metric_period=10, eval_metric='AUC'))
# 查看特征重要性排序catboost.get_feature_importance(cat_mod)
# 交叉验证cat_cv = catboost.cv(preprocessdata, fold_count = 10,param= list(loss_function = 'Logloss', iterations = 100, metric_period=10, eval_metric='AUC'))
# 预测新数据,不需要封装标签进去pooltest = catboost.load_pool(newtest[,-1])
# 预测,可以直接获取概率pp = catboost.predict(cat_mod,pooltest,prediction_type = 'Probability')
# 再获取预测类别,方便做混淆矩阵pclass = catboost.predict(cat_mod, pooltest, prediction_type = "Class")
# 混淆矩阵caret::confusionMatrix(as.factor(pclass), newtest$Survived,positive ='1')


训练速度确实快,但也可能是数据量小,看不出来什么上述代码还输出了特征重要性排序,大家可以用xgboost、随机森林等模型比较下。另外,与教程不同的是,笔者修改了评估指标为AUC,默认应该不是这个指标,有兴趣的话,大家可以看看文档。

10折交叉验证做起来就需要点时间了,大约1分钟,与电脑配置有关。结果显示,auc最高是约0.925,最低也有0.8,这个结果确实不错了,而且每折内部也有多轮次的评估,模型的稳定性还可以。

在新数据中预测一下,并输出一份混淆矩阵看看:

结果还可以。关于其他评价结果,以及一些可视化方法,比如ROC\PR\校准\DCA等,可以参考其他预测模型的代码,因为是很类似的(原理一模一样),基本是只需要改几个变量名就能搞定的事。如果您觉得实在搞不定或需要辅导,可以联系笔者游,微信号:hulitongjisuibi或者我们团队的其他成员。

此外,机器学习多数是黑箱模型,无法直接用于临床,考虑采用shap等方法进行解释,或转换成列线图等可视化工具,下面就是一份shap解释热图:

图片来源:见参考文献

觉得您觉得此文有帮助的话,可以帮忙点个赞吗?谢谢


主要参考来源:

官网 https://catboost.ai/

官方教程 https://catboost.ai/en/docs/concepts/r-quickstart

相关R包(catboost)的帮助文档;

Valsaraj A, Kalmady SV, Sharma V, Frost M, Sun W, Sepehrvand N, Ong M, Equilbec C, Dyck JRB, Anderson T, Becher H, Weeks S, Tromp J, Hung CL, Ezekowitz JA, Kaul P. Development and validation of echocardiography-based machine-learning models to predict mortality. EBioMedicine. 2023 Apr;90:104479. doi: 10.1016/j.ebiom.2023.104479. Epub 2023 Feb 28. PMID: 36857967; PMCID: PMC10006431.

免责声明:仅供科研学习和分享,不做商业用途,如有侵权,请联系我们删除,谢谢。


护理统计随笔
未来是精准护理的时代,护理研究的发展不仅在于基于证据的理论创新及实践,更在于大数据和人工智能。这里是一个从0到1的学习平台,关注我们,不但可以夯实科研基础,更可以开阔研究视野,让你的护理科研之路走得更远、更广。
 最新文章