Attention from first principles

A compact shape-level derivation of scaled dot-product attention.

Transformers are easiest to debug when the tensor shapes are boring. Start with a sequence of hidden states:

XRT×dmodelX \in \mathbb{R}^{T \times d_{model}}

Project it into queries, keys, and values:

Q=XWQ,K=XWK,V=XWVQ = XW_Q,\quad K = XW_K,\quad V = XW_V

The score matrix

The attention scores are pairwise similarities between every query and every key:

S=QKdkS = \frac{QK^\top}{\sqrt{d_k}}

If QRT×dkQ \in \mathbb{R}^{T \times d_k} and KRT×dkK \in \mathbb{R}^{T \times d_k}, then:

SRT×TS \in \mathbb{R}^{T \times T}

That square matrix is the whole trick. Row ii says which positions token ii should read from.

A minimal PyTorch version

import torch

def attention(q, k, v, mask=None):
    scale = q.size(-1) ** -0.5
    scores = q @ k.transpose(-2, -1) * scale

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    weights = scores.softmax(dim=-1)
    return weights @ v

The common bug

The easiest way to break attention is to normalize over the wrong axis:

# Wrong for normal attention.
weights = scores.softmax(dim=-2)

# Correct: each query row becomes a distribution over keys.
weights = scores.softmax(dim=-1)

The output keeps the value width:

Attention(Q,K,V)=softmax(QKdk)VRT×dv\operatorname{Attention}(Q,K,V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \in \mathbb{R}^{T \times d_v}