迁移到新的“Explanation” API
此笔记本演示了 shap API 在 shap v0.36.0 版本中引入的一些更改。
[1]:
# An example dataset and model
import xgboost
import shap
X, y = shap.datasets.adult(n_points=100)
model = xgboost.XGBClassifier().fit(X, y)
explainer = shap.TreeExplainer(model, X)
总结 API 的主要更改
[2]:
shap_values = explainer.shap_values(X)  # Old style
explanation = explainer(X)  # New style
计算解释
旧样式
在 shap v0.36.0 之前的版本中,解释表示为简单的 numpy 数组,并使用解释器的 .shap_values() 方法计算
[3]:
shap_values = explainer.shap_values(X)
shap_values[:2]  # a numpy array
[3]:
array([[-0.54854601,  0.01639348, -0.46476041,  0.85896822, -1.36168788,
        -0.64692199,  0.0254638 , -0.58422904, -0.02344483,  0.        ,
         0.1224989 ,  0.01079906],
       [-0.83802091,  0.01562196,  0.78349799, -1.10456323, -0.68524691,
        -0.84828204,  0.03734176, -0.86151311, -0.02564897,  0.        ,
        -0.56183428,  0.00415988]])
类似地,旧的绘图函数(如 shap.summary_plot)期望 shap_values 为 numpy 数组。
新样式
从 shap v0.36.0 版本开始,解释现在使用 Explanation 对象表示,并通过直接将解释器作为函数调用来创建
[4]:
explanation = explainer(X)
explanation[:2]  # a shap.Explanation object
[4]:
.values =
array([[-0.54854601,  0.01639348, -0.46476041,  0.85896822, -1.36168788,
        -0.64692199,  0.0254638 , -0.58422904, -0.02344483,  0.        ,
         0.1224989 ,  0.01079906],
       [-0.83802091,  0.01562196,  0.78349799, -1.10456323, -0.68524691,
        -0.84828204,  0.03734176, -0.86151311, -0.02564897,  0.        ,
        -0.56183428,  0.00415988]])
.base_values =
array([-2.70354599, -2.70354599])
.data =
array([[27.,  4., 10.,  0.,  1.,  1.,  4.,  0.,  0.,  0., 44., 39.],
       [27.,  4., 13.,  4., 10.,  0.,  4.,  0.,  0.,  0., 40., 39.]])
shap.Explanation 对象是一个更丰富的表示形式,它包括 shap 值(可通过 .values 属性访问),以及支持上下文信息,例如背景数据集和特征名称。
新样式的绘图函数(如 shap.plot.bar 和 shap.plots.beeswarm)接受这些 Explanation 对象,而不是 numpy 数组。