Attention dai principi fondamentali
Una derivazione compatta, a livello di shape, della scaled dot-product attention.
I transformer sono più facili da debuggare quando le shape dei tensori sono noiose. Parti da una sequenza di stati nascosti:
Proiettala in query, key e value:
La matrice degli score
Gli score di attention sono similarità a coppie tra ogni query e ogni key:
Se e , allora:
Quella matrice quadrata è tutto il trucco. La riga dice da quali posizioni il token dovrebbe leggere.
Una versione PyTorch minimale
import torch
def attention(q, k, v, mask=None):
scale = q.size(-1) ** -0.5
scores = q @ k.transpose(-2, -1) * scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = scores.softmax(dim=-1)
return weights @ vIl bug comune
Il modo più facile per rompere attention è normalizzare sull'asse sbagliato:
# Sbagliato per la attention normale.
weights = scores.softmax(dim=-2)
# Corretto: ogni riga di query diventa una distribuzione sulle key.
weights = scores.softmax(dim=-1)L'output mantiene la larghezza dei value: