Step 3. Predict TCR and peptide-HLA binding

T cell receptors (TCR), expressed on the surface of T cells, recognize and bind to immunogenic peptides or epitopes presented by human leukocyte antigen (HLA). This recognition process is fundamental to the adaptive immune system and serves as a critical mechanism for identifying and responding to pathogenic threats. A diverse array of TCRs ensures protection against a wide range of pathogens and malignant cells. Mechanisms exist to eliminate T cells that recognize self-antigens (central and peripheral tolerance), but self-tolerance can fail, leading to autoimmunity, especially when hidden self-antigens are exposed or when microbial antigens mimic self-antigens. Identifying the specific epitopes that trigger T cell activation in different disease contexts can provide crucial insights into disease pathogenesis and could enable the development of personalized therapies, such as tolerance-inducing vaccines or TCR-specific T cell depletion.

模型结果图

There are 5 million unique TCRs per individual, and the TCR repertoire is highly diverse. The TCR repertoire is shaped by the individual’s genetics, environmental exposures, and immune history. This diversity allows the immune system to recognize a wide range of antigens, but it also makes it challenging to predict which TCRs will bind to specific peptides or HLA molecules.

Here we use public dataset from VDJdb to train a model to predict TCR-peptide-HLA binding. The dataset contains TCR sequences, peptide sequences, and HLA alleles associated with T cell responses. We will use this data to train a machine learning model that can predict whether a given TCR will bind to a specific peptide-HLA complex.

Load TCR and peptide-HLA binding data.

Load our training data (Download here)

  1. TCRs with alpha and beta chain, V and J gene information

  2. Peptide and HLA information

import pandas as pd
from trimap import utils
import trimap
from trimap.model import TCRbind
import torch
import numpy as np

import warnings
warnings.filterwarnings("ignore")

print(f"Using trimap version {trimap.__version__}")

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

df_train = pd.read_csv('VDJdb_train.csv')
Using trimap version 1.0.5

Reconstruct full TCR sequences by CDR3 and VJ gene information

The TCR sequences in the dataset are often represented by their CDR3 regions, which are the most variable and antigen-specific parts of the TCR. To reconstruct the full TCR sequences, we need to combine the CDR3 regions with their corresponding V and J gene segments. This process involves using the V and J gene information to identify the appropriate gene segments and then concatenating them with the CDR3 regions to form the complete TCR sequences.

If you want to use your own TCR sequences, please make sure they are in the same format as the dataset. The TCR sequences should be in the form of CDR3 regions, and the V and J gene information should be provided separately. All V and J genes should be in the Allele column of the trajs_aa.tsv, trajs_aa.tsv, trbjs_aa.tsv, and trbjs_aa.tsv. in the ‘library/’ folder.

Or you can train the model using only CDR3 regions, but the performance may be lower.

df_train['alpha'] = utils.determine_tcr_seq_vj(df_train['alpha'].tolist(), df_train['V_alpha'].tolist(), df_train['J_alpha'].tolist(), chain='A')
df_train['beta'] = utils.determine_tcr_seq_vj(df_train['beta'].tolist(), df_train['V_beta'].tolist(), df_train['J_beta'].tolist(), chain='B')
df_train = df_train[['alpha', 'beta', 'V_alpha', 'J_alpha', 'V_beta', 'J_beta', 'HLA', 'Epitope']]
print(df_train.iloc[0])
alpha      AQKITQTQPGMFVQEKEAVTLDCTYDTSDPSYGLFWYKQPSSGEMI...
beta       DSGVTQTPKHLITATGQRVTLRCSPRSGDLSVYWYQQSLDQGLQFL...
V_alpha                                        TRAV14/DV4*01
J_alpha                                            TRAJ12*01
V_beta                                              TRBV9*01
J_beta                                            TRBJ1-2*01
HLA                                                 HLA-A*03
Epitope                                            KLGGALQAK
Name: 0, dtype: object

Show TCR distribution for top 15 epitopes in the dataset. It is domianted by the ‘KLGGALQAK’ epitopes.

df_train['Epitope'].value_counts()[0:15].plot(kind='bar', figsize=(10, 5))
<Axes: xlabel='Epitope'>
../_images/f494817c7e3a5f547bb416dd95b1634c5a022b59967daebe1f565ac50e0183e8.png

Train and save the model

To run the TCR model, you need the file pHLA_model.pt. If it is not available, please complete Step 2 first to generate it.

The files alpha_dict.pt and beta_dict.pt are optional — if they are missing, the model will automatically generate them during runtime. Make sure you have enough disk space (at least 20GB) to store the dictionaries.

model = TCRbind().to(device)
model.train_model(df_train, num_epochs=20, device=device, phla_model_dir='pHLA_model.pt')
torch.save(model.state_dict(), 'TCR_model.pt')
INFO:trimap.model:Training...
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
INFO:trimap.model:alpha_dict.pt not found, generating...
100%|██████████| 66/66 [05:08<00:00,  4.67s/it]
INFO:trimap.model:Saved alpha_dict.pt
INFO:trimap.model:No new alpha sequences found
INFO:trimap.model:beta_dict.pt not found, generating...
100%|██████████| 69/69 [05:24<00:00,  4.70s/it]
INFO:trimap.model:Saved beta_dict.pt
INFO:trimap.model:No new beta sequences found
100%|██████████| 253/253 [01:38<00:00,  2.57it/s]
Epoch [1/20], Loss: 0.5968, ROC: 0.6886
100%|██████████| 252/252 [01:32<00:00,  2.72it/s]
Epoch [2/20], Loss: 0.5190, ROC: 0.7571
100%|██████████| 253/253 [01:34<00:00,  2.68it/s]
Epoch [3/20], Loss: 0.4786, ROC: 0.7846
100%|██████████| 253/253 [01:25<00:00,  2.95it/s]
Epoch [4/20], Loss: 0.6179, ROC: 0.7987
100%|██████████| 253/253 [01:29<00:00,  2.84it/s]
Epoch [5/20], Loss: 0.3444, ROC: 0.8125
100%|██████████| 252/252 [01:30<00:00,  2.79it/s]
Epoch [6/20], Loss: 0.4474, ROC: 0.8223
100%|██████████| 253/253 [01:31<00:00,  2.78it/s]
Epoch [7/20], Loss: 0.3058, ROC: 0.8320
100%|██████████| 252/252 [01:26<00:00,  2.91it/s]
Epoch [8/20], Loss: 0.5200, ROC: 0.8405
100%|██████████| 253/253 [01:25<00:00,  2.97it/s]
Epoch [9/20], Loss: 0.5777, ROC: 0.8506
100%|██████████| 253/253 [01:23<00:00,  3.04it/s]
Epoch [10/20], Loss: 0.3730, ROC: 0.8578
100%|██████████| 253/253 [01:24<00:00,  2.99it/s]
Epoch [11/20], Loss: 0.4330, ROC: 0.8651
100%|██████████| 253/253 [01:27<00:00,  2.91it/s]
Epoch [12/20], Loss: 0.1742, ROC: 0.8771
100%|██████████| 253/253 [01:24<00:00,  3.00it/s]
Epoch [13/20], Loss: 0.5427, ROC: 0.8836
100%|██████████| 252/252 [01:22<00:00,  3.06it/s]
Epoch [14/20], Loss: 0.4505, ROC: 0.8918
100%|██████████| 253/253 [01:25<00:00,  2.97it/s]
Epoch [15/20], Loss: 0.1103, ROC: 0.8999
100%|██████████| 252/252 [01:22<00:00,  3.06it/s]
Epoch [16/20], Loss: 0.3280, ROC: 0.9075
100%|██████████| 252/252 [01:21<00:00,  3.10it/s]
Epoch [17/20], Loss: 0.3075, ROC: 0.9154
100%|██████████| 253/253 [01:29<00:00,  2.82it/s]
Epoch [18/20], Loss: 0.4465, ROC: 0.9234
100%|██████████| 252/252 [01:23<00:00,  3.00it/s]
Epoch [19/20], Loss: 0.3247, ROC: 0.9297
100%|██████████| 252/252 [01:22<00:00,  3.05it/s]
Epoch [20/20], Loss: 0.2876, ROC: 0.9367

(optional) Load pretrained model

If you want to use a pretrained model, you can download it from here and load it using the following code:

model = TCRbind().to(device)
model.load_state_dict(torch.load('TCR_model.pt', map_location=device))
<All keys matched successfully>

Test the model

Load the test data (Download here)

df_test = pd.read_csv('VDJdb_test.csv')
df_test['alpha'] = utils.determine_tcr_seq_vj(df_test['alpha'].tolist(), df_test['V_alpha'].tolist(), df_test['J_alpha'].tolist(), chain='A')
df_test['beta'] = utils.determine_tcr_seq_vj(df_test['beta'].tolist(), df_test['V_beta'].tolist(), df_test['J_beta'].tolist(), chain='B')
result, cdr3a_attn, cdr3b_attn = model.test_model(df_test=df_test, device=device, phla_model_dir='pHLA_model.pt')
df_test['pred'] = result
INFO:trimap.model:Loading alpha_dict.pt
INFO:trimap.model:No new alpha sequences found
INFO:trimap.model:Loading beta_dict.pt
INFO:trimap.model:No new beta sequences found
INFO:trimap.model:Predicting...
100%|██████████| 44/44 [00:14<00:00,  3.14it/s]

Plot AUC and PR curves for the model.

import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_curve,
    precision_recall_curve,
    roc_auc_score,
    average_precision_score
)

def plot_roc_prc_curve(y_true, y_scores, title_prefix=''):
    """
    Plot aesthetically improved ROC and Precision-Recall curves with AUC scores.

    Args:
        y_true (array-like): Ground truth binary labels (0 or 1).
        y_scores (array-like): Predicted scores (e.g., probabilities).
        title_prefix (str): Optional prefix for plot titles.
    """
    # Use a clean style
    plt.style.use('seaborn-v0_8-muted')

    # Compute metrics
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = roc_auc_score(y_true, y_scores)

    precision, recall, _ = precision_recall_curve(y_true, y_scores)
    prc_auc = average_precision_score(y_true, y_scores)

    # Print AUC values
    print(f'{title_prefix}ROC-AUC: {roc_auc:.4f}')
    print(f'{title_prefix}PRC-AUC: {prc_auc:.4f}')

    # Set up plot
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # ROC Curve
    axes[0].plot(fpr, tpr, color='#E24A33', lw=2.5, label=f'AUC = {roc_auc:.4f}')
    axes[0].plot([0, 1], [0, 1], linestyle='--', color='gray', lw=1)
    axes[0].set_xlim([-0.01, 1.01])
    axes[0].set_ylim([-0.01, 1.01])
    axes[0].set_xlabel('False Positive Rate', fontsize=12)
    axes[0].set_ylabel('True Positive Rate', fontsize=12)
    axes[0].set_title(f'{title_prefix}ROC Curve', fontsize=14, fontweight='bold')
    axes[0].legend(loc='lower right', fontsize=10)
    axes[0].grid(alpha=0.3)

    # PRC Curve
    axes[1].plot(recall, precision, color='#348ABD', lw=2.5, label=f'AUC = {prc_auc:.4f}')
    axes[1].set_xlim([-0.01, 1.01])
    axes[1].set_ylim([-0.01, 1.01])
    axes[1].set_xlabel('Recall', fontsize=12)
    axes[1].set_ylabel('Precision', fontsize=12)
    axes[1].set_title(f'{title_prefix}Precision-Recall Curve', fontsize=14, fontweight='bold')
    axes[1].legend(loc='lower left', fontsize=10)
    axes[1].grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

# Plot ROC and Precision-Recall curves
plot_roc_prc_curve(df_test['label'], df_test['pred'])
ROC-AUC: 0.8565
PRC-AUC: 0.4687
../_images/6a3d3acc82cc627f7c9fe9a1191cbb3a838feece150740eda5e195ffe1ce6acc.png

Mean predictive score of ture postive and false postive pairs for each peptide

def compute_score_stats(df_list, top_epitopes):
    tp_all, fp_all = [], []
    for df in df_list:
        tp, fp = [], []
        for pep in top_epitopes:
            preds = df[df['Epitope'] == pep]
            tp.append(np.mean(preds[preds['label'] == 1]['pred']))
            fp.append(np.mean(preds[preds['label'] == 0]['pred']))
        tp_all.append(tp)
        fp_all.append(fp)
    tp_arr = np.array(tp_all)
    fp_arr = np.array(fp_all)
    return (
        np.nanmean(tp_arr, axis=0), np.nanstd(tp_arr, axis=0),
        np.nanmean(fp_arr, axis=0), np.nanstd(fp_arr, axis=0)
    )

def plot_score_bars(top_epitopes, tp_mean, tp_std, fp_mean, fp_std):
    x = np.arange(len(top_epitopes))
    width = 0.35
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.bar(x - width/2, tp_mean, width, yerr=tp_std, label='Positive', color='indianred', capsize=5, edgecolor='black')
    ax.bar(x + width/2, fp_mean, width, yerr=fp_std, label='Negative', color='skyblue', capsize=5, edgecolor='black')
    ax.set_ylabel('Predictive Scores')
    ax.set_title('Mean predictive scores of positive and negative pairs per peptide')
    ax.set_xticks(x)
    ax.set_xticklabels(top_epitopes, rotation='vertical')
    ax.legend(frameon=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    fig.tight_layout()

def plot_peptide_counts(df_train, top_epitopes):
    counts = [df_train[df_train['Epitope'] == pep].shape[0] for pep in top_epitopes]
    x = np.arange(len(top_epitopes))
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.bar(x, counts, width=0.5, color='#1f77b4', edgecolor='black')
    ax.set_ylabel('Number of positive pairs')
    ax.set_title('Training sample count per peptide')
    ax.set_xticks(x)
    ax.set_xticklabels(top_epitopes, rotation='vertical')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    fig.tight_layout()

# Prepare data
top_epitopes = df_test['Epitope'].value_counts().head(15).index.tolist()

# Simulate multiple folds
df_test_list = [df_test.copy() for _ in range(1, 10)]  # assume 9 test folds

# Compute mean and std of predictive scores
tp_mean, tp_std, fp_mean, fp_std = compute_score_stats(df_test_list, top_epitopes)

# Plot predictive scores
plot_score_bars(top_epitopes, tp_mean, tp_std, fp_mean, fp_std)

# Plot training sample counts
plot_peptide_counts(df_train, top_epitopes)

plt.show()
../_images/ef12a0ba8af23465649db41f82a72285cb4f1e37ca909160379b37ceb5ef28d7.png ../_images/c2f0333976da5b0b0f6f035570b2c7d7b15852f77e1fc6d5a89d76f98b5e5b2d.png

Plot the precision for each peptide across different thresholds

This evaluates whether the model’s high-confidence predictions are indeed true positives.

# if precision is -1, it means there is no positive prediction
import seaborn as sns
from sklearn.metrics import confusion_matrix

thresholds = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]
precision_all = []
for pep in top_epitopes:
    precision_list = [] 
    for threshold in thresholds:
        df_baselines = df_test[df_test['Epitope'] == pep]
        cm = confusion_matrix(df_baselines['label'].values, df_baselines['pred'].values > threshold)
        TP = cm[1, 1]
        FP = cm[0, 1]
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0 
        precision_list.append(precision)
    precision_all.append(precision_list)


# Create a DataFrame
df = pd.DataFrame(precision_all, columns=thresholds, index=top_epitopes)

# Create the heatmap
# if precision is 0, delete the value
df = df.replace(0, np.nan)
plt.figure(figsize=(10, 5))
sns.heatmap(df, cmap='YlGnBu', annot=True, fmt=".2f", linewidths=0.5)
plt.xlabel('Threshold')
plt.title('Precision scores for common peptides under different thresholds', fontsize=15)
plt.show()
../_images/e477fa9aee33df5d16515cc901d732bf784146052b7c2eeaf970ced66e81cd8f.png