Attention from first principles
A compact shape-level derivation of scaled dot-product attention.
Transformers are easiest to debug when the tensor shapes are boring. Start with a sequence of hidden states:
Project it into queries, keys, and values:
The score matrix
The attention scores are pairwise similarities between every query and every key:
If and , then:
That square matrix is the whole trick. Row says which positions token should read from.
A minimal PyTorch version
import torch
def attention(q, k, v, mask=None):
scale = q.size(-1) ** -0.5
scores = q @ k.transpose(-2, -1) * scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = scores.softmax(dim=-1)
return weights @ vThe common bug
The easiest way to break attention is to normalize over the wrong axis:
# Wrong for normal attention.
weights = scores.softmax(dim=-2)
# Correct: each query row becomes a distribution over keys.
weights = scores.softmax(dim=-1)The output keeps the value width: