FlashAttention Kernels
Attention is the defining operation of the transformer architecture, and getting it right on Apple Silicon is perhaps the most challenging kernel engineering task in Akunu. The attention operation computes O = softmax(Q @ K^T / sqrt(d)) @ V, but a naive implementation would materialize the full [seq_len, kv_seq_len] attention matrix, which is quadratic in memory. FlashAttention avoids this by tiling the computation and maintaining an online softmax that never materializes the full matrix.1
Akunu implements five attention kernels, each optimized for a different scenario:
| Kernel | File | Use Case | Dispatch | Threads |
|---|---|---|---|---|
| Standard Decode | flash_attention.metal | Short-context decode, short prefill | (seq_len, n_heads) | up to 1024 |
| Fast Decode | flash_attention_decode_fast.metal | M=1 single-token decode, medium context | (1, n_heads) | 32 (1 SG) |
| Parallel Decode | flash_attention_decode_parallel.metal | M=1 decode, very long context | (1, n_heads) | 1024 (32 SG) |
| Prefill V1 | flash_attention.metal | Medium-length prefill | (ceil(S/NQ), n_heads) | 128 (4 SG) |
| Prefill V2 | flash_attention_prefill_v2.metal | Long prefill (>= 1024 tokens) | (ceil(S/32), n_heads) | 128 (4 SG) |
Plus a standalone Softmax kernel in softmax.metal.
This chapter covers all of them in detail – their algorithms, thread assignments, memory strategies, and performance characteristics.
The Online Softmax Algorithm
Before diving into kernels, we need to understand the algorithm they all share: online softmax. The standard softmax requires two passes over the data (find max, then compute exp and sum). FlashAttention’s online variant maintains a running max and sum, allowing it to process KV entries in a single streaming pass.
- Initialize:
max = -inf,sum = 0,O = 0 - For each KV block (tile of 32 positions):
- Compute scores:
S = Q @ K^T * scale - Find
block_max = max(S)
- Compute scores:
- Update running state:
new_max = max(running_max, block_max)correction = exp(old_max - new_max)running_sum = running_sum * correction + sum(exp(S - new_max))
- Rescale and accumulate V:
O = O * correction + exp(S - new_max) @ V
- Finalize:
O = O / running_sum
The correction factor exp(old_max - new_max) rescales all previously accumulated values when a new maximum is discovered. This is the heart of the algorithm – it allows processing KV entries in arbitrary-sized blocks without ever storing the full attention matrix.
Now let’s see how each kernel variant implements this differently. But first, watch the online softmax in action — this is the foundation that all four kernels share.
Interactive: Online Softmax — The Core Algorithm
This animation processes 6 KV positions one at a time. Watch the running max and sum update, and pay attention to what happens at position 3 when a new maximum is discovered — the correction factor rescales all previous work. This is the trick that makes FlashAttention possible.
Now let’s see how the four kernel variants implement this algorithm with different parallelization strategies.
Kernel 1: Standard Decode (flash_attention_decode_f16)
The standard decode kernel handles the general case: one threadgroup per query position, with threads collaborating on the QK dot product and V accumulation.
Thread Assignment
Dispatch: grid = (seq_len, n_heads), threadgroup = (head_dim or 1024)
Each thread “owns” one element of the head dimension. For head_dim=128, 128 threads are used. For head_dim=256, 256 threads. The cap at 1024 accommodates models with very large head dimensions.
Algorithm
float q_val = (tid < head_dim) ? float(q_row[tid]) : 0.0f;
for (uint kv_start = 0; kv_start < kv_seq_len; kv_start += ATTN_KV_TILE) {
for (uint kv = 0; kv < tile_len; kv++) {
// Q·K dot product: each thread multiplies one element
float local_dot = (tid < head_dim) ? q_val * float(k_row[tid]) : 0.0f;
// Cross-SIMD reduction via shared memory
float simd_val = simd_sum(local_dot);
if (slid == 0) shared_reduce[sgid] = simd_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Sum across SIMD groups
float score = shared_reduce[0];
for (uint s = 1; s < n_simd_groups; s++)
score += shared_reduce[s];
threadgroup_barrier(mem_flags::mem_threadgroup);
score *= scale;
// Online softmax update
float new_max = max(running_max, score);
float exp_score = exp(score - new_max);
float correction = exp(old_max - new_max);
running_sum = running_sum * correction + exp_score;
// V accumulation in register
if (tid < head_dim) {
float v_val = float(v_base[kv_pos * head_dim + tid]);
acc = acc * correction + exp_score * v_val;
}
running_max = new_max;
}
}
Key observations:
- Register-local accumulator: Each thread accumulates exactly one output element in
acc. No threadgroup memory needed for the output. - Two barriers per KV entry: One after the SIMD-group reduction write, one after the read. These are the performance bottleneck for long contexts.
- Q value cached in register: The query vector is loaded once and reused across all KV entries.
GQA Support
Grouped-Query Attention is handled by a single integer division:
const uint kv_head = head / (n_heads / n_kv_heads);
Multiple query heads share the same KV head. The KV stride is computed per-head, and all query heads in the same group read from the same K/V buffers.
Causal and Tree Masking
The kernel supports three masking modes via function constants:
constant bool FC_USE_TREE_MASK [[function_constant(1)]];
constant bool FC_NON_CAUSAL [[function_constant(3)]];
| Mode | Use Case | Masking Logic |
|---|---|---|
| Causal | Standard prefill | kv_pos > q_pos -> skip |
| Tree | Speculative verification | Bitmask lookup per (q,kv) pair |
| Non-causal | BERT/Whisper encoder | No masking |
Tree masking uses a per-query bitmask packed into uint16_t:
if (kv_pos >= batch_start) {
uint batch_idx = kv_pos - batch_start;
if (batch_idx < seq_len && ((tree_mask[q_pos] >> batch_idx) & 1) == 0)
continue;
}
This enables speculative decoding with tree-structured draft tokens, where each draft token can only attend to its ancestors in the tree.
Interactive: Standard Decode — Barrier-Based QK Broadcast
In the standard decode kernel, multiple SIMD groups share a threadgroup. The QK dot product needs cross-SG communication via threadgroup memory and barriers. Each KV position requires 2 barriers — the dominant cost for long contexts.
Kernel 2: Fast Decode (flash_attention_decode_fast_f16)
The fast decode kernel is a radical simplification: a single SIMD group (32 threads) per head, with zero threadgroup barriers.
Why No Barriers?
The standard decode kernel needs barriers because the QK dot product requires cross-SIMD-group communication. With only one SIMD group, simd_sum() provides the full reduction – no shared memory needed:
float local_dot = 0.0f;
for (uint e = 0; e < elems_per_thread; e++) {
uint idx = slid + e * SIMD_WIDTH;
local_dot += q_vals[e] * float(k_row[idx]);
}
float score = simd_sum(local_dot) * scale;
Each thread handles head_dim / 32 elements (4 elements for head_dim=128, 2 for head_dim=64). The simd_sum broadcasts the result to all lanes instantly, without any barrier.
Performance Profile
| Aspect | Standard Decode | Fast Decode |
|---|---|---|
| Threads per head | up to 1024 | 32 |
| Barriers per KV | 2 | 0 |
| Memory parallelism | High (many threads read KV) | Low (32 threads read KV) |
| Barrier overhead | ~0.2us * 2 * kv_len | 0 |
| Best for | Short contexts (< ~128 KV) | Medium contexts (128-1024 KV) |
For context lengths beyond ~128 KV entries, the barrier overhead in the standard kernel dominates. Each barrier costs roughly 0.2 microseconds, and with 2 barriers per KV entry, a 1024-entry context costs ~400 microseconds in barrier overhead alone. The fast decode kernel eliminates this entirely, at the cost of lower memory bandwidth utilization (32 threads vs 128+).
Multi-Element V Accumulation
float acc[8] = {}; // max 8 elements per thread
for (uint kv_pos = 0; kv_pos < kv_seq_len; kv_pos++) {
// ... compute score ...
for (uint e = 0; e < elems_per_thread; e++) {
uint idx = slid + e * SIMD_WIDTH;
float v_val = float(v_base[kv_pos * head_dim + idx]);
acc[e] = acc[e] * correction + exp_score * v_val;
}
}
Each thread maintains elems_per_thread accumulators in registers. The memory access pattern for V is strided: thread 0 reads elements 0, 32, 64, 96 (for head_dim=128). This is suboptimal for cache lines but acceptable because the KV data is typically in the SLC.
Interactive: Fast Decode — Zero Barriers
The breakthrough: use only 1 SIMD group (32 threads). Each thread holds multiple head_dim elements (4 for head_dim=128). Since all 32 threads are in one SG, simd_sum() gives the full QK dot product — no threadgroup memory, no barriers. Compare the barrier count to Standard Decode above.
Kernel 3: Parallel Decode (flash_attention_decode_parallel_f16)
For very long contexts (thousands of KV entries), even the fast decode kernel is limited by its sequential scan of KV entries. The parallel decode kernel uses 32 SIMD groups (1024 threads) to parallelize across the KV dimension:
constexpr uint NUM_SG = 32;
const uint sgid = tid / 32; // SIMD group = KV position group
const uint slid = tid % 32; // lane = head_dim partition
KV Parallelism
Each SIMD group handles every 32nd KV position:
for (uint kv_pos = sgid; kv_pos < kv_seq_len; kv_pos += NUM_SG) {
// Compute dot product for this KV position
// Online softmax within this SG's partial view
}
SG 0 processes positions 0, 32, 64, …; SG 1 processes 1, 33, 65, …; and so on. This gives 32x memory parallelism for KV reads compared to the single-SG approach.
Cross-SG Reduction
After all SGs finish their partial computations, the results must be merged. This is the tricky part – each SG has a partial (max_score, sum_exp, output[head_dim]) triplet that needs to be combined with correction factors:
// Phase 1: Find global max
if (slid == 0) {
tg_max[sgid] = max_score;
tg_sum[sgid] = sum_exp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float loaded_max = tg_max[slid];
float global_max = simd_max(loaded_max);
// Phase 2: Correct sums
float loaded_sum = tg_sum[slid] * fast::exp(tg_max[slid] - global_max);
float global_sum = simd_sum(loaded_sum);
float inv_sum = 1.0f / global_sum;
// Phase 3: Reduce output per element
float my_factor = fast::exp(max_score - global_max);
for (uint i = 0; i < elems; i++) {
tg_out[slid * 32 + sgid] = o[i] * my_factor;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgid == 0) {
float total = 0;
for (uint s = 0; s < 32; s++)
total += tg_out[slid * 32 + s];
o_row[slid * elems + i] = half(total * inv_sum);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
The reduction uses 4KB of threadgroup memory (tg_out[1024]) to stage the per-SG output partials. For each head_dim element, all 32 SGs write their corrected partials, then SG 0 sums them.
When to Use Parallel Decode
The parallel decode kernel is fastest for very long contexts (10K+ KV entries) where the KV scan dominates. For shorter contexts, the reduction overhead (barriers + threadgroup memory traffic) makes it slower than the fast decode kernel.
Interactive: Parallel Decode — 32 SGs Divide the KV Cache
For long contexts, even fast decode is slow because one SG must scan all KV positions sequentially. Parallel decode uses 32 SIMD groups, each processing every 32nd KV position. After the parallel scan, a cross-SG reduction merges the results. This is 32x more memory bandwidth for the KV read.
Kernel 4: Prefill V1 (flash_attention_prefill_f16)
The prefill kernel processes multiple query rows simultaneously using SIMD group matrix multiply-accumulate (MMA) for both the QK product and the PV product.
Threadgroup Geometry
constant constexpr uint NQ_ROWS_DEFAULT = 8;
constant constexpr uint KV_BLOCK_SIZE = 32;
constant constexpr uint V2_TG_SIZE = 128; // 4 SIMD groups
Each threadgroup handles NQ query rows (8 or 16, depending on head_dim) and tiles the KV dimension in blocks of 32.
Adaptive NQ via Function Constants
constant uint FC_NQ_ROWS [[function_constant(2)]];
const uint NQ = FC_NQ_SPECIALIZED ? FC_NQ_ROWS : NQ_ROWS_DEFAULT;
| Head Dim | NQ | Passes | TG Memory |
|---|---|---|---|
| <= 64 | 16 | 2 | ~6KB |
| <= 128 | 8 | 1 | ~12KB |
| <= 256 | 4 | 1 | ~24KB |
For small head dimensions, NQ=16 allows processing more queries per threadgroup, amortizing the KV load cost over more query rows.
Threadgroup Memory Layout
threadgroup half *tg_Q // [NQ, HD] half
threadgroup float *tg_S // [NQ, 32] float (score matrix)
threadgroup float *tg_O // [NQ, HD] float (output accumulator)
threadgroup half *tg_K // [HD, 32] half (also reused as P scratch)
threadgroup half *tg_V // [32, HD] half
threadgroup float *tg_row_max // [NQ] float
threadgroup float *tg_row_sum // [NQ] float
The output accumulator is in FP32 to maintain precision across many accumulation steps. The K tile is stored transposed ([HD, 32] instead of [32, HD]) because the MMA operation needs K in column-major format for the Q @ K^T product.
Direct Device K Loading
A key optimization in V1: for full KV blocks (32 entries), K is loaded directly from device memory during the MMA computation, bypassing threadgroup memory entirely:
if (full_block) {
const uint kv_block_pos = kv_start + s_col_tile * SIMD_TILE;
for (uint inner = 0; inner < head_dim; inner += SIMD_TILE) {
simdgroup_load(q_tile, tg_Q + q_row_offset * head_dim + inner, head_dim);
simdgroup_load(k_tile, k_base + kv_block_pos * head_dim + inner,
head_dim, 0, true); // transposed load!
simdgroup_multiply_accumulate(acc_s, q_tile, k_tile, acc_s);
}
}
The true parameter on simdgroup_load requests a transposed load. Apple’s MMA hardware can load matrices in either row-major or column-major order, so the transposition is free. This eliminates the cooperative load + barrier for the K tile in full blocks, saving ~2 microseconds per block.
Causal Skip Optimization
const bool is_causal = (!FC_HAS_TREE_MASK && !FC_IS_NON_CAUSAL
&& kv_seq_len == seq_len);
const uint causal_kv_limit = is_causal
? min(((last_q_pos + KV_BLOCK_SIZE) / KV_BLOCK_SIZE) * KV_BLOCK_SIZE,
kv_seq_len)
: kv_seq_len;
For causal attention, KV blocks beyond the causal limit are entirely masked and can be skipped. This saves roughly half the computation for long sequences (the causal mask is triangular).
Block Classification
Within the causal region, blocks are classified as fully unmasked, partially masked, or fully masked:
const bool block_fully_unmasked = is_causal
? (kv_end <= first_q_pos + 1 && kv_seq_len > KV_BLOCK_SIZE * 4)
: (!FC_HAS_TREE_MASK);
if (block_fully_unmasked && full_block && pass_nq == TILE_ROWS) {
// Fast path: just scale, no per-element masking
for (uint idx = tid_flat; idx < TILE_ROWS * KV_BLOCK_SIZE; idx += V2_TG_SIZE) {
tg_S[qr * KV_BLOCK_SIZE + kv_col] *= scale;
}
} else {
// Slow path: per-element causal/tree masking
}
Fully unmasked blocks (far from the causal diagonal) use a fast path that just scales the scores. Partially masked blocks (near the diagonal) check each element individually.
Multi-Pass for Large NQ
When NQ > 8, the kernel processes query rows in passes of 8 (matching the SIMD MMA tile size):
for (uint q_pass = 0; q_pass < n_q_passes; q_pass++) {
const uint q_row_offset = q_pass * TILE_ROWS;
// Compute S, softmax, P@V for rows [q_row_offset..+8]
}
The V tile is loaded once per KV block and reused across all query passes, amortizing the memory cost.
Interactive: Prefill Attention — Tiled Q x KV with Causal Mask
Prefill processes all prompt tokens at once. Unlike decode (1 query, all KV), prefill has many queries and many KV positions. It tiles both dimensions: a block of queries (BQ=32) iterates over blocks of KV (BK=32), computing attention scores and accumulating outputs — like a GEMM with online softmax and a causal mask.
Kernel 5: Prefill V2 (flash_attention_prefill_v2_f16)
The V2 kernel is the newest and fastest attention kernel, designed for long sequences. Its key innovations are:
Register-Based Output
Instead of accumulating output in threadgroup memory (tg_O), V2 keeps the output in SIMD group registers:
float2 acc_o[HD_TILES]; // HD/8 output tile fragments
for (uint i = 0; i < HD_TILES; i++)
acc_o[i] = float2(0);
Each thread holds two float values per output tile (matching the thread_elements() of an 8x8 SIMD matrix). For head_dim=128, this is 16 float2 values = 128 bytes per thread – entirely in registers, no threadgroup memory needed for output.
exp2 Instead of exp
const float scale_log2 = params.scale * M_LOG2E_F;
// Pre-scale Q
tg_Q[idx] = half(float(tg_Q[idx]) * scale_log2);
// Later, in softmax:
s_frag[t][0] = fast::exp2(s_frag[t][0] - new_max);
By pre-multiplying the scale factor by log2(e), the kernel can use fast::exp2 instead of exp. On Apple Silicon, fast::exp2 maps directly to the hardware transcendental unit and is approximately 2x faster than exp.2
Row-Level Reduce via SIMD Shuffle
Apple’s 8x8 MMA distributes elements across lanes in a specific pattern. To compute row max and row sum, V2 uses XOR shuffles:
template <typename T>
METAL_FUNC T row_max(T v) {
v = metal::max(v, simd_shuffle_xor(v, 1));
v = metal::max(v, simd_shuffle_xor(v, 8));
return v;
}
template <typename T>
METAL_FUNC T row_sum(T v) {
v += simd_shuffle_xor(v, 1);
v += simd_shuffle_xor(v, 8);
return v;
}
In Apple’s MMA layout, lanes that share a row differ by XOR distances of 1 and 8. Two shuffle operations suffice for a complete row reduction – no threadgroup barriers needed.
BQ=32 Query Block Size
V2 processes 32 query rows per threadgroup (vs V1’s 8-16), with 4 SIMD groups each handling 8 rows. This 4x improvement in query batch size means:
- The V tile ([32, head_dim]) is loaded once and reused across all 32 queries
- The KV load cost is amortized over 4x more queries
- The total number of threadgroups is reduced by 4x
KV Tile Reuse for K and V
threadgroup half *tg_KV = (threadgroup half *)(tg_raw + off);
// First: load K transposed → tg_KV
// Compute S = Q @ K^T
threadgroup_barrier(mem_flags::mem_threadgroup);
// Then: load V → tg_KV (reuse same memory)
// Compute O += P @ V
The K and V tiles are never needed simultaneously, so they share the same threadgroup memory region. This halves the threadgroup memory requirement.
P Fragment in Registers
Instead of materializing the full probability matrix P in threadgroup memory, V2 keeps it in SIMD registers:
simdgroup_half8x8 p_half;
thread half2 &p_h = *(thread half2 *)&(p_half.thread_elements());
p_h[0] = half(s_frag[t][0]);
p_h[1] = half(s_frag[t][1]);
simdgroup_load(v_mat, tg_KV + t * SG_TILE * HD + d * SG_TILE, HD);
simdgroup_multiply_accumulate(o_mat, p_half, v_mat, o_mat);
The score fragments are cast to half precision and packed into SIMD matrix format directly from registers, then used in the MMA operation. No threadgroup write/read cycle.
Interactive: Prefill V2 — Register-Only Output Pipeline
Prefill V2 is the fastest attention kernel. Its key insight: keep the output accumulator in SIMD registers instead of threadgroup memory, and reuse the same threadgroup memory region for K and V tiles alternately. This animation shows one KV-block iteration, contrasting V1’s memory-heavy approach with V2’s register pipeline.
Kernel 6: Softmax (softmax_f16)
The standalone softmax kernel operates on pre-computed score matrices:
kernel void softmax_f16(
device half *data, constant SoftmaxParams ¶ms,
uint3 tgid_v, uint3 tid_v, uint sgid, uint slid, uint3 tpg
) {
float local_max = -INFINITY;
for (uint i = tid; i < cols; i += tg_size)
local_max = max(local_max, float(row_data[i]));
float row_max = tg_reduce_max(local_max, sgid, slid, tg_size, shared);
float local_sum = 0.0f;
for (uint i = tid; i < cols; i += tg_size)
local_sum += exp(float(row_data[i]) - row_max);
float total_sum = tg_reduce_sum(local_sum, sgid, slid, tg_size, shared);
float inv_sum = 1.0f / total_sum;
for (uint i = tid; i < cols; i += tg_size)
row_data[i] = half(exp(float(row_data[i]) - row_max) * inv_sum);
}
This is a classic three-pass softmax: find max, compute exp-sum, normalize. It uses tg_reduce_max and tg_reduce_sum (utility functions that combine simd_max/simd_sum with threadgroup shared memory) for efficient cross-SIMD-group reductions.
The kernel operates in-place (reads and writes the same buffer), dispatches one threadgroup per row, and handles arbitrary row lengths via strided access.
Logit Soft-Capping (logit_cap)
Gemma models apply a soft cap to attention logits:
kernel void logit_softcap_f16(
device half *logits, constant float &cap, constant uint &count,
uint tid [[thread_position_in_grid]]
) {
float x = float(logits[tid]);
logits[tid] = half(cap * tanh(x / cap));
}
This bounds the logits to [-cap, +cap] using a smooth tanh function. Applied before the softmax in attention, it prevents extreme attention scores that could lead to numerical instability.3
Kernel Selection Strategy
The host selects the attention kernel based on the phase and sequence length:
DECODE (M=1):
if fast_decode_available && kv_seq_len < parallel_threshold:
→ flash_attention_decode_fast_f16 (32 threads, 0 barriers)
elif kv_seq_len >= parallel_threshold:
→ flash_attention_decode_parallel_f16 (1024 threads)
else:
→ flash_attention_decode_f16 (128-1024 threads)
PREFILL (M>1):
if seq_len >= 1024 && head_dim <= 128:
→ flash_attention_prefill_v2_f16 (BQ=32, register output)
elif seq_len >= v2_threshold:
→ flash_attention_prefill_f16 (NQ=8 or 16, simd MMA)
else:
→ flash_attention_decode_f16 (per-query TG, scalar dot)
Performance Comparison
Here is a rough comparison of the five kernels for head_dim=128:
| Kernel | Throughput at KV=512 | Throughput at KV=4096 | Throughput at KV=32K |
|---|---|---|---|
| Standard Decode | High | Medium | Poor |
| Fast Decode | Medium | High | Medium |
| Parallel Decode | Low (overhead) | Medium | High |
| Prefill V1 (NQ=8) | – | Good | Good |
| Prefill V2 (BQ=32) | – | Excellent | Excellent |
The crossover points between kernels are hardware-dependent and are determined by profiling. The general principle: use the simplest kernel that is not bottlenecked by the context length.
Memory Analysis: Why FlashAttention Matters
To understand why FlashAttention is necessary, let’s compare memory usage for a standard attention computation vs. FlashAttention on a concrete example:
Model: Llama 3.1 8B, head_dim=128, n_heads=32, seq_len=4096
Naive Attention
S = Q @ K^T: [4096, 4096] * 32 heads * 4 bytes (FP32) = 2 GB
P = softmax(S): Same as S = 2 GB
O = P @ V: [4096, 128] * 32 heads * 4 bytes = 64 MB
─────────────────────────────────────────────
Total: ~4 GB for attention alone
This is clearly infeasible for on-device inference where total system RAM might be 16-36 GB and most of it is used for model weights.
FlashAttention
Q tile: [NQ, 128] * 2 bytes = NQ * 256 bytes
K tile: [128, 32] * 2 bytes = 8 KB
V tile: [32, 128] * 2 bytes = 8 KB
S scores: [NQ, 32] * 4 bytes = NQ * 128 bytes
O accum: [NQ, 128] * 4 bytes = NQ * 512 bytes
Softmax: [NQ] * 8 bytes = negligible
─────────────────────────────────────────────
Total: ~17 KB per threadgroup (NQ=8)
~33 KB per threadgroup (NQ=32, V2)
FlashAttention reduces memory from O(seq_len^2) to O(NQ * head_dim) per threadgroup – a reduction of over 100,000x for a 4096-token sequence. The attention matrix S is never fully materialized; only a single KV block’s worth of scores exists at any time.4
Numerical Precision Considerations
Online softmax introduces a subtlety: the correction factor exp(old_max - new_max) is applied multiplicatively to the running accumulator. After many KV blocks, this means the output has been multiplied by a chain of correction factors:
O_final = O_0 * c_1 * c_2 * ... * c_T + ...
Each correction factor is <= 1.0 (since new_max >= old_max), so the chain product decreases monotonically. For very long sequences (tens of thousands of KV entries), the accumulated product can approach the FP32 denormalization threshold.
In practice, this is not a problem because:
- The correction factor is only significantly less than 1.0 when a new maximum exceeds the old by a large margin, which happens rarely after the first few KV blocks.
- Once the running maximum stabilizes (typically within the first 100-200 KV entries), all subsequent corrections are approximately 1.0.
- The final normalization by
1/running_sumrescales the output, compensating for any cumulative shrinkage.
Akunu’s V2 kernel further improves precision by using exp2 instead of exp, which maps to the hardware transcendental unit and avoids the intermediate multiply-by-ln(2) that standard exp requires.
Attention Kernel Selection in the Dispatch Table
The choice of attention kernel is made once during dispatch table construction, not at runtime. The host examines the model configuration and hardware capabilities to select the best kernel:
For DECODE dispatch table:
1. Check if fast decode is suitable:
- head_dim <= 256 (fits in 32 lanes with <= 8 elements/lane)
- M=1 (single token decode)
2. Check if parallel decode is suitable:
- Expected long contexts (> 4096 KV entries)
3. Default: standard decode
For PREFILL (called at runtime based on seq_len):
1. V2 if seq_len >= 1024 and head_dim <= 128
2. V1 if seq_len >= v2_threshold
3. Standard decode fallback for very short sequences
Function constant specialization is used to bake head_dim, NQ, and masking mode into the kernel at pipeline compilation time. This enables the Metal compiler to generate optimized code for the specific configuration.
Function Constant Specialization Strategy
All attention kernels use Metal function constants for compile-time specialization. The specialization strategy differs by kernel:
| Kernel | Function Constants | Benefit |
|---|---|---|
| Standard decode | FC_HEAD_DIM, FC_USE_TREE_MASK, FC_NON_CAUSAL | Eliminates head_dim conditionals, removes unused masking code |
| Fast decode | FC_HEAD_DIM | Enables elems_per_thread as compile-time constant |
| Parallel decode | FC_HEAD_DIM | Same as fast decode |
| Prefill V1 | FC_HEAD_DIM, FC_NQ_ROWS, FC_NON_CAUSAL | Enables loop unrolling for head_dim, NQ pass calculation |
| Prefill V2 | FC_HEAD_DIM, FC_NQ_ROWS, FC_NON_CAUSAL | Same plus HD_TILES becomes compile-time |
The host creates separate PSOs for each unique combination of function constants. For a model with head_dim=128 and no tree masking, a typical PSO cache contains:
attn_decode_hd128_notree(standard decode)attn_decode_fast_hd128(fast decode)attn_decode_parallel_hd128(parallel decode)attn_prefill_hd128_nq8(prefill V1 with NQ=8)attn_pfv2_hd128_nq32(prefill V2 with NQ=32)
Each PSO is compiled once during model initialization. The compilation cost (~20-50ms per PSO) is amortized over the entire inference session.
The Role of q_stride and kv_stride
The attention kernels support two memory layouts:
Head-major layout: Q[head, seq_len, head_dim] – used when Q comes directly from the GEMV projection output. Each head’s data is contiguous.
Row-major layout: Q[seq_len, n_heads * head_dim] – used in prefill when Q comes from a GEMM output where rows correspond to sequence positions.
The q_stride parameter tells the kernel which layout to use:
device const half *q_row = (q_str > 0)
? Q + q_pos * q_str + head * head_dim // row-major
: Q + (head * seq_len + q_pos) * head_dim; // head-major
Similarly, kv_stride controls the KV cache layout. When kv_stride > 0, the KV cache uses a fixed-stride layout (allocated for the maximum sequence length); when 0, it uses a compact layout.
Sliding Window Attention
Some models (Mistral, Gemma 3) use sliding window attention where each token only attends to the most recent W positions. Akunu handles this at the KV cache level rather than in the attention kernel: the KV cache uses a ring buffer, and the effective kv_seq_len passed to the attention kernel is clamped to the window size.
This design keeps the attention kernels simple and universal. The sliding window logic lives in the KV cache management code, which adjusts the visible range before dispatching attention.
Comparison with Other Implementations
| Feature | Akunu | llama.cpp Metal | MLX |
|---|---|---|---|
| Decode attention variants | 3 (standard, fast, parallel) | 1 | 1 |
| Prefill attention variants | 2 (V1 simd MMA, V2 register) | 1 | 1 |
| Online softmax | Yes (all variants) | Yes | Yes |
| exp2 optimization | Yes (V2) | No | Yes |
| Register output | Yes (V2) | No | Yes |
| Tree masking | Yes | No | No |
| Non-causal mode | Yes | No | Yes |
| GQA support | Yes | Yes | Yes |
| BQ=32 query blocks | Yes (V2) | No | No |
| Direct K device load | Yes (V1 full blocks) | No | No |
Akunu’s attention kernel family is arguably the most diverse of any on-device LLM inference engine, with 5 variants covering the full range of use cases from short-context decode to long-sequence prefill.
Deep Dive: The V1-to-V2 Evolution
The prefill V1 and V2 kernels represent two different approaches to the same problem. Understanding how V2 improves on V1 illuminates the tradeoffs in GPU kernel design.
V1: Threadgroup Memory-Centric
V1 stores all intermediate results in threadgroup memory:
- Output accumulator
tg_O: NQ * HD * 4 bytes = 4096 bytes (NQ=8, HD=128) - Score matrix
tg_S: NQ * 32 * 4 bytes = 1024 bytes - Total per-threadgroup: ~12-17 KB depending on head_dim
The advantage: any thread can read any accumulator element, enabling flexible work distribution across SIMD groups. The disadvantage: every MMA output must be written to TG memory and every accumulator update requires a TG memory read-modify-write.
V2: Register-Centric
V2 keeps the output in SIMD group registers:
- Output accumulator:
float2 acc_o[HD_TILES]per thread = HD * 2 * 4 / 32 bytes per thread - For HD=128: 32 bytes per thread in registers
- Total per-SG: 32 * 32 = 1024 bytes in registers (no TG memory)
The advantage: register access is free (0 latency, infinite bandwidth). The disadvantage: each SIMD group can only access its own registers, requiring careful work assignment to avoid cross-SG communication.
V2 Score Handling via thread_elements()
The key insight enabling V2’s register-centric approach is that the MMA instruction’s thread_elements() accessor provides direct access to the 2 elements each thread owns in the 8x8 matrix result. This means:
- After computing S = Q @ K^T via MMA, each thread can directly read its 2 score elements without going through TG memory.
- Row-level max and sum can be computed using
simd_shuffle_xor(because all lanes sharing a row can communicate without barriers). - The rescaling
acc_o *= correctionis a local register operation.
This eliminates the TG memory round-trip for scores, the barrier-heavy softmax update, and the TG memory round-trip for output accumulation – three of the four major bottlenecks in V1.
Performance Impact
For a 4096-token prefill on M4 Pro with head_dim=128:
| Aspect | V1 (NQ=8) | V2 (BQ=32) |
|---|---|---|
| Threadgroups | 512 * n_heads | 128 * n_heads |
| TG memory per TG | ~17 KB | ~9 KB |
| KV load overhead | Per NQ=8 queries | Per BQ=32 queries (4x better amortization) |
| Barriers per KV block | ~6 | ~4 |
| Output write | TG memory -> device | Register -> device |
| Relative throughput | 1.0x | ~1.3-1.5x |
The 30-50% improvement comes from three sources: (1) 4x better KV data amortization, (2) fewer barriers per block, and (3) elimination of TG memory traffic for the output accumulator.
Attention and Memory Bandwidth
Attention is unique among transformer operations because it reads from the KV cache, which grows linearly with context length. For a 7B model with 8 KV heads, head_dim=128, and context length L:
KV cache reads per token per layer:
K: L * 128 * 2 bytes = 256L bytes
V: L * 128 * 2 bytes = 256L bytes
Total: 512L bytes per layer
32 layers: 16384L bytes = 16L KB
For L=4096: 16 * 4096 KB = 64 MB per token
For L=32K: 16 * 32768 KB = 512 MB per token
At L=32K, the KV cache reads alone consume 512 MB per token. At 200 GB/s memory bandwidth, this is 2.56ms – a significant fraction of the per-token time. This is why attention becomes the bottleneck at long contexts, surpassing even the GEMV weight reads.
The parallel decode kernel (32 SIMD groups) addresses this by parallelizing the KV scan, effectively multiplying the read bandwidth by 32x through concurrent memory requests. Each SIMD group reads independent KV positions, saturating the memory subsystem’s request queues.
Attention Dispatch Counts
For a 7B model with 32 layers and 32 attention heads, the total attention dispatch count during decode is:
Greedy decode (1 token):
32 layers × 1 attention dispatch = 32 attention dispatches
Chain decode (64 tokens):
32 layers × 64 tokens × 1 attention dispatch = 2048 attention dispatches
Each dispatch handles all 32 heads (the grid Y dimension covers heads). The attention kernel is typically the 3rd or 4th most expensive dispatch per token (after the FFN GEMVs and the QKV GEMVs), but at long contexts it becomes the most expensive.
For the prefill of 2048 tokens:
32 layers × 1 attention dispatch (covers all 2048 queries) = 32 attention dispatches
Prefill uses far fewer dispatches because each attention kernel processes all query rows simultaneously. This is another reason prefill is more efficient per-token than decode.
The Softmax Temperature Connection
The standalone softmax_f16 kernel and the logit_softcap_f16 kernel are both attention-adjacent operations:
Softmax is used by non-FlashAttention code paths (e.g., when debugging or when the model requires explicit softmax for cross-attention). It processes one row per threadgroup using the standard three-pass algorithm (find max, compute exp sum, normalize). The tg_reduce_max and tg_reduce_sum helper functions use the same SIMD-first + threadgroup-memory reduction pattern seen in the sampling kernels.
Logit soft-capping (cap * tanh(x / cap)) is specific to Gemma models and is applied before the softmax within attention. It bounds the attention logits to prevent extreme values that could destabilize the softmax computation. The tanh function saturates at +/-1, so the effective range is [-cap, +cap]. Typical values are cap = 30 or cap = 50.
Both kernels are simple 1D dispatches with minimal state, taking <5us per invocation. Their impact on overall performance is negligible, but their presence enables Akunu to support architectures that require these operations.
Kernel Selection Decision Tree
To summarize the complete selection logic, here is the decision tree the host uses:
Is this decode (M=1)?
├── YES: Is the dispatch table using fast decode?
│ ├── YES: flash_attention_decode_fast_f16
│ │ (1 SG, 32 threads, 0 barriers, head_dim/32 elems/thread)
│ └── NO: Is parallel decode enabled?
│ ├── YES: flash_attention_decode_parallel_f16
│ │ (32 SG, 1024 threads, cross-SG reduction)
│ └── NO: flash_attention_decode_f16
│ (N SGs, up to 1024 threads, 2 barriers/KV)
└── NO (prefill, M>1): What is seq_len?
├── seq_len >= 1024 AND head_dim <= 128:
│ flash_attention_prefill_v2_f16
│ (BQ=32, register output, exp2, 128 threads)
├── seq_len >= v2_threshold:
│ flash_attention_prefill_f16
│ (NQ=8 or 16, TG output, simd MMA, 128 threads)
└── seq_len < v2_threshold:
flash_attention_decode_f16 (per-query threadgroup)
The v2_threshold is model-dependent: for head_dim=64 it is 32 (NQ=16, so need at least 32 queries), for head_dim=128 it is 16 (NQ=8). Below these thresholds, there are not enough queries to fill the prefill kernel’s tile efficiently.
Summary
Akunu’s attention kernel family demonstrates that there is no single “best” attention algorithm – the optimal approach depends on the sequence length, batch size, and hardware capabilities:
- Standard decode: Simple, uses threadgroup barriers, best for short contexts or fallback.
- Fast decode: Single SIMD group, barrier-free, best for medium contexts (128-1024 KV).
- Parallel decode: 32 SIMD groups with KV parallelism, best for very long contexts (10K+).
- Prefill V1: SIMD MMA with direct K loading, best for medium prefill sequences.
- Prefill V2: Register output, exp2, BQ=32, best for long prefill sequences (1024+).
The online softmax algorithm is the thread that connects all five kernels – the same mathematical principle of running max/sum correction, implemented differently based on the parallelism strategy.
-
Dao, T., Fu, D.Y., Ermon, S., Rudra, A., and Re, C. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022. The key insight is that by tiling attention and maintaining online softmax statistics, the algorithm achieves O(N) memory usage instead of O(N^2) while performing the same mathematical operation. See https://arxiv.org/abs/2205.14135. ↩
-
Apple. “Metal Best Practices Guide.” Section “Use Fast Math Functions.” The
fast::exp2function on Apple GPU hardware uses the native transcendental function unit, which computes exp2 in a single pipeline cycle. Standardexprequires an additional multiply byln(2)internally. See https://developer.apple.com/library/archive/documentation/3DDrawing/Conceptual/MTLBestPracticesGuide/index.html. ↩ -
Gemma Team, Google DeepMind. “Gemma: Open Models Based on Gemini Research and Technology.” arXiv:2403.08295, 2024. The logit soft-capping technique with
cap * tanh(x/cap)prevents attention score explosion while maintaining gradient flow during training. See https://arxiv.org/abs/2403.08295. ↩ -
Rabe, M.N. and Staats, C. “Self-attention Does Not Need O(n^2) Memory.” arXiv:2112.05682, 2021. This paper independently proved the same memory-efficient attention idea as FlashAttention, showing that O(1) memory is achievable for the attention computation. ↩