Architectural Decision Records
Every codebase is the sum of its decisions. Some of those decisions are obvious in retrospect; others look arbitrary unless you know the alternatives that were considered and rejected. This chapter documents ten key architectural decisions in akunu using the ADR (Architectural Decision Record) format.1 For each decision, we state the problem, enumerate the options considered, record the decision, and explain the rationale.
If you are contributing to akunu, modifying it for a different platform, or designing your own inference engine, these records tell you not just what was chosen but why – and more importantly, what would need to change if the underlying assumptions shifted.
ADR-1: Dispatch Table vs. Dynamic Dispatch
Problem
An LLM forward pass consists of dozens of GPU kernel dispatches per layer: embedding lookup, normalization, projections (GEMV/GEMM), RoPE, attention, activation, and residual adds. The engine needs a way to describe and execute this sequence. The two broad approaches are:
- Dynamic dispatch: At each step, the engine code decides which kernel to call, sets up buffer bindings, and dispatches. This is the “interpreter” approach.
- Static dispatch table: Pre-compile the entire forward pass into a flat array of dispatch commands at model load time. At inference time, just iterate the array.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Virtual method per layer | Each layer is a C++ object with a forward() method that encodes its own kernels | Clean OOP design, easy to understand | Virtual call overhead per dispatch, cache-unfriendly vtable chasing, hard to batch across tokens |
| B. Interpreter loop | A function that walks the model config and emits dispatch commands at each step | No up-front cost, flexible | Per-token overhead from branching on architecture, dtype, chip config; hard to batch |
| C. Pre-compiled dispatch table | Build a std::vector<DispatchCmd> once during init; replay it for each token | Zero per-token decision overhead, trivially batchable, cache-dense | Higher init cost, patching needed for dynamic fields (position, KV length) |
Decision
Option C: Pre-compiled dispatch table.
Rationale
The dispatch table approach was chosen because the forward pass structure is identical for every token. The only things that change between tokens are:
- The token embedding index (patched via
PATCH_TOKEN_OFFSET) - The position for RoPE and KV cache writes (patched via
PATCH_POSITION) - The KV sequence length for attention (patched via
PATCH_KV_SEQ_LEN)
Everything else – pipeline state objects, buffer bindings, threadgroup sizes, parameter structs – is invariant. Building all of this once and replaying it N times is the obvious optimization.
The DispatchCmd struct is a POD (Plain Old Data) type with no pointers to chase, no virtual calls, and no heap allocations beyond the vector itself. At 64 bytes for inline parameters plus fixed-size buffer arrays, it fits neatly in cache lines. The encode_chain() function in dispatch_table.h is a tight double loop:
for each token in [0, count):
for each command in table.commands:
set pipeline, set buffers, patch params, dispatch
This is the hot path. It runs once per chain decode chunk (64-128 tokens). The inner loop body is branch-free except for the patch type switch, which the compiler can lower to a jump table.
Consequences
- Pro: Chain decode became trivial to implement. Batching N tokens is just calling
encode_chain()withcount=N. - Pro: Profiling labels are stored in a parallel
DispatchLabelvector (cold data), keeping the hot command array dense. - Con: Adding a new architecture requires a new
build_dispatch_table()path intable_builder.h. The ArchDescriptor (ADR-3) mitigates this by making most architecture differences data-driven rather than code-driven. - Con: Dynamic control flow (e.g., early exit, mixture-of-experts routing) is harder to express. If akunu ever supports MoE models, the dispatch table design would need extension.
ADR-2: C API vs. C++
Problem
Akunu needs a public API for applications (CLI tools, Swift apps, servers) to load models, tokenize text, and run inference. The API design affects language binding ergonomics, ABI stability, and the mental model for users.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. C++ class API | class AkunuModel { ... } with methods | Natural for C++ users, can use RAII, templates | C++ ABI is fragile across compilers/versions, hard to bind from Swift/Python/Rust |
| B. C API with opaque handles | akunu_model_t as void*, free functions | Stable ABI, trivial to bind from any language, no name mangling | Verbose, no RAII, error handling via return codes or thread-local strings |
| C. C API wrapping C++ internals | Same as B, but the implementation uses C++ internally | Best of both: stable API surface, modern implementation | Thin translation layer between C and C++ |
Decision
Option C: C API wrapping C++ internals.
Rationale
The primary consumer of akunu’s API is the Swift binding layer (CAkunu/shim.c), which needs a C-compatible interface. Swift can import C headers directly via the Clang importer, but C++ interop (even with Swift 5.9+) is limited and fragile. A pure C API with opaque void* handles is the safest choice.
The actual API surface, visible in include/akunu/akunu.h, follows a consistent pattern:
- Lifecycle:
akunu_load_model()/akunu_free_model() - Info:
akunu_get_config(),akunu_model_memory() - Tokenization:
akunu_encode(),akunu_decode_token() - Generation:
akunu_generate(),akunu_chain_decode(),akunu_generate_continue() - Profiling:
akunu_profile_decode_step(),akunu_profile_label() - Error:
akunu_get_error()(thread-local)
All structs passed across the API boundary (AkunuModelConfig, AkunuGenerationStats, AkunuSamplingConfig) are defined in types.h as C-compatible POD types with fixed-width integer fields.
Consequences
- Pro: The Swift package (
Sources/CAkunu) imports the C header directly with zero bridging code beyond a thin shim. - Pro: The API is trivially bindable from Python (ctypes/cffi), Rust (bindgen), and any other language with C FFI.
- Pro: ABI stability – the library can be updated without recompiling consumers, as long as the C function signatures do not change.
- Con: No RAII for model handles. Forgetting
akunu_free_model()leaks GPU memory. The Swift binding wraps this in a class withdeinit. - Con: Error messages are thread-local strings, which is less ergonomic than exceptions or Result types.
ADR-3: Data-Driven Architecture Descriptors
Problem
Akunu supports multiple model architectures: LLaMA, Qwen3, Gemma, Gemma3, Whisper, and BERT. Each architecture has differences in activation functions, normalization placement, RoPE style, embedding scaling, and more. The question is how to handle these differences without littering the codebase with if (arch == "llama") ... else if (arch == "gemma") ... branches.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Architecture-specific subclasses | class LlamaModel : public Model, class GemmaModel : public Model | Clean separation of concerns | Virtual dispatch overhead, code duplication across architectures, adding a new arch means a new class |
| B. If/else chains on architecture name | Check cfg.architecture string throughout the codebase | Simple, no abstraction overhead | Scatters architecture logic everywhere, easy to miss a branch, N*M combinatorial explosion |
| C. Data-driven descriptor struct | An ArchDescriptor POD struct that captures all arch-specific behavior as data fields | Single source of truth, no branching in hot path, trivial to add new architectures | Must anticipate all possible variation dimensions up front |
Decision
Option C: Data-driven ArchDescriptor.
Rationale
The key insight is that the differences between architectures are almost entirely parametric, not structural. LLaMA and Gemma have the same transformer skeleton; they differ in:
- Activation function: SiLU (LLaMA) vs GELU (Gemma)
- Embedding scaling: none (LLaMA) vs
sqrt(dim)(Gemma) - QK norm: no (LLaMA) vs yes (Qwen3, Gemma)
- Post-attention/FFN norms: no (LLaMA) vs yes (Gemma)
- RoPE style: interleaved (LLaMA) vs NeoX/split-half (Qwen3, Gemma)
- Tied embeddings: no (LLaMA) vs yes (Qwen3, Gemma)
All of these can be captured as fields in a struct. The ArchDescriptor in src/core/arch_descriptor.h has fields like activation_kernel, embedding_scale, has_qk_norm, rope_kernel, and tie_embeddings. Factory functions (arch_llama(), arch_qwen3(), arch_gemma(), etc.) return pre-filled descriptors. The arch_from_config() function maps GGUF metadata strings to the right factory.
The build_dispatch_table() function in table_builder.h reads from the ArchDescriptor and never branches on architecture name. Adding support for a new LLaMA variant (say, Mistral with sliding window attention) is typically a one-line change: modify an existing factory or add a new one.
Consequences
- Pro: Adding Qwen3 support required writing
arch_qwen3()(4 lines that override 3 fields from the LLaMA defaults) and zero changes to the table builder or decode path. - Pro: The hot path (dispatch table replay) is completely architecture-agnostic. The architecture was “compiled away” during init.
- Con: Truly novel architectures (e.g., mixture of experts, state-space models) may not fit the descriptor model and would require structural changes.
- Con: Encoder-decoder models (Whisper) stretch the descriptor with fields like
is_encoder_decoder,has_cross_attention,has_conv_frontendthat are irrelevant for decoder-only models.
ADR-4: Precomputed RoPE via Fused Kernel
Problem
Rotary Position Embeddings (RoPE) apply a rotation to the Q and K vectors based on their position in the sequence. The rotation frequencies are computed as theta^(-2i/d) for each dimension pair i. This computation involves transcendental functions (sin, cos) which are expensive even on GPU hardware.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Compute RoPE frequencies per-token | Each decode step computes sin/cos from theta and position | Simple, no precomputation needed | Redundant transcendental function calls for every token |
| B. Precompute frequency table on CPU | Build a [max_seq_len, head_dim] table of sin/cos values at init | Amortizes transcendental cost | Large table (max_seqhead_dim4 bytes), wastes memory for short contexts |
| C. Precompute frequency divisors | Store [head_dim/2] frequency divisors; compute position*freq in kernel | Tiny table (head_dim/2 floats), position multiply is cheap | Slightly more complex kernel |
| D. Fuse RoPE with QKV projection + KV cache write | Single kernel: GEMV -> RoPE rotate Q,K -> write K,V to cache | Eliminates 2-3 separate kernel dispatches per layer | Complex kernel, harder to debug |
Decision
Option D: Fused RoPE+QKV+KV-write kernel, with precomputed frequency divisors (Option C) as the frequency source.
Rationale
In a non-fused approach, each transformer layer during decode requires:
- GEMV for Q projection
- GEMV for K projection
- GEMV for V projection
- RoPE on Q
- RoPE on K
- KV cache write for K
- KV cache write for V
That is 7 kernel dispatches. The fused kernel rope_qkv_write_f16 (or rope_neox_qkv_write_f16 for NeoX-style) combines steps 4-7 into a single dispatch. Combined with QKV fusion (fusing the three GEMV projections into a single GEMV that writes to a contiguous [q_dim + 2*kv_dim] buffer), the total drops from 7 dispatches to 2 (one fused GEMV, one fused RoPE+KV-write).
The RoPEQKVWriteParams struct captures all the parameters this fused kernel needs:
| Field | Purpose |
|---|---|
n_kv_heads | Number of KV heads (for GQA) |
head_dim | Elements per head |
max_seq_len | KV cache dimension for stride computation |
pos | Current position (patched per token in chain decode) |
theta | RoPE base frequency |
n_heads | Number of Q heads |
k_elem_offset | Byte offset to K section in QKV buffer |
v_elem_offset | Byte offset to V section in QKV buffer |
freq_scale | Linear RoPE scaling factor (1.0 = no scaling) |
The rope_freqs field in ArchDescriptor stores precomputed frequency divisors when the model provides them (some GGUF files include rope_freqs metadata). Otherwise, the kernel computes frequencies from theta directly using the standard formula.
Consequences
- Pro: Reduces per-layer dispatch count from 7 to 2, saving ~5 kernel launch overheads per layer per token. For a 32-layer model in chain decode (128 tokens), this eliminates 5 * 32 * 128 = 20,480 dispatch commands per chunk.
- Pro: Better memory access pattern. The fused kernel reads QKV once and writes K/V to cache in the same pass, improving cache utilization.
- Con: Two RoPE kernel variants (interleaved and NeoX) must be maintained, each with fused and standalone versions.
- Con: The fused kernel has more parameters (9 fields in
RoPEQKVWriteParams) and more complex dispatch geometry.
ADR-5: Chain Decode
Problem
The naive approach to autoregressive decoding is: for each token, encode one forward pass into a Metal command buffer, commit it, wait for completion, read back the result, and feed it to the next step. This creates a CPU-GPU synchronization point per token, and each sync costs 30-80 microseconds of overhead.2
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. One token per command buffer | Standard approach: encode, commit, wait, repeat | Simple, argmax result available immediately | ~50us sync overhead per token, GPU idle during CPU readback |
| B. Speculative multi-command buffer | Encode N tokens speculatively, verify in batch | Amortizes sync cost | Requires draft model or prediction, complex verification logic |
| C. Chain decode in single command buffer | Encode N identical forward passes back-to-back, patching only position and token offset. Argmax output of token i feeds as input to token i+1 via GPU buffer. | Single sync per chunk, GPU stays 100% busy | Must use greedy decoding (argmax). Non-greedy requires CPU-side sampling between tokens. |
Decision
Option C: Chain decode with single command buffer.
Rationale
The key observation is that for greedy decoding (temperature=0), the next token is determined entirely on the GPU (argmax of logits). There is no need to read the result back to the CPU between tokens. The encode_chain() function simply repeats the dispatch table N times, patching:
buffers[0].offset = tok * 4for embedding lookup (reads from token_ids buffer)param.pos = start_position + tokfor RoPE and KV cacheparam.kv_seq_len = start_position + tok + 1for attention
The argmax kernel writes its result to token_ids[tok + 1], and the next iteration’s embedding lookup reads from that same buffer. The entire chain runs as one GPU submission with zero CPU intervention.
The chunk size (chain_decode_chunk in ChipConfig) is tuned per chip:
| Chip Class | Chunk Size | Rationale |
|---|---|---|
| M1/M2/M3 base | 64 | Smaller GPU, less command buffer memory |
| M3 Pro | 96 | More GPU cores, but older command processor |
| M4 family, Max/Ultra | 128 | Improved command processor, higher bandwidth |
Consequences
- Pro: Eliminates ~50us * N sync overhead per chunk. For 128 tokens at 80 tok/s, this saves ~6.4ms – an 8% throughput improvement.
- Pro: GPU utilization approaches 100% within a chunk. Metal System Trace shows a continuous block of GPU activity with no idle gaps.
- Con: Only works for greedy (argmax) decoding. Non-greedy decoding (temperature > 0, top-k, top-p) requires
akunu_decode_step()with per-token CPU-GPU synchronization. - Con: The KV cache must be pre-sized for the full chunk. If the context window fills mid-chunk, the chain must terminate early.
- Con: Error recovery is harder – if a token generates an EOS mid-chain, the remaining tokens are wasted work.
ADR-6: GPU Gumbel-Max vs. CPU Sampling
Problem
When temperature > 0, the model needs to sample from the logit distribution rather than take the argmax. Sampling involves: (1) applying temperature scaling, (2) optionally applying repetition penalty, (3) computing probabilities (softmax), and (4) drawing a random sample. Where should this happen – CPU or GPU?
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. CPU sampling | Read logits back from GPU, sample on CPU | Full flexibility (top-k, top-p, min-p, grammar constraints), easy to implement | Requires GPU-to-CPU data transfer of entire logit vector (vocab_size * 2 bytes), breaks chain decode |
| B. GPU argmax only | Keep greedy on GPU, fall back to CPU for non-greedy | Simple, no sampling complexity on GPU | Non-greedy is slower due to sync overhead |
| C. GPU Gumbel-max trick | Add Gumbel noise to logits on GPU, then take argmax. Mathematically equivalent to sampling from the softmax distribution.3 | Keeps everything on GPU, compatible with chain decode | Limited to temperature sampling; top-k/top-p require sorting which is expensive on GPU |
Decision
Option A (CPU sampling) for non-greedy, Option B (GPU argmax) for greedy. GPU temperature scaling and repetition penalty are provided as optional GPU-side preprocessing (via akunu_gpu_temperature_scale() and akunu_gpu_repetition_penalty()), but the actual sampling decision happens on the CPU.
Rationale
The Gumbel-max trick (Option C) was prototyped but ultimately not adopted as the default for several reasons:
-
Grammar-constrained decoding requires masking invalid tokens before sampling. The xgrammar integration runs on the CPU and produces a token bitmask. Applying this mask on the GPU would require uploading it per token, which adds overhead that negates the chain decode benefit.
-
Top-k and top-p sampling require sorting or partial sorting of the logit vector, which is not efficient on Apple’s GPU for the vocabulary sizes used by modern LLMs (32K-128K). A GPU radix sort for 128K elements would consume more time than just reading the logits back to the CPU.
-
Sampling with temperature=0 (greedy) accounts for the majority of use cases in benchmarks and many production deployments. Chain decode works perfectly for greedy mode.
-
The logit readback cost is bounded: for a 128K vocabulary at FP16, the transfer is 256 KB – well within the UMA zero-copy window on Apple Silicon. The “transfer” is really just a cache flush, not a DMA copy.
Consequences
- Pro: Full sampling flexibility (temperature, top-k, top-p, min-p, repetition penalty, grammar constraints) without GPU-side complexity.
- Pro: Greedy mode gets the full chain decode benefit with zero overhead.
- Con: Non-greedy decoding cannot use chain decode and pays the per-token sync cost (~50us per token).
- Con: The GPU-side temperature and repetition penalty kernels (
TemperatureScaleParams,RepetitionPenaltyParams) are currently only used when the user explicitly calls the low-level API; the high-levelakunu_generate()does sampling on the CPU.
ADR-7: Dual Format Support (GGUF + MLX SafeTensors)
Problem
The Apple Silicon LLM ecosystem has two dominant weight formats:
- GGUF: The format from llama.cpp. Block-quantized weights with rich metadata. Supported by virtually every open-source LLM tool.
- MLX SafeTensors: The format from Apple’s MLX framework. Group-quantized weights in SafeTensors container with JSON config. Growing ecosystem, especially for Apple-optimized models.
Should akunu support one or both?
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. GGUF only | Rely on the GGUF ecosystem | Largest model library, well-understood format | Misses MLX-specific models, some newer models ship MLX-first |
| B. MLX only | Align with Apple’s own framework | Native Apple ecosystem, simpler quantization | Misses the vast GGUF library, less community tooling |
| C. Both via WeightProvider abstraction | Unified interface that wraps either backend | Access to both ecosystems | Two code paths to maintain, name mapping complexity |
Decision
Option C: Both formats via WeightProvider.
Rationale
Users should not have to choose a format. If they have a GGUF file from Hugging Face, it should work. If they have an MLX model from mlx-community, it should also work. The WeightProvider class in src/weight/weight_provider.h detects the format at load time:
- If the path is a directory or ends in
.safetensors-> MLX format - Otherwise -> GGUF format
Both backends expose the same interface: get_tensor(), get_dtype(), has_tensor(), get_config(), plus metadata accessors. The GGUF backend (WeightStore) wraps the GGUF parser; the MLX backend (MLXWeightStore) wraps the SafeTensors parser plus a name mapping layer (MLX uses HuggingFace naming conventions like model.layers.{n}.self_attn.q_proj.weight, which akunu maps to canonical names like layers.{n}.attention.q.weight).
The dtype system uses internal codes: GGUF dtypes 0-30 for standard GGUF types, plus synthetic codes 99-102 for MLX quantized formats (MLX Q3, Q4, Q6, Q8). The DTypeDescriptor table in dtype_descriptor.h maps each code to the appropriate GEMV, GEMM, and embedding kernels.
Consequences
- Pro: Users can load any model from either ecosystem with the same
akunu_load_model()call. - Pro: The same Metal kernels are used for both formats where the quantization is compatible (e.g., FP16 weights from either format use the same
gemv_f16kernel). - Con: MLX quantization is group-based (group_size=64 typically) while GGUF is block-based (block_size=32 for Q4_0). Different dequantization logic in the Metal kernels.
- Con: The MLX name mapping table (
kMLXRulesinmlx_weight_store.h) must be updated when new architectures use different naming conventions.
ADR-8: Fused Kernels
Problem
A transformer layer involves many small operations that are individually simple but collectively expensive due to kernel launch overhead and redundant memory traffic. For example, the FFN block in a SwiGLU model does:
- GEMV:
gate = W_gate @ x - GEMV:
up = W_up @ x - Elementwise:
act = silu(gate) * up - GEMV:
down = W_down @ act
Each of those first three steps reads and writes intermediate buffers. Can we fuse them?
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. No fusion | Separate kernels for every operation | Simple, composable, easy to debug | High kernel launch count, redundant memory traffic for intermediates |
| B. Full FFN fusion | Single kernel: gate+up GEMV -> SiLU -> down GEMV | Minimal memory traffic | Extremely complex kernel, hard to tune, loses flexibility |
| C. Selective fusion | Fuse operations where the benefit is clear: gate+up GEMV+activation, QKV+RoPE+KV-write | Good balance of performance and complexity | Must decide which fusions are worth the implementation cost |
Decision
Option C: Selective fusion.
Rationale
Akunu implements several fused kernels where profiling showed clear benefits:
| Fused Kernel | What It Fuses | Benefit |
|---|---|---|
gemv_q4_0_silu | Gate GEMV + SiLU activation + Up GEMV element multiply | Eliminates 2 intermediate buffer writes/reads, 1 fewer dispatch |
rope_qkv_write_f16 | RoPE rotation of Q,K + KV cache write for K,V | Eliminates 3-4 separate dispatches per layer (see ADR-4) |
gemv_kv_* (GEMVKVParams) | K or V projection GEMV + direct KV cache write | Eliminates separate KV cache copy kernel |
gemv_head_norm_* (GEMVHeadNormParams) | GEMV output + per-head RMSNorm | Eliminates separate norm dispatch for QK-norm models |
The fused SiLU kernels (gemv_q4_0_silu, gemv_mlx_q4_silu, etc.) are particularly valuable because the activation is applied during the GEMV accumulation, before the result is written to device memory. Each thread computes silu(gate_partial) * up_partial in registers, avoiding a round-trip through device memory for the intermediate gate and up buffers.
Weight fusion (gate+up weights concatenated into a single buffer) is a separate but related optimization. The ChipConfig::should_fuse_weights flag enables this on Pro+ chips where the SLC is large enough to benefit from reading the fused weight buffer sequentially.
Consequences
- Pro: 15-25% reduction in per-layer kernel count for decode, directly translating to lower dispatch overhead.
- Pro: Reduced memory traffic for intermediates, which matters on bandwidth-constrained base chips.
- Con: Each fused kernel variant must be written and tested for every supported dtype. The
gemv_*_silukernels exist for Q4_0, MLX Q3, MLX Q4, MLX Q6, and MLX Q8 – five implementations of the same fusion. - Con: The
DTypeDescriptortable must track both fused and unfused kernel names, adding complexity to the dtype lookup.
ADR-9: Ping-Pong Scratch Buffers
Problem
The transformer’s residual connection pattern means that each layer’s output is added to its input. The straightforward implementation allocates a new buffer for each intermediate result, but this wastes memory and forces unnecessary allocations in the hot path.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Dynamic allocation | device.allocate() per forward pass for intermediates | Simple, no buffer management | Allocation in hot path, memory fragmentation, Metal allocator overhead |
| B. Single scratch buffer with offsets | One large buffer, sub-allocate with manual offset management | Minimal total memory | Complex offset bookkeeping, risk of aliasing bugs, hard to reason about lifetimes |
| C. Named ping-pong buffers | Pre-allocate fixed named buffers (h0, h1, residual, qkv, etc.) at model load. Alternate between h0 and h1 for residual accumulation. | Zero allocation in hot path, clear ownership, easy to debug | Fixed memory footprint regardless of actual usage |
Decision
Option C: Named ping-pong buffers via ScratchBuffers.
Rationale
The ScratchBuffers struct in src/cache/scratch.h pre-allocates all intermediate buffers during akunu_load_model(). The decode buffers are:
| Buffer | Size | Purpose |
|---|---|---|
h0 | dim * 2 bytes | Residual stream “ping” (FP16) |
h1 | dim * 2 bytes | Residual stream “pong” (FP16) |
residual | dim * 2 bytes | Norm output |
qkv | (q_dim + 2*kv_dim) * 2 bytes | Contiguous Q, K, V projections |
attn_out | max(q_dim, dim) * 2 bytes | Attention output |
post_norm | dim * 2 bytes | Post-norm temp (Gemma-style architectures) |
ffn_gate | ffn_dim * 4 bytes | Gate projection (2x for fused gate+up) |
ffn_up | ffn_dim * 2 bytes | Up projection |
ffn_act | ffn_dim * 2 bytes | Activation output |
logits | vocab_size * 2 bytes | Final logits |
token_ids | max_chain * 4 bytes | Chain decode token buffer |
The residual connection works by alternating between h0 and h1:
Layer input: h0
-> norm: residual = rmsnorm(h0)
-> attn: attn_out = attention(qkv_proj(residual))
-> add: h1 = h0 + attn_out (residual add)
-> norm: residual = rmsnorm(h1)
-> ffn: ffn_out = down(silu(gate) * up)
-> add: h0 = h1 + ffn_out (residual add)
Layer output: h0 (same buffer as input -- full circle)
Odd-numbered layers swap the roles: input from h1, output to h0. This “ping-pong” pattern means we never need more than two hidden-state-sized buffers for the entire forward pass, regardless of model depth.
A separate set of batch_* buffers handles prefill, where each buffer is scaled by prefill_chunk * dim to handle batched operations.
Consequences
- Pro: Absolutely zero memory allocation during inference. Every buffer is pre-allocated and reused.
- Pro: The
ScratchBuffersstruct is POD-like: a flat collection ofBufferhandles. The dispatch table references these buffers directly by name, making the data flow through the model explicit and debuggable. - Pro: Memory footprint is predictable and constant.
akunu_model_memory()can report the exact GPU memory usage at load time. - Con: The memory footprint is the maximum needed for any forward pass, even if most operations use less. For example,
ffn_gateis allocated atffn_dim * 4to support fused gate+up, even when weight fusion is disabled. - Con: Adding a new buffer for a new operation (e.g., a second attention output for cross-attention in Whisper) requires modifying the
ScratchBuffersstruct and thecreate()/destroy()methods.
ADR-10: Head-Major KV Cache Layout
Problem
The KV cache stores the key and value vectors for all previously processed tokens. During attention, the kernel needs to read all K vectors for a given head to compute dot products with the query, then read all V vectors for the same head to compute the weighted sum. The memory layout of the KV cache determines the access pattern and thus the memory bandwidth efficiency.
Options Considered
| Option | Description | Memory Layout | Access Pattern |
|---|---|---|---|
| A. Sequence-major | [max_seq_len, n_kv_heads, head_dim] | All heads for position 0, then all heads for position 1, … | Attention reads for one head are strided across positions |
| B. Head-major | [n_kv_heads, max_seq_len, head_dim] | All positions for head 0 (contiguous), then all positions for head 1, … | Attention reads for one head are contiguous |
| C. Paged | Small fixed-size pages, indirection table | Pages allocated on demand, pages for one head may be non-contiguous | Flexible memory management, but indirection overhead |
Decision
Option B: Head-major layout [n_kv_heads, max_seq_len, head_dim].
Rationale
During decode attention, each query head computes:
scores[t] = dot(Q[head], K[head][t]) for t in 0..kv_seq_len-1
output = sum(scores[t] * V[head][t])
With head-major layout, all K vectors for a given head are contiguous in memory: K[head] is a contiguous block of max_seq_len * head_dim FP16 values. The attention kernel can read this block sequentially, which maximizes memory bandwidth utilization on Apple Silicon’s memory controller.
With sequence-major layout (Option A), reading K vectors for one head would require striding by n_kv_heads * head_dim elements between positions, resulting in poor cache line utilization – you load a full cache line but only use head_dim / (n_kv_heads * head_dim) of it.
The KV stride is precomputed as kv_stride = max_seq_len * head_dim (elements between consecutive KV heads) and stored in the KVCache struct. The AttentionParams struct passes this to the attention kernel via the kv_stride field. A value of 0 means “use kv_seq_len * head_dim” (the dense case), which is useful when the KV cache is exactly filled.
The KVCacheWriteParams struct handles writing new K/V vectors to the correct position:
offset_in_buffer = head * kv_stride + pos * head_dim
This is a simple multiply-add, computed in the fused RoPE+KV-write kernel.
Consequences
- Pro: Contiguous memory access for attention reads, maximizing Apple Silicon’s memory bandwidth. This is the most important access pattern to optimize because attention cost grows linearly with context length.
- Pro: The GQA (Grouped Query Attention) pattern falls out naturally: Q heads 0..3 all read from KV head 0, which is a single contiguous block. No gather/scatter needed.
- Pro: KV cache shifting (for sliding window) is a simple
memmovewithin each head’s contiguous block. TheKVCacheShiftParamsstruct supports this. - Con: Memory is allocated for the full
max_seq_lenper head, even if the actual sequence is shorter. For a model withn_kv_heads=8, max_seq_len=4096, head_dim=128in FP16, each layer’s K cache is8 * 4096 * 128 * 2 = 8 MB. Over 32 layers, that is 512 MB for K alone (plus 512 MB for V). - Con: Paged attention (Option C) would use less memory for short sequences, but the indirection overhead and implementation complexity were deemed not worth it for the target use case (single-user inference on Apple Silicon with sufficient memory).
Summary: How the Decisions Fit Together
These ten decisions are not independent. They form an interlocking system:
+--------------------+
| ArchDescriptor(3) |---> drives table_builder
+--------------------+
|
v
+--------------------+ +-------------------+
| DispatchTable(1) |<--->| ScratchBuffers(9) |
+--------------------+ +-------------------+
| |
v v
+--------------------+ +-------------------+
| Chain Decode(5) | | KV Cache(10) |
+--------------------+ +-------------------+
|
v
+--------------------+
| Fused Kernels(8) |
| - RoPE+QKV(4) |
| - GEMV+SiLU |
+--------------------+
|
v
+--------------------+ +-------------------+
| C API(2) | | WeightProvider(7) |
+--------------------+ +-------------------+
| |
v v
+--------------------+ +-------------------+
| Sampling(6) | | GGUF + MLX dtypes |
+--------------------+ +-------------------+
The ArchDescriptor (3) feeds into the dispatch table builder. The dispatch table (1) references pre-allocated scratch buffers (9) and KV cache buffers (10), and embeds fused kernels (8, 4). Chain decode (5) replays the dispatch table with minimal patching. The C API (2) wraps all of this behind opaque handles. The WeightProvider (7) supplies weights in either GGUF or MLX format. And the sampling strategy (6) determines whether chain decode can be used (greedy) or falls back to per-token decode.
If you are extending akunu, this dependency graph tells you what you need to touch. Adding a new quantization format? Modify DTypeDescriptor (8/7) and add kernels. Adding a new architecture? Add an ArchDescriptor factory (3). Implementing paged attention? That affects KV cache (10), the dispatch table (1), and the attention kernel (8).
-
The ADR format was popularized by Michael Nygard. See “Documenting Architecture Decisions” (2011). The format used here is a simplified version: Problem, Options, Decision, Rationale, Consequences. See https://cognitect.com/blog/2011/11/15/documenting-architecture-decisions. ↩
-
Measured on M2 Pro:
MTLCommandBuffercommit + waitUntilCompleted averages 45us when the command buffer contains a trivial kernel. The overhead is in the driver and command processor, not the GPU itself. ↩ -
The Gumbel-max trick: if
g_i ~ Gumbel(0,1), thenargmax(log(p_i) + g_i) ~ Categorical(p). Sincelog(p_i) = logit_i / temperature - log(Z), and thelog(Z)term is constant across categories,argmax(logit_i / temperature + g_i)samples from the temperature-scaled distribution. ↩