The Annotated AlphaFold2

A line-by-line walkthrough of Jumper et al. 2021, with a minimal PyTorch implementation you can actually run.

Paper
John Jumper, Richard Evans, Alexander Pritzel, et al. (2021)
Annotated by
Chris Hayduk
Published
April 21, 2026
Reading time
~90 min

Over the past couple of years, I’ve been on a journey to understand protein language models — what they actually learn, how they work, and why they represent one of the two largest breakthroughs in the history of AI (alongside GPT-3-style language models). Along the way, I wrote minAlphaFold2, a pedagogical PyTorch reimplementation of AlphaFold2 designed, from the first line, to be readable alongside the paper. Roughly 8,000 lines across 17 modules. Variable names that match the paper’s notation. Comments that cite the supplement’s algorithm numbers inline. Roughly 95% of the paper implemented (the main omission is self-distillation — I’ll be honest about that in §13). 130 passing tests that cover shapes, loss values, and the training loop end-to-end.

The inspiration for minAlphaFold2 was twofold. First, the AlphaFold2 supplementary paper is, in my view, one of the cleanest pieces of technical writing in modern AI — 62 pages of pseudocode, diagrams, and prose that essentially demands a clear code counterpart. Second, Andrej Karpathy’s minGPT showed what a pedagogical reimplementation could look like for GPT. The pedagogical material available for AF2 was essentially non-existent by comparison, and I wanted to do a small part in changing that.

This annotation is the other half of that work: a guided walkthrough of the repo that explains why each piece looks the way it does, one section of the paper at a time. The prose follows AlphaFold2 section by section, but every algorithmic claim is tied back to real, runnable code — specifically, the code that lives in minAlphaFold2 at the pinned commit this site builds against. No pseudocode shortcuts, no hand-waving about what the paper “probably” means.

If you’d like to follow along, clone the repo and pip install -e .. The scripts/overfit_single_pdb.py trainer runs on a laptop CPU and takes a small protein from random-init to ~2 Å Cα RMSD in about thirty seconds. We’ll walk through that exact run, step by step, in §14. Let’s dive in!

0. Prelims

Before we start, let’s do a quick orientation to minAlphaFold2 — the reference codebase this annotation walks through — and the conventions we’ll use throughout.

The goal of minAlphaFold2 was to make AlphaFold2 legible. The full DeepMind implementation is roughly 50k lines of JAX spread across a dense internal production codebase; minAlphaFold is ~8k lines of pure PyTorch written specifically to be read alongside the paper. Every design decision in the repo is downstream of that goal:

Variable names follow the paper. m is the MSA representation, shape (s,r,cm)(s, r, c_m), where ss is the number of sequences, rr is the number of residues, and cmc_m is the channel dimension. z is the pair representation, shape (r,r,cz)(r, r, c_z). s_i is the single representation in the Structure Module. T_i is a rigid frame. If you have the paper open in a tab, you can cross-reference any tensor in the code directly against its symbol in the supplement.

Algorithms are annotated inline. Every class and function in minAlphaFold cites the supplement algorithm it implements — comments like # Algorithm 7 appear right above the forward method. When I say “Algorithm 7” below, you can grep for it in the repo and land on the corresponding class. Shapes are asserted at every forward pass, and the tests run those assertions end-to-end.

Three config profiles live under configs/:

  • tiny.toml — about 1M parameters, CPU-runnable, good for unit tests and for the overfit walkthrough in §14.
  • medium.toml — for overfit experiments on a single GPU.
  • alphafold2.toml — the full 48-block, paper-spec config.

The walkthrough below will mostly stay in the abstract, but when we get to the real example in §14, we’ll load tiny.toml and watch the structure module actually converge.

1. Background

AlphaFold2 was not the first deep-learning approach to protein structure prediction, but it was the first one that worked — at least in the sense of producing structures accurate enough to replace an experimental method for a large fraction of proteins. The reasons it worked have more to do with how biology is arranged than with any single architectural trick, so it’s worth starting there.

1.1 The folding problem

A protein is a chain of amino acids. There are 20 of them (22 in nature, 20 in the human body), each a small organic molecule with a common backbone and a distinct side chain. The cell transcribes DNA into RNA and translates RNA into a chain of these amino acids — a sequence, essentially a string over a 20-letter alphabet. That’s the primary structure.

The primary structure then folds, under thermodynamic constraints, into a specific three-dimensional shape — the tertiary structure. That shape is what the protein does: the arrangement of side chains in 3D determines which other molecules it binds, which reactions it catalyzes, which signals it transduces. In other words, function follows structure, and structure follows sequence.

The folding problem is the problem of going from sequence to structure computationally. It turns out to be hard for a reason that Cyrus Levinthal pointed out in 1969: a sequence of length 100 has something like 31003^{100} possible conformations if each residue has three rotational states. You cannot brute-force the search. And yet, somehow, proteins fold on millisecond timescales inside cells, reliably and reproducibly. The answer has to be that the landscape is biased — that proteins don’t sample uniformly but are instead funneled, by the physics of their interactions, toward a single energy minimum.

For the last 50 years, the field has tried to model that funnel directly — with physics simulations, statistical potentials, fragment libraries, and every kind of approximation in between. None of them worked well enough, at scale, to be useful. AlphaFold2 skipped the physics entirely and learned the funnel from data.

1.2 MSAs and evolutionary coupling

The data, it turns out, is not just the target sequence itself. The key insight, which predates AlphaFold2 by decades, is that related proteins from different species share the same fold, and the mutations they accumulate over evolutionary time are not independent. When residues ii and jj are in contact in 3D, a mutation at ii that disrupts the contact exerts selection pressure for a compensating mutation at jj. Over millions of years and thousands of species, that pressure leaves a statistical signature — namely, the columns for residues ii and jj in a multiple sequence alignment co-vary.

A multiple sequence alignment, or MSA, is exactly what it sounds like — take the target sequence, find a bunch of evolutionarily related sequences, and line them up so that homologous residues sit in the same column.

Homo sapiens M K L L P V L - - T A L L A L C W L A Mus musculus M K L L P V L - - T A L L A L C W L A Danio rerio M K L L T V L - - T A L L A L C W L A Drosophila M L - - P V L P T T A L L A L T W - A Saccharomyces M K V L - V L A T T S L L T L S W M G ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ conserved gap variable

A fragment of an MSA. Each row is a protein from a different organism; each column is a position. Mutations show up as column variation; insertions and deletions as gaps (dashes).

From an MSA, you can infer a lot. Column variation tells you which residues are conserved (evolution has decided they matter) and which are free to drift. Correlated columns tell you which residues are coupled — the hallmark of contact in 3D. Phylogenetic structure tells you which species the signal is coming from. This is why AlphaFold2 treats the MSA as a first-class input rather than as an ancillary feature.

AlphaFold2 builds its MSAs by searching four databases with two tools (HHBlits for sensitivity, JackHMMER for coverage):

  • MGnify — metagenomic sequences, where a lot of novel protein diversity lives.
  • UniRef90 — clustered UniProt, the main protein sequence database.
  • Uniclust30 — another clustering of UniProt, at 30% sequence identity.
  • BFD — Big Fantastic Database, billions of sequences assembled from raw metagenomic reads.

For our purposes, the MSA pipeline itself is out of scope — minAlphaFold takes a pre-built MSA as input. What matters is the shape of what comes out: a tensor of one-hot-ish per-residue features for each aligned sequence, typically a few hundred to a few thousand sequences deep. Everything we do from here is downstream of that tensor.

If you’d like to go deeper on MSAs specifically — how AF2 uses row- and column-wise attention on the MSA representation, and why the outer product mean is the right way to push that information into the pair representation — I’ve written a three-part primer on protein language models that covers the MSA pipeline in Part I and the transformer-based replacement (ESMFold) in Parts II and III. This annotation will cover all of that ground too, but from the minAlphaFold2 codebase up rather than from the paper down.

1.3 Templates

Templates are the second source of prior knowledge AlphaFold2 uses. If an evolutionarily related protein already has a solved 3D structure — deposited in the Protein Data Bank — we can use its coordinates as a starting hint. AlphaFold2 finds up to four templates via HHSearch against PDB70, featurizes them into pair-like embeddings, and folds them into the same pipeline as the MSA.

In practice, templates matter less than you’d expect. For proteins with deep MSAs, templates add maybe a point or two of accuracy on average; for proteins with shallow MSAs, templates help a bit more. But the model works without them — in the ablations, the MSA is doing most of the heavy lifting. We’ll treat templates as a detail of the input pipeline and cover them briefly in §8.

1.4 Rigid frames and SE(3)

The output of AlphaFold2 is a 3D structure, but the intermediate representation isn’t quite Cartesian coordinates. Instead, each residue has an associated local frame — a position tiR3t_i \in \mathbb{R}^3 and a rotation RiSO(3)R_i \in \mathrm{SO}(3) that together specify a little coordinate system attached to the backbone. The whole structure is a sequence of these frames, one per residue.

This matters for two reasons. First, frames are the natural object for expressing geometric invariances. The physics of protein folding doesn’t care about the global position or orientation of the molecule — if you translate and rotate the entire structure, everything still works. That’s SE(3) invariance, and any loss or model component that respects it has a much easier time generalizing than one that doesn’t. Second, frames give you a way to express local geometry cleanly: the position of residue jj as seen by residue ii is just Ti1tjT_i^{-1} \cdot t_j — the translation tjt_j expressed in residue ii‘s local coordinate system. That “as seen by” operation is the core of Invariant Point Attention, which we’ll get to in §10.

For §§1–9, we mostly won’t need to touch frames. They’ll show up properly when we hit the Structure Module. But it’s worth keeping in mind, when we’re looking at the Evoformer updating an abstract pair representation zz, that the whole apparatus is a wind-up — an elaborate way of preparing the inputs that the Structure Module will need in order to place atoms in space.

1.5 What AlphaFold2 changed

Before diving in, it’s useful to name what was actually new here. The rough outline of a contact-prediction-plus-folding pipeline had existed for years — run an MSA, extract co-evolution signals, predict a residue–residue contact map, then fold the protein by fragment assembly or distance-based minimization. AlphaFold1 was already doing something like that. The pieces AlphaFold2 added:

  1. End-to-end training. No separate “predict contacts, then fold” pipeline. The entire system, from MSA input to atom coordinates, is one computation graph with one loss.
  2. The Evoformer. A transformer-style stack that cross-attends between the MSA representation and the pair representation, letting each update the other over 48 blocks.
  3. Invariant Point Attention. A geometric attention variant that operates on rigid frames, respects SE(3) invariance by construction, and lets the Structure Module reason about 3D without losing equivariance.
  4. FAPE loss. A frame-aligned loss that measures error per-residue-frame rather than in a global coordinate system, giving every residue a local view of the structure it’s trying to build.
  5. Recycling. A mechanism where the model’s own output predictions get fed back as input for another forward pass, enabling iterative refinement at both train and test time.

Each of these gets a section. The rest of the annotation is, essentially, a line-by-line tour of how these five ideas become executable PyTorch.

2. Input representations

AlphaFold2 takes in three kinds of raw data — the target sequence, the MSA, and the templates — and distills them into two working representations that flow through the rest of the network:

  • the MSA representation mm, shape (s,r,cm)(s, r, c_m), where ss is the number of sequences in the alignment and rr is the number of residues.
  • the pair representation zz, shape (r,r,cz)(r, r, c_z) — one embedding for every pair of residues in the target sequence.

You can think of mm as “what we know about each residue, as seen through evolutionary history” and zz as “what we know about each pair of residues.” The whole Evoformer is built around passing information back and forth between these two — using the pair representation to guide attention within the MSA, and using the MSA to refine the pair representation via the outer product mean. We’ll get there.

2.1 Building the features

The feature builders that turn raw inputs (an integer-encoded sequence, a stack of aligned MSA sequences, a set of template structures) into embedding-ready tensors live in build_target_feat . They’re mostly one-hot encoding and concatenation, but they matter because they fix what the network gets to see. The headliners:

  • build_target_feat takes the target integer-encoded sequence and produces a per-residue feature vector: a one-hot over the 21 amino-acid types (20 + unknown), plus a per-residue “deletion value” and a few binary flags. Dimension 22.
  • build_msa_feat takes the full MSA and produces a (s,r,49)(s, r, 49) tensor — for each sequence and each residue, a one-hot over 23 MSA tokens (the 20 amino acids, plus gap, mask, and unknown), a cluster profile averaged across the sequences in that cluster, plus features recording which residues had gap deletions and how many. This is where evolutionary information enters the model.
  • build_extra_msa_feat does the same thing for the “extra MSA” stack — a larger, cheaper-to-process slice of the alignment that gets its own shallower Evoformer (covered in §8).

2.2 The input embedder

Once we have the feature tensors, the InputEmbedder block projects them into mm and zz (Algorithm 3 of the supplement). The math is straightforward: three linear projections of the target features plus an outer-sum to produce a per-pair embedding for zz; one linear projection of the MSA features to produce the first row of mm; and a tile of the target features across all rows of mm, added in.

class InputEmbedder(torch.nn.Module):
    """Initial MSA + pair embedding (Algorithm 3).

    Produces the starting ``m_si`` and ``z_ij`` for the Evoformer by
    combining:

    * three linear projections of ``target_feat`` (shape
      ``(batch, N_res, TARGET_FEAT_DIM)``) — two broadcast into the
      outer-sum for ``z`` and one added to the query row of ``m``;
    * a relative-positional encoding ``RelPos(residue_index)`` added to
      ``z`` (Algorithm 4);
    * a linear projection of ``msa_feat`` (49 channels — cluster profile
      + deletion features, per Table 1) added to every MSA row in ``m``.

    Output shapes: ``m`` ``(batch, N_cluster, N_res, c_m)``, ``z``
    ``(batch, N_res, N_res, c_z)``.
    """

    def __init__(self, config):
        super().__init__()

        self.linear_target_feat_1 = torch.nn.Linear(in_features=TARGET_FEAT_DIM, out_features=config.c_z)
        self.linear_target_feat_2 = torch.nn.Linear(in_features=TARGET_FEAT_DIM, out_features=config.c_z)
        self.linear_target_feat_3 = torch.nn.Linear(in_features=TARGET_FEAT_DIM, out_features=config.c_m)


        self.linear_msa = torch.nn.Linear(in_features=49, out_features=config.c_m)
        init_linear(self.linear_target_feat_1, init="default")
        init_linear(self.linear_target_feat_2, init="default")
        init_linear(self.linear_target_feat_3, init="default")
        init_linear(self.linear_msa, init="default")

        self.rel_pos = RelPos(config)

    def forward(self, target_feat: torch.Tensor, residue_index: torch.Tensor, msa_feat: torch.Tensor):
        # target_feat shape: (batch, N_res, 22)
        # residue_index shape: (batch, N_res)
        # msa_feat shape: (batch, N_cluster, N_res, 49)
        assert target_feat.ndim == 3 and target_feat.shape[-1] == TARGET_FEAT_DIM, \
            f"target_feat must be (batch, N_res, {TARGET_FEAT_DIM}), got {target_feat.shape}"
        assert residue_index.ndim == 2, \
            f"residue_index must be (batch, N_res), got {residue_index.shape}"
        assert msa_feat.ndim == 4 and msa_feat.shape[-1] == 49, \
            f"msa_feat must be (batch, N_cluster, N_res, 49), got {msa_feat.shape}"

        # Output shape: (batch, N_res, c_z)
        a = self.linear_target_feat_1(target_feat)
        b = self.linear_target_feat_2(target_feat)

        # Output shape: (batch, N_res, N_res, c_z)
        # Row i should use element i from a, and col j should use element j from b
        z = a.unsqueeze(-2) + b.unsqueeze(-3)

        z += self.rel_pos(residue_index)

        # Output shape: (batch, N_cluster, N_res, c_m)
        m = self.linear_target_feat_3(target_feat).unsqueeze(1) + self.linear_msa(msa_feat)

        return m, z

Two details are worth lingering on. First, the pair representation is initialized from an outer-sum of the target features, not from the MSA. The MSA features arrive later, via the outer-product mean inside the Evoformer. Right after the embedder, zz encodes nothing about co-evolution — it’s a clean slate that the Evoformer will fill in. Second, the target features are tiled across every row of mm, including rows corresponding to other MSA sequences. This is the model’s way of telling each row “you are aligned to this target,” which gives the row attention and column attention operations a shared coordinate system to reason in.

2.3 Relative position encoding

The second piece of the input-embedding stage is the relative position encoding (Algorithm 4). Like the positional encoding in a vanilla Transformer, this tells the network where each residue sits in the chain — but AlphaFold2 encodes relative positions rather than absolute ones, which matches the physics (only relative positions along the chain affect structure) and generalizes better to lengths unseen at training time.

class RelPos(torch.nn.Module):
    """Relative-position encoding (Algorithm 4).

    One-hots the clipped residue-index difference
    ``clamp(r_i - r_j, -max_rel, max_rel)`` into ``2·max_rel+1`` bins
    (default ``max_rel = 32`` → 65 bins) and projects to ``c_z``. The
    output is added to the pair representation by :class:`InputEmbedder`
    so the Evoformer trunk has a learned sense of residue adjacency from
    the very first block. Clipping at ±32 matches the supplement.
    """

    def __init__(self, config, max_rel=32):
        super().__init__()
        self.max_rel = max_rel
        self.linear = torch.nn.Linear(2 * max_rel + 1, config.c_z)
        init_linear(self.linear, init="default")

    def forward(self, residue_index: torch.Tensor):
        # residue_index shape: (batch, N_res)
        d = residue_index[:, :, None] - residue_index[:, None, :]  # (batch, N_res, N_res)
        d = d.clamp(-self.max_rel, self.max_rel) + self.max_rel
        oh = torch.nn.functional.one_hot(d.long(), 2 * self.max_rel + 1).float()
        return self.linear(oh)  # (batch, N_res, N_res, c_z)

The encoding is a clipped one-hot over (ji)(j - i) within a window of ±32\pm 32 residues (anything further gets bucketed to the edges), followed by a linear projection to czc_z. It gets added directly into zz. Simple, but critical — without it, the pair representation has no notion of sequence adjacency, and the attention heads have to rediscover the chain from the pair features alone, which doesn’t really work.

3. The Evoformer — bird’s-eye view

The Evoformer is the core of AlphaFold2. It’s a stack of 48 identical blocks, each one taking an MSA representation mm and a pair representation zz, applying a fixed sequence of updates, and passing the refined (m,z)(m, z) to the next block. By the end of the stack, mm and zz are both dense with structure-relevant information and ready for the Structure Module to consume.

class Evoformer(torch.nn.Module):
    """Evoformer block (Algorithm 6).

    One full iteration over the paired MSA + pair representations:

    1. MSA row-wise attention with pair bias (Alg 7) + row-wise dropout;
    2. MSA column-wise attention (Alg 8) — no dropout;
    3. MSA transition (Alg 9) — no dropout;
    4. Pair update from the MSA via outer-product mean (Alg 10);
    5. Triangle multiplicative updates outgoing/incoming (Alg 11/12) +
       row-wise dropout;
    6. Triangle self-attention around the starting/ending node (Alg 13/14)
       — row-wise dropout on starting, column-wise on ending
       (supplement 1.11.6);
    7. Pair transition (Alg 15) — no dropout.

    Dropout rates are read from ``config.evoformer_msa_dropout`` and
    ``config.evoformer_pair_dropout`` so the Template pair stack
    (:class:`minalphafold.embedders.TemplatePair`) and the Extra MSA stack
    (:class:`minalphafold.embedders.ExtraMsaStack`) can reuse the same
    sub-modules with different dropout schedules. The block is stacked
    ``config.num_evoformer_blocks`` times in :class:`minalphafold.model.AlphaFold2`.
    """

    def __init__(self, config):
        super().__init__()
        self.msa_row_att = MSARowAttentionWithPairBias(config)
        self.msa_col_att = MSAColumnAttention(config)
        self.msa_transition = MSATransition(config)
        self.outer_mean = OuterProductMean(config)

        self.triangle_mult_out = TriangleMultiplicationOutgoing(config)
        self.triangle_mult_in = TriangleMultiplicationIncoming(config)
        self.triangle_att_start = TriangleAttentionStartingNode(config)
        self.triangle_att_end = TriangleAttentionEndingNode(config)

        self.pair_transition = PairTransition(config)

        # Dropout rates from config
        self.msa_dropout = config.evoformer_msa_dropout
        self.pair_dropout = config.evoformer_pair_dropout

    def forward(self, msa_representation: torch.Tensor, pair_representation: torch.Tensor,
                msa_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None):
        # msa_mask: (batch, N_seq, N_res) — 1 for valid, 0 for padding
        # pair_mask: (batch, N_res, N_res) — 1 for valid, 0 for padding
        assert msa_representation.ndim == 4, \
            f"msa_representation must be (batch, N_seq, N_res, c_m), got {msa_representation.shape}"
        assert pair_representation.ndim == 4, \
            f"pair_representation must be (batch, N_res, N_res, c_z), got {pair_representation.shape}"
        # Shape (batch, N_seq, N_res, c_m)
        z = self.msa_row_att(msa_representation, pair_representation, msa_mask=msa_mask)
        msa_representation = msa_representation + dropout_rowwise(z, p=self.msa_dropout, training=self.training)

        # No dropout on column attention or MSA transition per Algorithm 6
        msa_representation = msa_representation + self.msa_col_att(msa_representation, msa_mask=msa_mask)
        msa_representation = msa_representation + self.msa_transition(msa_representation)

        pair_representation = pair_representation + self.outer_mean(msa_representation, msa_mask=msa_mask)

        pair_representation = pair_representation + dropout_rowwise(self.triangle_mult_out(pair_representation, pair_mask=pair_mask), p=self.pair_dropout, training=self.training)
        pair_representation = pair_representation + dropout_rowwise(self.triangle_mult_in(pair_representation, pair_mask=pair_mask), p=self.pair_dropout, training=self.training)
        pair_representation = pair_representation + dropout_rowwise(self.triangle_att_start(pair_representation, pair_mask=pair_mask), p=self.pair_dropout, training=self.training)
        pair_representation = pair_representation + dropout_columnwise(self.triangle_att_end(pair_representation, pair_mask=pair_mask), p=self.pair_dropout, training=self.training)
        # No dropout on pair transition per Algorithm 6
        pair_representation = pair_representation + self.pair_transition(pair_representation)

        return msa_representation, pair_representation

Each block does, in order: (1) MSA row-wise attention biased by zz, (2) MSA column-wise attention, (3) an MSA transition MLP, (4) an outer-product mean that writes MSA information back into zz, (5) triangle multiplicative updates (outgoing and incoming) on zz, (6) triangle self-attention (starting- and ending-node) on zz, and (7) a pair transition MLP. That’s seven sub-blocks per Evoformer block, 48 blocks total, and every sub-block has a dropout pattern matched to its shape.

It’s worth noticing how lopsided this is. Of the seven sub-blocks, only the first three operate on mm; the other four operate on zz. The MSA stack is narrow and mostly one-way (MSA → pair via the outer product mean). The pair stack is where the geometry happens — triangle consistency, triangle-biased attention — and it eats most of the block. In essence, the whole design is: extract co-evolution signal from the MSA, push it into the pair representation, then spend most of the block refining the pair representation into something the Structure Module can use.

Two structural things to flag before we dive in. First, the MSA stack terminates at the end of the Evoformer — only the first row of mm (the target sequence’s representation) gets passed forward to the Structure Module. The other s1s-1 rows have done their job; they’ve pushed their evolutionary signal through the outer product mean and into zz, and we don’t need them anymore. Second, recycling (§9) wraps around the whole thing: the output of block 48 becomes part of the input of block 1 in the next cycle, and we do this up to four times per forward pass. Hence, “48 blocks” is really “up to 192 blocks of compute,” just with weights shared across cycles.

Now let’s open up the Evoformer block sub-block by sub-block.

4. MSA row attention with pair bias

This is Algorithm 7 of the supplement, and it’s the most important single operation in the model. Row-wise attention means that, for each sequence kk in the MSA, attention runs along the row — residues of the same sequence attend to each other. The “with pair bias” part means we inject information from zz into the attention weights as a learned bias term.

Concretely, for a sequence kk, the attention weight between residues ii and jj is:

asijh=softmaxj ⁣(1cqsihksjh+bijh)a_{sij}^h = \mathrm{softmax}_j\!\left(\frac{1}{\sqrt{c}}\, q^h_{si} \cdot k^h_{sj} + b^h_{ij}\right)

where qhq^h and khk^h are query and key projections of mm (sequence ss, residue ii, head hh), and bijhb^h_{ij} is a per-head bias term computed linearly from zijz_{ij}. Note that there’s no ss subscript on the bias — it’s shared across all sequences in the MSA, because the pair representation is about the target, not about any one alignment row.

class MSARowAttentionWithPairBias(torch.nn.Module):
    """MSA row-wise gated self-attention with pair bias (Algorithm 7).

    For each MSA row s, standard multi-head self-attention over residues
    ``i, j``, with the pair representation z injected as a learned
    per-head bias: ``a_{sij}^h = softmax_j(q · k / sqrt(c) + b_{ij}^h)``
    where ``b_{ij}^h = LinearNoBias(LayerNorm(z_{ij}))`` (line 3). The
    output is gated by ``sigmoid(Linear(m)) ⊙ attention_output`` and
    projected back to ``c_m``. The pair bias is what lets the pair rep
    influence the MSA rep inside a single Evoformer block — the
    symmetric path through :class:`OuterProductMean` fires on the way
    back in step 4.
    """

    def __init__(self, config):
        super().__init__()
        self.layer_norm_msa = torch.nn.LayerNorm(config.c_m)
        self.layer_norm_pair = torch.nn.LayerNorm(config.c_z)

        self.head_dim = config.dim
        self.num_heads = config.num_heads

        self.total_dim = self.head_dim * self.num_heads

        self.linear_q = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim, bias=False)
        self.linear_k = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim, bias=False)
        self.linear_v = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim, bias=False)

        self.linear_pair = torch.nn.Linear(in_features=config.c_z, out_features=self.num_heads, bias=False)

        self.linear_gate = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim)

        self.linear_output = torch.nn.Linear(in_features=self.total_dim, out_features=config.c_m)
        init_linear(self.linear_q, init="default")
        init_linear(self.linear_k, init="default")
        init_linear(self.linear_v, init="default")
        init_linear(self.linear_pair, init="default")
        init_gate_linear(self.linear_gate)
        init_linear(self.linear_output, init="final")

    def forward(self, msa_representation: torch.Tensor, pair_representation: torch.Tensor,
                msa_mask: Optional[torch.Tensor] = None):
        msa_representation = self.layer_norm_msa(msa_representation)

        # Shape (batch, N_seq, N_res, self.total_dim)
        Q = self.linear_q(msa_representation)
        K = self.linear_k(msa_representation)
        V = self.linear_v(msa_representation)

        # Reshape to (batch, N_seq, N_res, self.num_heads, self.head_dim)
        Q = Q.reshape((Q.shape[0], Q.shape[1], Q.shape[2], self.num_heads, self.head_dim))
        K = K.reshape((K.shape[0], K.shape[1], K.shape[2], self.num_heads, self.head_dim))
        V = V.reshape((V.shape[0], V.shape[1], V.shape[2], self.num_heads, self.head_dim))

        G = self.linear_gate(msa_representation)
        G = G.reshape((G.shape[0], G.shape[1], G.shape[2], self.num_heads, self.head_dim))

        # Squash values in range 0 to 1 to act as gating mechanism
        G = torch.sigmoid(G)

        pair_representation = self.layer_norm_pair(pair_representation)

        # Algorithm 7 line 3: b_ij^h = LinearNoBias(LayerNorm(z_ij)).
        # Shape (batch, N_res, N_res, self.num_heads)
        B = self.linear_pair(pair_representation)

        # Align B's axes with the score tensor below: (batch, num_heads, i, j),
        # then broadcast across the MSA sequence dim.
        B = B.permute(0, 3, 1, 2)            # (batch, num_heads, N_res_i, N_res_j)
        B = B.unsqueeze(1)                    # (batch, 1, num_heads, N_res_i, N_res_j)

        # Algorithm 7 line 5: a_sij^h = softmax_j(1/sqrt(c) q_si^h . k_sj^h + b_ij^h)
        # scores shape: (batch, N_seq, num_heads, N_res_i, N_res_j)
        scores = torch.einsum('bsihd, bsjhd -> bshij', Q, K)
        scores = scores / math.sqrt(self.head_dim) + B

        # Apply MSA mask to key positions (j dimension)
        if msa_mask is not None:
            # msa_mask: (batch, N_seq, N_res) -> (batch, N_seq, 1, 1, N_res)
            mask_bias = (1.0 - msa_mask[:, :, None, None, :]) * (-1e9)
            scores = scores + mask_bias

        attention = torch.nn.functional.softmax(scores, dim=-1)

        # Shape (batch, N_seq, N_res, self.num_heads, self.head_dim)
        values = torch.einsum('bshij, bsjhd -> bsihd', attention, V)

        values = G * values

        values = values.reshape((Q.shape[0], Q.shape[1], Q.shape[2], -1))

        output = self.linear_output(values)

        # Zero out padded query positions
        if msa_mask is not None:
            output = output * msa_mask[..., None]

        return output

Here’s the architectural trick: the pair representation enters the MSA stack only through this bias. That’s the coupling mechanism. Without the bias, each row of the MSA would be processed independently, and the model would have no way to use evolutionary information to guide attention — the alignment would be just a bunch of disconnected sequences. With the bias, every head gets a per-pair channel that says “pay attention to positions ii and jj together, in proportion to what we currently believe about their pair representation.” As the Evoformer iterates and zz accumulates structural information (through the outer-product mean from earlier blocks, the triangle updates, and recycling), the bias gets sharper, and the row attention becomes progressively more structure-aware.

The gating term (a sigmoid on a learned linear projection) matters too, even though it’s easy to breeze past. Without gating, the residual add would dominate — row attention would often compute a small update that got buried by the residual. The gate lets each head say “trust my update for this residue” or “don’t touch this residue, leave it to another head.” In practice, gated attention in AF2 is load-bearing: the ablations in the paper show that removing gates costs real accuracy.

5. Column attention, transition, and the outer product mean

Row attention teaches each sequence about its own residues. The next three sub-blocks teach the MSA about itself across sequences, shape up the representation for downstream consumption, and then — the payoff — push the aggregated evolutionary signal into zz so the pair stack can use it.

5.1 MSA column-wise attention (Algorithm 8)

Column attention is the dual of row attention — for each residue position ii, attention runs down the column, and sequences at the same position attend to each other. That is, for residue ii in sequence kk, the embedding gets updated using information from residue ii in sequences 1,,k1,k+1,,s1, \dots, k-1, k+1, \dots, s.

class MSAColumnAttention(torch.nn.Module):
    """MSA column-wise gated self-attention (Algorithm 8).

    For each residue column i, attend across MSA sequences
    ``s = 1, ..., N_seq`` with standard multi-head attention on ``m_{si}``
    (no pair bias, unlike the row variant). Gated by
    ``sigmoid(Linear(m))`` and projected back to ``c_m``. No dropout
    per Algorithm 6. Used only in the main Evoformer; the extra MSA
    stack uses :class:`MSAColumnGlobalAttention` instead.
    """

    def __init__(self, config):
        super().__init__()
        self.layer_norm_msa = torch.nn.LayerNorm(config.c_m)

        self.head_dim = config.dim
        self.num_heads = config.num_heads

        self.total_dim = self.head_dim * self.num_heads

        self.linear_q = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim, bias=False)
        self.linear_k = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim, bias=False)
        self.linear_v = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim, bias=False)

        self.linear_gate = torch.nn.Linear(in_features=config.c_m, out_features=self.total_dim)

        self.linear_output = torch.nn.Linear(in_features=self.total_dim, out_features=config.c_m)
        init_linear(self.linear_q, init="default")
        init_linear(self.linear_k, init="default")
        init_linear(self.linear_v, init="default")
        init_gate_linear(self.linear_gate)
        init_linear(self.linear_output, init="final")

    def forward(self, msa_representation: torch.Tensor, msa_mask: Optional[torch.Tensor] = None):
        msa_representation = self.layer_norm_msa(msa_representation)

        # Shape (batch, N_seq, N_res, self.total_dim)
        Q = self.linear_q(msa_representation)
        K = self.linear_k(msa_representation)
        V = self.linear_v(msa_representation)

        # Reshape to (batch, N_seq, N_res, self.num_heads, self.head_dim)
        Q = Q.reshape((Q.shape[0], Q.shape[1], Q.shape[2], self.num_heads, self.head_dim))
        K = K.reshape((K.shape[0], K.shape[1], K.shape[2], self.num_heads, self.head_dim))
        V = V.reshape((V.shape[0], V.shape[1], V.shape[2], self.num_heads, self.head_dim))

        G = self.linear_gate(msa_representation)
        G = G.reshape((G.shape[0], G.shape[1], G.shape[2], self.num_heads, self.head_dim))

        # Squash values in range 0 to 1 to act as gating mechanism
        G = torch.sigmoid(G)

        # Shape (batch, N_res, self.num_heads, N_seq, N_seq)
        scores = torch.einsum('bsihd, btihd -> bihst', Q, K)
        scores = scores / math.sqrt(self.head_dim)

        # Apply MSA mask to key positions (t dimension = sequences)
        if msa_mask is not None:
            # msa_mask: (batch, N_seq, N_res) -> (batch, N_res, 1, 1, N_seq)
            mask_bias = (1.0 - msa_mask.permute(0, 2, 1)[:, :, None, None, :]) * (-1e9)
            scores = scores + mask_bias

        attention = torch.nn.functional.softmax(scores, dim=-1)

        # Shape (batch, N_seq, N_res, self.num_heads, self.head_dim)
        values = torch.einsum('bihst, btihd -> bsihd', attention, V)

        values = G * values

        values = values.reshape((Q.shape[0], Q.shape[1], Q.shape[2], -1))

        output = self.linear_output(values)

        # Zero out padded query positions
        if msa_mask is not None:
            output = output * msa_mask[..., None]

        return output

There’s no pair bias here — the pair representation is a per-residue-pair object, and column attention is happening across sequences at the same residue. It’s a vanilla gated multi-head attention, except along the ss axis rather than the rr axis. Column attention is what lets the model ask “what does residue 47 look like across this whole family of homologs?” and compress that information into the embedding for residue 47 in the target row.

5.2 MSA transition (Algorithm 9)

A standard two-layer MLP applied per-residue, per-sequence. The hidden dimension is 4× the input dimension — the same 4× expansion pattern as a Transformer FFN — and the nonlinearity is ReLU.

class MSATransition(torch.nn.Module):
    """MSA transition (Algorithm 9).

    Two-layer feed-forward applied per MSA cell: ``LayerNorm → Linear(c
    → n·c) → ReLU → Linear(n·c → c)`` with ``n = 4`` by default. The
    widening factor ``n`` is the supplement's ``n`` parameter, kept
    configurable so the extra MSA stack can choose its own (``c_in``
    also overrideable so the same module serves both ``c_m`` and
    ``c_e`` MSA reps). No dropout — the residual connection is purely
    additive per Algorithm 6.
    """

    def __init__(self, config, c_in: Optional[int] = None, n: Optional[int] = None):
        super().__init__()
        self.c_in = config.c_m if c_in is None else c_in
        self.n = config.msa_transition_n if n is None else n

        self.layer_norm = torch.nn.LayerNorm(self.c_in)

        self.linear_up = torch.nn.Linear(in_features=self.c_in, out_features=self.n * self.c_in)
        self.linear_down = torch.nn.Linear(in_features=self.c_in * self.n, out_features=self.c_in)
        init_linear(self.linear_up, init="relu")
        init_linear(self.linear_down, init="final")

    def forward(self, msa_representation: torch.Tensor):
        msa_representation = self.layer_norm(msa_representation)

        activations = self.linear_up(msa_representation)

        return self.linear_down(torch.nn.functional.relu(activations))

Nothing surprising here. The transition is the MSA stack’s “compute” step, analogous to the FFN in a standard Transformer block. It gives the network capacity to combine features that the attention steps have just surfaced.

5.3 The outer product mean (Algorithm 10)

And now, the payoff. The outer product mean takes the MSA representation — which by this point in the block has been updated by row attention, column attention, and the transition MLP — and converts it into an update for the pair representation. Its job is to take everything we’ve learned about residues across the aligned sequences and distill it back down into a single signal about each residue pair (i,j)(i, j) in the target sequence, a signal we can then inject into zz so the pair stack can reason about contacts.

class OuterProductMean(torch.nn.Module):
    """Outer product mean (Algorithm 10).

    Symmetric MSA → pair update: project each MSA cell to two hidden
    vectors ``a_{si}, b_{si} ∈ R^{c_hidden}``, take the MSA mean of
    their outer product ``mean_s (a_{si} ⊗ b_{sj})``, flatten to
    ``c_hidden^2`` channels, and project back to ``c_z``. This is the
    only channel in the Evoformer where the MSA rep writes into the
    pair rep; the reverse direction (pair → MSA) goes through the
    pair-biased row attention.
    ``c_in``/``c_hidden`` are configurable so the extra MSA stack can
    run a narrower OPM.
    """

    def __init__(self, config, c_in: Optional[int] = None, c_hidden: Optional[int] = None):
        super().__init__()
        self.c_in = config.c_m if c_in is None else c_in
        self.layer_norm = torch.nn.LayerNorm(self.c_in)

        self.c = config.outer_product_dim if c_hidden is None else c_hidden

        self.linear_left = torch.nn.Linear(self.c_in, self.c)
        self.linear_right = torch.nn.Linear(self.c_in, self.c)

        self.linear_out = torch.nn.Linear(in_features=self.c*self.c, out_features=config.c_z)
        init_linear(self.linear_left, init="default")
        init_linear(self.linear_right, init="default")
        init_linear(self.linear_out, init="final")

    def forward(self, msa_representation: torch.Tensor, msa_mask: Optional[torch.Tensor] = None):
        # msa_mask: (batch, N_seq, N_res) — 1 for valid, 0 for padding
        msa_representation = self.layer_norm(msa_representation)

        # Shape (batch, N_seq, N_res, self.c)
        A = self.linear_left(msa_representation)
        B = self.linear_right(msa_representation)

        if msa_mask is not None:
            # Zero out padded MSA rows before outer product
            m = msa_mask.to(A.dtype)              # (batch, N_seq, N_res)
            A = A * m[..., None]                   # (batch, N_seq, N_res, c)
            B = B * m[..., None]

        # Sum over N_seq: (batch, N_res_i, N_res_j, c, c)
        outer = torch.einsum('bsic, bsjd -> bijcd', A, B)

        if msa_mask is not None:
            # Mask-aware normalization: count valid (s,i)*(s,j) pairs
            m = msa_mask.to(A.dtype)
            norm = torch.einsum('bsi, bsj -> bij', m, m).clamp(min=1.0)  # (batch, N_res, N_res)
            mean_val = outer / norm[..., None, None]
        else:
            mean_val = outer / msa_representation.shape[1]

        # Shape (batch, N_res, N_res, self.c*self.c)
        mean_val = mean_val.reshape(mean_val.shape[0], mean_val.shape[1], mean_val.shape[2], -1)

        return self.linear_out(mean_val)

The mechanics work as follows. We start from the MSA representation of shape (s,r,cm)(s, r, c_m). After a LayerNorm, we apply two separate linear projections to produce a “left” tensor AA and a “right” tensor BB, each of shape (s,r,c)(s, r, c'), where cc' is a smaller hidden dimension. You can think of AA as the representation of residue ii “as a left-hand partner” and BB as the representation of residue jj “as a right-hand partner” — two different views of the same residue, each tuned for its role in the pairwise interaction.

Then the core operation: for every sequence kk in the MSA and every pair of residue positions (i,j)(i, j), we form the outer product of Ak,iA_{k,i} and Bk,jB_{k,j}. That outer product is a c×cc' \times c' matrix that captures every pairwise interaction between the features of residue ii and the features of residue jj, as seen in sequence kk. Where a simple dot product would collapse everything down to a single number, the outer product preserves the full grid of feature-by-feature interactions — much richer information about how the two residues relate in this particular sequence.

We then sum these outer products over all sequences in the MSA and divide by the number of sequences (with proper masking to ignore padded rows). This gives us, for each pair (i,j)(i, j), the average feature-by-feature interaction between residue ii and residue jj across the entire alignment. Intuitively: if residues ii and jj consistently co-vary across the MSA — the hallmark of co-evolution and likely physical contact — that pattern will show up strongly and consistently in these outer products and survive the averaging. If they vary independently, the signals across sequences will wash each other out.

The resulting tensor has shape (r,r,c,c)(r, r, c', c'). We flatten the last two dimensions to get (r,r,cc)(r, r, c' \cdot c'), then apply a final linear projection to map down to czc_z, the channel dimension of the pair representation. That produces a tensor of shape (r,r,cz)(r, r, c_z) that can be added directly into zz.

This is the cleanest handoff in the whole architecture. The MSA stack uses row- and column-wise attention to figure out which residues behave similarly across evolutionary history, and the outer product mean then translates those cross-sequence patterns into a per-pair signal. It tells the pair representation, for every (i,j)(i, j), how strongly the evolutionary record suggests these two residues are coupled. The Structure Module downstream uses this — routed through the triangle updates and the IPA in the Structure Module — to guess which residues are in contact, which is what ultimately drives accurate structure prediction.

Putting it all together: the MSA steps compute a representation that optimally captures similarity of residues within sequences (via row attention) and across sequences (via column attention). The outer product mean then uses that representation to generate a measure of similarity between every pair of residues in the target and adds it into zz. In essence, we use the MSA to figure out which residues are similar to which other residues, and then hand this information to the pair stack so the Structure Module can guess at which residues are in contact.

6. Triangle multiplicative updates

After the MSA stack has finished — row attention, column attention, transition, and the outer product mean — the pair representation zz is holding a first pass at the evolutionary signal. For every residue pair (i,j)(i, j) in the target, zijz_{ij} is a vector whose direction and magnitude encode how strongly the alignment thinks those two residues are coupled. The rest of the Evoformer block is about shaping that signal into something a structure module can actually place atoms with. The first two sub-blocks that do the shaping are the triangle multiplicative updates.

The intuition lives in the word “triangle.” Think of zz as a directed graph on the residues — residues are vertices, and the pair representation zijz_{ij} is the vector sitting on the directed edge from ii to jj. Physical contacts are, in essence, a consistency claim about these edges: if residue ii and residue kk are in contact, and residue kk and residue jj are in contact, then residues ii and jj are also likely in contact — because all three are clustered in the same region of 3D space. You can’t just update one edge at a time without enforcing that kind of triangle consistency; if you do, the pair representation will happily hold a globally incoherent story where every edge is locally plausible but the triangles don’t close.

The triangle multiplicative update fixes this. For each target edge zijz_{ij}, it computes an update by summing over every residue kk, multiplying together the other two edges of a triangle (i,k,j)(i, k, j):

zijgijLinear ⁣(kaikbjk)z_{ij} \leftarrow g_{ij} \odot \mathrm{Linear}\!\left( \sum_k a_{ik} \odot b_{jk} \right)
(11)

That’s the outgoing variant. Two linear projections of zz produce aa and bb — you can think of aika_{ik} as “the view of the iki \to k edge that is being used to update outgoing edges at ii” and bjkb_{jk} as “the view of the jkj \to k edge for the same purpose.” For each kk, the elementwise product aikbjka_{ik} \odot b_{jk} combines these two views. We then sum over every kk, linearly project back to pair-rep dimension, gate with a learned sigmoid gijg_{ij}, and add into zz.

The incoming variant is the symmetric move, operating on columns instead of rows:

zijgijLinear ⁣(kakibkj)z_{ij} \leftarrow g_{ij} \odot \mathrm{Linear}\!\left( \sum_k a_{ki} \odot b_{kj} \right)
(12)

Same operation, different triangle. In outgoing, we’re using the edges emanating from ii and from jj to kk; in incoming, we’re using the edges arriving at ii and at jj from kk. Together, they cover both directional structures the pair representation is asymmetric in.

6.1 Seeing the operation

Before we look at code, the update is much easier to understand when you can watch a single kk-sweep fire. The widget below draws the pair representation as an N×NN \times N grid. Click any cell to pick the target pair (i,j)(i, j), then watch the animation sum over kk. In outgoing mode, the two contributing cells sit on row ii and row jj; in incoming mode, they sit on column ii and column jj. The dashed triangle connects the three cells in play at the current step of the sum.

00112233445566778899
k = 0z_{2,7} += a_{2,0} ⊙ b_{7,0}
Outgoing. For the target edge z2,7, sum over every residue k the elementwise product of row-i entry a2,k and row-j entry b7,k. Each k closes a triangle i → k ← j.

A few things worth noticing as the animation runs. (1) The target cell (i,j)(i, j) stays put, but the two contributing cells sweep across the grid together as kk advances. (2) In outgoing mode, the two contributing cells share a column (kk); in incoming mode, they share a row. (3) The kk-summation touches every other residue in the chain — not just the ones currently believed to be nearby. It’s a dense update, not a sparse attention.

6.2 The code

With the operation in mind, the PyTorch follows directly. The outgoing variant first:

class TriangleMultiplicationOutgoing(torch.nn.Module):
    """Triangle multiplicative update, outgoing edges (Algorithm 11).

    Update ``z_ij`` from the two *outgoing* edges of the triangle
    ``(i, j, k)``: ``z_ij ← g_ij ⊙ Linear(LayerNorm(sum_k a_ik ⊙
    b_jk))`` where ``a = gate_a ⊙ projection_a(z)`` and likewise for
    ``b``. Enforces the triangle-inequality structure across the pair
    rep. Algorithm 11 pools over intermediate node ``k`` via the
    outgoing edges ``z_{ik}`` and ``z_{jk}``; the incoming-edge
    counterpart is :class:`TriangleMultiplicationIncoming`.
    """

    def __init__(self, config, c: Optional[int] = None):
        super().__init__()
        mult_c = config.triangle_mult_c if c is None else c
        self.layer_norm_pair = torch.nn.LayerNorm(config.c_z)
        self.layer_norm_out = torch.nn.LayerNorm(mult_c)

        self.gate1 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)
        self.gate2 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)

        self.linear1 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)
        self.linear2 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)

        self.gate = torch.nn.Linear(in_features=config.c_z, out_features=config.c_z)

        self.out_linear = torch.nn.Linear(in_features=mult_c, out_features=config.c_z)
        init_gate_linear(self.gate1)
        init_gate_linear(self.gate2)
        init_linear(self.linear1, init="default")
        init_linear(self.linear2, init="default")
        init_gate_linear(self.gate)
        init_linear(self.out_linear, init="final")

    def forward(self, pair_representation: torch.Tensor, pair_mask: Optional[torch.Tensor] = None):
        pair_representation = self.layer_norm_pair(pair_representation)

        # Shape (batch, N_res, N_res, c)
        A = torch.sigmoid(self.gate1(pair_representation)) * self.linear1(pair_representation)
        B = torch.sigmoid(self.gate2(pair_representation)) * self.linear2(pair_representation)

        # Mask out padded positions before contraction
        if pair_mask is not None:
            A = A * pair_mask[..., None]
            B = B * pair_mask[..., None]

        # Shape (batch, N_res, N_res, c_z)
        G = torch.sigmoid(self.gate(pair_representation))

        # A: (batch, N_res_i, N_res_k, c)
        # B: (batch, N_res_j, N_res_k, c)
        # Result: (batch, N_res_i, N_res_j, c)
        vals = torch.einsum('bikc, bjkc -> bijc', A, B)

        # Shape (batch, N_res, N_res, c_z)
        out = G * self.out_linear(self.layer_norm_out(vals))

        if pair_mask is not None:
            out = out * pair_mask[..., None]

        return out

The pattern is exactly what the equation says: LayerNorm, four linear projections (linear_a, linear_b, plus gate linears), compute aa and bb, einsum the sum-over-kk, LayerNorm, final gated linear, done. The einsum string is where the geometry sits — bikc,bjkc->bijc says “for each batch bb, sum over kk the elementwise product of aika_{ik} and bjkb_{jk}, producing zijz_{ij} in channels cc.” Read it once carefully; this single einsum is the whole triangle multiplicative update.

The incoming variant differs only in the einsum indices:

class TriangleMultiplicationIncoming(torch.nn.Module):
    """Triangle multiplicative update, incoming edges (Algorithm 12).

    Symmetric partner of :class:`TriangleMultiplicationOutgoing`: pool
    over intermediate node ``k`` using the *incoming* edges
    ``z_{ki}`` and ``z_{kj}`` (i.e. ``sum_k a_ki ⊙ b_kj``). Outgoing and
    incoming variants fire back-to-back in every Evoformer block so the
    pair rep sees both triangle orientations per iteration.
    """

    def __init__(self, config, c: Optional[int] = None):
        super().__init__()
        mult_c = config.triangle_mult_c if c is None else c
        self.layer_norm_pair = torch.nn.LayerNorm(config.c_z)
        self.layer_norm_out = torch.nn.LayerNorm(mult_c)

        self.gate1 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)
        self.gate2 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)

        self.linear1 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)
        self.linear2 = torch.nn.Linear(in_features=config.c_z, out_features=mult_c)

        self.gate = torch.nn.Linear(in_features=config.c_z, out_features=config.c_z)

        self.out_linear = torch.nn.Linear(in_features=mult_c, out_features=config.c_z)
        init_gate_linear(self.gate1)
        init_gate_linear(self.gate2)
        init_linear(self.linear1, init="default")
        init_linear(self.linear2, init="default")
        init_gate_linear(self.gate)
        init_linear(self.out_linear, init="final")

    def forward(self, pair_representation: torch.Tensor, pair_mask: Optional[torch.Tensor] = None):
        pair_representation = self.layer_norm_pair(pair_representation)

        # Shape (batch, N_res, N_res, c)
        A = torch.sigmoid(self.gate1(pair_representation)) * self.linear1(pair_representation)
        B = torch.sigmoid(self.gate2(pair_representation)) * self.linear2(pair_representation)

        # Mask out padded positions before contraction
        if pair_mask is not None:
            A = A * pair_mask[..., None]
            B = B * pair_mask[..., None]

        # Shape (batch, N_res, N_res, c_z)
        G = torch.sigmoid(self.gate(pair_representation))

        # A: (batch, N_res_k, N_res_i, c)
        # B: (batch, N_res_k, N_res_j, c)
        # Result: (batch, N_res_i, N_res_j, c)
        vals = torch.einsum('bkic, bkjc -> bijc', A, B)

        # Shape (batch, N_res, N_res, c_z)
        out = G * self.out_linear(self.layer_norm_out(vals))

        if pair_mask is not None:
            out = out * pair_mask[..., None]

        return out

bkic,bkjc->bijc — swap which axis kk sums over. Everything else is identical.

It’s worth pausing on the cost of this operation. The einsum is O(r3c)O(r^3 c') per call — for a protein of length r=256r = 256 at channel width c=128c' = 128, that’s roughly 2 billion multiply-adds per triangle update, and the Evoformer applies two of them (outgoing and incoming) per block, across 48 blocks, up to 4 recycling cycles. This is the single most expensive operation in AlphaFold2’s forward pass, and it’s the reason training and inference both benefit dramatically from gradient checkpointing (which minAlphaFold applies around the Evoformer stack). Without it, the activation memory alone would blow out even an A100.

7. Triangle self-attention

The triangle multiplicative updates enforce triangle consistency in a direct algebraic way — every edge gets refined using a sum over third vertices. The triangle self-attention layers are the softer, attention-weighted version of the same idea. They let each edge attend to other edges that share a residue with it, weighted by a learned compatibility score that also reflects the current pair representation.

There are two variants again, differing in how the “shared residue” is picked.

Starting-node attention (Algorithm 13). For each target edge (i,j)(i, j), the query comes from zijz_{ij}, and the keys and values come from the set of edges starting at ii — that is, {zik:k[1,r]}\{z_{ik} : k \in [1, r]\}. The attention scores additionally get a bias term derived from zjkz_{jk}, so that the third vertex of the triangle enters the attention pattern directly.

class TriangleAttentionStartingNode(torch.nn.Module):
    """Triangle self-attention around the starting node (Algorithm 13).

    Gated multi-head self-attention over the pair rep with a
    triangle-consistency bias: for fixed starting node i, attend over
    ending nodes j with keys from ``z_{ij}`` and values from
    ``z_{ik}``, plus a pair bias ``b_{jk} = LinearNoBias(LayerNorm(
    z_{jk}))``. Row-wise dropout (supplement 1.11.6) matches Algorithm 6.
    """

    def __init__(self, config, c: Optional[int] = None, num_heads: Optional[int] = None):
        super().__init__()
        self.layer_norm = torch.nn.LayerNorm(config.c_z)

        self.head_dim = config.triangle_dim if c is None else c
        self.num_heads = config.triangle_num_heads if num_heads is None else num_heads

        self.total_dim = self.head_dim * self.num_heads

        self.linear_q = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim, bias=False)
        self.linear_k = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim, bias=False)
        self.linear_v = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim, bias=False)

        self.linear_bias = torch.nn.Linear(in_features=config.c_z, out_features=self.num_heads, bias=False)

        self.linear_gate = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim)

        self.linear_output = torch.nn.Linear(in_features=self.total_dim, out_features=config.c_z)
        init_linear(self.linear_q, init="default")
        init_linear(self.linear_k, init="default")
        init_linear(self.linear_v, init="default")
        init_linear(self.linear_bias, init="default")
        init_gate_linear(self.linear_gate)
        init_linear(self.linear_output, init="final")

    def forward(self, pair_representation: torch.Tensor, pair_mask: Optional[torch.Tensor] = None):
        pair_representation = self.layer_norm(pair_representation)

        # Shape (batch, N_res, N_res, self.total_dim)
        Q = self.linear_q(pair_representation)
        K = self.linear_k(pair_representation)
        V = self.linear_v(pair_representation)

        # Reshape to (batch, N_res, N_res, self.num_heads, self.head_dim)
        Q = Q.reshape((Q.shape[0], Q.shape[1], Q.shape[2], self.num_heads, self.head_dim))
        K = K.reshape((K.shape[0], K.shape[1], K.shape[2], self.num_heads, self.head_dim))
        V = V.reshape((V.shape[0], V.shape[1], V.shape[2], self.num_heads, self.head_dim))

        G = self.linear_gate(pair_representation)
        G = G.reshape((G.shape[0], G.shape[1], G.shape[2], self.num_heads, self.head_dim))

        # Squash values in range 0 to 1 to act as gating mechanism
        G = torch.sigmoid(G)

        # Shape (batch, N_res, N_res, self.num_heads)
        B = self.linear_bias(pair_representation)

        # Q shape (batch, N_res_i, N_res_j, self.num_heads, self.head_dim)
        # K shape (batch, N_res_i, N_res_k, self.num_heads, self.head_dim)
        # B shape (batch, N_res_j, N_res_k, self.num_heads)
        # Output shape (batch, N_res_i, N_res_j, N_res_k, self.num_heads)
        scores = torch.einsum('bijhd, bikhd -> bijkh', Q, K)
        scores = scores / math.sqrt(self.head_dim) + B.unsqueeze(1)

        # Apply pair mask to key positions (k dimension, for a given i)
        if pair_mask is not None:
            # pair_mask: (batch, N_res, N_res) -> (batch, N_res_i, 1, N_res_k, 1)
            mask_bias = (1.0 - pair_mask[:, :, None, :, None]) * (-1e9)
            scores = scores + mask_bias

        attention = torch.nn.functional.softmax(scores, dim=3)

        # Shape (batch, N_res, N_res, self.num_heads, self.head_dim)
        values = torch.einsum('bijkh, bikhd -> bijhd', attention, V)

        values = G * values

        values = values.reshape((Q.shape[0], Q.shape[1], Q.shape[2], -1))

        output = self.linear_output(values)

        # Zero out padded query positions
        if pair_mask is not None:
            output = output * pair_mask[..., None]

        return output

Ending-node attention (Algorithm 14). The same thing, but now the keys and values come from the edges ending at jj{zkj:k[1,r]}\{z_{kj} : k \in [1, r]\} — and the bias comes from zkiz_{ki}.

class TriangleAttentionEndingNode(torch.nn.Module):
    """Triangle self-attention around the ending node (Algorithm 14).

    Mirror image of :class:`TriangleAttentionStartingNode`: fix the
    ending node j and attend over starting nodes i. The pair bias is
    ``b_{ki} = LinearNoBias(LayerNorm(z_{ki}))``. The supplement
    prescribes column-wise dropout (not row-wise) on this output — the
    Evoformer block applies it accordingly.
    """

    def __init__(self, config, c: Optional[int] = None, num_heads: Optional[int] = None):
        super().__init__()
        self.layer_norm = torch.nn.LayerNorm(config.c_z)

        self.head_dim = config.triangle_dim if c is None else c
        self.num_heads = config.triangle_num_heads if num_heads is None else num_heads

        self.total_dim = self.head_dim * self.num_heads

        self.linear_q = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim, bias=False)
        self.linear_k = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim, bias=False)
        self.linear_v = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim, bias=False)

        self.linear_bias = torch.nn.Linear(in_features=config.c_z, out_features=self.num_heads, bias=False)

        self.linear_gate = torch.nn.Linear(in_features=config.c_z, out_features=self.total_dim)

        self.linear_output = torch.nn.Linear(in_features=self.total_dim, out_features=config.c_z)
        init_linear(self.linear_q, init="default")
        init_linear(self.linear_k, init="default")
        init_linear(self.linear_v, init="default")
        init_linear(self.linear_bias, init="default")
        init_gate_linear(self.linear_gate)
        init_linear(self.linear_output, init="final")

    def forward(self, pair_representation: torch.Tensor, pair_mask: Optional[torch.Tensor] = None):
        pair_representation = self.layer_norm(pair_representation)

        # Shape (batch, N_res, N_res, self.total_dim)
        Q = self.linear_q(pair_representation)
        K = self.linear_k(pair_representation)
        V = self.linear_v(pair_representation)

        # Reshape to (batch, N_res, N_res, self.num_heads, self.head_dim)
        Q = Q.reshape((Q.shape[0], Q.shape[1], Q.shape[2], self.num_heads, self.head_dim))
        K = K.reshape((K.shape[0], K.shape[1], K.shape[2], self.num_heads, self.head_dim))
        V = V.reshape((V.shape[0], V.shape[1], V.shape[2], self.num_heads, self.head_dim))

        G = self.linear_gate(pair_representation)
        G = G.reshape((G.shape[0], G.shape[1], G.shape[2], self.num_heads, self.head_dim))

        # Squash values in range 0 to 1 to act as gating mechanism
        G = torch.sigmoid(G)

        # Shape (batch, N_res, N_res, self.num_heads)
        B = self.linear_bias(pair_representation)

        # Algorithm 14 line 5: a_ijk^h = softmax_k(1/sqrt(c) q_ij^h . k_kj^h + b_ki^h).
        # The highlighted differences from the starting-node version (Algorithm 13)
        # are that keys/values are indexed (k, j) instead of (i, k), and the bias
        # is b_{k,i} instead of b_{j,k}.
        #
        # Q shape (batch, N_res_i, N_res_j, num_heads, head_dim)
        # K shape (batch, N_res_k, N_res_j, num_heads, head_dim)
        # B shape (batch, N_res, N_res, num_heads), with B[b, i, j, h] coming from z_{ij}.
        # Output shape (batch, N_res_i, N_res_j, N_res_k, num_heads)

        scores = torch.einsum('bijhd, bkjhd -> bijkh', Q, K)
        # We need B indexed as b^h_{k,i} at score position (b, i, j, k, h), i.e.
        # B[b, k, i, h]. Swapping axes 1 and 2 of B yields a view whose indexing
        # is B_t[b, x, y, h] = B[b, y, x, h], so B_t[b, i, k, h] = B[b, k, i, h].
        # Inserting a new j-axis with unsqueeze(2) broadcasts that bias to every
        # (i, j, k, h) score.
        scores = scores / math.sqrt(self.head_dim) + B.transpose(1, 2).unsqueeze(2)

        # Apply pair mask to key positions: an attention score at (b, i, j, k, h)
        # should be masked when the *key* pair z_{k,j} is padding, i.e. when
        # pair_mask[b, k, j] == 0. Permute swaps the k/j axes so the view exposes
        # [b, j, k], then broadcast over i and h.
        if pair_mask is not None:
            mask_bias = (1.0 - pair_mask.permute(0, 2, 1)[:, None, :, :, None]) * (-1e9)
            scores = scores + mask_bias

        attention = torch.nn.functional.softmax(scores, dim=3)

        # Shape (batch, N_res, N_res, self.num_heads, self.head_dim)
        values = torch.einsum('bijkh, bkjhd -> bijhd', attention, V)

        values = G * values

        values = values.reshape((Q.shape[0], Q.shape[1], Q.shape[2], -1))

        output = self.linear_output(values)

        # Zero out padded query positions
        if pair_mask is not None:
            output = output * pair_mask[..., None]

        return output

The two operations together span both directions in which a third residue can be “shared” with the edge (i,j)(i, j) — either kk sits at the start of both (i,k)(i, k) and (i,j)(i, j), or kk sits at the end of both (k,j)(k, j) and (i,j)(i, j). And the triangle bias — pulling the third edge of the triangle directly into the softmax — is what makes these “triangle” attention rather than vanilla row/column attention on the pair representation. Without the bias term, the two operations would look like familiar self-attention patterns over rows and columns of zz; the bias is the thing that wires the third edge in.

After triangle attention, the Evoformer block closes with a simple pair transition MLP — the same 2-layer, 4×-expansion FFN pattern we saw in §5.2, but now on the pair representation:

class PairTransition(torch.nn.Module):
    """Pair transition (Algorithm 15).

    Per-pair feed-forward: ``LayerNorm → Linear(c_z → n·c_z) → ReLU →
    Linear(n·c_z → c_z)`` with widening factor ``n = 4``. Same shape as
    :class:`MSATransition` but over the pair rep instead of the MSA
    rep. No dropout per Algorithm 6.
    """

    def __init__(self, config, n: Optional[int] = None):
        super().__init__()
        self.n = config.pair_transition_n if n is None else n

        self.layer_norm = torch.nn.LayerNorm(config.c_z)

        self.linear_up = torch.nn.Linear(in_features=config.c_z, out_features=self.n*config.c_z)
        self.linear_down = torch.nn.Linear(in_features=config.c_z*self.n, out_features=config.c_z)
        init_linear(self.linear_up, init="relu")
        init_linear(self.linear_down, init="final")

    def forward(self, pair_representation: torch.Tensor):
        pair_representation = self.layer_norm(pair_representation)

        activations = self.linear_up(pair_representation)

        return self.linear_down(torch.nn.functional.relu(activations))

That’s the full Evoformer block: seven sub-blocks, of which two operate on the MSA stack with a small ancillary role (column attn, MSA transition), one couples the two representations (OPM), and four operate on the pair representation (triangle mult × 2, triangle attn × 2, pair transition). The pair stack is where the geometric consistency gets enforced, and the sub-block budget reflects that.

8. Extra MSA and templates

Two side pipelines feed additional information into the main representations before the Evoformer runs on them. Both are best thought of as “cheaper stacks that produce pair-representation updates” — neither of them has the full 48-block Evoformer treatment, because neither needs it.

8.1 Extra MSA

AlphaFold2’s main MSA input is clustered and capped at a modest number of sequences (a few hundred in the full config, eight in tiny.toml) — enough to carry rich signal but small enough that the Evoformer’s attention cost stays manageable. The extra MSA is a larger slice of the same alignment (thousands of sequences in the full config) that gets processed by a shallower, cheaper sibling of the Evoformer — different block count, global column attention instead of per-sequence column attention, and a one-way handoff into the pair representation (no feedback into the main MSA).

class ExtraMsaStack(torch.nn.Module):
    """Extra MSA stack (Algorithm 18, supplement 1.7.2).

    Lightweight Evoformer-like block for the unclustered "extra" MSA.
    The extra MSA is much deeper (default ``N_extra_seq = 1024`` vs
    ``N_cluster = 128``) but compressed to a smaller channel dim
    ``c_e`` to stay cheap. Two differences from the main Evoformer:

    * MSA column attention is replaced by
      :class:`MSAColumnGlobalAttention` (Algorithm 19) — across
      thousands of sequences, per-head K/V sharing is what keeps
      the column step tractable.
    * Row attention with pair bias is inlined here rather than
      reusing :class:`~minalphafold.evoformer.MSARowAttentionWithPairBias`
      so ``c_e ≠ c_m`` projections stay self-contained.

    Consumes the extra MSA representation and the pair representation;
    writes updates back to both (triangle updates + pair transition
    apply after the OPM consumes the updated extra MSA).
    """

    def __init__(self, config):
        super().__init__()

        self.layer_norm_msa = torch.nn.LayerNorm(config.c_e)
        self.layer_norm_pair = torch.nn.LayerNorm(config.c_z)

        self.head_dim = config.extra_msa_dim
        self.num_heads = config.num_heads

        self.total_dim = self.head_dim * self.num_heads

        # MSA row attention with pair bias (inline, same as Algorithm 7)
        self.linear_q = torch.nn.Linear(in_features=config.c_e, out_features=self.total_dim, bias=False)
        self.linear_k = torch.nn.Linear(in_features=config.c_e, out_features=self.total_dim, bias=False)
        self.linear_v = torch.nn.Linear(in_features=config.c_e, out_features=self.total_dim, bias=False)

        self.linear_pair = torch.nn.Linear(in_features=config.c_z, out_features=self.num_heads, bias=False)

        self.linear_gate = torch.nn.Linear(in_features=config.c_e, out_features=self.total_dim)

        self.linear_output = torch.nn.Linear(in_features=self.total_dim, out_features=config.c_e)
        init_linear(self.linear_q, init="default")
        init_linear(self.linear_k, init="default")
        init_linear(self.linear_v, init="default")
        init_linear(self.linear_pair, init="default")
        init_gate_linear(self.linear_gate)
        init_linear(self.linear_output, init="final")

        self.msa_col_att = MSAColumnGlobalAttention(config, c_in=config.c_e)
        self.msa_transition = MSATransition(
            config,
            c_in=config.c_e,
            n=getattr(config, "extra_msa_transition_n", config.msa_transition_n),
        )
        self.outer_mean = OuterProductMean(
            config,
            c_in=config.c_e,
            c_hidden=getattr(config, "extra_msa_outer_product_dim", config.outer_product_dim),
        )

        self.triangle_mult_out = TriangleMultiplicationOutgoing(config)
        self.triangle_mult_in = TriangleMultiplicationIncoming(config)
        self.triangle_att_start = TriangleAttentionStartingNode(config)
        self.triangle_att_end = TriangleAttentionEndingNode(config)
        self.pair_transition = PairTransition(config)

        self.msa_dropout_p = config.extra_msa_dropout
        self.pair_dropout_p = config.extra_pair_dropout

    def forward(self, extra_msa_representation: torch.Tensor, pair_representation: torch.Tensor,
                extra_msa_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None):
        # extra_msa_representation shape: (batch, N_extra_seq, N_res, c_e)
        # pair_representation shape: (batch, N_res, N_res, c_z)
        # extra_msa_mask: (batch, N_extra_seq, N_res) — 1 for valid, 0 for padding
        # pair_mask: (batch, N_res, N_res) — 1 for valid, 0 for padding

        msa_representation = self.layer_norm_msa(extra_msa_representation)
        pair_norm = self.layer_norm_pair(pair_representation)

        # --- MSA row attention with pair bias ---

        # Shape (batch, N_extra_seq, N_res, total_dim)
        Q = self.linear_q(msa_representation)
        K = self.linear_k(msa_representation)
        V = self.linear_v(msa_representation)

        # Reshape to (batch, N_extra_seq, N_res, num_heads, head_dim)
        Q = Q.reshape((Q.shape[0], Q.shape[1], Q.shape[2], self.num_heads, self.head_dim))
        K = K.reshape((K.shape[0], K.shape[1], K.shape[2], self.num_heads, self.head_dim))
        V = V.reshape((V.shape[0], V.shape[1], V.shape[2], self.num_heads, self.head_dim))

        G = self.linear_gate(msa_representation)
        G = G.reshape((G.shape[0], G.shape[1], G.shape[2], self.num_heads, self.head_dim))

        # Squash values in range 0 to 1 to act as gating mechanism
        G = torch.sigmoid(G)

        # Pair bias: project pair representation to per-head bias
        # Shape (batch, N_res, N_res, num_heads) -> (batch, num_heads, N_res, N_res)
        B = self.linear_pair(pair_norm)
        B = B.permute(0, 3, 1, 2)

        # Add sequence dim for broadcast: (batch, 1, num_heads, N_res, N_res)
        B = B.unsqueeze(1)

        # Q shape (batch, N_extra_seq, N_res_i, num_heads, head_dim)
        # K shape (batch, N_extra_seq, N_res_j, num_heads, head_dim)
        # Output shape (batch, N_extra_seq, num_heads, N_res, N_res)
        scores = torch.einsum('bsihd, bsjhd -> bshij', Q, K)
        scores = scores / math.sqrt(self.head_dim) + B

        # Apply extra MSA mask to key positions (j dimension)
        if extra_msa_mask is not None:
            mask_bias = (1.0 - extra_msa_mask[:, :, None, None, :]) * (-1e9)
            scores = scores + mask_bias

        attention = torch.nn.functional.softmax(scores, dim=-1)

        # Shape (batch, N_extra_seq, N_res, num_heads, head_dim)
        values = torch.einsum('bshij, bsjhd -> bsihd', attention, V)

        values = G * values

        # Reshape to (batch, N_extra_seq, N_res, total_dim)
        values = values.reshape((values.shape[0], values.shape[1], values.shape[2], -1))

        row_update = self.linear_output(values)

        # Zero out padded query positions
        if extra_msa_mask is not None:
            row_update = row_update * extra_msa_mask[..., None]

        # --- MSA representation updates ---

        extra_msa_representation = extra_msa_representation + dropout_rowwise(
            row_update,
            p=self.msa_dropout_p,
            training=self.training,
        )
        extra_msa_representation = extra_msa_representation + self.msa_col_att(
            extra_msa_representation, msa_mask=extra_msa_mask)
        extra_msa_representation = extra_msa_representation + self.msa_transition(extra_msa_representation)

        # --- Pair representation updates ---

        pair_representation = pair_representation + self.outer_mean(
            extra_msa_representation, msa_mask=extra_msa_mask)
        pair_representation = pair_representation + dropout_rowwise(
            self.triangle_mult_out(pair_representation, pair_mask=pair_mask),
            p=self.pair_dropout_p,
            training=self.training,
        )
        pair_representation = pair_representation + dropout_rowwise(
            self.triangle_mult_in(pair_representation, pair_mask=pair_mask),
            p=self.pair_dropout_p,
            training=self.training,
        )
        pair_representation = pair_representation + dropout_rowwise(
            self.triangle_att_start(pair_representation, pair_mask=pair_mask),
            p=self.pair_dropout_p,
            training=self.training,
        )
        pair_representation = pair_representation + dropout_columnwise(
            self.triangle_att_end(pair_representation, pair_mask=pair_mask),
            p=self.pair_dropout_p,
            training=self.training,
        )
        pair_representation = pair_representation + self.pair_transition(pair_representation)

        return extra_msa_representation, pair_representation

The critical substitution is here: where the main MSA uses full column attention, the extra MSA uses MSAColumnGlobalAttention , which computes attention weights from a pooled summary of each column rather than from pairwise scores. This is the same trick used in Performer and Linformer — replace O(s2)O(s^2) column attention with an O(s)O(s) aggregate, at the cost of some expressivity. The extra MSA can afford that cost because its role is to push information into zz, not to build a refined per-residue representation.

8.2 Templates

When AlphaFold2 has structural hints from related proteins in the PDB, it processes them with a dedicated template pair stack. Each template is featurized into a pair representation of shape (r,r,ct)(r, r, c_t) — distances between backbone atoms, unit vectors between Cα positions, sequence identity masks — and fed through a few Evoformer-style blocks specialized for pair-only updates.

class TemplatePair(torch.nn.Module):
    """Template pair stack (Algorithm 16, supplement 1.7.1).

    Per-template shallow Evoformer-like pair stack: each of
    ``config.template_pair_num_blocks`` blocks applies triangle
    self-attention (start + end), triangle multiplication
    (outgoing + incoming), and a pair transition. Dropout matches the
    supplement — row-wise on starting / multiplicative updates, column-
    wise on ending. Batch and template dims are flattened for each
    block so templates evolve independently; the final LayerNorm
    happens once before the pointwise attention pool in
    :class:`TemplatePointwiseAttention`.
    """

    def __init__(self, config):
        super().__init__()

        self.num_blocks = config.template_pair_num_blocks
        self.dropout_p = config.template_pair_dropout

        # Supplement 1.7.1 / Algorithm 16: the template pair stack overrides
        # the main-Evoformer triangle dims (paper default c=64 for both the
        # multiplicative and attention updates, and n=2 on the pair transition)
        # instead of inheriting triangle_mult_c / triangle_dim / pair_transition_n.
        template_tri_mult_c = config.template_triangle_mult_c
        template_tri_attn_c = config.template_triangle_attn_c
        template_tri_attn_heads = config.template_triangle_attn_num_heads
        template_pair_trans_n = config.template_pair_transition_n

        self.layer_norm = torch.nn.LayerNorm(config.c_t)
        self.linear_in = torch.nn.Linear(in_features=config.c_t, out_features=config.c_z)
        init_linear(self.linear_in, init="default")

        self.triangle_mult_out = torch.nn.ModuleList(
            [TriangleMultiplicationOutgoing(config, c=template_tri_mult_c) for _ in range(self.num_blocks)]
        )
        self.triangle_mult_in = torch.nn.ModuleList(
            [TriangleMultiplicationIncoming(config, c=template_tri_mult_c) for _ in range(self.num_blocks)]
        )
        self.triangle_att_start = torch.nn.ModuleList(
            [
                TriangleAttentionStartingNode(config, c=template_tri_attn_c, num_heads=template_tri_attn_heads)
                for _ in range(self.num_blocks)
            ]
        )
        self.triangle_att_end = torch.nn.ModuleList(
            [
                TriangleAttentionEndingNode(config, c=template_tri_attn_c, num_heads=template_tri_attn_heads)
                for _ in range(self.num_blocks)
            ]
        )
        self.pair_transition = torch.nn.ModuleList(
            [PairTransition(config, n=template_pair_trans_n) for _ in range(self.num_blocks)]
        )
        self.final_layer_norm = torch.nn.LayerNorm(config.c_z)

    def forward(self, template_feat: torch.Tensor, pair_mask: Optional[torch.Tensor] = None):
        # template_feat shape: (batch, N_templates, N_res, N_res, c_t)

        # Project from template feature space to pair representation space
        # Output shape: (batch, N_templates, N_res, N_res, c_z)
        template_feat = self.linear_in(self.layer_norm(template_feat))

        b, t, n_i, n_j, c = template_feat.shape

        # Merge batch and template dims to process each template independently
        # Shape: (batch * N_templates, N_res, N_res, c_z)
        pair_representation = template_feat.reshape(b * t, n_i, n_j, c)
        flat_pair_mask = None
        if pair_mask is not None:
            flat_pair_mask = pair_mask.reshape(b * t, n_i, n_j)
            pair_representation = pair_representation * flat_pair_mask[..., None]

        for block_idx in range(self.num_blocks):
            if flat_pair_mask is not None:
                pair_representation = pair_representation * flat_pair_mask[..., None]
            pair_representation = pair_representation + dropout_rowwise(
                self.triangle_att_start[block_idx](pair_representation, pair_mask=flat_pair_mask),
                p=self.dropout_p,
                training=self.training,
            )
            pair_representation = pair_representation + dropout_columnwise(
                self.triangle_att_end[block_idx](pair_representation, pair_mask=flat_pair_mask),
                p=self.dropout_p,
                training=self.training,
            )
            pair_representation = pair_representation + dropout_rowwise(
                self.triangle_mult_out[block_idx](pair_representation, pair_mask=flat_pair_mask),
                p=self.dropout_p,
                training=self.training,
            )
            pair_representation = pair_representation + dropout_rowwise(
                self.triangle_mult_in[block_idx](pair_representation, pair_mask=flat_pair_mask),
                p=self.dropout_p,
                training=self.training,
            )
            pair_representation = pair_representation + self.pair_transition[block_idx](pair_representation)
            if flat_pair_mask is not None:
                pair_representation = pair_representation * flat_pair_mask[..., None]

        pair_representation = self.final_layer_norm(pair_representation)
        if flat_pair_mask is not None:
            pair_representation = pair_representation * flat_pair_mask[..., None]

        # Restore batch and template dims
        # Output shape: (batch, N_templates, N_res, N_res, c_z)
        pair_representation = pair_representation.reshape(b, t, n_i, n_j, c)

        return pair_representation

The template pair stack is Evoformer-lite: triangle multiplication (outgoing + incoming), triangle attention (starting + ending), pair transition. No MSA side. The output of the stack for each template is a pair tensor, and those tensors get combined across templates by a TemplatePointwiseAttention layer that lets the target sequence attend across its available templates.

For the minimal implementation and for this walkthrough, templates are present but optional. The tiny.toml config runs with zero templates; the full config runs with up to four. The important thing to take away is structural — templates enter the pipeline as pair-representation updates, exactly like the OPM output from the main MSA. By the time the Evoformer starts, zz has already been seeded with (a) relative positional encoding, (b) any template signal, and (c) the extra-MSA contribution. The main MSA will add more in each Evoformer block via the OPM.

9. Recycling

The last structural idea in the Evoformer-level description of AlphaFold2 is recycling — and it’s a surprisingly important one for something so simple. The setup is as follows: run a forward pass, get an MSA representation, a pair representation, and a structure module output. Then take (parts of) those outputs, feed them back into the inputs of the next forward pass, and run the whole thing again. Up to four times at inference, a random number of times during training.

Input features + Input embedder Evoformer × 48 Structure module Predicted structure detach → recycle (up to 4 cycles; gradients flow only through the last)
The recycling loop. On every cycle past the first, the MSA first row, the pair representation, and the pseudo-β coordinates from the previous cycle are fed back and added into the inputs — after LayerNorm and, for coordinates, distance binning. Gradients flow only through the last cycle.

What exactly gets fed back? Three things: the first row of the final MSA representation (m(1)m^{(1)}, the “target” row), the final pair representation zz, and a distance-binned version of the pseudo-Cβ coordinates extracted from the predicted structure. All three are run through their own LayerNorm and added into the current-cycle input embeddings. The pseudo-Cβ distances go through one extra step — recycling_distance_bin one-hot-encodes pairwise distances into 15 radial bins, linearly projects the bins into pair-representation dimension, and adds into zz.

)

# Algorithm 2 line 6 (= Algorithm 32 / RecyclingEmbedder):
#   m_1i  += LayerNorm(m_1i^prev)
#   z_ij  += LayerNorm(z_ij^prev) + Linear(one_hot(d_ij^prev))
# On the first cycle the prev tensors are zero, so these
# additions vanish. Clone before the in-place write to the
# first MSA row so the embedder's output tensor is untouched.
msa_repr = msa_representation.clone()
pair_repr = pair_representation.clone()
msa_repr[:, 0, :, :] += self.recycle_norm_s(single_rep_prev)
pair_repr += self.recycle_norm_z(z_prev)
pair_repr += self.recycle_linear_d(recycling_distance_bin(x_prev, n_bins=15))

# Algorithm 2 lines 7-8: template torsion-angle embedding,
# concatenated onto the MSA representation as extra rows so

Two details matter for how this trains.

Stop-gradient between cycles. The tensors fed back are detached from the computation graph. Gradients only flow through the final cycle of a forward pass; every earlier cycle is a no-grad refinement. This keeps the memory cost of the whole thing linear in the last cycle rather than in the sum of all cycles. It also has a regularization effect — the model has to learn representations that are useful both (a) when they’re consumed by the current cycle’s structure module and (b) when they’re consumed by the next cycle’s input embedder. Representations that work for both are more robust than representations tuned only for the immediate downstream.

Random number of cycles during training. At inference, four cycles is a fixed choice. At training, the number of cycles is sampled uniformly per batch. This means the model sees “1-cycle” inputs (no feedback yet), “2-cycle” inputs, etc., and must produce a reasonable prediction in all of them — rather than being allowed to specialize for the last cycle alone. The recycling loop is therefore not just an inference-time trick; it’s a training regime, and the weights have to learn to converge iteratively.

I find recycling to be the single most underrated component of AlphaFold2. It’s the mechanism that turns a deep-but-not-that-deep Evoformer into an effectively 192-block iterative refiner. The cost at inference is a 4× compute blow-up, but the accuracy improvement is large — the paper’s ablations show it’s worth roughly 4 GDT points on CASP14, comparable to removing half the Evoformer blocks. All that from a small addition to the input embedding. It’s also what makes self-distillation (training on the model’s own predictions of unlabeled sequences) work: the model has already been trained to consume its own output, so distillation inputs live on-manifold.

10. The Structure Module

Everything we’ve done so far — input embedding, the entire Evoformer, recycling — has been a wind-up. The pair representation zz is finally dense with structural information, and the first row of the MSA representation holds a structurally-informed per-residue embedding sis_i for the target sequence. None of it, however, has a position in 3D. The Structure Module is where those abstract representations become atoms.

It does this in a way that’s genuinely different from the Evoformer. There’s no self-attention on generic tensors here; instead, every operation is built around rigid frames — a local coordinate system attached to each residue — and every attention weight and coordinate update respects the fact that proteins live in 3D space and don’t care about how you’ve oriented the global axes. That constraint, SE(3) equivariance, is what gives the Structure Module its distinctive shape.

The module runs for 8 shared-weight iterations, and each iteration does four things: invariant point attention, a transition MLP, a backbone update, and all-atom coordinate construction. We’ll build up to the whole loop in §10e; first, let’s get the geometry straight.

10a. Rigid frames

Every residue in an AlphaFold2 prediction has an associated rigid frame: a translation tiR3t_i \in \mathbb{R}^3 and a rotation RiSO(3)R_i \in \mathrm{SO}(3). Together, Ti=(Ri,ti)T_i = (R_i, t_i) is an element of the special Euclidean group SE(3)\mathrm{SE}(3) — the group of rigid motions in 3D. You can think of TiT_i as a little floating coordinate system attached to the residue. The origin of that coordinate system sits on the residue’s CαC_\alpha atom; the xx axis points roughly along the CαCC_\alpha \to C bond; the yy axis lives in the plane of NN, CαC_\alpha, CC and is orthogonal to xx; and zz is their cross product.

def rigid_frame_from_three_points(
    point_on_neg_x_axis: torch.Tensor,
    origin: torch.Tensor,
    point_on_xy_plane: torch.Tensor,

The function computes frames from triples of atom coordinates — in AlphaFold2, these are the NN, CαC_\alpha, and CC backbone atoms — via a three-step Gram–Schmidt dance: translate so CαC_\alpha is at the origin, orient xx along CαCC_\alpha \to C, and orient yy to project NN into the xyxy-plane. The choice is standard (Algorithm 21), and the specific convention matters only inasmuch as it has to match the convention used to build side chains later.

Why carry around frames at all, rather than just Cartesian coordinates? Two reasons. First, frames express geometry relatively — for any two residues ii and jj, the position of jj “as seen by ii” is just Ti1tjT_i^{-1} \cdot t_j, which is residue jj‘s translation expressed in ii‘s local coordinate system. That operation is invariant under any global rigid motion of the whole protein, and invariance is something we want to bake in rather than learn. Second, frames compose cleanly — if you have an update ΔT\Delta T you’d like to apply to a residue, the composition TiTiΔTT_i \leftarrow T_i \cdot \Delta T is a frame. That’s how backbone updates work in §10c, and how side-chain atoms are built by composing torsion-angle rotations with literature frames in §10d.

Quaternions as the rotation parameterization

Rotations have three natural representations: rotation matrices (9 numbers, 6 constraints — orthogonality), Euler angles (3 numbers, all kinds of pathologies — gimbal lock, discontinuous wrap-around), and quaternions (4 numbers, 1 constraint — unit norm). AlphaFold2 uses quaternions everywhere a rotation is being output or updated by the model, and matrices everywhere a rotation is being applied.

The reason is mostly about smooth updates. When the Structure Module emits a backbone update, it outputs three scalars — let’s call them (b,c,d)(b, c, d) — and forms a quaternion q=(1,b,c,d)q = (1, b, c, d) (un-normalized). You then normalize to get a unit quaternion and convert to a rotation matrix. This parameterization has a very useful property: the identity rotation is at q=(1,0,0,0)q = (1, 0, 0, 0), which corresponds to (b,c,d)=(0,0,0)(b, c, d) = (0, 0, 0). That means “no update” is the all-zeros output, which is exactly what you want a transition MLP to default to at initialization. Contrast this with Euler angles, where no principled “zero” makes the model start stably, or rotation matrices, where outputting a valid RSO(3)R \in \mathrm{SO}(3) from a linear layer requires an awkward projection.

Concretely, a unit quaternion q=(w,x,y,z)q = (w, x, y, z) maps to a rotation matrix via:

R(q)=(12(y2+z2)2(xywz)2(xz+wy)2(xy+wz)12(x2+z2)2(yzwx)2(xzwy)2(yz+wx)12(x2+y2))R(q) = \begin{pmatrix} 1 - 2(y^2 + z^2) & 2(xy - wz) & 2(xz + wy) \\ 2(xy + wz) & 1 - 2(x^2 + z^2) & 2(yz - wx) \\ 2(xz - wy) & 2(yz + wx) & 1 - 2(x^2 + y^2) \end{pmatrix}
(quat→R)

The widget below is the clearest way I know to internalize this formula. Drag the four sliders to change the quaternion components, and the rotation matrix, the normalized quaternion, and the attached local frame (thick red/green/blue axes, plus a small gray cube) all update live. The world frame stays ghosted in the background for reference.

local x local y local z(faded = world frame)
‖q‖ = 1.000→ normalized:(1.00, 0.00, 0.00, 0.00)
R(q) =
 1.00 0.00 0.00 0.00 1.00 0.00 0.00 0.00 1.00
Drag inside the scene to rotate the camera. Adjust the sliders to change the quaternion; the matrix, the local frame, and the attached cube all update live. Only the direction of q matters — the rotation is unchanged if you scale all four components uniformly, since we always normalize (and treat the zero quaternion as the identity).

Play with it for a minute. A few things worth noticing: (1) the rotation depends only on the direction of qq, not its magnitude — since we always normalize, scaling all four sliders proportionally leaves the output unchanged. (2) qq and q-q give the same rotation matrix; this is the “double cover” of SO(3)\mathrm{SO}(3) by the unit quaternions S3S^3, and it means a quaternion of all-zero-but-ww and an all-zero-but-w-w both represent the identity rotation. (3) The presets for “90° about X” (or Y) sit at q=(1/2,1/2,0,0)q = (1/\sqrt{2}, 1/\sqrt{2}, 0, 0) rather than anything you’d expect from degrees — quaternions use the half-angle, because they double-cover. Pairs like these come up constantly when reading rotation code.

def backbone_frames(
    atom14_positions: torch.Tensor,
    atom14_mask: torch.Tensor,
    aatype: torch.Tensor | None = None,

The backbone_frames helper just calls rigid_frame_from_three_points in a loop across residues — this is how the ground-truth frames (and, initially, the model’s prior-cycle frames during recycling) get built from atom coordinates.

10b. Invariant Point Attention

Invariant Point Attention is the signature innovation of the Structure Module. Built on top of rigid frames, it’s a three-term attention score that is provably invariant to any global rigid motion of the input. This is not a soft inductive bias — it’s an algebraic property of the operation. Rotate the entire protein in space and the IPA scores are numerically unchanged.

The score between a query residue ii and a key residue jj for head hh looks like:

aijh1cqihkjhscalar+bijhpair biasγh2cpp=1cpTiqih,pTjkjh,p2point distancea^h_{ij} \propto \underbrace{\frac{1}{\sqrt{c}}\, q^h_i \cdot k^h_j}_{\text{scalar}} + \underbrace{b^h_{ij}}_{\text{pair bias}} - \underbrace{\frac{\gamma^h}{\sqrt{2\,c_p}} \sum_{p=1}^{c_p} \left\| T_i \cdot \vec{q}^{\,h,p}_i - T_j \cdot \vec{k}^{\,h,p}_j \right\|^2}_{\text{point distance}}
(22)

Three terms, three sources of information. The scalar term is a standard attention score on scalar queries and keys — the same as what the Evoformer does. The pair bias term injects the current pair representation zijz_{ij} into the attention pattern, just like row attention did in §4. The new piece is the point distance term.

Each query head carries a small bank of 3D query points qih,p\vec{q}^{\,h,p}_i — “3D” because they’re elements of R3\mathbb{R}^3, not because they’re positions in any particular frame. They’re expressed in residue ii‘s local frame; applying TiT_i lifts them into global coordinates. The same goes for key points kjh,p\vec{k}^{\,h,p}_j. The point-distance term is a negative squared-distance between each pair of globally-transformed points, averaged across points. If a query point lifted from ii ends up close in 3D to a key point lifted from jj, their pair gets a high attention score.

Why does this give invariance? Suppose we apply a global rigid motion GSE(3)G \in \mathrm{SE}(3) to every residue. Then TiGTiT_i \mapsto G \cdot T_i and TjGTjT_j \mapsto G \cdot T_j. The globally-transformed points become G(Tiqih,p)G \cdot (T_i \cdot \vec{q}^{\,h,p}_i) and G(Tjkjh,p)G \cdot (T_j \cdot \vec{k}^{\,h,p}_j), and rigid motions preserve distances, so the squared-distance term is unchanged. The scalar term doesn’t touch positions, and the pair bias comes from zz (which has no spatial orientation). Hence, all three terms are preserved, and the attention pattern is preserved.

Here’s what this looks like, conceptually, on a synthetic chain. Click a residue to select it as the query ii. Top-5 attention weights are drawn as lines from the selected sphere to the others, with thickness proportional to weight, and also listed in the panel on the right. Toggle Auto-rotate scene: the whole chain spins, but the attention pattern stays fixed — individual lines don’t brighten or fade as the scene rotates.

Query residue: i = 6Click any sphere to re-select
Top attention from residue 6
70.191
50.191
80.129
40.129
90.077

The attention weight for (i → j) depends on Ti−1 · tj— the position of residue j as seen from residue i's local frame. This is invariant under any global rotation of the chain, which is why the connecting lines thicken and fade only when you change the query, not when the scene spins.

The score function in the widget is simplified — I compute a single point per query/key and use a per-residue position only — but the geometric story it tells is the real one. A few things to watch for: (1) when you toggle “Auto-rotate scene,” the view spins but the weights don’t. That’s invariance in action. (2) Attention is not the same as nearest-neighbor in 3D — the sequence-distance penalty gives the pattern an asymmetry that a pure spatial score wouldn’t. (3) When you pick a residue in the middle of the chain, the attention weights are richer than at the ends — another thing a learned attention would mirror.

The code

class InvariantPointAttention(torch.nn.Module):
    """Invariant Point Attention (Algorithm 22, supplement 1.8.2).

    The geometry-aware attention that lets the Structure Module reason
    about 3-D context while staying equivariant to rigid motions of the
    input. For each head, three attention-score contributions are summed:

    1. Standard scalar Q·K attention on channel projections of ``s_i``.
    2. Pair bias ``b_{ij} = Linear(z_{ij})`` — the pair rep as score bias.
    3. "Point" attention on 3-D points ``Q, K, V`` transformed into each
       residue's local frame ``T_i``: the attention score for (i, j) is
       ``-γ_h · ||T_i(q_i^h) - T_j(k_j^h)||^2 / 2`` (line 7). Invariance
       to rigid motion of the whole protein is guaranteed because frames
       and points move together.

    The combined score is ``(1/sqrt(3)) · (scalar + bias + points)`` per
    supplement 1.8.2 (line 7), softmax'd over j, and the output pools
    scalar values, pair values, and point values (transformed back into
    each residue's local frame). ``seq_mask`` zeros both queries and
    keys before the softmax so padded residues neither attend nor are
    attended to.
    """

    def __init__(self, config):
        super().__init__()
        self.num_heads = config.ipa_num_heads
        self.head_dim = config.ipa_c
        self.total_dim = self.head_dim * self.num_heads
        self.n_query_points = config.ipa_n_query_points
        self.n_value_points = config.ipa_n_value_points
        self.inf = 1e5
        self.eps = 1e-8

        # Canonical AF2/OpenFold monomer IPA uses a biased query projection and
        # combined key/value projections for both scalar and point features.
        self.linear_q = torch.nn.Linear(in_features=config.c_s, out_features=self.total_dim, bias=True)
        self.linear_kv = torch.nn.Linear(in_features=config.c_s, out_features=2 * self.total_dim, bias=True)

        self.linear_q_points = torch.nn.Linear(
            in_features=config.c_s,
            out_features=3 * self.num_heads * self.n_query_points,
            bias=True,
        )
        self.linear_kv_points = torch.nn.Linear(
            in_features=config.c_s,
            out_features=3 * self.num_heads * (self.n_query_points + self.n_value_points),
            bias=True,
        )

        self.linear_bias = torch.nn.Linear(in_features=config.c_z, out_features=self.num_heads, bias=True)

        self.linear_output = torch.nn.Linear(
            in_features=self.total_dim
            + self.num_heads * self.n_value_points * 4
            + self.num_heads * config.c_z,
            out_features=config.c_s,
        )

        _init_linear(self.linear_q, init="default")
        _init_linear(self.linear_kv, init="default")
        _init_linear(self.linear_q_points, init="default")
        _init_linear(self.linear_kv_points, init="default")
        _init_linear(self.linear_bias, init="default")
        _init_linear(self.linear_output, init="final")

        self.head_weights = torch.nn.Parameter(torch.zeros(self.num_heads))

    def _project_points(
        self,
        linear: torch.nn.Linear,
        single_representation: torch.Tensor,
        num_points: int,
    ) -> torch.Tensor:
        raw_points = linear(single_representation)
        x_coords, y_coords, z_coords = torch.chunk(raw_points, 3, dim=-1)
        point_coords = torch.stack([x_coords, y_coords, z_coords], dim=-1)
        return point_coords.reshape(
            single_representation.shape[0],
            single_representation.shape[1],
            self.num_heads,
            num_points,
            3,
        )

    def _assemble_output_features(
        self,
        attention: torch.Tensor,
        value_scalar: torch.Tensor,
        value_points_global: torch.Tensor,
        pair_representation: torch.Tensor,
        rotations: torch.Tensor,
        translation: torch.Tensor,
    ) -> torch.Tensor:
        batch_size = rotations.shape[0]
        num_residues = rotations.shape[1]

        output_scalar = torch.matmul(
            attention,
            value_scalar.permute(0, 2, 1, 3).to(dtype=attention.dtype),
        ).permute(0, 2, 1, 3)
        output_scalar = output_scalar.reshape(batch_size, num_residues, -1)

        result_point_global = torch.einsum(
            "bhij,bjhpc->bihpc",
            attention,
            value_points_global.to(dtype=attention.dtype),
        )
        result_point_local = torch.einsum(
            "biop,bihqp->bihqo",
            rotations.transpose(-1, -2),
            result_point_global - translation[:, :, None, None, :],
        )
        result_point_norms = torch.sqrt(torch.sum(result_point_local ** 2, dim=-1) + self.eps)
        result_point_norms = result_point_norms.reshape(batch_size, num_residues, -1)

        # Canonical monomer IPA concatenates x/y/z point channels separately.
        result_point_local = result_point_local.reshape(batch_size, num_residues, -1, 3)
        result_point_x, result_point_y, result_point_z = result_point_local.unbind(dim=-1)

        output_pair = torch.einsum(
            "bhij,bijd->bihd",
            attention,
            pair_representation.to(dtype=attention.dtype),
        )
        output_pair = output_pair.reshape(batch_size, num_residues, -1)

        return torch.cat(
            [
                output_scalar,
                result_point_x,
                result_point_y,
                result_point_z,
                result_point_norms,
                output_pair,
            ],
            dim=-1,
        )

    def _forward_output_features(
        self,
        single_representation: torch.Tensor,
        pair_representation: torch.Tensor,
        rotations: torch.Tensor,
        translation: torch.Tensor,
        seq_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size = single_representation.shape[0]
        num_residues = single_representation.shape[1]

        q = self.linear_q(single_representation).reshape(batch_size, num_residues, self.num_heads, self.head_dim)
        kv = self.linear_kv(single_representation).reshape(batch_size, num_residues, self.num_heads, 2 * self.head_dim)
        k, v = torch.split(kv, self.head_dim, dim=-1)

        q_points = self._project_points(self.linear_q_points, single_representation, self.n_query_points)
        kv_points = self._project_points(
            self.linear_kv_points,
            single_representation,
            self.n_query_points + self.n_value_points,
        )
        k_points, v_points = torch.split(
            kv_points,
            [self.n_query_points, self.n_value_points],
            dim=-2,
        )

        query_points_global = torch.einsum("biop,bihqp->bihqo", rotations, q_points) + translation[:, :, None, None, :]
        key_points_global = torch.einsum("biop,bihqp->bihqo", rotations, k_points) + translation[:, :, None, None, :]
        value_points_global = torch.einsum("biop,bihqp->bihqo", rotations, v_points) + translation[:, :, None, None, :]

        bias = self.linear_bias(pair_representation)

        attention_logits = torch.matmul(
            q.permute(0, 2, 1, 3),
            k.permute(0, 2, 3, 1),
        )
        attention_logits *= math.sqrt(1.0 / (3.0 * self.head_dim))
        attention_logits += math.sqrt(1.0 / 3.0) * bias.permute(0, 3, 1, 2)

        point_attention = query_points_global[:, :, None, :, :, :] - key_points_global[:, None, :, :, :, :]
        point_attention = torch.sum(point_attention ** 2, dim=-1)
        head_weights = torch.nn.functional.softplus(self.head_weights).view(1, 1, 1, self.num_heads, 1)
        head_weights = head_weights * math.sqrt(1.0 / (3.0 * (self.n_query_points * 9.0 / 2.0)))
        point_attention = torch.sum(point_attention * head_weights, dim=-1) * (-0.5)
        attention_logits += point_attention.permute(0, 3, 1, 2)

        if seq_mask is not None:
            square_mask = seq_mask.unsqueeze(-1) * seq_mask.unsqueeze(-2)
            attention_logits = attention_logits + self.inf * (square_mask[:, None, :, :] - 1.0)

        attention = torch.nn.functional.softmax(attention_logits, dim=-1)
        return self._assemble_output_features(
            attention=attention,
            value_scalar=v,
            value_points_global=value_points_global,
            pair_representation=pair_representation,
            rotations=rotations,
            translation=translation,
        )

    def forward(self, single_representation: torch.Tensor, pair_representation: torch.Tensor,
                rotations: torch.Tensor, translation: torch.Tensor,
                seq_mask: Optional[torch.Tensor] = None):
        # single_rep shape: (batch, N_res, c_s)
        # pair_rep shape: (batch, N_res, N_res, c_z)
        # rotations shape: (batch, N_res, 3, 3)
        # translations shape: (batch, N_res, 3)
        # seq_mask shape: (batch, N_res) — 1 for valid, 0 for padding
        assert rotations.shape[-2:] == (3, 3), \
            f"rotations must end with (3, 3), got {rotations.shape}"
        assert translation.shape[-1] == 3, \
            f"translation must end with (3,), got {translation.shape}"

        output_features = self._forward_output_features(
            single_representation,
            pair_representation,
            rotations,
            translation,
            seq_mask=seq_mask,
        )
        output = self.linear_output(output_features)

        # Zero out padded query positions
        if seq_mask is not None:
            output = output * seq_mask[:, :, None]

        return output

Three things to look for as you read. First, the linear projections fall into three clean groups: scalar queries/keys/values (linear_q, linear_kv, etc.), 3D query/key/value points (linear_q_points, linear_kv_points), and the gate. The scalar ones produce tensors of shape (r,h,c)(r, h, c); the point ones produce (r,h,3,cp)(r, h, 3, c_p) — the extra dimension is 3D space. Second, _project_points takes the point tensors and applies each residue’s frame to them — this is the T_i · q operation in the equation. Third, the score is a sum of the three terms: a standard Q·K dot product, a pair bias projection, and a squared-distance sum across points, each scaled by a learnable per-head weight γh\gamma^h so that the model can dial any of the three terms up or down.

The output is a bit more elaborate than usual, too. IPA returns (1) a weighted sum of scalar values (regular attention output), (2) a weighted sum of value points — this is a point-cloud output, one 3D vector per head, which the subsequent linear layer turns into an additive update to the single representation sis_i, and (3) a flat scalar derived from the attention-weighted norms of the value points. All three get concatenated and fed into the transition MLP. In other words, we get scalar information out of IPA, but we also get geometric information — a key insight the paper drives home.

IPA is genuinely new. There are plenty of attention variants that operate on graphs or respect some symmetry, but IPA’s particular trick — attention scores as a learnable mixture of scalar, pair-bias, and 3D-point-distance terms, with the point-distance term made invariant by working in local frames — didn’t exist before AF2. Every downstream structure-prediction paper (ESMFold’s folding trunk, OpenFold’s reimplementation, AlphaFold-Multimer, RoseTTAFold’s refinement network) borrows from it. When you see “SE(3)-equivariant attention” in later papers, IPA is one of the ancestors.

10c. Backbone update

IPA and its transition MLP produce an updated single representation sis_i per residue. The backbone update is the step that takes that sis_i and turns it into an additive move on the frame TiT_i. The mechanics are simple: project sis_i to six scalars, interpret the first three as the imaginary part of an un-normalized quaternion, interpret the last three as a translation, and compose into TiT_i.

class BackboneUpdate(torch.nn.Module):
    """Backbone frame update (Algorithm 23, supplement 1.8.3).

    Projects ``s_i`` to six scalars per residue: three for the local
    axis-angle rotation update (interpreted as a unit quaternion
    ``(1, b, c, d)`` after normalising, line 3-4) and three for the
    translation update (in the residue's own frame, line 5). The
    current frame ``T_i`` is then updated as ``T_i ← T_i ∘ (R, t)``.
    The output projection is ``final``-initialised (zero) so each IPA
    iteration starts as a no-op — training discovers the updates.
    """

    def __init__(self, config):
        super().__init__()
        self.linear = torch.nn.Linear(in_features=config.c_s, out_features=6)
        _init_linear(self.linear, init="final")

    def forward(self, single_representation: torch.Tensor):
        # output shape; (batch, N_res, 6)
        vals = self.linear(single_representation)

        # Rotation quaternions
        b = vals[:, :, 0]
        c = vals[:, :, 1]
        d = vals[:, :, 2]

        a = torch.ones_like(b)
        norm = torch.sqrt(1 + b**2 + c**2 + d**2)

        a = a / norm
        b = b / norm
        c = c / norm
        d = d / norm

        # Construct pairwise multiplications for rotation matrix
        aa = a*a
        bb = b*b
        cc = c*c
        dd = d*d

        ab = a*b
        ac = a*c
        ad = a*d

        bc = b*c
        bd = b*d

        cd = c*d

        # Construct rotation matrix entries
        r11 = aa + bb - cc - dd
        r12 = 2*bc - 2*ad
        r13 = 2*bd + 2*ac

        r21 = 2*bc + 2*ad
        r22 = aa - bb + cc - dd
        r23 = 2*cd - 2*ab

        r31 = 2*bd - 2*ac
        r32 = 2*cd + 2*ab
        r33 = aa - bb - cc + dd

        # Output shape: (batch, N_res, 3, 3)
        R = torch.stack([r11, r12, r13, r21, r22, r23, r31, r32, r33], dim=-1).reshape((single_representation.shape[0], single_representation.shape[1], 3, 3))

        t = vals[:, :, 3:]

        return R, t

The quaternion construction is the clever bit. Rather than projecting sis_i to four scalars (w,x,y,z)(w, x, y, z), AlphaFold2 projects to three — (b,c,d)(b, c, d) — and forms q=(1,b,c,d)q = (1, b, c, d). Why? Because at initialization, linear_proj is near-zero, so (b,c,d)(0,0,0)(b, c, d) \approx (0, 0, 0), and q(1,0,0,0)q \approx (1, 0, 0, 0) — the identity quaternion. Hence, the update is the identity at initialization. Without this trick, the model would start predicting random rotations on cycle 1 and have to unlearn them; with it, the model starts at “no change from the previous frame” and learns corrections.

Read that carefully — the backbone update is additive in the space of rotations, not in the space of quaternion components. The composition TiTi(R(q),tpred)T_i \leftarrow T_i \cdot (R(q), t_{\text{pred}}) multiplies rotations (because SO(3)\mathrm{SO}(3) isn’t a vector space) and adds translations. The QuaternionFrame widget from §10a lets you see concretely what small perturbations in (b,c,d)(b, c, d) do to the rotation matrix — set (b,c,d)(b, c, d) to small values and watch the rotation stay close to identity.

10d. All-atom coordinates

A per-residue backbone frame isn’t enough to call a protein “predicted.” We need all the atoms: backbone (N, Cα, C, O) plus whatever side chain the residue type has. For all 20 amino acids, this adds up to at most 14 atoms per residue — the “atom14” representation used throughout the codebase.

So how do we get from backbone frames to all atoms? The trick AlphaFold2 uses is to reduce side-chain structure to seven torsion angles. Every residue, regardless of type, has (at most) seven rotatable bonds — the three backbone dihedrals (ϕ\phi, ψ\psi, ω\omega) plus four χ\chi side-chain angles. Given the torsion angles and the residue type, the full all-atom structure is determined — the geometry of every bond length and every bond angle is fixed by chemistry (literature values in residue_constants.py), and only the torsions are free.

AlphaFold2 predicts the seven torsions per residue with a small head attached to the Structure Module output, then composes a chain of small rigid-body transformations to produce the atom coordinates:

def compute_all_atom_coordinates(
    translations: torch.Tensor,   # (batch, N_res, 3)
    rotations: torch.Tensor,      # (batch, N_res, 3, 3)
    torsion_angles: torch.Tensor, # (batch, N_res, 7, 2) — [ω, φ, ψ, χ1, χ2, χ3, χ4]
    aatype: torch.Tensor,         # (batch, N_res) — integer residue type indices
    default_frames: torch.Tensor, # (21, 8, 4, 4) — registered buffer
    lit_positions: torch.Tensor,  # (21, 14, 3) — registered buffer
    atom_frame_idx_table: torch.Tensor,  # (21, 14) — registered buffer
    atom_mask_table: torch.Tensor,       # (21, 14) — registered buffer

For each residue, eight rigid groups get assembled:

  1. Backbone (N, Cα, C, O): sits in the residue’s backbone frame TiT_i directly.
  2. Group for χ1\chi_1: positioned by rotating around a literature-defined axis by the predicted χ1\chi_1 angle, with literature-defined atoms attached.
  3. Group for χ2\chi_2: positioned relative to group χ1\chi_1 by another torsion rotation.
  4. Groups for χ3\chi_3, χ4\chi_4: same pattern.
  5. Three more groups for secondary geometry (e.g., ring planes).
def rigid_group_frames_from_torsions(
    translations: torch.Tensor,   # (batch, N_res, 3)
    rotations: torch.Tensor,      # (batch, N_res, 3, 3)
    torsion_angles: torch.Tensor, # (batch, N_res, 7, 2) — [ω, φ, ψ, χ1, χ2, χ3, χ4]
    aatype: torch.Tensor,         # (batch, N_res)
    default_frames: torch.Tensor, # (21, 8, 4, 4)

The make_rot_x helper does the geometric work — it rotates around the local xx-axis by a learned angle, encoded as its sine and cosine (rather than the raw angle, so the model doesn’t have to deal with wrap-around at ±π\pm \pi). Each rigid-group frame TggroupT^\text{group}_g is computed as TgparentTgliteratureRotx(αg)T^\text{parent}_g \cdot T^\text{literature}_g \cdot \mathrm{Rot}_x(\alpha_g) — compose the parent group’s frame with the literature “default” transformation with a torsion-angle rotation.

Once the eight group frames are known, the actual atom positions are looked up from residue_constants.py’s restype_atom14_rigid_group_positions table — which says, for each amino acid, “atom at index kk of atom14 sits at local coordinates pkaa\vec{p}^{\text{aa}}_k in rigid group gkaag^{\text{aa}}_k.” Transform pkaa\vec{p}^{\text{aa}}_k by the group frame and you have the atom’s global position.

The torsion head and the atom-coordinate composition together are surprisingly simple given what they accomplish. The whole thing — 8 rigid groups, 7 torsion angles per residue, a lookup table of atom positions — fits in a few hundred lines. And yet it produces full-atom predictions that crystallographers find plausible at angstrom resolution.

What makes this work, in hindsight, is the separation of concerns. The Structure Module doesn’t have to learn bond lengths or bond angles; chemistry handles those. It only has to learn (a) where to place each residue’s backbone frame and (b) which torsion angles to emit. The residue_constants tables do the rest.

10e. The full structure module loop

With IPA, the backbone update, and the all-atom construction in place, the Structure Module itself is a short loop. Eight iterations, weights shared across iterations (like a deep equilibrium model), with each iteration refining both the backbone frames and the torsion angles.

class StructureModule(torch.nn.Module):
    """Structure Module — iterative IPA + frame update (Algorithm 20).

    Given the Evoformer's ``s_i`` and ``z_ij``:

    1. Normalise inputs and project ``s_i`` to the module's internal
       ``c`` channel dim (line 1-2).
    2. Initialise the per-residue rigid frame ``T_i`` to the identity
       ("black-hole initialisation", line 3).
    3. Run ``config.structure_module_layers`` IPA iterations (default
       8): :class:`InvariantPointAttention` updates ``s_i`` with
       geometric context; a transition MLP mixes channels; a
       :class:`BackboneUpdate` applies a local rotation + translation
       to ``T_i`` (supplement 1.8.3).
    4. After every iteration, :class:`MultiRigidSidechain` predicts
       seven torsion angles from ``s_i`` and :func:`compute_all_atom_coordinates`
       rolls them + ``T_i`` out to atom14 coordinates (Algorithm 24).

    Units: internal nm (literature bond lengths divided by
    ``position_scale``, output translations multiplied back to Å).
    The frame rotation is detached between iterations (line 13,
    "stop_gradient" on rotation only) — translations keep gradients so
    the auxiliary FAPE on Cα at every layer has signal through to the
    Evoformer.
    """

    default_frames: torch.Tensor
    lit_positions: torch.Tensor
    atom_frame_idx_table: torch.Tensor
    atom_mask_table: torch.Tensor

    def __init__(self, config):
        super().__init__()
        self.c = config.structure_module_c
        self.num_layers = config.structure_module_layers
        self.position_scale = float(getattr(config, "position_scale", 10.0))

        # Layer Norms
        self.layer_norm_single_rep_1 = torch.nn.LayerNorm(config.c_s)
        self.layer_norm_single_rep_2 = torch.nn.LayerNorm(config.c_s)
        self.layer_norm_single_rep_3 = torch.nn.LayerNorm(config.c_s)

        self.layer_norm_pair_rep = torch.nn.LayerNorm(config.c_z)

        # Dropouts (rates from config)
        self.dropout_1 = torch.nn.Dropout(p=config.structure_module_dropout_ipa)
        self.dropout_2 = torch.nn.Dropout(p=config.structure_module_dropout_transition)

        # Register residue constant tensors as buffers (avoid device bugs, improve speed)
        self.register_buffer('default_frames',
            torch.tensor(restype_rigid_group_default_frame))   # (21, 8, 4, 4)
        self.register_buffer('lit_positions',
            torch.tensor(restype_atom14_rigid_group_positions)) # (21, 14, 3)
        self.register_buffer('atom_frame_idx_table',
            torch.tensor(restype_atom14_to_rigid_group))        # (21, 14)
        self.register_buffer('atom_mask_table',
            torch.tensor(restype_atom14_mask))                  # (21, 14)

        # Linear layers
        self.single_rep_proj = torch.nn.Linear(in_features=config.c_s, out_features=config.c_s)
        self.transition_linear_1 = torch.nn.Linear(in_features=config.c_s, out_features=config.c_s)
        self.transition_linear_2 = torch.nn.Linear(in_features=config.c_s, out_features=config.c_s)
        self.transition_linear_3 = torch.nn.Linear(in_features=config.c_s, out_features=config.c_s)
        _init_linear(self.single_rep_proj, init="default")
        _init_linear(self.transition_linear_1, init="relu")
        _init_linear(self.transition_linear_2, init="relu")
        _init_linear(self.transition_linear_3, init="final")

        # Core blocks
        self.IPA = InvariantPointAttention(config)
        self.backbone_update = BackboneUpdate(config)
        self.sidechain_module = MultiRigidSidechain(config)

        self.relu = torch.nn.ReLU()

        # AF2 keeps structure-module translations in internal units and rescales
        # them by position_scale when materializing coordinates and losses.
        internal_scale = 1.0 / self.position_scale
        self.default_frames[..., :3, 3] *= internal_scale
        self.lit_positions *= internal_scale

    def forward(self, single_representation: torch.Tensor, pair_representation: torch.Tensor,
                aatype: torch.Tensor, seq_mask: Optional[torch.Tensor] = None,
                detach_rotations: bool = True):
        # seq_mask: (batch, N_res) — 1 for valid residues, 0 for padding
        # detach_rotations: if True (default, AF2 standard), apply stopgrad to
        #   the rotation component of T_i between iterations (Algorithm 20
        #   lines 19-21). The detach is placed at the *end* of each non-final
        #   iteration so that the next iteration's IPA sees T_i with no
        #   rotation-gradient path, preventing lever effects through the
        #   chained composition of frames. Set to False to allow full gradient
        #   flow (useful for memorization/debugging).
        assert single_representation.ndim == 3, \
            f"single_representation must be (batch, N_res, c_s), got {single_representation.shape}"
        assert pair_representation.ndim == 4, \
            f"pair_representation must be (batch, N_res, N_res, c_z), got {pair_representation.shape}"
        assert aatype.ndim == 2, \
            f"aatype must be (batch, N_res), got {aatype.shape}"
        assert single_representation.shape[1] == pair_representation.shape[1] == pair_representation.shape[2] == aatype.shape[1], \
            f"N_res mismatch: single={single_representation.shape[1]}, pair={pair_representation.shape[1:3]}, aatype={aatype.shape[1]}"

        single_representation = self.layer_norm_single_rep_1(single_representation)
        initial_single_representation = single_representation

        pair_representation = self.layer_norm_pair_rep(pair_representation)

        s = self.single_rep_proj(single_representation)

        rotations = torch.eye(3, device=s.device, dtype=s.dtype).view(1, 1, 3, 3).expand(s.shape[0], s.shape[1], 3, 3)

        translations = torch.zeros(s.shape[0], s.shape[1], 3, device=s.device, dtype=s.dtype)

        # Collect intermediates for auxiliary losses
        all_rotations = []
        all_translations = []
        all_torsion_angles = []
        all_torsion_angles_unnormalized = []
        sidechain_outputs = None

        for l in range(self.num_layers):
            # Algorithm 20, line 6-7: s += IPA(s); s = LN(Dropout(s))
            s = s + self.IPA(s, pair_representation, rotations, translations, seq_mask)
            s = self.layer_norm_single_rep_3(self.dropout_1(s))

            # Algorithm 20, line 8-9: s += Transition(s); s = LN(Dropout(s))
            s = s + self.transition_linear_3(self.relu(self.transition_linear_2(self.relu(self.transition_linear_1(s)))))
            s = self.layer_norm_single_rep_2(self.dropout_2(s))

            # Algorithm 20, line 10: T_i ← T_i ∘ BackboneUpdate(s_i)
            new_rotations, new_translations = self.backbone_update(s)
            translations = torch.einsum('bsij, bsj -> bsi', rotations, new_translations) + translations
            rotations = torch.einsum('bsij, bsjk -> bsik', rotations, new_rotations)

            # Algorithm 20, lines 11-18: side-chain torsions and auxiliary losses use
            # the updated (un-detached) T_i so gradients reach this iteration's s.
            sidechain_outputs = self.sidechain_module(
                s,
                initial_single_representation,
                rotations,
                translations,
                aatype,
                self.default_frames,
                self.lit_positions,
                self.atom_frame_idx_table,
                self.atom_mask_table,
            )

            all_rotations.append(rotations)
            all_translations.append(translations)
            all_torsion_angles.append(sidechain_outputs["angles_sin_cos"])
            all_torsion_angles_unnormalized.append(sidechain_outputs["unnormalized_angles_sin_cos"])

            # Algorithm 20, lines 19-21: stopgrad on rotations between iterations
            # (but not after the final one). Detaching *after* the backbone update
            # means iteration l+1 receives the rotation without a gradient path,
            # exactly as the supplement prescribes.
            if detach_rotations and l < self.num_layers - 1:
                rotations = rotations.detach()

        all_rotations = torch.stack(all_rotations)
        all_translations = torch.stack(all_translations)
        all_torsion_angles = torch.stack(all_torsion_angles)
        all_torsion_angles_unnormalized = torch.stack(all_torsion_angles_unnormalized)
        if sidechain_outputs is None:
            raise RuntimeError("StructureModule produced no sidechain outputs.")

        all_frames_R = sidechain_outputs["frames_R"]
        all_frames_t = sidechain_outputs["frames_t"]
        atom_coords = sidechain_outputs["atom_pos"]
        mask = sidechain_outputs["atom_mask"]

        # Convert internal structure-module units back to angstroms.
        # Rotations are unitless — no conversion needed
        predictions = {
            # Per-layer backbone frames for auxiliary FAPE loss
            "traj_rotations": all_rotations,               # (num_layers, batch, N_res, 3, 3)
            "traj_translations": all_translations * self.position_scale,

            # Per-layer torsion angles for torsion angle loss
            "traj_torsion_angles": all_torsion_angles,     # (num_layers, batch, N_res, 7, 2)
            "traj_torsion_angles_unnormalized": all_torsion_angles_unnormalized,

            # Final backbone frames
            "final_rotations": rotations,                  # (batch, N_res, 3, 3)
            "final_translations": translations * self.position_scale,

            # Final all-atom outputs (8 rigid-group frames including backbone frame 0)
            "all_frames_R": all_frames_R,                  # (batch, N_res, 8, 3, 3)
            "all_frames_t": all_frames_t * self.position_scale,
            "atom14_coords": atom_coords * self.position_scale,
            "atom14_mask": mask,                           # (batch, N_res, 14)

            # Final single representation (for distogram, pLDDT, etc.)
            "single": s,                                   # (batch, N_res, c_s)
        }

        return predictions

Two details to notice. First, every iteration applies IPA, transition, and backbone update, and all eight iterations see the same pair representation zzzz is computed once at the top of the Structure Module (post-Evoformer) and held fixed for the whole structure loop. The thing that changes iteration over iteration is the frames, which feed back into IPA’s point-distance term and shift the attention pattern as the structure settles. Second, every iteration produces a full prediction — backbone frames and torsion-derived coordinates — and per-iteration auxiliary FAPE losses are applied during training, so the model is pressured to make every iteration a reasonable structure, not just the last one.

This is different from vanilla encoders, which reserve the loss for the final layer. Applying FAPE at every iteration has two effects: it prevents the early iterations from becoming “scratch space” that’s only useful for the final iteration, and it lets the model work as a progressive refiner at inference time, producing increasingly accurate structures at every step. Combined with the weight sharing, you can think of the Structure Module as an unrolled fixed-point iteration — and like most such iterations, in practice, it converges well before the eighth step for most proteins.

After the loop finishes, the Structure Module emits the final frames TiT_i, the final atom14 coordinates, and the per-iteration predictions (for auxiliary losses). These, plus the final single and pair representations, feed into the auxiliary heads (§11) for distogram and pLDDT predictions, and — critically — the final pseudo-Cβ positions feed back into the recycling loop (§9) for the next cycle.

11. Auxiliary heads

The Structure Module produces the thing you really want — atom coordinates — but AlphaFold2 also emits a handful of auxiliary predictions from small heads attached to the final representations. Most of them exist to provide supervision signal during training, and some also give users useful information at inference.

The heads all live in a single short file:

HeadInputOutputPurpose
DistogramHead pair rep zz64-bin distribution over CβC_\betaCβC_\beta distancesDense structural supervision; predicts a probability distribution over pairwise distances.
PLDDTHead single rep sis_i50-bin distribution over per-residue lDDTModel’s own confidence estimate; what gets output as the pLDDT score in predicted PDBs.
MaskedMSAHead full MSA rep mm23-way classification per masked tokenBERT-style reconstruction signal on randomly masked MSA residues. Training-time only.
ExperimentallyResolvedHead single rep sis_iper-atom binary maskPredicts which atoms were experimentally resolved in the X-ray crystal. Helps the model down-weight loosely-defined regions in training.
TMScoreHead pair rep zz64-bin PAE matrixPredicted Aligned Error — pairwise expected error if the predicted structure were aligned at residue ii. This is the output that justifies “this part of the prediction is trustworthy.”
class DistogramHead(torch.nn.Module):
    """Distogram head (supplement 1.9.8).

    Projects the pair representation ``z_ij`` to ``n_dist_bins`` bin
    logits and symmetrises across (i, j). The loss target is one-hot
    binned Cβ-Cβ (or Cα for glycine) distance between ground-truth
    residue pairs, cross-entropy averaged (eq 41).
    """

    def __init__(self, config):
        super().__init__()
        self.linear = torch.nn.Linear(config.c_z, config.n_dist_bins)
        # Supplement 1.11.4: zero-init residue distance prediction logits
        _zero_init_linear(self.linear)

    def forward(self, pair_representation: torch.Tensor):
        # pair_representation: (batch, N_res, N_res, c_z)
        logits = self.linear(pair_representation)  # (batch, N_res, N_res, n_dist_bins)
        # Distograms of an unordered pair are symmetric by definition; average
        # the (i, j) and (j, i) logits so predictions match that invariance.
        logits = (logits + logits.transpose(1, 2)) / 2
        return logits

Every head is architecturally trivial — a LayerNorm, a linear projection (sometimes two), and a softmax or sigmoid. The interesting work happens in how the outputs are trained against the right targets, which is where the losses in §12 come in. The key architectural thing to notice here is that the distogram and PAE heads both read from zz, not from the atom coordinates directly. The pair representation has to simultaneously (a) drive the Structure Module’s IPA via pair bias and point scoring, and (b) carry enough geometric information that you can read pairwise distances straight off it. That dual responsibility is part of why the pair stack is so disproportionately expensive in the Evoformer — zz is doing two jobs.

class PLDDTHead(torch.nn.Module):
    """Per-residue confidence head (Algorithm 29 / supplement 1.9.6).

    Takes the post-Structure-Module single representation ``s_i``, normal-
    ises it, passes it through a two-layer ReLU MLP, and projects to
    ``n_plddt_bins`` (default 50). The scalar pLDDT for residue i is
    ``sum_k p_i^k · v_k`` where ``v_k = k + 0.5`` scaled into [1, 99] —
    that bin-centre transform lives in ``pdbio`` (B-factor write path) and
    ``losses.PLDDTLoss`` (LDDT-Cα supervision).
    """

    def __init__(self, config):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.LayerNorm(config.c_s),
            torch.nn.Linear(config.c_s, config.plddt_hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(config.plddt_hidden_dim, config.plddt_hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(config.plddt_hidden_dim, config.n_plddt_bins),
        )
        # Supplement 1.11.4: zero-init model confidence prediction logits
        final_linear: torch.nn.Linear = self.net[-1]  # type: ignore[assignment]
        _zero_init_linear(final_linear)

    def forward(self, single_representation: torch.Tensor):
        # single_representation: (batch, N_res, c_s)
        return self.net(single_representation)  # (batch, N_res, n_plddt_bins)

The pLDDT head is the one casual users interact with most. The coloring you see on predicted structures in the AlphaFold database — blue for “very high confidence,” green/yellow/orange/red descending — is per-residue pLDDT, with the bin centers averaged into a single scalar. Under the hood, it’s a 50-way classifier trained against lDDT-CαC_\alpha scores computed against the ground truth structure (lDDT is a local, alignment-free similarity metric that roughly measures the fraction of residue pairs whose predicted distance is close to the true distance, with a tolerance that shrinks as the distance shrinks). In essence, the model is learning to predict its own error.

12. Losses

The loss is where AlphaFold2 really spends its design budget. Structure prediction has a particular property that makes the loss non-trivial: the output lives in a space (SE(3)-equivariant atom coordinates) where naive 2\ell_2 distance between predicted and true structures depends on the global orientation of both, which is not something either model or supervisor cares about. If you predict a structure that’s correct but globally rotated 90°, a coordinate MSE loss will punish you enormously for it, even though the prediction is chemically and biologically identical.

AlphaFold2’s solution is Frame-Aligned Point Error — FAPE. It’s the single most important loss in the paper and the centerpiece of this section. After FAPE, we’ll cover the rest quickly: torsion loss, pLDDT loss, distogram loss, structural violations, and the combined AlphaFoldLoss that weights them all together.

12.1 FAPE

The idea of FAPE is simple once you’ve internalized rigid frames from §10 — rather than comparing predicted and true atom positions in a single global coordinate system, FAPE compares them in every residue’s local frame, one at a time, and averages. Concretely, for a predicted frame TiT_i and a predicted atom position xjx_j, and their ground-truth counterparts TiT_i^* and xjx_j^*, FAPE’s per-pair error is:

dij=Ti1xj    (Ti)1xjd_{ij} = \left\| T_i^{-1} \cdot x_j \;-\; (T_i^*)^{-1} \cdot x_j^* \right\|
(28)

Read the two terms from inside out. Ti1xjT_i^{-1} \cdot x_j is “the position of atom xjx_j as seen by residue ii‘s local frame” — exactly the same “lift into local frame” operation we met in IPA (§10b). (Ti)1xj(T_i^*)^{-1} \cdot x_j^* is the same thing, but in the ground truth structure — where residue ii should have seen atom jj sit. The error dijd_{ij} is the Euclidean distance between those two local-frame vectors.

This has two crucial properties. First, it’s SE(3) invariant by construction. Apply any global rigid motion GG to the predicted structure; frames transform as TiGTiT_i \mapsto G \cdot T_i and atoms as xjGxjx_j \mapsto G \cdot x_j; the local-frame position becomes (GTi)1(Gxj)=Ti1G1Gxj=Ti1xj(G \cdot T_i)^{-1} \cdot (G \cdot x_j) = T_i^{-1} \cdot G^{-1} \cdot G \cdot x_j = T_i^{-1} \cdot x_j. Same answer. The loss does not penalize global rotations or translations — exactly as we wanted. Second, it’s per-residue: the error is between residue ii‘s view and residue ii‘s view ^*, not a blob of globally-aggregated error. Every residue gets its own say about how the structure looks locally, and the loss is an average over all these views.

def frame_aligned_point_error(
    predicted_rotations: torch.Tensor,
    predicted_translations: torch.Tensor,
    true_rotations: torch.Tensor,
    true_translations: torch.Tensor,
    predicted_positions: torch.Tensor,
    true_positions: torch.Tensor,
    frames_mask: torch.Tensor,
    positions_mask: torch.Tensor,
    *,
    length_scale: float,
    pair_mask: Optional[torch.Tensor] = None,
    l1_clamp_distance: Optional[float] = None,
    eps: float,

Two implementation details matter here.

The distance clamp. The raw error can be arbitrarily large — if the model gets a structure very wrong, some local-frame distances can blow up to hundreds of angstroms. FAPE clamps at d_clamp, default 10 Å. This is more than cosmetic: without clamping, the gradient from an outlier atom swamps everything else and training goes unstable. With clamping, the loss is bounded above by d_clamp, and the model is pressured to bring every atom within 10 Å before it’s rewarded for bringing the already-close atoms even closer. The paper’s ablations show the clamp is load-bearing.

The normalization. FAPE divides the sum of clamped distances by d_clamp × mask.sum(), so the loss is dimensionless and in [0, 1]. This is a small thing, but it matters for how weights in the combined loss are chosen — every term stays on the same scale.

12.2 Per-iteration backbone FAPE

The Structure Module’s 8 iterations each produce a set of backbone frames (§10e). During training, FAPE is evaluated on every iteration’s output — not just the last one — and the per-iteration losses are averaged into a single “backbone trajectory” loss.

class BackboneTrajectoryLoss(torch.nn.Module):
    """Per-iteration backbone FAPE averaged over layers (L_aux^{FAPE}).

    Algorithm 20 emits backbone frames at every iteration l ∈ [1, N_layer];
    line 17 computes a Cα-only FAPE against ground truth on each iteration
    and line 23 averages them to yield the FAPE component of L_aux.

    `use_clamped_fape ∈ [0, 1]` (supplement 1.11.5): weight of the clamped
    FAPE in a soft mix with the unclamped version. `None` ≡ 1.0 (fully
    clamped). AlphaFold samples 10% of mini-batches to be fully unclamped,
    so the expected loss per batch has ≈0.9 weight on the clamped form;
    passing `use_clamped_fape=0.9` reproduces that expectation directly.
    """

    def __init__(self):
        super().__init__()
        self.fape_loss = BackboneFAPE()

    def forward(
            self,
            structure_model_prediction: dict,
            true_rotations: torch.Tensor,          # (b, N_res, 3, 3)
            true_translations: torch.Tensor,        # (b, N_res, 3)
            backbone_mask: Optional[torch.Tensor] = None,
            seq_mask: Optional[torch.Tensor] = None,
            use_clamped_fape: Optional[torch.Tensor] = None,
        ):
        traj_R = structure_model_prediction["traj_rotations"]          # (L, b, N_res, 3, 3)
        traj_t = structure_model_prediction["traj_translations"]       # (L, b, N_res, 3)

        num_layers = traj_R.shape[0]
        total_loss = torch.zeros(traj_R.shape[1], device=traj_R.device, dtype=traj_R.dtype)
        if backbone_mask is None:
            valid_mask = seq_mask
        elif seq_mask is None:
            valid_mask = backbone_mask
        else:
            valid_mask = backbone_mask * seq_mask

        for l in range(num_layers):
            clamped_fape = self.fape_loss(
                traj_R[l],
                traj_t[l],
                true_rotations,
                true_translations,
                frame_mask=valid_mask,
                position_mask=valid_mask,
                l1_clamp_distance=self.fape_loss.d_clamp_val,
            )
            if use_clamped_fape is None:
                total_loss = total_loss + clamped_fape
            else:
                unclamped_fape = self.fape_loss(
                    traj_R[l],
                    traj_t[l],
                    true_rotations,
                    true_translations,
                    frame_mask=valid_mask,
                    position_mask=valid_mask,
                    l1_clamp_distance=None,
                )
                total_loss = total_loss + (
                    clamped_fape * use_clamped_fape + unclamped_fape * (1.0 - use_clamped_fape)
                )

        # Algorithm 20 line 23: L_aux = mean_l(L_aux^l). This returns just the
        # FAPE component; the torsion component is added separately by the
        # caller as TorsionAngleLoss.
        return total_loss / num_layers

This is a meaningful design choice. It pressures every iteration to produce a structurally reasonable prediction, not just the final one. Without this, the early iterations would be free to emit arbitrary “scratchpad” frames that the later iterations clean up — and you’d lose the interpretability of reading structures off each iteration at inference time. With per-iteration FAPE, each iteration is trained to be an incremental refinement.

12.3 All-atom FAPE

The final piece is all-atom FAPE. Once the Structure Module has emitted its last set of backbone frames and the all-atom reconstruction (§10d) has placed every side-chain atom, FAPE is evaluated on the full atom14 tensor — every backbone and side-chain atom of every residue, as seen by every residue’s backbone frame.

class AllAtomFAPE(torch.nn.Module):
    """Final all-atom FAPE (Algorithm 20 line 28).

    Scores the 8 per-residue rigid-group frames (3 backbone + 4 side-chain
    torsion frames + ψ frame, see Table 2) against all 14 atom positions of
    every residue after the symmetric-ground-truth renaming of Algorithm 26.

    Per supplement 1.11.5, the side-chain FAPE is *always* clamped by
    d_clamp = 10 Å (unlike the backbone FAPE, which is unclamped in 10% of
    mini-batches). Uses ε = 10⁻⁴ per Algorithm 20 line 28.
    """

    def __init__(self, d_clamp: float = 10.0, eps: float = 1e-4, Z: float = 10.0):
        super().__init__()
        self.eps = eps
        self.d_clamp_val = d_clamp
        self.Z = Z

    def forward(self,
                predicted_frames_R,       # (b, N_res, 8, 3, 3)
                predicted_frames_t,       # (b, N_res, 8, 3)
                predicted_atom_positions, # (b, N_res, 14, 3)
                atom_mask,                # (b, N_res, 14) — predicted atom existence
                true_frames_R,            # (b, N_res, 8, 3, 3)
                true_frames_t,            # (b, N_res, 8, 3)
                true_atom_positions,      # (b, N_res, 14, 3)
                true_atom_mask: Optional[torch.Tensor] = None,  # (b, N_res, 14)
                seq_mask: Optional[torch.Tensor] = None,       # (b, N_res)
                frame_mask: Optional[torch.Tensor] = None,     # (b, N_res, 8)
    ):
        b, N_res, n_frames = predicted_frames_R.shape[:3]
        n_atoms = predicted_atom_positions.shape[2]

        # Flatten per-residue rigid groups and atoms into a single list of
        # frames / atoms so FAPE scores every (frame, atom) pair, as required
        # by Algorithm 20 line 28's `mean_{i,j}` over the full structure.
        pred_R = predicted_frames_R.reshape(b, N_res * n_frames, 3, 3)
        pred_t = predicted_frames_t.reshape(b, N_res * n_frames, 3)
        true_R = true_frames_R.reshape(b, N_res * n_frames, 3, 3)
        true_t = true_frames_t.reshape(b, N_res * n_frames, 3)
        pred_pos = predicted_atom_positions.reshape(b, N_res * n_atoms, 3)
        true_pos = true_atom_positions.reshape(b, N_res * n_atoms, 3)

        flat_atom_mask = atom_mask.reshape(b, N_res * n_atoms)
        if true_atom_mask is not None:
            flat_atom_mask = flat_atom_mask * true_atom_mask.reshape(b, N_res * n_atoms)

        group_mask = (
            frame_mask.to(predicted_frames_R.dtype)
            if frame_mask is not None
            else predicted_frames_R.new_ones(b, N_res, n_frames)
        )

        # Fold the residue-level seq_mask into both the per-atom mask and
        # per-frame mask so padded residues do not contribute.
        if seq_mask is not None:
            seq_atom_mask = seq_mask[:, :, None].expand(-1, -1, n_atoms).reshape(b, N_res * n_atoms)
            flat_atom_mask = flat_atom_mask * seq_atom_mask
            flat_frame_mask = (seq_mask[:, :, None] * group_mask).reshape(b, N_res * n_frames)
        else:
            flat_frame_mask = group_mask.reshape(b, N_res * n_frames)

        return frame_aligned_point_error(
            predicted_rotations=pred_R,
            predicted_translations=pred_t,
            true_rotations=true_R,
            true_translations=true_t,
            predicted_positions=pred_pos,
            true_positions=true_pos,
            frames_mask=flat_frame_mask,
            positions_mask=flat_atom_mask,
            length_scale=self.Z,
            # Supplement 1.11.5: side-chain FAPE is always clamped.
            l1_clamp_distance=self.d_clamp_val,
            eps=self.eps,
        )

One subtlety to look for: the reference (ground-truth) frames for this loss come from the true backbone atoms, not the predicted ones. And since some residues have symmetric side chains (tyrosine’s ring, for example), the ground truth is “renamed” — Algorithm 26 of the supplement — to match whichever of two chemically-equivalent permutations the model happened to predict. That permutation search is handled upstream of this loss, so by the time FAPE runs, both predicted and true atoms are in a canonical order.

12.4 Torsion angle loss

Side-chain torsions (χ1\chi_1 through χ4\chi_4) and backbone torsions (ϕ\phi, ψ\psi, ω\omega) need their own supervision. Torsion angles live on the circle S1S^1, so the naive thing — L2 on the angle itself — has a discontinuity at ±π\pm \pi that kills gradients whenever the true angle is near the wrap-around. AlphaFold2 sidesteps this by predicting (sinα,cosα)(\sin\alpha, \cos\alpha) pairs and supervising those:

class TorsionAngleLoss(torch.nn.Module):
    """Side-chain and backbone torsion angle loss (Algorithm 27).

    Scores all 7 torsions f ∈ {ω, φ, ψ, χ1, χ2, χ3, χ4} per the supplement.
    Validity of each torsion is carried by `torsion_mask_true` (shape
    [..., 7]): ω/φ are undefined for the first residue, χ1..χ4 existence
    depends on residue type, and any torsion whose atoms are missing in the
    ground truth is masked out. See `geometry.torsion_angles` for how the
    mask is built from atom14 data.

    The min-of-(true, alt_true) term in line 3 handles 180°-rotation symmetry
    for ASP/GLU/PHE/TYR; for all other torsions alt_true == true, so the
    minimum reduces to the plain L2 term.

    Both `torsion_weight` and `angle_norm_weight` pre-apply the outer 0.5
    factor that equation (7) multiplies L_aux by: they are (1.0, 0.02) from
    Algorithm 27 times 0.5, i.e. (0.5, 0.01). Keeping the pre-multiplication
    here lets `AlphaFoldLoss` add the returned value directly.
    """

    def __init__(self):
        super().__init__()
        self.torsion_weight = 0.5
        self.angle_norm_weight = 0.01

    def forward(
        self,
        torsion_angles: torch.Tensor,
        unnormalized_torsion_angles: torch.Tensor,
        torsion_angles_true: torch.Tensor,
        torsion_angles_true_alt: torch.Tensor,
        torsion_mask_true: torch.Tensor,
        seq_mask: Optional[torch.Tensor] = None,
    ):
        # Prepend the per-layer trajectory dim if only the final iteration was
        # passed (Algorithm 20 averages L_aux over layers; the normalization
        # below sums across L so a single layer broadcasts as L=1).
        if torsion_angles.ndim == 4:
            torsion_angles = torsion_angles.unsqueeze(0)
            unnormalized_torsion_angles = unnormalized_torsion_angles.unsqueeze(0)

        true_angles = torsion_angles_true.unsqueeze(0)
        true_alt = torsion_angles_true_alt.unsqueeze(0)
        mask = torsion_mask_true.unsqueeze(0)  # (1, b, N_res, 7)
        if seq_mask is not None:
            mask = mask * seq_mask.unsqueeze(0).unsqueeze(-1)

        # Algorithm 27 line 3: L_torsion = mean_{i,f} min(||α̂ - α_true||², ||α̂ - α_alt||²).
        true_dist_sq = torch.sum((true_angles - torsion_angles) ** 2, dim=-1)
        alt_dist_sq = torch.sum((true_alt - torsion_angles) ** 2, dim=-1)
        torsion_dist_sq = torch.minimum(true_dist_sq, alt_dist_sq)

        torsion_normalizer = mask.sum(dim=(0, 2, 3)).clamp(min=1.0)
        torsion_loss = torch.sum(torsion_dist_sq * mask, dim=(0, 2, 3)) / torsion_normalizer

        # Algorithm 27 line 4: L_anglenorm = mean_{i,f} |||α_tilde_i^f|| - 1|.
        # Covers all 7 torsions regardless of mask — this regulariser keeps the
        # raw network output close to unit norm before the α̂ = α̃ / ||α̃||
        # normalization, independent of whether the torsion is supervised.
        angle_norm = torch.sqrt(torch.sum(unnormalized_torsion_angles ** 2, dim=-1) + 1e-8)
        if seq_mask is not None:
            angle_norm_mask = seq_mask.unsqueeze(0).unsqueeze(-1).expand_as(angle_norm)
        else:
            angle_norm_mask = torch.ones_like(angle_norm)
        norm_normalizer = angle_norm_mask.sum(dim=(0, 2, 3)).clamp(min=1.0)
        angle_norm_loss = torch.sum(torch.abs(angle_norm - 1.0) * angle_norm_mask, dim=(0, 2, 3)) / norm_normalizer

        return self.torsion_weight * torsion_loss + self.angle_norm_weight * angle_norm_loss

The loss has two terms: 2\ell_2 between the predicted (sin,cos)(\sin, \cos) pair and the ground truth (sin,cos)(\sin, \cos), plus a penalty that encourages the predicted pair to lie on the unit circle (i.e., sin2+cos2=1\sqrt{\sin^2 + \cos^2} = 1). The unit-circle penalty matters because the model’s raw output is a 2D vector that’s normalized before being interpreted as an angle — regularizing the raw output to already be approximately unit-norm keeps training stable.

12.5 pLDDT, distogram, experimentally-resolved, violations

Four more losses round out the full objective. They’re mostly classification losses on the auxiliary heads from §11, plus a physical-plausibility loss.

class PLDDTLoss(torch.nn.Module):
    """Model-confidence loss L_conf (supplement 1.9.6, Algorithm 29 line 4).

    Cross-entropy between the predicted pLDDT distribution (from Algorithm 29
    lines 1-2) and the one-hot discretisation of the per-residue true lDDT-Cα
    score into 50 bins of width 2. The true lDDT-Cα is computed in
    `AlphaFoldLoss.compute_loss_terms`, this module just performs the CE.

    `filter_by_resolution` zeros the loss on examples whose crystal-structure
    resolution falls outside [0.1 Å, 3.0 Å] per supplement 1.9.6.
    """

    def __init__(
        self,
        *,
        filter_by_resolution: bool = False,
        min_resolution: float = 0.1,
        max_resolution: float = 3.0,
    ):
        super().__init__()
        self.filter_by_resolution = filter_by_resolution
        self.min_resolution = min_resolution
        self.max_resolution = max_resolution

    def forward(
        self,
        pred_plddt: torch.Tensor,
        true_plddt: torch.Tensor,
        seq_mask: Optional[torch.Tensor] = None,
        resolution: Optional[torch.Tensor] = None,
    ):
        # pred_plddt, true_plddt: (batch, N_res, n_plddt_bins), latter one-hot.
        log_pred = torch.log_softmax(pred_plddt, dim=-1)

        # Algorithm 29 line 4: L_conf = mean_i(p_i^{true LDDT T} · log p_i^pLDDT).
        # (The published formula omits the minus sign; cross-entropy is a
        # negative log-likelihood, hence the sign here.)
        conf_loss = -torch.einsum('bic, bic -> bi', true_plddt, log_pred)

        if seq_mask is not None:
            conf_loss = conf_loss * seq_mask
            conf_loss = conf_loss.sum(dim=-1) / seq_mask.sum(dim=-1).clamp(min=1)
        else:
            conf_loss = torch.mean(conf_loss, dim=-1)

        if self.filter_by_resolution and resolution is not None:
            resolution = resolution.to(conf_loss.device, dtype=conf_loss.dtype).reshape(-1)
            if resolution.numel() == 1 and conf_loss.numel() != 1:
                resolution = resolution.expand_as(conf_loss)
            else:
                resolution = resolution.reshape(conf_loss.shape)
            in_range = (
                (resolution >= self.min_resolution)
                & (resolution <= self.max_resolution)
            ).to(conf_loss.dtype)
            conf_loss = conf_loss * in_range

        return conf_loss

pLDDT is cross-entropy on 50 lDDT-CαC_\alpha bins. lDDT itself is computed by scanning through all CαC_\alpha pairs within a distance cutoff and checking how close the predicted distance is to the true distance — a local, alignment-free fidelity measure. The model’s pLDDT prediction head is trained to classify which of 50 lDDT bins the per-residue lDDT will fall into. This is a classic “learn to predict your own error” setup, and it gives users the confidence scores they eventually see on predicted structures.

class DistogramLoss(torch.nn.Module):
    """Distogram cross-entropy L_dist (supplement 1.9.8 equation 41).

        L_dist = -1/N_res^2 Σ_{i,j} Σ_b y_{ij}^b log p_{ij}^b

    where p_{ij}^b comes from the distogram head applied to the symmetrised
    pair representation and y_{ij}^b is the one-hot encoding of the true
    Cβ-Cβ distance (Cα for glycine) into 64 equal-width bins covering 2-22 Å,
    with the final bin catching anything more distant. Target construction
    lives in `AlphaFoldLoss.compute_loss_terms`.

    When `pair_mask` masks padded residues, the denominator is the number of
    *valid* pairs rather than `N_res^2` so variable-length batches behave
    sensibly; on a full un-padded crop the two agree.
    """

    def __init__(self):
        super().__init__()

    def forward(self, pred_distograms: torch.Tensor, true_distograms: torch.Tensor,
                pair_mask: Optional[torch.Tensor] = None):
        # input shapes: (batch, N_res, N_res, num_dist_buckets)
        log_pred = torch.log_softmax(pred_distograms, dim=-1)
        vals = torch.einsum('bijc, bijc -> bij', true_distograms, log_pred)
        if pair_mask is not None:
            vals = vals * pair_mask
            dist_loss = -vals.sum(dim=(1, 2)) / pair_mask.sum(dim=(1, 2)).clamp(min=1)
        else:
            dist_loss = -torch.mean(vals, dim=(1, 2))
        return dist_loss

Distogram is cross-entropy on 64 bins of CβC_\betaCβC_\beta distance (or pseudo-CβC_\beta for glycines). It reads from zz, not from the predicted structure, which means the pair representation has to learn to encode pairwise distances before the structure module runs. In practice, this makes the distogram head a rich auxiliary signal that helps the pair stack learn contact-map-like features early in training.

class ExperimentallyResolvedLoss(torch.nn.Module):
    """Experimentally-resolved loss L_exp_resolved (supplement 1.9.10 eq 43).

        L_exp_resolved = mean_{(i,a)} (
            -y_i^a log p_i^{exp resolved,a}
            -(1 - y_i^a) log(1 - p_i^{exp resolved,a}))

    Binary cross-entropy per (residue, atom37) slot predicting whether that
    atom was resolved in the crystal structure. Only used during fine-tuning
    (eq 7 fine-tuning row) and only on examples with resolution in
    [0.1 Å, 3.0 Å] (1.9.10).

    The `atom37_exists` mask restricts the mean to atom slots that are
    defined for the residue type — slots that do not exist (e.g. χ atoms on
    ALA) would otherwise inject meaningless BCE terms.
    """

    def __init__(
        self,
        *,
        filter_by_resolution: bool = False,
        min_resolution: float = 0.1,
        max_resolution: float = 3.0,
    ):
        super().__init__()
        self.filter_by_resolution = filter_by_resolution
        self.min_resolution = min_resolution
        self.max_resolution = max_resolution

    def forward(
        self,
        exp_resolved_preds: torch.Tensor,
        exp_resolved_true: torch.Tensor,
        atom37_exists: torch.Tensor,
        resolution: Optional[torch.Tensor] = None,
    ):
        xent = torch.nn.functional.binary_cross_entropy_with_logits(
            exp_resolved_preds,
            exp_resolved_true,
            reduction="none",
        )
        weighted_xent = xent * atom37_exists
        normalizer = atom37_exists.sum(dim=(1, 2)).clamp(min=1.0)
        loss = weighted_xent.sum(dim=(1, 2)) / normalizer

        if self.filter_by_resolution and resolution is not None:
            resolution = resolution.to(loss.device, dtype=loss.dtype).reshape(-1)
            if resolution.numel() == 1 and loss.numel() != 1:
                resolution = resolution.expand_as(loss)
            else:
                resolution = resolution.reshape(loss.shape)
            in_range = (
                (resolution >= self.min_resolution)
                & (resolution <= self.max_resolution)
            ).to(loss.dtype)
            loss = loss * in_range

        return loss

Experimentally-resolved is binary cross-entropy per atom — for each of the 14 possible atoms per residue, predict whether it was resolved in the experimental structure. Loops, terminal regions, and some side chains are often missing from crystal structures, and the model needs to know that “not seeing an atom” in the ground truth isn’t the same as “the atom doesn’t exist.” This loss lets the model explicitly reason about what it should and shouldn’t be confident about.

class StructuralViolationLoss(torch.nn.Module):
    """Structural-violation loss L_viol (supplement 1.9.11 equations 44-47).

        L_viol = L_bondlength + L_bondangle + L_clash        (eq 47)

    * `L_bondlength` (eq 44): flat-bottom L1 on inter-residue C-N peptide
      bond lengths relative to literature values, tolerance 12 σ_lit.
    * `L_bondangle` (eq 45): flat-bottom L1 on the cosine of each peptide
      bond angle (CA-C-N and C-N-CA) against literature, tolerance 12 σ_lit.
      Supplement 1.9.11 describes a single bond-angle term; we sum the two
      peptide-bond angles, which is a sum of two paper-faithful flat-bottom
      L1 terms and matches what the DeepMind reference releases compute.
    * `L_clash` (eq 46): one-sided flat-bottom on VDW overlaps between
      non-bonded heavy atoms with tolerance 1.5 Å. Split into between- and
      within-residue halves for memory, averaged per atom so the term sits
      at a sane scale relative to L_bondlength / L_bondangle (eq 7 applies
      an absolute weight of 1.0 to L_viol).

    Used only during fine-tuning (eq 7 fine-tuning row).
    """

    # Declaring registered buffers as class-level Tensor annotations lets type
    # checkers see them as tensors rather than the parent ``Module``.
    vdw_table: torch.Tensor
    distance_lower_bound_table: torch.Tensor
    distance_upper_bound_table: torch.Tensor

    def __init__(
        self,
        violation_tolerance_factor: float = 12.0,
        clash_overlap_tolerance: float = 1.5,
    ):
        super().__init__()
        bounds = make_atom14_dists_bounds(
            overlap_tolerance=clash_overlap_tolerance,
            bond_length_tolerance_factor=violation_tolerance_factor,
        )
        self.register_buffer('vdw_table', torch.tensor(restype_atom14_vdw_radius))
        self.register_buffer('distance_lower_bound_table', torch.tensor(bounds["lower_bound"]))
        self.register_buffer('distance_upper_bound_table', torch.tensor(bounds["upper_bound"]))
        self.violation_tolerance_factor = violation_tolerance_factor
        self.clash_overlap_tolerance = clash_overlap_tolerance

    def forward(
        self,
        predicted_positions,  # (batch, N_res, 14, 3) — all-atom coordinates
        atom_mask,            # (batch, N_res, 14)    — 1 if atom exists, 0 otherwise
        residue_types,        # (batch, N_res)         — integer residue type index (0–20)
        residue_index,        # (batch, N_res)
    ):
        connection_violations = self.between_residue_bond_and_angle_loss(
            predicted_positions,
            atom_mask,
            residue_types,
            residue_index,
        )
        between_residue_clashes = self.between_residue_clash_loss(
            predicted_positions,
            atom_mask,
            residue_types,
            residue_index,
        )
        within_residue_violations = self.within_residue_violation_loss(
            predicted_positions,
            atom_mask,
            residue_types,
        )
        num_atoms = torch.sum(atom_mask, dim=(1, 2)).clamp(min=1e-6)
        per_atom_clash = (
            between_residue_clashes["per_atom_loss_sum"]
            + within_residue_violations["per_atom_loss_sum"]
        )
        clash_loss = torch.sum(per_atom_clash, dim=(1, 2)) / num_atoms
        return (
            connection_violations["c_n_loss_mean"]
            + connection_violations["ca_c_n_loss_mean"]
            + connection_violations["c_n_ca_loss_mean"]
            + clash_loss
        )

    def between_residue_bond_and_angle_loss(
        self,
        predicted_positions,  # (batch, N_res, 14, 3) — all-atom coordinates
        atom_mask,            # (batch, N_res, 14)    — 1 if atom exists, 0 otherwise
        residue_types,        # (batch, N_res)         — integer residue type index (0–20)
        residue_index,        # (batch, N_res)
    ):
        eps = 1e-6
        this_ca_pos = predicted_positions[:, :-1, 1, :]
        this_ca_mask = atom_mask[:, :-1, 1]
        this_c_pos = predicted_positions[:, :-1, 2, :]
        this_c_mask = atom_mask[:, :-1, 2]
        next_n_pos = predicted_positions[:, 1:, 0, :]
        next_n_mask = atom_mask[:, 1:, 0]
        next_ca_pos = predicted_positions[:, 1:, 1, :]
        next_ca_mask = atom_mask[:, 1:, 1]
        has_no_gap_mask = (residue_index[:, 1:] - residue_index[:, :-1] == 1).to(predicted_positions.dtype)

        c_n_bond_length = torch.sqrt(torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1) + eps)
        next_is_proline = (residue_types[:, 1:] == 14).to(predicted_positions.dtype)
        gt_length = (
            (1.0 - next_is_proline) * between_res_bond_length_c_n[0]
            + next_is_proline * between_res_bond_length_c_n[1]
        )
        gt_stddev = (
            (1.0 - next_is_proline) * between_res_bond_length_stddev_c_n[0]
            + next_is_proline * between_res_bond_length_stddev_c_n[1]
        )
        c_n_mask = this_c_mask * next_n_mask * has_no_gap_mask
        c_n_error = torch.sqrt((c_n_bond_length - gt_length) ** 2 + eps)
        c_n_loss_per_residue = torch.relu(c_n_error - self.violation_tolerance_factor * gt_stddev)
        c_n_loss = torch.sum(c_n_mask * c_n_loss_per_residue, dim=-1) / (torch.sum(c_n_mask, dim=-1) + eps)
        c_n_violation_mask = c_n_mask * (c_n_error > (self.violation_tolerance_factor * gt_stddev))

        ca_c_bond_length = torch.sqrt(torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1) + eps)
        n_ca_bond_length = torch.sqrt(torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1) + eps)
        c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
        c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
        n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]

        ca_c_n_metric = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
        ca_c_n_mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
        ca_c_n_error = torch.sqrt((ca_c_n_metric - between_res_cos_angles_ca_c_n[0]) ** 2 + eps)
        ca_c_n_loss_per_residue = torch.relu(
            ca_c_n_error - self.violation_tolerance_factor * between_res_cos_angles_ca_c_n[1]
        )
        ca_c_n_loss = torch.sum(ca_c_n_mask * ca_c_n_loss_per_residue, dim=-1) / (
            torch.sum(ca_c_n_mask, dim=-1) + eps
        )
        ca_c_n_violation_mask = ca_c_n_mask * (
            ca_c_n_error > (self.violation_tolerance_factor * between_res_cos_angles_ca_c_n[1])
        )

        c_n_ca_metric = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
        c_n_ca_mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
        c_n_ca_error = torch.sqrt((c_n_ca_metric - between_res_cos_angles_c_n_ca[0]) ** 2 + eps)
        c_n_ca_loss_per_residue = torch.relu(
            c_n_ca_error - self.violation_tolerance_factor * between_res_cos_angles_c_n_ca[1]
        )
        c_n_ca_loss = torch.sum(c_n_ca_mask * c_n_ca_loss_per_residue, dim=-1) / (
            torch.sum(c_n_ca_mask, dim=-1) + eps
        )
        c_n_ca_violation_mask = c_n_ca_mask * (
            c_n_ca_error > (self.violation_tolerance_factor * between_res_cos_angles_c_n_ca[1])
        )

        per_residue_loss_sum = c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
        per_residue_loss_sum = 0.5 * (
            torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
            + torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
        )
        violation_mask = torch.max(
            torch.stack(
                [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
                dim=-1,
            ),
            dim=-1,
        ).values
        violation_mask = torch.maximum(
            torch.nn.functional.pad(violation_mask, (0, 1)),
            torch.nn.functional.pad(violation_mask, (1, 0)),
        )
        return {
            "c_n_loss_mean": c_n_loss,
            "ca_c_n_loss_mean": ca_c_n_loss,
            "c_n_ca_loss_mean": c_n_ca_loss,
            "per_residue_loss_sum": per_residue_loss_sum,
            "per_residue_violation_mask": violation_mask,
        }

    def between_residue_clash_loss(
        self,
        predicted_positions,  # (batch, N_res, 14, 3) — all-atom coordinates
        atom_mask,            # (batch, N_res, 14)    — 1 if atom exists, 0 otherwise
        residue_types,        # (batch, N_res)         — integer residue type index (0–20)
        residue_index,        # (batch, N_res)
    ):
        batch, N_res = predicted_positions.shape[:2]
        overlap_tolerance = self.clash_overlap_tolerance

        # Flatten atoms: (batch, N_res*14, 3) and (batch, N_res*14)
        pos_flat = predicted_positions.reshape(batch, N_res * 14, 3)
        mask_flat = atom_mask.reshape(batch, N_res * 14)

        # VDW radii per atom: look up from registered buffer
        # residue_types: (batch, N_res) -> vdw: (batch, N_res, 14)
        residue_types_clamped = residue_types.clamp(max=20)
        vdw = self.vdw_table[residue_types_clamped]  # (batch, N_res, 14)
        vdw_flat = vdw.reshape(batch, N_res * 14)  # (batch, N_res*14)

        # Pairwise distances: (batch, N_res*14, N_res*14)
        diff = pos_flat[:, :, None, :] - pos_flat[:, None, :, :]  # (batch, M, M, 3)
        dist = torch.sqrt(torch.sum(diff ** 2, dim=-1) + 1e-8)   # (batch, M, M)

        # Pair mask: both atoms valid
        pair_mask = mask_flat[:, :, None] * mask_flat[:, None, :]  # (batch, M, M)

        atom_residue_index = residue_index.repeat_interleave(14, dim=1)
        unique_residue_pairs = (atom_residue_index[:, :, None] < atom_residue_index[:, None, :]).to(
            predicted_positions.dtype
        )
        pair_mask = pair_mask * unique_residue_pairs

        atom_slot_index = torch.arange(14, device=predicted_positions.device).repeat(N_res)
        atom_slot_index = atom_slot_index.unsqueeze(0).expand(batch, -1)
        residue_type_flat = residue_types.repeat_interleave(14, dim=1)

        c_n_bond = (
            (atom_slot_index[:, :, None] == 2)
            & (atom_slot_index[:, None, :] == 0)
            & (atom_residue_index[:, None, :] - atom_residue_index[:, :, None] == 1)
        ) | (
            (atom_slot_index[:, :, None] == 0)
            & (atom_slot_index[:, None, :] == 2)
            & (atom_residue_index[:, :, None] - atom_residue_index[:, None, :] == 1)
        )
        disulfide_bond = (
            (residue_type_flat[:, :, None] == 4)
            & (residue_type_flat[:, None, :] == 4)
            & (atom_slot_index[:, :, None] == 5)
            & (atom_slot_index[:, None, :] == 5)
        )
        pair_mask = pair_mask * (~c_n_bond).to(predicted_positions.dtype) * (~disulfide_bond).to(predicted_positions.dtype)

        # Overlap: vdw_i + vdw_j - tolerance - dist
        vdw_sum = vdw_flat[:, :, None] + vdw_flat[:, None, :]  # (batch, M, M)
        overlap = vdw_sum - overlap_tolerance - dist

        clash = torch.clamp(overlap, min=0) * pair_mask
        mean_loss = torch.sum(clash, dim=(1, 2)) / torch.sum(pair_mask, dim=(1, 2)).clamp(min=1e-6)
        per_atom_loss_sum = (torch.sum(clash, dim=1) + torch.sum(clash, dim=2)).reshape(batch, N_res, 14)
        clash_mask = pair_mask * (dist < (vdw_sum - overlap_tolerance))
        per_atom_clash_mask = torch.maximum(
            torch.amax(clash_mask, dim=1),
            torch.amax(clash_mask, dim=2),
        ).reshape(batch, N_res, 14)
        per_atom_num_clash = (torch.sum(clash_mask, dim=1) + torch.sum(clash_mask, dim=2)).reshape(batch, N_res, 14)
        return {
            "mean_loss": mean_loss,
            "per_atom_loss_sum": per_atom_loss_sum,
            "per_atom_clash_mask": per_atom_clash_mask,
            "per_atom_num_clash": per_atom_num_clash,
        }

    def within_residue_violation_loss(
        self,
        predicted_positions,  # (batch, N_res, 14, 3)
        atom_mask,            # (batch, N_res, 14)
        residue_types,        # (batch, N_res)
    ):
        residue_types_clamped = residue_types.clamp(max=20)
        lower_bound = self.distance_lower_bound_table[residue_types_clamped]
        upper_bound = self.distance_upper_bound_table[residue_types_clamped]

        distances = torch.sqrt(
            torch.sum(
                (predicted_positions[:, :, :, None, :] - predicted_positions[:, :, None, :, :]) ** 2,
                dim=-1,
            ) + 1e-8
        )
        pair_mask = atom_mask[:, :, :, None] * atom_mask[:, :, None, :]
        eye = torch.eye(14, device=predicted_positions.device, dtype=predicted_positions.dtype)
        pair_mask = pair_mask * (1.0 - eye.view(1, 1, 14, 14))
        bound_mask = ((lower_bound > 0.0) | (upper_bound > 0.0)).to(predicted_positions.dtype)
        pair_mask = pair_mask * bound_mask

        lower_violation = torch.clamp(lower_bound - distances, min=0.0)
        upper_violation = torch.clamp(distances - upper_bound, min=0.0)
        loss = (lower_violation + upper_violation) * pair_mask
        per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)

        violations = pair_mask * ((distances < lower_bound) | (distances > upper_bound)).to(predicted_positions.dtype)
        per_atom_violations = torch.maximum(
            torch.amax(violations, dim=-2),
            torch.amax(violations, dim=-1),
        )
        per_atom_num_clash = torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1)
        return {
            "per_atom_loss_sum": per_atom_loss_sum,
            "per_atom_violations": per_atom_violations,
            "per_atom_num_clash": per_atom_num_clash,
        }

Structural violations is the physicality constraint. It has three components: bond length deviations (are consecutive peptide bonds the right length?), bond angle deviations (do the NNCαC_\alphaCC and CαC_\alphaCCNN angles match chemistry?), and van der Waals clashes (are any two atoms closer than the sum of their VDW radii minus a slack?). Each component is a soft penalty on physically implausible geometries. In practice, this loss is turned on later in training — it’s weighted to zero at the start and ramps up — because early in training, the structures are too wrong for “no clashes” to be a productive constraint.

12.6 The combined loss

All of the above land in a single weighted sum:

class AlphaFoldLoss(torch.nn.Module):
    """Combined AlphaFold loss (supplement 1.9 equation 7).

    Training:
        L = 0.5 L_FAPE + 0.5 L_aux + 0.3 L_dist + 2.0 L_msa + 0.01 L_conf
    Fine-tuning (equation 7 fine-tuning row + supplement 1.9.7):
        L += 0.01 L_exp_resolved + 1.0 L_viol + 0.1 L_pae

    Term-by-term mapping (all weights below are *absolute* eq 7 weights, not
    relative fractions):

    * `sidechain_weight_frac = 0.5` doubles as the weight of L_FAPE (final
      all-atom FAPE, Algorithm 20 line 28) AND the weight on the FAPE half
      of L_aux (backbone FAPE averaged over iterations, Algorithm 20 line 17).
      We decompose `0.5 L_aux = 0.5 L_aux^{FAPE} + 0.5 L_aux^{torsion}` and
      sum `0.5 L_FAPE + 0.5 L_aux^{FAPE}` directly.
    * `distogram_weight = 0.3`   → 0.3 L_dist (supplement 1.9.8, eq 41).
    * `msa_weight = 2.0`         → 2.0 L_msa  (supplement 1.9.9, eq 42).
    * `confidence_weight = 0.01` → 0.01 L_conf (supplement 1.9.6, Alg 29).
    * `experimentally_resolved_weight = 0.01` → 0.01 L_exp_resolved (1.9.10).
    * `structural_violation_weight = 1.0` → 1.0 L_viol (1.9.11, eq 47).
    * `tm_score_weight = 0.1` → 0.1 L_pae (supplement 1.9.7, paragraph
      following equation 38: "average categorical cross-entropy loss, with
      weight 0.1"). Requires `tm_pred` to be passed; otherwise skipped.

    `TorsionAngleLoss` pre-applies the outer 0.5 factor on L_aux^{torsion}
    and the 0.02 coefficient on L_anglenorm from Algorithm 27, so
    `weighted_torsion_loss = torsion_loss` without an additional weight.

    `use_clamped_fape` implements supplement 1.11.5: in 90% of training
    mini-batches the backbone FAPE is clamped by 10 Å, and unclamped in the
    remaining 10%. Passing a float in [0, 1] mixes the two versions with
    the given clamped weight; `None` defaults to fully clamped. The side-
    chain FAPE is always clamped regardless, so this knob never reaches it.
    """

    def __init__(self, finetune: bool = False, use_clamped_fape: Optional[float] = None):
        super().__init__()
        self.torsion_angle_loss = TorsionAngleLoss()
        self.plddt_loss = PLDDTLoss(
            filter_by_resolution=True,
            min_resolution=0.1,
            max_resolution=3.0,
        )
        self.distogram_loss = DistogramLoss()
        self.msa_loss = MSALoss()
        self.experimentally_resolved_loss = ExperimentallyResolvedLoss(
            filter_by_resolution=True,
            min_resolution=0.1,
            max_resolution=3.0,
        )
        self.structural_violation_loss = StructuralViolationLoss()
        self.tm_score_loss = TMScoreLoss(
            filter_by_resolution=True,
            min_resolution=0.1,
            max_resolution=3.0,
        )
        self.backbone_loss = BackboneTrajectoryLoss()
        self.sidechain_fape_loss = AllAtomFAPE()

        # Equation 7 weights, all absolute (not relative).
        self.sidechain_weight_frac = 0.5
        self.distogram_weight = 0.3
        self.msa_weight = 2.0
        self.confidence_weight = 0.01
        self.experimentally_resolved_weight = 0.01
        self.structural_violation_weight = 1.0
        self.tm_score_weight = 0.1  # supplement 1.9.7 (paragraph after eq 38)

        self.finetune = finetune
        self.use_clamped_fape = use_clamped_fape

    def forward(
            self,
            structure_model_prediction: dict,
            true_rotations: torch.Tensor,           # (b, N_res, 3, 3)
            true_translations: torch.Tensor,        # (b, N_res, 3)
            true_atom_positions: torch.Tensor,      # (b, N_res, 14, 3)
            true_atom_mask: torch.Tensor,           # (b, N_res, 14)
            true_atom_positions_alt: torch.Tensor,
            true_atom_mask_alt: torch.Tensor,
            true_atom_is_ambiguous: torch.Tensor,
            true_torsion_angles: torch.Tensor,      # (b, N_res, 7, 2)
            true_torsion_angles_alt: torch.Tensor,  # (b, N_res, 7, 2)
            true_torsion_mask: torch.Tensor,        # (b, N_res, 7)
            true_rigid_group_frames_R: torch.Tensor,
            true_rigid_group_frames_t: torch.Tensor,
            true_rigid_group_frames_R_alt: torch.Tensor,
            true_rigid_group_frames_t_alt: torch.Tensor,
            true_rigid_group_exists: torch.Tensor,
            experimentally_resolved_pred: torch.Tensor,
            experimentally_resolved_true: torch.Tensor,
            experimentally_resolved_exists: torch.Tensor,
            masked_msa_pred: torch.Tensor,
            masked_msa_target: torch.Tensor,
            masked_msa_mask: torch.Tensor,
            plddt_pred: torch.Tensor,
            distogram_pred: torch.Tensor,
            res_types: torch.Tensor,                # (b, N_res) integer 0-20
            residue_index: torch.Tensor,
            seq_mask: Optional[torch.Tensor] = None,  # (b, N_res) 1=valid, 0=padding
            return_breakdown: bool = False,
            resolution: Optional[torch.Tensor] = None,
            tm_pred: Optional[torch.Tensor] = None,  # (b, N_res, N_res, n_pae_bins)
        ):
        loss_terms = self.compute_loss_terms(
            structure_model_prediction=structure_model_prediction,
            true_rotations=true_rotations,
            true_translations=true_translations,
            true_atom_positions=true_atom_positions,
            true_atom_mask=true_atom_mask,
            true_atom_positions_alt=true_atom_positions_alt,
            true_atom_mask_alt=true_atom_mask_alt,
            true_atom_is_ambiguous=true_atom_is_ambiguous,
            true_torsion_angles=true_torsion_angles,
            true_torsion_angles_alt=true_torsion_angles_alt,
            true_torsion_mask=true_torsion_mask,
            true_rigid_group_frames_R=true_rigid_group_frames_R,
            true_rigid_group_frames_t=true_rigid_group_frames_t,
            true_rigid_group_frames_R_alt=true_rigid_group_frames_R_alt,
            true_rigid_group_frames_t_alt=true_rigid_group_frames_t_alt,
            true_rigid_group_exists=true_rigid_group_exists,
            experimentally_resolved_pred=experimentally_resolved_pred,
            experimentally_resolved_true=experimentally_resolved_true,
            experimentally_resolved_exists=experimentally_resolved_exists,
            resolution=resolution,
            masked_msa_pred=masked_msa_pred,
            masked_msa_target=masked_msa_target,
            masked_msa_mask=masked_msa_mask,
            plddt_pred=plddt_pred,
            distogram_pred=distogram_pred,
            res_types=res_types,
            residue_index=residue_index,
            seq_mask=seq_mask,
            tm_pred=tm_pred,
        )
        if return_breakdown:
            return loss_terms["loss"], loss_terms
        return loss_terms["loss"]

    def compute_loss_terms(
            self,
            structure_model_prediction: dict,
            true_rotations: torch.Tensor,
            true_translations: torch.Tensor,
            true_atom_positions: torch.Tensor,
            true_atom_mask: torch.Tensor,
            true_atom_positions_alt: torch.Tensor,
            true_atom_mask_alt: torch.Tensor,
            true_atom_is_ambiguous: torch.Tensor,
            true_torsion_angles: torch.Tensor,
            true_torsion_angles_alt: torch.Tensor,
            true_torsion_mask: torch.Tensor,
            true_rigid_group_frames_R: torch.Tensor,
            true_rigid_group_frames_t: torch.Tensor,
            true_rigid_group_frames_R_alt: torch.Tensor,
            true_rigid_group_frames_t_alt: torch.Tensor,
            true_rigid_group_exists: torch.Tensor,
            experimentally_resolved_pred: torch.Tensor,
            experimentally_resolved_true: torch.Tensor,
            experimentally_resolved_exists: torch.Tensor,
            masked_msa_pred: torch.Tensor,
            masked_msa_target: torch.Tensor,
            masked_msa_mask: torch.Tensor,
            plddt_pred: torch.Tensor,
            distogram_pred: torch.Tensor,
            res_types: torch.Tensor,
            residue_index: torch.Tensor,
            seq_mask: Optional[torch.Tensor] = None,
            resolution: Optional[torch.Tensor] = None,
            tm_pred: Optional[torch.Tensor] = None,
        ) -> dict[str, torch.Tensor]:
        pred_all_frames_R = structure_model_prediction["all_frames_R"]  # (batch, N_res, 8, 3, 3)
        pred_all_frames_t = structure_model_prediction["all_frames_t"]  # (batch, N_res, 8, 3)
        atom_coords = structure_model_prediction["atom14_coords"]   # (batch, N_res, 14, 3)
        atom_mask = structure_model_prediction["atom14_mask"]       # (batch, N_res, 14)
        # Canonically rename ambiguous sidechains before any atom-derived supervision.
        true_atom_positions, true_atom_mask, alt_naming_is_better = select_best_atom14_ground_truth(
            atom_coords,
            true_atom_positions,
            true_atom_mask,
            true_atom_positions_alt,
            true_atom_mask_alt,
            true_atom_is_ambiguous,
        )

        renamed_rigid_group_frames_R = torch.where(
            alt_naming_is_better[:, :, None, None, None] > 0,
            true_rigid_group_frames_R_alt,
            true_rigid_group_frames_R,
        )
        renamed_rigid_group_frames_t = torch.where(
            alt_naming_is_better[:, :, None, None] > 0,
            true_rigid_group_frames_t_alt,
            true_rigid_group_frames_t,
        )

        backbone_mask = true_atom_mask[:, :, 0] * true_atom_mask[:, :, 1] * true_atom_mask[:, :, 2]
        backbone_loss = self.backbone_loss(
            structure_model_prediction,
            true_rotations,
            true_translations,
            backbone_mask=backbone_mask,
            seq_mask=seq_mask,
            use_clamped_fape=self.use_clamped_fape,
        )
        # Side-chain FAPE is always clamped (supplement 1.11.5); use_clamped_fape
        # controls only the backbone FAPE trajectory loss above.
        sidechain_loss = self.sidechain_fape_loss(
            pred_all_frames_R,
            pred_all_frames_t,
            atom_coords,
            atom_mask,
            renamed_rigid_group_frames_R,
            renamed_rigid_group_frames_t,
            true_atom_positions,
            true_atom_mask=true_atom_mask,
            seq_mask=seq_mask,
            frame_mask=true_rigid_group_exists,
        )
        torsion_loss = self.torsion_angle_loss(
            structure_model_prediction["traj_torsion_angles"],
            structure_model_prediction["traj_torsion_angles_unnormalized"],
            true_torsion_angles,
            true_torsion_angles_alt,
            true_torsion_mask,
            seq_mask=seq_mask,
        )

        # --- Derive distogram target (supplement 1.9.8) ---
        # Targets are the one-hot encoding of Cβ-Cβ distances (Cα for GLY).
        is_gly = (res_types == 7)                       # (batch, N_res)
        cb_idx = torch.where(is_gly, 1, 4)              # atom14 slots: CA=1, CB=4
        cb_pos = torch.gather(
            true_atom_positions, 2,
            cb_idx[:, :, None, None].expand(-1, -1, 1, 3),
        ).squeeze(2)
        n_dist_bins = distogram_pred.shape[-1]
        cb_mask = torch.gather(true_atom_mask, 2, cb_idx[:, :, None]).squeeze(-1)
        distogram_true = distance_bin(cb_pos, n_dist_bins)
        dist_pair_mask = cb_mask[:, :, None] * cb_mask[:, None, :]
        if seq_mask is not None:
            dist_pair_mask = dist_pair_mask * (seq_mask[:, :, None] * seq_mask[:, None, :])

        dist_loss = self.distogram_loss(distogram_pred, distogram_true, pair_mask=dist_pair_mask)
        msa_loss = self.msa_loss(masked_msa_pred, masked_msa_target, masked_msa_mask)

        # --- Derive pLDDT target (supplement 1.9.6) ---
        # Compute per-residue lDDT-Cα of the prediction against ground truth,
        # then discretise into 50 bins of width 2 (v_bins in Algorithm 29).
        # lDDT-Cα is the mean over 4 thresholds (0.5, 1, 2, 4 Å) of the fraction
        # of included Cα-Cα distance pairs that are preserved within tolerance.
        # "Included" = pairs with d_true < 15 Å, excluding self.
        N_res = atom_coords.shape[1]
        with torch.no_grad():
            pred_ca = atom_coords[:, :, 1, :]                 # (batch, N_res, 3)
            true_ca = true_atom_positions[:, :, 1, :]
            true_ca_mask = true_atom_mask[:, :, 1]
            true_ca_dists = torch.cdist(true_ca, true_ca)     # (batch, N_res, N_res)
            pred_ca_dists = torch.cdist(pred_ca, pred_ca)
            inclusion = (true_ca_dists < 15.0).float() * (
                1.0 - torch.eye(N_res, device=pred_ca.device).unsqueeze(0))
            inclusion = inclusion * (true_ca_mask[:, :, None] * true_ca_mask[:, None, :])
            if seq_mask is not None:
                pair_valid = seq_mask[:, :, None] * seq_mask[:, None, :]  # (batch, N_res, N_res)
                inclusion = inclusion * pair_valid
            dist_error = torch.abs(pred_ca_dists - true_ca_dists)
            # Average fraction of preserved distances across four thresholds
            lddt = torch.zeros(pred_ca.shape[:2], device=pred_ca.device)  # (batch, N_res)
            n_included = inclusion.sum(dim=-1).clamp(min=1)
            for thresh in [0.5, 1.0, 2.0, 4.0]:
                lddt = lddt + ((dist_error < thresh).float() * inclusion).sum(dim=-1) / n_included
            lddt = lddt / 4.0  # (batch, N_res) in [0, 1]
            lddt_mask = true_ca_mask if seq_mask is None else true_ca_mask * seq_mask
            n_plddt_bins = plddt_pred.shape[-1]
            plddt_edges = torch.arange(1, n_plddt_bins, device=pred_ca.device).float() / n_plddt_bins
            plddt_bin_idx = torch.bucketize(lddt, plddt_edges)
            plddt_true = torch.nn.functional.one_hot(plddt_bin_idx, n_plddt_bins).float()

        plddt_loss = self.plddt_loss(
            plddt_pred,
            plddt_true,
            seq_mask=lddt_mask,
            resolution=resolution,
        )

        # Equation 7, training row. `backbone_loss` is L_aux^{FAPE} = mean_l(FAPE^l);
        # `sidechain_loss` is L_FAPE (all-atom, final layer); `torsion_loss` already
        # bundles 0.5 * L_aux^{torsion} + 0.01 * L_aux^{anglenorm} per Algorithm 27
        # and the L_aux factor from equation 7.
        weighted_backbone_loss = (1.0 - self.sidechain_weight_frac) * backbone_loss
        weighted_sidechain_fape_loss = self.sidechain_weight_frac * sidechain_loss
        weighted_torsion_loss = torsion_loss
        fape_loss = weighted_backbone_loss + weighted_sidechain_fape_loss
        structure_loss = fape_loss + weighted_torsion_loss
        weighted_distogram_loss = self.distogram_weight * dist_loss
        weighted_msa_loss = self.msa_weight * msa_loss
        weighted_plddt_loss = self.confidence_weight * plddt_loss
        loss = structure_loss + weighted_distogram_loss + weighted_msa_loss + weighted_plddt_loss

        loss_terms = {
            "loss": loss,
            "structure_loss": structure_loss,
            "fape_loss": fape_loss,
            "backbone_loss": backbone_loss,
            "sidechain_fape_loss": sidechain_loss,
            "torsion_loss": torsion_loss,
            "distogram_loss": dist_loss,
            "msa_loss": msa_loss,
            "plddt_loss": plddt_loss,
            "weighted_backbone_loss": weighted_backbone_loss,
            "weighted_sidechain_fape_loss": weighted_sidechain_fape_loss,
            "weighted_torsion_loss": weighted_torsion_loss,
            "weighted_distogram_loss": weighted_distogram_loss,
            "weighted_msa_loss": weighted_msa_loss,
            "weighted_plddt_loss": weighted_plddt_loss,
        }

        if self.finetune:
            structural_violation_loss = self.structural_violation_loss(
                atom_coords,
                atom_mask,
                res_types,
                residue_index,
            )
            weighted_structural_violation_loss = self.structural_violation_weight * structural_violation_loss
            loss = loss + weighted_structural_violation_loss
            loss_terms["structural_violation_loss"] = structural_violation_loss
            loss_terms["weighted_structural_violation_loss"] = weighted_structural_violation_loss

            exp_resolved_loss = self.experimentally_resolved_loss(
                experimentally_resolved_pred,
                experimentally_resolved_true,
                experimentally_resolved_exists,
                resolution=resolution,
            )
            weighted_exp_resolved_loss = self.experimentally_resolved_weight * exp_resolved_loss
            loss = loss + weighted_exp_resolved_loss
            loss_terms["experimentally_resolved_loss"] = exp_resolved_loss
            loss_terms["weighted_experimentally_resolved_loss"] = weighted_exp_resolved_loss

            # Supplement 1.9.7: predicted aligned error / pTM head, fine-tuning
            # only, weight 0.1. Skipped silently if tm_pred is not supplied.
            if tm_pred is not None:
                tm_score_loss = self.tm_score_loss(
                    tm_pred,
                    predicted_rotations=structure_model_prediction["final_rotations"],
                    predicted_translations=structure_model_prediction["final_translations"],
                    true_rotations=true_rotations,
                    true_translations=true_translations,
                    backbone_mask=backbone_mask,
                    seq_mask=seq_mask,
                    resolution=resolution,
                )
                weighted_tm_score_loss = self.tm_score_weight * tm_score_loss
                loss = loss + weighted_tm_score_loss
                loss_terms["tm_score_loss"] = tm_score_loss
                loss_terms["weighted_tm_score_loss"] = weighted_tm_score_loss

        loss_terms["loss"] = loss
        return loss_terms

The supplement gives specific weights: FAPE dominates at weight 0.5\approx 0.5; distogram and masked-MSA each contribute around 0.30.3; pLDDT is around 0.010.01; experimentally-resolved is around 0.010.01; and violations are typically held at 00 for the first 85%\sim 85\% of training before being turned up. The whole thing is one loss number that gets differentiated.

FAPE is the clearest example in the whole AF2 paper of an architectural choice expressed as a loss rather than as a layer. You could imagine building a model that outputs atom coordinates and then applies some canonical alignment (like Kabsch superposition) before comparing to ground truth in a single global frame — and some earlier structure prediction work did exactly this. FAPE skips the alignment step entirely by constructing a loss that is intrinsically alignment-free. Every one of the network’s intermediate frames is supervised by this loss, and because the loss itself is SE(3)-invariant, the frames can live in any global orientation the model finds convenient during training.

This is the same kind of move as, say, choosing a contrastive loss over a classification loss in representation learning — you don’t change the model, you change what “being right” means, and a lot of downstream architectural choices become easier as a consequence. If you take one thing away from the losses section, let it be that FAPE is a load-bearing conceptual move, not a minor implementation detail.

13. Training

The training recipe for AlphaFold2 is where the paper’s ambition shows most clearly. The architecture and the loss are one thing; making them converge at scale on biological data is another. minAlphaFold covers most of the training mechanics (cropping, recycling sampling, gradient checkpointing), but it skips one big piece — self-distillation — which we’ll call out honestly.

13.1 Feature engineering

We saw the input feature builders back in §2: build_target_feat , build_msa_feat , build_extra_msa_feat , and build_template_pair_feat . The training-time additions are the supervision-building pipeline — turning raw .mmcif files into all the targets the losses in §12 need:

def build_processed_example(
    example: Dict[str, Any],
    *,
    crop_size: int,
    msa_depth: int,
    extra_msa_depth: int,
    max_templates: int,
    training: bool,
    block_delete_training_msa: bool = True,
    block_delete_msa_fraction: float = 0.3,
    block_delete_msa_randomize_num_blocks: bool = False,
    block_delete_msa_num_blocks: int = 5,
    masked_msa_probability: float = 0.15,
    random_seed: int | None = None,

This extracts atom14 coordinates and masks from the crystal structure, builds ground-truth backbone frames via rigid_frame_from_three_points, computes torsion angles from ϕ\phi/ψ\psi/ω\omega/χ1..4\chi_{1..4}, and produces the per-residue lDDT and distogram targets. The output is a single processed example dict that flows through the rest of the pipeline.

13.2 Cropping

Real proteins can be thousands of residues long. The full Evoformer at paper-spec config, with 48 blocks and triangle multiplication’s O(r3)O(r^3) scaling, would not fit in any single-GPU memory budget for r>400r > \sim 400. AlphaFold2’s answer is to train on crops — random contiguous windows of a fixed length NcropN_\text{crop} (usually 256 residues) sampled from each protein.

def build_processed_example_from_cropped(
    example: Dict[str, Any],
    *,
    msa_depth: int,
    extra_msa_depth: int,
    max_templates: int,
    training: bool,
    block_delete_training_msa: bool = True,
    block_delete_msa_fraction: float = 0.3,
    block_delete_msa_randomize_num_blocks: bool = False,
    block_delete_msa_num_blocks: int = 5,
    masked_msa_probability: float = 0.15,
    random_seed: int | None = None,

The cropping is random contiguous: pick a start index uniformly, take NcropN_\text{crop} consecutive residues, and feature-engineer the crop. This has a nontrivial effect on supervision — the model only ever sees pieces of proteins at training time, and yet it still has to generalize to full-length proteins at inference. The evidence that this works is mostly empirical: the triangle-consistency inductive biases in the Evoformer, combined with the SE(3)-equivariant Structure Module, are apparently enough that local structural relationships learned on crops compose reasonably to full proteins.

13.3 Batch collation

With variable-length crops (some proteins are shorter than 256), we need to pad and mask. collate_batch does both:

def collate_batch(
    examples: List[Dict[str, Any]],
    *,
    crop_size: int,
    msa_depth: int,
    extra_msa_depth: int,
    max_templates: int,
    training: bool,
    block_delete_training_msa: bool = True,
    block_delete_msa_fraction: float = 0.3,
    block_delete_msa_randomize_num_blocks: bool = False,
    block_delete_msa_num_blocks: int = 5,
    masked_msa_probability: float = 0.15,
    random_seed: int | None = None,
    num_recycling_samples: int = 1,
    num_ensemble_samples: int = 1,

Every feature tensor in the batch is padded to the largest example in the batch, and a seq_mask tensor is threaded through every attention operation to blackout the padded positions. Attention scores at padded indices go to -\infty before softmax, and outputs at padded positions are discarded before loss computation. You’ll see seq_mask passed into nearly every module in minAlphaFold for this reason.

13.4 Recycling sampling

We covered the mechanics of recycling in §9 — feed previous-cycle outputs back into current-cycle inputs, detach gradients between cycles. What changes at training time is that the number of cycles is sampled uniformly for each batch:

n_cycles = random.randint(1, max_cycles)  # uniform during training

This forces the model to produce usable representations at every cycle count, not just the max. A model trained only with 4 cycles would specialize its weights for the 4-cycle case; with uniform cycle sampling, the weights have to handle 1-cycle, 2-cycle, …, nn-cycle all reasonably. The combined effect is that recycling works at inference time from the first cycle onward, and the refinement over cycles is monotonic.

13.5 Gradient checkpointing

The full Evoformer at 48 blocks, with 4 cycles, and with triangle multiplication’s O(r3)O(r^3) cost, has an activation memory footprint that dwarfs the parameter count by orders of magnitude. PyTorch’s torch.utils.checkpoint trades compute for memory by not storing intermediate activations during the forward pass — instead, during the backward pass, the forward is re-run for each checkpointed block to reconstruct the activations.

In minAlphaFold, gradient checkpointing wraps every Evoformer block and every extra-MSA block. The practical effect is that training memory scales linearly with the number of checkpointed units rather than with the full computation depth. Without it, paper-spec AF2 would not fit in 80 GB of GPU memory; with it, it does.

13.6 The honest gap: self-distillation

The full AlphaFold2 training procedure has one piece that minAlphaFold does not implement: self-distillation. In the full procedure, a first model is trained on the PDB (roughly 150k structures). Then that model is used to predict structures for millions of unlabeled protein sequences from UniClust. Predictions with high pLDDT (>70> 70) are retained as pseudo-ground-truth, and a second model is trained on the union of real PDB structures and self-distilled pseudo-structures.

Self-distillation is worth roughly a point or two of GDT on CASP14 benchmarks — not huge, but not trivial. It’s also a lot of infrastructure: two-stage training, a prediction run at scale across a massive sequence database, and careful quality filtering of the pseudo-labels.

minAlphaFold leaves self-distillation as a deliberate gap — the repo’s README calls it out as a TODO. The practical implications: if you train minAlphaFold end-to-end on just the PDB using the configs provided, you can expect to get to within a few points of CASP14-level accuracy, but not all the way there. For a pedagogical walkthrough, this is the right call — self-distillation is orthogonal to the ideas that make AlphaFold2 work, and folding it in would add roughly 2,000 words of pipeline plumbing without new conceptual content. I’m noting it here so nobody thinks they’ve been handed a faithful reproduction of the full training recipe.

13.7 The training loop itself

Everything above lands in a single training loop:

def train_step(
    model: AlphaFold2,
    loss_fn: AlphaFoldLoss,
    optimizer: torch.optim.Optimizer,
    batch: dict[str, Any],
    training_config: TrainingConfig,

Pretty standard PyTorch mechanics — a forward pass through the model with the right n_cycles, an AlphaFoldLoss call on the outputs, .backward(), a gradient clipping step, an optimizer step. The learning rate schedule is warmup-plus-decay (linear warmup for the first ~1000 steps, constant or square-root decay after). Most of the interesting work is in the model and the loss; the training loop itself fits on one screen.

With that, we have the full pipeline: input features, Evoformer, Structure Module, auxiliary heads, losses, and training procedure. The last two sections of the annotation make it concrete — §14 walks through a real overfit-single-PDB run with a structure viewer, and §15 points at where AF2 has since been extended.

14. A real-world walkthrough

We’ve built up enough machinery. Thirteen sections of architecture, loss, and training. So what does it look like when you actually run this thing?

The cleanest test — the one that convinces you the pipeline works end-to-end — is a single-protein overfit. Pick one PDB, feed it to the model as both input (sequence) and target (coordinates), and train until the model memorizes the structure. If the whole pipeline is coherent — features, Evoformer, Structure Module, loss, optimizer — the predicted coordinates should converge to the ground truth in a handful of minutes on CPU. If anything is broken (a wrong sign in a gradient, a frame convention mismatch, a loss that’s computed in the wrong space), the overfit will fail. It’s a tiny test that exercises everything.

14.1 The setup

The protein I’m using is 1CRN — crambin, a 46-residue plant protein with three disulfide bonds. It’s small enough to train on a laptop CPU in under a minute, and interesting enough geometrically that a random-init model will get it embarrassingly wrong at step 0 and recognizably correct by step 600.

The driver script lives at scripts/overfit_single_pdb.py in the minAlphaFold repo. Simplified to the essentials:

parsed = parse_pdb("1crn.pdb")
example = build_minimal_example("1crn", parsed)
batch = collate_batch([example], crop_size=n_res, msa_depth=1,
                      extra_msa_depth=0, max_templates=0, ...)

model = AlphaFold2(load_model_config("tiny"))    # 0.09M params, 2 Evoformer blocks
loss_fn = AlphaFoldLoss(finetune=False)
optimizer = build_optimizer(model, TrainingConfig(learning_rate=1e-3, ...))

for step in range(1, 601):
    optimizer.zero_grad()
    outputs = model(**model_inputs_from_batch(batch, tc))
    loss = loss_fn(**loss_inputs_from_batch(batch, outputs)).mean()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    optimizer.step()

A few things worth flagging about this setup. The MSA has exactly one sequence — the query itself. Templates are disabled. Extra MSA is disabled. We’re running the tiny profile (90k parameters, 2 Evoformer blocks, shallow Structure Module) — not full paper-spec AF2, which would be ~93M parameters. And we’re training with the same sequence and same target for every step, so there’s no generalization happening here; the model is literally memorizing one structure.

That’s what makes this a pipeline test, not a science test. If your 90k-parameter model running on a single protein for 600 steps converges to the right geometry, that doesn’t mean AlphaFold2 works — it means your implementation of the pipeline is coherent enough that gradient descent on FAPE can drive coordinates toward ground truth. That’s a low bar. But it’s a necessary bar, and it catches an enormous number of bugs. “Does my loss go down?” is easy. “Does the predicted structure start looking like the real one?” is the real question, and the overfit is the cheapest way I know to answer it.

14.2 What actually happens

The widget below loads 11 checkpoints from a real run — the same code as above, executed on CPU, producing a trajectory from random-init (step 0) to converged (step 600). Drag the scrubber, hit play, or click Reset. The gray cartoon is the ground truth. The colored (or solid red, depending on the toggle) cartoon is the prediction at whichever step you’ve parked on.

Loading training trajectory…

The story the scrubber tells:

  • Step 0–10: the prediction is random. Cα RMSD is ~9.7 Å — essentially saying “these residues could be anywhere.” The Structure Module has just produced 46 backbone frames by unrolling IPA from Gaussian-noise weights; there’s no reason they’d correspond to anything.
  • Step 10–100: the model starts to contract. FAPE’s gradient is pulling every residue into rough proximity of every other residue’s ground-truth view. Still very wrong, but the blob is shrinking.
  • Step 100–300: phase transition. Around step 145, the prediction starts to look like a protein — you can see secondary structure emerging in places. Cα RMSD drops below 10 Å consistently, then below 8 Å by step 300.
  • Step 300–600: convergence. The model is now learning the specific fold of 1CRN — the three-helix bundle and the disulfide-bonded topology. By step 600, Cα RMSD has dropped to ~2.25 Å, which is close enough that the prediction and the ground truth visibly overlay.

Total wall time: roughly 30 seconds on a single CPU core. Total parameters: 90,000 — three orders of magnitude smaller than paper-spec AF2, which has 93 million. This is what the “tiny” profile is good for — a smoke test that exercises every code path in the model with minimal time and memory.

14.3 What this does and doesn’t tell us

The overfit says: the pipeline is coherent. Features get built correctly, the Evoformer maps them through, the Structure Module produces frames, FAPE computes a sensible gradient, the backbone update and all-atom reconstruction compose into plausible coordinates, and optimization drives all of that towards the ground truth. Every line of code from §§2–13 just got exercised in a single forward-backward loop, and the result is a structure that — by the end — looks like a real protein.

It does not say: AlphaFold2 generalizes. The model has seen exactly one protein. Step 600’s 2.25 Å Cα RMSD is memorization, not prediction. Real AF2 training runs on the entire Protein Data Bank (~150k structures), with deep MSAs, templates, self-distillation on millions of unlabeled sequences, and hundreds of GPU-days of compute — a world apart from a laptop CPU and one PDB.

But the overfit is still the most satisfying “does my code actually work” demo I know of, because it makes the progress of the training loop visible in the domain that matters. A loss curve going down could be literally anything; a structure morphing from noise into crambin is a loss curve going down because the model is learning geometry. You can’t fake it.

If you’d like to run it yourself, the command is:

# From the minAlphaFold2 repo root:
python scripts/overfit_single_pdb.py --pdb path/to/1crn.pdb --steps 600 --model-profile tiny

It’ll write its own PDBs and metrics to artifacts/overfit_single_pdb/1crn/. Watch the per-step log — seeing Cα RMSD drop in real time does the same job as the widget above, just without the 3D rendering.

15. What AlphaFold2 left unsolved

AlphaFold2 changed everything in structural biology, and like any result that changes everything, it also sharpened the shape of what comes next. The paper solved monomeric protein structure prediction for the large fraction of proteins with reasonable MSA coverage. That turned out to be a lot of proteins — the AlphaFold Protein Structure Database now covers nearly every sequenced protein in every major organism. But it left open a long list of things that real biology cares about:

Protein complexes. Proteins rarely act alone — they form complexes with other proteins, with DNA, with small molecules, and with membranes. AF2 was trained on monomers and struggles with multimeric structures. AlphaFold-Multimer (2021) extended the training and inference to complexes, and AlphaFold3 (2024) went further — a redesigned architecture, diffusion-based output, and a scope that includes nucleic acids, ligands, and post-translational modifications.

Dynamics. Proteins aren’t static. They breathe, they flex, they transition between conformations that are functionally distinct. AF2 predicts a single “most likely” structure, which is often enough for structural biology but not for enzyme mechanisms, allosteric regulation, or anything involving a motion-dependent function. Extensions like MSA-clustering methods (Wayment-Steele et al.) coax multi-state predictions out of AF2 by perturbing the input MSA, and diffusion-based successors like AlphaFold3 and Boltz-1 have native mechanisms for sampling conformational diversity.

Design. Structure prediction is the forward problem: sequence → structure. The inverse problem — “give me a sequence that folds into this structure” — is protein design, and it’s as important scientifically but uses different tooling. RoseTTAFold Diffusion, ProteinMPNN, and (more recently) AlphaProteo all live in this inverse space, and they tend to use AF2-descended representations to evaluate their designs but not to generate them.

Language-model shortcuts. Perhaps the most interesting follow-up is the shift toward models that replace the MSA entirely. ESMFold (Meta, 2022) uses a pretrained protein language model in place of the MSA pipeline, arguing that a large enough transformer has already learned the co-evolutionary patterns the MSA was there to surface. ESM3 (EvolutionaryScale, 2024) goes further — a single multimodal model trained jointly on sequence, structure, and function, positioning itself as a general-purpose biological foundation model rather than a structure-prediction specialist. I’ve written a multi-part primer on protein language models that traces this arc from AF2 through ESM3 in much more detail.

Compute and inference cost. AF2’s training run was enormous, and its per-sequence inference cost isn’t cheap either. A lot of real-world work — drug discovery, enzyme engineering, synthetic biology — can’t afford 100 GPU-seconds per prediction. The successors trying to bring these costs down (ESMFold for MSA-free inference, distillation-based approaches, and the compute-optimized architectures in AlphaFold3) are responses to this constraint.

One of the more interesting meta-observations about this space is how many of the specific architectural choices in AF2 turned out to be load-bearing and how many turned out to be incidental. The Evoformer’s triangle updates, IPA, and FAPE are in almost every paper that followed, in roughly the same form. Recycling is universal. The specific MSA pipeline is already halfway obsolete. The 8-iteration Structure Module has been replaced in successors by denoising diffusion. The training recipe (self-distillation, MSA clustering) has been reworked.

If you read AF2 as the 2021 paper it is, you’ll notice design choices that look obvious in hindsight but were not at all obvious at the time. And you’ll notice choices that looked obvious at the time that turned out to be disposable. This is a healthy thing to keep in mind while reading any big result — not every component of a landmark paper deserves to be a permanent fixture of the field. The authors of the next AlphaFold-ish paper will know which is which.

— end —

If you’ve made it this far, thank you for reading! If you found this useful, I write regularly about AI, biology, and adjacent topics at chrishayduk.com. The annotation project continues — the next paper up is open, but ESM-2, AlphaFold3, and Boltz-1 are all on the shortlist. Suggestions welcome.