Chris McCormick    Live Walkthroughs    Support My Work    Archive

Watch, Code, Master: ML tutorials that actually work → Start learning today!

Refactoring the Attention Equations: Patterns and Messages

I figure I’ll publish another post on this each day until it’s fully birthed. Up ahead–practical applications, illustrations, experimental results, and who knows what else!

The original formulation of multi-head attention from Vaswani et. al concatenates multiple attention heads before applying the output projection matrix $W^O$:

\begin{align} \text{MultiHead}(Q, K, V) &= \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O, \
\text{head}_i &= \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V), \
\text{Attention}(Q, K, V) &= \text{softmax} \left(\frac{Q K^T}{\sqrt{d_k}}\right) V \end{align}

1. Splitting $W^O$ by Head

By splitting $W^O$ by head, we can replace the concatenation with a sum:

\[\text{MultiHead}(X) = \sum_{i=1}^{h} \text{head}_i W^O_i\]

Instead of concatenating heads, we project each head separately using $W^O_i$.

MultiHead(Q, K, V) is a way of indicating that these three attention inputs can have different sources, such as in multimodal attention or cross-attention. But I’ll simply write $X$ as the input for simplicity.

2. Separating the Value Projection from the Scores

We replace the “Attention” function with just the “scores”, and move out the Value projection.

\begin{align} \text{MultiHead}(X) &= \sum_{i=1}^{h} \text{scores}_i \cdot \text{values}_i \
\text{scores}_i &= \text{softmax} \left(\frac{X W_i^Q (X W_i^K)^T}{\sqrt{d_k}}\right) \
\text{values}_i &= X W_i^V W^O_i \end{align}

3. Introducing Messages $M$

So far, we have treated $W^V$ and $W^O$ as separate transformations. However, they can be seen as a low-rank decomposition of a larger transformation. Defining a single matrix $W^M$, we rewrite:

\[M_i = X W^M_i\]

Each row in $M_i$ represents a message passed between tokens in the attention mechanism. Attention scores determine how much influence each message has on a given token.

\begin{align} \text{MultiHead}(X) &= \sum_{i=1}^{h} \text{scores}_i \cdot \text{messages}_i, \
\text{scores}_i &= \text{softmax} \left(\frac{X W_i^Q (X W_i^K)^T}{\sqrt{d_k}}\right) \
\text{messages}_i &= X W^M_i \end{align}

4. Introducing Patterns $P$

Instead of separately computing queries and keys, we can merge their weight matrices into a single matrix $W^P_i$:

\[W^P_i = W^Q_i (W^K_i)^T\]

To compute attention scores we first compute the attention head patterns:

\[P_i = X W^P_i\]

Now, attention logits can be rewritten as:

\[\text{logits}_i = P_i X^T\]

Each row of $P_i$ represents a pattern that the attention head is searching for in the sequence.

This reformulation keeps attention in model space, making it more interpretable.

5. Updated Multi-Head Attention Formulation

\begin{align} \text{MultiHead}(X) &= \sum_{i=1}^{h} \text{scores}_i \cdot \text{messages}_i \
\text{patterns}_i &= X W^P_i \
\text{scores}_i &= \text{softmax} \left(\frac{P_i X^T}{\sqrt{d_k}} \right) \
\text{messages}_i &= X W^M_i \end{align}

\begin{align} \text{MultiHead}(X) &= \sum_{i=1}^{h} \text{scores}_i \cdot M_i \
P_i &= X W^P_i \
\text{scores}_i &= \text{softmax} \left(\frac{P_i X^T}{\sqrt{d_k}} \right) \
M_i &= X W^M_i \end{align}