Nature Methods 2024 AI / Foundation Model

scGPT: Foundation Model for Single-Cell Multi-Omics

A transformer-based foundation model pre-trained on 33 million single-cell profiles. Use it for zero-shot cell annotation, gene network inference, multi-batch integration, and perturbation response prediction — without retraining from scratch.

~60 min Advanced Python + GPU scRNA-Seq
Byte

What is scGPT?

scGPT (Cui et al., Nature Methods 2024) is a generative pre-trained transformer (GPT-style) trained on over 33 million single-cell RNA-seq profiles from the CELLxGENE Discover corpus. It learns a generalised representation of cell biology that can be fine-tuned or used directly for a wide range of downstream tasks:

Cell Type Annotation

Annotate cells in new datasets using pre-trained embeddings — even zero-shot without labelled training data

Batch Integration

Remove batch effects across datasets, technologies, and labs using the shared latent space

Gene Network Inference

Extract gene–gene regulatory relationships from attention weights of the transformer

Perturbation Prediction

Predict transcriptional response to genetic knockouts or drug perturbations

Byte key point
Why it matters: Traditional single-cell pipelines (Seurat, Scanpy) treat each dataset independently. scGPT's pre-trained model carries biological knowledge from 33M cells across tissues, diseases, and organisms — enabling transfer learning that outperforms task-specific models on cell annotation and integration benchmarks.

Architecture Overview

scGPT treats each cell as a "sentence" and each gene expression value as a "token". Unlike text GPT:

  • Gene tokens are embeddings of gene identities (from gene name/ID), not text words
  • Expression tokens encode binned expression values (0, low, medium, high)
  • Condition tokens encode batch, tissue, disease state
  • A cell embedding [CLS] token aggregates the full cell representation

Pre-training uses a masked gene modelling objective — predict masked gene expression values from context (analogous to BERT's masked language modelling).

Installation

Byte warning
GPU recommended: scGPT inference is feasible on CPU for small datasets but a GPU (16GB+ VRAM) is strongly recommended for fine-tuning and large datasets.
Bash
# Create conda environment
conda create -n scgpt python=3.10 -y
conda activate scgpt

# Install PyTorch (adjust CUDA version to match your GPU driver)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install scGPT
pip install scgpt

# Install supporting packages
pip install scanpy anndata matplotlib seaborn wandb

# Verify
python -c "import scgpt; print(scgpt.__version__)"
Download pre-trained model weights
Bash
# Download the whole-human-body pre-trained model (~1.5 GB)
# Model is hosted on HuggingFace
pip install huggingface_hub

python - <<'EOF'
from huggingface_hub import snapshot_download
snapshot_download(
    repo_id="bowang-lab/scGPT_human",
    local_dir="./scGPT_human",
    ignore_patterns=["*.md"]
)
print("Model downloaded to ./scGPT_human/")
EOF

# Available pre-trained models:
# scGPT_human  - whole human body (recommended for human data)
# scGPT_CP     - cell perturbation model
# scGPT_bc     - blood cells
# scGPT_brain  - brain-specific

Prepare Your Data

scGPT expects AnnData objects (the standard scRNA-Seq format used by Scanpy). Input should be raw counts or lightly normalised data.

Python
import scanpy as sc
import numpy as np

# Load example PBMC dataset (or your own .h5ad)
adata = sc.datasets.pbmc3k()  # 2,700 PBMCs, 32,738 genes

# Standard preprocessing
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

# Calculate QC metrics
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

# Filter cells
adata = adata[adata.obs.n_genes_by_counts < 2500, :]
adata = adata[adata.obs.pct_counts_mt < 5, :]

# Store raw counts (scGPT needs raw counts)
adata.layers["counts"] = adata.X.copy()

# Normalize and log-transform (for visualisation only)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

print(f"Data shape: {adata.shape}")
# (2638 cells, 13714 genes)

1Cell Type Annotation

The most common scGPT use case — annotate cells in a query dataset using the pre-trained model without requiring labelled training data from the same dataset.

Python
import scgpt
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
import torch

# Load model
model_dir = "./scGPT_human"
vocab_file = f"{model_dir}/vocab.json"

# Load gene vocabulary
vocab = GeneVocab.from_file(vocab_file)

# Get cell embeddings using scGPT
embeddings = scgpt.get_batch_cell_embeddings(
    adata,
    cell_embedding_mode="cls",    # use [CLS] token as cell embedding
    model_dir=model_dir,
    vocab=vocab,
    max_length=1200,              # max genes per cell
    batch_size=64,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Add embeddings to AnnData
adata.obsm["X_scGPT"] = embeddings
print(f"Embeddings shape: {embeddings.shape}")

# Cluster on scGPT embeddings
sc.pp.neighbors(adata, use_rep="X_scGPT", n_neighbors=15)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=0.5)

# Visualise
import matplotlib.pyplot as plt
sc.pl.umap(adata,
           color=["leiden", "CD3D", "MS4A1", "LYZ", "PPBP"],
           ncols=3,
           show=False)
plt.savefig("scgpt_umap.pdf", dpi=300, bbox_inches="tight")
plt.show()
Zero-shot cell annotation using scGPT reference atlas
Python
from scgpt.tasks import CellTypeAnnotation

# Fine-tune with a reference dataset that has known cell types
# Or use zero-shot nearest-neighbour search against the pre-trained space

annotator = CellTypeAnnotation(model_dir=model_dir)

# Load reference (e.g. annotated PBMC atlas)
ref = sc.read_h5ad("pbmc_reference_annotated.h5ad")

# Predict cell types in query
predictions = annotator.predict(
    query=adata,
    reference=ref,
    ref_label_col="cell_type",
    method="knn",     # k-nearest neighbour in embedding space
    k=10
)

adata.obs["scGPT_celltype"] = predictions
sc.pl.umap(adata, color="scGPT_celltype", show=False)
plt.savefig("scgpt_annotation.pdf", dpi=300, bbox_inches="tight")
plt.show()

2Multi-Batch Integration

Python
import scanpy as sc

# Assume you have two batches
adata_batch1 = sc.read_h5ad("batch1.h5ad")
adata_batch2 = sc.read_h5ad("batch2.h5ad")

# Concatenate
adata_combined = sc.concat(
    [adata_batch1, adata_batch2],
    label="batch",
    keys=["batch1", "batch2"]
)

# Get scGPT embeddings — the model's shared latent space
# inherently reduces batch effects
embeddings = scgpt.get_batch_cell_embeddings(
    adata_combined,
    cell_embedding_mode="cls",
    model_dir=model_dir,
    batch_size=64
)
adata_combined.obsm["X_scGPT"] = embeddings

# UMAP coloured by batch — should show mixing if integration worked
sc.pp.neighbors(adata_combined, use_rep="X_scGPT")
sc.tl.umap(adata_combined)
sc.pl.umap(adata_combined,
           color=["batch", "cell_type"],
           show=False)
plt.savefig("scgpt_integration.pdf", dpi=300, bbox_inches="tight")

# Quantify integration quality with scib metrics
# pip install scib
import scib
results = scib.metrics.metrics(
    adata_combined,
    adata_combined,
    batch_key="batch",
    label_key="cell_type",
    embed="X_scGPT"
)
print(results[["ASW_label", "iLISI", "NMI_cluster/label"]])

3Gene Regulatory Network Inference

scGPT's attention weights encode relationships between genes in a cell context. These can be extracted to infer gene regulatory networks without requiring TF-binding or chromatin data.

Python
from scgpt.tasks import GeneEmbedding
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

# Get gene-level embeddings (not cell-level)
gene_emb = GeneEmbedding(model_dir=model_dir)

# Extract attention-based gene-gene similarity
# Focus on highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
hvgs = adata.var_names[adata.var["highly_variable"]].tolist()

# Compute gene-gene attention scores
grn_scores = gene_emb.get_gene_gene_attention(
    adata,
    genes_of_interest=hvgs,
    model_dir=model_dir,
    top_k=20   # keep top 20 connections per gene
)

# Build network
G = nx.from_pandas_edgelist(
    grn_scores,
    source="gene1",
    target="gene2",
    edge_attr="attention_score",
    create_using=nx.DiGraph()
)

# Find hub genes (highest degree)
degree_df = pd.DataFrame({
    "gene": list(dict(G.degree()).keys()),
    "degree": list(dict(G.degree()).values())
}).sort_values("degree", ascending=False)

print("Top 10 hub genes (regulators):")
print(degree_df.head(10).to_string(index=False))

# Visualise subnetwork around top TF
top_tf = degree_df.iloc[0]["gene"]
subgraph = G.subgraph([top_tf] + list(G.successors(top_tf)))
pos = nx.spring_layout(subgraph, seed=42)
fig, ax = plt.subplots(figsize=(10, 8))
nx.draw_networkx(subgraph, pos=pos,
                 node_color="lightblue", node_size=500,
                 font_size=8, arrows=True, ax=ax)
ax.set_title(f"Gene regulatory subnetwork around {top_tf}", fontweight="bold")
plt.tight_layout()
plt.savefig("grn_subnetwork.pdf", dpi=300, bbox_inches="tight")
plt.show()

4Perturbation Response Prediction

Using the scGPT perturbation model (scGPT_CP), you can predict how cells respond to gene knockouts or overexpression — useful for drug target identification.

Python
# Download perturbation model
from huggingface_hub import snapshot_download
snapshot_download(repo_id="bowang-lab/scGPT_CP", local_dir="./scGPT_CP")

from scgpt.tasks import PerturbationPrediction

predictor = PerturbationPrediction(model_dir="./scGPT_CP")

# Predict effect of knocking out TP53 in cancer cells
predicted_expr = predictor.predict(
    adata=adata,
    perturbation="TP53",        # gene to knock out
    perturbation_type="KO",     # KO = knockout, OE = overexpression
    ctrl_key="control"          # column in adata.obs indicating control cells
)

# Compare predicted vs observed (if you have Perturb-Seq data)
import numpy as np
import seaborn as sns

obs_ctrl = adata[adata.obs["condition"] == "ctrl"].X.mean(axis=0)
obs_ko   = adata[adata.obs["condition"] == "TP53_KO"].X.mean(axis=0)
pred_ko  = predicted_expr.mean(axis=0)

# Pearson correlation of prediction vs observation
from scipy.stats import pearsonr
r, p = pearsonr(np.asarray(obs_ko).flatten(), np.asarray(pred_ko).flatten())
print(f"Pearson r (predicted vs observed): {r:.3f}  p={p:.2e}")

5Fine-Tuning on Your Dataset

For best performance on a specific tissue/disease, fine-tune the pre-trained model on your labelled data.

Python
from scgpt.tasks import GeneExpressionClassifier
import torch

# Requires labelled training data
# adata.obs["cell_type"] must exist and have enough cells per class

trainer = GeneExpressionClassifier(
    model_dir=model_dir,
    n_classes=len(adata.obs["cell_type"].unique()),
    device="cuda"
)

trainer.train(
    train_adata=adata_train,          # 80% of data
    valid_adata=adata_valid,          # 20% of data
    label_key="cell_type",
    max_epochs=10,
    batch_size=32,
    lr=1e-4,
    save_dir="./scGPT_finetuned/"
)

# Predict on held-out test set
preds = trainer.predict(adata_test)
adata_test.obs["predicted_celltype"] = preds

# Evaluate
from sklearn.metrics import classification_report
print(classification_report(
    adata_test.obs["cell_type"],
    adata_test.obs["predicted_celltype"]
))

scGPT vs CellTypist vs Scimilarity

FeaturescGPTCellTypistScimilarity
PublicationNature Methods 2024Science 2022Nature Methods 2023
Pre-training data33M cells (CELLxGENE)Not pre-trained22M cells
Model typeTransformer (GPT)Logistic regressionContrastive VAE
TasksAnnotation, integration, GRN, perturbationAnnotation onlyAnnotation, integration
Zero-shot annotationYesNoYes
GPU requiredRecommendedNoRecommended
SpeedModerateVery fastFast
Best forComplex multi-task analysisQuick annotationLarge atlas integration

Summary

Byte summary
What you learned:
  • scGPT is a GPT-style transformer pre-trained on 33M single-cell profiles
  • Use get_batch_cell_embeddings() to get cell representations for clustering and UMAP
  • Zero-shot cell annotation via kNN in the shared embedding space
  • Gene regulatory networks extracted from transformer attention weights
  • Perturbation response prediction with the scGPT_CP model
  • Fine-tune on your own labelled data for best tissue-specific performance
Cite this tool
Cui H, Wang C, Maan H, et al. scGPT: toward building a foundation model for single-cell multi-omics using generative AI. Nature Methods. 2024. https://doi.org/10.1038/s41592-024-02201-0
Related Tutorials