A transformer-based foundation model pre-trained on 33 million single-cell profiles. Use it for zero-shot cell annotation, gene network inference, multi-batch integration, and perturbation response prediction — without retraining from scratch.
scGPT (Cui et al., Nature Methods 2024) is a generative pre-trained transformer (GPT-style) trained on over 33 million single-cell RNA-seq profiles from the CELLxGENE Discover corpus. It learns a generalised representation of cell biology that can be fine-tuned or used directly for a wide range of downstream tasks:
Annotate cells in new datasets using pre-trained embeddings — even zero-shot without labelled training data
Remove batch effects across datasets, technologies, and labs using the shared latent space
Extract gene–gene regulatory relationships from attention weights of the transformer
Predict transcriptional response to genetic knockouts or drug perturbations
scGPT treats each cell as a "sentence" and each gene expression value as a "token". Unlike text GPT:
Pre-training uses a masked gene modelling objective — predict masked gene expression values from context (analogous to BERT's masked language modelling).
# Create conda environment conda create -n scgpt python=3.10 -y conda activate scgpt # Install PyTorch (adjust CUDA version to match your GPU driver) pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # Install scGPT pip install scgpt # Install supporting packages pip install scanpy anndata matplotlib seaborn wandb # Verify python -c "import scgpt; print(scgpt.__version__)"
# Download the whole-human-body pre-trained model (~1.5 GB)
# Model is hosted on HuggingFace
pip install huggingface_hub
python - <<'EOF'
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="bowang-lab/scGPT_human",
local_dir="./scGPT_human",
ignore_patterns=["*.md"]
)
print("Model downloaded to ./scGPT_human/")
EOF
# Available pre-trained models:
# scGPT_human - whole human body (recommended for human data)
# scGPT_CP - cell perturbation model
# scGPT_bc - blood cells
# scGPT_brain - brain-specific
scGPT expects AnnData objects (the standard scRNA-Seq format used by Scanpy). Input should be raw counts or lightly normalised data.
import scanpy as sc
import numpy as np
# Load example PBMC dataset (or your own .h5ad)
adata = sc.datasets.pbmc3k() # 2,700 PBMCs, 32,738 genes
# Standard preprocessing
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
# Calculate QC metrics
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)
# Filter cells
adata = adata[adata.obs.n_genes_by_counts < 2500, :]
adata = adata[adata.obs.pct_counts_mt < 5, :]
# Store raw counts (scGPT needs raw counts)
adata.layers["counts"] = adata.X.copy()
# Normalize and log-transform (for visualisation only)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
print(f"Data shape: {adata.shape}")
# (2638 cells, 13714 genes)
The most common scGPT use case — annotate cells in a query dataset using the pre-trained model without requiring labelled training data from the same dataset.
import scgpt
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
import torch
# Load model
model_dir = "./scGPT_human"
vocab_file = f"{model_dir}/vocab.json"
# Load gene vocabulary
vocab = GeneVocab.from_file(vocab_file)
# Get cell embeddings using scGPT
embeddings = scgpt.get_batch_cell_embeddings(
adata,
cell_embedding_mode="cls", # use [CLS] token as cell embedding
model_dir=model_dir,
vocab=vocab,
max_length=1200, # max genes per cell
batch_size=64,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Add embeddings to AnnData
adata.obsm["X_scGPT"] = embeddings
print(f"Embeddings shape: {embeddings.shape}")
# Cluster on scGPT embeddings
sc.pp.neighbors(adata, use_rep="X_scGPT", n_neighbors=15)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=0.5)
# Visualise
import matplotlib.pyplot as plt
sc.pl.umap(adata,
color=["leiden", "CD3D", "MS4A1", "LYZ", "PPBP"],
ncols=3,
show=False)
plt.savefig("scgpt_umap.pdf", dpi=300, bbox_inches="tight")
plt.show()
from scgpt.tasks import CellTypeAnnotation
# Fine-tune with a reference dataset that has known cell types
# Or use zero-shot nearest-neighbour search against the pre-trained space
annotator = CellTypeAnnotation(model_dir=model_dir)
# Load reference (e.g. annotated PBMC atlas)
ref = sc.read_h5ad("pbmc_reference_annotated.h5ad")
# Predict cell types in query
predictions = annotator.predict(
query=adata,
reference=ref,
ref_label_col="cell_type",
method="knn", # k-nearest neighbour in embedding space
k=10
)
adata.obs["scGPT_celltype"] = predictions
sc.pl.umap(adata, color="scGPT_celltype", show=False)
plt.savefig("scgpt_annotation.pdf", dpi=300, bbox_inches="tight")
plt.show()
import scanpy as sc
# Assume you have two batches
adata_batch1 = sc.read_h5ad("batch1.h5ad")
adata_batch2 = sc.read_h5ad("batch2.h5ad")
# Concatenate
adata_combined = sc.concat(
[adata_batch1, adata_batch2],
label="batch",
keys=["batch1", "batch2"]
)
# Get scGPT embeddings — the model's shared latent space
# inherently reduces batch effects
embeddings = scgpt.get_batch_cell_embeddings(
adata_combined,
cell_embedding_mode="cls",
model_dir=model_dir,
batch_size=64
)
adata_combined.obsm["X_scGPT"] = embeddings
# UMAP coloured by batch — should show mixing if integration worked
sc.pp.neighbors(adata_combined, use_rep="X_scGPT")
sc.tl.umap(adata_combined)
sc.pl.umap(adata_combined,
color=["batch", "cell_type"],
show=False)
plt.savefig("scgpt_integration.pdf", dpi=300, bbox_inches="tight")
# Quantify integration quality with scib metrics
# pip install scib
import scib
results = scib.metrics.metrics(
adata_combined,
adata_combined,
batch_key="batch",
label_key="cell_type",
embed="X_scGPT"
)
print(results[["ASW_label", "iLISI", "NMI_cluster/label"]])
scGPT's attention weights encode relationships between genes in a cell context. These can be extracted to infer gene regulatory networks without requiring TF-binding or chromatin data.
from scgpt.tasks import GeneEmbedding
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
# Get gene-level embeddings (not cell-level)
gene_emb = GeneEmbedding(model_dir=model_dir)
# Extract attention-based gene-gene similarity
# Focus on highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
hvgs = adata.var_names[adata.var["highly_variable"]].tolist()
# Compute gene-gene attention scores
grn_scores = gene_emb.get_gene_gene_attention(
adata,
genes_of_interest=hvgs,
model_dir=model_dir,
top_k=20 # keep top 20 connections per gene
)
# Build network
G = nx.from_pandas_edgelist(
grn_scores,
source="gene1",
target="gene2",
edge_attr="attention_score",
create_using=nx.DiGraph()
)
# Find hub genes (highest degree)
degree_df = pd.DataFrame({
"gene": list(dict(G.degree()).keys()),
"degree": list(dict(G.degree()).values())
}).sort_values("degree", ascending=False)
print("Top 10 hub genes (regulators):")
print(degree_df.head(10).to_string(index=False))
# Visualise subnetwork around top TF
top_tf = degree_df.iloc[0]["gene"]
subgraph = G.subgraph([top_tf] + list(G.successors(top_tf)))
pos = nx.spring_layout(subgraph, seed=42)
fig, ax = plt.subplots(figsize=(10, 8))
nx.draw_networkx(subgraph, pos=pos,
node_color="lightblue", node_size=500,
font_size=8, arrows=True, ax=ax)
ax.set_title(f"Gene regulatory subnetwork around {top_tf}", fontweight="bold")
plt.tight_layout()
plt.savefig("grn_subnetwork.pdf", dpi=300, bbox_inches="tight")
plt.show()
Using the scGPT perturbation model (scGPT_CP), you can predict how cells respond to gene knockouts or overexpression — useful for drug target identification.
# Download perturbation model
from huggingface_hub import snapshot_download
snapshot_download(repo_id="bowang-lab/scGPT_CP", local_dir="./scGPT_CP")
from scgpt.tasks import PerturbationPrediction
predictor = PerturbationPrediction(model_dir="./scGPT_CP")
# Predict effect of knocking out TP53 in cancer cells
predicted_expr = predictor.predict(
adata=adata,
perturbation="TP53", # gene to knock out
perturbation_type="KO", # KO = knockout, OE = overexpression
ctrl_key="control" # column in adata.obs indicating control cells
)
# Compare predicted vs observed (if you have Perturb-Seq data)
import numpy as np
import seaborn as sns
obs_ctrl = adata[adata.obs["condition"] == "ctrl"].X.mean(axis=0)
obs_ko = adata[adata.obs["condition"] == "TP53_KO"].X.mean(axis=0)
pred_ko = predicted_expr.mean(axis=0)
# Pearson correlation of prediction vs observation
from scipy.stats import pearsonr
r, p = pearsonr(np.asarray(obs_ko).flatten(), np.asarray(pred_ko).flatten())
print(f"Pearson r (predicted vs observed): {r:.3f} p={p:.2e}")
For best performance on a specific tissue/disease, fine-tune the pre-trained model on your labelled data.
from scgpt.tasks import GeneExpressionClassifier
import torch
# Requires labelled training data
# adata.obs["cell_type"] must exist and have enough cells per class
trainer = GeneExpressionClassifier(
model_dir=model_dir,
n_classes=len(adata.obs["cell_type"].unique()),
device="cuda"
)
trainer.train(
train_adata=adata_train, # 80% of data
valid_adata=adata_valid, # 20% of data
label_key="cell_type",
max_epochs=10,
batch_size=32,
lr=1e-4,
save_dir="./scGPT_finetuned/"
)
# Predict on held-out test set
preds = trainer.predict(adata_test)
adata_test.obs["predicted_celltype"] = preds
# Evaluate
from sklearn.metrics import classification_report
print(classification_report(
adata_test.obs["cell_type"],
adata_test.obs["predicted_celltype"]
))
| Feature | scGPT | CellTypist | Scimilarity |
|---|---|---|---|
| Publication | Nature Methods 2024 | Science 2022 | Nature Methods 2023 |
| Pre-training data | 33M cells (CELLxGENE) | Not pre-trained | 22M cells |
| Model type | Transformer (GPT) | Logistic regression | Contrastive VAE |
| Tasks | Annotation, integration, GRN, perturbation | Annotation only | Annotation, integration |
| Zero-shot annotation | Yes | No | Yes |
| GPU required | Recommended | No | Recommended |
| Speed | Moderate | Very fast | Fast |
| Best for | Complex multi-task analysis | Quick annotation | Large atlas integration |
get_batch_cell_embeddings() to get cell representations for clustering and UMAP