使用 scikit-learn 进行鸢尾花分类
在这里,我们使用著名的鸢尾花物种数据集来说明 SHAP 如何解释多种不同模型类型的输出,从 K-最近邻到神经网络。这是一个非常小的数据集,只有 150 个样本。我们随机使用 130 个样本进行训练,20 个样本用于测试模型。由于这是一个特征很少的小数据集,我们使用整个训练数据集作为背景数据。在特征更多的问题中,我们可能只想传递训练数据集的中位数或加权的 K-中位数。虽然我们只有少量样本,但这个预测问题相当简单,所有方法都达到了完美的准确率。有趣的是,不同的方法有时会依赖于不同的特征集来进行预测。
加载数据
[1]:
import time
import numpy as np
import sklearn
from sklearn.model_selection import train_test_split
import shap
X_train, X_test, Y_train, Y_test = train_test_split(*shap.datasets.iris(), test_size=0.2, random_state=0)
# rather than use the whole training set to estimate expected values, we could summarize with
# a set of weighted kmeans, each weighted by the number of points they represent. But this dataset
# is so small we don't worry about it
# X_train_summary = shap.kmeans(X_train, 50)
def print_accuracy(f):
    print(f"Accuracy = {100 * np.sum(f(X_test) == Y_test) / len(Y_test)}%")
    time.sleep(0.5)  # to let the print get out before any progress bars
shap.initjs()
K-最近邻
[2]:
knn = sklearn.neighbors.KNeighborsClassifier()
knn.fit(X_train, Y_train)
print_accuracy(knn.predict)
Accuracy = 96.66666666666667%
解释测试集中的单个预测
[3]:
explainer = shap.KernelExplainer(knn.predict_proba, X_train)
shap_values = explainer.shap_values(X_test.iloc[0, :])
shap.force_plot(explainer.expected_value[0], shap_values[:, 0], X_test.iloc[0, :])
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[3]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
解释测试集中的所有预测
[4]:
shap_values = explainer.shap_values(X_test)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
[4]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
支持向量机(线性核)
[5]:
svc_linear = sklearn.svm.SVC(kernel="linear", probability=True)
svc_linear.fit(X_train, Y_train)
print_accuracy(svc_linear.predict)
# explain all the predictions in the test set
explainer = shap.KernelExplainer(svc_linear.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
# this is multiclass so we only visualize the contributions to first class (hence index 0)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
Accuracy = 100.0%
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[5]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
支持向量机(径向基函数核)
[6]:
svc_linear = sklearn.svm.SVC(kernel="rbf", probability=True)
svc_linear.fit(X_train, Y_train)
print_accuracy(svc_linear.predict)
# explain all the predictions in the test set
explainer = shap.KernelExplainer(svc_linear.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
# this is multiclass so we only visualize the contributions to first class (hence index 0)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
Accuracy = 100.0%
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[6]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
逻辑回归
[7]:
linear_lr = sklearn.linear_model.LogisticRegression(solver="newton-cg")
linear_lr.fit(X_train, Y_train)
print_accuracy(linear_lr.predict)
# explain all the predictions in the test set
explainer = shap.KernelExplainer(linear_lr.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
Accuracy = 100.0%
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[7]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
决策树
[8]:
import sklearn.tree
dtree = sklearn.tree.DecisionTreeClassifier(min_samples_split=2)
dtree.fit(X_train, Y_train)
print_accuracy(dtree.predict)
# explain all the predictions in the test set
explainer = shap.KernelExplainer(dtree.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
Accuracy = 100.0%
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[8]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
随机森林
[9]:
from sklearn.ensemble import RandomForestClassifier
rforest = RandomForestClassifier(n_estimators=100, max_depth=None, min_samples_split=2, random_state=0)
rforest.fit(X_train, Y_train)
print_accuracy(rforest.predict)
# explain all the predictions in the test set
explainer = shap.KernelExplainer(rforest.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
Accuracy = 100.0%
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[9]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
神经网络
[10]:
from sklearn.neural_network import MLPClassifier
nn = MLPClassifier(solver="lbfgs", alpha=1e-1, hidden_layer_sizes=(5, 2), random_state=0)
nn.fit(X_train, Y_train)
print_accuracy(nn.predict)
# explain all the predictions in the test set
explainer = shap.KernelExplainer(nn.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)
Accuracy = 96.66666666666667%
Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
[10]:
  可视化已省略,Javascript 库未加载!
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。
你是否已在此 notebook 中运行 `initjs()`?如果此 notebook 来自其他用户,你还必须信任此 notebook (文件 -> 信任 notebook)。如果你在 github 上查看此 notebook,则 Javascript 已为安全起见被剥离。如果你正在使用 JupyterLab,此错误是因为尚未编写 JupyterLab 扩展。