情感分类多类别示例

本笔记本演示了如何在多类别文本分类场景中使用 Partition 解释器。一旦计算出一组句子的 SHAP 值,我们就可以可视化特征对各个类别的归因。我们使用的文本分类模型是 BERT,它在情感数据集上进行了微调,以将句子分类为六个类别:喜悦、悲伤、愤怒、恐惧、爱和惊讶。

[1]:
import datasets
import pandas as pd
import transformers

import shap

# load the emotion dataset
dataset = datasets.load_dataset("emotion", split="train")
data = pd.DataFrame({"text": dataset["text"], "emotion": dataset["label"]})
Using custom data configuration default
Reusing dataset emotion (/home/slundberg/.cache/huggingface/datasets/emotion/default/0.0.0/aa34462255cd487d04be8387a2d572588f6ceee23f784f37365aa714afeb8fe6)

构建 transformers 流水线

请注意,我们为流水线设置了 return_all_scores=True,以便我们可以观察模型对所有类别的行为,而不仅仅是顶部输出。

[2]:
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("nateraw/bert-base-uncased-emotion", use_fast=True)
model = transformers.AutoModelForSequenceClassification.from_pretrained("nateraw/bert-base-uncased-emotion").cuda()

# build a pipeline object to do predictions
pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0,
    return_all_scores=True,
)

为流水线创建解释器

transformers 流水线对象可以直接传递给 shap.Explainer,这会将流水线模型包装为 shap.models.TransformersPipeline 模型,并将流水线分词器包装为 shap.maskers.Text 掩码器。

[3]:
explainer = shap.Explainer(pred)

计算 SHAP 值

解释器与它们正在解释的模型具有相同的方法签名,因此我们只需传递一个字符串列表,用于解释分类。

[4]:
shap_values = explainer(data["text"][:3])

可视化所有输出类别的影响

在下面的图中,当您将鼠标悬停在输出类别上时,您将获得该输出类别的解释。当您单击输出类别名称时,该类别将保持解释可视化的焦点,直到您单击另一个类别。

基准值是当整个输入文本被掩盖时模型输出的值,而 \(f_{output class}(inputs)\) 是模型对于完整原始输入的输出。SHAP 值以累加的方式解释了取消掩盖每个词如何将模型输出从基准值(整个输入被掩盖时)更改为最终预测值。

[5]:
shap.plots.text(shap_values)


[0]
输出
悲伤
喜悦
愤怒
恐惧
惊讶


0.30.1-0.1-0.30.50.70.90.1316720.131672base value0.9964080.996408fsadness(inputs)0.855 humiliated 0.009 didn 0.003 i 0.001 t 0.0 -0.004 feel -0.0
输入
-0.0
0.003
0.009
0.001
-0.004
感觉
0.855
羞辱
0.0


[1]
输出
悲伤
喜悦
愤怒
恐惧
惊讶


0.30.1-0.1-0.30.50.70.90.1441950.144195base value0.9952920.995292fsadness(inputs)0.599 hopeless 0.28 feeling 0.039 so 0.004 from 0.004 damned 0.002 from 0.002 awake 0.002 i 0.001 to 0.001 who 0.0 go 0.0 0.0 -0.045 hopeful -0.011 cares -0.006 just -0.006 is -0.004 someone -0.003 can -0.002 and -0.002 being -0.002 around -0.001 so
输入
0.0
0.002
-0.003
0.0
0.004
0.28
感觉
0.039
如此
0.599
绝望
0.001
-0.001
如此
0.004
该死
-0.045
充满希望
-0.006
只是
0.002
-0.002
-0.002
周围
-0.004
某人
0.001
-0.011
关心
-0.002
-0.006
0.002
醒着
0.0


[2]
输出
悲伤
喜悦
愤怒
恐惧
惊讶


0.30.1-0.1-0.30.50.70.90.152610.15261base value0.002277240.00227724fsadness(inputs)0.0 i 0.0 0.0 -0.097 greedy -0.019 feel -0.013 grabbing -0.007 im -0.005 a -0.005 to -0.003 post -0.001 wrong -0.0 minute
输入
0.0
-0.007
-0.013
抓住
-0.005
一个
-0.0
分钟
-0.005
-0.003
帖子
0.0
-0.019
感觉
-0.097
贪婪
-0.001
错误
0.0

可视化单个类别的影响

由于 Explanation 对象是可切片的,我们可以切出一个输出类别来可视化模型对该类别的输出。

[11]:
shap.plots.text(shap_values[:, :, "anger"])


[0]
0.50.30.10.70.90.2789150.278915base value0.001233210.00123321fanger(inputs)0.028 didn 0.015 i 0.008 t -0.199 humiliated -0.13 feel -0.0 -0.0
输入
-0.0
0.015
0.028
0.008
-0.13
感觉
-0.199
羞辱
-0.0


[1]
0.50.30.10.70.90.2716290.271629base value0.000462820.00046282fanger(inputs)0.015 damned 0.005 from 0.005 to 0.004 so 0.004 around 0.002 i 0.002 being 0.001 is 0.0 -0.097 hopeful -0.08 hopeless -0.045 feeling -0.028 awake -0.021 cares -0.016 so -0.008 someone -0.004 just -0.004 who -0.003 and -0.003 can go -0.001 from -0.0
输入
-0.0
0.002
-0.003 / 2
能去
0.005
-0.045
感觉
-0.016
如此
-0.08
绝望
0.005
0.004
如此
0.015
该死
-0.097
充满希望
-0.004
只是
-0.001
0.002
0.004
周围
-0.008
某人
-0.004
-0.021
关心
-0.003
0.001
-0.028
醒着
0.0


[2]
0.50.30.10.70.90.2303730.230373base value0.9914620.991462fanger(inputs)0.545 greedy 0.118 wrong 0.07 grabbing 0.023 post 0.015 im 0.006 feel 0.005 minute 0.0 -0.016 to -0.004 i -0.001 a -0.0
输入
-0.0
0.015
0.07
抓住
-0.001
一个
0.005
分钟
-0.016
0.023
帖子
-0.004
0.006
感觉
0.545
贪婪
0.118
错误
0.0

绘制影响特定类别的顶部词汇

除了切片之外,Explanation 对象还支持一组缩减方法。 这里我们使用 .mean(0) 来获取所有词语对“喜悦”类别的平均影响。请注意,这里我们还对三个示例进行了平均,为了获得更好的总结,您需要使用数据集的更大部分。

[12]:
shap.plots.bar(shap_values[:, :, "joy"].mean(0))
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_14_0.png
[13]:
# we can sort the bar chart in decending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort)
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_15_0.png
[14]:
# ...or acending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort.flip)
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_16_0.png

解释对数几率而非概率

在上面的示例中,我们解释了流水线对象的直接输出,即类别概率。有时在对数几率空间中工作更有意义,在其中添加和减去效果是很自然的(加法和减法对应于证据信息位的加法或减法)。要使用 logits,我们可以使用 shap.models.TransformersPipeline 对象的一个参数

[15]:
logit_explainer = shap.Explainer(shap.models.TransformersPipeline(pred, rescale_to_logits=True))

logit_shap_values = logit_explainer(data["text"][:3])
shap.plots.text(logit_shap_values)


[0]
输出
悲伤
喜悦
愤怒
恐惧
惊讶


-1-4-725-1.88626-1.88626base value5.625445.62544fsadness(inputs)6.901 humiliated 0.201 feel 0.173 didn 0.16 i 0.076 t 0.0 -0.0
输入
-0.0
0.16
0.173
0.076
0.201
感觉
6.901
羞辱
0.0


[1]
输出
悲伤
喜悦
愤怒
恐惧
惊讶


-1-4-725-1.78088-1.78088base value5.353885.35388fsadness(inputs)5.914 hopeless 2.741 feeling 0.248 so 0.079 to so 0.063 can go 0.053 damned 0.029 from -1.3 hopeful -0.172 just from -0.135 awake -0.119 cares -0.11 someone who -0.071 being around -0.054 is -0.025 i -0.006 and -0.0
输入
-0.025 / 2
0.063 / 2
能去
0.029
2.741
感觉
0.248
如此
5.914
绝望
0.079 / 2
到 如此
0.053
该死
-1.3
充满希望
-0.172 / 2
只是 从
-0.071 / 2
在周围
-0.11 / 2
某人 谁
-0.119
关心
-0.006
-0.054
-0.135
醒着
-0.0


[2]
输出
悲伤
喜悦
愤怒
恐惧
惊讶


-1-4-725-1.71428-1.71428base value-6.08251-6.08251fsadness(inputs)0.212 wrong 0.009 post 0.0 0.0 -3.174 greedy -0.528 feel -0.518 grabbing -0.152 im -0.131 a -0.067 to -0.02 i -0.0 minute
输入
0.0
-0.152
-0.518
抓住
-0.131
一个
-0.0
分钟
-0.067
0.009
帖子
-0.02
-0.528
感觉
-3.174
贪婪
0.212
错误
0.0

有更多有用的示例的想法吗? 鼓励提交 pull request 来为此文档笔记本添加内容!