Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

FlashAttention: Theory and Practice

In the previous chapter, we saw that standard attention requires materializing an [n x n] attention matrix. For a sequence length of 4096, that is 64 MB per head. For 32 heads, 2 GB. For 32,768 tokens (Llama 3’s context), 32 GB just for attention weights. This is clearly untenable.

FlashAttention is the algorithmic breakthrough that makes long-context attention practical. The key idea is beautifully simple: never materialize the full attention matrix. Instead, compute attention in tiles, maintaining a running softmax that produces exactly the same result as the standard algorithm – bit for bit.

This chapter will build up the theory from first principles, work through a complete numerical example, and then show how FlashAttention adapts for both prefill and decode phases, including advanced variants for GQA and speculative decoding.

The Problem: O(n^2) Memory

Let us be very concrete about what “materializing the attention matrix” means in the standard algorithm:

  Standard Attention Memory Usage
  ================================

  Step 1: S = Q * K^T         Allocate [n x n] matrix S
  Step 2: S = S / sqrt(d)     In-place on S
  Step 3: P = softmax(S)      Allocate [n x n] matrix P (or overwrite S)
  Step 4: O = P * V           Allocate [n x d] matrix O

  Peak memory: at least one [n x n] matrix in memory simultaneously

  n = 4096, FP32:
    [n x n] = 4096 * 4096 * 4 bytes = 64 MB (per head!)
    32 heads = 2 GB

  n = 32768, FP32:
    [n x n] = 32768 * 32768 * 4 bytes = 4 GB (per head!)
    32 heads = 128 GB  <-- exceeds any GPU's memory

Even if memory were unlimited, there is a performance problem. The attention matrix is written to device memory in step 1/3 and read back in step 3/4. Those round-trips to device memory are slow:

  Memory Traffic in Standard Attention
  =====================================

  Write S to memory:    n^2 * 4 bytes      (after Q * K^T)
  Read S from memory:   n^2 * 4 bytes      (for softmax)
  Write P to memory:    n^2 * 4 bytes      (softmax output)
  Read P from memory:   n^2 * 4 bytes      (for P * V)

  Total memory traffic: 4 * n^2 * 4 bytes = 16 * n^2 bytes

  At n = 4096: 16 * 16M = 256 MB of memory traffic per head
  At 200 GB/s: 1.3 ms just for the attention matrix I/O (per head!)

FlashAttention eliminates this memory traffic entirely by keeping everything in registers and threadgroup memory (SRAM).

The Key Insight: Online Softmax

The reason we need the full [n x n] matrix is softmax. Softmax requires knowing the maximum value across an entire row (for numerical stability) and the sum of exponentials (for normalization). Both of these seem to require seeing all n values before producing any output.

The breakthrough is online softmax: a streaming algorithm that processes values one chunk at a time, maintaining running statistics that can be corrected as new data arrives.

Standard Softmax (Two-Pass)

The textbook softmax for a vector z of length n:

  Pass 1: Find maximum
    m = max(z[0], z[1], ..., z[n-1])

  Pass 2: Compute exponentials and sum
    For i in 0..n-1:
      e[i] = exp(z[i] - m)
    s = sum(e[0], e[1], ..., e[n-1])

  Output: softmax[i] = e[i] / s

  Requires storing all n values of z to make two passes.

Online Softmax (One-Pass, Streaming)

Online softmax processes elements one at a time (or one tile at a time), maintaining a running maximum m and a running sum s. When a new maximum is found, the previously accumulated sum is corrected:

  Initialize:
    m = -infinity    (running maximum)
    s = 0            (running sum of exp)

  For each new value z[i]:
    m_new = max(m, z[i])
    s = s * exp(m - m_new) + exp(z[i] - m_new)
    m = m_new

  After all values:
    softmax[i] = exp(z[i] - m) / s

  The key line: s = s * exp(m - m_new) + exp(z[i] - m_new)

  When m_new > m (new max found):
    exp(m - m_new) < 1, so the old sum is scaled DOWN
    This corrects for the fact that previous exponentials
    were computed relative to the old (smaller) maximum

  When m_new = m (no new max):
    exp(m - m_new) = exp(0) = 1, so old sum unchanged
    Just add the new exponential

Let us trace through an example:

  Online Softmax Example: z = [2.0, 5.0, 1.0, 4.0]
  ===================================================

  Step 0: m = -inf, s = 0

  Step 1: z[0] = 2.0
    m_new = max(-inf, 2.0) = 2.0
    s = 0 * exp(-inf - 2.0) + exp(2.0 - 2.0)
      = 0 + exp(0)
      = 1.0
    m = 2.0

  Step 2: z[1] = 5.0
    m_new = max(2.0, 5.0) = 5.0       <-- New max found!
    s = 1.0 * exp(2.0 - 5.0) + exp(5.0 - 5.0)
      = 1.0 * exp(-3.0) + exp(0)
      = 1.0 * 0.0498 + 1.0
      = 1.0498
    m = 5.0

    Verification: at this point, s should equal exp(2-5) + exp(5-5) = 0.0498 + 1 = 1.0498  ✓
    The correction factor exp(2.0 - 5.0) = 0.0498 adjusts the old sum.

  Step 3: z[2] = 1.0
    m_new = max(5.0, 1.0) = 5.0       <-- No new max
    s = 1.0498 * exp(5.0 - 5.0) + exp(1.0 - 5.0)
      = 1.0498 * 1.0 + exp(-4.0)
      = 1.0498 + 0.0183
      = 1.0681
    m = 5.0

  Step 4: z[3] = 4.0
    m_new = max(5.0, 4.0) = 5.0       <-- No new max
    s = 1.0681 * exp(5.0 - 5.0) + exp(4.0 - 5.0)
      = 1.0681 * 1.0 + exp(-1.0)
      = 1.0681 + 0.3679
      = 1.4360
    m = 5.0

  Final softmax:
    softmax[0] = exp(2.0 - 5.0) / 1.4360 = 0.0498 / 1.4360 = 0.0347
    softmax[1] = exp(5.0 - 5.0) / 1.4360 = 1.0000 / 1.4360 = 0.6964
    softmax[2] = exp(1.0 - 5.0) / 1.4360 = 0.0183 / 1.4360 = 0.0128
    softmax[3] = exp(4.0 - 5.0) / 1.4360 = 0.3679 / 1.4360 = 0.2562

  Sum = 0.0347 + 0.6964 + 0.0128 + 0.2562 = 1.0001 ≈ 1.0  ✓

Online Softmax with Running Output Correction

But FlashAttention needs more than just softmax values – it needs softmax(S) * V, where V is the value matrix. We need to maintain a running weighted sum of V that gets corrected each time the maximum changes:

  Online Attention (combining softmax + V multiply):
  ====================================================

  Initialize:
    m = -inf          (running max)
    s = 0             (running sum of exp)
    o = 0             (running output, vector of size d_v)

  For each position j (or tile of positions):
    score = q . K[j]                       (dot product)
    m_new = max(m, score)
    correction = exp(m - m_new)

    // Correct the old output and sum
    o = o * correction
    s = s * correction

    // Add new contribution
    weight = exp(score - m_new)
    o = o + weight * V[j]
    s = s + weight

    m = m_new

  After all positions:
    output = o / s    (normalize by total sum)

This is the core of FlashAttention. Let us trace through it:

  Online Attention Example
  =========================

  q = [1, 0]  (query, dimension 2 for simplicity)

  K = [ 0.5,  0.3 ]    V = [ 1.0,  0.0 ]
      [ 0.8, -0.2 ]        [ 0.0,  1.0 ]
      [ 0.1,  0.7 ]        [ 0.5,  0.5 ]

  Scores = q . K[j]:
    s0 = 1*0.5 + 0*0.3 = 0.5
    s1 = 1*0.8 + 0*(-0.2) = 0.8
    s2 = 1*0.1 + 0*0.7 = 0.1

  (Skip scaling by sqrt(d) for clarity)

  Step 0: m = -inf, s = 0, o = [0, 0]

  Step 1: Process position 0  (score = 0.5)
    m_new = max(-inf, 0.5) = 0.5
    correction = exp(-inf - 0.5) = 0
    o = [0,0] * 0 = [0, 0]
    s = 0 * 0 = 0
    weight = exp(0.5 - 0.5) = 1.0
    o = [0,0] + 1.0 * [1.0, 0.0] = [1.0, 0.0]
    s = 0 + 1.0 = 1.0
    m = 0.5

  Step 2: Process position 1  (score = 0.8)
    m_new = max(0.5, 0.8) = 0.8           <-- New max!
    correction = exp(0.5 - 0.8) = exp(-0.3) = 0.7408
    o = [1.0, 0.0] * 0.7408 = [0.7408, 0.0]
    s = 1.0 * 0.7408 = 0.7408
    weight = exp(0.8 - 0.8) = 1.0
    o = [0.7408, 0.0] + 1.0 * [0.0, 1.0] = [0.7408, 1.0]
    s = 0.7408 + 1.0 = 1.7408
    m = 0.8

  Step 3: Process position 2  (score = 0.1)
    m_new = max(0.8, 0.1) = 0.8           <-- No new max
    correction = exp(0.8 - 0.8) = 1.0
    o = [0.7408, 1.0] * 1.0 = [0.7408, 1.0]
    s = 1.7408 * 1.0 = 1.7408
    weight = exp(0.1 - 0.8) = exp(-0.7) = 0.4966
    o = [0.7408, 1.0] + 0.4966 * [0.5, 0.5] = [0.7408+0.2483, 1.0+0.2483]
      = [0.9891, 1.2483]
    s = 1.7408 + 0.4966 = 2.2374
    m = 0.8

  Normalize: output = o / s = [0.9891/2.2374, 1.2483/2.2374]
                             = [0.4421, 0.5579]

  Verification (standard method):
    scores = [0.5, 0.8, 0.1]
    softmax = exp([0.5, 0.8, 0.1] - 0.8) / sum
            = [exp(-0.3), exp(0), exp(-0.7)] / sum
            = [0.7408, 1.0, 0.4966] / 2.2374
            = [0.3311, 0.4470, 0.2219]
    output = 0.3311*[1,0] + 0.4470*[0,1] + 0.2219*[0.5,0.5]
           = [0.3311, 0] + [0, 0.4470] + [0.1110, 0.1110]
           = [0.4421, 0.5580]  ✓ (matches!)

The Tiling Strategy

Online softmax lets us process keys/values one at a time. But processing one position at a time would be very slow – too little parallelism. FlashAttention processes keys and values in tiles of size Bk (typically 32 or 64).

  FlashAttention Tiling
  ======================

  Instead of materializing the full [n x n] attention matrix:

  Standard:
  +--------------------------------------------------+
  |                                                    |
  |            Full [n x n] matrix in memory           |
  |                                                    |
  +--------------------------------------------------+
  Memory: O(n^2)


  FlashAttention:
  Process K/V in tiles, never store more than [Bq x Bk]:

  +------+
  | Tile |  Bq x Bk     Only this small tile is in memory
  +------+  at a time!

  Iterate over all K/V tiles:

  K/V:  [====Tile 0====][====Tile 1====][====Tile 2====]...
             Bk cols         Bk cols         Bk cols

  For each tile:
  1. Load Bk keys and values into threadgroup memory
  2. Compute Bq x Bk scores (in registers)
  3. Update running max, sum, and output (online softmax)
  4. Move to next tile

  Memory: O(Bq * Bk) for the score tile
        + O(Bq * d) for the running output
        = O(n) total (since Bq, Bk, d are all constants)

The Full Algorithm

Here is the complete FlashAttention algorithm for prefill (multiple query positions):

  FlashAttention Algorithm (Prefill)
  ====================================

  Inputs: Q [n x d], K [n x d], V [n x d]
  Output: O [n x d]
  Tile sizes: Bq (query tile), Bk (key/value tile)

  For each query tile q_tile in [0, n/Bq):
    Load Q_tile = Q[q_tile*Bq : (q_tile+1)*Bq]    // [Bq x d]

    Initialize per-row:
      m[Bq] = -infinity      // running max for each query in tile
      s[Bq] = 0              // running sum for each query in tile
      O_tile[Bq x d] = 0    // running output for each query in tile

    For each KV tile kv_tile in [0, n/Bk):

      // --- Skip if causally masked ---
      // If ALL queries in q_tile come BEFORE all keys in kv_tile,
      // then all scores are masked to -inf. Skip entirely.
      if (q_tile + 1) * Bq <= kv_tile * Bk:
        continue

      Load K_tile = K[kv_tile*Bk : (kv_tile+1)*Bk]    // [Bk x d]
      Load V_tile = V[kv_tile*Bk : (kv_tile+1)*Bk]    // [Bk x d]

      // --- Compute score tile ---
      S_tile = Q_tile * K_tile^T / sqrt(d)              // [Bq x Bk]

      // --- Apply causal mask within tile ---
      For each (i, j) in S_tile:
        if (q_tile*Bq + i) < (kv_tile*Bk + j):
          S_tile[i][j] = -infinity

      // --- Online softmax update ---
      For each row i in [0, Bq):
        m_new = max(m[i], max(S_tile[i]))
        correction = exp(m[i] - m_new)

        // Correct old accumulators
        O_tile[i] *= correction
        s[i] *= correction

        // Add new contributions
        For each j in [0, Bk):
          weight = exp(S_tile[i][j] - m_new)
          O_tile[i] += weight * V_tile[j]
          s[i] += weight

        m[i] = m_new

    // --- Normalize ---
    For each row i in [0, Bq):
      O_tile[i] /= s[i]

    Write O_tile to O[q_tile*Bq : (q_tile+1)*Bq]

Let us visualize how the tiles march through the attention matrix:

  Tiling Visualization (n=16, Bq=4, Bk=4)
  ==========================================

  Full attention matrix [16 x 16] with causal mask:

       KV tiles:  0    1    2    3
              +----+----+----+----+
  Q tile 0:  | XX |    |    |    |    XX = computed tile
              +----+----+----+----+    .. = skipped (causal)
  Q tile 1:  | XX | XX |    |    |
              +----+----+----+----+
  Q tile 2:  | XX | XX | XX |    |
              +----+----+----+----+
  Q tile 3:  | XX | XX | XX | XX |
              +----+----+----+----+

  Processing order for Q tile 2:
  1. Load Q[8:12]                    (4 queries)
  2. Load K[0:4], V[0:4]            -> compute 4x4 scores, update
  3. Load K[4:8], V[4:8]            -> compute 4x4 scores, update
  4. Load K[8:12], V[8:12]          -> compute 4x4 scores (with mask), update
  5. Skip K[12:16] (causally masked) -> all queries < all keys
  6. Normalize and write O[8:12]

  At no point do we have more than a 4x4 score tile in memory.
  Memory: O(16 * d) for O instead of O(16 * 16) for the full matrix.

Memory Analysis: O(n) Instead of O(n^2)

Let us carefully count the memory used by FlashAttention:

  FlashAttention Memory Usage
  ============================

  Threadgroup memory (SRAM):
    Q_tile:      Bq * d * sizeof(half)     = 32 * 128 * 2  =   8 KB
    K_tile:      Bk * d * sizeof(half)     = 32 * 128 * 2  =   8 KB
    V_tile:      Bk * d * sizeof(half)     = 32 * 128 * 2  =   8 KB
    S_tile:      Bq * Bk * sizeof(float)   = 32 * 32 * 4   =   4 KB
    Total SRAM:                                             ~  28 KB

  Per-query state (registers):
    m[Bq]:       Bq * sizeof(float)        = 32 * 4        = 128 bytes
    s[Bq]:       Bq * sizeof(float)        = 32 * 4        = 128 bytes
    O_tile:      Bq * d * sizeof(float)    = 32 * 128 * 4  =  16 KB
    Total registers:                                        ~  16 KB

  Device memory:
    Input Q, K, V:  3 * n * d * sizeof(half)               = O(n * d)
    Output O:       n * d * sizeof(half)                    = O(n * d)
    NO attention matrix stored!                             = O(n)

  Comparison for n = 32768, d = 128:
    Standard:      n^2 * 4 bytes = 4 GB per head
    FlashAttention: n * d * 2 bytes = 8 MB per head   (500x reduction!)

A Complete Numerical Example with Tiling

Let us trace FlashAttention on a small example: n=6, d=2, Bq=2, Bk=3.

  Setup
  ======

  Q = [ 1.0,  0.5 ]     K = [ 0.3,  0.7 ]     V = [ 1.0,  0.0 ]
      [ 0.8, -0.1 ]         [ 0.6,  0.2 ]         [ 0.0,  1.0 ]
      [ 0.2,  0.9 ]         [-0.1,  0.8 ]         [ 0.5,  0.5 ]
      [-0.3,  0.4 ]         [ 0.4, -0.3 ]         [ 0.8,  0.2 ]
      [ 0.7,  0.6 ]         [ 0.9,  0.1 ]         [ 0.3,  0.7 ]
      [ 0.1, -0.5 ]         [ 0.2,  0.5 ]         [ 0.6,  0.4 ]

  sqrt(d) = sqrt(2) = 1.414

  Tile sizes: Bq = 2 (queries), Bk = 3 (keys)

Q tile 0 (queries 0-1), KV tile 0 (keys 0-2):

  Q_tile = [ 1.0,  0.5 ]     K_tile = [ 0.3,  0.7 ]     V_tile = [ 1.0,  0.0 ]
           [ 0.8, -0.1 ]              [ 0.6,  0.2 ]              [ 0.0,  1.0 ]
                                       [-0.1,  0.8 ]              [ 0.5,  0.5 ]

  S_tile = Q_tile * K_tile^T / 1.414:

  Q[0].K[0] = 1.0*0.3 + 0.5*0.7 = 0.65    /1.414 = 0.460
  Q[0].K[1] = 1.0*0.6 + 0.5*0.2 = 0.70    /1.414 = 0.495
  Q[0].K[2] = 1.0*(-0.1)+0.5*0.8 = 0.30   /1.414 = 0.212
  Q[1].K[0] = 0.8*0.3+(-0.1)*0.7 = 0.17   /1.414 = 0.120
  Q[1].K[1] = 0.8*0.6+(-0.1)*0.2 = 0.46   /1.414 = 0.325
  Q[1].K[2] = 0.8*(-0.1)+(-0.1)*0.8=-0.16 /1.414 = -0.113

  S_tile = [ 0.460   0.495   0.212 ]
           [ 0.120   0.325  -0.113 ]

  Causal mask (query 0 sees key 0 only, query 1 sees keys 0-1):
  S_tile = [ 0.460  -inf    -inf   ]
           [ 0.120   0.325  -inf   ]

  Online softmax update (initial m=-inf, s=0, O=0):

  Row 0:
    m_new = max(-inf, 0.460) = 0.460
    correction = exp(-inf - 0.460) = 0
    O[0] = [0,0]*0 = [0, 0]
    s[0] = 0*0 = 0
    weight for key 0: exp(0.460 - 0.460) = 1.0
    O[0] += 1.0 * V[0] = [1.0, 0.0]
    s[0] += 1.0 = 1.0
    (keys 1,2 masked: weight = exp(-inf) = 0)
    m[0] = 0.460

  Row 1:
    m_new = max(-inf, max(0.120, 0.325)) = 0.325
    correction = 0
    weight for key 0: exp(0.120 - 0.325) = exp(-0.205) = 0.8146
    weight for key 1: exp(0.325 - 0.325) = 1.0
    O[1] = 0.8146*[1,0] + 1.0*[0,1] = [0.8146, 1.0]
    s[1] = 0.8146 + 1.0 = 1.8146
    m[1] = 0.325

Q tile 0 (queries 0-1), KV tile 1 (keys 3-5):

  K_tile = [ 0.4, -0.3 ]     V_tile = [ 0.8,  0.2 ]
           [ 0.9,  0.1 ]              [ 0.3,  0.7 ]
           [ 0.2,  0.5 ]              [ 0.6,  0.4 ]

  But wait -- query 0 is at position 0, and keys 3-5 are all in the future.
  Query 1 is at position 1, and keys 3-5 are also in the future.
  ALL entries would be masked to -inf.

  We can skip this entire KV tile! (Causal skip optimization)

After processing all KV tiles for Q tile 0, normalize:

  O[0] = [1.0, 0.0] / 1.0 = [1.0, 0.0]
  O[1] = [0.8146, 1.0] / 1.8146 = [0.449, 0.551]

This is exactly what standard attention would produce (query 0 only sees key 0, so it copies V[0]; query 1 sees keys 0-1 with softmax weights).

The remaining Q tiles (1 and 2) would be processed similarly, each iterating over the relevant KV tiles.

Decode Variant: FlashDecoding

During decoding, we have a single query (M=1) attending to the entire KV cache. The FlashAttention algorithm simplifies because there is only one query:

  FlashDecoding (single query)
  =============================

  q: [1 x d]                     (single query vector)
  K_cache: [seq_len x d]         (all past keys)
  V_cache: [seq_len x d]         (all past values)

  Tiling: split KV cache into tiles of Bk positions

  +-------+-------+-------+-------+-------+
  | Tile0 | Tile1 | Tile2 | Tile3 | Tile4 |   K/V cache
  +-------+-------+-------+-------+-------+
   0..31   32..63  64..95  96..127  ...

  Each tile is processed by one threadgroup (or SIMD group).
  All tiles can be processed IN PARALLEL since each
  produces an independent (m, s, o) tuple.

  Then a final reduction step merges the per-tile results.

The parallel structure is different from prefill FlashAttention, where the KV tiles must be processed sequentially (because the running softmax state is updated incrementally). In decode, each tile independently produces a partial result:

  Parallel FlashDecoding
  =======================

  Phase 1: Each tile independently computes (m_t, s_t, o_t)

  Tile 0: q vs K[0:32]   -> (m_0, s_0, o_0)
  Tile 1: q vs K[32:64]  -> (m_1, s_1, o_1)
  Tile 2: q vs K[64:96]  -> (m_2, s_2, o_2)
  ...

  Phase 2: Merge all tiles' results

  merge(tile_i, tile_j):
    m_new = max(m_i, m_j)
    corr_i = exp(m_i - m_new)
    corr_j = exp(m_j - m_new)
    s_new = s_i * corr_i + s_j * corr_j
    o_new = (o_i * s_i * corr_i + o_j * s_j * corr_j) / s_new
    (or equivalently, keep unnormalized and normalize at the end)

  This merge is associative! Can be done in a tree reduction.
// FlashDecoding: Phase 1 -- per-tile computation
kernel void flash_decode_phase1(
    device const half*  q          [[buffer(0)]],    // [d]
    device const half*  K_cache    [[buffer(1)]],    // [seq_len x d]
    device const half*  V_cache    [[buffer(2)]],    // [seq_len x d]
    device float*       tile_max   [[buffer(3)]],    // [num_tiles]
    device float*       tile_sum   [[buffer(4)]],    // [num_tiles]
    device float*       tile_out   [[buffer(5)]],    // [num_tiles x d]
    constant uint&      seq_len    [[buffer(6)]],
    constant uint&      d          [[buffer(7)]],
    uint                lane       [[thread_index_in_simdgroup]],
    uint                sgid       [[simdgroup_index_in_threadgroup]],
    uint                tg_id      [[threadgroup_position_in_grid]])
{
    const uint Bk = 32;
    uint k_start = tg_id * Bk;
    uint k_end = min(k_start + Bk, seq_len);

    // Load q into registers (small enough)
    // (In practice, load into threadgroup memory cooperatively)

    float m = -INFINITY;
    float s = 0.0f;
    float o[128] = {0};  // Assuming d <= 128

    for (uint k = k_start; k < k_end; k++) {
        // Compute score = q . K[k] / sqrt(d)
        float score = 0.0f;
        for (uint i = lane; i < d; i += 32) {
            score += float(q[i]) * float(K_cache[k * d + i]);
        }
        score = simd_sum(score);
        score /= sqrt(float(d));

        // Online softmax update
        float m_new = max(m, score);
        float correction = exp(m - m_new);
        float weight = exp(score - m_new);

        s = s * correction + weight;
        for (uint i = lane; i < d; i += 32) {
            o[i] = o[i] * correction + weight * float(V_cache[k * d + i]);
        }
        m = m_new;
    }

    // Write per-tile results
    if (lane == 0) {
        tile_max[tg_id] = m;
        tile_sum[tg_id] = s;
    }
    for (uint i = lane; i < d; i += 32) {
        tile_out[tg_id * d + i] = o[i];
    }
}
// FlashDecoding: Phase 2 -- merge tile results
kernel void flash_decode_phase2(
    device const float* tile_max   [[buffer(0)]],    // [num_tiles]
    device const float* tile_sum   [[buffer(1)]],    // [num_tiles]
    device const float* tile_out   [[buffer(2)]],    // [num_tiles x d]
    device half*        output     [[buffer(3)]],    // [d]
    constant uint&      num_tiles  [[buffer(4)]],
    constant uint&      d          [[buffer(5)]],
    uint                lane       [[thread_index_in_simdgroup]])
{
    float m = -INFINITY;
    float s = 0.0f;
    float o[128] = {0};

    for (uint t = 0; t < num_tiles; t++) {
        float m_t = tile_max[t];
        float s_t = tile_sum[t];

        float m_new = max(m, m_t);
        float corr_old = exp(m - m_new);
        float corr_new = exp(m_t - m_new);

        // Merge sums
        s = s * corr_old + s_t * corr_new;

        // Merge outputs
        for (uint i = lane; i < d; i += 32) {
            o[i] = o[i] * corr_old + tile_out[t * d + i] * corr_new;
        }

        m = m_new;
    }

    // Normalize and write
    for (uint i = lane; i < d; i += 32) {
        output[i] = half(o[i] / s);
    }
}

Tree Masking for Speculative Decoding

Speculative decoding is an optimization where a smaller “draft” model proposes several tokens ahead, and the larger model verifies them all at once. The verification step requires attention with a tree mask instead of the standard causal mask.

In speculative decoding, the draft model might produce a tree of candidates:

  Speculative Decoding Tree
  ==========================

  Position:  0  1  2  3  4  5  6  7  8
  Token:     A  B  C  D  E  F  G  H  I
  Parent:    -  0  1  1  2  2  3  3  4

  Tree structure:
       A
       |
       B
      / \
     C   D
    /|   |\
   E  F  G  H
   |
   I

  Causal mask for this tree (1 = can attend, 0 = masked):

       A  B  C  D  E  F  G  H  I
  A  [ 1  0  0  0  0  0  0  0  0 ]
  B  [ 1  1  0  0  0  0  0  0  0 ]
  C  [ 1  1  1  0  0  0  0  0  0 ]
  D  [ 1  1  0  1  0  0  0  0  0 ]
  E  [ 1  1  1  0  1  0  0  0  0 ]
  F  [ 1  1  1  0  0  1  0  0  0 ]
  G  [ 1  1  0  1  0  0  1  0  0 ]
  H  [ 1  1  0  1  0  0  0  1  0 ]
  I  [ 1  1  1  0  1  0  0  0  1 ]

  Note: D can see A and B (its ancestors) but NOT C (sibling).
  E can see A, B, C (ancestors along its path) but NOT D, F, G, H.

FlashAttention handles tree masks by passing the mask as an additional input. During the tiled computation, the mask is consulted for each score tile:

  Tree Mask in FlashAttention
  ============================

  For each score tile S_tile[Bq x Bk]:
    For each (i, j):
      query_pos = q_tile * Bq + i
      key_pos   = kv_tile * Bk + j
      if mask[query_pos][key_pos] == 0:
        S_tile[i][j] = -infinity

  The mask is typically stored as a bitmask for efficiency:
    1 bit per (query, key) pair
    For n=256: 256 * 256 / 8 = 8 KB

  With standard causal mask, this is implicit (just compare positions).
  With tree mask, we need the explicit bitmask.

GQA Handling in FlashAttention

When using Grouped Query Attention, multiple query heads share the same KV head. The FlashAttention kernel needs to route each query head to the correct KV head:

  GQA in FlashAttention
  ======================

  n_q_heads = 32, n_kv_heads = 8, group_size = 4

  Query heads  0, 1, 2, 3   --> KV head 0
  Query heads  4, 5, 6, 7   --> KV head 1
  Query heads  8, 9, 10, 11 --> KV head 2
  ...
  Query heads 28, 29, 30, 31 --> KV head 7

  Implementation options:

  Option A: One kernel launch per KV head, process 4 Q heads
  +-----------+
  | Q heads   |  Process heads 0,1,2,3 together
  | 0,1,2,3   |  They all read the SAME K,V tiles
  |           |  Load K,V once, apply to 4 query sets
  +-----------+

  Option B: Separate launches, but the KV tile loading is shared
  Each Q head kernel knows its kv_head_id = q_head / group_size
  and indexes into the correct KV cache slice.

The efficient approach processes all query heads in a group simultaneously, loading each KV tile once:

// GQA FlashAttention: process one group of query heads
kernel void flash_attn_gqa(
    device const half*  Q          [[buffer(0)]],    // [n_q_heads x seq x d]
    device const half*  K_cache    [[buffer(1)]],    // [n_kv_heads x seq x d]
    device const half*  V_cache    [[buffer(2)]],    // [n_kv_heads x seq x d]
    device half*        O          [[buffer(3)]],    // [n_q_heads x seq x d]
    constant uint&      kv_head    [[buffer(4)]],    // Which KV head
    constant uint&      group_size [[buffer(5)]],    // Typically 4
    // ... other params ...
    uint                sgid       [[simdgroup_index_in_threadgroup]],
    uint                tg_id      [[threadgroup_position_in_grid]])
{
    // Determine which query head this SIMD group handles
    // Within one threadgroup, different SIMD groups handle different Q heads
    uint q_head_in_group = sgid % group_size;
    uint q_head = kv_head * group_size + q_head_in_group;

    // K, V come from kv_head (shared by all Q heads in group)
    device const half* K = K_cache + kv_head * seq_len * d;
    device const half* V = V_cache + kv_head * seq_len * d;

    // Q comes from this specific query head
    device const half* my_Q = Q + q_head * seq_len * d;

    // Now run standard FlashAttention on (my_Q, K, V)
    // K and V tiles are loaded once and shared across SIMD groups
    // within the threadgroup (via threadgroup memory)
    // ...
}

Numerical Stability Deep Dive

The exp and division operations in softmax are numerically treacherous. Let us examine the potential pitfalls and how online softmax handles them:

  Numerical Issues in Softmax
  ============================

  Problem 1: exp overflow
    exp(100) = 2.69e43   (fine in FP32)
    exp(89)  = 4.49e38   (max FP32 is 3.40e38!)
    exp(90)  = +inf      (OVERFLOW!)

  Solution: subtract the max before exp
    If max(z) = 90 and z[i] = 85:
      exp(85) = overflow
      exp(85 - 90) = exp(-5) = 0.0067   (safe!)

  Problem 2: exp underflow (less critical)
    exp(-100) = 3.72e-44  (denormal in FP32)
    exp(-104) = 0.0       (underflow to zero)

  This is usually fine -- those entries get negligible weight anyway.

  Online softmax and stability:
    We always compute exp(z[i] - m) where m = running max.
    Since m >= z[i] for all previously seen values:
      z[i] - m <= 0    (always non-positive)
      exp(z[i] - m) <= 1.0    (never overflows!)

  The correction factor exp(m_old - m_new):
    m_old <= m_new   (max only increases)
    m_old - m_new <= 0
    exp(m_old - m_new) <= 1.0   (never overflows!)

  Summary: online softmax is ALWAYS numerically safe.
  Standard softmax without the max-subtraction trick can overflow.

Practical Performance Considerations

Tile Size Selection

The choice of Bq and Bk affects performance significantly:

  Tile Size Tradeoffs
  ====================

  Larger tiles (Bq=64, Bk=64):
    + More compute per memory load (better arithmetic intensity)
    + Better SIMD group MMA utilization
    - More threadgroup memory needed
    - Fewer threadgroups can run concurrently (occupancy)
    - More wasted compute at sequence boundaries

  Smaller tiles (Bq=16, Bk=32):
    + Less threadgroup memory
    + Better occupancy
    + Less waste at boundaries
    - Lower arithmetic intensity
    - More kernel overhead per tile

  Typical sweet spots on Apple Silicon:
    Prefill: Bq=32, Bk=32 or Bq=64, Bk=64
    Decode:  Bk=32 or Bk=64 (Bq=1)

Memory Bandwidth vs Compute

  FlashAttention Performance Model
  =================================

  Prefill (n=4096, d=128, Bq=32, Bk=32):

  Per Q-tile iteration (processing one KV tile):
    Load K_tile: 32 * 128 * 2 = 8 KB
    Load V_tile: 32 * 128 * 2 = 8 KB
    Compute S:   2 * 32 * 32 * 128 = 262K FLOPs
    Compute O:   2 * 32 * 32 * 128 = 262K FLOPs
    Total compute: ~524K FLOPs
    Total memory: ~16 KB

    Arithmetic intensity: 524K / 16K = 32.75 FLOPs/byte
    This is in the compute-bound regime! (Good!)

  Decode (n=4096, d=128, Bk=32):
    Per KV tile:
      Load K_tile: 32 * 128 * 2 = 8 KB
      Load V_tile: 32 * 128 * 2 = 8 KB
      Compute scores: 2 * 1 * 32 * 128 = 8K FLOPs
      Compute output: 2 * 1 * 32 * 128 = 8K FLOPs
      Total compute: ~16K FLOPs
      Total memory: ~16 KB

    Arithmetic intensity: 16K / 16K = 1 FLOP/byte
    Bandwidth-bound (expected for decode)

    Total KV cache read: 2 * 4096 * 128 * 2 = 2 MB per head
    At 200 GB/s: ~0.01 ms per head, ~0.32 ms for 32 heads

Summary

FlashAttention transforms attention from an O(n^2) memory algorithm to an O(n) memory algorithm without changing the mathematical result. The key ideas are:

  1. Online softmax enables streaming computation by maintaining a running maximum and sum, with correction factors applied when the maximum changes. This is numerically stable by construction.

  2. Tiling divides the KV sequence into blocks of size Bk. For each query tile, we iterate over KV tiles, computing a small score matrix in registers/threadgroup memory and updating the running output.

  3. Memory savings are dramatic: from 4 GB per head (n=32768) to about 8 MB per head. This is what makes 128K+ context lengths feasible.

  4. Decode variant (FlashDecoding) processes KV tiles in parallel since there is only one query, then merges results using the same online softmax correction. This maximizes GPU utilization for the bandwidth-bound decode phase.

  5. Tree masking extends the algorithm for speculative decoding by replacing the simple causal mask comparison with a lookup into an explicit bitmask.

  6. GQA integration shares KV tile loads across multiple query heads in the same group, reducing redundant memory traffic.

The combination of FlashAttention with quantized KV caches and GQA is what makes modern long-context LLM inference practical on consumer GPUs. Without these techniques, even a 4096-token context would strain memory; with them, 128K-token contexts are routine.