开放式 GPT2 文本生成解释
本笔记本演示了如何获取用于开放式文本生成的 gpt2 输出的解释。在本演示中,我们使用 hugging face 提供的预训练 gpt2 模型 (https://hugging-face.cn/gpt2) 来解释 gpt2 生成的文本。我们进一步展示了如何获取自定义输出生成文本的解释,并绘制任何输出生成 token 的全局输入 token 重要性。
[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import shap
加载模型和分词器
[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()
下面,我们设置某些模型配置。我们需要定义模型是解码器还是编码器-解码器。这可以通过模型配置文件中的 ‘is_decoder’ 或 ‘is_encoder_decoder’ 参数设置。我们还可以设置自定义模型生成参数,这些参数将在输出文本生成解码过程中使用。
[3]:
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
"do_sample": True,
"max_length": 50,
"temperature": 0.7,
"top_k": 50,
"no_repeat_ngram_size": 2,
}
定义初始文本
[4]:
s = ["I enjoy walking with my cute dog"]
创建解释器对象并计算 SHAP 值
[5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
可视化 shap 解释
[6]:
shap.plots.text(shap_values)
[0]
输出
,
但是
我
是
不
确定
是否
我
会
永远
能够
能够
去
输入
-0.117
我
-0.431
享受
-0.427
散步
0.072
和
-0.15
我的
-0.238
可爱
4.064
狗
输入
-0.117
我
-0.431
享受
-0.427
散步
0.072
和
-0.15
我的
-0.238
可爱
4.064
狗
输入
0.873
我
1.015
享受
-0.029
散步
0.311
和
-0.112
我的
0.378
可爱
-0.781
狗
输入
0.873
我
1.015
享受
-0.029
散步
0.311
和
-0.112
我的
0.378
可爱
-0.781
狗
输入
1.005
我
0.023
享受
-0.002
散步
-0.107
和
0.405
我的
-0.16
可爱
0.016
狗
输入
1.005
我
0.023
享受
-0.002
散步
-0.107
和
0.405
我的
-0.16
可爱
0.016
狗
输入
0.126
我
-0.158
享受
-0.035
散步
-0.082
和
0.106
我的
0.168
可爱
-0.068
狗
输入
0.126
我
-0.158
享受
-0.035
散步
-0.082
和
0.106
我的
0.168
可爱
-0.068
狗
输入
0.199
我
0.167
享受
-0.196
散步
0.018
和
-0.014
我的
0.003
可爱
-0.041
狗
输入
0.199
我
0.167
享受
-0.196
散步
0.018
和
-0.014
我的
0.003
可爱
-0.041
狗
输入
-0.41
我
0.149
享受
-0.323
散步
0.408
和
-0.469
我的
0.067
可爱
0.08
狗
输入
-0.41
我
0.149
享受
-0.323
散步
0.408
和
-0.469
我的
0.067
可爱
0.08
狗
输入
-0.064
我
0.036
享受
0.188
散步
0.053
和
0.009
我的
0.24
可爱
0.062
狗
输入
-0.064
我
0.036
享受
0.188
散步
0.053
和
0.009
我的
0.24
可爱
0.062
狗
输入
0.406
我
0.356
享受
0.171
散步
-0.094
和
0.204
我的
-0.016
可爱
-0.183
狗
输入
0.406
我
0.356
享受
0.171
散步
-0.094
和
0.204
我的
-0.016
可爱
-0.183
狗
输入
-0.13
我
0.457
享受
-0.046
散步
-0.005
和
-0.047
我的
0.061
可爱
0.031
狗
输入
-0.13
我
0.457
享受
-0.046
散步
-0.005
和
-0.047
我的
0.061
可爱
0.031
狗
输入
0.013
我
0.012
享受
0.232
散步
-0.151
和
-0.038
我的
0.07
可爱
0.261
狗
输入
0.013
我
0.012
享受
0.232
散步
-0.151
和
-0.038
我的
0.07
可爱
0.261
狗
输入
0.078
我
0.016
享受
0.392
散步
0.032
和
0.128
我的
0.03
可爱
-0.005
狗
输入
0.078
我
0.016
享受
0.392
散步
0.032
和
0.128
我的
0.03
可爱
-0.005
狗
输入
-0.333
我
-0.023
享受
0.203
散步
0.088
和
0.11
我的
-0.241
可爱
0.145
狗
输入
-0.333
我
-0.023
享受
0.203
散步
0.088
和
0.11
我的
-0.241
可爱
0.145
狗
输入
0.089
我
0.058
享受
0.171
散步
-0.156
和
0.123
我的
0.758
可爱
0.421
狗
输入
0.089
我
0.058
享受
0.171
散步
-0.156
和
0.123
我的
0.758
可爱
0.421
狗
另一个例子…
[7]:
s = ["Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth"]
[8]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[9]:
shap.plots.text(shap_values)
[0]
输出
在
接下来的
几天
科学家
.
输入
0.442
证实
-0.035
最坏的
-0.168
接下来的
0.24
可能
-0.156
结果
0.075
巨大的
-0.365
:
-0.385
接下来的
0.093
小行星
0.221
将
-0.166
碰撞
1.28
地球
0.489
和
1.436
地球
输入
0.442
证实
-0.035
最坏的
-0.168
接下来的
0.24
可能
-0.156
结果
0.075
巨大的
-0.365
:
-0.385
接下来的
0.093
小行星
0.221
将
-0.166
碰撞
1.28
地球
0.489
和
1.436
地球
输入
-0.26
证实
-0.074
最坏的
0.251
接下来的
-0.062
可能
0.065
结果
-0.117
巨大的
-0.009
:
0.308
接下来的
-0.011
小行星
-0.203
将
-0.308
碰撞
0.252
地球
0.291
和
-0.145
地球
输入
-0.26
证实
-0.074
最坏的
0.251
接下来的
-0.062
可能
0.065
结果
-0.117
巨大的
-0.009
:
0.308
接下来的
-0.011
小行星
-0.203
将
-0.308
碰撞
0.252
地球
0.291
和
-0.145
地球
输入
0.325
证实
0.202
最坏的
0.054
接下来的
-0.493
可能
0.407
结果
0.318
巨大的
-0.351
:
-0.077
接下来的
-0.146
小行星
0.207
将
2.257
碰撞
2.382
地球
-0.043
和
0.427
地球
输入
0.325
证实
0.202
最坏的
0.054
接下来的
-0.493
可能
0.407
结果
0.318
巨大的
-0.351
:
-0.077
接下来的
-0.146
小行星
0.207
将
2.257
碰撞
2.382
地球
-0.043
和
0.427
地球
输入
0.001
证实
0.446
最坏的
0.128
接下来的
-0.074
可能
0.107
结果
-0.122
巨大的
-0.015
:
-0.098
接下来的
0.06
小行星
0.12
将
0.123
碰撞
0.337
地球
0.06
和
-0.411
地球
输入
0.001
证实
0.446
最坏的
0.128
接下来的
-0.074
可能
0.107
结果
-0.122
巨大的
-0.015
:
-0.098
接下来的
0.06
小行星
0.12
将
0.123
碰撞
0.337
地球
0.06
和
-0.411
地球
输入
-0.578
证实
0.36
最坏的
0.012
接下来的
0.043
可能
-0.247
结果
-0.11
巨大的
-0.245
:
0.284
接下来的
0.032
小行星
0.062
将
-0.973
碰撞
1.145
地球
0.796
和
0.217
地球
输入
-0.578
证实
0.36
最坏的
0.012
接下来的
0.043
可能
-0.247
结果
-0.11
巨大的
-0.245
:
0.284
接下来的
0.032
小行星
0.062
将
-0.973
碰撞
1.145
地球
0.796
和
0.217
地球
输入
0.016
证实
0.292
最坏的
0.12
接下来的
0.172
可能
0.169
结果
0.194
巨大的
-0.078
:
0.215
接下来的
0.122
小行星
0.089
将
-0.078
碰撞
0.122
地球
0.168
和
-0.109
地球
输入
0.016
证实
0.292
最坏的
0.12
接下来的
0.172
可能
0.169
结果
0.194
巨大的
-0.078
:
0.215
接下来的
0.122
小行星
0.089
将
-0.078
碰撞
0.122
地球
0.168
和
-0.109
地球
自定义文本生成和调试有偏差的输出
下面我们演示如何解释使用模型生成特定输出句子的可能性,给定一个输入句子。例如,我们提出一个问题:在句子 “我知道很多人是 [target]” 中,哪个国家/地区的居民(目标)最有可能在输出句子 “他们喜欢他们的伏特加!” 中生成 token “伏特加”?为此,我们首先定义输入-输出句子对
[10]:
# define input
x = [
"I know many people who are Russian.",
"I know many people who are Greek.",
"I know many people who are Australian.",
"I know many people who are American.",
"I know many people who are Italian.",
"I know many people who are Spanish.",
"I know many people who are German.",
"I know many people who are Indian.",
]
[11]:
# define output
y = [
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
]
我们用 Teacher Forcing 评分类包装模型,并创建一个文本掩码器
[12]:
teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
创建解释器…
[13]:
explainer = shap.Explainer(teacher_forcing_model, masker)
生成 SHAP 解释值!
[14]:
shap_values = explainer(x, y)
现在我们已经生成了 SHAP 值,我们可以看看输入中的 token 对输出句子中 token “伏特加” 的贡献,使用文本图。注意:红色表示正向贡献,而蓝色表示负向贡献,颜色的强度显示其在各自方向上的强度。
[15]:
shap.plots.text(shap_values)
[0]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.377
我
-0.158
知道
-0.157
很多
0.124
人
0.035
他们
0.109
是
-0.488
俄罗斯人
0.375
.
输入
-0.377
我
-0.158
知道
-0.157
很多
0.124
人
0.035
他们
0.109
是
-0.488
俄罗斯人
0.375
.
输入
0.126
我
0.448
知道
0.248
很多
0.45
人
0.032
他们
0.061
是
-0.089
俄罗斯人
-0.082
.
输入
0.126
我
0.448
知道
0.248
很多
0.45
人
0.032
他们
0.061
是
-0.089
俄罗斯人
-0.082
.
输入
-0.069
我
0.088
知道
0.297
很多
0.144
人
0.175
他们
0.253
是
-0.024
俄罗斯人
-0.087
.
输入
-0.069
我
0.088
知道
0.297
很多
0.144
人
0.175
他们
0.253
是
-0.024
俄罗斯人
-0.087
.
输入
0.036
我
-0.013
知道
-0.132
很多
0.021
人
-0.062
他们
-0.164
是
2.648
俄罗斯人
0.05
.
输入
0.036
我
-0.013
知道
-0.132
很多
0.021
人
-0.062
他们
-0.164
是
2.648
俄罗斯人
0.05
.
输入
-0.449
我
-0.182
知道
-0.125
很多
-0.309
人
-0.122
他们
-0.071
是
0.183
俄罗斯人
0.202
.
输入
-0.449
我
-0.182
知道
-0.125
很多
-0.309
人
-0.122
他们
-0.071
是
0.183
俄罗斯人
0.202
.
[1]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.351
我
-0.125
知道
-0.242
很多
0.149
人
0.054
他们
0.144
是
-0.716
希腊人
0.387
.
输入
-0.351
我
-0.125
知道
-0.242
很多
0.149
人
0.054
他们
0.144
是
-0.716
希腊人
0.387
.
输入
0.192
我
0.511
知道
0.229
很多
0.516
人
-0.004
他们
-0.029
是
0.407
希腊人
-0.088
.
输入
0.192
我
0.511
知道
0.229
很多
0.516
人
-0.004
他们
-0.029
是
0.407
希腊人
-0.088
.
输入
-0.044
我
0.076
知道
0.277
很多
0.147
人
0.169
他们
0.339
是
0.141
希腊人
-0.106
.
输入
-0.044
我
0.076
知道
0.277
很多
0.147
人
0.169
他们
0.339
是
0.141
希腊人
-0.106
.
输入
0.011
我
0.001
知道
-0.311
很多
0.031
人
-0.15
他们
-0.445
是
0.162
希腊人
0.061
.
输入
0.011
我
0.001
知道
-0.311
很多
0.031
人
-0.15
他们
-0.445
是
0.162
希腊人
0.061
.
输入
-0.445
我
-0.14
知道
-0.125
很多
-0.218
人
-0.131
他们
-0.041
是
0.339
希腊人
0.241
.
输入
-0.445
我
-0.14
知道
-0.125
很多
-0.218
人
-0.131
他们
-0.041
是
0.339
希腊人
0.241
.
[2]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.41
我
-0.158
知道
-0.176
很多
0.144
人
-0.015
他们
0.015
是
-0.529
澳大利亚人
0.701
.
输入
-0.41
我
-0.158
知道
-0.176
很多
0.144
人
-0.015
他们
0.015
是
-0.529
澳大利亚人
0.701
.
输入
0.148
我
0.457
知道
0.248
很多
0.453
人
0.032
他们
0.042
是
0.365
澳大利亚人
-0.057
.
输入
0.148
我
0.457
知道
0.248
很多
0.453
人
0.032
他们
0.042
是
0.365
澳大利亚人
-0.057
.
输入
-0.031
我
0.115
知道
0.32
很多
0.177
人
0.184
他们
0.298
是
0.089
澳大利亚人
-0.053
.
输入
-0.031
我
0.115
知道
0.32
很多
0.177
人
0.184
他们
0.298
是
0.089
澳大利亚人
-0.053
.
输入
-0.14
我
-0.093
知道
-0.265
很多
-0.371
人
-0.11
他们
-0.648
是
-0.393
澳大利亚人
0.123
.
输入
-0.14
我
-0.093
知道
-0.265
很多
-0.371
人
-0.11
他们
-0.648
是
-0.393
澳大利亚人
0.123
.
输入
-0.455
我
-0.201
知道
-0.14
很多
-0.315
人
-0.121
他们
-0.125
是
0.119
澳大利亚人
0.227
.
输入
-0.455
我
-0.201
知道
-0.14
很多
-0.315
人
-0.121
他们
-0.125
是
0.119
澳大利亚人
0.227
.
[3]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.439
我
-0.185
知道
-0.162
很多
0.134
人
-0.03
他们
0.03
是
-0.632
美国人
0.39
.
输入
-0.439
我
-0.185
知道
-0.162
很多
0.134
人
-0.03
他们
0.03
是
-0.632
美国人
0.39
.
输入
0.13
我
0.451
知道
0.174
很多
0.398
人
-0.019
他们
-0.072
是
0.474
美国人
-0.095
.
输入
0.13
我
0.451
知道
0.174
很多
0.398
人
-0.019
他们
-0.072
是
0.474
美国人
-0.095
.
输入
-0.04
我
0.109
知道
0.343
很多
0.212
人
0.18
他们
0.275
是
0.372
美国人
-0.041
.
输入
-0.04
我
0.109
知道
0.343
很多
0.212
人
0.18
他们
0.275
是
0.372
美国人
-0.041
.
输入
-0.094
我
-0.055
知道
-0.366
很多
-0.43
人
-0.082
他们
-0.514
是
-0.519
美国人
0.027
.
输入
-0.094
我
-0.055
知道
-0.366
很多
-0.43
人
-0.082
他们
-0.514
是
-0.519
美国人
0.027
.
输入
-0.484
我
-0.182
知道
-0.129
很多
-0.34
人
-0.116
他们
-0.117
是
-0.212
美国人
0.283
.
输入
-0.484
我
-0.182
知道
-0.129
很多
-0.34
人
-0.116
他们
-0.117
是
-0.212
美国人
0.283
.
[4]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.454
我
-0.149
知道
-0.24
很多
0.106
人
0.079
他们
0.155
是
-0.76
意大利人
0.428
.
输入
-0.454
我
-0.149
知道
-0.24
很多
0.106
人
0.079
他们
0.155
是
-0.76
意大利人
0.428
.
输入
0.138
我
0.485
知道
0.258
很多
0.472
人
-0.004
他们
0.056
是
0.561
意大利人
-0.141
.
输入
0.138
我
0.485
知道
0.258
很多
0.472
人
-0.004
他们
0.056
是
0.561
意大利人
-0.141
.
输入
-0.056
我
0.119
知道
0.3
很多
0.192
人
0.172
他们
0.285
是
0.163
意大利人
-0.124
.
输入
-0.056
我
0.119
知道
0.3
很多
0.192
人
0.172
他们
0.285
是
0.163
意大利人
-0.124
.
输入
-0.012
我
-0.115
知道
-0.23
很多
-0.142
人
-0.084
他们
-0.444
是
0.779
意大利人
0.203
.
输入
-0.012
我
-0.115
知道
-0.23
很多
-0.142
人
-0.084
他们
-0.444
是
0.779
意大利人
0.203
.
输入
-0.467
我
-0.172
知道
-0.11
很多
-0.266
人
-0.12
他们
-0.054
是
0.41
意大利人
0.248
.
输入
-0.467
我
-0.172
知道
-0.11
很多
-0.266
人
-0.12
他们
-0.054
是
0.41
意大利人
0.248
.
[5]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.399
我
-0.15
知道
-0.225
很多
0.106
人
0.156
他们
0.288
是
-1.015
西班牙人
0.414
.
输入
-0.399
我
-0.15
知道
-0.225
很多
0.106
人
0.156
他们
0.288
是
-1.015
西班牙人
0.414
.
输入
0.149
我
0.526
知道
0.225
很多
0.427
人
-0.003
他们
-0.01
是
0.353
西班牙人
-0.117
.
输入
0.149
我
0.526
知道
0.225
很多
0.427
人
-0.003
他们
-0.01
是
0.353
西班牙人
-0.117
.
输入
-0.06
我
0.101
知道
0.297
很多
0.157
人
0.172
他们
0.327
是
0.01
西班牙人
-0.081
.
输入
-0.06
我
0.101
知道
0.297
很多
0.157
人
0.172
他们
0.327
是
0.01
西班牙人
-0.081
.
输入
-0.048
我
-0.099
知道
-0.258
很多
-0.167
人
-0.103
他们
-0.376
是
-0.028
西班牙人
0.129
.
输入
-0.048
我
-0.099
知道
-0.258
很多
-0.167
人
-0.103
他们
-0.376
是
-0.028
西班牙人
0.129
.
输入
-0.482
我
-0.176
知道
-0.1
很多
-0.276
人
-0.129
他们
-0.04
是
0.221
西班牙人
0.23
.
输入
-0.482
我
-0.176
知道
-0.1
很多
-0.276
人
-0.129
他们
-0.04
是
0.221
西班牙人
0.23
.
[6]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.38
我
-0.125
知道
-0.282
很多
0.138
人
0.063
他们
0.186
是
-0.811
德国人
0.46
.
输入
-0.38
我
-0.125
知道
-0.282
很多
0.138
人
0.063
他们
0.186
是
-0.811
德国人
0.46
.
输入
0.135
我
0.482
知道
0.231
很多
0.44
人
0.026
他们
0.054
是
0.113
德国人
-0.122
.
输入
0.135
我
0.482
知道
0.231
很多
0.44
人
0.026
他们
0.054
是
0.113
德国人
-0.122
.
输入
-0.059
我
0.133
知道
0.317
很多
0.205
人
0.201
他们
0.294
是
0.229
德国人
-0.08
.
输入
-0.059
我
0.133
知道
0.317
很多
0.205
人
0.201
他们
0.294
是
0.229
德国人
-0.08
.
输入
-0.079
我
-0.071
知道
-0.269
很多
-0.182
人
-0.065
他们
-0.401
是
0.726
德国人
0.157
.
输入
-0.079
我
-0.071
知道
-0.269
很多
-0.182
人
-0.065
他们
-0.401
是
0.726
德国人
0.157
.
输入
-0.461
我
-0.171
知道
-0.117
很多
-0.293
人
-0.135
他们
-0.06
是
0.329
德国人
0.22
.
输入
-0.461
我
-0.171
知道
-0.117
很多
-0.293
人
-0.135
他们
-0.06
是
0.329
德国人
0.22
.
[7]
输出
他们
喜欢
他们的
伏特加
!
输入
-0.484
我
-0.227
知道
-0.21
很多
0.128
人
-0.054
他们
-0.011
是
0.1
印度人
0.374
.
输入
-0.484
我
-0.227
知道
-0.21
很多
0.128
人
-0.054
他们
-0.011
是
0.1
印度人
0.374
.
输入
0.111
我
0.487
知道
0.202
很多
0.438
人
0.006
他们
-0.076
是
0.184
印度人
-0.02
.
输入
0.111
我
0.487
知道
0.202
很多
0.438
人
0.006
他们
-0.076
是
0.184
印度人
-0.02
.
输入
-0.065
我
0.104
知道
0.337
很多
0.176
人
0.178
他们
0.277
是
0.245
印度人
-0.085
.
输入
-0.065
我
0.104
知道
0.337
很多
0.176
人
0.178
他们
0.277
是
0.245
印度人
-0.085
.
输入
-0.07
我
-0.026
知道
-0.341
很多
-0.216
人
-0.151
他们
-0.571
是
-0.666
印度人
0.175
.
输入
-0.07
我
-0.026
知道
-0.341
很多
-0.216
人
-0.151
他们
-0.571
是
-0.666
印度人
0.175
.
输入
-0.429
我
-0.175
知道
-0.132
很多
-0.305
人
-0.088
他们
-0.11
是
-0.067
印度人
0.261
.
输入
-0.429
我
-0.175
知道
-0.132
很多
-0.305
人
-0.088
他们
-0.11
是
-0.067
印度人
0.261
.
要查看哪些输入 token 影响(正面/负面)生成单词 “伏特加” 的可能性,我们绘制单词 “伏特加” 的全局 token 重要性。
瞧!俄罗斯人喜欢他们的伏特加,不是吗? :)
[16]:
shap.plots.bar(shap_values[0, :, "vodka"])

有更多有用的示例的想法吗?欢迎提交 Pull Request 来添加到此文档笔记本!