开放式 GPT2 文本生成解释

本笔记本演示了如何获取用于开放式文本生成的 gpt2 输出的解释。在本演示中,我们使用 hugging face 提供的预训练 gpt2 模型 (https://hugging-face.cn/gpt2) 来解释 gpt2 生成的文本。我们进一步展示了如何获取自定义输出生成文本的解释,并绘制任何输出生成 token 的全局输入 token 重要性。

[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

加载模型和分词器

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()

下面,我们设置某些模型配置。我们需要定义模型是解码器还是编码器-解码器。这可以通过模型配置文件中的 ‘is_decoder’ 或 ‘is_encoder_decoder’ 参数设置。我们还可以设置自定义模型生成参数,这些参数将在输出文本生成解码过程中使用。

[3]:
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2,
}

定义初始文本

[4]:
s = ["I enjoy walking with my cute dog"]

创建解释器对象并计算 SHAP 值

[5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.

可视化 shap 解释

[6]:
shap.plots.text(shap_values)


[0]
输出
,
但是
确定
是否
永远
能够
能够


0-2-424-4.04941-4.04941base value-1.27522-1.27522f,(inputs)4.064 dog 0.072 with -0.431 enjoy -0.427 walking -0.238 cute -0.15 my -0.117 I
输入
-0.117
-0.431
享受
-0.427
散步
0.072
-0.15
我的
-0.238
可爱
4.064

另一个例子…

[7]:
s = ["Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth"]
[8]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[9]:
shap.plots.text(shap_values)


[0]
输出
接下来的
几天
科学家
.


-3-5-7-11-4.7396-4.7396base value-1.7384-1.7384fin(inputs)1.436 Earth 1.28 collide 0.489 with 0.442 Scientists 0.24 worst 0.221 asteroid 0.093 massive 0.075 outcome -0.385 the -0.365 : -0.168 the -0.166 will -0.156 possible -0.035 confirmed
输入
0.442
证实
-0.035
最坏的
-0.168
接下来的
0.24
可能
-0.156
结果
0.075
巨大的
-0.365
:
-0.385
接下来的
0.093
小行星
0.221
-0.166
碰撞
1.28
地球
0.489
1.436
地球

自定义文本生成和调试有偏差的输出

下面我们演示如何解释使用模型生成特定输出句子的可能性,给定一个输入句子。例如,我们提出一个问题:在句子 “我知道很多人是 [target]” 中,哪个国家/地区的居民(目标)最有可能在输出句子 “他们喜欢他们的伏特加!” 中生成 token “伏特加”?为此,我们首先定义输入-输出句子对

[10]:
# define input
x = [
    "I know many people who are Russian.",
    "I know many people who are Greek.",
    "I know many people who are Australian.",
    "I know many people who are American.",
    "I know many people who are Italian.",
    "I know many people who are Spanish.",
    "I know many people who are German.",
    "I know many people who are Indian.",
]
[11]:
# define output
y = [
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
]

我们用 Teacher Forcing 评分类包装模型,并创建一个文本掩码器

[12]:
teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)

创建解释器…

[13]:
explainer = shap.Explainer(teacher_forcing_model, masker)

生成 SHAP 解释值!

[14]:
shap_values = explainer(x, y)

现在我们已经生成了 SHAP 值,我们可以看看输入中的 token 对输出句子中 token “伏特加” 的贡献,使用文本图。注意:红色表示正向贡献,而蓝色表示负向贡献,颜色的强度显示其在各自方向上的强度。

[15]:
shap.plots.text(shap_values)


[0]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.78452-8.78452fThey(inputs)0.375 . 0.124 people 0.109 are 0.035 who -0.488 Russian -0.377 I -0.158 know -0.157 many
输入
-0.377
-0.158
知道
-0.157
很多
0.124
0.035
他们
0.109
-0.488
俄罗斯人
0.375
.


[1]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.94869-8.94869fThey(inputs)0.387 . 0.149 people 0.144 are 0.054 who -0.716 Greek -0.351 I -0.242 many -0.125 know
输入
-0.351
-0.125
知道
-0.242
很多
0.149
0.054
他们
0.144
-0.716
希腊人
0.387
.


[2]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.67602-8.67602fThey(inputs)0.701 . 0.144 people 0.015 are -0.529 Australian -0.41 I -0.176 many -0.158 know -0.015 who
输入
-0.41
-0.158
知道
-0.176
很多
0.144
-0.015
他们
0.015
-0.529
澳大利亚人
0.701
.


[3]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.14276-9.14276fThey(inputs)0.39 . 0.134 people 0.03 are -0.632 American -0.439 I -0.185 know -0.162 many -0.03 who
输入
-0.439
-0.185
知道
-0.162
很多
0.134
-0.03
他们
0.03
-0.632
美国人
0.39
.


[4]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.08274-9.08274fThey(inputs)0.428 . 0.155 are 0.106 people 0.079 who -0.76 Italian -0.454 I -0.24 many -0.149 know
输入
-0.454
-0.149
知道
-0.24
很多
0.106
0.079
他们
0.155
-0.76
意大利人
0.428
.


[5]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.0745-9.0745fThey(inputs)0.414 . 0.288 are 0.156 who 0.106 people -1.015 Spanish -0.399 I -0.225 many -0.15 know
输入
-0.399
-0.15
知道
-0.225
很多
0.106
0.156
他们
0.288
-1.015
西班牙人
0.414
.


[6]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.9994-8.9994fThey(inputs)0.46 . 0.186 are 0.138 people 0.063 who -0.811 German -0.38 I -0.282 many -0.125 know
输入
-0.38
-0.125
知道
-0.282
很多
0.138
0.063
他们
0.186
-0.811
德国人
0.46
.


[7]
输出
他们
喜欢
他们的
伏特加
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.63055-8.63055fThey(inputs)0.374 . 0.128 people 0.1 Indian -0.484 I -0.227 know -0.21 many -0.054 who -0.011 are
输入
-0.484
-0.227
知道
-0.21
很多
0.128
-0.054
他们
-0.011
0.1
印度人
0.374
.

要查看哪些输入 token 影响(正面/负面)生成单词 “伏特加” 的可能性,我们绘制单词 “伏特加” 的全局 token 重要性。

瞧!俄罗斯人喜欢他们的伏特加,不是吗? :)

[16]:
shap.plots.bar(shap_values[0, :, "vodka"])
../../../_images/example_notebooks_text_examples_text_generation_Open_Ended_GPT2_Text_Generation_Explanations_30_0.png

有更多有用的示例的想法吗?欢迎提交 Pull Request 来添加到此文档笔记本!