Attention dai principi fondamentali

Una derivazione compatta, a livello di shape, della scaled dot-product attention.

I transformer sono più facili da debuggare quando le shape dei tensori sono noiose. Parti da una sequenza di stati nascosti:

XRT×dmodelX \in \mathbb{R}^{T \times d_{model}}

Proiettala in query, key e value:

Q=XWQ,K=XWK,V=XWVQ = XW_Q,\quad K = XW_K,\quad V = XW_V

La matrice degli score

Gli score di attention sono similarità a coppie tra ogni query e ogni key:

S=QKdkS = \frac{QK^\top}{\sqrt{d_k}}

Se QRT×dkQ \in \mathbb{R}^{T \times d_k} e KRT×dkK \in \mathbb{R}^{T \times d_k}, allora:

SRT×TS \in \mathbb{R}^{T \times T}

Quella matrice quadrata è tutto il trucco. La riga ii dice da quali posizioni il token ii dovrebbe leggere.

Una versione PyTorch minimale

import torch

def attention(q, k, v, mask=None):
    scale = q.size(-1) ** -0.5
    scores = q @ k.transpose(-2, -1) * scale

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    weights = scores.softmax(dim=-1)
    return weights @ v

Il bug comune

Il modo più facile per rompere attention è normalizzare sull'asse sbagliato:

# Sbagliato per la attention normale.
weights = scores.softmax(dim=-2)

# Corretto: ogni riga di query diventa una distribuzione sulle key.
weights = scores.softmax(dim=-1)

L'output mantiene la larghezza dei value:

Attention(Q,K,V)=softmax(QKdk)VRT×dv\operatorname{Attention}(Q,K,V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \in \mathbb{R}^{T \times d_v}