Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Architectural Decision Records

Every codebase is the sum of its decisions. Some of those decisions are obvious in retrospect; others look arbitrary unless you know the alternatives that were considered and rejected. This chapter documents ten key architectural decisions in akunu using the ADR (Architectural Decision Record) format.1 For each decision, we state the problem, enumerate the options considered, record the decision, and explain the rationale.

If you are contributing to akunu, modifying it for a different platform, or designing your own inference engine, these records tell you not just what was chosen but why – and more importantly, what would need to change if the underlying assumptions shifted.


ADR-1: Dispatch Table vs. Dynamic Dispatch

Problem

An LLM forward pass consists of dozens of GPU kernel dispatches per layer: embedding lookup, normalization, projections (GEMV/GEMM), RoPE, attention, activation, and residual adds. The engine needs a way to describe and execute this sequence. The two broad approaches are:

  1. Dynamic dispatch: At each step, the engine code decides which kernel to call, sets up buffer bindings, and dispatches. This is the “interpreter” approach.
  2. Static dispatch table: Pre-compile the entire forward pass into a flat array of dispatch commands at model load time. At inference time, just iterate the array.

Options Considered

OptionDescriptionProsCons
A. Virtual method per layerEach layer is a C++ object with a forward() method that encodes its own kernelsClean OOP design, easy to understandVirtual call overhead per dispatch, cache-unfriendly vtable chasing, hard to batch across tokens
B. Interpreter loopA function that walks the model config and emits dispatch commands at each stepNo up-front cost, flexiblePer-token overhead from branching on architecture, dtype, chip config; hard to batch
C. Pre-compiled dispatch tableBuild a std::vector<DispatchCmd> once during init; replay it for each tokenZero per-token decision overhead, trivially batchable, cache-denseHigher init cost, patching needed for dynamic fields (position, KV length)

Decision

Option C: Pre-compiled dispatch table.

Rationale

The dispatch table approach was chosen because the forward pass structure is identical for every token. The only things that change between tokens are:

  • The token embedding index (patched via PATCH_TOKEN_OFFSET)
  • The position for RoPE and KV cache writes (patched via PATCH_POSITION)
  • The KV sequence length for attention (patched via PATCH_KV_SEQ_LEN)

Everything else – pipeline state objects, buffer bindings, threadgroup sizes, parameter structs – is invariant. Building all of this once and replaying it N times is the obvious optimization.

The DispatchCmd struct is a POD (Plain Old Data) type with no pointers to chase, no virtual calls, and no heap allocations beyond the vector itself. At 64 bytes for inline parameters plus fixed-size buffer arrays, it fits neatly in cache lines. The encode_chain() function in dispatch_table.h is a tight double loop:

for each token in [0, count):
    for each command in table.commands:
        set pipeline, set buffers, patch params, dispatch

This is the hot path. It runs once per chain decode chunk (64-128 tokens). The inner loop body is branch-free except for the patch type switch, which the compiler can lower to a jump table.

Consequences

  • Pro: Chain decode became trivial to implement. Batching N tokens is just calling encode_chain() with count=N.
  • Pro: Profiling labels are stored in a parallel DispatchLabel vector (cold data), keeping the hot command array dense.
  • Con: Adding a new architecture requires a new build_dispatch_table() path in table_builder.h. The ArchDescriptor (ADR-3) mitigates this by making most architecture differences data-driven rather than code-driven.
  • Con: Dynamic control flow (e.g., early exit, mixture-of-experts routing) is harder to express. If akunu ever supports MoE models, the dispatch table design would need extension.

ADR-2: C API vs. C++

Problem

Akunu needs a public API for applications (CLI tools, Swift apps, servers) to load models, tokenize text, and run inference. The API design affects language binding ergonomics, ABI stability, and the mental model for users.

Options Considered

OptionDescriptionProsCons
A. C++ class APIclass AkunuModel { ... } with methodsNatural for C++ users, can use RAII, templatesC++ ABI is fragile across compilers/versions, hard to bind from Swift/Python/Rust
B. C API with opaque handlesakunu_model_t as void*, free functionsStable ABI, trivial to bind from any language, no name manglingVerbose, no RAII, error handling via return codes or thread-local strings
C. C API wrapping C++ internalsSame as B, but the implementation uses C++ internallyBest of both: stable API surface, modern implementationThin translation layer between C and C++

Decision

Option C: C API wrapping C++ internals.

Rationale

The primary consumer of akunu’s API is the Swift binding layer (CAkunu/shim.c), which needs a C-compatible interface. Swift can import C headers directly via the Clang importer, but C++ interop (even with Swift 5.9+) is limited and fragile. A pure C API with opaque void* handles is the safest choice.

The actual API surface, visible in include/akunu/akunu.h, follows a consistent pattern:

  • Lifecycle: akunu_load_model() / akunu_free_model()
  • Info: akunu_get_config(), akunu_model_memory()
  • Tokenization: akunu_encode(), akunu_decode_token()
  • Generation: akunu_generate(), akunu_chain_decode(), akunu_generate_continue()
  • Profiling: akunu_profile_decode_step(), akunu_profile_label()
  • Error: akunu_get_error() (thread-local)

All structs passed across the API boundary (AkunuModelConfig, AkunuGenerationStats, AkunuSamplingConfig) are defined in types.h as C-compatible POD types with fixed-width integer fields.

Consequences

  • Pro: The Swift package (Sources/CAkunu) imports the C header directly with zero bridging code beyond a thin shim.
  • Pro: The API is trivially bindable from Python (ctypes/cffi), Rust (bindgen), and any other language with C FFI.
  • Pro: ABI stability – the library can be updated without recompiling consumers, as long as the C function signatures do not change.
  • Con: No RAII for model handles. Forgetting akunu_free_model() leaks GPU memory. The Swift binding wraps this in a class with deinit.
  • Con: Error messages are thread-local strings, which is less ergonomic than exceptions or Result types.

ADR-3: Data-Driven Architecture Descriptors

Problem

Akunu supports multiple model architectures: LLaMA, Qwen3, Gemma, Gemma3, Whisper, and BERT. Each architecture has differences in activation functions, normalization placement, RoPE style, embedding scaling, and more. The question is how to handle these differences without littering the codebase with if (arch == "llama") ... else if (arch == "gemma") ... branches.

Options Considered

OptionDescriptionProsCons
A. Architecture-specific subclassesclass LlamaModel : public Model, class GemmaModel : public ModelClean separation of concernsVirtual dispatch overhead, code duplication across architectures, adding a new arch means a new class
B. If/else chains on architecture nameCheck cfg.architecture string throughout the codebaseSimple, no abstraction overheadScatters architecture logic everywhere, easy to miss a branch, N*M combinatorial explosion
C. Data-driven descriptor structAn ArchDescriptor POD struct that captures all arch-specific behavior as data fieldsSingle source of truth, no branching in hot path, trivial to add new architecturesMust anticipate all possible variation dimensions up front

Decision

Option C: Data-driven ArchDescriptor.

Rationale

The key insight is that the differences between architectures are almost entirely parametric, not structural. LLaMA and Gemma have the same transformer skeleton; they differ in:

  • Activation function: SiLU (LLaMA) vs GELU (Gemma)
  • Embedding scaling: none (LLaMA) vs sqrt(dim) (Gemma)
  • QK norm: no (LLaMA) vs yes (Qwen3, Gemma)
  • Post-attention/FFN norms: no (LLaMA) vs yes (Gemma)
  • RoPE style: interleaved (LLaMA) vs NeoX/split-half (Qwen3, Gemma)
  • Tied embeddings: no (LLaMA) vs yes (Qwen3, Gemma)

All of these can be captured as fields in a struct. The ArchDescriptor in src/core/arch_descriptor.h has fields like activation_kernel, embedding_scale, has_qk_norm, rope_kernel, and tie_embeddings. Factory functions (arch_llama(), arch_qwen3(), arch_gemma(), etc.) return pre-filled descriptors. The arch_from_config() function maps GGUF metadata strings to the right factory.

The build_dispatch_table() function in table_builder.h reads from the ArchDescriptor and never branches on architecture name. Adding support for a new LLaMA variant (say, Mistral with sliding window attention) is typically a one-line change: modify an existing factory or add a new one.

Consequences

  • Pro: Adding Qwen3 support required writing arch_qwen3() (4 lines that override 3 fields from the LLaMA defaults) and zero changes to the table builder or decode path.
  • Pro: The hot path (dispatch table replay) is completely architecture-agnostic. The architecture was “compiled away” during init.
  • Con: Truly novel architectures (e.g., mixture of experts, state-space models) may not fit the descriptor model and would require structural changes.
  • Con: Encoder-decoder models (Whisper) stretch the descriptor with fields like is_encoder_decoder, has_cross_attention, has_conv_frontend that are irrelevant for decoder-only models.

ADR-4: Precomputed RoPE via Fused Kernel

Problem

Rotary Position Embeddings (RoPE) apply a rotation to the Q and K vectors based on their position in the sequence. The rotation frequencies are computed as theta^(-2i/d) for each dimension pair i. This computation involves transcendental functions (sin, cos) which are expensive even on GPU hardware.

Options Considered

OptionDescriptionProsCons
A. Compute RoPE frequencies per-tokenEach decode step computes sin/cos from theta and positionSimple, no precomputation neededRedundant transcendental function calls for every token
B. Precompute frequency table on CPUBuild a [max_seq_len, head_dim] table of sin/cos values at initAmortizes transcendental costLarge table (max_seqhead_dim4 bytes), wastes memory for short contexts
C. Precompute frequency divisorsStore [head_dim/2] frequency divisors; compute position*freq in kernelTiny table (head_dim/2 floats), position multiply is cheapSlightly more complex kernel
D. Fuse RoPE with QKV projection + KV cache writeSingle kernel: GEMV -> RoPE rotate Q,K -> write K,V to cacheEliminates 2-3 separate kernel dispatches per layerComplex kernel, harder to debug

Decision

Option D: Fused RoPE+QKV+KV-write kernel, with precomputed frequency divisors (Option C) as the frequency source.

Rationale

In a non-fused approach, each transformer layer during decode requires:

  1. GEMV for Q projection
  2. GEMV for K projection
  3. GEMV for V projection
  4. RoPE on Q
  5. RoPE on K
  6. KV cache write for K
  7. KV cache write for V

That is 7 kernel dispatches. The fused kernel rope_qkv_write_f16 (or rope_neox_qkv_write_f16 for NeoX-style) combines steps 4-7 into a single dispatch. Combined with QKV fusion (fusing the three GEMV projections into a single GEMV that writes to a contiguous [q_dim + 2*kv_dim] buffer), the total drops from 7 dispatches to 2 (one fused GEMV, one fused RoPE+KV-write).

The RoPEQKVWriteParams struct captures all the parameters this fused kernel needs:

FieldPurpose
n_kv_headsNumber of KV heads (for GQA)
head_dimElements per head
max_seq_lenKV cache dimension for stride computation
posCurrent position (patched per token in chain decode)
thetaRoPE base frequency
n_headsNumber of Q heads
k_elem_offsetByte offset to K section in QKV buffer
v_elem_offsetByte offset to V section in QKV buffer
freq_scaleLinear RoPE scaling factor (1.0 = no scaling)

The rope_freqs field in ArchDescriptor stores precomputed frequency divisors when the model provides them (some GGUF files include rope_freqs metadata). Otherwise, the kernel computes frequencies from theta directly using the standard formula.

Consequences

  • Pro: Reduces per-layer dispatch count from 7 to 2, saving ~5 kernel launch overheads per layer per token. For a 32-layer model in chain decode (128 tokens), this eliminates 5 * 32 * 128 = 20,480 dispatch commands per chunk.
  • Pro: Better memory access pattern. The fused kernel reads QKV once and writes K/V to cache in the same pass, improving cache utilization.
  • Con: Two RoPE kernel variants (interleaved and NeoX) must be maintained, each with fused and standalone versions.
  • Con: The fused kernel has more parameters (9 fields in RoPEQKVWriteParams) and more complex dispatch geometry.

ADR-5: Chain Decode

Problem

The naive approach to autoregressive decoding is: for each token, encode one forward pass into a Metal command buffer, commit it, wait for completion, read back the result, and feed it to the next step. This creates a CPU-GPU synchronization point per token, and each sync costs 30-80 microseconds of overhead.2

Options Considered

OptionDescriptionProsCons
A. One token per command bufferStandard approach: encode, commit, wait, repeatSimple, argmax result available immediately~50us sync overhead per token, GPU idle during CPU readback
B. Speculative multi-command bufferEncode N tokens speculatively, verify in batchAmortizes sync costRequires draft model or prediction, complex verification logic
C. Chain decode in single command bufferEncode N identical forward passes back-to-back, patching only position and token offset. Argmax output of token i feeds as input to token i+1 via GPU buffer.Single sync per chunk, GPU stays 100% busyMust use greedy decoding (argmax). Non-greedy requires CPU-side sampling between tokens.

Decision

Option C: Chain decode with single command buffer.

Rationale

The key observation is that for greedy decoding (temperature=0), the next token is determined entirely on the GPU (argmax of logits). There is no need to read the result back to the CPU between tokens. The encode_chain() function simply repeats the dispatch table N times, patching:

  • buffers[0].offset = tok * 4 for embedding lookup (reads from token_ids buffer)
  • param.pos = start_position + tok for RoPE and KV cache
  • param.kv_seq_len = start_position + tok + 1 for attention

The argmax kernel writes its result to token_ids[tok + 1], and the next iteration’s embedding lookup reads from that same buffer. The entire chain runs as one GPU submission with zero CPU intervention.

The chunk size (chain_decode_chunk in ChipConfig) is tuned per chip:

Chip ClassChunk SizeRationale
M1/M2/M3 base64Smaller GPU, less command buffer memory
M3 Pro96More GPU cores, but older command processor
M4 family, Max/Ultra128Improved command processor, higher bandwidth

Consequences

  • Pro: Eliminates ~50us * N sync overhead per chunk. For 128 tokens at 80 tok/s, this saves ~6.4ms – an 8% throughput improvement.
  • Pro: GPU utilization approaches 100% within a chunk. Metal System Trace shows a continuous block of GPU activity with no idle gaps.
  • Con: Only works for greedy (argmax) decoding. Non-greedy decoding (temperature > 0, top-k, top-p) requires akunu_decode_step() with per-token CPU-GPU synchronization.
  • Con: The KV cache must be pre-sized for the full chunk. If the context window fills mid-chunk, the chain must terminate early.
  • Con: Error recovery is harder – if a token generates an EOS mid-chain, the remaining tokens are wasted work.

ADR-6: GPU Gumbel-Max vs. CPU Sampling

Problem

When temperature > 0, the model needs to sample from the logit distribution rather than take the argmax. Sampling involves: (1) applying temperature scaling, (2) optionally applying repetition penalty, (3) computing probabilities (softmax), and (4) drawing a random sample. Where should this happen – CPU or GPU?

Options Considered

OptionDescriptionProsCons
A. CPU samplingRead logits back from GPU, sample on CPUFull flexibility (top-k, top-p, min-p, grammar constraints), easy to implementRequires GPU-to-CPU data transfer of entire logit vector (vocab_size * 2 bytes), breaks chain decode
B. GPU argmax onlyKeep greedy on GPU, fall back to CPU for non-greedySimple, no sampling complexity on GPUNon-greedy is slower due to sync overhead
C. GPU Gumbel-max trickAdd Gumbel noise to logits on GPU, then take argmax. Mathematically equivalent to sampling from the softmax distribution.3Keeps everything on GPU, compatible with chain decodeLimited to temperature sampling; top-k/top-p require sorting which is expensive on GPU

Decision

Option A (CPU sampling) for non-greedy, Option B (GPU argmax) for greedy. GPU temperature scaling and repetition penalty are provided as optional GPU-side preprocessing (via akunu_gpu_temperature_scale() and akunu_gpu_repetition_penalty()), but the actual sampling decision happens on the CPU.

Rationale

The Gumbel-max trick (Option C) was prototyped but ultimately not adopted as the default for several reasons:

  1. Grammar-constrained decoding requires masking invalid tokens before sampling. The xgrammar integration runs on the CPU and produces a token bitmask. Applying this mask on the GPU would require uploading it per token, which adds overhead that negates the chain decode benefit.

  2. Top-k and top-p sampling require sorting or partial sorting of the logit vector, which is not efficient on Apple’s GPU for the vocabulary sizes used by modern LLMs (32K-128K). A GPU radix sort for 128K elements would consume more time than just reading the logits back to the CPU.

  3. Sampling with temperature=0 (greedy) accounts for the majority of use cases in benchmarks and many production deployments. Chain decode works perfectly for greedy mode.

  4. The logit readback cost is bounded: for a 128K vocabulary at FP16, the transfer is 256 KB – well within the UMA zero-copy window on Apple Silicon. The “transfer” is really just a cache flush, not a DMA copy.

Consequences

  • Pro: Full sampling flexibility (temperature, top-k, top-p, min-p, repetition penalty, grammar constraints) without GPU-side complexity.
  • Pro: Greedy mode gets the full chain decode benefit with zero overhead.
  • Con: Non-greedy decoding cannot use chain decode and pays the per-token sync cost (~50us per token).
  • Con: The GPU-side temperature and repetition penalty kernels (TemperatureScaleParams, RepetitionPenaltyParams) are currently only used when the user explicitly calls the low-level API; the high-level akunu_generate() does sampling on the CPU.

ADR-7: Dual Format Support (GGUF + MLX SafeTensors)

Problem

The Apple Silicon LLM ecosystem has two dominant weight formats:

  1. GGUF: The format from llama.cpp. Block-quantized weights with rich metadata. Supported by virtually every open-source LLM tool.
  2. MLX SafeTensors: The format from Apple’s MLX framework. Group-quantized weights in SafeTensors container with JSON config. Growing ecosystem, especially for Apple-optimized models.

Should akunu support one or both?

Options Considered

OptionDescriptionProsCons
A. GGUF onlyRely on the GGUF ecosystemLargest model library, well-understood formatMisses MLX-specific models, some newer models ship MLX-first
B. MLX onlyAlign with Apple’s own frameworkNative Apple ecosystem, simpler quantizationMisses the vast GGUF library, less community tooling
C. Both via WeightProvider abstractionUnified interface that wraps either backendAccess to both ecosystemsTwo code paths to maintain, name mapping complexity

Decision

Option C: Both formats via WeightProvider.

Rationale

Users should not have to choose a format. If they have a GGUF file from Hugging Face, it should work. If they have an MLX model from mlx-community, it should also work. The WeightProvider class in src/weight/weight_provider.h detects the format at load time:

  • If the path is a directory or ends in .safetensors -> MLX format
  • Otherwise -> GGUF format

Both backends expose the same interface: get_tensor(), get_dtype(), has_tensor(), get_config(), plus metadata accessors. The GGUF backend (WeightStore) wraps the GGUF parser; the MLX backend (MLXWeightStore) wraps the SafeTensors parser plus a name mapping layer (MLX uses HuggingFace naming conventions like model.layers.{n}.self_attn.q_proj.weight, which akunu maps to canonical names like layers.{n}.attention.q.weight).

The dtype system uses internal codes: GGUF dtypes 0-30 for standard GGUF types, plus synthetic codes 99-102 for MLX quantized formats (MLX Q3, Q4, Q6, Q8). The DTypeDescriptor table in dtype_descriptor.h maps each code to the appropriate GEMV, GEMM, and embedding kernels.

Consequences

  • Pro: Users can load any model from either ecosystem with the same akunu_load_model() call.
  • Pro: The same Metal kernels are used for both formats where the quantization is compatible (e.g., FP16 weights from either format use the same gemv_f16 kernel).
  • Con: MLX quantization is group-based (group_size=64 typically) while GGUF is block-based (block_size=32 for Q4_0). Different dequantization logic in the Metal kernels.
  • Con: The MLX name mapping table (kMLXRules in mlx_weight_store.h) must be updated when new architectures use different naming conventions.

ADR-8: Fused Kernels

Problem

A transformer layer involves many small operations that are individually simple but collectively expensive due to kernel launch overhead and redundant memory traffic. For example, the FFN block in a SwiGLU model does:

  1. GEMV: gate = W_gate @ x
  2. GEMV: up = W_up @ x
  3. Elementwise: act = silu(gate) * up
  4. GEMV: down = W_down @ act

Each of those first three steps reads and writes intermediate buffers. Can we fuse them?

Options Considered

OptionDescriptionProsCons
A. No fusionSeparate kernels for every operationSimple, composable, easy to debugHigh kernel launch count, redundant memory traffic for intermediates
B. Full FFN fusionSingle kernel: gate+up GEMV -> SiLU -> down GEMVMinimal memory trafficExtremely complex kernel, hard to tune, loses flexibility
C. Selective fusionFuse operations where the benefit is clear: gate+up GEMV+activation, QKV+RoPE+KV-writeGood balance of performance and complexityMust decide which fusions are worth the implementation cost

Decision

Option C: Selective fusion.

Rationale

Akunu implements several fused kernels where profiling showed clear benefits:

Fused KernelWhat It FusesBenefit
gemv_q4_0_siluGate GEMV + SiLU activation + Up GEMV element multiplyEliminates 2 intermediate buffer writes/reads, 1 fewer dispatch
rope_qkv_write_f16RoPE rotation of Q,K + KV cache write for K,VEliminates 3-4 separate dispatches per layer (see ADR-4)
gemv_kv_* (GEMVKVParams)K or V projection GEMV + direct KV cache writeEliminates separate KV cache copy kernel
gemv_head_norm_* (GEMVHeadNormParams)GEMV output + per-head RMSNormEliminates separate norm dispatch for QK-norm models

The fused SiLU kernels (gemv_q4_0_silu, gemv_mlx_q4_silu, etc.) are particularly valuable because the activation is applied during the GEMV accumulation, before the result is written to device memory. Each thread computes silu(gate_partial) * up_partial in registers, avoiding a round-trip through device memory for the intermediate gate and up buffers.

Weight fusion (gate+up weights concatenated into a single buffer) is a separate but related optimization. The ChipConfig::should_fuse_weights flag enables this on Pro+ chips where the SLC is large enough to benefit from reading the fused weight buffer sequentially.

Consequences

  • Pro: 15-25% reduction in per-layer kernel count for decode, directly translating to lower dispatch overhead.
  • Pro: Reduced memory traffic for intermediates, which matters on bandwidth-constrained base chips.
  • Con: Each fused kernel variant must be written and tested for every supported dtype. The gemv_*_silu kernels exist for Q4_0, MLX Q3, MLX Q4, MLX Q6, and MLX Q8 – five implementations of the same fusion.
  • Con: The DTypeDescriptor table must track both fused and unfused kernel names, adding complexity to the dtype lookup.

ADR-9: Ping-Pong Scratch Buffers

Problem

The transformer’s residual connection pattern means that each layer’s output is added to its input. The straightforward implementation allocates a new buffer for each intermediate result, but this wastes memory and forces unnecessary allocations in the hot path.

Options Considered

OptionDescriptionProsCons
A. Dynamic allocationdevice.allocate() per forward pass for intermediatesSimple, no buffer managementAllocation in hot path, memory fragmentation, Metal allocator overhead
B. Single scratch buffer with offsetsOne large buffer, sub-allocate with manual offset managementMinimal total memoryComplex offset bookkeeping, risk of aliasing bugs, hard to reason about lifetimes
C. Named ping-pong buffersPre-allocate fixed named buffers (h0, h1, residual, qkv, etc.) at model load. Alternate between h0 and h1 for residual accumulation.Zero allocation in hot path, clear ownership, easy to debugFixed memory footprint regardless of actual usage

Decision

Option C: Named ping-pong buffers via ScratchBuffers.

Rationale

The ScratchBuffers struct in src/cache/scratch.h pre-allocates all intermediate buffers during akunu_load_model(). The decode buffers are:

BufferSizePurpose
h0dim * 2 bytesResidual stream “ping” (FP16)
h1dim * 2 bytesResidual stream “pong” (FP16)
residualdim * 2 bytesNorm output
qkv(q_dim + 2*kv_dim) * 2 bytesContiguous Q, K, V projections
attn_outmax(q_dim, dim) * 2 bytesAttention output
post_normdim * 2 bytesPost-norm temp (Gemma-style architectures)
ffn_gateffn_dim * 4 bytesGate projection (2x for fused gate+up)
ffn_upffn_dim * 2 bytesUp projection
ffn_actffn_dim * 2 bytesActivation output
logitsvocab_size * 2 bytesFinal logits
token_idsmax_chain * 4 bytesChain decode token buffer

The residual connection works by alternating between h0 and h1:

Layer input:  h0
  -> norm:    residual = rmsnorm(h0)
  -> attn:    attn_out = attention(qkv_proj(residual))
  -> add:     h1 = h0 + attn_out     (residual add)
  -> norm:    residual = rmsnorm(h1)
  -> ffn:     ffn_out = down(silu(gate) * up)
  -> add:     h0 = h1 + ffn_out      (residual add)
Layer output: h0  (same buffer as input -- full circle)

Odd-numbered layers swap the roles: input from h1, output to h0. This “ping-pong” pattern means we never need more than two hidden-state-sized buffers for the entire forward pass, regardless of model depth.

A separate set of batch_* buffers handles prefill, where each buffer is scaled by prefill_chunk * dim to handle batched operations.

Consequences

  • Pro: Absolutely zero memory allocation during inference. Every buffer is pre-allocated and reused.
  • Pro: The ScratchBuffers struct is POD-like: a flat collection of Buffer handles. The dispatch table references these buffers directly by name, making the data flow through the model explicit and debuggable.
  • Pro: Memory footprint is predictable and constant. akunu_model_memory() can report the exact GPU memory usage at load time.
  • Con: The memory footprint is the maximum needed for any forward pass, even if most operations use less. For example, ffn_gate is allocated at ffn_dim * 4 to support fused gate+up, even when weight fusion is disabled.
  • Con: Adding a new buffer for a new operation (e.g., a second attention output for cross-attention in Whisper) requires modifying the ScratchBuffers struct and the create() / destroy() methods.

ADR-10: Head-Major KV Cache Layout

Problem

The KV cache stores the key and value vectors for all previously processed tokens. During attention, the kernel needs to read all K vectors for a given head to compute dot products with the query, then read all V vectors for the same head to compute the weighted sum. The memory layout of the KV cache determines the access pattern and thus the memory bandwidth efficiency.

Options Considered

OptionDescriptionMemory LayoutAccess Pattern
A. Sequence-major[max_seq_len, n_kv_heads, head_dim]All heads for position 0, then all heads for position 1, …Attention reads for one head are strided across positions
B. Head-major[n_kv_heads, max_seq_len, head_dim]All positions for head 0 (contiguous), then all positions for head 1, …Attention reads for one head are contiguous
C. PagedSmall fixed-size pages, indirection tablePages allocated on demand, pages for one head may be non-contiguousFlexible memory management, but indirection overhead

Decision

Option B: Head-major layout [n_kv_heads, max_seq_len, head_dim].

Rationale

During decode attention, each query head computes:

scores[t] = dot(Q[head], K[head][t])  for t in 0..kv_seq_len-1
output = sum(scores[t] * V[head][t])

With head-major layout, all K vectors for a given head are contiguous in memory: K[head] is a contiguous block of max_seq_len * head_dim FP16 values. The attention kernel can read this block sequentially, which maximizes memory bandwidth utilization on Apple Silicon’s memory controller.

With sequence-major layout (Option A), reading K vectors for one head would require striding by n_kv_heads * head_dim elements between positions, resulting in poor cache line utilization – you load a full cache line but only use head_dim / (n_kv_heads * head_dim) of it.

The KV stride is precomputed as kv_stride = max_seq_len * head_dim (elements between consecutive KV heads) and stored in the KVCache struct. The AttentionParams struct passes this to the attention kernel via the kv_stride field. A value of 0 means “use kv_seq_len * head_dim” (the dense case), which is useful when the KV cache is exactly filled.

The KVCacheWriteParams struct handles writing new K/V vectors to the correct position:

offset_in_buffer = head * kv_stride + pos * head_dim

This is a simple multiply-add, computed in the fused RoPE+KV-write kernel.

Consequences

  • Pro: Contiguous memory access for attention reads, maximizing Apple Silicon’s memory bandwidth. This is the most important access pattern to optimize because attention cost grows linearly with context length.
  • Pro: The GQA (Grouped Query Attention) pattern falls out naturally: Q heads 0..3 all read from KV head 0, which is a single contiguous block. No gather/scatter needed.
  • Pro: KV cache shifting (for sliding window) is a simple memmove within each head’s contiguous block. The KVCacheShiftParams struct supports this.
  • Con: Memory is allocated for the full max_seq_len per head, even if the actual sequence is shorter. For a model with n_kv_heads=8, max_seq_len=4096, head_dim=128 in FP16, each layer’s K cache is 8 * 4096 * 128 * 2 = 8 MB. Over 32 layers, that is 512 MB for K alone (plus 512 MB for V).
  • Con: Paged attention (Option C) would use less memory for short sequences, but the indirection overhead and implementation complexity were deemed not worth it for the target use case (single-user inference on Apple Silicon with sufficient memory).

Summary: How the Decisions Fit Together

These ten decisions are not independent. They form an interlocking system:

+--------------------+
| ArchDescriptor(3)  |---> drives table_builder
+--------------------+
         |
         v
+--------------------+     +-------------------+
| DispatchTable(1)   |<--->| ScratchBuffers(9) |
+--------------------+     +-------------------+
         |                          |
         v                          v
+--------------------+     +-------------------+
| Chain Decode(5)    |     | KV Cache(10)      |
+--------------------+     +-------------------+
         |
         v
+--------------------+
| Fused Kernels(8)   |
| - RoPE+QKV(4)      |
| - GEMV+SiLU        |
+--------------------+
         |
         v
+--------------------+     +-------------------+
| C API(2)           |     | WeightProvider(7) |
+--------------------+     +-------------------+
         |                          |
         v                          v
+--------------------+     +-------------------+
| Sampling(6)        |     | GGUF + MLX dtypes |
+--------------------+     +-------------------+

The ArchDescriptor (3) feeds into the dispatch table builder. The dispatch table (1) references pre-allocated scratch buffers (9) and KV cache buffers (10), and embeds fused kernels (8, 4). Chain decode (5) replays the dispatch table with minimal patching. The C API (2) wraps all of this behind opaque handles. The WeightProvider (7) supplies weights in either GGUF or MLX format. And the sampling strategy (6) determines whether chain decode can be used (greedy) or falls back to per-token decode.

If you are extending akunu, this dependency graph tells you what you need to touch. Adding a new quantization format? Modify DTypeDescriptor (8/7) and add kernels. Adding a new architecture? Add an ArchDescriptor factory (3). Implementing paged attention? That affects KV cache (10), the dispatch table (1), and the attention kernel (8).


  1. The ADR format was popularized by Michael Nygard. See “Documenting Architecture Decisions” (2011). The format used here is a simplified version: Problem, Options, Decision, Rationale, Consequences. See https://cognitect.com/blog/2011/11/15/documenting-architecture-decisions.

  2. Measured on M2 Pro: MTLCommandBuffer commit + waitUntilCompleted averages 45us when the command buffer contains a trivial kernel. The overhead is in the driver and command processor, not the GPU itself.

  3. The Gumbel-max trick: if g_i ~ Gumbel(0,1), then argmax(log(p_i) + g_i) ~ Categorical(p). Since log(p_i) = logit_i / temperature - log(Z), and the log(Z) term is constant across categories, argmax(logit_i / temperature + g_i) samples from the temperature-scaled distribution.