shap.Cohorts

class shap.Cohorts(**kwargs: Explanation)

一个 Explanation 对象的集合,通常每个对象解释相似样本的集群。

示例

Cohorts 对象可以通过多种方式初始化。

通过显式指定 cohorts

>>> exp = Explanation(
...     values=np.random.uniform(low=-1, high=1, size=(500, 5)),
...     data=np.random.normal(loc=1, scale=3, size=(500, 5)),
...     feature_names=list("abcde"),
... )
>>> cohorts = Cohorts(
...     col_a_neg=exp[exp[:, "a"].data < 0],
...     col_a_pos=exp[exp[:, "a"].data >= 0],
... )
>>> cohorts
<shap._explanation.Cohorts object with 2 cohorts of sizes: [(198, 5), (302, 5)]>

或使用 Explanation.cohorts() 方法

>>> cohorts2 = exp.cohorts(3)
>>> cohorts2
<shap._explanation.Cohorts object with 3 cohorts of sizes: [(182, 5), (12, 5), (306, 5)]>

Explanation 接口的大部分也在 Cohorts 中公开。例如,要检索跨所有 cohorts 的列 ‘a’ 对应的 SHAP 值,您可以使用

>>> cohorts[..., 'a'].values
<shap._explanation.Cohorts object with 2 cohorts of sizes: [(198,), (302,)]>

要实际检索特定 Explanation 的值,您需要通过 Cohorts.cohorts() 属性访问它

>>> cohorts.cohorts["col_a_neg"][..., 'a'].values
array([...])  # truncated
__init__(**kwargs: Explanation) None

方法

__init__(**kwargs)

属性

cohorts

cohorts 的内部集合,存储为字典。

property cohorts: dict[str, Explanation]

cohorts 的内部集合,存储为字典。