Tree SHAP 的 Python 版本
这是用 Python 编写的 Tree SHAP 的示例实现,易于阅读。
[1]:
import time
import numba
import numpy as np
import sklearn.ensemble
import xgboost
import shap
加载加州数据集
[2]:
X, y = shap.datasets.california(n_points=1000)
X.shape
[2]:
(1000, 8)
训练 sklearn 随机森林
[3]:
model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)
model.fit(X, y)
[3]:
RandomForestRegressor(max_depth=4, n_estimators=1000)在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任笔记本。
在 GitHub 上,HTML 表示无法渲染,请尝试使用 nbviewer.org 加载此页面。
RandomForestRegressor(max_depth=4, n_estimators=1000)
训练 XGBoost 模型
[4]:
bst = xgboost.train({"learning_rate": 0.01, "max_depth": 4}, xgboost.DMatrix(X, label=y), 1000)
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
Python TreeExplainer
这使用 numba 来加速。
[5]:
class TreeExplainer:
def __init__(self, model, **kwargs):
if str(type(model)).endswith("sklearn.ensemble._forest.RandomForestRegressor'>"):
# self.trees = [Tree(e.tree_) for e in model.estimators_]
self.trees = [
Tree(
children_left=e.tree_.children_left,
children_right=e.tree_.children_right,
children_default=e.tree_.children_right,
feature=e.tree_.feature,
threshold=e.tree_.threshold,
value=e.tree_.value[:, 0, 0],
node_sample_weight=e.tree_.weighted_n_node_samples,
)
for e in model.estimators_
]
# Preallocate space for the unique path data
maxd = np.max([t.max_depth for t in self.trees]) + 2
s = (maxd * (maxd + 1)) // 2
self.feature_indexes = np.zeros(s, dtype=np.int32)
self.zero_fractions = np.zeros(s, dtype=np.float64)
self.one_fractions = np.zeros(s, dtype=np.float64)
self.pweights = np.zeros(s, dtype=np.float64)
def shap_values(self, X, **kwargs):
# convert dataframes
if str(type(X)).endswith("pandas.core.series.Series'>"):
X = X.values
elif str(type(X)).endswith("'pandas.core.frame.DataFrame'>"):
X = X.values
assert str(type(X)).endswith("'numpy.ndarray'>"), "Unknown instance type: " + str(type(X))
assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"
# single instance
if len(X.shape) == 1:
phi = np.zeros(X.shape[0] + 1)
x_missing = np.zeros(X.shape[0], dtype=bool)
for t in self.trees:
self.tree_shap(t, X, x_missing, phi)
phi /= len(self.trees)
elif len(X.shape) == 2:
phi = np.zeros((X.shape[0], X.shape[1] + 1))
x_missing = np.zeros(X.shape[1], dtype=bool)
for i in range(X.shape[0]):
for t in self.trees:
self.tree_shap(t, X[i, :], x_missing, phi[i, :])
phi /= len(self.trees)
return phi
def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):
# update the bias term, which is the last index in phi
# (note the paper has this as phi_0 instead of phi_M)
if condition == 0:
phi[-1] += tree.values[0]
# start the recursive algorithm
tree_shap_recursive(
tree.children_left,
tree.children_right,
tree.children_default,
tree.features,
tree.thresholds,
tree.values,
tree.node_sample_weight,
x,
x_missing,
phi,
0,
0,
self.feature_indexes,
self.zero_fractions,
self.one_fractions,
self.pweights,
1,
1,
-1,
condition,
condition_feature,
1,
)
[6]:
# extend our decision path with a fraction of one and zero extensions
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.float64,
numba.types.float64,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
zero_fraction,
one_fraction,
feature_index,
):
feature_indexes[unique_depth] = feature_index
zero_fractions[unique_depth] = zero_fraction
one_fractions[unique_depth] = one_fraction
if unique_depth == 0:
pweights[unique_depth] = 1
else:
pweights[unique_depth] = 0
for i in range(unique_depth - 1, -1, -1):
pweights[i + 1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)
pweights[i] = zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)
# undo a previous extension of the decision path
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwind_path(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = pweights[i]
pweights[i] = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
next_one_portion = tmp - pweights[i] * zero_fraction * (unique_depth - i) / (unique_depth + 1)
else:
pweights[i] = (pweights[i] * (unique_depth + 1)) / (zero_fraction * (unique_depth - i))
for i in range(path_index, unique_depth):
feature_indexes[i] = feature_indexes[i + 1]
zero_fractions[i] = zero_fractions[i + 1]
one_fractions[i] = one_fractions[i + 1]
# determine what the total permuation weight would be if
# we unwound a previous extension in the decision path
@numba.jit(
numba.types.float64(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
total = 0
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
total += tmp
next_one_portion = pweights[i] - tmp * zero_fraction * ((unique_depth - i) / (unique_depth + 1))
else:
total += (pweights[i] / zero_fraction) / ((unique_depth - i) / (unique_depth + 1))
return total
class Tree:
def __init__(
self,
children_left,
children_right,
children_default,
feature,
threshold,
value,
node_sample_weight,
):
self.children_left = children_left.astype(np.int32)
self.children_right = children_right.astype(np.int32)
self.children_default = children_default.astype(np.int32)
self.features = feature.astype(np.int32)
self.thresholds = threshold
self.values = value
self.node_sample_weight = node_sample_weight
self.max_depth = compute_expectations(
self.children_left,
self.children_right,
self.node_sample_weight,
self.values,
0,
)
@numba.jit(nopython=True)
def compute_expectations(children_left, children_right, node_sample_weight, values, i, depth=0):
if children_right[i] == -1:
values[i] = values[i]
return 0
else:
li = children_left[i]
ri = children_right[i]
depth_left = compute_expectations(children_left, children_right, node_sample_weight, values, li, depth + 1)
depth_right = compute_expectations(children_left, children_right, node_sample_weight, values, ri, depth + 1)
left_weight = node_sample_weight[li]
right_weight = node_sample_weight[ri]
v = (left_weight * values[li] + right_weight * values[ri]) / (left_weight + right_weight)
values[i] = v
return max(depth_left, depth_right) + 1
# recursive computation of SHAP values for a decision tree
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.boolean[:],
numba.types.float64[:],
numba.types.int64,
numba.types.int64,
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64,
numba.types.float64,
numba.types.int64,
numba.types.int64,
numba.types.int64,
numba.types.float64,
),
nopython=True,
nogil=True,
)
def tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
node_index,
unique_depth,
parent_feature_indexes,
parent_zero_fractions,
parent_one_fractions,
parent_pweights,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
condition,
condition_feature,
condition_fraction,
):
# stop if we have no weight coming down to us
if condition_fraction == 0:
return
# extend the unique path
feature_indexes = parent_feature_indexes[unique_depth + 1 :]
feature_indexes[: unique_depth + 1] = parent_feature_indexes[: unique_depth + 1]
zero_fractions = parent_zero_fractions[unique_depth + 1 :]
zero_fractions[: unique_depth + 1] = parent_zero_fractions[: unique_depth + 1]
one_fractions = parent_one_fractions[unique_depth + 1 :]
one_fractions[: unique_depth + 1] = parent_one_fractions[: unique_depth + 1]
pweights = parent_pweights[unique_depth + 1 :]
pweights[: unique_depth + 1] = parent_pweights[: unique_depth + 1]
if condition == 0 or condition_feature != parent_feature_index:
extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
)
split_index = features[node_index]
# leaf node
if children_right[node_index] == -1:
for i in range(1, unique_depth + 1):
w = unwound_path_sum(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
i,
)
phi[feature_indexes[i]] += (
w * (one_fractions[i] - zero_fractions[i]) * values[node_index] * condition_fraction
)
# internal node
else:
# find which branch is "hot" (meaning x would follow it)
hot_index = 0
cleft = children_left[node_index]
cright = children_right[node_index]
if x_missing[split_index] == 1:
hot_index = children_default[node_index]
elif x[split_index] < thresholds[node_index]:
hot_index = cleft
else:
hot_index = cright
cold_index = cright if hot_index == cleft else cleft
w = node_sample_weight[node_index]
hot_zero_fraction = node_sample_weight[hot_index] / w
cold_zero_fraction = node_sample_weight[cold_index] / w
incoming_zero_fraction = 1
incoming_one_fraction = 1
# see if we have already split on this feature,
# if so we undo that split so we can redo it for this node
path_index = 0
while path_index <= unique_depth:
if feature_indexes[path_index] == split_index:
break
path_index += 1
if path_index != unique_depth + 1:
incoming_zero_fraction = zero_fractions[path_index]
incoming_one_fraction = one_fractions[path_index]
unwind_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
path_index,
)
unique_depth -= 1
# divide up the condition_fraction among the recursive calls
hot_condition_fraction = condition_fraction
cold_condition_fraction = condition_fraction
if condition > 0 and split_index == condition_feature:
cold_condition_fraction = 0
unique_depth -= 1
elif condition < 0 and split_index == condition_feature:
hot_condition_fraction *= hot_zero_fraction
cold_condition_fraction *= cold_zero_fraction
unique_depth -= 1
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
hot_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
hot_zero_fraction * incoming_zero_fraction,
incoming_one_fraction,
split_index,
condition,
condition_feature,
hot_condition_fraction,
)
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
cold_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
cold_zero_fraction * incoming_zero_fraction,
0,
split_index,
condition,
condition_feature,
cold_condition_fraction,
)
比较 XGBoost Tree SHAP 的运行时…
[7]:
start = time.time()
shap_values = bst.predict(xgboost.DMatrix(X), pred_contribs=True)
print(time.time() - start)
0.1391134262084961
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
与 Python (numba) Tree SHAP 相比…
[8]:
x = np.ones(X.shape[1])
TreeExplainer(model).shap_values(x)
[8]:
array([-0.5140118 , 0.07819144, 0.09588052, -0.01004643, 0.21487504,
0.37846563, 0.15740923, -0.04002657, 2.04526993])
[9]:
start = time.time()
ex = TreeExplainer(model)
print(time.time() - start)
start = time.time()
ex.shap_values(X.iloc[:, :])
print(time.time() - start)
0.0075643062591552734
10.171732664108276