使用 Keras 进行人口普查收入分类
要下载此 notebook 的副本,请访问 github。
[1]:
from keras.layers import (
Dense,
Dropout,
Flatten,
Input,
concatenate,
)
from keras.layers.embeddings import Embedding
from keras.models import Model
from sklearn.model_selection import train_test_split
import shap
# print the JS visualization code to the notebook
shap.initjs()
Using TensorFlow backend.
加载数据集
[2]:
X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)
# normalize data (this is important for model convergence)
dtypes = list(zip(X.dtypes.index, map(str, X.dtypes)))
for k, dtype in dtypes:
if dtype == "float32":
X[k] -= X[k].mean()
X[k] /= X[k].std()
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=7)
训练 Keras 模型
[3]:
# build model
input_els = []
encoded_els = []
for k, dtype in dtypes:
input_els.append(Input(shape=(1,)))
if dtype == "int8":
e = Flatten()(Embedding(X_train[k].max() + 1, 1)(input_els[-1]))
else:
e = input_els[-1]
encoded_els.append(e)
encoded_els = concatenate(encoded_els)
layer1 = Dropout(0.5)(Dense(100, activation="relu")(encoded_els))
out = Dense(1)(layer1)
# train model
regression = Model(inputs=input_els, outputs=[out])
regression.compile(optimizer="adam", loss="binary_crossentropy")
regression.fit(
[X_train[k].values for k, t in dtypes],
y_train,
epochs=50,
batch_size=512,
shuffle=True,
validation_data=([X_valid[k].values for k, t in dtypes], y_valid),
)
Train on 26048 samples, validate on 6513 samples
Epoch 1/50
26048/26048 [==============================] - 1s 28us/step - loss: 2.3308 - val_loss: 0.4450
Epoch 2/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.5018 - val_loss: 0.5353
Epoch 3/50
26048/26048 [==============================] - 0s 9us/step - loss: 1.3662 - val_loss: 0.5634
Epoch 4/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.3522 - val_loss: 0.6502
Epoch 5/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.3053 - val_loss: 0.5451
Epoch 6/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.2348 - val_loss: 0.5146
Epoch 7/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.2083 - val_loss: 0.4880
Epoch 8/50
26048/26048 [==============================] - 0s 9us/step - loss: 1.2280 - val_loss: 0.7679
Epoch 9/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.1979 - val_loss: 0.4658
Epoch 10/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.1313 - val_loss: 0.5112
Epoch 11/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.1138 - val_loss: 0.5580
Epoch 12/50
26048/26048 [==============================] - 0s 9us/step - loss: 1.2020 - val_loss: 0.4981
Epoch 13/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0844 - val_loss: 0.4940
Epoch 14/50
26048/26048 [==============================] - 0s 10us/step - loss: 1.0802 - val_loss: 0.5090
Epoch 15/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0761 - val_loss: 0.5058
Epoch 16/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0470 - val_loss: 0.5143
Epoch 17/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0285 - val_loss: 0.5553
Epoch 18/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0215 - val_loss: 0.5479
Epoch 19/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0137 - val_loss: 0.5628
Epoch 20/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0022 - val_loss: 0.5426
Epoch 21/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.9641 - val_loss: 0.5291
Epoch 22/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.9765 - val_loss: 0.7090
Epoch 23/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0097 - val_loss: 0.4819
Epoch 24/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.9100 - val_loss: 0.4874
Epoch 25/50
26048/26048 [==============================] - 0s 9us/step - loss: 0.8821 - val_loss: 0.4724
Epoch 26/50
26048/26048 [==============================] - 0s 9us/step - loss: 0.8653 - val_loss: 0.5671
Epoch 27/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0496 - val_loss: 0.6884
Epoch 28/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.9529 - val_loss: 0.5993
Epoch 29/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.9255 - val_loss: 0.5297
Epoch 30/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.8726 - val_loss: 0.4880
Epoch 31/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.8523 - val_loss: 0.4730
Epoch 32/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.8526 - val_loss: 0.4683
Epoch 33/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.7988 - val_loss: 0.4655
Epoch 34/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7920 - val_loss: 0.4560
Epoch 35/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7629 - val_loss: 0.4449
Epoch 36/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.7506 - val_loss: 0.4388
Epoch 37/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.7266 - val_loss: 0.4366
Epoch 38/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7460 - val_loss: 0.4239
Epoch 39/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7268 - val_loss: 0.4159
Epoch 40/50
26048/26048 [==============================] - 0s 10us/step - loss: 0.7199 - val_loss: 0.4025
Epoch 41/50
26048/26048 [==============================] - 0s 9us/step - loss: 0.6725 - val_loss: 0.4090
Epoch 42/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7740 - val_loss: 0.4576
Epoch 43/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7491 - val_loss: 0.4111
Epoch 44/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6639 - val_loss: 0.4068
Epoch 45/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6734 - val_loss: 0.4218
Epoch 46/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6580 - val_loss: 0.3993
Epoch 47/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6516 - val_loss: 0.4000
Epoch 48/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6464 - val_loss: 0.3989
Epoch 49/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6258 - val_loss: 0.4004
Epoch 50/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6157 - val_loss: 0.4005
[3]:
<keras.callbacks.History at 0x10d720390>
解释预测
在这里,我们采用上面训练的 Keras 模型,并解释它为什么对不同的人做出不同的预测。SHAP 期望模型函数将 2D numpy 数组作为输入,因此我们在原始 Keras 预测函数周围定义了一个包装器函数。
[4]:
def f(X):
return regression.predict([X[:, i] for i in range(X.shape[1])]).flatten()
解释单个预测
在这里,我们使用数据集中 50 个样本的选择来表示“典型”特征值,然后使用 500 个扰动样本来估计给定预测的 SHAP 值。 请注意,这需要对模型进行 500 * 50 次评估。
[5]:
explainer = shap.KernelExplainer(f, X.iloc[:50, :])
shap_values = explainer.shap_values(X.iloc[299, :], nsamples=500)
shap.force_plot(explainer.expected_value, shap_values, X_display.iloc[299, :])
[5]:
可视化已省略,Javascript 库未加载!
你是否在此 notebook 中运行了 `initjs()`? 如果此 notebook 来自其他用户,您还必须信任此 notebook(文件 -> 信任 notebook)。 如果您在 github 上查看此 notebook,则 Javascript 已被剥离以确保安全。
你是否在此 notebook 中运行了 `initjs()`? 如果此 notebook 来自其他用户,您还必须信任此 notebook(文件 -> 信任 notebook)。 如果您在 github 上查看此 notebook,则 Javascript 已被剥离以确保安全。
解释多个预测
在这里,我们对 50 个人重复上述解释过程。 由于我们使用的是基于采样的近似,因此每个解释可能需要几秒钟,具体取决于您的机器设置。
[6]:
shap_values50 = explainer.shap_values(X.iloc[280:330, :], nsamples=500)
100%|██████████| 50/50 [00:53<00:00, 1.08s/it]
[7]:
shap.force_plot(explainer.expected_value, shap_values50, X_display.iloc[280:330, :])
[7]:
可视化已省略,Javascript 库未加载!
你是否在此 notebook 中运行了 `initjs()`? 如果此 notebook 来自其他用户,您还必须信任此 notebook(文件 -> 信任 notebook)。 如果您在 github 上查看此 notebook,则 Javascript 已被剥离以确保安全。
你是否在此 notebook 中运行了 `initjs()`? 如果此 notebook 来自其他用户,您还必须信任此 notebook(文件 -> 信任 notebook)。 如果您在 github 上查看此 notebook,则 Javascript 已被剥离以确保安全。