临床预测模型/机器学习-随机森林树RSF(RandomForest/RandomForestSRC)算法学习

文摘   2024-11-04 15:12   日本  

随机森林(Random Forest)是一种集成机器学习方法,由多棵决策树组成。它通过训练大量的决策树并结合这些树的预测结果,来提高模型的准确性和稳健性。随机森林常用于分类、回归和其他预测任务,尤其适合处理高维数据和噪声数据。

在随机森林树种,每种生成的树指的是决策树,多棵决策树组成了"森林"(随机森林),每颗树单独对数据进行分类或预测,最后通过多数投票(分类)或平均(回归)得到最终结果,换句话说,每颗树可以看做是数据的特征,不同的特征可以组成数据的全貌,每种特征又都会被用来对数据进行分类或者预测(如果这个理解有误请尽管批评指正)。

决策树的生成:

  1. 每棵树是从训练集数据中随机抽样生成的,这个抽样是有放回的。
  2. 每棵树在节点分裂时随机选择部分特征,以减少树之间的相关性并增强模型的泛化能力。
  3. 每棵树会尽可能完全生长(没有剪枝),以提高个体树的强度。

树的作用:

  1. 每棵树是一个独立的分类或回归模型,整体的随机森林通过这些树的组合来进行更为稳健的预测。
  2. 通过让每棵树“投票”或者输出预测值,随机森林能够降低单棵树的过拟合问题,并提升整体的预测准确性和鲁棒性。

错误率依赖于树之间的相关性和单棵树的强度:

  1. 树之间的相关性越低,整体模型的误差越低。
  2. 单棵树的强度越高,模型的误差也越低。因此,通过调整每棵树的特征选择数量来平衡这两者,以获得最佳表现的随机森林模型。

多数投票(分类)

  1. 对于分类任务,每棵决策树会对输入数据进行预测,输出一个类别(例如,分类为“猫”或“狗”)。当所有树都做出预测后,随机森林会统计每个类别的预测次数,并选择A获得最多“票数”,A的类别作为最终结果。这种方式称为“多数投票”。
  2. 举例:假设一个随机森林包含100棵树,预测一个图像属于“猫”或“狗”。其中,70棵树预测是“猫”,30棵树预测是“狗”。最终结果是“猫”,因为“猫”得到了多数票。

平均(回归)

  1. 对于回归任务,每棵决策树会对输入数据输出一个连续的数值(例如,房价的预测)。当所有树都做出预测后,随机森林会对所有预测值求平均值,并将这个平均值作为最终的预测结果。这种方式称为“平均”。
  2. 举例:假设一个随机森林包含100棵树,用于预测房价。每棵树对房价的预测值可能不同,比如一棵树预测300,000,另一棵预测320,000,以此类推。最终的预测结果是所有100棵树的预测值的平均值(例如,305,000),这个平均值作为房价的最终预测值。

从开发者对它的描述来看,RSF可强大了,不过作为使用者我们还是需要留个心,毕竟没有最好的算法/工具,只有符合自己数据情况的相对合适的算法/工具。

随机森林的工作原理:随机森林使用“袋外数据”(out-of-bag, OOB)来估计分类误差和变量重要性。每棵树通过有放回抽样的方式从原始数据中随机抽取训练样本,这导致约三分之一的数据未被选入,用作 OOB 数据。这些 OOB 数据用于提供"无偏倚"误差估计,并帮助评估变量的重要性。

袋外(OOB)误差估计:在随机森林中,不需要额外的交叉验证/测试集来估计模型的泛化误差。每棵树在训练时有一部分数据未被使用,作为 OOB 数据。OOB 数据通过各树投票得到预测,OOB 误差率就是预测错误的比例,这是经过验证的可靠误差估计方法。

变量重要性:为了评估变量的重要性,将每棵树的 OOB 数据通过树模型并记录投票次数,然后随机置换某个变量的值,计算分类正确的投票次数变化量,变化越大,说明该变量的重要性越高。通过所有树的平均结果,得出每个变量的重要性分数。

基尼重要性:也称为基尼指数重要性或基尼不纯度减少量,是一种用于评估特征(变量)在决策树或随机森林模型中的重要性的方法。在决策树中,基尼不纯度(Gini Impurity)用于衡量一个节点的“纯度”——也就是说,节点中样本的类别有多么一致。基尼不纯度越低,节点中的样本越趋于相同的类别。每当一个节点使用某个特征进行分裂时,这个特征会降低基尼不纯度,这种减少的量越大,说明这个特征越重要。

Interactions(变量交互): 在随机森林中,变量之间的交互定义为:如果某一变量(如 mmm)的分裂影响了另一变量(如 kkk)的分裂可能性,则这两个变量存在交互。每棵树中计算变量的基尼值排名差并取绝对值,最后在所有树上取平均值来度量交互强度。此方法基于变量独立的假设,并具有实验性质,仅在少量数据集上测试,结果需谨慎解释。

Proximities(接近度): 接近度是随机森林中的一种重要工具。通过记录数据对在相同终端节点出现的频次,构成一个 N×NN \times NN×N 矩阵,并在所有树中取平均值进行归一化。对于大数据集,接近度矩阵可能超出内存限制,可以只保留最近邻的接近度。当有测试集时,也可以计算测试集和训练集之间的接近度,额外的计算量适中。总之,随机森林这个强大的工具可以通过反复抽样(袋装法)的方式获得多棵决策树模型并综合这些模型的结果,RSF能够进行分类,回归,生存分析等多种任务。

分析流程

randomForest包
1.导入
rm(list = ls())
library(randomForest)
load("consensus.Rdata")
2.数据预处理
# 把基因数据转置之后跟生存信息整合在一起
# 行为样本,列为生存信息+变量
meta <- meta[,c(1:3)]
head(meta)
#                                ID OS  OS.time
# TCGA-CR-7374-01A TCGA-CR-7374-01A  0 1.000000
# TCGA-CV-A45V-01A TCGA-CV-A45V-01A  1 1.066667
# TCGA-CV-7102-01A TCGA-CV-7102-01A  1 1.866667
# TCGA-MT-A67D-01A TCGA-MT-A67D-01A  0 1.866667
# TCGA-P3-A6T4-01A TCGA-P3-A6T4-01A  1 2.066667
# TCGA-CV-7255-01A TCGA-CV-7255-01A  1 2.133333
identical(rownames(meta),colnames(exprSet))
# [1] TRUE
meta <- cbind(meta,t(exprSet))
head(meta)[1:5,1:5]
#                                ID OS  OS.time    WASH7P AL627309.6
# TCGA-CR-7374-01A TCGA-CR-7374-01A  0 1.000000 0.5808846   3.117962
# TCGA-CV-A45V-01A TCGA-CV-A45V-01A  1 1.066667 1.4177642   6.250413
# TCGA-CV-7102-01A TCGA-CV-7102-01A  1 1.866667 0.6501330   1.219729
# TCGA-MT-A67D-01A TCGA-MT-A67D-01A  0 1.866667 1.2045780   3.038835
# TCGA-P3-A6T4-01A TCGA-P3-A6T4-01A  1 2.066667 1.3470145   3.799571
str(meta)
# 'data.frame': 493 obs. of  18238 variables:
#  $ ID                       : chr  "TCGA-CR-7374-01A" "TCGA-CV-A45V-01A" "TCGA-CV-7102-01A" "TCGA-MT-A67D-01A" ...
#  $ OS                       : int  0 1 1 0 1 1 1 1 1 0 ...
#  $ OS.time                  : num  1 1.07 1.87 1.87 2.07 ...
#  $ WASH7P                   : num  0.581 1.418 0.65 1.205 1.347 ...
#  $ AL627309.6               : num  3.12 6.25 1.22 3.04 3.8 ...
#  $ AL627309.7               : num  3.73 6.38 2.13 2.96 4.43 ...

# 为了减少运算符合,对变量进行随机抽样
meta <- meta[, c("ID","OS","OS.time",names(meta)[4:204])]

# 因子化
# 规范命名
meta$OS <- factor(meta$OS,levels = c("0","1"))
colnames(meta) <- gsub("-","_",colnames(meta))

# 数据分割 7:3,8:2 均可
# 划分是随机的,设置种子数可以让结果复现
set.seed(123)
ind <- sample(1:nrow(meta), size = 0.7*nrow(meta))
train <- meta[ind,]
test <- meta[-ind, ]
3.RandomForest分析

建立模型

dat <- train[,-c(1,3)] # 这里是表格
rf <- randomForest(OS~., 
                   data=dat, 
                   proximity=TRUE,
                   importance = T # 需要计算变量的重要性
                   ) 
print(rf)
# Call:
#  randomForest(formula = OS ~ ., data = dat, proximity = TRUE,      importance = T) 
#                Type of random forest: classification
#                      Number of trees: 500
# No. of variables tried at each split: 14

#         OOB estimate of  error rate: 39.13%
# Confusion matrix:
#     0  1 class.error
# 0 166 37   0.1822660
# 1  98 44   0.6901408

optionTrees <- which.min(rf$err.rate[,1])  # 选择误差最小树模型
optionTrees
# [1] 25


# 可视化
plot(rf)

Number of trees: 500:随机森林中生成的决策树的数量是500棵。这是默认值,也可以通过设置 ntree 参数来更改。

No. of variables tried at each split: 14:每次分裂时,从所有特征中随机选择14个特征用于寻找最佳分裂。这个数字是根据随机森林的规则自动选择的(通常是特征总数的平方根),但可以通过 mtry 参数手动设置。

OOB(Out-of-Bag)estimate of error rate: 39.13%: 袋外误差(Out-of-Bag Error)是用来估计模型泛化误差的。袋外误差率为39.71%,表示大约有39.13%的样本被错误分类。较高的袋外误差率(接近40%)可能表明模型在数据上的表现较差,可能是因为数据不够线性可分,或者模型复杂度不足。

Confusion matrix(混淆矩阵) :混淆矩阵显示了模型预测和实际分类的比较结果。它按行表示实际的类别,按列表示模型的预测类别。0(负类)实际类别为0的样本有203个,其中166个被正确预测为0,37个被误分类为1。class.error 为 0.1823(18.23%),表示实际类别为0的样本中大约有18.23%被错误分类。1(正类)实际类别为1的样本有142个,其中44个被正确预测为1,98个被误分类为0。class.error 为 0.6901(69.01%),表示实际类别为1的样本中有69.01%被错误分类。这是一个较高的误差率,表明模型在识别类别1时表现较差。

随机森林模型的错误率随树的数量变化的曲线图

黑色实线(Overall Error):代表随机森林的总体袋外误差(OOB error),即所有类别的平均误差率。可以看到,随着树的数量增加,总体误差率逐渐趋于稳定,但在 500 棵树后仍保持在大约 0.4(即40%)左右。这与之前的袋外误差率一致,表示模型在这组数据上的总体表现并不理想。

红色虚线(Class 0 Error):表示类别0的误差率。从图中可以看到,类别0的误差率逐渐下降,并在增加到一定数量的树后趋于稳定,大约在 0.2 左右(即20%)。这表明模型在识别类别0时的表现相对较好,错误率较低。

绿色虚线(Class 1 Error):表示类别1的误差率。类别1的误差率较高,一直保持在 0.7 左右。这表明模型在识别类别1时有明显的困难,误分类率较高。这个高误差率可能是由于类别1的样本数量较少或类别特征不明显,导致模型难以正确分类。提取关键变量,这里得到的变量可以按照分值进行自行筛选

randomForest::importance(rf)[1:5,]
#                     0           1 MeanDecreaseAccuracy MeanDecreaseGini
# WASH7P      2.3630045 -2.40316028         -0.006391703        0.7960579
# AL627309.6  0.9272169 -1.98664513         -0.781005992        0.4949906
# AL627309.7 -0.7136037 -0.73516331         -0.990568361        0.6711430
# WASH9P      3.4953492 -0.05153207          2.813529742        0.9358831
# AL732372.2  1.4735856 -0.99531545          0.306859094        0.5746815

varImpPlot(rf)

0 和 1:这些列表示每个特征对分类类别 0 和 1 的重要性度量。在这里,这些值可能表示当该特征被随机置换时,模型在对应类别上的分类准确性下降程度。正值 表示该特征对分类有正向贡献(即置换后错误率增加),负值 则表示置换后错误率反而减少,可能表明该特征对该类别的贡献较小,甚至是噪声。

MeanDecreaseAccuracy:这是 平均准确性减少量。它表示当该特征被随机置换时,模型整体准确性的降低程度。值越大表示该特征对模型整体准确性影响越大,即特征越重要。如果值为负,说明该特征对模型准确性贡献不大,甚至可能带来噪声。

MeanDecreaseGini:这是 基尼指数减少量,是另一种重要性指标。它衡量了该特征在分裂节点时对基尼不纯度的贡献,表示模型在使用该特征分裂后纯度的提升。值越大表示该特征在决策树的分裂中越重要,对最终的分类结果影响越大。

4.构建最佳模型及预测
dat <- train[,-c(1,3)] # 这里是表格
rf_best <- randomForest(OS~., 
                        data = dat,
                        ntree = optionTrees)
rf_best
# Call:
#  randomForest(formula = OS ~ ., data = dat, ntree = optionTrees) 
#                Type of random forest: classification
#                      Number of trees: 25
# No. of variables tried at each split: 14

#         OOB estimate of  error rate: 45.8%
# Confusion matrix:
#     0  1 class.error
# 0 143 60   0.2955665
# 1  98 44   0.6901408

# 验证集
data <- test[,-c(1,3)] # 这里是表格
pred.test <- predict(rf_best,newdata = data)
pred.test.value <- predict(rf_best,newdata = data,type = "prob")
head(pred.test.value)
#                     0    1
# TCGA-CR-7374-01A 0.68 0.32
# TCGA-CV-7102-01A 0.44 0.56
# TCGA-CV-7255-01A 0.48 0.52
# TCGA-BA-A6DG-01A 0.52 0.48
# TCGA-CV-6961-01A 0.44 0.56

后续可以根据预测结果进行ROC曲线等其他内容的分析。

randomForestSRC包

这个包值得深度探索一下,里面有很多功能

1.导入
rm(list = ls())
library(randomForestSRC)
load("data.Rdata")
2.数据预处理

这里去掉了因子化的一步

# 把基因数据转置之后跟生存信息整合在一起
# 行为样本,列为生存信息+变量
meta <- meta[,c(1:3)]
head(meta)
#                                ID OS  OS.time
# TCGA-CR-7374-01A TCGA-CR-7374-01A  0 1.000000
# TCGA-CV-A45V-01A TCGA-CV-A45V-01A  1 1.066667
# TCGA-CV-7102-01A TCGA-CV-7102-01A  1 1.866667
# TCGA-MT-A67D-01A TCGA-MT-A67D-01A  0 1.866667
# TCGA-P3-A6T4-01A TCGA-P3-A6T4-01A  1 2.066667
# TCGA-CV-7255-01A TCGA-CV-7255-01A  1 2.133333
identical(rownames(meta),colnames(exprSet))
# [1] TRUE
meta <- cbind(meta,t(exprSet))
head(meta)[1:5,1:5]
#                                ID OS  OS.time    WASH7P AL627309.6
# TCGA-CR-7374-01A TCGA-CR-7374-01A  0 1.000000 0.5808846   3.117962
# TCGA-CV-A45V-01A TCGA-CV-A45V-01A  1 1.066667 1.4177642   6.250413
# TCGA-CV-7102-01A TCGA-CV-7102-01A  1 1.866667 0.6501330   1.219729
# TCGA-MT-A67D-01A TCGA-MT-A67D-01A  0 1.866667 1.2045780   3.038835
# TCGA-P3-A6T4-01A TCGA-P3-A6T4-01A  1 2.066667 1.3470145   3.799571
str(meta)
# 'data.frame': 493 obs. of  18238 variables:
#  $ ID                       : chr  "TCGA-CR-7374-01A" "TCGA-CV-A45V-01A" "TCGA-CV-7102-01A" "TCGA-MT-A67D-01A" ...
#  $ OS                       : int  0 1 1 0 1 1 1 1 1 0 ...
#  $ OS.time                  : num  1 1.07 1.87 1.87 2.07 ...
#  $ WASH7P                   : num  0.581 1.418 0.65 1.205 1.347 ...
#  $ AL627309.6               : num  3.12 6.25 1.22 3.04 3.8 ...
#  $ AL627309.7               : num  3.73 6.38 2.13 2.96 4.43 ...

# 为了减少运算符合,对变量进行随机抽样
meta <- meta[, c("ID","OS","OS.time",names(meta)[4:204])]

# 规范命名
colnames(meta) <- gsub("-","_",colnames(meta))

# 数据分割 7:3,8:2 均可
# 划分是随机的,设置种子数可以让结果复现
set.seed(123)
ind <- sample(1:nrow(meta), size = 0.7*nrow(meta))
train <- meta[ind,]
test <- meta[-ind, ]
3.RandomForestSRC分析

建立模型

seed <- 123
rf_nodesize <- 10

dat <- train[,-c(1)] # 这里是表格
fit <- rfsrc(Surv(OS.time,OS)~.,data = dat,
             ntree = 1000,
             nodesize = rf_nodesize, #3-15 之间调整
             splitrule = 'logrank'#分割的规则,logrank是生存分析常用
             importance = T#计算每个预测变量对模型预测能力的贡献度
             proximity = T#计算样本之间的接近度
             forest = T#保存整个随机森林模型
             seed = seed)
print(fit)
#                          Sample size: 345
#                     Number of deaths: 142
#                      Number of trees: 1000
#            Forest terminal node size: 10
#        Average no. of terminal nodes: 32.661
# No. of variables tried at each split: 15
#               Total no. of variables: 201
#        Resampling used to grow trees: swor
#     Resample size used to grow trees: 218
#                             Analysis: RSF
#                               Family: surv
#                       Splitting rule: logrank *random*
#        Number of random split points: 10
#                           (OOB) CRPS: 35.03819086
#                    (OOB) stand. CRPS: 0.16380641
#    (OOB) Requested performance error: 0.44121949
4.可视化
plot(fit)

#绘图2---基因
library(ggplot2)
library(dplyr)
library(tibble)
library(viridis)

importance_gene <- data.frame(fit$importance) %>% 
  rownames_to_column("gene") %>% 
  arrange(- fit.importance)
  #head(20)  #可以调整
importance_gene

ggplot(data=importance_gene, aes(x = reorder(gene,  fit.importance), 
                                 y=fit.importance,fill=gene)) +
  geom_bar(stat="identity") + 
  theme_classic() + 
  theme(legend.position = 'none') + 
  scale_fill_viridis(discrete = TRUE, option = "D") +
  #scale_fill_brewer(palette = "Set3") +
  coord_flip()
ggsave("model_genes.pdf",width = 9,height = 7)

两种可视化

5.探索变量
oo <- subsample(fit, verbose = FALSE)
vimpCI <- extract.subsample(oo)$var.jk.sel.Z
vimpCI
#                     lower          mean        upper     pvalue signif
# WASH7P       -0.0030876511 -1.227215e-04 0.0028422081 0.53232870  FALSE
# AL627309.6   -0.0016648692  4.370141e-04 0.0025388974 0.34181789  FALSE
# AL627309.7   -0.0015882484  4.877482e-04 0.0025637447 0.32258344  FALSE
# WASH9P       -0.0030185608  3.628007e-04 0.0037441623 0.41671953  FALSE
# AL732372.2   -0.0039735002 -7.430470e-04 0.0024874062 0.67393951  FALSE

# Confidence Intervals for VIMP
plot.subsample(oo)
# take the variable "Month" for example for partial plot
plot.variable(fit, 
              xvar.names = "ICMT"
              partial = TRUE)

也是用于观察每个变量的重要性~后续就用predict函数去分析测试数据集就可以啦~

同时可以点进参考资料中RandomForestSRC的资料网站,里面还有其他功能可以使用哦~

参考资料:

  1. Random Forests(Leo Breiman and Adele Cutler): https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm
  2. RandomForestSRC:https://github.com/cran/randomForestSRC https://www.randomforestsrc.org/
  3. RandomForest:https://github.com/cran/randomForest
  4. 医学和生信笔记:https://mp.weixin.qq.com/s/T2SSdL5OTkuVcx7bTRwloQ https://mp.weixin.qq.com/s/Yh444t3IoEc5SVKn5I4gpw https://mp.weixin.qq.com/s/tnITiARh2MbW49VfB5u71A
  5. 生信碱移:https://mp.weixin.qq.com/s/0QuxJEsch-mesN-8coAbkw
  6. 生信补给站:https://mp.weixin.qq.com/s/3htyj2AAELzOHB40ri2Ryg

:若对内容有疑惑或者有发现明确错误的朋友,请联系后台(欢迎交流)。更多内容可关注公众号:生信方舟

- END -


生信方舟
执着医学,热爱科研。站在巨人的肩膀上,学习和整理各种知识。
 最新文章