IF18.8nature子刊--基于树的 Explainable AI:从局部解释到全局理解

文摘   2024-12-19 12:00   北京  

摘要

基于树的机器学习模型,如随机森林、决策树和梯度提升树,是流行的非线性预测模型,但相比之下,对解释它们的预测结果的关注相对较少。在这里,我们通过三个主要贡献提高了基于树的模型的可解释性:1)首个基于博弈论的多项式时间算法,用于计算最优解释。2)一种新型解释,直接测量局部特征交互效应。3)一套新工具,用于基于结合每个预测的许多局部解释来理解全局模型结构。我们将这些工具应用于三个医学机器学习问题,并展示了如何通过结合许多高质量的局部解释来表示全局结构,同时保留对原始模型的局部忠实度。这些工具使我们能够:i)在美国人群中识别高幅度但低频率的非线性死亡风险因素;ii)突出具有共同风险特征的不同人群亚群;iii)识别慢性肾脏病风险因素之间的非线性交互效应;iv)通过识别随时间推移哪些特征正在降低模型性能,监控医院中部署的机器学习模型。鉴于基于树的机器学习模型的流行,这些对其可解释性的改进在广泛的领域都有影响。



01

方法
01

Shapley值

Shapley值是博弈论中的一个解决方案概念,用于确定在n人合作游戏中每个参与者应得的收益。在机器学习模型解释中,它被用来量化每个特征对模型预测的贡献度。Shapley值满足三个关键属性:局部准确性(所有特征贡献的总和应等于模型的预测输出)、一致性(如果一个特征在所有可能的上下文中的贡献增加或保持不变,则其归因值不应减少)和缺失性(对模型输出没有影响的特征应被赋予零影响)。

02

TreeExplainer算法的三个阶段

第一阶段:路径依赖特征扰动(Algorithm 1)

算法描述:这个算法通过模拟每个特征对模型输出的影响来计算Shapley值。它通过递归地沿着决策树的路径模拟每个特征的影响,并计算每个特征对模型预测的贡献。

计算过程:对于每个特征,算法会计算在包含该特征和不包含该特征的情况下模型输出的期望值,然后根据这些值计算Shapley值。

第二阶段:多项式时间Tree SHAP算法(Algorithm 2)

算法描述:这个算法通过递归跟踪所有可能的特征子集在树的每个叶子节点中的比例,从而在多项式时间内计算Shapley值。

计算过程:算法使用EXTENDUNWIND方法来递归地跟踪和更新特征子集。EXTEND方法用于在树的每个分支上扩展特征子集,而UNWIND方法用于撤销之前的扩展。这种方法允许算法在多项式时间内计算Shapley值。

第三阶段:干预性特征扰动(Algorithm 3)

算法描述:这个算法允许对树模型的输出进行非线性变换的解释,例如模型的损失函数。

计算过程:算法通过遍历单个前景样本和背景样本在树中的混合路径来计算Shapley值。在每个内部节点,算法会根据前景和背景样本的路径决定如何分配正负贡献,并在叶子节点计算最终的贡献。


03

算法复杂度分析

Algorithm 2将计算Shapley值的复杂度从指数级降低到多项式级别,这对于大规模数据集和复杂模型尤为重要。复杂度分析考虑了树的最大深度、树的数量、叶子节点的数量和输入特征的数量。

04

基准评估指标



  1. Local Accuracy (局部准确性):

    评估解释方法是否能够准确地反映模型的预测输出。

  2. Consistency (一致性):

    评估当模型输出变化时,解释方法是否能够一致地反映这些变化。

  3. Missingness (缺失性):

    评估对模型输出没有影响的特征是否被正确地赋予零影响。

  4. Runtime (运行时间):

    衡量解释方法在实际应用中的效率。

  5. Average Score (平均分数):

    综合多个指标来评估解释方法的整体性能。

  6. Keep Positive (保持正值):

    评估当特征值增加时,模型输出增加的情况下,解释方法是否能够正确地反映这一关系。

  7. Keep Negative (保持负值):

    评估当特征值增加时,模型输出减少的情况下,解释方法是否能够正确地反映这一关系。

  8. Mask Positive (掩蔽正值):

    通过将特征值掩蔽为平均值,评估解释方法对模型输出变化的敏感性。

  9. Mask Negative (掩蔽负值):

    通过将特征值掩蔽为平均值,评估解释方法对模型输出变化的敏感性。

  10. Resample Positive (重采样正值):

    通过随机重采样特征值,评估解释方法对模型输出变化的稳定性。

  11. Resample Negative (重采样负值):

    通过随机重采样特征值,评估解释方法对模型输出变化的稳定性。

  12. Keep Absolute (保持绝对值):

    评估解释方法是否能够正确地反映特征值的绝对变化对模型输出的影响。

  13. Remove Positive (移除正值):

    评估当特征值被移除时,解释方法是否能够正确地反映这一变化。

  14. Remove Negative (移除负值):

    评估当特征值被移除时,解释方法是否能够正确地反映这一变化。

  15. Model-Specific Metrics (模型特定指标):

    根据特定模型的特性,评估解释方法的性能。


05

SHAP交互值


SHAP交互值是一种基于Shapley交互指数的解释模型,用于捕捉特征之间的局部交互效应。

定义:SHAP交互值:通过考虑特征对之间的交互效应,SHAP交互值提供了一个矩阵,其中对角线上的值表示特征的独立效应,而非对角线上的值表示特征之间的交互效应。

计算过程

基于Shapley值:SHAP交互值可以通过计算包含和不包含特定特征对的情况下模型输出的差异来计算。这种方法允许算法在考虑特征交互效应的同时,为每个特征提供更细致的解释。

算法优化:通过利用TreeExplainer算法,可以显著降低计算SHAP交互值的复杂度。特别是,可以通过两次运行Algorithm 2或3来计算每个特征的交互效应,一次是将特定特征视为固定存在,另一次视为固定不存在



02

结果
01
基于树的模型的Shapley值解释的精确计算

经典Shapley值可以被认为是“最优”的,因为在一大类方法中,它们是唯一一种在保持合作博弈论中几个自然属性的同时衡量特征重要性的方法。不幸的是,通常只能近似计算这些值,因为精确计算它们是特别困难的,需要对所有特征子集进行求和。已经提出了基于采样的近似方法;然而,使用它们来计算本文中即使最小数据集的低方差版本的结果也会消耗数年的CPU时间(特别是对于交互效应)。通过特别关注树模型,我们开发了一种算法,它可以在多项式时间内计算基于精确Shapley值的局部解释。这提供了理论上保证局部准确性和一致性的局部解释。


02
扩展局部解释以直接捕获特征交互

为每个输入特征分配一个数字的局部解释虽然非常直观,但不能直接表示交互效果。我们提供了一种基于博弈论文献中提出的 Shapley 值的泛化来测量局部交互效应的理论方法。我们认为这种方法为模型的行为提供了有价值的见解。


03

基于许多局部解释的全球模型结构解释工具

使用整个数据集中的Shapley值高效、准确地计算局部解释的能力,使开发一系列工具来解释模型的全局行为成为可能(图1B)。我们表明,结合许多局部解释可以让我们表示全局结构,同时保持对原始模型的局部忠实度,从而产生模型行为的详细和准确的表示。


在医学应用中,解释树模型的预测尤为重要,因为模型揭示的模式可能比模型的预测性能更重要。为了证明 TreeExplainer 的价值,我们使用了三个医学数据集,它们代表了三种类型的损失函数:1) 死亡率,一个包含 14,407 个个体和 79 个特征的数据集,基于 NHANES I 流行病学随访研究 ,我们在其中模拟了 20 年随访中的死亡风险。2) 慢性肾病,该数据集跟踪了慢性肾功能不全队列研究中的 3,939 名慢性肾病患者,超过 10,745 次就诊,我们使用 333 个特征对患者是否会在 4 年内进展为终末期肾病进行分类。3) 医院手术持续时间,一个具有 147,000 个手术和 2,185 个特征的电子病历数据集,我们在其中预测患者即将进行的手术的住院时间。


在本文中,我们讨论了基于树的模型的准确性和可解释性如何使其适用于许多应用程序。然后,我们描述了为什么这些模型需要更精确的局部解释,以及我们如何使用 TreeExplainer 解决这一需求。接下来,我们扩展局部解释以捕获交互效应。最后,我们展示了可解释的 AI 工具的价值,这些工具结合了 TreeExplainer (https://github.com/suinleelab/treeexplainer-study) 中的许多本地解释。

04
基于树的模型的优点

在许多应用程序中,基于树的模型可能比神经网络更准确。虽然深度学习模型更适合于图像识别、语音识别和自然语言处理等领域,但基于树的模型在表格风格的数据集上始终优于标准的深度模型,其中特征具有单独的意义,并且缺乏强大的多尺度时间或空间结构.我们在这里检查的三个医学数据集都代表表格样式的数据。在所有三个数据集中,梯度提升树的性能都优于纯深度学习和线性回归(图 2A)。


由于模型不匹配效应,基于树的模型也可能比线性模型更容易解释。众所周知,机器学习中的偏差/方差权衡对模型准确性有影响。但不太被理解的是,这种权衡也会影响可解释性。简单的高偏差模型(如线性模型)似乎很容易理解,但它们对模型不匹配很敏感,即模型的形式与其在数据中的真实关系不匹配。这种不匹配可能会产生难以解释的模型伪影。


为了说明为什么低偏差模型比高偏差模型更容易解释,我们使用死亡率数据集将梯度提升树与线性 logistic 回归进行了比较。我们根据参与者的年龄和体重指数 (BMI) 模拟二元结果,并改变模拟关系中的非线性量(图 2B)。正如预期的那样,线性模型的偏差会增加非线性,从而导致精度下降(图 2C)。也许出乎意料的是,它还会导致可解释性下降(图 2D)。我们知道模型应该只取决于年龄和 BMI,但即使是真实关系中适度的非线性也会导致线性模型开始使用其他不相关的特征(图 2D),并且这些特征的权重是由不易解释的复杂抵消效应驱动的.当线性模型依赖于不相关特征之间的抵消效应时,函数本身并不复杂,但它所依赖的特征的含义变得微妙:它们不再主要用于边际效应,而是用于交互效应。因此,即使更简单的高偏差模型实现了高精度,低偏差模型也可能更可取,甚至更具可解释性,因为它们可能更好地代表真实的数据生成机制,并且更自然地依赖于它们的输入特征。

05
树的局部解释

当前基于树的模型的局部解释不一致。据我们所知,只有两种特定于树的方法可以量化特征对单个预测的局部重要性。第一种是简单地报告决策路径,这对于许多树的集成没有帮助。第二种是未发表的启发式方法(由Saabas  提出),它通过遵循决策路径并将模型预期输出的变化归因于路径上的每个特征来解释预测。Saabas 方法尚未得到很好的研究,我们在这里证明,根据特征与树根的距离来改变特征的影响是有偏见的。这种偏差使 Saabas 值不一致,其中增加模型对某个特征的依赖性实际上可能会降低该特征的 Saabas 值。这与有效的归因方法应该做的事情相反。我们通过检查表示多向 AND 函数的树来显示这种差异,对于这些函数,任何特征的功劳都不应超过另一个特征。然而,Saabas 值给出的根附近的分裂比叶附近的分裂要少得多。一致性对于解释方法至关重要,因为它使特征重要性值之间的比较变得有意义。


与模型无关的局部解释方法缓慢且多变.虽然与模型无关的局部解释方法可以解释树模型,但它们依赖于任意函数的事后建模,因此在应用于具有许多输入特征的模型时可能会很慢和/或受到采样可变性的影响。为了说明这一点,我们生成了大小不断增加的随机数据集,然后解释了具有 1,000 棵树的(过)拟合 XGBoost 模型。此实验的运行时间显示,随着特征数量的增加,复杂性呈线性增加;与模型无关的方法需要大量时间在这些数据集上运行,即使我们允许非平凡的估计可变性并且只使用了中等数量的特征。虽然通常对于单个解释很实用,但与模型无关的方法对于解释整个数据集很快就会变得不切实际。


TreeExplainer 提供快速的本地解释,并保证一致性。它通过将精确 Shapley 值计算的复杂性从指数时间降低到多项式时间,将理论与实践联系起来。这很重要,因为在加法特征归因方法类中,我们已经展示的一类包含许多以前的局部特征归因方法 ,博弈论的结果表明 Shapley 值是满足三个重要属性的唯一方法:局部准确性、一致性和缺失性。局部准确性(在博弈论中称为加性)指出,当为特定输入 x 近似原始模型 f 时,解释的归因值应求和为输出 f(x)。一致性(在博弈论中称为单调性)指出,如果模型发生变化,使得某些特征的贡献增加或保持不变,而不管其他输入如何,则该输入的归因不应减少。缺失性(博弈论中的零效应和对称性)是所有先前解释方法都满足的微不足道的属性。


TreeExplainer 通过利用基于树的模型的内部结构,可以在低阶多项式时间内精确计算 Shapley 值。Shapley 值需要所有可能的特征子集的项求和,TreeExplainer 将此求和折叠为一组特定于树中每个叶子的计算(方法)。这表示与以前的精确 Shapley 方法相比,复杂性呈指数级改进。为了计算 Shapley 值计算过程中特定特征子集的影响,TreeExplainer 对用户提供的背景数据集使用干预性期望 [11]。但它也可以通过仅依赖存储在模型中的路径覆盖率信息(通常来自训练数据集)来避免对用户提供的后台数据集的需求。


高效、准确地计算 Shapley 值可以保证解释始终一致且局部准确,从而在以下几个方面比以前的局部解释方法改进结果:

公正的要素信用分配,而不考虑树深度。与 Saabas 值相比,Shapley 值在参与多向 AND 运算的所有特征之间均匀分配信用,从而避免了不一致问题。

无估计变异性。由于与模型无关的抽样方法的解是近似的,因此 TreeExplainer 的精确解释消除了检查它们的收敛性和在估计中接受一定量噪声(来自选择背景数据集的噪声除外)的额外负担。

强劲的基准性能(图 3)。我们设计了 15 个指标来全面评估局部解释方法的性能;我们将这些指标应用于三种不同模型类型和三个数据集的十种不同的解释方法。慢性肾病数据集的结果(如图 3 所示)表明 TreeExplainer 的性能得到了一致的改进。

与人类直觉的一致性。我们通过将解释方法的输出与基于简单模型的 12 种场景的人类共识解释进行比较,评估了解释方法与人类直觉的匹配程度。与启发式 Saabas 值不同,基于 Shapley 值的解释方法在所有测试场景中都与人类直觉一致。


TreeExplainer 还扩展了局部解释以测量交互效应。传统上,基于特征属性的本地解释会为每个输入特征分配一个数字。这种自然表示的简单性是以混淆主效应和交互效应为代价的。虽然特征之间的交互效应可以反映在许多局部解释的全局模式中,但它们与主效应的区别在每个局部解释中都丢失了(图4B-G)。


我们建议将 SHAP 交互值作为一种更丰富的局部解释类型。这些值使用博弈论中的“Shapley 交互指数”来捕获局部交互效应。它们遵循原始 Shapley 价值属性的泛化,不仅在游戏的每个玩家之间分配信用,而且在所有对子的玩家之间分配信用。SHAP 交互值由特征属性矩阵组成(交互作用对非对角线的影响和对对角线的其余影响)。通过为单个模型预测单独考虑交互效应,TreeExplainer 可以发现可能被遗漏的重要模式。

06
局部解释是全局解释的基石

以前理解模型全局的方法侧重于使用简单的全局近似 、寻找新的可解释特征 或量化深度网络中特定内部节点的影响 。我们提出了一些方法,这些方法结合了许多局部解释,以提供对模型行为的全局洞察。这样,我们就可以在保持对模型的局部忠实度,同时仍能捕获全局模式,从而更丰富、更准确地表示模型的行为。


局部模型总结揭示了对死亡风险的罕见高幅度影响,并增加了特征选择能力。将 TreeExplainer 中来自整个数据集的局部解释组合起来,通过以下方式增强了特征重要性的传统全局表示:(1) 避免了当前方法的不一致问题,(2) 提高了检测数据集中真正特征依赖关系的能力,以及 (3) 使我们能够构建 SHAP 摘要图,它简洁地显示特征效果的大小、普遍性和方向。SHAP 汇总图避免将效应的大小和普遍性混为一个数字,从而揭示罕见的高幅度效应。图 4A(右)揭示了影响的方向,例如男性(蓝色)的死亡风险高于女性(红色);以及效应大小的分布,例如许多医学检验值的长右尾。这些长尾意味着全局重要性较低的特征对于特定个体来说可能非常重要。有趣的是,罕见死亡率效应总是向右延伸,这意味着当医学测量超出范围时,有很多方法可以异常提前死亡,但异常延长寿命的方法并不多。


局部特征依赖性揭示了死亡风险和慢性肾病的全球模式和个体变异性。SHAP 依赖性图显示了特征的值(x 轴)如何影响数据集中每个样本(每个点)的预测(y 轴)(图 4B 和 E)。它们提供了比传统的部分依赖图更丰富的信息。对于死亡率模型,SHAP 依赖图再现了收缩压的标准风险拐点,同时也强调了血压对不同年龄人群死亡风险的影响不同(图 4B)。这些类型的交互作用效应在 SHAP 依赖性图中显示为垂直离散。


对于慢性肾病模型,依赖图再次清楚地揭示了收缩压的风险拐点。然而,在这个数据集中,相互作用效应的垂直分散似乎部分是由血尿素氮的差异驱动的(图 4E)。在保持可解释性的同时正确模拟血压风险至关重要,因为特定慢性肾脏病 (CKD) 人群的血压控制可能会延缓肾脏疾病的进展并降低心血管事件的风险。


局部相互作用揭示了衰老过程中性别特异性预期寿命的变化以及慢性肾病的炎症影响。使用 SHAP 交互值,我们可以将特征对特定样本的影响分解为与其他特征的交互效应。这有助于我们测量全局交互强度,并将 SHAP 依赖性图分解为局部(即每个样本)水平的交互效应(图 4B-D)。在死亡率数据集中,绘制年龄和性别之间的 SHAP 交互值显示男性和女性在一生中的相对风险发生了明显变化(图 4G)。男性和女性之间的最大风险差异发生在 60 岁;男性风险增加可能是由于相对于接近该年龄的女性,他们的心血管死亡率增加。如果没有 SHAP 交互值,则无法清楚地捕捉到这种模式,因为男性总是比女性带来更大的死亡风险(图 4A)。


在慢性肾病模型中,我们发现了一个有趣的相互作用(图 4F):当高白细胞计数伴有高血尿素氮时,模型更关心它们。这支持了炎症可能与高血尿素氮相互作用以加速肾功能下降的观点 。


本地模型监控揭示了已部署机器学习模型以前看不见的问题。使用 TreeExplainer 来解释模型的损失,而不是模型的预测,可以提高我们监控已部署模型的能力。监控模型具有挑战性,因为输入和模型目标之间的关系在部署后可以通过多种方式发生变化。检测此类变化何时发生是很困难的,因此机器学习管道中的许多错误都没有被检测到,即使是在顶级科技公司的核心软件中也是如此 。我们演示了本地模型监控有助于调试模型部署,并通过分解模型输入特征之间的损失来直接识别有问题的特征(如果有)。


我们使用医院手术持续时间数据集模拟了模型部署,使用第一年的数据进行训练,并使用接下来的三年数据进行部署。我们举了三个例子:一个是故意的错误和两个以前未被发现的问题。(1) 我们在部署过程中特意交换了 6 号和 13 号手术室的标签,以模拟典型的 Feature Pipeline 错误。模型预测的总体损失没有给出错误的指示(图 5A),而 6 号房间特征的 SHAP 监控图清楚地识别了标记错误(图 5B)。(2) 图 5C 显示了部署窗口开始后不久全身麻醉功能的误差峰值。此峰值对应于受以前未发现的临时电子病历配置问题影响的程序子集。(3) 图 5D 显示了特征随时间漂移的示例,而不是处理错误的示例。在训练期间和部署初期,使用“心房颤动”功能可降低损失;但是,随着时间的推移,该功能的用处逐渐降低,并最终降低模型的性能。我们发现这种漂移是由技术和人员变化驱动的心房颤动消融手术持续时间的显着变化引起的。当前的部署实践既监测模型随时间推移的总体损失(图 5A),也监测有关输入特征的潜在统计数据。相反,TreeExplainer 让我们直接监控单个特征对模型损失的影响。


局部解释嵌入揭示了与慢性肾病的死亡风险和补充诊断指标相关的人群亚组。无监督聚类和降维被广泛用于发现表征样本亚组(例如研究参与者)的模式,例如疾病亚型。这些技术有两个缺点:1) 距离度量没有考虑特征的单位/含义之间的差异(例如,权重与年龄),以及 2) 无监督方法无法知道哪些特征与感兴趣的结果相关,因此应该更强地加权。我们使用本地解释嵌入来解决这两个限制,将每个样本嵌入到一个新的“解释空间”中。在这个新空间中运行聚类将产生一个监督聚类,其中样本根据其解释进行分组。监督聚类自然会考虑各种特征的不同单元,仅突出显示与特定结果相关的更改。


使用死亡率模型运行分层监督聚类会导致许多人群由于类似的原因而具有相似的死亡风险(图 6A)。这种样本分组可以揭示数据集中的高级结构,而这些结构是使用标准无监督聚类无法揭示的,并且具有各种应用,从客户细分到模型调试,再到疾病子类型。类似地,我们也可以在慢性肾病样本的局部解释嵌入上运行 PCA。这揭示了识别处于终末期肾病风险的独特个体的两大类主要风险因素:(1) 基于尿液测量的因素,以及 (2) 基于血液测量的因素(图 6B-D)。这种模式值得注意,因为它会随着我们绘制更多主要特征而继续。血液和尿液特征之间的分离与临床上应同时测量这些因素的事实是一致的。这种对肾脏风险整体结构的洞察在标准的无监督嵌入中根本不明显。




END


扫码关注

公众号:小猪的科研生活



排版:王倩倩

文字:王倩倩


小猪的科研生活
分享日常科研生活和统计以及机器学习知识
 最新文章