The Attention Mechanism
If matrix multiplication is the workhorse of transformer inference, attention is the brain. It is the mechanism that allows each token to look at every other token in the sequence and decide what information is relevant. Without attention, a transformer is just a stack of feed-forward networks – powerful but unable to model relationships between positions in a sequence.
In this chapter, we are going to take attention apart piece by piece. We will start with the mathematical formulation, walk through a complete numerical example, explore multi-head and grouped-query variants, understand the KV cache, and analyze the computational complexity. By the end, you will understand every detail of how attention works, which will set us up perfectly for the FlashAttention optimization in the next chapter.
The Core Idea: Queries, Keys, and Values
Attention is fundamentally a lookup mechanism. Think of it like a fuzzy dictionary lookup: you have a query (“what am I looking for?”), a set of keys (“what is available?”), and corresponding values (“what information do those keys hold?”). The twist is that instead of an exact match, you compute a similarity score between the query and every key, then return a weighted combination of the values.
Attention as Fuzzy Dictionary Lookup
=====================================
Query: "What comes after 'the'?"
Keys: Values: Similarity:
---- ------ -----------
"the" ---> [0.2, 0.8, ...] 0.85 (high -- relevant!)
"cat" ---> [0.5, 0.1, ...] 0.72 (medium)
"sat" ---> [0.3, 0.6, ...] 0.15 (low)
"on" ---> [0.1, 0.4, ...] 0.62 (medium)
Output = 0.85 * val("the") + 0.72 * val("cat") + 0.15 * val("sat") + ...
(after normalizing similarities with softmax)
But where do queries, keys, and values come from? They are all produced by linear projections of the input embeddings.
Self-Attention: Q, K, V Projections
Given an input sequence of token embeddings X with shape [seq_len, d_model], we
produce Q, K, and V through three separate linear projections:
Q = X * W_Q where W_Q has shape [d_model, d_k]
K = X * W_K where W_K has shape [d_model, d_k]
V = X * W_V where W_V has shape [d_model, d_v]
In the original transformer paper and most modern LLMs, d_k = d_v = d_model. Each
projection is a matrix multiplication – exactly the kind we optimized in the previous
chapter.
Self-Attention Projections
===========================
Input X: [seq_len x d_model]
X ----+-----> [W_Q] -----> Q [seq_len x d_k]
|
+-----> [W_K] -----> K [seq_len x d_k]
|
+-----> [W_V] -----> V [seq_len x d_v]
"Self" attention means Q, K, V all come from the SAME input X.
This is in contrast to cross-attention where Q comes from one
source and K, V come from another (used in encoder-decoder models).
The Attention Equation
Once we have Q, K, and V, the attention computation is:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Let us break this down step by step.
Step 1: Compute Attention Scores (Q * K^T)
We multiply Q by the transpose of K. If Q has shape [seq_len, d_k] and K has shape
[seq_len, d_k], then Q * K^T has shape [seq_len, seq_len]. Each element (i, j) is
the dot product of query vector i with key vector j – a measure of how much token i
should attend to token j.
Q * K^T --> Attention Score Matrix
======================================
Q = [ q0 ] K^T = [ k0^T k1^T k2^T k3^T ]
[ q1 ]
[ q2 ] (each qi, kj is a vector of dimension d_k)
[ q3 ]
Q * K^T = [ q0.k0 q0.k1 q0.k2 q0.k3 ]
[ q1.k0 q1.k1 q1.k2 q1.k3 ]
[ q2.k0 q2.k1 q2.k2 q2.k3 ]
[ q3.k0 q3.k1 q3.k2 q3.k3 ]
Element (i,j) = dot product of query i with key j
= how much should token i attend to token j?
Step 2: Scale by sqrt(d_k)
We divide every score by sqrt(d_k). Why? Without scaling, the dot products grow in
magnitude proportional to the dimension. If d_k = 64, the dot products might be around
64 in magnitude. When you pass large numbers into softmax, the output becomes extremely
peaked – almost all the weight goes to one key, and the gradients vanish. Dividing by
sqrt(64) = 8 keeps the values in a reasonable range.
Why Scale?
==========
Without scaling (d_k = 64):
Typical dot product magnitude: ~64
softmax([64, 2, 1, 3]) = [1.000, 0.000, 0.000, 0.000]
--> Almost a hard lookup. Gradients vanish.
With scaling:
Scaled dot product: ~64/8 = ~8
softmax([8.0, 0.25, 0.125, 0.375]) = [0.997, 0.001, 0.001, 0.001]
--> Still peaked, but gradients can flow.
The scaling keeps variance of dot products at ~1.0
regardless of dimension, assuming inputs have unit variance.
Step 3: Apply Softmax
The softmax function converts the scaled scores into a probability distribution over keys for each query position. Each row of the score matrix is independently softmaxed:
softmax(z_i) = exp(z_i) / sum_j(exp(z_j))
After softmax, each row sums to 1.0, and all values are between 0 and 1. These are the attention weights.
Softmax: Scores --> Weights
============================
Before softmax (one row, after scaling):
[ 2.1, 0.5, -0.3, 1.2 ]
exp each:
[ 8.17, 1.65, 0.74, 3.32 ]
sum = 13.88
Divide by sum:
[ 0.589, 0.119, 0.053, 0.239 ]
^ ^
Token 0 gets most Token 3 gets
attention second most
Sum = 1.0 (it is a probability distribution)
Step 4: Weighted Sum of Values
Finally, we multiply the attention weights by V. Each output vector is a weighted combination of value vectors, where the weights come from softmax:
Weighted Value Sum
===================
Attention weights (one row): [0.589, 0.119, 0.053, 0.239]
V = [ v0 ] = [ 0.1, 0.9, 0.3 ]
[ v1 ] [ 0.5, 0.2, 0.8 ]
[ v2 ] [ 0.7, 0.4, 0.1 ]
[ v3 ] [ 0.3, 0.6, 0.5 ]
output = 0.589 * v0 + 0.119 * v1 + 0.053 * v2 + 0.239 * v3
= 0.589 * [0.1, 0.9, 0.3]
+ 0.119 * [0.5, 0.2, 0.8]
+ 0.053 * [0.7, 0.4, 0.1]
+ 0.239 * [0.3, 0.6, 0.5]
= [0.227, 0.718, 0.388]
The output is dominated by v0 (weight 0.589) and v3 (weight 0.239).
The Full Pipeline
Complete Attention Computation
===============================
Input: X [seq_len x d_model]
|
+--[W_Q]--> Q [seq_len x d_k]
+--[W_K]--> K [seq_len x d_k]
+--[W_V]--> V [seq_len x d_v]
|
v
Scores = Q * K^T [seq_len x seq_len]
|
v
Scaled = Scores / sqrt(d_k) [seq_len x seq_len]
|
v
Masked = apply causal mask [seq_len x seq_len]
|
v
Weights = softmax(Masked) [seq_len x seq_len]
|
v
Output = Weights * V [seq_len x d_v]
|
v
Projected = Output * W_O [seq_len x d_model]
Multi-Head Attention: Divide and Conquer
In practice, we do not compute a single attention function. Instead, we split the representation into multiple heads, each operating on a smaller dimension. This allows the model to attend to information from different representation subspaces simultaneously.
If d_model = 512 and we use h = 8 heads, then each head operates on
d_k = d_model / h = 64 dimensions.
Multi-Head Attention
=====================
d_model = 512, h = 8 heads, d_k = 64 per head
Input X: [seq_len x 512]
Full Q projection: X * W_Q --> Q_full [seq_len x 512]
Split into 8 heads:
Q_full = [ Q_head0 | Q_head1 | Q_head2 | ... | Q_head7 ]
[s x 64] [s x 64] [s x 64] [s x 64]
Similarly for K and V.
Each head computes attention independently:
head_i = Attention(Q_head_i, K_head_i, V_head_i)
Concatenate results:
MultiHead = [ head_0 | head_1 | ... | head_7 ] [seq_len x 512]
Final projection:
Output = MultiHead * W_O [seq_len x 512]
Why multiple heads? Different heads can learn different types of relationships:
What Different Heads Might Learn
=================================
Head 0: Syntactic dependencies ("the" attends to its noun)
Head 1: Positional patterns (attend to previous token)
Head 2: Semantic similarity (similar words attend to each other)
Head 3: Coreference ("it" attends to its referent)
Head 4: Long-range dependencies (closing bracket attends to opening)
...
Each head has its own Q, K, V projections, so each head
learns to look for different patterns in the data.
In terms of implementation, multi-head attention is usually done by performing the full
projection to get [seq_len, d_model] shaped Q, K, V, then reshaping to
[seq_len, n_heads, d_k] (or equivalently [n_heads, seq_len, d_k] after transposing).
The attention is then computed in parallel across all heads – on a GPU, this is a batched
operation.
Implementation: Reshape and Batch
==================================
After Q projection: Q [seq_len x d_model]
Reshape to: Q [seq_len x n_heads x d_k]
Transpose: Q [n_heads x seq_len x d_k]
Now attention is a batched operation:
For each head h in [0, n_heads):
scores_h = Q[h] * K[h]^T [seq_len x seq_len]
weights_h = softmax(scores_h) [seq_len x seq_len]
output_h = weights_h * V[h] [seq_len x d_v]
This is embarrassingly parallel -- all heads computed simultaneously.
Grouped Query Attention (GQA)
Standard multi-head attention requires storing separate K and V for every head. For a
model with 32 heads generating 4096-length sequences with d_k = 128, that is:
KV cache per layer = 2 * 32 * 4096 * 128 * 2 bytes = 64 MB
For 32 layers: 2 GB just for the KV cache!
Grouped Query Attention (GQA) reduces this by sharing K and V across groups of query heads. Instead of 32 KV heads, you might use only 8. Each KV head serves a group of 4 query heads.
Multi-Head vs Grouped Query vs Multi-Query
============================================
Multi-Head Attention (MHA): n_kv_heads = n_q_heads = 32
-------------------------------------------------------
Q heads: Q0 Q1 Q2 Q3 Q4 Q5 ... Q31
K heads: K0 K1 K2 K3 K4 K5 ... K31
V heads: V0 V1 V2 V3 V4 V5 ... V31
(Each Q head has its own K,V pair)
Grouped Query Attention (GQA): n_kv_heads = 8, group_size = 4
---------------------------------------------------------------
Q heads: Q0 Q1 Q2 Q3 | Q4 Q5 Q6 Q7 | Q8 ... | Q28 Q29 Q30 Q31
K heads: K0 | K1 | K2 | K7
V heads: V0 | V1 | V2 | V7
(4 Q heads share each K,V pair)
Multi-Query Attention (MQA): n_kv_heads = 1
---------------------------------------------
Q heads: Q0 Q1 Q2 Q3 Q4 Q5 ... Q31
K heads: K0
V heads: V0
(All Q heads share a single K,V pair)
The memory savings are substantial:
KV Cache Size Comparison (per layer)
======================================
seq_len = 4096, d_k = 128, FP16
MHA (32 KV heads): 2 * 32 * 4096 * 128 * 2 = 64 MB
GQA (8 KV heads): 2 * 8 * 4096 * 128 * 2 = 16 MB (4x reduction)
MQA (1 KV head): 2 * 1 * 4096 * 128 * 2 = 2 MB (32x reduction)
For 32 layers:
MHA: 2048 MB
GQA: 512 MB <-- Llama 2 70B, Llama 3
MQA: 64 MB
In the GPU kernel, GQA changes how you index into the KV cache. For query head h, the
corresponding KV head is h / group_size:
// GQA: mapping query heads to KV heads
uint q_head = head_index; // 0..31
uint kv_head = q_head / group_size; // 0..7 (if group_size=4)
// Load Q from q_head's slice
device const half* q = Q + q_head * d_k;
// Load K, V from kv_head's slice (shared by 4 Q heads)
device const half* k = K_cache + kv_head * seq_len * d_k;
device const half* v = V_cache + kv_head * seq_len * d_v;
The KV Cache: Why We Store K and V
During autoregressive decoding, we generate one token at a time. At step t, we need to
compute attention between the new token’s query and all previous tokens’ keys and values.
Without caching, we would need to reproject all previous tokens through W_K and W_V every
single step – a massive waste of computation.
The KV cache stores the K and V vectors for all previous tokens. At each decoding step, we only compute K and V for the new token and append them to the cache.
KV Cache During Decoding
=========================
Step 1: Process token "The"
+---------+
| K: k_0 | V: v_0
+---------+
Q = q_0, attend to [k_0]
Step 2: Process token "cat"
+---------+---------+
| K: k_0 | K: k_1 | V: v_0, v_1
+---------+---------+
Q = q_1, attend to [k_0, k_1]
Step 3: Process token "sat"
+---------+---------+---------+
| K: k_0 | K: k_1 | K: k_2 | V: v_0, v_1, v_2
+---------+---------+---------+
Q = q_2, attend to [k_0, k_1, k_2]
Step 4: Process token "on"
+---------+---------+---------+---------+
| K: k_0 | K: k_1 | K: k_2 | K: k_3 | V: v_0, v_1, v_2, v_3
+---------+---------+---------+---------+
Q = q_3, attend to [k_0, k_1, k_2, k_3]
At each step, we compute K, V for only the NEW token
and append to the cache. The cache grows by one entry per step.
Without KV cache:
Step t: compute K, V for ALL t tokens, then attend
Cost: O(t * d_model * d_k) for projections alone
Total over T steps: O(T^2 * d_model * d_k) -- quadratic!
With KV cache:
Step t: compute K, V for 1 new token, append to cache
Cost: O(d_model * d_k) for projections
Total over T steps: O(T * d_model * d_k) for projections -- linear!
(Attention itself is still O(T * d_k) per step, O(T^2 * d_k) total)
The KV cache is typically stored as a contiguous buffer per layer, per head, pre-allocated to the maximum sequence length:
KV Cache Memory Layout
=======================
For one layer, one KV head:
K cache: [max_seq_len x d_k]
+------+------+------+------+------+------+------+------+
| k_0 | k_1 | k_2 | k_3 | .... | .... | .... | .... |
+------+------+------+------+------+------+------+------+
^ ^ ^
filled positions | empty (pre-allocated)
current position
V cache: [max_seq_len x d_v]
+------+------+------+------+------+------+------+------+
| v_0 | v_1 | v_2 | v_3 | .... | .... | .... | .... |
+------+------+------+------+------+------+------+------+
At each decode step:
1. Compute k_new, v_new for the new token
2. Write k_new to K_cache[current_pos]
3. Write v_new to V_cache[current_pos]
4. Compute attention: q_new vs K_cache[0..current_pos]
5. Increment current_pos
Causal Masking
In autoregressive language models, a token should only attend to tokens that came before it (or at the same position). This is called causal masking or the “no peeking into the future” constraint.
During prefill (when processing the entire prompt at once), we enforce this by adding a mask to the attention scores before softmax:
Causal Mask (4 tokens)
=======================
Before masking: After masking:
[ 2.1 0.5 1.3 0.8 ] [ 2.1 -inf -inf -inf ]
[ 0.3 1.7 0.9 0.4 ] [ 0.3 1.7 -inf -inf ]
[ 1.1 0.2 2.3 1.5 ] [ 1.1 0.2 2.3 -inf ]
[ 0.7 1.4 0.6 1.9 ] [ 0.7 1.4 0.6 1.9 ]
Mask matrix:
[ 0 -inf -inf -inf ]
[ 0 0 -inf -inf ]
[ 0 0 0 -inf ]
[ 0 0 0 0 ]
Token i can only see tokens 0..i (the lower triangle).
-inf becomes 0 after softmax, so masked tokens contribute nothing.
After softmax, the -inf entries become 0, effectively removing those tokens from the
weighted sum. The result is that each token’s output depends only on itself and all
preceding tokens.
After softmax with causal mask:
================================
[ 1.000 0.000 0.000 0.000 ] Token 0: only sees itself
[ 0.197 0.803 0.000 0.000 ] Token 1: sees tokens 0, 1
[ 0.192 0.078 0.639 0.000 ] Token 2: sees tokens 0, 1, 2 (skipped)
[ 0.092 0.186 0.083 0.306 ] Token 3: sees all tokens (skipped)
^^^^^^
sum may not be 1.0
because I showed
approximate values
During decoding, causal masking is implicit – the query is for position t and the KV
cache only contains positions 0..t-1, so there is nothing to mask.
Computational Complexity
Let us carefully count the operations in attention for a sequence of length n with
head dimension d:
Operation Counts for Attention
===============================
1. Q * K^T: matrix multiply [n x d] * [d x n] = [n x n]
Operations: 2 * n * n * d (multiply + add)
2. Scale: divide each of n*n elements by sqrt(d)
Operations: n * n
3. Softmax: for each of n rows, compute exp and normalize over n elements
Operations: ~3 * n * n (exp, sum, divide)
4. Weights * V: matrix multiply [n x n] * [n x d] = [n x d]
Operations: 2 * n * n * d
Total: ~4 * n^2 * d + 4 * n^2
= O(n^2 * d)
For n = 4096, d = 128:
4 * 4096^2 * 128 = ~8.6 billion operations per head
Times 32 heads = ~274 billion operations
Memory for attention matrix: n * n * sizeof(float) per head
= 4096 * 4096 * 4 = 64 MB per head
= 2 GB for 32 heads <-- This is the problem FlashAttention solves!
The quadratic scaling in sequence length is the fundamental limitation:
How Attention Scales with Sequence Length
==========================================
n Ops (per head) Attn Matrix Size
--- --------------- ----------------
512 ~34M 1 MB
1024 ~134M 4 MB
2048 ~537M 16 MB
4096 ~2.15B 64 MB
8192 ~8.59B 256 MB
16384 ~34.4B 1 GB
32768 ~137.4B 4 GB
131072 ~2.20T 64 GB (Llama 3.1 context length!)
The attention matrix alone exceeds GPU memory for long sequences.
This is why FlashAttention (next chapter) is essential.
Worked Example: 4-Token, 2-Head Attention
Let us trace through a complete attention computation with concrete numbers. We will use
a tiny example: 4 tokens, 2 heads, d_model = 8, d_k = d_v = 4 per head.
Setup
Model parameters:
d_model = 8
n_heads = 2
d_k = d_v = d_model / n_heads = 4
seq_len = 4
Input tokens: "The cat sat on"
Input embeddings X [4 x 8]:
Token 0 ("The"): [ 1.0, 0.5, 0.3, 0.8, | 0.2, 0.7, 0.4, 0.6 ]
Token 1 ("cat"): [ 0.3, 1.2, 0.7, 0.1, | 0.9, 0.4, 0.5, 0.3 ]
Token 2 ("sat"): [ 0.6, 0.2, 1.1, 0.4, | 0.3, 0.8, 1.0, 0.2 ]
Token 3 ("on"): [ 0.4, 0.7, 0.5, 1.0, | 0.6, 0.3, 0.7, 0.9 ]
Step 1: Compute Q, K, V (via projection)
For simplicity, let us assume the projections have already been computed and show the result after splitting into heads:
After projection and head split:
Head 0 (d_k = 4):
Q0 = [ 1.2, 0.3, 0.5, 0.8 ] K0 = [ 0.9, 0.4, 0.7, 0.2 ] V0 = [ 0.3, 0.8, 0.5, 0.1 ]
[ 0.4, 1.1, 0.2, 0.6 ] [ 0.5, 1.0, 0.3, 0.8 ] [ 0.7, 0.2, 0.9, 0.4 ]
[ 0.7, 0.5, 0.9, 0.3 ] [ 0.8, 0.6, 1.1, 0.5 ] [ 0.4, 0.6, 0.3, 0.8 ]
[ 0.3, 0.8, 0.4, 1.0 ] [ 0.2, 0.7, 0.5, 1.0 ] [ 0.9, 0.5, 0.7, 0.3 ]
Head 1 (d_k = 4):
Q1 = [ 0.6, 0.9, 0.2, 0.4 ] K1 = [ 0.3, 0.7, 0.5, 0.1 ] V1 = [ 0.5, 0.4, 0.2, 0.7 ]
[ 0.8, 0.3, 0.7, 0.5 ] [ 0.6, 0.2, 0.8, 0.4 ] [ 0.2, 0.9, 0.6, 0.3 ]
[ 0.1, 0.6, 0.4, 0.8 ] [ 0.4, 0.5, 0.3, 0.9 ] [ 0.8, 0.3, 0.5, 0.6 ]
[ 0.5, 0.4, 0.9, 0.7 ] [ 0.7, 0.3, 0.6, 0.5 ] [ 0.3, 0.7, 0.4, 0.8 ]
Step 2: Q * K^T for Head 0
Scores0 = Q0 * K0^T [4 x 4]
Score(0,0) = Q0[0] . K0[0] = 1.2*0.9 + 0.3*0.4 + 0.5*0.7 + 0.8*0.2
= 1.08 + 0.12 + 0.35 + 0.16 = 1.71
Score(0,1) = Q0[0] . K0[1] = 1.2*0.5 + 0.3*1.0 + 0.5*0.3 + 0.8*0.8
= 0.60 + 0.30 + 0.15 + 0.64 = 1.69
Score(0,2) = Q0[0] . K0[2] = 1.2*0.8 + 0.3*0.6 + 0.5*1.1 + 0.8*0.5
= 0.96 + 0.18 + 0.55 + 0.40 = 2.09
Score(0,3) = Q0[0] . K0[3] = 1.2*0.2 + 0.3*0.7 + 0.5*0.5 + 0.8*1.0
= 0.24 + 0.21 + 0.25 + 0.80 = 1.50
(Computing all 16 entries similarly...)
Scores0 = [ 1.71 1.69 2.09 1.50 ]
[ 1.16 1.68 1.49 1.37 ]
[ 1.38 1.22 1.80 1.08 ]
[ 1.04 1.49 1.33 1.42 ]
Step 3: Scale
sqrt(d_k) = sqrt(4) = 2.0
Scaled0 = Scores0 / 2.0
Scaled0 = [ 0.855 0.845 1.045 0.750 ]
[ 0.580 0.840 0.745 0.685 ]
[ 0.690 0.610 0.900 0.540 ]
[ 0.520 0.745 0.665 0.710 ]
Step 4: Apply Causal Mask
Masked0 = [ 0.855 -inf -inf -inf ]
[ 0.580 0.840 -inf -inf ]
[ 0.690 0.610 0.900 -inf ]
[ 0.520 0.745 0.665 0.710 ]
Step 5: Softmax (row by row)
Row 0: softmax([0.855, -inf, -inf, -inf])
= [1.000, 0.000, 0.000, 0.000]
Row 1: softmax([0.580, 0.840]) (ignoring -inf entries)
exp: [1.786, 2.317] sum = 4.103
= [0.435, 0.565, 0.000, 0.000]
Row 2: softmax([0.690, 0.610, 0.900])
exp: [1.994, 1.840, 2.460] sum = 6.294
= [0.317, 0.292, 0.391, 0.000]
Row 3: softmax([0.520, 0.745, 0.665, 0.710])
exp: [1.682, 2.107, 1.945, 2.034] sum = 7.768
= [0.217, 0.271, 0.250, 0.262]
Weights0 = [ 1.000 0.000 0.000 0.000 ]
[ 0.435 0.565 0.000 0.000 ]
[ 0.317 0.292 0.391 0.000 ]
[ 0.217 0.271 0.250 0.262 ]
Step 6: Weighted Sum of Values
Output0 = Weights0 * V0 [4 x 4]
Output0[0] = 1.000 * V0[0]
= [0.300, 0.800, 0.500, 0.100]
Output0[1] = 0.435 * V0[0] + 0.565 * V0[1]
= 0.435 * [0.3, 0.8, 0.5, 0.1] + 0.565 * [0.7, 0.2, 0.9, 0.4]
= [0.131, 0.348, 0.218, 0.044] + [0.396, 0.113, 0.509, 0.226]
= [0.526, 0.461, 0.726, 0.270]
Output0[2] = 0.317 * V0[0] + 0.292 * V0[1] + 0.391 * V0[2]
= [0.095, 0.254, 0.159, 0.032]
+ [0.204, 0.058, 0.263, 0.117]
+ [0.156, 0.235, 0.117, 0.313]
= [0.456, 0.547, 0.539, 0.461]
Output0[3] = 0.217 * V0[0] + 0.271 * V0[1] + 0.250 * V0[2] + 0.262 * V0[3]
= [0.065, 0.174, 0.109, 0.022]
+ [0.190, 0.054, 0.244, 0.108]
+ [0.100, 0.150, 0.075, 0.200]
+ [0.236, 0.131, 0.183, 0.079]
= [0.591, 0.509, 0.611, 0.409]
Step 7: Concatenate Heads and Project
Head 1 would be computed identically (with its own Q1, K1, V1). Then the two heads’ outputs are concatenated along the last dimension and projected:
Concatenation:
Output0[i]: [d_v = 4 values]
Output1[i]: [d_v = 4 values]
Concat[i] = [Output0[i] | Output1[i]] (d_model = 8 values)
Final = Concat * W_O [4 x 8] * [8 x 8] = [4 x 8]
This produces the final attention output for this layer.
Decode Attention vs Prefill Attention
The attention computation looks quite different depending on the inference phase:
Prefill Attention (processing entire prompt)
=============================================
Q: [seq_len x d_k] (all query positions)
K: [seq_len x d_k] (all key positions)
V: [seq_len x d_v] (all value positions)
QK^T: [seq_len x seq_len] -- full attention matrix
Softmax: [seq_len x seq_len]
Output: [seq_len x d_v]
This is a GEMM problem. Matrix-matrix multiplications.
Causal mask applied explicitly.
Decode Attention (generating one token at a time)
==================================================
q: [1 x d_k] (single new query)
K: [seq_len x d_k] (all keys from KV cache)
V: [seq_len x d_v] (all values from KV cache)
q * K^T: [1 x seq_len] -- one row of attention scores
Softmax: [1 x seq_len]
Output: [1 x d_v]
This is a GEMV problem. Vector-matrix multiplications.
No causal mask needed (KV cache only has past tokens).
This distinction is critical for optimization. Prefill attention can use tiled GEMM techniques with high data reuse. Decode attention is bandwidth-bound, reading through the entire KV cache for each head.
Decode Attention: Memory Access Pattern
=========================================
For one head, one decode step at position t:
Read q: 1 * d_k * 2 bytes = 256 bytes (d_k=128)
Read K cache: t * d_k * 2 bytes = varies
Read V cache: t * d_v * 2 bytes = varies
Write output: 1 * d_v * 2 bytes = 256 bytes
At position t = 4096:
Read K: 4096 * 128 * 2 = 1 MB per head
Read V: 4096 * 128 * 2 = 1 MB per head
Total per head: ~2 MB
Total for 32 heads: ~64 MB
At 200 GB/s: ~0.32 ms
This grows linearly with sequence length.
At t = 32768: ~0.5 GB read, ~2.5 ms
Putting It All Together
Here is a Metal kernel sketch showing the structure of decode attention for one head:
// Decode attention: single query vs KV cache
kernel void decode_attention(
device const half* q [[buffer(0)]], // [d_k]
device const half* K_cache [[buffer(1)]], // [seq_len x d_k]
device const half* V_cache [[buffer(2)]], // [seq_len x d_v]
device half* output [[buffer(3)]], // [d_v]
constant uint& seq_len [[buffer(4)]],
constant uint& d_k [[buffer(5)]],
uint tid [[thread_position_in_grid]])
{
// Step 1: Compute attention scores (q . K[i] for each i)
float score = 0.0f;
for (uint d = 0; d < d_k; d++) {
score += float(q[d]) * float(K_cache[tid * d_k + d]);
}
score /= sqrt(float(d_k));
// Step 2: Softmax (requires two passes or online algorithm)
// ... (this is where FlashAttention comes in!)
// Step 3: Weighted sum of values
// output[d] = sum over i: weight[i] * V_cache[i * d_v + d]
// ...
}
This sketch is deliberately incomplete – computing softmax across all positions requires either materializing the full score vector (memory expensive for long sequences) or using the online softmax algorithm that FlashAttention employs. That is exactly what we will cover in the next chapter.
Summary
The attention mechanism is the defining feature of transformers. Here are the key takeaways:
-
Q, K, V projections are matrix multiplications that transform input embeddings into queries, keys, and values.
-
The attention equation
softmax(QK^T / sqrt(d_k)) * Vcomputes a weighted combination of values based on query-key similarity. -
Multi-head attention splits the representation into independent heads that can learn different attention patterns.
-
Grouped Query Attention shares KV heads across multiple query heads, reducing KV cache memory by the group size factor.
-
The KV cache stores previously computed keys and values during decoding, avoiding redundant recomputation. It grows linearly with sequence length.
-
Causal masking ensures tokens only attend to past positions, enforced via -inf masking before softmax during prefill.
-
Complexity is O(n^2 * d) – quadratic in sequence length. The attention matrix itself requires O(n^2) memory, which becomes prohibitive for long sequences.
-
Prefill vs decode have very different computational profiles: prefill is a GEMM problem (compute bound), decode is a GEMV problem (bandwidth bound).
In the next chapter, we will see how FlashAttention solves the O(n^2) memory problem through a clever tiling scheme and online softmax computation.