DeepExplainer 基因组示例

此示例使用在 DeepLIFT 代码库中模拟基因组数据上训练的模型运行 DeepExplainer (https://github.com/kundajelab/deeplift/blob/master/examples/genomics/genomics_simulation.ipynb),使用动态参考(即,参考根据输入序列而变化;在本例中,参考是输入序列的双核苷酸改组版本的集合)

模拟数据如下

  • 1/4 的序列同时包含 1-3 个 GATA_disc1 模体 和 1-3 个 TAL1_known1 模体;这些被标记为 1,1,1

  • 1/4 的序列嵌入了 1-3 个 GATA_disc1 模体;这些被标记为 0,1,0

  • 1/4 的序列嵌入了 1-3 个 TAL1_known1 模体;这些被标记为 0,0,1

  • 1/4 的序列没有模体;这些被标记为 0,0,0

[1]:
%matplotlib inline

获取数据和 keras 模型

拉取相关数据

[2]:
! [[ ! -f sequences.simdata.gz ]] && wget https://raw.githubusercontent.com/AvantiShri/model_storage/db919b12f750e5844402153233249bb3d24e9e9a/deeplift/genomics/sequences.simdata.gz
! [[ ! -f keras2_conv1d_record_5_model_PQzyq_modelJson.json ]] && wget https://raw.githubusercontent.com/AvantiShri/model_storage/b6e1d69/deeplift/genomics/keras2_conv1d_record_5_model_PQzyq_modelJson.json
! [[ ! -f keras2_conv1d_record_5_model_PQzyq_modelWeights.h5 ]] && wget https://raw.githubusercontent.com/AvantiShri/model_storage/b6e1d69/deeplift/genomics/keras2_conv1d_record_5_model_PQzyq_modelWeights.h5
! [[ ! -f test.txt.gz ]] && wget https://raw.githubusercontent.com/AvantiShri/model_storage/9aadb769735c60eb90f7d3d896632ac749a1bdd2/deeplift/genomics/test.txt.gz

加载数据

[3]:
! pip install simdna
Requirement already satisfied: simdna in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (0.4.3.2)
Requirement already satisfied: numpy>=1.9 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from simdna) (1.26.3)
Requirement already satisfied: matplotlib in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from simdna) (3.8.2)
Requirement already satisfied: scipy in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from simdna) (1.11.4)
Requirement already satisfied: contourpy>=1.0.1 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (4.47.2)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (23.2)
Requirement already satisfied: pillow>=8 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (10.2.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from matplotlib->simdna) (2.8.2)
Requirement already satisfied: six>=1.5 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->simdna) (1.16.0)
[4]:
import gzip

import simdna.synthetic as synthetic

data_filename = "sequences.simdata.gz"

# read in the data in the testing set
test_ids_fh = gzip.open("test.txt.gz", "rb")
ids_to_load = [x.decode("utf-8").rstrip("\n") for x in test_ids_fh]
data = synthetic.read_simdata_file(data_filename, ids_to_load=ids_to_load)
[5]:
import numpy as np


# this is set up for 1d convolutions where examples
# have dimensions (len, num_channels)
# the channel axis is the axis for one-hot encoding.
def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence), 4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return, sequence=sequence, one_hot_axis=1)
    return to_return


def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis in (0, 1)
    if one_hot_axis == 0:
        assert zeros_array.shape[1] == len(sequence)
    elif one_hot_axis == 1:
        assert zeros_array.shape[0] == len(sequence)
    # will mutate zeros_array
    for i, char in enumerate(sequence):
        if char in ("A", "a"):
            char_idx = 0
        elif char in ("C", "c"):
            char_idx = 1
        elif char in ("G", "g"):
            char_idx = 2
        elif char in ("T", "t"):
            char_idx = 3
        elif char in ("N", "n"):
            continue  # leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: " + str(char))
        if one_hot_axis == 0:
            zeros_array[char_idx, i] = 1
        elif one_hot_axis == 1:
            zeros_array[i, char_idx] = 1


onehot_data = np.array([one_hot_encode_along_channel_axis(seq) for seq in data.sequences])

加载模型

[6]:
from keras.models import model_from_json

# load the keras model
keras_model_weights = "keras2_conv1d_record_5_model_PQzyq_modelWeights.h5"
keras_model_json = "keras2_conv1d_record_5_model_PQzyq_modelJson.json"

keras_model = model_from_json(open(keras_model_json).read())
keras_model.load_weights(keras_model_weights)

安装 deeplift 包,用于双核苷酸改组和可视化代码

[7]:
!pip install deeplift
Requirement already satisfied: deeplift in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (0.6.13.0)
Requirement already satisfied: numpy>=1.9 in /home/xiao/miniforge3/envs/shap/lib/python3.11/site-packages (from deeplift) (1.26.3)

计算重要性分数

定义生成参考的函数,在本例中通过对给定的输入序列执行双核苷酸改组来实现

[8]:
from deeplift.dinuc_shuffle import dinuc_shuffle


def shuffle_several_times(s):
    s = np.squeeze(s)
    return dinuc_shuffle(s, num_shufs=100)

使用动态参考函数运行 DeepExplainer

[9]:
import shap

np.random.seed(1)

seqs_to_explain = onehot_data[[0, 3, 9]]  # these three are positive for task 0
dinuc_shuff_explainer = shap.DeepExplainer((keras_model.input, keras_model.output[:, 0]), shuffle_several_times)
raw_shap_explanations = dinuc_shuff_explainer.shap_values(seqs_to_explain, check_additivity=False)

可视化个体序列上的分数

[10]:
from deeplift.visualization import viz_sequence

# project the importance at each position onto the base that's actually present
dinuc_shuff_explanations = np.sum(raw_shap_explanations, axis=-1)[:, :, None] * seqs_to_explain
for idx, dinuc_shuff_explanation in zip([0, 3, 9], dinuc_shuff_explanations):
    print("Scores for example", idx)
    highlight = {
        "blue": [
            (embedding.startPos, embedding.startPos + len(embedding.what))
            for embedding in data.embeddings[idx]
            if "GATA_disc1" in embedding.what.getDescription()
        ],
        "green": [
            (embedding.startPos, embedding.startPos + len(embedding.what))
            for embedding in data.embeddings[idx]
            if "TAL1_known1" in embedding.what.getDescription()
        ],
    }
    viz_sequence.plot_weights(dinuc_shuff_explanation, subticks_frequency=20, highlight=highlight)
Scores for example 0
../../_images/example_notebooks_genomic_examples_DeepExplainer_Genomics_Example_21_1.png
Scores for example 3
../../_images/example_notebooks_genomic_examples_DeepExplainer_Genomics_Example_21_3.png
Scores for example 9
../../_images/example_notebooks_genomic_examples_DeepExplainer_Genomics_Example_21_5.png

上图显示了三个示例序列在预测包含 GATA_disc1 和 TAL1_known1 模体的序列的任务上的重要性分数。 字母高度反映了分数。 蓝色框表示插入的 GATA_disc1 模体的真实位置,绿色框表示插入的 TAL1_known1 模体的真实位置。