使用 scikit-learn 进行人口普查收入分类
此示例使用来自 UCI 机器学习数据仓库的标准成人人口普查收入数据集。我们使用 scikit-learn 训练了一个 k 近邻分类器,然后解释了预测结果。
[1]:
import sklearn
import shap
加载人口普查数据
[2]:
X, y = shap.datasets.adult()
X["Occupation"] *= 1000 # to show the impact of feature scale on KNN predictions
X_display, y_display = shap.datasets.adult(display=True)
X_train, X_valid, y_train, y_valid = sklearn.model_selection.train_test_split(X, y, test_size=0.2, random_state=7)
训练 k 近邻分类器
在这里,我们直接在数据上进行训练,没有任何标准化。
[4]:
knn = sklearn.neighbors.KNeighborsClassifier()
knn.fit(X_train, y_train)
[4]:
KNeighborsClassifier()
解释预测
通常我们会使用 logit 链接函数,以便加性特征输入更好地映射到模型的概率输出空间,但 k 近邻可能会产生无限的对数几率比,因此在本示例中我们不使用。
重要的是要注意,在我们解释的 1000 个预测中,职业是主要特征。这是因为它比其他特征的值变化更大,因此对 k 近邻计算的影响更大。
[5]:
def f(x):
return knn.predict_proba(x)[:, 1]
med = X_train.median().values.reshape((1, X_train.shape[1]))
explainer = shap.Explainer(f, med)
shap_values = explainer(X_valid.iloc[0:1000, :])
Permutation explainer: 1001it [00:25, 38.69it/s]
[5]:
shap.plots.waterfall(shap_values[0])

摘要蜜蜂图是查看所有特征在整个数据集上的相对影响的更好方法。特征按其 SHAP 值幅度在所有样本上的总和排序。
[7]:
shap.plots.beeswarm(shap_values)

热图提供了模型行为的另一个全局视图,这次重点关注人口子群体。
[8]:
shap.plots.heatmap(shap_values)

在训练模型之前标准化数据
在这里,我们在标准化数据上重新训练了一个 KNN 模型。
[9]:
# normalize data
dtypes = list(zip(X.dtypes.index, map(str, X.dtypes)))
X_train_norm = X_train.copy()
X_valid_norm = X_valid.copy()
for k, dtype in dtypes:
m = X_train[k].mean()
s = X_train[k].std()
X_train_norm[k] -= m
X_train_norm[k] /= s
X_valid_norm[k] -= m
X_valid_norm[k] /= s
[10]:
knn_norm = sklearn.neighbors.KNeighborsClassifier()
knn_norm.fit(X_train_norm, y_train)
[10]:
KNeighborsClassifier()
解释预测
当我们解释来自新的 KNN 模型的预测时,我们发现职业不再是主要特征,而是更具预测性的特征,例如婚姻状况,驱动了大多数预测。这是一个简单的示例,说明解释您的模型为什么做出预测可以揭示训练过程中的问题。
[11]:
def f(x):
return knn_norm.predict_proba(x)[:, 1]
med = X_train_norm.median().values.reshape((1, X_train_norm.shape[1]))
explainer = shap.Explainer(f, med)
shap_values_norm = explainer(X_valid_norm.iloc[0:1000, :])
Permutation explainer: 1001it [01:26, 11.55it/s]
通过摘要图,我们可以看到平均而言婚姻状况是最重要的,但其他特征(例如资本收益)可能对特定个人产生更大的影响。
[12]:
shap.summary_plot(shap_values_norm, X_valid.iloc[0:1000, :])

依赖散点图显示了教育年限如何增加年收入超过 5 万美元的机会。
[14]:
shap.plots.scatter(shap_values_norm[:, "Education-Num"])

有更多有用的示例的想法吗?鼓励提交拉取请求以添加到此文档笔记本!