Explanatory Model Analysis (1):局部解释triplot(可解释型预测模型)

文摘   2024-07-23 14:27   马来西亚  

介绍

Break DownShapLIME方法用于解释黑盒模型的局部解释。但是当预测变量存在相关关系时,这些模型的解释结果就可能出现问题。triplot包提供的predict_aspects 的目的是通过提供解释变量组的实例级和数据级解释器来提高黑盒模型的可解释性。它能够将预测变量分组成称为分组的实体。然后,它可以计算这些分组对预测的贡献。具体而言,triplot包提供了用于探索机器学习预测模型的工具。它包含了一个称为predict_aspects(也称为aspects_importance)的实例级解释器,能够解释整组解释变量的贡献。此外,该包还提供了triplot功能,它展示了不同大小的分组(预测变量组)的重要性如何变化。关键功能包括:

  • predict_triplot() 和 model_triplot():用于自动分组重要性分组的实例级和数据级汇总。

  • predict_aspects():用于计算选定观察值的特征组重要性(称为分组重要性)。

  • group_variables():用于将相关联的数值型特征分组。

假设我们在一个高维数据集上构建了一个模型,并使用这个模型来预测一个新的观测结果。我们希望了解哪些特征对计算出的预测结果贡献更大,哪些贡献较小。然而,当我们计算每个单独特征的重要性时,图像可能仍然不清晰,因为一些特征可能是相关的。为了获得更好的理解,我们可以将这些特征进行分组。之后,我们可以计算每个特征组(分组)对预测的贡献(重要性)。这正是"predict aspects"方法的作用

加载R包

导入所需要的R包,在导入前需要用户自己安装。

library(tidyverse)

# devtools::install_github("ModelOriented/triplot")
library(DALEX)
library(triplot)
library(mlbench)
library(randomForest)

# rm(list = ls())
options(stringsAsFactors = F)
options(future.globals.maxSize = 1000 * 1024^2)

导入数据

我们将使用来自mlbench包的波士顿住房数据集。这个著名的数据集包含了波士顿506个人口普查区域的住房数据。我们将预测cmedv——业主自住房屋的校正中位价值

data("BostonHousing2")

head(BostonHousing2)
towntractlonlatmedvcmedv
Nahant2011-70.955042.255024.024.0
Swampscott2021-70.950042.287521.621.6
Swampscott2022-70.936042.283034.734.7
Marblehead2031-70.928042.293033.433.4
Marblehead2032-70.922042.298036.236.2
Marblehead2033-70.916542.304028.728.7

数据预处理

  • predict_aspects仅适合连续型变量

  • 筛选变量

  • 筛选预测对象

metadata <- BostonHousing2[, -c(1:5, 10)]

trainData <- metadata[-4, , ]
testData <- metadata[4, , ]

x_train <- trainData[, -1]
y_train <- trainData[, 1]

x_test <- testData[, -1]
y_test <- testData[, 1]


head(trainData)
cmedvcrimznindusnoxrm
24.00.0063218.02.310.5386.575
21.60.027310.07.070.4696.421
34.70.027290.07.070.4697.185
36.20.069050.02.180.4587.147
28.70.029850.02.180.4586.430
22.90.0882912.57.870.5246.012

构建随机森林预测模型

cmedv变量构建回归模型,自变量存在相关关系且均是连续型变量

rf_fit <- randomForest(cmedv ~ ., data = trainData)

paste0("PredictedValue = ", predict(rf_fit, x_test))
paste0("TrueValue = ", y_test)
[1] "PredictedValue = 34.9478866666667"
[1] "TrueValue = 33.4"

结果:预测值和真实值差别较小

分组的重要性

针对单个样本预测每个分组的贡献情况

  • 分组设置可以:1)根据先验经验设定或2)根据相关系数设定(triplot::group_variables);

  • 使用DALEX::explain构建解释器;

  • triplot::predict_aspects计算分组的局部解释结果

set.seed(123)

## 手动设置分组
# rf_aspects <- list(
# geo = c("dis", "nox", "rad"),
# wealth = c("rm", "lstat"),
# structure = c("indus", "age", "zn"),
# ptratio = "ptratio",
# b = "b",
# tax = "tax",
# crim = "crim")

## 自动设置分组,以相关系数0.6作为阈值
rf_aspects <- triplot::group_variables(
x = trainData,
h = 0.6)

explain_rf <- DALEX::explain(
model = rf_fit,
data = trainData,
verbose = FALSE)

ai_rf <- triplot::predict_aspects(
x = explain_rf,
new_observation = testData,
variable_groups = rf_aspects,
N = 5000,
show_cor = TRUE,
label = "RF")

print(ai_rf, show_features = TRUE)

plot(ai_rf, show_features = FALSE)
 variable_groups importance                   features
2 aspect.group1 11.3257 cmedv, rm, lstat
3 aspect.group2 0.8630 crim, indus, nox, age, dis
5 aspect.group4 0.7812 rad, tax
4 aspect.group3 -0.3285 zn
6 aspect.group5 -0.2007 ptratio
7 aspect.group6 0.0945 b

结果:每个分组变量的模型贡献程度大小(接下来可以查看分组变量之间的相关性)

  • **aspect.group3 (rm, lstat)**的贡献度最大,是11.3257

  • **aspect.group5 (ptratio)的贡献度最小,是-0.2007

变量整体解释结果

Triplot 是一个基于aspects_importance函数构建的工具,它使我们能够更深入地理解黑盒模型的内部工作机制。它在一个地方展示了以下内容:

  • 每个单独特征的重要性;

  • 层次化分组的重要性;

  • group_variables()中将特征分组的顺序。

层次化分组的重要性允许我们检查不同变量分组级别的分组重要性。该方法首先观察每个分组都有一个单一变量的分组重要性。然后,它通过迭代地将具有最高绝对相关性的分组合并为一个更大的分组,并计算其对预测的贡献,来创建更大的分组。

需要注意的是,与group_variables()类似,calculate_triplot()仅适用于包含数值变量的数据集。

explain_rf_global <- DALEX::explain(
model = rf_fit,
data = x_train,
y = y_train,
verbose = FALSE)

tri_rf_global <- triplot::model_triplot(
x = explain_rf_global)

plot(tri_rf_global) +
patchwork::plot_annotation(title = "Global triplot for variables in the RF model")

结果:随机森林回归模型变量的整体解释的triplot图。

  • 左侧面板显示了各个变量的整体重要性(参考randomForest给出的importance score);

  • 中间面板显示了由层次聚类确定的变量组的局部重要性(通过层次聚类评估不同变量的聚类结果);

  • 右侧面板展示了通过层次聚类可视化的全局相关性结构。

变量局部解释结果(针对单个预测样本而言)

Triplot 是一个基于aspects_importance函数构建的工具,它使我们能够更深入地理解黑盒模型的内部工作机制。它在一个地方展示了以下内容:

  • 每个单独特征的重要性;

  • 层次化分组的重要性;

  • group_variables()中将特征分组为分组的顺序。

层次化分组的重要性允许我们检查不同变量分组级别的分组重要性。该方法首先观察每个分组都有一个单一变量的分组重要性。然后,它通过迭代地将具有最高绝对相关性的分组合并为一个更大的分组,并计算其对预测的贡献,来创建更大的分组。

需要注意的是,与group_variables()类似,calculate_triplot()仅适用于包含数值变量的数据集。

tri_rf_local <- triplot::predict_triplot(
x = explain_rf,
new_observation = testData,
N = 10000)

plot(tri_rf_local) +
patchwork::plot_annotation(title = "Local triplot for variables in the RF model")

结果:随机森林回归模型对testData的变量的局部解释的triplot图(针对单个样本而言)。

  • 左侧面板显示了各个变量的局部重要性(lstat对模型贡献最大);

  • 中间面板显示了由层次聚类确定的变量组的局部重要性;

  • 右侧面板展示了通过层次聚类可视化的全局相关性结构。

总结

predict_aspects是一种强大的工具,它通过分析特征组对预测的贡献来提高黑盒模型的可解释性。它具备自动分组特征、控制结果中非零值的数量以及评估特征组重要性变化的能力。通过进一步的调整和优化,predict_aspects 可以提供更准确、更透明的模型解释,帮助研究人员和实践者更好地理解和信任他们的模型。

相比常用的SHAP算法,triplot提供了另一种解析变量之间相关的模型解释方法。以下是它的优缺点:

  • 优势

  1. 多变量组的贡献解释:triplot包中的predict_aspects函数能够解释整组解释变量的贡献,这对于那些使用大量相关变量的模型尤其有用;

  2. 实例级和数据级摘要:通过predict_triplot()model_triplot()函数,可以提供实例级和数据级自动分组重要性分组的摘要;

  3. 可视化相关性结构:triplot能够展示特征之间的相关性结构,并通过层次聚类进行可视化,这有助于理解变量组如何共同影响模型的预测;

  4. 解决相关特征问题:在特征高度相关的情况下,传统的特征重要性方法可能会产生误导性结果,而triplot通过考虑变量组的相关性来评估重要性,从而解决了这一问题;

  • 局限性

  1. 特定类型的数据:triplot主要设计用于数值特征,这可能限制了它在处理非数值数据时的应用范围;

  2. 计算复杂性:在处理具有大量变量的模型时,计算变量组的重要性可能需要复杂的算法和较长的处理时间;

  3. 解释的局限性:尽管triplot提供了变量组重要性的解释,但可能不如个别变量重要性那样直观或易于理解,特别是对于不熟悉该方法的研究人员;

  4. 可能需要领域知识:为了充分利用triplot包,用户可能需要对数据和模型有深入的理解,包括变量之间的相关性以及它们如何影响模型的预测。

参考

  • https://modeloriented.github.io/triplot/articles/vignette_aspect_importance_indepth.html

  • https://ema.drwhy.ai/ 


生信学习者
生信教程分享,专注数据分析和科研绘图方向欢迎大家关注,也可一起探讨生信问题