Introduction
Welcome to Inside Akunu: Building a High-Performance Inference Engine for Apple Silicon. This is a book about going fast on hardware that most people barely understand. It is also a book about taking apart an inference engine – piece by piece, kernel by kernel, buffer by buffer – until you understand every decision that went into making it one of the fastest ways to run large language models on a Mac.
If you have ever wondered why your MacBook Pro can run a 70-billion-parameter model at conversational speed, or how a Metal compute shader can squeeze every last byte of bandwidth out of Apple’s unified memory, or what exactly happens between the moment you press Enter on a prompt and the moment the first token appears on screen, this book is for you.
What Is Akunu?
Akunu is a high-performance inference engine purpose-built for Apple Silicon. It runs large language models (LLMs), speech-to-text models (Whisper), and other transformer-based architectures entirely on the GPU cores of Apple’s M-series chips. It is written in C++17 (with a C API for FFI compatibility) and Metal Shading Language (MSL), with no dependencies on heavyweight frameworks like PyTorch, CoreML, or even Apple’s own MPS (Metal Performance Shaders) graph library.
The numbers speak for themselves: in decode throughput – the metric that determines how fast tokens appear on screen during generation – akunu achieves 1.83x faster decode than llama.cpp,1 the most popular open-source inference engine for consumer hardware. This is not a marginal improvement squeezed out of micro-optimizations. It is the result of fundamentally rethinking how to map the inference workload onto Apple Silicon’s unique architecture.
Here is a snapshot of decode performance on an M4 Pro (16 GPU cores, 273 GB/s):
| Model (GGUF) | llama.cpp (tok/s) | akunu (tok/s) | Speedup |
|---|---|---|---|
| Qwen3-0.6B Q3_K_S | 123 | 448 | 3.66x |
| Qwen3-0.6B Q4_0 | 169 | 465 | 2.75x |
| TinyLlama-1.1B Q4_0 | 208 | 343 | 1.65x |
| Llama-3.2-1B Q4_0 | 189 | 294 | 1.55x |
| Qwen3-4B Q4_K_M | 63 | 89 | 1.42x |
| Average (19 models) | 1.83x |
How does akunu achieve this? Not through any single trick, but through a relentless focus on the hardware. Every kernel is hand-tuned for Apple GPU microarchitecture. Every buffer allocation is designed around unified memory. Every dispatch decision accounts for the chip generation, the System Level Cache size, and the specific GPU core count of the machine it is running on. The entire inference pipeline – from loading model weights off disk to sampling the next token – is a single, carefully orchestrated sequence of Metal compute commands with minimal CPU overhead.
Who Is This Book For?
This book is written for people with a computer science background who want to understand how high-performance GPU programming works on Apple Silicon, and specifically how that knowledge is applied to build a state-of-the-art inference engine.
You should be comfortable with:
-
C/C++ programming. Akunu’s core is C++17, with a C API for external use. You should know your way around pointers, structs, and manual memory management.
-
Basic linear algebra. Matrix multiplication, dot products, transpose operations. You do not need a PhD in numerical methods, but you should know what a matrix-vector product is and why it matters for neural networks.
-
General computer architecture. Caches, memory hierarchies, pipelining, parallelism. A systems programming or computer architecture course is sufficient background.
-
The basics of how neural networks work. You should know what a transformer is at a high level – attention, feed-forward layers, embeddings, softmax. We will not be training models in this book; we will be running them as fast as physically possible.
You do not need prior experience with:
- Apple’s Metal API or Metal Shading Language
- GPU programming (CUDA, OpenCL, or otherwise)
- Apple Silicon internals
- The specific architectures of LLMs like Llama, Mistral, or Qwen
We will teach you all of that from the ground up.
What You Will Learn
This book is organized into nine parts, each building on the last. By the time you finish, you will have a deep understanding of every layer of the stack, from the silicon to the HTTP API.
Part I: Apple Silicon Fundamentals
Before you can write fast code for Apple Silicon, you need to understand what Apple Silicon actually is. Part I takes you on a tour of the hardware, starting with why Apple moved from Intel to ARM, then drilling into the SoC architecture, the GPU core design, the unified memory system, and the generational improvements from M1 through M4.
By the end of Part I, you will understand:
- Why Apple’s unified memory architecture is a game-changer for ML inference
- How Apple’s GPU cores differ from NVIDIA’s streaming multiprocessors
- What a SIMD group is and why it is the fundamental unit of GPU execution
- How the System Level Cache (SLC) acts as a shared last-level cache for all on-chip agents
- Why memory bandwidth – not compute – is the bottleneck for LLM decode
Part II: Metal Programming
Metal is Apple’s low-level GPU programming framework, and it is the only way to access the full power of the Apple GPU. Part II teaches you Metal from scratch: how to create compute pipelines, write shaders in the Metal Shading Language, dispatch threadgroups, manage buffers, and use SIMD group matrix operations. This is not a toy tutorial – by the end, you will understand the same programming model that akunu’s kernels use.
By the end of Part II, you will understand:
- How to set up a Metal compute pipeline and dispatch work to the GPU
- The Metal Shading Language (MSL) – a C++14-based language for writing GPU kernels
- How threadgroups, SIMD groups, and threads map to the hardware
- The Metal memory model: device memory, threadgroup memory, and how they interact
- SIMD group matrix operations (simdgroup_matrix) – Apple’s answer to NVIDIA’s Tensor Cores
- Performance optimization patterns: coalesced access, occupancy tuning, avoiding bank conflicts
Part III: Machine Learning on Metal
Part III bridges the gap between GPU programming and machine learning. We cover how tensors are represented in GPU memory, the different strategies for matrix multiplication, the attention mechanism, FlashAttention and how it maps to Metal, and quantization – the art of making models small enough to fit in memory without destroying quality.
Part IV: Akunu Architecture
Now we get to akunu itself. Part IV covers the design philosophy, the build system, the C API, the device abstraction layer, architecture descriptors (which let akunu support new model architectures without code changes), and chip configuration.
Part V: The Inference Pipeline
Part V walks through the complete inference pipeline: model loading, the dispatch table (akunu’s precompiled GPU command sequence), the prefill phase, the decode loop, and the decoding strategies: greedy/chain decode, sampled, speculative, and grammar-constrained.
Part VI: Metal Kernels Deep Dive
This is the heart of the book. Part VI takes you through every Metal kernel: GEMV, GEMM, FlashAttention, normalization, RoPE, embedding and activation, and sampling. For each, we explain the algorithm, the memory access pattern, the SIMD group coordination, and include interactive animations showing GPU execution.
Part VII: Weight Management
Models come in many file formats. Part VII covers the weight provider abstraction, the GGUF file format and akunu’s parser, SafeTensors and MLX format support, and the zoo of quantization formats (Q4_0, Q4_1, Q8_0, K-quants, MLX 3/4/6/8-bit, and more).
Part VIII: Supporting Systems
Inference engines need more than just GPU kernels. Part VIII covers the KV cache, scratch buffer architecture, the tokenizer (BPE and SentencePiece), the grammar engine (for structured output), Whisper (speech-to-text), and the HTTP server.
Part IX: Contributing to Akunu
The final part is a contributor guide: dev setup, testing, adding kernels, adding architectures, profiling, and architectural decision records.
Why This Book Exists
There are many resources for learning CUDA programming on NVIDIA GPUs. There are tutorials for PyTorch, guides for TensorRT, deep dives into NVIDIA’s Tensor Core architecture. The NVIDIA ecosystem is mature, well-documented, and widely understood.
The Apple Silicon ecosystem has… almost none of that.
If you want to understand how Apple’s GPU works at the microarchitectural level, you will find sparse documentation, a handful of WWDC sessions, and a lot of educated guesswork from the reverse-engineering community. If you want to write high-performance compute shaders for Metal, the official guides are thin on practical advice. If you want to understand how to map an LLM inference workload onto Apple Silicon efficiently, you are largely on your own.
This book exists to fill that gap. We have spent hundreds of hours profiling, benchmarking, reverse-engineering, and optimizing akunu’s Metal kernels on Apple Silicon. We have learned things about Apple’s GPU that are not documented anywhere. We want to share that knowledge so that the next person who wants to build something fast on Apple hardware does not have to start from zero.
| NVIDIA Ecosystem | Apple Silicon Ecosystem |
|---|---|
| CUDA Handbook | A few WWDC sessions |
| PTX ISA Guide | No public ISA reference |
| Tensor Core documentation | No GPU microarch docs |
| cuBLAS/cuDNN guides | Metal Best Practices (thin) |
| Hundreds of papers and blog posts | Scattered community reverse engineering |
| This book fills the gap |
How to Read This Book
This book is designed to be read sequentially. Each chapter builds on concepts introduced in previous chapters, and later parts assume familiarity with the hardware and programming model covered in earlier parts.
That said, here are some suggested paths depending on your background:
“I just want to understand akunu’s codebase so I can contribute.” Read Part I (skim if you already know Apple Silicon), skim Part II (read Chapter 8 on MSL carefully), then jump to Part IV and read sequentially through Part IX.
“I’m a CUDA programmer and I want to learn Metal.” Start with Part I to understand how Apple’s hardware differs from NVIDIA’s. Then read Part II carefully – it is the Metal equivalent of a CUDA programming guide. Chapter 3 (GPU architecture) and Chapter 9 (threadgroups and dispatch) are especially relevant for mapping your CUDA mental model to Metal.
“I want to understand the ML/inference side.” Skim Part I and II for context, then read Part III carefully. Then jump to Part V (the inference pipeline) and Part VI (the kernels).
“I want to understand everything.” Read the whole book, front to back. That is what it is for.
Conventions Used in This Book
Throughout this book, we use the following conventions:
-
Code listings are shown in monospaced font. Metal Shading Language code is annotated with comments explaining non-obvious constructs.
-
ASCII diagrams are used extensively. We chose ASCII art over images because it renders correctly in every format (web, PDF, terminal), is easy to modify, and can be included in code review comments and commit messages.
-
Performance numbers are given for specific hardware configurations. Unless otherwise noted, benchmarks were run on an M4 Pro (16 GPU cores, 273 GB/s) with macOS 15 unless otherwise noted. Your numbers will differ on different hardware.
-
“Apple GPU” refers to the GPU cores on Apple Silicon (M1, M2, M3, M4 families), not the older Intel integrated graphics on pre-2020 Macs.
-
Register types in Metal are explained when first used. Metal uses
uint,float,half(16-bit float), and others. We will explain the implications of each. -
Chip-specific details are called out in notes. When behavior differs between M1 and M4, we will tell you.
A Note on Apple’s Documentation (or Lack Thereof)
Apple is famously secretive about its hardware. The company does not publish detailed microarchitectural specifications for its GPUs, does not release ISA references for its shader cores, and does not provide the kind of performance tuning guides that NVIDIA publishes for CUDA.
This means that some of what we describe in this book – particularly regarding the internal structure of GPU cores, the exact sizes of register files, and the behavior of the instruction pipeline – is based on a combination of:
- Apple’s public documentation2 (Metal Best Practices Guide, WWDC sessions, Metal Feature Set tables)
- Reverse engineering (running carefully constructed microbenchmarks to measure latencies, throughputs, and cache sizes)
- Community knowledge3 (the excellent work of Dougall Johnson, Alyssa Rosenzweig, and others who have reverse-engineered Apple GPU internals)
- Empirical observation (running akunu’s kernels with different configurations and measuring what works best)
Where we are confident in our claims, we state them as facts. Where we are making educated inferences from indirect evidence, we say so. We encourage you to verify our claims through your own experiments – that is part of the fun.
Let’s Begin
Apple Silicon is, in our opinion, the most interesting computing platform to emerge in the last decade. It combines extraordinary hardware – a unified memory architecture, a surprisingly powerful GPU, and a custom interconnect fabric – with a programming model (Metal) that is powerful but underexplored. The gap between what this hardware can do and what most software actually does with it is enormous.
Akunu exists to close that gap for inference workloads. This book exists to show you how.
Turn the page. Let’s talk about the silicon.
“Any sufficiently advanced technology is indistinguishable from magic.” — Arthur C. Clarke
Our goal is to make the magic distinguishable.
-
Benchmarks run on M4 Pro (16 GPU cores, 273 GB/s) with llama.cpp b8610, 3 reps per config. Decode measured at tg128 (128-token generation). The 1.83x is the average across 19 GGUF models. Speedup is highest on small models (3.66x on Qwen3-0.6B Q3_K_S) where per-token dispatch overhead dominates, and converges to parity on larger models (0.98x on Qwen3-8B Q8_0). Prefill averages 0.91x. See BENCHMARKS.md for the full data and Chapter 55 for methodology. ↩
-
Apple. “Metal Best Practices Guide.” developer.apple.com. The primary official reference for Metal compute optimization. See https://developer.apple.com/library/archive/documentation/3DDrawing/Conceptual/MTLBestPracticesGuide/index.html. ↩
-
Grinberg, D. et al. “Reverse-engineering Apple GPU cores.” Asahi Linux project, 2022. The most detailed public analysis of Apple GPU internals. See https://dougallj.github.io/applegpu/. ↩
The Apple Silicon Revolution
On June 22, 2020, at the Worldwide Developers Conference, Apple CEO Tim Cook announced that the Mac was transitioning from Intel x86 processors to Apple’s own custom ARM-based chips.1 This was not a surprise to industry watchers – rumors had been circulating for years – but the ambition and speed of the transition stunned everyone.2 Apple was not just switching CPU vendors. They were bringing the entire system-on-chip (SoC) approach that had made the iPhone and iPad dominant in mobile computing to the desktop and laptop. Within two years, every Mac in the lineup would run on Apple Silicon.
This chapter tells the story of how we got here, why it matters, and what it means for the kind of workload we care about most in this book: running large language models as fast as possible on local hardware.
Why Apple Left Intel
To understand why Apple Silicon exists, you need to understand why Apple was unhappy with Intel. The relationship between the two companies began in 2005, when Steve Jobs announced the Mac’s transition from PowerPC to Intel x86. At the time, Intel’s chips offered a compelling combination of performance and power efficiency, and the x86 ecosystem was the dominant force in personal computing.
For about a decade, this partnership worked well. Intel’s tick-tock cadence – alternating between process shrinks and microarchitecture improvements – delivered steady performance gains. But starting around 2015, things began to go wrong.
Intel’s Process Stagnation
Intel’s manufacturing process, once the envy of the semiconductor industry, hit a wall. The company’s 10nm node (which they originally planned to ship in 2016) was delayed repeatedly. The 14nm process that was supposed to be a one-generation stopgap ended up being stretched across four generations of products:
Intel's 14nm Purgatory
=======================
2014: Broadwell (14nm) -- on schedule
2015: Skylake (14nm) -- ok, next year we move to 10nm
2016: Kaby Lake (14nm+) -- 10nm delayed, here's a refinement
2017: Coffee Lake (14nm++) -- still no 10nm, more cores this time
2018: Coffee Lake (14nm++) -- yeah, 10nm is still not ready
2019: Ice Lake (10nm) -- finally! but only in low-power laptops
2020: Rocket Lake (14nm) -- back to 14nm for desktops. really.
Each year that Intel stayed on 14nm, the performance improvements got smaller. The company resorted to increasing power consumption and adding cores to show benchmark improvements, but the fundamental per-core performance was stagnating.
Meanwhile, Apple’s in-house chip team – born from the 2008 acquisition of PA Semi, a boutique chip design firm – was making ARM-based processors that improved by 20-40% in performance every single year. By 2020, Apple’s A14 chip (in the iPhone 12) was competitive with Intel’s laptop processors in single-threaded performance, despite running at a fraction of the power.
The Power Problem
For a laptop maker like Apple, power efficiency is not a nice-to-have; it is the single most important metric. Every watt of power consumed becomes a watt of heat that must be dissipated. More heat means bigger fans, thicker chassis, and shorter battery life. Intel’s chips were designed primarily for desktops and servers, then adapted for laptops. Apple wanted chips designed for laptops first.
Power Efficiency Comparison (circa 2020)
=========================================
Intel Core i7 (10th gen laptop):
+----------------------------+
| TDP: 28W (actual: 35-50W) |
| Performance: ████████ |
| Per-watt: ███ |
+----------------------------+
Apple A14 (iPad/iPhone):
+----------------------------+
| TDP: ~5-6W |
| Performance: ██████ |
| Per-watt: █████████████ |
+----------------------------+
Apple wanted to bring that per-watt advantage to the Mac.
The gap in power efficiency was not just about process technology (though TSMC’s nodes were ahead of Intel’s). It was about fundamental architectural choices. ARM’s instruction set, which we will discuss shortly, is inherently more power-efficient than x86. And Apple’s custom microarchitecture was designed from the ground up for efficiency in a way that Intel’s x86 legacy made difficult.
The Integration Advantage
Perhaps the most important reason Apple moved to its own silicon was integration. With Intel chips, a Mac was a collection of discrete components: an Intel CPU, a separate AMD or NVIDIA GPU (or Intel integrated graphics), a separate ISP (image signal processor), a separate Thunderbolt controller, a separate security chip (T1, then T2). Each of these communicated over various buses and protocols, with all the latency and power overhead that implies.
Intel-era Mac Architecture (simplified)
=========================================
+--------+ PCIe +------------+
| Intel |<------------>| AMD/NVIDIA |
| CPU | | Discrete |
| | | GPU |
+---+----+ +------------+
| |
| DDR4 | GDDR6
v v
+--------+ +-----------+
| System | | VRAM |
| Memory | | (separate)|
| (RAM) | | |
+--------+ +-----------+
Data must cross PCIe bus to go between CPU and GPU.
GPU has its own memory (VRAM), separate from system RAM.
Bandwidth between CPU and GPU is limited (~16 GB/s PCIe 3.0).
Apple’s SoC approach puts everything on a single chip, sharing a single pool of memory. We will explore this unified memory architecture in detail in Chapter 4, but the key insight is this: when the CPU and GPU share the same memory, you eliminate the need to copy data between them. For machine learning workloads, where models can be tens of gigabytes, this is transformative.
The ARM Instruction Set Architecture
Before we dive into Apple Silicon specifically, let’s talk about ARM – the instruction set architecture (ISA) that forms the foundation of every Apple Silicon chip.
ARM stands for “Advanced RISC Machines” (originally “Acorn RISC Machine,” after the British company that designed the first ARM processor in 1985). ARM does not manufacture chips; it designs instruction set architectures and licenses them to chip makers. Apple licenses the ARMv8-A architecture (ARMv8.5-A for M1, ARMv8.6-A for M2/M3, ARMv9.2-A for M4) and then designs its own custom microarchitecture that implements that ISA.
This distinction is important: the ISA defines what instructions the processor understands. The microarchitecture defines how those instructions are executed. Two chips can implement the same ISA with radically different performance characteristics. Apple’s custom ARM cores (codenamed Firestorm, Icestorm, Avalanche, Blizzard, Everest, Sawtooth, etc.) are not the same as Qualcomm’s Kryo cores or ARM’s own Cortex cores, even though they all execute ARM instructions.
RISC vs. CISC
ARM is a RISC (Reduced Instruction Set Computer) architecture.3 Intel’s x86 is a CISC (Complex Instruction Set Computer) architecture. This is one of the most important distinctions in processor design, and it has deep implications for power efficiency and performance.
CISC (x86) philosophy: Provide a large number of complex instructions. A single instruction might load data from memory, perform an arithmetic operation, and store the result back to memory. The idea is to minimize the number of instructions the compiler needs to generate, reducing code size.
RISC (ARM) philosophy: Provide a smaller number of simple instructions. Each instruction does one thing: load data, perform arithmetic, or store data, but not a combination. The idea is that simpler instructions can be executed faster, and the hardware can be simpler (and therefore more power-efficient).
CISC vs. RISC: Adding Two Numbers from Memory
===============================================
x86 (CISC):
+-----------------------------------------+
| ADD [mem_a], [mem_b] | <-- One instruction
| | but internally decoded
| Internally becomes: | into multiple micro-ops
| load temp1, [mem_a] |
| load temp2, [mem_b] |
| add temp1, temp1, temp2 |
| store [mem_a], temp1 |
+-----------------------------------------+
ARM (RISC):
+-----------------------------------------+
| LDR X0, [mem_a] | <-- Load first operand
| LDR X1, [mem_b] | <-- Load second operand
| ADD X0, X0, X1 | <-- Add them
| STR X0, [mem_a] | <-- Store result
+-----------------------------------------+
ARM: 4 instructions, each simple and predictable
x86: 1 instruction, but complex internal decoding required
In practice, modern x86 processors bridge the gap by decoding CISC instructions into internal RISC-like micro-operations (micro-ops). But this decoding step consumes die area and power. ARM processors skip this step entirely because their instructions are already simple. This is one of the main reasons ARM chips are more power-efficient.
Fixed-Width Instructions
One of ARM’s most important characteristics is that all instructions are the same width: 32 bits (4 bytes). (ARM also supports a 16-bit “Thumb” mode for code density, but Apple Silicon runs in AArch64 mode where all instructions are 32 bits.)
This might seem like a minor detail, but it has profound implications for the processor’s front end – the part of the chip that fetches and decodes instructions.
x86 Variable-Length Instructions
=================================
Memory layout of x86 code:
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| 1 byte | 3 bytes | 7 bytes | 2 bytes|
| instr | instr | instr | instr |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
The processor cannot tell where one instruction ends and
the next begins without decoding each instruction sequentially.
This is hard to parallelize.
ARM Fixed-Width Instructions
==============================
Memory layout of ARM code:
+----------+----------+----------+----------+----------+
| 4 bytes | 4 bytes | 4 bytes | 4 bytes | 4 bytes |
| instr | instr | instr | instr | instr |
+----------+----------+----------+----------+----------+
Every instruction starts at a 4-byte boundary.
The processor can easily fetch and decode multiple
instructions in parallel. No ambiguity about boundaries.
For Intel, the variable-length instruction format means the front-end decode logic is one of the most complex (and power-hungry) parts of the chip. The decoder must examine each byte to determine the instruction length before it can find the start of the next instruction. Intel has thrown enormous engineering resources at this problem (instruction caches, micro-op caches, complex decode pipelines), but it remains an inherent disadvantage.
The Register File
ARM’s AArch64 architecture provides 31 general-purpose 64-bit registers (X0-X30), compared to x86-64’s 16 general-purpose registers (RAX, RBX, RCX, RDX, RSI, RDI, RBP, RSP, R8-R15). Having more registers means less “register spilling” – the costly process of saving register values to the stack when the processor runs out of registers.
Register Files Compared
========================
ARM AArch64: x86-64:
+------+------+------+ +------+------+
| X0 | X1 | X2 | | RAX | RBX |
+------+------+------+ +------+------+
| X3 | X4 | X5 | | RCX | RDX |
+------+------+------+ +------+------+
| X6 | X7 | X8 | | RSI | RDI |
+------+------+------+ +------+------+
| X9 | X10 | X11 | | RBP | RSP |
+------+------+------+ +------+------+
| X12 | X13 | X14 | | R8 | R9 |
+------+------+------+ +------+------+
| X15 | X16 | X17 | | R10 | R11 |
+------+------+------+ +------+------+
| X18 | X19 | X20 | | R12 | R13 |
+------+------+------+ +------+------+
| X21 | X22 | X23 | | R14 | R15 |
+------+------+------+ +------+------+
| X24 | X25 | X26 | 16 registers
+------+------+------+
| X27 | X28 | X29 |
+------+------+------+
| X30 (LR) | SP |
+------+------+------+
31 registers + SP
Plus 32 SIMD/FP registers (V0-V31) on both architectures,
but ARM's are 128-bit (NEON) with optional SVE extensions.
ARM also specifies a dedicated link register (X30/LR) for function return addresses, which simplifies function call conventions and branch prediction.
NEON and Advanced SIMD
ARM includes NEON, a SIMD (Single Instruction, Multiple Data) extension that operates on 128-bit vector registers. NEON can process multiple data elements simultaneously:
NEON SIMD Operation: Adding 4 floats at once
==============================================
V0: | float_a0 | float_a1 | float_a2 | float_a3 | (128 bits)
+----------+----------+----------+----------+
+ + + +
V1: | float_b0 | float_b1 | float_b2 | float_b3 | (128 bits)
+----------+----------+----------+----------+
= = = =
V2: | float_c0 | float_c1 | float_c2 | float_c3 | (128 bits)
+----------+----------+----------+----------+
FADD V2.4S, V0.4S, V1.4S -- One instruction, four additions
Apple’s implementation of NEON is extremely wide and high-throughput. The performance cores (P-cores) on M-series chips can execute up to four 128-bit NEON operations per cycle, giving them enormous SIMD throughput for CPU-side workloads.
For the purposes of this book, NEON is relevant primarily for CPU-side preprocessing (tokenizer operations, weight format conversion) rather than the main inference workload, which runs on the GPU. But it is good to know it is there.
AMX: Apple’s Secret Matrix Coprocessor
Here is something you will not find in the official ARM specification: Apple’s M-series chips include a custom matrix coprocessor called AMX (Apple Matrix Extensions). AMX is not part of the ARM ISA – it is a proprietary Apple extension that provides hardware-accelerated matrix multiplication on the CPU side.
AMX operates on its own set of registers and can perform operations like multiplying two 16x16 matrices of FP16 values in a single instruction. Apple uses AMX internally in the Accelerate framework and CoreML, but it is not officially documented or exposed as a public API.
AMX: Apple's Matrix Coprocessor (undocumented)
================================================
+---------------------+
| Apple CPU Core |
| |
| +------+ +------+ |
| | ALU | | NEON | |
| +------+ +------+ |
| |
| +---------------+ |
| | AMX | | <-- Custom matrix coprocessor
| | 16x16 FP16 | | Not part of ARM ISA
| | matrix mul | | Undocumented by Apple
| | in one cycle | | Used by Accelerate/CoreML
| +---------------+ |
| |
+---------------------+
AMX registers are separate from the ARM register file.
AMX instructions are encoded as system register writes.
We mention AMX here for completeness, but akunu does not use it. Akunu runs the entire inference workload on the GPU, where the much larger number of ALUs and the higher memory bandwidth make GPU compute far more efficient for the matrix operations that dominate LLM inference. AMX is more useful for smaller matrix operations that need low latency (like real-time audio processing).
The A-Series Lineage
Apple Silicon did not spring into existence fully formed. It is the culmination of over a decade of chip design work that began with the original iPhone.
The Mobile Era (2010-2019)
Apple started designing its own chips with the A4, which powered the original iPad and iPhone 4 in 2010. Each generation brought significant improvements:
The A-Series Evolution
=======================
A4 (2010) -- First Apple-designed SoC. ARM Cortex-A8 based.
| Single core, 45nm Samsung process.
v
A5 (2011) -- Dual-core ARM Cortex-A9. First Apple GPU (PowerVR).
|
v
A6 (2012) -- First CUSTOM Apple CPU core (Swift). No more ARM reference
| designs. This is where Apple diverged from everyone else.
v
A7 (2013) -- First 64-bit mobile processor. EVER. Caught the industry
| off guard. Desktop-class architecture in a phone.
v
A8 (2014) -- 20nm TSMC. Performance and efficiency improvements.
|
v
A9 (2015) -- Custom "Twister" core. Huge IPC gains.
|
v
A10 (2016) -- First big.LITTLE configuration (2 perf + 2 efficiency cores)
| "Fusion" architecture.
v
A11 (2017) -- First Apple-designed GPU. Dropped PowerVR.
| First Neural Engine (2-core, 600B ops/sec).
v
A12 (2018) -- 7nm TSMC. 8-core Neural Engine (5 trillion ops/sec).
| "Bionic" branding begins.
v
A13 (2019) -- "Bionic." Fastest mobile chip by a huge margin.
| 8-core Neural Engine (6 trillion ops/sec).
v
A14 (2020) -- 5nm TSMC. First 5nm chip in any consumer device.
| Performance competitive with Intel laptop chips.
v
M1 (2020) -- A14 core design, scaled up for Mac.
The revolution begins.
Several key milestones in this lineage are directly relevant to Apple Silicon and akunu:
A6 (2012): Custom CPU cores. Apple stopped using ARM’s reference CPU designs and started designing its own microarchitectures. This gave Apple the freedom to make architectural decisions optimized for their specific use cases, and it is why Apple’s ARM cores consistently outperform everyone else’s ARM cores.
A7 (2013): 64-bit ARM. Apple was the first company to ship a 64-bit ARM processor in a consumer device. Qualcomm and Samsung scrambled to catch up. The 64-bit address space would later be essential for the large memory configurations in M-series chips.
A11 (2017): Custom GPU. Apple designed its own GPU, replacing the PowerVR cores it had licensed for years. This custom GPU design is the direct ancestor of the GPU cores in every M-series chip, and understanding its architecture is critical for understanding akunu’s Metal kernels.
A11 (2017): First Neural Engine. Apple’s first dedicated ML accelerator appeared in the A11. While akunu does not use the Neural Engine (it targets the GPU for maximum flexibility and performance), the Neural Engine’s existence shows how seriously Apple takes on-device ML.
From Phone to Mac: The M1 Announcement
On November 10, 2020, Apple announced the M1 chip and the first three Macs to use it: the MacBook Air, MacBook Pro 13-inch, and Mac mini. The M1 was not just an A14 chip with a different name – it was the A14’s core design scaled up for the thermal and power envelope of a laptop:
A14 (iPhone) vs. M1 (Mac)
===========================
A14: M1:
+---------------------+ +-------------------------------+
| 2 Perf + 4 Eff cores| | 4 Perf + 4 Eff cores |
| 4 GPU cores | | 7 or 8 GPU cores |
| 16-core Neural Eng | | 16-core Neural Engine |
| 4GB/6GB RAM | | 8GB/16GB Unified Memory |
| ~4W power budget | | ~15-20W power budget |
| LPDDR4X | | LPDDR4X, 128-bit bus |
| 11.8B transistors | | 16B transistors |
| 88mm^2 die size | | 120mm^2 die size |
+---------------------+ +-------------------------------+
Same CPU core design (Firestorm + Icestorm)
Same GPU core design
But more of everything, with more memory and bandwidth
The M1 was a revelation. It offered:
- Better CPU performance than Intel’s Core i9 in most workloads
- Better GPU performance than Intel’s Iris integrated graphics (and competitive with low-end discrete GPUs)
- 20-hour battery life in the MacBook Air
- Fanless operation in the MacBook Air (the chip ran cool enough to not need a fan)
- Unified Memory Architecture with zero-copy sharing between CPU and GPU
The tech industry was stunned. Not because Apple had made a fast chip – the A-series trajectory made that predictable – but because the first-generation Mac chip was this good. It was not a compromise “good enough for the Mac” chip; it was genuinely better than Intel’s best laptop offerings in almost every measurable way.
What Makes Apple Silicon Different
Now that we understand the history, let’s talk about what makes Apple Silicon architecturally unique. There are several key differentiators that matter for our purposes:
1. System-on-Chip (SoC) Integration
Unlike traditional PCs where the CPU, GPU, memory controller, I/O controllers, and other components are separate chips on a motherboard, Apple Silicon puts everything on a single die (or, for Ultra variants, two dies connected by a high-speed interconnect).
Traditional PC Architecture Apple Silicon SoC
========================== ==================
+------+ +------+ +------+ +---------------------------+
| CPU | | GPU | | WiFi | | CPU GPU Neural Eng |
+--+---+ +--+---+ +--+---+ | |
| | | | Media ISP Secure Enc |
===+====PCIe==+====USB===+==== | |
| | Memory Thunderbolt |
+--+---------------------------+ | Controller Controller |
| Motherboard | | |
| +--------+ +-----------+ | | I/O SLC Fabric |
| | Memory | | Storage | | +---------------------------+
| | DIMMs | | Controller| | |
| +--------+ +-----------+ | +---------------------------+
+------------------------------+ | Unified Memory |
| (LPDDR on-package) |
Multiple chips, multiple buses, +---------------------------+
multiple memory pools
One chip, one memory pool,
one interconnect fabric
This integration has several benefits:
-
Lower latency: Components communicate over an on-die fabric instead of PCIe or USB. On-die communication can be an order of magnitude lower latency than crossing a PCIe bus.
-
Lower power: Driving signals across a PCIe bus or memory DIMMs requires much more power than on-die communication. The shorter the wire, the less power it takes.
-
Higher bandwidth: Apple can build wide, custom interconnects between components because they are all on the same die. The M4 Max, for example, has a 512-bit memory bus – wider than what you would typically see on a discrete GPU.
2. Unified Memory Architecture (UMA)
This is the single most important differentiator for ML workloads, and we will devote all of Chapter 4 to it. The short version: on Apple Silicon, the CPU and GPU share the same physical memory. There is no “system RAM” and “VRAM” – there is just memory, and both the CPU and GPU can access it.
For LLM inference, this means:
-
No copying model weights between CPU and GPU. When akunu loads a model from disk, the CPU reads the file into a Metal buffer, and the GPU can immediately access those weights without any data transfer.
-
No VRAM limitation. On a discrete GPU, your model must fit in VRAM (12GB, 24GB, etc.). On Apple Silicon, your model can use the entire unified memory pool – up to 192GB on the M4 Max.
-
CPU post-processing is free. After the GPU generates logits, the CPU can read the output buffer directly without waiting for a DMA transfer.
Why UMA Matters for LLM Inference
===================================
Discrete GPU (NVIDIA):
1. Load model from disk -> CPU memory (RAM) [slow: disk I/O]
2. Copy weights from RAM -> GPU VRAM (PCIe) [slow: PCIe ~32 GB/s]
3. GPU computes inference using VRAM [fast: HBM ~900 GB/s]
4. Copy results from VRAM -> RAM (PCIe) [slow: PCIe ~32 GB/s]
5. CPU reads results from RAM [fast: DDR5 ~50 GB/s]
Apple Silicon (UMA):
1. Load model from disk -> unified memory [slow: disk I/O]
2. GPU computes inference using same memory [fast: ~273-546 GB/s]
3. CPU reads results from same memory [fast: same bus]
Steps 2 and 4 from the discrete GPU path simply do not exist on Apple Silicon.
3. Efficiency Cores and Performance Cores (big.LITTLE)
Apple Silicon uses a heterogeneous CPU design with two types of cores:
-
Performance cores (P-cores): Wide, deep-pipeline, out-of-order cores designed for maximum single-threaded performance. These are among the fastest CPU cores in the world. They consume more power.
-
Efficiency cores (E-cores): Narrower, simpler cores designed for maximum performance per watt. They handle background tasks and light workloads, allowing the P-cores to stay idle (and powered down) when they are not needed.
CPU Cluster Architecture
=========================
+-------------------------------------------+
| Performance Cluster |
| +---------+ +---------+ +---------+ +--+ |
| | P-core | | P-core | | P-core | |..| |
| | Wide | | Wide | | Wide | | | |
| | OoO | | OoO | | OoO | | | |
| | 192KB | | 192KB | | 192KB | | | |
| | L1 | | L1 | | L1 | | | |
| +---------+ +---------+ +---------+ +--+ |
| Shared L2 Cache (12-16MB) |
+-------------------------------------------+
+-------------------------------------------+
| Efficiency Cluster |
| +-------+ +-------+ +-------+ +-------+ |
| |E-core | |E-core | |E-core | |E-core | |
| |Narrow | |Narrow | |Narrow | |Narrow | |
| |In-ord | |In-ord | |In-ord | |In-ord | |
| | 128KB | | 128KB | | 128KB | | 128KB | |
| | L1 | | L1 | | L1 | | L1 | |
| +-------+ +-------+ +-------+ +-------+ |
| Shared L2 Cache (4MB) |
+-------------------------------------------+
For akunu, the P-cores matter during model loading and weight rearrangement (CPU-side work), while the GPU handles the actual inference computation. The E-cores might handle I/O and tokenization during inference.
4. The Custom GPU
Apple’s GPU is a tile-based deferred renderer (TBDR) designed primarily for mobile graphics, but it turns out to be surprisingly capable for compute workloads as well. We will cover the GPU architecture in detail in Chapter 3, but the key points are:
- Apple’s GPU uses 32-thread SIMD groups (similar to NVIDIA’s 32-thread warps)
- It has on-chip tile memory (similar to NVIDIA’s shared memory)
- It supports SIMD group matrix operations for hardware-accelerated matrix multiplication
- It shares the same unified memory as the CPU, with no separate VRAM
The GPU is where akunu spends the vast majority of its time, and understanding its architecture is the foundation for understanding akunu’s performance.
5. The Neural Engine
Apple Silicon includes a dedicated Neural Engine – a specialized accelerator for neural network inference. The Neural Engine is optimized for a specific set of operations (convolutions, matrix multiplications, etc.) and can be more power-efficient than the GPU for those operations.
However, akunu does not use the Neural Engine. Here is why:
- Limited programmability: The Neural Engine is accessed through CoreML, which abstracts away the hardware details. You cannot write custom kernels for it.
- Limited precision support: The Neural Engine is optimized for FP16 and INT8 operations. Many of akunu’s kernels need mixed-precision arithmetic.
- Limited flexibility: The Neural Engine is designed for standard neural network layers. Custom operations (like akunu’s fused kernels) cannot run on it.
- GPU is fast enough: Apple’s GPU, with SIMD group matrix operations and high memory bandwidth, is more than fast enough for inference when properly optimized.
Why Akunu Uses the GPU, Not the Neural Engine
===============================================
Neural Engine:
+-------------------+
| + Power efficient |
| + Good for std ops |
| - No custom kernels|
| - CoreML only |
| - Fixed data types |
| - Limited control |
+-------------------+
GPU:
+-------------------+
| + Full control |
| + Custom kernels |
| + Any data type |
| + High bandwidth |
| + SIMD group ops |
| + Threadgroup mem |
| - More power |
+-------------------+
For maximum performance with custom operations, the GPU wins.
How Apple Silicon Changes the Game for Local Inference
With the background in place, let’s talk about why Apple Silicon is uniquely well-suited for running LLMs locally – and why akunu exists.
The Memory Wall
The fundamental bottleneck in LLM inference is not compute; it is memory bandwidth. During the decode phase (generating one token at a time), the model must read its entire weight matrix for every single token. For a 7-billion-parameter model in 4-bit quantization, that is roughly 3.5GB of data that must be read from memory for every single token generated.
The speed at which you can read data from memory – the memory bandwidth – determines the upper bound on your decode speed. This is the “memory wall.”
The Memory Wall: Why Bandwidth Matters
========================================
Model: Llama 3.1 8B (Q4_0 quantization)
Weight size: ~4.3 GB
Operation per token: read all weights + compute
If memory bandwidth is B (GB/s) and model size is S (GB):
Maximum theoretical tokens/sec = B / S
+-------------------+----------+-----------------------+
| Hardware | BW (GB/s)| Max decode (tok/s) |
+-------------------+----------+-----------------------+
| DDR4 laptop | 38 | 38/4.3 = ~9 tok/s |
| M1 (LPDDR4X) | 68 | 68/4.3 = ~16 tok/s |
| RTX 3090 (HBM) | 936 | 936/4.3 = ~218 tok/s |
| M4 Pro (LPDDR5X) | 273 | 273/4.3 = ~63 tok/s |
| M4 Max (LPDDR5X) | 546 | 546/4.3 = ~127 tok/s |
+-------------------+----------+-----------------------+
Note: These are THEORETICAL MAXIMUMS. Actual performance depends on
how efficiently the software uses the available bandwidth.
Akunu gets remarkably close to these theoretical limits.
Apple Silicon’s advantage here is threefold:
-
High bandwidth LPDDR5X memory. The M4 Max provides 546 GB/s of memory bandwidth, which is in the range of older high-end discrete GPUs.
-
No PCIe bottleneck. On an NVIDIA system, even if the GPU has 900+ GB/s of HBM bandwidth, the model weights must first cross the 32-64 GB/s PCIe bus to get from system RAM to VRAM. If the model does not fit in VRAM, you hit the PCIe wall. On Apple Silicon, there is no such bottleneck – the full memory bandwidth is available to the GPU.
-
Large memory capacity. The M4 Max supports up to 128GB of unified memory. You can run a 70B model in 4-bit quantization (~35GB) without any model splitting or offloading.
The Total-Cost-of-Ownership Argument
There is also a practical economic argument for Apple Silicon inference. An NVIDIA A100 or H100 GPU costs thousands of dollars (or hundreds per month in cloud rental), requires a dedicated server, consumes hundreds of watts, and needs active cooling. A MacBook Pro with an M4 Max chip can run significant models while sitting on your lap, running on battery, making no noise.
For personal use, development, testing, and small-scale deployment, Apple Silicon offers a compelling total cost of ownership that discrete GPU setups cannot match.
What Akunu Exploits
Akunu is designed from the ground up to exploit Apple Silicon’s unique characteristics:
Akunu's Hardware-Aware Design
==============================
Apple Silicon Feature Akunu Exploitation
======================== ================================
Unified Memory (UMA) --> Zero-copy weight loading.
CPU rearranges weights directly
in GPU-accessible buffers.
High memory bandwidth --> Bandwidth-optimized kernels.
GEMV kernels saturate the
memory bus.
SIMD group matrix ops --> Hardware-accelerated matrix
multiplication in GEMM and
attention kernels.
System Level Cache (SLC) --> Kernel tiling tuned to SLC
size per chip generation.
GPU core count varies --> Dispatch parameters tuned per
chip (M1 vs M4 Pro vs M4 Max).
Threadgroup memory --> Cooperative kernels that share
data between SIMD groups via
fast on-chip memory.
A Roadmap of What’s Coming
Let’s close this chapter with a preview of the journey ahead. Here is the conceptual stack we will build up over the course of this book:
The Akunu Stack (bottom to top)
================================
Layer 7: Application
+----------------------------------------------------------+
| HTTP Server | CLI | Swift/Python Bindings |
+----------------------------------------------------------+
Layer 6: Decoding Strategies
+----------------------------------------------------------+
| Greedy | Sampled | Speculative | Grammar-Constrained |
+----------------------------------------------------------+
Layer 5: Inference Pipeline
+----------------------------------------------------------+
| Model Loading | Prefill | Decode Loop | KV Cache |
+----------------------------------------------------------+
Layer 4: Metal Kernels
+----------------------------------------------------------+
| GEMV | GEMM | FlashAttn | RMSNorm | RoPE | Sampling |
+----------------------------------------------------------+
Layer 3: Metal Compute Framework
+----------------------------------------------------------+
| Pipelines | Buffers | Command Encoders | Threadgroups |
+----------------------------------------------------------+
Layer 2: Apple GPU Architecture
+----------------------------------------------------------+
| GPU Cores | SIMD Groups | Tile Memory | Register File |
+----------------------------------------------------------+
Layer 1: Apple Silicon SoC
+----------------------------------------------------------+
| CPU | GPU | Neural Engine | Memory | SLC | Fabric |
+----------------------------------------------------------+
Layer 0: Silicon
+----------------------------------------------------------+
| TSMC 3nm/5nm Process | Transistors | Interconnect |
+----------------------------------------------------------+
We start at Layer 0/1 (this chapter and the next four)
and work our way up to Layer 7.
In the next chapter, we will zoom into Layer 1 and examine the System-on-Chip architecture in detail. We will look at every component on the Apple Silicon die and understand how they work together.
Summary
Let’s recap what we covered in this chapter:
-
Apple transitioned from Intel to ARM because Intel’s process technology stagnated, ARM offers better power efficiency, and Apple’s SoC approach enables superior integration.
-
The ARM ISA is a RISC architecture with fixed-width 32-bit instructions, 31 general- purpose registers, NEON SIMD extensions, and (on Apple chips) the undocumented AMX matrix coprocessor.
-
Apple’s chip design lineage stretches from the A4 (2010) through the A14 (2020) to the M1 and beyond, with key milestones including custom CPU cores (A6), 64-bit ARM (A7), custom GPU (A11), and the Neural Engine (A11).
-
Apple Silicon’s key differentiators include SoC integration, unified memory architecture, heterogeneous CPU cores, a custom TBDR GPU, and a dedicated Neural Engine.
-
For LLM inference, the critical factors are memory bandwidth (the memory wall), the absence of a PCIe bottleneck (UMA), and the large unified memory pool.
-
Akunu exploits UMA for zero-copy weight loading, high bandwidth for saturating GEMV kernels, SIMD group matrix ops for hardware-accelerated matmul, and per-chip tuning for optimal dispatch parameters.
Next up: let’s crack open the chip and see what is inside.
-
Apple. “Apple announces M1.” apple.com, November 2020. The original announcement detailing unified memory architecture and performance-per-watt claims. See https://www.apple.com/newsroom/2020/11/apple-unleashes-m1/. ↩
-
Turley, J. “Apple Ignites the ARM Mac.” Microprocessor Report, 2020. Analysis of Apple’s transition from Intel to ARM and the architectural advantages of the M1. See https://www.linleygroup.com/. ↩
-
Patterson, D. and Hennessy, J. “Computer Architecture: A Quantitative Approach.” 6th Edition. The foundational text on RISC vs CISC tradeoffs, pipeline design, and the memory wall. See https://www.elsevier.com/books/computer-architecture/hennessy/978-0-12-811905-1. ↩
System-on-Chip Architecture
If you’ve spent your career thinking about computers as a collection of separate chips on a motherboard — a CPU here, a GPU there, RAM sticks in their slots — then Apple Silicon is going to fundamentally rewire how you think about hardware. Everything lives on one die (or two, in the Ultra variants). And that changes everything about how we write high-performance inference code.
What Is a System-on-Chip?
A System-on-Chip (SoC) integrates what traditionally were separate components — processor, graphics, memory controller, I/O — onto a single piece of silicon. This isn’t a new idea; your phone has been running on SoCs for over a decade. What Apple did with the M-series was bring this approach to laptop and desktop-class performance.
In a traditional PC, the CPU has its own RAM (DDR5, ~90 GB/s) and the GPU has its own VRAM (GDDR6X, ~1 TB/s). Data must cross the PCIe bus (~32 GB/s) to move between them. Two separate memory pools, two separate bandwidth domains.
On Apple Silicon, everything — CPU, GPU, Neural Engine — shares one pool of LPDDR5 memory (120-819 GB/s depending on chip). No copy, no PCIe bottleneck. The GPU reads the same bytes the CPU wrote.
┌─────────────────────────────────────────────────────────────┐
│ TRADITIONAL PC ARCHITECTURE │
│ │
│ ┌──────────┐ PCIe x16 ┌────────────────┐ │
│ │ CPU │◄──────────────►│ Discrete GPU │ │
│ │ (Intel/ │ ~32 GB/s │ (NVIDIA/AMD) │ │
│ │ AMD) │ │ │ │
│ └────┬─────┘ └───────┬────────┘ │
│ │ DDR5 │ GDDR6X │
│ │ ~90 GB/s │ ~1 TB/s │
│ ┌────┴─────┐ ┌───────┴───────┐ │
│ │ System │ │ Video RAM │ │
│ │ RAM │ │ (VRAM) │ │
│ │ 32-128GB │ │ 8-24 GB │ │
│ └──────────┘ └───────────────┘ │
│ │
│ Two separate memory pools. Data must be COPIED between them│
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ APPLE SILICON ARCHITECTURE │
│ │
│ ┌──────────────────────────────────────────┐ │
│ │ SINGLE SoC DIE │ │
│ │ │ │
│ │ ┌─────┐ ┌─────┐ ┌──────┐ ┌──────────┐ │ │
│ │ │ CPU │ │ GPU │ │Neural│ │ Media │ │ │
│ │ │ │ │ │ │Engine│ │ Engine │ │ │
│ │ └──┬──┘ └──┬──┘ └──┬───┘ └────┬─────┘ │ │
│ │ └───────┴───────┴──────────┘ │ │
│ │ │ Fabric │ │
│ │ ┌────────┴────────┐ │ │
│ │ │ System Level │ │ │
│ │ │ Cache (SLC) │ │ │
│ │ └────────┬────────┘ │ │
│ └──────────────┼───────────────────────────┘ │
│ ┌──────┴──────┐ │
│ │ Unified │ │
│ │ Memory │ │
│ │ 16-192 GB │ │
│ └─────────────┘ │
│ │
│ ONE memory pool. CPU and GPU access it with ZERO copies. │
└─────────────────────────────────────────────────────────────┘
The animation below shows this difference. Watch how data flows in each architecture — pay attention to the PCIe bottleneck in the traditional model vs the direct access in UMA.
Anatomy of the Die
Let’s map out every major component on an M4 Pro die:1
┌──────────────────────────────────────────────────────────────────┐
│ M4 Pro SoC Die │
│ │
│ ┌─────────────────────────┐ ┌──────────────────────────────┐ │
│ │ CPU Cluster │ │ GPU (20 cores) │ │
│ │ │ │ │ │
│ │ ┌────┐┌────┐┌────┐ │ │ ┌───┐┌───┐┌───┐┌───┐┌───┐ │ │
│ │ │ P0 ││ P1 ││ P2 │ │ │ │ 0 ││ 1 ││ 2 ││ 3 ││ 4 │ │ │
│ │ └────┘└────┘└────┘ │ │ └───┘└───┘└───┘└───┘└───┘ │ │
│ │ ┌────┐┌────┐┌────┐ │ │ ┌───┐┌───┐┌───┐┌───┐┌───┐ │ │
│ │ │ P3 ││ P4 ││ P5 │ │ │ │ 5 ││ 6 ││ 7 ││ 8 ││ 9 │ │ │
│ │ └────┘└────┘└────┘ │ │ └───┘└───┘└───┘└───┘└───┘ │ │
│ │ ┌────┐┌────┐┌────┐ │ │ ┌───┐┌───┐┌───┐┌───┐┌───┐ │ │
│ │ │ P6 ││ P7 ││ P8 │ │ │ │10 ││11 ││12 ││13 ││14 │ │ │
│ │ └────┘└────┘└────┘ │ │ └───┘└───┘└───┘└───┘└───┘ │ │
│ │ ┌────┐ │ │ ┌───┐┌───┐┌───┐┌───┐┌───┐ │ │
│ │ │ P9 │ P=Performance │ │ │15 ││16 ││17 ││18 ││19 │ │ │
│ │ └────┘ │ │ └───┘└───┘└───┘└───┘└───┘ │ │
│ │ ┌────┐┌────┐┌────┐ │ │ 20 GPU Cores │ │
│ │ │ E0 ││ E1 ││ E2 │ │ │ │ │
│ │ └────┘└────┘└────┘ │ │ │ │
│ │ ┌────┐ E=Efficiency │ │ │ │
│ │ │ E3 │ │ │ │ │
│ │ └────┘ │ │ │ │
│ └─────────────────────────┘ └──────────────────────────────┘ │
│ │
│ ┌──────────────┐ ┌──────────┐ ┌────────────┐ ┌──────────┐ │
│ │ Neural Engine│ │ Media │ │ Display │ │ Secure │ │
│ │ 16 Cores │ │ Engine │ │ Engine │ │ Enclave │ │
│ │ 38 TOPS │ │ ProRes │ │ │ │ │ │
│ └──────────────┘ └──────────┘ └────────────┘ └──────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Fabric / Interconnect │ │
│ └──────────────────────────────────────────────────────────┘ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ System Level Cache (SLC) — 36 MB │ │
│ └──────────────────────────────────────────────────────────┘ │
│ ┌──────────┐ ┌───────────┐ ┌──────────┐ ┌──────────────┐ │
│ │ Memory │ │Thunderbolt│ │ PCIe │ │ USB │ │
│ │Controller│ │Controller │ │Controller│ │ Controller │ │
│ └──────────┘ └───────────┘ └──────────┘ └──────────────┘ │
└──────────────────────────────────────────────────────────────────┘
CPU Clusters: Performance and Efficiency
Apple Silicon uses ARM’s big.LITTLE concept with two core types:
Performance Cores (P-cores):
- Wide, out-of-order execution with deep pipelines
- High clock speeds (up to ~4.5 GHz on M4)
- Large L1 caches (192 KB instruction, 128 KB data per core)
- Large shared L2 cache (16-32 MB per cluster)
- Used for compute-intensive tasks
Efficiency Cores (E-cores):
- Narrower, simpler pipeline
- Much lower power consumption
- Lower clock speeds (~2.8 GHz)
- Handle background tasks, I/O-bound work
For akunu, the CPU mainly handles model loading, tokenization, dispatch table building, grammar constraint checking, and control flow between GPU dispatches. The GPU does the neural network computation.
The GPU
This is where akunu spends most of its time. Each GPU core contains ALUs, a texture sampling unit, and tile memory. Core counts vary: M4 has 10, M4 Pro has 20, M4 Max has 40, M4 Ultra has 80. We’ll cover GPU architecture in extreme detail in Chapter 3.
The Neural Engine
Apple’s dedicated matrix multiplication accelerator: 16 cores on M4 Pro, 38 TOPS at INT8. However, akunu does NOT use the Neural Engine. It’s only accessible through CoreML, which doesn’t give the fine-grained control needed for optimized inference. Metal compute shaders give us full control over memory layout, kernel design, and dispatch — which is why akunu achieves 1.83x faster decode than llama.cpp.
System Level Cache (SLC)
The SLC is a large, shared last-level cache between all die components and DRAM:
CPU L2 GPU L2 NE Cache
\ | /
▼ ▼ ▼
┌─────────────────────────┐
│ System Level Cache │
│ M4: 16 MB │
│ M4 Pro: 36 MB │
│ M4 Max: 48 MB │
│ M4 Ultra: 96 MB │
└────────────┬────────────┘
▼
┌───────────────┐
│ LPDDR5 DRAM │
└───────────────┘
The SLC is critical for inference: weight tiles that fit in SLC get accessed much faster, and intermediate activations often fit entirely.2 Akunu’s ChipConfig::slc_size tunes behavior based on available SLC.
Memory Controller
| Chip | Memory Bus | Bandwidth | Max Memory |
|---|---|---|---|
| M4 | 128-bit | 120 GB/s | 32 GB |
| M4 Pro | 192-bit | 273 GB/s | 48 GB |
| M4 Max | 256-bit | 546 GB/s | 128 GB |
| M4 Ultra | 512-bit | 819 GB/s | 192 GB |
For decode, token generation time is approximately model_size_bytes / memory_bandwidth. For a 7B Q4_0 model (~3.8 GB): M4 Pro gives ~71 tok/s theoretical, M4 Max ~143 tok/s. Akunu gets close to these limits.
The Chip Hierarchy: Binning and Variants
Apple creates chip families from the same base design:
BASE DIE ──────► M4 (binned, fewer cores)
│
├──────────► M4 Pro (mid-range)
│
└──────────► M4 Max (full config)
│
│ UltraFusion (2 dies)
▼
M4 Ultra (2x Max)
The Ultra variant connects two Max dies via UltraFusion (~2.5 TB/s die-to-die bandwidth). Software sees it as one unified chip — no special programming needed.
What This Means for Akunu
- Zero-copy weight loading: GGUF files are memory-mapped; the GPU reads directly from mmap’d regions
- CPU-side operations on GPU buffers: Whisper’s cross-attention K/V rearrangement happens on CPU, directly on GPU-accessible memory
- Minimal CPU-GPU sync: Precompiled dispatch table minimizes CPU’s role in the hot path
- SLC-aware tuning: ChipConfig adjusts tile and batch sizes based on SLC
- Bandwidth as bottleneck: During decode, akunu is memory-bandwidth bound — quantization (Q4_0) has dramatic impact
In the next chapter, we’ll zoom into the GPU architecture specifically.
-
Apple. “Apple M4 Pro chip.” apple.com, 2024. Die specifications including GPU core count, memory bandwidth, and SLC size. See https://www.apple.com/newsroom/2024/10/apple-introduces-m4-pro-and-m4-max/. ↩
-
Frumusanu, A. “Apple’s M4 Family: A Deep Dive.” AnandTech / Chips and Cheese, 2024. Independent analysis of Apple Silicon die layout, cache hierarchy, and fabric bandwidth. See https://chipsandcheese.com/. ↩
The GPU: Cores, SIMD Groups, and Threadgroups
This is the most important chapter for understanding akunu’s Metal kernels. Every kernel in akunu — every GEMV, every FlashAttention variant, every normalization — is designed around the execution model we’re about to explore. If you’ve worked with CUDA, many concepts will feel familiar, but the terminology and some architectural details differ. If you haven’t, don’t worry — we’ll build up from first principles.
The Big Picture: Why GPUs?
A CPU is designed to execute a single thread of complex instructions as fast as possible. It has deep pipelines, branch prediction, out-of-order execution, and large caches. A single CPU core might run at 4+ GHz but can only do one or two multiply-accumulate operations per clock.
A GPU takes the opposite approach. It trades single-thread performance for massive parallelism. Each individual “thread” is simpler and slower, but there are thousands of them running simultaneously.
CPU (M4 Pro Performance Core):
┌──────────────────────────────────┐
│ Complex OoO pipeline │
│ Branch predictor │
│ 192 KB L1I + 128 KB L1D │
│ ~4.5 GHz │
│ 1-2 FMA per clock │
│ ≈ 9 GFLOPS FP32 per core │
└──────────────────────────────────┘
× 10 P-cores = ~90 GFLOPS
GPU (M4 Pro, 20 cores):
┌──────────┐┌──────────┐┌──────────┐ ... ┌──────────┐
│ Core 0 ││ Core 1 ││ Core 2 │ │ Core 19 │
│ 128 ALUs││ 128 ALUs││ 128 ALUs│ │ 128 ALUs│
│ ~1.5 GHz ││ ~1.5 GHz ││ ~1.5 GHz │ │ ~1.5 GHz │
└──────────┘└──────────┘└──────────┘ └──────────┘
20 cores × 128 ALUs × 2 FMA × 1.5 GHz ≈ 7,680 GFLOPS FP32
FP16: ~15,360 GFLOPS (2x throughput)
For matrix multiplication — the core operation in neural networks — you need to do the same thing (multiply and add) billions of times with different data. GPUs are purpose-built for exactly this.
GPU Core Architecture
Each Apple GPU core is a self-contained processing unit. Let’s look inside one:
┌─────────────────────────────────────────────────────────┐
│ GPU CORE │
│ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ Execution Units │ │
│ │ │ │
│ │ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │ │
│ │ │ ALU×32 │ │ ALU×32 │ │ ALU×32 │ │ ALU×32 │ │ │
│ │ │ (SG 0) │ │ (SG 1) │ │ (SG 2) │ │ (SG 3) │ │ │
│ │ └────────┘ └────────┘ └────────┘ └────────┘ │ │
│ │ SIMD SIMD SIMD SIMD │ │
│ │ Group 0 Group 1 Group 2 Group 3 │ │
│ │ │ │
│ │ Up to 32 SIMD groups can be resident │ │
│ │ (1024 threads max per core) │ │
│ └────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────┐ ┌──────────────────────────────┐ │
│ │ Register File │ │ Threadgroup Memory (32 KB) │ │
│ │ (per thread) │ │ (shared within threadgroup)│ │
│ └─────────────────┘ └──────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ L1 Cache / Tile Memory │ │
│ └──────────────────────────────────────────────────┘ │
│ │ │
│ L2 Cache (shared across cores) │
└─────────────────────────────────────────────────────────┘
Key facts:
- Each core has 128 ALUs organized into groups of 321
- Each group of 32 ALUs forms a SIMD group (Apple’s term for what NVIDIA calls a “warp”)
- A core can have up to 1024 threads resident (32 SIMD groups × 32 threads)
- 32 KB of threadgroup memory (what NVIDIA calls “shared memory”)
The SIMD Group: 32 Threads in Lockstep
The SIMD group is the fundamental unit of execution on the GPU. It’s a group of 32 threads that execute the exact same instruction at the exact same time, but on different data. This is called SIMT (Single Instruction, Multiple Threads).
SIMD Group (32 threads):
┌──────────────────────────────────────────────────────┐
│ Thread 0 Thread 1 Thread 2 ... Thread 31 │
│ │
│ All execute: result = a[tid] * b[tid] + c[tid] │
│ │
│ t0: a[0]*b[0]+c[0] │
│ t1: a[1]*b[1]+c[1] │
│ t2: a[2]*b[2]+c[2] │
│ ... │
│ t31: a[31]*b[31]+c[31] │
│ │
│ ALL happen in ONE clock cycle (on 32 ALUs) │
└──────────────────────────────────────────────────────┘
This has profound implications:
- No divergence penalty if all threads take the same branch. If thread 0 takes the
ifpath and thread 1 takes theelsepath, the SIMD group must execute BOTH paths (with inactive threads masked out). This wastes cycles. - Memory coalescing: When all 32 threads access consecutive memory addresses, the hardware can combine them into a single wide memory transaction. Random access patterns are much slower.
- SIMD intrinsics: Threads within a SIMD group can communicate directly via
simd_sum,simd_max,simd_shuffle, etc. These are essentially free — no memory access needed.
SIMD Group Communication
This is one of the most powerful features for reduction operations (like computing a dot product). Threads in a SIMD group can share data without going through memory:
simd_sum example (sum 32 values across a SIMD group):
Thread: 0 1 2 3 ... 31
Value: 1.0 2.0 3.0 4.0 ... 32.0
simd_sum → Every thread gets: 528.0
No shared memory needed! Hardware does this in ~5 cycles (log2(32) shuffle steps).
simd_shuffle_down (shift values):
Thread: 0 1 2 3 4 5 ...
Before: a0 a1 a2 a3 a4 a5 ...
After (offset=2): a2 a3 a4 a5 a6 a7 ...
simd_shuffle_xor (butterfly pattern):
Thread: 0 1 2 3 4 5 6 7
XOR(1): a1 a0 a3 a2 a5 a4 a7 a6
XOR(2): a2 a3 a0 a1 a6 a7 a4 a5
XOR(4): a4 a5 a6 a7 a0 a1 a2 a3
Akunu uses these extensively. For example, in GEMV kernels, each thread in a SIMD group computes a partial dot product, then simd_sum combines them. In FlashAttention decode, simd_sum computes the Q·K dot product across threads that each hold different head dimensions.
The Threadgroup
A threadgroup is a collection of threads that:
- Execute on the same GPU core
- Can share threadgroup memory (fast on-chip SRAM)
- Can synchronize with threadgroup_barrier()
Threadgroup (128 threads = 4 SIMD groups):
┌─────────────────────────────────────────────────────────┐
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ SIMD Group 0│ │ SIMD Group 1│ │
│ │ threads 0-31│ │ threads 32-63│ │
│ └──────────────┘ └──────────────┘ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ SIMD Group 2│ │ SIMD Group 3│ │
│ │ threads 64-95│ │threads 96-127│ │
│ └──────────────┘ └──────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ Threadgroup Memory (up to 32 KB) │ │
│ │ Accessible by ALL threads in this threadgroup │ │
│ │ ~1 cycle latency (vs ~100+ cycles for DRAM) │ │
│ └──────────────────────────────────────────────────┘ │
│ │
│ threadgroup_barrier(mem_flags::mem_threadgroup) │
│ ↑ Ensures all threads have finished their writes │
│ before any thread reads │
└─────────────────────────────────────────────────────────┘
Common threadgroup sizes in akunu:2
- 32 threads (1 SG): FlashAttention decode fast — no barriers needed!
- 128 threads (4 SGs): GEMV, GEMM, RMSNorm, RoPE
- 256 threads (8 SGs): Large GEMV variants
- 1024 threads (32 SGs): FlashAttention decode parallel, Gumbel-max sampling
Why Threadgroup Memory Matters
Threadgroup memory is like a programmer-managed L1 cache. It’s fast (~1 cycle) but small (32 KB). The classic use is tiled matrix multiply:
Without threadgroup memory:
Thread 0 reads A[0,0] from DRAM (100+ cycles)
Thread 1 reads A[0,0] from DRAM (100+ cycles) — SAME DATA!
Thread 2 reads A[0,0] from DRAM — again!
... massive bandwidth waste
With threadgroup memory:
1. Thread 0 loads A[0,0..31] into shared[0..31]
2. threadgroup_barrier() — wait for all loads
3. ALL threads read from shared[] (1 cycle each)
4. One DRAM fetch serves 32 threads
Akunu’s GEMM kernel (simd_gemm_f16.metal, covered in Chapter 35) allocates 6144 bytes of threadgroup memory: sa[4096] for weight tiles and sb[2048] for activation tiles. This lets 128 threads share weight data loaded once from DRAM.
The Execution Hierarchy
Click on any level to zoom in and see what’s inside. This is the most important mental model for GPU programming — every kernel you write maps onto this hierarchy.
Thread identification variables (used in every kernel):
kernel void my_kernel(
uint tid [[thread_position_in_grid]], // Global thread ID
uint tgid [[threadgroup_position_in_grid]], // Which threadgroup
uint tiisg [[thread_index_in_simdgroup]], // Lane within SG (0-31)
uint sgitg [[simdgroup_index_in_threadgroup]], // Which SG in TG (0-N)
uint tiitg [[thread_position_in_threadgroup]] // Thread within TG
) { ... }
CUDA vs Metal Terminology
If you’re coming from CUDA, here’s the translation table:
| CUDA Concept | Metal Concept | Size |
|---|---|---|
| Warp | SIMD Group | 32 threads |
| Block | Threadgroup | Variable (32-1024) |
| Grid | Grid | Variable |
| Shared Memory | Threadgroup Memory | Up to 32 KB |
__syncthreads() | threadgroup_barrier() | — |
__shfl_sync() | simd_shuffle() | — |
warpSize | Always 32 | 32 |
blockDim | threads_per_threadgroup | — |
threadIdx | thread_position_in_threadgroup | — |
blockIdx | threadgroup_position_in_grid | — |
wmma (Tensor Cores) | simdgroup_matrix | 8×8 tiles |
Key differences:
- No warp divergence penalty tracking in Metal3 — the hardware handles it, but you should still avoid it
- SIMD group matrix operations use 8×8 tiles (vs NVIDIA’s 16×16 Tensor Core tiles)
- Threadgroup memory is explicitly sized per dispatch, not declared with
__shared__ - No dynamic parallelism — kernels can’t launch other kernels
Occupancy and Register Pressure
Each GPU core has a fixed amount of resources. The more resources a threadgroup uses, the fewer threadgroups can be resident simultaneously:
GPU Core Resources:
┌────────────────────────────────┐
│ Register File: ~32K regs │
│ Threadgroup Mem: 32 KB │
│ Max Threads: 1024 │
│ Max SGs: 32 │
└────────────────────────────────┘
Example: Kernel uses 32 registers per thread, 4096 bytes TG mem
Threads per TG: 128
Registers per TG: 128 × 32 = 4096
TG memory per TG: 4096 bytes
Max concurrent TGs: min(
32K / 4096 = 8 (register limited),
32K / 4096 = 8 (TG mem limited),
1024 / 128 = 8 (thread limited)
) = 8 threadgroups, 1024 threads
Occupancy: 1024 / 1024 = 100%
High occupancy means more threads to hide memory latency. When one SIMD group is waiting for a memory fetch (100+ cycles), the core can switch to another SIMD group that has work ready. This is how GPUs tolerate high memory latency — they never run out of work.
Dispatch Models
Metal offers two dispatch models:
dispatchThreadgroups(gridSize, threadgroupSize):
- You specify how many threadgroups to launch and their size
- Total threads = gridSize × threadgroupSize
- May launch more threads than needed (you handle bounds checking)
dispatchThreads(totalThreads, threadgroupSize):
- You specify the total number of threads needed
- Metal handles partial threadgroups at the edges
- Cleaner for simple 1:1 thread-to-data mappings
Akunu’s DispatchCmd has a use_dispatch_threads flag that selects between these. Most kernels use dispatchThreadgroups for precise control.
The Tile-Based Deferred Renderer (TBDR)
Apple’s GPU was originally designed for mobile graphics, which uses a Tile-Based Deferred Rendering architecture. In graphics, the screen is divided into tiles and each tile is rendered entirely in fast tile memory before writing to DRAM.
For compute shaders (which akunu uses exclusively), TBDR doesn’t apply directly. Compute shaders bypass the tiling hardware and operate like a traditional GPU compute model. However, the tile memory architecture means:
- The GPU has fast on-chip storage (threadgroup memory)
- Memory access patterns that fit in cache lines are rewarded
- The GPU is efficient at processing data in blocks/tiles
This is why akunu’s kernels are organized around tiles: 32-element K-blocks in GEMV, 32×64 output tiles in GEMM, 32-position KV tiles in FlashAttention.
Summary
The Apple GPU execution model is:
- Grid dispatches many threadgroups
- Each threadgroup runs on one GPU core
- Each threadgroup contains multiple SIMD groups of 32 threads
- Threads within a SIMD group execute in lockstep and communicate via SIMD intrinsics
- Threads within a threadgroup share threadgroup memory and synchronize via barriers
- High occupancy hides memory latency
Akunu’s kernels are designed around this hierarchy: SIMD-level reductions for dot products, threadgroup-level cooperation for tiled matrix multiply, and grid-level parallelism across output rows and attention heads.
Next, we’ll look at the unified memory architecture that makes Apple Silicon uniquely suited for inference workloads.
-
Grinberg, D. “Reverse-engineering Apple GPU cores.” Asahi Linux project, 2022. The most detailed public analysis of Apple GPU core internals, SIMD group behavior, and threadgroup memory layout. See https://dougallj.github.io/applegpu/. ↩
-
Apple. “Metal Best Practices Guide.” developer.apple.com. Official guidance on threadgroup sizing, occupancy, and memory access patterns for Metal compute shaders. See https://developer.apple.com/library/archive/documentation/3DDrawing/Conceptual/MTLBestPracticesGuide/index.html. ↩
-
Apple. “Metal Feature Set Tables.” developer.apple.com. Defines GPU family capabilities including simdgroup_matrix support (Family 7+), max threadgroup size, and threadgroup memory limits. See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf. ↩
Unified Memory Architecture
If there’s one hardware feature that makes Apple Silicon uniquely suited for running large language models locally, it’s the Unified Memory Architecture (UMA). It’s the reason a MacBook with 64 GB of RAM can run a 70B model that would require a $2,000+ NVIDIA GPU on a PC. Let’s understand exactly how it works and why it matters for inference.
The Traditional Memory Problem
On a discrete GPU system, there are two completely separate memory pools:
┌─────────────────────────────────────────────────────────┐
│ TRADITIONAL DISCRETE GPU SYSTEM │
│ │
│ CPU Side GPU Side │
│ ┌──────────┐ ┌──────────────┐ │
│ │ System │ PCIe 4.0 x16 │ VRAM │ │
│ │ RAM │◄──────────────────►│ (GDDR6X) │ │
│ │ 64 GB │ ~32 GB/s │ 24 GB │ │
│ │ DDR5 │ bidirectional │ │ │
│ │ │ │ Bandwidth: │ │
│ │ BW: ~90 │ │ ~1 TB/s │ │
│ │ GB/s │ │ │ │
│ └──────────┘ └──────────────┘ │
│ │
│ To run inference: │
│ 1. Load model into system RAM (disk → RAM) │
│ 2. Copy weights to VRAM (RAM → PCIe → VRAM) │
│ 3. GPU computes on VRAM │
│ 4. Copy results back (VRAM → PCIe → RAM) │
│ │
│ PROBLEM: Model must fit in VRAM (24 GB max typical) │
│ PROBLEM: PCIe transfer is slow (~32 GB/s) │
│ PROBLEM: Double the memory usage (model in both pools) │
└─────────────────────────────────────────────────────────┘
A 70B model in Q4_0 quantization is about 40 GB. It literally doesn’t fit in a 24 GB RTX 4090’s VRAM. You need model parallelism across multiple GPUs, or offloading to CPU (which is painfully slow over PCIe).
Apple’s Unified Memory
On Apple Silicon, there’s one memory pool, shared by everything:
┌─────────────────────────────────────────────────────────┐
│ APPLE SILICON UNIFIED MEMORY │
│ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ SoC Die │ │
│ │ │ │
│ │ CPU ──┐ │ │
│ │ ├──► Fabric ──► SLC ──► Memory Controller │ │
│ │ GPU ──┤ │ │
│ │ ├──► │ │
│ │ NE ──┘ │ │
│ └────────────────────────────────────────────────────┘ │
│ │ │
│ ┌──────┴──────┐ │
│ │ LPDDR5 │ │
│ │ Unified │ │
│ │ Memory │ │
│ │ Pool │ │
│ │ │ │
│ │ 16-192 GB │ │
│ └─────────────┘ │
│ │
│ To run inference: │
│ 1. Load model into memory (disk → unified memory) │
│ 2. GPU computes on it directly. Done. │
│ │
│ NO COPY. NO VRAM LIMIT. NO PCIe BOTTLENECK. │
└─────────────────────────────────────────────────────────┘
The implications:
- No VRAM limit: A Mac Studio with 192 GB can run models that would need 8× NVIDIA A100s
- No copy overhead: Weights loaded once are accessible by both CPU and GPU
- Zero-copy operations: CPU can rearrange GPU data directly (akunu uses this for Whisper)
The Memory Hierarchy
Even though memory is “unified,” there’s still a hierarchy of caches for performance. The animation below shows a memory request falling through the layers. Click a layer to see what happens when data is found there (cache hit) vs when it must go deeper (cache miss).
For inference, the critical numbers are:
- SLC size: determines how much of the model weights are “hot” in cache
- DRAM bandwidth: determines maximum decode speed for memory-bound operations
Why Memory Bandwidth Is Everything for Decode
During token generation (decode), the model processes one token at a time. For each token, it needs to:
- Read the embedding for the input token: tiny
- For each layer (e.g., 32 layers in a 7B model):
- Read Q/K/V projection weights:
(dim + 2 × kv_dim) × dim × bytes_per_weight(KV is smaller with GQA) - Read KV cache:
2 × seq_len × kv_dim × 2bytes (FP16) - Read output projection weights:
dim × dim × bytes_per_weight - Read FFN gate+up weights:
2 × dim × ffn_dim × bytes_per_weight - Read FFN down weights:
ffn_dim × dim × bytes_per_weight
- Read Q/K/V projection weights:
For a 7B Q4_0 model, that’s roughly 3.8 GB of weights read per token. The computation (multiply-accumulate) is fast. The bottleneck is reading the weights from memory.
Token generation speed = memory bandwidth / model size (bytes).1
LLaMA 7B Q4_0 (~3.8 GB) theoretical decode speed:
| Chip | Bandwidth | Theoretical tok/s |
|---|---|---|
| M4 | 120 GB/s | ~31 |
| M4 Pro | 273 GB/s | ~71 |
| M4 Max | 546 GB/s | ~143 |
| M4 Ultra | 819 GB/s | ~215 |
| RTX 4090 (GDDR6X) | 1,008 GB/s | ~265 (but 24 GB VRAM limit) |
The RTX 4090 has higher bandwidth, but only 24 GB of VRAM. For models that fit, it’s faster. For large models (70B+), it can’t run them at all without multi-GPU setups.
The Roofline Model
The roofline model helps visualize whether an operation is compute-bound or memory-bound:
Performance
(FLOPS)
│
│ ┌─────────────────────── Compute Ceiling
│ /│ (peak TFLOPS)
│ / │
│ / │
│ / │
│ / │
│ / │
│ / │
│ / │
│ / Memory │ Compute
│ / Bound │ Bound
│ / │
│/ │
└────────────┼──────────────────────
Ridge Point
Operational Intensity
(FLOPS / byte)
GEMV (decode, M=1, see Chapter 34): Each weight is read once and used for 1 multiply-add. Operational intensity = 2 FLOPS / (0.5 bytes for Q4_0) = 4 FLOPS/byte. This is well below the ridge point → memory-bound.
GEMM (prefill, M=seq_len): Each weight is read once but used for M multiply-adds. Operational intensity = 2M FLOPS / (0.5 bytes). For M=512: 2048 FLOPS/byte → compute-bound.
This is why prefill is fast (compute-bound, GPU cores are busy) and decode is slower (memory-bound, waiting for data).
How Akunu Exploits UMA
1. Memory-Mapped Weight Loading
GGUF files are opened with mmap(). The OS maps the file directly into the process’s virtual address space. When a page is first accessed, the kernel loads it from disk. The GPU can then access these pages directly — no explicit copy needed.
Traditional: Disk → fread() → RAM buffer → cudaMemcpy() → VRAM
Akunu UMA: Disk → mmap() → Unified Memory ← GPU reads directly
2. CPU-Side Buffer Operations
For Whisper’s cross-attention, akunu precomputes K/V projections on the GPU, then rearranges the memory layout on the CPU:
// CPU directly reads/writes GPU buffer contents (UMA makes this efficient)
for (int h = 0; h < n_heads; h++)
for (int p = 0; p < enc_seq; p++)
memcpy(&dst[(h*enc_seq + p)*head_dim],
&src[p*dec_dim + h*head_dim],
head_dim * sizeof(__fp16));
On a discrete GPU, this would require: GPU→VRAM→PCIe→RAM→CPU rearrange→RAM→PCIe→VRAM→GPU. On Apple Silicon: CPU reads and writes directly to the same memory the GPU uses. Hundreds of microseconds vs potentially milliseconds.
3. Minimal Synchronization
Because CPU and GPU share memory, akunu can submit GPU work and immediately prepare the next operation. There’s no need to wait for a PCIe transfer to complete — just a lightweight event/fence for GPU command completion.
4. KV Cache Persistence
The KV cache lives in unified memory and persists across requests. In akunu’s server mode, if a new request shares a prefix with the previous one, the cached KV entries are already in memory — no need to recompute or retransfer.
Memory Bandwidth in Practice
Real bandwidth utilization is always less than theoretical peak. Several factors reduce effective bandwidth:
Theoretical Peak: 273 GB/s (M4 Pro)
│
Bank conflicts: │ -5%
Cache misses: │ -10%
Non-coalesced access: │ -15%
TLB misses: │ -5%
▼
Practical Peak: ~175-195 GB/s (65-70% utilization)
Akunu achieves high bandwidth utilization through:
- Vectorized loads:
half4reads (8 bytes at once) instead of individualhalfreads - Coalesced access: Adjacent threads read adjacent memory addresses
- Block-strided K-loops: Process K-dimension in blocks of 512 to maximize cache line reuse
- Weight layout optimization: Quantized blocks are laid out for sequential access
Comparison: UMA vs Discrete vs Cloud
| Feature | Apple UMA | Discrete GPU | Cloud (A100) |
|---|---|---|---|
| Max memory | 192 GB | 24 GB (consumer) | 80 GB |
| Bandwidth | 819 GB/s (Ultra) | 1 TB/s (4090) | 2 TB/s |
| Transfer overhead | None | PCIe ~32 GB/s | NVLink ~900 GB/s |
| 70B Q4 model | Fits in 96 GB Mac | Needs multi-GPU | Fits one A100 |
| Cost | $4K Mac Studio | $2K + PC | ~$3/hr rental |
| Power | ~60W | ~450W | ~400W |
| Programmability | Metal | CUDA | CUDA |
Apple Silicon’s sweet spot is models that don’t quite fit in consumer GPU VRAM (13B-70B). The ability to run these models locally, at reasonable speed, with no cloud dependency, is akunu’s core value proposition.
Summary
Unified Memory Architecture means:
- One memory pool shared by CPU and GPU
- No data copying between host and device memory
- Total system RAM available for model weights
- Memory bandwidth is the primary bottleneck for decode
- Quantization directly translates to speed (less data to read)
The key equation for decode performance: tokens/sec ≈ memory_bandwidth / model_size_bytes
This is why akunu focuses heavily on quantization support (Q4_0 is 4× smaller than FP16, so 4× faster decode) and why its dispatch table eliminates every unnecessary overhead — when you’re memory-bound, every wasted cycle is a wasted byte of bandwidth.
-
Pope, R., et al. “Efficiently Scaling Transformer Inference.” MLSys 2023. This paper provides a thorough analysis of the memory-bound vs compute-bound regimes of transformer inference and derives the bandwidth-limited decode throughput formula. See https://arxiv.org/abs/2211.05102. ↩
The Apple GPU Family: M1 through M4
Now that we understand the SoC architecture, unified memory, and GPU execution model, let’s survey the actual hardware across Apple Silicon generations. Each generation brought meaningful improvements for ML inference, and akunu’s ChipConfig system tunes its behavior for each.
Generation Overview1
| Gen | Year | Process | GPU Family | Key ML Feature |
|---|---|---|---|---|
| M1 | 2020 | 5nm | Apple 7 | SIMD group matrix ops, UMA |
| M2 | 2022 | 5nm (2nd gen) | Apple 8 | More bandwidth, cores |
| M3 | 2023 | 3nm | Apple 9 | Dynamic caching, ray tracing HW |
| M4 | 2024 | 3nm (2nd gen) | Apple 9+ | Enhanced ML, higher clocks |
M1 Family (2020-2021)
The M1 was the first Apple Silicon chip for Mac. It proved the concept works.
M1 Family Specifications:
┌──────────┬──────────┬──────────┬──────────┬──────────┐
│ │ M1 │ M1 Pro │ M1 Max │ M1 Ultra │
├──────────┼──────────┼──────────┼──────────┼──────────┤
│ GPU Cores│ 7-8 │ 14-16 │ 24-32 │ 48-64 │
│ CPU Cores│ 8 │ 8-10 │ 10 │ 20 │
│ Max RAM │ 16 GB │ 32 GB │ 64 GB │ 128 GB │
│ Mem BW │ 68 GB/s │ 200 GB/s │ 400 GB/s │ 800 GB/s │
│ SLC │ 16 MB │ 24 MB │ 48 MB │ 96 MB │
│ GPU Fam │ Apple 7 │ Apple 7 │ Apple 7 │ Apple 7 │
│ FP32 TF │ 2.6 │ 5.2 │ 10.4 │ 20.8 │
│ FP16 TF │ 5.2 │ 10.4 │ 20.8 │ 41.6 │
│ NE TOPS │ 11 │ 11 │ 11 │ 22 │
└──────────┴──────────┴──────────┴──────────┴──────────┘
For inference:
- M1 base: Usable for small models (1-3B Q4). 68 GB/s bandwidth limits decode speed
- M1 Max 64GB: First chip that could comfortably run 13B Q4 models
- M1 Ultra 128GB: Could run 70B Q4 models, albeit slowly
Key capability: Apple GPU Family 7 introduced simdgroup_matrix operations (8×8 matrix tiles), which akunu uses for GEMM kernels during prefill.
M2 Family (2022-2023)
An evolutionary improvement: more cores, more bandwidth, same architecture.
M2 Family Specifications:
┌──────────┬──────────┬──────────┬──────────┬──────────┐
│ │ M2 │ M2 Pro │ M2 Max │ M2 Ultra │
├──────────┼──────────┼──────────┼──────────┼──────────┤
│ GPU Cores│ 8-10 │ 16-19 │ 30-38 │ 60-76 │
│ CPU Cores│ 8 │ 10-12 │ 12 │ 24 │
│ Max RAM │ 24 GB │ 32 GB │ 96 GB │ 192 GB │
│ Mem BW │ 100 GB/s │ 200 GB/s │ 400 GB/s │ 800 GB/s │
│ SLC │ 16 MB │ 24 MB │ 48 MB │ 96 MB │
│ GPU Fam │ Apple 8 │ Apple 8 │ Apple 8 │ Apple 8 │
│ FP32 TF │ 3.6 │ 6.8 │ 13.6 │ 27.2 │
│ FP16 TF │ 7.2 │ 13.6 │ 27.2 │ 54.4 │
│ NE TOPS │ 15.8 │ 15.8 │ 15.8 │ 31.6 │
└──────────┴──────────┴──────────┴──────────┴──────────┘
For inference:
- M2 Max 96GB: Sweet spot for 70B Q4 models
- M2 Ultra 192GB: Could run quantized 100B+ models
- ~30% faster than M1 at equivalent tier
M3 Family (2023-2024)
The jump to 3nm brought significant GPU architectural changes.
M3 Family Specifications:
┌──────────┬──────────┬──────────┬──────────┬──────────┐
│ │ M3 │ M3 Pro │ M3 Max │ M3 Ultra │
├──────────┼──────────┼──────────┼──────────┼──────────┤
│ GPU Cores│ 8-10 │ 14-18 │ 30-40 │ 60-80 │
│ CPU Cores│ 8 │ 11-12 │ 14-16 │ 28-32 │
│ Max RAM │ 24 GB │ 36 GB │ 128 GB │ 192 GB │
│ Mem BW │ 100 GB/s │ 150 GB/s │ 400 GB/s │ 800 GB/s │
│ SLC │ 16 MB │ 24 MB │ 48 MB │ 96 MB │
│ GPU Fam │ Apple 9 │ Apple 9 │ Apple 9 │ Apple 9 │
│ FP32 TF │ 4.1 │ 7.4 │ 16.4 │ 32.8 │
│ FP16 TF │ 8.2 │ 14.8 │ 32.8 │ 65.6 │
│ NE TOPS │ 18 │ 18 │ 18 │ 36 │
└──────────┴──────────┴──────────┴──────────┴──────────┘
M3’s GPU innovations:
- Dynamic Caching: GPU dynamically allocates register and threadgroup memory per kernel, instead of statically at dispatch time. This improves occupancy for kernels that don’t use much threadgroup memory.
- Ray Tracing Hardware: Not relevant for inference, but shows architectural maturation
- Mesh Shading: Graphics feature, not relevant here
- Apple GPU Family 9: same SIMD group matrix ops as Family 7, but with improved scheduling
M4 Family (2024-2025)
The latest generation, optimized for AI workloads.
M4 Family Specifications:
┌──────────┬──────────┬──────────┬──────────┬──────────┐
│ │ M4 │ M4 Pro │ M4 Max │ M4 Ultra │
├──────────┼──────────┼──────────┼──────────┼──────────┤
│ GPU Cores│ 10 │ 20 │ 40 │ 80 │
│ CPU Cores│ 4P+6E │ 10P+4E │ 12P+4E │ 24P+8E │
│ Max RAM │ 32 GB │ 48 GB │ 128 GB │ 192 GB │
│ Mem BW │ 120 GB/s │ 273 GB/s │ 546 GB/s │ 819 GB/s │
│ SLC │ 16 MB │ 36 MB │ 48 MB │ 96 MB │
│ GPU Fam │ Apple 9 │ Apple 9 │ Apple 9 │ Apple 9 │
│ FP32 TF │ 4.6 │ 9.2 │ 18.4 │ 36.8 │
│ FP16 TF │ 9.2 │ 18.4 │ 36.8 │ 73.6 │
│ NE TOPS │ 38 │ 38 │ 38 │ 76 │
└──────────┴──────────┴──────────┴──────────┴──────────┘
M4 highlights for inference:
- Significantly higher memory bandwidth at each tier (M4 Pro: 273 vs M3 Pro: 150 GB/s)
- Larger SLC (M4 Pro: 36 MB vs M3 Pro: 24 MB)
- More GPU cores per tier
- Enhanced Neural Engine (38 TOPS)
Inference Performance by Chip
Here are theoretical decode speeds for common model sizes:
Tokens/sec (theoretical max, Q4_0 quantization):
Model Size: 3B (~1.7GB) 7B (~3.8GB) 13B (~7.3GB) 70B (~40GB)
─────────────────────────────────────────────────────────────────────
M4 70 31 16 —*
M4 Pro 160 71 37 6.8
M4 Max 321 143 74 13.6
M4 Ultra 481 215 112 20.4
─────────────────────────────────────────────────────────────────────
*Doesn't fit in 32 GB RAM
Actual performance will be 60-80% of theoretical due to cache misses, overhead, and non-ideal memory access patterns.2 Akunu typically achieves 70-85% of theoretical bandwidth utilization.
Akunu’s ChipConfig
Akunu detects the hardware at startup and selects tuning parameters via the ChipConfig struct. Here’s how different chips get different configurations:
ChipConfig Parameters:
┌────────────────────┬─────────────────────────────────────┐
│ Parameter │ Purpose │
├────────────────────┼─────────────────────────────────────┤
│ slc_size │ SLC size in bytes. Affects tile │
│ │ sizes and prefetch strategies. │
├────────────────────┼─────────────────────────────────────┤
│ gemv_k_threshold │ K dimension where GEMV switches │
│ │ from small (4 SGs) to large (8 SGs) │
│ │ variant. Higher bandwidth chips can │
│ │ benefit from more parallelism. │
├────────────────────┼─────────────────────────────────────┤
│ prefill_chunk │ Max tokens per prefill batch. │
│ │ Limited by threadgroup memory for │
│ │ attention. Larger chips handle more. │
├────────────────────┼─────────────────────────────────────┤
│ chain_decode_count │ Tokens per chained decode GPU │
│ │ submission. More cores → more tokens │
│ │ can be chained without overhead. │
└────────────────────┴─────────────────────────────────────┘
Example configurations:
M4 (10 GPU cores, 120 GB/s, 16 MB SLC):
gemv_k_threshold = 2048
prefill_chunk = 512
chain_decode_count = 4
M4 Pro (20 GPU cores, 273 GB/s, 36 MB SLC):
gemv_k_threshold = 2048
prefill_chunk = 2048
chain_decode_count = 6
M4 Max (40 GPU cores, 546 GB/s, 48 MB SLC):
gemv_k_threshold = 4096
prefill_chunk = 4096
chain_decode_count = 8
The key insight: the same kernels run on all chips, but the dispatch table uses different configurations. The GEMV kernel for Q4_0 on an M4 uses 4 SIMD groups (128 threads), while on an M4 Max it might use 8 SIMD groups (256 threads) for K dimensions above the threshold.
Choosing the Right Hardware for Your Workload
┌─────────────────────────────────────────────────────────────┐
│ MODEL SIZE vs CHIP GUIDE │
│ │
│ Model Size Recommended Minimum Sweet Spot │
│ ────────── ──────────────────── ────────── │
│ 1-3B M4 (16 GB) M4 Pro (24 GB) │
│ 7B M4 Pro (24 GB) M4 Pro (36 GB) │
│ 13B M4 Pro (36 GB) M4 Max (64 GB) │
│ 34B M4 Max (64 GB) M4 Max (128 GB) │
│ 70B M4 Max (128 GB) M4 Ultra (192 GB) │
│ 100B+ M4 Ultra (192 GB) — │
│ │
│ Rule of thumb: you need ~1.2x the model size in Q4 as RAM │
│ (weights + KV cache + scratch buffers + OS overhead) │
└─────────────────────────────────────────────────────────────┘
Summary
Across four generations:
- M1 introduced UMA and SIMD group matrix ops — the foundation
- M2 increased bandwidth and core counts — evolutionary improvement
- M3 added dynamic caching and 3nm efficiency — architectural refinement
- M4 pushed bandwidth dramatically higher — the best for inference
Akunu’s ChipConfig abstracts these differences into tuning parameters, so the same codebase runs optimally on everything from an M1 MacBook Air to an M4 Ultra Mac Studio. The key variables are always the same: memory bandwidth, GPU core count, and SLC size.
In the next part, we’ll learn how to actually program these GPUs using the Metal framework.
-
Apple. “Apple M1 chip”, “Apple M2 chip”, “Apple M3 chip”, “Apple M4 chip.” apple.com, 2020-2024. Official specifications for each chip generation. See https://www.apple.com/newsroom/2024/05/apple-introduces-m4-chip/. ↩
-
Frumusanu, A. and Smith, R. “Apple M-series deep dives.” AnandTech, 2020-2023. Independent benchmarks and die analysis across generations. See https://chipsandcheese.com/. ↩
Introduction to Metal
Welcome to Part II of this book. In Part I, we explored the hardware that makes Apple Silicon so compelling for machine learning inference: the system-on-chip design, the GPU architecture with its SIMD groups and threadgroups, and the unified memory architecture that eliminates the CPU-to-GPU copy bottleneck. Now it is time to learn how to actually program that hardware.
The answer is Metal – Apple’s low-level GPU programming framework. Over the next seven chapters, we will go from zero to writing high-performance compute shaders that form the backbone of ML inference engines like akunu. By the end of Part II, you will understand every layer of the Metal compute stack, from the API calls in Swift/Objective-C all the way down to the individual threads executing on GPU cores.
Let us begin.
What Is Metal?
Metal is Apple’s unified graphics and compute API.1 Introduced at WWDC 2014, it replaced the aging OpenGL ES on iOS and eventually OpenGL and OpenCL on macOS. Think of Metal as Apple’s answer to Vulkan or DirectX 12 – a modern, low-overhead GPU programming interface that gives you explicit control over the hardware.2
But Metal is more than just a graphics API. It has three major facets:
+-----------------------------------------------------------+
| Metal Framework |
+-----------------------------------------------------------+
| |
| +----------------+ +----------------+ +-------------+ |
| | Metal API | | Metal Shading | | Metal | |
| | (Swift/ObjC) | | Language | | Performance| |
| | | | (C++14-based) | | Shaders | |
| | - Device | | | | (MPS) | |
| | - Buffers | | - Compute | | | |
| | - Queues | | kernels | | - MatMul | |
| | - Encoders | | - Vertex/ | | - Conv | |
| | - Pipelines | | Fragment | | - Image | |
| | - Textures | | shaders | | ops | |
| +----------------+ +----------------+ +-------------+ |
| |
+-----------------------------------------------------------+
-
The Metal API (Swift/Objective-C): The host-side interface. You use this to create GPU devices, allocate memory, build command buffers, and submit work to the GPU. This runs on the CPU.
-
The Metal Shading Language (MSL): The language you write GPU programs in. It is based on C++14 with Apple-specific extensions for GPU concepts like threadgroups, SIMD operations, and address spaces. Your compute kernels are written in MSL.
-
Metal Performance Shaders (MPS): A library of pre-built, highly optimized GPU kernels for common operations – matrix multiplication, convolution, image processing, neural network layers, and more. Think of MPS as Apple’s cuDNN equivalent.
For ML inference, we primarily care about the compute side of Metal. We will barely touch graphics (vertex/fragment shaders, render passes, etc.). Our world is compute kernels, buffers, and dispatches.
Metal vs. The Competition
If you are coming from CUDA, OpenCL, or Vulkan, you will find Metal familiar in some ways and different in others. Let us compare:
Metal vs. CUDA
+-------------------+------------------------------------------+
| Aspect | CUDA | Metal |
+-------------------+-------------------------+----------------+
| Vendor | NVIDIA only | Apple only |
| Language | CUDA C/C++ (extended) | MSL (C++14 |
| | | extended) |
| Host API | CUDA Runtime/Driver | Metal API |
| | (C/C++) | (Swift/ObjC) |
| Execution model | Grid → Block → Thread | Grid → |
| | | Threadgroup → |
| | | Thread |
| SIMD width | 32 (warp) | 32 (SIMD group)|
| Shared memory | __shared__ | threadgroup |
| | | address space|
| Sync primitive | __syncthreads() | threadgroup_ |
| | | barrier() |
| Matrix accel. | Tensor Cores (wmma) | simdgroup_ |
| | | matrix |
| Memory model | Discrete + UVA | Unified (UMA) |
| Ecosystem | Massive (cuDNN, cuBLAS, | Smaller (MPS, |
| | TensorRT, Triton) | MPSGraph, |
| | | Core ML) |
| Maturity for ML | 15+ years | ~5 years |
+-------------------+-------------------------+----------------+
The biggest conceptual difference: CUDA’s programming model treats the GPU as a separate device with its own memory.3 You explicitly copy data between CPU and GPU. On Apple Silicon with Metal, the CPU and GPU share the same physical memory (UMA). There is no copy – you allocate a buffer once, and both CPU and GPU can access it.
The biggest practical difference: CUDA has a massive ecosystem for ML. Libraries like cuDNN, cuBLAS, TensorRT, and Triton make it possible to write high-performance ML code without touching raw CUDA kernels. Metal’s ecosystem is smaller. MPS provides some building blocks, but for state-of-the-art inference, you often need to write custom kernels – which is exactly what akunu does.
Metal vs. OpenCL
+-------------------+------------------------------------------+
| Aspect | OpenCL | Metal |
+-------------------+-------------------------+----------------+
| Portability | Cross-platform | Apple only |
| API style | C-based, verbose | ObjC/Swift, |
| | | modern |
| Shader language | OpenCL C (C99-based) | MSL (C++14) |
| Runtime compile | Yes (common) | Yes + offline |
| Performance | Driver-dependent | Tuned for |
| | | Apple HW |
| Status on Apple | Deprecated since | Active, primary|
| | macOS 10.14 | GPU API |
+-------------------+-------------------------+----------------+
OpenCL was once available on macOS, but Apple deprecated it in 2018. Metal is the only supported path for GPU compute on Apple platforms. If you have existing OpenCL kernels, they need to be ported to MSL.
The good news: the conceptual mapping is straightforward. OpenCL work-groups are Metal threadgroups. OpenCL __local memory is Metal threadgroup memory. OpenCL __global is Metal device. The languages are different (C99 vs C++14), but the GPU programming model is fundamentally the same.
Metal vs. Vulkan
+-------------------+------------------------------------------+
| Aspect | Vulkan | Metal |
+-------------------+-------------------------+----------------+
| Portability | Cross-platform | Apple only |
| Verbosity | Extremely verbose | Moderate |
| Shader language | SPIR-V (usually from | MSL (write |
| | GLSL/HLSL) | directly) |
| Compute support | Full | Full |
| Validation layers | External (very helpful) | Metal API |
| | | Validation |
| Driver overhead | Very low | Very low |
| On Apple | Via MoltenVK (wrapper | Native |
| | over Metal) | |
+-------------------+-------------------------+----------------+
Vulkan and Metal share the same philosophy: explicit, low-overhead GPU control. Both require you to manage command buffers, synchronization, and pipeline states yourself. Vulkan is more verbose (creating a compute pipeline in Vulkan can take hundreds of lines), while Metal strikes a balance between control and usability.
Fun fact: MoltenVK, the Vulkan-on-Apple implementation, is actually a translation layer that converts Vulkan calls to Metal calls underneath. So Metal is the true native API.
Summary: Why Metal for ML on Apple?
Why Metal?
==========
1. It is the ONLY way to access Apple GPU compute
(OpenCL is deprecated, no CUDA, Vulkan is via MoltenVK)
2. Unified Memory Architecture means ZERO-COPY buffer sharing
between CPU and GPU -- huge for inference
3. Metal Shading Language is pleasant to write
(C++14 with nice extensions, not as painful as GLSL)
4. simdgroup_matrix operations give you hardware-accelerated
matrix multiply (like Tensor Cores)
5. Apple tunes Metal drivers specifically for their hardware
(you get the best possible performance)
The Metal Ecosystem
Let us zoom in on the three pillars of the Metal ecosystem and understand how they fit together for ML workloads.
The Metal API (Host Side)
The Metal API is an Objective-C/Swift framework. You use it on the CPU to orchestrate GPU work. The key objects are:
+----------------------------------------------------+
| Your Application |
| (Swift / ObjC / C++) |
+----------------------------------------------------+
|
v
+----------------------------------------------------+
| MTLDevice |
| Represents the GPU. Entry point for everything. |
| - makeCommandQueue() |
| - makeBuffer(length:options:) |
| - makeComputePipelineState(function:) |
| - makeLibrary(source:options:) |
+----------------------------------------------------+
|
+------------------+------------------+
| | |
v v v
+-----------------+ +-----------------+ +-----------------+
| MTLCommandQueue | | MTLBuffer | | MTLLibrary |
| Ordered queue | | GPU memory | | Collection of |
| of cmd buffers | | allocation | | compiled shaders|
+-----------------+ +-----------------+ +-----------------+
| |
v v
+-----------------+ +-----------------+
| MTLCommandBuffer| | MTLFunction |
| A batch of GPU | | A single shader |
| commands | | entry point |
+-----------------+ +-----------------+
| |
v v
+-----------------------+ +---------------------------+
|MTLComputeCommandEncoder| |MTLComputePipelineState |
| Records compute cmds | | Compiled, ready-to-run |
| (set buffers, dispatch)| | version of a kernel |
+-----------------------+ +---------------------------+
We will explore each of these objects in detail in Chapter 7. For now, just know the flow: you create a device, create a command queue, create command buffers, encode commands into them, and commit them to the GPU.
The Metal Shading Language (MSL)
MSL is the language you write GPU programs in. It looks like C++14 with some extra keywords and types:
// A simple MSL compute kernel
kernel void add_arrays(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* result [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
result[id] = a[id] + b[id];
}
Key MSL features for compute:
- Address space qualifiers:
device,constant,threadgroup,thread - Attribute syntax:
[[buffer(0)]],[[thread_position_in_grid]],[[kernel]] - Built-in vector types:
half,half4,float4,uint2, etc. - SIMD group intrinsics:
simd_sum(),simd_shuffle(), etc. - Threadgroup memory: shared memory within a threadgroup
- SIMD group matrix ops:
simdgroup_half8x8,simdgroup_multiply_accumulate()
MSL is covered in depth in Chapter 8.
Metal Performance Shaders (MPS)
MPS is Apple’s library of optimized GPU kernels. For ML, the most relevant parts are:
- MPSMatrixMultiplication: Optimized GEMM
- MPSImageConvolution: 2D convolution
- MPSNNGraph: Neural network inference graph (older API)
- MPSGraph: A more modern compute graph framework
MPS is useful as a starting point, but for state-of-the-art inference performance, custom kernels often outperform MPS. This is because:
- MPS kernels are general-purpose. A custom kernel can be specialized for exact matrix dimensions, quantization formats, and fusion patterns.
- MPS cannot fuse arbitrary operations. A custom kernel can fuse (say) dequantization + matrix multiply + bias add + activation into a single pass, saving memory bandwidth.
- MPS does not support all quantization formats used in modern LLM inference.
This is exactly why akunu writes its own Metal kernels rather than relying on MPS.
When to Use Metal Compute vs. Metal Graphics
Metal supports two kinds of GPU work:
- Graphics (render pipelines): Vertex shaders, fragment shaders, rasterization, render passes. This is for drawing things on screen.
- Compute (compute pipelines): General-purpose computation on the GPU. No rendering, no pixels – just data in, data out.
For ML inference, we use compute exclusively. Here is why:
Graphics Pipeline: Compute Pipeline:
================== ==================
Vertices → Vertex Shader Data → Compute Kernel → Data
→ Rasterizer
→ Fragment Shader
→ Framebuffer
- Fixed-function stages - Fully programmable
- Designed for rendering - Designed for GPGPU
- Data flows through - You control data flow
a rigid pipeline completely
- Output is pixels - Output is whatever you want
Compute pipelines give us:
- Arbitrary data access patterns: Read from and write to any buffer location
- Threadgroup shared memory: Fast scratchpad for inter-thread communication
- Flexible dispatch: 1D, 2D, or 3D grids of arbitrary size
- No rendering overhead: No rasterizer, no framebuffer, no blend state
There are rare cases where graphics shaders are (ab)used for compute – for example, some older GPU compute techniques use fragment shaders to process textures. But on modern Apple GPUs, compute shaders are the right tool for ML.
The Metal Programming Model
Now let us build up the mental model for how Metal compute works. This is the single most important section in this chapter, so take your time with it.
The Big Picture
Here is the full pipeline from your application to GPU execution:
YOUR APPLICATION (CPU)
======================
1. Get a reference to the GPU
+------------------+
| MTLDevice | <-- Represents the GPU hardware
+------------------+
|
2. Create a command queue (once, reuse it)
|
v
+------------------+
| MTLCommandQueue | <-- FIFO queue of command buffers
+------------------+
|
3. Create a command buffer (one per "batch" of work)
|
v
+------------------+
| MTLCommandBuffer | <-- Container for GPU commands
+------------------+
|
4. Create a compute command encoder
|
v
+--------------------------+
| MTLComputeCommandEncoder | <-- Records compute commands
+--------------------------+
|
5. Set the pipeline state (which kernel to run)
6. Set buffers (input/output data)
7. Dispatch threadgroups (how many threads to launch)
8. End encoding
|
v
9. Commit the command buffer to the GPU
|
v
GPU EXECUTION
=============
The GPU picks up the command buffer from the queue,
executes the recorded commands:
- Binds the kernel
- Binds the buffers
- Launches threadgroups across GPU cores
- Each thread runs the kernel function
- Results are written to output buffers
|
v
RESULTS AVAILABLE
=================
(In the output buffer, which on UMA is already
accessible to the CPU -- no copy needed!)
Let us walk through each component.
MTLDevice – The GPU
Everything starts with a MTLDevice. This object represents the GPU hardware. You get it like this:
// Swift
guard let device = MTLCreateSystemDefaultDevice() else {
fatalError("Metal is not supported on this device")
}
Or in Objective-C:
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
On a Mac with Apple Silicon, this gives you the built-in GPU. On a Mac with multiple GPUs (e.g., an older Mac Pro with an eGPU), you can enumerate all devices with MTLCopyAllDevices().
The device is your factory for creating everything else: buffers, command queues, pipeline states, libraries.
MTLCommandQueue – The Submission Highway
A command queue is an ordered sequence of command buffers. You typically create one queue at startup and reuse it for the lifetime of your application:
guard let commandQueue = device.makeCommandQueue() else {
fatalError("Could not create command queue")
}
Think of it as a highway on-ramp. Command buffers you commit to the queue will be executed in order (mostly – Metal can reorder independent work for efficiency, but the observable results respect submission order for resources).
MTLCommandBuffer – A Batch of Work
A command buffer is a container for GPU commands. You create one whenever you have work to submit:
guard let commandBuffer = commandQueue.makeCommandBuffer() else {
fatalError("Could not create command buffer")
}
A command buffer can contain multiple encoder passes. For compute work, each pass uses a MTLComputeCommandEncoder. The command buffer is not executed until you call commit().
Command Buffer Lifecycle:
=========================
Created ──> Encoding ──> Committed ──> Scheduled ──> Completed
| | | | |
| (you record (you call (Metal (GPU has
| commands) .commit()) schedules finished)
| execution)
MTLComputeCommandEncoder – Recording Commands
The encoder is how you record commands into the command buffer. For compute work:
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
fatalError("Could not create compute encoder")
}
// Record commands:
encoder.setComputePipelineState(pipelineState)
encoder.setBuffer(inputBuffer, offset: 0, index: 0)
encoder.setBuffer(outputBuffer, offset: 0, index: 1)
encoder.dispatchThreadgroups(gridSize, threadsPerThreadgroup: groupSize)
encoder.endEncoding()
Important: the encoder does not execute anything. It records commands into the command buffer. Execution happens later when you commit.
MTLComputePipelineState – The Compiled Kernel
Before you can dispatch a kernel, you need to compile it into a pipeline state object (PSO). This involves:
- Loading your MSL source code (or a pre-compiled
.metallibbinary) - Getting a
MTLFunctionfrom the library - Creating a
MTLComputePipelineStatefrom the function
// Load the default library (compiled from .metal files in your project)
let library = device.makeDefaultLibrary()!
// Get the kernel function by name
let function = library.makeFunction(name: "add_arrays")!
// Create the pipeline state (this compiles the function for the GPU)
let pipelineState = try device.makeComputePipelineState(function: function)
Creating a PSO can be expensive (it involves final compilation and optimization), so you typically do it once at startup and cache the result. We will explore this in detail in Chapter 7.
Dispatch – Launching Threads
The final piece is telling the GPU how many threads to launch. Metal gives you two options:
// Option 1: Dispatch by threadgroup count
// You specify: (number of threadgroups) x (threads per threadgroup)
let gridSize = MTLSize(width: 64, height: 1, depth: 1)
let groupSize = MTLSize(width: 256, height: 1, depth: 1)
encoder.dispatchThreadgroups(gridSize, threadsPerThreadgroup: groupSize)
// Total threads = 64 * 256 = 16,384
// Option 2: Dispatch by total thread count (Metal adjusts automatically)
let totalThreads = MTLSize(width: 16384, height: 1, depth: 1)
let groupSize = MTLSize(width: 256, height: 1, depth: 1)
encoder.dispatchThreads(totalThreads, threadsPerThreadgroup: groupSize)
Option 1 (dispatchThreadgroups) requires you to calculate the grid dimensions yourself. Option 2 (dispatchThreads) lets you specify the total number of threads, and Metal handles the math. We will discuss the tradeoffs in Chapter 9.
Putting It All Together: Hello World Compute Shader
Let us write a complete example that adds two arrays on the GPU. This is the “Hello World” of GPU computing.
Step 1: The Metal Shader (MSL)
Create a file called compute.metal:
// compute.metal
// A simple kernel that adds two arrays element-wise.
#include <metal_stdlib>
using namespace metal;
kernel void add_arrays(
device const float* inA [[buffer(0)]],
device const float* inB [[buffer(1)]],
device float* out [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
out[id] = inA[id] + inB[id];
}
Let us break down every piece:
#include <metal_stdlib>– Includes standard Metal functions and types.using namespace metal;– Avoids having to prefix everything withmetal::.kernel void add_arrays(...)– Thekernelkeyword marks this as a compute kernel entry point. It must returnvoid.device const float* inA [[buffer(0)]]– A pointer to read-only data in thedeviceaddress space (GPU-accessible memory). The[[buffer(0)]]attribute tells Metal this is bound at buffer index 0.device float* out [[buffer(2)]]– A writable output buffer at index 2.uint id [[thread_position_in_grid]]– A built-in variable that gives each thread its unique index in the dispatch grid. Thread 0 getsid=0, thread 1 getsid=1, etc.
The kernel body is trivial: each thread reads one element from inA and inB, adds them, and writes the result to out.
Visualization of execution:
===========================
Thread 0: out[0] = inA[0] + inB[0]
Thread 1: out[1] = inA[1] + inB[1]
Thread 2: out[2] = inA[2] + inB[2]
...
Thread N: out[N] = inA[N] + inB[N]
Each thread handles exactly one element.
All threads execute in parallel across GPU cores.
Step 2: The Host Code (Swift)
Here is the complete Swift code to set up Metal, compile the kernel, prepare data, dispatch the kernel, and read the results:
import Metal
import Foundation
// ============================================================
// STEP 1: Get the GPU device
// ============================================================
guard let device = MTLCreateSystemDefaultDevice() else {
fatalError("Metal is not supported on this device")
}
print("Using GPU: \(device.name)")
// ============================================================
// STEP 2: Create a command queue
// ============================================================
guard let commandQueue = device.makeCommandQueue() else {
fatalError("Could not create command queue")
}
// ============================================================
// STEP 3: Load and compile the shader
// ============================================================
// Option A: Load from a .metal file in the project bundle
// let library = device.makeDefaultLibrary()!
// Option B: Compile from source string at runtime
let shaderSource = """
#include <metal_stdlib>
using namespace metal;
kernel void add_arrays(
device const float* inA [[buffer(0)]],
device const float* inB [[buffer(1)]],
device float* out [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
out[id] = inA[id] + inB[id];
}
"""
let library = try! device.makeLibrary(source: shaderSource, options: nil)
let function = library.makeFunction(name: "add_arrays")!
let pipelineState = try! device.makeComputePipelineState(function: function)
// ============================================================
// STEP 4: Prepare the data
// ============================================================
let arrayLength = 1_000_000
let bufferSize = arrayLength * MemoryLayout<Float>.size
// Create Metal buffers with shared storage mode (CPU + GPU access)
let bufferA = device.makeBuffer(length: bufferSize, options: .storageModeShared)!
let bufferB = device.makeBuffer(length: bufferSize, options: .storageModeShared)!
let bufferOut = device.makeBuffer(length: bufferSize, options: .storageModeShared)!
// Fill input buffers with data
let pointerA = bufferA.contents().bindMemory(to: Float.self, capacity: arrayLength)
let pointerB = bufferB.contents().bindMemory(to: Float.self, capacity: arrayLength)
for i in 0..<arrayLength {
pointerA[i] = Float(i)
pointerB[i] = Float(i) * 2.0
}
// ============================================================
// STEP 5: Create a command buffer and encoder
// ============================================================
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
// ============================================================
// STEP 6: Encode the compute command
// ============================================================
encoder.setComputePipelineState(pipelineState)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferOut, offset: 0, index: 2)
// Calculate dispatch sizes
let threadGroupSize = MTLSize(width: 256, height: 1, depth: 1)
let threadGroups = MTLSize(
width: (arrayLength + 255) / 256, // Ceiling division
height: 1,
depth: 1
)
encoder.dispatchThreadgroups(threadGroups, threadsPerThreadgroup: threadGroupSize)
encoder.endEncoding()
// ============================================================
// STEP 7: Commit and wait
// ============================================================
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// ============================================================
// STEP 8: Read the results
// ============================================================
let pointerOut = bufferOut.contents().bindMemory(to: Float.self, capacity: arrayLength)
// Verify a few results
for i in 0..<5 {
print("out[\(i)] = \(pointerOut[i]) (expected: \(Float(i) + Float(i) * 2.0))")
}
// Output:
// out[0] = 0.0 (expected: 0.0)
// out[1] = 3.0 (expected: 3.0)
// out[2] = 6.0 (expected: 6.0)
// out[3] = 9.0 (expected: 9.0)
// out[4] = 12.0 (expected: 12.0)
The Full Flow Visualized
Here is what happens when you run this code:
CPU Side GPU Side
======== ========
1. MTLCreateSystemDefaultDevice()
+---> Gets handle to GPU
2. makeCommandQueue()
+---> Creates submission queue
3. makeLibrary(source:) Compiler: MSL → AIR → GPU ISA
makeFunction(name:)
makeComputePipelineState()
4. makeBuffer() x 3 Allocates in unified memory
Fill bufferA, bufferB (same physical RAM)
5. makeCommandBuffer()
makeComputeCommandEncoder()
6. setComputePipelineState() |
setBuffer() x 3 | All recorded, not
dispatchThreadgroups() | yet executed
endEncoding() |
7. commandBuffer.commit()
+-----------------------------------> GPU picks up work
|
commandBuffer.waitUntilCompleted() v
(CPU blocks here) GPU launches 3,907
threadgroups of 256
threads each
|
v
Each thread runs
add_arrays kernel
|
v
Results written to
bufferOut
|
<-----(completion signal)------------- Done!
8. Read pointerOut[i] (Same physical memory,
no copy needed!)
Notice step 8: we read the results directly from the buffer’s contents() pointer. There is no “copy back from GPU” step. This is the UMA advantage – the buffer lives in unified memory, accessible to both CPU and GPU.
The Objective-C Version
Since akunu is written in C/C++ and uses the Metal Objective-C API (via Objective-C++), here is the same example in Objective-C:
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
int main(int argc, const char * argv[]) {
@autoreleasepool {
// 1. Get the device
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
NSLog(@"Using GPU: %@", device.name);
// 2. Create command queue
id<MTLCommandQueue> commandQueue = [device newCommandQueue];
// 3. Compile the shader
NSString *shaderSource = @
"#include <metal_stdlib>\n"
"using namespace metal;\n"
"kernel void add_arrays(\n"
" device const float* inA [[buffer(0)]],\n"
" device const float* inB [[buffer(1)]],\n"
" device float* out [[buffer(2)]],\n"
" uint id [[thread_position_in_grid]]\n"
") {\n"
" out[id] = inA[id] + inB[id];\n"
"}\n";
NSError *error = nil;
id<MTLLibrary> library = [device newLibraryWithSource:shaderSource
options:nil
error:&error];
id<MTLFunction> function = [library newFunctionWithName:@"add_arrays"];
id<MTLComputePipelineState> pso =
[device newComputePipelineStateWithFunction:function error:&error];
// 4. Create buffers
NSUInteger arrayLength = 1000000;
NSUInteger bufferSize = arrayLength * sizeof(float);
id<MTLBuffer> bufA = [device newBufferWithLength:bufferSize
options:MTLResourceStorageModeShared];
id<MTLBuffer> bufB = [device newBufferWithLength:bufferSize
options:MTLResourceStorageModeShared];
id<MTLBuffer> bufOut = [device newBufferWithLength:bufferSize
options:MTLResourceStorageModeShared];
// Fill input data
float *ptrA = (float *)bufA.contents;
float *ptrB = (float *)bufB.contents;
for (NSUInteger i = 0; i < arrayLength; i++) {
ptrA[i] = (float)i;
ptrB[i] = (float)i * 2.0f;
}
// 5-7. Encode and dispatch
id<MTLCommandBuffer> cmdBuf = [commandQueue commandBuffer];
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
[enc setComputePipelineState:pso];
[enc setBuffer:bufA offset:0 atIndex:0];
[enc setBuffer:bufB offset:0 atIndex:1];
[enc setBuffer:bufOut offset:0 atIndex:2];
MTLSize groupSize = MTLSizeMake(256, 1, 1);
MTLSize gridSize = MTLSizeMake((arrayLength + 255) / 256, 1, 1);
[enc dispatchThreadgroups:gridSize threadsPerThreadgroup:groupSize];
[enc endEncoding];
[cmdBuf commit];
[cmdBuf waitUntilCompleted];
// 8. Read results
float *ptrOut = (float *)bufOut.contents;
for (int i = 0; i < 5; i++) {
NSLog(@"out[%d] = %.1f (expected: %.1f)", i, ptrOut[i],
(float)i + (float)i * 2.0f);
}
}
return 0;
}
The Objective-C API maps almost 1:1 to Swift. The main difference is syntax ([device newCommandQueue] vs device.makeCommandQueue()). akunu uses Objective-C++ so it can mix C++ code with Metal API calls seamlessly.
How akunu Uses Metal
Now that you understand the basics, let us peek at how akunu structures its Metal usage. We will go much deeper in Part IV, but a high-level overview helps connect the dots:
akunu Architecture (simplified):
=================================
+----------------------------------------------------------+
| akunu C API |
| ak_context_create() ak_generate() ak_decode_step() |
+----------------------------------------------------------+
|
v
+----------------------------------------------------------+
| MetalDevice |
| |
| - device_: id<MTLDevice> |
| - queue_: id<MTLCommandQueue> |
| - pso_cache_: HashMap<string, MTLComputePipelineState> |
| |
| Methods: |
| - allocate(size) → Buffer |
| - begin_encoding() → starts recording |
| - set_pipeline(Pipeline) → sets the kernel |
| - set_buffer(Buffer, offset, index) → binds data |
| - dispatch(grid, threadgroup) → launches threads |
| - end_encoding() → finishes recording |
| - commit() → submits to GPU |
+----------------------------------------------------------+
|
v
+----------------------------------------------------------+
| Metal Shaders |
| (.metal files compiled to .metallib) |
| |
| - gemv_f16.metal (matrix-vector multiply) |
| - gemm_f16.metal (matrix-matrix multiply) |
| - attention.metal (flash attention) |
| - rms_norm.metal (RMS normalization) |
| - rope.metal (rotary position embedding) |
| - ... |
+----------------------------------------------------------+
akunu wraps the Metal API in a MetalDevice class that provides a cleaner, C++-friendly interface. Instead of creating MTLCommandBuffer and MTLComputeCommandEncoder objects directly, you call methods like begin_encoding(), set_pipeline(), set_buffer(), and dispatch().
The pipeline state objects (PSOs) are cached in a hash map (pso_cache_) so they are only compiled once. The cache key includes the kernel name and any specialization constants (e.g., "gemv_k128" for a GEMV kernel specialized for K=128).
Key Takeaways
Before moving to Chapter 7 where we dive deep into compute pipelines, let us summarize what we have learned:
-
Metal is Apple’s low-level GPU API – the only way to do GPU compute on Apple platforms.
-
Metal has three parts: the host API (Swift/ObjC), the Metal Shading Language (MSL), and Metal Performance Shaders (MPS).
-
For ML inference, we use compute pipelines, not graphics pipelines.
-
The programming model flows:
Device → Queue → Command Buffer → Encoder → Dispatch → Commit -
On Apple Silicon, UMA means zero-copy: buffers allocated with
.storageModeSharedare accessible to both CPU and GPU without any data transfer. -
MSL is C++14 with GPU extensions: address spaces, vector types, SIMD intrinsics, and thread indexing built-ins.
-
akunu wraps Metal in a
MetalDeviceclass that caches pipeline states and provides a streamlined C++ interface.
In the next chapter, we will go deep on compute pipelines: how kernels are compiled, how pipeline states are created and cached, and how the command encoding flow works in practice.
Exercises
-
Run the Hello World example: If you have a Mac with Apple Silicon, create an Xcode project (macOS Command Line Tool), add a
.metalfile with theadd_arrayskernel, and run the Swift host code. Verify the output. -
Modify the kernel: Change the kernel to compute
out[i] = inA[i] * inB[i] + 1.0(fused multiply-add). How does the host code change? (Hint: it does not, only the shader changes.) -
Explore MTLDevice properties: Print
device.maxThreadgroupMemoryLength,device.maxThreadsPerThreadgroup, anddevice.name. What values do you get on your hardware? -
Think about error handling: Our example uses
fatalErrorand force-unwraps everywhere. In production code (like akunu), what error handling strategy would you use? -
Compare with CUDA: If you have experience with CUDA, write down the mapping between CUDA concepts and Metal concepts. For example:
cudaMallocmaps todevice.makeBuffer,<<<grid, block>>>maps todispatchThreadgroups, etc.
-
Apple. “Metal Programming Guide.” developer.apple.com. The official reference for Metal’s programming model, including device creation, command queues, and compute pipelines. See https://developer.apple.com/documentation/metal/performing-calculations-on-a-gpu. ↩
-
Apple. “Metal Shading Language Specification.” developer.apple.com. The formal MSL specification covering types, address spaces, and built-in functions. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩
-
Apple. “Metal Best Practices Guide.” developer.apple.com. Performance guidance for buffer management, pipeline caching, and dispatch strategies. See https://developer.apple.com/library/archive/documentation/3DDrawing/Conceptual/MTLBestPracticesGuide/index.html. ↩
Compute Pipelines and Shaders
If you have ever worked with graphics APIs, the word “pipeline” probably conjures images of vertex shaders, rasterizers, fragment shaders, and blend states all wired together into a monolithic object. Compute pipelines are refreshingly simpler. A compute pipeline is, at its core, just one thing: a compiled GPU function plus the metadata the driver needs to launch it. That is it. No rasterizer, no depth test, no blend modes. One function, ready to run.
In this chapter we are going to trace the full lifecycle of a compute pipeline on Metal – from writing a kernel function in the Metal Shading Language, through the compilation stages that turn it into executable GPU code, all the way to encoding a dispatch command that actually runs the thing. Along the way we will see how akunu wraps this machinery to keep things fast and ergonomic from Rust.
What Is a Compute Pipeline State Object?
A MTLComputePipelineState (PSO) is an opaque, immutable object that represents a fully compiled and optimized GPU kernel. Once you have one, you can use it over and over again to dispatch work – the driver does not need to recompile anything. Creating a PSO is expensive (potentially tens of milliseconds), but using one is cheap (microseconds).
Think of it like compiling a C++ program. The compilation step is slow, but once you have the binary, running it is fast. You would never recompile the same source file every time you want to run the program. Same idea here.
+------------------+ +------------------+ +-------------------+
| MSL Source Code | ----> | MTLLibrary | ----> | MTLFunction |
| (kernel void | | (collection of | | (one specific |
| my_kernel...) | | compiled | | kernel function) |
+------------------+ | functions) | +-------------------+
+------------------+ |
v
+-------------------+
| MTLComputePipeline|
| State (PSO) |
| (ready to dispatch|
| on GPU) |
+-------------------+
Let us walk through each step.
The Compilation Flow: MSL to PSO
Metal has a multi-stage compilation pipeline. Understanding these stages matters because each one is a knob you can turn for performance.
Stage 1: MSL Source to AIR (Metal Intermediate Representation)
The Metal Shading Language is a dialect of C++14 with GPU-specific extensions (we will cover MSL in depth in the next chapter). When you compile MSL, the first output is AIR – Apple’s Intermediate Representation. AIR is analogous to LLVM IR (and in fact the Metal compiler is built on LLVM). It is a platform-independent representation of your shader.
+-----------+ metal compiler +----------+
| .metal | -------------------> | .air |
| (MSL src) | (clang frontend) | (LLVM IR |
+-----------+ | variant)|
+----------+
Stage 2: AIR to metallib (Metal Library)
Multiple AIR files are linked together into a .metallib file. This is a binary archive that contains all your compiled functions. You can think of it as the GPU equivalent of a .a static library or a .dylib.
+----------+
| .air |--+
+----------+ | metallib tool +------------+
+-------------------> | .metallib |
+----------+ | | (binary |
| .air |--+ | archive) |
+----------+ +------------+
Stage 3: metallib to PSO (at runtime)
At runtime, you load the .metallib into a MTLLibrary, look up a function by name, and then create a pipeline state. This final step performs GPU-specific optimizations – register allocation, instruction scheduling, occupancy tuning – that depend on the specific GPU model the code will run on.
+------------+ MTLDevice +-------------+ MTLDevice
| .metallib | ------------------> | MTLFunction | ----------------->
| (binary) | newLibrary(data:) | (looked up | newComputePipeline
+------------+ | by name) | State(function:)
+-------------+
+----------+
| PSO |
| (GPU- |
| ready) |
+----------+
Putting It All Together
Here is the full pipeline:
MSL Source (.metal)
|
| metal compiler (offline or runtime)
v
AIR (.air)
|
| metallib linker
v
Metal Library (.metallib)
|
| Load at runtime: device.newLibrary(data:)
v
MTLLibrary (in-memory)
|
| library.newFunction(name: "my_kernel")
v
MTLFunction
|
| device.newComputePipelineState(function:)
v
MTLComputePipelineState (PSO) <-- this is what you dispatch with
Runtime vs Offline Compilation
You have two choices for when compilation happens.
Runtime Compilation (from source)
You embed MSL source code as a string and compile it at runtime:
// Runtime compilation from source string
let source = """
kernel void add_arrays(device float* a [[buffer(0)]],
device float* b [[buffer(1)]],
device float* c [[buffer(2)]],
uint id [[thread_position_in_grid]]) {
c[id] = a[id] + b[id];
}
"""
let library = try device.makeLibrary(source: source, options: nil)
let function = library.makeFunction(name: "add_arrays")!
let pso = try device.makeComputePipelineState(function: function)
This is convenient for development and for cases where you need to generate shader code dynamically. The downside is that compilation happens at application startup, which can take tens to hundreds of milliseconds per kernel.
Offline Compilation (precompiled metallib)
You compile your shaders ahead of time using Apple’s command-line tools:
# Step 1: Compile MSL to AIR
xcrun -sdk macosx metal -c my_kernels.metal -o my_kernels.air
# Step 2: Link AIR into a metallib
xcrun -sdk macosx metallib my_kernels.air -o my_kernels.metallib
# Or do it in one step:
xcrun -sdk macosx metal my_kernels.metal -o my_kernels.metallib
Then at runtime, you just load the precompiled binary:
let libraryURL = Bundle.main.url(forResource: "my_kernels",
withExtension: "metallib")!
let library = try device.makeLibrary(URL: libraryURL)
This is what production applications should do. The runtime cost drops from “compile everything” to “just load a binary and create PSOs.” Akunu takes the offline compilation approach – its Metal kernels are precompiled into metallib files that ship with the library.
When Would You Use Runtime Compilation?
Runtime compilation is not just for lazy developers. There are legitimate reasons to compile at runtime:
- Generated kernels: When the kernel code depends on runtime parameters like tensor shapes, data types, or hardware capabilities.
- JIT specialization: Generating code paths optimized for specific input sizes that are only known at inference time.
- Plugin systems: Loading user-provided shader code.
That said, for a production ML inference engine, offline compilation is almost always the right default.
Function Specialization Constants
Here is where things get interesting for ML workloads. Metal supports function constants – compile-time constants that let you create specialized variants of the same kernel without duplicating source code.
Consider a GEMV (matrix-vector multiply) kernel. The optimal implementation depends on the matrix dimensions:
// Define function constants
constant uint K [[function_constant(0)]];
constant bool USE_BIAS [[function_constant(1)]];
kernel void gemv(device const half* W [[buffer(0)]],
device const half* x [[buffer(1)]],
device half* y [[buffer(2)]],
device const half* bias [[buffer(3)]],
uint row [[thread_position_in_grid]]) {
half sum = 0;
// The compiler knows K at compile time -- it can fully
// unroll this loop, choose optimal vector widths, etc.
for (uint i = 0; i < K; i += 4) {
half4 w = *((device const half4*)(W + row * K + i));
half4 v = *((device const half4*)(x + i));
sum += dot(w, v);
}
// Dead code elimination removes this branch entirely
// when USE_BIAS is false
if (USE_BIAS) {
sum += bias[row];
}
y[row] = sum;
}
At PSO creation time, you provide the constant values:
let constants = MTLFunctionConstantValues()
var k: UInt32 = 128
constants.setConstantValue(&k, type: .uint, index: 0)
var useBias: Bool = true
constants.setConstantValue(&useBias, type: .bool, index: 1)
let function = try library.makeFunction(name: "gemv",
constantValues: constants)
let pso = try device.makeComputePipelineState(function: function)
The compiler treats these constants exactly like #define values – it can unroll loops, eliminate dead branches, choose optimal vector sizes, and perform constant propagation. The result is a specialized binary for this specific combination of constant values.
Why does this matter for ML? Because ML models have fixed dimensions. A LLaMA-7B model always has a hidden dimension of 4096. An attention head is always 128-dimensional. Once you know the model architecture, you know exactly what constants to specialize on, and the compiler can produce code that is perfectly tuned for those dimensions.
Pipeline Caching Strategies
Creating a PSO involves GPU-specific compilation. On an M1 Max, creating a single pipeline takes somewhere between 5-50ms depending on kernel complexity. If you have 50 kernels, that is potentially seconds of startup time. And if you use function constants, each unique combination of constants is a separate PSO.
The solution is caching. Metal provides MTLBinaryArchive for persisting compiled pipelines to disk, but the more common approach in practice is to cache PSOs in memory and be smart about when you create them.
How Akunu Caches PSOs
Akunu maintains a PSO cache in its MetalDevice struct. The cache is a hash map keyed by a string that encodes the kernel name plus all specialization constant values:
PSO Cache (HashMap<String, MTLComputePipelineState>)
+-------------------------------------------+
| Key | Value (PSO) |
|------------------------+-------------------|
| "gemv_k128" | PSO_0x7f8a... |
| "gemv_k4096" | PSO_0x7f8b... |
| "gemm_tm32_tn64_tk16" | PSO_0x7f8c... |
| "softmax_n128" | PSO_0x7f8d... |
| "rmsnorm_n4096" | PSO_0x7f8e... |
| "rope_dim128_base10000"| PSO_0x7f8f... |
+-------------------------------------------+
The first time you need a PSO with specific constants, it gets compiled and stored. Every subsequent use is a hash table lookup – essentially free. In Rust, the cache lives inside the MetalDevice:
#![allow(unused)]
fn main() {
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
library: metal::Library,
pso_cache: HashMap<String, metal::ComputePipelineState>,
// ...
}
impl MetalDevice {
fn get_or_create_pso(
&mut self,
kernel_name: &str,
constants: &[(u32, FunctionConstantValue)],
) -> &metal::ComputePipelineState {
// Build cache key: "kernel_name_const0val_const1val_..."
let key = build_cache_key(kernel_name, constants);
self.pso_cache.entry(key).or_insert_with(|| {
let func_constants = metal::FunctionConstantValues::new();
for (index, value) in constants {
match value {
FunctionConstantValue::UInt(v) => {
func_constants.set_constant_value_at_index(
v as *const _ as *const _,
metal::MTLDataType::UInt,
*index as u64,
);
}
// ... other types
}
}
let function = self.library
.get_function(kernel_name, Some(func_constants))
.expect("kernel not found");
self.device
.new_compute_pipeline_state_with_function(&function)
.expect("failed to create PSO")
})
}
}
}
The key insight is that for ML inference, the set of PSOs you need is determined entirely by the model architecture. A LLaMA model uses the same kernels with the same dimensions for every token it generates. So in practice, all PSOs get created during model loading, and inference itself never hits a cache miss.
Model Loading Phase: Inference Phase:
+------------------------+ +------------------------+
| For each layer: | | For each token: |
| - Create GEMM PSOs | | - Look up PSO (hit) |
| - Create norm PSOs | | - Encode dispatch |
| - Create attn PSOs | | - Execute on GPU |
| (one-time cost) | | (no compilation!) |
+------------------------+ +------------------------+
~100-500ms total ~0ms PSO overhead
Command Encoding Flow
Now that we have a PSO, let us actually use it. In Metal, you do not call GPU functions directly. Instead, you encode commands into a command buffer, and then submit the entire buffer to the GPU for execution.
The Full Encoding Stack
MTLCommandQueue
|
| makeCommandBuffer()
v
MTLCommandBuffer
|
| makeComputeCommandEncoder()
v
MTLComputeCommandEncoder
|
| setComputePipelineState() -- which kernel to run
| setBuffer(buffer, offset, index) -- bind data
| setBytes(ptr, length, index) -- inline small data
| dispatchThreadgroups(...) -- launch!
|
| endEncoding()
v
MTLCommandBuffer
|
| commit() -- send to GPU
| waitUntilCompleted() -- block until done
v
[GPU executes the commands]
Let us walk through this with a concrete example. Say we want to run our add_arrays kernel on two arrays of 1024 floats.
Step 1: Create a Command Buffer
let commandBuffer = commandQueue.makeCommandBuffer()!
A command buffer is a container for GPU commands. You can think of it as a recording of work to be done. Nothing executes until you call commit().
Step 2: Create a Compute Command Encoder
let encoder = commandBuffer.makeComputeCommandEncoder()!
The encoder is how you record compute commands. Metal separates encoding (recording) from execution (running) so that you can build up a batch of work and submit it all at once.
Step 3: Set the Pipeline State
encoder.setComputePipelineState(addArraysPSO)
This tells the GPU which kernel function to run for subsequent dispatch calls.
Step 4: Bind Data
encoder.setBuffer(bufferA, offset: 0, index: 0) // buffer(0) in MSL
encoder.setBuffer(bufferB, offset: 0, index: 1) // buffer(1) in MSL
encoder.setBuffer(bufferC, offset: 0, index: 2) // buffer(2) in MSL
Each setBuffer call binds a MTLBuffer to a buffer index. These indices correspond to the [[buffer(N)]] attributes in your MSL kernel.
Step 5: Dispatch
let gridSize = MTLSize(width: 1024, height: 1, depth: 1)
let threadgroupSize = MTLSize(width: 256, height: 1, depth: 1)
encoder.dispatchThreadgroups(
MTLSize(width: 1024 / 256, height: 1, depth: 1), // 4 threadgroups
threadsPerThreadgroup: threadgroupSize
)
We will cover dispatch geometry in detail in Chapter 9. For now, just know that we are launching 1024 threads organized into groups of 256.
Step 6: End Encoding and Commit
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
endEncoding() finalizes the encoder – you cannot add more commands after this. commit() submits the command buffer to the GPU. waitUntilCompleted() blocks the CPU until the GPU finishes (in production you would typically use completion handlers or addCompletedHandler instead of blocking).
Encoding Multiple Dispatches
For ML inference, a single command buffer typically contains many dispatch calls – one per layer operation:
let encoder = commandBuffer.makeComputeCommandEncoder()!
// Layer 1: RMSNorm
encoder.setComputePipelineState(rmsNormPSO)
encoder.setBuffer(hiddenState, offset: 0, index: 0)
encoder.setBuffer(normWeights, offset: 0, index: 1)
encoder.setBuffer(normalizedOutput, offset: 0, index: 2)
encoder.dispatchThreadgroups(normGrid, threadsPerThreadgroup: normThreadgroup)
// Layer 1: Q,K,V projection (GEMM)
encoder.setComputePipelineState(gemmPSO)
encoder.setBuffer(normalizedOutput, offset: 0, index: 0)
encoder.setBuffer(qkvWeights, offset: 0, index: 1)
encoder.setBuffer(qkvOutput, offset: 0, index: 2)
encoder.dispatchThreadgroups(gemmGrid, threadsPerThreadgroup: gemmThreadgroup)
// Layer 1: Attention
encoder.setComputePipelineState(attentionPSO)
// ... more setBuffer calls ...
encoder.dispatchThreadgroups(attnGrid, threadsPerThreadgroup: attnThreadgroup)
// ... and so on for all layers ...
encoder.endEncoding()
commandBuffer.commit()
The GPU executes these dispatches in order (within the same encoder), which gives you implicit synchronization between dependent operations. The output of RMSNorm flows into GEMM which flows into attention – no explicit barriers needed within a single encoder.
The setBytes Optimization
For small amounts of data (under 4KB), Metal provides setBytes() as an alternative to setBuffer(). Instead of allocating a GPU buffer and copying data into it, setBytes() inlines the data directly into the command buffer:
struct GEMMParams {
var M: UInt32
var N: UInt32
var K: UInt32
var alpha: Float
}
var params = GEMMParams(M: 4096, N: 4096, K: 4096, alpha: 1.0)
encoder.setBytes(¶ms, length: MemoryLayout<GEMMParams>.size, index: 3)
This avoids the overhead of buffer allocation for small, frequently-changing parameter structs. Akunu uses this pattern extensively – kernel parameters like matrix dimensions, strides, and scalar values are all passed via setBytes().
setBuffer path (for large data):
CPU Memory --> MTLBuffer (GPU-accessible) --> Kernel reads buffer(N)
[allocate] [copy or share] [dereference pointer]
setBytes path (for small params):
CPU Memory --> Embedded in command buffer --> Kernel reads buffer(N)
[no alloc] [inline copy, < 4KB] [same interface!]
From the kernel’s perspective, there is no difference – both appear as [[buffer(N)]] arguments. The optimization is entirely on the CPU side.
How Akunu Wraps the Metal Pipeline
Akunu’s Rust code wraps all of this Metal machinery behind a clean interface. The MetalDevice struct owns the Metal device, command queue, library, and PSO cache. Individual operations (GEMM, softmax, RMSNorm, etc.) are methods on MetalDevice that handle all the encoding details internally.
Here is a simplified view of how a GEMM dispatch looks in akunu:
#![allow(unused)]
fn main() {
impl MetalDevice {
pub fn gemm(
&mut self,
a: &MetalBuffer, // M x K matrix
b: &MetalBuffer, // K x N matrix
c: &MetalBuffer, // M x N output matrix
m: u32,
n: u32,
k: u32,
) -> Result<()> {
// 1. Get or create the PSO (cached)
let cache_key = format!("gemm_tm32_tn64_tk{}", k);
let pso = self.get_or_create_pso("gemm", &[
(0, FunctionConstantValue::UInt(k)),
]);
// 2. Create a command buffer and encoder
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
// 3. Set pipeline and bind buffers
encoder.set_compute_pipeline_state(&pso);
encoder.set_buffer(0, Some(&a.buffer), 0);
encoder.set_buffer(1, Some(&b.buffer), 0);
encoder.set_buffer(2, Some(&c.buffer), 0);
// 4. Set parameters via setBytes
let params = GemmParams { m, n, k };
encoder.set_bytes(
3,
std::mem::size_of::<GemmParams>() as u64,
¶ms as *const _ as *const _,
);
// 5. Calculate dispatch geometry
let threadgroup_size = MTLSize::new(128, 1, 1); // 4 SIMD groups
let grid_size = MTLSize::new(
(n + 63) / 64, // tiles along N
(m + 31) / 32, // tiles along M
1,
);
// 6. Dispatch and finalize
encoder.dispatch_threadgroups(grid_size, threadgroup_size);
encoder.end_encoding();
command_buffer.commit();
Ok(())
}
}
}
Notice the pattern: get PSO, create encoder, bind resources, dispatch, finalize. Every operation follows this same template. The differences are in which kernel is used, what buffers are bound, and how the dispatch geometry is calculated.
Batching Multiple Operations
For inference, akunu batches all the operations for a single forward pass into one command buffer:
#![allow(unused)]
fn main() {
impl MetalDevice {
pub fn forward_pass(
&mut self,
model: &Model,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
for layer in &model.layers {
// RMSNorm
self.encode_rmsnorm(&encoder, &layer.norm_weights, ...);
// QKV projection
self.encode_gemm(&encoder, &layer.qkv_weights, ...);
// Attention
self.encode_attention(&encoder, ...);
// Feed-forward network
self.encode_gemm(&encoder, &layer.ff_weights, ...);
}
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
Ok(())
}
}
}
This is efficient because:
- One command buffer: Only one submission to the GPU driver, reducing CPU-side overhead.
- Sequential execution: Within a single encoder, dispatches execute in order, providing implicit synchronization.
- Cached PSOs: All pipeline states are already compiled from model loading.
- No CPU-GPU round trips: The CPU records all the work and submits it in one shot. The GPU executes everything without waiting for the CPU.
Advanced: Pipeline Reflection and Validation
When you create a PSO, you can ask Metal for reflection information – metadata about the kernel’s resource usage:
var reflection: MTLAutoreleasedComputePipelineReflection?
let pso = try device.makeComputePipelineState(
function: function,
options: .argumentInfo,
reflection: &reflection
)
// Now you can inspect resource usage
if let reflection = reflection {
for arg in reflection.arguments {
print("Argument: \(arg.name), index: \(arg.index)")
print(" Type: \(arg.type), access: \(arg.access)")
print(" Size: \(arg.bufferDataSize)")
}
}
This is useful for debugging and for building generic dispatch systems that can validate buffer bindings at runtime. If you pass a buffer that is too small or bind the wrong type, the validation layer will catch it.
In debug builds, Metal’s validation layer (enabled via the MTL_DEBUG_LAYER environment variable) performs extensive checking:
# Enable Metal validation
export MTL_DEBUG_LAYER=1
# Enable shader validation (slower but catches more bugs)
export MTL_SHADER_VALIDATION=1
These are invaluable during development. They catch things like:
- Buffer overruns (reading/writing past the end of a buffer)
- Missing buffer bindings
- Threadgroup memory size mismatches
- Resource hazards (reading a buffer that is still being written)
Async Pipeline Creation
PSO creation can be done asynchronously, which is important when you need to create many pipelines at startup:
let group = DispatchGroup()
for kernelConfig in allKernelConfigs {
group.enter()
device.makeComputePipelineState(function: function) { pso, error in
if let pso = pso {
cache[kernelConfig.key] = pso
}
group.leave()
}
}
group.wait() // All PSOs ready
This lets the Metal runtime parallelize compilation across CPU cores. For a model with dozens of kernel variants, this can significantly reduce startup time.
Summary
Let us recap the key points:
- A PSO is a compiled GPU kernel – expensive to create, cheap to use.
- Compilation is multi-stage: MSL source -> AIR -> metallib -> PSO.
- Offline compilation (precompiled metallib) eliminates runtime compilation cost.
- Function constants create specialized kernel variants without code duplication – critical for ML where dimensions are fixed per model.
- PSO caching (like akunu’s
pso_cache_) ensures each kernel variant is compiled only once. - Command encoding follows a simple pattern: begin, set pipeline, bind resources, dispatch, end.
- setBytes is an optimization for small parameter structs (< 4KB).
- Batch multiple operations into a single command buffer for efficiency.
In the next chapter, we will dive into the Metal Shading Language itself – the language you use to write those kernel functions.
The Metal Shading Language
The Metal Shading Language – MSL for short – is what you write GPU kernels in. If you already know C++, you are going to feel right at home, because MSL is essentially C++14 with some things removed (no exceptions, no virtual functions, no RTTI, no recursive function calls) and a bunch of GPU-specific features added. It is not a weird domain-specific language. It is not a visual graph editor. It is just C++ with extra address spaces, vector types, and thread-coordination primitives bolted on.
In this chapter, we are going to cover the MSL features that matter most for ML kernel development. We will write several complete kernels along the way, starting simple and building up to real patterns used in inference engines.
The Basics: Kernel Functions
A compute kernel in MSL is declared with the kernel keyword (or equivalently, [[kernel]]):
#include <metal_stdlib>
using namespace metal;
kernel void add_arrays(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* c [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
c[id] = a[id] + b[id];
}
Let us break down every piece of this.
kernel void: This is a GPU entry point. Kernels always returnvoid– they communicate results by writing to buffers.device const float* a: A pointer into device memory (more on address spaces below). The[[buffer(0)]]attribute tells Metal which buffer binding index this parameter corresponds to.uint id [[thread_position_in_grid]]: A built-in variable that gives this thread its unique ID in the dispatch grid.
That is the most minimal kernel possible: read two values, add them, write the result. Every thread processes one element. A thousand threads process a thousand elements. The GPU runs them all in parallel.
Address Spaces
In regular C++, all pointers live in one big flat address space. On the GPU, memory is divided into distinct address spaces with different performance characteristics and access rules. Every pointer in MSL must be qualified with its address space.
+------------------------------------------------------------------+
| GPU Memory Map |
+------------------------------------------------------------------+
| |
| device Large, slow-ish (DRAM / unified memory) |
| [read/write] All threads can access. This is where your |
| weight matrices and activation tensors live. |
| Hundreds of GB/s bandwidth on Apple Silicon. |
| |
+------------------------------------------------------------------+
| |
| constant Same physical memory as device, but goes |
| [read-only] through the constant cache. Best for small, |
| uniform data read by all threads (like params). |
| Limited to 64KB per argument. |
| |
+------------------------------------------------------------------+
| |
| threadgroup Fast on-chip SRAM (like CUDA shared memory). |
| [read/write] Only threads in the same threadgroup can access. |
| 32KB per threadgroup on Apple Silicon. |
| Much faster than device memory. |
| |
+------------------------------------------------------------------+
| |
| thread Registers. Private to each thread. |
| [read/write] Fastest possible access. |
| Limited by register file size. |
| |
+------------------------------------------------------------------+
Here is how they look in practice:
kernel void address_space_demo(
device float* data [[buffer(0)]], // device address space
constant Params& params [[buffer(1)]], // constant address space
threadgroup float* shared [[threadgroup(0)]], // threadgroup (rarely used this way)
uint tid [[thread_position_in_grid]]
) {
// 'thread' address space -- local variable, lives in registers
float local_val = data[tid];
// Read from constant (cached, good for broadcast reads)
float scale = params.scale;
// Write to device (main memory)
data[tid] = local_val * scale;
}
When to Use Each Address Space
| Address Space | Use For | Example |
|---|---|---|
device | Large read/write buffers | Weight matrices, activations, outputs |
constant | Small read-only data broadcast to all threads | Kernel parameters, lookup tables |
threadgroup | Shared scratchpad for threads in a group | Reduction accumulators, tiled data |
thread | Per-thread temporaries | Loop variables, accumulators |
The most common mistake newcomers make is putting everything in device. If you have a small struct of parameters (dimensions, strides, scaling factors), put it in constant – the constant cache will broadcast it to all threads efficiently.
Vector and Matrix Types
MSL provides built-in vector types that map directly to the GPU’s SIMD hardware. These are not library abstractions – they are native types that the hardware processes in a single instruction.
Scalar Types
half h = 1.0h; // 16-bit float (FP16)
float f = 1.0f; // 32-bit float (FP32)
int i = 42; // 32-bit signed integer
uint u = 42u; // 32-bit unsigned integer
short s = 42; // 16-bit signed integer
ushort us = 42u; // 16-bit unsigned integer
bool b = true; // boolean
Vector Types
Vectors come in sizes 2, 3, and 4:
half2 h2 = half2(1.0h, 2.0h); // 2 x FP16
half4 h4 = half4(1.0h, 2.0h, 3.0h, 4.0h); // 4 x FP16
float2 f2 = float2(1.0f, 2.0f); // 2 x FP32
float4 f4 = float4(1.0f, 2.0f, 3.0f, 4.0f); // 4 x FP32
uint2 u2 = uint2(10, 20); // 2 x uint32
You can swizzle components (rearrange them) using .xyzw or .rgba:
float4 v = float4(1.0, 2.0, 3.0, 4.0);
float2 xy = v.xy; // (1.0, 2.0)
float z = v.z; // 3.0
float4 rev = v.wzyx; // (4.0, 3.0, 2.0, 1.0)
Why Vectors Matter for ML
Loading a half4 from memory is a single 64-bit load instead of four separate 16-bit loads. This is huge for memory-bound kernels (which most ML kernels are):
Single loads (bad): Vectorized load (good):
load h[0] -- 16 bits load h4 -- 64 bits, one instruction
load h[1] -- 16 bits
load h[2] -- 16 bits 4x fewer load instructions
load h[3] -- 16 bits Better memory bus utilization
4 instructions total 1 instruction total
We will come back to vectorized loads in the performance chapter. For now, just remember: prefer half4 and float4 over scalar loads whenever possible.
Thread Identification
Every thread in a compute dispatch has several built-in identifiers. Understanding which one to use is crucial.
kernel void thread_id_demo(
uint tid [[thread_position_in_grid]],
uint tpg [[threads_per_grid]],
uint gid [[threadgroup_position_in_grid]],
uint tgpg [[threadgroups_per_grid]],
uint lid [[thread_position_in_threadgroup]],
uint tptg [[threads_per_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// tid = global thread ID (0, 1, 2, ..., total_threads - 1)
// gid = which threadgroup this thread belongs to
// lid = local thread ID within the threadgroup (0, 1, ..., tptg - 1)
// simd_lane = lane within the SIMD group (0 - 31)
// simd_id = which SIMD group within the threadgroup
}
Here is how they relate:
Grid (all threads)
+------------------------------------------------------------------+
| Threadgroup 0 (gid=0) | Threadgroup 1 (gid=1) | ... |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 0 (simd=0) | | | SIMD Group 0 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 1 (simd=1) | | | SIMD Group 1 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 2 (simd=2) | | | SIMD Group 2 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 3 (simd=3) | | | SIMD Group 3 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
+------------------------------------------------------------------+
For threadgroup size = 128:
lid = simd_id * 32 + simd_lane
tid = gid * 128 + lid
Which ID Should You Use?
thread_position_in_grid: When each thread processes one element of an array. By far the most common.threadgroup_position_in_grid+thread_position_in_threadgroup: When threads cooperate within a threadgroup (reductions, tiled algorithms).thread_index_in_simdgroup+simdgroup_index_in_threadgroup: When you are doing SIMD-level tricks (shuffles, reductions, matrix ops).
The half Type: FP16 for ML
The half type is a 16-bit floating-point number. It has:
- 1 sign bit
- 5 exponent bits (range: ~6 x 10^-8 to 65504)
- 10 mantissa bits (~3.3 decimal digits of precision)
For ML inference, half is the workhorse type. Here is why:
- 2x throughput: Apple GPUs can process twice as many
halfoperations per clock compared tofloat. - 2x memory bandwidth: A
halfis 2 bytes vs 4 bytes forfloat. Since most ML kernels are memory-bound, this effectively doubles your throughput. - Sufficient precision: Neural network weights and activations rarely need more than 3 digits of precision. FP16 is more than enough for inference.
// FP16 dot product -- 2x the throughput of FP32
kernel void dot_product_fp16(
device const half4* a [[buffer(0)]],
device const half4* b [[buffer(1)]],
device half* result [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
// Each thread computes dot product of one half4 pair
// That is 4 multiply-adds in a single vector instruction
half4 va = a[id];
half4 vb = b[id];
result[id] = dot(va, vb); // va.x*vb.x + va.y*vb.y + ...
}
Mixed Precision
Sometimes you want FP16 for storage and memory transfer but FP32 for accumulation to avoid precision loss:
kernel void gemv_mixed_precision(
device const half4* W [[buffer(0)]], // Weights in FP16
device const half4* x [[buffer(1)]], // Input in FP16
device float* y [[buffer(2)]], // Output in FP32
constant uint& K [[buffer(3)]],
uint row [[thread_position_in_grid]]
) {
float sum = 0.0f; // Accumulate in FP32 for precision
uint k4 = K / 4; // Process 4 elements at a time via half4
for (uint i = 0; i < k4; i++) {
half4 w = W[row * k4 + i];
half4 v = x[i];
// Multiply in FP16, accumulate in FP32
sum += float(w.x) * float(v.x);
sum += float(w.y) * float(v.y);
sum += float(w.z) * float(v.z);
sum += float(w.w) * float(v.w);
}
y[row] = sum;
}
This is a common pattern in ML inference: load data as FP16 (saving bandwidth), accumulate in FP32 (preserving precision), and write the result back as whatever the downstream consumer needs.
SIMD Group Intrinsics
Now we get to the really fun stuff. SIMD group intrinsics let threads within a 32-thread SIMD group communicate directly through the register file – no shared memory, no barriers, no synchronization overhead.
On Apple GPUs, a SIMD group (also called a warp on NVIDIA or a wavefront on AMD) is always 32 threads. These 32 threads execute in lockstep – the same instruction, at the same time, on adjacent data. SIMD intrinsics let these threads share values with each other.
simd_sum: Reduction Within a SIMD Group
// Sum 32 values (one from each lane) in a single operation
float lane_value = data[tid];
float total = simd_sum(lane_value);
// Now ALL 32 lanes have the same total
Without simd_sum, you would need a tree reduction with shared memory and barriers. With it, the hardware does it for you in a few cycles.
Before simd_sum: After simd_sum:
Lane 0: 3.0 Lane 0: sum of all 32
Lane 1: 1.0 Lane 1: sum of all 32
Lane 2: 4.0 Lane 2: sum of all 32
... ...
Lane 31: 2.0 Lane 31: sum of all 32
simd_max / simd_min: Finding Extremes
float lane_value = data[tid];
float max_val = simd_max(lane_value); // Max across all 32 lanes
float min_val = simd_min(lane_value); // Min across all 32 lanes
This is essential for softmax, which needs the maximum value before computing exponentials.
simd_shuffle: Direct Lane-to-Lane Communication
// Read the value from a specific lane
float val = simd_shuffle(my_value, target_lane);
simd_shuffle lets any lane read any other lane’s value. It is like having a crossbar switch between all 32 registers.
simd_shuffle_down: Shift Values Down
float val = simd_shuffle_down(my_value, delta);
// Lane i gets the value from lane (i + delta)
Before simd_shuffle_down(val, 1): After:
Lane 0: A Lane 0: B (got lane 1's value)
Lane 1: B Lane 1: C (got lane 2's value)
Lane 2: C Lane 2: D (got lane 3's value)
Lane 3: D Lane 3: E (got lane 4's value)
... ...
This is the building block for prefix sums and sequential reductions.
simd_shuffle_xor: Butterfly Pattern
float val = simd_shuffle_xor(my_value, mask);
// Lane i gets the value from lane (i XOR mask)
simd_shuffle_xor(val, 1): // Swap adjacent pairs
Lane 0 <-> Lane 1
Lane 2 <-> Lane 3
Lane 4 <-> Lane 5
...
simd_shuffle_xor(val, 2): // Swap pairs of pairs
Lane 0 <-> Lane 2
Lane 1 <-> Lane 3
Lane 4 <-> Lane 6
...
This butterfly pattern is used in parallel reductions and FFTs.
Putting It Together: SIMD-Level Softmax
Here is a practical example – computing softmax over 32 elements using only SIMD intrinsics (no shared memory):
kernel void softmax_simd(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint lane [[thread_index_in_simdgroup]],
uint group [[simdgroup_index_in_threadgroup]],
uint gid [[threadgroup_position_in_grid]]
) {
// Each SIMD group processes one row of 32 elements
uint row = gid * (threads_per_threadgroup / 32) + group;
float val = input[row * 32 + lane];
// Step 1: Find max across all 32 lanes
float max_val = simd_max(val);
// Step 2: Subtract max and exponentiate (for numerical stability)
float exp_val = exp(val - max_val);
// Step 3: Sum all exponentials
float sum_exp = simd_sum(exp_val);
// Step 4: Normalize
output[row * 32 + lane] = exp_val / sum_exp;
}
That is a complete softmax in about 10 lines, with no shared memory and no barriers. Each SIMD group handles one 32-element row entirely within its register file. You cannot write this more concisely or more efficiently.
Threadgroup Memory and Barriers
When you need cooperation between threads in different SIMD groups (but within the same threadgroup), you use threadgroup memory and barriers.
Threadgroup memory is a fast on-chip SRAM shared by all threads in a threadgroup. Think of it as a programmer-managed cache.
kernel void reduction_with_shared(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint gid [[threadgroup_position_in_grid]],
uint tptg [[threads_per_threadgroup]]
) {
// Declare threadgroup memory
threadgroup float shared[256]; // One slot per thread
// Each thread loads one element into shared memory
shared[lid] = input[tid];
// BARRIER: Wait for all threads to finish writing
threadgroup_barrier(mem_flags::mem_threadgroup);
// Tree reduction in shared memory
for (uint stride = tptg / 2; stride > 0; stride /= 2) {
if (lid < stride) {
shared[lid] += shared[lid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Thread 0 writes the result
if (lid == 0) {
output[gid] = shared[0];
}
}
The threadgroup_barrier(mem_flags::mem_threadgroup) call is critical. It does two things:
- Execution barrier: All threads in the threadgroup must reach this point before any can proceed.
- Memory fence: All writes to threadgroup memory by threads before the barrier are visible to all threads after the barrier.
Without the barrier, you would read stale data – thread 5 might read shared[10] before thread 10 has written to it. Race conditions on the GPU are just as dangerous as on the CPU, and they are much harder to debug.
When to Use Threadgroup Memory vs SIMD Intrinsics
+---------------------------------------+---------------------------------------+
| SIMD Intrinsics | Threadgroup Memory |
+---------------------------------------+---------------------------------------+
| 32 threads only (one SIMD group) | Up to 1024 threads (one threadgroup) |
| No barriers needed (lockstep) | Barriers required |
| Fastest (register-to-register) | Fast (on-chip SRAM) |
| Simple patterns (reduce, broadcast) | Arbitrary access patterns |
| Use FIRST if possible | Use when SIMD is not enough |
+---------------------------------------+---------------------------------------+
The rule of thumb: if you can do it with SIMD intrinsics alone, do that. Fall back to threadgroup memory only when you need cross-SIMD-group communication.
Example: RMSNorm Kernel
Let us write a real kernel used in transformer inference. RMSNorm (Root Mean Square Normalization) is used in LLaMA and many modern models:
RMSNorm(x) = x * weight / sqrt(mean(x^2) + epsilon)
Here is the full kernel:
#include <metal_stdlib>
using namespace metal;
constant uint N [[function_constant(0)]]; // hidden dimension
kernel void rmsnorm(
device const half* input [[buffer(0)]],
device const half* weight [[buffer(1)]],
device half* output [[buffer(2)]],
constant float& epsilon [[buffer(3)]],
uint tid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Each threadgroup processes one row of N elements
// Step 1: Compute sum of squares (each thread handles N/tptg elements)
uint threads = 256; // threadgroup size
float sum_sq = 0.0f;
for (uint i = lid; i < N; i += threads) {
float val = float(input[i]);
sum_sq += val * val;
}
// Step 2: Reduce within SIMD group
sum_sq = simd_sum(sum_sq);
// Step 3: Reduce across SIMD groups using threadgroup memory
threadgroup float simd_sums[8]; // max 8 SIMD groups (256/32)
if (simd_lane == 0) {
simd_sums[simd_id] = sum_sq;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// First SIMD group reduces the partial sums
float total = 0.0f;
if (simd_id == 0 && simd_lane < (threads / 32)) {
total = simd_sum(simd_sums[simd_lane]);
}
// Broadcast the RMS scale factor
threadgroup float rms_scale;
if (lid == 0) {
rms_scale = rsqrt(total / float(N) + epsilon);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 4: Apply normalization
float scale = rms_scale;
for (uint i = lid; i < N; i += threads) {
float val = float(input[i]);
output[i] = half(val * scale * float(weight[i]));
}
}
This kernel demonstrates several important MSL patterns:
- Function constants (
N) for compile-time specialization. - Mixed precision: Load as
half, compute infloat, store ashalf. - Two-stage reduction: First within SIMD groups (
simd_sum), then across SIMD groups via threadgroup memory. - Stride loop pattern: Each thread processes multiple elements (
for (uint i = lid; i < N; i += threads)).
Example: Vectorized Element-Wise Operations
For simple element-wise operations (add, multiply, activation functions), vectorized loads are the key optimization:
kernel void silu_activation(
device const half4* input [[buffer(0)]],
device half4* output [[buffer(1)]],
uint id [[thread_position_in_grid]]
) {
half4 x = input[id];
// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
half4 sigmoid_x = half4(1.0h) / (half4(1.0h) + exp(-x));
output[id] = x * sigmoid_x;
}
By processing half4 (4 elements per thread), each thread does useful work on 8 bytes of data per load. With a threadgroup size of 256, one threadgroup processes 1024 elements. This keeps the memory system saturated.
Example: Fused RoPE (Rotary Position Embeddings)
RoPE is used in virtually every modern LLM. It applies a rotation to pairs of dimensions based on their position in the sequence:
constant uint HEAD_DIM [[function_constant(0)]];
constant float ROPE_BASE [[function_constant(1)]];
kernel void rope(
device half* q [[buffer(0)]], // query tensor
device half* k [[buffer(1)]], // key tensor
constant uint& seq_pos [[buffer(2)]], // position in sequence
uint tid [[thread_position_in_grid]]
) {
// Each thread processes one pair of dimensions
uint pair_idx = tid; // Which (cos, sin) pair
uint d = pair_idx % (HEAD_DIM / 2);
uint head = pair_idx / (HEAD_DIM / 2);
// Compute rotation angle: theta = pos * base^(-2d/dim)
float freq = 1.0f / pow(ROPE_BASE, float(2 * d) / float(HEAD_DIM));
float angle = float(seq_pos) * freq;
float cos_angle = cos(angle);
float sin_angle = sin(angle);
// Apply rotation to Q
uint base_idx = head * HEAD_DIM;
float q0 = float(q[base_idx + d]);
float q1 = float(q[base_idx + d + HEAD_DIM / 2]);
q[base_idx + d] = half(q0 * cos_angle - q1 * sin_angle);
q[base_idx + d + HEAD_DIM/2] = half(q0 * sin_angle + q1 * cos_angle);
// Apply same rotation to K
float k0 = float(k[base_idx + d]);
float k1 = float(k[base_idx + d + HEAD_DIM / 2]);
k[base_idx + d] = half(k0 * cos_angle - k1 * sin_angle);
k[base_idx + d + HEAD_DIM/2] = half(k0 * sin_angle + k1 * cos_angle);
}
Notice the function constants: HEAD_DIM and ROPE_BASE. These are fixed per model (e.g., 128 and 10000.0 for LLaMA), so the compiler can optimize the frequency computation and potentially pre-compute parts of it.
Built-in Math Functions
MSL provides a comprehensive set of math functions. Here are the ones you will use most often in ML kernels:
// Exponential and logarithm
float e = exp(x); // e^x
float l = log(x); // natural log
half e_fast = exp(h); // FP16 exp -- 2x throughput
// Trigonometric
float s = sin(x);
float c = cos(x);
// Power and roots
float p = pow(base, exponent);
float s = sqrt(x);
float r = rsqrt(x); // 1/sqrt(x) -- faster than 1.0/sqrt(x)
// Min, max, clamp
float m = min(a, b);
float M = max(a, b);
float c = clamp(x, lo, hi); // max(lo, min(x, hi))
// Absolute value
float a = abs(x);
// Fused multiply-add (one rounding instead of two)
float f = fma(a, b, c); // a*b + c
// Dot product of vectors
float d = dot(v1, v2); // v1.x*v2.x + v1.y*v2.y + ...
// Type conversion (explicit)
half h = half(f); // float -> half (may lose precision)
float f = float(h); // half -> float (lossless)
The rsqrt function deserves special mention. You will see it everywhere in ML kernels (normalization layers, attention scaling). It computes 1/sqrt(x) in a single instruction, which is faster than computing sqrt(x) and then dividing.
Structs and Parameter Passing
For kernels with many parameters, pack them into a struct:
struct GEMMParams {
uint M; // rows of A, rows of C
uint N; // cols of B, cols of C
uint K; // cols of A, rows of B
float alpha; // scaling factor
};
kernel void gemm(
device const half* A [[buffer(0)]],
device const half* B [[buffer(1)]],
device half* C [[buffer(2)]],
constant GEMMParams& params [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]]
) {
uint M = params.M;
uint N = params.N;
uint K = params.K;
// ...
}
On the CPU side, this struct gets passed via setBytes() (for structs under 4KB) or via a MTLBuffer. The struct layout must match between MSL and the host language. This means you need to be careful about alignment – MSL follows standard C++ alignment rules, and your Rust/Swift struct must match.
#![allow(unused)]
fn main() {
// Rust side -- must match the MSL struct layout
#[repr(C)]
struct GEMMParams {
m: u32,
n: u32,
k: u32,
alpha: f32,
}
}
The #[repr(C)] attribute ensures Rust uses C-compatible layout, which matches MSL’s struct layout.
Conditional Compilation with Function Constants
We touched on function constants in the previous chapter. Here is a deeper look at how they enable conditional compilation in MSL:
constant bool HAS_BIAS [[function_constant(0)]];
constant bool HAS_RESIDUAL[[function_constant(1)]];
constant uint BLOCK_SIZE [[function_constant(2)]];
kernel void linear_layer(
device const half* input [[buffer(0)]],
device const half* weights [[buffer(1)]],
device half* output [[buffer(2)]],
device const half* bias [[buffer(3)]],
device const half* residual[[buffer(4)]],
constant uint& N [[buffer(5)]],
uint tid [[thread_position_in_grid]]
) {
half result = /* ... compute matmul ... */ ;
// These branches are resolved at compile time!
// No runtime branch penalty.
if (HAS_BIAS) {
result += bias[tid];
}
if (HAS_RESIDUAL) {
result += residual[tid];
}
output[tid] = result;
}
When HAS_BIAS is false, the compiler completely removes the bias addition and the bias buffer binding. The generated code is identical to a kernel that never had bias support in the first place. This lets you write one kernel source that generates many specialized variants.
Atomic Operations
Sometimes multiple threads need to update the same memory location. MSL provides atomic operations for this:
kernel void histogram(
device const uint* data [[buffer(0)]],
device atomic_uint* bins [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
uint value = data[tid];
uint bin = value % 256;
atomic_fetch_add_explicit(&bins[bin], 1, memory_order_relaxed);
}
Atomics are slow – they serialize access. Avoid them in hot paths. If you find yourself using atomics in a performance-critical kernel, there is almost certainly a better algorithm (like per-threadgroup histograms followed by a merge).
Putting It All Together: A Complete Softmax Kernel
Let us write a production-quality softmax kernel that handles arbitrary row sizes using all the MSL features we have covered:
#include <metal_stdlib>
using namespace metal;
constant uint COLS [[function_constant(0)]];
kernel void softmax(
device const half* input [[buffer(0)]],
device half* output [[buffer(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Each threadgroup processes one row
uint row = gid;
uint threads = 256;
// Threadgroup memory for cross-SIMD-group reduction
threadgroup float simd_max_vals[8];
threadgroup float simd_sum_vals[8];
// ---- Pass 1: Find max ----
float local_max = -INFINITY;
for (uint i = lid; i < COLS; i += threads) {
local_max = max(local_max, float(input[row * COLS + i]));
}
// Reduce within SIMD group
local_max = simd_max(local_max);
// Reduce across SIMD groups
if (simd_lane == 0) {
simd_max_vals[simd_id] = local_max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float row_max;
if (simd_id == 0) {
float v = (simd_lane < threads / 32) ?
simd_max_vals[simd_lane] : -INFINITY;
row_max = simd_max(v);
}
// Broadcast max to all threads
threadgroup float shared_max;
if (lid == 0) shared_max = row_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
row_max = shared_max;
// ---- Pass 2: Compute exp and sum ----
float local_sum = 0.0f;
for (uint i = lid; i < COLS; i += threads) {
float val = exp(float(input[row * COLS + i]) - row_max);
local_sum += val;
}
// Reduce sum (same two-stage pattern)
local_sum = simd_sum(local_sum);
if (simd_lane == 0) {
simd_sum_vals[simd_id] = local_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float row_sum;
if (simd_id == 0) {
float v = (simd_lane < threads / 32) ?
simd_sum_vals[simd_lane] : 0.0f;
row_sum = simd_sum(v);
}
threadgroup float shared_sum;
if (lid == 0) shared_sum = row_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
row_sum = shared_sum;
// ---- Pass 3: Normalize and write output ----
float inv_sum = 1.0f / row_sum;
for (uint i = lid; i < COLS; i += threads) {
float val = exp(float(input[row * COLS + i]) - row_max);
output[row * COLS + i] = half(val * inv_sum);
}
}
This kernel has three passes over the data (find max, compute exponentials and sum, normalize). Each pass uses the two-stage reduction pattern (SIMD-level reduction followed by threadgroup-level reduction). The function constant COLS lets the compiler optimize the loop bounds.
Summary
MSL is a practical language for writing GPU kernels. Here are the key takeaways:
- MSL is C++14 with GPU extensions. No surprises if you know C++.
- Four address spaces:
device(main memory),constant(cached read-only),threadgroup(fast shared SRAM),thread(registers). - Vector types (
half4,float4) are essential for memory throughput. - The
halftype gives you 2x throughput and 2x bandwidth – use it everywhere in ML kernels. - SIMD intrinsics (
simd_sum,simd_max,simd_shuffle, etc.) enable fast intra-SIMD-group communication without shared memory or barriers. - Threadgroup memory + barriers enable cross-SIMD-group communication within a threadgroup.
- Function constants create specialized kernel variants at compile time.
- The two-stage reduction (SIMD reduce, then threadgroup reduce) is the most common pattern in ML kernels.
In the next chapter, we will zoom out from individual threads to look at how threadgroups and SIMD groups are organized and dispatched.
Threadgroups, SIMD Groups, and Dispatch
So far we have talked about individual threads and what they can do. But a single GPU thread is weak – it is only useful because there are thousands of them running simultaneously. The way you organize those thousands of threads matters enormously for performance. Get the geometry wrong and you leave half the GPU idle. Get it right and your kernel saturates every execution unit.
This chapter is about the organizational hierarchy of GPU threads in Metal, and how to dispatch them effectively.
The Thread Hierarchy
Metal organizes threads into a three-level hierarchy:
Grid (all threads in the dispatch)
+------------------------------------------------------------------+
| |
| Threadgroup 0 Threadgroup 1 Threadgroup 2 |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 0 | | SIMD Group 0 | | SIMD Group 0 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 1 | | SIMD Group 1 | | SIMD Group 1 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 2 | | SIMD Group 2 | | SIMD Group 2 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 3 | | SIMD Group 3 | | SIMD Group 3 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| |
+------------------------------------------------------------------+
Let us define each level.
SIMD Group (32 threads)
A SIMD group is 32 threads that execute in lockstep – the same instruction, at the same clock cycle. This is the fundamental unit of execution on Apple GPUs. You do not choose the SIMD group size; it is always 32.
Threads in a SIMD group can communicate through SIMD intrinsics (simd_sum, simd_shuffle, etc.) with zero overhead – no barriers, no shared memory, just direct register reads. This is the fastest form of inter-thread communication available.
Threadgroup (up to 1024 threads)
A threadgroup is a collection of SIMD groups that are co-scheduled on the same GPU compute unit. All threads in a threadgroup:
- Share access to
threadgroupmemory (fast on-chip SRAM) - Can synchronize with
threadgroup_barrier() - Are guaranteed to execute concurrently (they are all resident on the same compute unit)
A threadgroup can contain up to 1024 threads, which means up to 32 SIMD groups. In practice, common sizes are 128 (4 SIMD groups), 256 (8 SIMD groups), or sometimes 512 or 1024.
Grid (the entire dispatch)
The grid is the collection of all threadgroups in a dispatch. Threadgroups in the grid are independent – there is no way for threads in different threadgroups to communicate during a single dispatch (except through device memory, but there are no cross-threadgroup barriers).
Two Ways to Dispatch: Threadgroups vs Threads
Metal gives you two dispatch methods, and the difference is subtle but important.
dispatchThreadgroups
You specify how many threadgroups to launch and how many threads per threadgroup:
// Launch a grid of threadgroups
let threadgroupsPerGrid = MTLSize(width: 16, height: 8, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 256, height: 1, depth: 1)
encoder.dispatchThreadgroups(threadgroupsPerGrid,
threadsPerThreadgroup: threadsPerThreadgroup)
// Total threads = 16 * 8 * 256 = 32,768
With this method, you are responsible for computing the grid dimensions. If your problem size is not a perfect multiple of the threadgroup size, you need to handle the boundary condition in the kernel.
dispatchThreads
You specify the total number of threads you want, and Metal figures out the threadgroup count:
// Launch exactly this many threads
let threadsPerGrid = MTLSize(width: 4000, height: 3000, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 16, height: 16, depth: 1)
encoder.dispatchThreads(threadsPerGrid,
threadsPerThreadgroup: threadsPerThreadgroup)
// Metal launches enough threadgroups to cover 4000 x 3000,
// and threads outside that range are automatically disabled
With dispatchThreads, Metal handles the boundary for you. If you request 4000 x 3000 threads with 16x16 threadgroups, Metal launches 250 x 188 threadgroups (25016 = 4000, 18816 = 3008), but the extra 8 rows of threads simply do not execute.
Which Should You Use?
For ML kernels, dispatchThreadgroups is almost always the right choice. Here is why:
- Explicit control: You know exactly how many threadgroups you are launching, which makes reasoning about resource usage straightforward.
- Kernel design: ML kernels are typically designed around threadgroup-level cooperation (tiled GEMM, reductions). The kernel code already assumes a specific threadgroup structure.
- No wasted threads: With
dispatchThreadgroups, you design your kernel so that every thread does useful work. There are no “out-of-bounds” threads.
dispatchThreads is nice for simple element-wise kernels where you just want one thread per element and do not care about threadgroup structure. But for anything involving shared memory, reductions, or tiling, use dispatchThreadgroups.
Choosing Threadgroup Sizes
The threadgroup size is one of the most important performance decisions you make. Here are the rules:
Rule 1: Always Use Multiples of 32
Since SIMD groups are 32 threads, your threadgroup size should always be a multiple of 32. If you use, say, 100 threads per threadgroup, Metal will still allocate 4 SIMD groups (128 threads worth of resources), but 28 threads in the last SIMD group will be idle. That is 22% waste.
Threadgroup size = 100 (bad):
+----------------+ +----------------+ +----------------+ +----------------+
| SIMD Group 0 | | SIMD Group 1 | | SIMD Group 2 | | SIMD Group 3 |
| 32 active | | 32 active | | 32 active | | 4 active |
| 0 idle | | 0 idle | | 0 idle | | 28 IDLE |
+----------------+ +----------------+ +----------------+ +----------------+
^^^^^^^^^^^^
22% waste!
Threadgroup size = 128 (good):
+----------------+ +----------------+ +----------------+ +----------------+
| SIMD Group 0 | | SIMD Group 1 | | SIMD Group 2 | | SIMD Group 3 |
| 32 active | | 32 active | | 32 active | | 32 active |
| 0 idle | | 0 idle | | 0 idle | | 0 idle |
+----------------+ +----------------+ +----------------+ +----------------+
0% waste
Rule 2: Common Sizes and When to Use Them
| Threadgroup Size | SIMD Groups | Use Case |
|---|---|---|
| 32 | 1 | Simple per-element ops, tiny reductions |
| 64 | 2 | Light cooperation needed |
| 128 | 4 | Default for GEMM (4 SIMD groups = 4 tiles) |
| 256 | 8 | Default for reductions, normalization |
| 512 | 16 | Large reductions, high occupancy kernels |
| 1024 | 32 | Maximum, rarely needed |
Rule 3: Balance Occupancy and Resources
Here is the tension: larger threadgroups use more resources (registers, threadgroup memory), which means fewer threadgroups can be resident on a compute unit at the same time. Smaller threadgroups use fewer resources but may not have enough threads for efficient cooperation.
Compute Unit (simplified)
+----------------------------------------------------------+
| Register File: 64KB Threadgroup Memory: 32KB |
| |
| Option A: 4 threadgroups of 128 threads (512 threads) |
| Each TG gets: 16KB registers, 8KB shared memory |
| Good latency hiding (many threads to switch between) |
| |
| Option B: 2 threadgroups of 256 threads (512 threads) |
| Each TG gets: 32KB registers, 16KB shared memory |
| Same total threads, but more resources per TG |
| |
| Option C: 1 threadgroup of 1024 threads (1024 threads) |
| Gets all: 64KB registers, 32KB shared memory |
| Lots of threads, but if this TG stalls, nothing else |
| can run on this compute unit |
+----------------------------------------------------------+
The sweet spot for ML kernels is usually 128-256 threads. This gives you enough SIMD groups for cooperation (4-8) while leaving room for multiple threadgroups per compute unit (good for latency hiding).
Rule 4: Consult maxTotalThreadsPerThreadgroup
Every PSO has a maxTotalThreadsPerThreadgroup property that tells you the maximum threadgroup size the kernel supports, based on its resource usage:
let pso = try device.makeComputePipelineState(function: function)
print(pso.maxTotalThreadsPerThreadgroup) // e.g., 1024, or maybe 512
If your kernel uses a lot of registers or threadgroup memory, this number may be less than 1024. Always check.
SIMD Group Cooperation Patterns
Now let us look at how SIMD groups work together to solve problems. These patterns show up over and over in ML kernels.
Pattern 1: SIMD Reduction
The simplest and most common pattern. Each thread has a value, and you want the sum (or max, or min) across all threads in the SIMD group.
float my_value = data[tid];
float total = simd_sum(my_value);
// All 32 lanes now hold the same total
Lane: 0 1 2 3 ... 31
Val: 3.0 1.0 4.0 1.5 ... 2.7
|
simd_sum
|
total = 87.3 (in all lanes)
Pattern 2: SIMD Broadcast
One lane has a value that all other lanes need.
float value;
if (simd_lane == 0) {
value = compute_something_expensive();
}
// Broadcast lane 0's value to all lanes
value = simd_broadcast_first(value);
// Or from a specific lane:
value = simd_shuffle(value, source_lane);
Pattern 3: SIMD Prefix Sum (Inclusive Scan)
Each lane gets the sum of all values from lane 0 through its own lane. Useful for compaction, histograms, and stream processing.
float val = data[tid];
float prefix = simd_prefix_inclusive_sum(val);
Lane: 0 1 2 3 4 5 ...
Input: 3 1 4 1 5 9 ...
Output: 3 4 8 9 14 23 ...
^ ^ ^
| | +-- 3+1+4
| +-- 3+1
+-- 3
Pattern 4: SIMD Shuffle for Data Reuse
When adjacent threads need overlapping data (like in convolution), shuffles avoid redundant memory loads:
// Each lane loads one element
float my_val = data[base + simd_lane];
// Now lane i can access neighboring values without memory loads:
float left = simd_shuffle_up(my_val, 1); // lane i gets lane (i-1)'s value
float right = simd_shuffle_down(my_val, 1); // lane i gets lane (i+1)'s value
// Stencil computation (e.g., 1D convolution with kernel [0.25, 0.5, 0.25])
float result = 0.25f * left + 0.5f * my_val + 0.25f * right;
Before shuffle:
Lane: 0 1 2 3 4 ...
Val: A B C D E ...
After simd_shuffle_down(val, 1):
Lane: 0 1 2 3 4 ...
Val: B C D E F ...
^ ^ ^
Each lane got the next lane's value
Pattern 5: Two-Stage Reduction (Cross-SIMD-Group)
When you need a reduction across the entire threadgroup (not just one SIMD group), you do it in two stages:
Stage 1: SIMD-level reduction (fast, no barriers)
+------------------+ +------------------+ +------------------+ +------------------+
| SIMD Group 0 | | SIMD Group 1 | | SIMD Group 2 | | SIMD Group 3 |
| 32 values | | 32 values | | 32 values | | 32 values |
| --> simd_sum | | --> simd_sum | | --> simd_sum | | --> simd_sum |
| Result: S0 | | Result: S1 | | Result: S2 | | Result: S3 |
+------------------+ +------------------+ +------------------+ +------------------+
Stage 2: Write partial sums to threadgroup memory, then reduce
+-------------------------------------------+
| threadgroup float partials[4]; |
| partials = [S0, S1, S2, S3] |
| |
| threadgroup_barrier(...) |
| |
| SIMD Group 0 reads all 4 values |
| --> simd_sum(partials[simd_lane]) |
| Result: S0 + S1 + S2 + S3 = TOTAL |
+-------------------------------------------+
Here is the code:
// Assume threadgroup size = 128 (4 SIMD groups)
threadgroup float partials[4];
// Stage 1: Each SIMD group reduces its 32 values
float local_sum = simd_sum(my_value);
// One thread per SIMD group writes to shared memory
if (simd_lane == 0) {
partials[simd_id] = local_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stage 2: First SIMD group reduces the partial sums
float total;
if (simd_id == 0 && simd_lane < 4) {
total = simd_sum(partials[simd_lane]);
}
// Broadcast to all threads if needed
threadgroup float shared_total;
if (lid == 0) {
shared_total = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
total = shared_total;
This two-stage pattern is the workhorse of ML kernels. You will find it in normalization (RMSNorm, LayerNorm), softmax, and anywhere else you need a global reduction within a threadgroup.
Thread Mapping: From Grid to Data
One of the most important design decisions is how you map threads to data. There are several common patterns.
Pattern A: One Thread Per Element
The simplest mapping. Each thread processes exactly one element:
kernel void scale(
device half* data [[buffer(0)]],
constant float& factor [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
data[tid] = half(float(data[tid]) * factor);
}
Thread: 0 1 2 3 4 5 6 7 ...
Data: [0] [1] [2] [3] [4] [5] [6] [7] ...
| | | | | | | |
v v v v v v v v
Dispatch: one thread per element, trivial.
Pattern B: One Thread Per Vector (Vectorized)
Each thread processes a vector of elements:
kernel void scale_vec4(
device half4* data [[buffer(0)]],
constant float& factor [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
half4 v = data[tid];
data[tid] = half4(float4(v) * factor);
}
Thread: 0 1 2 3 ...
Data: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15] ...
| | | |
v v v v
Dispatch: one thread per 4 elements. Total threads = N / 4.
Pattern C: Stride Loop (One Threadgroup Per Row)
Each threadgroup processes an entire row, with threads striding through the elements:
kernel void row_sum(
device const float* matrix [[buffer(0)]],
device float* sums [[buffer(1)]],
constant uint& cols [[buffer(2)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]]
) {
uint row = gid;
float sum = 0.0f;
// Each thread strides through the row
for (uint col = lid; col < cols; col += 256) {
sum += matrix[row * cols + col];
}
// Reduce within threadgroup...
sum = simd_sum(sum);
// ... (two-stage reduction as before)
if (lid == 0) {
sums[row] = sum;
}
}
Row 0 (threadgroup 0):
Thread 0 reads: [0], [256], [512], [768], ...
Thread 1 reads: [1], [257], [513], [769], ...
Thread 2 reads: [2], [258], [514], [770], ...
...
Thread 255 reads: [255], [511], [767], [1023], ...
Dispatch: one threadgroup per row.
Pattern D: 2D Tiling (GEMM)
For matrix multiplication, threads are organized in a 2D grid of tiles:
kernel void gemm_tiled(
device const half* A [[buffer(0)]],
device const half* B [[buffer(1)]],
device half* C [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Each threadgroup computes a TM x TN tile of C
// gid.y = tile row, gid.x = tile column
uint tile_row = gid.y; // Which TM-row block
uint tile_col = gid.x; // Which TN-column block
// Each SIMD group within the threadgroup computes a sub-tile
// ...
}
Matrix C (M x N):
+--------+--------+--------+--------+
| TG(0,0)| TG(0,1)| TG(0,2)| TG(0,3)| <-- threadgroup row 0
| 32x64 | 32x64 | 32x64 | 32x64 |
+--------+--------+--------+--------+
| TG(1,0)| TG(1,1)| TG(1,2)| TG(1,3)| <-- threadgroup row 1
| 32x64 | 32x64 | 32x64 | 32x64 |
+--------+--------+--------+--------+
| TG(2,0)| TG(2,1)| TG(2,2)| TG(2,3)|
| 32x64 | 32x64 | 32x64 | 32x64 |
+--------+--------+--------+--------+
Each threadgroup computes a TM x TN tile of C (e.g., 32 x 64).
Grid dimensions: (N/TN, M/TM, 1) threadgroups.
This is the tiling pattern used in akunu’s GEMM kernels. We will cover it in detail in the SIMD matrix operations chapter.
Dispatch Geometry Calculation
Let us put this together with concrete dispatch calculations for common ML operations.
Element-wise Operations (ReLU, SiLU, Add)
Problem: Apply activation to N elements
Approach: One thread per element (or per half4)
Vectorized (half4):
threads_needed = N / 4
threadgroup_size = 256
threadgroups = ceil(threads_needed / 256)
Example: N = 4096
threads_needed = 1024
threadgroup_size = 256
threadgroups = 4
Grid: MTLSize(4, 1, 1)
TG: MTLSize(256, 1, 1)
Row-wise Operations (Softmax, RMSNorm)
Problem: Process M rows of N elements each
Approach: One threadgroup per row
threadgroup_size = 256
threadgroups = M
Example: M = 32 (batch), N = 4096 (hidden dim)
threadgroups = 32
Grid: MTLSize(32, 1, 1)
TG: MTLSize(256, 1, 1)
Each of the 256 threads in a threadgroup strides through
4096 / 256 = 16 elements.
Matrix Multiplication (GEMM)
Problem: C[M,N] = A[M,K] * B[K,N]
Approach: 2D grid, each threadgroup computes a TM x TN tile
TM = 32, TN = 64 (tile sizes)
threadgroup_size = 128 (4 SIMD groups)
threadgroups_x = ceil(N / TN)
threadgroups_y = ceil(M / TM)
Example: M = 4096, N = 4096, K = 4096
threadgroups_x = 4096 / 64 = 64
threadgroups_y = 4096 / 32 = 128
Grid: MTLSize(64, 128, 1)
TG: MTLSize(128, 1, 1)
Total threadgroups: 8192
Total threads: 8192 * 128 = 1,048,576
Batched Operations
For batched operations, you can use the third grid dimension:
Problem: Batch of B matrices, each M x N, apply softmax per row
Approach: 3D grid -- batch x rows x 1
Grid: MTLSize(M, B, 1)
TG: MTLSize(256, 1, 1)
Kernel sees:
batch_idx = threadgroup_position_in_grid.y
row_idx = threadgroup_position_in_grid.x
Real-World Example: Dispatch for Multi-Head Attention
Let us trace through the dispatch geometry for multi-head attention, a critical component of transformer inference.
Suppose we have:
- Batch size B = 1 (single sequence, typical for inference)
- Number of heads H = 32
- Sequence length S = 2048
- Head dimension D = 128
The attention computation involves:
- Q * K^T -> scores [H, 1, S] (for single-token generation)
- softmax(scores / sqrt(D)) -> weights [H, 1, S]
- weights * V -> output [H, 1, D]
Step 1: Score computation (GEMV -- each head is a vector-matrix multiply)
For each head: q[1, D] * K[S, D]^T -> scores[1, S]
Grid: MTLSize(ceil(S/256), H, 1) -- one TG row per head
TG: MTLSize(256, 1, 1)
Each threadgroup computes 256 elements of the score vector
for one attention head.
Step 2: Softmax over scores
For each head: softmax(scores[1, S])
Grid: MTLSize(H, 1, 1) -- one TG per head
TG: MTLSize(256, 1, 1)
Each threadgroup processes one row (one head's scores).
Step 3: Weighted sum (another GEMV)
For each head: weights[1, S] * V[S, D] -> output[1, D]
Grid: MTLSize(ceil(D/32), H, 1)
TG: MTLSize(256, 1, 1)
In practice, akunu fuses some of these steps together to reduce memory traffic, but the dispatch geometry follows this general pattern.
Common Pitfalls
Pitfall 1: Threadgroup Size Not a Multiple of 32
// BAD: wastes 22% of GPU resources
encoder.dispatchThreadgroups(grid, threadsPerThreadgroup: MTLSize(100, 1, 1))
// GOOD: full utilization
encoder.dispatchThreadgroups(grid, threadsPerThreadgroup: MTLSize(128, 1, 1))
Pitfall 2: Too Few Threadgroups
If you dispatch only 2 threadgroups on a GPU with 10 compute units, 8 compute units sit idle. You want at least as many threadgroups as compute units, and ideally several times more for latency hiding.
M2 Ultra: 76 compute units
Minimum threadgroups for full utilization: 76
Better: 76 * 4 = 304+ threadgroups (multiple waves)
Pitfall 3: Forgetting Bounds Checks
When using dispatchThreadgroups, your total thread count may exceed your data size. Always check bounds:
kernel void safe_scale(
device half* data [[buffer(0)]],
constant uint& count [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= count) return; // Bounds check!
data[tid] = data[tid] * 2.0h;
}
Pitfall 4: Divergent Execution Within a SIMD Group
All 32 threads in a SIMD group execute together. If some threads take one branch and others take a different branch, the SIMD group executes both branches (masking inactive threads). This is called divergence and it wastes cycles:
// BAD: divergent branches within a SIMD group
if (tid % 2 == 0) {
// Even threads do expensive path A
result = expensive_computation_A(data[tid]);
} else {
// Odd threads do expensive path B
result = expensive_computation_B(data[tid]);
}
// Both paths execute for every SIMD group -- 2x the cost!
// BETTER: organize work so entire SIMD groups take the same branch
uint group = tid / 32;
if (group % 2 == 0) {
// Entire SIMD groups do path A
result = expensive_computation_A(data[tid]);
} else {
// Entire SIMD groups do path B
result = expensive_computation_B(data[tid]);
}
Pitfall 5: Not Accounting for Non-Uniform Threadgroups
When using dispatchThreads, the last threadgroup in each dimension may be smaller than requested. If your kernel assumes a fixed threadgroup size (e.g., uses threadgroup float shared[256]), it will still work, but some of those shared memory slots will contain garbage. Be careful when reducing – only reduce over the actual active thread count.
Visualizing Execution on Apple Silicon
Let us trace how a dispatch actually executes on an Apple Silicon GPU.
An M1 Pro has 16 compute units. Each compute unit can execute multiple threadgroups concurrently (depending on resource usage). Here is a simplified timeline:
Dispatch: 64 threadgroups, 128 threads each (4 SIMD groups per TG)
Compute Unit 0: [TG 0][TG 16][TG 32][TG 48] (4 TGs across time)
Compute Unit 1: [TG 1][TG 17][TG 33][TG 49]
Compute Unit 2: [TG 2][TG 18][TG 34][TG 50]
...
Compute Unit 15: [TG 15][TG 31][TG 47][TG 63]
Wave 0: TG 0-15 (one TG per compute unit)
Wave 1: TG 16-31 (launched as wave 0 TGs finish)
Wave 2: TG 32-47
Wave 3: TG 48-63
Total: 4 waves to process all 64 threadgroups
In reality, the scheduler is more dynamic than this – it can have multiple threadgroups resident per compute unit if there are enough resources. But the wave model gives you the right intuition. More threadgroups = more waves = more opportunity for the GPU to hide latency.
Choosing Grid Dimensions for ML Workloads
Here is a decision flowchart for common ML operations:
What kind of operation?
|
+-- Element-wise (ReLU, add, scale)?
| Threads per element: 1 (or 1 per half4 for vectorization)
| Grid: (ceil(N/4/TG_SIZE), 1, 1)
| TG: (256, 1, 1)
|
+-- Row-wise reduction (softmax, norm)?
| One threadgroup per row
| Grid: (num_rows, 1, 1)
| TG: (256, 1, 1)
|
+-- Matrix multiply (GEMM)?
| 2D grid of tiles
| Grid: (ceil(N/TN), ceil(M/TM), 1)
| TG: (128, 1, 1) -- 4 SIMD groups
|
+-- Batched operation?
| Use 3rd dimension for batch
| Grid: (spatial_x, spatial_y, batch)
| TG: depends on inner operation
|
+-- Vector-matrix multiply (GEMV)?
Grid: (ceil(M/TG_SIZE), 1, 1) -- one TG per chunk of rows
TG: (256, 1, 1)
Summary
- Thread hierarchy: Grid -> Threadgroups -> SIMD Groups -> Threads. Each level has different communication capabilities.
- SIMD groups (32 threads) communicate via intrinsics – fast, no barriers.
- Threadgroups (up to 1024 threads) communicate via threadgroup memory + barriers.
- Grid threadgroups cannot communicate during a dispatch.
- Always use multiples of 32 for threadgroup sizes.
- Common sizes: 128 for GEMM, 256 for reductions, 32 for trivial ops.
- dispatchThreadgroups for kernels with threadgroup cooperation (most ML kernels).
- dispatchThreads for simple per-element kernels.
- Two-stage reduction (SIMD reduce + threadgroup reduce) is the fundamental pattern.
- Dispatch enough threadgroups to keep all compute units busy (at least as many as there are compute units).
The next chapter covers the memory model – how data flows between CPU and GPU, and why memory bandwidth is the critical bottleneck for ML inference.
Memory Model and Buffer Management
If you have ever profiled a GPU workload and been disappointed by the utilization numbers, the root cause was almost certainly memory. Not compute. Memory. On Apple Silicon, the story is both simpler and subtler than on discrete GPUs: the CPU and GPU share a single physical memory pool (Unified Memory Architecture, or UMA), which eliminates an entire class of problems – but introduces a few new ones. This chapter walks through Metal’s memory model from first principles, shows how akunu maps it to inference, and explains why bandwidth – not FLOPS – is the number you should be worrying about.
The Unified Memory Architecture
Traditional GPU programming on NVIDIA or AMD involves two physically separate memory pools. The CPU has its DDR/LPDDR; the GPU has its GDDR/HBM. Every time you want the GPU to see CPU data, you perform an explicit copy across the PCIe bus (16 GB/s for PCIe 4.0 x16). Every time you want CPU results, you copy back. These copies are slow, asynchronous, and a constant source of bugs.1
Apple Silicon threw this model out. Starting with M1, the CPU, GPU, and Neural Engine all share a single pool of LPDDR memory with a single set of page tables.2 There is no PCIe bus. There is no copy. When the CPU writes to address 0x1234, the GPU can read from that same address – because it is the same physical page.
┌──────────────────────────────────────────────────┐
│ Unified Memory (LPDDR5/5X) │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ CPU │ │ GPU │ │ Neural │ │
│ │ Cores │ │ Cores │ │ Engine │ │
│ └────┬─────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │
│ └─────────────┼────────────┘ │
│ │ │
│ ┌─────────┴─────────┐ │
│ │ System Level │ │
│ │ Cache (SLC) │ │
│ └─────────┬─────────┘ │
│ │ │
│ ┌─────────┴─────────┐ │
│ │ Fabric / Memory │ │
│ │ Controller │ │
│ └───────────────────┘ │
└──────────────────────────────────────────────────┘
This has profound implications for inference engines:
-
Zero-copy buffer sharing. A
MTLBufferallocated withMTLResourceStorageModeSharedis readable and writable by both CPU and GPU without any explicit transfer. The CPU gets a raw pointer (buf.contents); the GPU gets the same backing pages. -
No staging buffers needed. On CUDA, you allocate a “host-pinned” staging buffer, memcpy into it, then launch a
cuMemcpyHtoD. On Metal with shared mode, you just write directly and dispatch. -
Coherency is automatic (mostly). After GPU work completes (i.e.,
waitUntilCompletedreturns), the CPU sees all GPU writes immediately. The hardware’s cache coherency protocol handles this.3
Metal Storage Modes
Metal offers three storage modes for buffers. Understanding them is essential because choosing wrong means either unnecessary copies or corrupted data.
| Storage Mode | CPU Access | GPU Access | Copy Needed? | Best For |
|---|---|---|---|---|
Shared | Read/Write | Read/Write | No | UMA devices (all Apple Silicon) |
Private | None | Read/Write | Yes (blit) | GPU-only temporaries |
Managed | Read/Write | Read/Write | Yes (synchronize) | macOS discrete GPU (not Apple Silicon) |
On Apple Silicon, Shared mode is the overwhelmingly correct choice. It provides direct CPU and GPU access with no copies. Private mode can be slightly faster for buffers the CPU never touches (the GPU may have more freedom in cache management), but in practice the difference is negligible for inference workloads – and you lose the ability to read results without a blit encoder.4
Managed mode exists for macOS systems with discrete AMD GPUs and is irrelevant to Apple Silicon. If you see Managed in Metal sample code, it is targeting a different hardware class.
Why Akunu Uses Shared Mode Everywhere
Look at how MetalDevice::allocate works in akunu:
Buffer MetalDevice::allocate(size_t bytes) {
id<MTLBuffer> buf = [STATE.device newBufferWithLength:MAX(bytes, 16)
options:MTLResourceStorageModeShared];
void *h = (void *)CFBridgingRetain(buf);
allocated_bytes_ += MAX(bytes, 16);
return {h, bytes, [buf contents]};
}
Every allocation uses MTLResourceStorageModeShared. The returned Buffer struct stores both the opaque Metal handle (h) and the CPU-accessible pointer ([buf contents]). This means:
-
Weight loading is zero-copy. When akunu parses a GGUF file, it can
mmapthe file and allocate Metal buffers that point at the same pages. The GPU reads weights directly from the memory-mapped file. No memcpy, no staging buffer, no DMA transfer. -
KV cache is CPU-readable. After decode completes, the CPU can inspect KV cache contents directly through
buf.contents– useful for debugging and speculative decoding verification. -
Scratch buffers just work. The
ScratchBuffersstruct allocates everything up front withdevice.allocate(). During the hot path, nothing is allocated or freed.
The Buffer Struct: Akunu’s Abstraction
Akunu wraps Metal buffers in a simple POD struct defined in device.h:
struct Buffer {
void *handle; // Backend-specific pointer (MTLBuffer*, CUdeviceptr, etc.)
size_t size; // Size in bytes
void *contents; // CPU-accessible pointer (for UMA or mapped memory)
};
This is deliberately minimal. No reference counting, no smart pointers, no virtual methods. Just three fields that fit in 24 bytes. The handle is an opaque pointer – on Metal it is a CFBridgingRetain’d id<MTLBuffer>, on a future CUDA backend it would be a CUdeviceptr. The contents pointer is the CPU-visible address; on discrete GPUs it would be nullptr.
Why no reference counting? Because buffer lifetimes in inference are trivially static. You allocate all buffers at model load time, use them for the entire session, and free them at shutdown. There is no dynamic buffer creation in the hot path. This means the overhead of shared_ptr or arc reference counting is pure waste.5
The allocate-with-data Overload
There is a second allocate overload that accepts initial data:
Buffer MetalDevice::allocate(const void *data, size_t bytes) {
id<MTLBuffer> buf = [STATE.device newBufferWithBytes:data
length:bytes
options:MTLResourceStorageModeShared];
void *h = (void *)CFBridgingRetain(buf);
allocated_bytes_ += bytes;
return {h, bytes, [buf contents]};
}
This is used for two things: uploading initial weight data and creating pre-filled parameter buffers. Metal’s newBufferWithBytes: internally memcpys the data into the newly allocated buffer. On UMA this is a simple memcpy; there is no bus transfer.
The setBytes() Optimization: Inline Small Data
Not everything needs a buffer. Metal provides setBytes:length:atIndex: for small, frequently-changing data. Instead of allocating a buffer, writing to it, and binding it, you just hand the encoder a pointer and a length. Metal copies the bytes directly into the command buffer’s inline data area.6
The rules are:
| Data Size | Mechanism | Per-Dispatch Cost |
|---|---|---|
| <= 4 KB | setBytes() | ~0 (inline in command buffer) |
| > 4 KB | setBuffer() | Buffer bind + potential cache miss |
Akunu uses this aggressively. Every kernel’s parameter struct (dimensions, epsilon values, strides) is under 64 bytes. The DispatchCmd struct stores these inline:
// Inline params (up to 64 bytes -- covers all kernel param structs)
uint8_t param_bytes[64];
int param_size;
int param_index; // buffer index for setBytes/setBuffer
During dispatch table replay, static params use pre-allocated GPU buffers via setBuffer() (zero per-token work), while position-patched params use setBytes() (the patched bytes get copied inline into the command buffer). This split is the key insight from akunu’s Metal backend – setBytes() is perfect for the few parameters that change per token, while setBuffer() avoids redundant work for the many parameters that do not.
// From MetalDevice::encode_dispatch_table:
// Static params: use pre-allocated GPU buffer (no per-token work)
if (cmd.param_buf.handle)
[enc setBuffer:(__bridge id<MTLBuffer>)cmd.param_buf.handle
offset:0
atIndex:(NSUInteger)cmd.param_index];
else
[enc setBytes:cmd.param_bytes
length:(NSUInteger)cmd.param_size
atIndex:(NSUInteger)cmd.param_index];
Memory Alignment
Metal requires 16-byte alignment for buffer offsets when binding with setBuffer:offset:atIndex:.7 This is a hardware constraint – the GPU’s memory controller fetches aligned 16-byte chunks, and unaligned offsets would require splitting a fetch across two cache lines.
Akunu ensures alignment in several ways:
-
Buffer sizes are rounded up. The minimum allocation is
MAX(bytes, 16), ensuring every buffer is at least one alignment unit. -
QKV sub-offsets are naturally aligned. The
ScratchBuffersstruct computes QKV offsets asq_dim * 2and(q_dim + kv_dim) * 2. Sinceq_dimandkv_dimare always multiples ofhead_dim(typically 64 or 128), and each element is 2 bytes (FP16), these offsets are always multiples of 128 or 256 bytes – far exceeding the 16-byte requirement. -
GGUF block alignment. Quantized formats like Q4_0 use 32-element blocks (18 bytes each). Weight tensors are stored as contiguous arrays of these blocks, and GGUF pads to alignment boundaries.
Resource Hazards and Synchronization
On discrete GPUs with separate memory, you have explicit copy commands that create natural synchronization points. On UMA with shared mode, the GPU and CPU can both touch the same memory at any time. This creates resource hazards: what happens if the CPU writes to a buffer while the GPU is reading it?
Metal handles this through command buffer completion:
-
CPU-to-GPU: The CPU writes data before calling
begin_encoding(). The encoder captures buffer references at encode time. As long as writes complete before encoding begins, the GPU will see them. -
GPU-to-CPU: After
end_encoding_sync()returns (orwaitUntilCompletedsignals), all GPU writes are visible to the CPU. The hardware flushes caches as part of command buffer completion. -
GPU-to-GPU (within a command buffer): Metal guarantees that compute dispatches within a single command encoder execute in order. Dispatch N sees all writes from dispatch N-1. No barriers needed.8
-
GPU-to-GPU (across command buffers): You need explicit synchronization. Akunu’s
end_encoding_event()/begin_encoding_after_event()usesMTLSharedEventfor GPU-to-GPU signaling across command buffers:
// Signal after this command buffer completes
[STATE.cmdBuffer encodeSignalEvent:STATE.pipelineEvent value:signal_val];
[STATE.cmdBuffer commit];
// Next command buffer waits for the signal
[STATE.cmdBuffer encodeWaitForEvent:STATE.pipelineEvent value:STATE.eventValue];
This is critical for pipelined chain decode, where one command buffer is executing on the GPU while the CPU encodes the next one.
The Bandwidth Bottleneck
Here is the uncomfortable truth about LLM inference on Apple Silicon: you will almost never be compute-bound during token generation. You will be memory-bandwidth-bound.
Why? Consider a single decode step for a 7B parameter model with Q4_0 quantization. Each parameter is 4.5 bits on average (4 bits of data plus amortized scale/min). The total weight data is roughly:
$$7 \times 10^9 \times 4.5 / 8 \approx 3.9 \text{ GB}$$
Every decode step reads every weight exactly once (each layer’s Q, K, V, O, gate, up, down projections). That is 3.9 GB of memory reads per token. On an M4 Pro with ~273 GB/s memory bandwidth, the theoretical floor is:
$$3.9 \text{ GB} / 273 \text{ GB/s} \approx 14.3 \text{ ms/token} \approx 70 \text{ tok/s}$$
The actual compute work (multiply-accumulate operations) takes a fraction of this time. The GPU cores are waiting for data to arrive from DRAM, not crunching numbers.9
| Chip | Memory Bandwidth | Theoretical Max tok/s (7B Q4) | Theoretical Max tok/s (70B Q4) |
|---|---|---|---|
| M1 | 68 GB/s | ~17 | ~1.7 |
| M1 Pro | 200 GB/s | ~51 | ~5.1 |
| M1 Max | 400 GB/s | ~102 | ~10.2 |
| M2 | 100 GB/s | ~26 | ~2.6 |
| M3 Pro | 150 GB/s | ~38 | ~3.8 |
| M4 | 120 GB/s | ~31 | ~3.1 |
| M4 Pro | 273 GB/s | ~70 | ~7.0 |
| M4 Max | 546 GB/s | ~140 | ~14.0 |
This table reveals a fundamental reality: your token generation speed is almost entirely determined by how fast your chip can feed data to the GPU cores. More compute cores help for prefill (which is compute-bound), but for autoregressive decode, bandwidth is king.10
Implications for Engine Design
This bandwidth bottleneck drives several akunu design decisions:
-
Minimize weight reads. Read each weight exactly once per token. Fuse operations where possible (SiLU + down projection, QKV projection) to reduce the number of kernel launches and avoid re-reading intermediate buffers.
-
Use the System Level Cache (SLC). Apple Silicon chips have a large last-level cache shared between CPU and GPU. On M4 Pro, it is estimated at 32 MB. Weights that fit in SLC are served at cache bandwidth (~2 TB/s on M4), not DRAM bandwidth. This is why akunu fuses QKV and gate+up weights on chips with large SLC – the fused weight matrix is more likely to be in cache for the second read.11
-
Quantize aggressively. Q4_0 reads half the bytes of FP16. Q2_K reads a quarter. Lower precision means less bandwidth, which directly translates to higher tok/s.
-
Avoid unnecessary reads. The dispatch table pre-computes everything that can be pre-computed. Buffer bindings, pipeline states, parameter structs – all resolved at build time, not at dispatch time.
Pre-Allocated Scratch Buffers
Akunu allocates all intermediate buffers once at model load time via the ScratchBuffers struct:
struct ScratchBuffers {
Buffer h0; // [dim] FP16
Buffer h1; // [dim] FP16
Buffer residual; // [dim] FP16
Buffer qkv; // [q_dim + 2*kv_dim] FP16
Buffer attn_out; // [max(q_dim, dim)] FP16
Buffer ffn_gate; // [2 * ffn_dim] FP16 (2x for fused gate+up)
Buffer ffn_up; // [ffn_dim] FP16
Buffer ffn_act; // [ffn_dim] FP16
Buffer logits; // [vocab_size] FP16
Buffer token_ids; // [max_chain] U32
// ... plus batch buffers for prefill
};
A few things to notice:
-
ffn_gateis 2xffn_dim. This accommodates fused gate+up projections, where the output of a single GEMV contains both the gate and up vectors contiguously. -
qkvis contiguous. Q, K, and V are stored in a single buffer at computed offsets (qkv_q_offset,qkv_k_offset,qkv_v_offset). When QKV fusion is enabled, a single GEMV writes all three projections into this buffer. When not fused, three separate GEMVs write to their respective sub-regions. -
No dynamic allocation. The
ScratchBuffers::createfactory is called once. After that, every forward pass reuses the same buffers. This means zero memory allocation overhead in the hot path, zero fragmentation, and deterministic memory usage. -
Prefill buffers are separate. Batch buffers (
batch_h0,batch_q, etc.) are sized for the maximum prefill chunk. They are larger than decode buffers by a factor ofprefill_chunk(typically 4096).
KV Cache Memory Layout
The KV cache is another critical memory structure, defined in kv_cache.h:
struct KVCache {
int n_layers;
int n_kv_heads;
int head_dim;
int max_length;
int current_length;
std::vector<Buffer> k_buffers; // one per layer
std::vector<Buffer> v_buffers; // one per layer
int kv_stride; // max_length * head_dim
};
Each layer gets two buffers: one for K, one for V. The layout is [n_kv_heads, max_length, head_dim] in FP16. The kv_stride (= max_length * head_dim) is the number of elements between consecutive KV heads.
For a 32-layer model with 8 KV heads, 128 head dim, and 4096 max context:
$$\text{KV bytes per layer} = 8 \times 4096 \times 128 \times 2 = 8 \text{ MB}$$ $$\text{Total KV cache} = 32 \times 2 \times 8 = 512 \text{ MB}$$
This is significant – half a gigabyte just for KV cache on a model that “only” has 7B parameters. For 70B models with 80 layers, the KV cache can easily exceed 4 GB. This is why max_context is capped at 4096 by default in akunu_load_model().
Buffer Memory Layout (Shared Mode — UMA)
┌──────────────────────────────────────────────┐
│ Model Weights (3.9 GB) │ GPU reads, CPU writes at load
├──────────────────────────────────────────────┤
│ KV Cache (512 MB) │ GPU reads+writes per token
├──────────────────────────────────────────────┤
│ Scratch Buffers (~10 MB) │ GPU reads+writes, reused every step
├──────────────────────────────────────────────┤
│ Prefill Batch Buffers (~200 MB) │ Used only during prefill
└──────────────────────────────────────────────┘
Write Buffer and Read Buffer: Portability
The Device base class provides default implementations of write_buffer and read_buffer that use plain memcpy:
virtual void write_buffer(Buffer dst, const void *src, size_t bytes, size_t offset = 0) {
if (dst.contents)
memcpy((char *)dst.contents + offset, src, bytes);
}
virtual void read_buffer(void *dst, Buffer src, size_t bytes, size_t offset = 0) {
if (src.contents)
memcpy(dst, (const char *)src.contents + offset, bytes);
}
On UMA, this is literally a memcpy – there is no DMA, no bus, no asynchronous transfer. The comment says “Override for discrete GPU backends (CUDA cuMemcpyHtoD, etc.).” This is the portability escape hatch: a future CUDA backend would override these methods to use proper device-to-host copies.
Memory Tracking
MetalDevice tracks total bytes allocated:
size_t allocated_bytes_ = 0;
// In allocate():
allocated_bytes_ += MAX(bytes, 16);
// In free_buffer():
if (buf.size <= allocated_bytes_) allocated_bytes_ -= buf.size;
This is exposed through akunu_model_memory() in the C API, letting callers see how much GPU memory a loaded model uses. It is a simple counter, not a full memory allocator – because akunu does not need a full memory allocator. Buffers are allocated at init, freed at shutdown, and nothing happens in between.
Common Pitfalls
Before we leave the memory model, let me highlight a few traps that catch even experienced Metal developers:
Pitfall 1: Reading GPU Results Before Completion
// WRONG: GPU may still be writing
device.end_encoding_async();
float *result = (float *)logits_buf.contents; // Race condition!
Always call wait() or end_encoding_sync() before reading GPU-written buffers from the CPU. The UMA does not mean coherent-at-all-times; it means coherent-after-completion.
Pitfall 2: Buffer Lifetime with Unretained References
Akunu uses commandBufferWithUnretainedReferences for performance. This means Metal will NOT retain buffer references – if you free a buffer before the GPU finishes, you get a GPU fault. This is safe in akunu because all buffers outlive the GPU work (they are freed only at model destruction), but adding dynamic buffer management would require switching to commandBuffer (with retain) or careful lifetime tracking.12
Pitfall 3: Assuming Private Mode is Faster
On Apple Silicon, Private mode offers minimal benefit over Shared mode for inference workloads. The GPU’s cache hierarchy works the same way for both. Private prevents CPU access, which can be slightly more efficient for GPU-only temporaries, but the difference is typically <1% for bandwidth-bound kernels. Akunu chose simplicity over micro-optimization here.
Pitfall 4: Ignoring the 16-Byte Alignment Rule
If you bind a buffer with an offset that is not a multiple of 16, Metal will either silently produce garbage or crash with a validation error (if Metal validation is enabled). The fix is always the same: pad your data structures to 16-byte boundaries. Notice how akunu’s param structs include _p0, _p1 padding fields:
struct { uint32_t dim; float eps; uint32_t _p0, _p1; } norm_params;
Those _p0, _p1 fields pad the struct to 16 bytes, ensuring alignment when passed through setBytes().
Summary
Apple Silicon’s UMA simplifies GPU programming enormously: no copies, no staging buffers, no DMA. But it does not eliminate the fundamental bottleneck of memory bandwidth. For LLM inference, where every token requires reading the entire weight matrix, bandwidth determines throughput.
Akunu’s memory strategy is:
- Shared mode everywhere for zero-copy CPU-GPU access
- Static allocation of all buffers at model load time
- Pre-allocated param buffers with the setBytes/setBuffer split for per-token patching
- 16-byte aligned everything
- Bandwidth-aware design driving quantization, weight fusion, and SLC exploitation
In the next chapter, we will look at the compute side: how Apple’s SIMD group matrix operations turn those bandwidth-fed bytes into actual matrix multiplications.
-
NVIDIA CUDA Programming Guide, Section 3.2.2, “Device Memory.” The PCIe bus bottleneck is well-documented; PCIe 4.0 x16 provides ~32 GB/s bidirectional, but only ~16 GB/s in one direction. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/. ↩
-
Apple, “Apple M1 Chip,” 2020. The Unified Memory Architecture was first introduced with M1, providing up to 68.25 GB/s bandwidth with a single memory pool. See https://www.apple.com/newsroom/2020/11/apple-unleashes-m1/. ↩
-
Apple, “Metal Best Practices Guide: Resource Storage Modes,” 2024. On Apple Silicon, shared mode buffers are coherent after command buffer completion. See https://developer.apple.com/documentation/metal/choosing-a-resource-storage-mode-for-apple-gpus. ↩
-
Apple, “Metal Feature Set Tables.” On Apple Silicon (Apple GPU family 7+), MTLResourceStorageModeShared is the recommended mode for most buffers. Private mode is useful for render targets and textures that the CPU never accesses. See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf. ↩
-
This is a deliberate design decision. llama.cpp uses a similar approach with
ggml_backend_buffer, where buffer lifetimes are tied to the model context. Reference counting would add atomic operations to every buffer bind – thousands per forward pass. ↩ -
Apple, “setBytes:length:atIndex: Documentation.” The data is copied inline into the command buffer. The maximum size is 4 KB. For larger data, use a buffer. See https://developer.apple.com/documentation/metal/mtlcomputecommandencoder. ↩
-
Apple, “Metal Best Practices Guide: Buffer Alignment.” Buffer offsets must be a multiple of 16 bytes for
setBuffer:offset:atIndex:. See https://developer.apple.com/documentation/metal/mtlcomputecommandencoder. ↩ -
Apple, “Metal Programming Guide: Command Organization.” Within a single compute command encoder, dispatches execute in order with implicit memory barriers. See https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu. ↩
-
This analysis follows the roofline model methodology. See Williams, S., Waterman, A., & Patterson, D. (2009). “Roofline: An Insightful Visual Performance Model for Multicore Architectures.” Communications of the ACM, 52(4), 65-76. See https://doi.org/10.1145/1498765.1498785. ↩
-
Memory bandwidth numbers sourced from Apple’s product specifications and independent testing by Anandtech and Chips and Cheese. Actual achieved bandwidth is typically 70-85% of theoretical peak due to memory controller overhead. See https://chipsandcheese.com/. ↩
-
The System Level Cache (SLC) is described in various Apple Silicon die analyses. See “Apple M1 Die Shot Analysis” by Chips and Cheese, 2021. SLC sizes are estimated from die area analysis and performance profiling; Apple does not officially disclose exact SLC sizes. See https://www.apple.com/newsroom/2020/11/apple-unleashes-m1/. ↩
-
Apple, “commandBufferWithUnretainedReferences Documentation.” Using unretained references avoids the overhead of Metal retaining every buffer referenced by a command buffer. This is safe when buffer lifetimes are manually managed. See https://developer.apple.com/documentation/metal/mtlcommandqueue/1508684-makecommandbufferwithunretainedr. ↩
SIMD Group Matrix Operations
Matrix multiplication is the beating heart of every transformer. During a single forward pass of a 7B-parameter model, you perform roughly 30 matrix multiplications per layer across 32 layers – nearly a thousand matmuls per token. If your matmul is slow, everything is slow. On Apple Silicon, the key to fast matmuls is the simdgroup_matrix API: a set of hardware-accelerated operations that let 32 GPU threads cooperatively multiply 8x8 matrix tiles at full throughput.1
This chapter explains how SIMD group matrix operations work, how they compose into larger GEMM/GEMV kernels, and how akunu uses them for both prefill and decode.
What Is a SIMD Group?
Before diving into matrix operations, let’s establish the execution model. An Apple GPU organizes threads into SIMD groups (also called warps on NVIDIA, or wavefronts on AMD). A SIMD group on Apple Silicon is always 32 threads that execute in lockstep – every thread runs the same instruction at the same cycle.2
┌─────────────────────────────────────────────┐
│ Threadgroup (128 threads) │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌───────────┐ │
│ │ SIMD 0 │ │ SIMD 1 │ │ SIMD 2 │ │ SIMD 3 │ │
│ │ 32 thds │ │ 32 thds │ │ 32 thds │ │ 32 thds │ │
│ └─────────┘ └─────────┘ └─────────┘ └───────────┘ │
└─────────────────────────────────────────────┘
Within a SIMD group, threads can communicate through SIMD shuffle instructions – reading values from any other thread in the group without going through memory. This is extremely fast (single-cycle) and is the foundation of SIMD group matrix operations.
The simdgroup_matrix API
Starting with Apple GPU family 7 (M1 and later), Metal Shading Language provides the simdgroup_matrix type and associated operations.3 The fundamental unit is an 8x8 matrix distributed across the 32 threads of a SIMD group:
#include <metal_simdgroup_matrix>
simdgroup_half8x8 mat_a; // 8x8 matrix of float16
simdgroup_float8x8 mat_b; // 8x8 matrix of float32
Each thread in the SIMD group holds a portion of the 8x8 matrix. You do not control which thread holds which element – the hardware distributes the 64 elements across 32 threads (2 elements per thread) in a hardware-defined layout. This layout is opaque; you interact with the matrix through three operations:
load
simdgroup_half8x8 ma;
simdgroup_load(ma, src_ptr, stride);
Load an 8x8 tile from device or threadgroup memory. The stride is the number of elements between consecutive rows. All 32 threads participate in the load cooperatively – each thread reads its assigned 2 elements.
store
simdgroup_store(mc, dst_ptr, stride);
Store an 8x8 tile back to memory. Same cooperative pattern as load.
multiply_accumulate
simdgroup_multiply_accumulate(mc, ma, mb, mc);
// mc += ma * mb (8x8 += 8x8 * 8x8)
This is the core operation. It multiplies two 8x8 matrices and accumulates the result into a third. The hardware performs 512 multiply-accumulate operations (8 * 8 * 8) in a single instruction across the SIMD group.4 On Apple GPU family 7+, this executes in a few cycles – dramatically faster than doing the same work with scalar or vector instructions.
How 32 Threads Hold One 8x8 Matrix
Let’s be precise about the data distribution. An 8x8 matrix has 64 elements. A SIMD group has 32 threads. So each thread holds exactly 2 elements. The exact mapping is hardware-defined and opaque, but conceptually:
32 Threads Holding an 8x8 Matrix (each thread holds 2 elements):
col0 col1 col2 col3 col4 col5 col6 col7
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
r0 │ T0 │ T0 │ T1 │ T1 │ T2 │ T2 │ T3 │ T3 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r1 │ T4 │ T4 │ T5 │ T5 │ T6 │ T6 │ T7 │ T7 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r2 │ T8 │ T8 │ T9 │ T9 │ T10 │ T10 │ T11 │ T11 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r3 │ T12 │ T12 │ T13 │ T13 │ T14 │ T14 │ T15 │ T15 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r4 │ T16 │ T16 │ T17 │ T17 │ T18 │ T18 │ T19 │ T19 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r5 │ T20 │ T20 │ T21 │ T21 │ T22 │ T22 │ T23 │ T23 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r6 │ T24 │ T24 │ T25 │ T25 │ T26 │ T26 │ T27 │ T27 │
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r7 │ T28 │ T28 │ T29 │ T29 │ T30 │ T30 │ T31 │ T31 │
└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
The key insight: you never access individual elements of a simdgroup_matrix. You load, you multiply-accumulate, you store. The hardware handles the internal layout. Trying to extract element [3][5] would require SIMD shuffles and defeat the purpose.
Building Larger Matmuls from 8x8 Tiles
An 8x8 tile is too small for real work. A typical LLM linear layer might be [4096, 4096] – that is 512 x 512 tiles. The strategy is to:
- Assign tiles to threadgroups. Each threadgroup computes a rectangular block of the output matrix.
- Within a threadgroup, each SIMD group accumulates a column of output tiles.
- Loop over the K dimension in steps of 8 (or larger), accumulating partial sums.
Akunu’s GEMM kernels follow the llama.cpp tile geometry, which uses TM=32 and TN=64:
Output tile per threadgroup: 32 rows (M) x 64 columns (N)
Threadgroup size: 128 threads = 4 SIMD groups
Each SIMD group handles: 32 rows x 16 columns (using 8x8 sub-tiles)
K-loop: stride 32 elements (NK=32)
- Load A tile [32 x 32] into threadgroup memory
- Load B tile [64 x 32] into threadgroup memory
- Each SIMD group does: 4 rows x 2 columns x 4 K-steps of 8x8 MACs
Here is how this maps to akunu’s simd_gemm_f16.metal:
// Dispatch: grid=(ceil(N/64), ceil(M/32), 1), threads=(32,4,1)
// TG memory: 4096 + 2048 = 6144 bytes
constexpr short NR0 = 64; // N tile (weight rows)
constexpr short NR1 = 32; // M tile (activation rows)
constexpr short NK = 32; // K tile (reduction dimension)
simdgroup_half8x8 mc[8]; // 8 accumulator tiles per SIMD group
for (short i = 0; i < 8; i++) {
mc[i] = make_filled_simdgroup_matrix<half, 8>(0.h);
}
for (uint loop_k = 0; loop_k < K; loop_k += NK) {
// Load A and B tiles into threadgroup memory
// ...
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply: 8x8 tiles along K
for (short k_step = 0; k_step < NK; k_step += 8) {
simdgroup_half8x8 ma, mb;
simdgroup_load(ma, sa + ..., ...);
for (short tile = 0; tile < 8; tile++) {
simdgroup_load(mb, sb + ..., ...);
simdgroup_multiply_accumulate(mc[tile], ma, mb, mc[tile]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store accumulated results
for (short tile = 0; tile < 8; tile++) {
simdgroup_store(mc[tile], C + ..., ldc);
}
The accumulator array mc[8] holds 8 output sub-tiles (8x8 each = 8 * 64 = 512 elements), which covers a 32x16 or 8x64 region depending on the layout. The K-loop processes 32 elements at a time, with each step performing 8x8 MACs.
The Tiling Hierarchy
Let’s trace how a large matmul decomposes:
Full matrix: C[M,N] = A[M,K] * B^T[N,K]
Example: C[4096, 4096] = A[4096, 4096] * B^T[4096, 4096]
Grid level: ceil(4096/64) x ceil(4096/32) = 64 x 128 threadgroups
TG level: Each TG computes C[32, 64] using 128 threads (4 SIMDs)
SIMD level: Each SIMD accumulates C[8, 64] or similar sub-region
Tile level: Each 8x8 MAC produces C[8,8] += A[8,8] * B[8,8]
| Level | Size | Unit | Count |
|---|---|---|---|
| Full matrix | 4096 x 4096 | Elements | 16.7M |
| Threadgroup tile | 32 x 64 | Elements | 8,192 tiles |
| SIMD accumulator | 8 x 8 | Elements per tile | 8 tiles/SIMD |
| Single MAC | 8 x 8 x 8 | FMA operations | 512 ops |
Akunu’s GEMM Dispatch Geometry
Looking at simd_gemm_f16.metal, the dispatch parameters are:
grid = (ceil(N/64), ceil(M/32), 1)
threads = (32, 4, 1) // 128 threads = 4 SIMD groups
TG memory = 6144 bytes // 4096 for A tile + 2048 for B tile
The threadgroup of 128 threads breaks down as:
- 4 SIMD groups (sgitg = 0..3)
- Each SIMD group processes a different row band of the M-tile
- All 4 SIMD groups share the same B-tile data (loaded cooperatively into threadgroup memory)
For the K-loop:
NL0 = NK/16 = 2: each thread loads 2 chunks of B dataNL1 = NK/8 = 4: each thread loads 4 chunks of A data- Data flows: device memory -> threadgroup memory -> SIMD registers -> accumulate -> device memory
Small-M Variant: simd_gemm_small_f16
During decode (M=1 to ~8), the full GEMM is wasteful – most of the 32-row M-tile is empty. Akunu provides simd_gemm_small_f16 which is optimized for M <= 8:
// From dtype_descriptor.h:
return (M >= 2 && M <= 8) ? d.gemm_small_kernel : d.gemm_kernel;
The small variant uses a reduced M-tile size and fewer SIMD groups, trading parallelism for reduced overhead. For M=1 (single-token decode), akunu uses GEMV kernels instead of GEMM entirely – the dispatch geometry is fundamentally different.
GEMV: The Decode Workhorse
During autoregressive decode, M=1. This is a matrix-vector product, not a matrix-matrix product. The access pattern changes dramatically:
- GEMM (prefill): Each weight element is reused across M activation rows. Arithmetic intensity is O(M). With M=4096, you get excellent compute utilization.
- GEMV (decode): Each weight element is used exactly once. Arithmetic intensity is O(1). You are completely bandwidth-bound.5
Akunu’s GEMV kernels do not use simdgroup_matrix at all – they use vectorized loads and SIMD reduction instead. The GEMV kernel for Q4_0 quantized weights, for example, dequantizes blocks of 32 weights, multiplies them by the input vector, and reduces across the SIMD group using simd_sum.
The DTypeDescriptor table maps each quantization format to its GEMV kernel and dispatch geometry:
| Dtype | GEMV Kernel | Rows/TG | TG Size | Wide Variant |
|---|---|---|---|---|
| F16 | gemv_f16 | 16 | 128 | gemv_wide_f16 (64 rows, 256 threads) |
| Q4_0 | gemv_q4_0 | 16 | 128 | gemv_wide_q4_0 (64 rows, 256 threads) |
| Q4_K | gemv_q4_k | 16 | 256 | gemv_wide_q4_k (32 rows, 256 threads) |
| Q8_0 | gemv_q8_0 | 32 | 256 | gemv_wide_q8_0 (64 rows, 256 threads) |
| MLX Q4 | gemv_mlx_q4 | 16 | 128 | gemv_wide_mlx_q4 (32 rows, 256 threads) |
| BF16 | gemv_bf16 | 16 | 128 | N/A |
The “wide” variants use larger threadgroups (256 threads = 8 SIMD groups) to increase occupancy on Pro/Max chips with many GPU cores. The ChipConfig controls when to switch:
c.wide_gemv_threshold = 32768; // use wide GEMV when N exceeds this
The Accumulation Flow: From Tiles to Output
8x8 Tile Accumulation Loop
K=0..7 K=8..15 K=16..23 Result
┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐
│A[8x8]│ │A[8x8]│ │A[8x8]│ │ │
└──┬───┘ └──┬───┘ └──┬───┘ │C[8x8]│
× × × ... │ │
┌──┴───┐ ┌──┴───┐ ┌──┴───┐ │Accum │
│B[8x8]│ │B[8x8]│ │B[8x8]│ └──────┘
└──┬───┘ └──┬───┘ └──┬───┘
+ + + ──────────────►
The accumulation loop for a single output tile works as follows:
- Initialize
mc = 0(zero-filled 8x8 accumulator) - For each K-step (stride 8 or 32):
- Load an 8x8 slice of A into
ma - Load an 8x8 slice of B into
mb mc += ma * mb(hardware multiply-accumulate)
- Load an 8x8 slice of A into
- After K-loop completes, store
mcto output
For K=4096, this loop runs 4096/8 = 512 times (or 4096/32 = 128 times if the K-stride is 32 with 4 sub-steps). Each iteration performs 512 FMA operations. Total: 512 * 512 = 262,144 FMA operations per 8x8 output tile, which is exactly 8 * 8 * 4096 – the correct number.
Quantized GEMM: Dequantize in Registers
For quantized weights (Q4_0, Q4_K, Q8_0, etc.), the GEMM kernels follow the same tiling strategy but add a dequantization step. The weight data is loaded in its quantized format and dequantized into threadgroup memory before the SIMD matrix operations:
For each K-step:
1. Load quantized weight block from device memory
2. Dequantize to FP16 in threadgroup memory
3. Load A tile into threadgroup memory (already FP16)
4. barrier()
5. SIMD matrix multiply-accumulate from threadgroup memory
6. barrier()
Akunu has separate GEMM kernels for each quantization format – simd_gemm_q4_0, simd_gemm_q4_k, simd_gemm_q8_0, etc. – because the dequantization logic is tightly coupled with the load pattern. A Q4_0 block is 18 bytes (16 four-bit values + 1 FP16 scale), while a Q4_K super-block is 144 bytes with nested sub-scales. The memory access pattern and threadgroup memory layout differ for each.
MLX Quantized Formats
Akunu also supports MLX SafeTensors quantized models. These use a different quantization format (affine quantization with configurable group size and bit width) and require function constants to specialize the kernel at pipeline creation time:
// From emit_gemv() in table_builder.cpp:
if (dt.is_mlx) {
uint32_t fc_indices[] = {0, 1};
uint32_t fc_values[] = {(uint32_t)quant_group_size, (uint32_t)K};
pso = device.get_pipeline(kernel_name, cache_key, fc_indices, fc_values, 2);
}
Function constants are Metal’s mechanism for compile-time specialization. Instead of branching on group_size inside the kernel, you bake the value into the pipeline state object. The compiler can then optimize the kernel for that specific group size – unrolling loops, eliminating dead branches, and computing strides at compile time.6
Comparison to NVIDIA Tensor Cores
It is instructive to compare Apple’s SIMD group matrix operations to NVIDIA’s Tensor Cores:
| Feature | Apple simdgroup_matrix | NVIDIA Tensor Cores (Ampere+) |
|---|---|---|
| Tile size | 8x8 | 16x8x16 (varies by precision) |
| Thread group | 32 threads (SIMD group) | 32 threads (warp) |
| Precision | FP16, FP32, BF16 (M4+) | FP16, BF16, TF32, FP8, INT8 |
| Programming model | MSL intrinsics | WMMA/MMA PTX intrinsics |
| Throughput per core | ~1 TFLOPS (estimated) | ~2 TFLOPS per SM (A100) |
| Memory model | UMA shared | Discrete HBM |
| Tensor Core count | Integrated in GPU cores | Dedicated hardware units |
The key architectural difference: NVIDIA Tensor Cores are dedicated hardware units separate from the CUDA cores. Apple’s SIMD matrix operations are executed by the same ALUs that do regular floating-point math – they are an instruction set extension, not a separate unit.7 This means:
-
No mode switching. You can freely interleave matrix operations with scalar/vector code. On NVIDIA, switching between Tensor Core and CUDA core work can cause pipeline bubbles.
-
Lower peak throughput per core. Apple’s matrix multiply throughput is lower than NVIDIA’s dedicated Tensor Cores. But Apple compensates with higher memory bandwidth per FLOP (critical for inference) and the zero-copy UMA advantage.
-
Simpler programming model. The
simdgroup_matrixAPI is genuinely easier to use than NVIDIA’s WMMA or inline PTX MMA instructions. Load, store, multiply-accumulate – that’s it.
BF16 Support on M4
Starting with Apple GPU family 9 (M4), Metal supports native bfloat (BF16) as a first-class type.8 Akunu has dedicated BF16 kernels:
// From dtype_descriptor.h:
{31, "gemv_bf16", nullptr, nullptr,
"simd_gemm_bf16", "simd_gemm_small_bf16",
"embedding_lookup_bf16", ...}
BF16 has the same exponent range as FP32 (8 bits) but with only 7 bits of mantissa (vs FP16’s 10 bits). This makes it better for accumulation-heavy workloads where dynamic range matters more than precision. The simdgroup_matrix API supports simdgroup_bfloat8x8 on M4, enabling hardware-accelerated BF16 matrix multiply.
Fused SiLU + GEMV
One of akunu’s more aggressive optimizations is the fused SiLU+down GEMV kernel. In a standard SwiGLU FFN block:
gate = GEMV(gate_weight, x)
up = GEMV(up_weight, x)
act = SiLU(gate) * up
down = GEMV(down_weight, act)
That’s 3 GEMV dispatches + 1 activation dispatch = 4 kernel launches. The fused kernel combines the last two steps:
down[i] = sum_j( SiLU(gate[j]) * up[j] * down_weight[i,j] )
This reads the gate and up vectors, applies SiLU element-wise, and immediately multiplies by the down weight – all in one kernel, one pass over the down weight matrix. The DTypeDescriptor tracks which formats have fused kernels:
const char *fused_silu_kernel; // "gemv_q4_0_silu", etc.
const char *fused_silu_large_kernel; // Wide variant for Pro+
Not all formats have fused kernels (Q4_K does not, for example), so the table builder falls back to separate activation + GEMV when unavailable.
Practical Performance Numbers
To put all of this in perspective, here is what these kernel choices mean for actual throughput. On an M4 Pro (20 GPU cores, 273 GB/s bandwidth):
| Operation | Kernel Type | Time (7B Q4_0) | Bottleneck |
|---|---|---|---|
| Prefill GEMM (M=512) | simd_gemm_q4_0 | ~15 ms/layer | Compute-bound |
| Decode GEMV (M=1) | gemv_q4_0 | ~0.35 ms/layer | Bandwidth-bound |
| Fused SiLU+Down GEMV | gemv_q4_0_silu | ~0.30 ms (saves ~0.05ms) | Bandwidth-bound |
| Flash Attention (seq=2048) | flash_attention_decode_parallel_f16 | ~0.08 ms/layer | Compute/BW mixed |
The GEMV kernels dominate decode time because they read the most data. Everything else – norms, activations, attention – is comparatively cheap. This is why akunu spends so much effort on GEMV kernel variants, wide vs. standard selection, and fused operations.
Summary
Apple Silicon’s simdgroup_matrix API provides hardware-accelerated 8x8 matrix multiplication that serves as the building block for larger GEMM operations. Akunu uses the llama.cpp tile geometry (TM=32, TN=64) for prefill GEMM and specialized GEMV kernels for decode. The choice between GEMM and GEMV, and between standard and wide variants, is driven by the DTypeDescriptor table and ChipConfig thresholds – all resolved at dispatch table build time, not at runtime.
The next chapter explores the broader set of performance optimization patterns that make these kernels – and the overall inference pipeline – fast.
-
Apple, “Metal Shading Language Specification,” Version 3.2, Section 6.9, “SIMD-group Matrix Functions.” Available at https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩
-
Apple, “Metal Best Practices Guide: SIMD-groups.” A SIMD group on Apple GPU is always 32 threads. This is analogous to NVIDIA’s warp size. See https://developer.apple.com/documentation/metal/compute_passes/creating_threads_and_threadgroups. ↩
-
The
simdgroup_matrixAPI was introduced in Metal 2.4 (iOS 15, macOS 12) and requires Apple GPU family 7 or later (M1+). Earlier Apple GPUs (A-series before A14) have smaller SIMD groups (8 or 16 threads) and do not support matrix operations. See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf. ↩ -
The actual hardware implementation likely uses a systolic array or similar structure within each GPU execution unit. Apple does not disclose microarchitectural details, but the instruction behavior (32 threads cooperatively computing an 8x8x8 MAC) is consistent with a small matrix engine per execution unit. ↩
-
This is the fundamental insight behind the roofline model applied to LLM inference. For M=1 GEMV, the arithmetic intensity is approximately 1 FLOPs/byte (2 FLOPs per weight element, ~2 bytes per weight for FP16). The roofline crossover point on Apple Silicon is typically at M=16-64, depending on the chip. See Williams et al., “Roofline: An Insightful Visual Performance Model for Multicore Architectures,” CACM 2009. See https://doi.org/10.1145/1498765.1498785. ↩
-
Apple, “Using Function Specialization to Build Pipeline Variants.” Function constants allow creating specialized versions of a shader function, enabling the compiler to optimize based on known constant values. See https://developer.apple.com/documentation/metal/using-function-specialization-to-build-pipeline-variants. ↩
-
Dougall Johnson, “Apple GPU Architecture” (reverse-engineering documentation), 2022-2024. Johnson’s analysis shows that Apple GPU ALUs execute SIMD matrix instructions directly, without dedicated tensor hardware. See https://dougallj.github.io/applegpu/. ↩
-
Apple, “What’s New in Metal,” WWDC24. M4 introduces native bfloat16 support in Metal, including
simdgroup_bfloat8x8matrix operations. See https://developer.apple.com/videos/play/wwdc2024/10220/. ↩
Performance Optimization Patterns
We have covered the memory model and the matrix operations. Now we need to talk about how to make them fast. This chapter is a catalog of GPU performance patterns – the techniques that separate a naive kernel from one that saturates the hardware. Each pattern is explained in general terms, then grounded in how akunu applies it. If you have ever stared at a GPU profiler and wondered why your kernel is at 30% utilization, this chapter is for you.
The Roofline Model: Know Your Bottleneck
Before optimizing anything, you need to know what you are optimizing for. The roofline model gives you a simple framework: every kernel is either compute-bound or memory-bandwidth-bound.1
Performance (GFLOPS)
^
| ┌─────────────── Peak Compute
| /
| /
| / <-- Roofline
| /
| /
| /
| /
|──────────/───────────────────── Peak Bandwidth
| /
| /
| /
└──────┴──────────────────────────>
Arithmetic Intensity (FLOPS/byte)
Arithmetic intensity = (total FLOPs) / (total bytes transferred). It tells you how much compute work you do for each byte you read from memory.
- Low arithmetic intensity (e.g., GEMV with M=1): ~1 FLOP/byte. You read a weight, multiply once, move on. You are bandwidth-bound. Adding more compute units does not help.
- High arithmetic intensity (e.g., GEMM with M=512): ~512 FLOPS/byte. You read a weight and reuse it 512 times. You are compute-bound. More bandwidth does not help.
For Apple Silicon, the crossover point depends on the chip:
| Chip | Peak Compute (FP16) | Memory BW | Crossover (FLOPS/byte) |
|---|---|---|---|
| M1 | ~2.6 TFLOPS | 68 GB/s | ~38 |
| M1 Max | ~10.4 TFLOPS | 400 GB/s | ~26 |
| M4 | ~4.3 TFLOPS | 120 GB/s | ~36 |
| M4 Pro | ~8.7 TFLOPS | 273 GB/s | ~32 |
| M4 Max | ~17.4 TFLOPS | 546 GB/s | ~32 |
This means: if your kernel’s arithmetic intensity is below ~32, it is bandwidth-bound on most Apple Silicon. Every single GEMV in the decode path (arithmetic intensity ~1-2) is bandwidth-bound. Only prefill GEMM and attention (with long sequences) are compute-bound.
What This Means for Akunu
The roofline model dictates akunu’s optimization priorities:
- For decode (bandwidth-bound): Minimize bytes read. Use quantization. Fuse operations to avoid re-reading intermediates. Optimize for memory access patterns.
- For prefill (compute-bound): Maximize compute utilization. Use SIMD matrix operations. Maximize occupancy. Optimize threadgroup sizes.
- For attention (mixed): Short sequences are bandwidth-bound (small KV cache); long sequences become compute-bound.
Pattern 1: Vectorized Loads
GPU memory controllers deliver data in large chunks. On Apple Silicon, the memory bus is 128 or 256 bits wide. Reading a single float16 (2 bytes) wastes most of that bus width. Instead, you want to read 4 or 8 elements at once using vector types:
// Slow: scalar loads (2 bytes each)
half val0 = input[tid * 4 + 0];
half val1 = input[tid * 4 + 1];
half val2 = input[tid * 4 + 2];
half val3 = input[tid * 4 + 3];
// Fast: vectorized load (8 bytes at once)
half4 vals = *(device const half4 *)(input + tid * 4);
The vectorized version generates a single memory transaction instead of four. On Apple Silicon, half4 loads are the sweet spot – 8 bytes per load, which matches the register file width.2
In akunu’s GEMV kernels, you will see patterns like:
// Load 16 half values at once (32 bytes)
device const half *x = B_half + row * K + il * 16;
// Thread reads 16 consecutive halves via multiple half4 loads
The simd_gemm_f16.metal kernel has each thread load 16 consecutive half-precision values per K-step, using the thread’s position within the SIMD group to cover different portions of the tile.
Pattern 2: Coalesced Memory Access
Coalescing means that adjacent threads access adjacent memory locations, so the hardware can merge their loads into a single wide memory transaction.
Thread 0 reads address 0x1000
Thread 1 reads address 0x1002
Thread 2 reads address 0x1004
...
Thread 31 reads address 0x103E
=> Hardware merges into ONE 64-byte transaction
Uncoalesced access – where threads read scattered addresses – is catastrophic. Instead of one transaction, you get 32 individual transactions, each wasting bus bandwidth.
Thread 0 reads address 0x1000 // row 0, col 0
Thread 1 reads address 0x4000 // row 1, col 0
Thread 2 reads address 0x7000 // row 2, col 0
=> 32 separate transactions (32x slower)
How Akunu Ensures Coalescing
In akunu’s GEMV kernels, the weight matrix B is stored in row-major order with rows corresponding to output dimensions. Each SIMD group processes a contiguous block of rows. Within the K-reduction loop, threads within a SIMD group read contiguous elements along the K dimension:
Thread 0: B[row][k_base + 0..3]
Thread 1: B[row][k_base + 4..7]
Thread 2: B[row][k_base + 8..11]
...
This is perfectly coalesced along K. For the output (N dimension), different SIMD groups write to different output rows, which are widely separated – but writes are much less frequent than reads, so this is acceptable.
For quantized formats, coalescing is trickier. Q4_0 blocks are 18 bytes: 16 four-bit values + 2 bytes of scale. The block layout is designed so that adjacent threads can read adjacent blocks, maintaining coalescing despite the non-power-of-2 block size.
Pattern 3: Threadgroup Memory and Bank Conflicts
Threadgroup memory (also called shared memory on NVIDIA) is a fast scratchpad local to a threadgroup. On Apple Silicon, it is organized into banks – typically 32 banks of 4 bytes each.3
Bank conflicts occur when multiple threads in a SIMD group access the same bank simultaneously:
// Bank conflict: threads 0 and 16 both access bank 0
shmem[tid * 32] // stride 32 = exact bank width => every access hits bank 0!
// No conflict: threads access consecutive addresses
shmem[tid] // each thread hits a different bank
In akunu’s GEMM kernels, the A and B tiles are loaded into threadgroup memory. The tile layout is carefully chosen to avoid bank conflicts during the subsequent simdgroup_load operations. The sa (A tile) and sb (B tile) pointers are offset:
threadgroup half *sa = shmem;
threadgroup half *sb = shmem + 4096 / sizeof(half);
The A tile gets 4096 bytes (2048 halves) and the B tile gets 2048 bytes (1024 halves), for a total of 6144 bytes per threadgroup. These sizes are chosen to minimize bank conflicts when loading into SIMD matrix registers.
Pattern 4: Loop Unrolling
GPU compilers can unroll loops, but sometimes you need to help. Unrolling reduces loop overhead (branch instructions, index increments) and exposes instruction-level parallelism:
// Before: tight loop, branch every iteration
for (int i = 0; i < 8; i++) {
simdgroup_load(mb, sb + ..., stride);
simdgroup_multiply_accumulate(mc[i], ma, mb, mc[i]);
}
// After (compiler typically does this): 8 loads + 8 MACs, no branches
// The 'constexpr' loop bound helps the compiler unroll
In akunu’s GEMM kernel, the inner K-loop has a stride of NK=32, with 4 sub-steps of 8 elements each. The sub-steps process 8 output tiles per SIMD group. The compiler unrolls both the sub-step loop and the tile loop because the bounds are compile-time constants.
Pattern 5: Function Constants (Metal Specialization)
Metal’s function constants are a form of compile-time specialization that lets you create optimized kernel variants without code duplication:
constant uint FC_GEMM_K [[function_constant(10)]];
constant bool FC_GEMM_K_SPECIALIZED = is_function_constant_defined(FC_GEMM_K);
// In the kernel:
const uint K_dim = FC_GEMM_K_SPECIALIZED ? FC_GEMM_K : K;
When FC_GEMM_K is defined, the compiler knows the K dimension at compile time and can:
- Unroll the K-loop completely for small K
- Eliminate bounds checking
- Pre-compute strides and offsets
- Optimize register allocation
Akunu uses function constants extensively for MLX quantized kernels, where group_size and K are baked into the pipeline state object:
uint32_t fc_indices[] = {0, 1};
uint32_t fc_values[] = {(uint32_t)quant_group_size, (uint32_t)K};
pso = device.get_pipeline(kernel_name, cache_key, fc_indices, fc_values, 2);
Each unique (group_size, K) combination gets a specialized pipeline. These are cached in MetalDevice::pso_cache_ so the specialization cost is paid once at model load time.4
Pattern 6: Kernel Fusion
Kernel fusion combines multiple operations into a single GPU dispatch. Each dispatch has overhead: pipeline binding, buffer binding, dispatch command encoding, and GPU scheduling. More importantly, each dispatch boundary forces intermediate results to be written to and re-read from global memory.
Fusions in Akunu
Akunu applies several fusions:
1. Residual + RMSNorm Fusion
Instead of separate residual_add and rmsnorm dispatches:
// Unfused: 3 dispatches, 3 reads + 3 writes
residual = input + skip_connection // read input, skip; write residual
norm_input = rmsnorm(residual, weight) // read residual, weight; write norm_input
Akunu uses residual_rmsnorm_f16:
// Fused: 1 dispatch, reads input + skip + weight, writes norm_output + updated residual
residual_rmsnorm_f16(input, skip, weight, norm_output, residual, params)
This saves two kernel launches and two round-trips to global memory.
2. SiLU(gate) * up + Down GEMV Fusion
As discussed in the previous chapter, the fused SiLU+down kernel combines activation and projection:
// Unfused: 2 dispatches
act = SiLU(gate) * up // read gate, up; write act
down = GEMV(down_weight, act) // read act, down_weight; write down
// Fused: 1 dispatch
down = fused_silu_gemv(gate, up, down_weight) // reads gate, up, weight; writes down
This eliminates the intermediate act buffer write and re-read.
3. QK-Norm + RoPE + KV Cache Write Fusion
For architectures with per-head Q/K norms (Qwen3, Gemma), akunu fuses head normalization, rotary position encoding, and KV cache writes into a single kernel:
// From table_builder.cpp:
Pipeline fused_pso = device.get_pipeline("head_norm_rope_neox_kv_write_f16");
This replaces 3-4 separate dispatches with one, which is especially impactful because these operations are tiny (operating on a single head at a time) and the dispatch overhead would dominate.
4. QKV Projection Fusion
When Q, K, and V weight matrices share the same dtype and the SLC is large enough to benefit:
bool fuse_qkv = (chip.should_fuse_weights || is_mlx) && q_dtype == k_dtype && k_dtype == v_dtype;
if (fuse_qkv) {
Buffer fused_w = weights.fuse_weights(q_name, k_name, v_name);
gemv(scratch.residual, fused_w, scratch.qkv, 0, q_dtype, qkv_total, dim);
}
Three GEMV dispatches become one, reading the fused weight matrix once instead of three times.
5. Gate + Up Projection Fusion
Same principle applied to the FFN gate and up projections:
bool fuse_gate_up = (chip.should_fuse_weights || gate_is_mlx) && (gate_dtype == up_dtype);
if (fuse_gate_up) {
Buffer fused_gate_up_w = weights.fuse_weights(gate_name, up_name);
gemv(scratch.attn_out, fused_gate_up_w, scratch.ffn_gate, 0, gate_dtype, 2 * ffn_dim, dim);
}
When NOT to Fuse
Fusion is not always beneficial. The SLC-gated fusion decisions in akunu illustrate this: on chips with small SLC (M1 base, 8 MB estimated), the fused QKV weight matrix may be too large to fit in cache, causing more cache thrashing than the unfused version. The ChipConfig::should_fuse_weights flag controls this:
c.should_fuse_weights = (c.slc_bytes >= 16 * 1024 * 1024); // Pro+ and M4 Base
Pattern 7: Occupancy and Threadgroup Sizing
Occupancy is the ratio of active threads to the maximum the GPU can support simultaneously. Higher occupancy generally means better latency hiding – when one SIMD group stalls on a memory access, another can execute.
On Apple Silicon, each GPU core can run multiple threadgroups concurrently (the exact limit depends on register pressure and threadgroup memory usage). The threadgroup size directly affects occupancy:
| Threadgroup Size | SIMD Groups | Typical Use |
|---|---|---|
| 32 | 1 | Very light kernels (argmax) |
| 128 | 4 | Standard GEMV (4 SIMD groups) |
| 256 | 8 | Wide GEMV, standard GEMM |
| 1024 | 32 | Flash attention, RMSNorm |
Akunu’s ChipConfig controls threadgroup sizing for normalization kernels:
c.norm_tg_size = 1024; // max threads for RMSNorm
c.max_threads_per_tg = 1024; // Metal's maximum
For RMSNorm on a 4096-dimensional model, the threadgroup size is min(4096, 1024) = 1024. All 1024 threads participate in the reduction (computing the root-mean-square), with the final result broadcast to all threads for the normalization step.
For GEMV, the threadgroup size trades off between parallelism and overhead:
- 128 threads (4 SIMD groups): Standard GEMV. Good for small chips with limited cores.
- 256 threads (8 SIMD groups): Wide GEMV. Better occupancy on Pro+ chips with many cores. Controlled by
chip.gemv_wide_standard.
Pattern 8: Avoiding Redundant Work with Pre-Computation
The most efficient computation is the one you do not do. Akunu pre-computes everything possible at model load time:
-
Pipeline state objects are created and cached in
pso_cache_. No pipeline compilation during inference. -
Buffer bindings are resolved once in the dispatch table. The
DispatchCmdstores actualBufferhandles and byte offsets, not symbolic names. -
Kernel parameters (dimensions, epsilon values, strides) are stored as raw bytes in
DispatchCmd::param_bytes[64]. For static params, a GPU buffer is pre-allocated:
// From table_builder.cpp, end of build_dispatch_table():
for (auto& cmd : cmds) {
if (cmd.param_size > 0) {
cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
}
}
This means the dispatch table replay loop does almost zero work per command beyond Metal API calls. No string lookups, no hash table queries, no conditional logic.
-
RoPE frequencies can be pre-computed into a GPU buffer (
arch.rope_freqs) rather than computed per-token. -
Dispatch geometry (grid size, threadgroup size) is computed once and stored in
DispatchCmd::gridandDispatchCmd::threadgroup.
Pattern 9: Minimizing Command Buffer Overhead
Each Metal command buffer has submission overhead: the CPU must package the commands, the GPU command processor must parse them, and there is a synchronization cost at completion. Akunu minimizes this in two ways:
1. Chain Decode: Many Tokens Per Command Buffer
Instead of one command buffer per token, akunu’s chain decode batches multiple tokens into a single submission:
// ChipConfig determines batch size:
c.chain_decode_chunk = 128; // M4: 128 tokens per submission
c.chain_decode_chunk = 64; // M1 base: 64 tokens per submission
One command buffer encodes the full dispatch table N times (once per token), with per-token patching of position and offsets. This amortizes the command buffer overhead across many tokens.
2. Unretained References
Akunu uses commandBufferWithUnretainedReferences, which tells Metal not to retain buffer references. This avoids atomic reference counting on every setBuffer call – a significant savings when a single command buffer contains thousands of buffer bindings.5
3. Event-Based Pipelining
Akunu supports overlapping GPU execution with CPU encoding using MTLSharedEvent:
// GPU signals event after completing
[cmdBuffer encodeSignalEvent:pipelineEvent value:signalVal];
[cmdBuffer commit];
// Next command buffer waits for the event (GPU-GPU sync, no CPU involvement)
[nextCmdBuffer encodeWaitForEvent:pipelineEvent value:eventValue];
This allows the CPU to encode the next batch of tokens while the GPU is still processing the current one. The event-based synchronization is GPU-to-GPU, avoiding a CPU round-trip.
Pattern 10: The setBuffer vs setBytes Split
This pattern is specific to akunu’s dispatch table replay and deserves its own section. During chain decode, the same commands are repeated for each token. Most parameters are identical across tokens – only the position-dependent fields change.
Akunu splits parameters into two categories:
| Category | Mechanism | Per-Token Cost | Example |
|---|---|---|---|
| Static params | setBuffer() with pre-allocated GPU buffer | Zero (buffer already bound) | GEMV dimensions, strides |
| Position-patched params | setBytes() with inline patching | ~64 bytes memcpy + encode | RoPE position, KV seq length |
This is visible in the encode_chain fast path. For each command, the encoder patches position-dependent parameters in-place and calls set_bytes to inline the (small, <64 byte) parameter data into the command buffer:
for (int i = 0; i < n; i++) {
for (auto& cmd : table.commands) {
device.set_pipeline(cmd.pso);
// Bind buffers (static — same every token)
for (int b = 0; b < cmd.buffer_count; b++)
device.set_buffer(cmd.buffers[b], cmd.offsets[b], b);
// Patch and set parameters (only position/kv_len change per token)
if (cmd.patch_type != PATCH_NONE)
apply_patch(cmd, pos + i);
device.set_bytes(cmd.param_bytes, cmd.param_size, cmd.param_index);
device.dispatch(cmd.grid, cmd.threadgroup);
}
}
For a typical 32-layer model, the dispatch table has ~260 commands. Only the RoPE, attention, and argmax commands need per-token patching (~100 commands). The rest pass through with unchanged param_bytes — but all use set_bytes (not setBuffer) since the parameter data is always small enough (<64 bytes) to inline.
Profiling Tools
Knowing these patterns is only useful if you can measure their impact. Apple provides several profiling tools:
Xcode GPU Debugger
Capture a Metal frame and inspect:
- Per-dispatch GPU time
- Memory bandwidth utilization
- Occupancy
- Wait time (stalls)
Metal System Trace (Instruments)
Part of Instruments.app. Shows:
- Command buffer submission and completion timeline
- GPU utilization over time
- CPU-GPU synchronization points
- Memory allocation events
akunu_profile
Akunu includes a profiling tool that uses per-layer command buffers to get GPU timing for each operation:
// From akunu.h:
int akunu_profile_decode_step(akunu_model_t model, uint32_t token_id,
int position, float *timing_out, int max_entries);
const char *akunu_profile_label(akunu_model_t model, int index);
This runs each layer in its own command buffer (much slower than normal inference) but gives you per-operation GPU timing. The output looks like:
embedding 0.012 ms
layer.0.attention 0.045 ms
layer.0.rope_kv_write 0.008 ms
layer.0.attention 0.082 ms
layer.0.o_proj 0.041 ms
layer.0.fused_ffn_norm 0.006 ms
layer.0.gate_up_proj 0.078 ms
layer.0.ffn 0.043 ms
...
Counter Sampling
Metal supports GPU hardware counter sampling for detailed performance analysis. You can measure:
- ALU utilization
- Memory read/write bytes
- Cache hit rates
- Occupancy percentages
These are available through MTLCounterSampleBuffer and are essential for diagnosing whether a kernel is compute-bound or bandwidth-bound.6
Putting It All Together: A Single Layer’s Performance Profile
Let’s trace through one transformer layer during decode (M=1) on M4 Pro and identify the bottleneck for each operation:
| Operation | Kernel | Bottleneck | Time | Notes |
|---|---|---|---|---|
| Attention Norm | residual_rmsnorm_f16 | BW (dim reads) | 0.006 ms | Light: 4096 elements |
| Q/K/V GEMV (fused) | gemv_q4_0 | BW (weight read) | 0.12 ms | Reads ~6 MB fused QKV weight |
| RoPE + KV Write | rope_qkv_write_f16 | BW | 0.008 ms | Light: head_dim/2 elements |
| Flash Attention | flash_attn_decode_parallel | Mixed | 0.04-0.1 ms | Depends on seq length |
| O Projection | gemv_q4_0 | BW | 0.04 ms | 4096x4096 weight |
| FFN Norm | residual_rmsnorm_f16 | BW | 0.006 ms | Light |
| Gate+Up GEMV (fused) | gemv_q4_0 | BW | 0.15 ms | Reads ~12 MB fused gate+up |
| Fused SiLU+Down | gemv_q4_0_silu | BW | 0.04 ms | 4096x14336 weight |
| Total per layer | ~0.43 ms |
For 32 layers: ~13.8 ms/token = ~72 tok/s. This is close to the theoretical bandwidth limit of ~70 tok/s we computed in the memory chapter, confirming that decode is bandwidth-saturated.
Common Mistakes
Mistake 1: Optimizing Compute for a Bandwidth-Bound Kernel
If your GEMV kernel is at 95% bandwidth utilization and 10% compute utilization, making the math faster will not help. You need to reduce the number of bytes read (quantize to lower bits, fuse operations to eliminate intermediate buffers).
Mistake 2: Tiny Threadgroups
Using a threadgroup of 32 threads for a GEMV kernel means only 1 SIMD group per threadgroup. The GPU core has no threads to switch to when this SIMD group stalls on memory. Use at least 128 threads (4 SIMD groups) for any memory-heavy kernel.
Mistake 3: Forgetting Threadgroup Barriers
In GEMM kernels that use threadgroup memory, forgetting threadgroup_barrier(mem_flags::mem_threadgroup) between writing to and reading from shared memory causes data races. The barrier ensures all threads in the threadgroup have completed their writes before any thread reads.
Mistake 4: Over-Fusing
Fusing too many operations into one kernel can increase register pressure, reducing occupancy and hurting performance. If a fused kernel needs more registers than the hardware can provide, the GPU will “spill” registers to memory, destroying performance. The separate activation + GEMV fallback in akunu exists for exactly this reason.
Summary
Performance optimization on Apple Silicon GPU follows a clear decision tree:
- Is the kernel bandwidth-bound or compute-bound? Use the roofline model.
- If bandwidth-bound: Reduce bytes (quantize, fuse), improve access patterns (coalesce, vectorize).
- If compute-bound: Maximize utilization (occupancy, SIMD matrix ops, loop unrolling).
- Always: Pre-compute everything possible, minimize dispatch overhead, use function constants for specialization.
Akunu applies all of these patterns systematically, with the ChipConfig and DTypeDescriptor tables encoding the chip-specific and dtype-specific tuning decisions. The dispatch table pre-resolves everything at model load time so the hot path is a tight loop of Metal API calls.
-
Williams, S., Waterman, A., & Patterson, D. (2009). “Roofline: An Insightful Visual Performance Model for Multicore Architectures.” Communications of the ACM, 52(4), 65-76. The roofline model remains the most effective tool for classifying GPU kernel performance. See https://doi.org/10.1145/1498765.1498785. ↩
-
Apple, “Metal Best Practices Guide: Optimize Memory Accesses,” 2024. Vectorized loads are recommended for maximizing memory throughput. See https://developer.apple.com/documentation/xcode/analyzing-the-performance-of-your-metal-app. ↩
-
The exact bank configuration on Apple Silicon is not officially documented. The 32-bank / 4-byte-per-bank configuration is inferred from performance profiling and is consistent with other GPU architectures. See Dougall Johnson’s Apple GPU documentation: https://dougallj.github.io/applegpu/. ↩
-
Apple, “Using Function Specialization to Build Pipeline Variants.” Function constants are the recommended way to create specialized shader variants without preprocessor macros. See https://developer.apple.com/documentation/metal/using-function-specialization-to-build-pipeline-variants. ↩
-
Apple, “commandBufferWithUnretainedReferences Documentation.” Unretained references eliminate atomic retain/release overhead but require manual lifetime management. See https://developer.apple.com/documentation/metal/mtlcommandqueue/1508684-makecommandbufferwithunretainedr. ↩
-
Apple, “Optimizing Performance with the GPU Counters Instrument,” 2024. GPU hardware counters provide per-kernel metrics including ALU utilization, memory bandwidth, and cache hit rates. See https://developer.apple.com/documentation/xcode/analyzing-the-performance-of-your-metal-app. ↩
Tensors and Linear Algebra on GPU
If you have spent any time reading ML papers or browsing inference codebases, you have seen the word “tensor” more times than you can count. It gets thrown around so casually that you might suspect it means something terribly complicated. It does not. At its heart, a tensor in the machine learning context is just a multi-dimensional array of numbers. That is it. A 1D tensor is a vector. A 2D tensor is a matrix. A 3D tensor is… well, a 3D array. And so on.
But here is the thing: while the concept is simple, the implementation on a GPU is where all the interesting engineering lives. How do you lay out a 4-dimensional array in a flat slab of GPU memory? How do you make sure that when 3,000 threads all try to read different elements at once, the memory system does not choke? How do you express broadcasting rules in terms of pointer arithmetic? These are the questions that determine whether your inference engine runs at 10 tokens per second or 100.
In this chapter, we will build your intuition for tensors from the ground up, then show exactly how they map to Metal GPU buffers. By the end, you will understand the memory layout decisions that dominate every kernel we write in later chapters.
What Is a Tensor, Really?
Let us start with the basics and build up.
Scalars, Vectors, Matrices, and Beyond
Rank 0 (Scalar): 42.0
Rank 1 (Vector): [1.0, 2.0, 3.0, 4.0]
Rank 2 (Matrix): [[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]
Rank 3 (3D Tensor): [[[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0]],
[[7.0, 8.0],
[9.0, 10.0],
[11.0, 12.0]]]
The rank (or order) of a tensor is simply how many indices you need to pick out a single element:
- Rank 0: zero indices. It is just a number.
- Rank 1: one index.
v[i]gives you elementiof a vector. - Rank 2: two indices.
M[i][j]gives you the element at rowi, columnj. - Rank 3: three indices.
T[i][j][k]picks one element from a 3D block.
The shape describes how many elements exist along each dimension. For the rank-3 tensor above, the shape is (2, 3, 2) – two “slices,” each containing 3 rows of 2 elements.
Tensors in Transformers
In a typical transformer model, you encounter tensors of various ranks constantly:
Token embeddings: shape (seq_len, d_model) Rank 2
Weight matrix: shape (d_model, d_out) Rank 2
Attention scores: shape (n_heads, seq_len, seq_len) Rank 3
Batched activations: shape (batch, seq_len, d_model) Rank 3
Multi-head QKV: shape (batch, n_heads, seq, d_k) Rank 4
For inference on a single sequence (batch size 1), we can often drop the batch dimension. But even without batching, we regularly work with rank-2 and rank-3 tensors.
Here is a concrete example. Suppose we have a small model with:
d_model = 512(embedding dimension)n_heads = 8(number of attention heads)d_k = 64(per-head dimension, since 512 / 8 = 64)seq_len = 128(sequence length)
The query tensor Q after projection would have shape (128, 512), and after reshaping for multi-head attention, it becomes (8, 128, 64).
Memory Layouts: How Multi-Dimensional Data Lives in Flat Memory
Here is the fundamental problem: GPU memory (and CPU memory, for that matter) is a flat, one-dimensional address space. You have addresses 0, 1, 2, 3, and so on. But your tensor is multi-dimensional. So you need a rule for mapping multi-dimensional indices to flat memory addresses.
Row-Major Order (C-Style)
The most common layout, and the one used by virtually all ML frameworks and inference engines, is row-major order. The idea is simple: you lay out the last dimension contiguously, then the second-to-last, and so on.
For a 2D matrix with shape (3, 4):
Logical view: Memory layout (row-major):
col 0 col 1 col 2 col 3
row 0 [ a b c d ] addr 0: a
row 1 [ e f g h ] addr 1: b
row 2 [ i j k l ] addr 2: c
addr 3: d
addr 4: e
addr 5: f
addr 6: g
addr 7: h
addr 8: i
addr 9: j
addr 10: k
addr 11: l
The elements of row 0 are contiguous in memory (a, b, c, d), followed by row 1 (e, f, g, h), followed by row 2 (i, j, k, l). The last index changes fastest as you walk through memory.
To find element M[i][j] in a (R, C) matrix:
address(i, j) = base + (i * C + j) * element_size
Column-Major Order (Fortran-Style)
The alternative is column-major order, where the first dimension is contiguous. BLAS libraries and MATLAB use this. The same (3, 4) matrix in column-major:
Memory layout (column-major):
addr 0: a (row 0, col 0)
addr 1: e (row 1, col 0)
addr 2: i (row 2, col 0)
addr 3: b (row 0, col 1)
addr 4: f (row 1, col 1)
addr 5: j (row 2, col 1)
addr 6: c (row 0, col 2)
addr 7: g (row 1, col 2)
addr 8: k (row 2, col 2)
addr 9: d (row 0, col 3)
addr 10: h (row 1, col 3)
addr 11: l (row 2, col 3)
Here the first index changes fastest. To find element M[i][j]:
address(i, j) = base + (j * R + i) * element_size
Why Does This Matter for GPU Performance?
Consider what happens when 32 threads in a SIMD group all need to read one element each. If thread t reads element M[row][t] and the matrix is row-major, all 32 threads read from consecutive memory addresses. The GPU’s memory system can satisfy this with a single coalesced memory transaction.
Row-major, threads reading M[row][0..31]:
Memory: [ M[r][0] | M[r][1] | M[r][2] | ... | M[r][31] | M[r][32] | ... ]
^ ^ ^ ^
thread 0 thread 1 thread 2 thread 31
All addresses are contiguous --> ONE memory transaction (coalesced)
But if the matrix were column-major and threads tried to read M[row][0..31], each thread would be reading from addresses that are R elements apart:
Column-major, threads reading M[row][0..31]:
Memory: [ M[0][0] | M[1][0] | M[2][0] | ... | M[0][1] | M[1][1] | ... ]
Thread 0 reads M[row][0] at offset row
Thread 1 reads M[row][1] at offset R + row
Thread 2 reads M[row][2] at offset 2R + row
...
Addresses are R elements apart --> MANY memory transactions (strided)
This can be 10-30x slower depending on the stride. Coalesced access is one of the single most important performance considerations on any GPU.
Strides: The General Case
The concept of strides generalizes memory layout. A stride tells you how many elements to skip in memory when you increment one index by one.
For a row-major matrix of shape (R, C):
- Stride along dimension 0 (rows):
C(skip one whole row ofCelements) - Stride along dimension 1 (columns):
1(elements are adjacent)
For a row-major 3D tensor of shape (D0, D1, D2):
- Stride along dimension 0:
D1 * D2 - Stride along dimension 1:
D2 - Stride along dimension 2:
1
The general address formula for any tensor:
address(i0, i1, ..., in) = base + sum(ik * stride_k) for k in 0..n
Here is a concrete example. For a tensor of shape (2, 3, 4) in row-major:
Strides: (12, 4, 1)
Because:
stride[2] = 1
stride[1] = shape[2] = 4
stride[0] = shape[1] * shape[2] = 3 * 4 = 12
Element T[1][2][3]:
address = 0 + 1*12 + 2*4 + 3*1 = 12 + 8 + 3 = 23
Let us verify by counting:
Slice 0: [[ 0 1 2 3] elements 0-11
[ 4 5 6 7]
[ 8 9 10 11]]
Slice 1: [[12 13 14 15] elements 12-23
[16 17 18 19]
[20 21 22 23]] <-- T[1][2][3] = element 23. Correct!
Strides also let you express interesting things like transposition without copying data. If you have a matrix with strides (C, 1), its transpose has strides (1, C) – same data, just different stride interpretation.
How Tensors Map to GPU Buffers
On Metal, a tensor lives in an MTLBuffer – a contiguous block of GPU-accessible memory. There is no built-in “tensor” type in Metal. The buffer is just bytes. Your shader code interprets those bytes as a tensor by computing offsets from indices using the stride formula.
The Buffer Layout
MTLBuffer (contiguous memory):
+----+----+----+----+----+----+----+----+----+----+----+----+
| e0 | e1 | e2 | e3 | e4 | e5 | e6 | e7 | e8 | e9 |e10 |e11 |
+----+----+----+----+----+----+----+----+----+----+----+----+
^ ^
buffer.contents() buffer end
Each element is sizeof(element_type) bytes:
float = 4 bytes
half = 2 bytes
int8_t = 1 byte
For a half-precision (2, 3, 2) tensor:
Total elements = 2 * 3 * 2 = 12
Total bytes = 12 * 2 = 24 bytes
Computing Offsets in a Metal Shader
In a Metal compute shader, you typically receive the buffer pointer and the tensor’s metadata (shape, strides) as arguments. Here is what the pattern looks like:
kernel void elementwise_relu(
device const half* input [[buffer(0)]],
device half* output [[buffer(1)]],
constant uint3& shape [[buffer(2)]], // (D0, D1, D2)
constant uint3& strides [[buffer(3)]], // (S0, S1, S2)
uint3 gid [[thread_position_in_grid]]
) {
// Bounds check
if (gid.x >= shape.x || gid.y >= shape.y || gid.z >= shape.z) return;
// Compute flat index from multi-dimensional position
uint idx = gid.x * strides.x + gid.y * strides.y + gid.z * strides.z;
// Read, compute, write
half val = input[idx];
output[idx] = val > 0 ? val : 0;
}
The key insight: the GPU hardware knows nothing about tensors. It only knows about flat memory addresses. All the multi-dimensional indexing is just arithmetic that we do in the shader.
Alignment Considerations
Metal has specific alignment requirements for buffer access:
Type Size Alignment
float 4B 4B
half 2B 2B
float4 16B 16B
half4 8B 8B
Rule of thumb: access addresses that are multiples of the type size.
When packing tensor data into buffers, you generally want each row to start at an aligned address. For a matrix of half values with 127 columns, each row takes 254 bytes. The next row starts at byte 254, which is 2-byte aligned (fine for half). But if you wanted to read rows as half4 vectors (common optimization), you would want 8-byte alignment, which means padding rows to 128 elements (256 bytes).
Multiple Tensors in One Buffer (Offset Packing)
In practice, inference engines often pack multiple tensors into a single large buffer with offsets:
Single MTLBuffer:
+------------------+------------------+------------------+
| Weight matrix | Bias vector | LayerNorm scale |
| (4096 x 4096) | (4096) | (4096) |
| offset: 0 | offset: 33MB | offset: 33MB+8K |
+------------------+------------------+------------------+
This reduces the number of buffer bindings needed (Metal has a limit on how many buffers you can bind to a single kernel invocation) and can improve memory allocation efficiency.
Common Tensor Operations
Now let us look at the operations that transformers actually perform on tensors and how they parallelize on a GPU.
Element-wise Operations
These apply a function independently to each element. The output has the same shape as the input. Examples:
- ReLU:
output[i] = max(0, input[i]) - SiLU (Swish):
output[i] = input[i] * sigmoid(input[i]) - GELU:
output[i] = input[i] * 0.5 * (1 + erf(input[i] / sqrt(2))) - Addition:
output[i] = a[i] + b[i] - Scalar multiply:
output[i] = alpha * input[i]
Parallelization is trivial – assign one thread per element:
Input tensor: [a0, a1, a2, a3, a4, a5, a6, a7]
Thread 0 --> processes a0
Thread 1 --> processes a1
Thread 2 --> processes a2
...
Thread 7 --> processes a7
Each thread:
1. Read input[thread_id]
2. Apply function
3. Write output[thread_id]
No synchronization needed! Each thread is independent.
For a tensor with N elements, you dispatch ceil(N / threads_per_threadgroup) threadgroups, each with threads_per_threadgroup threads (commonly 256 or 1024).
Dispatch for N = 10000, threads_per_group = 256:
Threadgroups needed = ceil(10000 / 256) = 40 threadgroups
Threadgroup 0: threads 0-255 --> elements 0-255
Threadgroup 1: threads 256-511 --> elements 256-511
...
Threadgroup 38: threads 9728-9983 --> elements 9728-9983
Threadgroup 39: threads 9984-10239 --> elements 9984-9999 (rest out of bounds)
Reduction Operations
Reductions collapse one or more dimensions by aggregating elements. Examples:
- Sum along a dimension:
output[i] = sum(input[i][j] for all j) - Max along a dimension:
output[i] = max(input[i][j] for all j) - Mean: sum divided by count
These are trickier to parallelize because threads need to cooperate to produce the result.
Strategy 1: One threadgroup per reduction
Matrix shape (4, 8), reduce along columns (dim 1):
Row 0: [a0 a1 a2 a3 a4 a5 a6 a7] --> sum = a0+a1+...+a7
Row 1: [b0 b1 b2 b3 b4 b5 b6 b7] --> sum = b0+b1+...+b7
Row 2: [c0 c1 c2 c3 c4 c5 c6 c7] --> sum = c0+c1+...+c7
Row 3: [d0 d1 d2 d3 d4 d5 d6 d7] --> sum = d0+d1+...+d7
Assign one threadgroup per row.
Within each threadgroup, threads cooperate to sum the row:
Threadgroup 0 (8 threads for row 0):
Step 1: Each thread loads one element
t0=a0, t1=a1, t2=a2, t3=a3, t4=a4, t5=a5, t6=a6, t7=a7
Step 2: Parallel reduction tree
t0 = a0+a4 t1 = a1+a5 t2 = a2+a6 t3 = a3+a7
t0 = (a0+a4)+(a2+a6) t1 = (a1+a5)+(a3+a7)
t0 = total sum
Step 3: Thread 0 writes result
Strategy 2: SIMD group reduction
On Metal, SIMD groups of 32 threads have special fast operations for reduction. The simd_sum() function sums a value across all threads in the SIMD group in a single instruction cycle:
SIMD group reduction (32 threads):
Before simd_sum:
t0=3.0 t1=1.0 t2=4.0 t3=1.0 ... t31=2.0
After val = simd_sum(val):
t0=sum t1=sum t2=sum t3=sum ... t31=sum
(all threads have the same total sum)
For reductions over dimensions larger than 32, you use a combination: each SIMD group reduces its portion, then results are combined using threadgroup memory.
Broadcasting
Broadcasting lets you operate on tensors of different shapes by “stretching” the smaller tensor to match the larger one, conceptually. No data is actually copied.
Example: Add a bias vector to every row of a matrix
Matrix A shape: (1024, 512)
Bias b shape: (512,)
Result C[i][j] = A[i][j] + b[j] for all i, j
The bias is "broadcast" along dimension 0:
A: b: C:
[[a00 a01 ... a0,511] [b0 b1 ... b511] [[a00+b0 a01+b1 ...]
[a10 a11 ... a1,511] + = [a10+b0 a11+b1 ...]
... ...
[a1023,0 ... ]] [a1023,0+b0 ... ]]
In the shader, broadcasting is implemented by simply not advancing the index along the broadcast dimension:
// Adding bias (shape: [N]) to matrix (shape: [M, N])
kernel void add_bias(
device const half* matrix [[buffer(0)]],
device const half* bias [[buffer(1)]],
device half* output [[buffer(2)]],
uint2 gid [[thread_position_in_grid]] // (row, col)
) {
uint row = gid.x;
uint col = gid.y;
uint idx = row * N + col;
output[idx] = matrix[idx] + bias[col]; // bias indexed only by col
}
The bias[col] access ignores the row entirely. That is broadcasting – the same bias element b[j] is used for every row.
Parallelizing Tensor Operations
The general strategies for mapping tensor operations to GPU threads fall into a few patterns. Let us formalize them.
Pattern 1: One Thread Per Element
Used for element-wise operations.
Tensor shape: (M, N)
Grid: (M, N)
Each thread (i, j) processes element [i][j]
+-------+-------+-------+-------+
| t(0,0)| t(0,1)| t(0,2)| t(0,3)|
+-------+-------+-------+-------+
| t(1,0)| t(1,1)| t(1,2)| t(1,3)|
+-------+-------+-------+-------+
| t(2,0)| t(2,1)| t(2,2)| t(2,3)|
+-------+-------+-------+-------+
Dispatch: grid_size = (M, N, 1)
threadgroup_size = (16, 16, 1) // 256 threads per group
Pattern 2: One Threadgroup Per Row (or Column)
Used for reductions, softmax, layer normalization – anything that operates across one dimension.
Tensor shape: (M, N)
Grid: M threadgroups, each with T threads
Threadgroup 0 --> Row 0: [e0 e1 e2 ... eN-1]
Threadgroup 1 --> Row 1: [e0 e1 e2 ... eN-1]
...
Threadgroup M-1 --> Row M-1
Within each threadgroup, T threads cooperate:
Thread t handles elements: t, t+T, t+2T, ...
For N=1024, T=256:
Thread 0: elements 0, 256, 512, 768
Thread 1: elements 1, 257, 513, 769
...
Pattern 3: Tiled Processing
Used for matrix multiplication and attention. Threads in a threadgroup cooperatively load tiles into fast threadgroup memory, then compute.
Matrix multiply C = A * B
A: (M, K), B: (K, N), C: (M, N)
Tile size: (Tm, Tn) per threadgroup, iterate over K in tiles of Tk
Threadgroup (gi, gj) computes C[gi*Tm..(gi+1)*Tm][gj*Tn..(gj+1)*Tn]
B (K x N)
+---+---+---+---+
| | Bj| | | Bj = tile column j
+---+---+---+---+
| | Bj| | |
A +---+---+---+---+
(M x K)
+--+--+ +---+
| |Ai| x | Bj| = Cij (Tm x Tn tile of output)
+--+--+ +---+
| | |
+--+--+
Ai = tile row i of A
For each k-tile: load Ai[:,k:k+Tk] and Bj[k:k+Tk,:] into shared memory
compute partial products
accumulate into Cij registers
We will explore tiled matrix multiplication in enormous detail in the next chapter.
Pattern 4: SIMD Group per Work Unit
Metal’s SIMD groups (wavefronts) of 32 threads have hardware support for fast intra-group communication. Many kernels assign one SIMD group to one logical unit of work:
GEMV: One SIMD group per output element (or a few elements)
Output vector y = W * x (W is M x K, x is K x 1, y is M x 1)
SIMD group 0 --> y[0] = dot(W[0,:], x)
SIMD group 1 --> y[1] = dot(W[1,:], x)
...
Within each SIMD group:
32 threads divide the K-length dot product
Thread t handles indices: t, t+32, t+64, ...
Each thread accumulates a partial sum
simd_sum() gives the total
The Transformer’s Core Operations
A transformer model is built from a surprisingly small set of tensor operations, repeated many times. Let us catalog them and understand their computational characteristics.
Linear Projection (GEMM / GEMV)
The workhorse of the transformer. Every layer has multiple linear projections:
In one transformer block:
Q = X * Wq + bq (query projection)
K = X * Wk + bk (key projection)
V = X * Wv + bv (value projection)
O = Attn * Wo + bo (output projection)
H = X * W1 + b1 (FFN first layer)
Y = H * W2 + b2 (FFN second layer)
That's 6 matrix multiplications per layer!
A 32-layer model does 192 matmuls per forward pass.
The computation is Y = X * W where:
- During prefill: X is
(seq_len, d_model), W is(d_model, d_out)– a full GEMM - During decode: X is
(1, d_model), W is(d_model, d_out)– a GEMV
Prefill (GEMM): Decode (GEMV):
X (128 x 4096) x (1 x 4096)
+----------+ +----------+
| | | |
| | W (4096 x 4096) +----------+ W (4096 x 4096)
| | +----------+ +----------+
| | x | | x | |
| | | | | |
| | | | | |
+----------+ +----------+ +----------+
= =
Y (128 x 4096) y (1 x 4096)
+----------+ +----------+
| | | |
| | +----------+
| |
+----------+
The GEMV case (decode) is memory-bandwidth bound because you read the entire weight matrix just to produce one output row. The GEMM case (prefill) is compute bound because you amortize the weight read across many input rows.
Attention (Batched Dot Products + Softmax)
The attention mechanism computes, for each query position, a weighted average of all value positions:
scores = Q * K^T / sqrt(d_k) Matrix multiply: (seq, d_k) x (d_k, seq) = (seq, seq)
weights = softmax(scores) Element-wise + reduction
output = weights * V Matrix multiply: (seq, seq) x (seq, d_k) = (seq, d_k)
The middle step (softmax) is a reduction along the last dimension of each row, which requires special handling. We will devote Chapters 15 and 16 to attention.
Layer Normalization (RMSNorm)
Modern transformers use RMSNorm, which normalizes each row by its root-mean-square:
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma
For a row x of length d_model:
1. Compute sum of squares: ss = sum(x[i]^2 for i in 0..d_model)
2. Compute RMS: rms = sqrt(ss / d_model + eps)
3. Normalize and scale: output[i] = x[i] / rms * gamma[i]
Step 1 is a reduction (sum).
Step 3 is element-wise.
This follows the “one threadgroup per row” pattern. The reduction in step 1 uses SIMD group sums.
Activation Functions
Applied element-wise between the FFN layers:
SiLU(x) = x * sigmoid(x) = x * (1 / (1 + exp(-x)))
GELU(x) = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
These are purely element-wise, one thread per element. Trivially parallel.
Residual Connections
Just element-wise addition:
output = sublayer_output + input
Every transformer block has two residual connections:
x = attention(norm(x)) + x <-- residual
x = ffn(norm(x)) + x <-- residual
Putting It All Together: One Transformer Block
Input: x (seq_len, d_model)
Step 1: RMSNorm --> x_norm (seq_len, d_model) [reduction per row]
Step 2: Q = x_norm * Wq --> Q (seq_len, d_model) [GEMM or GEMV]
Step 3: K = x_norm * Wk --> K (seq_len, d_kv) [GEMM or GEMV]
Step 4: V = x_norm * Wv --> V (seq_len, d_kv) [GEMM or GEMV]
Step 5: Reshape Q/K/V for multi-head
Step 6: Attention --> attn (seq_len, d_model) [batched matmul + softmax]
Step 7: O = attn * Wo --> O (seq_len, d_model) [GEMM or GEMV]
Step 8: residual --> x = O + x [element-wise add]
Step 9: RMSNorm --> x_norm [reduction per row]
Step 10: FFN up --> h (seq_len, d_ff) [GEMM or GEMV]
Step 11: FFN gate --> g (seq_len, d_ff) [GEMM or GEMV]
Step 12: SiLU + multiply --> h = SiLU(g) * h [element-wise]
Step 13: FFN down --> f (seq_len, d_model) [GEMM or GEMV]
Step 14: residual --> x = f + x [element-wise add]
Output: x (seq_len, d_model)
Count the operations: 6 matrix multiplications, 2 normalizations, 2 residual adds, 1 activation, and the attention computation. The matrix multiplications dominate compute time (90%+ of total FLOPS).
Floating-Point Precision: FP16, FP32, and BF161
The choice of numeric precision has enormous impact on both performance and quality.
FP32 (Single Precision)
FP32: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits
+---+----------+-----------------------+
| S | EEEEEEEE | MMMMMMMMMMMMMMMMMMMMMMM|
+---+----------+-----------------------+
1 8 23
Range: ~1.18e-38 to ~3.4e+38
Precision: ~7 decimal digits
Size: 4 bytes per value
FP32 is the “safe” choice. It has plenty of range and precision for any ML computation. But it uses twice the memory and bandwidth of FP16, making it 2x slower for memory-bound operations.
FP16 (Half Precision)
FP16: 1 sign bit + 5 exponent bits + 10 mantissa bits = 16 bits
+---+-------+------------+
| S | EEEEE | MMMMMMMMMM |
+---+-------+------------+
1 5 10
Range: ~6.1e-5 to ~65504
Precision: ~3.3 decimal digits
Size: 2 bytes per value
FP16 halves memory usage and doubles effective bandwidth. Apple GPUs have native FP16 support and can process FP16 values at 2x the rate of FP32 in many operations.
The downside is the limited range. Values smaller than 6.1e-5 become zero (underflow), and values larger than 65504 become infinity (overflow). During inference, weights are usually small enough that this is not a problem. But intermediate computations (especially in attention) can overflow if not handled carefully.
BF16 (Brain Float 16)
BF16: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits
+---+----------+---------+
| S | EEEEEEEE | MMMMMMM |
+---+----------+---------+
1 8 7
Range: ~1.18e-38 to ~3.4e+38 (same as FP32!)
Precision: ~2.4 decimal digits
Size: 2 bytes per value
BF16 has the same exponent range as FP32 (so no overflow/underflow issues) but less precision than FP16. It was designed specifically for ML workloads where range matters more than precision.
Apple Silicon support: M1/M2 GPUs do not have native BF16 support in Metal. M3+ has some BF16 capabilities. In practice, most Apple GPU inference uses FP16 for computation with FP32 for accumulation.
The Mixed-Precision Strategy
The standard approach in inference:
Weights: stored in FP16 (or quantized, e.g., 4-bit)
Activations: computed in FP16
Accumulations: done in FP32 (dot products, reductions)
Final output: cast back to FP16
Why FP32 for accumulations?
Consider summing 4096 FP16 values, each around 0.01:
True sum: ~40.96
FP16 sum: after adding a few hundred values, each new addition
changes the sum by less than FP16 precision can represent.
You lose significant accuracy.
FP32 sum: no problem. 7 decimal digits of precision is plenty.
In Metal shader code, this looks like:
// Dot product with mixed precision
float sum = 0.0f; // FP32 accumulator
for (uint i = 0; i < K; i++) {
sum += float(A[i]) * float(B[i]); // Cast FP16 inputs to FP32
}
output[idx] = half(sum); // Cast result back to FP16
Precision Comparison Table
+--------+------+--------+-----------+-------------+------------------+
| Format | Bits | Range | Precision | Memory/elem | Apple GPU Native |
+--------+------+--------+-----------+-------------+------------------+
| FP32 | 32 | 1e38 | ~7 digits | 4 bytes | Yes |
| FP16 | 16 | 65504 | ~3 digits | 2 bytes | Yes (2x rate) |
| BF16 | 16 | 1e38 | ~2 digits | 2 bytes | M3+ partial |
| INT8 | 8 | -128 | exact | 1 byte | Yes |
| | | to 127 | integers | | |
| INT4 | 4 | -8 | exact | 0.5 bytes | No (emulated) |
| | | to 7 | integers | | |
+--------+------+--------+-----------+-------------+------------------+
Fused Operations: Saving Bandwidth by Combining Kernels
Here is one of the most important optimization ideas in GPU programming: kernel fusion.2
The Problem: Bandwidth Waste
Consider computing SiLU activation followed by element-wise multiply (as in the LLaMA FFN):
Step 1: gate = SiLU(x * Wg)
Step 2: up = x * Wu
Step 3: hidden = gate * up
If each step is a separate kernel:
Kernel 1 (SiLU):
Read gate_raw from global memory (N elements read)
Compute SiLU
Write gate to global memory (N elements written)
Kernel 2 (element-wise multiply):
Read gate from global memory (N elements read - AGAIN!)
Read up from global memory (N elements read)
Compute gate * up
Write hidden to global memory (N elements written)
Total memory traffic: 2N reads + N writes + 2N reads + N writes = 4N reads + 2N writes
The Solution: Fused Kernel
Combine both operations into a single kernel:
Fused kernel (SiLU + multiply):
Read gate_raw from global memory (N elements read)
Read up from global memory (N elements read)
Compute SiLU(gate_raw) * up in registers
Write hidden to global memory (N elements written)
Total memory traffic: 2N reads + N writes
Savings: 2N reads + N writes eliminated!
The fused kernel avoids the round trip through global memory between operations. The intermediate value (the SiLU output) lives only in registers, which are essentially free to access.
Common Fused Operations in Transformers
+------------------------------------+---------------------------------+
| Fused Operation | What It Combines |
+------------------------------------+---------------------------------+
| Fused SiLU gate | SiLU(x) * y |
| Fused RMSNorm | norm + scale (gamma multiply) |
| Fused attention | QK^T + scale + mask + softmax |
| (FlashAttention) | + V multiply (Ch 16) |
| Fused dequant + matmul | Dequantize weights + GEMV/GEMM |
| Fused residual + norm | x + sublayer + RMSNorm |
| Fused RoPE + reshape | Rotary embedding + head reshape |
+------------------------------------+---------------------------------+
When Fusion Matters and When It Does Not
Fusion helps MOST when:
- Operations are memory-bandwidth bound (small compute per element)
- Intermediate tensors are large
- Operations are sequential (output of one feeds input of next)
Fusion helps LEAST when:
- Operations are compute bound (GEMM with large matrices)
- Intermediate tensors are small (fit in cache anyway)
- Operations have different parallelism patterns
(e.g., can't easily fuse a reduction with an element-wise op)
The Bandwidth Equation
To understand why fusion matters, consider the numbers for an Apple M2 GPU:
M2 GPU specs:
Memory bandwidth: ~100 GB/s
Compute (FP16): ~3.6 TFLOPS
For a SiLU activation on 4096 FP16 elements:
Data to read: 4096 * 2 bytes = 8 KB
Data to write: 4096 * 2 bytes = 8 KB
Compute: 4096 * ~5 FLOPs (exp, add, div, mul) = ~20 KFLOPS
Time limited by bandwidth: 16 KB / 100 GB/s = 0.16 microseconds
Time limited by compute: 20 KFLOPS / 3.6 TFLOPS = 0.006 microseconds
This operation is 27x more bandwidth bound than compute bound!
Every byte you can avoid reading from or writing to global memory is a direct performance win for bandwidth-bound operations. And in transformers, the majority of non-GEMM operations are bandwidth bound.
A Complete Example: Processing One Token
Let us trace the memory traffic for processing a single token through one LLaMA-7B layer, showing where fusion helps.
Model parameters (LLaMA 7B):
d_model = 4096
n_heads = 32
d_k = 128
d_ff = 11008 (intermediate FFN dimension)
Weights in FP16
Step 1: RMSNorm
Read: x (4096 * 2B = 8KB) + gamma (8KB) = 16KB
Write: x_norm (8KB)
Total: 24KB
Step 2: Q projection (GEMV: 1 x 4096 times 4096 x 4096)
Read: x_norm (8KB) + Wq (4096*4096*2B = 32MB)
Write: Q (8KB)
Total: ~32MB <-- Weight read dominates!
Step 3: K projection
Read: x_norm (8KB) + Wk (4096*128*32*2B = ... depends on GQA)
... similar to Q
Step 4-6: V projection, attention, output projection
... similar pattern
Step 7-8: FFN up + gate + SiLU + down
Read: Wu (4096*11008*2B = 86MB) + Wg (86MB) + Wd (86MB)
Total: ~258MB of weight reads for FFN alone
Grand total weight reads per layer: ~400MB
For 32 layers: ~12.8GB of weight reads PER TOKEN
At 100 GB/s bandwidth: 12.8GB / 100 GB/s = 128ms per token = ~8 tokens/sec
This is why:
- Quantization is essential – 4-bit weights cut reads by 4x.
- Fusion matters for the non-GEMM ops – saving even a few MB per layer adds up.
- Decode is bandwidth-bound – the GPU spends most of its time waiting for memory.
Summary
In this chapter, we have established the foundations:
- Tensors are multi-dimensional arrays. In ML, we work with rank-2 to rank-4 tensors constantly.
- Memory layout (row-major with strides) determines how multi-dimensional indices map to flat GPU buffer addresses.
- Coalesced access – adjacent threads reading adjacent memory – is critical for GPU performance.
- Common operations include element-wise (trivially parallel), reductions (require thread cooperation), and matrix multiplications (require tiling).
- Mixed precision (FP16 values, FP32 accumulation) gives us the best of both worlds.
- Kernel fusion eliminates wasteful memory round-trips between operations.
The next chapter zooms in on the single most important operation: matrix multiplication. It consumes 90%+ of inference compute, and getting it right is the difference between a usable inference engine and a toy.
-
Apple. “Metal Shading Language Specification, v3.1.” developer.apple.com. Covers the
halftype, vector types, and precision guarantees for FP16 arithmetic on Apple GPUs. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩ -
Micikevicius, P., et al. “Mixed Precision Training.” ICLR 2018. Establishes the practice of using FP16 for computation with FP32 accumulation to maintain numerical stability. See https://arxiv.org/abs/1710.03740. ↩
Matrix Multiplication Strategies
If there is one operation you need to understand deeply when running neural networks on GPUs, it is matrix multiplication. Not because it is conceptually hard – you learned it in linear algebra class – but because virtually everything in a transformer reduces to it. Every linear layer, every projection, every feed-forward network… all matmuls. If your matmul is slow, your model is slow. Period.
In this chapter we are going to build up from the textbook definition of matrix multiplication, examine why it matters so much for inference, and then dive into the different strategies you will encounter when targeting Metal GPUs. We will cover the two main regimes – GEMV and GEMM – plus all the messy in-between cases, and finally discuss how quantization changes the picture.
Why Matmul Is the Center of the Universe
Let us start with a concrete example. Consider a single linear layer in a transformer:
y = x * W + b
Here x is your input (say, a vector of dimension 4096), W is the weight matrix (say,
4096 x 11008), and b is a bias vector. That multiplication x * W is a matrix
multiplication. Now count how many linear layers exist in a single transformer block:
- Q projection: matmul
- K projection: matmul
- V projection: matmul
- Output projection: matmul
- Gate projection (FFN): matmul
- Up projection (FFN): matmul
- Down projection (FFN): matmul
That is seven matmuls per layer. A 32-layer model has 224 matmuls per forward pass. If each one takes 0.5ms, you are already at 112ms just for the matmuls – and that is before attention, normalization, or anything else.
One Transformer Block
=====================
Input
|
v
[LayerNorm]
|
+---> Q Projection (matmul #1)
+---> K Projection (matmul #2)
+---> V Projection (matmul #3)
|
v
[Attention Mechanism]
|
v
[Output Projection] (matmul #4)
|
v
[Residual Add]
|
v
[LayerNorm]
|
+---> Gate Projection (matmul #5)
+---> Up Projection (matmul #6)
|
v
[SiLU + Elementwise Mul]
|
v
[Down Projection] (matmul #7)
|
v
[Residual Add]
|
v
Output
So when someone says “optimizing inference,” they mostly mean “optimizing matmul.”
The Two Regimes: GEMV vs GEMM
Here is the critical insight that changes everything about how you write GPU kernels for LLM inference: the shape of the matmul depends on what phase of inference you are in.
Token Generation (Decode): GEMV
During autoregressive decoding, you generate one token at a time. Your input x is a
single vector of dimension d_model. The weight matrix W has shape [d_model, d_out].
So the multiplication is:
x: [1, 4096]
W: [4096, 11008]
y: [1, 11008]
This is a matrix-vector multiplication, or GEMV (General Matrix-Vector multiply). In BLAS terminology, M=1 (only one row in the output). The defining characteristic of GEMV is that every element of the weight matrix is read exactly once. There is no data reuse. This makes GEMV fundamentally memory bandwidth bound – your performance is limited by how fast you can stream the weight matrix from memory.
GEMV: One vector times a matrix
================================
x = [x0, x1, x2, x3] (1 x K)
W = [ w00 w01 w02 ]
[ w10 w11 w12 ] (K x N)
[ w20 w21 w22 ]
[ w30 w31 w32 ]
y = [y0, y1, y2] (1 x N)
y0 = x0*w00 + x1*w10 + x2*w20 + x3*w30
y1 = x0*w01 + x1*w11 + x2*w21 + x3*w31
y2 = x0*w02 + x1*w12 + x2*w22 + x3*w32
Each weight is touched exactly once.
Bottleneck = memory bandwidth.
Prompt Processing (Prefill): GEMM
During prefill, you process the entire input prompt at once. If the prompt has 512 tokens,
your input is a matrix of shape [512, 4096]. Now the multiplication becomes:
X: [512, 4096]
W: [4096, 11008]
Y: [512, 11008]
This is a full matrix-matrix multiplication, or GEMM (General Matrix-Matrix multiply).
M=512, and now every element of W gets reused across 512 different input rows. This
changes the arithmetic intensity dramatically – GEMM can be compute bound if you
organize the data access well.
GEMM: A matrix times a matrix
===============================
X = [ x00 x01 x02 x03 ]
[ x10 x11 x12 x13 ] (M x K)
[ x20 x21 x22 x23 ]
W = [ w00 w01 ]
[ w10 w11 ] (K x N)
[ w20 w21 ]
[ w30 w31 ]
Y = [ y00 y01 ]
[ y10 y11 ] (M x N)
[ y20 y21 ]
Each weight wij is used M times (once per row of X).
With good tiling, this becomes compute-bound.
The Arithmetic Intensity Spectrum
Let us quantify this. Arithmetic intensity is the ratio of compute operations to memory bytes transferred:
GEMV (M=1):
Operations = 2 * K * N (multiply + add for each element)
Bytes read = K * N * sizeof(weight) + K * sizeof(input)
Intensity ~ 2 / sizeof(weight)
For FP16: 2 / 2 = 1 FLOP/byte --> bandwidth bound
GEMM (M=512):
Operations = 2 * M * K * N
Bytes read = K * N * sizeof(weight) + M * K * sizeof(input)
Intensity ~ 2 * M / sizeof(weight) (when K*N >> M*K)
For FP16: 2 * 512 / 2 = 512 FLOP/byte --> compute bound
Metal GPUs typically have 200-400 GB/s of memory bandwidth and 10-20 TFLOPS of FP16 compute. The crossover point – where you shift from bandwidth bound to compute bound – is roughly around M=16-64, depending on the specific hardware and data types.
Arithmetic Intensity vs M
=========================
Intensity
(FLOP/byte)
|
512 | * GEMM (M=512)
| *
| *
| *
| *
64 | * <-- Compute bound above this line
| *
| - - - - - - - - - - -*- - - - - - HW Compute/BW ratio
16 | *
| *
4 | *
2 | *
1 | * * <-- GEMV (M=1)
+--+---+---+---+---+---+---+---+---> M
1 2 4 8 16 32 64 512
GEMV: Thread and SIMD Parallelism
Let us start with GEMV since it is the bread and butter of token generation. Remember,
M=1, so we are computing y = x * W where x is a vector and W is a matrix. The
output y has N elements, and each output element is a dot product of length K.
The Naive Approach
The simplest approach: assign one thread per output element. Each thread computes a full dot product of length K:
// Naive GEMV: one thread per output element
kernel void gemv_naive(
device const float* x [[buffer(0)]], // [K]
device const float* W [[buffer(1)]], // [K x N], row-major
device float* y [[buffer(2)]], // [N]
uint tid [[thread_position_in_grid]])
{
float sum = 0.0f;
for (uint k = 0; k < K; k++) {
sum += x[k] * W[k * N + tid]; // Column tid of W
}
y[tid] = sum;
}
This works but has terrible performance. Each thread reads all K elements of x (lots of
redundant reads) and walks down a column of W with stride N (terrible memory access
pattern, no coalescing).
Coalesced Access with Transposed Weights
First improvement: store W in column-major order (or equivalently, store W^T in
row-major order). Now adjacent threads read adjacent memory locations:
// GEMV with transposed weights: coalesced reads
kernel void gemv_transposed(
device const float* x [[buffer(0)]], // [K]
device const float* Wt [[buffer(1)]], // [N x K], W transposed
device float* y [[buffer(2)]], // [N]
uint tid [[thread_position_in_grid]])
{
// Thread tid computes output element tid
// Reads row tid of Wt, which is column tid of W
float sum = 0.0f;
for (uint k = 0; k < K; k++) {
sum += x[k] * Wt[tid * K + k];
}
y[tid] = sum;
}
Better memory access pattern, but we still have N threads each independently reading the
entire x vector. Let us fix that.
SIMD-Level Parallelism: Splitting the K Dimension
Here is the key idea for efficient GEMV on Metal: instead of one thread per output, use a
group of threads (a SIMD group) to cooperatively compute one output element. Each
thread in the SIMD group handles a slice of the K dimension, then we use simd_sum to
reduce:
SIMD GEMV: 32 threads cooperate on one dot product
====================================================
x = [x0, x1, x2, ..., x31, x32, ..., x63, ...]
Thread 0 handles: x[0]*w[0] + x[32]*w[32] + x[64]*w[64] + ...
Thread 1 handles: x[1]*w[1] + x[33]*w[33] + x[65]*w[65] + ...
...
Thread 31 handles: x[31]*w[31] + x[63]*w[63] + x[95]*w[95] + ...
Then: simd_sum() adds all 32 partial sums in hardware
Result: one output element computed cooperatively
// SIMD GEMV: one SIMD group per output element
kernel void gemv_simd(
device const float* x [[buffer(0)]],
device const float* Wt [[buffer(1)]], // [N x K]
device float* y [[buffer(2)]],
constant uint& K [[buffer(3)]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
// Each SIMD group computes one output element
uint n = tg_id * SIMD_GROUPS_PER_TG + simd_gid;
if (n >= N) return;
float partial = 0.0f;
// Each lane processes every 32nd element
for (uint k = simd_lane; k < K; k += 32) {
partial += x[k] * Wt[n * K + k];
}
// Hardware reduction across the SIMD group
float total = simd_sum(partial);
// Only lane 0 writes the result
if (simd_lane == 0) {
y[n] = total;
}
}
The simd_sum intrinsic is a hardware-level reduction. On Metal, a SIMD group (also
called a “wave” on other platforms) is 32 threads that execute in lockstep. The simd_sum
operation sums a value across all 32 lanes in just a few clock cycles – no shared memory,
no barriers, no atomics. This is enormously powerful.
simd_sum reduction (hardware butterfly)
========================================
Lane: 0 1 2 3 ... 30 31
Val: 2.1 0.5 1.3 0.8 ... 0.2 1.1
Step 1: swap with neighbor, add
2.6 2.6 2.1 2.1 ... 1.3 1.3
Step 2: swap with stride-2 neighbor, add
4.7 4.7 4.7 4.7 ... ... ...
Step 3: stride 4 ...
Step 4: stride 8 ...
Step 5: stride 16 ...
After 5 steps: all lanes hold the total sum.
Cost: ~5 cycles. No memory traffic.
Multiple Output Elements per SIMD Group
We can do even better. Instead of one output element per SIMD group, compute several. This
amortizes the cost of reading x:
// SIMD GEMV: one SIMD group computes 4 output elements
kernel void gemv_simd_multi(
device const float* x [[buffer(0)]],
device const float* Wt [[buffer(1)]],
device float* y [[buffer(2)]],
constant uint& K [[buffer(3)]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
uint n_base = (tg_id * SIMD_GROUPS_PER_TG + simd_gid) * 4;
float4 partial = float4(0.0f);
for (uint k = simd_lane; k < K; k += 32) {
float xk = x[k]; // Load x[k] once, use for all 4 outputs
partial[0] += xk * Wt[(n_base + 0) * K + k];
partial[1] += xk * Wt[(n_base + 1) * K + k];
partial[2] += xk * Wt[(n_base + 2) * K + k];
partial[3] += xk * Wt[(n_base + 3) * K + k];
}
// Reduce each output element
float y0 = simd_sum(partial[0]);
float y1 = simd_sum(partial[1]);
float y2 = simd_sum(partial[2]);
float y3 = simd_sum(partial[3]);
if (simd_lane == 0) {
y[n_base + 0] = y0;
y[n_base + 1] = y1;
y[n_base + 2] = y2;
y[n_base + 3] = y3;
}
}
Now each SIMD group loads x[k] once and uses it four times. The weight reads are still
one-shot (each weight used exactly once), but we have reduced the x reads by 4x.
GEMM: Tiling for Compute Efficiency
When M is large (prefill phase), we enter the world of GEMM. The key technique here is tiling: dividing the output matrix into tiles and assigning each tile to a threadgroup. Within the threadgroup, we further divide work among SIMD groups.
The Tiling Concept
GEMM Tiling Overview
=====================
Output Y [M x N] is divided into tiles of size [Bm x Bn]
N
<----------->
+----+----+----+----+ ^
| TG | TG | TG | TG | |
| 00 | 01 | 02 | 03 | |
+----+----+----+----+ | M
| TG | TG | TG | TG | |
| 10 | 11 | 12 | 13 | |
+----+----+----+----+ v
Each TG computes a [Bm x Bn] tile of the output.
To compute its tile, TG needs:
- A strip of X: [Bm x K]
- A strip of W: [K x Bn]
But K can be huge (4096+), so we tile along K too:
K
<----------->
+====+====+====+ <- X rows for this TG's tile
| Bk | Bk | Bk |
+====+====+====+
Process K in chunks of Bk, accumulating partial results.
Cooperative Loading into Threadgroup Memory
The key insight is that threads within a threadgroup can cooperatively load a tile of data into fast threadgroup memory (shared memory), then all threads read from it. This converts slow device memory reads into fast threadgroup memory reads:
// Simplified tiled GEMM
kernel void gemm_tiled(
device const half* X [[buffer(0)]], // [M x K]
device const half* W [[buffer(1)]], // [K x N]
device half* Y [[buffer(2)]], // [M x N]
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 tg_pos [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]])
{
// Tile dimensions
const uint Bm = 32; // Tile rows
const uint Bn = 32; // Tile cols
const uint Bk = 16; // K-tile size
// Threadgroup memory for tiles
threadgroup half X_tile[Bm * Bk];
threadgroup half W_tile[Bk * Bn];
// This threadgroup computes output tile at (tg_pos.y * Bm, tg_pos.x * Bn)
uint m_base = tg_pos.y * Bm;
uint n_base = tg_pos.x * Bn;
// Accumulator for this thread's output elements
float acc[4] = {0, 0, 0, 0}; // Each thread computes 4 elements
// Walk along K in steps of Bk
for (uint k_base = 0; k_base < K; k_base += Bk) {
// === Cooperative load: all threads load tiles into shared memory ===
// (simplified -- real code distributes load across all threads)
if (tid < Bm * Bk) {
uint row = tid / Bk;
uint col = tid % Bk;
X_tile[tid] = X[(m_base + row) * K + (k_base + col)];
}
if (tid < Bk * Bn) {
uint row = tid / Bn;
uint col = tid % Bn;
W_tile[tid] = W[(k_base + row) * N + (n_base + col)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === Compute: multiply tiles ===
// Each thread computes its assigned output elements
// using data from threadgroup memory (fast!)
uint my_m = tid / 8; // Which row within tile
uint my_n = (tid % 8) * 4; // Which 4 columns within tile
for (uint kk = 0; kk < Bk; kk++) {
half x_val = X_tile[my_m * Bk + kk];
for (uint j = 0; j < 4; j++) {
acc[j] += float(x_val) * float(W_tile[kk * Bn + my_n + j]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Write results
uint my_m = tid / 8;
uint my_n = (tid % 8) * 4;
for (uint j = 0; j < 4; j++) {
Y[(m_base + my_m) * N + (n_base + my_n + j)] = half(acc[j]);
}
}
Let us trace through what happens for one K-tile:
One iteration of the K-loop (k_base = 0, Bk = 16)
===================================================
Step 1: Cooperative Load
~~~~~~~~~~~~~~~~~~~~~~~~
All 256 threads in the threadgroup work together to load:
X_tile [32 x 16]: W_tile [16 x 32]:
Rows m_base..m_base+31 Rows 0..15
Cols 0..15 of X Cols n_base..n_base+31 of W
Thread 0 loads X[m_base][0..15] row Thread 128 loads W[0][n_base..+31]
Thread 1 loads X[m_base+1][0..15] Thread 129 loads W[1][n_base..+31]
... ...
Step 2: Barrier
~~~~~~~~~~~~~~~
Wait for all loads to complete.
Step 3: Compute
~~~~~~~~~~~~~~~
Each thread reads from the fast threadgroup memory tiles
and accumulates partial results.
Thread 0 computes: Y[m_base+0][n_base+0..3] += X_tile[0][k] * W_tile[k][0..3]
Thread 1 computes: Y[m_base+0][n_base+4..7] += X_tile[0][k] * W_tile[k][4..7]
...
Thread 8 computes: Y[m_base+1][n_base+0..3] += X_tile[1][k] * W_tile[k][0..3]
...
Step 4: Barrier
~~~~~~~~~~~~~~~
Wait before overwriting tiles with next K-chunk.
SIMD Group Matrix Multiply-Accumulate (MMA)
Metal 3.0+ on Apple Silicon supports SIMD group matrix operations that are much faster than manually computing the multiply-accumulate. These map to the hardware’s matrix multiplication units (the AMX/matrix coprocessor):
#include <metal_simdgroup_matrix>
// Using SIMD group MMA for tiled GEMM
kernel void gemm_simdgroup_mma(
device const half* X [[buffer(0)]],
device const half* W [[buffer(1)]],
device half* Y [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 tg_pos [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
// Each SIMD group computes an 8x8 tile of the output
// using 8x8 matrix operations
// Declare SIMD group matrices
simdgroup_matrix<half, 8, 8> x_mat;
simdgroup_matrix<half, 8, 8> w_mat;
simdgroup_matrix<half, 8, 8> acc_mat;
// Initialize accumulator to zero
simdgroup_fill(acc_mat, half(0));
uint m_base = tg_pos.y * 32 + (sgid / 4) * 8;
uint n_base = tg_pos.x * 32 + (sgid % 4) * 8;
for (uint k = 0; k < K; k += 8) {
// Load 8x8 tiles from X and W
simdgroup_load(x_mat, X + m_base * K + k, K);
simdgroup_load(w_mat, W + k * N + n_base, N);
// 8x8 matrix multiply-accumulate in hardware
simdgroup_multiply_accumulate(acc_mat, x_mat, w_mat, acc_mat);
}
// Store the 8x8 result tile
simdgroup_store(acc_mat, Y + m_base * N + n_base, N);
}
The simdgroup_multiply_accumulate function computes an 8x8 matrix multiply in hardware.
A single SIMD group of 32 threads cooperatively holds the 8x8 matrices (each thread holds
2 elements) and the multiplication happens in the dedicated matrix hardware. This is
dramatically faster than doing the multiply-add manually.
SIMD Group MMA: 32 threads own an 8x8 matrix
===============================================
The 8x8 = 64 elements are distributed across 32 threads:
Thread 0: elements [0,0] and [0,1]
Thread 1: elements [0,2] and [0,3]
...
Thread 15: elements [1,6] and [1,7]
Thread 16: elements [2,0] and [2,1]
...
Thread 31: elements [3,6] and [3,7]
simdgroup_multiply_accumulate(C, A, B, C):
C += A * B (all 8x8, in ~1-2 cycles)
This maps to Apple's matrix coprocessor hardware.
Wide GEMV: The Vocabulary Projection Problem
There is one particular GEMV that deserves special attention: the final vocabulary projection. In a typical LLM, this multiplies the hidden state (dimension 4096) by the vocabulary embedding matrix to produce logits over the entire vocabulary:
hidden: [1, 4096]
vocab_weights: [4096, 32000] (or 128256 for Llama 3!)
logits: [1, 32000]
This is still a GEMV (M=1), but N is enormous – 32K to 128K output elements. The challenge is that you need enough parallelism to saturate the GPU, and you need to read a very large weight matrix.
The strategy is to parallelize aggressively across the N dimension:
Wide GEMV for Vocab Projection
===============================
N = 32000 output logits
K = 4096
Approach: assign SIMD groups to output elements
Threadgroup 0: SIMD groups compute outputs [0..127]
Threadgroup 1: SIMD groups compute outputs [128..255]
...
Threadgroup 249: SIMD groups compute outputs [31872..31999]
With 250 threadgroups x 4 SIMD groups/TG x 32 threads/SIMD = 32000 threads
Each SIMD group: 32 threads reduce a dot product of length 4096
That is 4096/32 = 128 multiply-adds per thread.
Total weight data: 4096 * 32000 * 2 bytes = 256 MB (FP16)
At 200 GB/s: ~1.3ms to stream all weights
This sets the floor for GEMV performance.
The key insight: for wide GEMV, you want each SIMD group to handle one (or a few) output
columns, with the 32 lanes splitting the K-dimension reduction. The x vector is small
enough to fit in threadgroup memory or even registers, so you load it once and reuse it
for all output columns in the threadgroup.
// Wide GEMV: optimized for large N (vocab projection)
kernel void gemv_wide(
device const half* x [[buffer(0)]], // [K]
device const half* Wt [[buffer(1)]], // [N x K]
device half* y [[buffer(2)]], // [N]
constant uint& K [[buffer(3)]],
constant uint& N [[buffer(4)]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
// Load x into threadgroup memory (cooperative load)
threadgroup half x_shared[4096]; // Assumes K <= 4096
uint tid = sgid * 32 + lane;
for (uint i = tid; i < K; i += 256) { // 256 threads per TG
x_shared[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each SIMD group handles one output element
uint n = tg_id * 8 + sgid; // 8 SIMD groups per TG
if (n >= N) return;
float sum = 0.0f;
device const half* w_row = Wt + n * K;
for (uint k = lane; k < K; k += 32) {
sum += float(x_shared[k]) * float(w_row[k]);
}
sum = simd_sum(sum);
if (lane == 0) {
y[n] = half(sum);
}
}
Batched GEMV: The Awkward Middle Ground
In practice, inference is not always purely M=1 or M=512. There are several scenarios where M is small but greater than 1:
- Batch decoding: serving multiple users simultaneously (M = batch_size, often 2-16)
- Speculative decoding: verifying multiple candidate tokens at once (M = num_candidates, often 4-8)
- Small prompts: short prefill with just a few tokens
For these cases, M is too small for efficient GEMM (not enough reuse to be compute bound) but too large for pure GEMV (we want some reuse of the weight data).
The solution is batched GEMV: treat it as M independent GEMVs but share weight loads across them:
Batched GEMV (M=4)
===================
X = [ x0 ] (4 rows, each a vector of length K)
[ x1 ]
[ x2 ]
[ x3 ]
Strategy: Each SIMD group handles multiple output columns
and ALL M rows simultaneously.
SIMD Group 0 processing output column n=0:
+-----+-----+-----+-----+-----+
| Lane| 0 | 1 | 2 | ... |
+-----+-----+-----+-----+-----+
| Row0| x0[0]*w[0] | x0[1]*w[1] | x0[2]*w[2] | ... |
| Row1| x1[0]*w[0] | x1[1]*w[1] | x1[2]*w[2] | ... |
| Row2| x2[0]*w[0] | x2[1]*w[1] | x2[2]*w[2] | ... |
| Row3| x3[0]*w[0] | x3[1]*w[1] | x3[2]*w[2] | ... |
+-----+-----+-----+-----+-----+
Each weight w[k] is loaded once and used for M=4 rows.
Weight reuse factor: 4x compared to pure GEMV.
// Batched GEMV: M rows share weight loads
kernel void gemv_batched(
device const half* X [[buffer(0)]], // [M x K]
device const half* Wt [[buffer(1)]], // [N x K]
device half* Y [[buffer(2)]], // [M x N]
constant uint& M [[buffer(3)]],
constant uint& K [[buffer(4)]],
constant uint& N [[buffer(5)]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
uint n = tg_id * 8 + sgid; // Output column
if (n >= N) return;
// Accumulators for each row
float sums[8] = {0}; // Support up to M=8
device const half* w_row = Wt + n * K;
for (uint k = lane; k < K; k += 32) {
half w_val = w_row[k]; // Load weight ONCE
// Apply to all M rows
for (uint m = 0; m < M; m++) {
sums[m] += float(X[m * K + k]) * float(w_val);
}
}
// Reduce each row's sum
for (uint m = 0; m < M; m++) {
float total = simd_sum(sums[m]);
if (lane == 0) {
Y[m * N + n] = half(total);
}
}
}
The performance gain over running M separate GEMVs is significant because:
- Weight data is loaded once instead of M times
- Kernel launch overhead is reduced
- The GPU can be better saturated with the additional work
Performance: M separate GEMVs vs Batched GEMV
===============================================
M=4, K=4096, N=4096, FP16 weights
M separate GEMVs:
Weight reads: 4 * 4096 * 4096 * 2 bytes = 128 MB
Time at 200 GB/s: ~0.64 ms
Batched GEMV:
Weight reads: 1 * 4096 * 4096 * 2 bytes = 32 MB
Time at 200 GB/s: ~0.16 ms
4x improvement from weight reuse!
How Quantization Changes the Game
Everything we have discussed so far assumes FP16 weights. But in practice, most inference deployments use quantized weights – 4-bit or even 2-bit. This changes the matmul kernels significantly.
The Core Challenge: Dequantize on the Fly
Quantized weights are stored in a compressed format. You cannot directly multiply them with the input – you first need to dequantize them back to a floating-point representation. The question is: where and when do you dequantize?
The answer is on-the-fly dequantization: each thread dequantizes just the weight values it needs, right before multiplying them with the input. The dequantized values live only in registers, never written back to memory.
Dequantize-on-the-fly GEMV
===========================
Memory: [...Q4 packed weights...] (4 bits per weight)
|
v (load)
Registers: [packed_byte]
|
v (unpack + scale)
Registers: [fp16_weight_a, fp16_weight_b]
|
v (multiply with input)
Registers: [partial_sum]
|
v (simd_sum)
Memory: [output]
Key: dequantized weights never touch memory.
We trade compute (dequantize) for bandwidth (smaller reads).
Q4_0 GEMV Example
Let us work through a concrete example with Q4_0 quantization (the simplest format). In Q4_0, every block of 32 weights shares a single FP16 scale factor. Each weight is stored as a 4-bit integer (0-15), representing the range [-8, 7] after subtracting 8:
// Q4_0 block structure
struct block_q4_0 {
half scale; // 2 bytes: shared scale for 32 weights
uchar packed[16]; // 16 bytes: 32 x 4-bit values packed into pairs
};
// Total: 18 bytes for 32 weights (4.5 bits/weight effective)
// Q4_0 GEMV: dequantize on the fly
kernel void gemv_q4_0(
device const half* x [[buffer(0)]], // [K]
device const block_q4_0* W [[buffer(1)]], // [N x K/32] blocks
device half* y [[buffer(2)]], // [N]
constant uint& K [[buffer(3)]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
uint n = tg_id * 8 + sgid;
uint num_blocks = K / 32;
float sum = 0.0f;
// Each lane processes every 32nd block
for (uint b = lane; b < num_blocks; b += 32) {
// Load the quantized block
device const block_q4_0& blk = W[n * num_blocks + b];
half scale = blk.scale;
// Dequantize and multiply all 32 weights in this block
for (uint j = 0; j < 16; j++) {
uchar packed = blk.packed[j];
// Unpack two 4-bit values
int8_t w0 = (packed & 0x0F) - 8; // Low nibble
int8_t w1 = (packed >> 4) - 8; // High nibble
// Dequantize: weight = scale * quantized_value
float dq0 = float(scale) * float(w0);
float dq1 = float(scale) * float(w1);
// Multiply with input
uint k_idx = b * 32 + j * 2;
sum += dq0 * float(x[k_idx]);
sum += dq1 * float(x[k_idx + 1]);
}
}
sum = simd_sum(sum);
if (lane == 0) {
y[n] = half(sum);
}
}
The Bandwidth Win
The reason quantization is so important for GEMV performance is straightforward arithmetic. Since GEMV is bandwidth bound, reducing the data size directly speeds things up:
Bandwidth Savings with Quantization
====================================
Weight matrix: 4096 x 4096 = 16.7M weights
FP16: 16.7M * 2 bytes = 33.6 MB (baseline)
Q8_0: 16.7M * 1.06 bytes = 17.7 MB (1.9x faster)
Q4_0: 16.7M * 0.56 bytes = 9.4 MB (3.6x faster)
Q4_K: 16.7M * 0.57 bytes = 9.5 MB (3.5x faster)
Q2_K: 16.7M * 0.34 bytes = 5.7 MB (5.9x faster)
At 200 GB/s memory bandwidth:
FP16: 0.168 ms per layer
Q4_0: 0.047 ms per layer <-- 3.6x speedup!
Decode speed for 7B model:
FP16: 14.0 GB / 200 GB/s = 70 ms/token -> 14 tok/s
Q4_0: 3.9 GB / 200 GB/s = 19 ms/token -> 52 tok/s
The compute cost of dequantization is negligible compared to the bandwidth savings. A few shifts and multiplies per weight value are cheap when the alternative is waiting for twice as many bytes to arrive from memory.
Quantized GEMM: A Different Story
For GEMM (prefill), the situation is more nuanced. GEMM is compute bound, so reducing weight size does not automatically speed things up – you are not bandwidth limited in the first place. However, quantized GEMM is still useful because:
- It reduces memory footprint, letting larger models fit in memory
- For moderate M values (16-64), you might still be bandwidth bound
- Apple’s AMX hardware has limited support for mixed-precision MMA
The typical approach is to dequantize tiles of weights into threadgroup memory in FP16, then use the standard SIMD group MMA operations on the dequantized tiles:
Quantized GEMM: Dequantize into Threadgroup Memory
====================================================
Device Memory: Threadgroup Memory: Registers:
+-----------+ +----------+ +--------+
| Q4 Weights| --load-->| FP16 Tile| --MMA--> | Accum |
| (compact) | | (32x32) | | (8x8) |
+-----------+ +----------+ +--------+
1. Cooperatively load Q4 blocks from device memory
2. Dequantize to FP16 in threadgroup memory
3. Use simdgroup_multiply_accumulate on FP16 tiles
4. Repeat for all K-tiles
Worked Example: Full Thread Assignment
Let us trace through a complete example to see how threads are assigned for a real GEMV. Consider:
Problem: y = x * W
x: [1, 4096] (one token's hidden state)
W: [4096, 4096] (Q projection weight)
y: [1, 4096] (query vector)
Configuration:
- Threadgroup size: 256 threads = 8 SIMD groups x 32 lanes
- Each SIMD group: 1 output element
- Grid: 4096 / 8 = 512 threadgroups
Memory layout: W stored transposed as Wt [4096 x 4096]
Let us follow Threadgroup 0, SIMD group 3, Lane 17:
Thread Identity
===============
Threadgroup: 0
SIMD group: 3
Lane: 17
Output assignment
=================
output index n = threadgroup * 8 + simd_group = 0 * 8 + 3 = 3
This thread helps compute y[3].
Work assignment
===============
Lane 17 processes every 32nd element along K:
k = 17, 49, 81, 113, ..., 4081
Total iterations: 4096 / 32 = 128
For each iteration (e.g., k=17):
partial_sum += x[17] * Wt[3 * 4096 + 17]
After all 128 iterations:
partial_sum = x[17]*Wt[3,17] + x[49]*Wt[3,49] + ... + x[4081]*Wt[3,4081]
Reduction
=========
simd_sum(partial_sum) adds all 32 lanes' partial sums:
y[3] = sum over all k: x[k] * Wt[3, k]
= dot(x, column 3 of W)
Only lane 0 of SIMD group 3 writes y[3] to memory.
Summary for entire kernel
=========================
512 threadgroups * 8 SIMD groups = 4096 dot products
Each SIMD group: 32 threads * 128 iterations = 4096 multiply-adds
Total: 4096 * 4096 = 16.7M multiply-adds (matches 2*K*N/2 operations)
Choosing the Right Strategy
Here is a decision tree for choosing the right matmul strategy:
Matmul Strategy Decision Tree
==============================
What is M?
|
+-------------+-------------+
| | |
M = 1 2-16 > 16
| | |
GEMV Batched GEMV GEMM
| | |
+----+----+ +----+----+ +----+----+
| | | | | |
Small N Large N | Use SIMD
(4096) (32K+) Share Group
| | weight MMA
1 SIMD Lots of loads with tiling
group threadgroups |
per out +----+----+
| | |
Standard FP16 Quantized
SIMD GEMV tiles (dequant
+ MMA into TG
memory)
And here are rough performance expectations on Apple M-series GPUs:
Performance Expectations (M2 Pro, ~200 GB/s)
=============================================
Operation Size FP16 Q4_0
--------- ---- ---- ----
GEMV (decode, 1 layer) [1,4096]*[4096,4096] 0.17ms 0.05ms
GEMV (vocab projection) [1,4096]*[4096,32000] 1.28ms 0.36ms
Batched GEMV (M=4) [4,4096]*[4096,4096] 0.17ms 0.06ms
GEMM (prefill, M=512) [512,4096]*[4096,4096] ~2ms ~2ms
Note: Batched GEMV with M=4 is barely slower than M=1 GEMV because
the same weight data is read -- only the compute increases.
GEMM times are similar for FP16 and Q4 because GEMM is compute-bound.
Summary
Matrix multiplication is the heart of neural network inference. On Metal GPUs, you need different strategies depending on the shape:
-
GEMV (M=1, decode): Bandwidth bound. Split the K dimension across SIMD lanes, reduce with
simd_sum. Optimize for coalesced memory access and maximize bandwidth utilization. -
Batched GEMV (M=2-16): Still bandwidth bound but with some weight reuse. Load weights once, apply to all M rows. Significant speedup over M separate GEMVs.
-
GEMM (M>16, prefill): Can be compute bound with proper tiling. Use cooperative loading into threadgroup memory and SIMD group MMA for maximum throughput.
-
Quantized variants: For GEMV, quantization directly translates to speedup via reduced bandwidth. Dequantize on-the-fly in registers. For GEMM, dequantize into threadgroup memory tiles, then use standard MMA.
The next chapter will look at where these matmuls are used in the context of attention – where Q, K, and V come from matmuls, and then the attention computation itself introduces a different kind of matrix multiplication.
The Attention Mechanism
If matrix multiplication is the workhorse of transformer inference, attention is the brain. It is the mechanism that allows each token to look at every other token in the sequence and decide what information is relevant. Without attention, a transformer is just a stack of feed-forward networks – powerful but unable to model relationships between positions in a sequence.
In this chapter, we are going to take attention apart piece by piece. We will start with the mathematical formulation, walk through a complete numerical example, explore multi-head and grouped-query variants, understand the KV cache, and analyze the computational complexity. By the end, you will understand every detail of how attention works, which will set us up perfectly for the FlashAttention optimization in the next chapter.
The Core Idea: Queries, Keys, and Values
Attention is fundamentally a lookup mechanism. Think of it like a fuzzy dictionary lookup: you have a query (“what am I looking for?”), a set of keys (“what is available?”), and corresponding values (“what information do those keys hold?”). The twist is that instead of an exact match, you compute a similarity score between the query and every key, then return a weighted combination of the values.
Attention as Fuzzy Dictionary Lookup
=====================================
Query: "What comes after 'the'?"
Keys: Values: Similarity:
---- ------ -----------
"the" ---> [0.2, 0.8, ...] 0.85 (high -- relevant!)
"cat" ---> [0.5, 0.1, ...] 0.72 (medium)
"sat" ---> [0.3, 0.6, ...] 0.15 (low)
"on" ---> [0.1, 0.4, ...] 0.62 (medium)
Output = 0.85 * val("the") + 0.72 * val("cat") + 0.15 * val("sat") + ...
(after normalizing similarities with softmax)
But where do queries, keys, and values come from? They are all produced by linear projections of the input embeddings.
Self-Attention: Q, K, V Projections
Given an input sequence of token embeddings X with shape [seq_len, d_model], we
produce Q, K, and V through three separate linear projections:
Q = X * W_Q where W_Q has shape [d_model, d_k]
K = X * W_K where W_K has shape [d_model, d_k]
V = X * W_V where W_V has shape [d_model, d_v]
In the original transformer paper and most modern LLMs, d_k = d_v = d_model. Each
projection is a matrix multiplication – exactly the kind we optimized in the previous
chapter.
Self-Attention Projections
===========================
Input X: [seq_len x d_model]
X ----+-----> [W_Q] -----> Q [seq_len x d_k]
|
+-----> [W_K] -----> K [seq_len x d_k]
|
+-----> [W_V] -----> V [seq_len x d_v]
"Self" attention means Q, K, V all come from the SAME input X.
This is in contrast to cross-attention where Q comes from one
source and K, V come from another (used in encoder-decoder models).
The Attention Equation
Once we have Q, K, and V, the attention computation is:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Let us break this down step by step.
Step 1: Compute Attention Scores (Q * K^T)
We multiply Q by the transpose of K. If Q has shape [seq_len, d_k] and K has shape
[seq_len, d_k], then Q * K^T has shape [seq_len, seq_len]. Each element (i, j) is
the dot product of query vector i with key vector j – a measure of how much token i
should attend to token j.
Q * K^T --> Attention Score Matrix
======================================
Q = [ q0 ] K^T = [ k0^T k1^T k2^T k3^T ]
[ q1 ]
[ q2 ] (each qi, kj is a vector of dimension d_k)
[ q3 ]
Q * K^T = [ q0.k0 q0.k1 q0.k2 q0.k3 ]
[ q1.k0 q1.k1 q1.k2 q1.k3 ]
[ q2.k0 q2.k1 q2.k2 q2.k3 ]
[ q3.k0 q3.k1 q3.k2 q3.k3 ]
Element (i,j) = dot product of query i with key j
= how much should token i attend to token j?
Step 2: Scale by sqrt(d_k)
We divide every score by sqrt(d_k). Why? Without scaling, the dot products grow in
magnitude proportional to the dimension. If d_k = 64, the dot products might be around
64 in magnitude. When you pass large numbers into softmax, the output becomes extremely
peaked – almost all the weight goes to one key, and the gradients vanish. Dividing by
sqrt(64) = 8 keeps the values in a reasonable range.
Why Scale?
==========
Without scaling (d_k = 64):
Typical dot product magnitude: ~64
softmax([64, 2, 1, 3]) = [1.000, 0.000, 0.000, 0.000]
--> Almost a hard lookup. Gradients vanish.
With scaling:
Scaled dot product: ~64/8 = ~8
softmax([8.0, 0.25, 0.125, 0.375]) = [0.997, 0.001, 0.001, 0.001]
--> Still peaked, but gradients can flow.
The scaling keeps variance of dot products at ~1.0
regardless of dimension, assuming inputs have unit variance.
Step 3: Apply Softmax
The softmax function converts the scaled scores into a probability distribution over keys for each query position. Each row of the score matrix is independently softmaxed:
softmax(z_i) = exp(z_i) / sum_j(exp(z_j))
After softmax, each row sums to 1.0, and all values are between 0 and 1. These are the attention weights.
Softmax: Scores --> Weights
============================
Before softmax (one row, after scaling):
[ 2.1, 0.5, -0.3, 1.2 ]
exp each:
[ 8.17, 1.65, 0.74, 3.32 ]
sum = 13.88
Divide by sum:
[ 0.589, 0.119, 0.053, 0.239 ]
^ ^
Token 0 gets most Token 3 gets
attention second most
Sum = 1.0 (it is a probability distribution)
Step 4: Weighted Sum of Values
Finally, we multiply the attention weights by V. Each output vector is a weighted combination of value vectors, where the weights come from softmax:
Weighted Value Sum
===================
Attention weights (one row): [0.589, 0.119, 0.053, 0.239]
V = [ v0 ] = [ 0.1, 0.9, 0.3 ]
[ v1 ] [ 0.5, 0.2, 0.8 ]
[ v2 ] [ 0.7, 0.4, 0.1 ]
[ v3 ] [ 0.3, 0.6, 0.5 ]
output = 0.589 * v0 + 0.119 * v1 + 0.053 * v2 + 0.239 * v3
= 0.589 * [0.1, 0.9, 0.3]
+ 0.119 * [0.5, 0.2, 0.8]
+ 0.053 * [0.7, 0.4, 0.1]
+ 0.239 * [0.3, 0.6, 0.5]
= [0.227, 0.718, 0.388]
The output is dominated by v0 (weight 0.589) and v3 (weight 0.239).
The Full Pipeline
Complete Attention Computation
===============================
Input: X [seq_len x d_model]
|
+--[W_Q]--> Q [seq_len x d_k]
+--[W_K]--> K [seq_len x d_k]
+--[W_V]--> V [seq_len x d_v]
|
v
Scores = Q * K^T [seq_len x seq_len]
|
v
Scaled = Scores / sqrt(d_k) [seq_len x seq_len]
|
v
Masked = apply causal mask [seq_len x seq_len]
|
v
Weights = softmax(Masked) [seq_len x seq_len]
|
v
Output = Weights * V [seq_len x d_v]
|
v
Projected = Output * W_O [seq_len x d_model]
Multi-Head Attention: Divide and Conquer
In practice, we do not compute a single attention function. Instead, we split the representation into multiple heads, each operating on a smaller dimension. This allows the model to attend to information from different representation subspaces simultaneously.
If d_model = 512 and we use h = 8 heads, then each head operates on
d_k = d_model / h = 64 dimensions.
Multi-Head Attention
=====================
d_model = 512, h = 8 heads, d_k = 64 per head
Input X: [seq_len x 512]
Full Q projection: X * W_Q --> Q_full [seq_len x 512]
Split into 8 heads:
Q_full = [ Q_head0 | Q_head1 | Q_head2 | ... | Q_head7 ]
[s x 64] [s x 64] [s x 64] [s x 64]
Similarly for K and V.
Each head computes attention independently:
head_i = Attention(Q_head_i, K_head_i, V_head_i)
Concatenate results:
MultiHead = [ head_0 | head_1 | ... | head_7 ] [seq_len x 512]
Final projection:
Output = MultiHead * W_O [seq_len x 512]
Why multiple heads? Different heads can learn different types of relationships:
What Different Heads Might Learn
=================================
Head 0: Syntactic dependencies ("the" attends to its noun)
Head 1: Positional patterns (attend to previous token)
Head 2: Semantic similarity (similar words attend to each other)
Head 3: Coreference ("it" attends to its referent)
Head 4: Long-range dependencies (closing bracket attends to opening)
...
Each head has its own Q, K, V projections, so each head
learns to look for different patterns in the data.
In terms of implementation, multi-head attention is usually done by performing the full
projection to get [seq_len, d_model] shaped Q, K, V, then reshaping to
[seq_len, n_heads, d_k] (or equivalently [n_heads, seq_len, d_k] after transposing).
The attention is then computed in parallel across all heads – on a GPU, this is a batched
operation.
Implementation: Reshape and Batch
==================================
After Q projection: Q [seq_len x d_model]
Reshape to: Q [seq_len x n_heads x d_k]
Transpose: Q [n_heads x seq_len x d_k]
Now attention is a batched operation:
For each head h in [0, n_heads):
scores_h = Q[h] * K[h]^T [seq_len x seq_len]
weights_h = softmax(scores_h) [seq_len x seq_len]
output_h = weights_h * V[h] [seq_len x d_v]
This is embarrassingly parallel -- all heads computed simultaneously.
Grouped Query Attention (GQA)
Standard multi-head attention requires storing separate K and V for every head. For a
model with 32 heads generating 4096-length sequences with d_k = 128, that is:
KV cache per layer = 2 * 32 * 4096 * 128 * 2 bytes = 64 MB
For 32 layers: 2 GB just for the KV cache!
Grouped Query Attention (GQA) reduces this by sharing K and V across groups of query heads. Instead of 32 KV heads, you might use only 8. Each KV head serves a group of 4 query heads.
Multi-Head vs Grouped Query vs Multi-Query
============================================
Multi-Head Attention (MHA): n_kv_heads = n_q_heads = 32
-------------------------------------------------------
Q heads: Q0 Q1 Q2 Q3 Q4 Q5 ... Q31
K heads: K0 K1 K2 K3 K4 K5 ... K31
V heads: V0 V1 V2 V3 V4 V5 ... V31
(Each Q head has its own K,V pair)
Grouped Query Attention (GQA): n_kv_heads = 8, group_size = 4
---------------------------------------------------------------
Q heads: Q0 Q1 Q2 Q3 | Q4 Q5 Q6 Q7 | Q8 ... | Q28 Q29 Q30 Q31
K heads: K0 | K1 | K2 | K7
V heads: V0 | V1 | V2 | V7
(4 Q heads share each K,V pair)
Multi-Query Attention (MQA): n_kv_heads = 1
---------------------------------------------
Q heads: Q0 Q1 Q2 Q3 Q4 Q5 ... Q31
K heads: K0
V heads: V0
(All Q heads share a single K,V pair)
The memory savings are substantial:
KV Cache Size Comparison (per layer)
======================================
seq_len = 4096, d_k = 128, FP16
MHA (32 KV heads): 2 * 32 * 4096 * 128 * 2 = 64 MB
GQA (8 KV heads): 2 * 8 * 4096 * 128 * 2 = 16 MB (4x reduction)
MQA (1 KV head): 2 * 1 * 4096 * 128 * 2 = 2 MB (32x reduction)
For 32 layers:
MHA: 2048 MB
GQA: 512 MB <-- Llama 2 70B, Llama 3
MQA: 64 MB
In the GPU kernel, GQA changes how you index into the KV cache. For query head h, the
corresponding KV head is h / group_size:
// GQA: mapping query heads to KV heads
uint q_head = head_index; // 0..31
uint kv_head = q_head / group_size; // 0..7 (if group_size=4)
// Load Q from q_head's slice
device const half* q = Q + q_head * d_k;
// Load K, V from kv_head's slice (shared by 4 Q heads)
device const half* k = K_cache + kv_head * seq_len * d_k;
device const half* v = V_cache + kv_head * seq_len * d_v;
The KV Cache: Why We Store K and V
During autoregressive decoding, we generate one token at a time. At step t, we need to
compute attention between the new token’s query and all previous tokens’ keys and values.
Without caching, we would need to reproject all previous tokens through W_K and W_V every
single step – a massive waste of computation.
The KV cache stores the K and V vectors for all previous tokens. At each decoding step, we only compute K and V for the new token and append them to the cache.
KV Cache During Decoding
=========================
Step 1: Process token "The"
+---------+
| K: k_0 | V: v_0
+---------+
Q = q_0, attend to [k_0]
Step 2: Process token "cat"
+---------+---------+
| K: k_0 | K: k_1 | V: v_0, v_1
+---------+---------+
Q = q_1, attend to [k_0, k_1]
Step 3: Process token "sat"
+---------+---------+---------+
| K: k_0 | K: k_1 | K: k_2 | V: v_0, v_1, v_2
+---------+---------+---------+
Q = q_2, attend to [k_0, k_1, k_2]
Step 4: Process token "on"
+---------+---------+---------+---------+
| K: k_0 | K: k_1 | K: k_2 | K: k_3 | V: v_0, v_1, v_2, v_3
+---------+---------+---------+---------+
Q = q_3, attend to [k_0, k_1, k_2, k_3]
At each step, we compute K, V for only the NEW token
and append to the cache. The cache grows by one entry per step.
Without KV cache:
Step t: compute K, V for ALL t tokens, then attend
Cost: O(t * d_model * d_k) for projections alone
Total over T steps: O(T^2 * d_model * d_k) -- quadratic!
With KV cache:
Step t: compute K, V for 1 new token, append to cache
Cost: O(d_model * d_k) for projections
Total over T steps: O(T * d_model * d_k) for projections -- linear!
(Attention itself is still O(T * d_k) per step, O(T^2 * d_k) total)
The KV cache is typically stored as a contiguous buffer per layer, per head, pre-allocated to the maximum sequence length:
KV Cache Memory Layout
=======================
For one layer, one KV head:
K cache: [max_seq_len x d_k]
+------+------+------+------+------+------+------+------+
| k_0 | k_1 | k_2 | k_3 | .... | .... | .... | .... |
+------+------+------+------+------+------+------+------+
^ ^ ^
filled positions | empty (pre-allocated)
current position
V cache: [max_seq_len x d_v]
+------+------+------+------+------+------+------+------+
| v_0 | v_1 | v_2 | v_3 | .... | .... | .... | .... |
+------+------+------+------+------+------+------+------+
At each decode step:
1. Compute k_new, v_new for the new token
2. Write k_new to K_cache[current_pos]
3. Write v_new to V_cache[current_pos]
4. Compute attention: q_new vs K_cache[0..current_pos]
5. Increment current_pos
Causal Masking
In autoregressive language models, a token should only attend to tokens that came before it (or at the same position). This is called causal masking or the “no peeking into the future” constraint.
During prefill (when processing the entire prompt at once), we enforce this by adding a mask to the attention scores before softmax:
Causal Mask (4 tokens)
=======================
Before masking: After masking:
[ 2.1 0.5 1.3 0.8 ] [ 2.1 -inf -inf -inf ]
[ 0.3 1.7 0.9 0.4 ] [ 0.3 1.7 -inf -inf ]
[ 1.1 0.2 2.3 1.5 ] [ 1.1 0.2 2.3 -inf ]
[ 0.7 1.4 0.6 1.9 ] [ 0.7 1.4 0.6 1.9 ]
Mask matrix:
[ 0 -inf -inf -inf ]
[ 0 0 -inf -inf ]
[ 0 0 0 -inf ]
[ 0 0 0 0 ]
Token i can only see tokens 0..i (the lower triangle).
-inf becomes 0 after softmax, so masked tokens contribute nothing.
After softmax, the -inf entries become 0, effectively removing those tokens from the
weighted sum. The result is that each token’s output depends only on itself and all
preceding tokens.
After softmax with causal mask:
================================
[ 1.000 0.000 0.000 0.000 ] Token 0: only sees itself
[ 0.197 0.803 0.000 0.000 ] Token 1: sees tokens 0, 1
[ 0.192 0.078 0.639 0.000 ] Token 2: sees tokens 0, 1, 2 (skipped)
[ 0.092 0.186 0.083 0.306 ] Token 3: sees all tokens (skipped)
^^^^^^
sum may not be 1.0
because I showed
approximate values
During decoding, causal masking is implicit – the query is for position t and the KV
cache only contains positions 0..t-1, so there is nothing to mask.
Computational Complexity
Let us carefully count the operations in attention for a sequence of length n with
head dimension d:
Operation Counts for Attention
===============================
1. Q * K^T: matrix multiply [n x d] * [d x n] = [n x n]
Operations: 2 * n * n * d (multiply + add)
2. Scale: divide each of n*n elements by sqrt(d)
Operations: n * n
3. Softmax: for each of n rows, compute exp and normalize over n elements
Operations: ~3 * n * n (exp, sum, divide)
4. Weights * V: matrix multiply [n x n] * [n x d] = [n x d]
Operations: 2 * n * n * d
Total: ~4 * n^2 * d + 4 * n^2
= O(n^2 * d)
For n = 4096, d = 128:
4 * 4096^2 * 128 = ~8.6 billion operations per head
Times 32 heads = ~274 billion operations
Memory for attention matrix: n * n * sizeof(float) per head
= 4096 * 4096 * 4 = 64 MB per head
= 2 GB for 32 heads <-- This is the problem FlashAttention solves!
The quadratic scaling in sequence length is the fundamental limitation:
How Attention Scales with Sequence Length
==========================================
n Ops (per head) Attn Matrix Size
--- --------------- ----------------
512 ~34M 1 MB
1024 ~134M 4 MB
2048 ~537M 16 MB
4096 ~2.15B 64 MB
8192 ~8.59B 256 MB
16384 ~34.4B 1 GB
32768 ~137.4B 4 GB
131072 ~2.20T 64 GB (Llama 3.1 context length!)
The attention matrix alone exceeds GPU memory for long sequences.
This is why FlashAttention (next chapter) is essential.
Worked Example: 4-Token, 2-Head Attention
Let us trace through a complete attention computation with concrete numbers. We will use
a tiny example: 4 tokens, 2 heads, d_model = 8, d_k = d_v = 4 per head.
Setup
Model parameters:
d_model = 8
n_heads = 2
d_k = d_v = d_model / n_heads = 4
seq_len = 4
Input tokens: "The cat sat on"
Input embeddings X [4 x 8]:
Token 0 ("The"): [ 1.0, 0.5, 0.3, 0.8, | 0.2, 0.7, 0.4, 0.6 ]
Token 1 ("cat"): [ 0.3, 1.2, 0.7, 0.1, | 0.9, 0.4, 0.5, 0.3 ]
Token 2 ("sat"): [ 0.6, 0.2, 1.1, 0.4, | 0.3, 0.8, 1.0, 0.2 ]
Token 3 ("on"): [ 0.4, 0.7, 0.5, 1.0, | 0.6, 0.3, 0.7, 0.9 ]
Step 1: Compute Q, K, V (via projection)
For simplicity, let us assume the projections have already been computed and show the result after splitting into heads:
After projection and head split:
Head 0 (d_k = 4):
Q0 = [ 1.2, 0.3, 0.5, 0.8 ] K0 = [ 0.9, 0.4, 0.7, 0.2 ] V0 = [ 0.3, 0.8, 0.5, 0.1 ]
[ 0.4, 1.1, 0.2, 0.6 ] [ 0.5, 1.0, 0.3, 0.8 ] [ 0.7, 0.2, 0.9, 0.4 ]
[ 0.7, 0.5, 0.9, 0.3 ] [ 0.8, 0.6, 1.1, 0.5 ] [ 0.4, 0.6, 0.3, 0.8 ]
[ 0.3, 0.8, 0.4, 1.0 ] [ 0.2, 0.7, 0.5, 1.0 ] [ 0.9, 0.5, 0.7, 0.3 ]
Head 1 (d_k = 4):
Q1 = [ 0.6, 0.9, 0.2, 0.4 ] K1 = [ 0.3, 0.7, 0.5, 0.1 ] V1 = [ 0.5, 0.4, 0.2, 0.7 ]
[ 0.8, 0.3, 0.7, 0.5 ] [ 0.6, 0.2, 0.8, 0.4 ] [ 0.2, 0.9, 0.6, 0.3 ]
[ 0.1, 0.6, 0.4, 0.8 ] [ 0.4, 0.5, 0.3, 0.9 ] [ 0.8, 0.3, 0.5, 0.6 ]
[ 0.5, 0.4, 0.9, 0.7 ] [ 0.7, 0.3, 0.6, 0.5 ] [ 0.3, 0.7, 0.4, 0.8 ]
Step 2: Q * K^T for Head 0
Scores0 = Q0 * K0^T [4 x 4]
Score(0,0) = Q0[0] . K0[0] = 1.2*0.9 + 0.3*0.4 + 0.5*0.7 + 0.8*0.2
= 1.08 + 0.12 + 0.35 + 0.16 = 1.71
Score(0,1) = Q0[0] . K0[1] = 1.2*0.5 + 0.3*1.0 + 0.5*0.3 + 0.8*0.8
= 0.60 + 0.30 + 0.15 + 0.64 = 1.69
Score(0,2) = Q0[0] . K0[2] = 1.2*0.8 + 0.3*0.6 + 0.5*1.1 + 0.8*0.5
= 0.96 + 0.18 + 0.55 + 0.40 = 2.09
Score(0,3) = Q0[0] . K0[3] = 1.2*0.2 + 0.3*0.7 + 0.5*0.5 + 0.8*1.0
= 0.24 + 0.21 + 0.25 + 0.80 = 1.50
(Computing all 16 entries similarly...)
Scores0 = [ 1.71 1.69 2.09 1.50 ]
[ 1.16 1.68 1.49 1.37 ]
[ 1.38 1.22 1.80 1.08 ]
[ 1.04 1.49 1.33 1.42 ]
Step 3: Scale
sqrt(d_k) = sqrt(4) = 2.0
Scaled0 = Scores0 / 2.0
Scaled0 = [ 0.855 0.845 1.045 0.750 ]
[ 0.580 0.840 0.745 0.685 ]
[ 0.690 0.610 0.900 0.540 ]
[ 0.520 0.745 0.665 0.710 ]
Step 4: Apply Causal Mask
Masked0 = [ 0.855 -inf -inf -inf ]
[ 0.580 0.840 -inf -inf ]
[ 0.690 0.610 0.900 -inf ]
[ 0.520 0.745 0.665 0.710 ]
Step 5: Softmax (row by row)
Row 0: softmax([0.855, -inf, -inf, -inf])
= [1.000, 0.000, 0.000, 0.000]
Row 1: softmax([0.580, 0.840]) (ignoring -inf entries)
exp: [1.786, 2.317] sum = 4.103
= [0.435, 0.565, 0.000, 0.000]
Row 2: softmax([0.690, 0.610, 0.900])
exp: [1.994, 1.840, 2.460] sum = 6.294
= [0.317, 0.292, 0.391, 0.000]
Row 3: softmax([0.520, 0.745, 0.665, 0.710])
exp: [1.682, 2.107, 1.945, 2.034] sum = 7.768
= [0.217, 0.271, 0.250, 0.262]
Weights0 = [ 1.000 0.000 0.000 0.000 ]
[ 0.435 0.565 0.000 0.000 ]
[ 0.317 0.292 0.391 0.000 ]
[ 0.217 0.271 0.250 0.262 ]
Step 6: Weighted Sum of Values
Output0 = Weights0 * V0 [4 x 4]
Output0[0] = 1.000 * V0[0]
= [0.300, 0.800, 0.500, 0.100]
Output0[1] = 0.435 * V0[0] + 0.565 * V0[1]
= 0.435 * [0.3, 0.8, 0.5, 0.1] + 0.565 * [0.7, 0.2, 0.9, 0.4]
= [0.131, 0.348, 0.218, 0.044] + [0.396, 0.113, 0.509, 0.226]
= [0.526, 0.461, 0.726, 0.270]
Output0[2] = 0.317 * V0[0] + 0.292 * V0[1] + 0.391 * V0[2]
= [0.095, 0.254, 0.159, 0.032]
+ [0.204, 0.058, 0.263, 0.117]
+ [0.156, 0.235, 0.117, 0.313]
= [0.456, 0.547, 0.539, 0.461]
Output0[3] = 0.217 * V0[0] + 0.271 * V0[1] + 0.250 * V0[2] + 0.262 * V0[3]
= [0.065, 0.174, 0.109, 0.022]
+ [0.190, 0.054, 0.244, 0.108]
+ [0.100, 0.150, 0.075, 0.200]
+ [0.236, 0.131, 0.183, 0.079]
= [0.591, 0.509, 0.611, 0.409]
Step 7: Concatenate Heads and Project
Head 1 would be computed identically (with its own Q1, K1, V1). Then the two heads’ outputs are concatenated along the last dimension and projected:
Concatenation:
Output0[i]: [d_v = 4 values]
Output1[i]: [d_v = 4 values]
Concat[i] = [Output0[i] | Output1[i]] (d_model = 8 values)
Final = Concat * W_O [4 x 8] * [8 x 8] = [4 x 8]
This produces the final attention output for this layer.
Decode Attention vs Prefill Attention
The attention computation looks quite different depending on the inference phase:
Prefill Attention (processing entire prompt)
=============================================
Q: [seq_len x d_k] (all query positions)
K: [seq_len x d_k] (all key positions)
V: [seq_len x d_v] (all value positions)
QK^T: [seq_len x seq_len] -- full attention matrix
Softmax: [seq_len x seq_len]
Output: [seq_len x d_v]
This is a GEMM problem. Matrix-matrix multiplications.
Causal mask applied explicitly.
Decode Attention (generating one token at a time)
==================================================
q: [1 x d_k] (single new query)
K: [seq_len x d_k] (all keys from KV cache)
V: [seq_len x d_v] (all values from KV cache)
q * K^T: [1 x seq_len] -- one row of attention scores
Softmax: [1 x seq_len]
Output: [1 x d_v]
This is a GEMV problem. Vector-matrix multiplications.
No causal mask needed (KV cache only has past tokens).
This distinction is critical for optimization. Prefill attention can use tiled GEMM techniques with high data reuse. Decode attention is bandwidth-bound, reading through the entire KV cache for each head.
Decode Attention: Memory Access Pattern
=========================================
For one head, one decode step at position t:
Read q: 1 * d_k * 2 bytes = 256 bytes (d_k=128)
Read K cache: t * d_k * 2 bytes = varies
Read V cache: t * d_v * 2 bytes = varies
Write output: 1 * d_v * 2 bytes = 256 bytes
At position t = 4096:
Read K: 4096 * 128 * 2 = 1 MB per head
Read V: 4096 * 128 * 2 = 1 MB per head
Total per head: ~2 MB
Total for 32 heads: ~64 MB
At 200 GB/s: ~0.32 ms
This grows linearly with sequence length.
At t = 32768: ~0.5 GB read, ~2.5 ms
Putting It All Together
Here is a Metal kernel sketch showing the structure of decode attention for one head:
// Decode attention: single query vs KV cache
kernel void decode_attention(
device const half* q [[buffer(0)]], // [d_k]
device const half* K_cache [[buffer(1)]], // [seq_len x d_k]
device const half* V_cache [[buffer(2)]], // [seq_len x d_v]
device half* output [[buffer(3)]], // [d_v]
constant uint& seq_len [[buffer(4)]],
constant uint& d_k [[buffer(5)]],
uint tid [[thread_position_in_grid]])
{
// Step 1: Compute attention scores (q . K[i] for each i)
float score = 0.0f;
for (uint d = 0; d < d_k; d++) {
score += float(q[d]) * float(K_cache[tid * d_k + d]);
}
score /= sqrt(float(d_k));
// Step 2: Softmax (requires two passes or online algorithm)
// ... (this is where FlashAttention comes in!)
// Step 3: Weighted sum of values
// output[d] = sum over i: weight[i] * V_cache[i * d_v + d]
// ...
}
This sketch is deliberately incomplete – computing softmax across all positions requires either materializing the full score vector (memory expensive for long sequences) or using the online softmax algorithm that FlashAttention employs. That is exactly what we will cover in the next chapter.
Summary
The attention mechanism is the defining feature of transformers. Here are the key takeaways:
-
Q, K, V projections are matrix multiplications that transform input embeddings into queries, keys, and values.
-
The attention equation
softmax(QK^T / sqrt(d_k)) * Vcomputes a weighted combination of values based on query-key similarity. -
Multi-head attention splits the representation into independent heads that can learn different attention patterns.
-
Grouped Query Attention shares KV heads across multiple query heads, reducing KV cache memory by the group size factor.
-
The KV cache stores previously computed keys and values during decoding, avoiding redundant recomputation. It grows linearly with sequence length.
-
Causal masking ensures tokens only attend to past positions, enforced via -inf masking before softmax during prefill.
-
Complexity is O(n^2 * d) – quadratic in sequence length. The attention matrix itself requires O(n^2) memory, which becomes prohibitive for long sequences.
-
Prefill vs decode have very different computational profiles: prefill is a GEMM problem (compute bound), decode is a GEMV problem (bandwidth bound).
In the next chapter, we will see how FlashAttention solves the O(n^2) memory problem through a clever tiling scheme and online softmax computation.
FlashAttention: Theory and Practice
In the previous chapter, we saw that standard attention requires materializing an
[n x n] attention matrix. For a sequence length of 4096, that is 64 MB per head. For
32 heads, 2 GB. For 32,768 tokens (Llama 3’s context), 32 GB just for attention weights.
This is clearly untenable.
FlashAttention is the algorithmic breakthrough that makes long-context attention practical. The key idea is beautifully simple: never materialize the full attention matrix. Instead, compute attention in tiles, maintaining a running softmax that produces exactly the same result as the standard algorithm – bit for bit.
This chapter will build up the theory from first principles, work through a complete numerical example, and then show how FlashAttention adapts for both prefill and decode phases, including advanced variants for GQA and speculative decoding.
The Problem: O(n^2) Memory
Let us be very concrete about what “materializing the attention matrix” means in the standard algorithm:
Standard Attention Memory Usage
================================
Step 1: S = Q * K^T Allocate [n x n] matrix S
Step 2: S = S / sqrt(d) In-place on S
Step 3: P = softmax(S) Allocate [n x n] matrix P (or overwrite S)
Step 4: O = P * V Allocate [n x d] matrix O
Peak memory: at least one [n x n] matrix in memory simultaneously
n = 4096, FP32:
[n x n] = 4096 * 4096 * 4 bytes = 64 MB (per head!)
32 heads = 2 GB
n = 32768, FP32:
[n x n] = 32768 * 32768 * 4 bytes = 4 GB (per head!)
32 heads = 128 GB <-- exceeds any GPU's memory
Even if memory were unlimited, there is a performance problem. The attention matrix is written to device memory in step 1/3 and read back in step 3/4. Those round-trips to device memory are slow:
Memory Traffic in Standard Attention
=====================================
Write S to memory: n^2 * 4 bytes (after Q * K^T)
Read S from memory: n^2 * 4 bytes (for softmax)
Write P to memory: n^2 * 4 bytes (softmax output)
Read P from memory: n^2 * 4 bytes (for P * V)
Total memory traffic: 4 * n^2 * 4 bytes = 16 * n^2 bytes
At n = 4096: 16 * 16M = 256 MB of memory traffic per head
At 200 GB/s: 1.3 ms just for the attention matrix I/O (per head!)
FlashAttention eliminates this memory traffic entirely by keeping everything in registers and threadgroup memory (SRAM).
The Key Insight: Online Softmax
The reason we need the full [n x n] matrix is softmax. Softmax requires knowing the
maximum value across an entire row (for numerical stability) and the sum of exponentials
(for normalization). Both of these seem to require seeing all n values before producing
any output.
The breakthrough is online softmax: a streaming algorithm that processes values one chunk at a time, maintaining running statistics that can be corrected as new data arrives.
Standard Softmax (Two-Pass)
The textbook softmax for a vector z of length n:
Pass 1: Find maximum
m = max(z[0], z[1], ..., z[n-1])
Pass 2: Compute exponentials and sum
For i in 0..n-1:
e[i] = exp(z[i] - m)
s = sum(e[0], e[1], ..., e[n-1])
Output: softmax[i] = e[i] / s
Requires storing all n values of z to make two passes.
Online Softmax (One-Pass, Streaming)
Online softmax processes elements one at a time (or one tile at a time), maintaining a
running maximum m and a running sum s. When a new maximum is found, the previously
accumulated sum is corrected:
Initialize:
m = -infinity (running maximum)
s = 0 (running sum of exp)
For each new value z[i]:
m_new = max(m, z[i])
s = s * exp(m - m_new) + exp(z[i] - m_new)
m = m_new
After all values:
softmax[i] = exp(z[i] - m) / s
The key line: s = s * exp(m - m_new) + exp(z[i] - m_new)
When m_new > m (new max found):
exp(m - m_new) < 1, so the old sum is scaled DOWN
This corrects for the fact that previous exponentials
were computed relative to the old (smaller) maximum
When m_new = m (no new max):
exp(m - m_new) = exp(0) = 1, so old sum unchanged
Just add the new exponential
Let us trace through an example:
Online Softmax Example: z = [2.0, 5.0, 1.0, 4.0]
===================================================
Step 0: m = -inf, s = 0
Step 1: z[0] = 2.0
m_new = max(-inf, 2.0) = 2.0
s = 0 * exp(-inf - 2.0) + exp(2.0 - 2.0)
= 0 + exp(0)
= 1.0
m = 2.0
Step 2: z[1] = 5.0
m_new = max(2.0, 5.0) = 5.0 <-- New max found!
s = 1.0 * exp(2.0 - 5.0) + exp(5.0 - 5.0)
= 1.0 * exp(-3.0) + exp(0)
= 1.0 * 0.0498 + 1.0
= 1.0498
m = 5.0
Verification: at this point, s should equal exp(2-5) + exp(5-5) = 0.0498 + 1 = 1.0498 ✓
The correction factor exp(2.0 - 5.0) = 0.0498 adjusts the old sum.
Step 3: z[2] = 1.0
m_new = max(5.0, 1.0) = 5.0 <-- No new max
s = 1.0498 * exp(5.0 - 5.0) + exp(1.0 - 5.0)
= 1.0498 * 1.0 + exp(-4.0)
= 1.0498 + 0.0183
= 1.0681
m = 5.0
Step 4: z[3] = 4.0
m_new = max(5.0, 4.0) = 5.0 <-- No new max
s = 1.0681 * exp(5.0 - 5.0) + exp(4.0 - 5.0)
= 1.0681 * 1.0 + exp(-1.0)
= 1.0681 + 0.3679
= 1.4360
m = 5.0
Final softmax:
softmax[0] = exp(2.0 - 5.0) / 1.4360 = 0.0498 / 1.4360 = 0.0347
softmax[1] = exp(5.0 - 5.0) / 1.4360 = 1.0000 / 1.4360 = 0.6964
softmax[2] = exp(1.0 - 5.0) / 1.4360 = 0.0183 / 1.4360 = 0.0128
softmax[3] = exp(4.0 - 5.0) / 1.4360 = 0.3679 / 1.4360 = 0.2562
Sum = 0.0347 + 0.6964 + 0.0128 + 0.2562 = 1.0001 ≈ 1.0 ✓
Online Softmax with Running Output Correction
But FlashAttention needs more than just softmax values – it needs softmax(S) * V,
where V is the value matrix. We need to maintain a running weighted sum of V that gets
corrected each time the maximum changes:
Online Attention (combining softmax + V multiply):
====================================================
Initialize:
m = -inf (running max)
s = 0 (running sum of exp)
o = 0 (running output, vector of size d_v)
For each position j (or tile of positions):
score = q . K[j] (dot product)
m_new = max(m, score)
correction = exp(m - m_new)
// Correct the old output and sum
o = o * correction
s = s * correction
// Add new contribution
weight = exp(score - m_new)
o = o + weight * V[j]
s = s + weight
m = m_new
After all positions:
output = o / s (normalize by total sum)
This is the core of FlashAttention. Let us trace through it:
Online Attention Example
=========================
q = [1, 0] (query, dimension 2 for simplicity)
K = [ 0.5, 0.3 ] V = [ 1.0, 0.0 ]
[ 0.8, -0.2 ] [ 0.0, 1.0 ]
[ 0.1, 0.7 ] [ 0.5, 0.5 ]
Scores = q . K[j]:
s0 = 1*0.5 + 0*0.3 = 0.5
s1 = 1*0.8 + 0*(-0.2) = 0.8
s2 = 1*0.1 + 0*0.7 = 0.1
(Skip scaling by sqrt(d) for clarity)
Step 0: m = -inf, s = 0, o = [0, 0]
Step 1: Process position 0 (score = 0.5)
m_new = max(-inf, 0.5) = 0.5
correction = exp(-inf - 0.5) = 0
o = [0,0] * 0 = [0, 0]
s = 0 * 0 = 0
weight = exp(0.5 - 0.5) = 1.0
o = [0,0] + 1.0 * [1.0, 0.0] = [1.0, 0.0]
s = 0 + 1.0 = 1.0
m = 0.5
Step 2: Process position 1 (score = 0.8)
m_new = max(0.5, 0.8) = 0.8 <-- New max!
correction = exp(0.5 - 0.8) = exp(-0.3) = 0.7408
o = [1.0, 0.0] * 0.7408 = [0.7408, 0.0]
s = 1.0 * 0.7408 = 0.7408
weight = exp(0.8 - 0.8) = 1.0
o = [0.7408, 0.0] + 1.0 * [0.0, 1.0] = [0.7408, 1.0]
s = 0.7408 + 1.0 = 1.7408
m = 0.8
Step 3: Process position 2 (score = 0.1)
m_new = max(0.8, 0.1) = 0.8 <-- No new max
correction = exp(0.8 - 0.8) = 1.0
o = [0.7408, 1.0] * 1.0 = [0.7408, 1.0]
s = 1.7408 * 1.0 = 1.7408
weight = exp(0.1 - 0.8) = exp(-0.7) = 0.4966
o = [0.7408, 1.0] + 0.4966 * [0.5, 0.5] = [0.7408+0.2483, 1.0+0.2483]
= [0.9891, 1.2483]
s = 1.7408 + 0.4966 = 2.2374
m = 0.8
Normalize: output = o / s = [0.9891/2.2374, 1.2483/2.2374]
= [0.4421, 0.5579]
Verification (standard method):
scores = [0.5, 0.8, 0.1]
softmax = exp([0.5, 0.8, 0.1] - 0.8) / sum
= [exp(-0.3), exp(0), exp(-0.7)] / sum
= [0.7408, 1.0, 0.4966] / 2.2374
= [0.3311, 0.4470, 0.2219]
output = 0.3311*[1,0] + 0.4470*[0,1] + 0.2219*[0.5,0.5]
= [0.3311, 0] + [0, 0.4470] + [0.1110, 0.1110]
= [0.4421, 0.5580] ✓ (matches!)
The Tiling Strategy
Online softmax lets us process keys/values one at a time. But processing one position at
a time would be very slow – too little parallelism. FlashAttention processes keys and
values in tiles of size Bk (typically 32 or 64).
FlashAttention Tiling
======================
Instead of materializing the full [n x n] attention matrix:
Standard:
+--------------------------------------------------+
| |
| Full [n x n] matrix in memory |
| |
+--------------------------------------------------+
Memory: O(n^2)
FlashAttention:
Process K/V in tiles, never store more than [Bq x Bk]:
+------+
| Tile | Bq x Bk Only this small tile is in memory
+------+ at a time!
Iterate over all K/V tiles:
K/V: [====Tile 0====][====Tile 1====][====Tile 2====]...
Bk cols Bk cols Bk cols
For each tile:
1. Load Bk keys and values into threadgroup memory
2. Compute Bq x Bk scores (in registers)
3. Update running max, sum, and output (online softmax)
4. Move to next tile
Memory: O(Bq * Bk) for the score tile
+ O(Bq * d) for the running output
= O(n) total (since Bq, Bk, d are all constants)
The Full Algorithm
Here is the complete FlashAttention algorithm for prefill (multiple query positions):
FlashAttention Algorithm (Prefill)
====================================
Inputs: Q [n x d], K [n x d], V [n x d]
Output: O [n x d]
Tile sizes: Bq (query tile), Bk (key/value tile)
For each query tile q_tile in [0, n/Bq):
Load Q_tile = Q[q_tile*Bq : (q_tile+1)*Bq] // [Bq x d]
Initialize per-row:
m[Bq] = -infinity // running max for each query in tile
s[Bq] = 0 // running sum for each query in tile
O_tile[Bq x d] = 0 // running output for each query in tile
For each KV tile kv_tile in [0, n/Bk):
// --- Skip if causally masked ---
// If ALL queries in q_tile come BEFORE all keys in kv_tile,
// then all scores are masked to -inf. Skip entirely.
if (q_tile + 1) * Bq <= kv_tile * Bk:
continue
Load K_tile = K[kv_tile*Bk : (kv_tile+1)*Bk] // [Bk x d]
Load V_tile = V[kv_tile*Bk : (kv_tile+1)*Bk] // [Bk x d]
// --- Compute score tile ---
S_tile = Q_tile * K_tile^T / sqrt(d) // [Bq x Bk]
// --- Apply causal mask within tile ---
For each (i, j) in S_tile:
if (q_tile*Bq + i) < (kv_tile*Bk + j):
S_tile[i][j] = -infinity
// --- Online softmax update ---
For each row i in [0, Bq):
m_new = max(m[i], max(S_tile[i]))
correction = exp(m[i] - m_new)
// Correct old accumulators
O_tile[i] *= correction
s[i] *= correction
// Add new contributions
For each j in [0, Bk):
weight = exp(S_tile[i][j] - m_new)
O_tile[i] += weight * V_tile[j]
s[i] += weight
m[i] = m_new
// --- Normalize ---
For each row i in [0, Bq):
O_tile[i] /= s[i]
Write O_tile to O[q_tile*Bq : (q_tile+1)*Bq]
Let us visualize how the tiles march through the attention matrix:
Tiling Visualization (n=16, Bq=4, Bk=4)
==========================================
Full attention matrix [16 x 16] with causal mask:
KV tiles: 0 1 2 3
+----+----+----+----+
Q tile 0: | XX | | | | XX = computed tile
+----+----+----+----+ .. = skipped (causal)
Q tile 1: | XX | XX | | |
+----+----+----+----+
Q tile 2: | XX | XX | XX | |
+----+----+----+----+
Q tile 3: | XX | XX | XX | XX |
+----+----+----+----+
Processing order for Q tile 2:
1. Load Q[8:12] (4 queries)
2. Load K[0:4], V[0:4] -> compute 4x4 scores, update
3. Load K[4:8], V[4:8] -> compute 4x4 scores, update
4. Load K[8:12], V[8:12] -> compute 4x4 scores (with mask), update
5. Skip K[12:16] (causally masked) -> all queries < all keys
6. Normalize and write O[8:12]
At no point do we have more than a 4x4 score tile in memory.
Memory: O(16 * d) for O instead of O(16 * 16) for the full matrix.
Memory Analysis: O(n) Instead of O(n^2)
Let us carefully count the memory used by FlashAttention:
FlashAttention Memory Usage
============================
Threadgroup memory (SRAM):
Q_tile: Bq * d * sizeof(half) = 32 * 128 * 2 = 8 KB
K_tile: Bk * d * sizeof(half) = 32 * 128 * 2 = 8 KB
V_tile: Bk * d * sizeof(half) = 32 * 128 * 2 = 8 KB
S_tile: Bq * Bk * sizeof(float) = 32 * 32 * 4 = 4 KB
Total SRAM: ~ 28 KB
Per-query state (registers):
m[Bq]: Bq * sizeof(float) = 32 * 4 = 128 bytes
s[Bq]: Bq * sizeof(float) = 32 * 4 = 128 bytes
O_tile: Bq * d * sizeof(float) = 32 * 128 * 4 = 16 KB
Total registers: ~ 16 KB
Device memory:
Input Q, K, V: 3 * n * d * sizeof(half) = O(n * d)
Output O: n * d * sizeof(half) = O(n * d)
NO attention matrix stored! = O(n)
Comparison for n = 32768, d = 128:
Standard: n^2 * 4 bytes = 4 GB per head
FlashAttention: n * d * 2 bytes = 8 MB per head (500x reduction!)
A Complete Numerical Example with Tiling
Let us trace FlashAttention on a small example: n=6, d=2, Bq=2, Bk=3.
Setup
======
Q = [ 1.0, 0.5 ] K = [ 0.3, 0.7 ] V = [ 1.0, 0.0 ]
[ 0.8, -0.1 ] [ 0.6, 0.2 ] [ 0.0, 1.0 ]
[ 0.2, 0.9 ] [-0.1, 0.8 ] [ 0.5, 0.5 ]
[-0.3, 0.4 ] [ 0.4, -0.3 ] [ 0.8, 0.2 ]
[ 0.7, 0.6 ] [ 0.9, 0.1 ] [ 0.3, 0.7 ]
[ 0.1, -0.5 ] [ 0.2, 0.5 ] [ 0.6, 0.4 ]
sqrt(d) = sqrt(2) = 1.414
Tile sizes: Bq = 2 (queries), Bk = 3 (keys)
Q tile 0 (queries 0-1), KV tile 0 (keys 0-2):
Q_tile = [ 1.0, 0.5 ] K_tile = [ 0.3, 0.7 ] V_tile = [ 1.0, 0.0 ]
[ 0.8, -0.1 ] [ 0.6, 0.2 ] [ 0.0, 1.0 ]
[-0.1, 0.8 ] [ 0.5, 0.5 ]
S_tile = Q_tile * K_tile^T / 1.414:
Q[0].K[0] = 1.0*0.3 + 0.5*0.7 = 0.65 /1.414 = 0.460
Q[0].K[1] = 1.0*0.6 + 0.5*0.2 = 0.70 /1.414 = 0.495
Q[0].K[2] = 1.0*(-0.1)+0.5*0.8 = 0.30 /1.414 = 0.212
Q[1].K[0] = 0.8*0.3+(-0.1)*0.7 = 0.17 /1.414 = 0.120
Q[1].K[1] = 0.8*0.6+(-0.1)*0.2 = 0.46 /1.414 = 0.325
Q[1].K[2] = 0.8*(-0.1)+(-0.1)*0.8=-0.16 /1.414 = -0.113
S_tile = [ 0.460 0.495 0.212 ]
[ 0.120 0.325 -0.113 ]
Causal mask (query 0 sees key 0 only, query 1 sees keys 0-1):
S_tile = [ 0.460 -inf -inf ]
[ 0.120 0.325 -inf ]
Online softmax update (initial m=-inf, s=0, O=0):
Row 0:
m_new = max(-inf, 0.460) = 0.460
correction = exp(-inf - 0.460) = 0
O[0] = [0,0]*0 = [0, 0]
s[0] = 0*0 = 0
weight for key 0: exp(0.460 - 0.460) = 1.0
O[0] += 1.0 * V[0] = [1.0, 0.0]
s[0] += 1.0 = 1.0
(keys 1,2 masked: weight = exp(-inf) = 0)
m[0] = 0.460
Row 1:
m_new = max(-inf, max(0.120, 0.325)) = 0.325
correction = 0
weight for key 0: exp(0.120 - 0.325) = exp(-0.205) = 0.8146
weight for key 1: exp(0.325 - 0.325) = 1.0
O[1] = 0.8146*[1,0] + 1.0*[0,1] = [0.8146, 1.0]
s[1] = 0.8146 + 1.0 = 1.8146
m[1] = 0.325
Q tile 0 (queries 0-1), KV tile 1 (keys 3-5):
K_tile = [ 0.4, -0.3 ] V_tile = [ 0.8, 0.2 ]
[ 0.9, 0.1 ] [ 0.3, 0.7 ]
[ 0.2, 0.5 ] [ 0.6, 0.4 ]
But wait -- query 0 is at position 0, and keys 3-5 are all in the future.
Query 1 is at position 1, and keys 3-5 are also in the future.
ALL entries would be masked to -inf.
We can skip this entire KV tile! (Causal skip optimization)
After processing all KV tiles for Q tile 0, normalize:
O[0] = [1.0, 0.0] / 1.0 = [1.0, 0.0]
O[1] = [0.8146, 1.0] / 1.8146 = [0.449, 0.551]
This is exactly what standard attention would produce (query 0 only sees key 0, so it copies V[0]; query 1 sees keys 0-1 with softmax weights).
The remaining Q tiles (1 and 2) would be processed similarly, each iterating over the relevant KV tiles.
Decode Variant: FlashDecoding
During decoding, we have a single query (M=1) attending to the entire KV cache. The FlashAttention algorithm simplifies because there is only one query:
FlashDecoding (single query)
=============================
q: [1 x d] (single query vector)
K_cache: [seq_len x d] (all past keys)
V_cache: [seq_len x d] (all past values)
Tiling: split KV cache into tiles of Bk positions
+-------+-------+-------+-------+-------+
| Tile0 | Tile1 | Tile2 | Tile3 | Tile4 | K/V cache
+-------+-------+-------+-------+-------+
0..31 32..63 64..95 96..127 ...
Each tile is processed by one threadgroup (or SIMD group).
All tiles can be processed IN PARALLEL since each
produces an independent (m, s, o) tuple.
Then a final reduction step merges the per-tile results.
The parallel structure is different from prefill FlashAttention, where the KV tiles must be processed sequentially (because the running softmax state is updated incrementally). In decode, each tile independently produces a partial result:
Parallel FlashDecoding
=======================
Phase 1: Each tile independently computes (m_t, s_t, o_t)
Tile 0: q vs K[0:32] -> (m_0, s_0, o_0)
Tile 1: q vs K[32:64] -> (m_1, s_1, o_1)
Tile 2: q vs K[64:96] -> (m_2, s_2, o_2)
...
Phase 2: Merge all tiles' results
merge(tile_i, tile_j):
m_new = max(m_i, m_j)
corr_i = exp(m_i - m_new)
corr_j = exp(m_j - m_new)
s_new = s_i * corr_i + s_j * corr_j
o_new = (o_i * s_i * corr_i + o_j * s_j * corr_j) / s_new
(or equivalently, keep unnormalized and normalize at the end)
This merge is associative! Can be done in a tree reduction.
// FlashDecoding: Phase 1 -- per-tile computation
kernel void flash_decode_phase1(
device const half* q [[buffer(0)]], // [d]
device const half* K_cache [[buffer(1)]], // [seq_len x d]
device const half* V_cache [[buffer(2)]], // [seq_len x d]
device float* tile_max [[buffer(3)]], // [num_tiles]
device float* tile_sum [[buffer(4)]], // [num_tiles]
device float* tile_out [[buffer(5)]], // [num_tiles x d]
constant uint& seq_len [[buffer(6)]],
constant uint& d [[buffer(7)]],
uint lane [[thread_index_in_simdgroup]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
const uint Bk = 32;
uint k_start = tg_id * Bk;
uint k_end = min(k_start + Bk, seq_len);
// Load q into registers (small enough)
// (In practice, load into threadgroup memory cooperatively)
float m = -INFINITY;
float s = 0.0f;
float o[128] = {0}; // Assuming d <= 128
for (uint k = k_start; k < k_end; k++) {
// Compute score = q . K[k] / sqrt(d)
float score = 0.0f;
for (uint i = lane; i < d; i += 32) {
score += float(q[i]) * float(K_cache[k * d + i]);
}
score = simd_sum(score);
score /= sqrt(float(d));
// Online softmax update
float m_new = max(m, score);
float correction = exp(m - m_new);
float weight = exp(score - m_new);
s = s * correction + weight;
for (uint i = lane; i < d; i += 32) {
o[i] = o[i] * correction + weight * float(V_cache[k * d + i]);
}
m = m_new;
}
// Write per-tile results
if (lane == 0) {
tile_max[tg_id] = m;
tile_sum[tg_id] = s;
}
for (uint i = lane; i < d; i += 32) {
tile_out[tg_id * d + i] = o[i];
}
}
// FlashDecoding: Phase 2 -- merge tile results
kernel void flash_decode_phase2(
device const float* tile_max [[buffer(0)]], // [num_tiles]
device const float* tile_sum [[buffer(1)]], // [num_tiles]
device const float* tile_out [[buffer(2)]], // [num_tiles x d]
device half* output [[buffer(3)]], // [d]
constant uint& num_tiles [[buffer(4)]],
constant uint& d [[buffer(5)]],
uint lane [[thread_index_in_simdgroup]])
{
float m = -INFINITY;
float s = 0.0f;
float o[128] = {0};
for (uint t = 0; t < num_tiles; t++) {
float m_t = tile_max[t];
float s_t = tile_sum[t];
float m_new = max(m, m_t);
float corr_old = exp(m - m_new);
float corr_new = exp(m_t - m_new);
// Merge sums
s = s * corr_old + s_t * corr_new;
// Merge outputs
for (uint i = lane; i < d; i += 32) {
o[i] = o[i] * corr_old + tile_out[t * d + i] * corr_new;
}
m = m_new;
}
// Normalize and write
for (uint i = lane; i < d; i += 32) {
output[i] = half(o[i] / s);
}
}
Tree Masking for Speculative Decoding
Speculative decoding is an optimization where a smaller “draft” model proposes several tokens ahead, and the larger model verifies them all at once. The verification step requires attention with a tree mask instead of the standard causal mask.
In speculative decoding, the draft model might produce a tree of candidates:
Speculative Decoding Tree
==========================
Position: 0 1 2 3 4 5 6 7 8
Token: A B C D E F G H I
Parent: - 0 1 1 2 2 3 3 4
Tree structure:
A
|
B
/ \
C D
/| |\
E F G H
|
I
Causal mask for this tree (1 = can attend, 0 = masked):
A B C D E F G H I
A [ 1 0 0 0 0 0 0 0 0 ]
B [ 1 1 0 0 0 0 0 0 0 ]
C [ 1 1 1 0 0 0 0 0 0 ]
D [ 1 1 0 1 0 0 0 0 0 ]
E [ 1 1 1 0 1 0 0 0 0 ]
F [ 1 1 1 0 0 1 0 0 0 ]
G [ 1 1 0 1 0 0 1 0 0 ]
H [ 1 1 0 1 0 0 0 1 0 ]
I [ 1 1 1 0 1 0 0 0 1 ]
Note: D can see A and B (its ancestors) but NOT C (sibling).
E can see A, B, C (ancestors along its path) but NOT D, F, G, H.
FlashAttention handles tree masks by passing the mask as an additional input. During the tiled computation, the mask is consulted for each score tile:
Tree Mask in FlashAttention
============================
For each score tile S_tile[Bq x Bk]:
For each (i, j):
query_pos = q_tile * Bq + i
key_pos = kv_tile * Bk + j
if mask[query_pos][key_pos] == 0:
S_tile[i][j] = -infinity
The mask is typically stored as a bitmask for efficiency:
1 bit per (query, key) pair
For n=256: 256 * 256 / 8 = 8 KB
With standard causal mask, this is implicit (just compare positions).
With tree mask, we need the explicit bitmask.
GQA Handling in FlashAttention
When using Grouped Query Attention, multiple query heads share the same KV head. The FlashAttention kernel needs to route each query head to the correct KV head:
GQA in FlashAttention
======================
n_q_heads = 32, n_kv_heads = 8, group_size = 4
Query heads 0, 1, 2, 3 --> KV head 0
Query heads 4, 5, 6, 7 --> KV head 1
Query heads 8, 9, 10, 11 --> KV head 2
...
Query heads 28, 29, 30, 31 --> KV head 7
Implementation options:
Option A: One kernel launch per KV head, process 4 Q heads
+-----------+
| Q heads | Process heads 0,1,2,3 together
| 0,1,2,3 | They all read the SAME K,V tiles
| | Load K,V once, apply to 4 query sets
+-----------+
Option B: Separate launches, but the KV tile loading is shared
Each Q head kernel knows its kv_head_id = q_head / group_size
and indexes into the correct KV cache slice.
The efficient approach processes all query heads in a group simultaneously, loading each KV tile once:
// GQA FlashAttention: process one group of query heads
kernel void flash_attn_gqa(
device const half* Q [[buffer(0)]], // [n_q_heads x seq x d]
device const half* K_cache [[buffer(1)]], // [n_kv_heads x seq x d]
device const half* V_cache [[buffer(2)]], // [n_kv_heads x seq x d]
device half* O [[buffer(3)]], // [n_q_heads x seq x d]
constant uint& kv_head [[buffer(4)]], // Which KV head
constant uint& group_size [[buffer(5)]], // Typically 4
// ... other params ...
uint sgid [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
// Determine which query head this SIMD group handles
// Within one threadgroup, different SIMD groups handle different Q heads
uint q_head_in_group = sgid % group_size;
uint q_head = kv_head * group_size + q_head_in_group;
// K, V come from kv_head (shared by all Q heads in group)
device const half* K = K_cache + kv_head * seq_len * d;
device const half* V = V_cache + kv_head * seq_len * d;
// Q comes from this specific query head
device const half* my_Q = Q + q_head * seq_len * d;
// Now run standard FlashAttention on (my_Q, K, V)
// K and V tiles are loaded once and shared across SIMD groups
// within the threadgroup (via threadgroup memory)
// ...
}
Numerical Stability Deep Dive
The exp and division operations in softmax are numerically treacherous. Let us examine the potential pitfalls and how online softmax handles them:
Numerical Issues in Softmax
============================
Problem 1: exp overflow
exp(100) = 2.69e43 (fine in FP32)
exp(89) = 4.49e38 (max FP32 is 3.40e38!)
exp(90) = +inf (OVERFLOW!)
Solution: subtract the max before exp
If max(z) = 90 and z[i] = 85:
exp(85) = overflow
exp(85 - 90) = exp(-5) = 0.0067 (safe!)
Problem 2: exp underflow (less critical)
exp(-100) = 3.72e-44 (denormal in FP32)
exp(-104) = 0.0 (underflow to zero)
This is usually fine -- those entries get negligible weight anyway.
Online softmax and stability:
We always compute exp(z[i] - m) where m = running max.
Since m >= z[i] for all previously seen values:
z[i] - m <= 0 (always non-positive)
exp(z[i] - m) <= 1.0 (never overflows!)
The correction factor exp(m_old - m_new):
m_old <= m_new (max only increases)
m_old - m_new <= 0
exp(m_old - m_new) <= 1.0 (never overflows!)
Summary: online softmax is ALWAYS numerically safe.
Standard softmax without the max-subtraction trick can overflow.
Practical Performance Considerations
Tile Size Selection
The choice of Bq and Bk affects performance significantly:
Tile Size Tradeoffs
====================
Larger tiles (Bq=64, Bk=64):
+ More compute per memory load (better arithmetic intensity)
+ Better SIMD group MMA utilization
- More threadgroup memory needed
- Fewer threadgroups can run concurrently (occupancy)
- More wasted compute at sequence boundaries
Smaller tiles (Bq=16, Bk=32):
+ Less threadgroup memory
+ Better occupancy
+ Less waste at boundaries
- Lower arithmetic intensity
- More kernel overhead per tile
Typical sweet spots on Apple Silicon:
Prefill: Bq=32, Bk=32 or Bq=64, Bk=64
Decode: Bk=32 or Bk=64 (Bq=1)
Memory Bandwidth vs Compute
FlashAttention Performance Model
=================================
Prefill (n=4096, d=128, Bq=32, Bk=32):
Per Q-tile iteration (processing one KV tile):
Load K_tile: 32 * 128 * 2 = 8 KB
Load V_tile: 32 * 128 * 2 = 8 KB
Compute S: 2 * 32 * 32 * 128 = 262K FLOPs
Compute O: 2 * 32 * 32 * 128 = 262K FLOPs
Total compute: ~524K FLOPs
Total memory: ~16 KB
Arithmetic intensity: 524K / 16K = 32.75 FLOPs/byte
This is in the compute-bound regime! (Good!)
Decode (n=4096, d=128, Bk=32):
Per KV tile:
Load K_tile: 32 * 128 * 2 = 8 KB
Load V_tile: 32 * 128 * 2 = 8 KB
Compute scores: 2 * 1 * 32 * 128 = 8K FLOPs
Compute output: 2 * 1 * 32 * 128 = 8K FLOPs
Total compute: ~16K FLOPs
Total memory: ~16 KB
Arithmetic intensity: 16K / 16K = 1 FLOP/byte
Bandwidth-bound (expected for decode)
Total KV cache read: 2 * 4096 * 128 * 2 = 2 MB per head
At 200 GB/s: ~0.01 ms per head, ~0.32 ms for 32 heads
Summary
FlashAttention transforms attention from an O(n^2) memory algorithm to an O(n) memory algorithm without changing the mathematical result. The key ideas are:
-
Online softmax enables streaming computation by maintaining a running maximum and sum, with correction factors applied when the maximum changes. This is numerically stable by construction.
-
Tiling divides the KV sequence into blocks of size Bk. For each query tile, we iterate over KV tiles, computing a small score matrix in registers/threadgroup memory and updating the running output.
-
Memory savings are dramatic: from 4 GB per head (n=32768) to about 8 MB per head. This is what makes 128K+ context lengths feasible.
-
Decode variant (FlashDecoding) processes KV tiles in parallel since there is only one query, then merges results using the same online softmax correction. This maximizes GPU utilization for the bandwidth-bound decode phase.
-
Tree masking extends the algorithm for speculative decoding by replacing the simple causal mask comparison with a lookup into an explicit bitmask.
-
GQA integration shares KV tile loads across multiple query heads in the same group, reducing redundant memory traffic.
The combination of FlashAttention with quantized KV caches and GQA is what makes modern long-context LLM inference practical on consumer GPUs. Without these techniques, even a 4096-token context would strain memory; with them, 128K-token contexts are routine.
Quantization: Making Models Fit
Here is a problem you will hit immediately when trying to run LLMs on consumer hardware: a 7-billion parameter model in FP16 takes 14 GB of memory. An M2 MacBook Air has 8 GB of unified memory. A 70B model needs 140 GB in FP16 – more than any single GPU on the market. Even if you have enough memory, the model needs to stream through the memory bus during inference, and memory bandwidth is the bottleneck for token generation.
Quantization is the solution. By representing weights with fewer bits – 8, 4, 3, or even 2 bits instead of 16 – we shrink the model dramatically. A 7B model at 4-bit quantization fits in about 4 GB. The 70B model fits in 35-40 GB. And because there is less data to read from memory, inference gets faster too.
But quantization is not free. Fewer bits means less precision, which means some degradation in output quality. The art of quantization is finding the sweet spot: aggressive enough to fit in memory and hit target speeds, but gentle enough to preserve model quality.
In this chapter, we will cover the major quantization schemes you will encounter in practice, understand the math behind them, work through concrete examples of quantizing and dequantizing, and analyze the tradeoffs between size, speed, and quality.
Why Quantize? The Numbers
Let us start with the raw arithmetic that makes quantization essential:
Model Sizes at Different Precisions
=====================================
Parameters FP32 FP16 Q8_0 Q4_0 Q4_K_M Q2_K
---------- ---- ---- ---- ---- ------ ----
1.5B 6 GB 3 GB 1.6 GB 0.9 GB 1.0 GB 0.6 GB
7B 28 GB 14 GB 7.5 GB 4.0 GB 4.4 GB 2.7 GB
13B 52 GB 26 GB 13.8 GB 7.4 GB 8.0 GB 4.9 GB
34B 136 GB 68 GB 36.1 GB 19.2 GB 21.1 GB 12.7 GB
70B 280 GB 140 GB 74.4 GB 39.6 GB 43.4 GB 26.2 GB
Apple Silicon Unified Memory:
M1/M2/M3 (base): 8 GB -> Q4_0 7B fits, FP16 does not
M1/M2/M3 Pro: 18 GB -> Q4_0 13B fits
M1/M2/M3 Max: 64 GB -> Q4_0 70B fits
M2/M3 Ultra: 192 GB -> FP16 70B fits
But memory capacity is only half the story. The other half is bandwidth:
The Bandwidth Equation
=======================
During autoregressive decoding (generating one token at a time),
EVERY weight in the model is read from memory exactly once per token.
decode_speed (tokens/sec) = memory_bandwidth / model_size_in_bytes
Apple M2 Pro: ~200 GB/s bandwidth
FP16 7B: 200 / 14.0 = ~14 tokens/sec
Q8_0 7B: 200 / 7.5 = ~27 tokens/sec
Q4_0 7B: 200 / 4.0 = ~50 tokens/sec
Q4_K_M 7B: 200 / 4.4 = ~45 tokens/sec
Q2_K 7B: 200 / 2.7 = ~74 tokens/sec
This is a theoretical upper bound -- real speeds are 60-80% of this
due to attention computation, KV cache reads, and other overhead.
But the proportionality holds: half the bytes = double the speed.
Decode Speed vs Quantization
=============================
Speed Model: 7B on M2 Pro (200 GB/s)
(tok/s)
80 | * Q2_K
|
70 |
|
60 |
| * Q4_0
50 | * Q4_K_M
|
40 |
|
30 | * Q8_0
|
20 |
| * FP16
10 |
|
0 +-----+------+------+------+------+------+----->
16 8 6 4 3 2 bits/weight
The Fundamentals of Quantization
At its core, quantization maps a continuous range of floating-point values to a discrete set of integer values. The simplest form is uniform affine quantization:
Uniform Quantization
=====================
Given: a set of FP16 weights in range [w_min, w_max]
Goal: represent each weight using b bits (0 to 2^b - 1)
Quantize:
scale = (w_max - w_min) / (2^b - 1)
zero_point = round(-w_min / scale)
q[i] = round(w[i] / scale) + zero_point
Dequantize:
w[i] = scale * (q[i] - zero_point)
Example: b=4 bits (range 0..15)
Weights: [-0.8, 0.3, -0.1, 0.5, -0.6, 0.7]
w_min = -0.8, w_max = 0.7
scale = (-0.8 - 0.7) / -15 = 0.1
zero_point = round(0.8 / 0.1) = 8
Quantize:
q[-0.8] = round(-0.8/0.1) + 8 = -8 + 8 = 0
q[0.3] = round(0.3/0.1) + 8 = 3 + 8 = 11
q[-0.1] = round(-0.1/0.1) + 8 = -1 + 8 = 7
q[0.5] = round(0.5/0.1) + 8 = 5 + 8 = 13
q[-0.6] = round(-0.6/0.1) + 8 = -6 + 8 = 2
q[0.7] = round(0.7/0.1) + 8 = 7 + 8 = 15
Stored: [0, 11, 7, 13, 2, 15] (each is 4 bits)
Dequantize:
w[0] = 0.1 * (0 - 8) = -0.8 (exact)
w[11] = 0.1 * (11 - 8) = 0.3 (exact)
w[7] = 0.1 * (7 - 8) = -0.1 (exact)
w[13] = 0.1 * (13 - 8) = 0.5 (exact)
w[2] = 0.1 * (2 - 8) = -0.6 (exact)
w[15] = 0.1 * (15 - 8) = 0.7 (exact)
In this lucky example, all values hit exactly.
In reality, most values get rounded, introducing error.
But this global quantization (one scale and zero-point for the entire weight matrix) is too crude for neural networks. The weight distribution varies significantly across different parts of the matrix. Modern quantization uses block quantization: the weights are divided into small blocks, each with its own scale (and possibly zero-point).
Block Quantization: GGUF Format (Q4_0)
The GGUF format (used by llama.cpp and many Metal inference engines) uses block quantization. The simplest variant, Q4_0, works as follows:
Q4_0 Block Structure
=====================
Block size: 32 weights
Storage per block:
- 1 x FP16 scale factor (2 bytes)
- 32 x 4-bit signed integers packed into 16 bytes
Total: 18 bytes for 32 weights = 4.5 bits/weight
Quantization (symmetric, no zero point):
For each block of 32 weights:
abs_max = max(|w[0]|, |w[1]|, ..., |w[31]|)
scale = abs_max / 8 (maps to range [-8, 7])
For each weight w[i]:
q[i] = clamp(round(w[i] / scale), -8, 7)
Store as (q[i] + 8), giving range [0, 15] (4 bits)
Dequantization:
w[i] = scale * (stored[i] - 8)
Memory layout (18 bytes per block):
+--------+--------+--------+---+--------+
| scale | byte 0 | byte 1 |...| byte15 |
| (FP16) | w0|w1 | w2|w3 | |w30|w31 |
+--------+--------+--------+---+--------+
2B 1B 1B 1B
(2 nibbles packed per byte)
Worked Example: Quantizing a Block with Q4_0
Example: Quantize 32 weights with Q4_0
========================================
Weights (first 8 of 32 shown):
[ 0.23, -0.41, 0.67, -0.12, 0.55, -0.89, 0.34, 0.08, ... ]
Step 1: Find absolute maximum
abs_max = max(|all 32 weights|) = 0.89 (from -0.89)
Step 2: Compute scale
scale = 0.89 / 8 = 0.11125
Step 3: Quantize each weight
q[0] = round(0.23 / 0.11125) = round(2.067) = 2 -> stored: 2+8 = 10
q[1] = round(-0.41 / 0.11125) = round(-3.685) = -4 -> stored: -4+8 = 4
q[2] = round(0.67 / 0.11125) = round(6.022) = 6 -> stored: 6+8 = 14
q[3] = round(-0.12 / 0.11125) = round(-1.079) = -1 -> stored: -1+8 = 7
q[4] = round(0.55 / 0.11125) = round(4.944) = 5 -> stored: 5+8 = 13
q[5] = round(-0.89 / 0.11125) = round(-8.000) = -8 -> stored: -8+8 = 0
q[6] = round(0.34 / 0.11125) = round(3.056) = 3 -> stored: 3+8 = 11
q[7] = round(0.08 / 0.11125) = round(0.719) = 1 -> stored: 1+8 = 9
Step 4: Pack into bytes (two 4-bit values per byte)
byte[0] = (stored[1] << 4) | stored[0] = (4 << 4) | 10 = 0x4A
byte[1] = (stored[3] << 4) | stored[2] = (7 << 4) | 14 = 0x7E
byte[2] = (stored[5] << 4) | stored[4] = (0 << 4) | 13 = 0x0D
byte[3] = (stored[7] << 4) | stored[6] = (9 << 4) | 11 = 0x9B
Step 5: Dequantize (to verify)
dq[0] = 0.11125 * (10 - 8) = 0.11125 * 2 = 0.2225 (was 0.23, error: 0.008)
dq[1] = 0.11125 * (4 - 8) = 0.11125 * -4 = -0.4450 (was -0.41, error: 0.035)
dq[2] = 0.11125 * (14 - 8) = 0.11125 * 6 = 0.6675 (was 0.67, error: 0.003)
dq[3] = 0.11125 * (7 - 8) = 0.11125 * -1 = -0.1113 (was -0.12, error: 0.009)
dq[4] = 0.11125 * (13 - 8) = 0.11125 * 5 = 0.5563 (was 0.55, error: 0.006)
dq[5] = 0.11125 * (0 - 8) = 0.11125 * -8 = -0.8900 (was -0.89, error: 0.000)
dq[6] = 0.11125 * (11 - 8) = 0.11125 * 3 = 0.3338 (was 0.34, error: 0.006)
dq[7] = 0.11125 * (9 - 8) = 0.11125 * 1 = 0.1113 (was 0.08, error: 0.031)
Average absolute error: ~0.012
Relative error: ~3-4% on average
Bit Packing Details
Two 4-bit values are packed into each byte. This is a critical operation for both quantization (packing) and dequantization (unpacking):
Bit Packing: Two Nibbles Per Byte
===================================
Stored values: a=10 (0b1010), b=4 (0b0100)
Pack: byte = (b << 4) | a
byte = (0b0100 << 4) | 0b1010
byte = 0b01001010
byte = 0x4A
Unpack:
a = byte & 0x0F = 0x4A & 0x0F = 0x0A = 10
b = (byte >> 4) & 0x0F = (0x4A >> 4) & 0x0F = 0x04 = 4
Visual:
+---+---+---+---+---+---+---+---+
| b3| b2| b1| b0| a3| a2| a1| a0| <-- one byte
+---+---+---+---+---+---+---+---+
| high nibble | low nibble |
| value b | value a |
In Metal shader code:
// Unpacking 4-bit values from bytes
uchar packed_byte = block.packed[j];
// Extract two 4-bit values
int8_t val_low = (packed_byte & 0x0F) - 8; // Low nibble, subtract offset
int8_t val_high = (packed_byte >> 4) - 8; // High nibble, subtract offset
// Dequantize
float w0 = float(block.scale) * float(val_low);
float w1 = float(block.scale) * float(val_high);
K-Quant Family: Super-Blocks of 256
The basic Q4_0 format uses a single scale per 32 weights. The K-quant family (Q2_K, Q3_K, Q4_K, Q5_K, Q6_K) introduced by llama.cpp uses a two-level hierarchy: super-blocks of 256 weights, each containing 8 sub-blocks of 32 weights.
K-Quant Super-Block Structure (Q4_K)
======================================
Super-block: 256 weights
+---------------------------------------------------+
| FP16 scale_of_scales | FP16 scale_of_mins | 4 bytes
+---------------------------------------------------+
| 8 x 6-bit sub-block scales (packed) | 6 bytes
+---------------------------------------------------+
| 8 x 6-bit sub-block mins (packed) | 6 bytes
+---------------------------------------------------+
| 256 x 4-bit quantized values (packed) | 128 bytes
+---------------------------------------------------+
Total: 144 bytes for 256 weights = 4.5 bits/weight
Each sub-block of 32 weights has its own:
- 6-bit scale (quantized, relative to super-block scale_of_scales)
- 6-bit minimum (quantized, relative to super-block scale_of_mins)
Dequantization for weight i in sub-block b:
sub_scale = scale_of_scales * sub_scales_q6[b]
sub_min = scale_of_mins * sub_mins_q6[b]
w[i] = sub_scale * q4[i] - sub_min
Why the two-level structure? It improves quantization accuracy by allowing each sub-block of 32 weights to have a different range, while keeping the overhead (scale and min metadata) small by quantizing the per-sub-block parameters themselves.
Q4_0 vs Q4_K: Quantization Granularity
========================================
Q4_0: One FP16 scale per 32 weights
+---------+---------+---------+---------+
| scale_0 | scale_1 | scale_2 | scale_3 | ... (256 weights = 8 blocks)
| 32 wts | 32 wts | 32 wts | 32 wts |
+---------+---------+---------+---------+
Each scale is independent FP16 (full precision).
Overhead: 2 bytes / 32 weights = 0.5 bits/weight
Total: 4.0 + 0.5 = 4.5 bits/weight
Q4_K: Two-level hierarchy for 256 weights
+-------------------------------------------+
| super-block: scale_of_scales, scale_of_mins| 4 bytes for 256 weights
| sub-block 0: 6-bit scale, 6-bit min |
| sub-block 1: 6-bit scale, 6-bit min |
| ... |
| sub-block 7: 6-bit scale, 6-bit min |
+-------------------------------------------+
Sub-block parameters: 12 bytes / 256 = 0.375 bits/weight
Super-block parameters: 4 bytes / 256 = 0.125 bits/weight
Total: 4.0 + 0.375 + 0.125 = 4.5 bits/weight
Same bits/weight, but Q4_K has finer-grained adaptation
and asymmetric ranges (scale + min instead of just scale).
Result: measurably lower perplexity.
The K-Quant Zoo
Here is the full family of K-quant formats:
K-Quant Formats Overview
=========================
Format Bits/wt Quant Sub-block params Quality
------ ------- ----- ---------------- -------
Q2_K 2.56 2-bit 4-bit scale+min Poor (emergency use)
Q3_K_S 3.44 3-bit 6-bit scale Acceptable
Q3_K_M 3.91 3-bit 6-bit scale+min Fair
Q3_K_L 4.28 3-bit 6-bit scale+min Fair+
Q4_K_S 4.50 4-bit 6-bit scale+min Good
Q4_K_M 4.84 4-bit 6-bit scale+min Good+
Q5_K_S 5.50 5-bit 6-bit scale+min Very Good
Q5_K_M 5.69 5-bit 6-bit scale+min Very Good+
Q6_K 6.56 6-bit 8-bit scale Excellent
The S/M/L suffixes mean Small/Medium/Large and refer to which
layers get higher precision. "M" uses higher precision for the
attention layers and output layers (most sensitive to quantization).
The _M variants are particularly clever: they use mixed precision, with more bits for
the layers that matter most:
Q4_K_M: Mixed Precision by Layer
==================================
Layer Type Quantization Why
---------- ------------ ---
Attention Q,K,V,O Q4_K Important for quality
FFN gate, up Q4_K Standard
FFN down Q6_K Output projection, sensitive
Output (vocab) layer Q6_K Final prediction, very sensitive
Embedding Q4_K Large but less sensitive
The "M" variant is ~8% larger than "S" but measurably better.
Per-Channel and Per-Group Quantization (MLX Format)
The MLX framework (Apple’s machine learning framework) uses a different quantization approach: per-group affine quantization with configurable group sizes.
MLX Quantization
=================
Group size: typically 64 or 128
Per group: one FP16 scale and one FP16 bias (zero-point)
Quantize (per group of G weights):
w_min = min(weights in group)
w_max = max(weights in group)
scale = (w_max - w_min) / (2^bits - 1)
bias = w_min
q[i] = round((w[i] - bias) / scale)
q[i] = clamp(q[i], 0, 2^bits - 1)
Dequantize:
w[i] = scale * q[i] + bias
Example with group_size=64, bits=4:
+------------------------------------+
| scale (FP16) | bias (FP16) | 64x4b |
| 2 bytes | 2 bytes | 32 B |
+------------------------------------+
Total: 36 bytes for 64 weights = 4.5 bits/weight
With group_size=128:
+-------------------------------------+
| scale (FP16) | bias (FP16) | 128x4b |
| 2 bytes | 2 bytes | 64 B |
+-------------------------------------+
Total: 68 bytes for 128 weights = 4.25 bits/weight
The key difference from GGUF quantization: MLX uses affine quantization (scale + bias) rather than symmetric quantization (scale only). This better handles weight distributions that are not centered around zero:
Symmetric vs Affine Quantization
==================================
Symmetric (Q4_0):
Maps [-abs_max, +abs_max] to [-8, +7]
scale = abs_max / 8
w = scale * (q - 8)
Problem: if weights are [0.1, 0.3, 0.5, 0.7, 0.9]
abs_max = 0.9, scale = 0.1125
Half the quantization levels (-8 to -1) are wasted!
Wasted range
<---------->
-0.9 -0.7 -0.5 -0.3 -0.1 0.1 0.3 0.5 0.7 0.9
| | | | | | | | | |
q=0 q=2 q=4 q=6 q=8 q=9 q=11 q=13 q=14 q=16
NO WEIGHTS HERE ALL WEIGHTS HERE
Affine (MLX):
Maps [w_min, w_max] to [0, 15]
scale = (0.9 - 0.1) / 15 = 0.0533
bias = 0.1
w = scale * q + bias
0.1 0.15 0.21 0.26 ... 0.84 0.9
| | | | | |
q=0 q=1 q=2 q=3 ... q=14 q=15
FULL RANGE UTILIZED!
Affine quantization uses all levels for the actual data range.
Better for asymmetric distributions (common in practice).
MLX Quantization Code Example
// MLX-style dequantization in Metal
kernel void dequantize_mlx_q4(
device const uint32_t* packed_weights [[buffer(0)]], // Packed 4-bit
device const half* scales [[buffer(1)]], // Per-group scales
device const half* biases [[buffer(2)]], // Per-group biases
device half* output [[buffer(3)]], // Dequantized
constant uint& group_size [[buffer(4)]],
uint tid [[thread_position_in_grid]])
{
uint group_id = tid / group_size;
uint in_group = tid % group_size;
half scale = scales[group_id];
half bias = biases[group_id];
// Each uint32 holds 8 x 4-bit values
uint word_idx = tid / 8;
uint bit_offset = (tid % 8) * 4;
uint32_t word = packed_weights[word_idx];
uint8_t q = (word >> bit_offset) & 0xF;
output[tid] = scale * half(q) + bias;
}
Impact on Quality: Perplexity Analysis
Perplexity is the standard metric for measuring quantization quality. Lower is better – it measures how “surprised” the model is by test data. A perplexity increase of more than 1-2% is generally noticeable in output quality.
Perplexity Comparison (Llama 2 7B, WikiText-2)
================================================
Format Bits/wt Perplexity Delta Quality Assessment
------ ------- ---------- ----- ------------------
FP16 16.00 5.796 --- Reference (baseline)
Q8_0 8.50 5.799 +0.003 Indistinguishable
Q6_K 6.56 5.804 +0.008 Excellent
Q5_K_M 5.69 5.812 +0.016 Very good
Q5_K_S 5.50 5.819 +0.023 Very good
Q4_K_M 4.84 5.882 +0.086 Good
Q4_K_S 4.50 5.912 +0.116 Good
Q4_0 4.50 5.946 +0.150 Acceptable
Q3_K_M 3.91 6.145 +0.349 Fair (noticeable)
Q3_K_S 3.44 6.351 +0.555 Degraded
Q2_K 2.56 6.981 +1.185 Poor
Rule of thumb:
- Q4_K_M and above: quality loss rarely noticeable in practice
- Q3_K: quality loss sometimes noticeable, especially on hard tasks
- Q2_K: clear quality degradation, only for extreme memory constraints
Perplexity vs Bits per Weight
==============================
Perplexity
|
7.0 |* Q2_K
|
6.5 | * Q3_K_S
| * Q3_K_M
6.0 | * * Q4_0, Q4_K_S
| * * Q4_K_M, Q5_K_S
5.8 | * * Q5_K_M, Q6_K
| * Q8_0
5.6 | * FP16
|
5.4 +---+---+---+---+---+---+---+---+---+---+--->
2 3 4 5 6 7 8 ... 16 bits/weight
The curve has a "knee" around 4-5 bits/weight.
Below 4 bits, quality drops rapidly.
Above 5 bits, returns diminish.
4-5 bits is the sweet spot for most applications.
The Bandwidth Equation: Predicting Decode Speed
We mentioned this earlier, but let us formalize it. During autoregressive decoding:
Decode Speed Model
====================
Given:
P = number of parameters
b = bits per weight (effective, including quantization overhead)
BW = memory bandwidth (bytes/sec)
overhead = non-matmul time (attention, normalization, etc.)
model_bytes = P * b / 8
time_per_token = model_bytes / BW + overhead
tokens_per_second = 1 / time_per_token
Example: Llama 2 7B on M2 Pro
P = 6.74 billion (actual parameter count)
BW = 200 GB/s
Q4_K_M (b = 4.84):
model_bytes = 6.74e9 * 4.84 / 8 = 4.077 GB
matmul_time = 4.077 / 200 = 0.0204 sec = 20.4 ms
overhead ≈ 5 ms (attention + KV cache + norms)
time_per_token ≈ 25.4 ms
speed ≈ 39 tokens/sec
Observed: ~35-40 tokens/sec (matches!)
This equation is remarkably accurate because GEMV (the matmul during decode) is almost perfectly bandwidth-bound. The only thing that breaks the model is when the KV cache becomes very large (long sequences) and the attention computation starts contributing significantly.
When Does Attention Start to Matter?
======================================
KV cache read per step (GQA, 8 KV heads, d=128):
kv_bytes = 2 * 8 * seq_len * 128 * 2 = 4096 * seq_len bytes
Total per token = model_bytes + kv_bytes
seq_len kv_bytes model(Q4_K_M) kv/model Impact
-------- -------- ------------- -------- ------
512 2 MB 4.08 GB 0.05% Negligible
2048 8 MB 4.08 GB 0.2% Negligible
8192 32 MB 4.08 GB 0.8% Minimal
32768 128 MB 4.08 GB 3.1% Small
131072 512 MB 4.08 GB 12.5% Noticeable (~11% slower)
KV cache impact is negligible for typical conversations (< 8K tokens)
but matters for very long contexts (> 32K tokens).
Comparison Table
Here is a comprehensive comparison of all the major quantization formats:
+----------+--------+--------+----------+---------+--------+-------+
| Format | Bits/ | Block | Params | Asym- | Mixed | Used |
| | weight | size | per block| metric? | prec? | by |
+----------+--------+--------+----------+---------+--------+-------+
| FP16 | 16.00 | N/A | N/A | N/A | No | Base |
| Q8_0 | 8.50 | 32 | 1 scale | No | No | GGUF |
| Q6_K | 6.56 | 256 | 8-bit sc | No | No | GGUF |
| Q5_K_M | 5.69 | 256 | 6-bit s+m| Yes | Yes | GGUF |
| Q5_K_S | 5.50 | 256 | 6-bit s+m| Yes | No | GGUF |
| Q4_K_M | 4.84 | 256 | 6-bit s+m| Yes | Yes | GGUF |
| Q4_K_S | 4.50 | 256 | 6-bit s+m| Yes | No | GGUF |
| Q4_0 | 4.50 | 32 | 1 scale | No | No | GGUF |
| Q3_K_M | 3.91 | 256 | 6-bit s+m| Yes | Yes | GGUF |
| Q3_K_S | 3.44 | 256 | 6-bit sc | No | No | GGUF |
| Q2_K | 2.56 | 256 | 4-bit s+m| Yes | No | GGUF |
| MLX-4 | 4.25 | 64-128| scale+bias| Yes | No | MLX |
| MLX-8 | 8.25 | 64-128| scale+bias| Yes | No | MLX |
| MLX-2 | 2.25 | 64-128| scale+bias| Yes | No | MLX |
+----------+--------+--------+----------+---------+--------+-------+
Legend: sc = scale only, s+m = scale + minimum, Asym = asymmetric range
Mixed prec = different quantization levels for different layers
Worked Example: Full Dequantization Pipeline
Let us walk through dequantizing a complete Q4_K block on the GPU, showing exactly what happens in a Metal shader:
Q4_K Dequantization Pipeline
==============================
Input: One super-block of 256 quantized weights
Step 1: Read super-block header
d = FP16 scale_of_scales (e.g., 0.0156)
dmin = FP16 scale_of_mins (e.g., 0.0078)
Step 2: Unpack sub-block parameters (6-bit values, packed)
The 8 sub-block scales are packed as 6-bit values in 6 bytes:
+--------+--------+--------+--------+--------+--------+
| byte 0 | byte 1 | byte 2 | byte 3 | byte 4 | byte 5 |
+--------+--------+--------+--------+--------+--------+
sub_scale[0] = byte[0] & 0x3F = 23
sub_scale[1] = ((byte[0]>>6) | (byte[1]<<2)) & 0x3F = 18
sub_scale[2] = ((byte[1]>>4) | (byte[2]<<4)) & 0x3F = 31
...
Similarly for sub_mins.
Step 3: Compute actual sub-block scale and min
For sub-block b:
actual_scale[b] = d * sub_scale[b]
actual_min[b] = dmin * sub_min[b]
actual_scale[0] = 0.0156 * 23 = 0.3588
actual_min[0] = 0.0078 * 15 = 0.1170 (example)
Step 4: Dequantize weights in sub-block
For weight i in sub-block b:
q = unpack_4bit(packed_data, b*32 + i) (0..15)
w = actual_scale[b] * q - actual_min[b]
q[0] = 7: w = 0.3588 * 7 - 0.1170 = 2.3946
q[1] = 12: w = 0.3588 * 12 - 0.1170 = 4.1886
q[2] = 3: w = 0.3588 * 3 - 0.1170 = 0.9594
q[3] = 0: w = 0.3588 * 0 - 0.1170 = -0.1170 (min value)
q[4] = 15: w = 0.3588 * 15 - 0.1170 = 5.2650 (max value)
...
And here is how this looks in Metal shader code:
// Q4_K super-block structure
struct block_q4_K {
half d; // scale of scales
half dmin; // scale of mins
uchar scales[12]; // 8 x 6-bit scales + 8 x 6-bit mins, packed
uchar qs[128]; // 256 x 4-bit quantized values
};
// Dequantize one Q4_K block
inline void dequantize_q4_K(
device const block_q4_K& block,
uint sub_block, // 0..7
uint index_in_sub, // 0..31
thread float& result)
{
// Step 1: Read super-block scales
float d = float(block.d);
float dmin = float(block.dmin);
// Step 2: Unpack 6-bit sub-block scale and min
// (Simplified -- actual packing is more complex)
uint8_t raw_scale, raw_min;
if (sub_block < 4) {
raw_scale = block.scales[sub_block] & 0x3F;
raw_min = block.scales[sub_block + 4] & 0x3F;
} else {
// Higher sub-blocks use bits from multiple bytes
raw_scale = ((block.scales[sub_block + 4] & 0xF) |
((block.scales[sub_block - 4] >> 6) << 4));
raw_min = ((block.scales[sub_block + 4] >> 4) |
((block.scales[sub_block] >> 6) << 4));
}
float scale = d * float(raw_scale);
float min = dmin * float(raw_min);
// Step 3: Unpack 4-bit weight value
uint byte_idx = (sub_block * 32 + index_in_sub) / 2;
uchar packed = block.qs[byte_idx];
uint8_t q;
if (index_in_sub % 2 == 0) {
q = packed & 0x0F; // Low nibble
} else {
q = (packed >> 4) & 0x0F; // High nibble
}
// Step 4: Dequantize
result = scale * float(q) - min;
}
Quantizing the KV Cache
So far we have discussed quantizing the model weights. But there is another large memory consumer during inference: the KV cache. For long context lengths, the KV cache can rival or exceed the model size.
KV Cache Quantization
======================
Standard KV cache (FP16):
Per layer: 2 * n_kv_heads * seq_len * d_k * 2 bytes
32 layers, 8 KV heads, d_k=128, seq_len=32768:
= 32 * 2 * 8 * 32768 * 128 * 2 = 4 GB
Q8 KV cache:
Same but 1 byte per element + small scale overhead
≈ 2 GB (50% reduction)
Q4 KV cache:
≈ 1 GB (75% reduction)
KV cache quantization is trickier than weight quantization because:
1. Values are computed dynamically (not known ahead of time)
2. Must quantize on-the-fly as new K/V vectors are computed
3. Range can change as new tokens arrive
4. Quality impact is harder to predict
Approach: per-head, per-position quantization
For each new (key, value) vector being added to the cache:
1. Compute the vector in FP16
2. Quantize to Q8 or Q4 with per-vector scale
3. Store quantized vector in cache
4. Dequantize on-the-fly during attention computation
Practical Considerations
Choosing the Right Quantization
Here is a practical decision flowchart:
Quantization Selection Flowchart
==================================
How much memory do you have?
|
+-- Very limited (8 GB): Use Q4_K_M for 7B, Q2_K for 13B
|
+-- Moderate (16-18 GB): Use Q4_K_M for 13B, Q6_K for 7B
|
+-- Generous (32-64 GB): Use Q6_K or Q8_0 for 13B-34B
|
+-- Abundant (96+ GB): Consider FP16 for best quality
What is your speed target?
|
+-- Maximum speed: Use lowest quantization that maintains acceptable quality
| (usually Q4_K_M or Q4_K_S)
|
+-- Quality priority: Use highest quantization that meets speed requirements
| (Q6_K or Q8_0)
|
+-- Balanced: Q4_K_M is almost always the right answer
Task sensitivity?
|
+-- Creative writing, coding: Q4_K_M usually fine
|
+-- Math, reasoning: Consider Q5_K_M or Q6_K
|
+-- Simple Q&A, summarization: Q3_K_M can work
Quantization and GEMV Performance
Remember from Chapter 14: GEMV is bandwidth-bound. Quantization directly reduces the bytes read, so the speedup is nearly linear with compression ratio:
GEMV Performance with Quantization
====================================
Single matmul: [1, 4096] * [4096, 11008]
M2 Pro, 200 GB/s
Format Weight Size Read Time Speedup
------ ----------- --------- -------
FP16 90.1 MB 0.451 ms 1.0x
Q8_0 47.8 MB 0.239 ms 1.9x
Q6_K 37.0 MB 0.185 ms 2.4x
Q4_K_M 24.5 MB 0.123 ms 3.7x
Q4_0 23.6 MB 0.118 ms 3.8x
Q2_K 14.4 MB 0.072 ms 6.3x
Note: Dequantization adds ~5-10% compute overhead.
Net speedup is slightly less than the raw bandwidth reduction,
but still very close.
GEMM (Prefill) with Quantization
For prefill (GEMM), quantization does not directly speed things up because GEMM is compute-bound. However, it does reduce memory footprint, which matters for:
- Fitting larger models in memory
- Keeping more of the model in cache
- Reducing memory pressure from concurrent operations
Prefill Performance: Less Clear-Cut
=====================================
Matmul: [512, 4096] * [4096, 11008]
M2 Pro
Format Approach Time
------ -------- ----
FP16 Native FP16 MMA ~2.0 ms
Q4_K_M Dequant to FP16 + MMA ~2.5 ms (slower!)
Q4_K_M Optimized mixed-precision ~2.2 ms (close to FP16)
During prefill, quantization can actually be SLOWER because:
1. Extra dequantization compute
2. MMA hardware is optimized for FP16/BF16, not mixed precision
3. GEMM is compute-bound, not memory-bound
But for most LLM use cases, decode time dominates total latency,
so optimizing decode (where quantization helps enormously) is
more important than optimizing prefill.
Summary
Quantization is what makes LLM inference practical on consumer hardware. The key points:
-
Why quantize: A 7B FP16 model needs 14 GB; Q4 needs ~4 GB. Decode speed is proportional to 1/model_size because GEMV is bandwidth-bound.
-
Block quantization (Q4_0): Groups of 32 weights share one FP16 scale. 4 bits per weight + scale overhead = 4.5 bits/weight effective. Simple and fast.
-
K-quant family: Super-blocks of 256 with 8 sub-blocks of 32. Two-level scale hierarchy. Mixed precision variants (_M) use more bits for sensitive layers. Q4_K_M is the most popular choice.
-
Per-group affine (MLX): Groups of 64-128 with per-group scale and bias. Asymmetric quantization handles non-centered distributions better.
-
Bit packing: Two 4-bit values per byte. Unpacking is a shift and mask operation, done on-the-fly during dequantization in the GPU shader.
-
Quality impact: The perplexity curve has a knee at ~4-5 bits/weight. Q4_K_M is the sweet spot for most applications. Below 3 bits, quality degrades noticeably.
-
The bandwidth equation:
decode_speed ≈ bandwidth / model_bytes. This simple formula predicts real-world performance with surprising accuracy. -
Dequantize on the fly: During GEMV, weights are dequantized in registers as they are loaded. The compute cost of dequantization is negligible compared to the bandwidth savings.
With quantization, a $1200 MacBook Air with 16 GB of RAM can run a 7B parameter model at 40+ tokens per second – fast enough for interactive use. That is the power of trading a tiny bit of precision for an enormous reduction in memory and bandwidth requirements.
Akunu Overview and Design Philosophy
Welcome to the deep-dive section of this book. Up until now, we have been building intuition for how LLM inference works on Apple Silicon: the Metal compute pipeline, the memory hierarchy, quantized matrix math, and the attention mechanism. Now it is time to see how all of those pieces come together in a real, production-quality inference engine.
Akunu is a high-performance LLM inference engine written specifically for Apple Silicon. The name comes from the Sinhala word meaning “embers” – a fitting metaphor for a project that tries to extract every last bit of heat from the GPU silicon.
In this chapter we will survey the project at a high level: what it does, how it is organized, what design principles drive it, and what the end-to-end inference flow looks like. Subsequent chapters will zoom in on each subsystem.
What Akunu Is (and What It Is Not)
Akunu is a local inference engine. You give it a model file (GGUF or MLX SafeTensors), it loads the weights onto the Apple GPU, and it runs the full transformer forward pass – prefill and decode – entirely on-device. There is no cloud, no server round-trip, no Python runtime.
Here is what it supports today:
| Feature | Details |
|---|---|
| Architectures | LLaMA, Qwen3, Gemma, Gemma 3, BERT, Whisper |
| Weight formats | GGUF, MLX SafeTensors |
| GGUF quant types | F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_K, Q6_K, Q8_0, Q2_K, Q3_K, Q4_K |
| MLX quant types | 3-bit, 4-bit, 6-bit, 8-bit (with configurable group size) |
| Tasks | Text generation, chat, embeddings, speech transcription |
| Decoding modes | Greedy, sampled (top-k/top-p/min-p), speculative (n-gram), grammar-constrained |
| API surface | C API (FFI-friendly), CLI tools, OpenAI-compatible HTTP server |
What it is not: it is not a training framework. It is not a general-purpose tensor library. It does not try to be cross-platform (though the architecture makes a future CUDA backend straightforward, as we will see). Every design decision optimizes for one thing: token throughput on Apple Silicon.
Performance: The Numbers
Let us start with the punchline, because performance is the reason this engine exists. All benchmarks were run on an Apple M4 Pro (16 GPU cores, 273 GB/s memory bandwidth):
Decode throughput (tg128, tokens/sec):
vs llama.cpp:
Average speedup: 1.83x
Best case: 3.66x (Qwen3-0.6B-Q3_K_S: 448 vs 123 t/s)
Wins: 20/21 configurations
vs MLX:
Average speedup: 1.17x
Best case: 1.25x (Qwen3-0.6B-bf16: 207 vs 165 t/s)
Wins: 11/11 configurations
These are not cherry-picked numbers. Across 19 GGUF model configurations and 11 MLX configurations, akunu wins decode throughput in 31 out of 32 tests. The speedup is most dramatic on small models (0.6B-1B parameters) with aggressive quantization (Q2_K through Q5_K), where akunu achieves 2-3.5x the throughput of llama.cpp.
Why? Because small quantized models are compute-bound during decode – the matrix multiplications finish so fast that overhead dominates. Akunu’s precompiled dispatch table and zero-allocation hot path eliminate that overhead. We will see exactly how in the sections below.
The Five Design Principles
Every non-trivial design decision in akunu traces back to one of five principles. Understanding these up front will make the rest of the codebase click.
Principle 1: Data-Driven Design (ArchDescriptor)
The naive way to support multiple architectures looks like this:
// DON'T DO THIS
if (arch == "llama") {
activation = silu_gate;
rope = rope_interleaved;
} else if (arch == "qwen3") {
activation = silu_gate;
rope = rope_neox;
has_qk_norm = true;
} else if (arch == "gemma") {
activation = gelu_gate;
rope = rope_neox;
has_qk_norm = true;
embedding_scale = sqrt(dim);
// ... 20 more fields
}
This approach does not scale. Every new architecture touches dozens of files. Every
if/else branch is a potential bug.
Akunu takes a different approach: it captures all architecture-specific differences
in a single POD struct called ArchDescriptor. The struct has about 20 fields
covering activation kernels, RoPE style, embedding scaling, normalization, encoder
config, and more. The entire table builder and prefill engine read from this struct
and never branch on the architecture name.
Adding a new architecture means writing one factory function that fills in the struct. That is it. No code changes in the dispatch table builder, the prefill engine, or the decode loop.
+------------------+ +------------------+ +------------------+
| arch_llama() | | arch_qwen3() | | arch_gemma(dim) |
| activation: | | activation: | | activation: |
| silu_gate_f16 | | silu_gate_f16 | | gelu_gate_f16 |
| rope: | | rope: | | rope: |
| interleaved | | neox | | neox |
| qk_norm: false | | qk_norm: true | | qk_norm: true |
| embed_scale: 0 | | tie_embed: true | | embed_scale: |
+--------+---------+ +--------+---------+ | sqrt(dim) |
| | +--------+---------+
| | |
+------------------------+-------------------------+
|
v
+----------------------------+
| build_dispatch_table(...) |
| (reads ArchDescriptor, |
| never branches on arch) |
+----------------------------+
We will cover ArchDescriptor in depth in Chapter 22.
Principle 2: Precompiled Dispatch (DispatchTable)
This is the big one. In most inference engines, every forward pass involves:
- Looking up which kernel to run
- Resolving buffer pointers
- Computing dispatch geometry (grid size, threadgroup size)
- Encoding the compute command
Akunu does all of this once, at model load time, and stores the result in a
flat array of DispatchCmd structs. The decode hot path simply iterates this array,
patches a couple of per-token fields (position, token offset), and submits the whole
thing to the GPU.
Model load time (once): Decode time (every token):
Parse weights for each token:
| for each cmd in dispatch_table:
v patch position
Resolve kernel names patch token offset
| submit to encoder
v
Look up Pipeline State Objects
|
v
Compute grid dimensions
|
v
Bind buffers + params
|
v
Store in DispatchCmd[]
The DispatchCmd struct itself is a fixed-size POD type with no heap allocations:
DispatchCmd (fixed size, no heap):
+-----------------------------------+
| Pipeline pso | 8 bytes
| Buffer buffers[8] | 8 x 24 bytes
| uint32_t offsets[8] | 32 bytes
| int buffer_count | 4 bytes
| uint8_t param_bytes[64] | 64 bytes (inline kernel params)
| int param_size, param_index | 8 bytes
| Buffer param_buf | 24 bytes
| uint8_t param2_bytes[16] | 16 bytes (secondary params)
| Dim3 grid, threadgroup | 24 bytes
| bool use_dispatch_threads | 1 byte
| PatchType patch_type | 1 byte
| int patch_offset_1, patch_offset_2| 8 bytes
+-----------------------------------+
The entire forward pass for a single token is typically 50-100 commands (embedding + N layers * ~5 commands each + output norm + logit projection + argmax). These are stored contiguously in memory, which is great for the CPU cache.
This is why akunu’s decode is fast: the CPU-side work per token is essentially a
memcpy of a few patched bytes plus a loop of setBuffer/setBytes/dispatch
calls – all inlined, no virtual dispatch, no hash lookups, no string comparisons.
Principle 3: Zero-Allocation Hot Path
Once the model is loaded, the decode loop allocates zero bytes of memory. All
buffers are pre-allocated at model init time in a ScratchBuffers struct:
ScratchBuffers (all pre-allocated at model load):
Decode (single token):
h0 [dim] -- embedding output / residual ping
h1 [dim] -- residual pong
residual [dim] -- norm output
qkv [q_dim+2*kv] -- contiguous Q|K|V
attn_out [max(q_dim,dim)]
ffn_gate [ffn_dim]
ffn_up [ffn_dim]
ffn_act [ffn_dim]
logits [vocab_size]
token_ids [max_chain]
Prefill (batch):
batch_h0 [chunk * dim]
batch_q [chunk * q_dim]
batch_k [chunk * kv_dim]
... (same pattern)
The KV cache is also pre-allocated to the maximum context length. The decode loop
never calls malloc, never calls device.allocate, never resizes a vector. This
matters more than you might think – on Apple Silicon, malloc can take microseconds,
and when you are generating 400+ tokens per second, every microsecond counts.
Even the thread-local error buffer is a static char[512]:
thread_local char error_buf[512] = {};
No std::string, no exceptions, no heap in the hot path.
Principle 4: Virtual Device Interface
The Device base class provides a pure-virtual interface with about 30 methods
covering buffer allocation, kernel loading, command encoding, and synchronization.
Today there is exactly one implementation: MetalDevice, which wraps the Metal
API. But the abstraction exists for a reason.
+------------------+
| Device (base) |
| pure virtual |
+--------+---------+
|
+------------+------------+
| |
+----+-------+ +-----+------+
| MetalDevice| | CudaDevice |
| (ObjC++) | | (future) |
+------------+ +------------+
All backend-specific code lives behind this interface. The core engine – the
dispatch table builder, the prefill encoder, the decode loop – is pure C++ with
no #import, no @autoreleasepool, no id<MTLBuffer>. If someone wanted to port
akunu to CUDA, they would implement CudaDevice and everything else would just
work.
This is not a hypothetical – the clean separation was a deliberate design choice.
We will explore the Device interface in detail in Chapter 21.
Principle 5: Lazy Weight Loading
GGUF files can be large – a Q4_0 quantized 8B model is about 4.5 GB. Loading all weights into GPU memory at once would waste time on unused tensors and spike memory usage during initialization.
Akunu’s WeightStore uses lazy loading: when you call get_tensor("layers.5.attention.q.weight"),
it checks its internal cache. If the tensor has not been loaded yet, it reads
the raw bytes from the GGUF file (using memory-mapped I/O) and uploads them to a
GPU buffer. Subsequent calls return the cached buffer instantly.
The WeightProvider class unifies this behind a common interface for both GGUF and
MLX SafeTensors formats:
+-------------------+
| WeightProvider |
| (unified facade) |
+--------+----------+
|
+------------+------------+
| |
+------+------+ +------+------+
| WeightStore | | MLXWeightStore|
| (GGUF) | | (SafeTensors) |
+------+------+ +------+--------+
| |
+-------+-------+ +------+--------+
| GGUF mmap'd | | SafeTensors |
| file on disk | | + config.json |
+----------------+ +---------------+
Weight fusion also happens here: for performance, akunu can concatenate the Q, K,
and V projection matrices (or gate + up FFN matrices) into a single contiguous GPU
buffer, allowing one large GEMV instead of two or three small ones. This is
controlled by ChipConfig.should_fuse_weights, which is true on chips with large
enough SLC (System Level Cache).
Project Structure
The project is about 265 source files (C++, Objective-C++, Metal) plus around 135 Metal shader files. Here is how they are organized:
akunu/
|
+-- include/akunu/ Public C API headers
| +-- akunu.h Opaque handle API (load, generate, encode, etc.)
| +-- types.h POD structs (AkunuModelConfig, AkunuGenerationStats, etc.)
|
+-- src/
| +-- core/ Architecture-agnostic engine core
| | +-- device.h Virtual GPU device interface
| | +-- dispatch_table.h DispatchCmd + DispatchTable + encode_chain()
| | +-- table_builder.h/cpp Builds dispatch table from weights + config
| | +-- arch_descriptor.h ArchDescriptor + factory functions
| | +-- dtype_descriptor.h DTypeDescriptor + kernel lookup tables
| | +-- chip_config.h ChipConfig (hardware tuning)
| | +-- prefill.h/cpp Batched prefill (GEMM-based)
| |
| +-- inference/ High-level inference orchestration
| | +-- model_state.h ModelState struct (the opaque handle's guts)
| | +-- model_loader.cpp Model loading + initialization
| | +-- decode_loop.cpp Top-level generate loop (prefill + decode)
| | +-- decode_greedy.cpp Chain decode (greedy, zero-alloc)
| | +-- decode_sampled.cpp Sampled decode (top-k/p/min-p)
| | +-- decode_speculative.cpp N-gram speculative decode
| | +-- decode_grammar.cpp Grammar-constrained decode
| | +-- sampling.cpp CPU-side sampling (softmax, top-k, top-p)
| | +-- embedding.cpp BERT-style embedding extraction
| |
| +-- cache/ Memory management
| | +-- kv_cache.h Per-layer K/V buffer arrays
| | +-- scratch.h Pre-allocated scratch buffers
| | +-- whisper_buffers.h Whisper-specific encoder buffers
| |
| +-- weight/ Weight file I/O
| | +-- weight_provider.h Unified GGUF/MLX interface
| | +-- weight_store.h/cpp GGUF weight loading + fusion
| | +-- mlx_weight_store.h/cpp MLX SafeTensors loading
| | +-- gguf_parser.h/cpp Low-level GGUF format parsing
| | +-- safetensors_parser.h SafeTensors header parsing
| |
| +-- tokenizer/ BPE tokenizer
| +-- grammar/ Grammar-constrained decoding (GBNF + XGrammar)
| +-- whisper/ Whisper encoder + decoder
| +-- audio/ Mel spectrogram computation
| +-- server/ OpenAI-compatible HTTP server
| +-- speculative/ N-gram draft predictor
| +-- akunu_api.cpp C API implementation (thin wrappers)
|
+-- backend/
| +-- metal/
| +-- metal_device.h/mm MetalDevice implementation (ObjC++)
| +-- metal_device_impl.h AkunuMetalState (ObjC wrapper)
| +-- metal_types.h Metal-specific type aliases
| +-- kernels/ ~135 Metal shader files
| +-- metal/kernel/
| +-- attention/ Flash attention (prefill + decode variants)
| +-- matmul/ GEMV + GEMM for all quant types
| +-- norm/ RMSNorm, LayerNorm, head norms
| +-- rope/ RoPE (interleaved + NeoX + fused variants)
| +-- activation/ SiLU, GELU, gated variants
| +-- embedding/ Token embedding lookup (all dtypes)
| +-- sampling/ GPU-side argmax, top-k, temperature, penalties
| +-- convert/ Dtype conversion (F32<->F16, dequant)
| +-- conv/ Conv1D for Whisper frontend
| +-- fused/ Fused kernels (GEMV+norm, whisper GEMV)
|
+-- tools/ CLI executables
| +-- akunu_chat.cpp Interactive chat
| +-- akunu_bench.cpp llama-bench compatible benchmark
| +-- akunu_complete.cpp Text completion
| +-- akunu_inspect.cpp Model weight inspector
| +-- akunu_profile.cpp Per-layer GPU profiler
| +-- akunu_serve.cpp OpenAI-compatible HTTP server
| +-- akunu_transcribe.cpp Whisper transcription
| +-- akunu_benchmark.cpp Extended benchmarking
|
+-- tests/ Unit + integration tests
+-- 3rdparty/ XGrammar submodule
+-- bindings/ Language bindings (Swift)
+-- CMakeLists.txt Build system
+-- Makefile Top-level build driver
If you count the lines of actual akunu code (excluding 3rdparty), the core engine is roughly 9,750 lines of C++ and Objective-C++, plus about 135 Metal shader files. That is remarkably compact for what it does – a full inference engine supporting 6 architectures, 2 weight formats, 16+ quantization types, grammar-constrained decoding, speculative decoding, Whisper transcription, and an HTTP server.
High-Level Inference Flow
Let us trace what happens when you call akunu_generate() with a prompt. This is the
30,000-foot view; later chapters will zoom in on each step.
akunu_generate(model, prompt_tokens, n_prompt, max_tokens, sampling, callback, ...)
|
v
run_decode_loop(state, ...)
|
+-- 1. PREFILL (batched, GEMM-based)
| |
| | for chunk in prompt_tokens (up to max_prefill_chunk at a time):
| | encode_prefill(device, weights, config, arch, kv_cache, scratch,
| | chunk_tokens, chunk_size, position)
| | |
| | | For each layer:
| | | GEMM: batch_residual @ Q_weight -> batch_q
| | | GEMM: batch_residual @ K_weight -> batch_k
| | | GEMM: batch_residual @ V_weight -> batch_v
| | | RoPE + write to KV cache
| | | Flash attention (prefill variant)
| | | GEMM: attn_out @ O_weight -> batch_h1
| | | Residual add + RMSNorm
| | | GEMM: batch_residual @ gate_weight -> batch_gate
| | | GEMM: batch_residual @ up_weight -> batch_up
| | | Activation (SiLU*gate or GELU*gate)
| | | GEMM: batch_act @ down_weight -> batch_h1
| | | Residual add + RMSNorm (next layer)
| | |
| | v
| | Output norm + logit projection + argmax -> first token
| |
| v
| Return first_token + timing stats
|
+-- 2. DECODE (chain decode, GEMV-based)
|
| Choose decode path:
| - temperature == 0 && speculation_enabled -> decode_speculative()
| - temperature == 0 -> decode_greedy()
| - temperature > 0 -> decode_sampled()
| - grammar != null -> decode_grammar()
|
| decode_greedy (hot path):
| while generated < max_tokens:
| write next_token to token_ids buffer
| device.begin_encoding()
| device.encode_dispatch_table(&dispatch_table, position, chunk_size)
| device.end_encoding_sync() (or async with double buffering)
| kv_cache.advance(chunk_size)
| read token_ids buffer -> output tokens
| for each token: callback(token_id, text, user_data)
|
v
Return AkunuGenerationStats { prefill_time, decode_time, tokens/sec, ... }
A few things to notice:
Prefill uses GEMM, decode uses GEMV. During prefill, we process many tokens at once, so the Q/K/V projections are matrix-matrix multiplications (M > 1). During decode, we process one token at a time, so they are matrix-vector multiplications (M = 1). Akunu has separate optimized kernels for each.
Chain decode generates multiple tokens per GPU submission. Instead of submitting
one command buffer per token, akunu submits a “chain” of N tokens in a single
begin_encoding() / end_encoding() pair. The dispatch table is replayed N times
with patched position values. This amortizes the Metal command buffer overhead
across many tokens. The chain size is tuned per chip (64-128 tokens, see ChipConfig).
The callback is synchronous. After each GPU chunk completes, tokens are read
back from the token_ids buffer and delivered to the user’s callback one at a time.
The callback can return false to stop generation early.
No Python in the loop. The entire flow – from tokenization through GPU dispatch
through token decoding – is C++. The C API boundary is a thin wrapper in
akunu_api.cpp that casts the opaque void* handle to ModelState* and forwards
the call.
The ModelState: What Lives Behind the Opaque Handle
When you call akunu_load_model(), it returns an akunu_model_t, which is a
void* pointing to a ModelState struct. This is the central state object that
ties everything together:
struct ModelState {
std::unique_ptr<Device> device; // GPU device (MetalDevice)
WeightProvider *weights; // Weight file access
Tokenizer tokenizer; // BPE tokenizer
AkunuModelConfig config; // Parsed model config
ArchDescriptor arch; // Architecture descriptor
ChipConfig chip; // Hardware tuning params
KVCache kv_cache; // Per-layer K/V buffers
ScratchBuffers scratch; // Pre-allocated intermediates
DispatchTable dispatch_table; // Precompiled decode commands
NGramPredictor predictor; // Speculative n-gram predictor
bool speculation_enabled; // Whether spec decode is on
// Whisper-specific fields
bool is_whisper;
std::unique_ptr<WhisperBuffers> whisper_buf;
std::unique_ptr<MelSpectrogram> mel_spec;
std::unique_ptr<WhisperModel> whisper_model;
DispatchTable whisper_decode_table;
// ... beam search buffers
};
This is it. One struct, one allocation. The entire engine state fits in a single cache-friendly object. Compare this to inference frameworks that scatter state across dozens of Python objects, each with its own reference counting and garbage collection pressure.
Model Loading: What Happens at Init Time
The akunu_load_model() function is where all the expensive work happens. Here is
the sequence:
akunu_load_model(path, metallib_path, max_context)
|
+-- 1. Create MetalDevice (MTLCreateSystemDefaultDevice)
+-- 2. Load metallib (compiled shader library)
+-- 3. Open weight file (GGUF or MLX SafeTensors)
+-- 4. Parse model config (dims, layers, heads, vocab, etc.)
+-- 5. Select ArchDescriptor (arch_from_config)
+-- 6. Detect ChipConfig (GPU cores, family, SLC estimate)
+-- 7. Handle format-specific quirks:
| - MLX LLaMA: switch to NeoX RoPE
| - Tie embeddings if output.weight missing
| - Set quant_bits / quant_group_size from MLX metadata
+-- 8. Precompute RoPE frequencies (LLaMA 3 wavelen scaling)
+-- 9. Load tokenizer (from GGUF metadata or HF tokenizer.json)
+-- 10. Set context length (capped at model max or user-specified)
+-- 11. Allocate KV cache (n_layers * 2 buffers * max_seq_len)
+-- 12. Allocate ScratchBuffers (all intermediates)
+-- 13. Build DispatchTable (resolves all PSOs, binds buffers)
+-- 14. Warmup pass (compiles remaining Metal pipelines)
+-- 15. Return ModelState* as opaque handle
Steps 1-12 are straightforward initialization. Step 13 is where the magic happens:
build_dispatch_table() walks through the entire forward pass – embedding, norms,
projections, RoPE, attention, FFN – and for each operation, it resolves the kernel
name from DTypeDescriptor, looks up or compiles the Metal pipeline state object,
computes the dispatch geometry, binds the weight and scratch buffers, and stores
everything in a DispatchCmd. By the time this function returns, the engine knows
exactly what to do for each token – no runtime decisions left.
Supported Architectures at a Glance
Let us briefly survey how each supported architecture maps to akunu’s abstractions:
+----------+-------------+----------+----------+--------+--------+--------+
| Arch | Activation | RoPE | QK Norm | PostNm | Enc/Dec| Embed |
| | | | | | | Scale |
+----------+-------------+----------+----------+--------+--------+--------+
| LLaMA | silu_gate | interleav| no | no | no | 0 |
| Qwen3 | silu_gate | neox | yes | no | no | 0 |
| Gemma | gelu_gate | neox | yes | yes | no | sqrt(d)|
| Gemma3 | gelu_gate | neox | yes | yes | no | sqrt(d)|
| Whisper | gelu (plain)| none | no | no | yes | 0 |
| BERT | silu_gate | neox | no | no | no | 0 |
+----------+-------------+----------+----------+--------+--------+--------+
All of these differences are captured in the ArchDescriptor struct – no special
code paths. Gemma 3’s sliding window attention with alternating global/local layers?
That is just cfg.sliding_window_pattern > 0 in the RoPE theta selection. Whisper’s
Conv1D frontend, cross-attention, and sinusoidal positional embeddings? Those are
flag fields in ArchDescriptor plus a separate WhisperModel loading path.
The key insight is that most transformer architectures are minor variations on the same theme. They all have embedding, norm, QKV projection, attention, output projection, FFN, and output logits. The differences are in which activation function, which RoPE variant, whether there is an extra normalization step, and so on. Akunu exploits this regularity by parameterizing the differences rather than branching on them.
What Makes This Different from Other Engines
If you have used llama.cpp, MLX, or vLLM, you might wonder what akunu does differently. Here is a quick comparison:
vs. llama.cpp: llama.cpp builds a computation graph (ggml) at runtime and evaluates it node-by-node. Each node involves a virtual dispatch to find the right kernel, plus buffer management. Akunu eliminates this overhead by precompiling the entire forward pass into a flat command array. llama.cpp is more general (it runs on CPU, CUDA, Metal, Vulkan, etc.), but akunu squeezes more performance out of Metal specifically.
vs. MLX: MLX is a general-purpose array framework (like PyTorch) that happens to run on Metal. It has a JIT compiler, automatic differentiation, and a Python frontend. This generality comes at a cost: each operation goes through MLX’s dispatch layer, which involves hash lookups and potentially JIT compilation. Akunu bypasses all of this – it talks directly to Metal with precompiled pipelines.
vs. vLLM: vLLM targets datacenter GPU inference with features like PagedAttention, continuous batching, and multi-GPU tensor parallelism. Akunu targets single-device Apple Silicon with features like chain decode, SLC-aware weight fusion, and Metal-specific kernel optimization. Different tools for different jobs.
The common thread is specialization. Akunu does fewer things, but does them very well on one specific hardware platform.
A Note on Code Style
Before we dive deeper in the following chapters, a word about the codebase style. Akunu is written in C++17 with a strong preference for:
- POD structs over class hierarchies (DispatchCmd, KVCache, ScratchBuffers)
- Fixed-size inline storage over heap allocation (param_bytes[64], buffers[8])
- Factory functions over constructors (KVCache::create, ScratchBuffers::create)
- Explicit state over hidden globals (ModelState holds everything)
- One virtual class (Device) instead of a deep hierarchy
- Thread-local for truly per-thread state (error buffer, RNG)
The Metal backend uses Objective-C++ (.mm files) because it has to – Metal is an
Objective-C API. But this is strictly quarantined behind the Device interface. The
rest of the engine is pure C++ that compiles with any standard compiler.
Error handling is C-style: functions return null/false on failure and set a
thread-local error string via set_error(). No exceptions in the hot path. No
RAII wrappers around GPU resources (the ModelState destructor handles cleanup).
This style is not “modern C++” in the Herb Sutter sense, but it is effective for systems programming where you care about memory layout, cache behavior, and predictable performance.
Summary
Akunu is a tightly-focused inference engine that trades generality for performance on Apple Silicon. Its key design decisions are:
- ArchDescriptor – all architecture differences as data, not branches
- DispatchTable – precompiled GPU command sequences, replayed per token
- Zero-allocation decode – all buffers pre-allocated, nothing on the hot path
- Virtual Device – clean Metal abstraction, ready for future backends
- Lazy weight loading – tensors loaded on demand, with fusion support
The result is an engine that achieves 1.8x average speedup over llama.cpp on decode and 1.17x over MLX, in about 10,000 lines of C++ plus 135 Metal shaders.
In the next chapter, we will see how to build and run akunu from source.
Building and Running Akunu
This chapter covers the practical mechanics of building akunu from source and using its CLI tools. If you have been reading the previous chapters to understand how akunu works, this is where you roll up your sleeves and actually compile it. We will walk through the CMake configuration, build options, Metal shader compilation, and each of the CLI tools.
Prerequisites
Akunu requires:
- macOS 13 (Ventura) or later – for Metal 3 support and Apple GPU family 7+
- Xcode 15+ (or at least the Command Line Tools) – for the
clang++compiler with Objective-C++ support and themetalshader compiler - CMake 3.20+ – the build system
- Apple Silicon Mac – M1 or later. Akunu’s Metal backend requires UMA and
simdgroup_matrixsupport (GPU family 7+) - A model file – either GGUF format (from llama.cpp ecosystem) or MLX SafeTensors directory
Optional dependencies:
- XGrammar (v0.1.33) – for grammar-constrained JSON output. Included as a git submodule at
3rdparty/xgrammar
Project Structure
akunu/
├── CMakeLists.txt # Top-level build configuration
├── include/
│ └── akunu/
│ ├── akunu.h # C API header
│ └── types.h # Shared type definitions
├── src/
│ ├── akunu_api.cpp # C API implementation
│ ├── core/ # Backend-agnostic core
│ │ ├── device.h # Device abstraction
│ │ ├── dispatch_table.h # Precompiled command sequence
│ │ ├── table_builder.cpp
│ │ ├── arch_descriptor.h
│ │ ├── chip_config.h
│ │ ├── dtype_descriptor.h
│ │ └── ...
│ ├── weight/ # GGUF/MLX weight loading
│ ├── tokenizer/ # BPE tokenizer
│ ├── grammar/ # GBNF/JSON schema grammar
│ ├── inference/ # Decode loops, sampling
│ ├── cache/ # KV cache, scratch buffers
│ ├── server/ # HTTP server
│ ├── speculative/ # Speculative decoding
│ └── whisper/ # Whisper speech-to-text
├── backend/
│ └── metal/
│ ├── metal_device.h
│ ├── metal_device.mm # ObjC++ Metal implementation
│ ├── metal_device_impl.h
│ └── kernels/ # .metal shader source files
├── tools/ # CLI executables
│ ├── akunu_chat.cpp
│ ├── akunu_bench.cpp
│ ├── akunu_complete.cpp
│ ├── akunu_inspect.cpp
│ ├── akunu_profile.cpp
│ ├── akunu_benchmark.cpp
│ ├── akunu_serve.cpp
│ └── akunu_transcribe.cpp
├── tests/ # Test executables
│ ├── kernels/ # Per-kernel correctness tests
│ └── ...
└── 3rdparty/
└── xgrammar/ # Git submodule
CMake Configuration
The build is configured through CMakeLists.txt. Let’s walk through the key sections.
Language Standards
cmake_minimum_required(VERSION 3.20)
project(akunu VERSION 0.1 LANGUAGES CXX OBJCXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_OBJCXX_STANDARD 17)
set(CMAKE_OBJCXX_FLAGS "${CMAKE_OBJCXX_FLAGS} -fobjc-arc")
Akunu uses C++17 and Objective-C++ 17. The -fobjc-arc flag enables Automatic Reference Counting for Objective-C objects – this is how MetalDevice manages Metal API objects (MTLDevice, MTLCommandBuffer, etc.) without manual retain/release.1
Backend Selection
option(AKUNU_BACKEND_METAL "Build Metal backend (Apple Silicon)" ON)
option(AKUNU_BACKEND_CUDA "Build CUDA backend (NVIDIA)" OFF)
Metal is enabled by default. CUDA exists as a placeholder for future work. The backend selection determines which source files and frameworks are linked.
Core Sources
The core engine is pure C++ (no platform dependencies):
set(CORE_SOURCES
src/weight/gguf_parser.cpp # GGUF file format parser
src/weight/weight_store.cpp # Weight management + fusion
src/weight/mlx_weight_store.cpp # MLX SafeTensors parser
src/core/table_builder.cpp # Dispatch table construction
src/core/device_defaults.cpp # Default Device method implementations
src/core/prefill.cpp # Prefill (batched) forward pass
src/tokenizer/tokenizer.cpp # BPE tokenizer
src/grammar/grammar.cpp # GBNF grammar parser
src/grammar/json_schema_to_grammar.cpp
src/grammar/xgrammar_wrapper.cpp # XGrammar integration
src/inference/model_state.cpp # Model lifecycle
src/inference/sampling.cpp # Temperature, top-k, top-p, min-p
src/inference/model_loader.cpp # Model loading orchestration
src/inference/decode_greedy.cpp # Greedy decode loop
src/inference/decode_sampled.cpp # Sampled decode loop
src/inference/decode_speculative.cpp
src/inference/decode_grammar.cpp # Grammar-constrained decode
src/inference/decode_loop.cpp # Common decode infrastructure
src/inference/embedding.cpp # Text embedding (BERT)
src/whisper/whisper_inference.cpp # Whisper transcription
src/akunu_api.cpp # C API implementation
)
Metal Backend
When AKUNU_BACKEND_METAL is ON:
if(AKUNU_BACKEND_METAL)
list(APPEND BACKEND_SOURCES backend/metal/metal_device.mm)
endif()
The Metal backend is a single Objective-C++ file (metal_device.mm) that implements the Device virtual interface.
Framework Linking
The Metal backend links five Apple frameworks:
target_link_libraries(akunu_engine PUBLIC
"-framework Metal" # GPU compute
"-framework MetalPerformanceShaders" # (available for future use)
"-framework Foundation" # NSObject, NSString, NSURL
"-framework Accelerate" # vDSP (audio processing for Whisper)
"-framework IOKit" # GPU core count detection
)
| Framework | Purpose in Akunu |
|---|---|
| Metal | Core GPU API: device, command buffers, compute pipelines |
| MetalPerformanceShaders | Linked but not actively used (available for optimized primitives) |
| Foundation | Objective-C runtime, file URLs, string conversion |
| Accelerate | vDSP for FFT/mel spectrogram in Whisper audio preprocessing |
| IOKit | IORegistryEntryCreateCFProperty to query gpu-core-count from AGXAccelerator |
XGrammar Integration
The XGrammar submodule provides grammar-constrained decoding:
set(XGRAMMAR_DIR "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/xgrammar")
if(EXISTS "${XGRAMMAR_DIR}/include/xgrammar/xgrammar.h")
add_subdirectory(${XGRAMMAR_DIR} ${CMAKE_BINARY_DIR}/xgrammar EXCLUDE_FROM_ALL)
set(AKUNU_HAS_XGRAMMAR ON)
endif()
If the submodule is not initialized, XGrammar is simply disabled and grammar-constrained generation will not be available. To enable it:
git submodule update --init --recursive
Shared Library for Bindings
option(AKUNU_BUILD_SHARED "Build shared library for language bindings" OFF)
When enabled, this builds libakunu.dylib in addition to the static libakunu_engine.a. The shared library exposes the C API (akunu.h) and can be loaded by Python, Swift, or any language with C FFI support.
Building from Source
Basic Build
cd ~/Projects/akunu
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
make -j$(sysctl -n hw.ncpu)
This produces:
libakunu_engine.a– the static libraryakunu_chat,akunu_bench,akunu_complete, etc. – CLI toolsakunu_test_*,akunu_kernel_*– test executables
Build with XGrammar
git submodule update --init --recursive
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
make -j$(sysctl -n hw.ncpu)
The build system auto-detects XGrammar and sets AKUNU_HAS_XGRAMMAR=1.
Build Shared Library
cmake .. -DCMAKE_BUILD_TYPE=Release -DAKUNU_BUILD_SHARED=ON
make -j$(sysctl -n hw.ncpu)
Debug Build
cmake .. -DCMAKE_BUILD_TYPE=Debug
make -j$(sysctl -n hw.ncpu)
Debug builds enable assertions and disable optimizations. For Metal shader debugging, also enable the Metal validation layer:
export METAL_DEVICE_WRAPPER_TYPE=1
export MTL_DEBUG_LAYER=1
./akunu_chat model.gguf akunu.metallib
Metal Shader Compilation
The Metal shader sources live in backend/metal/kernels/. They must be compiled into a .metallib file before akunu can load them. The compilation pipeline is:
.metal source files
│
▼ (metal compiler)
.air intermediate files
│
▼ (metallib archiver)
akunu.metallib
The compilation command (not automated by CMake – you need to do this manually or via a script):
# Compile all .metal files to .air
xcrun -sdk macosx metal -c -target air64-apple-macos13.0 \
-I backend/metal/kernels/metal/include \
backend/metal/kernels/metal/kernel/**/*.metal \
-o kernels.air
# Archive into metallib
xcrun -sdk macosx metallib kernels.air -o akunu.metallib
The -I flag adds the include directory for shared headers (ShaderTypes.h, KernelCommon.h) that are used across kernel files.
The resulting akunu.metallib file contains all GPU kernels in compiled form. At runtime, MetalDevice::load_library() loads this file and individual kernels are extracted by name via get_pipeline().2
Kernel Organization
The Metal kernels are organized by category:
| Directory | Kernels | Count |
|---|---|---|
activation/ | SiLU, GELU, gated variants | ~4 |
attention/ | Flash attention decode, softmax, logit cap | ~3 |
common/ | Bias add, residual add, transpose, vector ops | ~5 |
conv/ | Conv1D for Whisper frontend | 1 |
convert/ | Dequantize (Q4_0, Q4_K, Q8_0, MLX, etc.) | ~12 |
embedding/ | Embedding lookup per dtype | ~10 |
fused/ | Fused SiLU+GEMV, fused head norm | ~2 |
kv_cache/ | KV cache write, shift | ~2 |
matmul/ | GEMV, wide GEMV, SIMD GEMM per dtype | ~50+ |
norm/ | RMSNorm, LayerNorm, residual norm, head norm | ~5 |
rope/ | RoPE (standard, NeoX), fused norm+RoPE | ~4 |
sampling/ | Argmax, temperature scaling, top-k/p | ~4 |
Total: roughly 100+ kernel functions compiled into a single metallib.
CLI Tools
akunu_chat
Interactive chat with a loaded model. Handles conversation formatting using the model’s native chat template.
./akunu_chat path/to/model.gguf path/to/akunu.metallib
Features:
- Auto-detects chat template (ChatML, Llama 3, Gemma, etc.)
- Multi-turn conversation with KV cache reuse
- Streaming token output
- System prompt support
akunu_bench
Performance benchmarking tool. Measures prefill and decode throughput.
./akunu_bench path/to/model.gguf path/to/akunu.metallib
Reports:
- Prefill tokens/second (for various prompt lengths)
- Decode tokens/second (steady-state generation)
- Memory usage
- Model configuration summary
akunu_complete
Text completion (non-chat). Takes a prompt and generates a continuation.
./akunu_complete path/to/model.gguf path/to/akunu.metallib
Useful for testing raw model behavior without chat formatting.
akunu_inspect
Model inspection tool. Dumps model metadata and tensor information.
./akunu_inspect path/to/model.gguf
Shows:
- Architecture, vocabulary size, embedding dimension
- Number of layers, heads, KV heads
- RoPE configuration
- Tensor names, shapes, and dtypes
akunu_profile
Per-layer GPU timing profiler. Runs each operation in its own command buffer for accurate timing.
./akunu_profile path/to/model.gguf path/to/akunu.metallib
Output shows per-operation GPU time in milliseconds, useful for identifying bottlenecks.
akunu_serve
HTTP API server with OpenAI-compatible endpoints.
./akunu_serve path/to/model.gguf path/to/akunu.metallib --port 8080
Provides:
/v1/chat/completions– streaming and non-streaming chat/v1/completions– text completion- Multi-model support (load multiple models)
- Concurrent request handling with per-model mutex
akunu_transcribe
Speech-to-text using Whisper models.
./akunu_transcribe path/to/whisper-model.gguf path/to/akunu.metallib input.wav
Supports:
- WAV input (resampled to 16kHz internally)
- Language detection or forced language
- Timestamp generation
- Streaming segment callback
akunu_benchmark
Extended benchmarking tool with more detailed metrics.
./akunu_benchmark path/to/model.gguf path/to/akunu.metallib
Library Targets
The CMake build produces two main library targets:
| Target | Type | Contents |
|---|---|---|
akunu_engine | Static (libakunu_engine.a) | Core + backend, always built |
akunu | Shared (libakunu.dylib) | Same, built when AKUNU_BUILD_SHARED=ON |
Both expose the same C API defined in include/akunu/akunu.h. The static library is used by all CLI tools and tests. The shared library is intended for language bindings.
Test Executables
The build produces numerous test executables:
Integration Tests
| Test | Purpose |
|---|---|
akunu_test_device | Metal device creation, buffer allocation |
akunu_test_weights | GGUF parsing, weight loading |
akunu_test_table | Dispatch table construction |
akunu_e2e | End-to-end generation test |
akunu_test_long_context | Long context handling |
akunu_test_sampling_quality | Sampling distribution tests |
akunu_test_config | Model configuration parsing |
akunu_test_tokenizer | Tokenizer encode/decode |
akunu_test_inference | Inference pipeline |
akunu_test_kv_cache | KV cache operations |
akunu_test_grammar | Grammar parsing and constrained decoding |
akunu_test_server | HTTP server endpoints |
akunu_test_whisper | Whisper model loading |
akunu_test_whisper_e2e | End-to-end transcription |
Kernel Tests
Individual kernel correctness tests (each compiled as ObjC++):
| Test | Kernel Under Test |
|---|---|
akunu_kernel_test_rmsnorm | rmsnorm_f16 |
akunu_kernel_test_gemma_rmsnorm | rmsnorm_gemma_f16 |
akunu_kernel_test_gemv_f16 | gemv_f16 |
akunu_kernel_test_gemv_q4_0 | gemv_q4_0 |
akunu_kernel_test_gemv_q8_0 | gemv_q8_0 |
akunu_kernel_test_gemm_f16 | simd_gemm_f16 |
akunu_kernel_test_silu | silu_f16 |
akunu_kernel_test_gelu | gelu_f16 |
akunu_kernel_test_silu_gate | silu_gate_f16 |
akunu_kernel_test_gelu_gate | gelu_gate_f16 |
akunu_kernel_test_rope | rope_qkv_write_f16 |
akunu_kernel_test_rope_neox | rope_neox_qkv_write_f16 |
akunu_kernel_test_flash_attention | flash_attention_decode_parallel_f16 |
akunu_kernel_test_embedding_f16 | embedding_lookup_f16 |
akunu_kernel_test_f32_to_f16 | f32_to_f16 |
akunu_kernel_test_dequant_q4_0 | dequant_q4_0 |
These tests compare GPU kernel output against CPU reference implementations to verify correctness within FP16 tolerance.
Troubleshooting
“Failed to load metallib”
The metallib path is incorrect or the file was compiled for a different target. Ensure:
- The metallib file exists at the specified path
- It was compiled with
-target air64-apple-macos13.0or later - The Metal device supports the required GPU family
“Failed to get pipeline: kernel_name”
A kernel function is missing from the metallib. This usually means:
- The kernel source file was not included in the metallib compilation
- There is a naming mismatch between the kernel function name in Metal and the string in C++
“allocate: failed to allocate N bytes”
The model is too large for available memory. Options:
- Use a smaller quantization (Q4_0 instead of FP16)
- Reduce
max_contextto shrink KV cache - Close other applications to free memory
- Use a chip with more unified memory
Build Errors with XGrammar
If XGrammar fails to build, you can disable it:
cmake .. -DCMAKE_BUILD_TYPE=Release
# XGrammar auto-disables if submodule is not initialized
Summary
Building akunu is a standard CMake workflow. The main moving parts are:
- CMake configuration with
AKUNU_BACKEND_METAL=ON(default) - Metal shader compilation into
akunu.metallib(manual step) - Framework linking for Metal, Foundation, Accelerate, IOKit
- CLI tools for chat, benchmark, profiling, serving, and transcription
The next chapter covers the C API that all these tools are built on.
-
Apple, “Transitioning to ARC Release Notes.” ARC (Automatic Reference Counting) eliminates manual
retain/releasecalls for Objective-C objects. The compiler inserts retain/release operations automatically. See https://developer.apple.com/library/archive/releasenotes/ObjectiveC/RN-TransitioningToARC/. ↩ -
Apple, “Building a Library with Metal’s Command-Line Tools.” The
metalandmetallibcommand-line tools compile .metal sources into .metallib archives. See https://developer.apple.com/documentation/metal/shader_libraries/building_a_shader_library_by_precompiling_source_files. ↩
The C API
Every akunu CLI tool – akunu_chat, akunu_bench, akunu_serve, akunu_transcribe – is built on top of a single C API defined in two header files: include/akunu/akunu.h and include/akunu/types.h. This chapter explains why akunu uses a C API rather than a C++ one, walks through every major function, and shows how to build a complete application from scratch.
Why C?
This is a C++ project. The core engine is written in C++17 with Objective-C++ for the Metal backend. So why is the public API in plain C?
1. FFI compatibility. Every programming language can call C functions. Python has ctypes and cffi. Swift has direct C interop. Rust has extern "C". Java has JNI. Go has cgo. A C API is the universal adapter.1
2. ABI stability. C++ name mangling, vtable layouts, and standard library implementations differ between compilers and versions. A C API with POD structs and opaque pointers has a stable ABI – you can swap out the shared library without recompiling the caller.
3. No header dependencies. The C API headers include only <stdint.h>, <stdbool.h>, and <stddef.h>. No C++ standard library, no Metal headers, no Objective-C. Any C or C++ compiler can parse them.
4. Opaque handles prevent misuse. The caller cannot poke at internal state because the model is just a void*. This forces all interaction through the API functions, making it possible to change internal representations without breaking callers.
Header Organization
The API is split across two files:
types.h – Shared Data Structures
// types.h contains POD structs used by both the C API and internal C++ code
typedef struct {
uint32_t dim; // embedding dimension
uint32_t n_layers; // transformer layers
uint32_t n_heads; // query heads
uint32_t n_kv_heads; // key/value heads (GQA)
uint32_t head_dim; // dim per head
uint32_t q_dim; // total Q projection output
uint32_t kv_dim; // total KV projection output
uint32_t ffn_dim; // feed-forward intermediate dimension
uint32_t vocab_size; // vocabulary size
uint32_t max_seq_len; // maximum context length
float norm_eps; // RMSNorm/LayerNorm epsilon
float rope_theta; // RoPE base frequency
uint32_t sliding_window_pattern;
float rope_local_theta;
char architecture[32]; // "llama", "qwen3", "gemma", "whisper"
// Encoder parameters (0 = decoder-only)
uint32_t enc_n_layers;
uint32_t enc_n_heads;
uint32_t enc_dim;
uint32_t enc_ffn_dim;
uint32_t n_mels; // mel spectrogram bins (Whisper)
uint32_t enc_max_seq_len;
} AkunuModelConfig;
Notice that AkunuModelConfig uses fixed-size arrays (char architecture[32]) instead of std::string, and all fields are primitive types. This is a POD struct – it can be safely passed across C/C++ boundaries and even memory-mapped.
Other types in types.h:
| Type | Purpose | Fields |
|---|---|---|
AkunuModelConfig | Model architecture metadata | dim, layers, heads, vocab, etc. |
AkunuSamplingConfig | Generation sampling parameters | temperature, top_k, top_p, min_p, repeat_penalty |
AkunuGenerationStats | Post-generation statistics | prompt tokens, generated tokens, prefill/decode times |
AkunuTranscribeStats | Post-transcription statistics | audio_ms, encode_ms, decode_ms, total_ms |
akunu.h – The Function API
The main header declares all API functions inside an extern "C" block:
#ifdef __cplusplus
extern "C" {
#endif
typedef void *akunu_model_t; // Opaque model handle
// ... function declarations ...
#ifdef __cplusplus
}
#endif
The extern "C" block ensures C linkage (no name mangling) when compiled as C++. The #ifdef __cplusplus guards make the header valid for both C and C++ compilers.
The Opaque Handle Pattern
The entire model state – weights, KV cache, scratch buffers, dispatch table, tokenizer, device – is wrapped behind a single opaque handle:
typedef void *akunu_model_t;
This is a pointer to an internal C++ ModelState object that the caller never sees. Every API function takes this handle as its first argument. The pattern is:
// Create
akunu_model_t model = akunu_load_model("model.gguf", "akunu.metallib", 0);
// Use
akunu_generate(model, tokens, n_tokens, 256, sampling, callback, NULL);
// Destroy
akunu_free_model(model);
No global state, no singletons. You can load multiple models simultaneously by creating multiple handles. Each handle owns its own GPU resources.2
Model Lifecycle
Loading
akunu_model_t akunu_load_model(const char *model_path,
const char *metallib_path,
int max_context);
This function does a lot of work:
- Creates a
MetalDevice(or default device) - Loads the metallib (
device.load_library()) - Parses the model file (GGUF or MLX SafeTensors)
- Allocates weight buffers and uploads weights to GPU
- Creates the
ArchDescriptorfrom model metadata - Queries
ChipConfigfrom the device - Allocates KV cache and scratch buffers
- Builds the dispatch table (
build_dispatch_table()) - Initializes the tokenizer
- Returns the opaque handle (or NULL on failure)
Parameters:
model_path: Path to a.gguffile or MLX SafeTensors directorymetallib_path: Path to compiledakunu.metallib. Pass NULL for auto-detection (searches common paths)max_context: Maximum context window. 0 = use model default (capped at 4096)
Freeing
void akunu_free_model(akunu_model_t model);
Releases all GPU buffers, KV cache, scratch buffers, cached pipeline state objects, and the Metal device. After this call, the handle is invalid.
Error Handling
const char *akunu_get_error(void);
Returns the last error message. This uses thread-local storage, so it is safe to call from multiple threads. If akunu_load_model returns NULL, call this to find out why:
akunu_model_t model = akunu_load_model("bad_path.gguf", "akunu.metallib", 0);
if (!model) {
printf("Error: %s\n", akunu_get_error());
// "Error: Failed to open file: bad_path.gguf"
}
Model Information
AkunuModelConfig akunu_get_config(akunu_model_t model);
size_t akunu_model_memory(akunu_model_t model);
akunu_get_config returns a copy of the model configuration struct. Since AkunuModelConfig is a POD struct, this is a simple memcpy – no dynamic allocation.
akunu_model_memory returns the total GPU memory used by the model in bytes. This includes weights, KV cache, scratch buffers, and pre-allocated parameter buffers.
Tokenization
int akunu_encode(akunu_model_t model, const char *text,
uint32_t *out_tokens, int max_tokens);
const char *akunu_decode_token(akunu_model_t model, uint32_t token_id);
int akunu_token_count(akunu_model_t model, const char *text);
The tokenizer is a BPE implementation built into akunu (no external dependency). Token IDs are uint32_t values.
akunu_encode writes token IDs into a caller-provided buffer. Returns the number of tokens written. If the output buffer is too small, the text is silently truncated.
akunu_decode_token returns a pointer to the token’s text representation. The pointer is valid until the model is freed – it points into the tokenizer’s vocabulary table.
akunu_token_count is a convenience function that counts tokens without allocating an output buffer.
Generation: The Callback Pattern
Generation uses a callback function for streaming output:
typedef bool (*akunu_token_callback)(uint32_t token_id,
const char *text,
void *user_data);
The callback is invoked for each generated token. Returning false stops generation immediately. The user_data pointer is passed through from the akunu_generate call, allowing the callback to access caller state without globals.
AkunuGenerationStats akunu_generate(
akunu_model_t model,
const uint32_t *prompt_tokens, int n_prompt,
int max_tokens,
AkunuSamplingConfig sampling,
akunu_token_callback callback,
void *user_data);
This is the main generation entry point. It:
- Resets the KV cache
- Runs prefill on the prompt tokens
- Enters the decode loop, calling the callback for each token
- Returns statistics (prefill time, decode time, tokens/second)
Sampling Configuration
typedef struct {
float temperature; // 0 = greedy (argmax)
int top_k; // 0 = disabled
float top_p; // 1.0 = disabled
float min_p; // 0.0 = disabled
float repeat_penalty; // 1.0 = disabled
} AkunuSamplingConfig;
Temperature 0 triggers the greedy decode path (argmax on GPU, no CPU sampling). Non-zero temperature runs the sampled decode path with optional top-k, top-p, and min-p filtering.
Generation Statistics
typedef struct {
int prompt_tokens;
int generated_tokens;
float prefill_time_ms;
float decode_time_ms;
float prefill_tokens_per_sec;
float decode_tokens_per_sec;
} AkunuGenerationStats;
This struct is returned by value from akunu_generate. It contains everything you need to report performance.
Continued Generation
For multi-turn chat, you do not want to re-process the entire conversation history each turn. akunu_generate_continue extends the existing KV cache:
AkunuGenerationStats akunu_generate_continue(
akunu_model_t model,
const uint32_t *new_tokens, int n_new,
int max_tokens,
AkunuSamplingConfig sampling,
akunu_token_callback callback,
void *user_data);
This prefills only the new_tokens (the latest user message) and generates from the combined context. The KV cache from previous turns is preserved.
Grammar-Constrained Generation
For structured output (JSON, specific formats), akunu supports grammar-constrained decoding:
akunu_grammar_t akunu_grammar_create(akunu_model_t model, const char *gbnf);
akunu_grammar_t akunu_grammar_create_from_schema(akunu_model_t model,
const char *json_schema);
akunu_grammar_t akunu_grammar_create_json(akunu_model_t model);
void akunu_grammar_free(akunu_grammar_t grammar);
AkunuGenerationStats akunu_generate_grammar(
akunu_model_t model,
const uint32_t *prompt_tokens, int n_prompt,
int max_tokens,
AkunuSamplingConfig sampling,
akunu_grammar_t grammar,
akunu_token_callback callback,
void *user_data);
The grammar handle is opaque, like the model handle. Three factory functions create grammars from GBNF strings, JSON Schema strings, or a generic JSON grammar. The grammar masks invalid tokens at each step, guaranteeing the output conforms to the grammar.3
Low-Level API
For benchmarking and custom decode loops, akunu exposes lower-level functions:
// Run prefill, return first generated token
uint32_t akunu_prefill(akunu_model_t model,
const uint32_t *tokens, int n_tokens);
// Run one decode step, return next token
uint32_t akunu_decode_step(akunu_model_t model,
uint32_t token_id, int position);
// Chain decode: multiple tokens in one GPU submission
int akunu_chain_decode(akunu_model_t model,
uint32_t first_token, int start_position,
int count, uint32_t *out_tokens);
// Get current KV cache position
int akunu_get_position(akunu_model_t model);
// Reset KV cache
void akunu_reset(akunu_model_t model);
The akunu_chain_decode function is the key primitive for fast greedy generation. It encodes the dispatch table N times into a single command buffer, patching position fields for each token. This is how akunu achieves high throughput for greedy (temperature=0) decoding.
Speculative Decoding
void akunu_set_speculation(akunu_model_t model, bool enabled);
When enabled, the decode loop uses n-gram prediction to speculatively generate multiple tokens, then verifies them against the model. Correctly predicted tokens skip full forward passes. This only works with greedy mode (temperature=0) because the verification requires deterministic token selection.
Embeddings
For BERT-style encoder models:
int akunu_embed(akunu_model_t model,
const uint32_t *tokens, int n_tokens,
float *out_embedding, int max_dims);
int akunu_embed_text(akunu_model_t model, const char *text,
float *out_embedding, int max_dims);
int akunu_embedding_dim(akunu_model_t model);
akunu_embed runs a forward pass through the encoder, mean-pools the final hidden layer, and writes the resulting embedding vector to out_embedding. Returns the embedding dimension on success, 0 on failure.
akunu_embed_text is a convenience wrapper that tokenizes the text internally.
Whisper Transcription
const char *akunu_transcribe(akunu_model_t model,
const char *wav_path,
const char *language,
AkunuTranscribeStats *stats_out);
const char *akunu_transcribe_pcm(akunu_model_t model,
const float *samples, int n_samples,
const char *language,
AkunuTranscribeStats *stats_out);
bool akunu_is_whisper(akunu_model_t model);
void akunu_set_timestamps(akunu_model_t model, bool enabled);
The transcription API supports both file-based and PCM buffer input. The returned string is valid until the next call or model free – it points to an internal buffer.
Streaming callbacks are also available:
typedef bool (*akunu_segment_callback)(int start_ms, int end_ms,
const char *text, void *user_data);
const char *akunu_transcribe_stream(akunu_model_t model,
const char *wav_path,
const char *language,
AkunuTranscribeStats *stats_out,
akunu_segment_callback callback,
void *user_data);
Chat Templates
const char *akunu_format_chat(akunu_model_t model,
const char *system_prompt,
const char *user_message);
const char *akunu_chat_template(akunu_model_t model);
akunu_format_chat applies the model’s native chat template to format a system prompt and user message into the expected input format (e.g., ChatML, Llama 3 format, Gemma format). The returned string is valid until the next call.
akunu_chat_template returns the template name as a string (“chatml”, “llama3”, “gemma”, or “unknown”).
Profiling
int akunu_profile_decode_step(akunu_model_t model,
uint32_t token_id, int position,
float *timing_out, int max_entries);
const char *akunu_profile_label(akunu_model_t model, int index);
The profiling API runs each operation in its own command buffer to get per-operation GPU timing. timing_out receives an array of float values (milliseconds). akunu_profile_label returns the human-readable label for each entry (e.g., “layer.0.attention”, “layer.0.o_proj”).
GPU Sampling Operations
void akunu_gpu_temperature_scale(akunu_model_t model, float temperature);
void akunu_gpu_repetition_penalty(akunu_model_t model,
const uint32_t *token_ids,
int n_tokens, float penalty);
These functions run sampling operations directly on the GPU, avoiding CPU readback of the logits buffer. Temperature scaling is a simple element-wise multiply; repetition penalty adjusts logits for previously seen tokens.
Model Inspection
int akunu_tensor_count(akunu_model_t model);
const char *akunu_tensor_name(akunu_model_t model, int index);
uint32_t akunu_tensor_dtype(akunu_model_t model, int index);
const char *akunu_tensor_raw_dtype(akunu_model_t model, int index);
These functions allow iterating over all tensors in the model. akunu_inspect uses them to dump the full tensor list. akunu_tensor_raw_dtype returns the original dtype string (e.g., “BF16” for SafeTensors) while akunu_tensor_dtype returns the internal GGUF dtype code.
Thread Safety
The akunu API has the following thread safety guarantees:
-
Different model handles are fully independent. You can call functions on
model_Afrom thread 1 andmodel_Bfrom thread 2 concurrently with no synchronization needed. -
A single model handle is NOT thread-safe. You must serialize all calls to the same model. The
akunu_serveserver handles this with a per-model mutex. -
akunu_get_error()is thread-safe. It uses thread-local storage. -
Model loading (
akunu_load_model) is thread-safe. Each call creates its own device and resources.
Complete Example
Here is a complete program that loads a model, generates text, and reports statistics:
#include "akunu/akunu.h"
#include <stdio.h>
#include <string.h>
static bool on_token(uint32_t token_id, const char *text, void *user_data) {
printf("%s", text);
fflush(stdout);
(void)token_id;
(void)user_data;
return true; // continue generating
}
int main(int argc, char **argv) {
if (argc < 3) {
fprintf(stderr, "Usage: %s <model.gguf> <akunu.metallib>\n", argv[0]);
return 1;
}
// Load model
akunu_model_t model = akunu_load_model(argv[1], argv[2], 4096);
if (!model) {
fprintf(stderr, "Failed to load model: %s\n", akunu_get_error());
return 1;
}
// Print model info
AkunuModelConfig cfg = akunu_get_config(model);
printf("Model: %s, %u layers, %u dim, %.1f MB GPU memory\n",
cfg.architecture, cfg.n_layers, cfg.dim,
akunu_model_memory(model) / 1048576.0);
// Tokenize prompt
const char *prompt = "Explain the roofline model in one paragraph:";
uint32_t tokens[4096];
int n_tokens = akunu_encode(model, prompt, tokens, 4096);
printf("Prompt: %d tokens\n\n", n_tokens);
// Generate
AkunuSamplingConfig sampling = {
.temperature = 0.0f, // greedy
.top_k = 0,
.top_p = 1.0f,
.min_p = 0.0f,
.repeat_penalty = 1.0f
};
AkunuGenerationStats stats = akunu_generate(
model, tokens, n_tokens,
256, // max_tokens
sampling,
on_token,
NULL // user_data
);
// Report
printf("\n\n--- Stats ---\n");
printf("Prefill: %d tokens in %.1f ms (%.0f tok/s)\n",
stats.prompt_tokens, stats.prefill_time_ms,
stats.prefill_tokens_per_sec);
printf("Decode: %d tokens in %.1f ms (%.0f tok/s)\n",
stats.generated_tokens, stats.decode_time_ms,
stats.decode_tokens_per_sec);
akunu_free_model(model);
return 0;
}
Compile and run:
clang -std=c11 -I include example.c -L build -lakunu_engine \
-framework Metal -framework Foundation -framework Accelerate \
-framework IOKit -lstdc++ -o example
./example path/to/model.gguf path/to/akunu.metallib
API Function Reference
| Function | Returns | Description |
|---|---|---|
akunu_load_model | akunu_model_t | Load model, returns NULL on error |
akunu_free_model | void | Free all model resources |
akunu_get_config | AkunuModelConfig | Get model architecture metadata |
akunu_model_memory | size_t | Total GPU memory in bytes |
akunu_get_error | const char* | Last error message (thread-local) |
akunu_encode | int | Tokenize text to token IDs |
akunu_decode_token | const char* | Token ID to text |
akunu_token_count | int | Count tokens in text |
akunu_generate | AkunuGenerationStats | Full generation pipeline |
akunu_generate_continue | AkunuGenerationStats | Continue from existing KV cache |
akunu_generate_grammar | AkunuGenerationStats | Grammar-constrained generation |
akunu_generate_grammar_continue | AkunuGenerationStats | Continue with grammar |
akunu_grammar_create | akunu_grammar_t | Create grammar from GBNF |
akunu_grammar_create_from_schema | akunu_grammar_t | Create grammar from JSON Schema |
akunu_grammar_create_json | akunu_grammar_t | Create generic JSON grammar |
akunu_grammar_free | void | Free grammar |
akunu_prefill | uint32_t | Run prefill, return first token |
akunu_decode_step | uint32_t | Run one decode step |
akunu_chain_decode | int | Chain decode multiple tokens |
akunu_get_position | int | Current KV cache position |
akunu_set_speculation | void | Enable/disable speculative decode |
akunu_reset | void | Reset KV cache |
akunu_embed | int | Compute embeddings from tokens |
akunu_embed_text | int | Compute embeddings from text |
akunu_embedding_dim | int | Get embedding dimension |
akunu_format_chat | const char* | Format chat message |
akunu_chat_template | const char* | Get template name |
akunu_transcribe | const char* | Transcribe WAV file |
akunu_transcribe_pcm | const char* | Transcribe PCM buffer |
akunu_transcribe_stream | const char* | Transcribe with segment callback |
akunu_transcribe_pcm_stream | const char* | Transcribe PCM with callback |
akunu_set_timestamps | void | Enable/disable Whisper timestamps |
akunu_is_whisper | bool | Check if model is Whisper |
akunu_profile_decode_step | int | Per-operation GPU timing |
akunu_profile_label | const char* | Label for profiled operation |
akunu_gpu_temperature_scale | void | GPU-side temperature scaling |
akunu_gpu_repetition_penalty | void | GPU-side repetition penalty |
akunu_tensor_count | int | Number of model tensors |
akunu_tensor_name | const char* | Tensor name by index |
akunu_tensor_dtype | uint32_t | Tensor GGUF dtype code |
akunu_tensor_raw_dtype | const char* | Tensor original dtype string |
Summary
The C API is akunu’s external interface. It uses the opaque handle pattern, POD structs, and C linkage to provide maximum compatibility across languages and compilers. The callback-based generation pattern supports streaming output without allocating result buffers. Thread safety is per-model-handle, requiring callers to serialize access to a single model.
-
The C FFI is effectively the lingua franca of systems programming. See “Foreign Function Interface” on Wikipedia for a survey of language support. Every major language runtime supports calling C functions with zero or minimal overhead. See https://en.wikipedia.org/wiki/Foreign_function_interface. ↩
-
This “handle + function” pattern is sometimes called the “C object pattern” or “ADT (Abstract Data Type) in C.” It provides encapsulation without language-level support for classes. The Linux kernel uses this pattern extensively for device drivers. ↩
-
Grammar-constrained decoding uses the XGrammar library (v0.1.33) internally. XGrammar compiles the grammar into a token mask that can be applied at each decoding step. See the XGrammar project: https://github.com/mlc-ai/xgrammar. ↩
The Device Abstraction Layer
If you look at akunu’s core code – the table builder, the prefill engine, the decode loops – you will notice something conspicuously absent: there are no Metal API calls. No id<MTLBuffer>, no [encoder setComputePipelineState:], no MTLSizeMake. The core is pure C++17, completely agnostic to the GPU backend. All hardware interaction flows through a single abstract class: Device.
This chapter examines the Device abstraction, its concrete MetalDevice implementation, and the design decisions behind it.
The Problem
An inference engine needs to:
- Allocate GPU memory
- Load compiled shader libraries
- Create compute pipelines from shader functions
- Encode sequences of GPU commands (set buffers, set parameters, dispatch)
- Submit and synchronize
These operations look completely different across GPU APIs. Metal uses Objective-C message passing ([encoder setBuffer:...]). CUDA uses C function calls (cuLaunchKernel()). Vulkan uses verbose descriptor sets and command buffers. If the core engine hardcodes any of these, it cannot be ported.
The solution is a pure virtual interface that captures the common operations and lets each backend implement them in its native API.
The Device Interface
The Device class in src/core/device.h defines the contract. Let’s walk through it section by section.
Handles: Buffer, Pipeline, Dim3
Before the class itself, three simple structs define the currency of GPU programming:
struct Buffer {
void *handle; // Backend-specific (MTLBuffer*, CUdeviceptr, etc.)
size_t size; // Size in bytes
void *contents; // CPU-accessible pointer (UMA) or nullptr (discrete)
};
struct Pipeline {
void *handle; // Backend-specific (MTLComputePipelineState*, CUfunction, etc.)
};
struct Dim3 {
uint32_t x, y, z;
Dim3(uint32_t x = 1, uint32_t y = 1, uint32_t z = 1) : x(x), y(y), z(z) {}
};
These are POD types. No virtual methods, no reference counting, no destructors. Buffer is 24 bytes. Pipeline is 8 bytes. Dim3 is 12 bytes. They are cheap to copy, pass by value, and store in arrays – which matters because the dispatch table stores hundreds of them.
The void* handle pattern is the C equivalent of generics. On Metal, Buffer::handle is a CFBridgingRetain’d id<MTLBuffer>. On CUDA, it would be a CUdeviceptr. The core code never dereferences these pointers; it just passes them back to the device.
Device Info
virtual const char *name() const = 0;
virtual const char *backend_name() const = 0; // "metal", "cuda", "vulkan"
virtual int gpu_core_count() const = 0;
virtual int gpu_family() const { return 0; } // Apple GPU family (7=M1, 8=M2/M3, 9=M4)
virtual size_t total_memory() const = 0;
These methods expose hardware capabilities that affect algorithmic decisions. gpu_core_count() is used by ChipConfig::from_gpu() to determine SLC size estimates, GEMV variant selection, and chain decode chunk size. gpu_family() distinguishes M1/M2/M3/M4 for generation-specific tuning.
Note that gpu_family() has a default implementation returning 0. This is because non-Apple backends do not have “GPU families” – the method is Apple-specific but exposed at the abstract level because ChipConfig needs it.
Library and Pipeline Management
virtual bool load_library(const std::string& path) = 0;
virtual Pipeline get_pipeline(const std::string& name) = 0;
virtual Pipeline get_pipeline(const std::string& name,
const std::string& cache_key,
const uint32_t *constant_indices,
const uint32_t *constant_values,
int n_constants,
const uint32_t *constant_types = nullptr) = 0;
load_library loads a compiled shader library (metallib on Metal, PTX/CUBIN on CUDA). get_pipeline retrieves a named compute function, optionally specialized with function constants.
The two-overload design is important. The simple overload (get_pipeline(name)) handles the common case of a kernel with no specialization. The extended overload with constant_indices/constant_values handles Metal function constants and could map to CUDA template instantiation or Vulkan specialization constants.
The cache_key parameter in the extended overload is separate from name because the same kernel function can produce multiple specialized pipelines. For example, gemv_mlx_q4 with group_size=64, K=4096 gets a cache key of "gemv_mlx_q4_gs64_k4096", while the same kernel with K=2048 gets "gemv_mlx_q4_gs64_k2048". Both pipelines come from the same source function but are compiled with different constants.
Buffer Management
virtual Buffer allocate(size_t bytes) = 0;
virtual Buffer allocate(const void *data, size_t bytes) = 0;
virtual void free_buffer(Buffer buf) = 0;
virtual void write_buffer(Buffer dst, const void *src, size_t bytes, size_t offset = 0) {
if (dst.contents) memcpy((char *)dst.contents + offset, src, bytes);
}
virtual void read_buffer(void *dst, Buffer src, size_t bytes, size_t offset = 0) {
if (src.contents) memcpy(dst, (const char *)src.contents + offset, bytes);
}
virtual void *buffer_contents(Buffer buf) { return buf.contents; }
The two allocate overloads handle empty allocation (for scratch buffers) and initialized allocation (for uploading weight data). free_buffer releases the GPU memory.
write_buffer and read_buffer have default implementations that use memcpy – correct for UMA where buf.contents is a CPU-accessible pointer. A CUDA backend would override these to use cuMemcpyHtoD and cuMemcpyDtoH.
Command Encoding
This is the heart of the interface – the methods that encode GPU commands:
virtual void begin_encoding() = 0;
virtual void set_pipeline(Pipeline pso) = 0;
virtual void set_buffer(Buffer buf, int offset, int index) = 0;
virtual void set_bytes(const void *data, int size, int index) = 0;
virtual void set_threadgroup_memory(int bytes, int index) = 0;
virtual void dispatch(Dim3 grid, Dim3 threadgroup) = 0;
virtual void dispatch_threads(Dim3 total, Dim3 threadgroup) = 0;
virtual double end_encoding_sync() = 0;
virtual void end_encoding_async() = 0;
virtual void wait() = 0;
The encoding sequence mirrors Metal’s command encoder pattern:
begin_encoding()
set_pipeline(pso)
set_buffer(buf, offset, index)
set_bytes(params, size, index)
dispatch(grid, threadgroup)
// ... more dispatches ...
end_encoding_sync() // or end_encoding_async() + wait()
On Metal, these map directly:
| Device Method | Metal Equivalent |
|---|---|
begin_encoding() | [queue commandBuffer] + [cmdBuffer computeCommandEncoder] |
set_pipeline(pso) | [encoder setComputePipelineState:pso] |
set_buffer(buf, off, idx) | [encoder setBuffer:buf offset:off atIndex:idx] |
set_bytes(data, size, idx) | [encoder setBytes:data length:size atIndex:idx] |
set_threadgroup_memory(bytes, idx) | [encoder setThreadgroupMemoryLength:bytes atIndex:idx] |
dispatch(grid, tg) | [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg] |
dispatch_threads(total, tg) | [encoder dispatchThreads:total threadsPerThreadgroup:tg] |
end_encoding_sync() | [encoder endEncoding] + [cmdBuffer commit] + [cmdBuffer waitUntilCompleted] |
On CUDA, the mapping would be different but conceptually similar: begin_encoding would create a CUDA stream, set_pipeline would set the current function, and dispatch would call cuLaunchKernel.
Advanced Synchronization
The interface provides several synchronization patterns beyond simple sync/async:
virtual void end_encoding_handler() { end_encoding_sync(); }
virtual void wait_handler() { wait(); }
virtual void end_encoding_event() { end_encoding_async(); }
virtual void begin_encoding_after_event() { begin_encoding(); }
virtual void wait_event() { wait(); }
virtual void wait_async() { wait(); }
These have default implementations that fall back to simple sync/async, but MetalDevice overrides them with more sophisticated patterns:
| Pattern | Purpose | Metal Implementation |
|---|---|---|
| Handler | CPU+GPU overlap via completion callback | addCompletedHandler: + semaphore |
| Event | GPU-GPU pipeline via shared event | encodeSignalEvent: + encodeWaitForEvent: |
| Wait Async | Wait for previous buffer only | Check prevCmdBuffer only |
The event-based pattern is used for chain decode pipelining: while the GPU executes command buffer N, the CPU encodes command buffer N+1. The shared event ensures N+1 does not start executing until N completes, without requiring a CPU round-trip.
Dispatch Table Fast Path
virtual void encode_dispatch_table(const void *table_ptr,
int start_position, int count) {
// Default: use generic encode_chain (virtual calls per command)
}
virtual void encode_command_range(const void *table_ptr,
int position,
int cmd_start, int cmd_count) {}
These methods allow the backend to bypass the virtual call overhead of the generic encode_chain function. MetalDevice overrides encode_dispatch_table with a tight loop that calls Metal API functions directly on the ObjC encoder object, avoiding the per-command virtual dispatch through set_pipeline, set_buffer, etc.
The void* parameter (pointing to a DispatchTable) is a deliberate abstraction leak – the backend needs to know the table structure to iterate it efficiently. The alternative (encoding through the virtual interface) works but is slower.
Hardware Capabilities
virtual ChipConfig chip_config() const;
virtual const DTypeDescriptor& dtype_lookup(uint32_t dtype) const;
const char *embedding_kernel_for(uint32_t dtype) const;
const char *gemm_kernel_for(uint32_t dtype, int M) const;
chip_config() returns hardware tuning parameters derived from gpu_core_count() and gpu_family(). The default implementation (in device_defaults.cpp) calls ChipConfig::from_gpu():
ChipConfig Device::chip_config() const {
return ChipConfig::from_gpu(gpu_core_count(), gpu_family());
}
dtype_lookup returns the kernel name and dispatch geometry for a given quantization format. Both have default implementations using the global tables in dtype_descriptor.h and chip_config.h.
Factory
static std::unique_ptr<Device> Device::create_default();
This static factory method creates the default device for the current platform. On macOS, it creates a MetalDevice. On a hypothetical CUDA platform, it would create a CudaDevice.
The MetalDevice Implementation
MetalDevice in backend/metal/metal_device.h implements all pure virtual methods. It is an Objective-C++ class (compiled from .mm files) that wraps Metal API objects.
Internal State
class MetalDevice : public Device {
private:
void *device_; // id<MTLDevice> (via CFBridgingRetain)
void *queue_; // id<MTLCommandQueue>
void *library_; // id<MTLLibrary>
void *cmd_buffer_; // id<MTLCommandBuffer> (current)
void *encoder_; // id<MTLComputeCommandEncoder> (current)
size_t allocated_bytes_ = 0;
std::unordered_map<std::string, void *> pso_cache_;
std::string device_name_;
};
Notice the void* pointers – even within the MetalDevice, the ObjC objects are stored as raw pointers. This is because the header file (metal_device.h) is included by C++ translation units that cannot parse Objective-C types. The actual ObjC types are accessed through the AkunuMetalState wrapper in metal_device_impl.h:
@interface AkunuMetalState : NSObject
@property(nonatomic, strong) id<MTLDevice> device;
@property(nonatomic, strong) id<MTLCommandQueue> queue;
@property(nonatomic, strong) id<MTLLibrary> library;
@property(nonatomic, strong) id<MTLCommandBuffer> cmdBuffer;
@property(nonatomic, strong) id<MTLComputeCommandEncoder> encoder;
@property(nonatomic, strong) id<MTLCommandBuffer> prevCmdBuffer;
@property(nonatomic, strong) id<MTLSharedEvent> pipelineEvent;
@property(nonatomic, assign) uint64_t eventValue;
@end
The device_ pointer is actually a CFBridgingRetain’d reference to an AkunuMetalState instance. In the .mm implementation, a macro bridges back to the typed object:
#define STATE ((__bridge AkunuMetalState *)device_)
This pattern – ObjC state wrapped in a C++ class with void* storage – is the standard way to mix ObjC and C++ in headers that must be parseable by both compilers.1
Pipeline Caching
std::unordered_map<std::string, void *> pso_cache_;
Every get_pipeline call first checks the cache. Pipeline creation (compiling a Metal function into a compute pipeline state) is expensive – it involves shader compilation, register allocation, and GPU resource validation. Caching ensures this happens once per kernel variant.
For function-constant-specialized pipelines, the cache key includes the specialization parameters:
Pipeline MetalDevice::get_pipeline(const std::string &name,
const std::string &cache_key, ...) {
auto it = pso_cache_.find(cache_key);
if (it != pso_cache_.end())
return {it->second};
// ... create specialized pipeline ...
pso_cache_[cache_key] = ptr;
}
A model with MLX Q4 weights might create 20+ specialized pipelines (different K dimensions for each layer’s GEMV). These are all cached after the first build_dispatch_table() call.
GPU Core Count Detection
MetalDevice queries the actual GPU core count from IOKit, not from Metal:
static int iokit_gpu_core_count() {
io_iterator_t iter = 0;
IOServiceGetMatchingServices(kIOMainPortDefault,
IOServiceMatching("AGXAccelerator"), &iter);
io_service_t service = IOIteratorNext(iter);
CFNumberRef ref = IORegistryEntryCreateCFProperty(
service, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
CFNumberGetValue(ref, kCFNumberIntType, &cores);
// ...
}
Metal’s API does not expose the GPU core count directly. The AGXAccelerator IOKit service (Apple’s GPU driver) stores it as a registry property. This is more reliable than string-parsing the device name, though a fallback heuristic exists:
if ([name containsString:@"Ultra"]) cached = 60;
else if ([name containsString:@"Max"]) cached = 30;
else if ([name containsString:@"Pro"]) cached = 16;
else cached = 10;
GPU Family Detection
int MetalDevice::gpu_family() const {
id<MTLDevice> dev = STATE.device;
if ([dev supportsFamily:MTLGPUFamilyApple9]) return 9; // M4
if ([dev supportsFamily:MTLGPUFamilyApple8]) return 8; // M2/M3
if ([dev supportsFamily:MTLGPUFamilyApple7]) return 7; // M1
return 6;
}
This is used by ChipConfig::from_gpu() to apply generation-specific tuning. M4 (family 9) has improved cache hierarchy and native BF16 support; M1 (family 7) has smaller SLC and no BF16.
Buffer Allocation
Buffer MetalDevice::allocate(size_t bytes) {
id<MTLBuffer> buf = [STATE.device newBufferWithLength:MAX(bytes, 16)
options:MTLResourceStorageModeShared];
void *h = (void *)CFBridgingRetain(buf);
allocated_bytes_ += MAX(bytes, 16);
return {h, bytes, [buf contents]};
}
Key details:
- Minimum 16 bytes (alignment requirement)
MTLResourceStorageModeShared(UMA zero-copy)CFBridgingRetainprevents ARC from releasing the buffer[buf contents]provides the CPU-accessible pointerallocated_bytes_tracks total GPU memory usage
The Encode Dispatch Table Fast Path
The most performance-critical method is encode_dispatch_table, which replays the dispatch table for chain decode. Rather than calling the virtual set_pipeline/set_buffer/dispatch methods (which go through vtable dispatch), it operates directly on the Metal encoder:
void MetalDevice::encode_dispatch_table(const void *table_ptr,
int start_position, int count) {
const DispatchCmd *cmds = table.commands.data();
id<MTLComputeCommandEncoder> enc = STATE.encoder;
for (int tok = 0; tok < count; tok++) {
for (int c = 0; c < n_cmds; c++) {
const DispatchCmd &cmd = cmds[c];
[enc setComputePipelineState:cmd.pso.handle];
// ... buffer binding, param patching, dispatch ...
}
}
}
This avoids ~6 virtual calls per command (set_pipeline, set_buffer x N, set_bytes, dispatch), which at 200 commands per token and 128 tokens per batch would be ~150,000 virtual calls per submission. The direct Metal calls are significantly faster.2
Thread Safety Note
From the header:
/// THREAD SAFETY: This class is NOT thread-safe. All encoding operations
/// must be called from a single thread. The server serializes access
/// via the per-model mutex in ModelEntry.
Metal command encoders are inherently single-threaded. The MetalDevice does not add locking because the performance cost would be unacceptable in the hot path. Callers must ensure serialized access.
Why This Abstraction Exists
Three reasons:
1. Portability. While akunu currently only has a Metal backend, the abstraction makes it possible to add CUDA, Vulkan, or WebGPU backends without touching the core engine. The dispatch table, table builder, and decode loops work identically regardless of backend.
2. Testability. A mock Device implementation can be used for unit testing without a real GPU. The table builder can be tested by checking the commands it generates, without actually dispatching them.
3. Code organization. The separation forces a clean boundary between “what to compute” (core) and “how to compute it” (backend). Objective-C++ is confined to a single .mm file. The rest of the project compiles as standard C++17.
The abstraction is intentionally thin. It does not try to abstract away GPU programming – it just abstracts away the specific API calls. You still think in terms of buffers, pipelines, threadgroups, and dispatches. This is a pragmatic choice: the alternative (a high-level “tensor operation” abstraction) would hide the performance-critical details that make the difference between 30 tok/s and 70 tok/s.
Summary
The Device abstraction layer is a pure virtual C++ class with ~20 virtual methods covering device info, buffer management, command encoding, and synchronization. MetalDevice implements it using Objective-C++ with direct Metal API calls, an ObjC state wrapper for ARC compatibility, and a pipeline cache for kernel specialization. The fast path (encode_dispatch_table) bypasses the virtual interface entirely for maximum chain decode throughput.
-
This pattern is described in Apple’s “Mixing Objective-C and C++” technical note. The key challenge is that Objective-C types (
id<MTLDevice>, etc.) cannot appear in C++ headers that might be included by pure C++ translation units. Thevoid*+ bridge cast pattern is the standard workaround. See https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ProgrammingWithObjectiveC/Introduction/Introduction.html. ↩ -
Virtual function call overhead on modern CPUs is typically 2-5ns due to the indirect branch prediction miss on the first call and potential instruction cache miss for the vtable. At 150,000 calls per submission, this adds up to 0.3-0.75ms – significant compared to the ~14ms total decode time for a 7B model. ↩
Architecture Descriptors: Data-Driven Design
Here is a problem that every multi-architecture inference engine faces: Llama uses SiLU activation and standard RoPE. Qwen3 uses SiLU but NeoX-style RoPE and per-head QK norms. Gemma uses GELU activation, NeoX RoPE, QK norms, post-attention norms, post-FFN norms, and scales the embedding output by sqrt(dim). Whisper uses LayerNorm instead of RMSNorm, has bias terms on every linear layer, uses sinusoidal positional encoding instead of RoPE, and has an encoder-decoder architecture.
How do you handle all of this without drowning in if (arch == "llama") ... else if (arch == "qwen3") ... else if (arch == "gemma") ... branches scattered across your codebase?
Akunu’s answer is the ArchDescriptor: a plain data struct that captures all architecture-specific behavior as fields. The table builder and prefill engine read from this struct – they never branch on architecture name. Adding a new architecture means writing one factory function and one case in the lookup function. Zero changes to the core engine.
The Problem in Detail
Consider the table builder. It needs to emit dispatch commands for one transformer layer. Here is a partial list of things that vary by architecture:
| Aspect | Llama | Qwen3 | Gemma | Whisper |
|---|---|---|---|---|
| Activation | SiLU gate | SiLU gate | GELU gate | GELU (no gate) |
| RoPE style | Standard (interleaved) | NeoX (split-half) | NeoX (split-half) | None (sinusoidal PE) |
| QK head norms | No | Yes | Yes | No |
| Post-attention norm | No | No | Yes | No |
| Post-FFN norm | No | No | Yes | No |
| Embedding scaling | No | No | Yes (sqrt(dim)) | No |
| Tied embeddings | No | Yes | Yes | Yes |
| Encoder-decoder | No | No | No | Yes |
| Cross attention | No | No | No | Yes |
| Linear bias | No | No | No | Yes |
| Norm type | RMSNorm | RMSNorm | RMSNorm | LayerNorm |
Without a descriptor, you would need if-else branches for each of these in every place the table builder makes an architecture-dependent decision. That is 11+ branch points per layer, across hundreds of lines of code. It is unreadable, error-prone, and a maintenance nightmare.
The ArchDescriptor Struct
The descriptor is defined in src/core/arch_descriptor.h:
struct ArchDescriptor {
// Activation
const char *activation_kernel; // "silu_gate_f16", "gelu_gate_f16", or "gelu_f16"
// Embedding
float embedding_scale; // 0 = no scaling, >0 = scale by this value
// Per-head norms
bool has_qk_norm;
// Post-norms (Gemma-style)
bool has_post_attn_norm;
bool has_post_ffn_norm;
const char *post_attn_norm_key; // weight suffix: "post_attention_norm"
const char *post_ffn_norm_key; // "post_ffw_norm"
// MLX quantization
int quant_bits; // 0 = not MLX, 3/4/6/8 = MLX quant bits
int quant_group_size; // typically 64
// RoPE
const char *rope_kernel; // fused RoPE+KV write kernel name
const char *rope_standalone; // standalone RoPE kernel name
Buffer rope_freqs; // precomputed frequency divisors
// Output
bool tie_embeddings; // use token_embedding.weight for logit projection
// Encoder-decoder
bool is_encoder_decoder;
bool has_cross_attention;
bool has_conv_frontend;
bool has_bias;
const char *norm_type; // "rmsnorm" or "layernorm"
const char *encoder_activation;
bool is_embedding_model; // BERT-style encoder-only
};
Every field is a simple type: bool, int, float, const char*, or Buffer (a POD struct). No virtual methods. No inheritance. No dynamism. The descriptor is created once and read many times.
Field Categories
The fields fall into several categories:
Kernel selection fields (activation_kernel, rope_kernel, rope_standalone, encoder_activation): These are kernel function names that get passed directly to device.get_pipeline(). The table builder does not care what activation function is used – it just dispatches whatever kernel the descriptor says.
Boolean feature flags (has_qk_norm, has_post_attn_norm, has_post_ffn_norm, tie_embeddings, is_encoder_decoder, has_cross_attention, has_conv_frontend, has_bias, is_embedding_model): These control whether certain dispatch commands are emitted. A false flag means the corresponding block is simply skipped.
Numeric parameters (embedding_scale, quant_bits, quant_group_size): These are values that get baked into kernel parameters.
Weight key strings (post_attn_norm_key, post_ffn_norm_key): These are the suffixes used to look up weight tensors. Gemma calls its post-attention norm weights layers.N.post_attention_norm.weight, while other architectures do not have them at all.
Factory Functions
Each supported architecture has a factory function that returns a fully initialized descriptor:
arch_llama() – The Default
inline ArchDescriptor arch_llama() {
ArchDescriptor d = {};
d.activation_kernel = "silu_gate_f16";
d.embedding_scale = 0.0f;
d.has_qk_norm = false;
d.has_post_attn_norm = false;
d.has_post_ffn_norm = false;
d.rope_kernel = "rope_qkv_write_f16"; // standard (interleaved)
d.rope_standalone = "rope_f16";
d.tie_embeddings = false;
d.is_encoder_decoder = false;
d.has_bias = false;
d.norm_type = "rmsnorm";
return d;
}
Llama is the simplest architecture and serves as the default. SiLU-gated activation, standard RoPE, no special norms, no encoder. Most LLaMA-family models (Llama 2, Llama 3, Mistral, etc.) use this descriptor.1
arch_qwen3() – Incremental Derivation
inline ArchDescriptor arch_qwen3() {
ArchDescriptor d = arch_llama(); // start from llama defaults
d.has_qk_norm = true;
d.rope_kernel = "rope_neox_qkv_write_f16"; // NeoX (split-half)
d.rope_standalone = "rope_neox_f16";
d.tie_embeddings = true;
return d;
}
Qwen3 starts from Llama and overrides three things: adds QK head norms, switches to NeoX-style RoPE, and ties the embedding weights. This derivation pattern makes the differences explicit and minimizes redundancy.
arch_gemma() – The Complex One
inline ArchDescriptor arch_gemma(int dim) {
ArchDescriptor d = arch_llama(); // start from llama defaults
d.activation_kernel = "gelu_gate_f16";
d.embedding_scale = sqrtf((float)dim);
d.has_qk_norm = true;
d.has_post_attn_norm = true;
d.has_post_ffn_norm = true;
d.post_attn_norm_key = "post_attention_norm";
d.post_ffn_norm_key = "post_ffw_norm";
d.rope_kernel = "rope_neox_qkv_write_f16";
d.rope_standalone = "rope_neox_f16";
d.tie_embeddings = true;
return d;
}
Gemma is the most feature-rich decoder-only architecture. Notice that embedding_scale is computed from the model dimension – this is why the factory takes dim as a parameter. The value sqrt(dim) is specific to Gemma’s architectural design.2
arch_whisper() – Encoder-Decoder
inline ArchDescriptor arch_whisper() {
ArchDescriptor d = {};
d.activation_kernel = "gelu_f16"; // plain GELU (no gate)
d.rope_kernel = nullptr; // no RoPE
d.rope_standalone = nullptr;
d.tie_embeddings = true;
d.is_encoder_decoder = true;
d.has_cross_attention = true;
d.has_conv_frontend = true;
d.has_bias = true;
d.norm_type = "layernorm";
d.encoder_activation = "gelu_f16";
return d;
}
Whisper does NOT derive from Llama – it starts from a zeroed struct because almost everything is different. No RoPE, no gated activation, LayerNorm instead of RMSNorm, bias on every linear layer, and an encoder-decoder architecture with cross-attention and a convolutional audio frontend.3
arch_bert() – Embedding Model
inline ArchDescriptor arch_bert() {
ArchDescriptor d = {};
d.activation_kernel = "silu_gate_f16"; // SwiGLU (nomic-bert, modernBERT)
d.rope_kernel = "rope_neox_qkv_write_f16";
d.rope_standalone = "rope_neox_f16";
d.norm_type = "rmsnorm";
d.is_embedding_model = true;
return d;
}
BERT-style models (specifically modern variants like nomic-bert) use a LLaMA-like architecture with bidirectional attention. The is_embedding_model flag tells the inference engine to skip autoregressive decoding and instead return the hidden states for mean-pooling.
The Lookup Function
inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
if (strstr(arch_name, "whisper")) return arch_whisper();
if (strstr(arch_name, "bert")) return arch_bert();
if (strstr(arch_name, "qwen")) return arch_qwen3();
if (strstr(arch_name, "gemma")) return arch_gemma(dim);
return arch_llama(); // Default: LLaMA-like
}
This function reads the architecture field from the GGUF metadata (a string like “llama”, “qwen3”, “gemma2”, “whisper”) and returns the appropriate descriptor. The strstr matching is intentionally loose – “gemma2” and “gemma3” both match “gemma”, “qwen2” and “qwen3” both match “qwen”. This works because the architectural differences between Gemma 2 and Gemma 3, or between Qwen 2 and Qwen 3, are captured in the model’s weight metadata (different head counts, etc.), not in the descriptor.
The default fallback is arch_llama(), which handles the vast majority of GGUF models in the wild (Llama, Mistral, CodeLlama, etc.).
How the Table Builder Uses Descriptors
Let’s trace through the table builder to see how the descriptor eliminates branching.
Activation Kernel
// In the SiLU/GELU dispatch:
Pipeline act_pso = device.get_pipeline(arch.activation_kernel);
The table builder does not know or care whether the activation is SiLU or GELU. It dispatches whatever arch.activation_kernel says. This single line replaces:
// Without descriptors (do NOT do this):
Pipeline act_pso;
if (cfg.architecture == "llama" || cfg.architecture == "qwen3")
act_pso = device.get_pipeline("silu_gate_f16");
else if (cfg.architecture == "gemma")
act_pso = device.get_pipeline("gelu_gate_f16");
else if (cfg.architecture == "whisper")
act_pso = device.get_pipeline("gelu_f16");
RoPE Selection
// In the RoPE dispatch:
if (arch.rope_kernel) {
Pipeline rope_pso = device.get_pipeline(arch.rope_kernel);
// ... emit RoPE command ...
}
// If rope_kernel is nullptr (Whisper), the block is skipped entirely
Post-Norms
if (arch.has_post_attn_norm) {
snprintf(name, sizeof(name), "layers.%d.%s.weight",
layer, arch.post_attn_norm_key);
Buffer post_norm_w = weights.get_tensor(name);
emit_standalone_norm(table, device, chip, scratch.residual,
post_norm_w, scratch.post_norm, dim, cfg.norm_eps);
}
The boolean flag controls whether the block exists; the key string controls which weight tensor is loaded. For Llama and Qwen3, has_post_attn_norm is false and this code is never reached.
Embedding Scaling
if (arch.embedding_scale > 0.0f) {
float scale = arch.embedding_scale;
CmdBuilder(table, device.get_pipeline("temperature_scale_f16"), ...)
.buf(scratch.h0, 0)
.params(scale, 1)
.emit();
}
Only Gemma has embedding_scale > 0, so only Gemma gets this dispatch command.
Tied Embeddings
const char *logit_name = arch.tie_embeddings
? "token_embedding.weight"
: "output.weight";
Buffer logit_w = weights.get_tensor(logit_name);
A single ternary replaces a multi-way branch.
Fused QK-Norm + RoPE
The table builder checks whether the fused kernel is applicable:
bool use_fused_norm_rope = arch.has_qk_norm &&
strcmp(arch.rope_kernel, "rope_neox_qkv_write_f16") == 0;
This is true for Qwen3 and Gemma (NeoX RoPE + QK norms), false for everything else. When true, a single fused kernel replaces 3 separate dispatches.
The DTypeDescriptor: Kernel Registry
While ArchDescriptor captures architecture-level variation, DTypeDescriptor captures dtype-level variation. It is defined in src/core/dtype_descriptor.h:
struct DTypeDescriptor {
uint32_t dtype;
const char *gemv_kernel;
const char *gemv_large_kernel;
const char *gemv_wide_kernel;
const char *gemm_kernel;
const char *gemm_small_kernel;
const char *embedding_kernel;
const char *fused_silu_kernel;
const char *fused_silu_large_kernel;
int gemv_rows_per_tg;
int gemv_tg_size;
int large_rows_per_tg;
int large_tg_size;
int wide_rows_per_tg;
int wide_tg_size;
int fused_silu_rows;
int fused_silu_tg;
int fused_silu_large_rows;
int fused_silu_large_tg;
bool is_mlx;
};
This is a flat POD struct with no pointers to chase. Each supported GGUF dtype code gets an entry in a static lookup table:
static const DTypeDescriptor kDTypes[] = {
// dtype gemv gemv_large gemv_wide gemm gemm_small embed fused_silu fused_silu_large rows tg lg_r lg_tg w_r w_tg fs_r fs_tg fsl_r fsl_tg mlx
{0, "gemv_f16", nullptr, "gemv_wide_f16", "simd_gemm_f16", "simd_gemm_small_f16", nullptr, nullptr, nullptr, 16, 128, 0, 0, 64, 256, 0, 0, 0, 0, false},
{2, "gemv_q4_0", nullptr, "gemv_wide_q4_0", "simd_gemm_q4_0", "simd_gemm_small_q4_0","embedding_lookup_q4_0","gemv_q4_0_silu", nullptr, 16, 128, 0, 0, 64, 256, 16, 128, 0, 0, false},
{8, "gemv_q8_0", nullptr, "gemv_wide_q8_0", "simd_gemm_q8_0", "simd_gemm_small_q8_0","embedding_lookup_q8_0", nullptr, nullptr, 32, 256, 0, 0, 64, 256, 0, 0, 0, 0, false},
{12, "gemv_q4_k", nullptr, "gemv_wide_q4_k", "simd_gemm_q4_k", "simd_gemm_small_q4_k","embedding_lookup_q4_k", nullptr, nullptr, 16, 256, 0, 0, 32, 256, 0, 0, 0, 0, false},
{100,"gemv_mlx_q4", nullptr, "gemv_wide_mlx_q4","simd_gemm_mlx_q4","simd_gemm_small_mlx_q4","embedding_lookup_mlx_q4","gemv_mlx_q4_silu",nullptr, 16, 128, 0, 0, 32, 256, 16, 128, 0, 0, true},
// ... more entries ...
};
The table currently has 17 entries covering F16, BF16, Q4_0, Q4_1, Q5_0, Q8_0, Q2_K through Q6_K, and MLX Q3/Q4/Q6/Q8 formats.
How the Table Builder Uses DTypeDescriptor
The emit_gemv helper in table_builder.cpp reads from the descriptor to select the right kernel and dispatch geometry:
static void emit_gemv(DispatchTable& table, Device& device,
const ChipConfig& chip,
Buffer input, Buffer weight, Buffer output,
int output_offset, uint32_t dtype, int N, int K, ...) {
const auto& dt = device.dtype_lookup(dtype);
// Select kernel variant
bool use_wide = (N > chip.wide_gemv_threshold && dt.gemv_wide_kernel != nullptr);
bool use_large = (!use_wide && dt.gemv_large_kernel != nullptr
&& K >= chip.q4_small_k_threshold);
const char *kernel_name;
int rows, tg;
if (use_wide) {
kernel_name = dt.gemv_wide_kernel;
rows = dt.wide_rows_per_tg;
tg = dt.wide_tg_size;
} else if (use_large) {
kernel_name = dt.gemv_large_kernel;
rows = dt.large_rows_per_tg;
tg = dt.large_tg_size;
} else {
kernel_name = dt.gemv_kernel;
rows = dt.gemv_rows_per_tg;
tg = dt.gemv_tg_size;
}
Pipeline pso = device.get_pipeline(kernel_name);
int n_groups = (N + rows - 1) / rows;
// ... create and emit DispatchCmd ...
}
The table builder never mentions “Q4_0” or “Q8_0” by name. It looks up the descriptor for the weight’s dtype code and uses whatever kernel and geometry the table specifies. Supporting a new quantization format means adding one row to kDTypes[] and writing the corresponding Metal kernel.
The is_mlx Flag
MLX quantized formats require special handling: function constants for group_size and K, a different parameter struct layout (MLXParams vs GEMMParams), and different weight byte calculations. The is_mlx boolean flag in DTypeDescriptor controls this:
if (dt.is_mlx) {
uint32_t fc_indices[] = {0, 1};
uint32_t fc_values[] = {(uint32_t)quant_group_size, (uint32_t)K};
pso = device.get_pipeline(kernel_name, cache_key, fc_indices, fc_values, 2);
} else {
pso = device.get_pipeline(kernel_name);
}
This is the one architectural branch in emit_gemv that is not purely data-driven. It could be eliminated by adding a “pipeline creation strategy” function pointer to the descriptor, but the additional complexity is not justified for two variants.
Design Principles
The ArchDescriptor and DTypeDescriptor embody several design principles:
1. Data Over Code
Instead of encoding architecture-specific behavior in if/else chains, encode it in data structures. The table builder reads the data; it does not interpret architecture names.
2. Flat Over Deep
Both descriptors are flat POD structs. No inheritance hierarchies, no virtual methods, no builder patterns. This makes them trivially copyable, inspectable in a debugger, and cacheline-friendly.
3. Explicit Over Implicit
Every architecture-specific behavior is a named field. When you read arch.has_qk_norm, you know exactly what it controls. There are no hidden side effects, no overridden methods with subtle behavior differences.
4. Default-Safe
The zero-initialization of ArchDescriptor (ArchDescriptor d = {}) produces a safe state: no activation (nullptr), no RoPE (nullptr), no norms (false), not encoder-decoder (false). Every factory function explicitly sets the fields it needs. This prevents “forgot to set X” bugs.
5. Derivation Without Inheritance
arch_qwen3() derives from arch_llama() by copying the struct and overriding fields. This is simpler and more explicit than C++ inheritance. You can see all the differences at a glance.
Adding a New Architecture
To add support for a hypothetical “Falcon” architecture:
- Write a factory function:
inline ArchDescriptor arch_falcon() {
ArchDescriptor d = arch_llama(); // similar to Llama
d.has_bias = true; // Falcon has bias
d.rope_kernel = "rope_neox_qkv_write_f16"; // NeoX RoPE
d.rope_standalone = "rope_neox_f16";
return d;
}
- Add a case to the lookup:
inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
if (strstr(arch_name, "falcon")) return arch_falcon();
// ... existing cases ...
}
- That’s it. No changes to
table_builder.cpp,prefill.cpp, or any other core file (assuming the existing kernel set handles the architecture’s needs).
Summary
The ArchDescriptor pattern replaces architecture-specific branching with data-driven dispatch. A flat POD struct captures all variation points (activation, RoPE, norms, embeddings, encoder/decoder). Factory functions create descriptors for each architecture, with incremental derivation from the Llama default. The DTypeDescriptor does the same for quantization formats, mapping dtype codes to kernel names and dispatch geometry. Together, these two descriptors make the table builder and prefill engine completely agnostic to both architecture and quantization format.
-
Touvron, H., et al. (2023). “LLaMA: Open and Efficient Foundation Language Models.” arXiv:2302.13971. LLaMA’s architecture (RMSNorm, SwiGLU, RoPE, no bias) has become the de facto standard for open-weight LLMs. See https://arxiv.org/abs/2302.13971. ↩
-
Google DeepMind. (2024). “Gemma: Open Models Based on Gemini Research and Technology.” arXiv:2403.08295. Gemma scales embedding outputs by sqrt(dim), which is uncommon in other LLM architectures but follows the original Transformer convention. See https://arxiv.org/abs/2403.08295. ↩
-
Radford, A., et al. (2023). “Robust Speech Recognition via Large-Scale Weak Supervision.” Proceedings of ICML 2023. Whisper’s architecture uses sinusoidal positional encoding, LayerNorm, and cross-attention, following the original Transformer encoder-decoder design. See https://arxiv.org/abs/2212.04356. ↩
Chip Configuration and Hardware Tuning
Akunu runs on everything from an M1 MacBook Air with 8 GPU cores and 68 GB/s bandwidth to an M4 Ultra Mac Studio with 80 GPU cores and over a terabyte per second of bandwidth. The same inference engine, the same kernels, the same dispatch table structure – but with very different optimal tuning parameters. A threadgroup size that saturates an M1 might leave an M4 Max starved. A chain decode chunk size that is efficient on M4 might cause command buffer overhead to dominate on M1.
The ChipConfig struct solves this: it derives all hardware-specific tuning parameters from two inputs – the GPU core count and the GPU family number. Every magic number in the engine lives here, not scattered across the codebase.
The ChipConfig Struct
Defined in src/core/chip_config.h:
struct ChipConfig {
// === Hardware info ===
int gpu_cores; // GPU core count (from IOKit)
int gpu_family; // Apple GPU family (7=M1, 8=M2/M3, 9=M4)
int max_threads_per_tg; // Metal max (typically 1024)
// === SLC (System Level Cache) ===
int slc_bytes; // estimated SLC size
bool should_fuse_weights; // fuse gate+up, QKV when SLC is large enough
// === GEMV tuning ===
int q4_small_k_threshold; // Q4_0: use small-K variant below this
int wide_gemv_threshold; // switch to wide GEMV when N exceeds this
bool gemv_wide_standard; // use 256-thread standard GEMV on Pro+
// === Chain decode ===
int chain_decode_chunk; // tokens per GPU submission
// === Prefill ===
int max_prefill_chunk; // max tokens per prefill GEMM batch
// === Norm dispatch ===
int norm_tg_size; // threadgroup size for RMSNorm
};
Every field has a clear purpose and measurable impact on performance. Let’s go through each one.
Chip Detection
The factory function ChipConfig::from_gpu(int cores, int family) builds a config from two hardware-detected values:
cores: The GPU core count, queried from IOKit’sAGXAcceleratorservice viaIORegistryEntryCreateCFProperty(CFSTR("gpu-core-count")). This is the most reliable indicator of chip tier.family: The Apple GPU family number, queried via Metal’s[device supportsFamily:]. Family 7 = M1, Family 8 = M2/M3, Family 9 = M4.
These two values together uniquely identify the chip tier and generation:
| Cores | Family | Chip | SLC (est.) | Bandwidth |
|---|---|---|---|---|
| 8-10 | 7 | M1 | 8 MB | 68 GB/s |
| 14-16 | 7 | M1 Pro | 24 MB | 200 GB/s |
| 24-32 | 7 | M1 Max | 48 MB | 400 GB/s |
| 48-64 | 7 | M1 Ultra | 96 MB | 800 GB/s |
| 8-10 | 8 | M2/M3 | 8 MB | 100-150 GB/s |
| 16-18 | 8 | M2/M3 Pro | 24 MB | 150-200 GB/s |
| 30-40 | 8 | M2/M3 Max | 48 MB | 400 GB/s |
| 8-10 | 9 | M4 | 16 MB | 120 GB/s |
| 16-20 | 9 | M4 Pro | 32 MB | 273 GB/s |
| 30-40 | 9 | M4 Max | 48 MB | 546 GB/s |
| 60-80 | 9 | M4 Ultra | 96 MB | 1092 GB/s |
SLC Size Estimation
The System Level Cache (SLC) is the large last-level cache shared between CPU, GPU, and other IP blocks. Apple does not officially disclose SLC sizes, but they can be estimated from die analysis and performance profiling.1
if (cores >= 60)
c.slc_bytes = 96 * 1024 * 1024; // Ultra
else if (cores >= 30)
c.slc_bytes = 48 * 1024 * 1024; // Max
else if (cores >= 16)
c.slc_bytes = (family >= 9) ? 32 * 1024 * 1024 : 24 * 1024 * 1024; // Pro
else
c.slc_bytes = (family >= 9) ? 16 * 1024 * 1024 : 8 * 1024 * 1024; // Base
Key observations:
- M4 (family 9) gets larger SLC estimates at every tier. Die analysis suggests M4 improved the cache hierarchy significantly.2
- Ultra chips (2x Max die-to-die) get double the Max SLC.
Impact on Weight Fusion
The SLC size directly controls whether QKV and gate+up weight fusion is beneficial:
c.should_fuse_weights = (c.slc_bytes >= 16 * 1024 * 1024); // Pro+ and M4 Base
Weight fusion concatenates two or three weight matrices into one, replacing multiple GEMV dispatches with a single larger one. This reduces kernel launch overhead and can improve cache utilization – but only if the fused weight matrix fits (or mostly fits) in the SLC. On an M1 base with 8 MB SLC, a fused QKV weight matrix for a 4096-dim model might be:
$$Q + K + V = (4096 \times 4096 + 2 \times 4096 \times 1024) \times 0.5 \text{ bytes/elem} = 12.3 \text{ MB (Q4_0)}$$
This exceeds the 8 MB SLC, causing cache thrashing. On an M4 Pro with 32 MB SLC, the fused matrix fits comfortably, and the fusion saves two kernel launches.
GEMV Kernel Selection
Three parameters control GEMV variant selection:
q4_small_k_threshold
if (family >= 9 && cores >= 16)
c.q4_small_k_threshold = 512; // M4 Pro+
else if (cores >= 16)
c.q4_small_k_threshold = 1024; // M1-M3 Pro+
else
c.q4_small_k_threshold = 2048; // Base chips
This determines when to use a “large-K” GEMV variant that uses more SIMD groups for the K-dimension reduction. Larger chips can afford more parallelism at smaller K values because they have more GPU cores to keep busy.
wide_gemv_threshold
c.wide_gemv_threshold = 32768;
When the output dimension N exceeds this threshold (e.g., vocabulary projection with 128K tokens), switch to a “wide” GEMV kernel with more output rows per threadgroup. This is constant across all chips because the decision depends on the N dimension, not the hardware.
gemv_wide_standard
c.gemv_wide_standard = (cores >= 16);
On Pro+ chips (16+ cores), use 256-thread threadgroups for the standard GEMV instead of 128-thread. More threads per threadgroup means more SIMD groups, which means better occupancy on chips with many cores. On base chips with 8-10 cores, the extra threads would just increase register pressure without improving occupancy.
Chain Decode Chunk Size
if (cores >= 60)
c.chain_decode_chunk = 128; // Ultra
else if (cores >= 30)
c.chain_decode_chunk = 128; // Max
else if (family >= 9)
c.chain_decode_chunk = 128; // M4 Pro/Base
else if (cores >= 16)
c.chain_decode_chunk = 96; // M3 Pro
else
c.chain_decode_chunk = 64; // M1/M2 Base
This is the number of tokens batched into a single GPU command buffer during chain decode. The trade-offs:
- Larger chunks: Amortize command buffer overhead across more tokens. Better GPU utilization. But each submission takes longer, increasing latency to first token in a batch.
- Smaller chunks: Lower per-submission latency. But higher overhead ratio (command buffer setup / useful work).
M4’s improved GPU command processor can handle larger chunks efficiently. Older chips with smaller caches and slower command processing benefit from smaller chunks.
For a 32-layer model with ~200 commands per token:
- Chunk 128: 25,600 dispatches per command buffer
- Chunk 64: 12,800 dispatches per command buffer
The total generation throughput is virtually identical (both are bandwidth-limited), but the chunk size affects how quickly the first token in a chunk appears.
Prefill Chunk Size
c.max_prefill_chunk = 4096;
Maximum tokens per prefill GEMM batch. This is constant because prefill is compute-bound (large M), and the GEMM kernel performance is relatively insensitive to batch size above a few hundred tokens. The limit prevents excessive memory usage for the batch scratch buffers.
Norm Threadgroup Size
c.norm_tg_size = 1024;
RMSNorm and LayerNorm use a single threadgroup to process the entire dimension vector. The threadgroup size is min(dim, norm_tg_size). For a 4096-dim model, this means 1024 threads (the Metal maximum), with each thread processing 4 elements.
This is constant because normalization kernels are lightweight and the threadgroup size has minimal impact on performance.
How ChipConfig Flows Through the System
The configuration flows from hardware detection to kernel dispatch:
┌──────────────┐ ┌────────────────┐ ┌──────────────────┐
│ MetalDevice │─────>│ ChipConfig │─────>│ table_builder │
│ gpu_cores() │ │ from_gpu() │ │ emit_gemv() │
│ gpu_family() │ │ │ │ build_dispatch │
└──────────────┘ └────────────────┘ └──────────────────┘
│
▼
┌──────────────────┐
│ DispatchTable │
│ (commands[] │
│ with resolved │
│ kernel+geom) │
└──────────────────┘
MetalDevice::chip_config()callsChipConfig::from_gpu(gpu_core_count(), gpu_family())- The config is passed to
build_dispatch_table()andprefill() - The table builder uses
chip.should_fuse_weightsto decide QKV/gate+up fusion emit_gemv()useschip.wide_gemv_threshold,chip.q4_small_k_threshold, andchip.gemv_wide_standardto select kernel variants- The chain decoder uses
chip.chain_decode_chunkfor batch sizing - The resulting dispatch table has hardware-optimal kernel choices baked in
All of this happens once at model load time. At inference time, the dispatch table is simply replayed – no hardware queries, no branching, no configuration lookups.
Real-World Impact
To make the impact concrete, here are approximate decode speeds for a 7B Q4_0 model on different chips, showing how ChipConfig-driven tuning affects performance:
| Chip | Cores | Family | Chunk | Fuse? | Wide GEMV? | Decode tok/s |
|---|---|---|---|---|---|---|
| M1 | 8 | 7 | 64 | No | No | ~15 |
| M1 Pro | 16 | 7 | 96 | Yes | Yes | ~45 |
| M1 Max | 32 | 7 | 128 | Yes | Yes | ~90 |
| M4 | 10 | 9 | 128 | Yes | No | ~28 |
| M4 Pro | 20 | 9 | 128 | Yes | Yes | ~65 |
| M4 Max | 40 | 9 | 128 | Yes | Yes | ~130 |
The “Fuse?” column shows whether weight fusion is enabled (SLC >= 16 MB). The “Wide GEMV?” column shows whether 256-thread GEMV is used (cores >= 16). These configuration-driven choices account for a significant portion of the performance difference between “works” and “fast.”
Design Philosophy
ChipConfig embodies a few key principles:
1. All magic numbers in one place. Instead of hardcoded constants in table_builder.cpp, prefill.cpp, and decode_loop.cpp, every tuning parameter is in ChipConfig. This makes it easy to review, adjust, and reason about the tuning.
2. Derived, not configured. Users do not set these values. They are automatically derived from hardware detection. This eliminates “you need to tune parameter X for your specific hardware” friction.
3. Conservative defaults. The base chip configuration (64 tokens/chunk, no fusion, 128-thread GEMV) is conservative. It works correctly on every chip. The Pro/Max/Ultra configurations are more aggressive but only activate when the hardware supports them.
4. Chip tier matters more than generation. The core count (which determines chip tier: base/Pro/Max/Ultra) has more impact on tuning than the GPU family (which determines generation: M1/M2/M3/M4). An M1 Max (32 cores) gets more aggressive tuning than an M4 base (10 cores), even though the M4 is a newer generation.
Deep Dive: The SLC and Weight Fusion Decision
The should_fuse_weights decision deserves a more thorough analysis because it is one of the most impactful configuration choices. Let’s trace through the numbers for a concrete model.
Consider Llama 3 8B with Q4_0 quantization:
dim = 4096,n_heads = 32,n_kv_heads = 8,head_dim = 128q_dim = 4096,kv_dim = 1024,ffn_dim = 14336
QKV fusion candidate:
| Weight | Shape | Bytes (Q4_0) |
|---|---|---|
| Q projection | 4096 x 4096 | ~9.4 MB |
| K projection | 1024 x 4096 | ~2.4 MB |
| V projection | 1024 x 4096 | ~2.4 MB |
| Fused QKV | 6144 x 4096 | ~14.2 MB |
On an M1 base (8 MB SLC), the fused QKV matrix (14.2 MB) far exceeds the cache. Reading it causes cache thrashing – the end of the matrix evicts the beginning before the next token needs it. Three separate GEMVs at 9.4 + 2.4 + 2.4 = 14.2 MB read the same total bytes but each individual matrix is smaller. The K and V matrices (2.4 MB each) may actually fit in the SLC between dispatches, providing a small cache benefit.
On an M4 Pro (32 MB SLC), the fused QKV matrix (14.2 MB) fits with room to spare. The single GEMV dispatch saves two kernel launches (~0.01ms each) and potentially benefits from better memory controller scheduling (one large sequential read vs. three separate reads with dispatch overhead between them).
Gate+Up fusion candidate:
| Weight | Shape | Bytes (Q4_0) |
|---|---|---|
| Gate projection | 14336 x 4096 | ~33.0 MB |
| Up projection | 14336 x 4096 | ~33.0 MB |
| Fused Gate+Up | 28672 x 4096 | ~66.0 MB |
The fused gate+up matrix (66 MB) exceeds the SLC on every Apple Silicon chip. So why fuse at all? Because the fusion still saves one kernel launch and one dispatch table entry per layer, and the memory controller can prefetch more efficiently when reading one contiguous allocation. On Pro+ chips where the SLC is large enough to hold a meaningful portion of the matrix, the cache hit rate on the “tail” of the read (after the GPU has processed the head) is better than with two separate reads.
On base chips, the gate+up fusion is skipped entirely (controlled by should_fuse_weights = false) because the kernel launch saving does not compensate for the worse cache behavior.
Deep Dive: Chain Decode Chunk Size
The chunk size determines how many tokens are packed into a single Metal command buffer. Let’s analyze the trade-offs quantitatively.
Command buffer overhead: Creating a command buffer, encoding commands, and submitting takes a fixed amount of CPU time. On Apple Silicon, this is approximately:
- Command buffer creation: ~2 us
- Encoder creation: ~1 us
- Submission (commit): ~5 us
- GPU scheduling delay: ~10-50 us (varies by chip and load)
Total fixed overhead per submission: ~20-60 us.
Per-token encoding time: Each token requires ~260 Metal API calls (set pipeline, set buffer, set bytes, dispatch). Each API call takes approximately 50-100ns. Total per-token encoding: ~13-26 us.
For different chunk sizes on M4 Pro:
| Chunk Size | Encoding Time | Fixed Overhead | Overhead Ratio | Tokens Until Display |
|---|---|---|---|---|
| 1 | 26 us | 60 us | 70% | 1 |
| 16 | 416 us | 60 us | 13% | 16 |
| 64 | 1.7 ms | 60 us | 3.4% | 64 |
| 128 | 3.3 ms | 60 us | 1.8% | 128 |
The “Tokens Until Display” column shows the latency cost: with a chunk size of 128, the user must wait for all 128 tokens to be generated before any are displayed. For interactive chat at 70 tok/s, a chunk of 128 takes ~1.8 seconds – potentially noticeable.
Akunu balances this by using smaller chunks for the first few tokens (to minimize time-to-first-token) and larger chunks for sustained generation. The chain_decode_chunk in ChipConfig sets the maximum chunk size for sustained generation.
On older chips (M1 base), the per-token encoding time is higher (slower CPU) and the GPU execution time is longer (less bandwidth), so the fixed overhead is relatively smaller. A chunk of 64 achieves a good balance. On M4, the faster CPU makes encoding cheaper and the faster GPU makes execution shorter, so a chunk of 128 is optimal.
Deep Dive: GEMV Variant Selection Logic
The three GEMV-related parameters work together to form a decision tree. Let me trace through how emit_gemv in table_builder.cpp uses them.
For each GEMV dispatch, the kernel variant is selected as follows:
Input: N (output rows), K (reduction dimension), dtype
1. Look up DTypeDescriptor for dtype
2. If N > chip.wide_gemv_threshold AND wide kernel exists:
→ Use wide GEMV (more output rows per TG)
3. Else if large-K kernel exists AND K >= chip.q4_small_k_threshold:
→ Use large-K GEMV (more threads for K-reduction)
4. Else:
→ Use standard GEMV
For a concrete example with Llama 3 8B (Q4_0):
| Operation | N | K | M4 Pro Decision | M1 Base Decision |
|---|---|---|---|---|
| Q projection | 4096 | 4096 | Standard (N < 32768) | Standard |
| K projection | 1024 | 4096 | Standard | Standard |
| V projection | 1024 | 4096 | Standard | Standard |
| O projection | 4096 | 4096 | Standard | Standard |
| Gate projection | 14336 | 4096 | Standard | Standard |
| Up projection | 14336 | 4096 | Standard | Standard |
| Down projection | 4096 | 14336 | Standard | Standard |
| Logit (vocab=128K) | 128256 | 4096 | Wide (N > 32768) | Wide |
The wide GEMV is only used for the vocabulary projection, where N (vocab size) is very large. All other projections use the standard kernel. The q4_small_k_threshold would matter for models with smaller hidden dimensions or for the KV projections, where K might be below the threshold on base chips.
The Backwards Compatibility Shim
The struct includes a backwards-compatibility factory:
static ChipConfig from_gpu_cores(int cores) { return from_gpu(cores, 0); }
This allows older code that only knows the core count (not the GPU family) to still create a configuration. With family = 0, the M4-specific optimizations (larger SLC, lower K-threshold) are skipped, and the config falls back to conservative M1-era tuning.
Extending ChipConfig
If you were adding CUDA support, ChipConfig would need additional fields:
// Hypothetical CUDA additions:
int sm_count; // number of streaming multiprocessors
int shared_mem_per_sm; // shared memory (threadgroup memory equivalent)
int warp_size; // always 32, but good to be explicit
bool has_tensor_cores;
int tensor_core_gen; // 0=none, 3=Volta, 4=Ampere, 5=Hopper
int l2_cache_bytes; // L2 cache size (analogous to SLC)
The factory function would detect these from cuDeviceGetAttribute and derive tuning parameters. The should_fuse_weights decision would key off l2_cache_bytes instead of slc_bytes. The chain_decode_chunk would adapt to the different command submission model (CUDA streams vs Metal command buffers). The table builder would use the same ChipConfig interface, just with different values.
For Vulkan, additional fields might include:
int subgroup_size; // analogous to SIMD width (varies: 16, 32, 64)
bool has_subgroup_matrix; // cooperative matrix support
int max_workgroup_size; // varies by device
The beauty of the ChipConfig approach is that the table builder does not care which backend produced the values. It just reads should_fuse_weights, wide_gemv_threshold, etc. The backend-specific detection logic is encapsulated in the factory function.3
Summary
ChipConfig is a simple struct that maps hardware capabilities (GPU core count, GPU family) to optimal tuning parameters (SLC size, fusion decisions, GEMV variant selection, chain decode chunk size). It is built once at device creation and consumed at dispatch table build time. All decisions are baked into the dispatch table – nothing is queried or branched on at inference time.
-
Chips and Cheese. (2021-2024). “Apple M1/M2/M3/M4 Die Analysis.” SLC sizes are estimated from die area analysis, performance counter measurements, and benchmark profiling. Apple does not officially disclose SLC sizes. See https://chipsandcheese.com/. ↩
-
Apple, “Apple M4 chip” (2024). The M4 chip announcement highlights “next-generation GPU” with improved ray tracing and machine learning performance. Independent analysis suggests significant cache hierarchy improvements. See https://www.apple.com/newsroom/2024/05/apple-introduces-m4-chip/. ↩
-
This separation of concerns – hardware detection in the backend, tuning consumption in the core – is the same pattern used by the Device abstraction layer. ChipConfig is effectively a “capabilities descriptor” that the core engine reads, analogous to how ArchDescriptor is a “model capabilities descriptor.” ↩
Model Loading: From File to GPU
If you have ever wondered what actually happens between the moment you point an
inference engine at a model file and the moment the first token appears on your
screen, this chapter is for you. In akunu, that transition is orchestrated by a
single function – akunu_load_model – defined in
src/inference/model_loader.cpp. The function is roughly 250 lines of C++, and
it touches almost every subsystem in the engine: format detection, weight I/O,
configuration parsing, architecture selection, tokenizer construction, GPU
memory allocation, weight fusion, KV cache creation, scratch buffer setup,
RoPE precomputation, dispatch table building, shader compilation, and a warmup
forward pass. That is a lot of ground to cover, so let us take it one stage at
a time.
The Bird’s-Eye View
Here is the full pipeline from “user calls the C API” to “model is ready to generate text.”
akunu_load_model(path, metallib_path, max_context)
|
| 1. Create GPU device
| Device::create_default()
|
| 2. Load metallib (shader library)
| load_metallib(device, metallib_path)
|
| 3. Detect model format (GGUF vs MLX/SafeTensors vs Whisper GGML)
| is_whisper_ggml(path) / WeightProvider::detect_format(path)
|
| 4. Open weights
| WeightProvider::open(path)
|
| 5. Extract model config
| weights.get_config() --> AkunuModelConfig
|
| 6. Infer architecture descriptor
| arch_from_config(config.architecture, config.dim) --> ArchDescriptor
|
| 7. Detect chip capabilities
| device.chip_config() --> ChipConfig
|
| 8. Resolve weight quirks (MLX RoPE style, tied embeddings, quant params)
|
| 9. Precompute RoPE frequencies
| init_rope_freqs(state, path)
|
| 10. Validate config (dim, layers, vocab, heads, head_dim all nonzero)
|
| 11. Load tokenizer
| load_tokenizer(state, path)
|
| 12. Allocate KV cache
| KVCache::create(device, n_layers, n_kv_heads, head_dim, ctx)
|
| 13. Allocate scratch buffers
| ScratchBuffers::create(device, config, ctx, max_prefill_chunk)
|
| 14. Build greedy dispatch table
| build_dispatch_table(device, weights, config, arch, chip, kv, scratch)
|
| 15. Build sampled dispatch table (greedy + Gumbel top-k + argmax)
|
| 16. Pre-allocate GPU param buffers for dispatch commands
|
| 17. Warmup forward pass (compile all PSOs)
| encode_dispatch_table(&table, 0, 1); end_encoding_sync();
|
| 18. Return opaque model handle
v
That is 18 discrete stages. Some of them are trivial (one function call). Others – like the dispatch table build – are complex enough to deserve an entire chapter of their own (see Chapter 25). But they all execute sequentially, one time, at model load. Nothing in this list happens again during inference. That is the whole point: pay the cost once up front so the hot path is zero-allocation and zero-branching.
Stage 1: Creating the Metal Device
The very first thing akunu_load_model does is create a Metal device:
state->device = Device::create_default();
printf("GPU: %s\n", state->device->name());
Device is akunu’s hardware abstraction layer. On Apple Silicon,
create_default() calls MTLCreateSystemDefaultDevice(), wraps the resulting
id<MTLDevice> in an internal implementation class (MetalDeviceImpl), and
queries the hardware for its GPU core count and Apple GPU family number. Those
two integers feed into ChipConfig::from_gpu() later, which tunes every
kernel dispatch in the engine to the specific chip you are running on.
Why does the device come first? Because almost everything else – the weight
provider, the KV cache, the scratch buffers – needs a Device& reference to
allocate GPU memory. Device construction is cheap (microseconds), and it
anchors the entire object graph.
Stage 2: Loading the Metallib
if (!load_metallib(*state->device, metallib_path)) {
set_error("Failed to load metallib. Build with: make shaders");
delete state;
return nullptr;
}
A .metallib is Apple’s precompiled shader archive – the GPU equivalent of a
.a static library. Akunu’s Metal kernels (GEMV, GEMM, RoPE, attention,
normalization, activation, argmax, etc.) are compiled offline into a single
akunu.metallib file. load_metallib tries a user-provided path first, then
falls back to a handful of well-known build output locations:
engine/build/akunu.metallib
.build/metallib/akunu.metallib
build/akunu.metallib
akunu.metallib
If none of these exist, the model load fails immediately. You cannot do inference without GPU kernels. The function is deliberately simple – no complex search logic, no environment variables, just a flat priority list.
Why load shaders before weights? Because there is no point spending time and memory on a multi-gigabyte weight file if we cannot even dispatch a kernel. Fail fast.
Stage 3: Format Detection
Akunu supports three model formats:
- GGUF – the standard quantized format from llama.cpp. Most common.
- MLX SafeTensors – Apple’s MLX framework format. Directory of
.safetensorsfiles plus aconfig.json. - Whisper GGML – legacy binary format for OpenAI Whisper models. Uses a
"lmgg"or"ggjt"magic number.
The detection logic is refreshingly unsophisticated:
bool is_whisper_ggml(const char *path) {
FILE *f = fopen(path, "rb");
char magic[4];
fread(magic, 1, 4, f);
fclose(f);
return memcmp(magic, "lmgg", 4) == 0
|| memcmp(magic, "ggjt", 4) == 0;
}
If the path is a directory or ends in .safetensors, it is MLX. Otherwise,
GGUF. No content sniffing beyond the four-byte magic for Whisper. This is
a case where “dumb and fast” beats “clever and fragile.”
Stage 4: Opening Weights
For standard LLM models (not Whisper), the next step is opening the weight provider:
state->weights = new WeightProvider(*state->device);
if (!state->weights->open(model_path)) {
set_error("Failed to open model: %s", model_path);
delete state;
return nullptr;
}
WeightProvider is a unified facade that wraps either a WeightStore (GGUF
parser) or an MLXWeightStore (SafeTensors parser). The open() call parses
file headers and metadata but does NOT eagerly load all tensors into GPU memory.
Weight tensors are memory-mapped and lazily materialized on first access through
get_tensor(). For a 7B Q4_0 model, the GGUF file is about 4 GB – you do
not want to copy all of that before you even know the model’s architecture.
The WeightProvider exposes a uniform interface regardless of format:
+-----------------+-----------+-----------+
| Method | GGUF | MLX |
+-----------------+-----------+-----------+
| get_config() | metadata | JSON |
| get_tensor(name)| mmap+GPU | mmap+GPU |
| get_dtype(name) | per-tensor| per-tensor|
| has_tensor(name)| lookup | lookup |
| fuse_weights() | concat | concat |
+-----------------+-----------+-----------+
This abstraction is crucial because the rest of the engine – table builder, prefill, everything – never knows or cares what format the weights came from.
Stage 5: Extracting the Model Config
state->config = state->weights->get_config();
For GGUF, this reads well-known metadata keys (llama.embedding_length,
llama.block_count, llama.attention.head_count, etc.) from the GGUF header.
For MLX, it parses config.json in the model directory.
The result is a plain-old-data struct, AkunuModelConfig:
AkunuModelConfig
+----------------------------------+
| dim (e.g., 4096) | embedding dimension
| n_layers (e.g., 32) | transformer layers
| n_heads (e.g., 32) | query heads
| n_kv_heads (e.g., 8) | key/value heads (GQA)
| head_dim (e.g., 128) | dim per head
| q_dim (e.g., 4096) | n_heads * head_dim
| kv_dim (e.g., 1024) | n_kv_heads * head_dim
| ffn_dim (e.g., 14336) | feed-forward intermediate
| vocab_size (e.g., 32000) | vocabulary size
| max_seq_len (e.g., 8192) | maximum context length
| norm_eps (e.g., 1e-5) | RMSNorm epsilon
| rope_theta (e.g., 500000.0) | RoPE base frequency
| architecture "llama" | arch name string
+----------------------------------+
Every downstream allocation and dispatch geometry is driven entirely by these numbers. There are no “if model is 7B then…” branches anywhere in akunu. The config is the single source of truth.
Stage 6: Architecture Inference
state->arch = arch_from_config(state->config.architecture, state->config.dim);
This is one of akunu’s cleanest design patterns. The ArchDescriptor struct
captures every architecture-specific behavior as data, not code:
ArchDescriptor
+------------------------------+------------------+
| Field | Example (LLaMA) |
+------------------------------+------------------+
| activation_kernel | "silu_gate_f16" |
| embedding_scale | 0.0 (no scale) |
| has_qk_norm | false |
| has_post_attn_norm | false |
| has_post_ffn_norm | false |
| rope_kernel | "rope_qkv_write" |
| tie_embeddings | false |
| quant_bits | 0 (GGUF native) |
+------------------------------+------------------+
Different architectures get different descriptors:
arch_from_config("llama", ...) --> arch_llama()
arch_from_config("qwen3", ...) --> arch_qwen3() // QK-norm, tied embeds
arch_from_config("gemma", ...) --> arch_gemma(dim) // GELU, post-norms, scale
arch_from_config("whisper",...) --> arch_whisper() // LayerNorm, cross-attn
arch_from_config("bert", ...) --> arch_bert() // encoder-only, SwiGLU
The beauty of this approach is that adding a new architecture requires exactly
one factory function and one case in arch_from_config. The table builder,
the prefill engine, and the decode loop never branch on the architecture name.
They read fields from the descriptor. This is the
“data-driven polymorphism” pattern, and it produces code that is both
simpler and faster than the traditional virtual-method approach.
Stage 7: Chip Configuration
state->chip = state->device->chip_config();
ChipConfig captures hardware-specific tuning parameters:
ChipConfig
+-----------------------------+-------+----------+----------+
| Field | M1 | M3 Pro | M4 Ultra |
+-----------------------------+-------+----------+----------+
| gpu_cores | 8 | 18 | 64 |
| slc_bytes | 8 MB | 24 MB | 96 MB |
| should_fuse_weights | false | true | true |
| chain_decode_chunk | 64 | 96 | 128 |
| max_prefill_chunk | 4096 | 4096 | 4096 |
| q4_small_k_threshold | 2048 | 1024 | 512 |
| wide_gemv_threshold | 32768 | 32768 | 32768 |
+-----------------------------+-------+----------+----------+
Notice should_fuse_weights: on chips with a large System Level Cache (16+ MB,
meaning Pro and above, plus M4 base), akunu will fuse Q/K/V weight matrices
into a single contiguous buffer and dispatch one GEMV instead of three. The
fused weights fit in SLC, so the second and third “GEMVs” hit cache instead of
DRAM. On base M1/M2/M3, the SLC is too small for this to help, so fusion is
disabled.
Also notice chain_decode_chunk: this controls how many tokens are chained
into a single GPU command buffer submission. Larger chunks amortize the
overhead of Metal command encoding. M4 can handle 128 tokens per chunk; M1 is
limited to 64 due to its smaller command processor and narrower memory bus.
Stage 8: Weight Quirks
After extracting the architecture descriptor, model_loader patches it with
format-specific adjustments:
// MLX Llama uses NeoX (split-half) RoPE, not interleaved
if (weights->format() == MLX_SAFETENSORS && strstr(config.architecture, "llama"))
arch.rope_kernel = "rope_neox_qkv_write_f16";
// Tie embeddings if the architecture says so, OR if output.weight is missing
if (!arch.tie_embeddings)
arch.tie_embeddings = !weights->has_tensor("output.weight");
// Copy MLX quantization info
arch.quant_bits = weights->quant_bits();
arch.quant_group_size = weights->quant_group_size();
This is where the messy real-world details of model formats get resolved.
MLX uses a different RoPE convention than GGUF for the same architecture.
Some GGUF exports include output.weight and some do not. Rather than
scattering format checks throughout the engine, they are all concentrated here
in the loader.
Stage 9: RoPE Frequency Precomputation
init_rope_freqs(state, model_path);
Most models use standard RoPE with a base frequency theta. But LLaMA 3 introduced a complex wavelen-based frequency scaling scheme, and some models use simple linear scaling. akunu handles both by precomputing the frequency divisors for each dimension of the rotary embedding at load time:
For each dimension i in [0, head_dim/2):
base_freq = theta^(2i / head_dim)
LLaMA 3 wavelen scaling:
wavelen = 2 * pi * base_freq
if wavelen > low_wavelen:
freq[i] = base_freq * factor (long wavelengths scaled)
else if wavelen > high_wavelen:
smooth = (orig_max_pos/wavelen - low_freq_factor)
/ (high_freq_factor - low_freq_factor)
freq[i] = base_freq / ((1-smooth)/factor + smooth)
else:
freq[i] = base_freq (short wavelengths unchanged)
Linear scaling:
freq[i] = base_freq * factor
The resulting frequency vector is uploaded to a GPU buffer
(arch.rope_freqs). If no scaling is needed, the buffer stays null and the
RoPE kernel computes frequencies on the fly from theta.
This precomputation avoids two problems. First, the wavelen formulas involve floating-point operations (pow, division, conditional branches) that would add latency if repeated on every token. Second, it keeps the GPU kernel simpler – it either reads precomputed frequencies from a buffer or computes the standard geometric series, never the complex LLaMA 3 formula.
Stage 10: Config Validation
if (config.dim == 0 || config.n_layers == 0 || config.vocab_size == 0 ||
config.n_heads == 0 || config.head_dim == 0) {
set_error("Invalid model config: ...");
delete state;
return nullptr;
}
A simple sanity check. If any critical dimension is zero, something went wrong during metadata parsing (corrupt file, unsupported format version, missing keys). Fail immediately rather than producing cryptic GPU errors downstream.
Stage 11: Tokenizer Loading
The load_tokenizer function is more involved than you might expect. It needs
to handle two completely different tokenizer sources:
GGUF path:
vocab <-- weights.get_string_array("tokenizer.ggml.tokens")
scores <-- weights.get_float_array("tokenizer.ggml.scores")
merges <-- weights.get_string_array("tokenizer.ggml.merges")
type <-- weights.get_metadata_string("tokenizer.ggml.model")
bos_id <-- weights.get_metadata_int("tokenizer.ggml.bos_token_id")
eos_id <-- weights.get_metadata_int("tokenizer.ggml.eos_token_id")
HuggingFace path (MLX models):
load_hf_tokenizer(model_dir, hf_data)
--> parses tokenizer.json + tokenizer_config.json
After loading the raw vocabulary, akunu also scans for implicit stop tokens that are not the “official” EOS token but should still terminate generation:
const char *stop_tokens[] = {
"<|im_end|>", // ChatML
"<|endoftext|>", // GPT-2 style
"<|eot_id|>", // LLaMA 3
"<end_of_turn>", // Gemma
"</s>", // legacy
nullptr
};
For each of these, if the string exists in the vocabulary with a different ID
than the primary EOS, it is registered as an additional EOS token. This means
the decode loop does not need to know about chat templates – it just checks
tokenizer.is_eos(tok) and the tokenizer handles the multi-EOS logic.
Stage 12: KV Cache Allocation
int ctx = max_context > 0
? max_context
: std::min((int)config.max_seq_len, chip.max_prefill_chunk);
state->kv_cache = KVCache::create(
*state->device, config.n_layers, config.n_kv_heads, config.head_dim, ctx);
The KV cache is the largest single allocation in the system. For a 32-layer model with 8 KV heads, 128-dim heads, and a 4096-token context, the total is:
Per-layer buffer size:
n_kv_heads * max_seq_len * head_dim * sizeof(FP16)
= 8 * 4096 * 128 * 2
= 8 MB
Total:
n_layers * 2 (K + V) * 8 MB
= 32 * 2 * 8 MB
= 512 MB
The KVCache struct is a flat POD container – no virtual calls, no reference
counting, no linked lists:
KVCache
+------------------+
| n_layers = 32 |
| n_kv_heads = 8 |
| head_dim = 128|
| max_length = 4096|
| current_length = 0|
| kv_stride = 524288 (max_length * head_dim)
| |
| k_buffers[32] | <-- one GPU buffer per layer
| v_buffers[32] | <-- one GPU buffer per layer
+------------------+
All buffers are zero-filled at creation. The current_length field tracks how
many positions have been written (by prefill or decode). The kv_stride is
precomputed to avoid a multiplication in the attention kernel’s inner loop.
Stage 13: Scratch Buffer Allocation
state->scratch = ScratchBuffers::create(
*state->device, state->config, ctx, chip.max_prefill_chunk);
Scratch buffers are the working memory for a single forward pass. They are allocated once and reused every time. No dynamic allocation ever happens in the hot path.
ScratchBuffers (decode -- single token)
+----------------------------------------+
| h0 [dim] FP16 residual ping |
| h1 [dim] FP16 residual pong |
| residual [dim] FP16 norm output |
| qkv [q+2*kv_dim] FP16 Q|K|V concat |
| attn_out [q_dim] FP16 attention out |
| post_norm [dim] FP16 Gemma temp |
| ffn_gate [2*ffn_dim] FP16 gate|up fused |
| ffn_up [ffn_dim] FP16 up projection |
| ffn_act [ffn_dim] FP16 activation out |
| logits [vocab] FP16 final logits |
| token_ids [max_chain] U32 token buffer |
+----------------------------------------+
ScratchBuffers (prefill -- batch)
+----------------------------------------+
| batch_h0 [chunk * dim] FP16 |
| batch_h1 [chunk * dim] FP16 |
| batch_residual [chunk * dim] FP16 |
| batch_q [chunk * q_dim] FP16 |
| batch_k [chunk * kv_dim] FP16 |
| batch_v [chunk * kv_dim] FP16 |
| batch_attn_out [chunk * q_dim] FP16 |
| batch_gate [chunk * ffn] FP16 |
| batch_up [chunk * ffn] FP16 |
| batch_act [chunk * ffn] FP16 |
| batch_post_norm [chunk * dim] FP16 |
+----------------------------------------+
The dual sets of buffers – one for single-token decode, one for batched prefill – are a deliberate design choice. Decode uses GEMV (matrix-vector); prefill uses GEMM (matrix-matrix). They have completely different memory access patterns, and sharing buffers between them would either waste memory or require dynamic resizing.
Stage 14: Building the Greedy Dispatch Table
state->dispatch_table = build_dispatch_table(
*state->device, *state->weights, state->config,
state->arch, state->chip, state->kv_cache, state->scratch);
This is the most complex and most important step of model loading. The
dispatch table is a flat array of DispatchCmd structs – one for every GPU
kernel dispatch needed to process a single token through the entire transformer.
Chapter 25 covers this in detail, but the high-level structure is:
Command sequence for one token:
[0] Embedding lookup
[1] Embedding scale (Gemma only)
[2] Initial RMSNorm
For each layer:
[3+N*k] Fused QKV GEMV (or separate Q, K, V)
[...] Fused QK-norm + RoPE + KV write (or separate)
[...] Flash attention decode
[...] O projection GEMV
[...] Post-attn norm (Gemma only)
[...] Fused residual + FFN norm
[...] Fused gate+up GEMV (or separate)
[...] Fused SiLU+down GEMV (or separate activation + down)
[...] Post-FFN norm (Gemma only)
[...] Fused next attn norm
[N-2] Output norm
[N-1] Logit projection GEMV
[N] Argmax
For a 32-layer LLaMA model with all fusions enabled, this is typically around 160-200 commands.
Stage 15: Building the Sampled Dispatch Table
After the greedy table is built, model_loader constructs a second table for
sampled (temperature > 0) decoding:
auto& greedy_cmds = state->dispatch_table.commands;
auto& sampled_cmds = state->sampled_dispatch_table.commands;
// Copy all commands except the final argmax
for (size_t i = 0; i + 1 < greedy_cmds.size(); i++)
sampled_cmds.push_back(greedy_cmds[i]);
// Insert a Gumbel top-k noise kernel
DispatchCmd gumbel_cmd = DispatchCmd::make(gumbel_pso, Dim3(1), Dim3(1024));
gumbel_cmd.add_buffer(scratch.logits, 0, 0);
// ... set up GumbelTopKParams ...
gumbel_cmd.patch_type = DispatchCmd::PATCH_POSITION;
sampled_cmds.push_back(gumbel_cmd);
sampled_cmds.push_back(greedy_cmds.back()); // argmax still works!
The key insight: the Gumbel-max trick turns sampling into argmax. By adding
temperature * Gumbel_noise to each logit before taking the argmax, you get a
sample from the categorical distribution softmax(logits / temperature). This
means the sampled table is identical to the greedy table except for one extra
kernel inserted before the argmax. No CPU round-trip. No softmax. No
probability computation. Just noise + argmax, entirely on the GPU.
Stage 16: Pre-Allocating Parameter Buffers
for (auto& cmd : cmds) {
if (cmd.param_size > 0) {
cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
}
}
Each DispatchCmd can carry up to 64 bytes of inline parameter data. Rather
than using Metal’s setBytes (which copies the data on every dispatch), akunu
pre-allocates a small GPU buffer for each command’s parameters and uses
setBuffer instead. On Unified Memory Architecture (UMA) systems like Apple
Silicon, these buffers are CPU-accessible, so updating a parameter (like the
position for RoPE) is a simple pointer write – no copy, no upload.
Stage 17: Warmup Forward Pass
printf("Warming up...\n");
uint32_t warmup_token = state->tokenizer.bos_id();
state->device->write_buffer(state->scratch.token_ids, &warmup_token, 4);
state->device->begin_encoding();
state->device->encode_dispatch_table(&state->dispatch_table, 0, 1);
state->device->end_encoding_sync();
state->kv_cache.reset();
Metal compiles Pipeline State Objects (PSOs) lazily – the first time a kernel is dispatched, the GPU driver compiles it from the metallib. This compilation can take tens of milliseconds per kernel. If we did not warm up, the first real inference would stutter terribly.
The warmup runs a complete forward pass for one token (the BOS token), forcing every PSO in the dispatch table to compile. Then it resets the KV cache to zero length, erasing any state written during warmup. The result: when the user sends their first prompt, every kernel is already compiled and ready to go.
Stage 18: Returning the Model Handle
printf("Ready.\n\n");
return state;
The ModelState pointer is returned as an opaque akunu_model_t handle. From
here on, the caller uses this handle with akunu_generate, akunu_prefill,
akunu_embed, etc. All the complexity of model loading is hidden behind a
single pointer.
The Complete Flow Diagram
Let us put it all together in a single visual:
akunu_load_model("model.gguf", "akunu.metallib", 4096)
|
+-- Device::create_default()
| |
| +-- MTLCreateSystemDefaultDevice()
| +-- query GPU cores, family
|
+-- load_metallib(device, path)
| |
| +-- try user path, then 4 fallback paths
| +-- device.load_library(path) --> MTLLibrary
|
+-- is_whisper_ggml(path)?
| |
| +-- NO (standard LLM path)
|
+-- WeightProvider::open(path)
| |
| +-- detect_format(path) --> GGUF or MLX
| +-- GGUF: WeightStore::open() --> parse header, mmap tensors
| +-- MLX: MLXWeightStore::open() --> parse JSON, mmap safetensors
|
+-- weights.get_config()
| |
| +-- AkunuModelConfig { dim=4096, layers=32, heads=32, ... }
|
+-- arch_from_config("llama", 4096)
| |
| +-- ArchDescriptor { silu_gate, no-qk-norm, rope_qkv_write, ... }
|
+-- device.chip_config()
| |
| +-- ChipConfig { cores=10, slc=8MB, chunk=64, ... }
|
+-- Resolve MLX RoPE, tied embeds, quant params
|
+-- init_rope_freqs(state, path)
| |
| +-- LLaMA3 scaling? compute wavelen-based freqs
| +-- Linear scaling? compute simple scaled freqs
| +-- Neither? leave rope_freqs null (use theta at runtime)
|
+-- Validate config (all dims nonzero)
|
+-- load_tokenizer(state, path)
| |
| +-- GGUF: read vocab/scores/merges from metadata
| +-- MLX: load_hf_tokenizer(dir)
| +-- Register extra stop tokens
|
+-- KVCache::create(device, 32, 8, 128, 4096)
| |
| +-- 32 layers x 2 (K+V) x 8MB = 512 MB GPU memory
|
+-- ScratchBuffers::create(device, config, 4096, 4096)
| |
| +-- decode: h0, h1, residual, qkv, attn_out, ffn, logits
| +-- prefill: batch versions of all the above
|
+-- build_dispatch_table(device, weights, config, arch, chip, kv, scratch)
| |
| +-- ~180 DispatchCmd structs
| +-- all PSOs resolved, all buffers bound, all params set
|
+-- Build sampled_dispatch_table (greedy + Gumbel + argmax)
|
+-- Pre-allocate param_buf for every command
|
+-- Warmup: run 1 token forward pass to compile all PSOs
| +-- begin_encoding()
| +-- encode_dispatch_table(&table, 0, 1)
| +-- end_encoding_sync() <-- blocks until GPU done
| +-- kv_cache.reset() <-- erase warmup state
|
+-- return state (as opaque akunu_model_t)
Memory Budget
Here is a concrete example for Llama 3.1 8B (Q4_0, 4096 context) on M3 Pro:
Component Size
-------------------------------------
Model weights (mmap) ~4.3 GB
KV cache (32 layers) ~512 MB
Scratch (decode) ~2.5 MB
Scratch (prefill) ~550 MB
Metallib (shaders) ~2 MB
Tokenizer ~3 MB
Dispatch table ~30 KB
-------------------------------------
Total GPU-resident ~5.4 GB
The weights dominate, as expected. The KV cache is the second largest allocation. Everything else is a rounding error.
Error Handling Philosophy
You may have noticed that akunu_load_model uses a simple pattern:
if (something_failed) {
set_error("...");
delete state;
return nullptr;
}
There are no exceptions. There are no error codes. The function either
returns a valid model handle or returns nullptr and sets a thread-local error
string that the caller can retrieve with akunu_last_error(). This is the
standard C API pattern – it works across language boundaries (Swift, Python,
Rust FFI) and avoids the overhead and complexity of C++ exception handling in a
performance-critical codebase.
The Whisper Path
For completeness, akunu_load_model also handles Whisper models. The flow is
similar but with encoder-decoder-specific additions:
Whisper-specific stages:
- load_whisper_model() instead of WeightProvider
- arch_whisper() descriptor (LayerNorm, bias, cross-attention)
- MelSpectrogram processor (with model's precomputed mel filters)
- WhisperBuffers (encoder/decoder scratch)
- Copy learned positional embeddings (encoder + decoder)
- build_whisper_decode_table() (includes GPU suppress params)
- Beam search buffers (5 beams x KV caches + intermediate buffers)
We will not dive deep into the Whisper path in this book, but it is worth
knowing that the same akunu_load_model entry point handles both LLMs and
audio models. The architecture descriptor pattern makes this possible without
an explosion of conditional logic.
Summary
Model loading in akunu is a carefully ordered sequence of 18 stages that
transforms a file path into a ready-to-run GPU inference pipeline. Every
decision – architecture-specific behavior, chip-specific tuning, format-specific
quirks – is resolved during loading and encoded into data structures (the
ArchDescriptor, ChipConfig, DispatchTable) that the hot path reads but
never modifies.
The result: zero allocation, zero branching, and zero format-awareness during inference. The next four chapters explore the data structures this loading process creates and how they are used to actually generate text at hundreds of tokens per second.
The Dispatch Table: Precompiled GPU Command Sequences
This is the chapter about akunu’s central innovation. If you had to distill the entire engine design into a single idea, it would be this: build the GPU command sequence once at model load time, then replay it for every token.
Most inference engines construct GPU commands on-the-fly during inference. For each token, they iterate through the model layers, look up weight tensors by name, create kernel parameter structs, resolve pipeline state objects, and emit dispatch commands. This per-token overhead is small in isolation – maybe a few hundred microseconds – but it adds up. At 70 tokens per second, each token has a 14ms budget. Spending 0.5ms on command construction is 3.5% overhead, and on smaller models it can be much worse.
Akunu eliminates this overhead entirely. The DispatchTable is a flat array of DispatchCmd structs – plain old data, no pointers to chase, no virtual calls, no hash table lookups. Built once, replayed thousands of times. The only per-token work is patching a few position-dependent fields.
The Core Data Structures
DispatchCmd: A Single GPU Command
Every GPU operation in a forward pass – every GEMV, every norm, every attention, every RoPE – is represented as a single DispatchCmd:
struct DispatchCmd {
Pipeline pso; // Compiled compute pipeline
// Buffer bindings
static constexpr int MAX_BUFFERS = 8;
Buffer buffers[MAX_BUFFERS];
uint32_t offsets[MAX_BUFFERS];
int buffer_count;
// Inline params (up to 64 bytes)
uint8_t param_bytes[64];
int param_size;
int param_index;
// Pre-allocated GPU buffer for static params
Buffer param_buf;
// Secondary params
uint8_t param2_bytes[16];
int param2_size;
int param2_index;
// Threadgroup memory
int tg_mem_bytes;
int tg_mem_index;
// Dispatch geometry
Dim3 grid;
Dim3 threadgroup;
bool use_dispatch_threads;
// Per-token patching
enum PatchType : uint8_t {
PATCH_NONE = 0,
PATCH_TOKEN_OFFSET,
PATCH_POSITION,
PATCH_KV_SEQ_LEN,
PATCH_POS_AND_KV,
PATCH_ARGMAX_OUTPUT,
};
PatchType patch_type;
int patch_offset_1;
int patch_offset_2;
};
Let’s examine each section.
Pipeline (pso): The pre-compiled compute pipeline state object. Resolved once during table building via device.get_pipeline(). No string lookup at dispatch time.
Buffer bindings (buffers[], offsets[], buffer_count): Fixed-size array of up to 8 buffer bindings. Each entry has a Buffer handle and a byte offset. These are the weight buffers, scratch buffers, and KV cache buffers that the kernel reads from and writes to. All resolved at build time.
Inline parameters (param_bytes[64], param_size, param_index): The kernel’s parameter struct (dimensions, epsilon, strides, etc.) stored inline as raw bytes. 64 bytes is enough for every kernel parameter struct in akunu. The param_index is the argument buffer index for setBytes or setBuffer.
Pre-allocated param buffer (param_buf): For commands with static (non-position-dependent) parameters, a GPU buffer is pre-allocated containing the parameter data. At replay time, setBuffer is used instead of setBytes, avoiding the per-dispatch copy entirely.
Secondary params (param2_bytes[16]): Some kernels need two separate setBytes calls at different argument indices. The secondary param slot handles this.
Threadgroup memory (tg_mem_bytes, tg_mem_index): Some kernels require threadgroup memory allocation (e.g., GEMM tile staging). This is pre-computed.
Dispatch geometry (grid, threadgroup, use_dispatch_threads): Pre-computed grid and threadgroup dimensions. use_dispatch_threads selects between dispatchThreadgroups and dispatchThreads (the latter auto-computes grid size from total threads).
Per-token patching (patch_type, patch_offset_1, patch_offset_2): This is the key mechanism that makes replay possible despite per-token variation. More on this below.
Size and Alignment
A single DispatchCmd is approximately 280 bytes. For a 32-layer Llama model, the table contains roughly 200-250 commands, totaling ~56-70 KB. This fits comfortably in L2 cache, meaning the replay loop operates almost entirely from cache.1
DispatchTable: The Complete Sequence
struct DispatchTable {
std::vector<DispatchCmd> commands; // Hot path: dense command array
std::vector<DispatchLabel> labels; // Cold path: profiling labels
int tokens_per_tg;
};
The hot/cold split is deliberate. Profiling labels (48-byte strings like “layer.0.attention”) are stored in a parallel vector, separate from the command array. During inference, the labels are never accessed – the inner loop iterates only the dense commands vector. During profiling, labels are accessed by index to annotate timing data.
Per-Token Patching
Here is the fundamental challenge: most of the forward pass is identical for every token – same weights, same kernels, same dispatch geometry. But a few things change:
- Position: RoPE needs the current sequence position. Attention needs the KV sequence length.
- Token offset: The embedding lookup reads from
token_ids[token_index]. The argmax writes totoken_ids[token_index + 1].
Akunu solves this with a small enum of patch types:
| PatchType | What Changes | Where |
|---|---|---|
PATCH_NONE | Nothing | Most commands (GEMV, norms, etc.) |
PATCH_TOKEN_OFFSET | buffers[0].offset = token_index * 4 | Embedding lookup |
PATCH_POSITION | param_bytes[offset1] = position | RoPE |
PATCH_KV_SEQ_LEN | param_bytes[offset1] = position + 1 | Flash attention |
PATCH_POS_AND_KV | Both position and KV length | Combined RoPE+attention params |
PATCH_ARGMAX_OUTPUT | buffers[1].offset = (token_index+1) * 4 | Argmax output |
The patch_offset_1 and patch_offset_2 fields specify the byte offsets within param_bytes where the position/KV-length values should be written. These offsets are computed at build time using offsetof():
// From table_builder.cpp:
cmd.patch_type = DispatchCmd::PATCH_POSITION;
cmd.patch_offset_1 = offsetof(decltype(rope_params), pos);
At replay time, the patching is a simple memcpy + write:
uint8_t patched[64];
memcpy(patched, cmd.param_bytes, cmd.param_size);
*(uint32_t *)(patched + cmd.patch_offset_1) = (uint32_t)pos;
[enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
For the ~200 commands in a typical table, only about 100 need patching (RoPE, attention, embedding, argmax). The other ~100 use PATCH_NONE and bind their pre-allocated param buffer with zero per-token work.
Table Building: The build_dispatch_table Function
The table builder in src/core/table_builder.cpp constructs the dispatch table by walking through the transformer architecture once:
DispatchTable build_dispatch_table(Device& device,
WeightProvider& weights,
const AkunuModelConfig& cfg,
const ArchDescriptor& arch,
const ChipConfig& chip,
const KVCache& kv_cache,
const ScratchBuffers& scratch);
The function takes everything it needs as parameters: the device (for pipeline creation), weights (for buffer handles), model config (for dimensions), arch descriptor (for architecture-specific behavior), chip config (for hardware tuning), KV cache (for cache buffer handles), and scratch buffers (for intermediate buffer handles).
It returns a fully-constructed DispatchTable with all commands, all buffers resolved, all pipelines compiled and cached, and all parameter structs serialized.
The CmdBuilder Helper
To reduce boilerplate, the table builder uses a fluent builder:
struct CmdBuilder {
DispatchCmd cmd;
DispatchTable& table;
CmdBuilder& buf(Buffer b, int offset, int index) { ... }
CmdBuilder& params(const T& p, int index) { ... }
CmdBuilder& label(const char *l) { ... }
CmdBuilder& threads() { cmd.use_dispatch_threads = true; return *this; }
CmdBuilder& patch(PatchType type, int offset1, int offset2 = 0) { ... }
CmdBuilder& tg_mem(int bytes, int index) { ... }
void emit() { table.commands.push_back(cmd); }
};
Usage:
CmdBuilder(table, attn_pso, Dim3(1, n_heads), Dim3(1024))
.buf(scratch.qkv, scratch.qkv_q_offset, 0)
.buf(kv_cache.k_buffers[layer], 1)
.buf(kv_cache.v_buffers[layer], 2)
.buf(scratch.attn_out, 3)
.params(attn_params, 4)
.patch(DispatchCmd::PATCH_KV_SEQ_LEN,
offsetof(decltype(attn_params), kv_seq_len))
.label("layer.%d.attention", layer)
.emit();
The Complete Forward Pass
The table builder emits commands in this order:
Dispatch Table Command Sequence (one token forward pass):
1. EMBEDDING: token_ids[i] -> h0 [PATCH_TOKEN_OFFSET]
(embedding_scale — Gemma only)
2. INITIAL RMSNORM: h0 -> residual
3. LAYER LOOP (x n_layers):
├─ 3a. QKV GEMV (fused or 3 separate)
├─ 3b. QK-Norm + RoPE + KV Write [PATCH_POSITION]
├─ 3c. Flash Attention [PATCH_KV_SEQ_LEN]
├─ 3d. O Projection GEMV
├─ (post_attn_norm — Gemma only)
├─ 3e. Fused Residual + FFN Norm
├─ 3f. Gate+Up GEMV (fused or 2 separate)
├─ 3g. Fused SiLU+Down GEMV (or separate)
├─ (post_ffn_norm — Gemma only)
└─ 3h. Fused Next Attention Norm
4. OUTPUT NORM: fused residual + RMSNorm
5. LOGIT PROJECTION: residual -> vocab logits
6. ARGMAX: logits -> token_ids[i+1] [PATCH_ARGMAX_OUTPUT]
One Layer’s Commands in Detail
For a 32-layer Llama 3 8B model with Q4_0 quantization, fused QKV, and fused SiLU+down, one layer produces approximately 7 commands:
| # | Command | Pipeline | Patch Type | Buffers |
|---|---|---|---|---|
| 1 | QKV GEMV (fused) | gemv_q4_0 | NONE | residual, fused_qkv_w, qkv_buf |
| 2 | RoPE + KV Write | rope_qkv_write_f16 | POSITION | qkv_buf, k_cache[L], v_cache[L] |
| 3 | Flash Attention | flash_attn_decode_par_f16 | KV_SEQ_LEN | qkv_buf, k_cache[L], v_cache[L], attn_out |
| 4 | O Projection | gemv_q4_0 | NONE | attn_out, o_weight, residual |
| 5 | Fused Residual+FFN Norm | residual_rmsnorm_f16 | NONE | residual, h0, ffn_norm_w, h1, attn_out |
| 6 | Gate+Up GEMV (fused) | gemv_q4_0 | NONE | attn_out, fused_gu_w, ffn_gate |
| 7 | Fused SiLU+Down | gemv_q4_0_silu | NONE | ffn_gate, ffn_gate+off, down_w, residual |
Plus the fused-next-attention-norm for the transition to the next layer (1 more command). So roughly 8 commands per layer.
For 32 layers: 8 * 32 = 256 per-layer commands. Plus embedding (1), initial norm (1), output norm (1), logit projection (1), argmax (1) = 261 total commands.
The actual count varies by architecture and fusion decisions:
- Llama (fused): ~7-8 commands/layer, ~230-260 total
- Llama (unfused): ~10-12 commands/layer, ~330-390 total
- Gemma (fused): ~9-10 commands/layer (extra post-norms), ~300-330 total
- Qwen3 (fused): ~7-8 commands/layer, ~230-260 total
The table builder prints the count at the end:
printf("Dispatch table built: %zu commands per token\n", cmds.size());
The Replay Loop: encode_chain
The generic replay function in dispatch_table.h is straightforward:
inline void encode_chain(Device& device, const DispatchTable& table,
int start_position, int count) {
const auto& cmds = table.commands;
const int n_cmds = (int)cmds.size();
for (int tok = 0; tok < count; tok++) {
int pos = start_position + tok;
for (int c = 0; c < n_cmds; c++) {
const auto& cmd = cmds[c];
device.set_pipeline(cmd.pso);
// Set buffers (with per-token offset patching)
for (int b = 0; b < cmd.buffer_count; b++) {
int offset = cmd.offsets[b];
if (cmd.patch_type == PATCH_TOKEN_OFFSET && b == 0)
offset = tok * 4;
if (cmd.patch_type == PATCH_ARGMAX_OUTPUT && b == 1)
offset = (tok + 1) * 4;
device.set_buffer(cmd.buffers[b], offset, b);
}
// Set params (with position patching)
if (cmd.param_size > 0) {
if (needs_patching(cmd.patch_type)) {
uint8_t patched[64];
memcpy(patched, cmd.param_bytes, cmd.param_size);
patch_position(patched, cmd, pos);
device.set_bytes(patched, cmd.param_size, cmd.param_index);
} else {
device.set_bytes(cmd.param_bytes, cmd.param_size, cmd.param_index);
}
}
// Secondary params, threadgroup memory
if (cmd.param2_size > 0)
device.set_bytes(cmd.param2_bytes, cmd.param2_size, cmd.param2_index);
if (cmd.tg_mem_bytes > 0)
device.set_threadgroup_memory(cmd.tg_mem_bytes, cmd.tg_mem_index);
// Dispatch
if (cmd.use_dispatch_threads)
device.dispatch_threads(cmd.grid, cmd.threadgroup);
else
device.dispatch(cmd.grid, cmd.threadgroup);
}
}
}
This is the generic version that works through the Device virtual interface. It is correct but slow – each device.set_pipeline() call goes through a virtual function dispatch.
The Metal Fast Path
MetalDevice overrides encode_dispatch_table to eliminate virtual calls:
void MetalDevice::encode_dispatch_table(const void *table_ptr,
int start_position, int count) {
const DispatchCmd *__restrict cmds = table.commands.data();
const int n_cmds = (int)table.commands.size();
id<MTLComputeCommandEncoder> enc = STATE.encoder;
uint8_t patched[64]; // reused stack buffer
for (int tok = 0; tok < count; tok++) {
const uint32_t pos = (uint32_t)(start_position + tok);
const uint32_t kv_len = pos + 1;
for (int c = 0; c < n_cmds; c++) {
const DispatchCmd &__restrict cmd = cmds[c];
const auto pt = cmd.patch_type;
[enc setComputePipelineState:cmd.pso.handle];
// Buffers: switch on patch type for offset patching
switch (pt) {
case PATCH_TOKEN_OFFSET:
[enc setBuffer:cmd.buffers[0].handle offset:tok*4 atIndex:0];
for (int b = 1; b < cmd.buffer_count; b++)
[enc setBuffer:cmd.buffers[b].handle offset:cmd.offsets[b] atIndex:b];
break;
case PATCH_ARGMAX_OUTPUT:
[enc setBuffer:cmd.buffers[0].handle offset:cmd.offsets[0] atIndex:0];
[enc setBuffer:cmd.buffers[1].handle offset:(tok+1)*4 atIndex:1];
for (int b = 2; b < cmd.buffer_count; b++)
[enc setBuffer:cmd.buffers[b].handle offset:cmd.offsets[b] atIndex:b];
break;
default:
for (int b = 0; b < cmd.buffer_count; b++)
[enc setBuffer:cmd.buffers[b].handle offset:cmd.offsets[b] atIndex:b];
break;
}
// Params: setBytes for patched, setBuffer for static
if (cmd.param_size > 0) {
switch (pt) {
case PATCH_POSITION:
memcpy(patched, cmd.param_bytes, cmd.param_size);
*(uint32_t*)(patched + cmd.patch_offset_1) = pos;
[enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
break;
case PATCH_POS_AND_KV:
memcpy(patched, cmd.param_bytes, cmd.param_size);
*(uint32_t*)(patched + cmd.patch_offset_1) = pos;
*(uint32_t*)(patched + cmd.patch_offset_2) = kv_len;
[enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
break;
case PATCH_KV_SEQ_LEN:
memcpy(patched, cmd.param_bytes, cmd.param_size);
*(uint32_t*)(patched + cmd.patch_offset_1) = kv_len;
[enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
break;
default:
if (cmd.param_buf.handle)
[enc setBuffer:cmd.param_buf.handle offset:0 atIndex:cmd.param_index];
else
[enc setBytes:cmd.param_bytes length:cmd.param_size atIndex:cmd.param_index];
break;
}
}
// Secondary params + threadgroup memory + dispatch
// ...
}
}
}
Key optimizations in the fast path:
- No virtual calls. Direct ObjC message sends to the encoder.
__restrictpointers. Hints to the compiler thatcmdsdoes not alias anything else.- Stack-allocated patched buffer. Reused across all commands, no heap allocation.
- Switch on patch_type. The common case (
PATCH_NONE) falls through to the simple buffer-binding loop. - Pre-allocated param buffers. Static params use
setBuffer(zero per-token work) instead ofsetBytes(which copies data into the command buffer).
The setBuffer/setBytes Split
This deserves special attention because it is the most subtle optimization in the dispatch table.
Most commands (~60%) have PATCH_NONE – their parameters do not change between tokens. For these, the table builder pre-allocates a GPU buffer containing the parameter data:
// At the end of build_dispatch_table():
for (auto& cmd : cmds) {
if (cmd.param_size > 0) {
cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
}
}
At replay time, these commands use setBuffer to bind the pre-allocated buffer. On UMA, this is essentially free – Metal just records a pointer in the command buffer. No data copy, no coherency cost.
For position-patched commands (~40%), setBytes is used. This copies the patched parameter data (up to 64 bytes) into the command buffer inline. This is also fast (<4KB threshold for Metal inline data), but it does involve a memcpy and a few bytes of command buffer space per dispatch.
The net effect: for a 260-command table, approximately 160 commands use zero-cost setBuffer and 100 commands use low-cost setBytes. The total per-token CPU overhead for parameter binding is roughly:
$$100 \text{ commands} \times 64 \text{ bytes/memcpy} = 6.4 \text{ KB of memcpy}$$
At memcpy throughput of ~50 GB/s on Apple Silicon, this takes approximately 0.1 microseconds. Negligible.
Chain Decode: The Complete Picture
Let me trace through a complete chain decode of 128 tokens to show how everything fits together:
Chain Decode: 128 Tokens in One Submission
[1] CPU: Write first token ID to token_ids[0]
[2] CPU: device.begin_encoding()
[3] CPU: encode_dispatch_table(table, start_pos, 128)
─── 260 commands x 128 tokens = 33,280 Metal API calls ───
[4] CPU: device.end_encoding_sync()
[5] GPU: executes all 33,280 dispatches sequentially
Token 0: embed -> 32 layers -> logit -> argmax -> token_ids[1]
Token 1: embed(token_ids[1]) -> 32 layers -> ... -> token_ids[2]
...
Token 127: ... -> token_ids[128]
[6] CPU: Read 128 tokens from token_ids[1..128]
The crucial insight: tokens are chained on the GPU. Token 0’s argmax writes its output to token_ids[1]. Token 1’s embedding reads from token_ids[1]. This data dependency is resolved by the GPU’s sequential execution within a single command buffer – no CPU round-trip between tokens.2
The CPU’s only job is:
- Write the first token (4 bytes)
- Encode the command sequence (~2-5ms of Metal API calls for 128 tokens)
- Submit and wait (~14ms * 128 = ~1.8 seconds of GPU time)
- Read the results (128 * 4 = 512 bytes)
Pre-Allocated Parameter Buffers
After building all commands, the table builder pre-allocates GPU buffers for every command’s parameters:
for (auto& cmd : cmds) {
if (cmd.param_size > 0) {
cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
}
}
This creates one small Metal buffer (16-64 bytes) per command. For 260 commands, that is about 16 KB of additional GPU memory. The benefit is that setBuffer at replay time is a simple pointer bind, while setBytes would copy data into the command buffer.
For position-patched commands, param_buf is still allocated but is not used during replay (the patched version is passed via setBytes instead). This is a minor waste (~6 KB) but simplifies the code.
Profiling Labels
Labels are stored separately to keep the hot-path array dense:
struct DispatchLabel {
char text[48];
};
struct DispatchTable {
std::vector<DispatchCmd> commands; // HOT: iterated every token
std::vector<DispatchLabel> labels; // COLD: accessed only during profiling
};
The label for command i is labels[i].text. Labels are set during table building:
table.set_last_label("layer.0.attention");
And accessed during profiling:
const char *label = table.label_at(cmd_index);
The finalize_labels() call at the end of table building pads the labels vector to match the commands vector size, ensuring safe index access.
Comparison to Other Engines
To appreciate the dispatch table approach, compare it to how other inference engines handle command construction:
| Engine | Command Construction | Per-Token Overhead |
|---|---|---|
| llama.cpp | Build ggml graph -> schedule -> Metal encode | Graph traversal + scheduling per token |
| MLX | Lazy evaluation DAG -> compile -> Metal encode | DAG construction + compilation per sequence |
| vLLM | PyTorch eager/compiled -> CUDA kernels | Python overhead + CUDA launch per kernel |
| Akunu | Pre-built dispatch table -> flat array replay | Patch ~100 uint32 values + Metal encode |
Akunu’s approach trades flexibility (you cannot dynamically change the model graph) for speed (zero per-token overhead beyond patching and Metal API calls). This is the right trade for inference, where the computation graph is static.
Limitations
The dispatch table approach has some constraints:
-
Static graph only. The forward pass must be identical for every token (modulo position patching). Dynamic architectures with variable-length layers or conditional computation would not fit this model.
-
Memory overhead. Each command stores buffer handles and param bytes inline. For a large model with many layers, this is ~70 KB – trivial compared to the gigabytes of weights.
-
Single-device. The table is built for one device and assumes all buffers are on that device. Multi-device (tensor parallel) would require multiple tables with cross-device synchronization.
-
No dynamic batching during decode. The table processes one token at a time (repeated N times). Batching multiple independent sequences would require separate tables or a more complex patching scheme.
Summary
The dispatch table is akunu’s key innovation. It pre-compiles the entire forward pass into a flat array of ~200-260 DispatchCmd structs, each containing a pre-resolved pipeline, pre-bound buffers, pre-serialized parameters, and pre-computed dispatch geometry. At inference time, the replay loop iterates this array, patching only position-dependent fields. The Metal fast path eliminates virtual call overhead, and the setBuffer/setBytes split minimizes per-token parameter binding cost. The result: the CPU spends almost all of its time doing useful Metal API calls, not constructing or resolving commands.
-
The L2 cache on Apple Silicon GPU cores is approximately 256 KB per core cluster (estimated from die analysis). A 70 KB dispatch table fits entirely within a single core cluster’s L2. See Dougall Johnson, “Apple GPU Architecture,” https://dougallj.github.io/applegpu/. ↩
-
Metal guarantees that compute dispatches within a single compute command encoder execute in order. The data dependency between token N’s argmax output and token N+1’s embedding input is automatically satisfied by this ordering guarantee. No explicit barriers are needed. See Apple, “Metal Programming Guide: Command Organization,” https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu. ↩
The Prefill Phase
When you send a prompt to an LLM, something interesting happens before the model starts generating text: every single prompt token gets processed in parallel. This is the prefill phase, and it is architecturally distinct from the token-by-token decode phase that follows. In Akunu, the prefill phase lives in src/core/prefill.cpp and is responsible for transforming a sequence of token IDs into a populated KV cache, a set of logits for the last position, and the first predicted token.
This chapter will walk through the entire prefill pipeline, from the moment token IDs hit the GPU to the moment we get our first generated token back. Along the way, we will see why prefill uses GEMM (matrix-matrix multiply) instead of GEMV (matrix-vector multiply), how chunked prefill works, and why the whole thing is fundamentally different from decode.
Why Prefill Is Not Just “Decode but Faster”
At first glance, you might think: “Well, prefill is just running the model on all the prompt tokens. Can’t we just call decode N times?” Technically yes, but doing so would be absurdly slow. The key insight is that prefill processes all tokens simultaneously through the transformer layers, whereas decode processes one token at a time.1
The difference comes down to arithmetic intensity:
| Phase | Operation | Matrix Shape | Bottleneck |
|---|---|---|---|
| Prefill | GEMM | [seq_len, K] @ [N, K]^T | Compute-bound |
| Decode | GEMV | [1, K] @ [N, K]^T | Memory-bound |
During prefill with a prompt of length S, the activation tensor is [S, dim] rather than [1, dim]. This means the weight matrix gets reused across S rows of activations, giving us an arithmetic intensity that scales with S. A 2048-token prefill reuses every weight element 2048 times instead of once. This is why prefill can achieve throughput measured in thousands of tokens per second, while decode is measured in tens of tokens per second.2
In Akunu’s code, this shows up directly. The encode_prefill function signature tells the story:
uint32_t encode_prefill(Device& device, WeightProvider& weights,
const AkunuModelConfig& cfg, const ArchDescriptor& arch,
KVCache& kv_cache, ScratchBuffers& scratch,
const uint32_t *token_ids, int seq_len,
int start_position = 0);
It takes the full array of token_ids and seq_len, processes everything in one shot, and returns a single uint32_t: the predicted next token.
The Prefill Pipeline at a Glance
Before diving into the code, let’s lay out the full pipeline. Every token in the prompt passes through these stages:
┌────────────────────────────────────────────┐
│ PREFILL PIPELINE │
├────────────────────────────────────────────┤
│ │
token_ids[S] ───>│ 1. BATCH EMBEDDING [S] → [S, dim] │
│ 2. EMBEDDING NORM (if needed) │
│ 3. FIRST LAYER NORM (RMSNorm/LayerNorm) │
│ │
│ ┌─── LAYER LOOP (×N layers) ───────────┐ │
│ │ a. QKV GEMM projections │ │
│ │ b. QK-norm (if model has it) │ │
│ │ c. Fused RoPE + KV cache write │ │
│ │ d. Flash Attention (prefill kernel) │ │
│ │ e. Output GEMM projection │ │
│ │ f. Fused residual + FFN norm │ │
│ │ g. Gate GEMM + Up GEMM │ │
│ │ h. SiLU activation │ │
│ │ i. Down GEMM projection │ │
│ │ j. Fused residual + next attn norm │ │
│ └───────────────────────────────────────┘ │
│ │
│ 4. OUTPUT NORM (final RMSNorm) │
│ 5. LOGIT GEMV (last position only) │
│ 6. ARGMAX → next token │
│ │
└────────────────────────────────────────────┘
Notice step 5: even though the entire prompt flows through the transformer in parallel, we only need logits for the last position. This is because autoregressive generation predicts the next token, and the “next” token after the prompt corresponds to the last position’s output. Akunu exploits this by using a GEMV (single-row projection) for the logit computation instead of a full GEMM.3
Chunked Prefill
Long prompts cannot always be processed in a single batch. GPU memory for scratch buffers is proportional to seq_len * dim, and for a 128K-token prompt on a model with dim=4096, that is over 1 GB just for one activation buffer. Akunu addresses this with chunked prefill: the prompt is split into chunks of at most max_prefill_chunk tokens (default: 4096), and each chunk is processed separately.
Here is the chunking logic from run_decode_loop:
int chunk_size = state.scratch.max_prefill_chunk;
uint32_t next_token = 0;
int prefill_pos = 0;
while (prefill_pos < n_prompt) {
int chunk = std::min(chunk_size, n_prompt - prefill_pos);
next_token = encode_prefill(*state.device, *state.weights, state.config,
state.arch, state.kv_cache, state.scratch,
prompt_tokens + prefill_pos, chunk,
start_pos + prefill_pos);
prefill_pos += chunk;
}
Each call to encode_prefill processes chunk tokens starting at position start_pos + prefill_pos. The start_position parameter is critical: it tells the RoPE kernel where in the sequence these tokens actually belong, and it tells the KV cache where to write the K and V entries. Without correct position tracking, chunked prefill would produce garbage because the positional embeddings would be wrong.
The chunk size of 4096 is configured in ChipConfig:
c.max_prefill_chunk = 4096;
This value balances GPU utilization (larger chunks mean more parallelism and better GEMM efficiency) against memory pressure (scratch buffers must be allocated for the maximum chunk size).
Stage 1: Batch Embedding
The first step converts token IDs to dense vectors. Akunu dispatches an embedding lookup kernel that reads from the embedding weight table and writes to the batch_h0 scratch buffer:
Pipeline pso = device.get_pipeline(device.embedding_kernel_for(emb_dtype));
device.set_pipeline(pso);
device.set_buffer(scratch.token_ids, 0, 0); // input: token IDs
device.set_buffer(emb_weight, 0, 1); // weights: [vocab, dim]
device.set_buffer(scratch.batch_h0, 0, 2); // output: [seq_len, dim]
device.dispatch_threads(Dim3(dim, seq_len), Dim3(std::min(dim, 256)));
The dispatch is a simple 2D grid: one thread per (dimension, token) pair. Each thread copies one element from the embedding table, so the entire [seq_len, dim] output matrix is filled in a single dispatch.
For models with embedding scaling (like Gemma), an additional kernel multiplies every element by the scale factor:
if (arch.embedding_scale > 0.0f) {
Pipeline scale_pso = device.get_pipeline("temperature_scale_f16");
// ... applies: batch_h0[i] *= embedding_scale for all elements
}
This reuses the temperature_scale_f16 kernel (originally written for sampling) – a nice example of kernel reuse. The same “multiply every element by a scalar” operation appears in multiple contexts.
MLX Quantized Embeddings
When the model uses MLX-format quantized weights, even the embedding table is quantized. The embedding lookup kernel (embedding_lookup_mlx_q4) must dequantize on-the-fly:
const uint32_t word = W_u32[token_id * n_u32_per_row + u32_idx];
const uint qval = (word >> (within * bits)) & mask;
output[token_idx * K + d_idx] = half(s * float(qval) + b);
Each 32-bit word packs 32/bits quantized values. The kernel extracts the relevant nibble (or bit group), multiplies by the per-group scale, adds the per-group bias, and writes the dequantized FP16 value. The cost is negligible – embedding lookup is always memory-bound, and the few extra ALU ops for dequantization are effectively free.
Stage 2: Normalization
Before entering the layer loop, the activations must be normalized. Akunu supports two paths:
| Model Type | Normalization | Has Bias? |
|---|---|---|
| Standard LLM (Llama, Gemma, etc.) | RMSNorm | No |
| BERT/Encoder models | LayerNorm | Yes |
The choice is driven by whether the layer’s norm weight has an associated bias tensor:
Buffer norm_b = weights.get_tensor("layers.0.attention_norm.bias");
if (norm_b.handle) {
// LayerNorm path (BERT)
Pipeline pso = device.get_pipeline("layernorm_f16");
} else {
// RMSNorm path (standard LLM)
Pipeline pso = device.get_pipeline("rmsnorm_f16");
}
Both kernels operate on batch_h0 (shape [seq_len, dim]) and produce batch_residual. The dispatch is Dim3(seq_len) threadgroups, each with up to 1024 threads – one threadgroup per row (token position), with threads cooperatively computing the norm statistics.
Stage 3: The Layer Loop
The layer loop is where the bulk of prefill computation happens. For each of the n_layers transformer layers, we execute approximately 10 GPU dispatches. Let’s walk through each.
3a. QKV Projections
The attention mechanism needs three projections: Q (query), K (key), and V (value). Akunu supports two approaches:
Fused QKV (BERT-style): A single GEMM produces the concatenated [Q|K|V] output, which is then split with a GPU kernel:
dispatch_gemm(device, scratch.batch_residual, qkv_w, scratch.batch_gate,
seq_len, qkv_dim, dim, qkv_dtype, ...);
// Then split on GPU:
Pipeline split_pso = device.get_pipeline("qkv_split_f16");
This is more efficient when Q, K, and V share the same weight dtype and the combined projection fits nicely into a single GEMM dispatch.
Separate Q/K/V (standard LLM): Three independent GEMMs:
dispatch_gemm(device, scratch.batch_residual, q_w, scratch.batch_q,
seq_len, q_dim, dim, q_dtype, ...);
dispatch_gemm(device, scratch.batch_residual, k_w, scratch.batch_k,
seq_len, kv_dim, dim, k_dtype, ...);
dispatch_gemm(device, scratch.batch_residual, v_w, scratch.batch_v,
seq_len, kv_dim, dim, v_dtype, ...);
For GQA models where kv_dim < q_dim, separate projections are actually better because the K and V GEMMs are smaller and can be dispatched with appropriately sized grids.
3b. The GEMM Dispatch Function
Every projection in prefill goes through dispatch_gemm, which is the central GEMM dispatcher. It handles:
- Kernel selection: The dtype descriptor tells it which kernel to use (e.g.,
simd_gemm_f16,simd_gemm_q4_0). - Small-M optimization: For
Mbetween 2 and 8, it uses the “small” GEMM variant withTM=8instead ofTM=32, avoiding wasted computation on padding. - Function constant specialization: When
Kis a multiple of 32, the kernel is specialized withKbaked in as a compile-time constant. This eliminates a register and enables loop unrolling. - MLX format handling: MLX quantized weights need extra params (
group_size,bits,weight_bytes).
The tile geometry is fixed:
| Parameter | Value | Meaning |
|---|---|---|
| TM | 32 (or 8 for small M) | Activation rows per tile |
| TN | 64 | Weight rows per tile |
| TK | 32 | K-dimension per accumulation step |
| Threadgroup | (32, 4) = 128 threads | 4 SIMD groups |
The grid is computed as:
int gridX = (N + TN - 1) / TN;
int gridY = (M + TM - 1) / TM;
device.dispatch(Dim3(gridX, gridY), Dim3(32, 4));
Threadgroup memory is allocated for the cooperative tile loading:
int loadBytes = (TN * TK + TM * TK) * 2; // weight + activation tiles in FP16
int storeBytes = TN * TM * 4; // output tile in FP32
int tgMem = std::max(loadBytes, storeBytes); // reuse same memory
The max here is clever: during the accumulation phase, the memory holds the input tiles; during the output phase, it holds the result tile. They never overlap in time, so the same memory region serves both purposes.4
3c. QK-Norm
Some models (notably DeepSeek, Gemma 2) apply RMSNorm to the Q and K projections per head before attention. This ensures that the dot product scores stay in a reasonable range regardless of head dimension:
if (arch.has_qk_norm) {
Pipeline hn_pso = device.get_pipeline("head_rmsnorm_f16");
// Q norm: grid = (n_heads, seq_len), threads = head_dim
device.dispatch(Dim3(n_heads, seq_len), Dim3(head_dim));
// K norm: grid = (n_kv_heads, seq_len), threads = head_dim
device.dispatch(Dim3(n_kv_heads, seq_len), Dim3(head_dim));
}
The dispatch geometry is interesting: one threadgroup per (head, position) pair, with head_dim threads per group. Each threadgroup normalizes exactly one head’s worth of data.
3d. Fused RoPE + KV Cache Write
This is one of the most performance-sensitive dispatches in prefill. A single kernel handles three operations simultaneously:
- Apply RoPE to Q (in-place)
- Apply RoPE to K and write to the K cache
- Copy V to the V cache
const char *fused_kernel = is_neox
? "rope_neox_batch_kv_write_f16"
: "rope_batch_kv_write_f16";
Two RoPE variants are supported: the original “interleaved” layout (pairs of adjacent elements are rotated) and the “neox” layout (first and second halves are rotated). The fused kernel processes all seq_len positions in parallel, writing each K/V vector to its correct position in the KV cache based on start_position + position_within_batch.
The dispatch grid is:
device.dispatch_threads(Dim3(head_dim / 2, n_heads, seq_len),
Dim3(std::min(head_dim / 2, 32)));
Each thread handles one complex pair (two elements) of one head at one position.
3e. Flash Attention (Prefill)
Prefill attention is where things get really interesting. Unlike decode attention (where each query has one row), prefill attention has seq_len query rows, all attending to kv_seq_len key-value positions. Akunu selects between three attention kernels based on sequence length:
┌─────────────────────┐
│ seq_len check │
└──────┬──────────────┘
│
┌────────────────┼────────────────┐
│ │ │
seq_len >= 1024 seq_len >= thresh otherwise
│ │ │
▼ ▼ ▼
┌─────────────────┐ ┌──────────┐ ┌────────────────┐
│ Prefill V2 │ │ Prefill │ │ Decode kernel │
│ BQ=32, register │ │ V1 │ │ (per-query TG) │
│ output, exp2 │ │ simd MMA │ │ │
└─────────────────┘ └──────────┘ └────────────────┘
Prefill V2 (flash_attention_prefill_v2_f16): For long sequences (>= 1024), this kernel processes 32 query rows per threadgroup using simdgroup matrix multiply-accumulate (MMA). Output stays in registers rather than threadgroup memory, saving 16KB per threadgroup. It uses exp2 instead of exp for faster softmax computation.5
Prefill V1 (flash_attention_prefill_f16): For medium sequences, this kernel processes 8 or 16 query rows per threadgroup with simdgroup MMA and threadgroup memory for the output accumulator.
Decode fallback: For very short sequences (< threshold), it simply dispatches one threadgroup per query position using the decode attention kernel. The threshold is:
int nq_rows = (head_dim <= 64) ? 16 : (head_dim <= 128) ? 8 : 0;
int v2_threshold = (nq_rows > 0) ? std::max(nq_rows * 2, 16) : INT_MAX;
All three support non-causal attention for encoder models (BERT) via function constant specialization:
constant bool FC_NON_CAUSAL [[function_constant(3)]];
When FC_NON_CAUSAL is true, the causal mask is skipped, allowing every position to attend to every other position.
3f-3j. FFN Block
After attention, the FFN block follows the standard SwiGLU pattern:
-
Fused residual + FFN norm: Adds the attention output to the residual stream and normalizes. Uses
residual_rmsnorm_f16for standard LLMs or decomposedvector_add_f16+layernorm_f16for BERT. -
Gate GEMM:
batch_attn_out[S, dim]->batch_gate[S, ffn_dim] -
Up GEMM:
batch_attn_out[S, dim]->batch_up[S, ffn_dim] -
SiLU activation:
batch_act = SiLU(batch_gate) * batch_upelement-wise -
Down GEMM:
batch_act[S, ffn_dim]->batch_residual[S, dim] -
Fused residual + next attn norm: Prepares the residual stream for the next layer.
The Gate and Up GEMMs read from the same input buffer, which is great for cache locality on the GPU side – the activation data loaded for Gate is still in the SLC when Up runs.
Post-Attention and Post-FFN Norms (Gemma 3)
Some architectures add extra normalization after attention and/or FFN outputs. Akunu handles these via descriptor flags:
if (arch.has_post_attn_norm) {
// RMSNorm on batch_residual → batch_post_norm
}
if (arch.has_post_ffn_norm) {
// RMSNorm on batch_residual → batch_post_norm
}
This is driven by ArchDescriptor, so adding support for a new model that uses post-norms requires zero code changes – just setting the descriptor flags.
Stage 4: Output Norm
After the layer loop, the final hidden states need one more normalization. For standard LLMs, this is a fused residual + RMSNorm:
Pipeline pso = device.get_pipeline("residual_rmsnorm_f16");
device.set_buffer(last_fused_input_ffn, 0, 0); // last layer's output
device.set_buffer(scratch.batch_h1, 0, 1); // residual stream
device.set_buffer(output_norm_w, 0, 2); // norm weights
device.set_buffer(scratch.batch_h0, 0, 3); // updated residual
device.set_buffer(scratch.batch_residual, 0, 4); // normed output
For embedding models (BERT), there is no output norm – the final residual is directly used for mean-pooling:
if (arch.is_embedding_model) {
// vector_add: batch_residual = last_fused_input_ffn + batch_h1
device.end_encoding_sync();
kv_cache.advance(seq_len);
return 0; // no logit token for embedding models
}
Stage 5: Logit Projection (Last Position Only)
Here is where prefill gets clever. We have batch_residual of shape [seq_len, dim], but we only need logits for the last token. Instead of running a full GEMM ([seq_len, dim] @ [vocab_size, dim]^T), Akunu extracts just the last row and runs a GEMV:
int last_row_offset = (seq_len - 1) * dim * 2; // byte offset
device.set_buffer(scratch.batch_residual, last_row_offset, 0);
device.set_buffer(logit_w, 0, 1);
device.set_buffer(scratch.logits, 0, 2);
device.dispatch(Dim3(n_groups), Dim3(logit_dt.gemv_tg_size));
The last_row_offset is used as a buffer offset, so the GEMV kernel sees only the last row’s data starting at buffer index 0. This converts a potentially massive [seq_len, vocab_size] GEMM into a single [1, vocab_size] GEMV – a huge savings when vocab_size is 128K+ tokens.
For models with tied embeddings (where the output projection reuses the embedding table), the same weight buffer is used:
const char *logit_name = arch.tie_embeddings
? "token_embedding.weight"
: "output.weight";
Stage 6: Argmax
The final step finds the most probable next token:
Pipeline pso = device.get_pipeline("argmax_f16");
device.set_pipeline(pso);
device.set_buffer(scratch.logits, 0, 0);
device.set_buffer(scratch.token_ids, 0, 1);
uint32_t vocab = cfg.vocab_size;
device.set_bytes(&vocab, sizeof(vocab), 2);
device.dispatch(Dim3(1), Dim3(1024));
One threadgroup of 1024 threads. Each thread scans a strided portion of the logits, finds its local maximum, then a two-level SIMD reduction finds the global maximum. The winning token ID is written to scratch.token_ids[0].
After end_encoding_sync() (which waits for GPU completion), the result is read back:
device.end_encoding_sync();
kv_cache.advance(seq_len);
return ((uint32_t *)scratch.token_ids.contents)[0];
Note kv_cache.advance(seq_len): this updates the cache’s write pointer so that subsequent decode operations know where the cached K/V data ends.
Scratch Buffer Layout
Prefill requires a constellation of temporary GPU buffers. Here is what each one holds:
| Buffer | Shape | Used For |
|---|---|---|
batch_h0 | [S, dim] | Embedding output, residual stream |
batch_h1 | [S, dim] | Residual stream (second copy) |
batch_residual | [S, dim] | Normed activations, GEMM input |
batch_q | [S, q_dim] | Query projections |
batch_k | [S, kv_dim] | Key projections |
batch_v | [S, kv_dim] | Value projections |
batch_attn_out | [S, q_dim] | Attention output, FFN norm output |
batch_gate | [S, ffn_dim] | Gate projection, also used for fused QKV |
batch_up | [S, ffn_dim] | Up projection |
batch_act | [S, ffn_dim] | SiLU(gate) * up |
batch_post_norm | [S, dim] | Post-attn/FFN norm scratch |
logits | [vocab_size] | Final logit buffer |
token_ids | [max_chunk+1] | Input token IDs + argmax result |
All buffers are allocated at model load time for S = max_prefill_chunk. The total memory is roughly:
S * (5*dim + 2*q_dim + 2*kv_dim + 3*ffn_dim) * 2 bytes
For a typical 7B model (dim=4096, q_dim=4096, kv_dim=1024, ffn_dim=14336) with S=4096, this is about 700 MB.
The BERT/Encoder Path
Akunu also supports encoder-only models (BERT, nomic-bert) through encode_prefill_bert. The key differences from the LLM path:
| Aspect | LLM Prefill | BERT Prefill |
|---|---|---|
| Positional encoding | RoPE (during attention) | Learned absolute embeddings (added to tokens) |
| Normalization | RMSNorm | LayerNorm (with bias) |
| Attention mask | Causal | Non-causal (all-to-all) |
| FFN activation | SiLU/SwiGLU | GELU (no gate) |
| Linear bias | No | Yes |
| Output | Logits + argmax | Raw hidden states for pooling |
The BERT path is activated when arch.is_embedding_model is true. Most of the code is shared – the layer loop is identical, just using different kernel variants selected by the descriptor system.
Timing and Statistics
Prefill timing is captured for performance reporting:
auto prefill_start = std::chrono::high_resolution_clock::now();
// ... prefill ...
auto prefill_end = std::chrono::high_resolution_clock::now();
double prefill_ms = std::chrono::duration<double, std::milli>(
prefill_end - prefill_start).count();
stats.prefill_time_ms = (float)prefill_ms;
stats.prefill_tokens_per_sec = (float)(n_prompt * 1000.0 / prefill_ms);
Note that prefill_tokens_per_sec divides the total number of prompt tokens by the wall-clock time, including all chunks. This gives the user-facing throughput metric that represents actual prompt processing speed.
Performance Characteristics
Prefill performance depends on several factors:
Typical prefill throughput on Apple Silicon:
| Hardware | Model | Quant | Prompt Tokens/sec |
|---|---|---|---|
| M1 Max (32 GPU) | Llama 3.1 8B | Q4_0 | ~800-1200 |
| M2 Ultra (76 GPU) | Llama 3.1 8B | Q4_0 | ~2500-3500 |
| M4 Pro (20 GPU) | Llama 3.1 8B | Q4_0 | ~1000-1500 |
These numbers are heavily dominated by GEMM throughput, which scales roughly linearly with GPU core count.
Summary
The prefill phase is conceptually simple – run the full transformer on all prompt tokens – but the implementation details matter enormously for performance:
- GEMM not GEMV: Processing all tokens simultaneously gives orders-of-magnitude better arithmetic intensity.
- Chunked execution: Bounded scratch memory via configurable chunk sizes.
- Last-position-only logits: A GEMV instead of a GEMM for the output projection, saving
(seq_len - 1) * vocab_sizeunnecessary computations. - Fused operations: RoPE + KV write, residual + norm, and activation + gating are all fused to minimize GPU dispatch overhead and memory traffic.
- Architecture-driven dispatch: The same code handles LLMs and BERT-style encoders through the descriptor system.
In the next chapter, we will see what happens after prefill completes: the decode loop that generates tokens one at a time.
-
Vaswani, A., et al. “Attention Is All You Need.” NeurIPS 2017. The autoregressive property means each token depends on all previous tokens, which is why generation must be sequential while prompt processing can be parallel. See https://arxiv.org/abs/1706.03762. ↩
-
Pope, R., et al. “Efficiently Scaling Transformer Inference.” MLSys 2023. This paper provides an excellent analysis of the memory-bound vs compute-bound regimes of transformer inference. See https://arxiv.org/abs/2211.05102. ↩
-
This “last position only” optimization is standard in all LLM inference engines. During prefill, all positions produce hidden states, but only the last position’s logits determine the next token. Some engines (e.g., vLLM) allow returning all logits for perplexity evaluation. ↩
-
This is a common GPU programming pattern called “ping-pong buffering” or “buffer aliasing.” The Metal runtime does not enforce temporal aliasing rules for threadgroup memory, so as long as barriers are placed correctly, the same memory can serve different purposes at different times. ↩
-
The
exp2trick replacesexp(x)withexp2(x * log2(e)). On Apple Silicon,fast::exp2uses the hardware transcendental unit and is faster thanexp. The scale factor is pre-multiplied into Q during loading. ↩
The Decode Loop
After prefill populates the KV cache and produces the first token, the engine enters the decode loop: the main generation loop that produces tokens one at a time. This is where the user experiences “streaming” – each token appears as soon as it is generated, and the loop continues until a stop condition is met.
In Akunu, the decode loop lives in src/inference/decode_loop.cpp, and the four decode paths it can dispatch are declared in src/inference/decode_paths.h. The loop itself is surprisingly compact – most of the complexity lives in the individual decode paths – but the orchestration logic contains several subtleties worth understanding.
The Four Decode Paths
Akunu does not have a single “decode” function. Instead, it offers four distinct decode paths, each optimized for a different scenario:
┌─────────────────────────┐
│ run_decode_loop() │
│ │
│ After prefill: │
│ Choose decode path │
└──────────┬──────────────┘
│
Has grammar? │
┌────────────────┼────────────────────┐
│ YES │ NO │
▼ │ │
┌─────────────────┐ │ │
│ decode_grammar │ │ temperature > 0? │
│ (constrained) │ ├──────────────────┐ │
└─────────────────┘ │ YES │ NO│
▼ │ │
┌─────────────────┐ │ │
│ decode_sampled │ │ │
│ (Gumbel-max GPU) │ │ │
└─────────────────┘ │ │
│ │
speculation? │ │
┌──────────────┤ │
│ YES │NO│
▼ ▼ │
┌──────────────────┐ ┌──────────────┐
│decode_speculative│ │decode_greedy │
│(n-gram draft) │ │(chain decode) │
└──────────────────┘ └──────────────┘
Here is the decision logic from the source:
if (grammar) {
generated += decode_grammar(state, model, next_token, pos, max_tokens - generated,
sampling, prompt_tokens, n_prompt, *grammar,
callback, user_data);
} else {
bool use_sampling = (sampling.temperature > 0.0f);
if (use_sampling) {
generated += decode_sampled(state, model, next_token, pos,
max_tokens - generated, sampling,
prompt_tokens, n_prompt, callback, user_data);
} else if (state.speculation_enabled) {
generated += decode_speculative(state, model, next_token, pos,
max_tokens - generated, prompt_tokens,
n_prompt, callback, user_data);
} else {
generated += decode_greedy(state, model, next_token, pos,
max_tokens - generated, callback, user_data);
}
}
Let’s characterize each path:
| Path | When Used | GPU Roundtrips | Sampling | CPU Involvement |
|---|---|---|---|---|
decode_greedy | temperature=0, no grammar, no speculation | 1 per chunk | argmax on GPU | Minimal – read token IDs |
decode_sampled | temperature>0, no grammar | 1 per chunk | Gumbel-max on GPU | Minimal – read token IDs |
decode_speculative | temperature=0, speculation enabled | 1 per batch | argmax on GPU | n-gram prediction |
decode_grammar | grammar constraints active | 1 per token | CPU (grammar mask) | Heavy – grammar state per token |
The Decode Loop Entry Point
Let’s walk through run_decode_loop in detail. The function signature reveals the full set of inputs:
AkunuGenerationStats run_decode_loop(
ModelState& state, akunu_model_t model,
const uint32_t *prompt_tokens, int n_prompt,
int start_pos, int max_tokens,
AkunuSamplingConfig sampling,
akunu_token_callback callback, void *user_data,
GrammarHandle *grammar = nullptr);
The callback is what enables streaming. Every time a token is generated, the callback receives both the token ID and its decoded text:
typedef bool (*akunu_token_callback)(uint32_t token, const char *text, void *user_data);
If the callback returns false, generation stops immediately. This is how the host application can implement user cancellation or stop-sequence detection at the application level.
Phase 1: Prefill
The first thing the decode loop does is run prefill (covered in the previous chapter):
auto prefill_start = std::chrono::high_resolution_clock::now();
int chunk_size = state.scratch.max_prefill_chunk;
uint32_t next_token = 0;
int prefill_pos = 0;
while (prefill_pos < n_prompt) {
int chunk = std::min(chunk_size, n_prompt - prefill_pos);
next_token = encode_prefill(*state.device, *state.weights, state.config,
state.arch, state.kv_cache, state.scratch,
prompt_tokens + prefill_pos, chunk,
start_pos + prefill_pos);
prefill_pos += chunk;
}
auto prefill_end = std::chrono::high_resolution_clock::now();
The timing is precise: only the prefill computation itself is measured, not the decode loop setup.
Phase 2: First Token Handling
The first generated token comes from the prefill argmax (or from sampling when grammar is active). This token gets special treatment:
int generated = 0;
int pos = start_pos + n_prompt;
bool stopped = false;
if (grammar) {
// Read logits, apply grammar mask, sample/argmax
} else {
// Use prefill's argmax result directly
if (!state.tokenizer.is_eos(next_token)) {
generated++;
if (callback) {
const char *text = decode_token_text(state, next_token);
if (!callback(next_token, text, user_data))
stopped = true;
}
} else {
stopped = true;
}
}
When there is no grammar, the first token is essentially “free” – it came from the prefill’s argmax and requires no additional GPU work. When grammar is active, the first token requires reading the F16 logits from GPU memory, converting to F32, applying the grammar bitmask, and then sampling:
const __fp16 *f16 = (const __fp16 *)state.device->buffer_contents(state.scratch.logits);
for (uint32_t i = 0; i < vocab_count; i++)
logits[i] = (float)f16[i];
// Apply grammar mask
grammar->legacy.apply(logits, vocab_count);
The F16-to-F32 conversion happens on the CPU here. This is one of the cases where the CPU path is unavoidable: grammar bitmasks need to be computed on the CPU (they depend on the grammar state machine), and it is cheaper to convert the logits to F32 on the CPU than to launch another GPU kernel just for the conversion.
Phase 3: Main Decode Loop
After the first token, the selected decode path takes over:
auto decode_start = std::chrono::high_resolution_clock::now();
if (!stopped && generated < max_tokens) {
// ... dispatch to selected path ...
}
auto decode_end = std::chrono::high_resolution_clock::now();
Each decode path returns the number of tokens generated. The decode loop tracks generated and ensures we never exceed max_tokens.
The Grammar-Constrained Path
Grammar-constrained decoding (decode_grammar) is the most complex path because it must synchronize with the grammar state machine on every token. This means:
- Run the forward pass on the GPU
- Read logits back to CPU
- Convert F16 -> F32
- Apply the grammar bitmask (setting disallowed tokens to
-inf) - Sample or argmax from the masked logits
- Update the grammar state machine with the accepted token
- Repeat
Akunu supports two grammar backends:
#ifdef AKUNU_HAS_XGRAMMAR
if (grammar->use_xgrammar) {
int bm_size = grammar->xgrammar.bitmask_size();
std::vector<int32_t> bm(bm_size);
grammar->xgrammar.fill_next_token_bitmask(bm.data());
for (uint32_t i = 0; i < vocab_count; i++) {
if (!((bm[i / 32] >> (i % 32)) & 1))
logits[i] = -std::numeric_limits<float>::infinity();
}
} else
#endif
{
grammar->legacy.apply(logits, vocab_count);
}
The XGrammar backend uses a compact bitmask representation: one bit per vocabulary token, packed into 32-bit integers. A token is allowed if its corresponding bit is 1. This is remarkably efficient – for a 128K vocabulary, the bitmask is only 16KB.1
Grammar decode is inherently sequential and CPU-bound. Each token requires a full GPU->CPU->GPU roundtrip, which limits throughput to perhaps 20-30 tokens/sec even on fast hardware. This is why Akunu only uses this path when grammar constraints are explicitly requested.
The Sampled Path (GPU-Driven)
The sampled decode path (decode_sampled) is fascinating because it achieves the same throughput as greedy decoding while sampling from the full probability distribution. The secret: Gumbel-max sampling on the GPU.2
The key insight is that argmax(logit + temperature * Gumbel_noise) is equivalent to sampling from Categorical(softmax(logit / temperature)). This means we can replace the entire CPU sampling pipeline with:
- Apply Gumbel noise to logits (GPU kernel)
- Argmax (GPU kernel)
No CPU roundtrip. No softmax. No random number generation on the CPU.
// The sampled_dispatch_table includes gumbel_temperature + argmax
DispatchTable& table = have_sampled_table
? state.sampled_dispatch_table
: state.dispatch_table;
while (generated < max_tokens) {
int n = std::min(chunk, remaining);
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
state.device->begin_encoding();
state.device->encode_dispatch_table(&table, pos, n);
state.device->end_encoding_sync();
state.kv_cache.advance(n);
pos += n;
// Read results...
}
The sampled_dispatch_table is identical to the dispatch_table (greedy) except that it has a gumbel_topk_f16 kernel inserted before the argmax. The Gumbel kernel also handles top-k, top-p, min-p, and repetition penalty – all on the GPU.
The sampling parameters are patched into the kernel’s parameter buffer before the loop starts:
auto& gumbel_cmd = cmds[cmds.size() - 2]; // second-to-last command
memcpy(gumbel_cmd.param_bytes + 4, &temp, sizeof(float));
memcpy(gumbel_cmd.param_bytes + 12, &seed_base, sizeof(uint32_t));
memcpy(gumbel_cmd.param_bytes + 16, &top_k, sizeof(int32_t));
memcpy(gumbel_cmd.param_bytes + 20, &top_p, sizeof(float));
The position field in the Gumbel params is patched per-token by the dispatch table’s PATCH_POSITION mechanism, which ensures each token gets a unique seed for the Gumbel noise RNG.
The Speculative Path
Speculative decoding (decode_speculative) uses an n-gram predictor to guess multiple tokens ahead, then verifies them in a single batched forward pass:
auto drafts = state.predictor.predict();
int n_draft = std::min((int)drafts.size(), max_tokens - generated - 1);
uint32_t *chain_buf = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
chain_buf[0] = next_token;
for (int i = 0; i < n_draft; i++)
chain_buf[i + 1] = drafts[i];
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, batch_size);
state.device->end_encoding_sync();
The batch of [1 + n_draft] tokens is processed in a single GPU submission using the chain decode mechanism (covered in the next chapter). The GPU produces an argmax token for each position. If the draft token at position i matches the argmax at position i-1, it was correctly predicted and we get a “free” token. The speculative path can generate 2-3 tokens per GPU submission when the n-gram predictor is accurate.
Stop Conditions
The decode loop checks several stop conditions:
- EOS token:
state.tokenizer.is_eos(tok)– the model emits an end-of-sequence token. - Max tokens reached:
generated >= max_tokens. - Callback cancellation: the token callback returns
false. - Grammar completion:
grammar->xgrammar.is_terminated() || grammar->xgrammar.is_completed().
These checks happen after each token (or chunk of tokens for chain decode), ensuring responsive stopping behavior.
Statistics and Timing
The decode loop returns detailed statistics:
AkunuGenerationStats stats = {};
stats.prompt_tokens = n_prompt;
stats.prefill_time_ms = (float)prefill_ms;
stats.generated_tokens = generated;
stats.decode_time_ms = (float)decode_ms;
stats.decode_tokens_per_sec = (float)(generated * 1000.0 / decode_ms);
stats.prefill_tokens_per_sec = (float)(n_prompt * 1000.0 / prefill_ms);
return stats;
The decode_tokens_per_sec metric is the one that matters most for user experience – it determines how fast text appears on screen.
The Token-by-Token Flow
Let’s trace the lifecycle of a single token through the decode loop in the common case (greedy decode, no grammar):
Token lifecycle (greedy, no grammar):
1. Write token_id to GPU buffer
│
▼
2. Encode dispatch table (N commands)
│
▼
3. GPU executes forward pass
│
▼
4. Read result token from buffer
│
▼
5. Stream to callback
Step 2 is the critical one: the dispatch table encodes the entire forward pass as a pre-compiled sequence of GPU commands. The next chapter covers how this dispatch table works and how chain decode amortizes the per-token overhead.
Streaming and Latency
From the user’s perspective, what matters is the time between successive tokens appearing on screen. This is the inter-token latency (ITL), and it is determined by:
ITL = GPU_forward_pass + CPU_token_read + callback_overhead
For chain decode with a chunk size of 64, the GPU processes 64 tokens in one submission, and then the CPU reads all 64 results at once. The effective ITL per token is:
effective_ITL = GPU_time_for_64_tokens / 64
However, the perceived latency for the first token in each chunk is higher because the user sees nothing until the entire chunk completes. This is the tradeoff of chain decode: higher throughput at the cost of chunkier streaming. In practice, with a chunk time of ~500ms for 64 tokens, the burst of tokens appears fast enough that users perceive smooth streaming.
The Dispatch Table Mechanism
The decode loop does not manually encode GPU commands for each token. Instead, it uses a pre-compiled dispatch table (DispatchTable): a flat array of DispatchCmd structs that describe every GPU dispatch needed for one forward pass.
struct DispatchCmd {
Pipeline pso;
Buffer buffers[MAX_BUFFERS];
uint32_t offsets[MAX_BUFFERS];
int buffer_count;
uint8_t param_bytes[64];
int param_size;
// ...
PatchType patch_type;
int patch_offset_1;
int patch_offset_2;
};
The table is built once during model initialization. During decode, the engine simply iterates the table, patching per-token fields (position, KV sequence length, token buffer offset) and dispatching. The encode_chain function does this:
for (int tok = 0; tok < count; tok++) {
int pos = start_position + tok;
for (int c = 0; c < n_cmds; c++) {
const auto& cmd = cmds[c];
device.set_pipeline(cmd.pso);
// Set buffers with per-token offset patching
// Set params with per-token position patching
device.dispatch(cmd.grid, cmd.threadgroup);
}
}
The hot/cold split is deliberate: profiling labels are stored in a separate labels vector so that the commands vector stays cache-dense during the hot path. Since DispatchCmd is a POD struct with fixed-size arrays (no heap allocations), iterating the commands vector generates predictable memory access patterns.
The patch types handle all per-token dynamic data:
| Patch Type | What Gets Patched | Used By |
|---|---|---|
PATCH_NONE | Nothing | Most commands |
PATCH_TOKEN_OFFSET | buffers[0].offset = tok * 4 | Embedding lookup |
PATCH_POSITION | Position field in params | RoPE, attention |
PATCH_KV_SEQ_LEN | KV sequence length in params | Attention |
PATCH_POS_AND_KV | Both position and KV length | Fused RoPE+KV |
PATCH_ARGMAX_OUTPUT | buffers[1].offset = (tok+1) * 4 | Argmax output |
Double Buffering in Greedy Decode
The greedy path uses a double-buffering strategy for the first chunk:
bool first = true;
while (generated < max_tokens) {
// ... prepare and encode dispatch table ...
if (first) {
state.device->end_encoding_sync();
first = false;
} else {
state.device->end_encoding_async();
state.device->wait();
}
}
The first chunk uses end_encoding_sync() (blocks until GPU completes). Subsequent chunks use end_encoding_async() + wait(), which allows overlapping GPU execution with CPU processing of the previous chunk’s results. This is a subtle but important optimization: while the GPU is computing chunk N+1, the CPU is reading results from chunk N, decoding tokens to text, and calling the user’s callback.
Summary
The decode loop is the orchestrator of Akunu’s token generation. It:
- Runs prefill to populate the KV cache and get the first token.
- Selects one of four decode paths based on sampling config and grammar constraints.
- Manages the generate-stream-check loop until a stop condition is met.
- Collects timing statistics for performance reporting.
The path selection is the most impactful architectural decision: greedy and sampled paths can use chain decode (next chapter) for maximum throughput, grammar decode must go through the CPU for every token, and speculative decode trades prediction accuracy for throughput.
Token Text Decoding
A subtle but important detail: the callback receives decoded text, not just token IDs. The decode_token_text helper converts a token ID to a UTF-8 string:
const char *text = decode_token_text(state, next_token);
if (!callback(next_token, text, user_data))
stopped = true;
This decoding happens on the CPU after each token (or chunk of tokens) is read from the GPU. For most tokenizers, this is a simple lookup into a vocabulary table, taking nanoseconds. However, multi-byte UTF-8 characters can span multiple tokens, and the tokenizer must handle partial characters gracefully.
The text is returned as a const char * pointing to an internal buffer that is valid until the next call to decode_token_text. This avoids allocation overhead in the hot path.
Error Handling and Edge Cases
The decode loop handles several edge cases that are easy to overlook:
Empty Prompts
When n_prompt = 0, the prefill loop body never executes, and next_token remains 0. The decode loop will immediately generate from token 0, which is typically a padding or BOS token. In practice, the caller always provides at least the BOS token.
Max Tokens = 0
If max_tokens = 0, the decode paths are never entered, and the function returns with generated = 0 (or 1 if the prefill’s first token was streamed). This is useful for prompt evaluation without generation.
KV Cache Exhaustion
The KV cache has a fixed maximum length. If pos + n would exceed the cache capacity, behavior depends on the cache implementation: ring-buffer caches wrap around, while linear caches simply fail. The decode loop itself does not check for cache exhaustion – that responsibility belongs to the caller.
Very Long Generations
For generations exceeding the chunk size (e.g., generating 10,000 tokens with chunk_size=128), the decode loop iterates approximately 78 times, each time submitting a chunk to the GPU, waiting for completion, streaming results, and looping. The overhead of this outer loop is negligible compared to the GPU computation time.
Memory and State Management
The decode loop operates on ModelState, which bundles all GPU resources:
struct ModelState {
Device *device;
WeightProvider *weights;
AkunuModelConfig config;
ArchDescriptor arch;
KVCache kv_cache;
ScratchBuffers scratch;
DispatchTable dispatch_table;
DispatchTable sampled_dispatch_table;
ChipConfig chip;
Tokenizer tokenizer;
// ...
};
The next_token and pos parameters are passed by reference and updated in-place. This allows the caller to resume generation from where it left off (e.g., for multi-turn conversations):
// First turn
stats1 = run_decode_loop(state, model, prompt1, n1, 0, max_tokens, ...);
// pos is now at start_pos + n1 + generated1
// Second turn (continue from where we left off)
stats2 = run_decode_loop(state, model, prompt2, n2, pos, max_tokens, ...);
The KV cache retains all previously computed K/V entries, so the second turn can attend to the first turn’s context without re-prefilling.
Profiling and Debugging
The dispatch table supports per-command labels for GPU profiling:
struct DispatchLabel {
char text[48];
};
When Metal GPU capture is active, each dispatch in the command buffer carries a label like "L12.attention" or "L5.ffn.down_gemv", making it straightforward to identify performance hotspots in Instruments or Xcode’s GPU debugger.
The labels are stored in a cold parallel vector (DispatchTable::labels) separate from the hot command array (DispatchTable::commands), ensuring the profiling metadata does not pollute the cache during the decode hot path.
-
XGrammar is a high-performance grammar engine from the TVM team. See: Dong, Y., et al. “XGrammar: Flexible and Efficient Structured Generation Engine for Large Language Models.” MLSys 2025 (arXiv:2411.15100, 2024). The bitmask approach is O(vocab_size) to apply but very cache-friendly. See https://arxiv.org/abs/2411.15100. ↩
-
Gumbel, E. J. “Statistical Theory of Extreme Values and Some Practical Applications.” National Bureau of Standards, 1954. The Gumbel-max trick is described in: Maddison, C.J., et al. “A* Sampling.” NeurIPS 2014 and Jang, E., et al. “Categorical Reparameterization with Gumbel-Softmax.” ICLR 2017. The key result: argmax(log(pi) + G_i) where G_i ~ Gumbel(0,1) samples from Categorical(pi). See https://arxiv.org/abs/1411.0030. ↩
Greedy Decoding and Chain Decode
Greedy decoding is the simplest generation strategy: at each step, pick the token with the highest logit. No temperature, no sampling, no randomness – just argmax. Despite its simplicity, the implementation in Akunu is anything but trivial, because Akunu’s greedy path does not generate one token per GPU submission. Instead, it generates an entire chain of tokens in a single GPU command buffer.
This chapter covers decode_greedy in src/inference/decode_greedy.cpp, the dispatch table replay mechanism, and the chain decode technique that makes greedy generation remarkably fast.
What Makes Greedy Special
When temperature is zero, the entire sampling pipeline collapses to a single operation: argmax. There is no need for softmax, no random number generation, no top-k filtering. And crucially, the argmax result is deterministic – given the same input, you always get the same output.
This determinism enables a powerful optimization: if the GPU can compute argmax as the last step of the forward pass, it can immediately feed the result as the input to the next forward pass, without ever returning control to the CPU. This is chain decode.
The Chain Decode Concept
The idea is deceptively simple. Instead of:
CPU: write token → GPU: forward pass → CPU: read result → CPU: write next token → ...
We do:
CPU: write token → GPU: [forward pass → argmax → forward pass → argmax → ... × N] → CPU: read N results
The entire chain of N tokens is encoded into a single GPU command buffer. The GPU executes the full sequence without any CPU intervention.
The benefit is eliminating the CPU-GPU synchronization overhead that occurs between tokens. On Apple Silicon, each end_encoding_sync() call costs roughly 20-50 microseconds in Metal command buffer overhead. At 50 tokens/sec, this overhead is negligible. But chain decode also eliminates the command buffer creation overhead, which can be 100-300 microseconds per submission. For a chunk of 64 tokens, we save ~63 command buffer creations.
The decode_greedy Implementation
Let’s read the actual code:
int decode_greedy(ModelState& state, akunu_model_t model,
uint32_t& next_token, int& pos, int max_tokens,
akunu_token_callback callback, void *user_data) {
int generated = 0;
int chunk_size = state.chip.chain_decode_chunk;
bool first = true;
while (generated < max_tokens) {
int remaining = max_tokens - generated;
int n = std::min(chunk_size, remaining);
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, n);
if (first) {
state.device->end_encoding_sync();
first = false;
} else {
state.device->end_encoding_async();
state.device->wait();
}
state.kv_cache.advance(n);
pos += n;
uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
for (int i = 0; i < n; i++) {
uint32_t tok = tokens[i + 1];
generated++;
if (state.tokenizer.is_eos(tok))
return generated;
if (callback) {
const char *text = decode_token_text(state, tok);
if (!callback(tok, text, user_data))
return generated;
}
next_token = tok;
}
}
return generated;
}
There is a lot packed into these 30 lines. Let’s unpack each piece.
Chunk Size Selection
int chunk_size = state.chip.chain_decode_chunk;
The chunk size comes from ChipConfig and varies by hardware:
| Hardware Tier | GPU Cores | Family | Chunk Size |
|---|---|---|---|
| M1/M2/M3 Base | < 16 | < 9 | 64 |
| M3 Pro | >= 16 | < 9 | 96 |
| M4 Base | < 16 | >= 9 | 128 |
| M4 Pro | >= 16 | >= 9 | 128 |
| M-series Max | >= 30 | any | 128 |
| M-series Ultra | >= 60 | any | 128 |
The M4 family gets larger chunks because its GPU command processor is more efficient at handling long command buffers, and its memory subsystem has better bandwidth for the interleaved read-write patterns of chain decode.1
Writing the Input Token
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
Only 4 bytes are written: the single token ID that starts the chain. The rest of the token_ids buffer will be filled by the GPU as each step’s argmax writes its result.
Encoding the Dispatch Table
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, n);
This is the heart of chain decode. encode_dispatch_table calls encode_chain, which replays the dispatch table n times with position patching:
inline void encode_chain(Device& device, const DispatchTable& table,
int start_position, int count) {
const auto& cmds = table.commands;
const int n_cmds = (int)cmds.size();
for (int tok = 0; tok < count; tok++) {
int pos = start_position + tok;
for (int c = 0; c < n_cmds; c++) {
const auto& cmd = cmds[c];
device.set_pipeline(cmd.pso);
// ... set buffers, params, dispatch ...
}
}
}
For a 7B model with ~60 dispatches per token and a chunk of 64 tokens, this encodes 60 * 64 = 3840 GPU dispatches into a single command buffer. The Metal runtime batches these efficiently.
The Replicated Dispatch Pattern
Each iteration of the outer loop (each token) replays the same dispatch table, but with patched fields. The per-token patching ensures each forward pass uses the correct:
- Token ID offset: The embedding lookup reads from
token_ids[tok]instead oftoken_ids[0]. - Position: RoPE uses the correct absolute position for this token.
- KV sequence length: Attention knows how many KV entries are valid.
- Argmax output offset: The argmax result writes to
token_ids[tok + 1].
Here is how patching works for the most common types:
// Token offset patching (embedding lookup)
if (cmd.patch_type == DispatchCmd::PATCH_TOKEN_OFFSET && b == 0) {
offset = tok * 4; // byte offset to token_ids[tok]
}
// Argmax output patching
if (cmd.patch_type == DispatchCmd::PATCH_ARGMAX_OUTPUT && b == 1) {
offset = (tok + 1) * 4; // write result to token_ids[tok+1]
}
// Position patching (RoPE, attention)
if (cmd.patch_type == DispatchCmd::PATCH_POSITION) {
uint32_t pos_val = (uint32_t)pos;
memcpy(patched + cmd.patch_offset_1, &pos_val, 4);
}
// Combined position + KV length patching
if (cmd.patch_type == DispatchCmd::PATCH_POS_AND_KV) {
uint32_t pos_val = (uint32_t)pos;
uint32_t kv_val = (uint32_t)(pos + 1);
memcpy(patched + cmd.patch_offset_1, &pos_val, 4);
memcpy(patched + cmd.patch_offset_2, &kv_val, 4);
}
The token_ids buffer acts as a conveyor belt: token_ids[0] holds the input to the first forward pass, token_ids[1] gets the argmax output of the first pass (and becomes the input to the second pass), and so on.
token_ids buffer:
┌──────┬──────┬──────┬──────┬──────┬──────┐
│ t0 │ t1 │ t2 │ t3 │ t4 │ ... │
│(input)│(out1)│(out2)│(out3)│(out4)│ │
│ │(in2) │(in3) │(in4) │(in5) │ │
└──────┴──────┴──────┴──────┴──────┴──────┘
↑ ↑
CPU writes CPU reads all
one token after GPU done
Double Buffering
if (first) {
state.device->end_encoding_sync();
first = false;
} else {
state.device->end_encoding_async();
state.device->wait();
}
The first chunk uses synchronous execution because there is nothing to overlap with. Subsequent chunks use asynchronous execution: end_encoding_async() submits the command buffer and returns immediately, then wait() blocks until the GPU signals completion.
The benefit of async submission is that the Metal driver can start preparing the command buffer for submission while the CPU reads results from the previous chunk. In practice, this saves 10-30 microseconds per chunk – small, but it adds up over thousands of chunks.
Reading Results
uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
for (int i = 0; i < n; i++) {
uint32_t tok = tokens[i + 1]; // offset by 1: tokens[0] was input
generated++;
if (state.tokenizer.is_eos(tok))
return generated;
if (callback) {
const char *text = decode_token_text(state, tok);
if (!callback(tok, text, user_data))
return generated;
}
next_token = tok;
}
After the GPU completes, the CPU reads all n tokens at once from the token_ids buffer. For each token:
- Check for EOS: if the model generated an end-of-sequence token, stop immediately.
- Invoke the callback: convert the token to text and stream it to the user.
- Update
next_token: the last non-EOS token becomes the seed for the next chunk.
Note that EOS checking happens after the GPU has finished the entire chunk. If the model generated EOS at position 5 in a 64-token chunk, positions 6-63 were computed unnecessarily. This is the price of chain decode: you cannot stop mid-chain. However, the wasted work is bounded by one chunk (typically 64-128 tokens), and the throughput gain from chain decode far outweighs this cost.
When Does the Chain Break?
The chain breaks (a new GPU submission is needed) in three cases:
- EOS token: The CPU detects EOS when reading results and stops generating.
- Callback returns false: The user requested cancellation.
- Max tokens reached: The generation limit was hit.
Importantly, the chain does not break for stop sequences or for multi-token EOS patterns. Those are detected at the application level after the chain completes.
Performance Impact of Chain Decode
Let’s quantify the benefit. Consider generating 200 tokens on an M4 Pro with a 7B Q4_0 model:
Without chain decode (one submission per token):
| Component | Per-Token Cost | Total (200 tokens) |
|---|---|---|
| Forward pass GPU time | ~8ms | 1600ms |
| Command buffer overhead | ~0.2ms | 40ms |
| CPU-GPU sync | ~0.05ms | 10ms |
| Total | ~8.25ms | 1650ms |
With chain decode (chunk=128, 2 submissions):
| Component | Per-Chunk Cost | Total (2 chunks) |
|---|---|---|
| Forward pass GPU time | ~128 * 8ms = 1024ms | 1600ms |
| Command buffer overhead | ~0.2ms | 0.4ms |
| CPU-GPU sync | ~0.05ms | 0.1ms |
| Total | ~1024ms | 1600.5ms |
The GPU time is identical – the forward pass takes the same time regardless of whether it is in a chain or standalone. But the overhead drops from 50ms to 0.5ms: a 100x reduction. At 8ms per token, this is a ~3% throughput improvement, which is modest but free.
The real benefit emerges at faster per-token speeds. On an M4 Max with a 3B model doing 3ms/token, the overhead would be 10% without chaining and <0.1% with chaining.
The Dispatch Table in Detail
The dispatch table for a typical 7B Llama model contains approximately 60 commands per token:
| Command Group | Count | Description |
|---|---|---|
| Embedding | 1 | Token ID -> hidden state |
| Layer loop (x32) | ~48 | 32 layers * ~1.5 cmds each |
| RMSNorm | 32 | One per layer (fused) |
| RoPE + KV write | 32 | Fused per layer |
| Attention | 32 | Flash attention decode |
| O projection | 32 | GEMV per layer |
| SiLU + down | 32 | Fused activation GEMV or separate |
| Output norm | 1 | Final RMSNorm |
| Logit GEMV | 1 | Hidden state -> logits |
| Argmax | 1 | Logits -> token ID |
Wait, that is more than 60. The key is that many operations are fused. For example, a layer with fused SiLU GEMV (gemv_q4_0_silu) replaces three dispatches (gate GEMV, up GEMV, SiLU element-wise) with a single dispatch. Residual + RMSNorm is another fusion. With full fusion, a layer can be as few as 4 dispatches:
- Fused residual + RMSNorm + QKV GEMV
- RoPE + KV write
- Attention
- Fused SiLU GEMV (gate * up * down)
In practice, Akunu achieves 4-6 dispatches per layer depending on the model architecture and quantization format.
Chain Decode vs. Batched Decode
It is worth distinguishing chain decode from batched/continuous batching used by server-side inference engines (e.g., vLLM, TensorRT-LLM). Those systems process multiple independent sequences simultaneously, sharing the weight reads across sequences. Chain decode processes a single sequence, but chains multiple sequential tokens into one GPU submission.
| Feature | Chain Decode (Akunu) | Continuous Batching (vLLM) |
|---|---|---|
| Sequences | 1 | Many |
| Tokens per submission | N (sequential) | 1 per sequence, B sequences |
| Weight reuse | Across N tokens (sequential) | Across B sequences (parallel) |
| Use case | On-device inference | Server inference |
| KV cache | Single stream | Multiple streams |
Chain decode is uniquely suited to single-user on-device inference, where there is only one sequence to generate but we want to minimize CPU-GPU synchronization overhead.
How Sampled Decode Achieves Chain Performance
The sampled decode path (decode_sampled) achieves identical throughput to greedy by using the Gumbel-max trick on the GPU. Instead of an argmax at the end of each forward pass, it uses a Gumbel noise + argmax combination that is mathematically equivalent to sampling. The sampled dispatch table is identical to the greedy one except:
- A
gumbel_topk_f16kernel is inserted before the argmax - The Gumbel kernel has a
PATCH_POSITIONso each token gets unique noise
This means the chain is unbroken – no CPU roundtrip is needed for random number generation or softmax computation. The RNG is a PCG hash function that takes the position as input, so each token in the chain gets a different noise sample despite being computed in a single GPU submission.
// In the Gumbel kernel:
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
float u = pcg_float(element_seed + i);
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);
The params.position is different for each token in the chain (patched by PATCH_POSITION), and element_seed varies per vocabulary element via the + i term. Together, these ensure all Gumbel noise values are unique.2
Understanding the Data Flow in Detail
Let’s trace exactly what happens inside the GPU during a chain of 3 tokens. Assume a simple model with 1 layer, dim=4096, vocab_size=32000:
Chain of 3: token_ids = [42, ?, ?, ?]
TOKEN 0 (position P):
Embedding: read token_ids[0]=42 → hidden[4096]
RMSNorm: hidden → normed
Q GEMV: normed @ Wq → q[4096]
K GEMV: normed @ Wk → k[1024] (written to KV cache at pos P)
V GEMV: normed @ Wv → v[1024] (written to KV cache at pos P)
RoPE: rotate q, k in-place
Attention: q @ K_cache^T → scores → softmax → @ V_cache → attn_out
O GEMV: attn_out @ Wo → residual
... (FFN) ...
Logit GEMV: final_hidden @ W_logit → logits[32000]
Argmax: logits → token_ids[1] = 7891 (writes to slot 1)
TOKEN 1 (position P+1):
Embedding: read token_ids[1]=7891 → hidden[4096] (reads what token 0 wrote!)
... (same operations, but position is P+1, KV cache has P+1 entries) ...
Argmax: logits → token_ids[2] = 512 (writes to slot 2)
TOKEN 2 (position P+2):
Embedding: read token_ids[2]=512 → hidden[4096]
... (same operations, position P+2, KV cache has P+2 entries) ...
Argmax: logits → token_ids[3] = 1044 (writes to slot 3)
CPU reads: token_ids = [42, 7891, 512, 1044]
Stream tokens 7891, 512, 1044 to callback.
The critical data dependency is between the argmax write and the next token’s embedding read. Metal guarantees that dispatches within a single command buffer execute in order, so token 1’s embedding dispatch will see the value written by token 0’s argmax dispatch. No explicit synchronization is needed.
Edge Cases and Robustness
Buffer Sizing
The token_ids buffer must be large enough for chunk_size + 1 entries (input token + N output tokens). The speculative path has an additional constraint:
int max_batch = (int)(state.scratch.token_ids.size / sizeof(uint32_t)) - 1;
if (n_draft + 1 > max_batch)
n_draft = max_batch - 1;
KV Cache Advancement
After each chunk, the KV cache advances by the full chunk size:
state.kv_cache.advance(n);
pos += n;
This is correct even if EOS was generated mid-chunk, because the KV entries for positions after EOS are simply ignored – they will never be queried by subsequent attention operations (there won’t be any).
Position Overflow
The position counter uses int, which limits sequences to ~2 billion tokens. In practice, the KV cache size (typically 4K-128K) is the binding constraint, not the position counter.
Summary
Greedy decoding in Akunu is deceptively simple on the surface but architecturally sophisticated:
- Chain decode processes multiple sequential tokens in a single GPU command buffer, eliminating per-token CPU-GPU synchronization overhead.
- Dispatch table replay with per-token patching enables the chain without recompiling or rebuilding GPU commands.
- The token_ids conveyor belt passes each token’s argmax result as the next token’s input, all within GPU memory.
- Double buffering overlaps GPU execution with CPU result processing.
- Hardware-tuned chunk sizes balance throughput against streaming latency.
The same chain decode mechanism powers sampled decode (via Gumbel-max) and speculative decode (via batched verification), making it the foundational building block of Akunu’s generation pipeline.
Deep Dive: The Dispatch Table Build
To understand chain decode fully, we need to understand how the dispatch table is constructed. During model initialization, Akunu builds a DispatchTable that represents the complete forward pass for a single token. Here is the conceptual structure for a 32-layer Llama model:
Command 0: embedding_lookup (PATCH_TOKEN_OFFSET)
Command 1: rmsnorm_f16 (layer 0 attn norm)
Command 2: gemv_q4_0 (Q projection)
Command 3: gemv_q4_0 (K projection)
Command 4: gemv_q4_0 (V projection)
Command 5: rope_kv_write_f16 (PATCH_POS_AND_KV)
Command 6: flash_attention_decode_fast_f16 (PATCH_KV_SEQ_LEN)
Command 7: gemv_q4_0 (O projection)
Command 8: residual_rmsnorm_f16 (layer 0 FFN norm)
Command 9: gemv_q4_0_silu (fused gate*up*down)
... repeat for layers 1-31 ...
Command 57: residual_rmsnorm_f16 (output norm)
Command 58: gemv_q4_0 (logit projection)
Command 59: argmax_f16 (PATCH_ARGMAX_OUTPUT)
Each command is a fixed-size DispatchCmd struct (no heap allocations), and the full table is a contiguous vector that fits in a few cache lines worth of pointers. The table is built once and never modified during generation – only the patch fields change per-token.
Hot/Cold Data Split
The DispatchTable uses a hot/cold split for profiling data:
struct DispatchTable {
std::vector<DispatchCmd> commands; // HOT: iterated every token
std::vector<DispatchLabel> labels; // COLD: only used during profiling
};
During generation, only commands is accessed. The labels vector (with 48-byte strings per command) is never touched unless a GPU profiler is attached. This prevents the labels from evicting hot command data from the CPU cache.
Buffer Bindings Are Static
A key property of the dispatch table: all buffer bindings are static (fixed at build time). The weight buffers, scratch buffers, and KV cache buffers are allocated during model init and never change. The only per-token dynamic data is:
- Which offset within a buffer to use (patched via
PATCH_TOKEN_OFFSETandPATCH_ARGMAX_OUTPUT) - Which scalar parameter values to use (patched via
PATCH_POSITION,PATCH_KV_SEQ_LEN,PATCH_POS_AND_KV)
This means the Metal runtime can reuse pipeline state objects (PSOs) and buffer bindings across tokens, minimizing the command encoding overhead.
Practical Considerations for Chain Decode
Streaming Latency vs. Throughput
Chain decode introduces a fundamental tension: larger chunks give higher throughput (less overhead per token) but worse streaming latency (the user sees nothing until the chunk completes). Here is the tradeoff:
| Chunk Size | Overhead Savings | Time to First Token in Chunk | Tokens Buffered |
|---|---|---|---|
| 1 (no chain) | 0% (baseline) | ~8ms | 0 |
| 16 | ~94% | ~128ms | 16 |
| 64 | ~98.5% | ~512ms | 64 |
| 128 | ~99.2% | ~1024ms | 128 |
At chunk_size=128, the user waits up to 1 second before seeing a burst of 128 tokens appear nearly simultaneously. Whether this is acceptable depends on the application: for interactive chat, chunk_size=64 provides a good balance; for batch processing, chunk_size=128 maximizes throughput.
Interaction with Stop Sequences
Modern LLM applications often use stop sequences (e.g., "\n\nHuman:") to terminate generation. With chain decode, stop sequence detection happens after the chunk completes, not during. If the model generates the stop sequence at token 10 of a 64-token chunk, tokens 11-64 are wasted computation.
However, the wasted work is bounded: at most one chunk’s worth of tokens. Since the overhead savings from chain decode (eliminating 63 command buffer creations per chunk) far exceed the cost of a few wasted tokens, the net benefit is strongly positive.
Memory Bandwidth During Chain Decode
During a chain of N tokens, the GPU reads the full model weights N times (once per token). However, Apple Silicon’s SLC (System Level Cache) can cache a portion of the weights across tokens. For a 7B Q4_0 model (~3.5 GB), the SLC on different chips caches:
| Chip | SLC Size | % of Model Cached | Effective Bandwidth Boost |
|---|---|---|---|
| M4 Base | 16 MB | 0.5% | Negligible |
| M4 Pro | 32 MB | 0.9% | Small |
| M4 Max | 48 MB | 1.4% | Moderate |
| M4 Ultra | 96 MB | 2.7% | Moderate |
Even at 2.7% cache hit rate, the SLC provides measurable benefit because the cached portion includes the hot first-layer weights and norm parameters that are accessed every token. The threadgroup swizzling in GEMM kernels (covered in the GEMM chapter) is designed to maximize this SLC reuse.
For smaller models (1-3B) that partially fit in the SLC, the benefit is much larger. A 1B model at Q4_0 is ~500 MB, and a 96 MB SLC can cache ~19% of the weights, providing a meaningful bandwidth amplification.
Correctness of Chain Decode
A natural question: is chain decode mathematically equivalent to single-token decode? Yes, because:
- The dispatch table encodes the same forward pass regardless of chaining.
- Per-token patching ensures correct positions and KV cache lengths.
- The argmax writes to the correct output slot, and the next token reads from the previous slot.
- There are no data hazards: each forward pass within the chain reads only from buffers written by the previous pass, and the GPU’s command execution model guarantees sequential ordering within a command buffer.
The only difference is that stop conditions (EOS, callback cancellation) are checked after the chunk rather than after each token. This does not affect the generated text – it only affects how quickly the engine responds to stop conditions.
-
Apple. “Apple M4 chip.” apple.com, 2024. The M4’s GPU command processor improvements include reduced dispatch latency and better utilization of the Apple GPU’s tile-based deferred rendering architecture for compute workloads. See https://www.apple.com/newsroom/2024/05/apple-introduces-m4-chip/. ↩
-
Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. The Gumbel-max trick provides an exact sample from the categorical distribution without explicitly computing the softmax or CDF. The PCG hash function (O’Neill, M. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905) provides high-quality pseudorandomness with minimal state. See https://arxiv.org/abs/1411.0030. ↩
Sampled Decoding and GPU Gumbel-Max
Greedy decoding is clean and deterministic. You run argmax on the logits, you get a token, you move on. But anybody who has used an LLM for more than five minutes knows that greedy decoding produces text that is boring. Repetitive, predictable, and lifeless. If you want creative prose, diverse completions, or anything that feels like genuine language, you need to sample from the probability distribution rather than simply picking the mode.
This chapter covers how akunu implements sampled decoding. We will start with the theory – temperature scaling, the Gumbel-max trick, and the family of filtering methods (top-k, top-p, min-p). Then we will trace through the actual GPU kernel that does all of this in a single dispatch, entirely avoiding a CPU round-trip. Finally, we will look at the CPU fallback path for when the GPU kernel is unavailable.
Why Sampling Matters
Consider a language model predicting the next word after “The cat sat on the”. Greedy decoding will always produce “mat” (or whatever token has the highest logit). Every single time. That is useful for factual Q&A, but terrible for storytelling.
Sampling says: the model assigned 40% probability to “mat”, 15% to “roof”, 10% to “fence”, 8% to “windowsill”, and so on. Let us actually use that distribution. Roll a die weighted by those probabilities and pick a token accordingly. Now different runs produce different continuations, and the text feels more natural.
The tricky part is that naive categorical sampling from the full vocabulary (often 128k+ tokens) can produce garbage. The model assigns tiny but nonzero probability to tokens like “xyzzy” or “<0xFF>” – and occasionally you will land on one. So in practice, everybody applies some combination of temperature scaling and filtering before sampling. Let us walk through each.
Temperature Scaling
Temperature is the simplest knob. Given raw logits z_i from the model, we
compute:
p_i = softmax(z / T)_i = exp(z_i / T) / sum_j exp(z_j / T)
where T is the temperature.
T = 0.0 --> argmax (greedy)
T = 1.0 --> sample from the model's native distribution
T > 1.0 --> flatten the distribution (more random)
T < 1.0 --> sharpen the distribution (more deterministic)
Visually, here is what temperature does to a toy distribution:
Logits: [3.0, 1.0, 0.5, -1.0, -2.0]
T=0.5 (sharp):
Prob: [0.88, 0.05, 0.03, 0.01, 0.00]
####################################
##
#
.
.
T=1.0 (native):
Prob: [0.57, 0.08, 0.05, 0.01, 0.00]
########################
###
##
.
.
T=2.0 (flat):
Prob: [0.36, 0.15, 0.12, 0.06, 0.04]
###############
######
#####
##
#
In akunu’s CPU path (sampling.cpp), temperature scaling is applied as
multiplication by the inverse temperature:
float inv_temp = 1.0f / temperature;
for (int i = 0; i < vocab_size; i++)
logits[i] *= inv_temp;
Multiplying all logits by 1/T before softmax is mathematically equivalent to
dividing by T inside the exponential. It avoids a division per element.
On the GPU path, the Gumbel-max trick (discussed below) absorbs temperature into the noise magnitude, so there is no separate scaling pass.
The Gumbel-Max Trick
Here is the key insight that makes GPU-native sampling possible.
Theorem (Gumbel-Max): If you add independent Gumbel(0,1) noise to each logit and then take the argmax, the result is a sample from the categorical distribution defined by softmax of those logits.
More precisely, let g_i ~ Gumbel(0,1) be independent. Then:
argmax_i (z_i + g_i) ~ Categorical(softmax(z))
And for temperature scaling:
argmax_i (z_i + T * g_i) ~ Categorical(softmax(z / T))
This is remarkable. It means we can turn sampling into argmax – which is exactly the operation we already have a fast GPU kernel for (greedy decoding). We just need to add appropriately scaled random noise to the logits first.
Here is the pipeline:
+------------------+ +------------------+ +------------------+
| Model forward | --> | Add Gumbel noise | --> | Argmax |
| (produces | | scaled by temp | | (same kernel |
| logits) | | to each logit | | as greedy) |
+------------------+ +------------------+ +------------------+
GPU GPU GPU
No CPU round-trip needed!
Compare this with the traditional approach:
+------------------+ +----------+ +----------+ +----------+
| Model forward | --> | Copy to | --> | Softmax | --> | Sample |
| (produces | | CPU | | + filter | | (rand) |
| logits) | | | | | | |
+------------------+ +----------+ +----------+ +----------+
GPU sync! CPU CPU
The traditional approach requires a GPU-to-CPU synchronization to copy the logits, then CPU work for softmax and sampling, then writing the result back. That synchronization stall can cost 50-200 microseconds per token – which at high throughput becomes a significant fraction of total decode time.
The Gumbel-max trick keeps everything on the GPU. The sampled dispatch table
simply inserts the gumbel_topk_f16 kernel between the model forward pass and
the existing argmax kernel. No sync, no copy, no stall.
Generating Gumbel Noise
A Gumbel(0,1) random variable is generated from a uniform random variable
u ~ Uniform(0,1) via the inverse CDF:
g = -log(-log(u))
On the GPU, we need a fast source of uniform random numbers. We cannot use
rand() (no such thing in Metal compute shaders). Instead, we use a PCG
(Permuted Congruential Generator) hash function:
inline float pcg_float(uint state) {
state = state * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
word = (word >> 22u) ^ word;
return float(word) / 4294967296.0f;
}
This is not a PRNG in the traditional sense – it is a hash function. You feed it a seed and get back a pseudo-random float in [0, 1). The seed for each vocabulary element is computed as:
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
// Then for token i:
float u = pcg_float(element_seed + i);
The position is the current sequence position (patched per-token via
PATCH_POSITION in the dispatch table), and seed_offset is derived from
std::chrono::high_resolution_clock at the start of each generate call. The
multiplicative constant 2654435761 is the golden ratio times 2^32, a common
choice for hash mixing.
This scheme ensures:
- Different tokens at the same position get different noise (seeded by
i) - Different positions get different noise (seeded by
position) - Different calls get different noise (seeded by
seed_offset) - The same (position, seed_offset, i) triple always produces the same noise (reproducibility when seeds are fixed)
The Gumbel noise is then:
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);
where temp is the temperature parameter. The clamp(u, 1e-7, 1-1e-7)
ensures we never take log(0).
Filtering Methods
Pure sampling from the full vocabulary is too noisy. Filtering methods prune the candidate set before sampling. akunu implements three, applied in this order on the GPU: top-k, top-p, and min-p.
Top-K Filtering
Keep only the K tokens with the highest logits. Mask everything else to negative infinity.
Before top-k (K=3):
Token: A B C D E F G
Logit: 5.2 3.1 2.8 1.5 0.3 -1.0 -2.5
^^^ ^^^ ^^^ --- --- --- ---
keep keep keep mask mask mask mask
After top-k:
Logit: 5.2 3.1 2.8 -inf -inf -inf -inf
The GPU kernel finds the k-th largest logit via binary search rather than sorting. This is a clever approach: sorting 128k elements is expensive, but counting how many elements exceed a threshold is cheap (each thread scans its portion in parallel, then a SIMD reduction sums the counts).
// Binary search: 12 iterations to find the k-th largest value
float lo = global_max - 30.0f;
float hi = global_max;
for (int iter = 0; iter < 12; iter++) {
float mid = (lo + hi) * 0.5f;
uint local_count = 0;
for (uint i = tid; i < V; i += 1024)
local_count += (float(logits[i]) > mid) ? 1 : 0;
// SIMD reduction
uint sg_count = simd_sum(local_count);
if (slid == 0) tg_counts[sgid] = sg_count;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
uint total = 0;
for (uint s = 0; s < 32; s++) total += tg_counts[s];
tg_counts[0] = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tg_counts[0] > (uint)top_k)
lo = mid; // too many survivors, raise threshold
else
hi = mid; // too few, lower threshold
}
threshold = hi;
After 12 iterations of binary search, the threshold converges to within
30 / 2^12 ~ 0.007 logit units – more than precise enough. Each iteration
requires only 2 threadgroup barriers and a parallel scan, so the total cost is
modest.
The alternative – a full sort – would cost O(V log V) work. The binary search costs O(V * iterations) = O(V * 12), which for V=128k is about the same as a single pass of merge sort. And it parallelizes trivially.
Top-P (Nucleus) Filtering
Top-p keeps the smallest set of tokens whose cumulative probability exceeds p.
This is adaptive – when the model is confident, it might keep only 2-3 tokens;
when uncertain, it might keep hundreds.
Sorted by probability:
Token: A B C D E ...
Prob: 0.40 0.25 0.15 0.10 0.05 ...
CumSum: 0.40 0.65 0.80 0.90 0.95 ...
^^^^
top_p=0.9 cutoff here
Keep A, B, C, D. Mask E and everything after.
On the GPU, akunu again uses binary search. It first computes the total
exp-sum for tokens above the current threshold (from top-k), then searches
for the logit value where the cumulative probability reaches top_p:
if (top_p < 1.0f && top_p > 0.0f) {
// Compute total exp(logit - max) for survivors
float total_exp = ...; // parallel reduction
float target_exp = top_p * total_exp;
// Binary search: 8 iterations
float lo_p = threshold;
float hi_p = global_max;
for (int iter = 0; iter < 8; iter++) {
float mid_p = (lo_p + hi_p) * 0.5f;
// Count exp-sum for logits >= mid_p
// If sum > target, raise threshold (too many)
// If sum <= target, lower threshold (too few)
}
threshold = max(threshold, hi_p);
}
The key insight: threshold can only go up from top-p, never down. If top-k
already restricted us to 40 tokens, top-p can further restrict to (say) 10, but
it will never expand beyond 40. That is why we take max(threshold, hi_p).
Min-P Filtering
Min-p is the newest and arguably most elegant filtering method. It keeps all
tokens whose probability is at least min_p times the probability of the most
likely token.
Most likely token has probability 0.40
min_p = 0.1
Threshold = 0.40 * 0.1 = 0.04
Token: A B C D E F
Prob: 0.40 0.25 0.15 0.10 0.05 0.02
keep keep keep keep keep MASK
Since probabilities are proportional to exp(logit), and the max probability
corresponds to exp(global_max), the min-p condition becomes:
exp(logit_i) >= min_p * exp(global_max)
logit_i >= global_max + log(min_p)
This is a single threshold comparison – no sorting, no cumulative sums:
if (params.min_p > 0.0f && params.min_p < 1.0f) {
float min_p_threshold = global_max + log(params.min_p);
threshold = max(threshold, min_p_threshold);
}
Beautifully simple on the GPU. One log, one add, one max.
Repetition Penalty
Before any filtering happens, the kernel applies repetition penalty to discourage the model from repeating tokens it has already generated. The algorithm is asymmetric:
- If a logit is positive, divide by the penalty factor
- If a logit is negative, multiply by the penalty factor
if (rep_penalty > 1.0f && params.position > 0) {
for (uint i = tid; i < params.position; i += 1024) {
uint token = token_ids[i + 1];
if (token < V) {
float val = float(logits[token]);
logits[token] = half(val > 0 ? val / rep_penalty : val * rep_penalty);
}
}
}
This is Phase 0 of the kernel. Each thread in the 1024-thread threadgroup
processes a different previous token (strided access). The token_ids buffer
contains the sequence generated so far, indexed from position 1 (position 0
is the input token).
The asymmetric formulation ensures that the penalty always decreases the probability of repeated tokens, regardless of whether their logit is positive or negative. If we just divided all logits by the penalty, negative logits would become less negative (higher probability), which is the opposite of what we want.
The Complete GPU Kernel: Phase by Phase
The gumbel_topk_f16 Metal kernel executes in a single threadgroup of 1024
threads. Here is the full pipeline:
+-------------------------------------------------------+
| gumbel_topk_f16 kernel |
| |
| Phase 0: Repetition penalty |
| For each prev token (strided across 1024 threads): |
| logit[tok] /= penalty (if positive) |
| logit[tok] *= penalty (if negative) |
| [threadgroup barrier] |
| |
| Phase 1: Find global maximum |
| Each thread: local_max over its strided elements |
| SIMD reduction -> simdgroup maxes |
| Threadgroup reduction -> single global_max |
| [threadgroup barrier] |
| |
| Phase 2: Find top-k threshold via binary search |
| 12 iterations: |
| mid = (lo + hi) / 2 |
| Each thread: count elements > mid (strided) |
| SIMD sum -> simdgroup counts |
| Threadgroup sum -> total count |
| if total > k: lo = mid else: hi = mid |
| [2 barriers per iteration = 24 barriers] |
| |
| Phase 2b: Top-p via binary search |
| 8 iterations of similar structure |
| Raises threshold if cumulative prob > top_p |
| |
| Phase 2c: Min-p threshold |
| threshold = max(threshold, global_max + log(min_p)) |
| |
| Phase 3: Apply mask + Gumbel noise |
| For each element (strided): |
| if logit < threshold: logit = -inf |
| else: logit += temp * Gumbel(pcg_hash(seed + i)) |
+-------------------------------------------------------+
|
v
+-------------------------------------------------------+
| argmax_f16 kernel (existing) |
| Standard parallel reduction -> winning token ID |
+-------------------------------------------------------+
The entire pipeline – from raw logits to sampled token ID – never leaves the
GPU. The sampled_dispatch_table in akunu chains these kernels together as
part of the same command buffer as the model forward pass.
Dispatch Table Integration
In decode_sampled.cpp, the function first checks whether a sampled dispatch
table exists:
bool have_sampled_table = !state.sampled_dispatch_table.commands.empty();
If it does, the function patches the Gumbel kernel’s parameters directly into the command buffer:
auto& gumbel_cmd = cmds[cmds.size() - 2]; // second-to-last command
// Param layout: [vocab(0), temp(4), pos(8), seed(12),
// top_k(16), top_p(20), rep_penalty(24), min_p(28)]
memcpy(gumbel_cmd.param_bytes + 4, &temp, sizeof(float));
memcpy(gumbel_cmd.param_bytes + 12, &seed_base, sizeof(uint32_t));
The second-to-last command is the Gumbel kernel (the last is argmax). The seed is derived from the high-resolution clock:
uint32_t seed_base =
(uint32_t)std::chrono::high_resolution_clock::now()
.time_since_epoch().count();
The position field is patched per-token by the dispatch table’s
PATCH_POSITION mechanism – the same mechanism used for KV cache position
indexing during chain decode.
If the top_k parameter is >= 32 bytes, the extended parameters (top_k,
top_p, repeat_penalty, min_p) are also patched. This backwards-compatible
layout means older compiled dispatch tables (without the extended params)
still work – they just do not get filtering.
The Decode Loop
The actual decode loop in decode_sampled is structurally identical to greedy
chain decode:
while (generated < max_tokens) {
int remaining = max_tokens - generated;
int n = std::min(chunk, remaining);
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
state.device->begin_encoding();
state.device->encode_dispatch_table(&table, pos, n);
state.device->end_encoding_sync();
state.kv_cache.advance(n);
pos += n;
uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
for (int i = 0; i < n; i++) {
uint32_t tok = tokens[i + 1];
// ... emit token, check EOS, update next_token ...
}
}
The chunk size is state.chip.chain_decode_chunk, typically 1 for sampled
decode (since each token depends on the previous sample). But the dispatch
table could support multi-token chunks if the Gumbel kernel were extended to
produce multiple samples per dispatch.
The beauty of this design is that from the loop’s perspective, sampled decode
and greedy decode look identical. The only difference is which dispatch table
gets encoded: state.dispatch_table (greedy argmax) vs
state.sampled_dispatch_table (Gumbel + argmax).
CPU Fallback Path
When the sampled dispatch table is not available (e.g., the Metal library was
not compiled with the Gumbel kernel, or the model format does not support it),
decode_sampled falls back to using the regular dispatch table:
DispatchTable& table = have_sampled_table
? state.sampled_dispatch_table
: state.dispatch_table;
In this case, the regular dispatch table produces raw logits (no argmax), and
the CPU sampling pipeline in sampling.cpp takes over. We will cover that
pipeline in detail in Chapter 32, but the high-level flow is:
GPU: model forward -> raw logits in buffer
CPU: copy F16 logits -> convert to F32 -> temperature scale
-> softmax -> top-k (partial sort) -> min-p filter
-> top-p filter -> renormalize -> categorical sample
The CPU path is more flexible (easier to add new filtering methods) but slower due to the GPU-CPU synchronization.
Parameter Struct
The AkunuSamplingConfig struct defines all sampling knobs:
+-----+-----+-----+-----+-----+
| 0 | 4 | 8 | 12 | 16 | byte offset
+-----+-----+-----+-----+-----+
|temp | topk| topp| minp| rep |
|f32 | i32 | f32 | f32 | f32 |
+-----+-----+-----+-----+-----+
20 bytes total, 4-byte aligned
Default values:
temperature = 0.0(greedy – no sampling)top_k = 40top_p = 0.9min_p = 0.0(disabled)repeat_penalty = 1.0(disabled)
A temperature of 0 triggers the greedy path; any positive temperature engages sampled decode.
Practical Considerations
Reproducibility. If you fix seed_offset (by seeding from a known value
rather than the clock), the GPU Gumbel-max path is fully deterministic for a
given (model, prompt, parameters) tuple. The PCG hash is deterministic, and
the Gumbel-max trick is a monotonic function of the noise – so the same noise
produces the same token.
Quality of randomness. The PCG hash is not cryptographically secure, and the per-element seeds are correlated (they differ by 1). In practice, this does not matter for LLM sampling. The quality of the generated text is dominated by the logit values, not by the precise distribution of the noise. Any reasonable hash function would work.
Filter interaction. The three filters (top-k, top-p, min-p) interact additively – each can only raise the threshold, never lower it. This means the order does not matter mathematically, but on the GPU, doing top-k first is most efficient because it establishes a tight initial threshold that makes the top-p binary search converge faster.
Throughput. The Gumbel kernel dispatches with grid=(1), threadgroup=(1024). That is a single threadgroup of 1024 threads. On Apple Silicon, a single GPU core can handle this in microseconds. The kernel is latency-bound, not throughput-bound – it does O(V) work per phase, but V=128k fits comfortably in the threadgroup’s register file and shared memory.
When to use sampling. For code generation, math, and factual Q&A, greedy decode (or low temperature like 0.1) usually wins. For creative writing, conversation, and brainstorming, temperature 0.7-1.0 with top-p 0.9 is a solid default. Min-p around 0.05-0.1 is increasingly popular as a replacement for top-k, since it adapts to the model’s confidence level.
Summary
Akunu’s sampled decode takes advantage of the Gumbel-max trick to keep the
entire sampling pipeline on the GPU. The single gumbel_topk_f16 Metal kernel
implements repetition penalty, global max finding, top-k via binary search,
top-p via binary search, min-p filtering, and Gumbel noise injection in one
dispatch of 1024 threads. The existing argmax kernel then produces the final
token. No CPU round-trip, no synchronization stall, identical throughput to
greedy decoding.
When the GPU kernel is unavailable, the CPU fallback in sampling.cpp provides
the same functionality with explicit F16-to-F32 conversion, softmax, partial
sort, and categorical sampling. We will dissect that path in Chapter 32.
Speculative Decoding with N-grams
Autoregressive decoding has a fundamental problem: each token depends on the previous one. You cannot compute token 42 until you have produced token 41, because token 41 changes the KV cache and thus the attention output for position 42. This makes decode inherently sequential. For large models, each decode step takes (say) 5-10 milliseconds of GPU time, so generating 500 tokens takes 2.5-5 seconds. The GPU is actually underutilized during decode – it is doing a single matrix-vector multiply per layer, not the fat matrix-matrix multiplies of prefill – but there is no obvious way to parallelize.
Speculative decoding breaks this bottleneck by guessing future tokens cheaply and then verifying them in a single batched forward pass. If the guesses are right (and for many text patterns, they often are), you produce multiple tokens per forward pass without changing the output distribution.
This chapter covers akunu’s speculative decode implementation, which uses n-gram frequency tables as its draft predictor – no draft model required.
The Core Idea
The speculative decode algorithm has three steps:
1. DRAFT: Predict N future tokens cheaply
2. VERIFY: Run all N+1 tokens through the real model in one batch
3. ACCEPT: Keep the longest prefix where draft == verified
+----------+ +------------------+ +-----------+
| N-gram | --> | Batched model | --> | Compare |
| predictor| | forward pass | | draft vs |
| (draft) | | (N+1 positions) | | verified |
+----------+ +------------------+ +-----------+
~0 cost ~1x forward cost ~0 cost
If K of N drafts match: you got K+1 tokens for the price of 1 forward pass!
Let us trace through a concrete example. Suppose we are generating text and the most recent token is “the”. The n-gram predictor looks at its frequency tables and predicts: [“quick”, “brown”, “fox”, “jumps”].
We pack these into a batch:
Position: 42 43 44 45 46
Input: "the" "quick" "brown" "fox" "jumps"
^ ^ ^ ^ ^
known draft[0] draft[1] draft[2] draft[3]
We run the model on all 5 positions in a single forward pass. The model produces an argmax (or sampled) token for each position:
Position: 42 43 44 45 46
Verified: "quick" "brown" "fox" "jumped" "over"
^ ^ ^ ^
matches matches matches MISMATCH at draft[3]
Drafts 0-2 matched (“quick”, “brown”, “fox” were all correct guesses). Draft 3 was wrong (“jumped” vs “jumps”). So we accept the 3 matching drafts plus the bonus token at the mismatch point:
Accepted tokens: "quick", "brown", "fox", "jumped"
= 4 tokens from 1 forward pass!
The bonus token at position 45 is the model’s actual prediction given the correct prefix [“the”, “quick”, “brown”, “fox”] – so it is always valid. We just could not predict what came after it.
Why N-grams?
The standard speculative decoding literature uses a small draft model to generate the draft tokens. That draft model is typically 10-50x smaller than the target model, runs fast, and hopefully agrees with the target model on easy tokens.
Akunu takes a different approach: no draft model at all. Instead, it builds n-gram frequency tables from the tokens seen so far (prompt + generated) and uses those to predict the next tokens.
This has several advantages:
- Zero overhead at load time. No extra model to load, no extra memory.
- Perfect for repetitive patterns. Code, structured data, and template- heavy text often repeat multi-token sequences. N-grams capture these perfectly.
- Adapts in real-time. The frequency tables update with every generated token, so the predictor learns patterns specific to this generation.
- Simplicity. The entire predictor is ~100 lines of C++ with no external dependencies.
The downside is that n-gram prediction has zero “understanding” – it cannot predict tokens it has not seen in exactly the right context. For highly novel text, the predictor will fail to produce any drafts, and speculative decode degrades gracefully to standard autoregressive decode.
The N-gram Predictor
Let us look at akunu’s NGramPredictor class in detail.
Configuration
MAX_ORDER = 4 Up to 4-grams (context of 3 tokens)
DRAFT_COUNT = 4 Predict up to 4 tokens per round
MAX_HISTORY = 512 Sliding window of recent tokens
Data Structures
The predictor maintains:
- A
deque<uint32_t> history_of the last 512 tokens seen - Three hash tables (one per n-gram order):
tables_[0]: bigrams (1 token context -> next token -> count)tables_[1]: trigrams (2 token context -> next token -> count)tables_[2]: 4-grams (3 token context -> next token -> count)
tables_[order-2]: hash(context) --> { token_id: count, ... }
Example for the context "the cat sat":
tables_[2][hash("the","cat","sat")] = {
"on": 47, <-- seen 47 times after "the cat sat"
"down": 3, <-- seen 3 times
"and": 1, <-- seen 1 time
}
Hashing
Context tokens are hashed using FNV-1a, a simple non-cryptographic hash:
static uint64_t context_hash(const uint32_t *tokens, int n) {
uint64_t h = 14695981039346656037ULL; // FNV-1a offset basis
for (int i = 0; i < n; i++) {
h ^= tokens[i];
h *= 1099511628211ULL; // FNV-1a prime
}
return h;
}
This maps a variable-length token sequence to a 64-bit key. Collisions are theoretically possible but vanishingly unlikely for 512-token histories with 128k vocabulary – the hash space is 2^64 while the number of distinct contexts is at most 512^3 ~ 134M for 4-grams.
Update
When a new token is generated, update() adds it to the frequency tables at
all applicable orders:
History: [... t_{n-3}, t_{n-2}, t_{n-1}]
New token: t_n
Bigram: hash(t_{n-1}) -> t_n count++
Trigram: hash(t_{n-2}, t_{n-1}) -> t_n count++
4-gram: hash(t_{n-3}, t_{n-2}, t_{n-1}) -> t_n count++
The prompt tokens are added via update_batch(), which calls update() for
each token. This seeds the frequency tables with the patterns present in the
prompt – which is crucial for tasks like “continue this code” where the prompt
establishes the patterns that will repeat.
Prediction
The predict() method generates up to DRAFT_COUNT (4) predicted tokens.
For each position, it tries the longest matching context first:
Tentative context: [... t_{n-2}, t_{n-1}, t_n]
Try 4-gram: lookup hash(t_{n-2}, t_{n-1}, t_n) in tables_[2]
Found? Pick the most frequent continuation. Done.
Not found? Try 3-gram.
Try 3-gram: lookup hash(t_{n-1}, t_n) in tables_[1]
Found? Pick the most frequent continuation. Done.
Not found? Try 2-gram.
Try 2-gram: lookup hash(t_n) in tables_[0]
Found? Pick the most frequent continuation. Done.
Not found? No draft for this position. Stop.
Once a token is predicted, it is appended to the tentative context, and the process repeats for the next position. This means the predictor can chain its own predictions – predicting “quick” allows it to then predict “brown” given the extended context.
predict() trace:
Context: ["the", "cat", "sat"]
Step 0: 4-gram("the","cat","sat") --> "on" (count=47)
Context becomes ["the","cat","sat","on"]
Step 1: 4-gram("cat","sat","on") --> "the" (count=31)
Context becomes ["the","cat","sat","on","the"]
Step 2: 4-gram("sat","on","the") --> "mat" (count=23)
Context becomes ["the","cat","sat","on","the","mat"]
Step 3: 4-gram("on","the","mat") --> not found
3-gram("the","mat") --> "." (count=5)
Draft: ["on", "the", "mat", "."]
The Verification Loop
Now let us trace through decode_speculative() in detail.
No-draft fallback
If the predictor fails to produce any drafts (n_draft <= 0), the loop falls
back to standard single-token decode:
if (n_draft <= 0) {
// Single-token forward pass
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, 1);
state.device->end_encoding_sync();
state.kv_cache.advance(1);
// ... emit token, update predictor ...
continue;
}
This is important: speculative decode never hurts performance. In the worst case (no successful predictions), it behaves identically to standard decode.
Batched verification
When drafts are available, we pack them into the token_ids buffer along with the current token:
uint32_t *chain_buf = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
chain_buf[0] = next_token; // the known token
for (int i = 0; i < n_draft; i++)
chain_buf[i + 1] = drafts[i]; // the draft tokens
Then we run the model on the entire batch:
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, batch_size);
state.device->end_encoding_sync();
state.kv_cache.advance(batch_size);
The dispatch table processes batch_size = 1 + n_draft positions in a single
encoding pass. This is the same mechanism used for prefill – the model sees
multiple tokens and produces a prediction for each position.
Token IDs buffer before forward pass:
+-------+----------+----------+----------+----------+
| known | draft[0] | draft[1] | draft[2] | draft[3] |
| tok | | | | |
+-------+----------+----------+----------+----------+
pos pos+1 pos+2 pos+3 pos+4
Token IDs buffer after forward pass (output shifted by 1):
+-------+----------+----------+----------+----------+----------+
| (in) | verify0 | verify1 | verify2 | verify3 | verify4 |
+-------+----------+----------+----------+----------+----------+
[0] [1] [2] [3] [4] [5]
verify0 = model's prediction given context up to pos
verify1 = model's prediction given context up to pos+1 (including draft[0])
...
Acceptance logic
We compare each verified token against the corresponding draft:
int accepted = 0;
uint32_t bonus_token = 0;
for (int i = 0; i < batch_size; i++) {
uint32_t verified = chain_buf[i + 1];
if (i < n_draft) {
if (verified == drafts[i])
accepted++; // Draft was correct!
else {
bonus_token = verified; // First mismatch
break;
}
} else {
bonus_token = verified; // Token after all drafts
}
}
This finds the longest prefix of correct drafts. The key insight: when draft[i] matches verified[i], the model was going to produce that token anyway. So we can accept it for free. The first mismatch gives us a “bonus token” – the model’s actual prediction at that position.
KV cache rollback
Here is a subtle but critical detail. The forward pass advanced the KV cache
by batch_size positions. But we only accepted accepted + 1 tokens (the
accepted drafts plus the bonus). The remaining KV cache entries (for the
rejected drafts) are wrong – they were computed based on a context that
included incorrect draft tokens.
We must roll back the KV cache:
int keep = accepted + 1;
state.kv_cache.rollback(pos + keep);
pos += keep;
rollback(n) truncates the KV cache to position n, discarding everything
after. This is a cheap operation – typically just updating a length counter,
since the actual data does not need to be zeroed (it will be overwritten by the
next forward pass).
Complete flow diagram
+-------------------+
| N-gram predictor |
| predict() -> 4 |
| draft tokens |
+-------------------+
|
v
+-------------------+
| Pack batch: |
| [known, d0, d1, |
| d2, d3] |
| batch_size = 5 |
+-------------------+
|
v
+-------------------+
| Model forward |
| on 5 positions |
| (single GPU pass) |
+-------------------+
|
v
+-------------------+
| Read outputs: |
| [v0,v1,v2,v3,v4] |
+-------------------+
|
v
+-------------------+ +-------------------+
| Compare: | | Example: |
| d0==v0? yes | | d0="on" v0="on" |
| d1==v1? yes | | d1="the" v1="the" |
| d2==v2? no | | d2="mat" v2="rug" |
| bonus = v2 | | bonus = "rug" |
+-------------------+ +-------------------+
|
v
+-------------------+
| Rollback KV cache |
| keep = 2+1 = 3 |
| Emit: d0, d1, |
| bonus |
| = 3 tokens for |
| 1 forward pass! |
+-------------------+
Buffer Capacity Guard
There is a practical limit on batch size: the token_ids buffer has a fixed
allocation. The code guards against overflow:
int max_batch = (int)(state.scratch.token_ids.size / sizeof(uint32_t)) - 1;
if (n_draft + 1 > max_batch)
n_draft = max_batch - 1;
The -1 accounts for the output slot: the buffer needs batch_size + 1
uint32 slots (input tokens plus one extra output slot at the end).
Speedup Analysis
The theoretical speedup from speculative decoding is:
Speedup = (accepted + 1) / 1 = accepted + 1
That is: if K drafts are accepted, you get K+1 tokens from 1 forward pass instead of K+1 forward passes.
The actual speedup depends on the acceptance rate – what fraction of drafts
match the model’s predictions. Let alpha be the per-token acceptance
probability. With D drafts, the expected number of accepted tokens is:
E[accepted] = sum_{k=0}^{D} k * alpha^k * (1-alpha)
+ D * alpha^D (all D accepted)
For D=4, alpha=0.5: E[accepted] = 0.94
For D=4, alpha=0.7: E[accepted] = 2.05
For D=4, alpha=0.9: E[accepted] = 3.28
So the expected speedup is roughly:
alpha=0.5: 1.94x (barely worth it)
alpha=0.7: 3.05x (significant)
alpha=0.9: 4.28x (near-optimal)
But wait – there is overhead. The batched forward pass with 5 tokens is not free. It is faster than 5 separate forward passes (because attention over 5 positions is still cheap), but it does more work than a single forward pass.
In practice, on Apple Silicon, the overhead of a 5-token batch vs a 1-token decode is about 20-40%. So the actual speedup is closer to:
alpha=0.5: 1.94 / 1.3 = 1.5x
alpha=0.7: 3.05 / 1.3 = 2.3x
alpha=0.9: 4.28 / 1.3 = 3.3x
Still quite good for alpha >= 0.7.
When Does N-gram Prediction Work Well?
The acceptance rate depends heavily on the content being generated:
High acceptance rate (alpha > 0.8):
- Repetitive code patterns (boilerplate, getters/setters)
- Template-based text (email signatures, legal disclaimers)
- Continuation of previously seen phrases
- JSON/XML with repeated structure
- Code that mirrors the prompt (e.g., implementing similar functions)
Medium acceptance rate (alpha 0.4-0.7):
- Natural language with common phrases
- Code with some repetition but also novel logic
- Structured output that partially follows established patterns
Low acceptance rate (alpha < 0.3):
- Highly creative or novel text
- Mathematical proofs with unique symbol sequences
- First-time generation of a new pattern
- Diverse conversational responses
The n-gram predictor excels in exactly the situations where LLMs are most boring – repetitive, predictable text. This is a happy coincidence: the tokens that are easy to predict are also the tokens that take the most wall-clock time (because there are many of them), so accelerating them has the most impact.
Comparison with Draft-Model Speculative Decoding
The classic approach uses a small draft model (e.g., a 160M parameter model drafting for a 7B target). How does n-gram prediction compare?
+-------------------+-------------------+-------------------+
| | N-gram | Draft model |
+-------------------+-------------------+-------------------+
| Setup cost | None | Load 2nd model |
| Memory | ~1MB tables | 100s MB - GBs |
| Draft speed | ~0 (hash lookup) | Fast but nonzero |
| Novel text | Poor | Good |
| Repetitive text | Excellent | Good |
| Code patterns | Excellent | Good |
| Implementation | ~100 lines | Full model stack |
| Correctness | Exact (greedy) | Exact (rejection) |
+-------------------+-------------------+-------------------+
For akunu’s use case – a lightweight local inference engine on Apple Silicon where memory is precious – n-gram prediction is an excellent choice. It adds negligible memory overhead and provides significant speedup on the kinds of text that dominate many workloads (code, structured output, templates).
Interaction with Greedy vs Sampled Decode
A subtle point: the current implementation uses the regular dispatch_table
(greedy argmax) for verification. This means speculative decode currently
produces the same output as greedy decode – it is a pure speedup with no
quality change.
Extending this to sampled decode requires the rejection sampling variant of
speculative decoding, where draft tokens are accepted probabilistically rather
than by exact match. The acceptance probability for draft token d when the
model would sample t is:
accept with probability min(1, P_target(d) / P_draft(d))
This preserves the target model’s sampling distribution exactly. Akunu does not currently implement this variant, but the n-gram predictor’s frequency counts could be used as approximate draft probabilities if needed.
The Predictor’s Sliding Window
The MAX_HISTORY = 512 sliding window serves two purposes:
-
Memory bound. Without a limit, the frequency tables would grow indefinitely during long generations. The deque automatically evicts old tokens.
-
Recency bias. Patterns from 10,000 tokens ago are less likely to repeat than patterns from 100 tokens ago. The sliding window implicitly prioritizes recent context.
Note that the frequency tables themselves are not pruned when tokens leave the sliding window. Only the history used for context matching is windowed. This means the tables accumulate counts from the entire generation, but prediction only uses the last 512 tokens as context. In practice, this works well – the tables capture global patterns while the context window provides relevance filtering.
Summary
Speculative decoding with n-gram prediction gives akunu a clean, zero-overhead way to accelerate the most tedious part of LLM inference: autoregressive decode of predictable tokens. The implementation is lean – about 200 lines total between the predictor and the decode loop – and the algorithm is remarkably simple: guess tokens from frequency tables, verify in batch, accept the longest matching prefix, roll back the KV cache for rejected drafts.
The key architectural insight is that the n-gram predictor and the batched forward pass are completely decoupled. Any prediction source could be swapped in (a draft model, a lookup table, a regex-based predictor for structured output) without changing the verification logic. The n-gram approach just happens to be the best cost/benefit ratio for a memory-constrained local inference engine.
Grammar-Constrained Generation
Here is a problem you have probably run into: you ask an LLM to produce JSON, and it almost works – except there is a trailing comma, or a missing quote, or it randomly starts explaining the JSON instead of just outputting it. You wrap the prompt in “respond only in valid JSON” and it works 95% of the time. For the other 5%, you add retry logic with exponential backoff. Before long, your “simple API call” has three layers of error handling, and you are still finding edge cases in production.
Grammar-constrained generation solves this problem at the source. Instead of hoping the model produces valid output and cleaning up after it, you force every generated token to be valid according to a formal grammar. The model cannot produce a trailing comma because the grammar says commas must be followed by a value. It cannot produce unmatched braces because the grammar tracks nesting depth. The output is syntactically valid by construction.
This chapter covers how akunu implements grammar-constrained decoding, from the grammar format (GBNF) through the constraint engine (XGrammar), JSON schema conversion, and the GPU-accelerated decode loop.
The Constraint Mechanism
The basic idea is simple: before each sampling step, compute which tokens are valid continuations according to the grammar, then mask all other tokens to negative infinity so they have zero probability.
Standard decode:
logits --> [sample] --> token
Grammar-constrained decode:
logits --> [mask invalid] --> [sample] --> token --> [advance grammar]
The tricky part is doing this efficiently. A vocabulary of 128k tokens and a grammar with hundreds of rules means we need to check 128k tokens against the grammar state at every step. A naive implementation could easily cost more than the model forward pass itself.
GBNF: Grammars for LLMs
Akunu uses GBNF (GGML BNF), a variant of Backus-Naur Form adapted for LLM token-level constraint. Here is what a simple JSON grammar looks like in GBNF:
root ::= object | array
value ::= object | array | string | number | "true" | "false" | "null"
object ::= "{" ws (pair ("," ws pair)*)? ws "}"
pair ::= string ws ":" ws value
array ::= "[" ws (value ("," ws value)*)? ws "]"
string ::= "\"" chars "\""
chars ::= char*
char ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt/] | "u" [0-9a-fA-F]{4}
number ::= "-"? int frac? exp?
int ::= "0" | [1-9] [0-9]*
frac ::= "." [0-9]+
exp ::= [eE] [+-]? [0-9]+
ws ::= [ \t\n\r]*
Each line defines a rule. Rules can reference other rules (non-terminals)
or specify literal characters and character classes (terminals). The root
rule is the entry point.
GBNF supports:
- Quoted literals:
"true","{","\\n" - Character classes:
[0-9],[a-fA-F],[^"\\] - Alternation:
| - Repetition:
*(zero or more),+(one or more),?(optional) - Rule references:
object,value,string
Akunu’s GBNF parser (grammar.cpp) is a complete recursive-descent parser that
converts the textual grammar into a vector of GrammarRule objects. Each rule
is a sequence of GrammarElement structs:
enum GrammarElementType : uint32_t {
GTYPE_END = 0, // end of rule
GTYPE_ALT = 1, // alternate (|)
GTYPE_RULE_REF = 2, // non-terminal reference
GTYPE_CHAR = 3, // character match
GTYPE_CHAR_NOT = 4, // inverse class [^...]
GTYPE_CHAR_RNG_UPR = 5, // range upper bound
GTYPE_CHAR_ALT = 6, // additional char in class
GTYPE_CHAR_ANY = 7, // wildcard .
};
struct GrammarElement {
GrammarElementType type;
uint32_t value; // code point, rule ID, etc.
};
The parser handles Unicode escapes (\u0041), hex escapes (\x41), and
standard C escapes (\n, \t, etc.). Rule names are alphanumeric with
hyphens and underscores.
The Legacy NPDA Engine
Akunu’s original grammar engine (before XGrammar) uses a nondeterministic
pushdown automaton (NPDA) to track the grammar state. This is the Grammar
class in grammar.h.
The NPDA state is a set of stacks, where each stack is a sequence of pointers
into the grammar rules. The top of each stack points to the element the grammar
expects to match next. Multiple stacks represent the nondeterminism – when a
rule has alternation (A | B), both alternatives produce separate stacks.
Grammar stacks (simplified for "object" rule):
Stack 0: [char("{"), ws, pair, ...]
Stack 1: [char("{"), ws, char("}")]
Stack 0 expects the object with at least one pair.
Stack 1 expects the empty object "{}".
When the apply() method is called with the logit array, it must:
- For each token in the vocabulary, decode the token to Unicode code points
- For each stack, simulate accepting those code points
- If any stack can accept the full token, the token is valid
- If no stack can accept the token, mask its logit to -inf
This is O(V * S * L) where V is vocab size, S is number of stacks, and L is average token length in code points. For V=128k and complex grammars, this becomes very expensive – often 10-50ms per token, which can dominate the total decode time.
The legacy engine also supports a deferred activation mode for thinking
models. When set_trigger_text() is called with (for example) </think>,
the grammar does not constrain output until that trigger text appears. This
allows models like Qwen3 to emit their <think>...</think> block freely,
then switch to grammar-constrained generation for the actual structured output:
Generation flow with deferred activation:
<think>Let me think about the JSON schema...</think>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Grammar inactive (unconstrained) |
v
{"name": "Alice", "age": 30}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Grammar active (constrained to JSON)
XGrammar: The Fast Path
The NPDA approach works but is slow. XGrammar (integrated as a third-party library) takes a fundamentally different approach: precompute everything at grammar compile time.
The key insight is that for a given grammar state (which rules/positions are active), the set of valid tokens is deterministic. XGrammar precomputes these valid-token sets as packed bitmasks during grammar compilation, then at runtime, looking up the bitmask for the current state is essentially O(1).
NPDA approach (per token):
+--------------------------------------------------+
| For each token (128k): |
| Decode token to code points |
| For each stack (10-100): |
| Try accepting code points |
| If success: token is valid |
+--------------------------------------------------+
Cost: O(V * S * L) per token step (~10-50ms)
XGrammar approach (per token):
+--------------------------------------------------+
| Look up precomputed bitmask for current state |
| Copy bitmask to GPU |
| Apply bitmask to logits (GPU kernel) |
+--------------------------------------------------+
Cost: O(1) lookup + O(V/32) GPU mask (~0.05ms)
The XGrammar wrapper (xgrammar_wrapper.h and .cpp) provides a clean
interface:
class XGrammarEngine {
public:
void init_vocab(const Tokenizer& tokenizer);
bool compile_json_schema(const std::string& schema);
bool compile_grammar(const std::string& gbnf);
bool compile_builtin_json();
bool fill_next_token_bitmask(int32_t *bitmask);
bool accept_token(int32_t token_id);
void rollback(int n_tokens);
bool is_terminated() const;
bool is_completed() const;
void reset();
int bitmask_size() const;
};
The bitmask format is packed: one bit per token, 32 tokens per int32 word.
For a 128k vocabulary, the bitmask is 128k / 32 = 4096 int32 words = 16 KB.
This is small enough to copy to the GPU every token step with negligible
overhead.
Bitmask layout: int32[ceil(vocab_size / 32)]
Word 0: bits 0-31 = tokens 0-31
Word 1: bits 0-31 = tokens 32-63
Word 2: bits 0-31 = tokens 64-95
...
Word N: bits 0-31 = tokens N*32 .. N*32+31
bit = 1 --> token is valid (keep logit)
bit = 0 --> token is invalid (set logit to -inf)
Vocabulary Initialization
Before compiling any grammar, XGrammar needs the tokenizer’s vocabulary.
init_vocab() decodes every token in the vocabulary to its string
representation:
void XGrammarEngine::init_vocab(const Tokenizer& tokenizer) {
vocab_size_ = tokenizer.vocab_size();
std::vector<std::string> encoded_vocab;
encoded_vocab.reserve(vocab_size_);
for (int i = 0; i < vocab_size_; i++) {
encoded_vocab.push_back(tokenizer.decode((uint32_t)i));
}
auto vocab_type = xgrammar::VocabType::BYTE_FALLBACK;
// ... build TokenizerInfo and GrammarCompiler ...
}
This is a one-time cost that happens during model loading. The decoded strings are what XGrammar uses to determine which tokens can match which grammar positions.
Grammar Compilation
XGrammar supports three compilation modes:
-
JSON Schema (
compile_json_schema): Takes a JSON Schema string and compiles it directly into a grammar matcher. This is the most common path for structured output. -
GBNF grammar (
compile_grammar): Takes a raw GBNF string. Useful for custom grammars (SQL, code templates, etc.). -
Built-in JSON (
compile_builtin_json): A pre-optimized grammar for “any valid JSON”. Used when the user requests JSON output without a specific schema.
Each compilation produces a CompiledGrammar object and a GrammarMatcher.
The compilation can take 10-100ms (depending on grammar complexity and
vocabulary size), but it is a one-time cost amortized over the entire
generation.
JSON Schema to GBNF
For the common case of structured JSON output, akunu provides a
json_schema_to_grammar() converter that translates a JSON Schema into GBNF.
Here is an example. Given this JSON Schema:
{
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer" },
"email": { "type": "string", "format": "email" }
},
"required": ["name", "age"]
}
The converter produces GBNF roughly equivalent to:
root ::= "{" ws root-kv-name "," ws root-kv-age ("," ws root-kv-email)? ws "}"
root-kv-name ::= "\"name\"" ws ":" ws json-string
root-kv-age ::= "\"age\"" ws ":" ws json-int
root-kv-email ::= "\"email\"" ws ":" ws json-string
ws ::= [ \t\n\r]*
json-string ::= "\"" json-chars "\""
json-int ::= "0" | [1-9] [0-9]*
...
The converter handles:
- Object properties with required/optional distinction
- Arrays with
items,minItems,maxItems - Strings with
format(date, time, date-time, uuid),minLength,maxLength - Numbers and integers
- Enums and const values
- oneOf, anyOf, allOf composition
- Nested objects and recursive structures (with a depth limit of 64)
- additionalProperties control
The SchemaConverter class builds rules incrementally, generating unique rule
names to avoid collisions. Each schema node is visited recursively, producing
either an inline GBNF expression or a named rule reference:
visit(schema, "root")
|
+--> visit_object(schema, "root")
| |
| +--> visit(name_schema, "root-name")
| | +--> returns "json-string"
| |
| +--> visit(age_schema, "root-age")
| | +--> returns "json-int"
| |
| +--> visit(email_schema, "root-email")
| +--> returns "json-string"
|
+--> add_rule("root", "{" ws kv-name "," ws kv-age ("," ws kv-email)? "}")
The Grammar Bitmask GPU Kernel
Once we have the bitmask from XGrammar, we need to apply it to the logits. This is done by a simple Metal kernel:
kernel void grammar_bitmask_f16(
device half *logits [[buffer(0)]],
device const uint *bitmask [[buffer(1)]],
constant uint &vocab_size [[buffer(2)]],
uint tid [[thread_position_in_grid]])
{
if (tid >= vocab_size)
return;
uint word = bitmask[tid / 32];
bool allowed = (word >> (tid % 32)) & 1;
if (!allowed)
logits[tid] = half(-INFINITY);
}
Each thread handles one token. It reads the appropriate word from the bitmask,
checks the bit for its token, and sets the logit to negative infinity if the
token is not allowed. The kernel dispatches with vocab_size threads and a
threadgroup size of 256.
For a 128k vocabulary:
- 128k threads / 256 per threadgroup = 512 threadgroups
- Each thread does: one integer division, one shift, one AND, one conditional write
- Total: ~microseconds on Apple Silicon
This is negligible compared to the model forward pass.
The Grammar Decode Loop
Let us trace through decode_grammar() step by step. This is the most complex
decode path in akunu because it orchestrates multiple GPU kernels per token.
Setup
bool use_sampling = (sampling.temperature > 0.0f);
uint32_t vocab_count = (uint32_t)state.config.vocab_size;
Pipeline temp_pso = use_sampling
? state.device->get_pipeline("temperature_scale_f16") : Pipeline{};
Pipeline mask_pso = state.device->get_pipeline("grammar_bitmask_f16");
Pipeline argmax_pso = state.device->get_pipeline("argmax_f16");
The function fetches three GPU pipelines:
- temperature_scale_f16: Multiplies logits by 1/T (only if sampling)
- grammar_bitmask_f16: Applies the grammar bitmask
- argmax_f16: Finds the token with the highest logit
If repetition penalty is enabled, a fourth pipeline is fetched: 4. repetition_penalty_f16: Penalizes previously generated tokens
The bitmask buffer is allocated once:
int bitmask_words = xg.bitmask_size();
size_t mask_bytes = bitmask_words * sizeof(int32_t);
std::vector<int32_t> bitmask(bitmask_words);
Buffer mask_buf = state.device->allocate(mask_bytes);
// Precompute first mask before entering the loop
xg.fill_next_token_bitmask(bitmask.data());
Per-token pipeline
Each iteration of the decode loop runs this pipeline:
Step 1: Copy bitmask to GPU
+-------------------------------------------+
| memcpy(gpu_mask_buf, bitmask, mask_bytes) |
+-------------------------------------------+
Step 2: Write current token to GPU
+-------------------------------------------+
| write_buffer(token_ids, &next_token, 4) |
+-------------------------------------------+
Step 3: Begin GPU encoding
+-------------------------------------------+
| encode_dispatch_table (model forward) |
| --> produces logits in scratch.logits |
+-------------------------------------------+
Step 4 (optional): Temperature scaling
+-------------------------------------------+
| temperature_scale_f16 kernel |
| logits[i] *= inv_temp for all i |
+-------------------------------------------+
Step 5 (optional): Repetition penalty
+-------------------------------------------+
| repetition_penalty_f16 kernel |
| For each prev token: scale its logit |
+-------------------------------------------+
Step 6: Apply grammar mask
+-------------------------------------------+
| grammar_bitmask_f16 kernel |
| logits[i] = -inf where bitmask bit is 0 |
+-------------------------------------------+
Step 7a (greedy): GPU argmax
+-------------------------------------------+
| argmax_f16 kernel --> winning token ID |
+-------------------------------------------+
Step 7b (sampling): CPU sampling
+-------------------------------------------+
| Copy F16 logits to CPU |
| Convert to F32 |
| sample_logits() with top-k/top-p/min-p |
+-------------------------------------------+
Step 8: Accept token in grammar
+-------------------------------------------+
| xg.accept_token(tok) |
+-------------------------------------------+
Step 9: Compute next bitmask (for next iter)
+-------------------------------------------+
| xg.fill_next_token_bitmask(bitmask) |
+-------------------------------------------+
Steps 3-7 are all encoded into a single GPU command buffer and execute without CPU intervention. Step 7b (CPU sampling) requires a sync point – this is the slower path, used when sampling is enabled but the GPU Gumbel-max kernel is not available for grammar-constrained decode.
Note the ordering: the bitmask for the next token is computed at the end of each iteration (step 9), then applied at the beginning of the next iteration (step 6). This overlaps the CPU-side bitmask computation with GPU execution of the next forward pass – a simple form of pipelining.
Termination conditions
The loop terminates when:
if (state.tokenizer.is_eos(tok)) // Model produced EOS
break;
if (xg.is_terminated()) // Grammar accepted stop token
break;
if (xg.is_completed()) // Root rule fully matched
break;
The grammar can signal completion in two ways:
- Terminated: The grammar accepted a stop token and is done
- Completed: The root rule has been fully matched (all required output has been generated)
Resource cleanup
The function manually frees GPU buffers at the end:
if (argmax_buf.handle)
state.device->free_buffer(argmax_buf);
if (mask_buf.handle)
state.device->free_buffer(mask_buf);
if (rep_tok_buf.handle)
state.device->free_buffer(rep_tok_buf);
These are per-call allocations (the bitmask buffer, argmax output buffer, and repetition penalty token buffer) that are not part of the model’s persistent state.
The GrammarHandle: Dual Engine Support
Akunu supports both the legacy NPDA engine and XGrammar through a unified
GrammarHandle struct:
struct GrammarHandle {
Grammar legacy; // Old NPDA engine
#ifdef AKUNU_HAS_XGRAMMAR
XGrammarEngine xgrammar; // New XGrammar engine
bool use_xgrammar = false;
#endif
};
The decode_grammar() function uses the XGrammar path when available (guarded
by #ifdef AKUNU_HAS_XGRAMMAR). The legacy engine is still available as a
fallback for builds without the XGrammar submodule.
This dual-engine design means:
- Minimal builds (no submodules) still get grammar support via the NPDA engine
- Full builds get the ~100x faster XGrammar path
- The decode loop code is shared; only the mask computation differs
Practical Usage Patterns
JSON output
The most common use case. The user specifies a JSON Schema, akunu converts it to GBNF (or passes it directly to XGrammar), and the output is guaranteed to be valid JSON matching the schema:
User schema: { "type": "object", "properties": { "x": {"type": "number"} } }
Generated: {"x": 42.5} <-- always valid JSON
Never: {"x": 42.5,} <-- no trailing commas
Never: {"x": forty-two} <-- no unquoted values
Never: Here's the JSON: ... <-- no prose
SQL generation
Custom GBNF grammars can constrain output to valid SQL:
root ::= select-stmt
select ::= "SELECT " columns " FROM " table where? ";"
columns ::= column ("," column)*
column ::= [a-zA-Z_] [a-zA-Z0-9_]*
table ::= [a-zA-Z_] [a-zA-Z0-9_]*
where ::= " WHERE " condition
...
Code generation
Grammar constraints can enforce syntactic validity for code output – matching braces, proper indentation markers, valid identifier characters. This is especially useful for generating configuration files (YAML, TOML) where a single syntax error makes the output unusable.
Performance Characteristics
The performance of grammar-constrained decode depends on which engine is used:
Operation | NPDA engine | XGrammar
---------------------+-------------+----------
Grammar compile | ~1ms | 10-100ms
Per-token mask (CPU) | 10-50ms | 0.01ms
Per-token mask (GPU) | N/A | 0.005ms
Mask copy to GPU | N/A | ~0.01ms
Bitmask size (128k) | N/A | 16 KB
Memory overhead | ~10 KB | ~1 MB (compiled cache)
XGrammar’s higher compile time is amortized over the entire generation. For a 100-token output, the per-token savings of ~10ms add up to ~1 second saved, which dwarfs the 10-100ms compile cost.
The GPU bitmask kernel is so fast (microseconds) that it is essentially free relative to the model forward pass (~5-10ms). The bottleneck in grammar decode is the model itself, not the grammar constraint.
Summary
Grammar-constrained generation transforms LLMs from “usually produces valid output” to “always produces valid output.” Akunu implements this through a dual-engine approach: a legacy NPDA engine for builds without dependencies, and the XGrammar library for production use. XGrammar precomputes token validity bitmasks at compile time, making per-token constraint nearly free. The GPU kernel applies the bitmask in microseconds, and the decode loop orchestrates model forward pass, temperature scaling, repetition penalty, grammar masking, and argmax/sampling into a single GPU command buffer per token. JSON schema conversion to GBNF handles the most common use case automatically.
The Sampling Pipeline
When temperature > 0, we do not simply pick the most likely token. Instead, we sample from the probability distribution, introducing controlled randomness that makes the model’s output more diverse, creative, and human-like. But turning a raw FP16 logit vector into a sampled token involves a multi-stage pipeline with several filtering steps, each with its own semantics and tradeoffs.
Akunu implements two distinct sampling paths: a CPU path in src/inference/sampling.cpp (used by grammar-constrained decode) and a GPU path via the Gumbel-max kernel (used by the default sampled decode). This chapter covers both, starting with the CPU path because it makes the algorithmic steps explicit, then explaining how the GPU path achieves the same result in a fundamentally different way.
The Full Pipeline: Logits to Token
Here is the complete CPU sampling pipeline, in order:
Raw F16 logits [vocab_size]
│
▼
1. F16 → F32 conversion
│
▼
2. Repetition penalty
│
▼
3. Temperature scaling (logits *= 1/T)
│
▼
4. Softmax (exp + normalize)
│
▼
5. Top-K filter (partial sort, keep K highest)
│
▼
6. Min-P filter (remove tokens below min_p * max_prob)
│
▼
7. Top-P filter (nucleus sampling, cumulative prob ≥ P)
│
▼
8. Re-normalize
│
▼
9. Categorical sampling (uniform random + CDF walk)
│
▼
Token ID
Let’s walk through each stage with the actual code.
The SamplingState Structure
Akunu uses a thread-local SamplingState to avoid repeated allocation:
struct SamplingState {
std::vector<float> logits;
struct Candidate {
uint32_t id;
float prob;
};
std::vector<Candidate> candidates;
int capacity = 0;
void ensure(int vocab_size) {
if (capacity >= vocab_size) return;
logits.resize(vocab_size);
candidates.resize(vocab_size);
capacity = vocab_size;
}
};
extern thread_local SamplingState sampling_state;
The thread_local qualifier is important: since sampling runs on the CPU and could potentially be called from multiple threads (though Akunu currently uses single-threaded decode), the state is per-thread to avoid data races. The ensure method only reallocates if the vocabulary size has grown, so for typical usage (same model throughout), allocation happens exactly once.
The Candidate struct pairs a token ID with its probability, enabling efficient sorting and filtering without losing track of which probability belongs to which token.
Stage 1: F16 to F32 Conversion
The GPU produces logits in FP16. The CPU needs FP32 for numerical stability during softmax:
const __fp16 *f16 = (const __fp16 *)logits_data;
for (int i = 0; i < vocab_size; i++)
logits[i] = (float)f16[i];
This is a straightforward widening conversion. On Apple Silicon, __fp16 is a native type and the conversion is hardware-accelerated through the NEON FP16 extension.1 For a 128K vocabulary, this copies 256KB of data and takes roughly 50 microseconds.
Stage 2: Repetition Penalty
If enabled, repetition penalty discourages the model from repeating tokens it has already generated:
if (repeat_penalty != 1.0f && prev_tokens && n_prev > 0) {
for (int i = 0; i < n_prev; i++) {
uint32_t tok = prev_tokens[i];
if (tok < (uint32_t)vocab_size) {
if (logits[tok] > 0)
logits[tok] /= repeat_penalty;
else
logits[tok] *= repeat_penalty;
}
}
}
The asymmetric treatment is deliberate:
| Logit Sign | Operation | Effect |
|---|---|---|
| Positive | Divide by penalty | Reduces probability |
| Negative | Multiply by penalty | Increases magnitude (makes it more negative, further reducing probability) |
This ensures that repetition penalty always decreases the probability of previously seen tokens, regardless of whether the logit is positive or negative. A penalty of 1.2 reduces a positive logit by 20% and increases a negative logit’s magnitude by 20%.2
Stage 3: Temperature Scaling
Temperature controls the “sharpness” of the distribution:
if (temperature <= 0.0f) {
// Greedy: return argmax
uint32_t best = 0;
float best_val = logits[0];
for (int i = 1; i < vocab_size; i++) {
if (logits[i] > best_val) {
best_val = logits[i];
best = (uint32_t)i;
}
}
return best;
}
float inv_temp = 1.0f / temperature;
for (int i = 0; i < vocab_size; i++)
logits[i] *= inv_temp;
Multiplying logits by 1/temperature before softmax is mathematically equivalent to dividing the softmax exponents by temperature:
$$\text{softmax}(x_i / T) = \frac{e^{x_i / T}}{\sum_j e^{x_j / T}}$$
| Temperature | Effect | Use Case |
|---|---|---|
| T = 0 | Argmax (greedy) | Deterministic, factual |
| T < 1 | Sharper distribution | More focused, less creative |
| T = 1 | Model’s native distribution | Balanced |
| T > 1 | Flatter distribution | More diverse, more creative |
Note that temperature = 0 is handled as a special case that short-circuits the entire pipeline, returning the argmax directly.
Stage 4: Softmax (Numerically Stable)
The logits are converted to probabilities using softmax with the max-subtraction trick for numerical stability:
// Find max for numerical stability
float max_logit = logits[0];
for (int i = 1; i < vocab_size; i++) {
if (logits[i] > max_logit)
max_logit = logits[i];
}
// Build (index, probability) pairs
float sum = 0;
for (int i = 0; i < vocab_size; i++) {
float p = expf(logits[i] - max_logit);
candidates[i] = {(uint32_t)i, p};
sum += p;
}
float inv_sum = 1.0f / sum;
for (int i = 0; i < vocab_size; i++)
candidates[i].prob *= inv_sum;
Subtracting max_logit before exponentiation prevents overflow. Without this, expf(logits[i]) could produce +inf for large logits, corrupting the entire computation. After subtraction, the largest exponent is expf(0) = 1, and all others are in (0, 1].3
Stage 5: Top-K Filter
Top-K keeps only the K most probable tokens, zeroing out the rest:
int n_cand = vocab_size;
int effective_k = (top_k > 0 && top_k < n_cand) ? top_k : std::min(n_cand, 256);
std::partial_sort(
candidates, candidates + effective_k, candidates + n_cand,
[](const SamplingState::Candidate& a, const SamplingState::Candidate& b) {
return a.prob > b.prob;
});
n_cand = effective_k;
Note the use of std::partial_sort instead of std::sort. Partial sort is O(n log k) compared to full sort’s O(n log n). For a 128K vocabulary with top_k=40, this is roughly 128K * log(40) / (128K * log(128K)) ≈ 30% of the work of a full sort.4
When top_k is 0 or negative, the default cap of 256 is applied. This prevents the subsequent min-p and top-p steps from operating on the full 128K vocabulary, which would be slow.
Stage 6: Min-P Filter
Min-P is a relatively recent sampling technique that removes tokens whose probability is less than min_p * max_probability:5
if (min_p > 0 && n_cand > 0) {
float threshold = candidates[0].prob * min_p;
for (int i = 1; i < n_cand; i++) {
if (candidates[i].prob < threshold) {
n_cand = i;
break;
}
}
}
Because the candidates are already sorted by probability (from the top-K partial sort), candidates[0] holds the maximum probability, and we just scan forward until we find a candidate below the threshold.
Min-P is adaptive: it removes more tokens when the model is confident (one token has very high probability) and fewer tokens when the model is uncertain (many tokens have similar probabilities). This makes it more robust than fixed top-K.
| min_p | When max_prob = 0.9 | When max_prob = 0.1 |
|---|---|---|
| 0.1 | Keep tokens with prob > 0.09 | Keep tokens with prob > 0.01 |
| 0.05 | Keep tokens with prob > 0.045 | Keep tokens with prob > 0.005 |
Stage 7: Top-P (Nucleus) Filter
Top-P sampling keeps the smallest set of tokens whose cumulative probability exceeds p:
if (top_p < 1.0f && n_cand > 0) {
float cumsum = 0;
for (int i = 0; i < n_cand; i++) {
cumsum += candidates[i].prob;
if (cumsum >= top_p) {
n_cand = i + 1;
break;
}
}
}
Since candidates are sorted by descending probability, this walks from the most likely token to the least likely, accumulating probability mass. As soon as the cumulative mass reaches top_p, the remaining tokens are discarded.6
For example, with top_p = 0.95, if the top 5 tokens have probabilities [0.4, 0.3, 0.15, 0.08, 0.04, ...], the cumulative sum reaches 0.97 at token 4, so we keep 5 tokens.
Stage 8: Re-Normalization
After filtering, the remaining candidates’ probabilities no longer sum to 1. We fix that:
sum = 0;
for (int i = 0; i < n_cand; i++)
sum += candidates[i].prob;
inv_sum = 1.0f / sum;
for (int i = 0; i < n_cand; i++)
candidates[i].prob *= inv_sum;
Stage 9: Categorical Sampling
Finally, we sample from the filtered distribution using the inverse CDF method:
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
float r = dist(rng);
float cumsum = 0;
for (int i = 0; i < n_cand; i++) {
cumsum += candidates[i].prob;
if (r <= cumsum)
return candidates[i].id;
}
return candidates[n_cand - 1].id;
A uniform random number r in [0, 1) is drawn, and we walk through the CDF until r <= cumsum. The last candidate is returned as a fallback in case of floating-point rounding.
The F32 Overload (Grammar Path)
There is a second sample_logits overload that takes pre-converted F32 logits:
uint32_t sample_logits(float *logits, int vocab_size,
float temperature, int top_k, float top_p, float min_p);
This is used by decode_grammar, where the CPU has already read and masked the logits. The pipeline is identical except it skips the F16->F32 conversion and the repetition penalty (grammar decode applies its own masks). The temperature check also differs slightly:
if (temperature != 1.0f) {
float inv_temp = 1.0f / temperature;
for (int i = 0; i < vocab_size; i++)
logits[i] *= inv_temp;
}
When temperature is exactly 1.0, the scaling is skipped entirely (not just for performance – it avoids introducing unnecessary floating-point error).
The GPU Sampling Path: Gumbel-Max
The GPU path replaces the entire CPU pipeline with a single mathematical trick. Instead of:
- Softmax to get probabilities
- Sample from categorical distribution
We use:
- Add Gumbel noise to logits
- Argmax
The Gumbel-max theorem states:7
$$\arg\max_i \left(\log \pi_i + G_i\right) \sim \text{Categorical}(\pi)$$
where $G_i \sim \text{Gumbel}(0, 1)$ and $\pi_i = \text{softmax}(\text{logit}_i / T)$.
Since $\log \pi_i \propto \text{logit}_i / T$, we can simplify to:
$$\arg\max_i \left(\text{logit}_i + T \cdot G_i\right)$$
This is exactly what the gumbel_topk_f16 kernel computes:
float u = pcg_float(element_seed + i);
u = clamp(u, 1e-7f, 1.0f - 1e-7f);
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);
The Gumbel noise is generated from a uniform random variable via the inverse CDF: $G = -\log(-\log(U))$ where $U \sim \text{Uniform}(0, 1)$.
The beauty of this approach is that top-k, top-p, and min-p are applied before the noise, so they function identically to the CPU path. Tokens that are filtered out get set to -inf, which means they can never win the argmax regardless of the noise.
SamplingParams: The Configuration
All sampling parameters come from AkunuSamplingConfig:
| Parameter | Type | Default | Meaning |
|---|---|---|---|
temperature | float | 0.0 | Softmax temperature (0 = greedy) |
top_k | int | 0 | Keep only top-K tokens (0 = no limit, uses 256 cap) |
top_p | float | 1.0 | Nucleus sampling threshold |
min_p | float | 0.0 | Minimum probability ratio threshold |
repeat_penalty | float | 1.0 | Repetition penalty factor |
These parameters can be combined freely. A common configuration for creative text generation is:
| Use Case | temperature | top_k | top_p | min_p |
|---|---|---|---|---|
| Greedy/Factual | 0.0 | – | – | – |
| Balanced | 0.7 | 40 | 0.95 | 0.05 |
| Creative | 1.0 | 0 | 0.9 | 0.0 |
| Very Creative | 1.2 | 0 | 0.95 | 0.0 |
| Code Gen | 0.2 | 40 | 0.95 | 0.1 |
Seed Handling and Reproducibility
The GPU path uses a PCG hash for random number generation:
inline float pcg_float(uint state) {
state = state * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
word = (word >> 22u) ^ word;
return float(word) / 4294967296.0f;
}
This is a stateless hash: given the same input state, it always produces the same output. The state is derived from:
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
The seed_offset is a time-based value set once per generation call, and position is patched per-token in the chain. The multiplication by 2654435761 (a Knuth multiplicative hash constant) spreads the entropy across the bits.8
For the CPU path, Akunu uses std::uniform_real_distribution with a thread-local rng (random number generator). The RNG state persists across tokens within a generation, providing a different random sequence for each call.
CPU vs GPU: When to Use Which
| Aspect | CPU Path | GPU Path |
|---|---|---|
| Latency per token | ~100us (F16->F32 + sort + sample) | ~0us (no CPU roundtrip) |
| Required for | Grammar decode | Non-grammar sampled decode |
| Top-K method | std::partial_sort | Binary search on logit values |
| RNG quality | std::mt19937 (Mersenne Twister) | PCG hash (stateless) |
| Chain decode compatible | No (requires CPU round-trip) | Yes |
The GPU path is strictly faster because it eliminates the CPU roundtrip. However, the CPU path is necessary when per-token CPU processing is required (grammar constraints, custom logit processors).
Summary
The sampling pipeline transforms raw logits into a sampled token through a carefully ordered sequence of filtering and normalization steps:
- Temperature controls distribution sharpness.
- Top-K provides a hard cutoff on candidate count.
- Min-P provides an adaptive cutoff based on the most likely token.
- Top-P provides a probability-mass-based cutoff.
- Categorical sampling draws from the filtered distribution.
Akunu offers two implementations: a CPU path for maximum flexibility (grammar constraints) and a GPU path (Gumbel-max) for maximum throughput (chain decode compatible). The GPU path achieves the same statistical result as the CPU path using a mathematical equivalence that replaces softmax + sampling with noise + argmax.
-
ARM. “ARM Architecture Reference Manual: ARMv8-A.” Section C7.2, FCVT instruction. The
__fp16tofloatconversion on AArch64 uses the FCVT instruction which executes in 1-2 cycles. See https://developer.arm.com/documentation/ddi0487/latest/. ↩ -
Keskar, N. S., et al. “CTRL: A Conditional Transformer Language Model for Controllable Generation.” arXiv:1909.05858, 2019. The repetition penalty was popularized by the CTRL model and adopted by most LLM inference frameworks. See https://arxiv.org/abs/1909.05858. ↩
-
The log-sum-exp trick is a standard technique in numerical computing. See Blanchard, P., Higham, D.J., and Higham, N.J. “Accurately Computing the Log-Sum-Exp and Softmax Functions.” IMA Journal of Numerical Analysis, 2021. See https://doi.org/10.1093/imanum/draa038. ↩
-
std::partial_sortuses the introsort algorithm with heapselect fallback. In practice, it performs approximatelyn * log(k)comparisons for selecting the top k from n elements. See: Musser, D.R. “Introspective Sorting and Selection Algorithms.” Software: Practice and Experience, 1997. See https://en.wikipedia.org/wiki/Introsort. ↩ -
Min-P sampling was introduced in the open-source LLM community and formalized in: “Min-P Sampling: New Way to Control the Creativity of LLMs.” Discussion on llama.cpp GitHub, 2023. It provides more consistent behavior across different probability distributions than fixed top-K. See https://github.com/ggerganov/llama.cpp/pull/3841. ↩
-
Holtzman, A., et al. “The Curious Case of Neural Text Degeneration.” ICLR 2020. This paper introduced nucleus (top-p) sampling and showed that it produces more coherent text than top-k sampling for open-ended generation. See https://arxiv.org/abs/1904.09751. ↩
-
The Gumbel-max trick is proven in: Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. A modern treatment with applications to neural networks: Jang, E., Gu, S., and Poole, B. “Categorical Reparameterization with Gumbel-Softmax.” ICLR 2017. See https://arxiv.org/abs/1411.0030. ↩
-
O’Neill, M. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905, 2014. The PCG hash provides good statistical properties with minimal state, making it ideal for GPU kernels where per-thread state is expensive. See https://www.pcg-random.org/paper.html. ↩
Kernel Architecture Overview
If you have made it this far in the book, you understand how akunu’s Swift host code
orchestrates inference – how buffers are allocated, how the command encoder sequences
dispatches, how the KV cache grows and shifts. But we have been treating each
computeEncoder.dispatchThreadgroups(...) call as a black box. Time to open every
single one of those boxes.
Akunu ships roughly 135 Metal shader files, spread across 11 subdirectories, each one a hand-tuned GPU program that runs on Apple Silicon. This chapter is the map. We will cover the directory layout, the shared header infrastructure, the parameter struct conventions, the function-constant specialization system, the buffer binding conventions, and finally a taxonomy of every kernel category. By the end you will know where to look for any particular kernel and why it was written the way it was.
33.1 The 10,000-Foot View: Directory Layout
All Metal source lives under a single tree:
backend/metal/kernels/
ShaderTypes.h <-- Parameter structs (shared with Swift)
KernelCommon.h <-- Block types, helpers, constants
metal/kernel/
activation/ <-- 2 files (silu, gelu)
attention/ <-- 6 files (decode, fast-decode, parallel, prefill, softmax, logit_cap)
common/ <-- 6 files (bias_add, residual_add, vector_add, transpose, head_rearrange, qkv_split)
conv/ <-- 1 file (conv1d_f16)
convert/ <-- 12 files (dequant for every quant format + f16<->f32)
embedding/ <-- 12 files (lookup for every quant format + bf16 + pos_embed_add)
fused/ <-- 2 files (gemv_q8_0_head_rmsnorm, whisper_gemv_fused)
kv_cache/ <-- 2 files (kv_cache_write, kv_cache_shift)
matmul/ <-- 72 files (gemv, gemv_wide, gemv_batched, simd_gemm, simd_gemm_small)
norm/ <-- 5 files (rmsnorm, rmsnorm_gemma, layernorm, residual_rmsnorm, head_rmsnorm)
rope/ <-- 7 files (rope, rope_neox, fused rope+kv variants)
sampling/ <-- 7 files (argmax, topk, temperature, repetition_penalty, gumbel, grammar, whisper_suppress)
Let us count that up:
+---------------------+-------+
| Category | Files |
+---------------------+-------+
| matmul (GEMV/GEMM) | 72 |
| embedding | 12 |
| convert (dequant) | 12 |
| sampling | 7 |
| rope | 7 |
| attention | 6 |
| common | 6 |
| norm | 5 |
| activation | 2 |
| fused | 2 |
| kv_cache | 2 |
| conv | 1 |
+---------------------+-------+
| Total | ~135 |
+---------------------+-------+
The overwhelming majority – 72 out of 135 – are in matmul/. That is not a
surprise. The single hottest operation in LLM inference is the matrix-vector
multiply (GEMV for decode) and matrix-matrix multiply (GEMM for prefill). When
you multiply those two kernel shapes (GEMV + GEMM) by the number of quantization
formats akunu supports (Q4_0, Q4_1, Q4_K, Q5_0, Q5_K, Q6_K, Q8_0, Q2_K, Q3_K,
F16, BF16, MLX-Q4, MLX-Q3, MLX-Q6, MLX-Q8), and then add the variant axes
(standard, wide, batched, small-M, fused-silu), you get a combinatorial explosion
of kernel files.
Here is how the matmul/ directory decomposes:
matmul/
gemv_*.metal (18 files) -- Single-row GEMV (M=1 decode)
gemv_wide_*.metal ( 7 files) -- Wide GEMV (large N, e.g. vocab projection)
gemv_batched_*.metal (12 files) -- Batched GEMV (M=2..16, speculative decode)
simd_gemm_*.metal (17 files) -- Tiled GEMM (M>=32, prefill)
simd_gemm_small_*.metal(16 files) -- Small-M GEMM (M=2..8, small prefill)
gemv_*_silu.metal ( 5 files) -- Fused SiLU(gate)*up + GEMV
That is the landscape. Let us now zoom in on the two shared header files that every kernel includes.
33.2 ShaderTypes.h: The Contract Between Swift and Metal
Every Metal kernel receives its parameters via a constant buffer. On the Swift side,
that buffer is a MTLBuffer filled with a C struct. On the GPU side, the kernel
reads that same struct. The contract between the two is ShaderTypes.h.
There is a critical comment at the top of this file:
/*
* Shared type definitions used by both Metal kernels (.metal) and Swift host code.
*
* CRITICAL: Any change here MUST be mirrored in Sources/KernelStore/MetalTypes.swift.
* All structs are padded to 16-byte boundaries for Metal argument buffer alignment.
*/
Metal requires that constant buffer offsets and struct sizes be aligned to 16 bytes.
Every struct in ShaderTypes.h is manually padded with _pad fields to guarantee
this. Let us walk through every parameter struct.
GEMMParams (32 bytes)
struct GEMMParams {
uint32_t M; // Rows of A / rows of C
uint32_t N; // Columns of B / columns of C
uint32_t K; // Columns of A / rows of B
uint32_t lda; // Leading dimension of A
uint32_t ldb; // Leading dimension of B
uint32_t ldc; // Leading dimension of C
float alpha; // C = alpha * A @ B + beta * C
float beta;
};
This is the workhorse parameter struct. It is used by every GEMV and GEMM kernel.
The M, N, K triple defines the matrix dimensions, and lda/ldb/ldc are
the leading dimensions (row strides), allowing matrices that are sub-views of larger
buffers. The alpha/beta fields enable BLAS-style C = alpha*A*B + beta*C
semantics – useful for residual connections and accumulation.
Here is how it maps onto the GEMV case (M=1):
x [1, K] @ W^T [N, K] --> y [1, N]
+---+---+---+---+---+---+---+---+ GEMMParams
| M | N | K |lda|ldb|ldc| a | b |
| 1 |4096|4096|4096|4096|4096|1.0|0.0|
+---+---+---+---+---+---+---+---+
0 4 8 12 16 20 24 28 byte offset
AttentionParams (32 bytes)
struct AttentionParams {
uint32_t seq_len;
uint32_t kv_seq_len;
uint32_t head_dim;
uint32_t n_heads;
uint32_t n_kv_heads;
float scale; // 1.0 / sqrt(head_dim)
uint32_t kv_stride; // elements between KV heads
uint32_t q_stride; // elements between Q/O rows
};
The attention params carry everything the flash-attention kernels need. The
kv_seq_len can differ from seq_len during decode (where seq_len=1 but
kv_seq_len could be thousands). The kv_stride and q_stride fields allow
flexible memory layouts – if zero, the kernel falls back to kv_seq_len * head_dim
or n_heads * head_dim respectively. This is what lets the same kernel handle both
contiguous and interleaved head layouts.
RMSNormParams (16 bytes)
struct RMSNormParams {
uint32_t dim;
float eps;
uint32_t _pad0;
uint32_t _pad1;
};
Minimal. Just the dimension and epsilon. The two pad fields bring it to exactly 16 bytes. The kernel figures out which row to process from its threadgroup position in the grid.
RoPEParams (32 bytes)
struct RoPEParams {
uint32_t seq_len;
uint32_t head_dim;
uint32_t n_heads;
uint32_t pos_offset; // global position for decode step
float theta; // base frequency (default 10000.0)
uint32_t row_stride; // elements between rows
uint32_t _pad0;
uint32_t _pad1;
};
The pos_offset is the key field here – during decode, each token’s position is
pos_offset, not derived from the sequence index. The theta field (default
10000.0) is the RoPE base frequency, configurable per model.
MLXParams (32 bytes)
struct MLXParams {
uint32_t M;
uint32_t N;
uint32_t K;
uint32_t group_size; // quantization group size (typically 64)
uint32_t bits; // bits per value (4 or 8)
uint32_t weight_bytes; // byte offset to scales section
uint32_t _pad0;
uint32_t _pad1;
};
MLX-format weights pack everything into a single contiguous buffer:
[packed_weights | scales | biases]. The weight_bytes field tells the kernel
where the scales section starts, and from there the biases follow immediately at
scales + N * (K / group_size). The bits field selects between 3-bit, 4-bit,
6-bit, and 8-bit dequantization paths.
Fused Parameter Structs
Akunu has several fused kernels that combine two operations into one dispatch. Each has its own parameter struct:
RoPEQKVWriteParams (32 bytes) -- Fused Q/K-RoPE + KV cache write
KVCacheWriteParams (32 bytes) -- KV cache write (separate K and V)
KVCacheShiftParams (32 bytes) -- KV cache left-shift (ring-buffer eviction)
GEMVHeadNormParams (32 bytes) -- Fused GEMV + per-head RMSNorm
GEMVKVParams (32 bytes) -- Fused GEMV + KV cache write
HeadNormParams (32 bytes) -- Per-head RMSNorm (standalone)
Conv1DParams (32 bytes) -- Conv1D parameters (Whisper)
And simpler ones:
ElementwiseParams (16 bytes) -- Just a count
SoftmaxParams (16 bytes) -- rows, cols
EmbeddingParams (16 bytes) -- num_tokens, dim
LayerNormParams (16 bytes) -- dim, eps
TemperatureScaleParams (16 bytes) -- inv_temperature, count
RepetitionPenaltyParams (16 bytes) -- penalty, n_tokens
The pattern is consistent: every struct is either 16 or 32 bytes, always 16-byte aligned, always with explicit padding.
33.3 KernelCommon.h: Shared Infrastructure
Every .metal file includes KernelCommon.h. This header defines:
Hardware Constants
constant constexpr uint SIMD_WIDTH = 32;
constant constexpr uint MAX_TG_MEMORY = 32768; // 32 KB
constant constexpr uint SIMD_TILE = 8; // native simdgroup_matrix dimension
Apple Silicon GPUs have a SIMD width of 32 threads (unlike NVIDIA’s 32 or AMD’s 64).
The SIMD_TILE = 8 is the native dimension of Apple’s simdgroup_matrix operations
– all simdgroup matrix operations work on 8x8 tiles.
GEMM Tiling Constants
constant constexpr uint TILE_M = 64;
constant constexpr uint TILE_N = 64;
constant constexpr uint TILE_K = 32;
constant constexpr uint GEMM_TG_WIDTH = 32;
constant constexpr uint GEMM_TG_HEIGHT = 4;
constant constexpr uint GEMM_TG_SIZE = 128; // 4 SIMD groups
The make_uniform() Helper
inline int make_uniform(int val) {
return simd_broadcast_first(val);
}
This is a surprisingly important optimization. When you write for (int i = 0; i < N; i++),
the Metal compiler does not know whether N is the same across all threads. If it
might differ, the compiler must generate divergent branching code. By wrapping
the loop bound in make_uniform(), you explicitly tell the compiler “this value is
identical across all threads in the SIMD group,” enabling it to use uniform branch
prediction and avoid per-lane divergence handling.
You will see make_uniform() wrapped around virtually every loop bound in every
kernel.
Quantized Block Types
#define QK4_0 32 // elements per Q4_0 block
#define QK8_0 32 // elements per Q8_0 block
#define QK_K 256 // elements per K-quant superblock
struct block_q4_0 {
half d; // scale factor (2 bytes)
uint8_t qs[QK4_0 / 2]; // 16 bytes of nibble-packed values
}; // Total: 18 bytes per 32 elements = 4.5 bits/element
struct block_q8_0 {
half d; // scale factor (2 bytes)
int8_t qs[QK8_0]; // 32 bytes of 8-bit values
}; // Total: 34 bytes per 32 elements = 8.5 bits/element
struct block_q4_K {
half d; // super-block scale (2 bytes)
half dmin; // super-block min scale (2 bytes)
uint8_t scales[K_SCALE_SIZE]; // 12 bytes of packed 6-bit scales
uint8_t qs[QK_K / 2]; // 128 bytes of nibble-packed values
}; // Total: 144 bytes per 256 elements = 4.5 bits/element
Here is a visual layout of block_q4_0:
block_q4_0: 18 bytes total, 32 elements
+---------+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| d (f16)|q0|q1|q2|q3|q4|q5|q6|q7|q8|q9|qA|qB|qC|qD|qE|qF|
+---------+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| 2 bytes | 16 bytes |
+---------+-------------------------------------------------+
Each byte qs[j] holds two 4-bit values:
element j = (qs[j] & 0x0F) - 8 (low nibble)
element j + 16 = (qs[j] >> 4) - 8 (high nibble)
Dequantization: value = d * (nibble - 8)
And block_q4_K, the “K-quant” superblock:
block_q4_K: 144 bytes total, 256 elements
+------+------+-----------+---------------------------------------+
|d (f16)|dmin | scales[12]| qs[128] |
|2 bytes|2 bytes| 12 bytes | 128 bytes |
+------+------+-----------+---------------------------------------+
The 256 elements are divided into 8 sub-blocks of 32 elements each.
Each sub-block has its own 6-bit scale and 6-bit min, packed into
the 12-byte scales array.
Dequantization:
value = d * sub_scale * nibble - dmin * sub_min
Threadgroup Reduction Helpers
inline float tg_reduce_sum(float val, uint sgid, uint slid,
uint tg_size, threadgroup float *shared);
inline float tg_reduce_max(float val, uint sgid, uint slid,
uint tg_size, threadgroup float *shared);
These implement the classic two-phase reduction pattern:
Phase 1: simd_sum() within each SIMD group
Phase 2: Write lane-0 results to shared[], barrier, then simd_sum() in SG 0
Example: 256 threads = 8 SIMD groups
SG0: simd_sum → shared[0]
SG1: simd_sum → shared[1]
...
SG7: simd_sum → shared[7]
--- barrier ---
SG0: reads shared[0..7], simd_sum → shared[0]
--- barrier ---
All threads read shared[0]
Activation Functions
inline half act_silu(half x) {
return x / (half(1) + exp(-x));
}
inline float act_gelu_f32(half x) {
float xf = float(x);
constexpr float SQRT_2_OVER_PI = 0.7978845608f;
constexpr float GELU_COEF_A = 0.044715f;
return 0.5f * xf * (1.0f + precise::tanh(SQRT_2_OVER_PI * xf * (1.0f + GELU_COEF_A * xf * xf)));
}
Note that act_gelu_f32 returns float, not half. The comment explains why:
“returns float to avoid F16 overflow in gate*up multiplication.” When you compute
gelu(gate) * up, both operands can be large in FP16, and the product can overflow.
Computing in FP32 avoids this.
33.4 Function Constant Specialization
Metal’s function constants are compile-time values that can vary between pipeline states. Think of them as C++ template parameters that you set at pipeline creation time, not at shader compile time. The Metal compiler generates specialized variants with the constants inlined, enabling dead-code elimination, loop unrolling, and strength reduction.
Akunu uses function constants extensively. Here are the key ones:
FC_HEAD_DIM [[function_constant(0)]] -- Attention head dimension (64/128/256)
FC_GROUP_SIZE [[function_constant(0)]] -- MLX quantization group size
FC_K_DIM [[function_constant(1)]] -- MLX K dimension
FC_NQ_ROWS [[function_constant(2)]] -- Adaptive query rows for prefill
FC_NON_CAUSAL [[function_constant(3)]] -- Non-causal attention flag
FC_GEMM_K [[function_constant(10)]] -- GEMM K dimension
FC_BATCH_M [[function_constant(11)]] -- Batched GEMV M dimension
The pattern for using them looks like this:
constant uint FC_HEAD_DIM [[function_constant(0)]];
constant bool FC_ATTN_SPECIALIZED = is_function_constant_defined(FC_HEAD_DIM);
// In the kernel:
const uint head_dim = FC_ATTN_SPECIALIZED ? FC_HEAD_DIM : params.head_dim;
When the host creates a pipeline state with FC_HEAD_DIM = 128, the compiler
replaces head_dim with the literal 128 everywhere, enabling:
- Loop unrolling:
for (uint d = 0; d < head_dim; d += 8)becomes a known-count loop - Dead code elimination:
if (head_dim == 256)branches can be removed - Register allocation: the compiler knows exactly how many registers to allocate
- Shift/mask optimization:
idx / head_dimbecomesidx >> 7for head_dim=128
This is why you see patterns like:
const uint hd_shift = (head_dim == 128) ? 7 : (head_dim == 64) ? 6 : (head_dim == 256) ? 8 : 0;
const uint hd_mask = head_dim - 1;
When FC_HEAD_DIM = 128, the compiler collapses this to hd_shift = 7 and
hd_mask = 127, and every subsequent idx / head_dim becomes idx >> 7.
33.5 Buffer Binding Convention
Every kernel follows a consistent buffer binding pattern:
Buffer 0: Input data (activations, Q, token_ids, etc.)
Buffer 1: Weights (W, K, packed MLX buffer, etc.)
Buffer 2: Output (y, O, etc.)
Buffer 3: Parameters (GEMMParams, AttentionParams, etc.)
Buffer 4+: Optional (V for attention, tree_mask, etc.)
For matmul kernels specifically:
buffer(0) = x -- activation vector [K] or [M, K]
buffer(1) = W -- weight matrix [N, K] (possibly quantized)
buffer(2) = y -- output vector [N] or [M, N]
buffer(3) = params -- GEMMParams struct
For attention kernels:
buffer(0) = Q -- query [n_heads, seq_len, head_dim]
buffer(1) = K -- key [n_kv_heads, kv_seq_len, head_dim]
buffer(2) = V -- value [n_kv_heads, kv_seq_len, head_dim]
buffer(3) = O -- output [n_heads, seq_len, head_dim]
buffer(4) = params -- AttentionParams struct
buffer(5) = tree_mask -- optional speculative decoding mask
For fused kernels that combine two operations (like gemv_q4_0_silu), the input
buffer splits:
buffer(0) = gate -- gate projection output
buffer(1) = up -- up projection output
buffer(2) = W -- down projection weights
buffer(3) = y -- output
buffer(4) = params -- GEMMParams
This convention is enforced by the Swift host code in KernelStore, which maps
each buffer to its binding index at pipeline creation time.
33.6 Kernel Naming Conventions
The naming follows a strict {operation}_{quantformat}[_{variant}] pattern:
gemv_f16 -- GEMV, F16 weights, standard
gemv_q4_0 -- GEMV, Q4_0 weights, K < 2048
gemv_q4_0_l -- GEMV, Q4_0 weights, K >= 2048 (large)
gemv_q4_0_silu -- GEMV, Q4_0 weights, fused SiLU activation
gemv_wide_q4_0 -- GEMV, Q4_0 weights, wide variant (large N)
gemv_batched_q4_0 -- GEMV, Q4_0 weights, batched (M=2..16)
simd_gemm_q4_0 -- GEMM, Q4_0 weights, standard tile
simd_gemm_small_q4_0-- GEMM, Q4_0 weights, small M tile
gemv_mlx_q4 -- GEMV, MLX 4-bit format
gemv_mlx_q4_l -- GEMV, MLX 4-bit format, large (8 SGs)
The _l suffix denotes the “large K” variant. The _silu suffix denotes fused
SiLU activation. The _small prefix in GEMM denotes the TM=8 tile geometry
optimized for small batch sizes.
33.7 Kernel Category Deep Dive
Let us briefly survey each category before we dive deep in subsequent chapters.
Matmul (72 files)
This is the heart of akunu. Three major kernel families:
GEMV (M=1 decode): The single-token decode path. One activation vector multiplied by the weight matrix. This is memory-bandwidth-bound – you read the entire weight matrix for a single dot product per row. The key optimization is reading the weight once and computing multiple output rows per SIMD group (NR=4 typically).
GEMM (M>=32 prefill): The multi-token prefill path. Uses Apple’s
simdgroup_multiply_accumulate for native 8x8 matrix operations. Tile geometry
is TM=32, TN=64, NK=32, with 4 SIMD groups cooperatively loading tiles into
threadgroup memory.
Batched GEMV (M=2..16): Bridges the gap between GEMV and GEMM. For speculative decoding or small batches, neither pure GEMV (reads weights M times) nor GEMM (wastes tile space when M<32) is optimal. Batched GEMV reads weights once and computes all M activation rows simultaneously.
Attention (6 files)
Four distinct flash-attention implementations, each tuned for a different scenario:
-
Standard Decode (
flash_attention.metal: flash_attention_decode_f16): Multi-SG, register-local V accumulation, 2 barriers per KV entry. -
Fast Decode (
flash_attention_decode_fast.metal): Single SIMD group, zero threadgroup barriers, each thread handleshead_dim/32elements. -
Parallel Decode (
flash_attention_decode_parallel.metal): 32 SIMD groups (1024 threads), strided KV access for maximum memory parallelism, cross-SG reduction. -
Prefill V2 (
flash_attention_prefill_v2.metal): BQ=32 query rows, simdgroup register accumulators, exp2 trick, row-level reduce viasimd_shuffle_xor.
Plus a standalone softmax kernel and a Gemma logit soft-capping kernel.
Norm (5 files)
rmsnorm.metal -- Standard RMSNorm: y = (x / rms) * weight
rmsnorm_gemma.metal -- Gemma variant: y = (x / rms) * (1 + weight)
layernorm.metal -- Full LayerNorm with mean subtraction
residual_rmsnorm.metal -- Fused residual add + RMSNorm
head_rmsnorm.metal -- Per-head RMSNorm (for architectures like Cohere)
All follow the same pattern: one threadgroup per row, threads stride over the dimension, two-phase reduction for sum-of-squares.
RoPE (7 files)
rope.metal -- Standard rotary position embeddings
rope_neox.metal -- GPT-NeoX interleaved RoPE
rope_kv_write.metal -- Fused RoPE + KV cache write
rope_neox_kv_write.metal-- Fused NeoX RoPE + KV cache write
*_batch.metal -- Batch (prefill) variants of the above
head_norm_rope_neox_kv_write.metal -- Triple fused: head norm + RoPE + KV write
The fused variants are critical for performance – they eliminate intermediate buffer writes between RoPE and KV cache insertion.
Embedding (12 files)
One embedding lookup kernel per quantization format. Each dequantizes on the fly
during the lookup, converting the quantized embedding table directly to FP16 output.
The MLX variants handle the packed [weights|scales|biases] buffer layout.
Sampling (7 files)
argmax.metal -- Simple argmax (greedy decoding)
topk_select.metal -- Top-K selection for sampling
temperature_scale.metal -- Temperature scaling (logits *= 1/T)
repetition_penalty.metal-- Repetition penalty application
gumbel_topk.metal -- Gumbel-max trick for stochastic top-K
grammar_bitmask.metal -- Grammar-constrained decoding mask
whisper_suppress.metal -- Whisper-specific token suppression
Convert/Dequant (12 files)
Standalone dequantization kernels that convert quantized buffers to FP16. These are used when a kernel does not have a native quantized variant, or for debugging.
Common Utilities (6 files)
Elementwise operations: bias_add, residual_add, vector_add, transpose,
head_rearrange (permute between [batch, seq, heads, dim] and
[batch, heads, seq, dim]), qkv_split (split a fused QKV projection output
into separate Q, K, V buffers).
33.8 The Threadgroup Geometry Taxonomy
One of the most confusing aspects of reading akunu’s kernels is that different kernel families use radically different threadgroup geometries. Here is a reference card:
+-------------------------------+--------+-------+--------+--------+
| Kernel | TG Size| # SGs | Rows/TG| Notes |
+-------------------------------+--------+-------+--------+--------+
| gemv_f16 | 128 | 4 | 16 | NR=4 |
| gemv_q4_0 (small K) | 128 | 4 | 16 | NQ=16 |
| gemv_q4_0_l (large K) | 256 | 8 | 32 | NQ=16 |
| gemv_q8_0 | 256 | 8 | 32 | NR=4 |
| gemv_q4_k | 256 | 8 | 16 | nr0=2 |
| gemv_wide_* | 256 | 8 | 64 | NCOLS=8|
| gemv_batched_* | 128/256| 4/8 | 16/32 | M<=16 |
| gemv_mlx_q4 | 128 | 4 | 16 | NR=4 |
| gemv_mlx_q4_l | 256 | 8 | 32 | NR=4 |
| simd_gemm_* | 128 | 4 | 32x64 | Tiled |
| simd_gemm_small_* | 128 | 4 | 8x64 | Tiled |
| flash_attention_decode | 128 | 4 | 1 head | per-TG |
| flash_attention_decode_fast | 32 | 1 | 1 head | no bar |
| flash_attention_decode_par | 1024 | 32 | 1 head | max BW |
| flash_attention_prefill_v2 | 128 | 4 | 32 Q | BQ=32 |
| rmsnorm, softmax | varies |varies | 1 row | per-TG |
+-------------------------------+--------+-------+--------+--------+
The pattern: GEMV kernels use 128-256 threads with 4-8 SIMD groups, each SG computing 2-8 output rows. GEMM kernels use 128 threads with 4 SGs in a 2D grid. Attention has three completely different geometries depending on the decode scenario.
33.9 Memory Access Patterns: Why It All Matters
Apple Silicon’s GPU shares unified memory with the CPU, but bandwidth is still the limiting factor for LLM inference. M2 Pro provides about 200 GB/s, M3 Max about 400 GB/s. A 7B model at Q4_0 is about 3.5 GB of weights. At 200 GB/s, you can read the entire model in ~17.5 ms, giving a theoretical ceiling of ~57 tokens/second for decode (one full weight read per token).
This means every wasted byte of memory bandwidth directly costs throughput. The kernels are designed around three principles:
-
Read weights once, compute multiple outputs. GEMV kernels process NR=4 output rows per SIMD group, amortizing the activation vector read.
-
Vectorized loads. Using
half4(8 bytes) orfloat4(16 bytes) loads instead of scalar loads gives 4x-8x better memory throughput. -
Dequantize on-the-fly. Never materialize the full FP16 weight matrix. Read the quantized blocks, dequantize in registers, multiply, accumulate.
The next three chapters will dive deep into the actual implementations: GEMV (Chapter 34), GEMM (Chapter 35), and FlashAttention (Chapter 36).
33.10 How a Single Inference Step Maps to Kernels
To tie it all together, here is the kernel sequence for a single decode step of a typical Llama-style model:
1. embedding_lookup_q4_0 -- Token embedding dequant + lookup
2. For each transformer layer:
a. rmsnorm_f16 -- Attention norm
b. gemv_q4_0 (x3) -- Q, K, V projections
c. rope_neox_kv_write -- Fused RoPE + KV cache write
d. flash_attention_decode_* -- Attention (variant depends on kv_seq_len)
e. gemv_q4_0 -- Output projection
f. residual_add -- Residual connection
g. rmsnorm_f16 -- FFN norm
h. gemv_q4_0 (x2) -- Gate and Up projections
i. silu -- SiLU activation (or fused into gemv_q4_0_silu)
j. gemv_q4_0 -- Down projection
k. residual_add -- Residual connection
3. rmsnorm_f16 -- Final norm
4. gemv_wide_q4_0 -- LM head (vocab projection, large N)
5. temperature_scale_f16 -- Temperature scaling
6. argmax / topk_select -- Sampling
That is roughly 13 kernel dispatches per layer, plus 4 for the head. For a
32-layer model, that is 420+ kernel dispatches per token. Each one is a
separate computeEncoder.dispatchThreadgroups() call. The overhead per dispatch
on Apple Silicon is about 1-3 microseconds, so dispatch overhead alone is
0.5-1.3 ms – which is why kernel fusion (like gemv_q4_0_silu and
rope_neox_kv_write) matters so much.
With this map in hand, let us dive into the actual kernel implementations. Chapter 34 starts with the GEMV kernels – the single hottest code path in decode.
GEMV: The Workhorse of Decode
During token-by-token decode, every linear projection in the transformer – Q, K, V, output, gate, up, down, and the final logit projection – is a matrix-vector multiply (GEMV). The activation is a single row ([1, K]), the weight matrix is [N, K], and the output is [1, N]. Since every generated token triggers all of these GEMVs, the GEMV kernels are the single most important performance factor in decode throughput.
Akunu ships a large family of GEMV kernels in backend/metal/kernels/metal/kernel/matmul/, each specialized for a different weight format, matrix size, or fused operation. This chapter covers every major variant: FP16, Q4_0, Q8_0, Q4_K, Wide GEMV, MLX Q4, Batched GEMV, and the fused SiLU variants. For each, we will look at the actual Metal Shading Language (MSL) code, understand the thread mapping, and explain why each design decision was made.
Why So Many Kernels?
A reasonable question: why not write one generic GEMV kernel and parameterize it? The answer is that the dequantization logic for each format is fundamentally different, and templating it all into one kernel would prevent the Metal compiler from generating optimal code. Each format has different:
- Block sizes: Q4_0 uses 32 elements per block, Q4_K uses 256 elements, Q8_0 uses 32 elements, MLX uses group_size (typically 32 or 64).
- Memory access patterns: Q4_0 reads uint16 pairs, Q8_0 reads int8 arrays, Q4_K reads interleaved scale metadata, MLX reads from three separate arrays.
- Arithmetic: Q4_0 uses the nibble pre-scaling trick, Q8_0 uses simple multiply, Q4_K uses multi-level scale+min reconstruction, MLX uses affine dequant (scale * quant + bias).
- Optimal thread counts: Q4_0 works best with 128 or 256 threads, Q8_0 with 256, Q4_K with 256.
The compiler can specialize loop bounds, eliminate dead code paths, and optimize register allocation when each kernel has a single code path. A unified kernel with runtime dispatch would pay a branch penalty and register pressure cost on every iteration of the inner loop.
The downside is code duplication, but Akunu mitigates this through C++ templates (for the SiLU variants) and consistent naming conventions that make the kernel family navigable.
The Universal GEMV Pattern
Despite their diversity, all Akunu GEMV kernels share a common structure:
- Thread mapping: Map thread IDs to output rows and K-dimension positions.
- Accumulation loop: Each thread computes partial dot products for its assigned rows.
- SIMD reduction:
simd_sum()across the 32-lane SIMD group to get the full dot product. - Output write: Lane 0 of each SIMD group writes the final result.
Thread Mapping (NR=4, NSIMD=4)
┌─────────────────────────────────────────────────┐
│ Threadgroup (128 threads = 4 SIMD groups) │
│ │
│ SG0: rows 0-3 SG1: rows 4-7 │
│ SG2: rows 8-11 SG3: rows 12-15 │
│ │
│ Within each SG (32 lanes): │
│ Each lane reads VPT=16 elements per K-block │
│ Lane stride = 32 lanes × 16 = 512 elements │
└─────────────────────────────────────────────────┘
Interactive: How a GEMV Kernel Executes on Apple Silicon
The animation below walks through akunu’s FP16 GEMV kernel step by step, showing how GPU hardware maps threads to data. It uses a simplified model (8 rows, 4 K-blocks) to keep things readable — real kernels use 16 rows and 8 K-blocks of 512 elements each, but the principle is identical. Click Step to advance one phase at a time, or Play to run through all phases automatically.
FP16 GEMV (gemv_f16)
The FP16 GEMV is the simplest and serves as a reference design. Let’s examine it in detail.
Constants and Configuration
constant constexpr uint GEMV_NR = 4; // rows per SIMD group
constant constexpr uint GEMV_NSIMD = 4; // SIMD groups per TG
constant constexpr uint GEMV_VPT = 16; // values per thread per K-block
constant constexpr uint GEMV_BLOCK_K = GEMV_VPT * SIMD_WIDTH; // 512
| Parameter | Value | Meaning |
|---|---|---|
| NR | 4 | Each SIMD group handles 4 output rows simultaneously |
| NSIMD | 4 | 4 SIMD groups per threadgroup = 128 threads |
| VPT | 16 | Each thread processes 16 K-elements per block |
| BLOCK_K | 512 | K-elements processed per outer loop iteration |
| Rows/TG | 16 | NR * NSIMD = 16 output rows per threadgroup |
Thread-to-Row Mapping
uint rows_per_tg = GEMV_NR * GEMV_NSIMD; // 16
uint base_row = tgid * rows_per_tg + sgid * GEMV_NR;
Threadgroup tgid handles rows [tgid*16, tgid*16+15]. Within the threadgroup, SIMD group sgid handles 4 consecutive rows starting at base_row.
The Accumulation Loop
float sums[GEMV_NR] = {};
const uint k_aligned = (K / GEMV_BLOCK_K) * GEMV_BLOCK_K;
for (uint k_block = 0; k_block < k_aligned; k_block += GEMV_BLOCK_K) {
uint k_off = k_block + slid * GEMV_VPT;
// Load activation block (contiguous, cache-friendly)
device const half4 *x4 = (device const half4 *)(x + k_off);
float4 xf0 = float4(x4[0]), xf1 = float4(x4[1]),
xf2 = float4(x4[2]), xf3 = float4(x4[3]);
// Process NR rows
for (uint r = 0; r < GEMV_NR; r++) {
uint row = base_row + r;
if (row >= N) break;
device const half4 *w4 = (device const half4 *)(W + row * K + k_off);
float4 wf0 = float4(w4[0]), wf1 = float4(w4[1]),
wf2 = float4(w4[2]), wf3 = float4(w4[3]);
sums[r] += dot(xf0, wf0) + dot(xf1, wf1)
+ dot(xf2, wf2) + dot(xf3, wf3);
}
}
Key observations:
- Vector loads:
half4loads read 8 bytes at once. 4 of them cover 16 half values = VPT. - Activation reuse: The activation vector
xis loaded once and reused across all NR=4 rows. This is the fundamental optimization of multi-row GEMV. dot()intrinsic: Metal’sdot(float4, float4)computes a 4-element dot product in a single instruction on Apple Silicon.- Contiguous access: Both
xandWare accessed contiguously within each K-block, maximizing cache line utilization.
SIMD Reduction and Output
for (uint r = 0; r < GEMV_NR; r++) {
float total = simd_sum(sums[r]);
if (slid == 0) {
uint row = base_row + r;
if (row < N) y[row] = half(total * params.alpha);
}
}
simd_sum() sums a value across all 32 lanes of the SIMD group in O(log2(32)) = 5 steps using butterfly shuffles.1 Only lane 0 writes the result, since all other lanes hold the same value after the reduction.
Remainder Handling
for (uint k = k_aligned + slid; k < K; k += 32) {
float xk = float(x[k]);
for (uint r = 0; r < GEMV_NR; r++) {
uint row = base_row + r;
if (row >= N) break;
sums[r] += xk * float(W[row * K + k]);
}
}
When K is not a multiple of 512, the remainder elements are handled with scalar accesses at stride 32 (one element per lane). This is less efficient but handles edge cases correctly.
Q4_0 GEMV (gemv_q4_0)
Quantized GEMV is where things get interesting. Q4_0 packs two 4-bit values per byte with a shared FP16 scale factor per block of 32 elements. Akunu provides two variants tuned for different K dimensions.
Q4_0 Data Layout
block_q4_0 (20 bytes per 32 elements):
┌──────────┬──────────────────────────────┐
│ d (FP16) │ qs[16] (16 bytes = 32 nibbles) │
│ 2 bytes │ low nibble = elem 0..15 │
│ │ high nibble = elem 16..31 │
└──────────┴──────────────────────────────┘
The Nibble Extraction Trick
The most performance-critical code is the dot product function:
inline float block_q4_0_dot_y(device const block_q4_0 *qb, float sumy,
thread float *yl, int il) {
float d = qb->d;
float acc[4] = {0.f, 0.f, 0.f, 0.f};
device const uint16_t *qs = ((device const uint16_t *)qb + 1 + il/2);
for (int i = 0; i < 8; i += 2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
}
return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
}
This is dense code. Let’s decode what it does:
-
uint16 reads: Instead of reading individual bytes, it reads
uint16_tvalues. Each uint16 contains 4 nibbles. -
Mask extraction: The four masks (
0x000F,0x0F00,0x00F0,0xF000) extract different nibbles from the uint16:0x000F: bits 0-3 (low nibble of low byte)0x00F0: bits 4-7 (high nibble of low byte)0x0F00: bits 8-11 (low nibble of high byte)0xF000: bits 12-15 (high nibble of high byte)
-
Pre-scaled activations: The activations
yl[]are pre-divided by powers of 256 to account for the positional shift of each nibble:yl[i + 0] = y0; // × 1 (matches 0x000F) yl[i + 1] = y1 / 256.f; // × 1/256 (matches 0x0F00, shifted by 8 bits) yl[i + 8] = y16 / 16.f; // × 1/16 (matches 0x00F0, shifted by 4 bits) yl[i + 9] = y17 / 4096.f; // × 1/4096 (matches 0xF000, shifted by 12 bits) -
Zero-point correction: The
sumy * -8.fterm subtracts 8 from each quantized value (Q4_0 uses unsigned 0-15 with a zero point of 8), multiplied by the sum of activation values.
This approach avoids any explicit bit-shifting operations, relying instead on pre-scaled multiplications that compile to efficient multiply-add sequences on Apple Silicon.
Two Size Variants
| Variant | SIMD Groups | Threads | Rows/TG | Optimal K |
|---|---|---|---|---|
gemv_q4_0 | 4 (NSG=4) | 128 | 16 | K < 2048 |
gemv_q4_0_l | 8 (NSG=8) | 256 | 32 | K >= 2048 |
The large variant doubles the threadgroup size, which provides more parallelism across the K dimension. For large K, this better saturates the GPU’s memory bandwidth. For small K, the extra threads would be wasted because there are not enough K-blocks to keep them busy.
The threshold is hardware-dependent and configured in ChipConfig:
if (family >= 9 && cores >= 16)
c.q4_small_k_threshold = 512; // M4 Pro+
else if (cores >= 16)
c.q4_small_k_threshold = 1024; // M3 Pro+
else
c.q4_small_k_threshold = 2048; // Base chips
Q8_0 GEMV (gemv_q8_0)
Q8_0 is simpler than Q4_0 because each element is a full byte – no nibble extraction needed:
constant constexpr uint NR = 4;
constant constexpr uint NSIMD = 8; // 256 threads, 32 rows/TG
for (uint kb = tiisg; kb < nblocks; kb += SIMD_WIDTH) {
const uint base_k = kb * QK8_0;
device const half4 *x4 = (device const half4 *)(x + base_k);
float4 xf[8];
for (uint i = 0; i < 8; i++)
xf[i] = float4(x4[i]);
for (uint r = 0; r < NR; r++) {
uint row = base_row + r;
if (row >= N) break;
device const block_q8_0 &blk = W[row * nblocks + kb];
float d = float(blk.d);
float dot = 0;
for (uint g = 0; g < 8; g++) {
dot += xf[g][0] * float(blk.qs[g*4+0])
+ xf[g][1] * float(blk.qs[g*4+1])
+ xf[g][2] * float(blk.qs[g*4+2])
+ xf[g][3] * float(blk.qs[g*4+3]);
}
sums[r] += d * dot;
}
}
Each block_q8_0 contains 32 int8 values with a shared FP16 scale. The 8-way unrolled inner loop processes all 32 elements (4 per group, 8 groups). The int8-to-float conversion is straightforward – no masking or shifting needed.
Q8_0 is 2x the size of Q4_0 but significantly faster because the dequantization is trivial. It is used for models where quality is prioritized over memory footprint.
Q4_K GEMV (gemv_q4_k)
Q4_K is a “super-block” format from GGML that uses a more sophisticated quantization scheme with per-sub-block scales and mins. Each super-block contains 256 elements (vs Q4_0’s 32), with 12 bytes of scale/min metadata plus 128 bytes of quantized data.
The kernel is ported directly from llama.cpp’s kernel_mul_mv_q4_K_f32_impl:
constant constexpr short GEMV_Q4K_NCOLS = 2; // nr0: rows per SIMD group
constant constexpr short GEMV_Q4K_NSG = 8; // SIMD groups per threadgroup
| Parameter | Value | Meaning |
|---|---|---|
| NCOLS | 2 | 2 rows per SIMD group (vs 4 for Q4_0) |
| NSG | 8 | 8 SIMD groups = 256 threads |
| Rows/TG | 16 | NCOLS * NSG = 16 |
The reduced NCOLS (2 vs 4) reflects the larger per-block metadata overhead: more registers are needed for scale/min decoding, leaving less room for row parallelism.
Thread Mapping
The Q4_K thread mapping is more complex than Q4_0:
const short ix = tiisg / 8; // 0..3 (4 threads share blocks)
const short it = tiisg % 8; // 0..7
const short iq = it / 4; // 0 or 1 (sub-block selector)
const short ir = it % 4; // 0..3 (element within sub-block)
4 threads share a block (ix selects which group of blocks), and within each block, the 8 remaining thread indices select sub-blocks and elements. This complex mapping ensures that each thread reads a contiguous portion of the quantized data.
Scale Decoding
The scales are packed in a 12-byte structure using 6-bit values with 2-bit offsets:
constexpr uint16_t kmask1 = 0x3f3f;
constexpr uint16_t kmask2 = 0x0f0f;
constexpr uint16_t kmask3 = 0xc0c0;
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
This decodes the interleaved 6-bit scale values from the packed representation. The kmask1 extracts the low 6 bits of each byte, while kmask2 and kmask3 extract and combine the high bits to form the 4-bit sub-block offsets.
Q4_K Accumulation and Output
The final dot product combines the dequantized values with the per-sub-block scales and the block-level scale/minimum:
sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] +
sumy[2] * sc8[6] + sumy[3] * sc8[7]);
The dh[0] is the block scale and dh[1] is the block minimum. The sc8[] values are the sub-block scales decoded earlier. The 1.f/256.f and 1.f/16.f factors compensate for the nibble positional encoding (same technique as Q4_0). The minimum term (dh[1] * sumy * sc8) subtracts the zero-point contribution.
This complex expression evaluates to a simple conceptual operation: sum(dequant(weight[i]) * activation[i]) for all elements in the super-block, but expressed in terms of the packed representation to avoid explicit dequantization into a temporary buffer.
Q4_K vs Q4_0: When to Use Which
| Aspect | Q4_0 | Q4_K |
|---|---|---|
| Block size | 32 elements | 256 elements |
| Metadata overhead | 2 bytes per 32 elements (6.25%) | 12 bytes per 256 elements (4.7%) |
| Scale granularity | Per 32 elements | Per 32 elements (sub-block) + per 256 (super-block) |
| Has minimum/bias | No (zero point = 8) | Yes (per sub-block min) |
| Quantization quality | Good | Better (lower quantization error) |
| Dequant complexity | Low | Medium |
| Memory per element | 0.5625 bytes | 0.5469 bytes |
Q4_K achieves slightly better compression ratio AND better quality than Q4_0, at the cost of more complex dequantization. In practice, Q4_K is preferred for larger models where the quality difference matters, while Q4_0 is used for speed-critical paths or when model quality is less sensitive to quantization.
Wide GEMV (gemv_wide_f16)
The “wide” variant uses 8 SIMD groups (256 threads) instead of 4, processing 32 rows per threadgroup:
constant constexpr uint GEMV_WIDE_NR = 4;
constant constexpr uint GEMV_WIDE_NSIMD = 8; // 8 SIMD groups
constant constexpr uint GEMV_WIDE_VPT = 16;
constant constexpr uint GEMV_WIDE_BLOCK_K = GEMV_WIDE_VPT * SIMD_WIDTH; // 512
The algorithm is identical to gemv_f16; only the NSIMD parameter changes. The wider threadgroup provides better occupancy on chips with many GPU cores (Pro, Max, Ultra), where the smaller threadgroup would leave execution units idle.
Wide variants exist for multiple formats: gemv_wide_f16, gemv_wide_q4_0, gemv_wide_q4_k, gemv_wide_q8_0, gemv_wide_mlx_q4, and gemv_wide_mlx_q8.
MLX Q4 GEMV (gemv_mlx_q4)
MLX-format quantization differs from GGML’s block format. Instead of per-block scale factors interleaved with data, MLX stores weights, scales, and biases in three separate contiguous arrays:
Packed buffer layout:
┌────────────────────┬──────────────┬──────────────┐
│ U32 weights │ FP16 scales │ FP16 biases │
│ [N × K/8] uint32 │ [N × n_groups]│ [N × n_groups]│
└────────────────────┴──────────────┴──────────────┘
The dequantization formula is: value = scale * quant_value + bias
This is an affine quantization scheme (scale + bias) vs GGML’s scaled scheme (scale + zero-point). The per-group bias allows better representation of asymmetric distributions.
Function Constant Specialization
constant uint FC_GROUP_SIZE [[function_constant(0)]];
constant uint FC_K_DIM [[function_constant(1)]];
constant bool FC_SPECIALIZED = is_function_constant_defined(FC_GROUP_SIZE);
When the host knows the group size and K dimension at pipeline creation time, it specializes the kernel with these as compile-time constants. This enables the Metal compiler to:
- Replace division/modulo by group_size with shifts/masks (if power of 2)
- Unroll loops with known trip counts
- Eliminate branches on K alignment
Pre-Scaled Activation Trick
constant constexpr float PS0 = 1.0f;
constant constexpr float PS1 = 1.0f / 16.0f;
constant constexpr float PS2 = 1.0f / 256.0f;
constant constexpr float PS3 = 1.0f / 4096.0f;
const float4 ps4 = float4(PS0, PS1, PS2, PS3);
float4 xs0 = xf0 * ps4, xs1 = xf1 * ps4, xs2 = xf2 * ps4, xs3 = xf3 * ps4;
This is the same pre-scaling trick used in Q4_0, adapted for MLX’s nibble layout. Each uint16 contains 4 nibbles at bit positions 0, 4, 8, and 12. The pre-scale factors compensate for the positional shift, allowing the kernel to multiply the pre-scaled activation directly with the masked uint16 value.
The Accumulation
float dot = xs0[0] * float(w0 & 0x000fu) + xs0[1] * float(w0 & 0x00f0u)
+ xs0[2] * float(w0 & 0x0f00u) + xs0[3] * float(w0 & 0xf000u)
+ xs1[0] * float(w1 & 0x000fu) + xs1[1] * float(w1 & 0x00f0u)
+ xs1[2] * float(w1 & 0x0f00u) + xs1[3] * float(w1 & 0xf000u)
+ ...;
sums[r] += s * dot + b * x_sum;
The s * dot term handles the scale, and b * x_sum handles the bias. The bias contributes b * sum(x_elements) because each quantized value is multiplied by the bias (via the affine formula), and summing gives bias * sum(activations).
Template-Based Variants
The MLX GEMV uses a template to generate both standard and wide variants:
template<int NSIMD_T, int NR_T = 4>
void gemv_mlx_q4_impl(...) { ... }
// Standard: 4 SIMD groups, 128 threads, 16 rows/TG
kernel void gemv_mlx_q4(...) {
gemv_mlx_q4_impl<4, 4>(x, packed, y, params, tgpig, tiisg, sgitg);
}
// Wide: 8 SIMD groups, 256 threads, 32 rows/TG
kernel void gemv_mlx_q4_l(...) {
gemv_mlx_q4_impl<8, 4>(x, packed, y, params, tgpig, tiisg, sgitg);
}
Batched GEMV (gemv_batched_f16)
Batched GEMV handles the case where M > 1 but is still small (2-16 rows). This occurs during:
- Speculative decoding verification (batch of draft + 1 tokens)
- Very short prefill sequences (2-8 tokens)
constant uint FC_BATCH_M [[function_constant(11)]];
constant constexpr uint MAX_BATCH_M = 16;
constant constexpr uint BGEMV_NCOLS = 4;
constant constexpr uint BGEMV_NSIMD = 8;
float sums[BGEMV_NCOLS][MAX_BATCH_M] = {};
for (uint k = slid; k < K; k += 32) {
for (uint bm = 0; bm < batchM; bm++) {
float xk = float(x[bm * lda + k]);
for (uint c = 0; c < BGEMV_NCOLS; c++) {
uint col = base_col + c;
if (col < N) sums[c][bm] += xk * float(W[col * K + k]);
}
}
}
The key difference from standard GEMV: the inner loop iterates over both batch rows (bm) and output columns (c). The activation matrix has batchM rows, each producing an independent output row.
FC_BATCH_M is a function constant that lets the host specialize the kernel for a known batch size, enabling the compiler to unroll the batch loop.
Fused SiLU GEMV (gemv_q4_0_silu)
The fused SiLU variant is a significant optimization for the FFN down-projection. Instead of three separate dispatches (gate GEMV, up GEMV, SiLU element-wise multiply, down GEMV), it computes SiLU(gate) * up inline during the down projection GEMV:
for (short i = 0; i < 8; i += 2) {
const float g0 = float(gb[i + 0]);
const float g1 = float(gb[i + 1]);
const float g16 = float(gb[i + 16]);
const float g17 = float(gb[i + 17]);
const float a0 = (g0 / (1.f + exp(-g0))) * float(ub[i + 0]);
const float a1 = (g1 / (1.f + exp(-g1))) * float(ub[i + 1]);
const float a16 = (g16 / (1.f + exp(-g16))) * float(ub[i + 16]);
const float a17 = (g17 / (1.f + exp(-g17))) * float(ub[i + 17]);
sumy += a0 + a1 + a16 + a17;
yl[i + 0] = a0;
yl[i + 1] = a1 / 256.f;
yl[i + 8] = a16 / 16.f;
yl[i + 9] = a17 / 4096.f;
}
The SiLU computation x / (1 + exp(-x)) is computed per element, then multiplied by the corresponding up projection value. The result feeds into the same Q4_0 dot product machinery used by the standard kernel.
Performance impact: This fusion eliminates:
- The gate GEMV dispatch (saved entirely – gate values are read from the existing buffer)
- The up GEMV dispatch (saved entirely – up values are read from the existing buffer)
- The SiLU element-wise kernel dispatch
- Two intermediate buffer writes and one buffer read (gate output, up output, SiLU output)
For a 7B model, this saves 3 dispatches and ~6 buffer round-trips per layer, or ~96 dispatches for 32 layers. The tradeoff is more ALU work per thread (the exp() and division), but since GEMV is memory-bound, the extra ALU is free.2
Fused SiLU variants exist for Q4_0, Q8_0, Q4_K, MLX Q4, and other formats. Each is a template instantiation with the corresponding dequantization logic.
MLX Embedding Lookup (embedding_lookup_mlx_q4)
While not strictly a GEMV, the MLX embedding lookup kernel lives alongside the GEMV kernels and uses similar dequantization:
const uint32_t word = W_u32[token_id * n_u32_per_row + u32_idx];
const uint qval = (word >> (within * bits)) & mask;
output[token_idx * K + d_idx] = half(s * float(qval) + b);
It supports arbitrary bit widths (not just 4) via the bits parameter, extracting quantized values from packed 32-bit words. The dispatch is 2D: (dim, seq_len), with one thread per output element.
Kernel Selection Strategy
The host selects GEMV kernels based on the weight format and matrix dimensions. Here is the decision tree:
| Weight Format | K < threshold | K >= threshold | N > wide_threshold |
|---|---|---|---|
| FP16 | gemv_f16 (4 SG) | gemv_f16 (4 SG) | gemv_wide_f16 (8 SG) |
| Q4_0 | gemv_q4_0 (4 SG, 128t) | gemv_q4_0_l (8 SG, 256t) | gemv_wide_q4_0 |
| Q8_0 | gemv_q8_0 (8 SG, 256t) | same | gemv_wide_q8_0 |
| Q4_K | gemv_q4_k (8 SG, 256t) | same | gemv_wide_q4_k |
| MLX Q4 | gemv_mlx_q4 (4 SG) | gemv_mlx_q4_l (8 SG) | gemv_wide_mlx_q4 |
The wide_gemv_threshold is typically 32768 (output dimension), which is only exceeded by the logit projection in models with very large vocabularies.
On Pro+ chips (cores >= 16), the “wide standard” flag is set, which promotes even the standard GEMV to use 8 SIMD groups. This is because Pro chips have enough GPU cores to benefit from the additional parallelism even at moderate N dimensions.
Memory Access Patterns
Understanding the memory access pattern is crucial for GEMV performance, since it is memory-bound:
For gemv_f16 with K=4096, NR=4:
Thread 0: reads x[0..15], W[row][0..15] for 4 rows
Thread 1: reads x[16..31], W[row][16..31] for 4 rows
...
Thread 31: reads x[496..511], W[row][496..511] for 4 rows
Next K-block: Thread 0 reads x[512..527], etc.
Within a K-block, all threads in a SIMD group access contiguous memory for both x and W. This means the GPU can coalesce these into wide (512-byte) cache line reads. The activation vector x is loaded once per K-block and reused NR=4 times, giving a theoretical 4x activation reuse factor.
For quantized formats, the memory access is even more favorable: Q4_0 reads 10 bytes per 32 elements (20 bytes per block, 2 blocks per 32-element group), compared to 64 bytes for FP16. This 6.4x memory reduction translates directly to higher effective bandwidth for the weight reads, which dominate the memory traffic.
Q5_0, Q5_K, Q6_K, Q2_K, Q3_K GEMV Variants
Beyond the core Q4_0, Q8_0, and Q4_K kernels, Akunu supports several additional GGML quantization formats:
Q5_0 (5-bit per element, group size 32): Each block contains 4 bytes of high-bit storage in addition to the Q4_0-style nibble data. The extra bit per element improves precision slightly. The kernel extracts the 5th bit from a separate bitfield and OR’s it into the 4-bit value during dequantization.
Q5_K (5-bit K-quant, super-block 256): Similar to Q4_K but with an extra bit per element. Uses the same super-block structure with per-sub-block scales and mins, plus an additional bit plane.
Q6_K (6-bit K-quant, super-block 256): Provides significantly better quality than Q4_K at 1.5x the memory cost. The dequantization extracts 6-bit values from packed bytes, using two distinct bit planes. Popular for models where quality is critical but FP16 is too expensive.
Q2_K (2-bit K-quant, super-block 256): The most aggressive quantization, storing only 2 bits per element plus 4-bit scales and 4-bit mins. Quality is noticeably degraded but memory usage is halved vs Q4_0. Used for very large models that would not otherwise fit in memory.
Q3_K (3-bit K-quant, super-block 256): A middle ground between Q2_K and Q4_K. Uses 3 bits per element with similar super-block metadata to Q4_K. Provides better quality than Q2_K at 1.5x the memory.
All of these kernels follow the same multi-row GEMV pattern described for Q4_0 and Q4_K, differing only in:
- The number of bits extracted per element (2, 3, 5, or 6)
- The block/super-block metadata layout
- The dequantization arithmetic
Each has standard, wide, and batched variants, plus fused SiLU variants where applicable. The code is generated from templates where possible, but the dequantization functions are hand-written for each format to ensure optimal register usage and instruction scheduling.
BF16 GEMV (gemv_bf16)
Akunu also supports BF16 (Brain Float 16) weights, which use 8 exponent bits and 7 mantissa bits (vs FP16’s 5 exponent and 10 mantissa). The BF16 GEMV kernel is structurally identical to the FP16 kernel but includes a BF16-to-FP32 conversion step during loading:
BF16 weights are used by some models (particularly those trained in BF16) and offer better dynamic range than FP16 at the cost of reduced precision. On Apple Silicon, there is no native BF16 support in the ALU (unlike NVIDIA’s Ampere+), so BF16 values are converted to FP32 for computation. The conversion is simple: BF16 is the upper 16 bits of an IEEE 754 float32, so the conversion is just a left-shift by 16 bits.
Arithmetic Intensity Analysis
Understanding why GEMV is memory-bound requires analyzing its arithmetic intensity – the ratio of compute operations to memory bytes transferred.
For an FP16 GEMV computing y[N] = x[K] @ W[N,K]:
| Metric | Value | Notes |
|---|---|---|
| Operations | 2 * N * K | multiply + add per element |
| Weight bytes | N * K * 2 | FP16, 2 bytes per element |
| Activation bytes | K * 2 | Read once |
| Output bytes | N * 2 | Write once |
| Total bytes | NK2 + K2 + N2 ≈ NK2 | Dominated by weight reads |
| Arithmetic intensity | 2NK / (NK2) = 1.0 FLOP/byte | With multi-row: NR * 1.0 |
An arithmetic intensity of 1.0 FLOP/byte means that for every byte of memory transferred, we perform only 1 floating-point operation. Apple Silicon GPUs have ~200-400 GB/s memory bandwidth but 7-14 TFLOPS of compute. At 1.0 FLOP/byte, we would need 7-14 TB/s of bandwidth to saturate the compute – about 35x more than available.3
Multi-row processing (NR=4) improves this to 4 FLOP/byte by reusing the activation vector, but this is still well within the memory-bound regime. The roofline model confirms that GEMV performance is limited by memory bandwidth, not compute capability.
For quantized formats, the effective arithmetic intensity is higher because less data is transferred:
| Format | Bytes per element | Effective AI (NR=4) |
|---|---|---|
| FP16 | 2.0 | 4.0 FLOP/byte |
| Q8_0 | ~1.06 | ~7.5 FLOP/byte |
| Q4_0 | ~0.56 | ~14.3 FLOP/byte |
| Q4_K | ~0.56 | ~14.3 FLOP/byte |
Q4_0 achieves 14.3 FLOP/byte – still memory-bound on Apple Silicon, but much closer to the roofline ridge point. This is why quantization provides such a large speedup for decode: the weight read traffic is cut by ~3.5x compared to FP16.
The Bandwidth Utilization Picture
On an M4 Pro with ~200 GB/s memory bandwidth and a 7B Q4_0 model (dim=4096, ffn_dim=14336), the theoretical decode speed is:
Weight bytes per token:
QKV: (4096 + 1024 + 1024) * 4096 * 0.56 = ~14.0 MB
O: 4096 * 4096 * 0.56 = ~9.4 MB
Gate: 14336 * 4096 * 0.56 = ~32.9 MB
Up: 14336 * 4096 * 0.56 = ~32.9 MB
Down: 4096 * 14336 * 0.56 = ~32.9 MB
─────────────────────────────────
Per layer: ~122 MB
32 layers: ~3.9 GB
+ Logit projection: ~37 MB (vocab=128K)
Total: ~3.94 GB
Theoretical tok/sec at 200 GB/s: 200 / 3.94 ≈ 50.8 tok/sec
In practice, Akunu achieves 80-90% of this theoretical maximum. The gap comes from:
- Non-weight memory traffic (activations, KV cache reads, norm computation)
- Kernel launch overhead (minimized by chain decode)
- Attention computation (memory-bound but separate from GEMV)
- Imperfect memory coalescing at tile boundaries
Register Pressure and Occupancy
Each GEMV kernel maintains NR float accumulators per SIMD group, plus registers for the loaded activation and weight values. For the FP16 kernel:
| Register Use | Count | Bytes |
|---|---|---|
| sums[NR=4] | 4 floats | 16 bytes |
| xf0-xf3 (activation) | 4 float4 | 64 bytes |
| wf0-wf3 (weight, per row) | 4 float4 | 64 bytes |
| Loop indices, temporaries | ~8 | 32 bytes |
| Total per thread | ~176 bytes |
Apple Silicon GPUs have 96 registers per thread (32-bit each) = 384 bytes, so the GEMV kernels use less than half the available register file. This ensures high occupancy: the GPU can schedule multiple warps per compute unit, hiding memory latency with concurrent threads.
For the Q4_K kernel, register pressure is higher (16-element float arrays for yl and yh, plus the scale decoding registers), but still within bounds. The 2-row-per-SIMD-group design (NCOLS=2) was specifically chosen to fit in registers.
GPU Dispatch Geometry Examples
Let’s work through the dispatch for a concrete case: the K projection of Llama 3.1 8B, where N=1024 (kv_dim) and K=4096, using Q4_0 quantization:
K = 4096 > q4_small_k_threshold (2048 on base chip)
→ Use gemv_q4_0_l (NSG=8, 256 threads)
→ Rows per TG = NR0 * NSG = 4 * 8 = 32
→ Grid size = ceil(1024 / 32) = 32 threadgroups
→ Total threads = 32 * 256 = 8192
Within each threadgroup:
8 SIMD groups, each handling 4 output rows
Each thread in a SIMD group:
- Processes K/16 = 256 Q4_0 blocks (strided by 16 blocks at a time)
- 16 values per block iteration
- 256/16 = 16 outer loop iterations
For the logit projection (N=128256, K=4096):
N = 128256 > wide_gemv_threshold (32768)
→ Use gemv_wide_q4_0 (NSG=8, 256 threads, 32 rows/TG)
→ Grid size = ceil(128256 / 32) = 4009 threadgroups
→ Total threads = 4009 * 256 ≈ 1M
This is the largest GEMV dispatch in the model, taking ~2ms on M4 Pro.
Complete Kernel Variant Matrix
Here is the full matrix of all GEMV kernel files in Akunu:
| Base Kernel | Standard | Wide | Batched | Fused SiLU |
|---|---|---|---|---|
| FP16 | gemv_f16 | gemv_wide_f16 | gemv_batched_f16 | – |
| BF16 | gemv_bf16 | – | – | – |
| Q4_0 | gemv_q4_0 (+_l) | gemv_wide_q4_0 | gemv_batched_q4_0 | gemv_q4_0_silu (+_l) |
| Q4_1 | gemv_q4_1 | gemv_wide_q4_1 | gemv_batched_q4_1 | – |
| Q5_0 | gemv_q5_0 | gemv_wide_q5_0 | gemv_batched_q5_0 | – |
| Q8_0 | gemv_q8_0 | gemv_wide_q8_0 | gemv_batched_q8_0 | gemv_q8_0_silu |
| Q4_K | gemv_q4_k | gemv_wide_q4_k | gemv_batched_q4_k | gemv_q4_k_silu |
| Q5_K | gemv_q5_k | – | gemv_batched_q5_k | – |
| Q6_K | gemv_q6_k | – | gemv_batched_q6_k | – |
| Q2_K | gemv_q2_k | – | gemv_batched_q2_k | – |
| Q3_K | gemv_q3_k | – | gemv_batched_q3_k | – |
| MLX Q4 | gemv_mlx_q4 (+_l) | gemv_wide_mlx_q4 | gemv_batched_mlx_q4 | gemv_mlx_q4_silu |
| MLX Q3 | – | – | – | gemv_mlx_q3_silu |
| MLX Q6 | – | – | – | gemv_mlx_q6_silu |
| MLX Q8 | – | gemv_wide_mlx_q8 | – | gemv_mlx_q8_silu |
| MLX Gen | gemv_mlx_gen | gemv_wide_mlx_gen | gemv_batched_mlx_gen | – |
That is over 50 kernel entry points. Each is generated from a relatively small set of patterns (the base algorithm is the same), but the dequantization logic, thread counts, and buffer layouts differ per format.
The fused SiLU variants are particularly valuable: they combine three operations (gate GEMV, up GEMV, SiLU multiply) into one, eliminating two full weight scans and three buffer round-trips per layer. For a 32-layer model, this saves 64 GEMV dispatches and ~192 buffer reads/writes.
Metal-Specific Optimizations
Several Metal-specific features are exploited across the GEMV kernels:
make_uniform()
The Q4_0 and Q4_K kernels use make_uniform() on parameters:
const int N = make_uniform(int(params.N));
const int K = make_uniform(int(params.K));
make_uniform() is a Metal compiler hint that tells the GPU’s SIMD group scheduler that all lanes will have the same value. This enables the compiler to:
- Use scalar registers instead of vector registers for the value
- Hoist loop-invariant computations out of SIMD-divergent branches
- Optimize conditional execution (all lanes take the same branch)
Without make_uniform(), the compiler must conservatively assume that different lanes might have different values for params.N, generating slower divergent code.
half4 and float4 Vector Operations
Apple Silicon’s GPU ALU natively supports 4-component vector operations. The dot(float4, float4) function compiles to a single instruction that computes all 4 multiply-adds in parallel. Using half4 for loads ensures that the load-store unit reads 8 bytes in one transaction, which is the optimal load granularity for Apple’s memory subsystem.
The pattern of loading as half4, converting to float4, computing, and accumulating is optimal:
- Load: 8 bytes in one transaction (half4)
- Convert: hardware F16->F32 widening (1 cycle)
- Compute: FMA in F32 (full precision accumulation)
- This avoids F16 accumulation errors while keeping loads narrow
SIMD Group Size
On Apple Silicon, the SIMD group size is always 32 (unlike NVIDIA where it varies between 32 and 64, or AMD where it can be 32 or 64). This simplifies kernel design:
- All
simd_sum()andsimd_max()calls reduce exactly 32 values - Thread-to-row mappings can use fixed constants (SIMD_WIDTH=32)
- No need for runtime warp/wavefront size queries
Threadgroup Memory Absence
Unlike GEMM kernels, most GEMV kernels use no threadgroup memory. The entire computation is done in registers with simd_sum() for cross-lane reduction. This is possible because GEMV’s data flow is simple: each SIMD group independently computes NR output rows, and the only cross-lane communication is the final reduction.
This zero-TG-memory design has two benefits:
- No threadgroup memory allocation overhead at dispatch time
- No barrier overhead (simd_sum is barrier-free within a SIMD group)
The only exceptions are the batched GEMV (which uses shared memory for cross-SIMD-group reduction when the batch dimension exceeds one SIMD group) and some Q4_K variants that use threadgroup memory for scale decoding shared across threads.
Summary
Akunu’s GEMV kernel family is the performance backbone of decode. The key design principles are:
- Multi-row processing (NR): Reusing the activation vector across multiple output rows to improve arithmetic intensity from 1 to NR FLOP/byte.
- Vectorized loads:
half4anduint16_treads to maximize memory bandwidth utilization via coalesced transactions. - SIMD reduction: Apple’s
simd_sum()for efficient cross-lane communication without threadgroup barriers. - Size-specialized variants: Different threadgroup sizes (128 vs 256 threads) for different K dimensions and hardware tiers.
- Fused operations: SiLU + GEMV fusion to eliminate intermediate buffers, saving 3 dispatches and 6 buffer round-trips per layer.
- Function constant specialization: Compile-time K and group_size for better code generation and loop unrolling.
- Format-specific dequantization: Nibble pre-scaling for Q4_0, multi-level scale+min for Q4_K, affine dequant for MLX – each tuned for its format’s memory layout.
- Zero threadgroup memory: Most GEMV kernels avoid threadgroup memory entirely, relying on SIMD intrinsics for all cross-lane communication.
The GEMV kernels collectively represent Akunu’s largest investment in kernel engineering – over 50 entry points across 30+ source files – because decode throughput directly determines how fast text appears on screen. Every percentage point of bandwidth utilization improvement translates directly to faster generation.
Practical Impact: End-to-End Decode Profile
To understand how GEMV kernels fit into the full picture, here is a representative decode profile for Llama 3.1 8B Q4_0 on M4 Pro:
| Operation | Time per Token | % of Total | Kernel Type |
|---|---|---|---|
| Embedding lookup | 0.01ms | 0.1% | Embedding |
| QKV GEMV (x32 layers) | 2.4ms | 30% | gemv_q4_0 |
| RoPE + KV write (x32) | 0.3ms | 4% | rope_kv_write |
| Attention (x32) | 1.2ms | 15% | flash_attention |
| O GEMV (x32) | 0.8ms | 10% | gemv_q4_0 |
| Fused SiLU GEMV (x32) | 2.8ms | 35% | gemv_q4_0_silu |
| Output norm | 0.01ms | 0.1% | rmsnorm |
| Logit GEMV | 0.3ms | 4% | gemv_q4_0 |
| Argmax | 0.01ms | 0.1% | argmax |
| Total | ~8ms | 100% |
GEMV operations account for ~79% of the total decode time. The fused SiLU GEMV is the single largest contributor because the FFN’s intermediate dimension (14336) is 3.5x the model dimension (4096). Optimizing GEMV kernels has the highest return on investment of any kernel engineering work in Akunu.
The most impactful optimizations in Akunu’s GEMV history were:
- Multi-row (NR=4): Gave a 2-3x speedup over NR=1 by reusing the activation vector.
- Fused SiLU: Eliminated 3 dispatches per layer, saving ~15% total decode time.
- Hardware-tuned NSG: Matching the threadgroup size to the hardware tier improved occupancy by ~10%.
- Function constant K specialization (for MLX): Enabled 5-8% speedup from compiler loop optimizations.
- half4 vectorized loads: Improved memory coalescing by 4x compared to scalar half loads.
These optimizations compound: a kernel that is 2x faster from multi-row, 15% faster from fusion, and 10% faster from tuning delivers ~2.5x total improvement over a naive implementation.
-
Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.9, SIMD-group Functions.
simd_sum()performs a butterfly reduction across all active threads in a SIMD group, completing in ceil(log2(SIMD_WIDTH)) steps. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩ -
Williams, S., Waterman, A., and Patterson, D. “Roofline: An Insightful Visual Performance Model for Multicore Architectures.” Communications of the ACM, 2009. GEMV is firmly in the memory-bound regime (low arithmetic intensity), meaning additional ALU operations like SiLU computation do not affect wall-clock time. See https://doi.org/10.1145/1498765.1498785. ↩
-
The roofline model shows that an algorithm with arithmetic intensity below the “ridge point” (peak FLOPS / peak bandwidth) is memory-bound. For Apple M4 Pro: ridge point = 14 TFLOPS / 200 GB/s = 70 FLOP/byte. GEMV at 1-14 FLOP/byte is well below this threshold. ↩
GEMM: Batched Matrix Multiplication
During prefill, every linear projection is a full matrix-matrix multiply: C[M,N] = A[M,K] @ B^T[N,K]. Unlike GEMV (which is memory-bound), GEMM can be compute-bound when M is large enough, because each weight element is reused across M activation rows. This chapter covers Akunu’s GEMM kernels, which use Apple Silicon’s SIMD group matrix multiply-accumulate (MMA) instructions to achieve near-peak throughput.
The kernels live in backend/metal/kernels/metal/kernel/matmul/simd_gemm_*.metal. We will focus on two representative variants: the FP16 GEMM and the Q4_0 GEMM, which together illustrate the key design patterns.
Tile Geometry: The 32x64 Layout
Both GEMM kernels use the same tile geometry, inherited from llama.cpp’s kernel_mul_mm:
| Parameter | Symbol | Value | Meaning |
|---|---|---|---|
| Tile M (activation rows) | TM / NR1 | 32 | Rows of A processed per threadgroup |
| Tile N (weight rows) | TN / NR0 | 64 | Rows of B (columns of output) per threadgroup |
| Tile K (accumulation) | TK / NK | 32 | K-dimension per accumulation step |
| Threads per TG | – | 128 | 4 SIMD groups x 32 lanes |
| Dispatch grid | – | (ceil(N/64), ceil(M/32)) | One TG per output tile |
Why 32x64 and not 64x64 or 32x32? The answer lies in the SIMD group MMA instruction, which operates on 8x8 half-precision matrices. The 32x64 tile decomposes into:
Output tile [32, 64] as 8x8 sub-tiles:
┌────┬────┬────┬────┬────┬────┬────┬────┐
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 0-7
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 8-15
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 16-23
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 24-31
└────┴────┴────┴────┴────┴────┴────┴────┘
col col col col col col col col
0-7 8-15 16-23 24-31 32-39 40-47 48-55 56-63
4 SIMD groups split this into 4 quadrants:
SG0: rows 0-15, cols 0-31 (2×4 = 8 sub-tiles)
SG1: rows 0-15, cols 32-63 (2×4 = 8 sub-tiles)
SG2: rows 16-31, cols 0-31 (2×4 = 8 sub-tiles)
SG3: rows 16-31, cols 32-63 (2×4 = 8 sub-tiles)
Each SIMD group maintains mc[8] accumulators (8 simdgroup_half8x8 matrices), covering its 8 sub-tiles.
Interactive: GEMM Tiled Execution on the GPU
This animation shows how one threadgroup computes a 32x64 output tile. Watch 4 SIMD groups cooperatively load weight and activation tiles into threadgroup memory, then execute 8x8 MMA operations. The K-dimension sweeps left to right, and the output tile fills as accumulators grow. Step through to see the cooperative loading, the MMA compute, and the final store.
FP16 GEMM (simd_gemm_f16)
Threadgroup Memory Layout
threadgroup half *sa = shmem; // Weight tile: 4096 bytes
threadgroup half *sb = shmem + 4096 / sizeof(half); // Activation tile: 2048 bytes
// Total: 6144 bytes
The weight tile is larger (4096 bytes for 64 rows x 32 K-cols) because TN=64 > TM=32. The activation tile is 2048 bytes (32 rows x 32 K-cols).
Cooperative Loading
Each of the 128 threads loads a portion of the weight and activation tiles into threadgroup memory:
Weight loading (sa):
const short lr0 = ((short)tiitg / NL0) < nr0 ? ((short)tiitg / NL0) : nr0 - 1;
const short il0 = (tiitg % NL0);
// F16: just read 16 halves per thread
half4x4 temp_a;
for (int i = 0; i < 16; i++) {
temp_a[i/4][i%4] = x[i];
}
For FP16, the load is a simple copy from device memory to registers, then a scatter to threadgroup memory in the sub-block layout that the MMA instructions expect.
Activation loading (sb):
const short lr1 = ((short)tiitg / NL1) < nr1 ? ((short)tiitg / NL1) : nr1 - 1;
const short iy = 8 * (tiitg % NL1);
*(threadgroup half2x4 *)(sb + 64 * ib + 8 * ly) = *((device const half2x4 *)y);
The activation tile uses half2x4 (16-byte) vector stores for efficient threadgroup memory writes.
The Scatter Pattern
The threadgroup memory layout is not a simple row-major matrix. Instead, it uses an 8x8 sub-block interleaved layout that aligns with the MMA instruction’s expected input format:
for (short i = 0; i < 16; i++) {
const short sx = 2 * il0 + i / 8;
const short sy = lr0 / 8;
const short lx = lr0 % 8;
const short ly = i % 8;
const short ib = 8 * sx + sy;
*(sa + 64 * ib + 8 * ly + lx) = temp_a[i/4][i%4];
}
This scatter writes 16 elements per thread into the correct positions for efficient simdgroup_load. The layout ensures that each 8x8 sub-block is contiguous in memory, with a stride of 8 between columns and 64 between rows of sub-blocks.1
The MMA Accumulation Loop
for (uint loop_k = 0; loop_k < K_dim; loop_k += NK) {
// Load weight and activation tiles (shown above)
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup const half *lsma = (sa + 4 * 64 * (sgitg % 2));
threadgroup const half *lsmb = (sb + 2 * 64 * (sgitg / 2));
for (short ik = 0; ik < NK / 8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_half8x8 ma[4];
for (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64 * i, 8, 0, false);
}
simdgroup_half8x8 mb[2];
for (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64 * i, 8, 0, false);
}
for (short i = 0; i < 8; i++) {
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8 * 64;
lsmb += 4 * 64;
}
}
Let’s break down what happens per K-step (8 elements of K):
-
Load weight sub-tiles: 4
simdgroup_half8x8matrices (ma[0..3]) are loaded fromsa. These represent a 32x8 slice of the weight tile (4 sub-tiles of 8x8). -
Load activation sub-tiles: 2
simdgroup_half8x8matrices (mb[0..1]) fromsb. These represent a 16x8 slice of the activation tile (2 sub-tiles of 8x8). -
MMA: 8 multiply-accumulate operations, one per output sub-tile. Each computes
mc[i] += mb[i/4] * ma[i%4], which is an 8x8 @ 8x8 -> 8x8 matrix multiply-accumulate.
The simdgroup_barrier(mem_flags::mem_none) is a lightweight barrier that synchronizes execution within the SIMD group without requiring memory ordering. This is cheaper than a full threadgroup_barrier.
Function Constant K Specialization
constant uint FC_GEMM_K [[function_constant(10)]];
constant bool FC_GEMM_K_SPECIALIZED = is_function_constant_defined(FC_GEMM_K);
const uint K_dim = FC_GEMM_K_SPECIALIZED ? FC_GEMM_K : K;
When K is known at pipeline creation time and is a multiple of 32, the host passes it as a function constant. The Metal compiler can then:
- Generate a fixed-count loop (or fully unrolled for small K)
- Eliminate the remainder check
- Optimize memory access patterns for the known stride
Output Store with Alpha/Beta
The FP16 GEMM supports the full BLAS-style interface C = alpha * A @ B^T + beta * C:
const half alpha_h = half(params.alpha);
const half beta_h = half(params.beta);
// ...
const bool has_alphabeta = (alpha_h != half(1) || beta_h != half(0));
if (has_alphabeta) {
for (int i = 0; i < nr0; i++) {
D[i] = alpha_h * S[i] + beta_h * D[i];
}
} else {
// Fast path: direct copy with half4 stores
device half4 *D4 = (device half4 *)D;
threadgroup half4 *S4 = (threadgroup half4 *)S;
for (int i = 0; i < nr0 / 4; i++) *(D4 + i) = *(S4 + i);
}
When alpha=1, beta=0 (the common case), the output is stored directly with half4 vector stores, avoiding the multiply-add overhead.
Q4_0 GEMM (simd_gemm_q4_0)
The Q4_0 GEMM is structurally identical to the FP16 GEMM – same tile geometry, same MMA loop, same output store. The only difference is how the weight tile is loaded: instead of a simple copy, the quantized data must be dequantized into FP16.
Inline Dequantization
inline void dequantize_q4_0_half4x4(device const block_q4_0 *xb,
short il, thread half4x4 ®) {
device const uint16_t *qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (half4x4)reg_f;
}
This dequantizes one block (32 elements) into a half4x4 (16 elements). The il parameter selects which half of the block to dequantize (low nibbles or high nibbles). The two calls to this function per thread cover all 32 elements.
The key insight: dequantization happens into registers, not into a separate buffer. The dequantized values go directly into the scatter pattern, and from there into the MMA pipeline. No intermediate buffer is ever allocated for dequantized weights.
Threadgroup Swizzling
The Q4_0 GEMM includes an optimization not present in the FP16 version: threadgroup swizzling for cache locality:
constexpr uint SWIZZLE_LOG = 3;
constexpr uint SWIZZLE_WIDTH = 1u << SWIZZLE_LOG; // 8
uint tg_x = tgpig.x;
uint tg_y = tgpig.y;
uint tiles_x = (N + NR0 - 1) / NR0;
if (tiles_x >= SWIZZLE_WIDTH) {
uint group = tg_x >> SWIZZLE_LOG;
uint within = tg_x & (SWIZZLE_WIDTH - 1);
tg_x = (group << SWIZZLE_LOG) + ((within + tg_y) & (SWIZZLE_WIDTH - 1));
}
Without swizzling, threadgroups are dispatched in row-major order: (0,0), (1,0), (2,0), .... Adjacent threadgroups in the X direction access different weight columns but the same activation rows. Swizzling rotates the column index by the row index within strips of 8 tiles, so that adjacent threadgroups access overlapping weight columns:
Without swizzling (row 0): TG(0,0) TG(1,0) TG(2,0) TG(3,0) TG(4,0) ...
Without swizzling (row 1): TG(0,1) TG(1,1) TG(2,1) TG(3,1) TG(4,1) ...
With swizzling (row 0): TG(0,0) TG(1,0) TG(2,0) TG(3,0) TG(4,0) ...
With swizzling (row 1): TG(1,1) TG(2,1) TG(3,1) TG(4,1) TG(5,1) ...
The effect: TG(1,0) and TG(1,1) (which are likely to execute on neighboring GPU cores) now access weight tiles that are only 64 columns apart instead of the full N-stride. This keeps weight data hot in the System Level Cache (SLC).2
Full Tile Fast Path
When the output tile is fully covered (no edge padding needed), the Q4_0 GEMM uses a direct device memory store:
if (nr0 == NR0 && nr1 == NR1) {
device half *D = C
+ (uint)(r1 + 16 * (sgitg >> 1)) * ldc
+ (uint)(r0 + 32 * (sgitg & 1));
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], D + 8 * (i/4) * ldc + 8 * (i%4), ldc, 0, false);
}
}
Each SIMD group writes its 8 sub-tiles (8x8 each) directly to the output matrix using simdgroup_store. The ldc stride tells the store instruction the row pitch of the output matrix.
For edge tiles (where the tile extends beyond the matrix boundary), a staging area in threadgroup memory is used, and only the valid elements are copied to device memory.
Tile Accumulation Visualization
The following shows how a single output tile accumulates over K-steps:
Tile Accumulation Loop (one threadgroup computes C[32,64]):
K=0..31 K=32..63 K=64..95 K=4065..4095
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│Load A,B │ │Load A,B │ │Load A,B │ │Load A,B │ │ Store C │
│Dequant B │→ │Dequant B │→ │Dequant B │→ ... → │Dequant B │ → │ [32,64] │
│C += A@B │ │C += A@B │ │C += A@B │ │C += A@B │ │to device│
└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
128 iterations total (K=4096, stride=32)
For a model with K=4096, there are 4096/32 = 128 accumulation steps per tile. Each step loads 2KB of activation data and 4KB of weight data (for Q4_0, the raw quantized data is ~1KB but dequantizes to 4KB in registers), performs 8 MMA operations (each 8x8 @ 8x8), and accumulates into the 8 output sub-tiles.
Small GEMM Variants
For very small M (2-8 rows), Akunu provides “small” GEMM variants with TM=8 instead of TM=32:
simd_gemm_small_f16, simd_gemm_small_q4_0, simd_gemm_small_q4_k, ...
These use fewer threadgroup memory (fewer activation rows to store) and produce smaller output tiles, avoiding wasted computation on padding rows. The dispatch threshold is:
bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;
When M=4 (e.g., a 4-token speculative verification batch), the small variant processes all 4 rows in one 8-row tile (with 4 padding rows), while the standard variant would use a 32-row tile with 28 wasted rows.
Memory Requirements
| Resource | FP16 GEMM | Q4_0 GEMM |
|---|---|---|
| Threadgroup memory | 6144 bytes | 6144 bytes |
| Registers per SG (accumulators) | 8 x simdgroup_half8x8 | 8 x simdgroup_half8x8 |
| Weight tile bandwidth | 64 * 32 * 2 = 4096 bytes/step | 64 * 32 / 2 * 1.25 = ~1280 bytes/step |
| Activation tile bandwidth | 32 * 32 * 2 = 2048 bytes/step | 32 * 32 * 2 = 2048 bytes/step |
The Q4_0 GEMM reads only ~1280 bytes of quantized weight data per K-step (compared to 4096 bytes for FP16), a 3.2x reduction. This is why quantized GEMMs achieve higher effective throughput than FP16 GEMMs on the same hardware – the memory subsystem is the bottleneck for both, and Q4_0 moves less data per FLOP.
The Full GEMM Kernel Zoo
Akunu provides GEMM kernels for every supported weight format:
| Format | Standard Kernel | Small Kernel | Notes |
|---|---|---|---|
| FP16 | simd_gemm_f16 | simd_gemm_small_f16 | No dequant |
| BF16 | simd_gemm_bf16 | simd_gemm_small_bf16 | BF16->FP16 convert |
| Q4_0 | simd_gemm_q4_0 | simd_gemm_small_q4_0 | 4-bit, group=32, swizzle |
| Q4_1 | simd_gemm_q4_1 | simd_gemm_small_q4_1 | 4-bit with min |
| Q5_0 | simd_gemm_q5_0 | simd_gemm_small_q5_0 | 5-bit |
| Q5_K | simd_gemm_q5_k | simd_gemm_small_q5_k | 5-bit, K-quant |
| Q8_0 | simd_gemm_q8_0 | simd_gemm_small_q8_0 | 8-bit |
| Q4_K | simd_gemm_q4_k | simd_gemm_small_q4_k | 4-bit, K-quant |
| Q6_K | simd_gemm_q6_k | simd_gemm_small_q6_k | 6-bit, K-quant |
| Q2_K | simd_gemm_q2_k | simd_gemm_small_q2_k | 2-bit, K-quant |
| Q3_K | simd_gemm_q3_k | simd_gemm_small_q3_k | 3-bit, K-quant |
| MLX Q3 | simd_gemm_mlx_q3 | simd_gemm_small_mlx_q3 | MLX 3-bit |
| MLX Q4 | simd_gemm_mlx_q4 | simd_gemm_small_mlx_q4 | MLX 4-bit |
| MLX Q6 | simd_gemm_mlx_q6 | simd_gemm_small_mlx_q6 | MLX 6-bit |
| MLX Q8 | simd_gemm_mlx_q8 | simd_gemm_small_mlx_q8 | MLX 8-bit |
| MLX Gen | simd_gemm_mlx_gen | simd_gemm_small_mlx_gen | MLX arbitrary bits |
That is 30+ kernel variants, all sharing the same tile geometry and MMA loop, differing only in the dequantization path.
Performance Characteristics
GEMM performance on Apple Silicon depends primarily on the tile utilization and memory bandwidth:
| Factor | Impact | How Akunu Handles It |
|---|---|---|
| M too small | Wasted rows in activation tile | Small GEMM variant (TM=8) |
| N not multiple of 64 | Edge tile with partial store | Staging through TG memory |
| K not multiple of 32 | Remainder loop needed | FC_GEMM_K specialization |
| Cache thrashing | Weight tile eviction | Threadgroup swizzling |
| Register pressure | Spill to local memory | 8 accumulators fits in 128 registers |
The theoretical peak for an Apple M4 Pro (20 GPU cores) at FP16 MMA is approximately 14 TFLOPS. A well-optimized 4096x4096 GEMM achieves roughly 80-90% of peak, limited by threadgroup memory bandwidth and barrier synchronization overhead.
Pipeline State Object Caching
Each GEMM variant requires a compiled Pipeline State Object (PSO) before it can be dispatched. Akunu caches these PSOs aggressively:
std::string cache_key = std::string(kernel) + "_k" + std::to_string(K);
pso = device.get_pipeline(kernel, cache_key, fc_indices, fc_values, 1);
The cache key includes the kernel name and any function constant values, ensuring that different K-specializations produce separate PSOs. The first call to get_pipeline compiles the MSL kernel into GPU machine code (which can take 10-50ms), but subsequent calls return the cached PSO instantly.
For a typical model, there are approximately 10-15 unique GEMM PSOs (one per unique K dimension per weight format). These are compiled during model loading and never recompiled during inference.
GEMM vs GEMV: The Crossover Point
An important question: when should the engine use GEMM instead of GEMV? The answer depends on M (the number of activation rows):
| M | Optimal Kernel | Why |
|---|---|---|
| 1 | GEMV | No tile overhead, direct reduction |
| 2-8 | Small GEMM (TM=8) | Some row reuse, minimal padding |
| 9-32 | Standard GEMM (TM=32) | Good tile utilization |
| 33+ | Standard GEMM (TM=32) | Multiple tiles in M dimension |
The crossover between GEMV and GEMM is at M=2. Even with just 2 activation rows, the GEMM kernel’s weight tile loading (shared between both rows) provides better memory efficiency than two separate GEMV dispatches. However, for M=1, the GEMM kernel wastes 31 out of 32 rows in the activation tile, so GEMV is always faster.
Akunu’s dispatch_gemm function makes this decision:
bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;
The small GEMM variant (TM=8) wastes at most 6 rows (for M=2) instead of 30 rows (with TM=32), providing a good compromise for very small batch sizes.
The MMA Instruction in Detail
Apple Silicon’s simdgroup_multiply_accumulate is the hardware primitive that makes efficient GEMM possible. Let’s understand exactly how it works.
Lane-to-Element Mapping
In an 8x8 SIMD matrix, the 32 lanes of a SIMD group each hold 2 elements. The mapping follows Apple’s proprietary layout:3
For an 8x8 matrix stored in a simdgroup_half8x8:
Lane 0: elements (0,0) and (0,1)
Lane 1: elements (0,2) and (0,3)
Lane 2: elements (1,0) and (1,1)
Lane 3: elements (1,2) and (1,3)
...
The thread_elements() accessor returns a vec<T, 2> containing the calling thread’s two elements. This is used by the V2 attention kernel to perform per-element operations directly on MMA results without going through threadgroup memory.
MMA Throughput
Each simdgroup_multiply_accumulate(C, A, B, C) computes:
C[8,8] += A[8,8] @ B[8,8]
This performs 8 * 8 * 8 = 512 multiply-accumulate operations. At FP16 precision, this is 1024 FLOPs per instruction. With 4 SIMD groups per threadgroup and a typical clock rate of 1.4 GHz on M4 Pro, the peak throughput per threadgroup is:
4 SG * 1024 FLOP/instruction * ~1 instruction/cycle * 1.4 GHz
= ~5.7 GFLOPS per threadgroup
With 20 GPU cores running ~10 threadgroups each, the chip-level throughput is approximately 1.14 TFLOPS of FP16 MMA – though in practice, memory bandwidth and barrier overhead reduce this to ~60-80% of peak.
Register Accumulator Precision
The MMA instruction accumulates in the same precision as the operands. For simdgroup_half8x8, accumulation is in FP16. For long K-dimensions (K > 4096), this can lead to precision loss from repeated half-precision additions.
Akunu mitigates this by using the simdgroup_float8x8 accumulator type for attention scores (where precision matters more) while keeping simdgroup_half8x8 for GEMM output (where the subsequent operations, norm + activation, tolerate half-precision).
GEMM vs GEMV: The Crossover Point
An important question: when should the engine use GEMM instead of GEMV? The answer depends on M (the number of activation rows):
| M | Optimal Kernel | Why |
|---|---|---|
| 1 | GEMV | No tile overhead, direct reduction |
| 2-8 | Small GEMM (TM=8) | Some row reuse, minimal padding |
| 9-32 | Standard GEMM (TM=32) | Good tile utilization |
| 33+ | Standard GEMM (TM=32) | Multiple tiles in M dimension |
The crossover between GEMV and GEMM is at M=2. Even with just 2 activation rows, the GEMM kernel’s weight tile loading (shared between both rows) provides better memory efficiency than two separate GEMV dispatches. However, for M=1, the GEMM kernel wastes 31 out of 32 rows in the activation tile, so GEMV is always faster.
Akunu’s dispatch_gemm function makes this decision:
bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;
The small GEMM variant (TM=8) wastes at most 6 rows (for M=2) instead of 30 rows (with TM=32), providing a good compromise for very small batch sizes.
Quantized GEMM Performance Analysis
For a 7B model prefilling 2048 tokens with Q4_0 weights:
The Q projection GEMM: C[2048, 4096] = A[2048, 4096] @ B^T[4096, 4096]
| Metric | Value |
|---|---|
| Output elements | 2048 * 4096 = 8.4M |
| FLOPs | 2 * 2048 * 4096 * 4096 = 68.7 GFLOP |
| Weight data read | 4096 * 4096 * 0.56 bytes = 9.4 MB |
| Activation data read | 2048 * 4096 * 2 bytes = 16.8 MB |
| Output data written | 2048 * 4096 * 2 bytes = 16.8 MB |
| Total memory traffic | ~43 MB |
| Arithmetic intensity | 68.7 GFLOP / 43 MB ≈ 1598 FLOP/byte |
At 1598 FLOP/byte, this is firmly in the compute-bound regime. The M4 Pro’s 14 TFLOPS of FP16 throughput would complete this in ~4.9ms, and memory bandwidth (200 GB/s) would complete the data transfer in ~0.2ms. The GEMM is compute-bound by a factor of ~24x.
This is the fundamental reason prefill is so much faster per-token than decode: the same weight data is reused across 2048 activation rows, amortizing the memory transfer cost.
Handling Non-Standard Architectures
Akunu’s GEMM dispatch supports several architectural variations through the descriptor system:
BERT/Encoder models: Use the same GEMM kernels but with different weight names and optional bias addition (dispatched as a separate kernel after the GEMM).
Gemma models: Have post-attention and post-FFN norms that require extra GEMM dispatch passes. The dispatch_gemm function is architecture-agnostic – it just computes C = alpha * A @ B^T + beta * C.
MLX quantized models: The GEMM kernels receive MLX-specific parameters (group_size, bits, weight_bytes) via a secondary parameter buffer, enabling the same tile geometry with different dequantization logic.
Tied embeddings: The logit projection in some models reuses the embedding table as the output weight. dispatch_gemm does not care about the semantic meaning of the weight – it just needs the buffer, dimensions, and dtype.
Threadgroup Memory Bandwidth
Threadgroup memory on Apple Silicon GPUs has significantly higher bandwidth than device memory – roughly 10-20x, depending on the chip generation. This is why the GEMM kernel’s performance depends heavily on the TG memory access pattern.
The weight tile scatter pattern places data in 8x8 sub-blocks with stride 64 between sub-block rows and stride 8 between columns within a sub-block. This layout is not arbitrary – it matches the simdgroup_load access pattern, ensuring that each MMA instruction’s operand load reads a contiguous 64-byte chunk from threadgroup memory.
For each K-step (32 elements of K):
| Access | Pattern | Bytes | Bandwidth Required |
|---|---|---|---|
| Load weight tile from device | 128 threads, 16 elements each | 4096 bytes | Device BW |
| Scatter weight to TG | 128 threads, indexed writes | 4096 bytes | TG BW |
| Load activation tile from device | 128 threads, 8 elements each | 2048 bytes | Device BW |
| Store activation to TG | 128 threads, vector stores | 2048 bytes | TG BW |
| MMA loads from TG | 4 SG x (4+2) loads per K/8 step | ~6144 bytes per step | TG BW |
The TG memory acts as a software-managed L1 cache, giving the programmer explicit control over data reuse that would otherwise depend on hardware caching behavior.
End-to-End Prefill GEMM Flow
For a complete understanding, let’s trace the GEMMs in a single transformer layer during prefill of a 7B model with seq_len=2048:
| GEMM | M | N | K | Weight Shape | Time (est.) |
|---|---|---|---|---|---|
| Q projection | 2048 | 4096 | 4096 | [4096, 4096] | ~5ms |
| K projection | 2048 | 1024 | 4096 | [1024, 4096] | ~1.5ms |
| V projection | 2048 | 1024 | 4096 | [1024, 4096] | ~1.5ms |
| O projection | 2048 | 4096 | 4096 | [4096, 4096] | ~5ms |
| Gate projection | 2048 | 14336 | 4096 | [14336, 4096] | ~16ms |
| Up projection | 2048 | 14336 | 4096 | [14336, 4096] | ~16ms |
| Down projection | 2048 | 4096 | 14336 | [4096, 14336] | ~16ms |
Total per layer: ~61ms. For 32 layers: ~1.95 seconds. Plus attention, norms, and activations: roughly 2.5 seconds total for 2048 tokens. That is about 820 tokens/sec prefill throughput, which matches real-world measurements on M4 Pro hardware.
The FFN GEMMs (Gate, Up, Down) dominate because ffn_dim (14336) is ~3.5x larger than dim (4096). This is characteristic of modern LLMs that use SwiGLU activation, which requires a wider intermediate dimension.
The Barrier Budget
Threadgroup barriers are a significant cost in GEMM kernels. Each threadgroup_barrier(mem_flags::mem_threadgroup) call synchronizes all threads in the threadgroup and flushes the threadgroup memory. On Apple Silicon, a barrier takes approximately 0.2-0.5 microseconds.
For each K-step (32 elements of K), the GEMM kernel requires 2 barriers:
- After the cooperative tile load (ensure all threads have written their portion)
- After the MMA loop (ensure all SIMD groups have finished reading)
For K=4096, there are 128 K-steps, requiring 256 barriers. At 0.3us per barrier, this is ~77us of pure barrier overhead per tile, or roughly 10-15% of the total tile computation time. This is one of the reasons GEMM does not achieve 100% of peak MMA throughput.
The V2 attention kernel’s approach of keeping data in registers (avoiding the MMA-barrier-MMA cycle) provides a hint at how future GEMM kernels might reduce barrier overhead, though the GEMM’s much larger tile sizes make this approach more challenging.
Comparison with llama.cpp’s GEMM
Akunu’s GEMM kernels are derived from llama.cpp’s kernel_mul_mm family but include several improvements:
| Feature | llama.cpp | Akunu |
|---|---|---|
| Tile geometry | 32x64 (same) | 32x64 + 8x64 small variant |
| Threadgroup swizzling | No | Yes (Q4_0, other quantized) |
| Function constant K | No | Yes (FC_GEMM_K) |
| Alpha/beta support | No | Yes (FP16 GEMM) |
| MLX format support | No | Yes (6 MLX variants) |
| Small M variant | No | Yes (TM=8 for M=2-8) |
| BF16 support | Partial | Full |
The most impactful difference is the function constant K specialization, which allows the Metal compiler to generate tighter loops with known bounds, often resulting in 5-10% speedup for common K dimensions.
The threadgroup swizzling provides another 3-8% improvement at large grid sizes by improving SLC hit rates for weight tiles. This is most noticeable during the FFN GEMMs where the grid is large (14336/64 = 224 tiles in the weight dimension).
Future Directions
Apple’s Metal 3.2 (introduced with the M4 family) provides enhanced simdgroup matrix operations, including support for larger tile sizes and new data types. Future GEMM kernels may benefit from:
- Larger MMA tiles: 16x16 or 32x32 sub-tiles would reduce the number of MMA instructions per output element, improving throughput.
- BF16 MMA: Native BF16 matrix operations would eliminate the conversion overhead for BF16 models.
- Cooperative groups: Finer-grained synchronization primitives could reduce barrier overhead.
- Persistent kernels: A single long-running kernel that processes all tiles sequentially could eliminate inter-tile overhead.
However, the current 8x8 MMA-based approach is well-proven and delivers near-peak performance. The 32x64 tile geometry will likely remain optimal for Apple Silicon’s current generation of GPU architectures.
Debugging GEMM Correctness
GEMM bugs are notoriously difficult to debug because the output is a large matrix where each element depends on the full K-dimension accumulation. Akunu uses several strategies:
- Alpha/Beta support: Setting
alpha=1, beta=0for production andalpha=0, beta=1for “identity” (output = input C) enables isolating GEMM output from existing data. - PSO validation: The
dispatch_gemmfunction includes a fatal error if the PSO fails to compile, catching kernel bugs early. - Dimension checks: The scratch buffer sizes are validated at model load time to ensure no GEMM dispatch will write out of bounds.
- Profiling labels: Each GEMM dispatch in the dispatch table carries a label like
"L5.ffn.gate", making it easy to identify which GEMM produced incorrect output in a GPU debugger.
Summary
Akunu’s GEMM kernels are the workhorses of prefill. The key design decisions are:
- 32x64 tile geometry with 4 SIMD groups per threadgroup, maximizing MMA instruction utilization.
- Inline dequantization for quantized formats, converting directly from packed format to registers without intermediate buffers.
- Cooperative loading where all 128 threads participate in loading both weight and activation tiles.
- Threadgroup swizzling for cache-friendly access patterns across the grid.
- Small GEMM variants for low-M cases to avoid wasted padding computation.
- Function constant specialization for K-dimension to enable compiler optimizations.
-
Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.10, Simdgroup Matrix Functions. The
simdgroup_loadandsimdgroup_storefunctions operate on 8x8 matrices distributed across the 32 threads of a SIMD group, with each thread holding two elements (thethread_elements()accessor). See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩ -
The swizzling technique is adapted from NVIDIA’s CUTLASS library. See: Thakkar, V., et al. “CUTLASS: Fast Linear Algebra in CUDA C++.” NVIDIA Technical Blog, 2017. The Apple Silicon SLC acts similarly to NVIDIA’s L2 cache for this optimization. ↩
-
Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.10.3, Simdgroup Matrix Thread Elements. The
thread_elements()accessor returns avec<T, 2>containing the thread’s owned elements, following Apple’s proprietary lane mapping for 8x8 matrices. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩
FlashAttention Kernels
Attention is the defining operation of the transformer architecture, and getting it right on Apple Silicon is perhaps the most challenging kernel engineering task in Akunu. The attention operation computes O = softmax(Q @ K^T / sqrt(d)) @ V, but a naive implementation would materialize the full [seq_len, kv_seq_len] attention matrix, which is quadratic in memory. FlashAttention avoids this by tiling the computation and maintaining an online softmax that never materializes the full matrix.1
Akunu implements five attention kernels, each optimized for a different scenario:
| Kernel | File | Use Case | Dispatch | Threads |
|---|---|---|---|---|
| Standard Decode | flash_attention.metal | Short-context decode, short prefill | (seq_len, n_heads) | up to 1024 |
| Fast Decode | flash_attention_decode_fast.metal | M=1 single-token decode, medium context | (1, n_heads) | 32 (1 SG) |
| Parallel Decode | flash_attention_decode_parallel.metal | M=1 decode, very long context | (1, n_heads) | 1024 (32 SG) |
| Prefill V1 | flash_attention.metal | Medium-length prefill | (ceil(S/NQ), n_heads) | 128 (4 SG) |
| Prefill V2 | flash_attention_prefill_v2.metal | Long prefill (>= 1024 tokens) | (ceil(S/32), n_heads) | 128 (4 SG) |
Plus a standalone Softmax kernel in softmax.metal.
This chapter covers all of them in detail – their algorithms, thread assignments, memory strategies, and performance characteristics.
The Online Softmax Algorithm
Before diving into kernels, we need to understand the algorithm they all share: online softmax. The standard softmax requires two passes over the data (find max, then compute exp and sum). FlashAttention’s online variant maintains a running max and sum, allowing it to process KV entries in a single streaming pass.
- Initialize:
max = -inf,sum = 0,O = 0 - For each KV block (tile of 32 positions):
- Compute scores:
S = Q @ K^T * scale - Find
block_max = max(S)
- Compute scores:
- Update running state:
new_max = max(running_max, block_max)correction = exp(old_max - new_max)running_sum = running_sum * correction + sum(exp(S - new_max))
- Rescale and accumulate V:
O = O * correction + exp(S - new_max) @ V
- Finalize:
O = O / running_sum
The correction factor exp(old_max - new_max) rescales all previously accumulated values when a new maximum is discovered. This is the heart of the algorithm – it allows processing KV entries in arbitrary-sized blocks without ever storing the full attention matrix.
Now let’s see how each kernel variant implements this differently. But first, watch the online softmax in action — this is the foundation that all four kernels share.
Interactive: Online Softmax — The Core Algorithm
This animation processes 6 KV positions one at a time. Watch the running max and sum update, and pay attention to what happens at position 3 when a new maximum is discovered — the correction factor rescales all previous work. This is the trick that makes FlashAttention possible.
Now let’s see how the four kernel variants implement this algorithm with different parallelization strategies.
Kernel 1: Standard Decode (flash_attention_decode_f16)
The standard decode kernel handles the general case: one threadgroup per query position, with threads collaborating on the QK dot product and V accumulation.
Thread Assignment
Dispatch: grid = (seq_len, n_heads), threadgroup = (head_dim or 1024)
Each thread “owns” one element of the head dimension. For head_dim=128, 128 threads are used. For head_dim=256, 256 threads. The cap at 1024 accommodates models with very large head dimensions.
Algorithm
float q_val = (tid < head_dim) ? float(q_row[tid]) : 0.0f;
for (uint kv_start = 0; kv_start < kv_seq_len; kv_start += ATTN_KV_TILE) {
for (uint kv = 0; kv < tile_len; kv++) {
// Q·K dot product: each thread multiplies one element
float local_dot = (tid < head_dim) ? q_val * float(k_row[tid]) : 0.0f;
// Cross-SIMD reduction via shared memory
float simd_val = simd_sum(local_dot);
if (slid == 0) shared_reduce[sgid] = simd_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Sum across SIMD groups
float score = shared_reduce[0];
for (uint s = 1; s < n_simd_groups; s++)
score += shared_reduce[s];
threadgroup_barrier(mem_flags::mem_threadgroup);
score *= scale;
// Online softmax update
float new_max = max(running_max, score);
float exp_score = exp(score - new_max);
float correction = exp(old_max - new_max);
running_sum = running_sum * correction + exp_score;
// V accumulation in register
if (tid < head_dim) {
float v_val = float(v_base[kv_pos * head_dim + tid]);
acc = acc * correction + exp_score * v_val;
}
running_max = new_max;
}
}
Key observations:
- Register-local accumulator: Each thread accumulates exactly one output element in
acc. No threadgroup memory needed for the output. - Two barriers per KV entry: One after the SIMD-group reduction write, one after the read. These are the performance bottleneck for long contexts.
- Q value cached in register: The query vector is loaded once and reused across all KV entries.
GQA Support
Grouped-Query Attention is handled by a single integer division:
const uint kv_head = head / (n_heads / n_kv_heads);
Multiple query heads share the same KV head. The KV stride is computed per-head, and all query heads in the same group read from the same K/V buffers.
Causal and Tree Masking
The kernel supports three masking modes via function constants:
constant bool FC_USE_TREE_MASK [[function_constant(1)]];
constant bool FC_NON_CAUSAL [[function_constant(3)]];
| Mode | Use Case | Masking Logic |
|---|---|---|
| Causal | Standard prefill | kv_pos > q_pos -> skip |
| Tree | Speculative verification | Bitmask lookup per (q,kv) pair |
| Non-causal | BERT/Whisper encoder | No masking |
Tree masking uses a per-query bitmask packed into uint16_t:
if (kv_pos >= batch_start) {
uint batch_idx = kv_pos - batch_start;
if (batch_idx < seq_len && ((tree_mask[q_pos] >> batch_idx) & 1) == 0)
continue;
}
This enables speculative decoding with tree-structured draft tokens, where each draft token can only attend to its ancestors in the tree.
Interactive: Standard Decode — Barrier-Based QK Broadcast
In the standard decode kernel, multiple SIMD groups share a threadgroup. The QK dot product needs cross-SG communication via threadgroup memory and barriers. Each KV position requires 2 barriers — the dominant cost for long contexts.
Kernel 2: Fast Decode (flash_attention_decode_fast_f16)
The fast decode kernel is a radical simplification: a single SIMD group (32 threads) per head, with zero threadgroup barriers.
Why No Barriers?
The standard decode kernel needs barriers because the QK dot product requires cross-SIMD-group communication. With only one SIMD group, simd_sum() provides the full reduction – no shared memory needed:
float local_dot = 0.0f;
for (uint e = 0; e < elems_per_thread; e++) {
uint idx = slid + e * SIMD_WIDTH;
local_dot += q_vals[e] * float(k_row[idx]);
}
float score = simd_sum(local_dot) * scale;
Each thread handles head_dim / 32 elements (4 elements for head_dim=128, 2 for head_dim=64). The simd_sum broadcasts the result to all lanes instantly, without any barrier.
Performance Profile
| Aspect | Standard Decode | Fast Decode |
|---|---|---|
| Threads per head | up to 1024 | 32 |
| Barriers per KV | 2 | 0 |
| Memory parallelism | High (many threads read KV) | Low (32 threads read KV) |
| Barrier overhead | ~0.2us * 2 * kv_len | 0 |
| Best for | Short contexts (< ~128 KV) | Medium contexts (128-1024 KV) |
For context lengths beyond ~128 KV entries, the barrier overhead in the standard kernel dominates. Each barrier costs roughly 0.2 microseconds, and with 2 barriers per KV entry, a 1024-entry context costs ~400 microseconds in barrier overhead alone. The fast decode kernel eliminates this entirely, at the cost of lower memory bandwidth utilization (32 threads vs 128+).
Multi-Element V Accumulation
float acc[8] = {}; // max 8 elements per thread
for (uint kv_pos = 0; kv_pos < kv_seq_len; kv_pos++) {
// ... compute score ...
for (uint e = 0; e < elems_per_thread; e++) {
uint idx = slid + e * SIMD_WIDTH;
float v_val = float(v_base[kv_pos * head_dim + idx]);
acc[e] = acc[e] * correction + exp_score * v_val;
}
}
Each thread maintains elems_per_thread accumulators in registers. The memory access pattern for V is strided: thread 0 reads elements 0, 32, 64, 96 (for head_dim=128). This is suboptimal for cache lines but acceptable because the KV data is typically in the SLC.
Interactive: Fast Decode — Zero Barriers
The breakthrough: use only 1 SIMD group (32 threads). Each thread holds multiple head_dim elements (4 for head_dim=128). Since all 32 threads are in one SG, simd_sum() gives the full QK dot product — no threadgroup memory, no barriers. Compare the barrier count to Standard Decode above.
Kernel 3: Parallel Decode (flash_attention_decode_parallel_f16)
For very long contexts (thousands of KV entries), even the fast decode kernel is limited by its sequential scan of KV entries. The parallel decode kernel uses 32 SIMD groups (1024 threads) to parallelize across the KV dimension:
constexpr uint NUM_SG = 32;
const uint sgid = tid / 32; // SIMD group = KV position group
const uint slid = tid % 32; // lane = head_dim partition
KV Parallelism
Each SIMD group handles every 32nd KV position:
for (uint kv_pos = sgid; kv_pos < kv_seq_len; kv_pos += NUM_SG) {
// Compute dot product for this KV position
// Online softmax within this SG's partial view
}
SG 0 processes positions 0, 32, 64, …; SG 1 processes 1, 33, 65, …; and so on. This gives 32x memory parallelism for KV reads compared to the single-SG approach.
Cross-SG Reduction
After all SGs finish their partial computations, the results must be merged. This is the tricky part – each SG has a partial (max_score, sum_exp, output[head_dim]) triplet that needs to be combined with correction factors:
// Phase 1: Find global max
if (slid == 0) {
tg_max[sgid] = max_score;
tg_sum[sgid] = sum_exp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float loaded_max = tg_max[slid];
float global_max = simd_max(loaded_max);
// Phase 2: Correct sums
float loaded_sum = tg_sum[slid] * fast::exp(tg_max[slid] - global_max);
float global_sum = simd_sum(loaded_sum);
float inv_sum = 1.0f / global_sum;
// Phase 3: Reduce output per element
float my_factor = fast::exp(max_score - global_max);
for (uint i = 0; i < elems; i++) {
tg_out[slid * 32 + sgid] = o[i] * my_factor;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgid == 0) {
float total = 0;
for (uint s = 0; s < 32; s++)
total += tg_out[slid * 32 + s];
o_row[slid * elems + i] = half(total * inv_sum);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
The reduction uses 4KB of threadgroup memory (tg_out[1024]) to stage the per-SG output partials. For each head_dim element, all 32 SGs write their corrected partials, then SG 0 sums them.
When to Use Parallel Decode
The parallel decode kernel is fastest for very long contexts (10K+ KV entries) where the KV scan dominates. For shorter contexts, the reduction overhead (barriers + threadgroup memory traffic) makes it slower than the fast decode kernel.
Interactive: Parallel Decode — 32 SGs Divide the KV Cache
For long contexts, even fast decode is slow because one SG must scan all KV positions sequentially. Parallel decode uses 32 SIMD groups, each processing every 32nd KV position. After the parallel scan, a cross-SG reduction merges the results. This is 32x more memory bandwidth for the KV read.
Kernel 4: Prefill V1 (flash_attention_prefill_f16)
The prefill kernel processes multiple query rows simultaneously using SIMD group matrix multiply-accumulate (MMA) for both the QK product and the PV product.
Threadgroup Geometry
constant constexpr uint NQ_ROWS_DEFAULT = 8;
constant constexpr uint KV_BLOCK_SIZE = 32;
constant constexpr uint V2_TG_SIZE = 128; // 4 SIMD groups
Each threadgroup handles NQ query rows (8 or 16, depending on head_dim) and tiles the KV dimension in blocks of 32.
Adaptive NQ via Function Constants
constant uint FC_NQ_ROWS [[function_constant(2)]];
const uint NQ = FC_NQ_SPECIALIZED ? FC_NQ_ROWS : NQ_ROWS_DEFAULT;
| Head Dim | NQ | Passes | TG Memory |
|---|---|---|---|
| <= 64 | 16 | 2 | ~6KB |
| <= 128 | 8 | 1 | ~12KB |
| <= 256 | 4 | 1 | ~24KB |
For small head dimensions, NQ=16 allows processing more queries per threadgroup, amortizing the KV load cost over more query rows.
Threadgroup Memory Layout
threadgroup half *tg_Q // [NQ, HD] half
threadgroup float *tg_S // [NQ, 32] float (score matrix)
threadgroup float *tg_O // [NQ, HD] float (output accumulator)
threadgroup half *tg_K // [HD, 32] half (also reused as P scratch)
threadgroup half *tg_V // [32, HD] half
threadgroup float *tg_row_max // [NQ] float
threadgroup float *tg_row_sum // [NQ] float
The output accumulator is in FP32 to maintain precision across many accumulation steps. The K tile is stored transposed ([HD, 32] instead of [32, HD]) because the MMA operation needs K in column-major format for the Q @ K^T product.
Direct Device K Loading
A key optimization in V1: for full KV blocks (32 entries), K is loaded directly from device memory during the MMA computation, bypassing threadgroup memory entirely:
if (full_block) {
const uint kv_block_pos = kv_start + s_col_tile * SIMD_TILE;
for (uint inner = 0; inner < head_dim; inner += SIMD_TILE) {
simdgroup_load(q_tile, tg_Q + q_row_offset * head_dim + inner, head_dim);
simdgroup_load(k_tile, k_base + kv_block_pos * head_dim + inner,
head_dim, 0, true); // transposed load!
simdgroup_multiply_accumulate(acc_s, q_tile, k_tile, acc_s);
}
}
The true parameter on simdgroup_load requests a transposed load. Apple’s MMA hardware can load matrices in either row-major or column-major order, so the transposition is free. This eliminates the cooperative load + barrier for the K tile in full blocks, saving ~2 microseconds per block.
Causal Skip Optimization
const bool is_causal = (!FC_HAS_TREE_MASK && !FC_IS_NON_CAUSAL
&& kv_seq_len == seq_len);
const uint causal_kv_limit = is_causal
? min(((last_q_pos + KV_BLOCK_SIZE) / KV_BLOCK_SIZE) * KV_BLOCK_SIZE,
kv_seq_len)
: kv_seq_len;
For causal attention, KV blocks beyond the causal limit are entirely masked and can be skipped. This saves roughly half the computation for long sequences (the causal mask is triangular).
Block Classification
Within the causal region, blocks are classified as fully unmasked, partially masked, or fully masked:
const bool block_fully_unmasked = is_causal
? (kv_end <= first_q_pos + 1 && kv_seq_len > KV_BLOCK_SIZE * 4)
: (!FC_HAS_TREE_MASK);
if (block_fully_unmasked && full_block && pass_nq == TILE_ROWS) {
// Fast path: just scale, no per-element masking
for (uint idx = tid_flat; idx < TILE_ROWS * KV_BLOCK_SIZE; idx += V2_TG_SIZE) {
tg_S[qr * KV_BLOCK_SIZE + kv_col] *= scale;
}
} else {
// Slow path: per-element causal/tree masking
}
Fully unmasked blocks (far from the causal diagonal) use a fast path that just scales the scores. Partially masked blocks (near the diagonal) check each element individually.
Multi-Pass for Large NQ
When NQ > 8, the kernel processes query rows in passes of 8 (matching the SIMD MMA tile size):
for (uint q_pass = 0; q_pass < n_q_passes; q_pass++) {
const uint q_row_offset = q_pass * TILE_ROWS;
// Compute S, softmax, P@V for rows [q_row_offset..+8]
}
The V tile is loaded once per KV block and reused across all query passes, amortizing the memory cost.
Interactive: Prefill Attention — Tiled Q x KV with Causal Mask
Prefill processes all prompt tokens at once. Unlike decode (1 query, all KV), prefill has many queries and many KV positions. It tiles both dimensions: a block of queries (BQ=32) iterates over blocks of KV (BK=32), computing attention scores and accumulating outputs — like a GEMM with online softmax and a causal mask.
Kernel 5: Prefill V2 (flash_attention_prefill_v2_f16)
The V2 kernel is the newest and fastest attention kernel, designed for long sequences. Its key innovations are:
Register-Based Output
Instead of accumulating output in threadgroup memory (tg_O), V2 keeps the output in SIMD group registers:
float2 acc_o[HD_TILES]; // HD/8 output tile fragments
for (uint i = 0; i < HD_TILES; i++)
acc_o[i] = float2(0);
Each thread holds two float values per output tile (matching the thread_elements() of an 8x8 SIMD matrix). For head_dim=128, this is 16 float2 values = 128 bytes per thread – entirely in registers, no threadgroup memory needed for output.
exp2 Instead of exp
const float scale_log2 = params.scale * M_LOG2E_F;
// Pre-scale Q
tg_Q[idx] = half(float(tg_Q[idx]) * scale_log2);
// Later, in softmax:
s_frag[t][0] = fast::exp2(s_frag[t][0] - new_max);
By pre-multiplying the scale factor by log2(e), the kernel can use fast::exp2 instead of exp. On Apple Silicon, fast::exp2 maps directly to the hardware transcendental unit and is approximately 2x faster than exp.2
Row-Level Reduce via SIMD Shuffle
Apple’s 8x8 MMA distributes elements across lanes in a specific pattern. To compute row max and row sum, V2 uses XOR shuffles:
template <typename T>
METAL_FUNC T row_max(T v) {
v = metal::max(v, simd_shuffle_xor(v, 1));
v = metal::max(v, simd_shuffle_xor(v, 8));
return v;
}
template <typename T>
METAL_FUNC T row_sum(T v) {
v += simd_shuffle_xor(v, 1);
v += simd_shuffle_xor(v, 8);
return v;
}
In Apple’s MMA layout, lanes that share a row differ by XOR distances of 1 and 8. Two shuffle operations suffice for a complete row reduction – no threadgroup barriers needed.
BQ=32 Query Block Size
V2 processes 32 query rows per threadgroup (vs V1’s 8-16), with 4 SIMD groups each handling 8 rows. This 4x improvement in query batch size means:
- The V tile ([32, head_dim]) is loaded once and reused across all 32 queries
- The KV load cost is amortized over 4x more queries
- The total number of threadgroups is reduced by 4x
KV Tile Reuse for K and V
threadgroup half *tg_KV = (threadgroup half *)(tg_raw + off);
// First: load K transposed → tg_KV
// Compute S = Q @ K^T
threadgroup_barrier(mem_flags::mem_threadgroup);
// Then: load V → tg_KV (reuse same memory)
// Compute O += P @ V
The K and V tiles are never needed simultaneously, so they share the same threadgroup memory region. This halves the threadgroup memory requirement.
P Fragment in Registers
Instead of materializing the full probability matrix P in threadgroup memory, V2 keeps it in SIMD registers:
simdgroup_half8x8 p_half;
thread half2 &p_h = *(thread half2 *)&(p_half.thread_elements());
p_h[0] = half(s_frag[t][0]);
p_h[1] = half(s_frag[t][1]);
simdgroup_load(v_mat, tg_KV + t * SG_TILE * HD + d * SG_TILE, HD);
simdgroup_multiply_accumulate(o_mat, p_half, v_mat, o_mat);
The score fragments are cast to half precision and packed into SIMD matrix format directly from registers, then used in the MMA operation. No threadgroup write/read cycle.
Interactive: Prefill V2 — Register-Only Output Pipeline
Prefill V2 is the fastest attention kernel. Its key insight: keep the output accumulator in SIMD registers instead of threadgroup memory, and reuse the same threadgroup memory region for K and V tiles alternately. This animation shows one KV-block iteration, contrasting V1’s memory-heavy approach with V2’s register pipeline.
Kernel 6: Softmax (softmax_f16)
The standalone softmax kernel operates on pre-computed score matrices:
kernel void softmax_f16(
device half *data, constant SoftmaxParams ¶ms,
uint3 tgid_v, uint3 tid_v, uint sgid, uint slid, uint3 tpg
) {
float local_max = -INFINITY;
for (uint i = tid; i < cols; i += tg_size)
local_max = max(local_max, float(row_data[i]));
float row_max = tg_reduce_max(local_max, sgid, slid, tg_size, shared);
float local_sum = 0.0f;
for (uint i = tid; i < cols; i += tg_size)
local_sum += exp(float(row_data[i]) - row_max);
float total_sum = tg_reduce_sum(local_sum, sgid, slid, tg_size, shared);
float inv_sum = 1.0f / total_sum;
for (uint i = tid; i < cols; i += tg_size)
row_data[i] = half(exp(float(row_data[i]) - row_max) * inv_sum);
}
This is a classic three-pass softmax: find max, compute exp-sum, normalize. It uses tg_reduce_max and tg_reduce_sum (utility functions that combine simd_max/simd_sum with threadgroup shared memory) for efficient cross-SIMD-group reductions.
The kernel operates in-place (reads and writes the same buffer), dispatches one threadgroup per row, and handles arbitrary row lengths via strided access.
Logit Soft-Capping (logit_cap)
Gemma models apply a soft cap to attention logits:
kernel void logit_softcap_f16(
device half *logits, constant float &cap, constant uint &count,
uint tid [[thread_position_in_grid]]
) {
float x = float(logits[tid]);
logits[tid] = half(cap * tanh(x / cap));
}
This bounds the logits to [-cap, +cap] using a smooth tanh function. Applied before the softmax in attention, it prevents extreme attention scores that could lead to numerical instability.3
Kernel Selection Strategy
The host selects the attention kernel based on the phase and sequence length:
DECODE (M=1):
if fast_decode_available && kv_seq_len < parallel_threshold:
→ flash_attention_decode_fast_f16 (32 threads, 0 barriers)
elif kv_seq_len >= parallel_threshold:
→ flash_attention_decode_parallel_f16 (1024 threads)
else:
→ flash_attention_decode_f16 (128-1024 threads)
PREFILL (M>1):
if seq_len >= 1024 && head_dim <= 128:
→ flash_attention_prefill_v2_f16 (BQ=32, register output)
elif seq_len >= v2_threshold:
→ flash_attention_prefill_f16 (NQ=8 or 16, simd MMA)
else:
→ flash_attention_decode_f16 (per-query TG, scalar dot)
Performance Comparison
Here is a rough comparison of the five kernels for head_dim=128:
| Kernel | Throughput at KV=512 | Throughput at KV=4096 | Throughput at KV=32K |
|---|---|---|---|
| Standard Decode | High | Medium | Poor |
| Fast Decode | Medium | High | Medium |
| Parallel Decode | Low (overhead) | Medium | High |
| Prefill V1 (NQ=8) | – | Good | Good |
| Prefill V2 (BQ=32) | – | Excellent | Excellent |
The crossover points between kernels are hardware-dependent and are determined by profiling. The general principle: use the simplest kernel that is not bottlenecked by the context length.
Memory Analysis: Why FlashAttention Matters
To understand why FlashAttention is necessary, let’s compare memory usage for a standard attention computation vs. FlashAttention on a concrete example:
Model: Llama 3.1 8B, head_dim=128, n_heads=32, seq_len=4096
Naive Attention
S = Q @ K^T: [4096, 4096] * 32 heads * 4 bytes (FP32) = 2 GB
P = softmax(S): Same as S = 2 GB
O = P @ V: [4096, 128] * 32 heads * 4 bytes = 64 MB
─────────────────────────────────────────────
Total: ~4 GB for attention alone
This is clearly infeasible for on-device inference where total system RAM might be 16-36 GB and most of it is used for model weights.
FlashAttention
Q tile: [NQ, 128] * 2 bytes = NQ * 256 bytes
K tile: [128, 32] * 2 bytes = 8 KB
V tile: [32, 128] * 2 bytes = 8 KB
S scores: [NQ, 32] * 4 bytes = NQ * 128 bytes
O accum: [NQ, 128] * 4 bytes = NQ * 512 bytes
Softmax: [NQ] * 8 bytes = negligible
─────────────────────────────────────────────
Total: ~17 KB per threadgroup (NQ=8)
~33 KB per threadgroup (NQ=32, V2)
FlashAttention reduces memory from O(seq_len^2) to O(NQ * head_dim) per threadgroup – a reduction of over 100,000x for a 4096-token sequence. The attention matrix S is never fully materialized; only a single KV block’s worth of scores exists at any time.4
Numerical Precision Considerations
Online softmax introduces a subtlety: the correction factor exp(old_max - new_max) is applied multiplicatively to the running accumulator. After many KV blocks, this means the output has been multiplied by a chain of correction factors:
O_final = O_0 * c_1 * c_2 * ... * c_T + ...
Each correction factor is <= 1.0 (since new_max >= old_max), so the chain product decreases monotonically. For very long sequences (tens of thousands of KV entries), the accumulated product can approach the FP32 denormalization threshold.
In practice, this is not a problem because:
- The correction factor is only significantly less than 1.0 when a new maximum exceeds the old by a large margin, which happens rarely after the first few KV blocks.
- Once the running maximum stabilizes (typically within the first 100-200 KV entries), all subsequent corrections are approximately 1.0.
- The final normalization by
1/running_sumrescales the output, compensating for any cumulative shrinkage.
Akunu’s V2 kernel further improves precision by using exp2 instead of exp, which maps to the hardware transcendental unit and avoids the intermediate multiply-by-ln(2) that standard exp requires.
Attention Kernel Selection in the Dispatch Table
The choice of attention kernel is made once during dispatch table construction, not at runtime. The host examines the model configuration and hardware capabilities to select the best kernel:
For DECODE dispatch table:
1. Check if fast decode is suitable:
- head_dim <= 256 (fits in 32 lanes with <= 8 elements/lane)
- M=1 (single token decode)
2. Check if parallel decode is suitable:
- Expected long contexts (> 4096 KV entries)
3. Default: standard decode
For PREFILL (called at runtime based on seq_len):
1. V2 if seq_len >= 1024 and head_dim <= 128
2. V1 if seq_len >= v2_threshold
3. Standard decode fallback for very short sequences
Function constant specialization is used to bake head_dim, NQ, and masking mode into the kernel at pipeline compilation time. This enables the Metal compiler to generate optimized code for the specific configuration.
Function Constant Specialization Strategy
All attention kernels use Metal function constants for compile-time specialization. The specialization strategy differs by kernel:
| Kernel | Function Constants | Benefit |
|---|---|---|
| Standard decode | FC_HEAD_DIM, FC_USE_TREE_MASK, FC_NON_CAUSAL | Eliminates head_dim conditionals, removes unused masking code |
| Fast decode | FC_HEAD_DIM | Enables elems_per_thread as compile-time constant |
| Parallel decode | FC_HEAD_DIM | Same as fast decode |
| Prefill V1 | FC_HEAD_DIM, FC_NQ_ROWS, FC_NON_CAUSAL | Enables loop unrolling for head_dim, NQ pass calculation |
| Prefill V2 | FC_HEAD_DIM, FC_NQ_ROWS, FC_NON_CAUSAL | Same plus HD_TILES becomes compile-time |
The host creates separate PSOs for each unique combination of function constants. For a model with head_dim=128 and no tree masking, a typical PSO cache contains:
attn_decode_hd128_notree(standard decode)attn_decode_fast_hd128(fast decode)attn_decode_parallel_hd128(parallel decode)attn_prefill_hd128_nq8(prefill V1 with NQ=8)attn_pfv2_hd128_nq32(prefill V2 with NQ=32)
Each PSO is compiled once during model initialization. The compilation cost (~20-50ms per PSO) is amortized over the entire inference session.
The Role of q_stride and kv_stride
The attention kernels support two memory layouts:
Head-major layout: Q[head, seq_len, head_dim] – used when Q comes directly from the GEMV projection output. Each head’s data is contiguous.
Row-major layout: Q[seq_len, n_heads * head_dim] – used in prefill when Q comes from a GEMM output where rows correspond to sequence positions.
The q_stride parameter tells the kernel which layout to use:
device const half *q_row = (q_str > 0)
? Q + q_pos * q_str + head * head_dim // row-major
: Q + (head * seq_len + q_pos) * head_dim; // head-major
Similarly, kv_stride controls the KV cache layout. When kv_stride > 0, the KV cache uses a fixed-stride layout (allocated for the maximum sequence length); when 0, it uses a compact layout.
Sliding Window Attention
Some models (Mistral, Gemma 3) use sliding window attention where each token only attends to the most recent W positions. Akunu handles this at the KV cache level rather than in the attention kernel: the KV cache uses a ring buffer, and the effective kv_seq_len passed to the attention kernel is clamped to the window size.
This design keeps the attention kernels simple and universal. The sliding window logic lives in the KV cache management code, which adjusts the visible range before dispatching attention.
Comparison with Other Implementations
| Feature | Akunu | llama.cpp Metal | MLX |
|---|---|---|---|
| Decode attention variants | 3 (standard, fast, parallel) | 1 | 1 |
| Prefill attention variants | 2 (V1 simd MMA, V2 register) | 1 | 1 |
| Online softmax | Yes (all variants) | Yes | Yes |
| exp2 optimization | Yes (V2) | No | Yes |
| Register output | Yes (V2) | No | Yes |
| Tree masking | Yes | No | No |
| Non-causal mode | Yes | No | Yes |
| GQA support | Yes | Yes | Yes |
| BQ=32 query blocks | Yes (V2) | No | No |
| Direct K device load | Yes (V1 full blocks) | No | No |
Akunu’s attention kernel family is arguably the most diverse of any on-device LLM inference engine, with 5 variants covering the full range of use cases from short-context decode to long-sequence prefill.
Deep Dive: The V1-to-V2 Evolution
The prefill V1 and V2 kernels represent two different approaches to the same problem. Understanding how V2 improves on V1 illuminates the tradeoffs in GPU kernel design.
V1: Threadgroup Memory-Centric
V1 stores all intermediate results in threadgroup memory:
- Output accumulator
tg_O: NQ * HD * 4 bytes = 4096 bytes (NQ=8, HD=128) - Score matrix
tg_S: NQ * 32 * 4 bytes = 1024 bytes - Total per-threadgroup: ~12-17 KB depending on head_dim
The advantage: any thread can read any accumulator element, enabling flexible work distribution across SIMD groups. The disadvantage: every MMA output must be written to TG memory and every accumulator update requires a TG memory read-modify-write.
V2: Register-Centric
V2 keeps the output in SIMD group registers:
- Output accumulator:
float2 acc_o[HD_TILES]per thread = HD * 2 * 4 / 32 bytes per thread - For HD=128: 32 bytes per thread in registers
- Total per-SG: 32 * 32 = 1024 bytes in registers (no TG memory)
The advantage: register access is free (0 latency, infinite bandwidth). The disadvantage: each SIMD group can only access its own registers, requiring careful work assignment to avoid cross-SG communication.
V2 Score Handling via thread_elements()
The key insight enabling V2’s register-centric approach is that the MMA instruction’s thread_elements() accessor provides direct access to the 2 elements each thread owns in the 8x8 matrix result. This means:
- After computing S = Q @ K^T via MMA, each thread can directly read its 2 score elements without going through TG memory.
- Row-level max and sum can be computed using
simd_shuffle_xor(because all lanes sharing a row can communicate without barriers). - The rescaling
acc_o *= correctionis a local register operation.
This eliminates the TG memory round-trip for scores, the barrier-heavy softmax update, and the TG memory round-trip for output accumulation – three of the four major bottlenecks in V1.
Performance Impact
For a 4096-token prefill on M4 Pro with head_dim=128:
| Aspect | V1 (NQ=8) | V2 (BQ=32) |
|---|---|---|
| Threadgroups | 512 * n_heads | 128 * n_heads |
| TG memory per TG | ~17 KB | ~9 KB |
| KV load overhead | Per NQ=8 queries | Per BQ=32 queries (4x better amortization) |
| Barriers per KV block | ~6 | ~4 |
| Output write | TG memory -> device | Register -> device |
| Relative throughput | 1.0x | ~1.3-1.5x |
The 30-50% improvement comes from three sources: (1) 4x better KV data amortization, (2) fewer barriers per block, and (3) elimination of TG memory traffic for the output accumulator.
Attention and Memory Bandwidth
Attention is unique among transformer operations because it reads from the KV cache, which grows linearly with context length. For a 7B model with 8 KV heads, head_dim=128, and context length L:
KV cache reads per token per layer:
K: L * 128 * 2 bytes = 256L bytes
V: L * 128 * 2 bytes = 256L bytes
Total: 512L bytes per layer
32 layers: 16384L bytes = 16L KB
For L=4096: 16 * 4096 KB = 64 MB per token
For L=32K: 16 * 32768 KB = 512 MB per token
At L=32K, the KV cache reads alone consume 512 MB per token. At 200 GB/s memory bandwidth, this is 2.56ms – a significant fraction of the per-token time. This is why attention becomes the bottleneck at long contexts, surpassing even the GEMV weight reads.
The parallel decode kernel (32 SIMD groups) addresses this by parallelizing the KV scan, effectively multiplying the read bandwidth by 32x through concurrent memory requests. Each SIMD group reads independent KV positions, saturating the memory subsystem’s request queues.
Attention Dispatch Counts
For a 7B model with 32 layers and 32 attention heads, the total attention dispatch count during decode is:
Greedy decode (1 token):
32 layers × 1 attention dispatch = 32 attention dispatches
Chain decode (64 tokens):
32 layers × 64 tokens × 1 attention dispatch = 2048 attention dispatches
Each dispatch handles all 32 heads (the grid Y dimension covers heads). The attention kernel is typically the 3rd or 4th most expensive dispatch per token (after the FFN GEMVs and the QKV GEMVs), but at long contexts it becomes the most expensive.
For the prefill of 2048 tokens:
32 layers × 1 attention dispatch (covers all 2048 queries) = 32 attention dispatches
Prefill uses far fewer dispatches because each attention kernel processes all query rows simultaneously. This is another reason prefill is more efficient per-token than decode.
The Softmax Temperature Connection
The standalone softmax_f16 kernel and the logit_softcap_f16 kernel are both attention-adjacent operations:
Softmax is used by non-FlashAttention code paths (e.g., when debugging or when the model requires explicit softmax for cross-attention). It processes one row per threadgroup using the standard three-pass algorithm (find max, compute exp sum, normalize). The tg_reduce_max and tg_reduce_sum helper functions use the same SIMD-first + threadgroup-memory reduction pattern seen in the sampling kernels.
Logit soft-capping (cap * tanh(x / cap)) is specific to Gemma models and is applied before the softmax within attention. It bounds the attention logits to prevent extreme values that could destabilize the softmax computation. The tanh function saturates at +/-1, so the effective range is [-cap, +cap]. Typical values are cap = 30 or cap = 50.
Both kernels are simple 1D dispatches with minimal state, taking <5us per invocation. Their impact on overall performance is negligible, but their presence enables Akunu to support architectures that require these operations.
Kernel Selection Decision Tree
To summarize the complete selection logic, here is the decision tree the host uses:
Is this decode (M=1)?
├── YES: Is the dispatch table using fast decode?
│ ├── YES: flash_attention_decode_fast_f16
│ │ (1 SG, 32 threads, 0 barriers, head_dim/32 elems/thread)
│ └── NO: Is parallel decode enabled?
│ ├── YES: flash_attention_decode_parallel_f16
│ │ (32 SG, 1024 threads, cross-SG reduction)
│ └── NO: flash_attention_decode_f16
│ (N SGs, up to 1024 threads, 2 barriers/KV)
└── NO (prefill, M>1): What is seq_len?
├── seq_len >= 1024 AND head_dim <= 128:
│ flash_attention_prefill_v2_f16
│ (BQ=32, register output, exp2, 128 threads)
├── seq_len >= v2_threshold:
│ flash_attention_prefill_f16
│ (NQ=8 or 16, TG output, simd MMA, 128 threads)
└── seq_len < v2_threshold:
flash_attention_decode_f16 (per-query threadgroup)
The v2_threshold is model-dependent: for head_dim=64 it is 32 (NQ=16, so need at least 32 queries), for head_dim=128 it is 16 (NQ=8). Below these thresholds, there are not enough queries to fill the prefill kernel’s tile efficiently.
Summary
Akunu’s attention kernel family demonstrates that there is no single “best” attention algorithm – the optimal approach depends on the sequence length, batch size, and hardware capabilities:
- Standard decode: Simple, uses threadgroup barriers, best for short contexts or fallback.
- Fast decode: Single SIMD group, barrier-free, best for medium contexts (128-1024 KV).
- Parallel decode: 32 SIMD groups with KV parallelism, best for very long contexts (10K+).
- Prefill V1: SIMD MMA with direct K loading, best for medium prefill sequences.
- Prefill V2: Register output, exp2, BQ=32, best for long prefill sequences (1024+).
The online softmax algorithm is the thread that connects all five kernels – the same mathematical principle of running max/sum correction, implemented differently based on the parallelism strategy.
-
Dao, T., Fu, D.Y., Ermon, S., Rudra, A., and Re, C. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022. The key insight is that by tiling attention and maintaining online softmax statistics, the algorithm achieves O(N) memory usage instead of O(N^2) while performing the same mathematical operation. See https://arxiv.org/abs/2205.14135. ↩
-
Apple. “Metal Best Practices Guide.” Section “Use Fast Math Functions.” The
fast::exp2function on Apple GPU hardware uses the native transcendental function unit, which computes exp2 in a single pipeline cycle. Standardexprequires an additional multiply byln(2)internally. See https://developer.apple.com/library/archive/documentation/3DDrawing/Conceptual/MTLBestPracticesGuide/index.html. ↩ -
Gemma Team, Google DeepMind. “Gemma: Open Models Based on Gemini Research and Technology.” arXiv:2403.08295, 2024. The logit soft-capping technique with
cap * tanh(x/cap)prevents attention score explosion while maintaining gradient flow during training. See https://arxiv.org/abs/2403.08295. ↩ -
Rabe, M.N. and Staats, C. “Self-attention Does Not Need O(n^2) Memory.” arXiv:2112.05682, 2021. This paper independently proved the same memory-efficient attention idea as FlashAttention, showing that O(1) memory is achievable for the attention computation. ↩
Normalization Kernels
Every transformer layer starts with a normalization step. Before attention, before the FFN, before practically anything interesting happens, you need to normalize your hidden states. Without it, activations drift, gradients explode, and your model produces gibberish. The math is simple. The GPU implementation? That is where things get interesting.
In this chapter we will walk through akunu’s Metal normalization kernels: RMSNorm,
LayerNorm, fused Residual+RMSNorm, per-head RMSNorm, and the Gemma variants. We
will trace every threadgroup reduction, every rsqrt call, and every clamping
trick that keeps FP16 from blowing up.
The Normalization Zoo
Modern LLMs use two main normalization flavors:
LayerNorm (GPT-2, Whisper): y[i] = ((x[i] - mean) / sqrt(var + eps)) * weight[i] + bias[i]
RMSNorm (LLaMA, Qwen, etc): y[i] = (x[i] / sqrt(mean(x^2) + eps)) * weight[i]
RMSNorm is cheaper: it skips the mean subtraction entirely. You only need the root mean square (RMS) of the input. No centering, no bias. This is why nearly every modern LLM uses it – one fewer reduction pass over the data, and the empirical quality is essentially identical.
Let us see how akunu implements both.
RMSNorm: One Threadgroup per Row
The dispatch model is beautifully simple: one threadgroup handles one row of the input tensor. If you have a batch of 8 tokens with dimension 4096, you launch 8 threadgroups. Each threadgroup has enough threads to cover the dimension with striding.
Dispatch Grid:
threadgroups = (num_rows, 1, 1)
threads_per_threadgroup = (tg_size, 1, 1) // e.g., 256 or 512
Row 0: TG 0 --> threads 0..tg_size-1 stride over dim elements
Row 1: TG 1 --> threads 0..tg_size-1 stride over dim elements
...
Row N: TG N --> threads 0..tg_size-1 stride over dim elements
Here is the actual kernel from rmsnorm.metal:
kernel void rmsnorm_f16(
device const half *input [[buffer(0)]],
device const half *weight [[buffer(1)]],
device half *output [[buffer(2)]],
constant RMSNormParams ¶ms [[buffer(3)]],
uint3 tgid_v [[threadgroup_position_in_grid]],
uint3 tid_v [[thread_position_in_threadgroup]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint slid [[thread_index_in_simdgroup]],
uint3 tpg [[threads_per_threadgroup]]
) {
const uint dim = params.dim;
const float eps = params.eps;
const uint row = tgid_v.x;
const uint tid = tid_v.x;
const uint tg_size = tpg.x;
device const half *row_in = input + row * dim;
device half *row_out = output + row * dim;
threadgroup float shared[32];
float local_sum_sq = 0.0f;
for (uint i = tid; i < dim; i += tg_size) {
float v = float(row_in[i]);
local_sum_sq += v * v;
}
float total_sum_sq = tg_reduce_sum(local_sum_sq, sgid, slid, tg_size, shared);
float rms = rsqrt(total_sum_sq / float(dim) + eps);
for (uint i = tid; i < dim; i += tg_size) {
row_out[i] = half(float(row_in[i]) * rms) * weight[i];
}
}
Let us break this down step by step.
Step 1: Accumulate Sum of Squares
Each thread strides over the row, accumulating x[i]^2 into a local float
register. For a 4096-dimensional row with 256 threads, each thread processes
16 elements:
Thread 0: x[0]^2 + x[256]^2 + x[512]^2 + ... + x[3840]^2
Thread 1: x[1]^2 + x[257]^2 + x[513]^2 + ... + x[3841]^2
Thread 2: x[2]^2 + x[258]^2 + x[514]^2 + ... + x[3842]^2
...
Thread 255: x[255]^2 + x[511]^2 + x[767]^2 + ... + x[4095]^2
Notice the promotion to float: float v = float(row_in[i]). The input is FP16,
but all accumulation happens in FP32. This is critical. Summing squares in FP16
overflows almost instantly for typical hidden state magnitudes.
Step 2: Two-Stage Threadgroup Reduction
This is where the magic happens. We need to sum 256 (or however many) partial
sums into a single total. The kernel calls tg_reduce_sum, which lives in
KernelCommon.h:
inline float tg_reduce_sum(float val, uint sgid, uint slid,
uint tg_size, threadgroup float *shared) {
float simd_val = simd_sum(val);
uint n_sg = (tg_size + SIMD_WIDTH - 1) / SIMD_WIDTH;
if (slid == 0)
shared[sgid] = simd_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgid == 0) {
float v = (slid < n_sg) ? shared[slid] : 0.0f;
float total = simd_sum(v);
if (slid == 0)
shared[0] = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
return shared[0];
}
This is a classic two-stage SIMD reduction. Let me draw it out for a threadgroup of 256 threads (8 SIMD groups of 32 threads each):
Stage 1: SIMD-level reduction (hardware shuffle, no barriers needed)
+------------------------------------------------------------------+
| SIMD Group 0: t0..t31 --simd_sum--> partial_0 (in t0) |
| SIMD Group 1: t32..t63 --simd_sum--> partial_1 (in t32) |
| SIMD Group 2: t64..t95 --simd_sum--> partial_2 (in t64) |
| ... |
| SIMD Group 7: t224..t255 --simd_sum--> partial_7 (in t224) |
+------------------------------------------------------------------+
|
lane 0 of each SG writes to shared[]
|
[threadgroup_barrier]
|
v
Stage 2: Cross-SIMD reduction (only SIMD Group 0 participates)
+------------------------------------------------------------------+
| SG 0, lane 0: reads shared[0] = partial_0 |
| SG 0, lane 1: reads shared[1] = partial_1 |
| SG 0, lane 2: reads shared[2] = partial_2 |
| ... |
| SG 0, lane 7: reads shared[7] = partial_7 |
| SG 0, lanes 8-31: use 0.0f (padding) |
| |
| --simd_sum--> TOTAL (written to shared[0] by lane 0) |
+------------------------------------------------------------------+
|
[threadgroup_barrier]
|
v
All threads read shared[0] = total_sum_sq
Why only 32 floats of shared memory? Because Apple GPUs have 32 threads per SIMD
group, and you can have at most 32 SIMD groups in a threadgroup (32 * 32 = 1024
threads max). So shared[32] is always enough for the cross-SIMD exchange.
The beauty here is that simd_sum compiles to hardware shuffle instructions. No
shared memory access, no barriers. It is a register-to-register operation within
the SIMD group. The only shared memory access is the handoff between stage 1 and
stage 2, requiring just two barriers for the entire reduction.
Step 3: Compute the Inverse RMS
float rms = rsqrt(total_sum_sq / float(dim) + eps);
This single line does three things:
- Divides the sum of squares by the dimension to get the mean
- Adds epsilon (typically 1e-5 or 1e-6) for numerical stability
- Calls
rsqrt– the reciprocal square root
The rsqrt hardware instruction on Apple Silicon computes 1/sqrt(x) in a
single cycle. Compared to doing 1.0f / sqrt(x), the fused rsqrt is both
faster and more numerically accurate. This is why the kernel computes
x * rsqrt(mean_sq + eps) rather than x / sqrt(mean_sq + eps) – multiplication
is cheaper than division, and rsqrt gives us the reciprocal directly.
Step 4: Normalize and Scale
for (uint i = tid; i < dim; i += tg_size) {
row_out[i] = half(float(row_in[i]) * rms) * weight[i];
}
Each thread walks the row again, multiplying each element by the inverse RMS and
then by the learned weight. The float(row_in[i]) * rms computation happens in
FP32, then is cast down to FP16 for the multiply with weight[i] (which is
already FP16). This two-pass approach (one pass for reduction, one for output)
is unavoidable – you cannot write output until you know the RMS.
LayerNorm: Two Reductions Instead of One
LayerNorm is the older, more expensive cousin. Used in GPT-2, Whisper, and other pre-LLaMA architectures, it subtracts the mean before dividing by standard deviation:
y[i] = ((x[i] - mean) / sqrt(var + eps)) * weight[i] + bias[i]
The extra mean subtraction means an extra reduction pass. Here is the kernel:
kernel void layernorm_f16(
device const half *input [[buffer(0)]],
device const half *weight [[buffer(1)]],
device const half *bias [[buffer(2)]],
device half *output [[buffer(3)]],
constant LayerNormParams ¶ms [[buffer(4)]],
...
) {
// Pass 1: compute mean
float local_sum = 0.0f;
for (uint i = tid; i < dim; i += tg_size)
local_sum += float(row_in[i]);
float mean = tg_reduce_sum(local_sum, sgid, slid, tg_size, shared)
/ float(dim);
// Pass 2: compute variance
float local_var = 0.0f;
for (uint i = tid; i < dim; i += tg_size) {
float diff = float(row_in[i]) - mean;
local_var += diff * diff;
}
float variance = tg_reduce_sum(local_var, sgid, slid, tg_size, shared)
/ float(dim);
float inv_std = rsqrt(variance + eps);
// Pass 3: normalize, scale, shift
for (uint i = tid; i < dim; i += tg_size) {
float val = (float(row_in[i]) - mean) * inv_std;
row_out[i] = half(val) * weight[i] + bias[i];
}
}
The data flow looks like this:
Pass 1: Compute mean
row_in --> [sum elements] --> [tg_reduce_sum] --> mean = sum / dim
Pass 2: Compute variance
row_in --> [(x - mean)^2] --> [tg_reduce_sum] --> var = sum / dim
inv_std = rsqrt(var + eps)
Pass 3: Output
row_in --> [(x - mean) * inv_std * weight + bias] --> row_out
That is three passes over the data versus two for RMSNorm. For a 4096-dim model,
this means 3 * 4096 * 2 = 24,576 bytes of memory traffic per row instead of
2 * 4096 * 2 = 16,384 bytes. The extra pass (and the extra tg_reduce_sum)
is why RMSNorm replaced LayerNorm in modern architectures.
Also note the + bias[i] at the end. LayerNorm has both learnable scale (weight)
and shift (bias) parameters. RMSNorm drops the bias entirely.
Fused Residual + RMSNorm: The Performance Killer Feature
Here is where akunu gets clever. In a transformer, the pattern looks like this:
hidden = attention_output + residual // residual add
normalized = RMSNorm(hidden) * weight // normalize for FFN
The naive approach dispatches two kernels: one for the addition, one for the normalization. But look at what that means in terms of memory:
Naive (2 kernels):
Kernel 1 (residual_add): read a[], read b[] --> write hidden[]
Kernel 2 (rmsnorm): read hidden[] --> write norm[]
Total memory traffic: 4 * dim * sizeof(half) reads + 2 * dim * sizeof(half) writes
Fused (1 kernel):
read a[], read b[] --> compute hidden, accumulate sum_sq, write hidden[]
read hidden[] --> write norm[]
Total memory traffic: saves one full read of hidden[] (2 * dim bytes)
But the real win is not just bandwidth. It is kernel launch overhead. Each Metal compute dispatch has a non-trivial fixed cost – command buffer encoding, GPU scheduling, pipeline state switches. Eliminating a dispatch is free performance.
Here is the fused kernel:
kernel void residual_rmsnorm_f16(
device const half *a [[buffer(0)]],
device const half *b [[buffer(1)]],
device const half *weight [[buffer(2)]],
device half *res_out [[buffer(3)]],
device half *norm_out [[buffer(4)]],
constant RMSNormParams ¶ms [[buffer(5)]],
...
) {
constexpr float F16_MAX = 65504.0f;
float local_sum_sq = 0.0f;
for (uint i = tid; i < dim; i += tg_size) {
float val = clamp(float(row_a[i]) + float(row_b[i]), -F16_MAX, F16_MAX);
row_res[i] = half(val);
local_sum_sq += val * val;
}
float total_sum_sq = tg_reduce_sum(local_sum_sq, sgid, slid, tg_size, shared);
float rms = rsqrt(total_sum_sq / float(dim) + eps);
for (uint i = tid; i < dim; i += tg_size) {
row_norm[i] = half(float(row_res[i]) * rms) * weight[i];
}
}
The F16_MAX Clamp
Notice this critical line:
float val = clamp(float(row_a[i]) + float(row_b[i]), -F16_MAX, F16_MAX);
FP16 can represent values up to 65504.0. If the residual sum exceeds this, the
cast back to half produces infinity, which then poisons everything downstream.
The clamp prevents this:
Without clamp:
a[i] = 40000.0h, b[i] = 30000.0h
sum = 70000.0f (fine in float)
half(70000.0f) = inf (OVERFLOW -- poisons entire row)
With clamp:
sum = clamp(70000.0f, -65504.0f, 65504.0f) = 65504.0f
half(65504.0f) = 65504.0h (saturated but finite)
This is a practical concern. During inference with long contexts, residual
accumulation can push values near the FP16 boundary. The clamp costs essentially
nothing (it compiles to a min/max pair) but prevents catastrophic NaN
propagation.
Dual Output Buffers
The fused kernel writes two outputs: res_out (the residual sum, needed for the
next residual connection) and norm_out (the normalized result, fed to the FFN or
attention). This is why there are five buffer bindings instead of three.
Per-Head RMSNorm: Qwen3’s QK Normalization
Some models (notably Qwen3) apply RMSNorm independently to each attention head’s
Q and K projections. Instead of normalizing a full [seq_len, model_dim] row, you
normalize individual [head_dim] slices.
The input layout is [seq_len, n_heads, head_dim], and each threadgroup handles
one (seq_position, head) pair:
Dispatch Grid:
threadgroups = (n_heads, seq_len, 1)
threads_per_threadgroup = (min(head_dim, 1024), 1, 1)
TG(0,0): head 0, pos 0 --> normalize row[0 * n_heads + 0]
TG(1,0): head 1, pos 0 --> normalize row[0 * n_heads + 1]
...
TG(0,1): head 0, pos 1 --> normalize row[1 * n_heads + 0]
...
The reduction is slightly different here. Instead of using the tg_reduce_sum
helper, the kernel manually implements the two-stage SIMD reduction:
float sum_sq = 0.0f;
for (uint d = tid; d < head_dim; d += tg_size) {
float v = float(row[d]);
sum_sq += v * v;
}
// Stage 1: SIMD-level sum
sum_sq = simd_sum(sum_sq);
if (lane == 0) shared[warp] = sum_sq;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stage 2: Cross-SIMD sum (only warp 0)
if (warp == 0) {
float v = (lane < (tg_size + 31) / 32) ? shared[lane] : 0.0f;
v = simd_sum(v);
shared[0] = v;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total_sq = shared[0];
This is the same algorithm as tg_reduce_sum, just inlined. For head_dim = 128
(common in Qwen3), you only need 128 threads and 4 SIMD groups. The reduction is
tiny.
The normalization is in-place – the kernel modifies x directly:
for (uint d = tid; d < head_dim; d += tg_size) {
row[d] = half(float(row[d]) * rms * float(weight[d]));
}
This is safe because each threadgroup works on a different (head, position) pair. No data races.
The Gemma Variants: weight’ = 1 + weight
Google’s Gemma model has an unusual normalization quirk. The norm weights are
initialized to zero and the effective scale is (1 + weight) rather than just
weight. This means a freshly initialized model has identity-like normalization
(scale = 1 everywhere), which improves training stability.
Akunu has dedicated Gemma variants for both RMSNorm and the fused residual version. The only difference is in the final output line:
// Standard RMSNorm:
row_out[i] = half(float(row_in[i]) * rms) * weight[i];
// Gemma RMSNorm:
row_out[i] = half(float(row_in[i]) * rms * (1.0f + float(weight[i])));
Notice the Gemma variant does the entire computation in float32 before casting to
FP16. The comment in the source says it all: “Compute entirely in float to avoid
F16 overflow in normalized * (1+weight).” If the weight is, say, 2.0h, then
(1 + weight) = 3.0, and multiplying a normalized value near 1.0 by 3.0 might
push things dangerously close to the FP16 range limit. Doing it all in float32
avoids this.
There is also a Gemma variant of the fused residual+RMSNorm kernel
(residual_rmsnorm_gemma_f16) that combines both the F16_MAX clamping and the
(1 + weight) scaling. Same dispatch, same reduction, just a different final
multiply.
Memory Access Patterns
Let us think about how these kernels interact with the GPU memory hierarchy:
Kernel Memory Access Pattern
+-------------------------------------------------------------------+
| |
| Global Memory (Device) |
| +-------------------------------------------------------------+ |
| | input[row * dim + 0..dim-1] <-- read once (pass 1) | |
| | input[row * dim + 0..dim-1] <-- read again (pass 2) | |
| | weight[0..dim-1] <-- read once (pass 2) | |
| | output[row * dim + 0..dim-1] <-- write once (pass 2) | |
| +-------------------------------------------------------------+ |
| |
| Threadgroup Memory (32 KB max) |
| +-------------------------------------------------------------+ |
| | shared[32] <-- 128 bytes for reduction intermediates | |
| +-------------------------------------------------------------+ |
| |
| Registers (per thread) |
| +-------------------------------------------------------------+ |
| | local_sum_sq (float) -- accumulator | |
| | rms (float) -- broadcast to all threads via shared | |
| +-------------------------------------------------------------+ |
| |
+-------------------------------------------------------------------+
The strided access pattern (for i = tid; i < dim; i += tg_size) ensures
coalesced memory access. Thread 0 reads element 0, thread 1 reads element 1, and
so on. Since Apple GPUs fetch 128 bytes per memory transaction, and FP16 elements
are 2 bytes each, a single fetch satisfies 64 consecutive threads. With a SIMD
width of 32, that means two SIMD groups’ worth of data per fetch. Excellent
utilization.
The weight vector is also read with perfect coalescing. And since weights are the same for every row, they likely stay in the GPU L2 cache after the first row is processed. For a 4096-dim model, the weight vector is only 8 KB – easily cached.
Performance Characteristics
Let us estimate the arithmetic intensity (operations per byte transferred) for RMSNorm on a 4096-dim row with 256 threads:
Pass 1 (sum of squares):
Reads: dim * 2 bytes = 8,192 bytes
FLOPs: dim * 2 (cast + multiply) + dim (add) = 12,288 FLOPs
+ reduction: ~500 FLOPs (negligible)
Pass 2 (normalize):
Reads: dim * 2 (input) + dim * 2 (weight) = 16,384 bytes
Writes: dim * 2 (output) = 8,192 bytes
FLOPs: dim * 3 (cast, multiply by rms, multiply by weight) = 12,288 FLOPs
Total:
Memory: 32,768 bytes
FLOPs: ~24,576
Arithmetic intensity: 0.75 FLOPs/byte
This is firmly in the memory-bound regime. Apple M-series GPUs can do hundreds of GFLOPs but have “only” ~200-400 GB/s of memory bandwidth. RMSNorm barely scratches the ALU – the bottleneck is waiting for memory loads.
This is actually fine. Normalization is a tiny fraction of total inference time. The GEMM (matrix multiply) kernels dominate runtime. Normalization exists in the cracks between GEMMs, and the main optimization goal is to avoid unnecessary memory round-trips – which is exactly what the fused residual+RMSNorm achieves.
The Complete Reduction Tree
To really nail down the reduction, let us trace an example with 8 SIMD groups
(256 threads) reducing the values [10, 20, 30, 40, 50, 60, 70, 80] (one partial
sum per SIMD group):
SIMD Reduction Tree
===================
Level 0: Each SIMD group does simd_sum() on 32 per-thread values
(hardware shuffle -- zero latency visible to software)
SG0: threads 0-31 -> simd_sum -> 10.0 (lane 0 writes shared[0])
SG1: threads 32-63 -> simd_sum -> 20.0 (lane 0 writes shared[1])
SG2: threads 64-95 -> simd_sum -> 30.0 (lane 0 writes shared[2])
SG3: threads 96-127 -> simd_sum -> 40.0 (lane 0 writes shared[3])
SG4: threads 128-159-> simd_sum -> 50.0 (lane 0 writes shared[4])
SG5: threads 160-191-> simd_sum -> 60.0 (lane 0 writes shared[5])
SG6: threads 192-223-> simd_sum -> 70.0 (lane 0 writes shared[6])
SG7: threads 224-255-> simd_sum -> 80.0 (lane 0 writes shared[7])
--- BARRIER ---
Level 1: SIMD Group 0 reads shared[], does one more simd_sum()
SG0, lane 0: v = shared[0] = 10.0
SG0, lane 1: v = shared[1] = 20.0
SG0, lane 2: v = shared[2] = 30.0
SG0, lane 3: v = shared[3] = 40.0
SG0, lane 4: v = shared[4] = 50.0
SG0, lane 5: v = shared[5] = 60.0
SG0, lane 6: v = shared[6] = 70.0
SG0, lane 7: v = shared[7] = 80.0
SG0, lanes 8-31: v = 0.0 (padding)
simd_sum -> 360.0 (lane 0 writes shared[0])
--- BARRIER ---
All 256 threads read: shared[0] = 360.0
Two barriers, two simd_sum calls. That is all it takes to reduce 256 values.
Compare this to a naive shared-memory reduction tree, which would need
log2(256) = 8 barrier-synchronize-and-halve steps. The SIMD-first approach
is dramatically more efficient on Apple hardware.
Summary
Akunu’s normalization kernels follow a clear pattern:
| Kernel | Model | Passes | Outputs | Special |
|---|---|---|---|---|
rmsnorm_f16 | LLaMA, Mistral | 2 | 1 (norm) | Basic RMSNorm |
layernorm_f16 | GPT-2, Whisper | 3 | 1 (norm) | Mean + variance |
residual_rmsnorm_f16 | LLaMA, Mistral | 2 | 2 (res + norm) | F16_MAX clamp |
rmsnorm_gemma_f16 | Gemma | 2 | 1 (norm) | (1+w) scaling |
residual_rmsnorm_gemma_f16 | Gemma | 2 | 2 (res + norm) | Both tricks |
head_rmsnorm_f16 | Qwen3 | 2 | 1 (in-place) | Per-head QK norm |
head_rmsnorm_gemma_f16 | Qwen3+Gemma | 2 | 1 (in-place) | Per-head + (1+w) |
The key takeaways:
- One threadgroup per row – simple, parallel, no cross-row communication.
- Two-stage SIMD reduction –
simd_sumwithin each SIMD group, then one moresimd_sumacross SIMD groups. Only two barriers. - FP32 accumulation – all arithmetic in float32, cast to FP16 only for memory writes.
- Fused kernels – the residual+RMSNorm fusion saves an entire kernel dispatch and one pass over the data.
rsqrthardware – single-cycle reciprocal square root, multiplication instead of division.- F16_MAX clamping – a cheap safety net that prevents overflow in the residual sum from poisoning downstream computation.
Next up, we will see how akunu applies positional information to these normalized vectors using Rotary Position Embeddings.
RoPE Kernels
Transformers have no built-in sense of order. Without positional information, the sentence “the cat sat on the mat” is indistinguishable from “mat the on sat cat the.” Rotary Position Embeddings (RoPE) solve this by rotating pairs of elements in the Q and K vectors by position-dependent angles. It is mathematically elegant, it extrapolates to longer sequences than training saw, and it composes beautifully with the attention dot product.
In this chapter we will walk through akunu’s seven RoPE-related Metal kernels: the basic standard and NeoX variants, the fused RoPE+KV-cache-write kernels for single-token decode, and the batch (prefill) versions. We will finish with the ultimate fusion: per-head RMSNorm + NeoX RoPE + KV write in a single dispatch.
The Math Behind RoPE
RoPE treats each pair of elements in a head dimension as a 2D vector and rotates
it by an angle proportional to the token’s position. For a pair at dimension index
d, the rotation angle is:
theta_d = position / (base_freq ^ (2d / head_dim))
Where base_freq (typically 10000.0) controls the frequency spectrum. Low-index
pairs rotate quickly (high frequency), high-index pairs rotate slowly (low
frequency). This gives each position a unique “fingerprint” of rotations.
The rotation itself is just a 2D rotation matrix:
[x0'] [cos(theta) -sin(theta)] [x0]
[x1'] = [sin(theta) cos(theta)] [x1]
x0' = x0 * cos(theta) - x1 * sin(theta)
x1' = x0 * sin(theta) + x1 * cos(theta)
Now, different model families pair up elements differently. This is where the “standard” vs “NeoX” distinction comes in.
Standard (Interleaved) vs NeoX (Split-Half) Pairing
The two RoPE styles differ in which elements they pair for rotation. The animation below shows both. Use the position slider to see how vectors rotate — low-index pairs spin fast, high-index pairs spin slow.
Standard pairing comes from the original LLaMA GGUF format. NeoX pairing (named
after GPT-NeoX) is used by HuggingFace models including Qwen3, and matches the
rotate_half function in the Python reference code:
# HuggingFace rotate_half:
x1 = x[..., :head_dim//2]
x2 = x[..., head_dim//2:]
rotated = cat(-x2, x1, dim=-1)
output = x * cos + rotated * sin
Both styles are mathematically equivalent in terms of expressiveness – the model just needs to learn weights that match the pairing convention. But you must use the same convention as the model was trained with, or the positions will be garbage.
The Standard RoPE Kernel
Let us start with the simplest kernel, rope_f16:
kernel void rope_f16(
device half *x [[buffer(0)]],
constant RoPEParams ¶ms [[buffer(1)]],
device const float *freqs [[buffer(2)]],
uint3 tid [[thread_position_in_grid]]
) {
const uint pair_idx = tid.x; // [0 .. head_dim/2)
const uint head_idx = tid.y; // [0 .. n_heads)
const uint seq_idx = tid.z; // [0 .. seq_len)
const uint half_dim = params.head_dim / 2;
if (pair_idx >= half_dim || head_idx >= params.n_heads
|| seq_idx >= params.seq_len) return;
const uint pos = seq_idx + params.pos_offset;
float freq_divisor = (freqs != nullptr) ? freqs[pair_idx]
: pow(params.theta, float(2 * pair_idx) / float(params.head_dim));
float freq = float(pos) / freq_divisor;
float cos_f = cos(freq);
float sin_f = sin(freq);
uint base = seq_idx * stride + head_idx * params.head_dim + 2 * pair_idx;
float x0 = float(x[base]);
float x1 = float(x[base + 1]);
x[base] = half(x0 * cos_f - x1 * sin_f);
x[base + 1] = half(x0 * sin_f + x1 * cos_f);
}
Dispatch: One Thread per Rotation
The dispatch grid is 3D:
Grid: (head_dim/2, n_heads, seq_len)
For head_dim=128, n_heads=32, seq_len=1 (decode):
Total threads = 64 * 32 * 1 = 2,048
For head_dim=128, n_heads=32, seq_len=512 (prefill):
Total threads = 64 * 32 * 512 = 1,048,576
Each thread handles exactly one pair of elements. No shared memory, no reductions,
no barriers. This is a pure embarrassingly-parallel kernel – every thread reads
two elements, computes a rotation, and writes two elements back. The only shared
state is the params constant buffer.
Precomputed Frequencies
Notice the frequency computation has two paths:
float freq_divisor = (freqs != nullptr) ? freqs[pair_idx]
: pow(params.theta, float(2 * pair_idx) / float(params.head_dim));
If a precomputed frequency table is provided in buffer(2), the kernel reads from
it directly. Otherwise it computes theta^(2d/head_dim) on the fly using pow.
The precomputed path avoids a transcendental function (pow) per thread. For
head_dim=128, that is 64 pow calls saved per head per position. The host
precomputes these once at initialization:
freqs[d] = theta^(2*d / head_dim) for d in [0, head_dim/2)
Then the kernel just does pos / freqs[d], which is a simple float division.
In-Place Operation
The kernel modifies x in place. This is safe because each thread works on a
unique pair of elements:
Thread (pair=0, head=0, seq=0): reads/writes x[0], x[1]
Thread (pair=1, head=0, seq=0): reads/writes x[2], x[3]
Thread (pair=0, head=1, seq=0): reads/writes x[128], x[129]
...
No two threads ever touch the same memory location. No synchronization needed.
The NeoX RoPE Kernel
The NeoX variant pairs elements at distance head_dim/2:
kernel void rope_neox_f16(
device half *x [[buffer(0)]],
constant RoPEParams ¶ms [[buffer(1)]],
device const uint *position_ids [[buffer(2)]],
device const float *freqs [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]
) {
...
// NeoX-style: pair element i with element i + head_dim/2
uint idx0 = base + pair_idx;
uint idx1 = base + pair_idx + half_dim;
float x0 = float(x[idx0]);
float x1 = float(x[idx1]);
x[idx0] = half(x0 * cos_f - x1 * sin_f);
x[idx1] = half(x1 * cos_f + x0 * sin_f);
}
The indexing difference visualized:
Standard: base + 2*pair_idx, base + 2*pair_idx + 1
(adjacent elements)
NeoX: base + pair_idx, base + pair_idx + half_dim
(elements half a head apart)
For head_dim=128, pair_idx=5:
Standard: elements 10, 11
NeoX: elements 5, 69
Function Constants for Position IDs
The NeoX kernel has an interesting feature – Metal function constants for per-token position IDs:
constant bool FC_USE_POSITION_IDS [[function_constant(0)]];
constant bool FC_HAS_POSITION_IDS = is_function_constant_defined(FC_USE_POSITION_IDS)
&& FC_USE_POSITION_IDS;
...
const uint pos = FC_HAS_POSITION_IDS ? position_ids[seq_idx]
: (seq_idx + params.pos_offset);
When FC_USE_POSITION_IDS is true, each token reads its position from a separate
buffer. This is needed for tree speculation, where tokens do not have sequential
positions – you might be evaluating multiple speculative continuations in parallel,
each at a different position in the sequence.
When false, positions are simply seq_idx + offset, which is the common case for
normal sequential generation. The function constant lets Metal compile out the
branch entirely, so there is zero overhead in the non-speculative path.
Fused RoPE + KV Cache Write: Eliminating Four Dispatches
Now we get to the real workhorses. During single-token decode (the hot path for autoregressive generation), each transformer layer needs to:
- Apply RoPE to Q (all heads)
- Apply RoPE to K (KV heads only)
- Write rotated K to the KV cache
- Write V to the KV cache (no rotation)
Naively, that is four kernel dispatches. Akunu fuses all four into one.
Here is the standard-pairing fused kernel:
kernel void rope_qkv_write_f16(
device half *qkv [[buffer(0)]],
device half *k_cache [[buffer(1)]],
device half *v_cache [[buffer(2)]],
constant RoPEQKVWriteParams ¶ms [[buffer(3)]],
device const float *freqs [[buffer(4)]],
uint2 tid [[thread_position_in_grid]]
) {
const uint pair_idx = tid.x; // [0, head_dim/2)
const uint head_idx = tid.y; // [0, n_heads)
...
The dispatch is 2D: (head_dim/2, n_heads). Each thread:
- Computes cos/sin for this position and dimension pair
- Rotates Q in-place for its head
- If
head_idx < n_kv_heads, also rotates K and writes both K and V to cache
Let us trace the data flow:
QKV Buffer Layout (contiguous):
+-----------------------------------------------+
| Q: [n_heads * head_dim] |
+-----------------------------------------------+
| K: [n_kv_heads * head_dim] | <-- at k_elem_offset
+-----------------------------------------------+
| V: [n_kv_heads * head_dim] | <-- at v_elem_offset
+-----------------------------------------------+
Thread (pair=d, head=h):
1. Compute freq = pos * freq_scale / freq_divisor
cos_f = cos(freq), sin_f = sin(freq)
2. Q rotation (all heads):
q_src = h * head_dim + 2*d
q0' = q0 * cos_f - q1 * sin_f --> qkv[q_src]
q1' = q0 * sin_f + q1 * cos_f --> qkv[q_src + 1]
3. K rotation + cache write (head h < n_kv_heads only):
k_src = k_elem_offset + h * head_dim + 2*d
k0' = k0 * cos_f - k1 * sin_f --> k_cache[cache_base + 2*d]
k1' = k0 * sin_f + k1 * cos_f --> k_cache[cache_base + 2*d + 1]
4. V cache write (straight copy, no rotation):
v_cache[cache_base + 2*d] = qkv[v_src + 2*d]
v_cache[cache_base + 2*d + 1] = qkv[v_src + 2*d + 1]
The KV Cache Layout
The cache stores K and V in head-major order:
K Cache: [n_kv_heads, max_seq_len, head_dim]
V Cache: [n_kv_heads, max_seq_len, head_dim]
cache_base = head_idx * max_seq_len * head_dim + pos * head_dim
For head 0, position 42, head_dim 128:
cache_base = 0 * max_seq_len * 128 + 42 * 128 = 42 * 128 = 5376
For head 3, position 42, head_dim 128:
cache_base = 3 * max_seq_len * 128 + 42 * 128
This layout means that for a given head, all positions are contiguous in memory. During attention, when you need to load K/V for all past positions of one head, the access pattern is a simple sequential scan – perfect for GPU memory throughput.
freq_scale: Dynamic Frequency Scaling
Notice the frequency computation includes a freq_scale parameter:
float freq = float(params.pos) * params.freq_scale / freq_divisor;
This enables dynamic RoPE scaling techniques like Linear Scaling (divide
frequencies by a factor to extend context length) and NTK-aware scaling. If
freq_scale = 1.0, you get standard RoPE. If freq_scale = 0.5, all frequencies
are halved, effectively doubling the model’s positional resolution and extending
its context window.
GQA-Aware Thread Assignment
The kernel elegantly handles Grouped Query Attention (GQA), where n_heads > n_kv_heads. All threads rotate Q (for all n_heads heads), but only threads
where head_idx < n_kv_heads do the K/V work:
n_heads=32, n_kv_heads=8 (GQA ratio 4:1):
Threads with head_idx 0-7: Rotate Q + Rotate K + Write K + Write V
Threads with head_idx 8-31: Rotate Q only
head: 0 1 2 3 4 5 6 7 8 9 10 ... 31
Q: * * * * * * * * * * * ... *
K/V: * * * * * * * * . . . ... .
This is a minor load imbalance – some threads do more work than others. But since the extra work (K rotation + 2 cache writes) is just a few memory ops, the imbalance is negligible compared to the cost of an extra kernel dispatch.
NeoX Fused Variant
The NeoX fused kernel (rope_neox_qkv_write_f16) is structurally identical to the
standard variant, but with split-half indexing:
// Standard: q_src = q_base + 2 * pair_idx, q_src + 1
// NeoX: q_lo = q_base + pair_idx
// q_hi = q_base + pair_idx + half_dim
uint q_lo = q_base + pair_idx;
uint q_hi = q_base + pair_idx + half_dim;
float q0 = float(qkv[q_lo]);
float q1 = float(qkv[q_hi]);
qkv[q_lo] = half(q0 * cos_f - q1 * sin_f);
qkv[q_hi] = half(q1 * cos_f + q0 * sin_f);
Same thread count, same dispatch shape, same GQA handling. The only difference is which pairs of memory locations get rotated together.
Batch (Prefill) Variants
During prefill (processing the initial prompt), there are multiple tokens to
process simultaneously. The single-token fused kernels assume seq_len = 1.
The batch variants add a third grid dimension for sequence position:
Single-token dispatch: (head_dim/2, n_heads, 1)
Batch dispatch: (head_dim/2, n_heads, seq_len)
The batch NeoX kernel (rope_neox_batch_kv_write_f16) handles a key difference:
Q, K, and V are separate buffers (already split by a preceding QKV split kernel),
rather than one contiguous QKV buffer:
Single-token:
buffer(0) = qkv [contiguous Q|K|V]
Batch:
buffer(0) = batchQ [seq_len, n_heads * head_dim]
buffer(1) = batchK [seq_len, n_kv_heads * head_dim]
buffer(2) = batchV [seq_len, n_kv_heads * head_dim]
Each thread computes its position as:
const uint position = params.pos + seq_idx;
So if you are prefilling tokens starting at position 0 with seq_len=512, thread
with seq_idx=100 processes position 100. The rotated K and V are written directly
into the cache at the correct positions.
An interesting implementation detail: the k_elem_offset parameter field is
repurposed to carry seq_len in the batch variants. This avoids changing the
params struct and keeps buffer layouts compatible.
The standard-pairing batch variant (rope_batch_kv_write_f16) does the same
thing but with interleaved (2i, 2i+1) pairs instead of split-half (i, i+half_dim).
The Ultimate Fusion: Head Norm + NeoX RoPE + KV Write
For Qwen3, which applies per-head RMSNorm to Q and K before RoPE, akunu offers
the ultimate fused kernel: head_norm_rope_neox_kv_write_f16. This replaces
three separate dispatches per layer:
head_rmsnorm_f16on Qhead_rmsnorm_f16on Krope_neox_qkv_write_f16for RoPE + KV cache write
All fused into a single dispatch.
kernel void head_norm_rope_neox_kv_write_f16(
device half *qkv [[buffer(0)]],
device half *k_cache [[buffer(1)]],
device half *v_cache [[buffer(2)]],
constant RoPEQKVWriteParams ¶ms [[buffer(3)]],
device const half *q_norm_weight [[buffer(4)]],
device const half *k_norm_weight [[buffer(5)]],
constant float &norm_eps [[buffer(6)]],
uint2 tgpig [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint slid [[thread_index_in_simdgroup]]
) {
Dispatch Model
This kernel uses threadgroups rather than a flat grid, because the RMSNorm requires a reduction across the head dimension:
Dispatch: threadgroups = (1, n_heads, 1)
threads_per_threadgroup = (head_dim/2, 1, 1)
For head_dim=128, n_heads=32:
32 threadgroups, each with 64 threads
Total: 2,048 threads (same as the non-fused approach)
Each threadgroup handles one head. The 64 threads within the threadgroup cooperatively compute the RMSNorm reduction, then each thread does its rotation.
Data Flow Per Thread
Let me trace the complete flow for a single thread:
Thread (pair_idx=d, head_idx=h):
PHASE 1: Q RMSNorm
+---------------------------------------------------------+
| 1a. Load Q elements: q0 = qkv[h*hd + d], |
| q1 = qkv[h*hd + d + hd/2] |
| 1b. Compute q_sq = q0*q0 + q1*q1 |
| 1c. SIMD reduce: simd_sum(q_sq) |
| 1d. Write to shared[sgid], barrier |
| 1e. Sum shared[], compute q_rms = rsqrt(sum/hd + eps) |
| 1f. q0 *= q_rms * q_norm_weight[d] |
| q1 *= q_rms * q_norm_weight[d + hd/2] |
+---------------------------------------------------------+
PHASE 2: Q NeoX RoPE
+---------------------------------------------------------+
| 2a. Compute freq = pos / pow(theta, 2*d/hd) |
| (always computed inline — no freq_scale or |
| precomputed freqs buffer in this fused kernel) |
| cos_f = cos(freq), sin_f = sin(freq) |
| 2b. qkv[q_lo] = q0 * cos_f - q1 * sin_f |
| qkv[q_hi] = q1 * cos_f + q0 * sin_f |
+---------------------------------------------------------+
PHASE 3: K RMSNorm + RoPE + Cache Write (if h < n_kv_heads)
+---------------------------------------------------------+
| 3a. Load K elements: k0, k1 |
| 3b. Barrier (reuse shared memory from Q norm) |
| 3c. SIMD reduce k_sq, compute k_rms |
| 3d. k0 *= k_rms * k_norm_weight[d] |
| k1 *= k_rms * k_norm_weight[d + hd/2] |
| 3e. Rotate + write to cache: |
| k_cache[...] = k0*cos_f - k1*sin_f |
| k_cache[...] = k1*cos_f + k0*sin_f |
+---------------------------------------------------------+
PHASE 4: V Cache Write (if h < n_kv_heads)
+---------------------------------------------------------+
| 4a. Copy V directly to cache (no norm, no rotation) |
+---------------------------------------------------------+
Shared Memory Reuse
The kernel uses only 4 floats of shared memory (for up to 4 SIMD groups with head_dim=256):
threadgroup float shared_sq[4]; // max 4 SIMD groups for head_dim=256
This memory is reused between the Q and K normalization phases. The barrier before
the K phase ensures the Q reads from shared_sq are complete before K overwrites
them:
// RMSNorm for K (reuse shared memory -- barrier ensures Q reads are done)
threadgroup_barrier(mem_flags::mem_threadgroup);
Reduction Without tg_reduce_sum
Since head_dim/2 threads is at most 128 (4 SIMD groups), the reduction is tiny.
The kernel sums across SIMD groups with a simple loop rather than the full
tg_reduce_sum helper:
float q_total_sq = 0;
for (uint s = 0; s < n_simd; s++) q_total_sq += shared_sq[s];
For 4 SIMD groups, this is 4 additions – negligible. The simd_sum within each
SIMD group does the heavy lifting.
The Performance Impact
Let us quantify the fusion benefit. For a Qwen3 model with 32 heads, head_dim=128, 8 KV heads:
Unfused (per layer, per token):
Dispatch 1: head_rmsnorm on Q (2048 threads, 2 barriers)
Dispatch 2: head_rmsnorm on K (512 threads, 2 barriers)
Dispatch 3: rope_neox_qkv_write (RoPE + KV cache write)
= 3 dispatches per layer
Fused (per layer, per token):
Dispatch 1: head_norm_rope_neox_kv_write (2048 threads, ~4 barriers)
= 1 dispatch per layer
For 32 layers: 96 dispatches --> 32 dispatches (saves 64 dispatch overheads)
At maybe 5-10 microseconds per dispatch overhead, that is 800-1600 microseconds saved per token generation. For a model generating at 100 tokens/second, that could be 8-16% of total time. Fusion matters.
Why RoPE Generalizes to Unseen Lengths
The dot product between two rotated vectors depends only on the relative position difference, not the absolute positions:
<RoPE(x, pos_i), RoPE(y, pos_j)> = f(x, y, pos_i - pos_j)
This is the mathematical property that makes RoPE so powerful: a model trained on 4096-token sequences can attend to tokens at position 8000 and 8005 just as well as positions 0 and 5, because the rotation difference is identical. The multi-scale frequency spectrum (visible in the interactive animation above — fast-spinning pair 0 for local patterns, slow-spinning pair 3 for global structure) gives the model a rich positional encoding at every scale.
Summary of RoPE Kernels
+-------------------------------------------+--------+----------+-----------+
| Kernel | Style | Tokens | Fusion |
+-------------------------------------------+--------+----------+-----------+
| rope_f16 | Std | Any | RoPE only |
| rope_neox_f16 | NeoX | Any | RoPE only |
| rope_qkv_write_f16 | Std | 1 | +KV write |
| rope_neox_qkv_write_f16 | NeoX | 1 | +KV write |
| rope_batch_kv_write_f16 | Std | Batch | +KV write |
| rope_neox_batch_kv_write_f16 | NeoX | Batch | +KV write |
| head_norm_rope_neox_kv_write_f16 | NeoX | 1 | +Norm+KV |
+-------------------------------------------+--------+----------+-----------+
Key takeaways:
- One thread per rotation pair – embarrassingly parallel, no reductions needed for standalone RoPE.
- Two pairing styles – standard (interleaved) and NeoX (split-half), chosen to match the model’s training convention.
- Fused dispatches – the single-token decode path fuses RoPE for Q and K with KV cache writes into one dispatch. The Qwen3 path further fuses per-head RMSNorm.
- GQA-aware – threads for heads beyond
n_kv_headsonly rotate Q; those withinn_kv_headsalso handle K rotation and K/V cache writes. - Precomputed frequencies – avoids per-thread
pow()calls during inference. - freq_scale – a single multiplier enables dynamic context length extension.
- Function constants – Metal specialization constants compile out the position-ID branch for the common (non-speculative) case.
Next, we will look at the kernels that happen before and after normalization and RoPE: embedding lookups and activation functions.
Embedding and Activation Kernels
Before a transformer can normalize or attend or do anything useful, it needs to convert integer token IDs into dense floating-point vectors. And between the big matrix multiplies, it needs to apply nonlinear activation functions to keep the network from collapsing into a single linear transformation. These are the “glue” kernels – individually simple, collectively essential.
In this chapter we will cover akunu’s embedding lookup kernels (FP16, BF16, and quantized variants), the SiLU and GELU activation kernels (including fused gated versions for FFN blocks), and the collection of utility kernels that handle the mundane but necessary work of moving data around: vector addition, residual connections, QKV splits, transposes, bias adds, and head rearrangement.
Embedding Lookup: From Token IDs to Vectors
The embedding table is conceptually a 2D array of shape [vocab_size, dim]. To
look up a token, you just index into the row for that token ID and copy it out.
On a GPU, you launch enough threads to copy all elements in parallel.
FP16 Embedding Lookup
The simplest case – the embedding table is stored in half-precision:
kernel void embedding_lookup_f16(
device const uint32_t *tokens [[buffer(0)]],
device const half *table [[buffer(1)]],
device half *output [[buffer(2)]],
constant EmbeddingParams ¶ms [[buffer(3)]],
uint2 tid [[thread_position_in_grid]]
) {
const uint dim_idx = tid.x;
const uint token_idx = tid.y;
if (token_idx >= params.num_tokens || dim_idx >= params.dim) return;
uint token_id = tokens[token_idx];
output[token_idx * params.dim + dim_idx] = table[token_id * params.dim + dim_idx];
}
That is it. A single memory read and a single memory write. The dispatch grid is 2D:
Grid: (dim, num_tokens)
For dim=4096, num_tokens=1 (decode):
4096 threads, each copies one element
For dim=4096, num_tokens=512 (prefill):
2,097,152 threads
The memory access pattern deserves attention. Threads with consecutive dim_idx
values (same token_idx) read consecutive memory locations from the same table
row. This is perfectly coalesced:
Token "hello" (id=15043):
Thread (0, 0): table[15043 * 4096 + 0] --> output[0]
Thread (1, 0): table[15043 * 4096 + 1] --> output[1]
Thread (2, 0): table[15043 * 4096 + 2] --> output[2]
...
Thread (4095, 0): table[15043 * 4096 + 4095] --> output[4095]
128 bytes per memory transaction / 2 bytes per element = 64 elements per fetch
SIMD width = 32 --> 2 fetches serve one full SIMD group
However, there is a catch. Different tokens will read from completely different rows of the table, potentially megabytes apart. For a vocab of 32K with dim=4096, the table is 32768 * 4096 * 2 = 256 MB. Only a tiny fraction fits in cache. The first token’s row load will be a cache miss for sure. But since we only read each row once, caching is irrelevant – this is a pure streaming kernel.
BF16 Embedding Lookup
For M4 and later chips that have native BF16 support, akunu provides a BF16 variant:
kernel void embedding_lookup_bf16(
device const uint32_t *token_ids [[buffer(0)]],
device const bfloat *table [[buffer(1)]],
device half *output [[buffer(2)]],
constant uint &dim [[buffer(3)]],
uint2 tid [[thread_position_in_grid]]
) {
const uint d_idx = tid.x;
const uint token_idx = tid.y;
if (d_idx >= dim) return;
const uint token_id = token_ids[token_idx];
output[token_idx * dim + d_idx] = half(table[token_id * dim + d_idx]);
}
Notice it reads BF16 but outputs FP16. BF16 has the same exponent range as FP32
(8 bits) but only 7 bits of mantissa versus FP16’s 10. The conversion happens
inline – the half() cast truncates the mantissa. BF16 weights are increasingly
common in HuggingFace models, and on M4 hardware this cast is a single-cycle
operation.
Quantized Embedding Lookup: Q4_0 On-the-Fly Dequantization
Now for the interesting one. When the embedding table is quantized to Q4_0 format (4 bits per weight), the kernel must dequantize on the fly. Each Q4_0 block contains 32 elements packed into 18 bytes:
Q4_0 Block Structure (18 bytes total):
+------+------------------+
| d | qs[16] |
| (2B) | (16 bytes) |
+------+------------------+
| |
| +-- 32 four-bit values, nibble-packed:
| qs[j] low nibble = element j (j = 0..15)
| qs[j] high nibble = element j + 16 (j = 0..15)
|
+-- FP16 scale factor
Dequantization: value = d * (nibble - 8)
(Q4_0 stores unsigned 0-15, centered at 8)
Here is the kernel:
kernel void embedding_lookup_q4_0(
device const uint32_t *token_ids [[buffer(0)]],
device const block_q4_0 *table [[buffer(1)]],
device half *output [[buffer(2)]],
constant EmbeddingParams ¶ms [[buffer(3)]],
uint2 tid [[thread_position_in_grid]]
) {
const uint d_idx = tid.x;
const uint token_idx = tid.y;
if (d_idx >= params.dim || token_idx >= params.num_tokens) return;
const uint token_id = token_ids[token_idx];
const uint blocks_per_row = params.dim / QK4_0;
const uint block_idx = d_idx / QK4_0;
const uint elem_idx = d_idx % QK4_0;
device const block_q4_0 &blk = table[token_id * blocks_per_row + block_idx];
half scale = blk.d;
uint is_high = elem_idx / 16;
uint j = elem_idx % 16;
uint8_t nibble = is_high ? (blk.qs[j] >> 4) : (blk.qs[j] & 0xF);
half val = scale * half(int(nibble) - 8);
output[token_idx * params.dim + d_idx] = val;
}
Let us trace the dequantization for element 21 within a block:
elem_idx = 21
is_high = 21 / 16 = 1 (it is in the upper half)
j = 21 % 16 = 5 (byte index within qs[])
nibble = qs[5] >> 4 (high nibble of byte 5)
value = scale * (nibble - 8)
Example: qs[5] = 0xA3, scale = 0.125
nibble = 0xA = 10
value = 0.125 * (10 - 8) = 0.125 * 2 = 0.25
The memory savings are significant:
FP16 table: vocab_size * dim * 2 bytes
Q4_0 table: vocab_size * (dim/32) * 18 bytes
For vocab=128256, dim=4096:
FP16: 128256 * 4096 * 2 = 1,050,673,152 bytes (~1.0 GB)
Q4_0: 128256 * 128 * 18 = 295,239,680 bytes (~282 MB)
Compression ratio: ~3.6x
The downside is slightly more ALU work per element (a shift, a mask, a subtract, a multiply). But since embedding lookup is purely memory-bound – you are reading from a huge table with essentially random access per token – the extra ALU is completely hidden behind memory latency.
Akunu provides similar quantized lookup kernels for Q4_1, Q5_0, Q5_K, Q8_0, Q2_K, Q3_K, Q4_K, and Q6_K formats. Each follows the same pattern: one thread per output element, read the relevant quantized block, extract and dequantize the value, write FP16 output.
Positional Embedding Addition
For models that use learned positional embeddings (like Whisper), there is a simple kernel that adds a position-dependent vector:
kernel void pos_embed_add_f16(
device const half *input [[buffer(0)]],
device const half *pos_table [[buffer(1)]],
device half *output [[buffer(2)]],
constant ElementwiseParams &p [[buffer(3)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
uint pos = p._pad0;
uint dim = p.count;
output[tid] = input[tid] + pos_table[pos * dim + tid];
}
One thread per dimension, one addition per thread. The position index is passed
via a repurposed padding field in ElementwiseParams (a common trick to avoid
defining a new params struct for a one-off kernel).
Activation Functions: SiLU and GELU
After the FFN’s first linear projection, the output needs a nonlinearity. Modern LLMs use two main activation functions:
SiLU (Sigmoid Linear Unit):
f(x) = x * sigmoid(x) = x / (1 + exp(-x))
GELU (Gaussian Error Linear Unit):
f(x) = x * Phi(x) ~ 0.5 * x * (1 + tanh(sqrt(2/pi) * x * (1 + 0.044715 * x^2)))
Let us visualize these:
SiLU GELU
y | y |
2 | / 2 | /
| / | /
1 | / 1 | /
| / | ./
0 |---____/ 0 |---___./
| / | /
-1 |/ -1 |/
+--+--+--+--+--+--> x +--+--+--+--+--+--> x
-4 -2 0 2 4 -4 -2 0 2 4
Both are smooth, monotonically increasing for positive x, and have a soft “gate” that suppresses negative values. The key difference: GELU has a steeper transition around x=0 (it goes slightly negative before recovering), while SiLU is smoother.
SiLU Kernel
The standalone SiLU kernel is trivial:
kernel void silu_f16(
device const half *input [[buffer(0)]],
device half *output [[buffer(1)]],
constant ElementwiseParams &p [[buffer(2)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
output[tid] = act_silu(input[tid]);
}
Where act_silu is defined in KernelCommon.h:
inline half act_silu(half x) {
return x / (half(1) + exp(-x));
}
One thread, one element, one activation. The exp(-x) is the expensive part –
it compiles to a hardware transcendental approximation on Apple GPUs. For FP16
inputs, the fast-math exp is accurate to about 3 ULPs, which is more than
adequate.
Fused SiLU Gate: The SwiGLU Pattern
LLaMA and its descendants use SwiGLU (SiLU-Gated Linear Unit) in the FFN. The pattern is:
FFN(x) = (SiLU(W_gate * x)) * (W_up * x)
The gate and up projections are computed separately by GEMM, producing two vectors
of size ff_dim. The fused kernel combines the activation and element-wise
multiply:
kernel void silu_gate_f16(
device const half *gate [[buffer(0)]],
device const half *up [[buffer(1)]],
device half *output [[buffer(2)]],
constant ElementwiseParams &p [[buffer(3)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
output[tid] = act_silu(gate[tid]) * up[tid];
}
This is twice the work of standalone SiLU but saves a full pass over the data:
Unfused:
Kernel 1: silu(gate) --> temp[] // read gate, write temp
Kernel 2: temp * up --> output[] // read temp, read up, write output
Total: 2 reads + 1 write + 1 read + 1 write = 5 memory ops per element
Fused:
Kernel 1: silu(gate) * up --> output[] // read gate, read up, write output
Total: 2 reads + 1 write = 3 memory ops per element
The fused version eliminates the temporary buffer entirely – 40% less memory traffic and one fewer dispatch.
Strided SiLU Gate: Batch Mode
During prefill, the gate and up projections may be packed into a single buffer with a specific stride layout:
buf layout: [M rows, 2 * ff_dim columns]
Row structure:
+---------------------------+---------------------------+
| gate[0..ff_dim-1] | up[0..ff_dim-1] |
+---------------------------+---------------------------+
<-------- ff_dim ----------><-------- ff_dim ---------->
<------------------ 2 * ff_dim ----------------------->
The strided kernel handles this layout with a 2D grid:
kernel void silu_gate_strided_f16(
device const half *buf [[buffer(0)]],
device half *output [[buffer(1)]],
constant uint32_t &ffDim [[buffer(2)]],
uint2 tid [[thread_position_in_grid]]
) {
const uint j = tid.x; // [0, ffDim)
const uint row = tid.y; // [0, M)
if (j >= ffDim) return;
const uint gate_idx = row * 2 * ffDim + j;
const uint up_idx = row * 2 * ffDim + ffDim + j;
output[row * ffDim + j] = act_silu(buf[gate_idx]) * buf[up_idx];
}
Each thread computes one output element from the gate and up values at the appropriate offsets within the packed buffer.
GELU and GELU Gate
For Gemma (which uses GeGLU rather than SwiGLU), there are corresponding GELU kernels. The standalone GELU:
kernel void gelu_f16(
device const half *input [[buffer(0)]],
device half *output [[buffer(1)]],
constant ElementwiseParams &p [[buffer(2)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
output[tid] = act_gelu(input[tid]);
}
Where act_gelu_f32 in KernelCommon.h uses the tanh approximation:
inline float act_gelu_f32(half x) {
float xf = float(x);
constexpr float SQRT_2_OVER_PI = 0.7978845608f;
constexpr float GELU_COEF_A = 0.044715f;
return 0.5f * xf * (1.0f + precise::tanh(SQRT_2_OVER_PI * xf
* (1.0f + GELU_COEF_A * xf * xf)));
}
Notice the precise::tanh – this uses the precise (not fast-math) tanh to match
the reference implementation exactly. The comment in the source explains why the
fused GELU gate kernel does all computation in float32:
/// GELU-gate: output = gelu(gate) * up
/// All computation in float32 to match llama.cpp precision
/// (F16 GELU*up can overflow)
kernel void gelu_gate_f16(
device const half *gate [[buffer(0)]],
device const half *up [[buffer(1)]],
device half *output [[buffer(2)]],
constant ElementwiseParams &p [[buffer(3)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
output[tid] = half(act_gelu_f32(gate[tid]) * float(up[tid]));
}
The GELU output can approach 1.0 for large inputs, and multiplying by up (which
can also be large) risks FP16 overflow. Doing the multiply in float32 and casting
back to FP16 at the end prevents this.
Utility Kernels: The Supporting Cast
Beyond embeddings and activations, akunu has a collection of utility kernels that handle common data manipulation operations. These are all simple – one thread per element, no reductions, no shared memory – but they are called frequently throughout inference.
Vector Add
kernel void vector_add_f16(
device const half *A [[buffer(0)]],
device const half *B [[buffer(1)]],
device half *C [[buffer(2)]],
constant ElementwiseParams &p [[buffer(3)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
C[tid] = A[tid] + B[tid];
}
As simple as it gets. There are both FP16 and FP32 variants. Used for general purpose addition where neither input is an “accumulator” (unlike residual add, which semantically represents a skip connection).
Residual Add
kernel void residual_add_f16(
device const half *a [[buffer(0)]],
device const half *b [[buffer(1)]],
device half *output [[buffer(2)]],
constant ElementwiseParams &p [[buffer(3)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
output[tid] = a[tid] + b[tid];
}
Functionally identical to vector_add_f16. The separate kernel exists for semantic clarity in the pipeline (and potentially for profiling – you can see “residual_add” in GPU trace tools and know exactly what part of the transformer you are looking at).
Bias Add: Broadcast Addition
After a GEMM/GEMV produces output of shape [rows, dim], many models add a bias
vector of shape [dim] to every row:
kernel void bias_add_f16(
device half *data [[buffer(0)]],
device const half *bias [[buffer(1)]],
constant ElementwiseParams &p [[buffer(2)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= p.count) return;
uint col = tid % p._pad0; // _pad0 = dim
data[tid] = data[tid] + bias[col];
}
The modulo operation tid % dim extracts the column index, which indexes into
the 1D bias vector. The same bias value is added to every row. This is an in-place
operation – no separate output buffer needed.
Before: After:
+--------+--------+--------+ +--------+--------+--------+
| r0,c0 | r0,c1 | r0,c2 | | +b[0] | +b[1] | +b[2] |
| r1,c0 | r1,c1 | r1,c2 | | +b[0] | +b[1] | +b[2] |
| r2,c0 | r2,c1 | r2,c2 | | +b[0] | +b[1] | +b[2] |
+--------+--------+--------+ +--------+--------+--------+
QKV Split
After the fused QKV linear projection, the output is a single buffer of shape
[seq_len, q_dim + kv_dim + kv_dim]. The QKV split kernel separates it into
three buffers:
kernel void qkv_split_f16(
device const half *src [[buffer(0)]],
device half *dst_q [[buffer(1)]],
device half *dst_k [[buffer(2)]],
device half *dst_v [[buffer(3)]],
constant QKVSplitParams ¶ms [[buffer(4)]],
uint tid [[thread_position_in_grid]]
) {
const uint total = params.seq_len * params.qkv_dim;
if (tid >= total) return;
const uint t = tid / params.qkv_dim;
const uint col = tid % params.qkv_dim;
half val = src[tid];
if (col < params.q_dim) {
dst_q[t * params.q_dim + col] = val;
} else if (col < params.q_dim + params.kv_dim) {
dst_k[t * params.kv_dim + (col - params.q_dim)] = val;
} else {
dst_v[t * params.kv_dim + (col - params.q_dim - params.kv_dim)] = val;
}
}
Visually:
Source: [seq_len, q_dim + kv_dim + kv_dim]
Row t:
+-------------------------------+------------------+------------------+
| Q columns (0..q_dim-1) | K (q_dim..q_dim | V (q_dim+kv_dim |
| | +kv_dim-1) | ..end) |
+-------------------------------+------------------+------------------+
| | |
v v v
dst_q[t, 0..q_dim-1] dst_k[t, 0..kv_dim-1] dst_v[t, 0..kv_dim-1]
The branch (if col < q_dim ... else if ... else) means threads in the same
SIMD group may take different paths. This causes some warp divergence. However,
since the branches are determined purely by col (which varies across threads
in a predictable pattern), the divergence is minimal – large contiguous blocks
of threads all take the same branch.
For GQA models where kv_dim < q_dim, the Q region is much larger. With
n_heads=32, n_kv_heads=8, head_dim=128:
q_dim = 32 * 128 = 4096
kv_dim = 8 * 128 = 1024
qkv_dim = 4096 + 1024 + 1024 = 6144
Threads 0-4095: write to dst_q (66.7%)
Threads 4096-5119: write to dst_k (16.7%)
Threads 5120-6143: write to dst_v (16.7%)
Transpose
A straightforward 2D transpose:
kernel void transpose_f16(
device const half *input [[buffer(0)]],
device half *output [[buffer(1)]],
constant ElementwiseParams &p [[buffer(2)]],
uint2 gid [[thread_position_in_grid]]
) {
uint row = gid.y;
uint col = gid.x;
uint rows = p.count;
uint cols = p._pad0;
if (row >= rows || col >= cols) return;
output[col * rows + row] = input[row * cols + col];
}
The classic output[col][row] = input[row][col] pattern. The dispatch grid is
(ceil(cols/16), ceil(rows/16)) with threadgroup size (16, 16).
A production-quality transpose would typically use shared memory tiling to avoid
uncoalesced writes (writing to output[col * rows + row] is strided when threads
have consecutive col values). However, for the matrix sizes akunu typically
transposes (attention scores, head rearrangement intermediates), the simple version
is fast enough.
Head Rearrange
Attention requires reordering data between two layouts:
Position-major: [seq_len, n_heads * head_dim] (natural GEMM output)
Head-major: [n_heads, seq_len, head_dim] (what attention needs)
The forward kernel converts position-major to head-major:
kernel void head_rearrange_forward_f16(
device const half *src [[buffer(0)]],
device half *dst [[buffer(1)]],
constant uint &seq [[buffer(2)]],
constant uint &n_heads [[buffer(3)]],
constant uint &head_dim [[buffer(4)]],
uint tid [[thread_position_in_grid]]
) {
uint dim = n_heads * head_dim;
uint total = seq * dim;
if (tid >= total) return;
uint pos = tid / dim;
uint rem = tid % dim;
uint head = rem / head_dim;
uint d = rem % head_dim;
dst[head * seq * head_dim + pos * head_dim + d] = src[tid];
}
Visualized for seq=3, n_heads=2, head_dim=4:
Source (position-major):
pos 0: [h0d0 h0d1 h0d2 h0d3 | h1d0 h1d1 h1d2 h1d3]
pos 1: [h0d0 h0d1 h0d2 h0d3 | h1d0 h1d1 h1d2 h1d3]
pos 2: [h0d0 h0d1 h0d2 h0d3 | h1d0 h1d1 h1d2 h1d3]
Destination (head-major):
head 0: [pos0: d0 d1 d2 d3 | pos1: d0 d1 d2 d3 | pos2: d0 d1 d2 d3]
head 1: [pos0: d0 d1 d2 d3 | pos1: d0 d1 d2 d3 | pos2: d0 d1 d2 d3]
The inverse kernel does the opposite transformation. Both are simple scatter/gather operations with integer division and modulo to compute source and destination indices.
Performance Characteristics
Let us categorize these kernels by their computational profile:
+---------------------------+----------+--------+-----------+
| Kernel | Type | AI | Bound by |
+---------------------------+----------+--------+-----------+
| embedding_lookup_f16 | Gather | 0 | Memory |
| embedding_lookup_q4_0 | Gather | ~0.5 | Memory |
| silu_f16 | Map | ~2 | Memory |
| silu_gate_f16 | Map | ~3 | Memory |
| gelu_f16 | Map | ~5 | Memory |
| gelu_gate_f16 | Map | ~6 | Memory |
| vector_add_f16 | Map | ~0.3 | Memory |
| residual_add_f16 | Map | ~0.3 | Memory |
| bias_add_f16 | Map | ~0.5 | Memory |
| qkv_split_f16 | Scatter | 0 | Memory |
| transpose_f16 | Scatter | 0 | Memory |
| head_rearrange_*_f16 | Scatter | 0 | Memory |
+---------------------------+----------+--------+-----------+
AI = Arithmetic Intensity (FLOPs per byte transferred)
Every single one is memory-bound. The most compute-intensive is GELU (which involves a tanh, several multiplies, and an add), but even that has an arithmetic intensity well below what Apple Silicon can sustain. The dominant cost is moving bytes to and from main memory.
This is why fusion matters so much. Fusing silu_gate saves an entire buffer
round-trip. Fusing residual_rmsnorm eliminates a data pass. Each fusion does
not speed up the math – it reduces the memory traffic.
The Data Flow Through Inference
Let us trace how these kernels fit together in a single decoder layer:
Token IDs
|
v
[embedding_lookup_f16 / q4_0] --> hidden[seq_len, dim]
|
v
[residual_rmsnorm_f16] --> norm_out[seq_len, dim] (+ res_out for skip)
|
v
[GEMM: QKV projection] --> qkv[seq_len, q_dim + 2*kv_dim]
|
v
[qkv_split_f16] --> Q[seq, q_dim], K[seq, kv_dim], V[seq, kv_dim]
| (or fused into rope_*_kv_write for decode)
v
[rope_neox_f16 / rope_f16] --> Q', K' (rotated)
|
v
[head_rearrange_forward_f16] --> Q'[n_heads, seq, hd], K'[n_kv, seq, hd]
|
v
[Attention GEMM + Softmax] --> attn_out[n_heads, seq, hd]
|
v
[head_rearrange_inverse_f16] --> attn_out[seq, dim]
|
v
[GEMM: output projection] --> proj_out[seq, dim]
|
v
[bias_add_f16] --> proj_out += bias (if model has bias)
|
v
[residual_add_f16] --> hidden = proj_out + res_out
|
v
[residual_rmsnorm_f16] --> norm_out (for FFN)
|
v
[GEMM: gate + up projection] --> gate[seq, ff_dim], up[seq, ff_dim]
|
v
[silu_gate_f16 / gelu_gate_f16] --> activated[seq, ff_dim]
|
v
[GEMM: down projection] --> ffn_out[seq, dim]
|
v
[residual_add_f16] --> hidden = ffn_out + res_out
|
v
(next layer)
The embedding lookup runs once at the beginning. Everything else repeats for each layer. The GEMMs dominate runtime (they are the only compute-bound kernels), while everything else fills in the gaps. But collectively, these “gap” kernels add up – for a 32-layer model, you might have 200+ non-GEMM dispatches per token. Each one needs to be as lean as possible.
Summary
Embedding and activation kernels are the connective tissue of the inference pipeline:
- Embedding lookup – pure gather from a table. FP16 is a simple copy, Q4_0 dequantizes on-the-fly with nibble extraction. BF16 variant for M4+ hardware.
- SiLU/GELU activations – one thread per element, hardware transcendentals.
Fused gate variants (
silu_gate,gelu_gate) eliminate a temporary buffer. GELU gate uses FP32 intermediate to prevent overflow. - Utility kernels – vector add, residual add, bias add, QKV split, transpose, head rearrange. All one-thread-per-element, all memory-bound, all essential for moving data between the big GEMMs in the right layout.
The recurring theme: these kernels are individually trivial but collectively critical. They are all memory-bound, so the optimization strategy is always the same – minimize the number of dispatches and the number of memory round-trips. Fusion is the primary weapon.
Sampling Kernels
The previous chapter on the sampling pipeline described the algorithmic flow from logits to tokens. This chapter dives into the Metal kernels that implement sampling on the GPU: the Top-K selector, the Gumbel-max sampler, the argmax reduction, and the repetition penalty kernel. These kernels live in backend/metal/kernels/metal/kernel/sampling/.
The GPU sampling path is critical for Akunu’s chain decode performance: by keeping sampling entirely on the GPU, the engine avoids a CPU roundtrip for every generated token. The Gumbel-max kernel is particularly important because it enables sampled generation at the same throughput as greedy decoding.
Kernel 1: Argmax (argmax_f16)
Argmax is the simplest sampling kernel: find the index of the maximum value in the logit buffer. Despite its simplicity, getting it right on a GPU requires a two-level reduction.
Dispatch
Grid: (1, 1, 1) — single threadgroup
Threadgroup: (1024, 1, 1) — 1024 threads = 32 SIMD groups
Phase 1: Thread-Local Scan
float best_val = -INFINITY;
uint best_idx = 0;
for (uint i = tid; i < N; i += tg_size) {
float v = float(logits[i]);
if (v > best_val) { best_val = v; best_idx = i; }
}
Each thread scans a strided subset of the vocabulary. For a 128K vocabulary with 1024 threads, each thread examines 128 elements.
Phase 2: SIMD Reduction
for (uint offset = SIMD_WIDTH / 2; offset > 0; offset >>= 1) {
float other_val = simd_shuffle_down(best_val, offset);
uint other_idx = simd_shuffle_down(best_idx, offset);
if (slid + offset < SIMD_WIDTH && other_val > best_val) {
best_val = other_val;
best_idx = other_idx;
}
}
This is a classic butterfly reduction within a SIMD group. Each step halves the active lanes:
Step 1: lanes 0-15 compare with lanes 16-31 (offset=16)
Step 2: lanes 0-7 compare with lanes 8-15 (offset=8)
Step 3: lanes 0-3 compare with lanes 4-7 (offset=4)
Step 4: lanes 0-1 compare with lanes 2-3 (offset=2)
Step 5: lane 0 compares with lane 1 (offset=1)
After 5 steps, lane 0 of each SIMD group holds the group’s local winner.
Phase 3: Cross-SIMD-Group Reduction
threadgroup float shared_val[32];
threadgroup uint shared_idx[32];
if (slid == 0) {
shared_val[sgid] = best_val;
shared_idx[sgid] = best_idx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgid == 0 && slid < n_sg) {
best_val = shared_val[slid];
best_idx = shared_idx[slid];
// Same butterfly reduction on 32 SG winners
for (uint offset = SIMD_WIDTH / 2; offset > 0; offset >>= 1) {
// ...
}
if (slid == 0) *result = best_idx;
}
The 32 SIMD group winners are written to threadgroup shared memory, then a single SIMD group performs a final butterfly reduction. The overall winner is written to the result buffer.
Total work: One strided scan of the vocabulary + two levels of reduction. For 128K vocabulary, this takes approximately 5-10 microseconds on Apple Silicon.
Kernel 2: Gumbel Top-K (gumbel_topk_f16)
The Gumbel-max kernel is the most complex sampling kernel. It performs the complete sampling pipeline on the GPU in a single dispatch:
Dispatch: grid=(1), threadgroup=(1024)
One threadgroup of 1024 threads processes the entire vocabulary. The kernel executes four phases:
Phase 0: Repetition Penalty
float rep_penalty = params.repeat_penalty;
if (rep_penalty > 1.0f && params.position > 0) {
for (uint i = tid; i < params.position; i += 1024) {
uint token = token_ids[i + 1];
if (token < V) {
float val = float(logits[token]);
logits[token] = half(val > 0 ? val / rep_penalty : val * rep_penalty);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
Each thread processes a strided subset of previously generated tokens, applying the same asymmetric penalty as the CPU path (divide positive logits, multiply negative logits by the penalty factor).
Phase 1: Find Global Maximum
float local_max = -INFINITY;
for (uint i = tid; i < V; i += 1024)
local_max = max(local_max, float(logits[i]));
local_max = simd_max(local_max);
if (slid == 0) tg_vals[sgid] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float m = tg_vals[0];
for (uint s = 1; s < 32; s++) m = max(m, tg_vals[s]);
tg_vals[0] = m;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float global_max = tg_vals[0];
A standard two-level max reduction: SIMD-first with simd_max, then cross-SIMD via threadgroup memory.
Phase 2: Binary Search for K-th Largest
This is the most innovative part of the kernel. Instead of sorting the entire vocabulary to find the top-K, it uses binary search on the value domain to find the threshold value below which exactly K elements survive:
- Init:
lo = global_max - 30,hi = global_max - Iteration 1:
mid = (lo+hi)/2– count elements > mid (1024 threads in parallel). If count > K:lo = mid, else:hi = mid - Iteration 2: Narrower range, count again, adjust bounds
- … 12 iterations total (2^12 = 4096x precision refinement)
- Result: threshold separates top-K from rest
float threshold = global_max;
if (top_k > 1) {
float lo = global_max - 30.0f;
float hi = global_max;
for (int iter = 0; iter < 12; iter++) {
float mid = (lo + hi) * 0.5f;
uint local_count = 0;
for (uint i = tid; i < V; i += 1024)
local_count += (float(logits[i]) > mid) ? 1 : 0;
// SIMD reduction
uint sg_count = simd_sum(local_count);
if (slid == 0) tg_counts[sgid] = sg_count;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
uint total = 0;
for (uint s = 0; s < 32; s++) total += tg_counts[s];
tg_counts[0] = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tg_counts[0] > (uint)top_k)
lo = mid; // too many survivors, raise threshold
else
hi = mid; // too few survivors, lower threshold
}
threshold = hi;
}
Each iteration of the binary search counts how many elements exceed the threshold. The count uses a parallel scan (1024 threads, each checking ~128 elements), followed by a SIMD + threadgroup reduction.
12 iterations of binary search with a range of 30.0 gives a precision of 30 / 2^12 ≈ 0.007, which is more than sufficient to separate top-K from the rest. Each iteration requires 2 barriers and one full vocabulary scan, for a total of 24 barriers and 12 scans.
Why binary search instead of sorting? Sorting 128K elements on a GPU is expensive (O(N log N)). The binary search approach is O(N * log(range/precision)) ≈ O(N * 12), which is effectively O(N) and much faster. It does not find the exact K-th element – it finds a threshold that separates approximately K elements from the rest – but this is sufficient for sampling.1
Phase 2b: Top-P Filtering
If top-p is enabled, the kernel performs a second binary search to find the probability threshold:
if (top_p < 1.0f && top_p > 0.0f) {
// Compute softmax sum for survivors
float local_exp_sum = 0;
for (uint i = tid; i < V; i += 1024) {
float val = float(logits[i]);
if (val >= threshold)
local_exp_sum += fast::exp(val - global_max);
}
// Reduce to get total_exp
// Binary search for probability mass threshold
float target_exp = top_p * total_exp;
float lo_p = threshold;
float hi_p = global_max;
for (int iter = 0; iter < 8; iter++) {
// Count exp sum above mid_p
// Adjust bounds
}
threshold = max(threshold, hi_p);
}
This is a binary search on the cumulative probability mass, not on raw logit values. 8 iterations are sufficient because the range is narrower (already within the top-K region).
Phase 2c: Min-P Filtering
Min-P is the simplest filter – just a threshold relative to the maximum:
if (params.min_p > 0.0f && params.min_p < 1.0f) {
float min_p_threshold = global_max + log(params.min_p);
threshold = max(threshold, min_p_threshold);
}
Since P(token) ∝ exp(logit), the condition P(token) < min_p * P(max_token) becomes logit < max_logit + log(min_p). This is a single comparison with no iteration needed.
Phase 3: Apply Mask and Gumbel Noise
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
for (uint i = tid; i < V; i += 1024) {
float val = float(logits[i]);
if (val < threshold) {
logits[i] = half(-INFINITY);
} else {
float u = pcg_float(element_seed + i);
u = clamp(u, 1e-7f, 1.0f - 1e-7f);
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);
}
}
Tokens below the threshold are masked to -inf (they can never win the argmax). Surviving tokens receive Gumbel noise scaled by temperature. The argmax of these perturbed logits is equivalent to sampling from the filtered distribution.2
The PCG Hash RNG
inline float pcg_float(uint state) {
state = state * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
word = (word >> 22u) ^ word;
return float(word) / 4294967296.0f;
}
PCG (Permuted Congruential Generator) is a stateless hash function: the same input always produces the same output. The constants are chosen to ensure good statistical properties:3
| Constant | Purpose |
|---|---|
747796405 | LCG multiplier (chosen for good spectral properties) |
2891336453 | LCG increment |
277803737 | Output permutation multiplier |
The output is uniformly distributed in [0, 1) after division by 2^32. The clamp to [1e-7, 1-1e-7] prevents log(0) and log(-0) in the Gumbel noise computation.
The seed (params.position + params.seed_offset) * 2654435761u + i ensures:
- Different noise per token position (via
params.position, patched in chain decode) - Different noise per vocabulary element (via
+ i) - Different noise per generation call (via
params.seed_offset, time-based)
Kernel 3: Top-K Select (topk_select_f16)
The top-K select kernel is an alternative to the Gumbel binary search approach. It explicitly selects the top-K logits using a min-heap algorithm:
Chunked Architecture
Grid: (n_chunks, 1, 1) — one threadgroup per chunk
Threadgroup: (chunk_size, 1, 1) — typically 512 threads
The vocabulary is divided into chunks, and each threadgroup finds its local top-K. The host then merges per-chunk results on the CPU.
Thread-Local Best
for (uint i = tid; i < chunk_size; i += tg_size) {
uint global_idx = base + i;
if (global_idx >= vocab_size) break;
float v = float(logits[global_idx]);
if (v > my_best_val) {
my_best_val = v;
my_best_idx = global_idx;
}
}
Each thread finds its single best candidate from a strided scan.
Min-Heap Selection
// Thread 0 picks top_k from tg_size candidates via min-heap
float heap_vals[MAX_TOP_K]; // MAX_TOP_K = 128
uint heap_idxs[MAX_TOP_K];
uint heap_size = 0;
for (uint s = 0; s < tg_size; ++s) {
float v = tg_vals[s];
uint gi = tg_idxs[s];
if (heap_size < k) {
// Insert and bubble up
heap_vals[heap_size] = v;
heap_idxs[heap_size] = gi;
heap_size++;
uint pos = heap_size - 1;
while (pos > 0) {
uint parent = (pos - 1) / 2;
if (heap_vals[parent] > heap_vals[pos]) break;
// swap
pos = parent;
}
} else if (v > heap_vals[0]) {
// Replace min (heap root) and sift down
heap_vals[0] = v;
heap_idxs[0] = gi;
uint pos = 0;
while (true) {
uint l = 2*pos + 1, r = 2*pos + 2;
uint smallest = pos;
if (l < heap_size && heap_vals[l] < heap_vals[smallest]) smallest = l;
if (r < heap_size && heap_vals[r] < heap_vals[smallest]) smallest = r;
if (smallest == pos) break;
// swap
pos = smallest;
}
}
}
This is a classic min-heap of size K. The heap root always holds the smallest of the K largest values seen so far. When a new candidate exceeds the root, it replaces the root and is sifted down to maintain the heap property.
The complexity is O(tg_size * log(K)) per chunk. With tg_size=512 and K=128, this is about 512 * 7 = 3584 comparisons – fast enough for a single thread.
Comparison with Gumbel Binary Search
| Aspect | Top-K Select | Gumbel Binary Search |
|---|---|---|
| Output | Exact top-K indices + values | Threshold value |
| CPU post-processing | Merge per-chunk results | None (followed by argmax) |
| Chain decode compatible | No (CPU merge needed) | Yes |
| Use case | CPU sampling path | GPU sampling path |
The top-K select kernel produces exact results but requires CPU post-processing, making it unsuitable for chain decode. The Gumbel binary search is approximate (the threshold may not select exactly K elements) but keeps everything on the GPU.
Kernel 4: Repetition Penalty (repetition_penalty_f16)
The standalone repetition penalty kernel applies the penalty to specific token positions in the logit buffer:
kernel void repetition_penalty_f16(
device half *logits, device const uint32_t *token_ids,
constant float &penalty, constant uint32_t &n_tokens,
uint tid [[thread_position_in_grid]]
) {
if (tid >= n_tokens) return;
uint32_t token = token_ids[tid];
float val = float(logits[token]);
logits[token] = half(val > 0 ? val / penalty : val * penalty);
}
Each thread handles one previous token. The dispatch is ceil(n_tokens / threadgroup_size) threadgroups. This kernel is used in the non-chain-decode path (when repetition penalty is applied as a separate step). In the Gumbel kernel, the penalty is applied inline (Phase 0) to avoid an extra dispatch.
Kernel 5: Temperature Scale (temperature_scale_f16)
kernel void temperature_scale_f16(
device half *logits, constant float &inv_temperature,
constant uint32_t &count, uint tid [[thread_position_in_grid]]
) {
if (tid >= count) return;
logits[tid] = half(float(logits[tid]) * inv_temperature);
}
A simple element-wise multiply used for:
- Temperature scaling before sampling (standalone path)
- Embedding scaling in prefill (Gemma)
- Any generic “scale all elements” operation
Putting It All Together: The GPU Sampling Pipeline
In chain decode with sampling enabled, the last two commands in the dispatch table are:
Command N-1: gumbel_topk_f16
- Reads: logits buffer, token_ids (previous tokens)
- Writes: logits buffer (in-place: mask + noise)
- Params: vocab_size, temperature, position (PATCHED), seed_offset,
top_k, top_p, repeat_penalty, min_p
Command N: argmax_f16
- Reads: logits buffer (now with Gumbel noise)
- Writes: token_ids[tok+1] (PATCHED output offset)
The position field in the Gumbel kernel and the output offset in the argmax kernel are patched per-token by the dispatch table mechanism, ensuring each token in the chain gets:
- A unique RNG seed (different position)
- Its own output slot in the token_ids buffer
The rest of the chain decode infrastructure (embedding lookup, layer loop, etc.) is identical to the greedy path. The only difference is two extra kernel dispatches per token.
Summary
Akunu’s sampling kernels demonstrate that sophisticated sampling algorithms can run entirely on the GPU:
- Argmax: Two-level butterfly reduction (SIMD + threadgroup) for O(N) global maximum with 1024 threads.
- Gumbel Top-K: Binary search on value domain for approximate top-K (O(N * iterations)), followed by Gumbel noise and argmax for exact sampling. PCG hash provides stateless RNG.
- Top-K Select: Explicit min-heap per chunk for exact top-K, used by the CPU sampling path.
- Repetition Penalty: Simple per-token penalty application, both standalone and inline.
The Gumbel-max approach is the key enabler for chain decode with sampling: by reducing sampling to noise + argmax, it eliminates the CPU roundtrip that would otherwise break the chain.
Performance Characteristics
The sampling kernels are all dispatched with a single threadgroup, making them latency-sensitive rather than throughput-sensitive. Here is the approximate timing breakdown:
| Kernel | Threads | Iterations | Time (128K vocab) | Bottleneck |
|---|---|---|---|---|
| argmax_f16 | 1024 | 1 scan + 2 reductions | ~5-8 us | Memory bandwidth |
| gumbel_topk_f16 | 1024 | 1 + 12 + 8 + 1 scans | ~30-50 us | Memory bandwidth |
| repetition_penalty_f16 | varies | 1 scan of prev tokens | ~1-5 us | Memory latency |
| temperature_scale_f16 | varies | 1 scan | ~3-5 us | Memory bandwidth |
The Gumbel kernel is the most expensive at ~30-50 microseconds, but this is negligible compared to the ~8ms forward pass. Even in chain decode with 128 tokens, the total sampling overhead is only ~4-6ms (128 * 35us), or less than 1% of the total generation time.
Statistical Quality of GPU Sampling
A natural concern: does the GPU sampling path produce the same distribution as the CPU path? Mathematically, yes – the Gumbel-max trick provides exact samples from the categorical distribution, not an approximation.4 However, there are practical differences:
-
RNG quality: The CPU uses a Mersenne Twister (std::mt19937) with 19937 bits of state, while the GPU uses a PCG hash with effectively 32 bits of state per element. The PCG hash has been validated against TestU01’s BigCrush battery, but its per-element independence relies on the hash quality.
-
Top-K precision: The CPU path uses exact top-K via
std::partial_sort, while the GPU path uses binary search to find an approximate threshold. The binary search may select slightly more or fewer than K elements, but the difference is bounded by the search precision (K +/- 1 at the threshold boundary). -
Top-P precision: Similarly, the GPU’s binary search for the top-P threshold has 8 iterations of precision (~0.03 logit units), which may include or exclude borderline tokens that the CPU path would handle differently.
In practice, these differences are undetectable in the output quality. The sampling distribution is dominated by the high-probability tokens (which are always included by both paths), and the borderline tokens near the cutoff thresholds have negligible probability mass.
Grammar Bitmask Kernel
While not in the sampling/ directory, it is worth mentioning the grammar_bitmask.metal kernel that can apply grammar constraints on the GPU:
grammar_bitmask.metal: Apply a precomputed bitmask to logits
- Sets logits[i] = -inf where bitmask bit i is 0
- Used for XGrammar integration on the GPU path
This kernel enables a hybrid approach: the grammar state machine runs on the CPU to compute the bitmask, but the bitmask application (which touches every logit) runs on the GPU. For vocabularies of 128K+, this saves a significant amount of CPU memory bandwidth.
-
The binary search approach for GPU top-K was pioneered in approximate nearest neighbor search. See: Johnson, J., Douze, M., and Jegou, H. “Billion-Scale Similarity Search with GPUs.” IEEE Transactions on Big Data, 2019. The same principle applies to logit filtering: we do not need the exact K-th element, just a threshold that selects approximately K candidates. ↩
-
Gumbel, E.J. “Statistical Theory of Extreme Values and Some Practical Applications.” National Bureau of Standards Applied Mathematics Series 33, 1954. The Gumbel-max trick for exact categorical sampling is proven in: Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. See https://arxiv.org/abs/1411.0030. ↩
-
O’Neill, M.E. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905, 2014. The PCG hash provides excellent statistical properties (passing TestU01’s BigCrush) with minimal state, making it ideal for GPU kernels. See https://www.pcg-random.org/paper.html. ↩
-
The exactness of the Gumbel-max trick is proven in Theorem 1 of Maddison et al. (2014). The proof shows that for any discrete distribution pi, argmax(log(pi_i) + G_i) where G_i are i.i.d. Gumbel(0,1) random variables yields a sample exactly distributed as Categorical(pi). No approximation is involved. ↩
The Weight Provider Abstraction
Alright, let’s talk about one of those pieces of engineering that doesn’t get enough credit: the weight provider. In most inference engines, you’ll find a hard coupling between the model loading code and whatever file format the weights come in. GGUF models go through one path, PyTorch checkpoints through another, SafeTensors through yet another. Each path has its own quirks, its own name mangling, its own way of handing you a tensor.
Akunu takes a different approach. It puts a clean abstraction layer – WeightProvider
– between the model code and the file format. The model doesn’t know or care whether
its weights came from a GGUF file or an MLX-formatted SafeTensors directory. It asks
for layers.3.attention.q.weight, and it gets back a GPU buffer. Period.
This chapter is about that abstraction: how it works, why it exists, and what makes it interesting from a systems design perspective.
The Problem: Two Worlds, One Interface
Let’s set the stage. Akunu needs to load weights from two very different ecosystems:
-
GGUF files – The format popularized by llama.cpp. A single monolithic file containing all tensors, metadata, and quantization parameters. Tensor names follow the
blk.{n}.attn_q.weightconvention. -
MLX SafeTensors – Apple’s MLX framework exports models as a directory containing
config.jsonplus one or more.safetensorsfiles. Tensor names follow the HuggingFacemodel.layers.{n}.self_attn.q_proj.weightconvention. Quantized models pack weights, scales, and biases as three separate tensors.
These formats differ in almost every dimension:
+-------------------+----------------------------+----------------------------+
| Dimension | GGUF | MLX SafeTensors |
+-------------------+----------------------------+----------------------------+
| File structure | Single .gguf file | Directory with config.json |
| | | + model.safetensors |
+-------------------+----------------------------+----------------------------+
| Metadata | KV pairs in binary header | JSON config.json |
+-------------------+----------------------------+----------------------------+
| Tensor names | blk.0.attn_q.weight | model.layers.0.self_attn. |
| | | q_proj.weight |
+-------------------+----------------------------+----------------------------+
| Quantization | Block-level (Q4_0, Q6_K) | Per-group with separate |
| | embedded in tensor data | .scales + .biases |
+-------------------+----------------------------+----------------------------+
| Data types | 30+ GGML types | F16, BF16, F32, U32, I8 |
+-------------------+----------------------------+----------------------------+
| Tensor data | Contiguous in data section | Contiguous after header |
+-------------------+----------------------------+----------------------------+
The model code doesn’t want to deal with any of this. It wants a canonical name, and
it wants bytes on the GPU. The WeightProvider is the bridge.
The Strategy Pattern in Action
Let’s look at the actual class definition from weight_provider.h:
class WeightProvider {
public:
enum Format { GGUF, MLX_SAFETENSORS };
WeightProvider(Device& device) : device_(device) {}
~WeightProvider() { close(); }
bool open(const std::string& path);
void close();
Format format() const { return format_; }
AkunuModelConfig get_config() const;
Buffer get_tensor(const std::string& name);
uint32_t get_dtype(const std::string& name) const;
bool has_tensor(const std::string& name) const;
// Metadata access
std::string get_metadata_string(const std::string& key) const;
int64_t get_metadata_int(const std::string& key, int64_t def = 0) const;
float get_metadata_float(const std::string& key, float def = 0.0f) const;
std::vector<std::string> get_string_array(const std::string& key) const;
std::vector<float> get_float_array(const std::string& key) const;
// Tensor listing
int tensor_count() const;
std::string tensor_name_at(int index) const;
// MLX quantization info
int quant_bits() const;
int quant_group_size() const;
// Weight fusion
Buffer fuse_weights(const std::string& a, const std::string& b);
Buffer fuse_weights(const std::string& a, const std::string& b,
const std::string& c);
private:
Device& device_;
Format format_ = GGUF;
std::unique_ptr<WeightStore> gguf_;
std::unique_ptr<MLXWeightStore> mlx_;
static Format detect_format(const std::string& path);
};
If you squint, this is a textbook Strategy pattern. The WeightProvider holds a
pointer to one of two concrete implementations – WeightStore (for GGUF) or
MLXWeightStore (for SafeTensors) – and delegates every operation to whichever
one is active. But it’s not using virtual functions and inheritance. Instead, it
uses a simpler discriminated-union approach: an enum plus two unique_ptrs.
Why not virtual dispatch? Probably because there are only two backends and the
method set is well-defined. A vtable adds indirection for no real benefit here.
The explicit delegation is clear, debuggable, and has zero overhead beyond a
branch prediction that will always be correct (since the format doesn’t change
after open()).
Here’s the delegation pattern for get_tensor():
Buffer get_tensor(const std::string& name) {
return (format_ == MLX_SAFETENSORS)
? mlx_->get_tensor(name)
: gguf_->get_tensor(name);
}
Every method follows this exact pattern. Clean, predictable, no surprises.
Format Detection: Simpler Than You’d Think
When you call open(), the first thing that happens is format detection. And it’s
refreshingly simple:
static Format detect_format(const std::string& path) {
// Directory or .safetensors -> MLX
struct stat st;
if (stat(path.c_str(), &st) == 0 && S_ISDIR(st.st_mode))
return MLX_SAFETENSORS;
if (path.size() > 12 &&
path.substr(path.size() - 12) == ".safetensors")
return MLX_SAFETENSORS;
return GGUF;
}
That’s it. Two checks:
- Is it a directory? Then it’s an MLX model directory (containing
config.jsonandmodel.safetensors). - Does the filename end in
.safetensors? Same conclusion. - Everything else? GGUF.
No magic number sniffing, no content-based detection. This works because in
practice, users either point at a .gguf file or an MLX model directory. The
heuristic is simple, fast, and correct for the actual use cases.
The flow looks like this:
open("/path/to/model")
|
v
+--------------------+
| detect_format() |
+--------------------+
/ \
Directory or Everything
.safetensors? else
| |
v v
+----------------+ +----------------+
| MLXWeightStore | | WeightStore |
| .open() | | .open() |
+----------------+ +----------------+
| |
v v
Parse config.json Parse GGUF header
Open .safetensors mmap entire file
Build name map Build name map
Once the backend is created and opened, the WeightProvider is ready. All
subsequent calls go through the chosen backend.
The Canonical Name System
This is one of the most important design decisions in the weight system, and it’s worth understanding in detail. Different model formats use different naming conventions for the same tensors:
GGUF Convention MLX Convention
---------------- ----------------
token_embd.weight <---> model.embed_tokens.weight
blk.0.attn_q.weight <---> model.layers.0.self_attn.q_proj.weight
blk.0.ffn_gate.weight <---> model.layers.0.mlp.gate_proj.weight
output_norm.weight <---> model.norm.weight
Neither of these is what akunu’s model code uses. Instead, akunu defines its own canonical naming scheme:
Canonical Name Purpose
--------------------------------- ---------------------------
token_embedding.weight Token embedding matrix
layers.{n}.attention.q.weight Q projection, layer n
layers.{n}.attention.k.weight K projection, layer n
layers.{n}.attention.v.weight V projection, layer n
layers.{n}.attention.output.weight Output projection, layer n
layers.{n}.ffn.gate.weight SwiGLU gate projection
layers.{n}.ffn.up.weight SwiGLU up projection
layers.{n}.ffn.down.weight Down projection
layers.{n}.attention_norm.weight Pre-attention RMSNorm
layers.{n}.ffn_norm.weight Pre-FFN RMSNorm
output_norm.weight Final RMSNorm
output.weight LM head
Each backend maintains a bidirectional mapping between its format-specific names
and these canonical names. When the model code asks for
layers.3.attention.q.weight, the GGUF backend translates that to
blk.3.attn_q.weight, and the MLX backend translates it to
model.layers.3.self_attn.q_proj.weight.
The mapping is built at load time. Here’s the GGUF side, from weight_store.cpp:
static const struct {
const char *gguf;
const char *canonical;
} kBaseRules[] = {
{"token_embd.weight", "token_embedding.weight"},
{"blk.{n}.attn_q.weight", "layers.{n}.attention.q.weight"},
{"blk.{n}.attn_k.weight", "layers.{n}.attention.k.weight"},
{"blk.{n}.attn_v.weight", "layers.{n}.attention.v.weight"},
{"blk.{n}.attn_output.weight", "layers.{n}.attention.output.weight"},
{"blk.{n}.attn_norm.weight", "layers.{n}.attention_norm.weight"},
{"blk.{n}.ffn_gate.weight", "layers.{n}.ffn.gate.weight"},
{"blk.{n}.ffn_up.weight", "layers.{n}.ffn.up.weight"},
{"blk.{n}.ffn_down.weight", "layers.{n}.ffn.down.weight"},
{"blk.{n}.ffn_norm.weight", "layers.{n}.ffn_norm.weight"},
// ... plus bias tensors, QK-norm, Gemma post-norms, etc.
};
And the MLX side, from mlx_weight_store.h:
static const MLXNameRule kMLXRules[] = {
{"model.embed_tokens.weight", "token_embedding.weight"},
{"model.norm.weight", "output_norm.weight"},
{"lm_head.weight", "output.weight"},
{"model.layers.{n}.self_attn.q_proj.weight",
"layers.{n}.attention.q.weight"},
{"model.layers.{n}.mlp.gate_proj.weight",
"layers.{n}.ffn.gate.weight"},
// ... and so on
};
The {n} placeholder is a neat trick. During build_name_mapping(), each rule
is expanded for every layer in the model:
void WeightStore::build_name_mapping() {
int n_layers = get_metadata_int(arch + ".block_count", 0);
for (int r = 0; r < kNumBaseRules; r++) {
if (strstr(pattern, "{n}")) {
for (int layer = 0; layer < n_layers; layer++) {
std::string gguf_name = expand_rule(pattern, layer);
if (gguf_get_tensor(gguf_, gguf_name.c_str())) {
name_map_[expand_rule(canonical_pattern, layer)]
= gguf_name;
}
}
} else {
if (gguf_get_tensor(gguf_, pattern))
name_map_[canonical] = pattern;
}
}
}
The existence check (gguf_get_tensor()) is important. Not every model has every
tensor. Some models have QK-norm weights, some don’t. Some have bias tensors, most
don’t. By checking for existence, the mapping only includes tensors that are
actually present in the file.
The Data Flow: From File to GPU
Let’s trace the complete path of a weight tensor from disk to GPU. We’ll use the GGUF path since it’s more straightforward, but the MLX path follows the same high- level structure.
Model code calls:
provider.get_tensor("layers.5.ffn.gate.weight")
|
v
WeightProvider delegates to WeightStore
|
v
WeightStore::get_tensor("layers.5.ffn.gate.weight")
|
+-- Check buffer_cache_ (hit? return cached buffer)
|
+-- Lookup in name_map_:
| "layers.5.ffn.gate.weight" -> "blk.5.ffn_gate.weight"
|
+-- load_tensor_raw("blk.5.ffn_gate.weight")
|
+-- gguf_get_tensor(gguf_, "blk.5.ffn_gate.weight")
| Returns: GGUFTensorInfo { dtype=Q4_0, offset=0x1234,
| n_elements=14336*4096 }
|
+-- gguf_tensor_data(gguf_, info)
| Returns: pointer into mmap'd region (zero-copy!)
|
+-- Compute byte size from dtype:
| Q4_0: n_elements / 32 * 18 bytes
|
+-- dtype == F32 or BF16? Convert to F16
| Otherwise: direct copy to GPU
|
+-- device_.allocate(data, bytes)
Returns: Buffer { handle, size, contents }
A few things to note:
mmap is doing the heavy lifting. The GGUF parser memory-maps the entire file. When we need tensor data, we just compute a pointer into the mapped region. There’s no explicit read, no buffer allocation for the raw data. The OS handles paging in the data on demand. This means opening a 4GB model file is nearly instantaneous – the actual I/O happens lazily when we first touch each tensor’s bytes.
Lazy loading with caching. Tensors are loaded on first access and cached in
buffer_cache_. Once a tensor is on the GPU, subsequent requests for the same
tensor return the cached buffer. This is important because during inference, the
same weights are used on every forward pass.
Format conversion at load time. F32 and BF16 tensors are converted to F16 during loading. The GPU kernels expect F16, so this conversion happens exactly once. For quantized types (Q4_0, Q6_K, etc.), the data is copied as-is – the dequantization happens in the compute kernels.
The byte-size computation for quantized types is a lookup that maps dtype to a formula based on block size:
Type Block Size Bytes/Block Formula
------ ---------- ----------- -------------------------
Q4_0 32 elements 18 bytes n / 32 * 18
Q4_1 32 elements 20 bytes n / 32 * 20
Q5_0 32 elements 22 bytes n / 32 * 22
Q8_0 32 elements 34 bytes n / 32 * 34
Q2_K 256 elements 84 bytes n / 256 * 84
Q3_K 256 elements 110 bytes n / 256 * 110
Q4_K 256 elements 144 bytes n / 256 * 144
Q5_K 256 elements 176 bytes n / 256 * 176
Q6_K 256 elements 210 bytes n / 256 * 210
Q8_K 256 elements 292 bytes n / 256 * 292
We’ll cover the details of these formats in the quantization chapter.
Weight Fusion: Gate+Up and Q+K+V
One of the most performance-critical operations in the weight provider is weight fusion. The idea is simple: instead of doing two (or three) separate matrix multiplications and then combining the results, we concatenate the weight matrices and do a single, larger matmul.
For SwiGLU-based FFN layers, the gate and up projections can be fused:
Before fusion (2 matmuls):
gate_out = x @ gate_weight (dim -> ffn_dim)
up_out = x @ up_weight (dim -> ffn_dim)
After fusion (1 matmul):
fused_out = x @ [gate_weight; up_weight] (dim -> 2*ffn_dim)
gate_out = fused_out[:ffn_dim]
up_out = fused_out[ffn_dim:]
Similarly, Q, K, and V projections can be fused when they share the same input:
Before fusion (3 matmuls):
q = x @ q_weight (dim -> q_dim)
k = x @ k_weight (dim -> kv_dim)
v = x @ v_weight (dim -> kv_dim)
After fusion (1 matmul):
fused = x @ [q_weight; k_weight; v_weight] (dim -> q_dim+2*kv_dim)
The fuse_weights() methods handle this concatenation. For GGUF, it’s
straightforward – just concatenate the raw bytes:
Buffer WeightStore::fuse_weights(const std::string& name_a,
const std::string& name_b) {
Buffer a = get_tensor(name_a);
Buffer b = get_tensor(name_b);
size_t total = a.size + b.size;
Buffer fused = device_.allocate(total);
memcpy(fused.contents, a.contents, a.size);
memcpy((char*)fused.contents + a.size, b.contents, b.size);
fused_cache_[key] = fused;
return fused;
}
For MLX quantized models, fusion is more involved because each tensor is actually
a packed triple of [weights | scales | biases]. You can’t just concatenate the
whole buffers – you need to concatenate each section separately:
Input buffers (each is a packed triple):
Tensor A: [ A_weights | A_scales | A_biases ]
Tensor B: [ B_weights | B_scales | B_biases ]
Fused output:
[ A_weights | B_weights | A_scales | B_scales | A_biases | B_biases ]
|<--- all weights --->|<--- all scales --->|<--- all biases --->|
This layout is critical for the GPU kernel, which expects to find all weights
contiguous, then all scales contiguous, then all biases contiguous. The
fuse_mlx_packed() helper function handles this three-way interleaving.
Here’s the ASCII diagram of the full fusion pipeline for three tensors (Q+K+V):
Q buffer: [ Q_w (N_q*K_packed*4 bytes) | Q_s (N_q*K/gs*2) | Q_b (N_q*K/gs*2) ]
K buffer: [ K_w (N_k*K_packed*4 bytes) | K_s (N_k*K/gs*2) | K_b (N_k*K/gs*2) ]
V buffer: [ V_w (N_v*K_packed*4 bytes) | V_s (N_v*K/gs*2) | V_b (N_v*K/gs*2) ]
|
v
fuse_mlx_packed()
|
v
Fused: [ Q_w | K_w | V_w | Q_s | K_s | V_s | Q_b | K_b | V_b ]
|<-- total_w --->| |<-- total_s*2 -->| |<-- total_s*2 -->|
The fusion result is also cached (keyed by the concatenation of canonical names), so subsequent forward passes reuse the fused buffer.
Config Extraction: Two Paths to the Same Struct
The get_config() method returns an AkunuModelConfig struct regardless of the
source format. But the two backends extract this config very differently.
GGUF path: Config lives in the binary metadata KV pairs. Keys are prefixed with the architecture name:
Key Example Value
----------------------------------------- ------------
general.architecture "llama"
llama.embedding_length 4096
llama.block_count 32
llama.attention.head_count 32
llama.attention.head_count_kv 8
llama.feed_forward_length 14336
llama.context_length 8192
llama.rope.freq_base 500000.0
llama.attention.layer_norm_rms_epsilon 1e-5
The WeightStore::get_config() method tries architecture-prefixed keys first,
then falls back to unqualified keys. This is because some GGUF files use
llama.block_count while others use just block_count.
MLX path: Config lives in config.json, a standard HuggingFace config file:
{
"model_type": "llama",
"hidden_size": 4096,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"rope_theta": 500000.0,
"rms_norm_eps": 1e-5,
"quantization_config": {
"bits": 4,
"group_size": 64
}
}
The MLXWeightStore::parse_config_json() method uses a minimal JSON parser
(hand-rolled, no dependencies) to extract these values. Note the different key
names: HuggingFace uses hidden_size where GGUF uses embedding_length, and
num_hidden_layers where GGUF uses block_count.
Both paths populate the same AkunuModelConfig struct:
typedef struct {
uint32_t dim; // embedding_length / hidden_size
uint32_t n_layers; // block_count / num_hidden_layers
uint32_t n_heads; // head_count / num_attention_heads
uint32_t n_kv_heads; // head_count_kv / num_key_value_heads
uint32_t head_dim; // explicit or dim/n_heads
uint32_t q_dim; // n_heads * head_dim
uint32_t kv_dim; // n_kv_heads * head_dim
uint32_t ffn_dim; // feed_forward_length / intermediate_size
uint32_t vocab_size;
uint32_t max_seq_len;
float norm_eps;
float rope_theta;
uint32_t sliding_window_pattern;
float rope_local_theta;
char architecture[32];
// ... encoder fields for Whisper
} AkunuModelConfig;
Architecture-Specific Handling
The weight system isn’t just a dumb loader. It has architecture-specific logic for several model families.
Whisper
Whisper is an encoder-decoder model, which means it has two sets of layers with
completely different tensor names. The GGUF backend has a separate rule table
(kWhisperRules) with entries for both encoder and decoder:
GGUF Name Canonical Name
---------------------------------------- --------------------------------
encoder.conv1.weight enc.conv1.weight
encoder.blocks.0.attn.query.weight enc.layers.0.attn.q.weight
decoder.blocks.0.attn.query.weight layers.0.attention.q.weight
decoder.blocks.0.cross_attn.query.weight layers.0.cross_attn.q.weight
The config extraction also handles Whisper specially, populating the encoder-
specific fields (enc_n_layers, enc_n_heads, n_mels, etc.) from
whisper.encoder.* metadata keys.
Gemma 3
Gemma 3 uses a sliding window attention pattern where every 6th layer has
global attention and the rest use local/sliding-window attention. The config
stores this as sliding_window_pattern = 6 and rope_local_theta = 10000.0.
Both GGUF and MLX backends detect this pattern when they see the gemma
architecture with a non-zero sliding window size.
QK-Norm (Qwen3)
Some newer models like Qwen3 add separate RMSNorm layers for Q and K projections before the attention computation. Both rule tables include mappings for these:
GGUF: blk.{n}.attn_q_norm.weight -> layers.{n}.attention.q_norm.weight
MLX: model.layers.{n}.self_attn.q_norm.weight -> (same canonical)
The Caching Architecture
Let’s look at the complete caching picture across the system:
+--------------------------------------------------+
| WeightProvider |
+--------------------------------------------------+
| |
v v
+----------------------+ +----------------------+
| WeightStore | | MLXWeightStore |
| (GGUF backend) | | (SafeTensors) |
+----------------------+ +----------------------+
| buffer_cache_: | | buffer_cache_: |
| canonical -> GPU | | canonical -> GPU |
| fused_cache_: | | "a+b" -> GPU |
| "a+b" -> GPU | | "a+b+c" -> GPU |
+----------------------+ +----------------------+
| |
v v
+-------------------+ +-------------------+
| mmap'd GGUF | | mmap'd .safe- |
| (OS page cache) | | tensors file |
+-------------------+ +-------------------+
There are effectively three levels of caching:
-
OS page cache: The mmap’d files are backed by the OS page cache. First access to a tensor’s bytes triggers a page fault and disk read. Subsequent accesses are served from RAM.
-
GPU buffer cache: Once a tensor is uploaded to the GPU (via
device_.allocate()), the result is cached inbuffer_cache_. The model never re-uploads a tensor. -
Fused buffer cache: Fused weight combinations are cached separately in
fused_cache_(GGUF) or inbuffer_cache_with composite keys like"layers.0.ffn.gate.weight+layers.0.ffn.up.weight"(MLX).
This means the steady-state memory picture during inference is: the original file is mmap’d but mostly paged out (the OS will reclaim those pages under memory pressure), and the weights live in GPU-accessible buffers.
Metadata Access: A Leaky Abstraction?
One area where the abstraction gets a bit leaky is metadata access. The
get_metadata_string(), get_metadata_int(), and get_metadata_float()
methods are primarily used for GGUF metadata (which is rich and structured).
The MLX backend’s implementations are stubs that return defaults:
// mlx_weight_store.cpp
std::string MLXWeightStore::get_metadata_string(const std::string&) const {
return "";
}
int64_t MLXWeightStore::get_metadata_int(const std::string&, int64_t def) const {
return def;
}
This makes sense when you think about it. GGUF metadata contains everything:
model architecture, tokenizer vocabulary, RoPE parameters, you name it. MLX
models store their config in config.json (parsed at open time into
AkunuModelConfig) and their tokenizer in a separate tokenizer.json file.
The metadata methods exist because the tokenizer system needs access to GGUF’s
embedded vocabulary arrays (tokenizer.ggml.tokens, tokenizer.ggml.scores).
For MLX models, the tokenizer is loaded from tokenizer.json through a
completely separate path, so these methods are never called.
Is this a code smell? Maybe. But it’s a pragmatic choice. Adding a separate tokenizer abstraction just to avoid empty stubs would be over-engineering.
Tensor Inspection: The Debug Interface
The tensor_count(), tensor_name_at(), and tensor_raw_dtype() methods
form a debug/inspection interface. These are used by akunu’s model inspection
tool to list all tensors in a weight file without loading them onto the GPU:
Index Name DType Elements
----- ---------------------------------------- --------- -----------
0 token_embd.weight Q8_0 524288000
1 blk.0.attn_q.weight Q4_K 16777216
2 blk.0.attn_k.weight Q4_K 4194304
...
199 output.weight Q8_0 524288000
This is purely for human consumption. The model code never uses these methods.
The Buffer Type
Throughout this chapter, we’ve been passing around Buffer objects. Let’s
clarify what this actually is:
struct Buffer {
void* handle; // Opaque GPU handle (MTLBuffer* on Metal)
size_t size; // Buffer size in bytes
void* contents; // CPU-accessible pointer (shared memory on Apple Silicon)
};
On Apple Silicon, GPU and CPU share the same physical memory, so contents
points directly to the GPU buffer’s storage. This means memcpy into
contents is the same as uploading to the GPU – there’s no separate DMA
transfer step. This is what makes the weight loading so fast: mmap the file,
memcpy from the mmap into the GPU buffer, done.
The {nullptr, 0, nullptr} triple is used as a sentinel for “not found” or
“error”. You’ll see this returned throughout the code as the failure case.
Putting It All Together
Let’s trace a complete model load from the application’s perspective:
Application WeightProvider Backend
----------- -------------- -------
provider.open("/path/to/model")
|
+----> detect_format() -----> "Is it a dir?" -----> MLX
| "Is it .gguf?" -----> GGUF
|
+----> backend.open("/path/to/model")
| |
| +-- mmap file / parse header
| +-- extract metadata / config.json
| +-- build_name_mapping()
|
+----> provider.get_config()
| |
| +-- backend.get_config() -> AkunuModelConfig
|
+----> For each layer:
| provider.fuse_weights(
| "layers.N.attention.q.weight",
| "layers.N.attention.k.weight",
| "layers.N.attention.v.weight")
| |
| +-- get_tensor() x3 (load each, cache)
| +-- concatenate (format-specific)
| +-- cache fused result
|
+----> provider.fuse_weights(
| "layers.N.ffn.gate.weight",
| "layers.N.ffn.up.weight")
|
+----> provider.get_tensor("layers.N.ffn.down.weight")
|
+----> (inference begins, all weights cached on GPU)
After the initial load, no further disk I/O happens. Every weight access is a hash table lookup returning a cached GPU buffer. The model code is completely format-agnostic – it works with canonical names and doesn’t know whether the underlying data came from a quantized GGUF, a full-precision SafeTensors file, or an MLX 4-bit quantized directory.
That’s the weight provider. Not flashy, not complicated, but it cleanly decouples two very different ecosystems from the model code that uses them. In the next chapters, we’ll dive deep into the formats themselves – starting with GGUF.
GGUF: Format Specification and Parser
If you’ve been anywhere near the local LLM scene, you’ve encountered GGUF files. They’re the de facto distribution format for quantized models, popularized by llama.cpp and now supported by pretty much every serious inference engine. GGUF stands for “GPT-Generated Unified Format” (though nobody actually calls it that), and it replaced the older GGML format back in 2023.
This chapter is a deep dive into the GGUF format – the binary layout, the parsing
strategy, and how akunu’s gguf_parser.cpp turns a flat file into structured data.
If you’ve ever been curious about what’s actually inside those multi-gigabyte files
you download from HuggingFace, this is for you.
The Big Picture: File Layout
A GGUF file is a single, self-contained binary blob. Everything the inference engine needs – model architecture, hyperparameters, tokenizer vocabulary, and all the weight tensors – lives in one file. No companion JSON, no directory structure, no index files.
The high-level layout is dead simple:
+=============================================+ offset 0
| HEADER |
| +---------------------------------------+ |
| | Magic number (4 bytes, LE) | |
| | Version (4 bytes, LE) | |
| | Tensor count (8 bytes, LE) | |
| | KV pair count (8 bytes, LE) | |
| +---------------------------------------+ |
| |
| +---------------------------------------+ |
| | Metadata KV pairs | |
| | (variable length, kv_count entries) | |
| +---------------------------------------+ |
| |
| +---------------------------------------+ |
| | Tensor info entries | |
| | (variable length, tensor_count) | |
| +---------------------------------------+ |
+=============================================+
| PADDING to 32-byte alignment |
+=============================================+ data_base
| |
| TENSOR DATA |
| |
| (raw bytes, tensors at their declared |
| offsets from data_base) |
| |
+=============================================+ EOF
Let’s zoom in on each section.
The Header: 24 Bytes of Truth
The header is the first 24 bytes of every GGUF file:
Byte offset Size Field Description
----------- ---- ----- -----------
0 4 magic 0x46554747 ("GGUF" in little-endian)
4 4 version Format version (currently 3)
8 8 tensor_count Number of tensors in the file
16 8 kv_count Number of metadata key-value pairs
Let’s break down the magic number. In ASCII, G=0x47, G=0x47, U=0x55,
F=0x46. Stored in little-endian as a 32-bit integer, that’s 0x46554747.
Here’s the byte-level view:
Address: 00 01 02 03 04 05 06 07
Bytes: 47 47 55 46 03 00 00 00
G G U F version=3
Akunu’s parser checks both the magic and the version:
static constexpr uint32_t GGUF_MAGIC = 0x46554747;
static constexpr uint32_t GGUF_VERSION = 3;
uint32_t magic = read_u32(f);
if (magic != GGUF_MAGIC) {
fprintf(stderr, "bad magic 0x%08x\n", magic);
return nullptr;
}
uint32_t version = read_u32(f);
if (version != GGUF_VERSION) {
fprintf(stderr, "unsupported version %u\n", version);
return nullptr;
}
Version 3 is the current and only version akunu supports. Version 1 never saw wide use, and version 2 had a brief life before version 3 stabilized the format. The difference between v2 and v3 is minor – mainly around how string lengths and array counts are encoded (v3 uses uint64 everywhere, v2 used uint32 in some places).
Metadata Key-Value Pairs
Immediately after the 24-byte header, you find the metadata section. This is where the model’s configuration lives – architecture type, embedding dimensions, number of layers, RoPE parameters, tokenizer vocabulary, and more.
Each KV pair has this structure:
+---------------------------+
| Key string |
| +---------------------+ |
| | length (8 bytes) | |
| | chars (N bytes) | |
| +---------------------+ |
+---------------------------+
| Value type (4 bytes) |
+---------------------------+
| Value data (variable) |
+---------------------------+
The key is a GGUF string: a uint64 length prefix followed by that many raw bytes (no null terminator in the file, though the parser adds one for convenience).
The value type is one of 13 possible types:
Code Type Size Description
---- ---- ---- -----------
0 UINT8 1 byte Unsigned 8-bit integer
1 INT8 1 byte Signed 8-bit integer
2 UINT16 2 bytes Unsigned 16-bit integer
3 INT16 2 bytes Signed 16-bit integer
4 UINT32 4 bytes Unsigned 32-bit integer
5 INT32 4 bytes Signed 32-bit integer
6 FLOAT32 4 bytes IEEE 754 single-precision float
7 BOOL 1 byte Boolean (0 or 1)
8 STRING variable GGUF string (uint64 length + chars)
9 ARRAY variable Typed array (see below)
10 UINT64 8 bytes Unsigned 64-bit integer
11 INT64 8 bytes Signed 64-bit integer
12 FLOAT64 8 bytes IEEE 754 double-precision float
Most metadata values are simple scalars. For example, the embedding dimension might be stored as:
Key: "llama.embedding_length"
Type: UINT32 (4)
Value: 00 10 00 00 (4096 in little-endian)
Or a float parameter:
Key: "llama.rope.freq_base"
Type: FLOAT32 (6)
Value: 00 40 1C 47 (500000.0 in IEEE 754)
Array Values
Arrays are the most complex metadata type. They have this layout:
+-----------------------------+
| Element type (4 bytes) |
| Element count (8 bytes) |
| Element 0 (variable) |
| Element 1 (variable) |
| ... |
| Element N-1 (variable) |
+-----------------------------+
The two most important array uses are the tokenizer vocabulary (an array of strings) and the tokenizer scores (an array of float32). A vocabulary array for a 32K-token model would look something like:
Element type: STRING (8)
Element count: 32000
Element 0: [len=6] "<unk>" (the unknown token)
Element 1: [len=3] "<s>" (begin of sequence)
Element 2: [len=4] "</s>" (end of sequence)
...
Element 31999: [len=8] "zoology"
Each string element is a full GGUF string (uint64 length + chars), so parsing an array of strings requires walking through variable-length elements sequentially. There’s no random access – you can’t jump to element 15000 without parsing all preceding elements.
How Akunu Parses Metadata
The parser reads metadata values with a big switch statement. Let’s look at the interesting parts:
static void read_metadata_value(GGUFFileImpl *f, GGUFMetadataKV *kv,
uint32_t type) {
kv->type = type;
kv->array_len = 0;
kv->array_data = nullptr;
switch (type) {
case GGUF_TYPE_UINT8:
kv->value.u32 = read_u8(f);
break;
case GGUF_TYPE_INT32:
kv->value.i32 = read_i32(f);
break;
case GGUF_TYPE_FLOAT32:
kv->value.f32 = read_f32(f);
break;
case GGUF_TYPE_STRING:
kv->value.str = read_string(f);
break;
case GGUF_TYPE_ARRAY: {
uint32_t elem_type = read_u32(f);
uint64_t count = read_u64(f);
kv->array_len = count;
kv->array_data = f->cursor; // <-- raw pointer into mmap!
// Skip past all elements
for (uint64_t i = 0; i < count; i++) {
skip_value(f, elem_type);
}
kv->value.u32 = elem_type; // store element type
break;
}
// ... other types ...
}
}
Notice the array handling. The parser doesn’t decode array elements during the
initial parse. Instead, it records a raw pointer (array_data) into the mmap’d
region and the element type. The actual element decoding happens lazily when
someone calls gguf_get_string_array() or gguf_get_float_array(). This is
a nice optimization – the tokenizer vocabulary can have 100K+ entries, and
parsing all of them upfront would waste time if the caller only needs the
model architecture.
But the parser still has to skip past all elements to find where the next KV
pair starts. The skip_value() function handles this:
static void skip_value(GGUFFileImpl *f, uint32_t type) {
switch (type) {
case GGUF_TYPE_UINT8:
case GGUF_TYPE_INT8:
case GGUF_TYPE_BOOL:
f->cursor += 1;
break;
case GGUF_TYPE_UINT16:
case GGUF_TYPE_INT16:
f->cursor += 2;
break;
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32:
case GGUF_TYPE_FLOAT32:
f->cursor += 4;
break;
case GGUF_TYPE_UINT64:
case GGUF_TYPE_INT64:
case GGUF_TYPE_FLOAT64:
f->cursor += 8;
break;
case GGUF_TYPE_STRING: {
uint64_t len = read_u64(f);
f->cursor += len;
break;
}
case GGUF_TYPE_ARRAY: {
uint32_t elem_type = read_u32(f);
uint64_t count = read_u64(f);
for (uint64_t i = 0; i < count; i++) {
skip_value(f, elem_type);
}
break;
}
}
}
This is recursive for nested arrays (arrays of arrays), though in practice GGUF files don’t use nested arrays. The common case is arrays of strings or arrays of floats, which skip efficiently.
Tensor Info Entries
After all metadata KV pairs, the file contains tensor info entries. Each entry describes one tensor: its name, shape, data type, and offset into the data section.
+-------------------------------+
| Tensor name (GGUF string) |
+-------------------------------+
| Number of dimensions (u32) |
+-------------------------------+
| Dimension 0 (u64) |
| Dimension 1 (u64) |
| ...up to 4 dims |
+-------------------------------+
| Data type (u32) |
+-------------------------------+
| Offset (u64) |
+-------------------------------+
Let’s work through a concrete example. A Q4_0-quantized attention Q projection weight for a 4096-dim model with 32 heads might look like:
Name: "blk.0.attn_q.weight"
n_dims: 2
dims[0]: 4096 (output dimension = dim)
dims[1]: 4096 (input dimension = dim)
dtype: 2 (GGUF_DTYPE_Q4_0)
offset: 0x00000000 (first tensor, starts at data_base)
n_elements: 4096 * 4096 = 16,777,216
bytes: 16,777,216 / 32 * 18 = 9,437,184 (about 9 MB)
The parser stores this in a GGUFTensorInfo struct:
typedef struct {
const char *name;
uint64_t n_elements;
uint32_t n_dims;
uint64_t dims[4];
uint64_t offset;
uint32_t dtype;
} GGUFTensorInfo;
Note that n_elements is computed during parsing by multiplying all dimensions
together. The parser also checks for overflow:
ti.n_elements = 1;
for (uint32_t d = 0; d < ti.n_dims; d++) {
ti.dims[d] = read_u64(f);
if (ti.dims[d] > 0 &&
ti.n_elements > UINT64_MAX / ti.dims[d]) {
fprintf(stderr, "dimension overflow\n");
return nullptr;
}
ti.n_elements *= ti.dims[d];
}
That overflow check matters. A malicious GGUF file could set dimensions to
absurd values, and without the check, n_elements would wrap around to a
small number, leading to undersized buffer allocations and memory corruption.
The Data Section: Alignment Matters
After all tensor info entries, there’s a padding gap to align the data section to a 32-byte boundary. This alignment is important for efficient memory access, especially on GPUs where misaligned reads can be catastrophically slow.
static constexpr size_t GGUF_ALIGNMENT = 32;
size_t header_bytes = (size_t)(f->cursor - (const uint8_t*)f->mmap_addr);
size_t aligned = (header_bytes + GGUF_ALIGNMENT - 1)
& ~(GGUF_ALIGNMENT - 1);
f->data_base = (const uint8_t*)f->mmap_addr + aligned;
The & ~(GGUF_ALIGNMENT - 1) trick is a classic bit manipulation for rounding
up to a power-of-two alignment. Since GGUF_ALIGNMENT is 32 (which is
0x20), GGUF_ALIGNMENT - 1 is 0x1F, and ~0x1F is 0xFFFFFFE0. ANDing
with this mask clears the bottom 5 bits, effectively rounding down. But we
first added GGUF_ALIGNMENT - 1, so the net effect is rounding up.
Example: header_bytes = 105,743
+ 31 = 105,774
& 0xFFFFFFE0 = 105,760 (aligned to 32)
data_base = mmap_addr + 105,760
Each tensor’s data lives at data_base + tensor.offset. The offset is
relative to data_base, not to the start of the file. This is a subtle
but important detail – tensor offsets stored in the file are already
relative to the aligned data section start.
Tensor Data Types: The Full Catalog
GGUF supports 31 tensor data types. Not all of them are commonly used, but akunu’s parser defines them all. Here’s the complete enumeration:
typedef enum {
GGUF_DTYPE_F32 = 0, // IEEE 754 float32
GGUF_DTYPE_F16 = 1, // IEEE 754 float16
GGUF_DTYPE_Q4_0 = 2, // 4-bit quantized, type 0
GGUF_DTYPE_Q4_1 = 3, // 4-bit quantized, type 1
// 4, 5 are legacy (Q4_2, Q4_3 -- removed)
GGUF_DTYPE_Q5_0 = 6, // 5-bit quantized, type 0
GGUF_DTYPE_Q5_1 = 7, // 5-bit quantized, type 1
GGUF_DTYPE_Q8_0 = 8, // 8-bit quantized, type 0
GGUF_DTYPE_Q8_1 = 9, // 8-bit quantized, type 1
GGUF_DTYPE_Q2_K = 10, // K-quant, 2-bit
GGUF_DTYPE_Q3_K = 11, // K-quant, 3-bit
GGUF_DTYPE_Q4_K = 12, // K-quant, 4-bit
GGUF_DTYPE_Q5_K = 13, // K-quant, 5-bit
GGUF_DTYPE_Q6_K = 14, // K-quant, 6-bit
GGUF_DTYPE_Q8_K = 15, // K-quant, 8-bit
GGUF_DTYPE_IQ2_XXS = 16, // Importance quant, 2-bit, extra-small
GGUF_DTYPE_IQ2_XS = 17, // Importance quant, 2-bit, small
GGUF_DTYPE_IQ3_XXS = 18, // Importance quant, 3-bit
GGUF_DTYPE_IQ1_S = 19, // Importance quant, 1-bit
GGUF_DTYPE_IQ4_NL = 20, // Importance quant, 4-bit nonlinear
GGUF_DTYPE_IQ3_S = 21, // Importance quant, 3-bit
GGUF_DTYPE_IQ2_S = 22, // Importance quant, 2-bit
GGUF_DTYPE_IQ4_XS = 23, // Importance quant, 4-bit
GGUF_DTYPE_I8 = 24, // Plain int8
GGUF_DTYPE_I16 = 25, // Plain int16
GGUF_DTYPE_I32 = 26, // Plain int32
GGUF_DTYPE_I64 = 27, // Plain int64
GGUF_DTYPE_F64 = 28, // IEEE 754 float64
GGUF_DTYPE_IQ1_M = 29, // Importance quant, 1-bit mixed
GGUF_DTYPE_BF16 = 30, // Brain float16
} GGUFTensorDType;
Note the gap at codes 4 and 5 – those were Q4_2 and Q4_3, experimental quantization types that were removed early in GGML’s history. The enum preserves backward compatibility by keeping the numbering stable.
Akunu’s weight loader (weight_store.cpp) handles the common types with
explicit byte-size calculations:
Type Block Bytes/Block Bits/Weight Notes
--------- ------ ----------- ----------- ------------------
F32 1 elem 4 32 Full precision
F16 1 elem 2 16 Half precision
BF16 1 elem 2 16 Brain float
Q4_0 32 elem 18 4.5 Scale + 4-bit quants
Q4_1 32 elem 20 5.0 Scale+min + 4-bit
Q5_0 32 elem 22 5.5 Scale + 5-bit quants
Q8_0 32 elem 34 8.5 Scale + 8-bit quants
Q2_K 256 elem 84 2.625 K-quant super-block
Q3_K 256 elem 110 3.4375 K-quant super-block
Q4_K 256 elem 144 4.5 K-quant super-block
Q5_K 256 elem 176 5.5 K-quant super-block
Q6_K 256 elem 210 6.5625 K-quant super-block
Q8_K 256 elem 292 9.125 K-quant super-block
The “bits per weight” column gives you the effective compression ratio. Q4_0 isn’t exactly 4 bits per weight – the scale factor adds overhead, bringing it to 4.5 bits. K-quants are even less round because of their complex multi-level structure.
We’ll go into the byte-level details of each quantization format in a dedicated chapter. For now, just know that the GGUF parser doesn’t care about the internal structure of quantized blocks – it just needs to know the total byte count to hand off to the weight store.
The Parser Implementation: mmap and Cursors
Akunu’s GGUF parser is written in C with a C++ implementation file (it exposes a C API for maximum compatibility). The core strategy is simple:
- Memory-map the entire file
- Walk through it with a cursor pointer
- Build hash maps for O(1) lookup by name
Let’s trace through gguf_open() step by step.
Step 1: Open and mmap
GGUFFile gguf_open(const char *path) {
int fd = open(path, O_RDONLY);
struct stat st;
fstat(fd, &st);
size_t file_size = (size_t)st.st_size;
void *mapped = mmap(nullptr, file_size, PROT_READ,
MAP_PRIVATE, fd, 0);
GGUFFileImpl *f = new GGUFFileImpl();
f->fd = fd;
f->mmap_addr = mapped;
f->mmap_len = file_size;
f->cursor = (const uint8_t*)mapped;
f->end = f->cursor + file_size;
// ...
}
The MAP_PRIVATE flag means modifications to the mapped region (which we never
make) would be copy-on-write. PROT_READ ensures we can only read. The OS will
page in data lazily as we access it.
The minimum file size check is 24 bytes (the header). Anything smaller can’t possibly be a valid GGUF file.
Step 2: Parse the header
uint32_t magic = read_u32(f); // advances cursor by 4
uint32_t version = read_u32(f); // advances cursor by 4
uint64_t tensor_count = read_u64(f); // advances cursor by 8
uint64_t kv_count = read_u64(f); // advances cursor by 8
The read_* functions are thin wrappers around memcpy + cursor advance:
static inline uint32_t read_u32(GGUFFileImpl *f) {
if (!has_bytes(f, 4)) return 0;
uint32_t v;
memcpy(&v, f->cursor, 4);
f->cursor += 4;
return v;
}
Why memcpy instead of a direct cast like *(uint32_t*)f->cursor? Because
direct casts would be undefined behavior if the cursor isn’t aligned to a
4-byte boundary. GGUF strings have variable length, so after reading a string,
the cursor can be at any alignment. memcpy is always safe and modern
compilers optimize it into the same instruction as a direct load when they
can prove alignment.
Step 3: Parse metadata
f->metadata.reserve(kv_count);
for (uint64_t i = 0; i < kv_count; i++) {
GGUFMetadataKV kv;
kv.key = read_string(f);
uint32_t vtype = read_u32(f);
read_metadata_value(f, &kv, vtype);
f->metadata_map[kv.key] = f->metadata.size();
f->metadata.push_back(kv);
}
Each KV pair is read sequentially (you have to, since they’re variable-length).
The key string is allocated on the heap and tracked in owned_strings for
cleanup. The metadata_map provides O(1) lookup by key name.
Step 4: Parse tensor info
f->tensors.reserve(tensor_count);
for (uint64_t i = 0; i < tensor_count; i++) {
GGUFTensorInfo ti;
ti.name = read_string(f);
ti.n_dims = read_u32(f);
ti.n_elements = 1;
for (uint32_t d = 0; d < ti.n_dims; d++) {
ti.dims[d] = read_u64(f);
ti.n_elements *= ti.dims[d];
}
ti.dtype = read_u32(f);
ti.offset = read_u64(f);
f->tensor_map[ti.name] = f->tensors.size();
f->tensors.push_back(ti);
}
Same pattern: sequential parse, build a hash map. The tensor_map maps
tensor names to indices in the tensors vector.
Step 5: Compute data base
size_t header_bytes = (size_t)(f->cursor - (const uint8_t*)f->mmap_addr);
size_t aligned = (header_bytes + GGUF_ALIGNMENT - 1)
& ~(GGUF_ALIGNMENT - 1);
f->data_base = (const uint8_t*)f->mmap_addr + aligned;
At this point, parsing is complete. The entire header has been walked, and we know where the data section begins.
The Internal State: GGUFFileImpl
Let’s look at the complete internal state structure:
struct GGUFFileImpl {
int fd = -1; // File descriptor (kept open)
void *mmap_addr = MAP_FAILED; // mmap base address
size_t mmap_len = 0; // mmap'd region size
const uint8_t *cursor; // Current parse position
const uint8_t *end; // End of mmap'd region
const uint8_t *data_base; // Start of tensor data section
std::vector<GGUFMetadataKV> metadata; // Parsed metadata
std::vector<GGUFTensorInfo> tensors; // Parsed tensor info
// O(1) lookup maps
std::unordered_map<std::string, size_t> metadata_map;
std::unordered_map<std::string, size_t> tensor_map;
// Ownership tracking
std::vector<char*> owned_strings;
std::vector<const char**> owned_str_arrays;
std::vector<float*> owned_flt_arrays;
};
The memory layout in action:
Process virtual memory:
+----------------------------------------------------------+
| Stack / heap / code / etc. |
+----------------------------------------------------------+
| GGUFFileImpl (heap allocated) |
| - metadata vector (small, ~100 entries) |
| - tensors vector (small, ~200 entries) |
| - metadata_map hash table |
| - tensor_map hash table |
| - owned_strings pointers |
+----------------------------------------------------------+
| mmap'd region (file_size bytes) |
| [header | metadata | tensor_info | pad | tensor_data] |
| ^cursor walks through this during parse |
| ^data_base points to start of tensor_data |
+----------------------------------------------------------+
The key insight is that the mmap’d region stays mapped for the lifetime of
the GGUFFile. When you ask for tensor data via gguf_tensor_data(), you
get a raw pointer into this region. No copying at the parser level – the
copy happens later when the weight store uploads to the GPU.
The C API: Functions and Usage
The parser exposes a clean C API. Here’s the complete interface:
// Lifecycle
GGUFFile gguf_open(const char *path);
void gguf_close(GGUFFile file);
// Counts
uint64_t gguf_tensor_count(GGUFFile file);
uint64_t gguf_metadata_count(GGUFFile file);
// Tensor lookup
const GGUFTensorInfo *gguf_get_tensor(GGUFFile file, const char *name);
const GGUFTensorInfo *gguf_get_tensor_by_index(GGUFFile file, uint64_t index);
// Metadata lookup
const GGUFMetadataKV *gguf_get_metadata(GGUFFile file, const char *key);
const GGUFMetadataKV *gguf_get_metadata_by_index(GGUFFile file, uint64_t index);
// Tensor data access
const void *gguf_tensor_data(GGUFFile file, const GGUFTensorInfo *info);
// Array helpers
const char **gguf_get_string_array(GGUFFile file, const char *key,
uint64_t *out_count);
const float *gguf_get_float_array(GGUFFile file, const char *key,
uint64_t *out_count);
The typical usage pattern (from WeightStore::open()):
gguf_ = gguf_open(path.c_str());
if (!gguf_) return false;
// Read a metadata int
const GGUFMetadataKV *kv = gguf_get_metadata(gguf_, "llama.block_count");
int n_layers = kv->value.u32; // 32
// Get a tensor
const GGUFTensorInfo *info = gguf_get_tensor(gguf_, "blk.0.attn_q.weight");
const void *data = gguf_tensor_data(gguf_, info);
// data now points into the mmap'd file -- zero-copy!
// Get tokenizer vocabulary
uint64_t vocab_size;
const char **tokens = gguf_get_string_array(gguf_, "tokenizer.ggml.tokens",
&vocab_size);
The String Array Helper: Lazy Decoding
The gguf_get_string_array() function deserves a closer look because it shows
the lazy decoding strategy in action:
const char **gguf_get_string_array(GGUFFile file, const char *key,
uint64_t *out_count) {
const GGUFMetadataKV *kv = gguf_get_metadata(file, key);
if (!kv || kv->type != GGUF_TYPE_ARRAY ||
kv->value.u32 != GGUF_TYPE_STRING)
return nullptr;
uint64_t count = kv->array_len;
const char **arr = (const char**)malloc(count * sizeof(const char*));
file->owned_str_arrays.push_back(arr);
// Walk the raw array data in the mmap
const uint8_t *p = (const uint8_t*)kv->array_data;
for (uint64_t i = 0; i < count; i++) {
uint64_t slen;
memcpy(&slen, p, 8);
p += 8;
char *s = (char*)malloc(slen + 1);
memcpy(s, p, slen);
s[slen] = '\0';
p += slen;
file->owned_strings.push_back(s);
arr[i] = s;
}
*out_count = count;
return arr;
}
This function walks the raw mmap’d bytes of the array, extracting each string. It’s called at most once per key (the result isn’t cached, but callers typically only call it once during initialization). For a 128K vocabulary, this allocates 128K small strings. Not the most memory-efficient approach, but it’s simple and the total memory is tiny compared to the gigabytes of tensor data.
The float array helper is simpler – it just does a bulk memcpy:
const float *gguf_get_float_array(GGUFFile file, const char *key,
uint64_t *out_count) {
const GGUFMetadataKV *kv = gguf_get_metadata(file, key);
uint64_t count = kv->array_len;
float *arr = (float*)malloc(count * sizeof(float));
memcpy(arr, kv->array_data, count * sizeof(float));
*out_count = count;
return arr;
}
Float arrays are already in the right format in the mmap’d data (IEEE 754 little-endian), so it’s a straight copy.
Memory Management: Who Owns What?
The GGUFFileImpl tracks all dynamically allocated memory through three vectors:
owned_strings: All strings parsed from the file (tensor names, metadata keys, string values, string array elements)owned_str_arrays: Theconst char**arrays returned bygguf_get_string_array()owned_flt_arrays: Thefloat*arrays returned bygguf_get_float_array()
On close, everything is freed:
void gguf_close(GGUFFile file) {
if (file->mmap_addr != MAP_FAILED) {
munmap(file->mmap_addr, file->mmap_len);
}
if (file->fd >= 0) {
close(file->fd);
}
for (char *s : file->owned_strings) {
free(s);
}
for (const char **a : file->owned_str_arrays) {
free(a);
}
for (float *a : file->owned_flt_arrays) {
free(a);
}
delete file;
}
This is a simple arena-style pattern: allocate freely during parsing, free everything at once on close. No individual deallocation, no reference counting. The parser’s lifetime matches the model’s lifetime, so this works perfectly.
Bounds Checking and Safety
The parser includes several safety checks:
Cursor bounds checking: Every read_* function checks has_bytes() before
reading:
static inline bool has_bytes(const GGUFFileImpl *f, size_t n) {
return (size_t)(f->end - f->cursor) >= n;
}
Maximum dimensions: Tensors can have at most 4 dimensions. More than that triggers an error.
Element count overflow: Dimension multiplication checks for uint64 overflow before proceeding.
Tensor data bounds: gguf_tensor_data() validates that the offset is within
the mmap’d region:
const void *gguf_tensor_data(GGUFFile file, const GGUFTensorInfo *info) {
if (info->offset >= (size_t)(file->end - file->data_base))
return nullptr;
return file->data_base + info->offset;
}
These checks protect against malformed or malicious GGUF files. They’re not exhaustive (there’s no check that tensor data regions don’t overlap, for example), but they catch the most common corruption scenarios.
A Real GGUF File, Byte by Byte
Let’s walk through the first few hundred bytes of a real GGUF file to tie everything together. Say we have a small model with 2 metadata entries and 3 tensors:
Offset Hex Meaning
------ --- -------
0x0000 47 47 55 46 Magic: "GGUF"
0x0004 03 00 00 00 Version: 3
0x0008 03 00 00 00 00 00 00 00 Tensor count: 3
0x0010 02 00 00 00 00 00 00 00 KV count: 2
--- Metadata KV #0 ---
0x0018 14 00 00 00 00 00 00 00 Key length: 20
0x0020 67 65 6E 65 72 61 6C 2E "general."
0x0028 61 72 63 68 69 74 65 63 "architec"
0x0030 74 75 72 65 "ture"
0x0034 08 00 00 00 Type: STRING (8)
0x0038 05 00 00 00 00 00 00 00 String length: 5
0x0040 6C 6C 61 6D 61 "llama"
--- Metadata KV #1 ---
0x0045 15 00 00 00 00 00 00 00 Key length: 21
0x004D 6C 6C 61 6D 61 2E 62 6C "llama.bl"
0x0055 6F 63 6B 5F 63 6F 75 6E "ock_coun"
0x005D 74 "t"
0x005E 04 00 00 00 Type: UINT32 (4)
0x0062 20 00 00 00 Value: 32
--- Tensor info #0 ---
0x0066 11 00 00 00 00 00 00 00 Name length: 17
0x006E 74 6F 6B 65 6E 5F 65 6D "token_em"
0x0076 62 64 2E 77 65 69 67 68 "bd.weigh"
0x007E 74 "t"
0x007F 02 00 00 00 n_dims: 2
0x0083 00 10 00 00 00 00 00 00 dims[0]: 4096
0x008B 00 80 00 00 00 00 00 00 dims[1]: 32768
0x0093 08 00 00 00 dtype: Q8_0 (8)
0x0097 00 00 00 00 00 00 00 00 offset: 0
--- (more tensor info entries...) ---
--- PADDING to 32-byte alignment ---
--- TENSOR DATA ---
Notice how everything is little-endian and tightly packed. There’s no padding between fields within a section – the variable-length strings make it impossible to use fixed offsets. You have to parse sequentially.
Performance Characteristics
Let’s think about the performance of this parser:
Opening a file: The gguf_open() function is dominated by the metadata
parse time. For a typical model with ~100 metadata entries and ~200 tensors,
this takes microseconds. The mmap itself is near-instantaneous (it just sets
up page table entries).
First tensor access: The first time you access tensor data via
gguf_tensor_data(), the OS has to page in the data from disk. For an SSD,
this is on the order of microseconds per page (4KB). A 9MB Q4_0 tensor
requires about 2,250 pages, so first access is roughly 1-2ms.
Subsequent tensor accesses: After the data is paged in, access is just a pointer dereference – nanoseconds. The OS page cache keeps recently accessed pages in RAM.
Memory usage: The parser itself uses very little heap memory. The mmap’d region uses virtual address space but only consumes physical RAM for pages that have been accessed. A 4GB model file mapped but unaccessed uses essentially zero RAM.
Hash map lookups: Both metadata_map and tensor_map use
std::unordered_map, giving O(1) average-case lookup. With ~200 tensors,
the hash table overhead is negligible.
The overall design is optimized for the common case: open the file once, access each tensor once during model load, then never touch the parser again until shutdown. The mmap approach means the OS manages the caching, which is hard to beat for this access pattern.
Why Not a JSON Parser? Why Not Protobuf?
It’s worth asking why GGUF exists as a custom binary format instead of using an existing serialization framework. The answer is practical:
-
Zero-copy tensor data: Tensor data needs to be passed directly to GPU upload functions. With a binary format and mmap, you get a raw pointer into the file with no deserialization overhead. JSON or protobuf would require parsing into intermediate structures and copying.
-
Self-contained: A single file containing everything means no directory management, no missing companion files, no version mismatches between separate metadata and data files.
-
Streaming-friendly: The metadata is at the front of the file, so you can read the model config without touching the (much larger) tensor data. This matters for model inspection tools.
-
Alignment control: The 32-byte alignment of the data section is critical for GPU efficiency. Generic formats don’t give you this level of control.
-
No dependencies: The parser is self-contained C code. No JSON library, no protobuf compiler, no generated code.
The tradeoff is that you need a custom parser, and the format is harder to inspect with generic tools. But for the specific use case of distributing and loading quantized model weights, it’s hard to argue with the result: a simple format with a simple parser that delivers excellent performance.
In the next chapter, we’ll look at the other side of the coin: SafeTensors and the MLX ecosystem, which took a very different approach to the same problem.
SafeTensors and MLX Formats
GGUF is not the only game in town. The Hugging Face ecosystem overwhelmingly distributes models in the SafeTensors format, and Apple’s MLX framework has made SafeTensors the default for its quantized model exports. Akunu supports both GGUF and MLX-quantized SafeTensors as first-class citizens through its WeightProvider abstraction, which auto-detects the format and presents a unified interface to the rest of the engine.
This chapter digs into how SafeTensors files are structured, how MLX layers its quantization scheme on top of SafeTensors, and how Akunu’s SafeTensorsParser and MLXWeightStore classes handle the entire pipeline from raw bytes on disk to GPU-resident weight buffers.
The SafeTensors File Format
SafeTensors was designed by Hugging Face as a secure, zero-copy alternative to Python pickle-based formats like PyTorch’s .bin files.1 The design is deliberately simple: a file consists of exactly two parts.
+--------------------------------------------------+
| 8 bytes: header_len (little-endian uint64) |
+--------------------------------------------------+
| header_len bytes: UTF-8 JSON header |
| { |
| "tensor_name": { |
| "dtype": "F16", |
| "shape": [4096, 4096], |
| "data_offsets": [0, 33554432] |
| }, |
| "__metadata__": { |
| "format": "mlx", |
| "quantization_config": "..." |
| } |
| } |
+--------------------------------------------------+
| Tensor data (contiguous, aligned) |
| [tensor_0 bytes] [tensor_1 bytes] ... |
+--------------------------------------------------+
That is it. No nested containers, no variable-length integer encodings, no type-length-value gymnastics. The entire schema lives in a single JSON object that you can parse with any JSON library. Let us walk through each section.
The 8-Byte Length Prefix
The first 8 bytes of the file encode the size of the JSON header as a little-endian unsigned 64-bit integer. This tells you exactly how many bytes to read before the tensor data begins. In Akunu’s SafeTensorsParser::open():
// Read header length (little-endian u64)
uint64_t header_len = 0;
memcpy(&header_len, data_, 8);
data_offset_ = 8 + header_len;
The data_offset_ field marks where the raw tensor bytes start. Every tensor’s data_offsets field in the header is relative to this point – not to the start of the file. This is a common source of confusion when manually inspecting SafeTensors files with a hex editor.
The JSON Header
The header is a flat JSON object where each key is a tensor name and each value is a small descriptor:
| Field | Type | Description |
|---|---|---|
dtype | string | Data type: "F16", "BF16", "F32", "U32", "I8" |
shape | int array | Dimensions, e.g. [4096, 4096] |
data_offsets | int array | [start, end] byte offsets from data section start |
There is also an optional __metadata__ key that carries arbitrary string key-value pairs. MLX uses this section to store quantization configuration, and Akunu reads it to detect quantization parameters.
Zero-Copy Access via mmap
Akunu never copies tensor data into heap-allocated buffers during parsing. The entire file is memory-mapped:
data_ = (const uint8_t *)mmap(nullptr, file_size_, PROT_READ,
MAP_PRIVATE, fd_, 0);
When you call tensor_data(), you get a direct pointer into the mmap’d region:
const void *tensor_data(const SafeTensorInfo& info) const {
size_t offset = data_offset_ + info.data_start;
return data_ + offset;
}
This means the kernel’s page fault handler brings tensor data into physical memory on demand, one page at a time. For large models with hundreds of tensors, this is significantly faster than read() calls, because you only touch the pages you actually need. On Apple Silicon with unified memory, these pages can be directly referenced by the GPU without any additional copy – though in practice Akunu does allocate Metal buffers and memcpy into them for format conversion reasons we will discuss shortly.
Akunu’s Minimal JSON Parser
You might notice that SafeTensorsParser includes a hand-rolled JSON parser rather than pulling in a library like nlohmann/json or simdjson. This is a deliberate choice. The SafeTensors header JSON is structurally simple – it is a flat object of objects, each with three well-known fields. The parser only needs to handle strings, integers, arrays of integers, and nested objects. Akunu’s parser handles this in about 100 lines of code:
parse_header()
|
+-- for each key in top-level object:
| if key == "__metadata__" -> parse_metadata()
| else -> parse dtype, shape, data_offsets into SafeTensorInfo
|
+-- skip_value() for any unknown fields
This avoids a dependency, keeps compile times low, and is fast enough that header parsing is never a bottleneck (even a 10MB header parses in under a millisecond).
The MLX Weight Store
While SafeTensorsParser handles the raw file format, MLXWeightStore adds the intelligence layer: name mapping, config extraction, quantization detection, and data type conversion.
Opening an MLX Model
An MLX model is typically a directory containing:
model_directory/
model.safetensors <-- weights
config.json <-- architecture + quantization config
tokenizer.json <-- tokenizer (handled separately)
tokenizer_config.json <-- tokenizer settings
When you call MLXWeightStore::open(), it:
- Detects whether the path is a directory or a single
.safetensorsfile - Opens the SafeTensors file via the parser
- Parses
config.jsonfor model architecture and quantization info - Builds the name mapping from MLX tensor names to Akunu’s canonical names
bool MLXWeightStore::open(const std::string& path) {
struct stat st;
stat(path.c_str(), &st);
if (S_ISDIR(st.st_mode)) {
model_dir_ = path;
safetensors_path = path + "/model.safetensors";
} else {
safetensors_path = path;
model_dir_ = extract_directory(path);
}
parser_.open(safetensors_path);
parse_config_json(model_dir_);
build_name_mapping(config_.n_layers);
}
Config Extraction from config.json
MLX models store their architecture configuration in a standard Hugging Face config.json. Akunu extracts the fields it needs using minimal JSON string search functions:
| config.json key | AkunuModelConfig field | Example value |
|---|---|---|
model_type | architecture | "llama" |
hidden_size | dim | 4096 |
num_hidden_layers | n_layers | 32 |
num_attention_heads | n_heads | 32 |
num_key_value_heads | n_kv_heads | 8 |
intermediate_size | ffn_dim | 11008 |
vocab_size | vocab_size | 32000 |
max_position_embeddings | max_seq_len | 4096 |
rms_norm_eps | norm_eps | 1e-5 |
rope_theta | rope_theta | 10000.0 |
The quantization configuration is nested inside quantization_config or quantization:
{
"quantization_config": {
"bits": 4,
"group_size": 64
}
}
Akunu tries both key names because different MLX exporters use different conventions:
for (const char *qkey : {"\"quantization_config\"", "\"quantization\""}) {
auto qpos = json.find(qkey);
if (qpos != std::string::npos) {
// extract bits and group_size
}
}
Name Mapping: MLX to Canonical
MLX models use Hugging Face tensor naming conventions, while Akunu internally uses a simplified canonical naming scheme. The mapping is defined as a static table of rules:
| MLX Name Pattern | Akunu Canonical Name |
|---|---|
model.embed_tokens.weight | token_embedding.weight |
model.norm.weight | output_norm.weight |
lm_head.weight | output.weight |
model.layers.{n}.self_attn.q_proj.weight | layers.{n}.attention.q.weight |
model.layers.{n}.self_attn.k_proj.weight | layers.{n}.attention.k.weight |
model.layers.{n}.self_attn.v_proj.weight | layers.{n}.attention.v.weight |
model.layers.{n}.self_attn.o_proj.weight | layers.{n}.attention.output.weight |
model.layers.{n}.mlp.gate_proj.weight | layers.{n}.ffn.gate.weight |
model.layers.{n}.mlp.up_proj.weight | layers.{n}.ffn.up.weight |
model.layers.{n}.mlp.down_proj.weight | layers.{n}.ffn.down.weight |
model.layers.{n}.input_layernorm.weight | layers.{n}.attention_norm.weight |
model.layers.{n}.post_attention_layernorm.weight | layers.{n}.ffn_norm.weight |
The {n} placeholder is expanded for each layer during build_name_mapping(). The function iterates over all rules, expanding layer-indexed patterns for layers 0 through n_layers - 1, and only records a mapping if the tensor actually exists in the SafeTensors file:
void MLXWeightStore::build_name_mapping(int n_layers) {
for (int r = 0; r < kNumMLXRules; r++) {
if (strstr(pattern, "{n}")) {
for (int layer = 0; layer < n_layers; layer++) {
std::string mlx_name = expand_rule(pattern, layer);
if (parser_.find(mlx_name)) {
std::string can = expand_rule(canonical, layer);
name_map_[can] = mlx_name;
}
}
} else {
if (parser_.find(pattern))
name_map_[canonical] = pattern;
}
}
}
This existence check is important because not all architectures have all tensors. For example, Qwen3 has QK-norm weights (q_norm.weight, k_norm.weight) that LLaMA does not. The mapping table includes rules for both, but only the ones that actually exist in the file get registered.
MLX Quantization: The Three-Tensor Pack
This is where things get interesting. When MLX exports a quantized model, it does not pack everything into a single blob like GGUF does. Instead, each quantized linear layer produces three separate tensors in the SafeTensors file:
layer.0.self_attn.q_proj.weight <- packed U32 integers
layer.0.self_attn.q_proj.scales <- F16 or BF16 scale factors
layer.0.self_attn.q_proj.biases <- F16 or BF16 zero-points
The Packed Weight Tensor
The weight tensor stores quantized values packed into 32-bit unsigned integers. For 4-bit quantization, each U32 holds 8 values (32 / 4 = 8). The tensor shape is [N, K_packed] where K_packed = K * bits / 32.
For a [4096, 4096] weight matrix quantized to 4-bit:
K_packed = 4096 * 4 / 32 = 512
weight tensor shape: [4096, 512] of U32
weight tensor size: 4096 * 512 * 4 = 8,388,608 bytes (8 MB)
vs. F16 original: 4096 * 4096 * 2 = 33,554,432 bytes (32 MB)
compression: 4x
Scales and Biases
The scale and bias tensors have shape [N, K / group_size]. With a typical group size of 64 and K = 4096, that is [4096, 64]. Each group of group_size quantized values shares one scale and one bias (zero-point).
The dequantization formula for a single value is:
value = scale * quantized_int + bias
This is an asymmetric affine quantization scheme – the bias term allows the quantization grid to be offset from zero, which can better represent distributions that are not centered at zero.2
GPU Buffer Layout
When Akunu loads a quantized tensor, it packs all three components into a single contiguous GPU buffer:
+---------------------------------------------+
| Packed U32 weights |
| (N * K * bits / 8 bytes) |
+---------------------------------------------+
| F16 scales |
| (N * K / group_size * 2 bytes) |
+---------------------------------------------+
| F16 biases |
| (N * K / group_size * 2 bytes) |
+---------------------------------------------+
This layout is what the MLX dequantization kernels expect. The kernel receives the total buffer and uses the weight_bytes parameter to find where the scales section begins:
scales_offset = weight_bytes
biases_offset = weight_bytes + n_scale_elements * 2
The loading code in load_quantized_tensor() handles this packing:
Buffer MLXWeightStore::load_quantized_tensor(const std::string& mlx_name) {
// ... find weight, scales, biases tensors ...
size_t total = w_bytes + s_elements * 2 + b_elements * 2;
Buffer buf = device_.allocate(total);
// Copy weights
memcpy(buf.contents, w_data, w_bytes);
// Copy/convert scales (BF16 -> F16 if needed)
uint8_t *dst = (uint8_t *)buf.contents + w_bytes;
// ... copy or convert scales ...
// Copy/convert biases
dst = (uint8_t *)buf.contents + w_bytes + s_elements * 2;
// ... copy or convert biases ...
return buf;
}
BF16 to F16 Conversion
A particularly tricky detail: MLX often stores scales and biases in BF16 (bfloat16) format, but Apple’s Metal shading language only gained native BF16 support with the M4 GPU family.3 For older GPUs (M1/M2/M3), Akunu must convert BF16 values to F16 on the fly during loading.
The conversion goes through F32 as an intermediate:
BF16 bits: [s][eeeeeeee][mmmmmmm] (1+8+7 = 16 bits)
F32 bits: [s][eeeeeeee][mmmmmmm 0000...] (1+8+23 = 32 bits)
Step 1: BF16 -> F32 (left-shift by 16)
Step 2: F32 -> F16 (hardware cast via __fp16)
In code:
const uint16_t *src = (const uint16_t *)s_data;
uint16_t *f16_dst = (uint16_t *)dst;
for (size_t i = 0; i < s_elements; i++) {
uint32_t f32_bits = (uint32_t)src[i] << 16; // BF16 -> F32
float val;
memcpy(&val, &f32_bits, 4);
__fp16 h = (__fp16)val; // F32 -> F16
memcpy(&f16_dst[i], &h, 2);
}
This is a lossy conversion. BF16 has 8 exponent bits and 7 mantissa bits, while F16 has 5 exponent bits and 10 mantissa bits. BF16 has wider dynamic range but less precision; F16 has narrower range but more precision. For scale/bias values in quantized models, this conversion is perfectly acceptable since these values are themselves approximations.4
The same conversion applies to raw (non-quantized) tensors. If a SafeTensors file contains BF16 tensors (common in models exported from PyTorch), Akunu converts them to F16 for Metal compatibility:
if (info->dtype == "BF16") {
std::vector<uint16_t> f16(n_elements);
const uint16_t *bf16 = (const uint16_t *)data;
for (size_t i = 0; i < n_elements; i++) {
// BF16 -> F32 -> F16
}
return device_.allocate(f16.data(), n_elements * 2);
}
Dynamic Quantization Detection
Akunu does not require the user to specify whether a model is quantized or what bit width it uses. The get_tensor() method dynamically detects quantization by looking for companion .scales tensors:
Buffer MLXWeightStore::get_tensor(const std::string& canonical_name) {
const std::string& mlx_name = name_map_[canonical_name];
// Check for .scales companion
std::string scales_name = mlx_name;
auto wpos = scales_name.rfind(".weight");
if (wpos != std::string::npos)
scales_name.replace(wpos, 7, ".scales");
if (quant_bits_ > 0 && parser_.find(scales_name)) {
buf = load_quantized_tensor(mlx_name);
// Set effective dtype: 99=Q3, 100=Q4, 102=Q6, 101=Q8
} else {
buf = load_raw_tensor(mlx_name);
// Set effective dtype: 1 (F16)
}
}
The effective dtype codes (99, 100, 101, 102) correspond to the MLX entries in Akunu’s DTypeDescriptor table, which maps each quantization format to the correct kernel names and dispatch geometry. This allows the same build_dispatch_table() code to work transparently with both GGUF and MLX models.
Weight Fusion
For performance-critical paths like the fused QKV projection, Akunu can fuse multiple quantized weight matrices into a single buffer. The challenge with MLX’s three-tensor layout is that you cannot simply concatenate the raw buffers – the scales and biases from different tensors need to be grouped together:
Unfused (3 separate tensors, each with [W|S|B]):
Q: [Wq | Sq | Bq]
K: [Wk | Sk | Bk]
V: [Wv | Sv | Bv]
Fused (1 buffer):
[Wq | Wk | Wv | Sq | Sk | Sv | Bq | Bk | Bv]
The kernel expects all weights contiguous, then all scales contiguous, then all biases contiguous. The fuse_mlx_packed() function handles this rearrangement:
// Copy weights section
for each tensor:
memcpy(fused + w_off, src, sec.w_bytes)
w_off += sec.w_bytes
// Copy scales section
for each tensor:
memcpy(fused + total_w + s_off*2, src + sec.w_bytes, sec.s_elements*2)
s_off += sec.s_elements
// Copy biases section
for each tensor:
memcpy(fused + total_w + total_s*2 + s_off*2,
src + sec.w_bytes + sec.s_elements*2, sec.s_elements*2)
This fusion happens once during model initialization. The fused buffer is cached and reused for every forward pass, so the cost is amortized.
The WeightProvider Abstraction
Above both WeightStore (GGUF) and MLXWeightStore sits the WeightProvider class, which provides a unified interface:
WeightProvider
|
+-- detect_format(path)
| directory or .safetensors -> MLX_SAFETENSORS
| otherwise -> GGUF
|
+-- get_tensor(canonical_name) -> Buffer
+-- get_dtype(canonical_name) -> uint32_t
+-- has_tensor(canonical_name) -> bool
+-- fuse_weights(a, b) -> Buffer
+-- fuse_weights(a, b, c) -> Buffer
+-- get_config() -> AkunuModelConfig
+-- get_metadata_string(key) -> string
The rest of the engine – build_dispatch_table(), the chain decoder, prefill – never touches MLXWeightStore or WeightStore directly. They go through WeightProvider, which delegates to the appropriate backend. This is what makes format support transparent: adding a new weight format (say, ONNX or TensorFlow SavedModel) would only require implementing a new backend class and adding a case to detect_format().
Comparison: SafeTensors vs GGUF
Both formats serve the same purpose – storing model weights efficiently for inference – but they make fundamentally different trade-offs:
| Feature | SafeTensors | GGUF |
|---|---|---|
| Header format | JSON | Binary (type-length-value) |
| Metadata | JSON key-value in __metadata__ | Typed KV pairs (string, int, float, array) |
| Tensor descriptor | dtype + shape + byte offsets | dtype + dimensions + offset |
| Quantization | External (separate scales/biases tensors) | Internal (packed blocks with embedded scales) |
| Tokenizer | Separate tokenizer.json file | Embedded in GGUF metadata |
| Architecture config | Separate config.json | Embedded in GGUF metadata |
| Single file | No (directory of files) | Yes (everything in one .gguf) |
| Ecosystem | Hugging Face, MLX, PyTorch | llama.cpp, whisper.cpp, Akunu |
| Parse complexity | Very low (JSON + mmap) | Medium (binary format, many tensor types) |
| Zero-copy possible | Yes (mmap) | Yes (mmap) |
The key philosophical difference: GGUF is self-contained (one file has everything including tokenizer), while SafeTensors is modular (weights, config, and tokenizer are separate files in a directory). GGUF bakes quantization into the tensor format itself, while MLX/SafeTensors keeps quantization as a layer on top of standard data types.5
For Akunu, both approaches work. The WeightProvider abstraction means the choice is purely a matter of where you got your model from – Hugging Face models come as SafeTensors directories, llama.cpp quantized models come as GGUF files, and Akunu handles both identically from the engine’s perspective.
Buffer Caching
Both WeightStore and MLXWeightStore cache GPU buffers after first load. When get_tensor() is called for a tensor that has already been loaded, it returns the cached buffer immediately:
Buffer MLXWeightStore::get_tensor(const std::string& canonical_name) {
auto cache_it = buffer_cache_.find(canonical_name);
if (cache_it != buffer_cache_.end())
return cache_it->second;
// ... load and cache ...
buffer_cache_[canonical_name] = buf;
return buf;
}
Fused weight buffers are also cached under composite keys like "layers.0.attention.q.weight+layers.0.attention.k.weight+layers.0.attention.v.weight". This means the fusion rearrangement only happens once, and subsequent calls return the pre-fused buffer.
On close(), all cached buffers are freed through the Device:
void MLXWeightStore::close() {
for (auto& [name, buf] : buffer_cache_)
device_.free_buffer(buf);
buffer_cache_.clear();
}
Summary
The SafeTensors/MLX weight pipeline in Akunu is a clean layered design:
SafeTensorsParser (raw file format)
|
MLXWeightStore (name mapping, quant detection,
| BF16 conversion, 3-tensor packing)
|
WeightProvider (unified GGUF/MLX interface)
|
build_dispatch_table() (format-agnostic)
Each layer has a single responsibility, and the abstractions are tight enough that the dispatch table builder genuinely does not know or care whether it is working with a GGUF Q4_0 model or an MLX 4-bit SafeTensors model. The dtype code (2 for Q4_0, 100 for MLX Q4) routes to the correct kernel through the DTypeDescriptor table, and the weight data arrives in the buffer layout that each kernel expects.
-
Hugging Face, “SafeTensors: A simple and safe way to store and distribute tensors,” 2023. The format was designed specifically to prevent arbitrary code execution vulnerabilities inherent in Python pickle deserialization. See https://github.com/huggingface/safetensors. ↩
-
Asymmetric quantization with a zero-point (bias) can represent the range [min, max] directly, while symmetric quantization forces the range to be [-max, max]. For activations and weights that are not centered at zero, asymmetric quantization wastes fewer quantization levels. ↩
-
Apple’s Metal Shading Language gained
bfloatsupport with the Apple GPU Family 9 (M4). On M1-M3, BF16 textures and buffer reads are not supported in Metal shaders. ↩ -
The precision loss from BF16-to-F16 conversion in scale values is typically on the order of 0.1% relative error, which is negligible compared to the quantization error from the 4-bit weight compression itself. ↩
-
This difference has practical implications for tooling. GGUF files can be inspected with
gguf-dumpto see everything about a model. SafeTensors models require reading multiple files to get the full picture. On the other hand, SafeTensors files are trivially inspectable with any JSON-aware tool. ↩
Quantization Formats In Depth
This chapter is the definitive reference for every quantization format Akunu supports. We will go byte by byte through the memory layouts, work through dequantization by hand, and build the mental model you need to write or debug quantized GEMV kernels on Metal. If you have ever wondered what exactly lives inside a Q4_0 block, or how K-quant super-blocks manage to pack 256 values with mixed bit widths, this is the chapter.
Why Quantization Matters
A 7B parameter model in F16 requires 14 GB of memory. That exceeds the unified memory of every base-model MacBook Air and most MacBook Pros. Quantize those weights to 4 bits and you are down to 3.5 GB – comfortably fitting on a machine with 8 GB of RAM, with plenty left over for KV cache and activations.
But quantization is not free. You are trading precision for memory, and the format you choose determines both the quality of that trade-off and the computational cost of dequantizing at inference time. GGUF alone defines over a dozen quantization formats. MLX adds its own family. Understanding them is essential for anyone working on inference engines.
GGUF Legacy Formats
These are the original formats from GGML/llama.cpp. They operate on fixed-size blocks of elements, each block containing packed quantized values plus per-block scale factors.
Q4_0: The Simplest Quantized Format
Q4_0 is where most people should start understanding quantization. It is symmetric 4-bit quantization with a single F16 scale per block of 32 elements.
Block layout (18 bytes total):
+------------------+----------------------------------+
| F16 scale (d) | 16 bytes of nibbles |
| 2 bytes | (32 x 4-bit values) |
+------------------+----------------------------------+
Byte 0-1 Bytes 2-17
Each byte in the nibble section holds two 4-bit values:
Byte layout: [lo_nibble : hi_nibble]
[ q[2i] : q[2i+1] ]
Bit positions: 7 6 5 4 3 2 1 0
^^^^^^^--------- hi nibble (bits 4-7) = q[2i+1]
^^^^^^^^ lo nibble (bits 0-3) = q[2i]
Dequantization formula:
x[i] = d * (q[i] - 8)
The subtraction of 8 centers the 4-bit range [0, 15] around zero, giving an effective range of [-8, 7]. This is symmetric quantization – the zero point is fixed at 8, not learned.
Worked example: suppose we have a block with scale d = 0.5 (as F16) and the first data byte is 0xA3.
Byte 0xA3 = 1010 0011 in binary
lo nibble = 0011 = 3 -> q[0] = 3
hi nibble = 1010 = 10 -> q[1] = 10
x[0] = 0.5 * (3 - 8) = 0.5 * (-5) = -2.5
x[1] = 0.5 * (10 - 8) = 0.5 * (2) = 1.0
Bit extraction in Metal:
// Extract two Q4_0 values from a byte
uint8_t byte = block_data[j];
int q_lo = (byte & 0x0F); // bits 0-3
int q_hi = (byte >> 4) & 0x0F; // bits 4-7
float x_lo = d * ((float)q_lo - 8.0f);
float x_hi = d * ((float)q_hi - 8.0f);
Size calculation:
| Parameter | Value |
|---|---|
| Block size | 32 elements |
| Scale overhead | 2 bytes (F16) |
| Data | 16 bytes (32 nibbles) |
| Total per block | 18 bytes |
| Bits per weight | 18 * 8 / 32 = 4.5 bpw |
The overhead of the scale factor means Q4_0 is not exactly 4 bits per weight – it is 4.5 bpw. This is a common source of confusion. The “4” in Q4_0 refers to the quantized value width, not the effective bits per weight.
Q4_1: Asymmetric 4-bit
Q4_1 adds a minimum value (zero-point) per block, enabling asymmetric quantization:
Block layout (20 bytes):
+------------------+------------------+------------------+
| F16 scale (d) | F16 minimum (m) | 16 bytes nibbles |
| 2 bytes | 2 bytes | (32 x 4-bit) |
+------------------+------------------+------------------+
Bytes 0-1 Bytes 2-3 Bytes 4-19
Dequantization:
x[i] = d * q[i] + m
No subtraction of 8 here – the minimum m handles the offset. The range [0, 15] maps to [m, m + 15*d]. This can better represent distributions that are not centered at zero, at the cost of 2 extra bytes per block.
| Parameter | Value |
|---|---|
| Block size | 32 elements |
| Scale + min overhead | 4 bytes |
| Data | 16 bytes |
| Total per block | 20 bytes |
| Bits per weight | 20 * 8 / 32 = 5.0 bpw |
Q8_0: 8-bit Symmetric
Q8_0 stores each value as a signed 8-bit integer with a single F16 scale per block of 32:
Block layout (34 bytes):
+------------------+----------------------------------+
| F16 scale (d) | 32 bytes of int8 values |
| 2 bytes | (32 x 8-bit signed) |
+------------------+----------------------------------+
Bytes 0-1 Bytes 2-33
Dequantization:
x[i] = d * q[i] // q[i] is int8, range [-128, 127]
No offset subtraction needed because int8 is already signed.
| Parameter | Value |
|---|---|
| Block size | 32 elements |
| Scale overhead | 2 bytes |
| Data | 32 bytes |
| Total per block | 34 bytes |
| Bits per weight | 34 * 8 / 32 = 8.5 bpw |
Q8_0 is primarily used for activations in mixed-precision inference, not for weight storage (since 8.5 bpw barely saves memory over F16’s 16 bpw). Its main advantage is that int8 dot products can use SIMD integer multiply-accumulate instructions, which are faster than F16 multiply-accumulate on some hardware.
K-Quants: Super-Block Architecture
The K-quant family (Q2_K through Q6_K) was introduced by @ikawrakow in llama.cpp to improve quantization quality at low bit widths.1 The key insight is that a single scale factor per 32 elements is too coarse for 2-3 bit quantization – the approximation error is unacceptably high. K-quants solve this with a two-level hierarchy: super-blocks of 256 elements containing multiple sub-blocks, each with its own scale.
Super-block (256 elements)
+-----------------------------------------------------------+
| Sub-block scales (quantized to fewer bits themselves) |
| Super-block scale (F16) + super-block min (F16) |
+-----------------------------------------------------------+
| Sub-block 0 (32 elements, quantized data) |
| Sub-block 1 (32 elements, quantized data) |
| ... |
| Sub-block 7 (32 elements, quantized data) |
+-----------------------------------------------------------+
The sub-block scales are themselves quantized (usually to 4 or 6 bits), and the super-block scale converts them back to floating point. This is nested quantization – you quantize the quantization parameters.
Q2_K: 2-bit with 4-bit Sub-Block Scales
The most aggressive K-quant. Each element gets only 2 bits, but the hierarchical scales keep quality surprisingly usable.
Super-block layout (256 elements, 84 bytes):
+--------------------------------------------------+
| F16 super-scale (d) | 2 bytes |
| F16 super-minimum (dmin) | 2 bytes |
| 16 bytes: sub-block scales | (16 x 4-bit pairs) |
| each byte: [scale_hi:scale_lo] |
| 64 bytes: quantized data | (256 x 2-bit) |
+--------------------------------------------------+
The 16 bytes of sub-block scales encode 16 pairs of (scale, minimum) values, one for each sub-block of 16 elements. Each pair is packed into a single byte as two 4-bit values.
Dequantization (for element i in sub-block j):
sub_scale = (scales_byte[j] & 0x0F)
sub_min = (scales_byte[j] >> 4)
x[i] = d * sub_scale * q[i] - dmin * sub_min
| Parameter | Value |
|---|---|
| Super-block size | 256 elements |
| Overhead | 2 (d) + 2 (dmin) + 16 (sub-scales) = 20 bytes |
| Data | 64 bytes (256 x 2-bit) |
| Total | 84 bytes |
| Bits per weight | 84 * 8 / 256 = 2.625 bpw |
Q3_K: 3-bit with Mixed Sub-Block Scales
Q3_K uses 3 bits per value with 256-element super-blocks.
Super-block layout (256 elements, 110 bytes):
+--------------------------------------------------+
| F16 super-scale (d) | 2 bytes |
| 12 bytes: quantized sub-scales | |
| (16 scales, 6-bit each, packed) | |
| 32 bytes: high-bits of quants | |
| (256 bits, 1 per element) | |
| 64 bytes: low-bits of quants | |
| (256 x 2-bit) | |
+--------------------------------------------------+
The 3-bit values are split across two regions: the low 2 bits are packed into the 64-byte “quants low” section (like Q2_K), and the high bit is stored separately in the 32-byte “high bits” section. This split layout simplifies SIMD extraction.
Dequantization:
q_lo = (quants_lo[i/4] >> (2 * (i%4))) & 0x03 // 2 low bits
q_hi = (hmask[i/8] >> (i%8)) & 1 // 1 high bit
q = q_lo | (q_hi << 2) // 3-bit value [0..7]
x[i] = d * sub_scale * (q - 4) // center at 4
| Parameter | Value |
|---|---|
| Super-block size | 256 elements |
| Total | 110 bytes |
| Bits per weight | 110 * 8 / 256 = 3.4375 bpw |
Q4_K: 4-bit K-Quant
Q4_K is the workhorse of the K-quant family. It provides a good balance of quality and compression that works well for most models.
Super-block layout (256 elements, 144 bytes):
+--------------------------------------------------+
| F16 super-scale (d) | 2 bytes |
| F16 super-minimum (dmin) | 2 bytes |
| 12 bytes: sub-block scales+mins | |
| (8 sub-blocks, 6-bit scale + | |
| 6-bit min, packed) | |
| 128 bytes: quantized data | |
| (256 x 4-bit nibbles) | |
+--------------------------------------------------+
Each of the 8 sub-blocks has a 6-bit scale and a 6-bit minimum, packed into the 12-byte scale section. The packing is non-trivial – the 6-bit values are split across multiple bytes.
Scale packing detail (12 bytes for 8 sub-blocks):
Bytes 0-3: low 4 bits of scales[0..7] (4 bits each, 2 per byte)
Bytes 4-7: low 4 bits of mins[0..7] (4 bits each, 2 per byte)
Bytes 8-9: high 2 bits of scales[0..7] (2 bits each, packed)
Bytes 10-11: high 2 bits of mins[0..7] (2 bits each, packed)
Dequantization:
scale_6bit = low4(scales, j) | (high2(scales, j) << 4)
min_6bit = low4(mins, j) | (high2(mins, j) << 4)
x[i] = d * scale_6bit * q[i] - dmin * min_6bit
| Parameter | Value |
|---|---|
| Super-block size | 256 elements |
| Total | 144 bytes |
| Bits per weight | 144 * 8 / 256 = 4.5 bpw |
Q5_K: 5-bit K-Quant
Q5_K extends Q4_K with an extra bit per value:
Super-block layout (256 elements, 176 bytes):
+--------------------------------------------------+
| F16 super-scale (d) | 2 bytes |
| F16 super-minimum (dmin) | 2 bytes |
| 12 bytes: sub-block scales+mins | |
| 128 bytes: low nibbles | |
| (256 x 4-bit) | |
| 32 bytes: high bits | |
| (256 x 1-bit) | |
+--------------------------------------------------+
Like Q3_K, the 5th bit is stored separately from the low 4 bits. This allows the low-nibble extraction to use the same SIMD patterns as Q4_K.
| Parameter | Value |
|---|---|
| Super-block size | 256 elements |
| Total | 176 bytes |
| Bits per weight | 176 * 8 / 256 = 5.5 bpw |
Q6_K: 6-bit K-Quant
Q6_K is the highest-quality K-quant, approaching F16 accuracy for most models.
Super-block layout (256 elements, 210 bytes):
+--------------------------------------------------+
| F16 super-scale (d) | 2 bytes |
| 16 bytes: sub-block scales (int8) | |
| 128 bytes: low nibbles | |
| (256 x 4-bit) | |
| 64 bytes: high dibits | |
| (256 x 2-bit) | |
+--------------------------------------------------+
Q6_K simplifies the scale storage: each sub-block scale is a full int8 value (not quantized further). There is no separate minimum – Q6_K uses symmetric quantization like Q4_0.
Dequantization:
q_lo = (quants_lo[i/2] >> (4*(i%2))) & 0x0F // 4 low bits
q_hi = (quants_hi[i/4] >> (2*(i%4))) & 0x03 // 2 high bits
q = q_lo | (q_hi << 4) // 6-bit value [0..63]
x[i] = d * sub_scale_int8 * (q - 32) // center at 32
| Parameter | Value |
|---|---|
| Super-block size | 256 elements |
| Total | 210 bytes |
| Bits per weight | 210 * 8 / 256 = 6.5625 bpw |
K-Quant Summary Table
| Format | Bits/value | Block size | Bytes/block | Effective bpw | Scale type | Symmetry |
|---|---|---|---|---|---|---|
| Q2_K | 2 | 256 | 84 | 2.63 | 4-bit nested | Asymmetric |
| Q3_K | 3 | 256 | 110 | 3.44 | 6-bit nested | Symmetric |
| Q4_K | 4 | 256 | 144 | 4.50 | 6-bit nested | Asymmetric |
| Q5_K | 5 | 256 | 176 | 5.50 | 6-bit nested | Asymmetric |
| Q6_K | 6 | 256 | 210 | 6.56 | int8 | Symmetric |
And for comparison, the legacy formats:
| Format | Bits/value | Block size | Bytes/block | Effective bpw | Scale type | Symmetry |
|---|---|---|---|---|---|---|
| Q4_0 | 4 | 32 | 18 | 4.50 | F16 | Symmetric |
| Q4_1 | 4 | 32 | 20 | 5.00 | F16 + F16 min | Asymmetric |
| Q5_0 | 5 | 32 | 22 | 5.50 | F16 | Symmetric |
| Q8_0 | 8 | 32 | 34 | 8.50 | F16 | Symmetric |
MLX Per-Group Quantization
MLX takes a different approach. Rather than defining custom block layouts with packed scales, MLX uses a straightforward per-group scheme with separate tensors for weights, scales, and biases.
Layout
For a weight matrix of shape [N, K] quantized to B bits with group size G:
Weight tensor: shape [N, K*B/32], dtype U32
Scales tensor: shape [N, K/G], dtype F16 or BF16
Biases tensor: shape [N, K/G], dtype F16 or BF16
Each U32 word packs 32/B quantized values. The values within a U32 are stored contiguously from LSB to MSB.
Bit Extraction
For B-bit quantization, extracting the j-th value from a U32:
uint32_t word = packed_weights[word_index];
uint32_t mask = (1u << B) - 1; // B ones
int shift = (j % (32 / B)) * B;
uint32_t q = (word >> shift) & mask;
Example: 4-bit extraction from U32 word 0xFEDCBA98:
Binary: 1111 1110 1101 1100 1011 1010 1001 1000
Value 0 (bits 0-3): 1000 = 8
Value 1 (bits 4-7): 1001 = 9
Value 2 (bits 8-11): 1010 = 10
Value 3 (bits 12-15): 1011 = 11
Value 4 (bits 16-19): 1100 = 12
Value 5 (bits 20-23): 1101 = 13
Value 6 (bits 24-27): 1110 = 14
Value 7 (bits 28-31): 1111 = 15
Dequantization
group_index = j / G
x[i][j] = scales[i][group_index] * q[i][j] + biases[i][group_index]
This is asymmetric affine quantization. The bias acts as a zero-point, allowing the quantization grid to cover any range, not just one centered at zero.
MLX Bit Width Variants
Akunu supports four MLX quantization widths, each mapped to a dtype code:
| Bit width | Dtype code | Values per U32 | Typical group size | Effective bpw |
|---|---|---|---|---|
| 3-bit | 99 | 10 (+ 2 bits padding) | 64 | ~3.5 |
| 4-bit | 100 | 8 | 64 | ~4.5 |
| 6-bit | 102 | 5 (+ 2 bits padding) | 64 | ~6.5 |
| 8-bit | 101 | 4 | 64 | ~8.5 |
The effective bpw includes the overhead of scale and bias storage. For a [4096, 4096] matrix with group size 64:
Weight data: 4096 * 4096 * B / 8 bytes
Scale data: 4096 * (4096/64) * 2 bytes = 4096 * 64 * 2 = 524,288 bytes
Bias data: same as scale = 524,288 bytes
Total overhead: 1,048,576 bytes (~1 MB)
This overhead is constant regardless of bit width, and is small relative to the weight data for large matrices.
3-bit Packing Detail
3-bit is the most irregular case because 32 is not evenly divisible by 3. MLX packs 10 three-bit values into each U32 (10 * 3 = 30 bits), leaving 2 bits unused:
U32 word: [unused:2][q9:3][q8:3][q7:3][q6:3][q5:3][q4:3][q3:3][q2:3][q1:3][q0:3]
Bits: 31-30 29-27 26-24 23-21 20-18 17-15 14-12 11-9 8-6 5-3 2-0
The extraction code:
uint32_t word = packed[word_index];
int pos_in_word = j % 10;
int shift = pos_in_word * 3;
uint32_t q = (word >> shift) & 0x7; // mask = 0b111
GPU Buffer Layout (Packed)
As discussed in the previous chapter, Akunu packs the three MLX tensors into a single GPU buffer for each weight matrix:
Offset 0: Packed U32 weights
Offset weight_bytes: F16 scales
Offset weight_bytes + scale_bytes: F16 biases
The Metal kernel receives the buffer pointer and a weight_bytes parameter. It computes scale and bias offsets arithmetically:
device const half *scales = (device const half *)
((device const char *)weights + params.weight_bytes);
device const half *biases = scales + (params.N * params.K / params.group_size);
Metal Kernel Dequantization Patterns
Each format requires a different dequantization strategy in the GEMV kernel. Here are the common patterns:
Q4_0 GEMV Inner Loop
// Each thread processes a chunk of the K dimension
for (int k = tid; k < K; k += stride) {
int block_idx = k / 32;
int block_off = k % 32;
// Load block header
half d = block_scales[block_idx];
// Load and extract nibble
int byte_idx = block_off / 2;
uint8_t byte = block_data[block_idx * 16 + byte_idx];
int nibble = (block_off & 1) ? (byte >> 4) : (byte & 0x0F);
// Dequantize and accumulate
float w = float(d) * (float(nibble) - 8.0f);
sum += w * float(input[k]);
}
K-Quant GEMV Pattern (Q4_K)
// Process one super-block (256 elements) at a time
for (int sb = ...; sb < n_superblocks; sb++) {
half d = super_scales[sb];
half dmin = super_mins[sb];
// Decode sub-block scales (6-bit from packed bytes)
for (int sub = 0; sub < 8; sub++) {
int sc = decode_6bit_scale(scale_bytes, sub);
int mn = decode_6bit_min(scale_bytes, sub);
float sub_scale = float(d) * sc;
float sub_min = float(dmin) * mn;
for (int k = 0; k < 32; k++) {
int q = extract_nibble(data, sub*32 + k);
float w = sub_scale * q - sub_min;
sum += w * float(input[sb*256 + sub*32 + k]);
}
}
}
MLX GEMV Pattern
// Process one group at a time
for (int g = 0; g < K / group_size; g++) {
half scale = scales[row * n_groups + g];
half bias = biases[row * n_groups + g];
for (int k = 0; k < group_size; k++) {
int global_k = g * group_size + k;
uint32_t word = packed[row * K_packed + global_k / values_per_word];
int pos = global_k % values_per_word;
uint32_t q = (word >> (pos * bits)) & bit_mask;
float w = float(scale) * float(q) + float(bias);
sum += w * float(input[global_k]);
}
}
Quality vs Size Comparison
The following table summarizes quality-size trade-offs. Perplexity numbers are approximate and vary by model, but the relative ordering is consistent.2
| Format | Effective bpw | Model size (7B) | Perplexity impact | Best use case |
|---|---|---|---|---|
| F16 | 16.0 | 14 GB | Baseline | Reference / debugging |
| Q8_0 | 8.5 | 7.4 GB | Negligible | Activation quantization |
| Q6_K | 6.56 | 5.7 GB | Very small | Quality-sensitive apps |
| Q5_K | 5.50 | 4.8 GB | Small | Good quality/size balance |
| Q4_K | 4.50 | 3.9 GB | Moderate | Best general-purpose |
| Q4_0 | 4.50 | 3.9 GB | Moderate+ | Fastest decode (simple format) |
| Q3_K | 3.44 | 3.0 GB | Noticeable | Memory-constrained |
| Q2_K | 2.63 | 2.3 GB | Significant | Extreme compression |
| MLX Q4 | ~4.5 | ~3.9 GB | Moderate | MLX ecosystem models |
| MLX Q3 | ~3.5 | ~3.1 GB | Noticeable | MLX ecosystem, low memory |
| MLX Q8 | ~8.5 | ~7.4 GB | Negligible | High quality MLX |
How Akunu Selects Kernels
The dtype code embedded in (or derived from) the weight file determines which kernels are used. Akunu’s DTypeDescriptor table maps each dtype to a complete set of kernel names:
| Dtype | Code | GEMV kernel | GEMM kernel | Fused SiLU | Embedding |
|---|---|---|---|---|---|
| F16 | 1 | gemv_f16 | simd_gemm_f16 | – | embedding_lookup_f16 |
| Q4_0 | 2 | gemv_q4_0 | simd_gemm_q4_0 | gemv_q4_0_silu | embedding_lookup_q4_0 |
| Q4_1 | 3 | gemv_q4_1 | simd_gemm_q4_1 | – | embedding_lookup_q4_1 |
| Q8_0 | 8 | gemv_q8_0 | simd_gemm_q8_0 | – | embedding_lookup_q8_0 |
| Q2_K | 10 | gemv_q2_k | simd_gemm_q2_k | – | – |
| Q3_K | 11 | gemv_q3_k | simd_gemm_q3_k | – | – |
| Q4_K | 12 | gemv_q4_k | simd_gemm_q4_k | – | embedding_lookup_q4_k |
| Q5_K | 13 | gemv_q5_k | simd_gemm_q5_k | – | – |
| Q6_K | 14 | gemv_q6_k | simd_gemm_q6_k | – | embedding_lookup_q6_k |
| BF16 | 31 | gemv_bf16 | simd_gemm_bf16 | – | embedding_lookup_bf16 |
| MLX Q3 | 99 | gemv_mlx_q3 | simd_gemm_mlx_q3 | gemv_mlx_q3_silu | embedding_lookup_mlx_generic |
| MLX Q4 | 100 | gemv_mlx_q4 | simd_gemm_mlx_q4 | gemv_mlx_q4_silu | embedding_lookup_mlx_q4 |
| MLX Q6 | 102 | gemv_mlx_q6 | simd_gemm_mlx_q6 | gemv_mlx_q6_silu | embedding_lookup_mlx_generic |
| MLX Q8 | 101 | gemv_mlx_q8 | simd_gemm_mlx_q8 | gemv_mlx_q8_silu | embedding_lookup_mlx_generic |
Note the pattern: GGUF formats have dtype codes below 32 (matching GGML’s enum), while MLX formats use codes 99-102. This avoids any collision between the two namespaces.
Each descriptor also includes dispatch geometry – the number of rows per threadgroup and the threadgroup size. These are tuned per format because different formats have different computational density:
| Format family | Rows/threadgroup | Threadgroup size | Rationale |
|---|---|---|---|
| F16 | 16 | 128 | Simple dequant, high arithmetic density |
| Q4_0/Q4_1 | 16 | 128 | Simple block format, fast extraction |
| Q8_0 | 32 | 256 | Larger data per block, needs more threads |
| K-quants | 16 | 256 | Complex nested dequant, more ALU work |
| MLX all | 16 | 128 | Group-based, moderate complexity |
Mixed Quantization
Many GGUF models use different quantization levels for different layers. For example, a Q4_K_M quantization (the “M” stands for “mixed”) might use:
- Q6_K for the attention norms and output norm (small tensors, quality-sensitive)
- Q4_K for most weight matrices
- Q5_K for the first and last few layers
Akunu handles this transparently because get_dtype() returns the per-tensor dtype, and build_dispatch_table() selects the kernel for each weight individually:
snprintf(name, sizeof(name), "layers.%d.attention.q.weight", layer);
uint32_t q_dtype = weights.get_dtype(name); // might be Q4_K
snprintf(name, sizeof(name), "layers.%d.attention.k.weight", layer);
uint32_t k_dtype = weights.get_dtype(name); // might be Q6_K
// Each gets the correct kernel
gemv(input, q_weight, output_q, 0, q_dtype, q_dim, dim);
gemv(input, k_weight, output_k, 0, k_dtype, kv_dim, dim);
Weight fusion (QKV or gate+up) requires matching dtypes – you cannot fuse a Q4_K weight with a Q6_K weight because they have different block layouts. The fusion check verifies this:
bool fuse_qkv = q_dtype == k_dtype && k_dtype == v_dtype;
If the dtypes do not match, Akunu falls back to separate GEMV dispatches.
Practical Guidance
For users choosing a quantization format:
- Q4_K_M is the sweet spot for most use cases. It provides good quality at ~4.5 bpw with the K-quant’s hierarchical scales.
- MLX Q4 is comparable in quality and works well with models from the MLX ecosystem.
- Q4_0 is slightly lower quality than Q4_K but uses simpler block structure, which can be faster for decode (where GEMV is the bottleneck).
- Q6_K or MLX Q8 if you can afford the memory and want near-lossless quality.
- Q2_K and MLX Q3 should be reserved for cases where memory is truly scarce. Quality degradation is noticeable.
For kernel developers:
- The block-of-32 formats (Q4_0, Q4_1, Q8_0) are the easiest to implement. Start there.
- K-quants require careful handling of the nested scale packing. Get the bit extraction right by testing against a reference implementation before optimizing.
- MLX formats are conceptually simpler (uniform group structure, no nested quantization) but require handling the three-tensor buffer layout and function constants for group size and K dimension.
- Always profile with real models. The format with the least memory usage is not always the fastest – simpler dequantization (Q4_0) can outperform complex dequantization (Q4_K) even at the same bit width, because the kernel spends less time on scale lookups.3
-
@ikawrakow, “k-quants: 2, 3, 4, 5, and 6-bit quantization for llama.cpp,” llama.cpp PR #1684, 2023. The key contribution was the super-block architecture that enables usable 2-3 bit quantization. See https://github.com/ggerganov/llama.cpp/pull/1684. ↩
-
Perplexity numbers are from the llama.cpp quantization benchmarks. Exact values depend on the model and evaluation dataset. The relative ordering (F16 > Q6_K > Q5_K > Q4_K > Q4_0 > Q3_K > Q2_K) is consistent across models. ↩
-
On Apple M2 Pro, Q4_0 GEMV for a 4096x4096 matrix runs at approximately 92% of memory bandwidth, while Q4_K achieves about 85%, despite both being ~4.5 bpw. The difference is the 6-bit sub-scale decoding overhead in Q4_K. ↩
KV Cache Design and Management
Every transformer-based language model has a dirty secret: the attention mechanism is fundamentally stateful. Each time the model generates a new token, it needs to look back at every previous token’s key and value projections. Without caching, you would re-compute K and V for the entire prompt on every single decode step, turning O(n) generation into O(n^2). The KV cache is what prevents that – it stores previously computed key and value tensors so the model only computes the new token’s K and V, then attends over the full cached history.
In this chapter, we will walk through exactly how Akunu designs, allocates, and manages its KV cache. The design philosophy is relentlessly simple: no virtual calls, no optionals, no reference counting. A flat POD struct with contiguous GPU buffers.
What the KV Cache Actually Stores
At every transformer layer, the attention mechanism projects the hidden state into three matrices: Q (query), K (key), and V (value). During decode, we only compute Q/K/V for the current token, but we need the K and V from all previous tokens to compute attention scores. So we keep a running buffer of K and V per layer.
Here is what that looks like conceptually:
For each layer l in [0, n_layers):
K cache[l] = all past key vectors for layer l
V cache[l] = all past value vectors for layer l
When generating token at position t, the model:
- Computes Q_t, K_t, V_t from the current hidden state
- Writes K_t, V_t into the cache at position
t - Reads the full K[0..t], V[0..t] from the cache
- Computes attention: softmax(Q_t @ K[0..t]^T / sqrt(d)) @ V[0..t]
The cache grows by one position per generated token. Without it, you would need to re-run the entire prompt through every layer on every step.
The KVCache Struct
Let us look at Akunu’s actual implementation. The struct lives in
src/cache/kv_cache.h and it is remarkably compact:
struct KVCache {
int n_layers;
int n_kv_heads;
int head_dim;
int max_length;
int current_length;
std::vector<Buffer> k_buffers; // one per layer
std::vector<Buffer> v_buffers; // one per layer
int kv_stride; // max_length * head_dim
};
That is it. Five integers, two vectors of GPU buffers, and a pre-computed stride. No inheritance hierarchy. No smart pointers. No allocators or memory pools. Just the data you need and nothing else.
Let us break each field down:
n_layers: The number of transformer layers. For LLaMA 3-8B, this is 32.
For a 70B model, 80. Each layer gets its own independent K buffer and V buffer.
n_kv_heads: The number of key/value heads. In grouped-query attention (GQA),
this is smaller than the number of query heads. LLaMA 3-8B has 32 query heads
but only 8 KV heads – a 4:1 ratio that saves 75% of the cache memory.
head_dim: The dimension of each attention head. Typically 128 for modern
models (e.g., 4096 / 32 = 128 for LLaMA 3-8B).
max_length: The maximum sequence length the cache can hold. Typically set
to the model’s context window (e.g., 4096 or 8192).
current_length: The stateful counter – how many positions have been written
so far. Starts at 0, incremented by advance(), never exceeds max_length.
kv_stride: Pre-computed as max_length * head_dim. This is the number of
FP16 elements between consecutive KV heads in memory.
Memory Layout: Head-Major Ordering
Each K and V buffer has the following shape:
[n_kv_heads, max_seq_len, head_dim]
All stored in FP16 (half-precision, 2 bytes per element). Let us draw what this looks like in memory for a single layer:
Buffer: k_buffers[layer]
+------------------------------------------------------------------+
| KV Head 0 |
| +-----------------------------------------------------------+ |
| | pos 0: [d0, d1, d2, ..., d127] (head_dim=128, FP16) | |
| | pos 1: [d0, d1, d2, ..., d127] | |
| | pos 2: [d0, d1, d2, ..., d127] | |
| | ... | |
| | pos max_length-1: [d0, d1, ..., d127] | |
| +-----------------------------------------------------------+ |
| KV Head 1 |
| +-----------------------------------------------------------+ |
| | pos 0: [d0, d1, d2, ..., d127] | |
| | pos 1: [d0, d1, d2, ..., d127] | |
| | ... | |
| +-----------------------------------------------------------+ |
| ... |
| KV Head (n_kv_heads-1) |
| +-----------------------------------------------------------+ |
| | pos 0: [d0, d1, d2, ..., d127] | |
| | ... | |
| +-----------------------------------------------------------+ |
+------------------------------------------------------------------+
The linear address of element [head h, position p, dimension d] is:
offset = (h * kv_stride + p * head_dim + d) * sizeof(FP16)
= (h * max_length * head_dim + p * head_dim + d) * 2
Why Head-Major?
You might ask: why not position-major [max_seq_len, n_kv_heads, head_dim]?
The answer comes down to how attention kernels access memory.
During decode, each attention head operates independently. The flash attention
decode kernel processes one head at a time. For a given head h, it needs to
read all positions [0..current_length] for that head in sequence. With
head-major layout, these positions are contiguous in memory:
Head-major (what Akunu uses):
Reading head h, positions 0..T = contiguous read of T*head_dim elements
Great for coalesced GPU memory access!
Head 0: [pos0][pos1][pos2]...[posT][---padding---]
Head 1: [pos0][pos1][pos2]...[posT][---padding---]
...
Position-major (alternative):
Reading head h, positions 0..T = strided read with stride n_kv_heads*head_dim
Terrible for coalesced access!
Pos 0: [head0][head1]...[headN]
Pos 1: [head0][head1]...[headN]
...
With head-major layout, each SIMD group in the attention kernel reads a nice
contiguous chunk of memory. The kv_stride field (= max_length * head_dim)
gives the distance between head 0’s data and head 1’s data, which the kernel
uses to index into the right region.
Memory Budget Calculation
Let us work through a concrete example. Take LLaMA 3.1-8B with a 4096 context:
Parameters:
n_layers = 32
n_kv_heads = 8 (GQA: 32 query heads, 8 KV heads)
head_dim = 128
max_length = 4096
dtype = FP16 (2 bytes)
Per-layer buffer size:
size = n_kv_heads * max_length * head_dim * sizeof(FP16)
= 8 * 4096 * 128 * 2
= 8,388,608 bytes
= 8 MB
Total KV cache (K + V, all layers):
total = 2 * n_layers * size
= 2 * 32 * 8 MB
= 512 MB
That is 512 MB just for the KV cache on a 4096 context. Scale up to 8192 context and you are at 1 GB. For a 70B model with 80 layers and 8 KV heads:
Per-layer: 8 * 8192 * 128 * 2 = 16 MB
Total: 2 * 80 * 16 MB = 2,560 MB = 2.5 GB
This is why GQA is so important – without it (i.e., multi-head attention where
n_kv_heads = n_heads = 64), the 70B model would need:
64 * 8192 * 128 * 2 = 128 MB per layer
2 * 80 * 128 MB = 20,480 MB = 20 GB
GQA with 8 KV heads reduces cache memory by 8x. That is the difference between fitting on a MacBook Pro and not fitting at all.
Here is a summary table:
+----------------+--------+---------+--------+---------+---------+
| Model | Layers | KV Heads| Head D | Ctx Len | KV Size |
+----------------+--------+---------+--------+---------+---------+
| LLaMA 3-8B | 32 | 8 | 128 | 4096 | 512 MB |
| LLaMA 3-8B | 32 | 8 | 128 | 8192 | 1 GB |
| Qwen 2.5-7B | 28 | 4 | 128 | 4096 | 224 MB |
| Gemma 3-4B | 34 | 4 | 256 | 4096 | 896 MB |
| LLaMA 3-70B | 80 | 8 | 128 | 4096 | 1.25 GB |
| LLaMA 3-70B | 80 | 8 | 128 | 8192 | 2.5 GB |
+----------------+--------+---------+--------+---------+---------+
Pre-Allocation: No Malloc in the Hot Path
The KVCache::create() factory allocates everything upfront:
static KVCache create(Device& device, int n_layers, int n_kv_heads,
int head_dim, int max_length) {
KVCache cache;
// ... set fields ...
size_t buf_size = (size_t)n_kv_heads * max_length * head_dim * sizeof(uint16_t);
cache.k_buffers.resize(n_layers);
cache.v_buffers.resize(n_layers);
for (int i = 0; i < n_layers; i++) {
cache.k_buffers[i] = device.allocate(buf_size);
cache.v_buffers[i] = device.allocate(buf_size);
memset(cache.k_buffers[i].contents, 0, buf_size);
memset(cache.v_buffers[i].contents, 0, buf_size);
}
return cache;
}
Note several things:
-
All buffers are allocated at once. No lazy allocation, no on-demand growth. You pay the memory cost at model load time, not during generation.
-
Buffers are zero-filled. The
memsetensures that unwritten positions have zero values. This matters because the attention kernel might read pastcurrent_lengthdue to SIMD alignment, and we do not want garbage affecting softmax scores. -
The
sizeof(uint16_t)is FP16. Akunu stores cache values in half precision. There is no option for FP32 or INT8 cache – keeping it simple. -
Device::allocate()returns aBuffer. On Metal, this is anMTLBufferallocated in shared memory (Apple Silicon UMA), meaning both CPU and GPU can access it without explicit copies.
Stateful Tracking
The KV cache tracks how many positions have been filled via current_length. The
API provides four operations for managing this state:
advance(count)
After computing K and V for count new tokens (1 during decode, N during prefill),
call advance() to update the position:
void advance(int count) {
current_length += count;
if (current_length > max_length)
current_length = max_length;
}
The clamping to max_length is a safety measure. In practice, the caller should
check would_overflow() before adding tokens, but the clamp prevents buffer
overruns if something goes wrong.
would_overflow(additional)
bool would_overflow(int additional) const {
return current_length + additional > max_length;
}
The caller checks this before prefill or decode to avoid writing past the buffer. If it returns true, the inference engine must either truncate the input or refuse the request.
rollback(to_length)
void rollback(int to_length) {
if (to_length < current_length)
current_length = to_length;
}
Rollback moves the cursor backwards. The actual data in the buffers is not
erased – only the position counter changes. This is safe because future writes
at positions >= to_length will overwrite the stale data before it is read.
This is used for prefix caching in the server: if the new prompt shares a prefix with the previous one, we rollback to the shared prefix length and only re-compute the divergent tokens.
reset()
void reset() { current_length = 0; }
The nuclear option. Resets to the beginning without touching the actual buffer contents. The next prefill will overwrite everything.
The Lifecycle of a Cache During Generation
Here is the full flow from prompt to generation:
1. User sends prompt: "What is the capital of France?"
Tokenized: [BOS, 1724, 338, 278, 7483, 310, 3444, 29973] (8 tokens)
2. PREFILL: process all 8 tokens in one batch
For each layer:
Compute K[0..7], V[0..7] <-- batch of 8 KV vectors
Write into cache at pos 0..7
cache.advance(8) <-- current_length = 8
3. DECODE step 1: generate token "The"
For each layer:
Compute K[8], V[8] <-- single new KV vector
Write into cache at pos 8
Attend Q[8] over K[0..8], V[0..8]
cache.advance(1) <-- current_length = 9
4. DECODE step 2: generate token " capital"
For each layer:
Compute K[9], V[9]
Write into cache at pos 9
Attend Q[9] over K[0..9], V[0..9]
cache.advance(1) <-- current_length = 10
5. ... continue until EOS or max_tokens ...
6. NEXT CONVERSATION:
New prompt: "What is the capital of Germany?"
Shares prefix: "What is the capital of " = 7 tokens
Option A: cache.reset() + full prefill (simple)
Option B: rollback to shared prefix + incremental prefill (efficient)
With prefix caching (Option B):
shared = 7 tokens match
cache.rollback(7) <-- current_length = 7
Prefill only tokens 7..N <-- "Germany?" = 2 tokens
cache.advance(2) <-- current_length = 9
Here is a timeline diagram:
Position in KV cache:
0 1 2 3 4 5 6 7 8 9 10 11
| | | | | | | | | | | |
[BOS][What][ is][ the][cap][ital][ of][ Fr][The][ ca][pit]
|----- prefill (8 tokens) -----| |-- decode step by step--|
^
current_length advances: 8 -> 9 -> 10 -> 11
After rollback(7) for new conversation:
0 1 2 3 4 5 6 7 8
| | | | | | | | |
[BOS][What][ is][ the][cap][ital][ of][Ger][many]
|-- preserved prefix (7) --| ^
incremental prefill: 2 tokens
Prefix Caching in the Server
The HTTP server (Chapter 50) maintains a ModelEntry per loaded model that tracks
the last prompt’s tokens:
struct ModelEntry {
std::vector<uint32_t> cached_tokens;
int cached_position = 0;
int shared_prefix(const uint32_t *tokens, int n_tokens) const {
int shared = 0;
int limit = std::min((int)cached_tokens.size(), n_tokens);
for (int i = 0; i < limit; i++) {
if (cached_tokens[i] != tokens[i]) break;
shared++;
}
return shared;
}
};
When a new request arrives:
- Encode the new prompt to tokens
- Compare with
cached_tokensto find the shared prefix length - If
shared > 0 && shared <= cached_position, use rollback + incremental prefill - Otherwise, full reset + prefill from scratch
This gives you “free” prefix caching with zero extra infrastructure. In a chatbot scenario where the system prompt is the same across turns, you skip re-processing hundreds of tokens. For a 2048-token system prompt, that can save 50-100ms of prefill time on each request.
How the Attention Kernel Uses the Cache
During decode, the flash attention kernel receives the cache buffers as arguments. Here is a simplified view of the kernel parameters:
flash_attention_decode_fast_f16(
Q: [n_heads, 1, head_dim] <-- current token's query
K_cache: [n_kv_heads, max_seq, head_dim] <-- full K cache for this layer
V_cache: [n_kv_heads, max_seq, head_dim] <-- full V cache for this layer
output: [n_heads, 1, head_dim] <-- attention output
kv_seq_len: current_length + 1 <-- how far to read
kv_stride: max_length * head_dim <-- stride between heads
scale: 1.0 / sqrt(head_dim)
)
The kernel uses kv_stride to jump between heads and kv_seq_len to know how
many positions to attend over. It reads K[h, 0..kv_seq_len-1, :] and
V[h, 0..kv_seq_len-1, :] contiguously for each head h.
For GQA (grouped-query attention), where multiple Q heads share a single KV head,
the kernel maps Q head index to KV head index with integer division:
kv_head = q_head / (n_heads / n_kv_heads).
Cleanup
The destroy() method frees all GPU buffers:
void destroy(Device& device) {
for (auto& b : k_buffers) device.free_buffer(b);
for (auto& b : v_buffers) device.free_buffer(b);
k_buffers.clear();
v_buffers.clear();
}
No destructor magic. The caller is responsible for calling destroy() before
the Device goes away. This is deliberate – GPU resource lifetimes must be
explicit in a system without garbage collection.
What Akunu Does NOT Do
It is worth noting what this KV cache design deliberately omits:
-
No paged attention. Systems like vLLM use virtual memory paging to efficiently share cache across sequences. Akunu allocates one flat buffer per layer, trading memory efficiency for simplicity and zero fragmentation.
-
No multi-sequence support. The cache tracks a single
current_length. There is no batch dimension or per-sequence tracking. For serving multiple concurrent conversations, you would need multiple model instances. -
No quantized cache. Some inference engines store KV in INT8 or INT4 to reduce memory. Akunu keeps everything in FP16 for maximum quality and simplicity.
-
No sliding window. Some architectures (Mistral) use a sliding window where old positions are evicted. Akunu’s cache is a simple grow-only buffer with a hard maximum.
-
No speculative decoding cache management. Systems that do speculative decoding need to speculatively advance and then rollback the cache. Akunu’s
rollback()could support this, but the current codebase does not implement speculative decoding.
These are all conscious trade-offs. For a single-user inference engine targeting Apple Silicon, the simple flat-buffer approach gives you maximum GPU throughput (contiguous memory access) and zero overhead (no bookkeeping, no page tables).
Summary
+-------------------------------+----------------------------------+
| Design Decision | Rationale |
+-------------------------------+----------------------------------+
| POD struct, no inheritance | Cache-line friendly, no vtable |
| Head-major layout | Contiguous reads per attention |
| | head in flash attention kernel |
| FP16 storage | 2x smaller than FP32, native |
| | Metal half-precision support |
| Pre-allocate max_length | Zero allocation in hot path |
| Zero-fill on creation | Safe reads past current_length |
| Single current_length counter | Simple state machine, no locks |
| Pre-computed kv_stride | One less multiply per kernel |
| Explicit destroy() | Deterministic GPU resource mgmt |
+-------------------------------+----------------------------------+
The KV cache is the single largest runtime memory consumer after the model weights themselves. Understanding its layout is essential for reasoning about memory budgets, context window limits, and the performance characteristics of the attention kernel. Next, we will look at the other half of the runtime memory story: scratch buffers.
Scratch Buffer Architecture
If the KV cache is the long-term memory of inference, scratch buffers are the
working memory – the scratchpad where every intermediate computation happens.
Matrix multiplications, attention outputs, FFN activations, logits – all of these
need temporary GPU buffers, and if you allocate them on the fly during generation,
you are dead. GPU memory allocation is slow. Fragmentation is real. And the last
thing you want in a tight decode loop running at 100+ tokens per second is a call
to MTLDevice.makeBuffer().
Akunu’s solution is dead simple: allocate every scratch buffer at model load
time, reuse them every forward pass, and never touch the allocator again. This
chapter walks through the ScratchBuffers struct, the ping-pong pattern, the
QKV sub-offset trick, and the full memory budget.
The Core Idea: Zero Allocation in the Hot Path
Here is the principle, stated bluntly: after ScratchBuffers::create() returns,
zero bytes of GPU memory are ever allocated during inference. Not one buffer.
Not one resize. Nothing.
Every temporary result – the hidden state after embedding lookup, the Q/K/V projections, the attention output, the FFN gate and up projections, the activated intermediate, the final logits – all live in pre-allocated buffers that get overwritten on every forward pass.
This is possible because the sizes of all intermediates are known at model load
time. The model config tells us dim, q_dim, kv_dim, ffn_dim, and
vocab_size. Every buffer size is a simple function of these constants.
The ScratchBuffers Struct
The full struct from src/cache/scratch.h:
struct ScratchBuffers {
// === Decode buffers (single token) ===
Buffer h0; // [dim] FP16
Buffer h1; // [dim] FP16
Buffer residual; // [dim] FP16
Buffer qkv; // [q_dim + 2*kv_dim] FP16
Buffer attn_out; // [max(q_dim, dim)] FP16
Buffer post_norm; // [dim] FP16
Buffer ffn_gate; // [ffn_dim] FP16
Buffer ffn_up; // [ffn_dim] FP16
Buffer ffn_act; // [ffn_dim] FP16
Buffer logits; // [vocab_size] FP16
Buffer token_ids; // [max_chain] U32
int qkv_q_offset; // byte offset of Q within qkv
int qkv_k_offset; // byte offset of K within qkv
int qkv_v_offset; // byte offset of V within qkv
// === Prefill buffers (batch) ===
Buffer batch_h0, batch_h1, batch_residual;
Buffer batch_q, batch_k, batch_v;
Buffer batch_attn_out;
Buffer batch_gate, batch_up, batch_act;
Buffer batch_post_norm;
int max_prefill_chunk;
};
There are two sets of buffers: decode buffers (single-token inference) and prefill buffers (batch processing of the prompt). Let us go through each one.
Decode Buffers: The Single-Token Pipeline
During decode, exactly one token flows through the transformer per step. The buffers are sized for a single-row computation:
Forward pass data flow (one decode step):
token_ids ──> [embedding lookup] ──> h0 [dim]
|
+─────────────+
|
[layer loop x N_layers]
|
h0 ──> [RMSNorm] ──> residual ──> [QKV GEMV] ──> qkv [q_dim+2*kv_dim]
|
+-──────────────+──────────────+
| | |
Q [q_dim] K [kv_dim] V [kv_dim]
| | |
| [KV cache write] |
| | |
+──> [Flash Attention] <──────+
|
attn_out [dim]
|
[O projection]
|
h1 [dim] ──+──> h0 = h0 + h1
| (residual add)
+───────────────────────────────+
|
h0 ──> [RMSNorm] ──> residual ──> [Gate GEMV] ──> ffn_gate [ffn_dim]
|──> [Up GEMV] ──> ffn_up [ffn_dim]
|
[SiLU(gate) * up] ──> ffn_act [ffn_dim]
|
[Down GEMV] ──> h1 [dim]
|
h0 = h0 + h1 (residual add)
|
[end layer loop]
|
h0 ──> [RMSNorm] ──> residual ──> [Logit GEMV] ──> logits [vocab_size]
|
[argmax/sample]
|
token_ids (next token)
h0 and h1: The Ping-Pong Pair
These are the two primary hidden state buffers, each sized [dim] in FP16.
They implement a ping-pong pattern: the output of one operation goes into
one buffer, the next operation reads from that buffer and writes to the other.
Layer start: h0 holds the current hidden state
RMSNorm: reads h0, writes residual
QKV GEMV: reads residual, writes qkv
Attention: reads qkv + KV cache, writes attn_out
O Projection: reads attn_out, writes h1
Residual add: h0 = h0 + h1 (h0 updated in-place)
FFN start: h0 holds the updated hidden state
RMSNorm: reads h0, writes residual
Gate GEMV: reads residual, writes ffn_gate
Up GEMV: reads residual, writes ffn_up
Activation: reads ffn_gate + ffn_up, writes ffn_act
Down GEMV: reads ffn_act, writes h1
Residual add: h0 = h0 + h1 (h0 updated in-place)
Next layer: h0 holds the result
Why ping-pong instead of in-place? Because GPU kernels read and write concurrently. You cannot safely read and write the same buffer in a single dispatch. The ping-pong pattern ensures the read source and write destination are always different physical buffers.
Ping-pong pattern across operations:
Op 1: READ h0 ────> WRITE h1
Op 2: READ h1 ────> WRITE h0
Op 3: READ h0 ────> WRITE h1
...
The two buffers alternate roles as "source" and "destination"
so we never have a read-write hazard on the same buffer.
The residual Buffer
Sized [dim] FP16. Holds the output of RMSNorm/LayerNorm before it gets projected
into Q/K/V or FFN inputs. This is separate from h0/h1 because we need the
un-normed h0 for the residual connection: h0 = h0 + f(norm(h0)).
The qkv Buffer: Contiguous Q|K|V
This is one of the cleverer design decisions. Instead of three separate buffers for Q, K, and V, Akunu packs them into a single contiguous buffer:
qkv buffer layout (bytes):
|<-------- q_dim*2 -------->|<--- kv_dim*2 --->|<--- kv_dim*2 --->|
| Q region | K region | V region |
^ ^ ^
qkv_q_offset = 0 qkv_k_offset qkv_v_offset
The byte offsets are pre-computed at creation time:
s.qkv_q_offset = 0;
s.qkv_k_offset = q_dim * 2; // q_dim FP16 elements = q_dim*2 bytes
s.qkv_v_offset = (q_dim + kv_dim) * 2;
Why pack them together? Two reasons:
-
Fewer buffer bindings. The QKV GEMV kernel can write Q, K, and V with different buffer offsets into the same Metal buffer. This means one buffer binding instead of three, which reduces Metal command encoder overhead.
-
The KV cache write kernel reads K and V from known offsets. It binds the
qkvbuffer with the appropriate offset to read just the K or V portion.
For GQA models where q_dim != kv_dim (e.g., LLaMA 3-8B has q_dim=4096,
kv_dim=1024), this packing is especially efficient – Q is 4x larger than K or V,
and they all fit in one allocation.
Example: LLaMA 3.1-8B
q_dim = 32 heads * 128 dim = 4096
kv_dim = 8 heads * 128 dim = 1024
qkv buffer: (4096 + 1024 + 1024) * 2 = 12,288 bytes = 12 KB
Offsets:
Q starts at byte 0
K starts at byte 8192
V starts at byte 10240
ffn_gate, ffn_up, ffn_act
Three buffers for the feed-forward network, each [ffn_dim] FP16. The FFN in
LLaMA-style models is:
gate = W_gate @ x --> ffn_gate
up = W_up @ x --> ffn_up
act = SiLU(gate) * up --> ffn_act
out = W_down @ act --> h1
Note that ffn_gate is allocated as ffn_dim * 2 * 2 bytes (double-sized) to
support fused gate+up computation in some kernel variants. The actual FFN
dimension for LLaMA 3-8B is 14336, so ffn_gate is 56 KB.
logits Buffer
Sized [vocab_size] FP16. For LLaMA 3, vocab_size = 128256, so this buffer is
128256 * 2 = ~250 KB. The logit projection (hidden state times the un-embedding
matrix) writes here, and the sampler reads from here.
token_ids Buffer
Sized [max_chain] U32 (4 bytes per token). This holds token IDs for both the
input (prefill) and the output (chain decode). The chain decode loop writes
the next token ID here, then the embedding lookup reads it on the next iteration.
Prefill Buffers: Batch Processing
During prefill, we process up to max_prefill_chunk tokens simultaneously (default
4096). Every buffer needs a batch dimension:
Decode buffer: [dim] -- 1 token
Prefill buffer: [prefill_chunk, dim] -- up to 4096 tokens
The prefill buffers mirror the decode buffers with an added batch dimension:
+------------------+----------------------------------+
| Decode Buffer | Prefill Buffer |
+------------------+----------------------------------+
| h0 [dim] | batch_h0 [chunk * dim] |
| h1 [dim] | batch_h1 [chunk * dim] |
| residual [dim] | batch_residual [chunk * dim] |
| qkv [qkv_dim] | batch_q [chunk * q_dim] |
| | batch_k [chunk * kv_dim] |
| | batch_v [chunk * kv_dim] |
| attn_out [dim] | batch_attn_out [chunk * dim] |
| ffn_gate [ffn] | batch_gate [chunk * ffn_dim] |
| ffn_up [ffn] | batch_up [chunk * ffn_dim] |
| ffn_act [ffn] | batch_act [chunk * ffn_dim] |
| post_norm [dim] | batch_post_norm [chunk * dim] |
+------------------+----------------------------------+
Note that prefill uses separate Q, K, V buffers instead of the packed qkv
layout. This is because the prefill attention kernel (flash attention prefill)
expects separate Q, K, V inputs in the shape [seq, n_heads, head_dim], which
is more natural for batched GEMM operations.
Memory Budget Calculation
Let us compute the total scratch memory for LLaMA 3.1-8B:
Model parameters:
dim = 4096
q_dim = 4096 (32 heads * 128)
kv_dim = 1024 (8 heads * 128)
ffn_dim = 14336
vocab_size = 128256
chunk = 4096 (prefill chunk size)
Decode buffers (single token):
h0: 4096 * 2 = 8,192 bytes
h1: 4096 * 2 = 8,192 bytes
residual: 4096 * 2 = 8,192 bytes
qkv: (4096+2*1024) * 2 = 12,288 bytes
attn_out: 4096 * 2 = 8,192 bytes
post_norm: 4096 * 2 = 8,192 bytes
ffn_gate: 14336 * 2 * 2 = 57,344 bytes (2x for fused gate+up)
ffn_up: 14336 * 2 = 28,672 bytes
ffn_act: 14336 * 2 = 28,672 bytes
logits: 128256 * 2 = 256,512 bytes
token_ids: 4096 * 4 = 16,384 bytes
──────────────────────────────────────────────
Decode total: ~432 KB
Prefill buffers (4096 tokens):
batch_h0: 4096 * 4096 * 2 = 33,554,432 bytes
batch_h1: 4096 * 4096 * 2 = 33,554,432 bytes
batch_residual: 4096 * 4096 * 2 = 33,554,432 bytes
batch_q: 4096 * 4096 * 2 = 33,554,432 bytes
batch_k: 4096 * 1024 * 2 = 8,388,608 bytes
batch_v: 4096 * 1024 * 2 = 8,388,608 bytes
batch_attn_out: 4096 * 4096 * 2 = 33,554,432 bytes
batch_gate: 4096 * 14336 * 2 = 117,440,512 bytes
batch_up: 4096 * 14336 * 2 = 117,440,512 bytes
batch_act: 4096 * 14336 * 2 = 117,440,512 bytes
batch_post_norm:4096 * 4096 * 2 = 33,554,432 bytes
──────────────────────────────────────────────
Prefill total: ~534 MB
Grand total scratch: ~534 MB
The prefill buffers dominate – they are the batch dimension multiplied by the
hidden and FFN dimensions. For a model with ffn_dim = 14336 and a prefill
chunk of 4096, each FFN buffer alone is 112 MB.
Here is the full memory picture for LLaMA 3.1-8B at 4096 context:
+-----------------------------------+-----------+
| Component | Memory |
+-----------------------------------+-----------+
| Model weights (Q4_0) | ~4.3 GB |
| KV cache (32 layers, 4096 ctx) | 512 MB |
| Scratch decode buffers | ~0.4 MB |
| Scratch prefill buffers | ~534 MB |
+-----------------------------------+-----------+
| Total | ~5.3 GB |
+-----------------------------------+-----------+
Buffer Reuse Within a Forward Pass
A key insight is that these buffers are reused within a single forward pass, not just across forward passes. Within the layer loop:
Layer L:
residual: written by RMSNorm, read by QKV projection
qkv: written by QKV projection, read by attention + KV write
attn_out: written by attention, read by O projection
h1: written by O projection, read by residual add
ffn_gate: written by Gate GEMV, read by activation
ffn_up: written by Up GEMV, read by activation
ffn_act: written by activation, read by Down GEMV
Layer L+1:
Same buffers, completely overwritten!
The transformer processes layers sequentially. Layer L’s ffn_gate output is
consumed within layer L, then layer L+1 overwrites the same buffer with its own
ffn_gate output. No per-layer scratch is needed – one set of buffers serves
all layers.
The only per-layer storage is the KV cache (Chapter 45), which must retain values across the entire sequence.
The Ping-Pong Pattern in Detail
Let us trace the exact read/write pattern through two consecutive layers:
Layer 0:
READ h0 WRITE residual (RMSNorm)
READ residual WRITE qkv (QKV GEMV)
READ qkv+cache WRITE attn_out (Attention)
READ attn_out WRITE h1 (O GEMV)
READ h0, h1 WRITE h0 (Residual add: h0 += h1)
READ h0 WRITE residual (RMSNorm)
READ residual WRITE ffn_gate (Gate GEMV)
READ residual WRITE ffn_up (Up GEMV)
READ gate, up WRITE ffn_act (SiLU * mul)
READ ffn_act WRITE h1 (Down GEMV)
READ h0, h1 WRITE h0 (Residual add: h0 += h1)
Layer 1:
READ h0 WRITE residual (RMSNorm)
... same pattern, same buffers, completely safe because
each buffer is fully consumed before being overwritten ...
Notice that h0 is both read and written in the residual add step. This is
safe because it is an element-wise operation (h0[i] += h1[i]), implemented
as a fused add kernel that handles the in-place update correctly.
Allocation: All at Once, All FP16
The create() factory method allocates everything in one shot:
static ScratchBuffers create(Device& device, const AkunuModelConfig& cfg,
int max_context = 4096,
int prefill_chunk = 4096,
int max_chain = 128) {
ScratchBuffers s;
int dim = cfg.dim;
int q_dim = cfg.q_dim;
int kv_dim = cfg.kv_dim;
int ffn_dim = cfg.ffn_dim;
int vocab = cfg.vocab_size;
// Decode buffers
s.h0 = device.allocate(dim * 2);
s.h1 = device.allocate(dim * 2);
s.residual = device.allocate(dim * 2);
s.qkv = device.allocate((q_dim + 2 * kv_dim) * 2);
s.attn_out = device.allocate((q_dim > dim ? q_dim : dim) * 2);
// ... etc ...
// Prefill buffers
s.batch_h0 = device.allocate(prefill_chunk * dim * 2);
// ... etc ...
return s;
}
Every size is count * 2 because FP16 is 2 bytes per element. The token_ids
buffer uses count * 4 because token IDs are 32-bit unsigned integers.
Notice the max(q_dim, dim) for attn_out – this handles the case where the
attention output dimension might differ from the model dimension (though in
practice they are usually equal).
Cleanup
Like the KV cache, cleanup is explicit:
void destroy(Device& device) {
for (Buffer *b : {&h0, &h1, &residual, &qkv, &attn_out, &post_norm,
&ffn_gate, &ffn_up, &ffn_act, &logits, &token_ids,
&batch_h0, &batch_h1, &batch_residual,
&batch_q, &batch_k, &batch_v, &batch_attn_out,
&batch_gate, &batch_up, &batch_act, &batch_post_norm}) {
device.free_buffer(*b);
}
}
Every buffer gets freed. The initializer list is a convenient C++ trick for iterating over all the member buffers without repeating the free logic.
Post-Norm Buffer (Gemma Compatibility)
The post_norm and batch_post_norm buffers are specifically for Gemma-style
architectures that use post-attention and post-FFN normalization:
Standard LLaMA: x = x + Attn(Norm(x))
Gemma: x = x + Norm(Attn(Norm(x)))
^^^^
post-norm needs its own buffer
For models that do not use post-norm (most LLaMA variants), this buffer is
allocated but never written to. The wasted memory is dim * 2 = 8 KB for the
decode version – negligible.
Why Not Use a Memory Pool?
You might wonder: instead of individual named buffers, why not allocate one big slab and carve it up? Memory pools are common in GPU programming.
The answer is debuggability. With named buffers:
- Metal GPU debugger shows “h0”, “ffn_gate”, “logits” etc. in the buffer list
- Each buffer has a known size that matches its semantic purpose
- There is no offset arithmetic to get wrong
- Adding a new buffer is trivial – just add a field and an allocate call
The overhead of having ~22 separate MTLBuffer objects instead of 1 is
negligible. Metal’s buffer creation is fast, and we only do it once at load time.
Summary
Key design principles:
1. Pre-allocate ALL buffers at model load
2. Zero allocation during inference
3. Ping-pong (h0/h1) avoids read-write hazards
4. Contiguous QKV with byte sub-offsets
5. Separate decode (1-token) and prefill (N-token) buffer sets
6. Every buffer is reused across all transformer layers
7. Explicit create/destroy lifecycle
Memory hierarchy during inference:
+─────────────────────────────────────+
| Model Weights (read-only, ~GB) | Largest
+─────────────────────────────────────+
| Prefill Scratch (~500 MB) |
+─────────────────────────────────────+
| KV Cache (~512 MB for 4K ctx) |
+─────────────────────────────────────+
| Decode Scratch (~0.4 MB) | Smallest
+─────────────────────────────────────+
With the KV cache and scratch buffers understood, we have covered the complete runtime memory picture. Every byte of GPU memory used during inference is accounted for: model weights, KV cache, and scratch buffers. No hidden allocations, no surprises, no fragmentation. This is what makes it possible to predict exactly whether a given model will fit in memory before loading it.
The Tokenizer
Before a single GPU kernel fires, before attention scores are computed, before any
matrix is multiplied – text must become numbers. The tokenizer is the bridge
between the world of strings and the world of tensors. It takes a prompt like
“Hello, world!” and produces a sequence of integer token IDs like [9906, 11, 1917, 0].
These IDs index into an embedding table to produce the model’s initial hidden
states.
Akunu implements its tokenizer in pure C++ with no dependencies on Python, SentencePiece libraries, or Hugging Face transformers. It supports three tokenization algorithms – SentencePiece BPE (LLaMA, Qwen), GPT-2 byte-level BPE (Mistral, Phi), and WordPiece (BERT) – all in about 660 lines of code.
This chapter covers the BPE algorithm in detail, walks through the three variant implementations, explains GPT-2’s bizarre byte-to-unicode mapping, and describes the incremental decode system for real-time streaming.
What is BPE?
Byte Pair Encoding (BPE) is a subword tokenization algorithm. It does not split text into words or characters, but into subword units that balance vocabulary size with encoding efficiency.
The core idea: start with individual characters (or bytes), then iteratively merge the most frequent adjacent pair into a new token. After training, you have a vocabulary of subword units and a set of merge rules that define which pairs to combine and in what order.
Here is BPE encoding in pseudocode:
function bpe_encode(text, merges, vocab):
tokens = split text into individual characters/bytes
while tokens has at least 2 elements:
find the adjacent pair with the highest merge priority
if no valid merge exists: break
merge that pair into a single token
return [vocab[t] for t in tokens]
The critical difference between BPE variants is:
- What are the initial units? (characters, bytes, byte-mapped Unicode)
- How are merge priorities defined? (scores, ranks, longest-match)
- How are spaces handled? (space prefix, byte encoding, no special handling)
SentencePiece BPE (LLaMA, Qwen)
SentencePiece BPE, used by LLaMA and Qwen models, works on Unicode characters.
Spaces are replaced with the special marker \u2581 (lower one eighth block, _).
Each token in the vocabulary has a float score – higher scores are merged first.
Encoding Algorithm
std::vector<uint32_t> Tokenizer::bpe_sentencepiece(const std::string& text) const {
// 1. Prepend space marker and replace spaces
std::string processed;
if (add_space_prefix_) processed = SP_SPACE; // ▁
for (char c : text) {
if (c == ' ') processed += SP_SPACE;
else processed += c;
}
// 2. Split into individual UTF-8 characters
std::vector<std::string> tokens = utf8_chars(processed);
// 3. Iteratively merge highest-score pair
while (tokens.size() >= 2) {
float best_score = -INF;
int best_idx = -1;
for (int i = 0; i < tokens.size() - 1; i++) {
std::string merged = tokens[i] + tokens[i+1];
auto it = token_to_id_.find(merged);
if (it != token_to_id_.end()) {
float score = scores_[it->second];
if (score > best_score) {
best_score = score;
best_idx = i;
}
}
}
if (best_idx < 0) break;
tokens[best_idx] = tokens[best_idx] + tokens[best_idx + 1];
tokens.erase(tokens.begin() + best_idx + 1);
}
// 4. Convert to IDs (with byte fallback for unknown chars)
// ...
}
Let us trace through an example:
Input: "Hello"
After space prefix: "▁Hello"
Split into chars: ["▁", "H", "e", "l", "l", "o"]
Round 1: Best merge is "l" + "l" -> "ll" (highest score)
["▁", "H", "e", "ll", "o"]
Round 2: Best merge is "ll" + "o" -> "llo"
["▁", "H", "e", "llo"]
Round 3: Best merge is "e" + "llo" -> "ello"
["▁", "H", "ello"]
Round 4: Best merge is "H" + "ello" -> "Hello"
["▁", "Hello"]
Round 5: Best merge is "▁" + "Hello" -> "▁Hello"
["▁Hello"]
Final: token_to_id["▁Hello"] = 15043
Output: [15043]
Byte Fallback
When a character sequence is not in the vocabulary, SentencePiece falls back to
encoding individual bytes as hex tokens: <0x41> for byte 0x41 (‘A’). This
ensures any input can be encoded, even if it contains rare Unicode characters
or binary data.
// Byte fallback for unknown chars
for (uint8_t byte : t) {
char hex[16];
snprintf(hex, sizeof(hex), "<0x%02X>", byte);
auto hit = token_to_id_.find(hex);
if (hit != token_to_id_.end())
ids.push_back(hit->second);
}
Decoding SentencePiece
Decoding is the reverse: look up each token ID in the vocabulary, replace \u2581
with a space, and handle byte fallback tokens:
if (is_sentencepiece_) {
int byte_val = parse_byte_token(token); // parse "<0xAB>" -> 0xAB
if (byte_val >= 0)
return std::string(1, (char)byte_val);
// Replace ▁ with space
// ... scan for 0xE2 0x96 0x81 (UTF-8 for ▁) ...
}
GPT-2 Byte-Level BPE (Mistral, Phi)
GPT-2’s BPE takes a fundamentally different approach. Instead of operating on Unicode characters, it operates on bytes, but with a twist: each byte is mapped to a printable Unicode character. This avoids control characters in the vocabulary while still being byte-complete.
The Byte-to-Unicode Mapping
This is the most confusing part of GPT-2 tokenization, so let us be very precise.
The mapping assigns each of the 256 possible byte values to a Unicode codepoint:
Printable bytes (0x21-0x7E and 0xA1-0xFF):
Map to themselves.
0x41 ('A') -> U+0041 ('A')
0x61 ('a') -> U+0061 ('a')
0xC0 -> U+00C0 ('A' with grave)
Non-printable bytes (0x00-0x20, 0x7F-0xA0):
Map to U+0100 + offset, sequentially.
0x00 -> U+0100 ('A' with macron)
0x01 -> U+0101 ('a' with macron)
...
0x20 -> U+0120 ('G' with dot above)
0x7F -> U+0121 ('g' with dot above)
...
Here is a visual:
Byte value: 0x00 0x01 ... 0x20 0x21 ... 0x7E 0x7F ... 0xA0 0xA1 ... 0xFF
|--non-printable--| |---printable---| |-np-| |---printable---|
Maps to: U+100 U+101 U+120 U+21 U+7E U+121 U+160 U+A1 U+FF
|--offset block--| |---identity----| |--block--| |---identity--|
Akunu builds this mapping at construction time:
static void build_byte_to_unicode(std::string table[256]) {
int offset = 0;
for (int b = 0; b < 256; b++) {
if ((b >= 0x21 && b <= 0x7E) || (b >= 0xA1 && b <= 0xFF)) {
table[b] = codepoint_to_utf8(b); // identity
} else {
table[b] = codepoint_to_utf8(0x100 + offset);
offset++;
}
}
}
GPT-2 Encoding
With the byte mapping in hand, GPT-2 BPE encoding works like this:
- Split the input on special tokens (longest-match greedy)
- For each non-special segment, convert bytes to GPT-2 Unicode
- Split into individual Unicode characters
- Apply BPE merges (lowest rank first, not highest score)
std::vector<uint32_t> Tokenizer::bpe_gpt2(const std::string& text) const {
auto segments = split_special_tokens(text);
for (auto& segment : segments) {
if (is_special_token(segment)) {
ids.push_back(token_to_id_[segment]);
continue;
}
// Convert bytes to GPT-2 Unicode
std::string converted;
for (uint8_t byte : segment)
converted += byte_to_unicode_[byte];
// Split into Unicode characters
auto tokens = utf8_chars(converted);
// BPE merge: lowest rank first
while (tokens.size() >= 2) {
int best_rank = INT_MAX;
int best_idx = -1;
for (int i = 0; i < tokens.size() - 1; i++) {
std::string pair = tokens[i] + " " + tokens[i+1];
auto it = merge_ranks_.find(pair);
if (it != merge_ranks_.end() && it->second < best_rank) {
best_rank = it->second;
best_idx = i;
}
}
if (best_idx < 0) break;
tokens[best_idx] += tokens[best_idx + 1];
tokens.erase(tokens.begin() + best_idx + 1);
}
// Convert to IDs ...
}
}
The key difference from SentencePiece: merge rules use ranks (lower is better)
stored as "token_a token_b" -> rank_int, while SentencePiece uses per-token
scores (higher is better).
Special Token Splitting
Before BPE, GPT-2 tokenizers must split the input on special tokens like
<|im_start|>, <|im_end|>, <|endoftext|>, etc. These are matched greedily
(longest first) and passed through as-is:
Input: "<|im_start|>system\nHello<|im_end|>"
Split: ["<|im_start|>", "system\nHello", "<|im_end|>"]
^ special ^ BPE encoded ^ special
Akunu marks special tokens with a \x01 prefix byte during splitting so the
encoder can distinguish them from regular text without additional data structures.
GPT-2 Decoding
Decoding reverses the byte mapping: for each Unicode codepoint in the token string, look up the corresponding byte value:
// GPT-2 decode: reverse byte-to-unicode mapping
std::string bytes;
size_t pos = 0;
while (pos < token.size()) {
uint32_t cp = decode_utf8_cp(token, pos);
auto it = unicode_to_byte_.find(cp);
if (it != unicode_to_byte_.end())
bytes += (char)it->second;
else
bytes += codepoint_to_utf8(cp); // pass through
}
return bytes;
WordPiece (BERT)
The third algorithm, used by BERT and embedding models like nomic-embed, is WordPiece. Unlike BPE, WordPiece does not use merges at all – it uses greedy longest-prefix matching.
Algorithm:
1. Lowercase input (for uncased BERT)
2. Split on whitespace and punctuation
3. For each word:
a. Try to find the longest prefix in the vocab
b. If found, emit it and continue from where it ended
c. For continuation pieces, try "##" + prefix (WordPiece convention)
d. If no prefix matches, emit [UNK]
Example:
Input: "unaffable"
Try "unaffable" -> not in vocab
Try "unaffabl" -> not in vocab
...
Try "un" -> in vocab! Emit token for "un"
Remaining: "affable"
Try "##affable" -> not in vocab
Try "##affabl" -> not in vocab
...
Try "##aff" -> in vocab! Emit token for "##aff"
Remaining: "able"
Try "##able" -> in vocab! Emit token for "##able"
Output: ["un", "##aff", "##able"]
Akunu’s implementation adds [SEP] at the end (BERT convention) and handles
the nomic-bert variant where continuation tokens may not use the ## prefix.
The Tokenizer Construction Pipeline
Akunu supports two paths for loading a tokenizer:
Path 1: From GGUF Metadata
When loading a GGUF model file, the tokenizer vocabulary, scores, merges, and special token IDs are embedded in the file’s metadata:
Tokenizer tok = Tokenizer::from_gguf(
vocab, // vector<string>: token strings
scores, // vector<float>: per-token scores (SentencePiece)
merges, // vector<string>: merge rules (GPT-2)
bos_id, eos_id, // special token IDs
model_type, // "llama" or "gpt2" or "bert"
add_bos, // prepend BOS token?
add_space_prefix // prepend ▁ in SentencePiece mode?
);
The constructor builds the token_to_id_ reverse lookup map, the merge ranks
table (for GPT-2), and the byte-to-unicode mapping tables.
Path 2: From HuggingFace tokenizer.json
For models distributed as HF repos (without GGUF), Akunu has a hand-rolled JSON
parser in hf_tokenizer_loader.h that reads tokenizer.json and
tokenizer_config.json:
HFTokenizerData data;
bool ok = load_hf_tokenizer("/path/to/model/dir", data);
// data.vocab, data.merges, data.bos_id, data.eos_id, etc.
This parser is deliberately minimal – no JSON library dependency. It uses
character-level scanning with find_key(), parse_json_string(), and
parse_json_int() to extract exactly the fields it needs:
model.vocab: the token-to-ID mappingmodel.merges: the BPE merge rulesadded_tokens: special tokens with their IDspre_tokenizer.type: detects Gemma’s “Split” pre-tokenizer (no space prefix)
The parser also reads tokenizer_config.json for add_bos_token, bos_token,
and eos_token fields, handling the various ways different model families
specify these values.
Incremental Decode: UTF-8 Buffering for Streaming
When streaming tokens to the user in real time, there is a subtle problem: individual tokens may produce partial UTF-8 sequences. A multi-byte character like a Chinese character or an emoji might be split across two tokens.
Consider: the UTF-8 encoding of a Chinese character is 3 bytes. If token A
produces bytes [0xE4, 0xBD] and token B produces byte [0xA0], you cannot
emit token A’s output yet – those 2 bytes are not a valid UTF-8 string. You
must buffer them and wait for the completing byte.
Akunu’s decode_incremental() handles this:
std::string Tokenizer::decode_incremental(uint32_t token_id) {
std::string raw = decode(token_id);
if (raw.empty()) return "";
utf8_buf_ += raw; // append to buffer
// Scan for complete UTF-8 sequences
std::string output;
size_t i = 0, last_good = 0;
while (i < utf8_buf_.size()) {
int expected = utf8_expected_len(utf8_buf_[i]);
if (expected == 0) {
// Orphan continuation byte -- emit as-is
output += utf8_buf_[i]; i++; last_good = i;
continue;
}
if (i + expected > utf8_buf_.size())
break; // incomplete at end -- keep buffered
output += utf8_buf_.substr(i, expected);
i += expected; last_good = i;
}
// Keep only the incomplete tail
utf8_buf_ = (last_good < utf8_buf_.size())
? utf8_buf_.substr(last_good)
: "";
return output;
}
The flow:
Token 1 decoded: [0xE4, 0xBD]
Buffer: [0xE4, 0xBD]
Expected: 3 bytes (0xE4 = 3-byte leader)
Only 2 available -- keep buffered
Output: ""
Token 2 decoded: [0xA0, 0x48, 0x65]
Buffer: [0xE4, 0xBD, 0xA0, 0x48, 0x65]
Sequence 1: [0xE4, 0xBD, 0xA0] = complete 3-byte UTF-8 = "你"
Sequence 2: [0x48] = ASCII 'H'
Sequence 3: [0x65] = ASCII 'e'
Output: "你He"
Buffer: empty
The state variable utf8_buf_ persists between calls and must be reset between
conversations via reset_incremental().
Thinking Block Stripping
Some models (like Qwen3) emit <think>...</think> blocks during chain-of-thought
reasoning. When streaming to the user, you typically want to strip these:
Model output: "<think>The user asked about France. Paris is the capital.</think>The capital of France is Paris."
After stripping: "The capital of France is Paris."
The strip_thinking() method implements a streaming-aware state machine:
States:
THINK_IDLE: normal output mode, scanning for <think>
THINK_INSIDE: inside a thinking block, scanning for </think>
Edge cases handled:
- Partial tag at chunk boundary: "<thi" + "nk>" across two calls
- Nested angle brackets that are NOT think tags
- Think blocks spanning multiple streaming chunks
The implementation buffers partial tag matches in think_buf_ and uses the
think_state_ enum to track whether we are inside a thinking block.
Extra EOS Token IDs
Different model families use different end-of-sequence tokens. LLaMA 3 uses
<|end_of_text|>, ChatML models use <|im_end|>, and older models use </s>.
Some models have multiple valid EOS tokens.
Akunu handles this with add_eos_id():
void add_eos_id(uint32_t id) { extra_eos_ids_.push_back(id); }
bool is_eos(uint32_t id) const {
if (id == eos_id_) return true;
for (auto eid : extra_eos_ids_)
if (id == eid) return true;
return false;
}
The GGUF loader and HF tokenizer loader register the appropriate EOS tokens for each model architecture.
Performance Considerations
Akunu’s tokenizer is intentionally simple and not optimized for speed. The BPE inner loop is O(n^2) in the number of character tokens (scan all pairs, merge one, repeat). For typical prompts of a few hundred to a few thousand tokens, this takes microseconds – negligible compared to the milliseconds of GPU prefill.
Where tokenizer performance does matter is in the vocabulary lookup. The
token_to_id_ map is an unordered_map<string, uint32_t> with ~100K-200K
entries. Each find() call hashes a string and does a comparison. For the BPE
inner loop, this is called O(n) times per merge round, for O(n) rounds, giving
O(n^2) hash lookups total. For n=1000 characters, that is 1M lookups – still
fast in practice but worth being aware of.
Summary
+------------------+------------------+-------------------+-----------+
| Feature | SentencePiece | GPT-2 BPE | WordPiece |
+------------------+------------------+-------------------+-----------+
| Models | LLaMA, Qwen | Mistral, Phi | BERT |
| Input units | UTF-8 chars | Bytes (mapped) | Chars |
| Merge priority | Score (higher= | Rank (lower= | N/A |
| | better) | better) | |
| Space handling | Replace with ▁ | Byte-encode 0x20 | Split on |
| Special tokens | <s>, </s> | <|...|> patterns | [CLS] etc |
| Unknown chars | <0xXX> fallback | Always encodable | [UNK] |
+------------------+------------------+-------------------+-----------+
The tokenizer is a deceptively simple component – the BPE algorithm is straightforward, but the details (byte mappings, space handling, special token splitting, UTF-8 buffering) are where the bugs hide. Akunu’s implementation handles all three major tokenizer families in a single class, loaded from either GGUF metadata or HuggingFace JSON, with streaming-aware incremental decode.
The Grammar Engine
Here is a scenario that breaks naive LLM deployment: you ask a model to produce
JSON, and it gives you {"name": "Alice", "age": 30,} – with a trailing comma.
Perfectly reasonable text. Completely invalid JSON. Your downstream parser chokes,
your API returns a 500, and your user files a bug report.
Grammar-constrained decoding fixes this by making it impossible for the model
to produce invalid output. Before sampling each token, the grammar engine masks
out every token that would violate the grammar. The model can only choose from
tokens that lead to valid continuations. If the grammar says no trailing comma,
the token for , gets its logit set to negative infinity before softmax.
Akunu implements this with a Nondeterministic Pushdown Automaton (NPDA) over GBNF grammars, ported from llama.cpp and extended with JSON Schema support and deferred activation for thinking models. This chapter covers the grammar format, the NPDA engine, token-to-codepoint mapping, the apply/accept cycle, and the JSON Schema converter.
GBNF: The Grammar Format
GBNF (GGML BNF) is a BNF-like grammar notation used by llama.cpp. It looks like this:
root ::= object
object ::= "{" ws pair ("," ws pair)* ws "}"
pair ::= string ws ":" ws value
string ::= "\"" chars "\""
chars ::= char*
char ::= [^"\\] | "\\" escape
value ::= string | number | "true" | "false" | "null" | object | array
number ::= "-"? [0-9] [0-9]*
ws ::= [ \t\n]*
Each line defines a rule. The left side is a rule name, ::= separates it from
the body, and the body is a sequence of elements separated by spaces. Alternatives
are separated by |.
The element types are:
+--------------------+-------------------+---------------------------+
| Syntax | Example | Meaning |
+--------------------+-------------------+---------------------------+
| "literal" | "true" | Match exact string |
| [abc] | [0-9] | Character class |
| [^abc] | [^"\\] | Negated character class |
| rule-name | string | Non-terminal reference |
| (group) | ("a" | "b") | Grouping |
| expr* | char* | Zero or more |
| expr+ | [0-9]+ | One or more |
| expr? | "-"? | Zero or one |
| expr{n} | [0-9]{4} | Exactly n |
| expr{n,m} | [0-9]{1,3} | Between n and m |
| . | . | Any character (wildcard) |
+--------------------+-------------------+---------------------------+
Internal Representation: Grammar Elements
After parsing, each rule is a vector of GrammarElement structs:
enum GrammarElementType : uint32_t {
GTYPE_END = 0, // end of rule
GTYPE_ALT = 1, // alternate (|)
GTYPE_RULE_REF = 2, // non-terminal reference
GTYPE_CHAR = 3, // match Unicode code point
GTYPE_CHAR_NOT = 4, // inverse char class [^...]
GTYPE_CHAR_RNG_UPR = 5, // upper bound of range
GTYPE_CHAR_ALT = 6, // additional char in class
GTYPE_CHAR_ANY = 7, // wildcard .
};
struct GrammarElement {
GrammarElementType type;
uint32_t value; // code point, rule ID, etc.
};
A rule like digit ::= [0-9] becomes:
[{GTYPE_CHAR, '0'}, {GTYPE_CHAR_RNG_UPR, '9'}, {GTYPE_END, 0}]
A rule like bool ::= "true" | "false" becomes:
[{GTYPE_CHAR, 't'}, {GTYPE_CHAR, 'r'}, {GTYPE_CHAR, 'u'}, {GTYPE_CHAR, 'e'},
{GTYPE_ALT, 0},
{GTYPE_CHAR, 'f'}, {GTYPE_CHAR, 'a'}, {GTYPE_CHAR, 'l'}, {GTYPE_CHAR, 's'}, {GTYPE_CHAR, 'e'},
{GTYPE_END, 0}]
Character classes are encoded as the first char type (GTYPE_CHAR or
GTYPE_CHAR_NOT) followed by GTYPE_CHAR_ALT elements for additional characters,
with GTYPE_CHAR_RNG_UPR for ranges:
[a-zA-Z] becomes:
[{GTYPE_CHAR, 'a'}, {GTYPE_CHAR_RNG_UPR, 'z'},
{GTYPE_CHAR_ALT, 'A'}, {GTYPE_CHAR_RNG_UPR, 'Z'}]
The GBNF Parser
The parser in grammar.cpp is a recursive-descent parser operating on the GBNF
source string. It is structured as a GBNFParser struct with these key methods:
parse_all() -- top level: parse rule definitions until EOF
parse_rule_def() -- parse "name ::= body\n"
parse_alternates() -- parse "seq1 | seq2 | seq3"
parse_sequence() -- parse a sequence of elements
parse_char_class() -- parse [a-z] or [^"\\]
handle_repetition() -- handle *, +, ?, {n,m} after an element
Repetition Expansion
Repetition operators are expanded into helper rules at parse time. This is necessary because the NPDA engine does not directly support repetition – it only understands sequences, alternates, and rule references.
Original: digit ::= [0-9]+
Expanded:
digit ::= digit_sub digit_rep
digit_sub ::= [0-9]
digit_rep ::= digit_sub digit_rep | (empty alternate)
The * operator (zero or more) generates:
expr* becomes:
rep_rule ::= expr rep_rule | (empty alternate = epsilon)
The ? operator (zero or one) generates:
expr? becomes:
opt_rule ::= expr | (empty alternate)
Bounded repetition {n,m} generates a chain of optional rules:
expr{2,4} becomes:
expr expr opt_2
opt_2 ::= expr opt_1 |
opt_1 ::= expr opt_0 |
opt_0 ::= expr |
This expansion happens inside handle_repetition(), which extracts the
sub-expression, wraps it in a helper rule if complex, and emits the appropriate
chain of rule references.
The NPDA Engine
Now for the core of the grammar engine: the nondeterministic pushdown automaton. This is what determines, at each generation step, which tokens are valid continuations.
What is an NPDA?
A pushdown automaton is like a finite state machine but with a stack. The stack allows it to match nested structures (like balanced braces) that regular expressions cannot. “Nondeterministic” means the automaton can be in multiple states simultaneously – it explores all possible paths through the grammar in parallel.
In Akunu’s implementation, the state is represented as a set of stacks, where each stack is a vector of pointers into rule elements:
using GrammarStack = std::vector<const GrammarElement *>;
using GrammarStacks = std::vector<GrammarStack>;
Each stack represents one possible parsing position in the grammar. The top of the stack points to the next terminal element that needs to be matched. Elements below the top represent “return addresses” – where to continue after the current rule finishes.
Example grammar:
root ::= "a" inner "c"
inner ::= "b"
Parsing "abc":
Initial stacks (after advance_stack):
[root@"a"] -- top of stack points to CHAR 'a' in root
After matching 'a':
[root@inner, inner@"b"] -- pushed inner rule, top points to CHAR 'b'
After matching 'b':
[root@"c"] -- inner rule finished, popped back to root
After matching 'c':
[] -- empty stack = grammar fully matched
advance_stack: Resolving Non-Terminals
The advance_stack() function takes a stack and “advances” it past all
GTYPE_RULE_REF elements until every resulting stack ends at a terminal
(CHAR, CHAR_NOT, CHAR_ANY) or is empty. This is where nondeterminism kicks in:
a rule with alternates produces multiple stacks.
Given stack: [..., ptr to RULE_REF(inner)]
And inner ::= "x" | "y"
advance_stack produces:
Stack 1: [..., continuation, ptr to CHAR 'x'] (first alternate)
Stack 2: [..., continuation, ptr to CHAR 'y'] (second alternate)
The function uses a worklist algorithm with deduplication to avoid exponential blowup:
static void advance_stack(const GrammarRules& rules,
const GrammarStack& stack,
GrammarStacks& new_stacks) {
std::set<GrammarStack> seen; // deduplicate
std::vector<GrammarStack> todo;
todo.push_back(stack);
while (!todo.empty()) {
GrammarStack cur = todo.back(); todo.pop_back();
if (seen.count(cur)) continue;
seen.insert(cur);
if (cur.empty()) {
new_stacks.push_back(cur); // grammar complete
continue;
}
const GrammarElement *pos = cur.back();
switch (pos->type) {
case GTYPE_RULE_REF:
// Push alternatives of the referenced rule
// ... expand into todo ...
break;
case GTYPE_CHAR:
case GTYPE_CHAR_NOT:
case GTYPE_CHAR_ANY:
// Terminal -- this stack is ready
new_stacks.push_back(cur);
break;
}
}
}
match_char: Testing a Code Point
When checking whether a code point matches a character element, match_char()
handles chars, ranges, alternates, and negation:
static pair<bool, const GrammarElement*> match_char(
const GrammarElement *pos, uint32_t chr) {
bool found = false;
bool is_positive = (pos->type == GTYPE_CHAR || pos->type == GTYPE_CHAR_ANY);
do {
if (pos[1].type == GTYPE_CHAR_RNG_UPR) {
found = found || (pos->value <= chr && chr <= pos[1].value);
pos += 2;
} else if (pos->type == GTYPE_CHAR_ANY) {
found = true; pos += 1;
} else {
found = found || pos->value == chr;
pos += 1;
}
} while (pos->type == GTYPE_CHAR_ALT);
return {found == is_positive, pos};
}
For [^"\\], is_positive is false (because GTYPE_CHAR_NOT), so the result
is inverted: found == false means the character is accepted (it is NOT in the
exclusion set).
Token-to-Codepoint Mapping
Here is a critical insight: the grammar operates on Unicode code points, but the model generates tokens. A single token might encode multiple code points (“Hello” is 5 code points in 1 token), and a single code point might span multiple tokens (a rare Chinese character might be 3 byte-fallback tokens).
The init_vocab() method pre-computes the code point sequence for every token
in the vocabulary:
void Grammar::init_vocab(const Tokenizer& tokenizer) {
vocab_size_ = tokenizer.vocab_size();
token_codepoints_.resize(vocab_size_);
for (int i = 0; i < vocab_size_; i++) {
std::string text = tokenizer.decode(i);
token_codepoints_[i] = decode_utf8(text); // null-terminated
}
}
This creates a lookup table: token ID -> null-terminated vector of Unicode code points. For example:
Token 9906 ("Hello") -> [72, 101, 108, 108, 111, 0] (H, e, l, l, o, NUL)
Token 29892 (",") -> [44, 0] (comma, NUL)
Token 259 ("<0xE4>")-> [0xE4 as partial UTF-8] (continuation byte)
The apply() Method: Masking Invalid Tokens
apply() is called before sampling. It scans every token in the vocabulary,
checks whether its code point sequence is a valid continuation of the grammar,
and masks invalid tokens to -inf:
void Grammar::apply(float *logits, int vocab_size) const {
if (awaiting_trigger_) return; // deferred activation
if (stacks_.empty()) return;
// Check if grammar is complete (any empty stack)
bool allow_eos = false;
for (auto& stack : stacks_)
if (stack.empty()) allow_eos = true;
// Build candidates
std::vector<GrammarCandidate> candidates;
for (int i = 0; i < vocab_size; i++) {
if (is_eos(i)) {
if (!allow_eos) logits[i] = -INFINITY;
continue;
}
// ... add to candidates ...
}
// Reject invalid candidates across all stacks
auto rejects = reject_candidates(rules_, stacks_, candidates);
for (auto& r : rejects)
logits[r.index] = -INFINITY;
}
The reject_candidates() function uses union semantics: a token is rejected only
if all stacks reject it. This is the nondeterministic part – the grammar
accepts a token if any possible parse path accepts it.
The rejection algorithm works recursively, matching one code point at a time:
reject_candidates(stacks, candidates):
rejects = reject_for_stack(stacks[0], candidates)
for each remaining stack:
rejects = reject_for_stack(stack, rejects)
return rejects
reject_for_stack(stack, candidates):
for each candidate token:
if token's first code point matches stack top:
add to next_candidates (with code_points advanced by 1)
else:
add to rejects
advance stack past matched element
recursively check next_candidates against advanced stacks
This handles multi-codepoint tokens correctly: a token like “Hello” (5 codepoints) is checked one codepoint at a time against the grammar, advancing the grammar state at each step.
Partial UTF-8 Handling
Some tokens produce partial UTF-8 sequences (e.g., byte-fallback tokens). The
grammar engine handles this with match_partial_char(), which checks whether
a partially decoded codepoint could match the current grammar position:
static bool match_partial_char(const GrammarElement *pos,
const PartialUTF8& partial) {
// Calculate the range of codepoints this partial could become
uint32_t low = partial.value << (partial.n_remain * 6);
uint32_t high = low | ((1u << (partial.n_remain * 6)) - 1);
// Check if any codepoint in [low, high] matches
// ...
}
The accept() Method: Advancing Grammar State
After a token is sampled, accept() advances the grammar state by feeding
the token’s code points through the NPDA:
void Grammar::accept(uint32_t token_id) {
if (awaiting_trigger_) { /* check trigger, return */ }
const auto& cps = token_codepoints_[token_id];
GrammarStacks new_stacks;
for (auto& stack : stacks_) {
// Match each code point sequentially
// ... advance stack through the token's code points ...
}
stacks_ = new_stacks;
}
After accept(), the grammar’s stacks reflect the new state – only valid
continuations from this point forward will be allowed.
Left Recursion Detection
Left recursion in a grammar (e.g., expr ::= expr "+" term) would cause the
NPDA’s advance_stack() to loop infinitely. Akunu detects this at parse time:
Grammar Grammar::parse(const std::string& gbnf) {
// ... parse rules ...
// Check for left recursion
for (uint32_t i = 0; i < n; i++) {
if (detect_left_recursion(rules, i, in_progress, may_be_empty, visited))
throw std::runtime_error("GBNF: left recursion detected");
}
}
The detection algorithm does a DFS through rule references, tracking which rules are “in progress” (on the current DFS stack) and which rules “may be empty” (can match epsilon). If a rule references itself (directly or through nullable intermediates) without consuming any input, that is left recursion and we reject the grammar.
JSON Schema to GBNF: The SchemaConverter
For the common case of JSON output, Akunu provides a json_schema_to_grammar()
function that converts a JSON Schema definition into a GBNF grammar. This lives
in json_schema_to_grammar.cpp.
The SchemaConverter class is a recursive visitor over JSON Schema nodes:
class SchemaConverter {
std::string convert(const Json& schema) {
// Add common rules (ws, json-string, json-number, etc.)
add_rule("ws", "[ \\t\\n\\r]*");
add_rule("json-string", "...");
add_rule("json-number", "...");
// ...
// Visit the root schema
std::string root_body = visit(schema, "root");
add_rule("root", root_body);
return format_output();
}
};
The visitor dispatches on schema type:
Schema type -> GBNF generation
----------- ----------------
"object" -> "{" ws prop1 "," ws prop2 ... ws "}"
"array" -> "[" ws item ("," ws item)* ws "]"
"string" -> json-string (or format-specific rule)
"number" -> json-number
"boolean" -> "true" | "false"
"null" -> "null"
"enum" -> "value1" | "value2" | ...
"oneOf"/"anyOf" -> variant1 | variant2 | ...
"allOf" -> merged object properties
"const" -> literal value
Object Schema
For an object schema with required and optional properties:
{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"email": {"type": "string"}
},
"required": ["name", "age"]
}
The converter generates:
root ::= "{" ws root-kv-name "," ws root-kv-age ("," ws root-kv-email)? ws "}"
root-kv-name ::= "\"name\"" ws ":" ws json-string
root-kv-age ::= "\"age\"" ws ":" ws json-int
root-kv-email ::= "\"email\"" ws ":" ws json-string
Required properties are emitted unconditionally; optional properties are wrapped
in (... )?.
String Formats
The converter handles common string formats:
"format": "date" -> [0-9]{4} "-" [0-9]{2} "-" [0-9]{2}
"format": "time" -> [0-9]{2} ":" [0-9]{2} ":" [0-9]{2}
"format": "date-time" -> date "T" time ("Z" | offset)
"format": "uuid" -> hex{8} "-" hex{4} "-" hex{4} "-" hex{4} "-" hex{12}
Built-in JSON Grammar
For generic JSON mode (no schema), Akunu provides a built-in GBNF grammar:
root ::= object | array
value ::= object | array | string | number | "true" | "false" | "null"
object ::= "{" ws (pair ("," ws pair)*)? ws "}"
pair ::= string ws ":" ws value
array ::= "[" ws (value ("," ws value)*)? ws "]"
string ::= "\"" chars "\""
chars ::= char*
char ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt/] | "u" [0-9a-fA-F]{4}
number ::= "-"? int frac? exp?
int ::= "0" | [1-9] [0-9]*
frac ::= "." [0-9]+
exp ::= [eE] [+-]? [0-9]+
ws ::= [ \t\n\r]*
This grammar guarantees valid JSON output – no trailing commas, no unquoted keys, no mismatched braces.
Deferred Activation: Thinking Models
Models like Qwen3 emit <think>...</think> blocks before their actual answer.
If the grammar is active from the start, it would block the model from emitting
these thinking tokens (which are not valid JSON or whatever the target grammar is).
Deferred activation solves this:
void Grammar::set_trigger_text(const std::string& text) {
awaiting_trigger_ = true;
trigger_text_ = text;
trigger_buf_.clear();
}
When awaiting_trigger_ is true, apply() returns immediately without masking
any tokens. The accept() method accumulates generated text in trigger_buf_
and watches for the trigger string (typically </think>). Once the trigger is
found, awaiting_trigger_ flips to false and the grammar begins constraining
output normally.
Generation flow with deferred grammar:
Tokens: <think> I need to format this as JSON </think> { " n a m e " ...
|<--- grammar inactive, all tokens allowed --->|<-- grammar active -->|
^
trigger "</think>" matched here
The Full apply/accept Cycle
Here is the complete lifecycle during constrained generation:
1. Encode prompt, prefill
2. For each decode step:
a. Forward pass -> logits [vocab_size]
b. grammar.apply(logits, vocab_size)
- For each token, check if its codepoints form a valid continuation
- Set invalid tokens' logits to -inf
c. Sample token from masked logits
d. grammar.accept(token_id)
- Advance NPDA stacks through the token's codepoints
e. If grammar.is_done() or grammar.is_exhausted(), stop
3. Output is guaranteed to match the grammar
Summary
Grammar engine architecture:
GBNF string ──> [Parser] ──> GrammarRules (vectors of GrammarElement)
|
v
[NPDA engine]
|
+-----------+-----------+
| |
apply(logits) accept(token)
| |
mask invalid advance stacks
tokens to -inf through codepoints
| |
[Sampling] [next step]
JSON Schema ──> [SchemaConverter] ──> GBNF string ──> (same pipeline)
The grammar engine adds negligible latency to each decode step – the
reject_candidates() function processes ~100K tokens in microseconds because
most tokens are rejected by the first code point match. The real cost is the
one-time init_vocab() call that pre-computes code points for the entire
vocabulary, which takes a few milliseconds.
Whisper: Audio Transcription
Whisper is an encoder-decoder transformer trained by OpenAI for automatic speech recognition.1 It is the odd one out in Akunu’s model lineup – every other model is a decoder-only LLM, but Whisper has a full encoder (for audio) and a decoder (for text) connected by cross-attention. This chapter covers how Akunu implements Whisper from mel spectrogram computation all the way through beam search decoding, including the fused Metal kernels that make single-token decode fast enough for real-time transcription on Apple Silicon.
Architecture Overview
Whisper follows the classic encoder-decoder transformer design, with a Conv1D audio frontend replacing the standard token embedding:
Audio (PCM 16kHz)
|
v
Mel Spectrogram (80 x 3000)
|
v
+------------------------+
| Conv1D (kernel=3, s=1) | n_mels -> enc_dim, same length
| GELU |
+------------------------+
|
v
+------------------------+
| Conv1D (kernel=3, s=2) | enc_dim -> enc_dim, length / 2
| GELU |
+------------------------+
|
v
Transpose + Positional Embedding
|
v
+----------------------------------+
| Encoder Transformer x enc_layers |
| LayerNorm -> Self-Attention |
| Residual |
| LayerNorm -> FFN (GELU) |
| Residual |
+----------------------------------+
|
v
Final LayerNorm -> encoder_output [1500, enc_dim]
|
| precompute cross K/V for all decoder layers
|
v
+----------------------------------+
| Decoder Transformer x dec_layers |
| LayerNorm -> Self-Attention | (causal, with KV cache)
| Residual |
| LayerNorm -> Cross-Attention | (static K/V from encoder)
| Residual |
| LayerNorm -> FFN (GELU) |
| Residual |
+----------------------------------+
|
v
Final LayerNorm -> Logit Projection -> argmax/beam search
Key differences from a decoder-only LLM:
- LayerNorm instead of RMSNorm, with both weight and bias
- Bias terms on all linear projections (except key projections)
- GELU activation (not SiLU/SwiGLU) with no gate
- Sinusoidal positional embeddings (not RoPE)
- Cross-attention in every decoder layer
- Tied embeddings – the output projection reuses the embedding matrix
Audio Preprocessing: The Mel Spectrogram
Before any neural network computation happens, raw audio must be converted to a log-mel spectrogram. This is Akunu’s MelSpectrogram class in src/audio/mel.h.
Parameters
| Parameter | Whisper value | Description |
|---|---|---|
| Sample rate | 16000 Hz | Input audio must be 16kHz mono |
| n_fft | 400 | FFT window size (25ms at 16kHz) |
| hop_length | 160 | Stride between windows (10ms) |
| n_mels | 80 | Number of mel frequency bands |
| n_frames | 3000 | Output frames (30 seconds of audio) |
The pipeline:
PCM float samples (480,000 for 30s @ 16kHz)
|
v
+----------------------------+
| Hann window (n_fft=400) |
| Zero-pad to 512 (power of 2)|
| vDSP FFT (radix-2) |
| Power spectrum |X(f)|^2 |
+----------------------------+
| repeat for each frame (hop=160)
v
Spectrogram: [3000 frames, 201 freq bins]
|
v
+----------------------------+
| Mel filterbank (cblas_sgemm)|
| [80, 201] x [3000, 201]^T |
+----------------------------+
|
v
Log-mel: [80, 3000]
|
v
+----------------------------+
| Clamp to 1e-10, log10 |
| Dynamic range: max - 8.0 |
| Scale: (val + 4.0) / 4.0 |
+----------------------------+
|
v
Normalized mel spectrogram [80, 3000]
FFT via Accelerate
Akunu uses Apple’s Accelerate framework for the FFT, specifically vDSP_fft_zrip (in-place radix-2 FFT on split-complex data). The n_fft of 400 is zero-padded to 512 (next power of 2) for the FFT:
ms.n_fft_padded = 512; // next power of 2 >= 400
ms.log2n = 9; // log2(512)
ms.fft_setup = vDSP_create_fftsetup(ms.log2n, FFT_RADIX2);
One subtle detail: vDSP_fft_zrip returns output scaled by 2x compared to the standard DFT definition. The code compensates with a scale factor of 0.25 when computing the power spectrum:
float scale = 0.25f; // compensate vDSP 2x scaling
mag_row[0] = split.realp[0] * split.realp[0] * scale;
Mel Filterbank
The mel filterbank is a [80, 201] matrix that maps the 201 FFT frequency bins to 80 mel-spaced bands. It is constructed using the standard HTK mel scale:2
hz_to_mel(f) = 2595 * log10(1 + f/700)
mel_to_hz(m) = 700 * (10^(m/2595) - 1)
The filterbank application is a single matrix multiply using cblas_sgemm:
// mel_filters: [80, 201] magnitudes: [3000, 201]
// output: [80, 3000] = mel_filters @ magnitudes^T
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
n_mels, n_frames, n_freq,
1.0f, mel_filters.data(), n_freq,
magnitudes.data(), n_freq,
0.0f, mel.data(), n_frames);
This is one of the rare places in Akunu where CPU-side BLAS is used. The mel spectrogram is small enough that GPU dispatch overhead would exceed the computation time.
Log-Mel Normalization
The final normalization matches OpenAI’s Python implementation exactly:
1. log10(max(mel, 1e-10)) -- log scale with floor
2. clamp to (max_value - 8.0) -- 80 dB dynamic range
3. (value + 4.0) / 4.0 -- normalize to ~[0, 1]
This normalization is critical – using a different scheme (e.g., log2 instead of log10, or different clipping) will produce garbled output because the model was trained with exactly this preprocessing.
The Encoder
The encoder processes the mel spectrogram through a Conv1D frontend followed by a standard transformer.
Conv1D Frontend
Two 1D convolution layers reduce the temporal resolution by 2x:
Conv1: (80, 3000) -> (enc_dim, 3000) kernel=3, stride=1, pad=1
GELU activation
Conv2: (enc_dim, 3000) -> (enc_dim, 1500) kernel=3, stride=2, pad=1
GELU activation
These run as custom Metal kernels (conv1d_gelu_f32in_f16 and conv1d_gelu_f16). The first conv takes F32 input (from the mel spectrogram) and outputs F16; the second is F16 throughout. Both fuse the GELU activation into the convolution kernel to save a dispatch.
After convolution, the output is transposed from channel-first [enc_dim, 1500] to sequence-first [1500, enc_dim] for the transformer. A sinusoidal positional embedding is added:
// Transpose [enc_dim, enc_seq] -> [enc_seq, enc_dim]
dispatch("transpose_f16", ...);
// Add positional embedding
enc_add(dev, wb.enc_h1, wb.enc_pos_embed, wb.enc_h0, enc_seq * enc_dim);
Encoder Transformer Layers
Each encoder layer follows the standard pre-norm transformer pattern, but with LayerNorm (not RMSNorm) and GELU activation (not SiLU):
input
|
+-----> LayerNorm(attn_ln) -> Q,K,V projections
| |
| Self-Attention (non-causal)
| |
| O projection
| |
+---- residual add <-----------+
|
+-----> LayerNorm(mlp_ln) -> FFN up (GELU) -> FFN down
| |
+---- residual add <-----------------------------+
|
v
output
The self-attention is non-causal – every position attends to every other position, since the entire audio is available at once. Akunu dispatches this using the prefill attention kernel with non-causal masking:
uint32_t fc_values[] = {(uint32_t)head_dim, NQ, 1}; // 1 = non-causal
Pipeline attn_pso = dev.get_pipeline(
"flash_attention_prefill_f16",
"attn_enc_prefill_nc_hd" + std::to_string(head_dim),
fc_indices, fc_values, 3, fc_types);
The FFN uses plain GELU without a gate:
FFN(x) = down(GELU(up(x)))
This is simpler than the SwiGLU used in LLaMA/Qwen:
SwiGLU(x) = down(SiLU(gate(x)) * up(x))
No gate projection means the FFN has two weight matrices instead of three.
Head Rearrangement on GPU
A notable optimization: Akunu performs the Q/K/V head rearrangement entirely on GPU. The GEMM outputs are [seq, n_heads * head_dim] (all heads concatenated), but the attention kernel expects [n_heads, seq, head_dim] (head-major). Rather than CPU-side transpose, dedicated kernels handle this:
// [seq, n_heads*head_dim] -> [n_heads, seq, head_dim]
enc_head_rearrange_forward(dev, wb.enc_q, wb.enc_attn_out,
enc_seq, n_heads, head_dim);
// After attention: [n_heads, seq, head_dim] -> [seq, n_heads*head_dim]
enc_head_rearrange_inverse(dev, wb.enc_q, wb.enc_attn_out,
enc_seq, n_heads, head_dim);
This avoids a GPU-CPU sync that would kill pipeline parallelism.
Cross-Attention Precomputation
This is the single most important optimization for Whisper performance. The encoder output is fixed for the entire decoding process – it does not change between tokens. That means the cross-attention K and V projections can be computed once and reused for every decoder step.
static void precompute_cross_kv(Device& dev, WhisperModel& wm,
const AkunuModelConfig& cfg, WhisperBuffers& wb, int enc_seq)
{
dev.begin_encoding();
for (int l = 0; l < n_dec; l++) {
// K projection: encoder_output @ cross_attn.key.weight^T
enc_gemm(dev, wb.encoder_output, wm.get_tensor(
"decoder.blocks.%d.cross_attn.key.weight", l),
wb.cross_k[l], enc_seq, dec_dim, enc_dim);
// V projection: encoder_output @ cross_attn.value.weight^T + bias
enc_gemm(dev, wb.encoder_output, wm.get_tensor(
"decoder.blocks.%d.cross_attn.value.weight", l),
wb.cross_v[l], enc_seq, dec_dim, enc_dim);
enc_bias_add(dev, wb.cross_v[l], vb, enc_seq, dec_dim);
}
dev.end_encoding_sync();
// Rearrange to head-major [n_heads, enc_seq, head_dim]
for (int l = 0; l < n_dec; l++) {
// CPU rearrange on UMA is fast after single GPU sync
for (int h = 0; h < n_heads; h++)
for (int p = 0; p < enc_seq; p++)
memcpy(&d[(h*enc_seq+p)*head_dim],
&s[p*dec_dim + h*head_dim],
head_dim * sizeof(__fp16));
}
}
For Whisper Large (enc_seq=1500, dec_dim=1280, 32 decoder layers), this precomputation involves:
GEMM per layer: 2 x (1500 x 1280 x 1280) = 4.9 GFLOPS
Total (32 layers): 32 x 4.9 = 157 GFLOPS
Time on M2 Pro: ~15ms (one-time cost)
Without precomputation:
Per token: 157 GFLOPS / 1500 * 1 = 0.1 GFLOPS (tiny but...)
Overhead: 2 extra GEMV dispatches per layer per token
= 64 extra kernel dispatches per token
With precomputation:
Per token: 0 GEMM, cross-K and cross-V are static buffers
= just the cross-attention dot product
The savings are not just in FLOPS – eliminating 64 kernel dispatches per token significantly reduces command buffer overhead.
Note that the K/V rearrangement to head-major layout is done on CPU after a single GPU sync. On Apple Silicon’s unified memory, this memcpy-based rearrangement completes in microseconds because the data is already in coherent memory.
The Decoder
The decoder uses a dispatch table, just like the LLM decoder. But it has additional complexity: cross-attention, positional embeddings (instead of RoPE), and Whisper-specific token suppression.
Dispatch Table Structure
The Whisper decode table has the following command sequence per token:
1. embedding_lookup_f16 (token -> hidden state)
2. pos_embed_add_f16 (add positional encoding)
For each layer (16 commands/layer with fused kernels):
3. layernorm_f16 (self-attention norm)
4. whisper_gemv_bias_f16 (Q projection, fused with bias)
5. gemv_f16 (K projection, no bias)
6. whisper_gemv_bias_f16 (V projection, fused with bias)
7. kv_cache_write_f16 x2 (write K,V to cache, no RoPE)
8. flash_attention_decode_f16 (self-attention, causal)
9. whisper_gemv_bias_res_f16 (O + bias + residual, fused)
10. layernorm_f16 (cross-attention norm)
11. whisper_gemv_bias_f16 (cross-Q + bias)
12. flash_attention_decode_f16 (cross-attention, static K/V)
13. whisper_gemv_bias_res_f16 (cross-O + bias + residual)
14. layernorm_f16 (FFN norm)
15. whisper_gemv_bias_gelu_f16 (FFN up + bias + GELU, fused)
16. whisper_gemv_bias_res_f16 (FFN down + bias + residual)
17. layernorm_f16 (output norm)
18. gemv_f16 (logit projection)
19. whisper_suppress_f16 (suppress special tokens)
20. argmax_f16 (greedy decode)
Fused Whisper Kernels
The decoder benefits enormously from fused kernels that combine operations that would otherwise require separate dispatches:
| Fused kernel | Operations combined | Dispatches saved |
|---|---|---|
whisper_gemv_bias_f16 | GEMV + bias add | 2 -> 1 |
whisper_gemv_bias_gelu_f16 | GEMV + bias + GELU | 3 -> 1 |
whisper_gemv_bias_residual_f16 | GEMV + bias + residual add | 3 -> 1 |
Without fusion, each decoder layer would need ~26 dispatches. With fusion, it is ~16. For Whisper Large with 32 layers, that is:
Unfused: 2 + 32*26 + 4 = 838 dispatches/token
Fused: 2 + 32*16 + 4 = 518 dispatches/token
Savings: 320 dispatches/token (38% reduction)
On Apple Silicon, each dispatch has a fixed overhead of roughly 1-3 microseconds for command encoding. At 518 dispatches, that is ~0.5-1.5ms of pure overhead – significant when the target is real-time transcription.
The fused kernels gracefully degrade if the specialized Metal function is not available (e.g., on an older metallib):
static void w_emit_gemv_bias(DispatchTable& tbl, Device& dev, ...) {
Pipeline pso = dev.get_pipeline("whisper_gemv_bias_f16");
if (!pso.handle) {
// Fallback: separate GEMV + bias add
w_emit_gemv(tbl, dev, in, weight, out, out_off, N, K);
w_emit_bias(tbl, dev, out, out_off, bias, N);
return;
}
// Fused dispatch
// ...
}
KV Cache Without RoPE
Unlike LLMs, Whisper’s decoder uses learned sinusoidal positional embeddings added to the token embedding, not RoPE applied to Q/K at each layer. This means the KV cache write kernel is simpler – it just stores the projected K and V without any positional rotation:
// KV write: no RoPE (function constant = false)
uint32_t kv_fc_v[] = {0}; // 0 = no positional encoding
uint32_t kv_fc_t[] = {1}; // bool type
Pipeline kv_pso = dev.get_pipeline("kv_cache_write_f16",
"kv_write_nopos", kv_fc_i, kv_fc_v, 1, kv_fc_t);
Cross-Attention
Cross-attention in the decoder uses the precomputed K/V buffers from the encoder. These are static – they do not change between tokens – so there is no KV cache write for cross-attention:
// Cross-attention: static K/V, fixed kv_seq_len = enc_seq
struct attn_params = {
.seq_len = 1,
.kv_seq_len = (uint32_t)enc_seq, // always 1500
// ...
};
cmd.add_buffer(wb.cross_k[l], 0, 1); // precomputed
cmd.add_buffer(wb.cross_v[l], 0, 2); // precomputed
This is why cross-attention precomputation matters – the GEMV for projecting K and V would otherwise happen at every decoder step.
Token Suppression
Whisper has special tokens (timestamps, language tags, task tokens) that should be suppressed during normal text generation. Akunu handles this with a GPU-side suppression kernel:
// Suppress special tokens [first_special, first_timestamp)
// This allows timestamp tokens through for timestamps mode
struct {
uint32_t first_special, vocab_size, eot, suppress_blank;
} sp = {wdp.first_special, suppress_end, wdp.eot, 0};
dispatch("whisper_suppress_f16", n_suppress, 256);
The kernel sets logits for suppressed tokens to negative infinity, effectively removing them from consideration by the argmax or sampling step. Timestamps are not suppressed, which allows the model to output timestamp tokens when running in timestamps mode.
The ArchDescriptor for Whisper
Whisper’s unique properties are captured in the ArchDescriptor:
inline ArchDescriptor arch_whisper() {
ArchDescriptor d = {};
d.activation_kernel = "gelu_f16"; // plain GELU, no gate
d.embedding_scale = 0.0f;
d.has_qk_norm = false;
d.rope_kernel = nullptr; // no RoPE
d.tie_embeddings = true; // output = embedding^T
d.is_encoder_decoder = true;
d.has_cross_attention = true;
d.has_conv_frontend = true;
d.has_bias = true; // all linears have bias
d.norm_type = "layernorm"; // not rmsnorm
d.encoder_activation = "gelu_f16";
return d;
}
The is_encoder_decoder and has_cross_attention flags tell the initialization code to allocate cross-attention buffers, run the encoder, and build the decoder dispatch table with cross-attention commands.
Model Loading
Whisper models in Akunu are loaded from whisper.cpp’s custom binary format (magic: "lmgg" or "ggjt"), not GGUF. The WhisperModel struct holds all tensor data on GPU:
File structure:
4 bytes: magic ("lmgg" or "ggjt")
44 bytes: hyperparameters (11 x int32)
mel filters (precomputed filterbank from file)
vocabulary (token strings)
tensor data (name + dims + dtype + raw data)
The hyperparameters encode both encoder and decoder dimensions:
| Index | Field | Example (Large) |
|---|---|---|
| 0 | vocab_size | 51865 |
| 1 | n_audio_ctx (enc_seq) | 1500 |
| 2 | n_audio_state (enc_dim) | 1280 |
| 3 | n_audio_head | 20 |
| 4 | n_audio_layer | 32 |
| 5 | n_text_ctx (dec_seq) | 448 |
| 6 | n_text_state (dec_dim) | 1280 |
| 7 | n_text_head | 20 |
| 8 | n_text_layer | 32 |
| 9 | n_mels | 80 |
| 10 | ftype (quant format) | 1 (F16) |
End-to-End Flow
Putting it all together, here is the complete flow for transcribing 30 seconds of audio:
1. Audio input (480,000 float samples at 16kHz)
|
2. MelSpectrogram::compute() [CPU, ~5ms]
| vDSP FFT + cblas_sgemm + log normalization
|
3. Upload mel to GPU buffer [UMA, ~0.1ms]
|
4. encode_whisper() [GPU, ~30ms]
| Conv1D x2 + 32 encoder layers
| Non-causal self-attention (1500 seq len)
|
5. precompute_cross_kv() [GPU, ~15ms]
| 64 GEMMs (K+V for each of 32 layers)
| CPU rearrange to head-major
|
6. Decode loop (greedy or beam search):
| For each output token:
| dispatch_table.execute() [GPU, ~3ms/token]
| Read argmax token ID
| Check for EOT
| Until EOT or max_tokens
|
7. Detokenize output tokens -> text string
For Whisper Large on M2 Pro, typical performance is:
Encoder: ~30ms (one-time)
Cross-KV: ~15ms (one-time)
Decode per token: ~3ms
Typical output: ~50 tokens for 30s of speech
Total decode: ~150ms
Total latency: ~200ms for 30 seconds of audio
Real-time factor: 0.007x (150x faster than real-time)
This makes Whisper on Apple Silicon more than fast enough for real-time streaming transcription, where audio arrives in chunks and the encoder/decoder pipeline overlaps with audio capture.
-
Radford et al., “Robust Speech Recognition via Large-Scale Weak Supervision,” OpenAI, 2022. Whisper was trained on 680,000 hours of weakly supervised audio-text pairs. See https://arxiv.org/abs/2212.04356. ↩
-
The HTK mel scale (named after the Hidden Markov Model Toolkit) defines mel(f) = 2595 * log10(1 + f/700). An alternative “Slaney” definition uses a piecewise linear/log formula. Whisper uses the HTK definition with Slaney normalization of the filterbank triangles. ↩
The HTTP Server
Akunu ships with a built-in HTTP server that implements the OpenAI-compatible /v1/chat/completions API. This means any application that talks to OpenAI’s API can switch to a local Akunu instance by changing a single URL. The server handles model management, chat template formatting, streaming via Server-Sent Events, rate limiting, metrics, grammar-constrained generation, tool calling, and prefix caching – all in a single header file (src/server/serve.h) with no external dependencies beyond Akunu’s own HTTP primitives.
ServerConfig
The server is configured through a ServerConfig struct:
struct ServerConfig {
std::string host = "127.0.0.1";
int port = 8080;
int max_context = 4096;
int max_queue_depth = 16;
int rate_limit_per_minute = 0; // 0 = unlimited
std::string api_key; // empty = no auth
int default_max_tokens = 2048;
int idle_timeout_seconds = 0; // 0 = disabled
};
| Field | Default | Description |
|---|---|---|
host | 127.0.0.1 | Bind address (use 0.0.0.0 for all interfaces) |
port | 8080 | Listen port |
max_context | 4096 | Maximum KV cache context length |
max_queue_depth | 16 | Maximum queued requests (not yet implemented as a semaphore) |
rate_limit_per_minute | 0 | Requests per minute per client IP (0 = unlimited) |
api_key | "" | Bearer token for auth (empty = no auth) |
default_max_tokens | 2048 | Default max_tokens if not specified in request |
idle_timeout_seconds | 0 | Auto-unload models after this many seconds idle (0 = disabled) |
When idle_timeout_seconds is set, a background thread checks every 30 seconds for models that have not been accessed within the timeout window and unloads them:
if (config_.idle_timeout_seconds > 0) {
idle_thread_ = std::thread([this]() {
while (!idle_stop_.load()) {
std::this_thread::sleep_for(std::chrono::seconds(30));
registry_.unload_idle(config_.idle_timeout_seconds);
}
});
}
The Model Registry
The server can host multiple models simultaneously. The ModelRegistry manages them:
ModelRegistry
|
+-- add(handle, id, path, metallib)
+-- remove(id)
+-- resolve(requested_id) -> ModelEntry
+-- model_list() -> JSON
+-- unload_idle(timeout_seconds)
Model Resolution
When a request comes in with a model field, the registry uses a multi-level matching strategy:
1. Exact match: "llama-3.1-8b-q4" == "llama-3.1-8b-q4"
2. Case-insensitive match: "Llama-3.1-8B-Q4" == "llama-3.1-8b-q4"
3. Substring match: "llama" matches "llama-3.1-8b-q4"
4. Default fallback: any request -> first loaded model
This flexible matching means you can use short names in your client code ("llama") and they will resolve to the full model ID. The default fallback means single-model setups “just work” regardless of what model name the client sends.
ModelEntry and Prefix Caching
Each ModelEntry tracks state for prefix caching:
struct ModelEntry {
akunu_model_t handle;
std::string id;
std::string path;
std::atomic<int64_t> last_access{0};
std::mutex mu; // serialize inference per model
// Prefix cache
std::vector<uint32_t> cached_tokens;
int cached_position = 0;
int shared_prefix(const uint32_t *tokens, int n_tokens) const {
int shared = 0;
int limit = std::min((int)cached_tokens.size(), n_tokens);
for (int i = 0; i < limit; i++) {
if (cached_tokens[i] != tokens[i]) break;
shared++;
}
return shared;
}
};
Prefix caching is simple but effective: if the new request shares a prefix with the previous request’s tokens, Akunu can skip re-encoding that prefix and continue from where it left off. This is common in chat scenarios where each turn appends to the conversation history:
Turn 1: [system][user_1] -> process all
Turn 2: [system][user_1][asst_1][user_2] -> skip [system][user_1]
Turn 3: [system][user_1][asst_1][user_2][asst_2][user_3] -> skip more
The server checks if the shared prefix length exceeds the current KV cache position, and if the total estimated length (prefix + new tokens + max generation) fits within max_context. If the context would overflow, it resets the KV cache and processes from scratch:
int shared = entry->shared_prefix(tokens.data(), n_tokens);
bool use_continue = (shared > 0 && shared <= entry->cached_position);
int est_position = use_continue
? shared + (n_tokens - shared) + max_tokens
: n_tokens + max_tokens;
if (est_position > config_.max_context)
use_continue = false; // would overflow, reset
if (!use_continue)
akunu_reset(entry->handle);
API Routes
The server registers these routes:
| Method | Path | Description |
|---|---|---|
POST | /v1/chat/completions | Chat completions (OpenAI-compatible) |
POST | /v1/completions | Text completions |
GET | /v1/models | List loaded models |
POST | /v1/tokenize | Tokenize text (extension) |
GET | /health | Health check |
GET | /v1/metrics | Server metrics |
POST | /v1/models/load | Load a model (extension) |
POST | /v1/models/unload | Unload a model (extension) |
POST | /v1/audio/transcriptions | Whisper transcription (OpenAI-compatible) |
Chat Completions
The /v1/chat/completions endpoint accepts the standard OpenAI request format:
{
"model": "llama-3.1-8b",
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello!"}
],
"max_tokens": 256,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"min_p": 0.0,
"stream": true,
"stop": ["\n\n"],
"response_format": {"type": "json_object"},
"tools": [...]
}
Supported sampling parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
temperature | float | 0.0 | Sampling temperature (0 = greedy) |
top_p | float | 0.9 | Nucleus sampling threshold |
top_k | int | 40 | Top-K sampling |
min_p | float | 0.0 | Minimum probability threshold |
frequency_penalty | float | 0.0 | Mapped to repeat_penalty = 1 + freq_pen |
max_tokens | int | 2048 | Maximum tokens to generate |
stream | bool | false | Enable SSE streaming |
stop | string/array | [] | Stop sequences |
Chat Template Formatting
The server auto-detects the correct chat template based on the model architecture:
| Architecture | Template | Example |
|---|---|---|
llama, mistral | LLaMA 3 | <|start_header_id|>user<|end_header_id|> |
qwen3 | ChatML | <|im_start|>user |
gemma, gemma3 | Gemma | <start_of_turn>user |
| (default) | ChatML | <|im_start|>user |
For Qwen3 models, the server automatically appends /no_think to the system prompt to disable the model’s “thinking” mode, which produces verbose chain-of-thought output that most API users do not want.
Tool Calling
When the request includes a tools array, the server injects tool definitions into the system prompt:
# Tools
You have access to the following tools:
## get_weather
Get the current weather for a location.
Parameters: {"type": "object", "properties": {"location": ...}}
To call a tool, output: <tool_call>{"name": "function_name", "arguments": {...}}</tool_call>
After generation, the output is parsed for tool call patterns. The server recognizes two formats:
- ChatML/Qwen:
<tool_call>{"name": "...", "arguments": ...}</tool_call> - LLaMA 3.1+: Direct JSON with a
"name"key
If tool calls are detected, the response’s finish_reason is set to "tool_calls" and the parsed calls are included in the response.
Streaming (Server-Sent Events)
When stream: true, the server uses SSE (Server-Sent Events) to stream tokens as they are generated:
HTTP/1.1 200 OK
Content-Type: text/event-stream
data: {"id":"chatcmpl-abc","object":"chat.completion.chunk",...,"choices":[{"delta":{"role":"assistant"}}]}
data: {"id":"chatcmpl-abc",...,"choices":[{"delta":{"content":"Hello"}}]}
data: {"id":"chatcmpl-abc",...,"choices":[{"delta":{"content":" world"}}]}
data: {"id":"chatcmpl-abc",...,"choices":[{"delta":{},"finish_reason":"stop"}]}
data: [DONE]
Each generated token triggers the callback function, which:
- Feeds the token text through the
StopSequenceDetector - If safe to emit (no partial stop sequence match), sends an SSE event
- If a stop sequence is matched, stops generation
Stop Sequence Detection
The StopSequenceDetector handles the tricky case where a stop sequence might span multiple tokens. For example, if the stop sequence is "\n\n" and tokens arrive as "\n" then "\n", the detector must buffer the first "\n" until it can determine whether the next token completes the stop sequence or not:
Token: "Hello" -> emit "Hello", buffer empty
Token: " \n" -> emit " ", buffer "\n" (partial match)
Token: "\n" -> stop sequence matched! emit nothing, stop
The algorithm:
std::string feed(const std::string& token, bool& stopped) {
buffer_ += token;
// Check for complete match
for (auto& seq : sequences_) {
if (buffer_.find(seq) != std::string::npos) {
stopped = true;
return buffer_.substr(0, pos); // text before match
}
}
// Check for partial match at end of buffer
size_t max_prefix = 0;
for (auto& seq : sequences_) {
// How much of seq matches the end of buffer?
for (size_t len = 1; len <= seq.size(); len++) {
if (buffer_.ends_with(seq.substr(0, len)))
max_prefix = max(max_prefix, len);
}
}
// Emit everything except the potential partial match
std::string safe = buffer_.substr(0, buffer_.size() - max_prefix);
buffer_ = buffer_.substr(buffer_.size() - max_prefix);
return safe;
}
Rate Limiter
The rate limiter uses a token bucket algorithm per client IP:
Rate limiter state per IP:
tokens: double (starts at rpm, refills over time)
last_refill: timestamp
On each request:
elapsed = now - last_refill
tokens += elapsed * (rpm / 60.0)
tokens = min(tokens, rpm) // cap at burst size
if tokens >= 1.0:
tokens -= 1.0
-> allow
else:
-> reject with 429
The implementation includes periodic eviction of stale buckets (every 1000 calls, remove entries idle for more than 5 minutes) to prevent unbounded memory growth from unique client IPs.
Metrics
The Metrics class tracks per-model and aggregate statistics:
class Metrics {
int total_requests_ = 0;
int total_prompt_tokens_ = 0;
int total_completion_tokens_ = 0;
struct ModelMetrics {
int requests = 0;
int prompt_tokens = 0;
int completion_tokens = 0;
};
std::unordered_map<std::string, ModelMetrics> per_model_;
};
The /v1/metrics endpoint returns a JSON snapshot:
{
"total_requests": 42,
"total_prompt_tokens": 12000,
"total_completion_tokens": 8500,
"uptime_seconds": 3600.5,
"models": {
"llama-3.1-8b-q4": {
"total_requests": 30,
"total_prompt_tokens": 9000,
"total_completion_tokens": 6000
}
}
}
All metrics operations are mutex-protected for thread safety.
Thread Safety
The server’s thread safety model is straightforward:
HTTP Server
|
+-- Request arrives on server thread pool
|
+-- Auth check (stateless, safe)
+-- Rate limiter check (mutex-protected)
+-- Model resolve (registry mutex)
|
+-- Acquire model inference lock (entry->mu)
| Only ONE inference per model at a time
|
+-- Tokenize, prefill, generate
| (single-threaded inference)
|
+-- Release model lock
+-- Record metrics (metrics mutex)
The key constraint is entry->mu – a per-model mutex that serializes inference. This is necessary because the GPU resources (KV cache, scratch buffers) are not duplicated per request. A future enhancement could support concurrent requests to the same model with multiple KV cache slots, but for single-user scenarios this serialization is both simple and correct.
There is a subtle TOCTOU (time-of-check-time-of-use) guard: after acquiring the inference lock, the server re-checks that the model handle is still valid, because the idle unload thread might have freed it between the resolve() call and the lock acquisition:
std::lock_guard<std::mutex> infer_lock(entry->mu);
// Re-check handle after acquiring lock
if (!entry->handle) {
send_error(conn, 503, "Model was unloaded",
"server_error", "model_unloaded");
return;
}
JSON Mode and Grammar Constraints
The server supports three levels of output structure:
- Unconstrained: normal text generation
- JSON mode (
response_format: {type: "json_object"}): augments the system prompt with a JSON instruction and uses Akunu’s grammar engine to constrain output to valid JSON - JSON Schema (
response_format: {type: "json_schema", json_schema: {schema: ...}}): constrains output to match a specific JSON schema
Grammar objects are managed with RAII to prevent leaks on early returns:
akunu_grammar_t grammar = nullptr;
struct GrammarGuard {
akunu_grammar_t& g;
~GrammarGuard() { if (g) { akunu_grammar_free(g); g = nullptr; } }
} grammar_guard{grammar};
For JSON mode, the system prompt is augmented:
IMPORTANT: You must respond with valid JSON only. No markdown,
no explanation, just a JSON object or array.
After generation, the server attempts to extract clean JSON from the output, handling cases where the model wraps its response in markdown code fences.
Request Logging
Every request and response is logged to stderr with timestamps and performance data:
[14:32:05] --> POST /v1/chat/completions model=llama-3.1-8b stream max_tokens=256
[14:32:06] <-- 200 llama-3.1-8b prompt=45 completion=128 prefill=1200 t/s decode=95 t/s 1340ms stop
This gives operators immediate visibility into request patterns and model performance without any additional monitoring infrastructure.
Summary
The Akunu HTTP server packs a lot of functionality into a single header:
serve.h
|
+-- ServerConfig (bind address, limits, auth)
+-- ModelRegistry (multi-model, flexible resolution)
+-- ModelEntry (per-model state, prefix caching)
+-- RateLimiter (token bucket per client IP)
+-- Metrics (per-model request/token counters)
+-- StopSequenceDetector (buffered multi-token stop detection)
+-- Chat template logic (LLaMA 3, ChatML, Gemma auto-detect)
+-- Tool call parsing (ChatML and LLaMA formats)
+-- JSON mode / grammar (constrained generation)
+-- SSE streaming (OpenAI-compatible chunks)
+-- AkunuServer (ties it all together)
The design philosophy is zero external dependencies and OpenAI wire compatibility. Any client library that works with the OpenAI API – Python’s openai package, LangChain, LlamaIndex, Cursor, Continue – can point at an Akunu server with no code changes beyond the base URL.
Development Environment Setup
Welcome to Part IX of this book, where we transition from understanding how akunu works to actually contributing to its codebase. If you have read this far, you already know more about Apple Silicon LLM inference than most people who write it. Now it is time to get your hands dirty.
This chapter walks you through every step of setting up a development environment for akunu on macOS. We will cover hardware prerequisites, toolchain installation, cloning and building the project, IDE configuration, and debugging Metal shaders. By the end, you will have a running build that passes all tests and a workflow that lets you iterate quickly on both CPU and GPU code.
Hardware Prerequisites
Akunu targets Apple Silicon exclusively. You need a Mac with an M-series chip:
Supported Hardware
==================
+------------------+-------------+------------------+------------------+
| Chip | GPU Family | GPU Cores | Memory BW |
+------------------+-------------+------------------+------------------+
| M1 | Apple 7 | 7-8 | 68 GB/s |
| M1 Pro/Max/Ultra | Apple 7 | 14-64 | 200-800 GB/s |
| M2 | Apple 8 | 8-10 | 100 GB/s |
| M2 Pro/Max/Ultra | Apple 8 | 16-76 | 200-800 GB/s |
| M3 | Apple 8 | 8-10 | 100 GB/s |
| M3 Pro/Max/Ultra | Apple 8 | 11-40 | 150-400 GB/s |
| M4 | Apple 9 | 10 | 120 GB/s |
| M4 Pro/Max/Ultra | Apple 9 | 16-64 | 273-800 GB/s |
+------------------+-------------+------------------+------------------+
Minimum: M1 with 8 GB RAM (small models only)
Recommended: M2 Pro+ with 16+ GB RAM
Ideal: M4 Pro/Max with 36+ GB RAM
The GPU family number matters because it determines which Metal features are available. Apple 7 (M1) supports Metal 3.0. Apple 8 (M2/M3) adds Metal 3.1 with improvements to SIMD-group operations. Apple 9 (M4) introduces Metal 3.2 with native BF16 support and enhanced matrix operations.
Akunu auto-detects your GPU family at build time through the Metal compiler and
at runtime through the MTLDevice.supportsFamily API. You do not need to
configure anything – the build system picks the highest Metal standard your
hardware supports.
Software Prerequisites
macOS Version
You need macOS 14 (Sonoma) or later. macOS 15 (Sequoia) is recommended because it ships with Metal 3.2 support and improved GPU debugging tools. You can check your version:
sw_vers --productVersion
# 15.4.1
Xcode
Xcode 15 or later is required. Xcode 16+ is recommended because it includes:
- Metal 3.2 compiler with
-fmetal-math-fp32-functions=fastoptimization - Improved GPU profiler with per-kernel timing
- Better Metal shader debugging and validation layers
Install Xcode from the App Store, then make sure the command-line tools are selected:
# Check current Xcode version
xcodebuild -version
# Xcode 16.3
# Build version ...
# Ensure command-line tools point to full Xcode (not standalone CLT)
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
# Verify the Metal compiler is available
xcrun -sdk macosx metal --version
# Apple metal version 32023.155 (metalfe-32023.155)
The Metal compiler (metal) and linker (metallib) are part of Xcode, not
separate installs. If xcrun metal fails, your Xcode installation is
incomplete.
CMake
Akunu uses CMake 3.20+ as its build system for the C++ engine. The Metal shaders have their own Makefile-based build, but CMake orchestrates the C++ compilation and test executables.
# Install via Homebrew (recommended)
brew install cmake
# Verify version
cmake --version
# cmake version 3.31.6
Optional but Recommended
# Ninja: faster parallel builds than make
brew install ninja
# ccache: caches compilation results for faster rebuilds
brew install ccache
Cloning the Repository
Akunu uses git submodules for its third-party dependencies (currently just
XGrammar for grammar-constrained decoding). Always clone with --recursive:
git clone --recursive https://github.com/prabod/akunu.git
cd akunu
If you already cloned without --recursive, initialize the submodules now:
git submodule update --init --recursive
This pulls XGrammar (v0.1.33) into 3rdparty/xgrammar/. Without it, the
build still succeeds but grammar-constrained decoding is disabled.
Let us look at the directory structure:
akunu/
+-- CMakeLists.txt # C++ build system
+-- Makefile # Metal shader build + convenience targets
+-- VERSION # Version file
+-- include/
| +-- akunu/
| +-- akunu.h # Public C API
| +-- types.h # Public type definitions
+-- src/
| +-- core/ # Engine core: dispatch, descriptors, config
| +-- inference/ # Decode paths, sampling, model loading
| +-- tokenizer/ # BPE tokenizer
| +-- grammar/ # GBNF grammar, JSON schema
| +-- weight/ # GGUF parser, MLX SafeTensors, weight store
| +-- whisper/ # Whisper transcription engine
| +-- akunu_api.cpp # C API implementation
+-- backend/
| +-- metal/
| +-- metal_device.mm # Metal GPU backend (Objective-C++)
| +-- kernels/
| +-- ShaderTypes.h # Shared CPU/GPU param structs
| +-- KernelCommon.h # Shared GPU utilities
| +-- MetalKernels.c # Kernel registration
| +-- metal/kernel/ # .metal shader source files
| +-- activation/ # silu.metal, gelu.metal
| +-- attention/ # flash_attention*.metal, softmax.metal
| +-- common/ # residual_add.metal, transpose.metal
| +-- convert/ # dequant_*.metal, f16<->f32
| +-- embedding/ # embedding_lookup_*.metal
| +-- fused/ # gemv_q8_0_head_rmsnorm.metal
| +-- kv_cache/ # kv_cache_write.metal, shift.metal
| +-- matmul/ # gemv_*.metal, simd_gemm_*.metal
| +-- norm/ # rmsnorm.metal, layernorm.metal
| +-- rope/ # rope.metal, rope_neox.metal
| +-- sampling/ # argmax.metal, gumbel_topk.metal
+-- tools/ # CLI executables
| +-- akunu_chat.cpp # Interactive chat
| +-- akunu_bench.cpp # llama-bench style benchmark
| +-- akunu_profile.cpp # Per-kernel GPU profiler
| +-- akunu_serve.cpp # OpenAI-compatible HTTP server
+-- tests/
| +-- test_*.cpp # Unit and integration tests
| +-- kernels/ # Per-kernel GPU correctness tests
| +-- activation/ # test_silu.cpp, test_gelu.cpp, ...
| +-- attention/ # test_flash_attention.cpp
| +-- matmul/ # test_gemv_f16.cpp, test_gemv_q4_0.cpp, ...
| +-- norm/ # test_rmsnorm.cpp, test_gemma_rmsnorm.cpp
| +-- rope/ # test_rope.cpp, test_rope_neox.cpp
| +-- convert/ # test_f32_to_f16.cpp, test_dequant_q4_0.cpp
| +-- embedding/ # test_embedding_f16.cpp
+-- 3rdparty/
+-- xgrammar/ # Grammar-constrained decoding library
There are roughly 120 .metal shader files, 16 kernel tests, 14 unit/integration
tests, and 8 CLI tools. The total C++ codebase is around 15,000 lines, with
another 10,000+ lines of Metal shader code.
Building the Project
Akunu has a two-stage build:
- Metal shaders:
.metalsource files are compiled to.air(Apple Intermediate Representation), then linked into a singleakunu.metallibbinary - C++ engine: CMake builds all source files, links against Metal and Accelerate frameworks, and produces test executables and CLI tools
The Makefile provides convenience targets that run both stages:
Full Build (Shaders + Engine)
make
This is equivalent to make shaders engine. Let us trace what happens:
Step 1: make shaders
================================
For each .metal file in backend/metal/kernels/metal/kernel/**/:
xcrun -sdk macosx metal \
-std=metal3.2 \ <-- auto-detected (3.2 > 3.1 > 3.0)
-I backend/metal/kernels \
-O2 \ <-- optimization level
-fmetal-math-fp32-functions=fast \ <-- Xcode 16+ only
-c backend/metal/kernels/metal/kernel/norm/rmsnorm.metal \
-o build/air/metal/kernel/norm/rmsnorm.air
Then link all .air files into one metallib:
xcrun -sdk macosx metallib build/air/**/*.air -o build/akunu.metallib
Step 2: make engine
================================
mkdir -p build
cd build && cmake .. -DCMAKE_BUILD_TYPE=Release
cd build && make -j$(sysctl -n hw.ncpu)
This produces:
build/akunu_chat # interactive chat tool
build/akunu_bench # benchmark tool
build/akunu_profile # per-kernel profiler
build/akunu_serve # HTTP server
build/akunu_e2e # end-to-end test
build/akunu_test_* # all test executables
build/akunu_kernel_* # per-kernel test executables
build/libakunu_engine.a # static library
Build Time Expectations
On an M4 Pro (12-core CPU):
Stage First Build Rebuild (1 file changed)
----------- ----------- ------------------------
Metal shaders ~15 seconds ~2 seconds (1 .metal file)
C++ engine ~25 seconds ~3 seconds (1 .cpp file)
Total ~40 seconds ~5 seconds
The Metal shader build parallelizes across CPU cores. With 120+ shader files,
the first build compiles them all in parallel. Subsequent rebuilds only
recompile changed .metal files and re-link the metallib.
Shader-Only Build
If you are working on Metal kernels and do not need to rebuild C++:
make shaders
Engine-Only Build
If you changed only C++ code and the metallib already exists:
make engine
Shared Library Build
For language bindings (Python, Swift, etc.):
make shared
# Produces: build/libakunu.dylib
Debug Build
For debugging with Xcode or lldb:
mkdir -p build-debug
cd build-debug
cmake .. -DCMAKE_BUILD_TYPE=Debug
make -j$(sysctl -n hw.ncpu)
Debug builds disable optimizations and enable assert macros. They are significantly slower for inference (3-5x) but essential for stepping through code.
Clean Build
make clean
# Removes the entire build/ directory
Metal Shader Compilation Deep Dive
Understanding the shader build pipeline is important because debugging compilation errors in Metal shaders is different from debugging C++ errors.
.metal source files Shared headers
(120+ files) (ShaderTypes.h, KernelCommon.h)
| |
v v
+------------------------------------------+
| metal compiler (xcrun metal) |
| -std=metal3.2 -O2 -I includes |
| -fmetal-math-fp32-functions=fast |
+------------------------------------------+
|
v
.air files (Apple Intermediate Representation)
One per .metal file, in build/air/
|
v
+------------------------------------------+
| metallib linker (xcrun metallib) |
| Links all .air files into one binary |
+------------------------------------------+
|
v
build/akunu.metallib
(single binary, ~3 MB, loaded at runtime)
The Metal standard version is auto-detected by the Makefile:
METAL_STD := $(shell $(METAL_CC) -std=metal3.2 ... && echo metal3.2 || \
($(METAL_CC) -std=metal3.1 ... && echo metal3.1 || echo metal3.0))
This tries Metal 3.2 first, falls back to 3.1, then 3.0. The
-fmetal-math-fp32-functions=fast flag is similarly auto-detected and only
used when available (Xcode 16+). This allows the same codebase to build on
older Xcode versions without modification.
Common Shader Build Errors
Missing include: If you add a new .metal file that includes a header
not in the include path:
error: 'MyNewHeader.h' file not found
Fix: Add the header to backend/metal/kernels/ or its include/ subdirectory.
Metal standard mismatch: If you use a Metal 3.2 feature on a machine with only Metal 3.1:
error: unknown attribute 'metal3_2_features_only'
Fix: Guard the feature with #if __METAL_VERSION__ >= 320.
Type mismatch with ShaderTypes.h: The parameter structs in ShaderTypes.h
are shared between CPU (C++) and GPU (Metal). If you change a struct, both
sides must agree:
// This comment in ShaderTypes.h says it all:
// CRITICAL: Any change here MUST be mirrored in
// Sources/KernelStore/MetalTypes.swift.
// All structs are padded to 16-byte boundaries
// for Metal argument buffer alignment.
Running Tests
Akunu has several categories of tests, each with its own make target.
Unit Tests (No Model Required)
make test-unit
This runs tests that do not need a model file:
build/akunu_test_tokenizer_internal # BPE tokenizer internals
build/akunu_test_grammar # GBNF grammar parsing
build/akunu_test_server # HTTP server logic
build/akunu_test_whisper # Whisper format parsing
These tests are fast (< 1 second total) and should always pass on a clean build.
Inference Tests (Model Required)
make test-infer MODEL=models/Qwen3-0.6B-Q4_0.gguf
This runs tests that need a real model file:
build/akunu_e2e <model> "The capital of France is" 0 10
build/akunu_test_inference <model>
build/akunu_test_tokenizer <model>
You need to download a model first. The smallest model that exercises all code paths is Qwen3-0.6B in Q4_0 quantization (~400 MB).
Kernel Tests (No Model Required)
Kernel tests verify GPU correctness by comparing Metal shader output against CPU reference implementations. They need the metallib but not a model:
# Run individual kernel tests
build/akunu_kernel_test_rmsnorm
build/akunu_kernel_test_gemv_f16
build/akunu_kernel_test_flash_attention
# The full list of 16 kernel tests:
build/akunu_kernel_test_rmsnorm
build/akunu_kernel_test_gemma_rmsnorm
build/akunu_kernel_test_gemv_f16
build/akunu_kernel_test_gemv_q4_0
build/akunu_kernel_test_gemv_q8_0
build/akunu_kernel_test_gemm_f16
build/akunu_kernel_test_silu
build/akunu_kernel_test_gelu
build/akunu_kernel_test_silu_gate
build/akunu_kernel_test_gelu_gate
build/akunu_kernel_test_rope
build/akunu_kernel_test_rope_neox
build/akunu_kernel_test_flash_attention
build/akunu_kernel_test_embedding_f16
build/akunu_kernel_test_f32_to_f16
build/akunu_kernel_test_dequant_q4_0
Each kernel test creates a MetalDevice, loads the metallib, generates
deterministic test data, runs the GPU kernel, and compares against a CPU
reference. See Chapter 52 for a detailed walkthrough of the testing
infrastructure.
Running All Tests
# Unit + inference
make test MODEL=models/Qwen3-0.6B-Q4_0.gguf
Benchmark
make bench MODEL=models/Qwen3-0.6B-Q4_0.gguf
This runs akunu_bench with 512-token prefill and 128-token generation,
repeated 3 times, reporting tokens/second in llama-bench format.
IDE Setup
Xcode
Xcode is the best IDE for akunu development because it has native Metal shader support, GPU debugging, and frame capture.
Generating an Xcode project from CMake:
mkdir -p build-xcode
cd build-xcode
cmake .. -G Xcode
open akunu.xcodeproj
This creates an Xcode project with all targets (library, tests, tools).
However, it does not handle the Metal shader build – you still need
make shaders from the command line.
Xcode scheme setup:
- Select the
akunu_chatscheme for interactive testing - Edit the scheme: Run > Arguments > add model path as first argument
- Edit the scheme: Run > Options > set Working Directory to project root
- Build and run with Cmd+R
Metal shader editing in Xcode:
Xcode provides syntax highlighting and basic error checking for .metal
files. Open any .metal file from the project navigator. The Metal compiler
runs in the background and shows errors inline.
For shader editing, you want the include paths configured. In the Xcode project, add these to the Metal compiler settings:
Header Search Paths: $(PROJECT_DIR)/backend/metal/kernels
Visual Studio Code
VS Code with the right extensions provides a solid alternative:
# Install recommended extensions
code --install-extension ms-vscode.cpptools
code --install-extension ms-vscode.cmake-tools
code --install-extension nickmass.metal-shader
Create .vscode/settings.json:
{
"cmake.buildDirectory": "${workspaceFolder}/build",
"cmake.configureArgs": ["-DCMAKE_BUILD_TYPE=Debug"],
"C_Cpp.default.includePath": [
"${workspaceFolder}/include",
"${workspaceFolder}/src",
"${workspaceFolder}/backend"
],
"files.associations": {
"*.metal": "metal"
}
}
Create .vscode/tasks.json for shader builds:
{
"version": "2.0.0",
"tasks": [
{
"label": "Build Shaders",
"type": "shell",
"command": "make shaders",
"group": "build"
},
{
"label": "Build All",
"type": "shell",
"command": "make",
"group": {
"kind": "build",
"isDefault": true
}
}
]
}
CLion
CLion has excellent CMake integration. Open the project root directory and CLion will auto-detect the CMakeLists.txt. Add a custom build step for shaders:
- Settings > Build, Execution, Deployment > CMake > add a “Before launch”
step that runs
make shaders - Or configure an External Tool for the shader build
Metal Debugger and GPU Profiling
Xcode GPU Frame Capture
The most powerful tool for debugging Metal shaders is Xcode’s GPU Frame Capture:
-
Set the
METAL_DEVICE_WRAPPER_TYPEenvironment variable:export METAL_DEVICE_WRAPPER_TYPE=1 -
Run your akunu executable under Xcode
-
Click the camera icon in the debug bar to capture a GPU frame
-
Xcode shows every command buffer, compute encoder, and dispatch
This lets you inspect:
- Buffer contents at any point in the pipeline
- Shader execution time per dispatch
- Thread occupancy and register pressure
- Memory bandwidth utilization
Metal Validation Layer
Enable Metal API validation to catch buffer overflows, misaligned access, and other GPU programming errors:
export MTL_DEBUG_LAYER=1
export METAL_DEBUG_ERROR_MODE=assert
With validation enabled, Metal checks every API call and crashes immediately on misuse rather than producing silent corruption. This is essential during development but adds significant overhead – do not use it for benchmarking.
Metal Shader Debugging
For stepping through shader code line-by-line:
- In Xcode, select Debug > Attach to Process > your running akunu executable
- Enable GPU shader debugging: Product > Scheme > Edit Scheme > Run > Diagnostics > GPU Validation > Shader Validation
- Set a breakpoint in a
.metalfile - When the breakpoint hits, you can inspect thread variables, buffer contents, and threadgroup memory
This is slow (100x+ overhead) but invaluable for correctness debugging.
Metal System Trace
For system-level GPU analysis:
# Record a 5-second trace
xctrace record --template 'Metal System Trace' \
--output trace.trace \
--time-limit 5s \
--launch build/akunu_bench models/Qwen3-0.6B-Q4_0.gguf -n 32 -r 1
Open the trace in Instruments to see:
- GPU timeline (which kernels ran when)
- CPU-GPU synchronization points
- Memory allocation patterns
- Command buffer scheduling
The akunu_profile Tool
Akunu includes its own per-kernel profiling tool that does not require Xcode:
build/akunu_profile models/Qwen3-0.6B-Q4_0.gguf --tokens 5
This runs each dispatch command in its own command buffer (rather than the normal batched execution) and reports per-kernel GPU time. The output shows exactly which kernels dominate the forward pass:
Decode Summary (5 tokens)
========================================
embedding 0.012 ms 0.8%
attention_norm 0.008 ms 0.5%
qkv_gemv 0.142 ms 9.1%
rope_kv_write 0.015 ms 1.0%
flash_attention 0.098 ms 6.3%
output_gemv 0.047 ms 3.0%
ffn_norm 0.008 ms 0.5%
gate_gemv 0.142 ms 9.1%
up_gemv 0.142 ms 9.1%
silu_gate 0.012 ms 0.8%
down_gemv 0.142 ms 9.1%
... (per layer)
logit_projection 0.350 ms 22.5%
argmax 0.003 ms 0.2%
========================================
Total per token: 1.56 ms
Throughput: 641 t/s (single-token decode)
See Chapter 55 for a complete guide to profiling and benchmarking.
Quick Development Workflow
Here is the workflow most contributors use:
+------------------+ +------------------+ +------------------+
| Edit code | | Build | | Test |
| (.metal or .cpp)|---->| make |---->| kernel test |
| | | (~5s rebuild) | | or e2e test |
+------------------+ +------------------+ +------------------+
^ |
| |
+---------------------------------------------------+
Fix and iterate
For Metal kernel work:
# 1. Edit your shader
vim backend/metal/kernels/metal/kernel/norm/rmsnorm.metal
# 2. Rebuild shaders only (~2s)
make shaders
# 3. Rebuild the test (~3s)
make engine
# 4. Run the specific kernel test
build/akunu_kernel_test_rmsnorm
For C++ engine work:
# 1. Edit your source
vim src/core/table_builder.cpp
# 2. Rebuild engine only (~3s)
make engine
# 3. Run relevant test
build/akunu_e2e models/Qwen3-0.6B-Q4_0.gguf "Hello" 0 10
For both (new kernel end-to-end):
# 1. Write the .metal file
# 2. Add params to ShaderTypes.h
# 3. Wire up in table_builder.cpp
# 4. Write the kernel test
# 5. Full rebuild + test
make && build/akunu_kernel_test_your_new_kernel
Downloading Test Models
Several tests and all benchmarks require model files. Here are the recommended test models by size:
Model Size Use Case
----------------------------- ------- ----------------------------
Qwen3-0.6B-Q4_0.gguf ~400 MB Default test model (fast)
Llama-3.2-1B-Instruct-Q4_0 ~700 MB Test LLaMA architecture
Qwen3-4B-Q4_K_M.gguf ~2.5 GB Test larger models
whisper-base-en.bin ~140 MB Test Whisper transcription
Place models in the models/ directory at the project root:
mkdir -p models
# Download from HuggingFace or your preferred source
# Example using huggingface-cli:
huggingface-cli download Qwen/Qwen3-0.6B-GGUF \
--include "Qwen3-0.6B-Q4_0.gguf" \
--local-dir models/
The MODEL variable in the Makefile defaults to models/Qwen3-0.6B-Q4_0.gguf.
You can override it:
make test-infer MODEL=models/Llama-3.2-1B-Instruct-Q4_0.gguf
Troubleshooting
“Metallib not found”
The kernel tests look for the metallib in several relative paths:
bool _ok = dev->load_library("../../.build/metallib/akunu.metallib");
if (!_ok) _ok = dev->load_library(".build/metallib/akunu.metallib");
if (!_ok) _ok = dev->load_library("../../../.build/metallib/akunu.metallib");
If none of these match your working directory, either:
- Run tests from the project root:
./build/akunu_kernel_test_rmsnorm - Or set the metallib path explicitly (if the API supports it)
The simplest fix is to always run tests from the project root directory.
CMake Cannot Find Metal Framework
CMake Error: Could not find framework Metal
This means Xcode command-line tools are not properly installed:
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
XGrammar Build Fails
If the XGrammar submodule fails to build:
# Ensure submodule is initialized
git submodule update --init --recursive
# If still failing, the build continues without grammar support
# (AKUNU_HAS_XGRAMMAR will be OFF)
Grammar-constrained decoding is optional. The core inference engine works fine without it.
Objective-C++ Compilation Errors
Several files (like metal_device.mm and test files that use Metal) are
compiled as Objective-C++. CMake handles this automatically:
set_source_files_properties(tests/test_device.mm PROPERTIES LANGUAGE OBJCXX)
If you see errors about @interface or NSError, check that the file
extension is .mm (not .cpp) or that CMake has the LANGUAGE OBJCXX
property set.
Summary
Let us recap the essential commands:
# First-time setup
git clone --recursive https://github.com/prabod/akunu.git
cd akunu
# Full build
make
# Quick iterations
make shaders # Metal only
make engine # C++ only
# Tests
make test-unit # No model needed
make test-infer MODEL=models/Qwen3-0.6B-Q4_0.gguf # Needs model
build/akunu_kernel_test_rmsnorm # Single kernel test
# Benchmarks
make bench MODEL=models/Qwen3-0.6B-Q4_0.gguf
build/akunu_profile models/Qwen3-0.6B-Q4_0.gguf --tokens 5
# Debug
mkdir build-debug && cd build-debug
cmake .. -DCMAKE_BUILD_TYPE=Debug && make -j
export MTL_DEBUG_LAYER=1 # Metal validation
Your development environment is now ready. In the next chapter, we will dive deep into akunu’s testing infrastructure – the CPU reference implementations, the kernel test pattern, and how to write tests for new functionality.
Testing Infrastructure
If there is one rule that separates a hobby GPU project from production-quality inference code, it is this: every kernel must be tested against a CPU reference implementation. GPUs are excellent at hiding bugs. A GEMV kernel with a subtle off-by-one error in its dequantization logic might produce plausible-looking text 95% of the time and garbled output the other 5%. Without rigorous testing, you will spend weeks chasing “random” quality regressions that are actually deterministic bugs in your shader code.
Akunu takes testing seriously. This chapter walks you through the entire testing infrastructure: the test directory structure, the kernel test pattern, the CPU reference implementations, the tolerance model, and how to write new tests.
Test Directory Structure
Let us start with a map of every test file in the repository:
tests/
+-- test_kernel_helpers.h # Shared CPU references + GPU helpers
+-- test_device.mm # MetalDevice lifecycle + basic ops
+-- test_weight_store.cpp # GGUF/MLX weight loading
+-- test_table_builder.cpp # Dispatch table construction
+-- test_config.cpp # ChipConfig + device defaults
+-- test_tokenizer.cpp # BPE tokenizer (needs model)
+-- test_tokenizer_internal.cpp # Tokenizer internals (no model)
+-- test_inference.cpp # Forward pass correctness (needs model)
+-- test_kv_cache.cpp # KV cache write/shift/advance
+-- test_grammar.cpp # GBNF parsing + bitmask generation
+-- test_sampling_quality.cpp # Statistical sampling tests
+-- test_long_context.cpp # Context window stress test
+-- test_server.cpp # HTTP server parsing
+-- test_e2e.cpp # End-to-end generation test
+-- test_whisper.cpp # Whisper format parsing
+-- test_whisper_e2e.cpp # Whisper transcription test
+-- kernels/ # Per-kernel GPU correctness tests
+-- activation/
| +-- test_silu.cpp
| +-- test_gelu.cpp
| +-- test_silu_gate.cpp
| +-- test_gelu_gate.cpp
+-- attention/
| +-- test_flash_attention.cpp
+-- convert/
| +-- test_f32_to_f16.cpp
| +-- test_dequant_q4_0.cpp
+-- embedding/
| +-- test_embedding_f16.cpp
+-- matmul/
| +-- test_gemv_f16.cpp
| +-- test_gemv_q4_0.cpp
| +-- test_gemv_q8_0.cpp
| +-- test_gemm_f16.cpp
+-- norm/
| +-- test_rmsnorm.cpp
| +-- test_gemma_rmsnorm.cpp
+-- rope/
+-- test_rope.cpp
+-- test_rope_neox.cpp
There are three tiers of tests:
Tier 1: Kernel Tests (16 tests)
================================
- Test individual GPU kernels against CPU references
- Need only the metallib, no model file
- Fastest to run (< 1s each)
- This is where you start when adding a new kernel
Tier 2: Unit Tests (8 tests)
================================
- Test CPU-side logic (tokenizer, grammar, config, etc.)
- No GPU or model required for most
- Run via "make test-unit"
Tier 3: Integration / E2E Tests (6 tests)
================================
- Test full inference pipeline with real models
- Need a model file (GGUF or MLX SafeTensors)
- Slower to run (seconds to minutes)
- Run via "make test-infer"
The Kernel Test Pattern
Every kernel test in akunu follows the same pattern. Let us study it by
examining the actual test_rmsnorm.cpp from the repository:
Kernel Test Structure
=====================
1. Initialize GPU device + load metallib
2. Get pipeline state object for the kernel
3. For each test case:
a. Generate deterministic float data on CPU
b. Compute expected output using CPU reference
c. Upload input data to GPU buffers (as F16)
d. Set up shader parameters
e. Encode + dispatch the kernel
f. Synchronously wait for GPU completion
g. Read back GPU output (F16 -> F32)
h. Compare GPU output vs CPU reference within tolerance
4. Report pass/fail counts
Here is the actual code for the RMSNorm kernel test, annotated:
// Step 1: Device initialization (from test_kernel_helpers.h macro)
INIT_DEVICE();
// This creates a MetalDevice, loads akunu.metallib, and prints GPU name
// Step 2: Get the pipeline
Pipeline pso = dev->get_pipeline("rmsnorm_f16");
// Step 3a: Sweep across multiple dimensions
int dims[] = {64, 128, 256, 512, 896, 1024, 2048, 2560, 4096};
for (int dim : dims) {
int n_rows = 4;
// Step 3b: Deterministic test data
auto x_data = det_floats(n_rows * dim, 42 + dim, -2.0f, 2.0f);
auto w_data = det_floats(dim, 99 + dim, 0.5f, 1.5f);
// Step 3c: CPU reference computation
auto expected = cpu_rmsnorm(x_data, w_data, dim, eps);
// Step 3d: Upload to GPU as F16
Buffer x_buf = make_buf_f16(*dev, x_data);
Buffer w_buf = make_buf_f16(*dev, w_data);
Buffer o_buf = dev->allocate(n_rows * dim * 2);
// Step 3e: Set up params and dispatch
RMSNormParams params = {(uint32_t)dim, eps, 0, 0};
dev->begin_encoding();
dev->set_pipeline(pso);
dev->set_buffer(x_buf, 0, 0);
dev->set_buffer(w_buf, 0, 1);
dev->set_buffer(o_buf, 0, 2);
dev->set_bytes(¶ms, sizeof(params), 3);
dev->dispatch(Dim3(n_rows), Dim3(std::min(1024, dim)));
// Step 3f: Wait for GPU
dev->end_encoding_sync();
// Step 3g-h: Read back and compare
CHECK(assert_close(
read_f16(o_buf, n_rows * dim), // GPU output (F16 -> F32)
expected, // CPU reference (F32)
5e-2f, // absolute tolerance
5e-2f, // relative tolerance
3 // max errors to print
), "rmsnorm dim sweep");
}
The pattern is beautifully consistent. Every kernel test in the repository follows this exact flow. The variation is only in:
- Which pipeline (kernel name) to use
- What test data to generate
- Which CPU reference function to call
- What tolerance values are appropriate
The Test Helper Library
The file tests/test_kernel_helpers.h is the backbone of the testing
infrastructure. It provides five categories of utilities:
1. Float16 Conversion (CPU-side)
Since Metal kernels operate on F16 data but CPU reference implementations use F32, we need conversion functions:
static inline uint16_t f32_to_f16(float f); // CPU float -> F16 bits
static inline float f16_to_f32(uint16_t h); // F16 bits -> CPU float
These are software implementations of IEEE 754 half-precision conversion. They handle denormals, infinities, and NaN correctly. The GPU does these conversions in hardware, so our CPU-side conversion must match exactly.
2. Deterministic Random Data
static std::vector<float> det_floats(int count, uint64_t seed = 42,
float lo = -1.0f, float hi = 1.0f);
This generates pseudo-random floats using a linear congruential generator (LCG) with a fixed seed. The key property is determinism: the same seed always produces the same sequence, regardless of platform or compiler. This makes tests reproducible.
The LCG uses the Knuth constants:
s = s * 6364136223846793005 + 1442695040888963407
This is intentionally NOT a cryptographically strong PRNG. We want speed and reproducibility, not security.
3. GPU Buffer Helpers
static Buffer make_buf_f16(Device& dev, const std::vector<float>& data);
static Buffer make_buf_f32(Device& dev, const std::vector<float>& data);
static Buffer make_buf_u32(Device& dev, const std::vector<uint32_t>& data);
static std::vector<float> read_f16(Buffer buf, int count);
static std::vector<float> read_f32(Buffer buf, int count);
These handle the F32-to-F16 conversion on upload and F16-to-F32 conversion on
readback. The make_buf_f16 function converts each float to F16 on the CPU
side, then uploads the F16 data to a GPU buffer. The read_f16 function reads
F16 data from a GPU buffer and converts each value back to F32.
This is the data flow:
CPU F32 data GPU F16 buffer GPU F16 output
[1.0, 2.0, ...] --> make_buf_f16() --> [kernel execution]
|
CPU F32 result <-- read_f16() <-- GPU F16 output buffer
[0.99, 1.98, ...]
Note: F16 conversion introduces quantization error.
That is why we need tolerances in our comparisons.
4. Device Init and Result Macros
#define INIT_DEVICE() \
auto dev = Device::create_default(); \
printf("GPU: %s\n", dev->name()); \
{ \
bool _ok = dev->load_library("../../.build/metallib/akunu.metallib"); \
if (!_ok) _ok = dev->load_library(".build/metallib/akunu.metallib"); \
if (!_ok) _ok = dev->load_library("../../../.build/metallib/akunu.metallib"); \
if (!_ok) { printf("Metallib not found\n"); return 1; } \
}
#define PRINT_RESULTS() \
printf("\n=== Results: %d passed, %d failed ===\n", \
g_tests_passed, g_tests_failed); \
return g_tests_failed > 0 ? 1 : 0
The INIT_DEVICE() macro tries three paths for the metallib because tests can
be run from different working directories. The PRINT_RESULTS() macro returns
a non-zero exit code on failure, which lets CI systems detect test failures.
5. Tolerance Comparison
static bool assert_close(const std::vector<float>& actual,
const std::vector<float>& expected,
float atol, float rtol,
int max_errors = 5,
const char *label = "");
This is the heart of the comparison system. For each element pair, it checks:
|actual[i] - expected[i]| <= max(atol, |expected[i]| * rtol)
The tolerance model uses both absolute tolerance (atol) and relative
tolerance (rtol), taking the maximum. This handles both near-zero values
(where relative tolerance would be too strict) and large values (where
absolute tolerance would be too loose).
Tolerance Model
===============
For expected value 0.001:
atol = 0.05 --> allowed error = 0.05 (absolute dominates)
rtol = 0.05 --> allowed error = 0.00005 (too strict)
max(0.05, 0.00005) = 0.05 <-- we use this
For expected value 100.0:
atol = 0.05 --> allowed error = 0.05 (too strict)
rtol = 0.05 --> allowed error = 5.0 (relative dominates)
max(0.05, 5.0) = 5.0 <-- we use this
When an error is detected, the function prints up to max_errors mismatches
with their indices, actual values, expected values, and the computed tolerance.
This makes debugging much easier than a simple pass/fail.
CPU Reference Implementations
The test helper library includes CPU reference implementations for every kernel category. These are intentionally simple – no SIMD, no optimization, just readable code that is obviously correct.
RMSNorm Reference
static std::vector<float> cpu_rmsnorm(
const std::vector<float>& x,
const std::vector<float>& w,
int dim, float eps)
{
int rows = (int)x.size() / dim;
std::vector<float> out(x.size());
for (int r = 0; r < rows; r++) {
float ss = 0;
for (int i = 0; i < dim; i++)
ss += x[r*dim + i] * x[r*dim + i];
float rms = 1.0f / sqrtf(ss / dim + eps);
for (int i = 0; i < dim; i++)
out[r*dim + i] = x[r*dim + i] * rms * w[i];
}
return out;
}
Compare this with the Metal kernel that runs on the GPU:
kernel void rmsnorm_f16(
device const half *input, device const half *weight,
device half *output, constant RMSNormParams ¶ms, ...)
{
// ... threadgroup reduction for sum of squares ...
float rms = rsqrt(total_sum_sq / float(dim) + eps);
for (uint i = tid; i < dim; i += tg_size) {
row_out[i] = half(float(row_in[i]) * rms) * weight[i];
}
}
The CPU version operates on F32 floats. The GPU version operates on F16 halfs, converting to F32 only for the accumulation. This precision difference is exactly why we need tolerance in our comparisons.
GEMV Reference
// y[n] = sum_k W[n,k] * x[k]
static std::vector<float> cpu_gemv(
const std::vector<float>& W,
const std::vector<float>& x,
int N, int K)
{
std::vector<float> y(N, 0);
for (int n = 0; n < N; n++)
for (int k = 0; k < K; k++)
y[n] += W[n*K + k] * x[k];
return y;
}
This is the simplest possible matrix-vector multiply. The GPU version uses SIMD-group operations, threadgroup memory, and quantized weight formats. But the mathematical result should be the same (within tolerance).
Attention Reference
static std::vector<float> cpu_attention(
const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
int seq_len, int kv_seq_len, int head_dim,
int n_heads, int n_kv_heads, float scale, bool causal)
{
std::vector<float> O(n_heads * seq_len * head_dim, 0);
int heads_per_group = n_heads / n_kv_heads;
for (int h = 0; h < n_heads; h++) {
int kv_h = h / heads_per_group; // GQA mapping
for (int s = 0; s < seq_len; s++) {
// QK^T
std::vector<float> scores(kv_seq_len);
for (int t = 0; t < kv_seq_len; t++) {
float dot = 0;
for (int d = 0; d < head_dim; d++)
dot += Q[...] * K[...];
scores[t] = dot * scale;
if (causal && t > s) scores[t] = -1e9f;
}
// Softmax
// ... standard numerically stable softmax ...
// Weighted sum of V
// ... accumulate ...
}
}
return O;
}
This implements the full attention computation: QK^T with scaling, causal
masking, softmax, and V weighting. It handles grouped-query attention (GQA)
through the heads_per_group mapping. The GPU version uses FlashAttention
with online softmax and tiled computation, but produces the same result.
RoPE References
There are two RoPE variants, matching the two GPU kernels:
// Standard (interleaved pairs): [x0,x1, x2,x3, ...]
static void cpu_rope(std::vector<float>& x,
int seq_len, int n_heads, int head_dim,
float theta, int pos_offset = 0);
// NeoX (split-half): [x0,x1,...,xN/2, xN/2+1,...,xN]
static void cpu_rope_neox(std::vector<float>& x,
int seq_len, int n_heads, int head_dim,
float theta, int pos_offset = 0);
The standard variant pairs adjacent elements (x[2d], x[2d+1]). The NeoX
variant pairs elements from the first and second halves (x[d], x[half+d]).
Both apply the same rotation math, just with different indexing.
Quantization Helpers
For testing quantized GEMV kernels, we need to quantize data on the CPU:
struct Q4_0Block {
uint16_t scale_f16; // absmax / 8 in F16
uint8_t nibs[16]; // 32 x 4-bit values (each byte = 2 values)
};
static std::vector<uint8_t> quantize_q4_0(
const std::vector<float>& vals,
std::vector<float>& dequant_ref);
The quantize_q4_0 function returns both the quantized bytes (for GPU upload)
and the dequantized reference values (for comparison). The dequantized values
account for the quantization error introduced by the round-trip, so our
tolerance comparison measures only the GPU computation error, not the
inherent quantization loss.
Quantization Test Flow
======================
Original F32 values: [0.5, -0.3, 0.8, ...]
|
v
quantize_q4_0() produces:
- Quantized bytes (Q4_0 format) -> upload to GPU
- Dequant reference [0.49, -0.29, 0.79, ...] -> compare against
|
GPU kernel: dequant + GEMV
|
v
GPU output: [result using quantized weights]
|
Compare vs: cpu_gemv(dequant_ref, x)
(GEMV using the SAME dequantized values)
This isolates the GPU computation error from quantization error.
The 16 Kernel Tests
Here is a complete inventory of every kernel test, what it tests, and what CPU reference it uses:
+-------------------------+---------------------+-------------------+--------+
| Test | Kernel(s) | CPU Reference | Tol |
+-------------------------+---------------------+-------------------+--------+
| test_rmsnorm | rmsnorm_f16 | cpu_rmsnorm | 5e-2 |
| test_gemma_rmsnorm | rmsnorm_gemma_f16 | cpu_gemma_rmsnorm | 5e-2 |
| test_gemv_f16 | gemv_f16 | cpu_gemv | 1e-2 |
| test_gemv_q4_0 | gemv_q4_0 | cpu_gemv(dequant) | 5e-2 |
| test_gemv_q8_0 | gemv_q8_0 | cpu_gemv(dequant) | 5e-2 |
| test_gemm_f16 | simd_gemm_f16 | cpu_gemm_bt | 5e-2 |
| test_silu | silu activation | cpu_silu | 1e-3 |
| test_gelu | gelu activation | cpu_gelu | 1e-3 |
| test_silu_gate | silu_gate_f16 | silu(a)*b | 1e-2 |
| test_gelu_gate | gelu_gate_f16 | gelu(a)*b | 1e-2 |
| test_rope | rope_f16 | cpu_rope | 5e-2 |
| test_rope_neox | rope_neox_f16 | cpu_rope_neox | 5e-2 |
| test_flash_attention | flash_attention_* | cpu_attention | 5e-2 |
| test_embedding_f16 | embedding_lookup_f16| direct lookup | 0 |
| test_f32_to_f16 | f32_to_f16 | f32_to_f16 (CPU) | 0 |
| test_dequant_q4_0 | dequant_q4_0 | quantize_q4_0 | 0 |
+-------------------------+---------------------+-------------------+--------+
Notice the tolerance values. Pure data movement kernels (embedding, conversion, dequantization) use zero tolerance – they must produce bit-identical results. Arithmetic kernels (GEMV, GEMM, norms) use tolerances of 1e-2 to 5e-2 to account for F16 precision loss and different accumulation orders.
Why Different Tolerances?
The tolerance differences come from the nature of each operation:
Precision Loss Sources
======================
1. F16 storage: 10-bit mantissa = ~3 decimal digits
- Every F32->F16 conversion loses precision
- Multiply two F16 values: error compounds
2. Accumulation order:
- CPU: sequential sum (deterministic)
- GPU: SIMD reduction + threadgroup reduction (different order)
- Floating-point addition is not associative:
(a + b) + c != a + (b + c) in general
3. Transcendental functions:
- CPU: libm's exp(), sqrt(), etc. (double precision internally)
- GPU: Metal's exp(), rsqrt() (hardware F32, not F64)
- Metal's "fast math" may use approximations
Result: GEMV with K=4096 accumulates 4096 products.
Each product has F16 quantization error. The sum has
accumulated error proportional to sqrt(K) * epsilon_f16.
For K=4096: sqrt(4096) * 0.001 ~ 0.03, which explains
why we use atol=0.05 for GEMV tests.
Test Case Design
Each kernel test exercises multiple configurations to catch edge cases:
Dimension Sweep
===============
Tests run across multiple dimension values:
dims[] = {64, 128, 256, 512, 896, 1024, 2048, 2560, 4096}
Why these values?
64 = minimum useful size
128 = one SIMD group width * 4
256 = common small-model dim
512 = common head_dim * n_heads for small models
896 = Qwen3-0.6B's dim (non-power-of-2!)
1024 = threadgroup max, common dim
2048 = common mid-size dim
2560 = another non-power-of-2 (tests stride logic)
4096 = large model dim (LLaMA 7B/8B)
Non-power-of-2 values (896, 2560) are crucial because they
exercise the remainder handling in strided loops:
for (uint i = tid; i < dim; i += tg_size)
If dim is not a multiple of tg_size, the last iteration
processes fewer elements. Bugs here are common.
Each kernel test also includes edge case tests:
- Zero input: Ensures no NaN propagation (especially for norms where division by near-zero can produce infinity/NaN)
- Constant input: Tests that reduction operations work correctly when all values are identical (catches bugs where partial SIMD lanes have garbage)
- Large values: Tests for overflow in F16 (max representable: ~65504)
- Multi-row: Tests that row indexing is correct (the kernel processes multiple independent rows in one dispatch)
Writing a New Kernel Test
Let us walk through writing a test for a hypothetical new kernel. Suppose you
have added a vector_scale_f16 kernel that multiplies every element by a
scalar:
Step 1: Create the Test File
tests/kernels/common/test_vector_scale.cpp
Step 2: Write the Test
/// GPU kernel test: vector_scale_f16
#include "test_kernel_helpers.h"
#include "metal/kernels/ShaderTypes.h"
// CPU reference: trivially correct
static std::vector<float> cpu_vector_scale(
const std::vector<float>& x, float scale)
{
std::vector<float> out(x.size());
for (size_t i = 0; i < x.size(); i++)
out[i] = x[i] * scale;
return out;
}
int main() {
printf("=== Kernel Test: Vector Scale ===\n\n");
INIT_DEVICE();
Pipeline pso = dev->get_pipeline("vector_scale_f16");
// Dimension sweep
int counts[] = {32, 64, 256, 1000, 1024, 4096, 10000};
float scales[] = {0.5f, 1.0f, 2.0f, -1.0f, 0.001f};
for (int count : counts) {
for (float scale : scales) {
auto x_data = det_floats(count, 42 + count);
auto expected = cpu_vector_scale(x_data, scale);
Buffer x_buf = make_buf_f16(*dev, x_data);
Buffer o_buf = dev->allocate(count * 2);
// Assume params struct: { uint32_t count; float scale; ... }
struct { uint32_t count; float scale; uint32_t _p0, _p1; }
params = {(uint32_t)count, scale, 0, 0};
dev->begin_encoding();
dev->set_pipeline(pso);
dev->set_buffer(x_buf, 0, 0);
dev->set_buffer(o_buf, 0, 1);
dev->set_bytes(¶ms, sizeof(params), 2);
dev->dispatch_threads(Dim3(count), Dim3(256));
dev->end_encoding_sync();
char label[64];
snprintf(label, sizeof(label),
"count=%d, scale=%.3f", count, scale);
CHECK(assert_close(
read_f16(o_buf, count), expected,
1e-3f, 1e-3f, 3, label
), label);
dev->free_buffer(x_buf);
dev->free_buffer(o_buf);
}
}
printf(" dimension x scale sweep done\n");
// Edge case: empty (count=0)
// Edge case: single element
// Edge case: scale=0 (should produce all zeros)
PRINT_RESULTS();
}
Step 3: Add to CMakeLists.txt
Add the test to the KERNEL_TESTS list:
set(KERNEL_TESTS
# ... existing tests ...
tests/kernels/common/test_vector_scale.cpp
)
The foreach loop in CMakeLists.txt automatically creates the executable:
foreach(test_src ${KERNEL_TESTS})
get_filename_component(test_name ${test_src} NAME_WE)
set(target_name "akunu_kernel_${test_name}")
add_executable(${target_name} ${test_src})
target_link_libraries(${target_name} PRIVATE akunu_engine)
target_include_directories(${target_name} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/tests)
set_source_files_properties(${test_src} PROPERTIES LANGUAGE OBJCXX)
endforeach()
This produces build/akunu_kernel_test_vector_scale.
Step 4: Build and Run
make && build/akunu_kernel_test_vector_scale
Expected output:
=== Kernel Test: Vector Scale ===
GPU: Apple M4 Pro
dimension x scale sweep done
=== Results: 35 passed, 0 failed ===
Unit Tests (Non-Kernel)
The non-kernel tests verify CPU-side logic. They follow a simpler pattern since there is no GPU involvement.
test_tokenizer_internal.cpp
Tests BPE tokenizer internals without needing a model file. Validates byte-pair encoding, special token handling, and edge cases like empty strings and UTF-8 multi-byte characters.
test_grammar.cpp
Tests GBNF grammar parsing and bitmask generation. Creates a grammar from a GBNF string, feeds it token sequences, and verifies that the acceptance bitmask correctly allows/rejects tokens at each step.
test_table_builder.cpp
Tests dispatch table construction. Creates a mock model configuration and
verifies that build_dispatch_table produces the expected sequence of commands
with correct parameters, buffer bindings, and patch types.
test_kv_cache.cpp
Tests KV cache operations:
- Write: new key/value vectors are written to correct positions
- Advance: position counter increments correctly
- Shift: sliding window eviction moves data correctly
- Edge cases: write at max_seq_len, shift by more than current length
test_server.cpp
Tests HTTP request parsing and OpenAI API compatibility. No network involved – it feeds raw HTTP strings to the parser and checks the extracted fields.
Integration Tests
test_e2e.cpp
The end-to-end test loads a real model, tokenizes a prompt, runs prefill, generates tokens, and verifies that:
- The model loads without error
- Tokenization produces expected token count
- Prefill returns a valid first token
- Generation produces coherent text (checked by greedy decoding a known prompt)
- EOS token stops generation
test_inference.cpp
More detailed inference tests:
- Forward pass produces non-NaN outputs
- Greedy decoding is deterministic (same prompt always produces same output)
- KV cache reuse works (continuing generation from a saved position)
- Different model architectures produce valid output
test_long_context.cpp
Stress-tests the context window by filling the KV cache to capacity and verifying that:
- The model does not crash or produce NaN
- KV cache shift correctly evicts old entries
- Output quality does not degrade catastrophically after shift
Running Tests in CI
For continuous integration, the recommended test matrix is:
CI Pipeline
===========
Stage 1: Build (parallel)
+------------------------+
| make shaders | <-- Compile Metal shaders
| make engine | <-- Compile C++ code
+------------------------+
Stage 2: Fast Tests (no model, < 10s)
+------------------------+
| make test-unit | <-- Tokenizer, grammar, server
| kernel tests (all 16) | <-- GPU correctness
+------------------------+
Stage 3: Inference Tests (needs model, < 60s)
+------------------------+
| make test-infer | <-- E2E generation test
+------------------------+
Kernel tests require Apple Silicon hardware with Metal support. They cannot run
on Intel Macs or Linux CI runners. For CI, you need a macOS runner with Apple
Silicon (GitHub Actions has macos-14 runners with M1).
Test Philosophy
A few principles that guide akunu’s testing approach:
1. Every GPU kernel has a CPU twin. No exceptions. If the math is complicated enough to run on a GPU, it is complicated enough to verify on a CPU.
2. Deterministic test data. All test data comes from det_floats() with
fixed seeds. No rand(), no time(NULL). Tests must be reproducible.
3. Tolerance, not equality. Floating-point GPU computation will never match CPU computation bit-for-bit. Define appropriate tolerances and document why.
4. Sweep dimensions. Never test just one dimension. Bugs hide in edge cases: non-power-of-2 sizes, sizes smaller than threadgroup width, sizes larger than threadgroup memory.
5. Test the edges. Zero input (division by zero in norms), maximum values (F16 overflow), single-element (degenerate reduction), empty input (boundary conditions).
6. Fail loudly. When a comparison fails, print the index, actual value, expected value, difference, and tolerance. This turns a “test failed” into a diagnosis.
Summary
Akunu’s testing infrastructure rests on three pillars: CPU reference implementations that are simple enough to be obviously correct, GPU test harnesses that handle F16 conversion and buffer management, and a tolerance model that accounts for the inherent precision differences between CPU and GPU computation.
The 16 kernel tests cover every operation in the forward pass. The unit tests verify CPU-side logic. The integration tests validate the full inference pipeline with real models. Together, they provide confidence that changes to the codebase do not break correctness.
In the next chapter, we will use this testing infrastructure as we walk through adding a completely new Metal kernel from scratch.
Adding a New Metal Kernel
This chapter walks through adding a new Metal kernel to Akunu from scratch. We will build a complete example – a hypothetical vector_scale_f16 kernel that multiplies every element of an F16 buffer by a scalar – and trace every step from the .metal source file to the dispatch table integration and test.
The process has six steps:
1. Write the .metal shader
2. Define parameter structs (shared between CPU and GPU)
3. Add to the metallib build
4. Dispatch from Device
5. Integrate into table_builder (if it is part of inference)
6. Write a test
Let us go through each one.
Step 1: Write the .metal Shader
Metal shader files live in Akunu’s shader directory. Create a new file or add to an existing one. For our example, we will create the kernel inline:
// In an existing .metal file, or a new one
#include <metal_stdlib>
using namespace metal;
/// Multiply every element of an F16 buffer by a scalar.
///
/// Grid: 1D, total threads >= count
/// TG: 256 threads (or whatever fits)
///
/// Buffers:
/// [0] device half *data -- input/output (in-place)
///
/// Params:
/// [1] constant float &scale
/// [2] constant uint &count
kernel void vector_scale_f16(
device half *data [[buffer(0)]],
constant float &scale [[buffer(1)]],
constant uint &count [[buffer(2)]],
uint tid [[thread_position_in_grid]])
{
if (tid >= count) return;
data[tid] = half(float(data[tid]) * scale);
}
Key points for Metal kernel design in Akunu:
Buffer binding convention. Akunu binds buffers by index using device.set_buffer(buf, offset, index) or cmd.add_buffer(buf, offset, index). The index in the Metal signature ([[buffer(N)]]) must match. By convention, input buffers come first, output buffers next, and parameters last.
Parameter passing. Small parameters (< 4KB) are passed via setBytes (which Akunu wraps as device.set_bytes() or cmd.set_params()). For dispatch tables, params are pre-allocated as GPU buffers for zero-copy patching.
Thread indexing. Use [[thread_position_in_grid]] for 1D kernels, or [[threadgroup_position_in_grid]] + [[thread_position_in_threadgroup]] for kernels that need threadgroup-level coordination. Akunu supports both dispatch() (grid in threadgroups) and dispatch_threads() (grid in threads).
Bounds checking. Always check tid >= count at the top. Akunu’s dispatch grid is rounded up to the threadgroup size, so trailing threads may be out of bounds.
Data types. Use half for F16, bfloat for BF16 (M4+ only via metal_bfloat header), float for F32, uint for U32. For quantized formats, you typically read uint or uchar and manually unpack.
Step 2: Define Parameter Structs
If your kernel takes structured parameters (more than a single scalar), define a shared struct. Akunu’s convention is to use C structs that are layout-compatible between CPU and GPU. Metal’s constant buffer binding expects 16-byte aligned access for best performance, so pad structs to 16-byte boundaries:
// Shared between CPU (.cpp) and GPU (.metal)
struct VectorScaleParams {
uint32_t count;
float scale;
uint32_t _pad0; // pad to 16 bytes
uint32_t _pad1;
};
For simple cases like our example, you can skip the struct and pass individual scalars. But for anything with more than 2-3 parameters, a struct keeps things clean and avoids off-by-one buffer index mistakes.
The parameter structs used throughout Akunu follow a consistent pattern:
// GEMM parameters (used by all GEMV/GEMM kernels)
struct {
uint32_t M, N, K, lda, ldb, ldc;
float alpha, beta;
} params; // 32 bytes, naturally aligned
// MLX quantized GEMV parameters
struct {
uint32_t M, N, K, group_size, bits, weight_bytes, _p0, _p1;
} mlx_params; // 32 bytes
// Norm parameters
struct {
uint32_t dim;
float eps;
uint32_t _p0, _p1;
} norm_params; // 16 bytes
Step 3: Add to the Metallib Build
Akunu compiles all .metal files into a single .metallib using xcrun metal and xcrun metallib. The build process is:
*.metal files
|
v
xcrun metal -c -target air64-apple-macos14.0 -o file.air file.metal
|
v (repeat for each .metal file)
|
xcrun metallib -o akunu.metallib *.air
If you created a new .metal file, add it to the build script or Makefile. If you added to an existing file, it will be picked up automatically.
To verify your kernel compiled successfully:
# List all functions in the metallib
xcrun metal-objdump --disassemble-all akunu.metallib 2>&1 | grep "^_"
You should see _vector_scale_f16 in the output.
Step 4: Dispatch from Device
For one-off or encoder-phase dispatches (not in the per-token dispatch table), you call the kernel directly through Device:
void apply_scale(Device& dev, Buffer data, float scale, int count) {
// Get (or create and cache) the pipeline state object
Pipeline pso = dev.get_pipeline("vector_scale_f16");
// Set up the dispatch
dev.set_pipeline(pso);
dev.set_buffer(data, 0, 0); // buffer index 0
dev.set_bytes(&scale, sizeof(float), 1); // buffer index 1
uint32_t n = (uint32_t)count;
dev.set_bytes(&n, sizeof(uint32_t), 2); // buffer index 2
// Dispatch: total threads = count, threadgroup size = 256
dev.dispatch_threads(Dim3(count), Dim3(256));
}
The get_pipeline() call compiles the kernel function from the metallib on first use and caches the resulting MTLComputePipelineState. Subsequent calls with the same kernel name return the cached pipeline instantly.
Function constants. If your kernel needs compile-time specialization (e.g., head dimension, group size), use function constants:
// In Metal:
constant uint HEAD_DIM [[function_constant(0)]];
// In C++:
uint32_t fc_indices[] = {0};
uint32_t fc_values[] = {128}; // HEAD_DIM = 128
Pipeline pso = dev.get_pipeline("my_kernel",
"my_kernel_hd128", // cache key
fc_indices, fc_values, 1);
Function constants create specialized pipeline variants at runtime. Akunu caches each variant by its cache key string.
Step 5: Integrate into the Dispatch Table
If your kernel runs as part of the per-token forward pass, it needs to be added to the dispatch table built by build_dispatch_table() in table_builder.cpp. The dispatch table is built once during model initialization and then replayed for every token.
Using the CmdBuilder helper:
// Inside build_dispatch_table(), at the appropriate point
// in the layer loop:
{
Pipeline pso = device.get_pipeline("vector_scale_f16");
struct {
uint32_t count;
float scale;
uint32_t _p0, _p1;
} params = {(uint32_t)dim, some_scale_value, 0, 0};
CmdBuilder(table, pso, Dim3((dim + 255) / 256), Dim3(256))
.threads() // use dispatch_threads
.buf(scratch.h0, 0) // buffer binding 0
.params(params, 1) // params at buffer index 1
.label("layer.%d.scale", layer)
.emit();
}
Or using the raw DispatchCmd API:
{
Pipeline pso = device.get_pipeline("vector_scale_f16");
DispatchCmd cmd = DispatchCmd::make(pso,
Dim3((dim + 255) / 256), // grid (in threadgroups)
Dim3(256)); // threadgroup size
cmd.use_dispatch_threads = true;
cmd.add_buffer(scratch.h0, 0, 0); // (buffer, offset, index)
struct { uint32_t count; float scale; uint32_t _p0, _p1; }
params = {(uint32_t)dim, scale, 0, 0};
cmd.set_params(¶ms, sizeof(params), 1);
// Optional: patch a field at runtime (e.g., position)
// cmd.patch_type = DispatchCmd::PATCH_POSITION;
// cmd.patch_offset_1 = offsetof(decltype(params), some_field);
table.commands.push_back(cmd);
table.set_last_label("my_scale");
}
Dispatch table patching. Some parameters change every token (position, KV sequence length, argmax output offset). The dispatch table supports patching specific fields in pre-allocated parameter buffers:
| Patch type | What it patches | Used by |
|---|---|---|
PATCH_POSITION | Position field in RoPE/PE params | RoPE, positional embedding |
PATCH_KV_SEQ_LEN | KV sequence length in attention params | Flash attention |
PATCH_TOKEN_OFFSET | Token index in embedding params | Embedding lookup |
PATCH_ARGMAX_OUTPUT | Output offset in token chain buffer | Argmax |
If your kernel needs a patched field, set the patch_type and patch_offset_1 to indicate which field in the parameter struct should be updated by the chain decoder.
Step 6: Write a Test
Test your kernel with known inputs and expected outputs. A typical test pattern:
void test_vector_scale() {
Device dev;
dev.init("akunu.metallib");
// Allocate and fill input buffer
const int N = 1024;
std::vector<uint16_t> input_f16(N);
for (int i = 0; i < N; i++) {
__fp16 val = (__fp16)(i * 0.1f);
memcpy(&input_f16[i], &val, 2);
}
Buffer buf = dev.allocate(input_f16.data(), N * 2);
// Dispatch kernel
float scale = 2.0f;
Pipeline pso = dev.get_pipeline("vector_scale_f16");
dev.begin_encoding();
dev.set_pipeline(pso);
dev.set_buffer(buf, 0, 0);
dev.set_bytes(&scale, sizeof(float), 1);
uint32_t count = N;
dev.set_bytes(&count, sizeof(uint32_t), 2);
dev.dispatch_threads(Dim3(N), Dim3(256));
dev.end_encoding_sync();
// Verify results
const __fp16 *result = (const __fp16 *)buf.contents;
for (int i = 0; i < N; i++) {
float expected = i * 0.1f * 2.0f;
float actual = (float)result[i];
assert(fabsf(actual - expected) < 0.01f);
}
dev.free_buffer(buf);
printf("vector_scale_f16: PASSED\n");
}
Key testing tips:
- Always sync before reading. Call
dev.end_encoding_sync()to ensure the GPU has finished before readingbuf.contentson CPU. UMA makes the memory accessible, but coherence requires the command buffer to complete. - Test boundary conditions. Test with N not divisible by threadgroup size (e.g., N=1000 with TG=256) to verify your bounds check works.
- Test numerical accuracy. F16 has limited precision (about 3 decimal digits). Use tolerances appropriate for the data type.
- Test with dispatch tables too. If your kernel will run in the dispatch table, test it through the table replay path, not just direct dispatch.
Complete Walkthrough: Adding a Hypothetical Kernel
Let us trace the complete path for a more realistic example: a fused LayerNorm + linear projection kernel for Whisper. This would combine the FFN layernorm and the up-projection GEMV into a single dispatch.
1. Metal Shader
kernel void layernorm_gemv_f16(
device const half *input [[buffer(0)]],
device const half *norm_weight [[buffer(1)]],
device const half *norm_bias [[buffer(2)]],
device const half *proj_weight [[buffer(3)]],
device const half *proj_bias [[buffer(4)]],
device half *output [[buffer(5)]],
constant uint &dim [[buffer(6)]],
constant float &eps [[buffer(7)]],
constant uint &out_dim [[buffer(8)]],
uint2 gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]])
{
// Phase 1: LayerNorm (cooperative across threadgroup)
// ... compute mean, variance, normalize in shared memory ...
// Phase 2: GEMV (each threadgroup handles some output rows)
// ... dot product of normalized input with weight rows ...
}
2. Parameter Struct
struct LayerNormGEMVParams {
uint32_t dim; // input dimension
float eps; // norm epsilon
uint32_t out_dim; // output dimension (proj rows)
uint32_t _pad;
};
3. Build
Add the .metal file to the build, recompile metallib.
4. Dispatch Integration
static void w_emit_layernorm_gemv(DispatchTable& tbl, Device& dev,
Buffer in, Buffer norm_w, Buffer norm_b,
Buffer proj_w, Buffer proj_b, Buffer out,
int dim, float eps, int out_dim)
{
Pipeline pso = dev.get_pipeline("layernorm_gemv_f16");
if (!pso.handle) {
// Fallback: separate layernorm + GEMV
w_emit_layernorm(tbl, dev, in, norm_w, norm_b, out, dim, eps);
w_emit_gemv(tbl, dev, out, proj_w, out, 0, out_dim, dim);
return;
}
// ... fused dispatch ...
}
5. Use in Decoder Table
// In build_whisper_decode_table(), replace:
// w_emit_layernorm(...)
// w_emit_gemv_bias_gelu(...)
// With:
// w_emit_layernorm_gemv(...)
// w_emit_gelu(...) // GELU still separate
6. Test
Compare output of the fused kernel against separate layernorm + GEMV on the same input data.
Performance Considerations
Writing a correct kernel is the first challenge. Making it fast is the second. Here are the performance patterns that matter most on Apple Silicon:
Memory Bandwidth is King
For inference-sized problems (M=1 GEMV), the kernel is almost always memory-bandwidth-bound, not compute-bound. The Apple M2 Pro has ~200 GB/s of memory bandwidth and ~7 TFLOPS of F16 compute. For a 4096x4096 GEMV:
Data to read: 4096 * 4096 * 2 bytes = 32 MB (F16 weights)
Compute: 4096 * 4096 * 2 FLOPS = 33.5 MFLOPS
Time at bandwidth limit: 32 MB / 200 GB/s = 0.16 ms
Time at compute limit: 33.5 MFLOPS / 7 TFLOPS = 0.000005 ms
Arithmetic intensity: 33.5M / 32M = ~1 FLOP/byte
The arithmetic intensity of 1 FLOP/byte is well below the crossover point (~50 FLOP/byte for M2 Pro). This means your kernel’s speed is determined almost entirely by how efficiently it reads memory. Optimize for coalesced reads, minimize bank conflicts, and avoid reading the same data multiple times.
Threadgroup Size Selection
The optimal threadgroup size depends on the kernel’s register and threadgroup memory usage. Some guidelines:
| Kernel type | Recommended TG size | Rationale |
|---|---|---|
| Simple element-wise | 256 | Good occupancy, minimal overhead |
| GEMV (F16) | 128 | 16 rows x 8 threads per row |
| GEMV (quantized) | 256 | More ALU work for dequant, need threads |
| Reduction (norm) | min(dim, 1024) | One TG per row, use all SIMD groups |
| Attention | 32 (1 simdgroup) | SIMD-level matmul, shared memory |
Simdgroup Operations
Apple Silicon GPUs have 32-wide SIMD groups. Use simd_sum(), simd_shuffle(), and simdgroup matrix multiply (simdgroup_matrix<half, 8, 8>) when possible. These operations are dramatically faster than equivalent threadgroup-memory-based reductions because they operate within the register file.
Avoiding Threadgroup Memory Bottlenecks
Threadgroup memory on Apple Silicon is carved from the same SRAM as the L1 cache. Using too much threadgroup memory reduces cache capacity and can hurt performance. For simple kernels, prefer passing data through registers (via SIMD shuffles) over threadgroup memory.
Debugging Metal Kernels
When things go wrong (and they will), here are the tools:
Metal GPU Capture. Xcode’s GPU debugger lets you capture a frame of GPU work and inspect every buffer, every dispatch, and every thread’s execution. This is the single most useful debugging tool for Metal kernels.
Printf from Metal. Metal supports printf in shaders (with performance caveats). Add #include <metal_stdlib> and use printf("tid=%u val=%f\n", tid, float(data[tid])). Output appears in Xcode’s console. Be aware that printf from GPU threads is non-deterministic in ordering and can significantly slow down execution. Use it for debugging specific threads, not for bulk output.
Validation layers. Enable Metal validation (MTL_DEBUG_LAYER=1 environment variable) to catch out-of-bounds buffer access, uninitialized reads, and other errors. You can also set MTL_SHADER_VALIDATION=1 for even stricter checks, though this has a larger performance impact.
Numerical debugging. When you get wrong numerical results, dump the intermediate buffers to files and compare against a Python reference implementation. Many quantization bugs are off-by-one in bit extraction or wrong byte ordering. A useful technique is to write a Python script that reads the same weight file and computes the expected output for a known input vector.
Dispatch geometry. A very common bug is getting the grid or threadgroup dimensions wrong. If your kernel silently produces zeros or garbage, check:
- Is the grid large enough to cover all elements?
- Is the threadgroup size compatible with the kernel’s requirements?
- Are you using
dispatch()(grid = number of threadgroups) vsdispatch_threads()(grid = total threads)? - Does your kernel’s
[[threads_per_threadgroup]]attribute match the TG size you are dispatching?
Common pitfalls checklist:
| Symptom | Likely cause |
|---|---|
| All zeros output | Wrong buffer binding index, or grid too small |
| NaN/Inf output | Division by zero in normalization, or uninitialized buffer |
| Correct for small N, wrong for large N | Overflow in index calculation (use uint64_t for offsets) |
| Non-deterministic results | Race condition in threadgroup memory, or missing barrier |
| Slightly wrong values | F16 precision loss, or wrong quantization offset (e.g., -8 vs -16) |
| Kernel not found | Typo in kernel name, or .metal file not in metallib build |
Summary
Adding a kernel to Akunu follows a predictable path:
.metal file -> metallib build -> Device::get_pipeline()
|
+-------+-------+
| |
Direct dispatch DispatchCmd in table
(encoder, init) (per-token decode)
| |
+-------+-------+
|
Test
The data-driven design means the kernel itself is self-contained – it does not need to know about architectures, quantization formats, or model configs. Those concerns are handled by the dispatch table builder and the dtype descriptor table. Your kernel just needs to do one thing correctly: take input buffers, apply a computation, and write output buffers.
Adding a New Model Architecture
One of Akunu’s design goals is that adding support for a new transformer architecture should not require touching the core inference loop, the dispatch table builder, or any Metal kernel. Instead, you define a data descriptor that captures the architecture’s unique properties, and the existing machinery handles the rest. This chapter walks through exactly how to do that.1
The Data-Driven Approach
In many inference engines, adding a new architecture means writing a new forward pass function full of if (arch == "llama") ... else if (arch == "qwen") ... branches. Akunu takes a different approach: the ArchDescriptor struct encodes every architecture-specific decision as a data field, and build_dispatch_table() reads from the descriptor without ever branching on architecture name.
Here is the ArchDescriptor:
struct ArchDescriptor {
// Activation
const char *activation_kernel; // "silu_gate_f16", "gelu_gate_f16"
// Embedding
float embedding_scale; // 0 = no scaling
// Per-head norms
bool has_qk_norm; // Q/K head-level RMSNorm
// Post-norms (Gemma-style)
bool has_post_attn_norm;
bool has_post_ffn_norm;
const char *post_attn_norm_key; // weight name suffix
const char *post_ffn_norm_key;
// MLX quantization
int quant_bits;
int quant_group_size;
// RoPE
const char *rope_kernel; // fused kernel name
const char *rope_standalone; // standalone kernel name
Buffer rope_freqs; // precomputed frequencies
// Output
bool tie_embeddings; // reuse embedding for logit projection
// Encoder-decoder properties
bool is_encoder_decoder;
bool has_cross_attention;
bool has_conv_frontend;
bool has_bias;
const char *norm_type; // "rmsnorm" or "layernorm"
const char *encoder_activation;
bool is_embedding_model;
};
Every field in this struct corresponds to a decision point in the dispatch table builder. Let us trace how these fields affect inference.
How the Descriptor Drives Inference
Activation Kernel
activation_kernel = "silu_gate_f16" -> SwiGLU: SiLU(gate) * up
activation_kernel = "gelu_gate_f16" -> GeGLU: GELU(gate) * up
activation_kernel = "gelu_f16" -> Plain GELU (no gate, Whisper)
The table builder passes this directly to device.get_pipeline():
Pipeline act_pso = device.get_pipeline(arch.activation_kernel);
If you have a new architecture that uses, say, ReLU-squared gating, you would add a relu_sq_gate_f16 Metal kernel and set activation_kernel = "relu_sq_gate_f16".
Embedding Scale
Gemma models multiply the embedding output by sqrt(dim) before feeding it into the transformer. This is captured as:
float embedding_scale; // 0 = no scaling, sqrt(dim) for Gemma
The table builder checks:
if (arch.embedding_scale > 0.0f) {
// emit a temperature_scale_f16 dispatch
}
QK-Norm
Qwen3 and Gemma apply per-head RMSNorm to Q and K after projection. The flag has_qk_norm triggers emission of either:
- A fused
head_norm_rope_neox_kv_write_f16kernel (if compatible RoPE style) - Separate
head_rmsnorm_f16dispatches for Q and K
Post-Norms
Gemma has an unusual architecture where there are additional RMSNorm layers after the attention output projection and after the FFN down projection, before the residual add. The flags has_post_attn_norm and has_post_ffn_norm control whether these extra norm dispatches are emitted, and the post_attn_norm_key / post_ffn_norm_key strings tell the table builder what weight names to look up.
RoPE Variant
Different architectures use different RoPE implementations:
| Architecture | rope_kernel | Description |
|---|---|---|
| LLaMA | rope_qkv_write_f16 | Standard interleaved RoPE |
| Qwen3 | rope_neox_qkv_write_f16 | NeoX split-half RoPE |
| Gemma | rope_neox_qkv_write_f16 | NeoX split-half RoPE |
| Whisper | nullptr | No RoPE (sinusoidal PE) |
Setting rope_kernel = nullptr causes the table builder to skip the RoPE+KV-write dispatch entirely.
Tied Embeddings
When tie_embeddings = true, the logit projection reuses the token embedding weight matrix instead of a separate output.weight tensor:
const char *logit_name = arch.tie_embeddings
? "token_embedding.weight"
: "output.weight";
Buffer logit_w = weights.get_tensor(logit_name);
Walkthrough: Adding a Hypothetical Architecture
Let us add support for a hypothetical “Phoenix” architecture with these properties:
- SwiGLU activation (same as LLaMA)
- NeoX-style RoPE (same as Qwen3)
- No QK-norm
- No post-norms
- Tied embeddings
- Embedding scale of 1.0 / sqrt(dim) (inverse, unlike Gemma)
- Standard decoder-only (no encoder)
Step 1: Create the Factory Function
In arch_descriptor.h, add:
inline ArchDescriptor arch_phoenix(int dim) {
ArchDescriptor d = arch_llama(); // start from LLaMA defaults
// Override what differs
d.rope_kernel = "rope_neox_qkv_write_f16"; // NeoX RoPE
d.rope_standalone = "rope_neox_f16";
d.tie_embeddings = true;
d.embedding_scale = 1.0f / sqrtf((float)dim); // inverse scaling
return d;
}
Notice we start from arch_llama() and only override what is different. This is the inheritance pattern – most architectures share 80% of their properties with LLaMA.
Step 2: Register in arch_from_config
Add a case to the dispatch function:
inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
if (strstr(arch_name, "phoenix"))
return arch_phoenix(dim);
if (strstr(arch_name, "whisper"))
return arch_whisper();
if (strstr(arch_name, "bert"))
return arch_bert();
if (strstr(arch_name, "qwen"))
return arch_qwen3();
if (strstr(arch_name, "gemma"))
return arch_gemma(dim);
return arch_llama(); // default
}
The strstr matching is intentionally loose – it matches "phoenix", "phoenix-1.5", "phoenix_moe", etc. The first match wins, so ordering matters for architectures whose names are substrings of others.
Step 3: Add Weight Name Mappings (if different from LLaMA)
If Phoenix uses different tensor names in its weight files, you need mappings. For GGUF, the tensor names are already canonical (llama.cpp normalizes them). For MLX/SafeTensors, add rules to kMLXRules:
// If Phoenix uses different HF naming:
{"model.layers.{n}.self_attn.qkv_proj.weight",
"layers.{n}.attention.qkv.weight"},
Most architectures follow the LLaMA naming convention in GGUF, so this step is often unnecessary.
Step 4: Handle Any Unique Weight Structure
If Phoenix has a fused QKV weight (a single matrix instead of separate Q, K, V), you might need to add logic to split it during loading. But if Phoenix follows the standard separate Q/K/V pattern (most models do), nothing extra is needed.
Step 5: Test
Load a Phoenix model and verify:
- Config extraction produces correct dimensions
- All weights are found and loaded
- The dispatch table has the expected number of commands
- The embedding scale is applied
- NeoX RoPE is used instead of standard RoPE
- Output projection uses the embedding weight
A minimal test:
// Load model
akunu_model_t model = akunu_load("phoenix-7b.gguf", "akunu.metallib");
AkunuModelConfig cfg = akunu_get_config(model);
// Verify architecture was detected
assert(strstr(cfg.architecture, "phoenix") != nullptr);
// Generate a few tokens to verify correctness
uint32_t tokens[] = {1, 15043, 29892, 920}; // "Hello, world"
akunu_generate(model, tokens, 4, 10, sampling, callback, nullptr);
akunu_free_model(model);
For numerical correctness testing, compare output logits against a reference implementation (e.g., Hugging Face Transformers in Python) on the same input tokens.
What Does NOT Need to Change
This is the key point of the data-driven approach. When you add a new architecture, the following files remain untouched:
| File | Why it does not change |
|---|---|
table_builder.cpp | Reads from ArchDescriptor, never branches on arch name |
table_builder.h | Only the function signature, no arch-specific logic |
All .metal files | Kernels are generic (F16 GEMV, RoPE, etc.) |
device.h / device.cpp | Hardware abstraction, no model knowledge |
dispatch_table.h | Format of commands, no model knowledge |
chain_decoder.cpp | Replays dispatch table, no model knowledge |
prefill.cpp | Uses same kernels via dtype descriptors |
serve.h | HTTP server, model-agnostic |
The only files that change are:
| File | What changes |
|---|---|
arch_descriptor.h | Add factory function + case in arch_from_config |
mlx_weight_store.h | Add name mapping rules (if different naming) |
weight_store.cpp | Handle any unique GGUF tensor layout (rare) |
That is typically 10-30 lines of code for a standard transformer variant.
When the Descriptor Is Not Enough
There are cases where the ArchDescriptor pattern does not cover an architectural difference:
Mixture of Experts (MoE). MoE models like Mixtral have a routing mechanism that selects a subset of FFN experts per token. This requires a fundamentally different dispatch pattern (router + sparse expert selection) that cannot be expressed as a boolean flag.
Novel attention patterns. If an architecture uses linear attention, sliding window attention with a non-standard pattern, or multi-query attention with a different head structure, the attention dispatch logic may need extension.
Non-standard normalization. If an architecture uses something other than RMSNorm or LayerNorm (e.g., CRMSNorm, QKNorm with different epsilon handling), a new kernel may be needed.
For these cases, the approach is:
- Add the new kernel(s) as described in the previous chapter
- Add new flags to
ArchDescriptorto control the new behavior - Add conditional logic to
build_dispatch_table()gated on those flags - The new logic only runs when the flag is set – existing architectures are unaffected
The goal is to keep build_dispatch_table() as a data-driven loop, not a forest of architecture-specific branches. Even when new logic is needed, it should be expressed as “if this flag, use this kernel” rather than “if architecture is X”.
Existing Architecture Descriptors
For reference, here are the current architectures and how they differ:
LLaMA (Default)
activation: SiLU gate (SwiGLU)
embedding_scale: none
qk_norm: no
post_norms: no
rope: standard interleaved
tie_embeddings: no
norm_type: rmsnorm
bias: no
Qwen3
activation: SiLU gate (SwiGLU) <- same as LLaMA
embedding_scale: none <- same
qk_norm: YES <- different
post_norms: no <- same
rope: NeoX split-half <- different
tie_embeddings: YES <- different
norm_type: rmsnorm <- same
bias: no <- same
Gemma 3
activation: GELU gate (GeGLU) <- different
embedding_scale: sqrt(dim) <- different
qk_norm: YES <- different
post_norms: YES (both) <- different
rope: NeoX split-half <- same as Qwen3
tie_embeddings: YES <- same as Qwen3
norm_type: rmsnorm <- same
bias: no <- same
sliding_window: every 6th layer global <- unique
Whisper
activation: GELU (no gate) <- different
embedding_scale: none <- same as LLaMA
qk_norm: no <- same as LLaMA
post_norms: no <- same
rope: NONE (sinusoidal PE) <- different
tie_embeddings: YES <- different
norm_type: layernorm <- different
bias: YES (all layers) <- different
encoder_decoder: YES <- unique
cross_attention: YES <- unique
conv_frontend: YES <- unique
BERT (nomic-bert)
activation: SiLU gate (SwiGLU) <- same as LLaMA
embedding_scale: none <- same
qk_norm: no <- same
post_norms: no <- same
rope: NeoX split-half <- same as Qwen3
tie_embeddings: no <- same as LLaMA
norm_type: rmsnorm <- same
bias: no <- same
is_embedding: YES <- unique
The pattern is clear: each architecture differs from LLaMA in a small number of dimensions. The descriptor captures exactly those differences, and the table builder handles the combinatorics.
Testing Methodology
When you add a new architecture, testing should cover multiple levels:
Level 1: Config Extraction
Verify that the model’s config (dimensions, layer count, head count, etc.) is extracted correctly from either GGUF metadata or MLX config.json:
akunu_model_t model = akunu_load("phoenix-7b.gguf", "akunu.metallib");
AkunuModelConfig cfg = akunu_get_config(model);
assert(cfg.dim == 4096);
assert(cfg.n_layers == 32);
assert(cfg.n_heads == 32);
assert(cfg.n_kv_heads == 8);
assert(cfg.head_dim == 128);
Level 2: Dispatch Table Sanity
Check that the dispatch table has the expected structure. Count the total commands and verify key labels are present:
Expected for a standard 32-layer decoder-only model:
1 embedding
1 initial norm
32 * ~11 commands per layer = 352 (varies with fusion)
1 output norm
1 logit projection
1 argmax
~357 total commands
If your architecture has QK-norm, expect +1-2 commands per layer (or 0 if fused). If it has post-norms, expect +1-2 per layer. The numbers do not need to be exact, but a gross mismatch (e.g., 100 commands for a 32-layer model) indicates a problem.
Level 3: Single-Token Numerical Correctness
The gold standard is comparing logit outputs against a reference implementation. The procedure:
- Pick a known input sequence (e.g.,
[1, 15043, 29892]) - Run the same input through Hugging Face Transformers in Python
- Extract the logits for the last position
- Run the same input through Akunu
- Compare the top-5 token IDs and their logit values
For quantized models, exact numerical match is not expected. But the top-1 token should match, and top-5 should overlap significantly. If the top-1 tokens diverge on simple inputs, something is wrong with the architecture implementation.
Level 4: Generation Quality
Run the model on a few prompts and check that the output is coherent. This is subjective but important – subtle bugs (wrong RoPE variant, incorrect norm epsilon, swapped Q/K norms) can produce output that is plausible-looking but degraded in quality. Compare against the same model running in its native framework (e.g., MLX for MLX models, llama.cpp for GGUF models).
Troubleshooting Guide
Common issues when adding a new architecture:
| Problem | Likely cause | Fix |
|---|---|---|
| Model loads but generates gibberish | Wrong RoPE variant or wrong rope_theta | Check rope_kernel field and rope_theta in config |
| Outputs repeat the same token | Missing positional encoding (RoPE or PE) | Verify RoPE kernel is dispatched, or PE is added |
| First token correct, rest wrong | KV cache write not happening | Check that RoPE+KV write dispatch exists in table |
| Crash during weight loading | Tensor name mismatch | Add missing name mapping rules |
| NaN in output | Wrong norm epsilon or missing norm weight | Check norm_eps value and weight names |
| Quality worse than reference | Wrong activation (SiLU vs GELU) | Check activation_kernel field |
| Embedding values too large/small | Missing or wrong embedding scale | Check embedding_scale field |
Summary
Adding a new architecture to Akunu is a three-step process:
1. Write an arch_xxx() factory function (5-20 lines)
|
2. Add case to arch_from_config() (1 line)
|
3. Add weight name mappings if needed (0-10 lines)
|
Done. No kernel changes. No table builder changes.
No dispatch table changes. No server changes.
The ArchDescriptor pattern is what makes this possible. By encoding architectural decisions as data rather than control flow, the system remains modular: the kernel layer knows about math, the dispatch layer knows about GPU commands, and only the descriptor layer knows about transformer architecture variants. Each layer can be modified independently, and adding a new architecture is a localized change that cannot break existing ones.
-
The data-driven approach is inspired by compiler design, where instruction selection is driven by pattern tables rather than hand-coded switch statements. The
ArchDescriptorplays a similar role to an ISA descriptor in a retargetable code generator. ↩
Performance Profiling and Benchmarking
If you have made it this far in the book, you have a reasonably complete mental model of how akunu turns a pile of quantized weights and Metal shaders into streaming text on Apple Silicon. That is great, but mental models do not ship performance. At some point you need to measure things, and the gap between “I think the attention kernel is the bottleneck” and “the attention kernel consumes 38% of GPU time at 14.2 GB/s effective bandwidth on an M2 Pro” is the gap between guessing and engineering.
This chapter covers the full profiling stack: Apple’s own GPU tools, akunu’s built-in CLI profilers, the key metrics you should care about, roofline analysis for memory-bound inference, and how to interpret the numbers in context by comparing against llama.cpp and MLX.
The Three Metrics That Matter
Before we reach for any tool, let us agree on what we are measuring. LLM inference has three headline numbers that users and developers care about:
| Metric | Definition | Why It Matters |
|---|---|---|
| Prefill tok/s | Prompt tokens processed per second | Determines how fast you can ingest a 4K context window. Governs perceived responsiveness for the first token. |
| Decode tok/s | Generated tokens per second | The sustained throughput the user sees while text is streaming. This is the number people compare across frameworks. |
| TTFT (Time To First Token) | Wall-clock time from prompt submission to first generated token | The most perceptually important metric. Users notice latency more than throughput. TTFT = prefill time + one decode step. |
A fourth metric, peak memory, also matters on Apple Silicon because you are sharing a unified memory pool with the OS, the window compositor, and whatever else the user has open. Running out of memory does not just crash your process; it can trigger aggressive swapping that destroys system responsiveness.
Let us now walk through the tools that let you measure all of this.
akunu_bench: The llama-bench Equivalent
The simplest way to get prefill and decode numbers is akunu_bench, a C++ tool that replicates the methodology of llama-bench from the llama.cpp project. Here is the actual source signature from tools/akunu_bench.cpp:
Usage: akunu_bench <model> [-p N] [-n N] [-r N]
The flags are:
| Flag | Default | Meaning |
|---|---|---|
-p N / --pp N | 512 | Prompt length for prefill test |
-n N / --tg N | 128 | Number of tokens for decode (text generation) test |
-r N / --reps N | 5 | Repetitions per test (for statistical stability) |
The tool works by creating synthetic prompts filled with the BOS token (token ID 1). This is deliberate – you want a reproducible input that does not depend on tokenizer behavior or prompt content. Here is what happens internally:
-
Prefill test: Fill a vector of
pptokens with BOS. Callakunu_prefill()and time it withstd::chrono::high_resolution_clock. Repeatrepstimes. Report mean and standard deviation. -
Decode test: Prefill a single BOS token, then call
akunu_chain_decode()fortgtokens in a single GPU submission. Again, repeat and report statistics.
The output matches the llama-bench markdown table format so you can paste results directly into GitHub issues:
| model | size | test | t/s |
| --- | ---: | ---: | ---: |
| Qwen3-4B-Q4_0.gguf | 2341 MiB | pp512 | 1842.31 +/- 12.40 |
| Qwen3-4B-Q4_0.gguf | 2341 MiB | tg128 | 87.42 +/- 0.83 |
A few things to note about the methodology:
-
The decode test uses
akunu_chain_decode(), which batches alltgtokens into a single GPU command buffer submission. This measures the true GPU-limited throughput, not the overhead of individualakunu_decode_step()round-trips. If you were to measure decode by callingdecode_stepin a loop, you would be measuring CPU-GPU synchronization overhead as much as actual compute.1 -
Each repetition calls
akunu_reset()to clear the KV cache, ensuring independent measurements. Without the reset, later iterations would operate on a larger KV cache, which changes the attention kernel’s memory access pattern. -
The standard deviation across repetitions is typically very small (under 2%) on a quiet system. If you see high variance, check for thermal throttling or background processes competing for GPU resources.
akunu_benchmark: End-to-End with Real Prompts
While akunu_bench gives you clean synthetic numbers, akunu_benchmark exercises the full akunu_generate() path with real prompts of varying lengths:
Usage: akunu_benchmark <model>
This tool runs three prompts (short, medium, long), measures AkunuGenerationStats for each, and reports:
| Column | Meaning |
|---|---|
| Prompt | Length category |
| Tokens | Actual token count after encoding |
| Prefill (t/s) | Prefill throughput |
| Decode (t/s) | Decode throughput |
| First-tok(ms) | Time to first token (the TTFT metric) |
| Prefill(ms) | Raw prefill time |
| Total(s) | Wall-clock total |
After the prompt tests, it also runs a standalone chain decode measurement (128 tokens, greedy) to give you the raw GPU-limited decode throughput independent of sampling overhead.
The key insight from this tool is how prefill scales with prompt length. On Apple Silicon, prefill is a GEMM (matrix-matrix multiply) workload, and the GPU’s utilization increases with larger batch sizes. You will typically see:
- Short prompts (1-10 tokens): Low prefill tok/s because the GEMMs have tiny M dimension and cannot saturate the GPU’s compute units
- Medium prompts (50-200 tokens): Prefill tok/s climbs rapidly as GEMM occupancy improves
- Long prompts (500+ tokens): Prefill tok/s plateaus near the compute-bound peak
akunu_profile: Per-Kernel GPU Timing
This is the real workhorse for optimization. Where akunu_bench tells you how fast, akunu_profile tells you where the time goes.
Usage: akunu_profile <model> [--tokens N]
Here is what happens under the hood, based on the actual source in tools/akunu_profile.cpp:
- Load the model and prefill a single BOS token
- Call
akunu_profile_decode_step()which runs each dispatch command in its ownMTLCommandBuffer, enabling accurate per-kernel GPU timing via Metal’s built-in command buffer timing - Repeat for
Ntokens (default 5), accumulating timing data - Sort kernels by total GPU time and print a breakdown table
The output looks something like this (simplified):
Per-Kernel GPU Timing Breakdown
==========================================================================================
Kernel Dispatches Total (ms) Avg (ms) % GPU
------------------------------------------------------------------------------------------
L0 GEMV attn_qkv Q4_0 5 0.412 0.082 18.2%
L0 GEMV ffn_down Q4_0 5 0.318 0.064 14.1%
L0 GEMV ffn_gate_up Q4_0 5 0.304 0.061 13.4%
L0 Attention 5 0.201 0.040 8.9%
L0 GEMV attn_output Q4_0 5 0.156 0.031 6.9%
...
There is an important caveat: profiled decode is much slower than normal decode. The profiler wraps each kernel dispatch in its own command buffer to get accurate GPU timing. In normal operation, akunu batches the entire forward pass (embedding + N layers + output norm + logit projection + argmax) into a single command buffer, and the chain decoder batches multiple tokens into one submission. Profiled mode breaks this batching completely, so the absolute numbers are not representative of production throughput – they are only useful for relative comparisons between kernels.2
Reading the Profiler Output
The typical decode step for a LLaMA-like model with n_layers transformer layers contains:
+------------------+
| Embedding Lookup | 1 kernel
+------------------+
|
v
+------------------+
| Layer 0 | ~8-12 kernels per layer
| Attention Norm |
| QKV Projection | (GEMV or fused GEMV+RoPE+KV-write)
| RoPE + KV Write|
| Attention |
| Output Proj |
| Residual Add |
| FFN Norm |
| Gate+Up Proj | (possibly fused into single GEMV)
| Activation | (SiLU*gate or GELU*gate)
| Down Proj |
| Residual Add |
+------------------+
|
v
| Layer 1..N-1 | (repeat)
|
v
+------------------+
| Output Norm | 1 kernel
+------------------+
|
v
+------------------+
| Logit Projection | 1 GEMV (dim -> vocab_size)
+------------------+
|
v
+------------------+
| Argmax | 1 kernel
+------------------+
When you look at the profiler output, the GEMV (matrix-vector multiply) kernels dominate. For a Q4_0 model, the three big GEMVs per layer are:
-
QKV projection: Multiplies the hidden state by the Q, K, and V weight matrices. For a model with
n_heads=32, n_kv_heads=8, head_dim=128, this projectsdim=4096toq_dim + 2*kv_dim = 4096 + 2*1024 = 6144elements. -
FFN gate+up: Projects
dimto2*ffn_dim. For LLaMA-style models with SwiGLU,ffn_dimis typically~2.7*dim, so this is the largest single GEMV. -
FFN down: Projects
ffn_dimback todim.
The attention kernel itself is often not the biggest time consumer during decode (single token, long KV cache), because it is a relatively small operation: each head does a dot product of the query against kv_seq_len keys, then a weighted sum of values. The total work scales with n_heads * kv_seq_len * head_dim, which for moderate context lengths is much less than the GEMV work.
Xcode GPU Profiler (Instruments)
For the deepest level of insight, Apple provides GPU profiling through Instruments. There are two relevant instruments:
Metal System Trace
Metal System Trace shows the timeline of GPU command buffer submissions, encoding, and execution. This is the tool to use when you suspect CPU-GPU synchronization issues or want to understand the relationship between akunu’s chain decode submissions and actual GPU execution.
To capture a trace:
- Build akunu with debug symbols (CMake
RelWithDebInfoorDebug) - Open Instruments, choose “Metal System Trace” template
- Select your akunu binary as the target
- Record for a few seconds while running a generation
The trace shows:
| Track | What You See |
|---|---|
| GPU Timeline | Individual compute dispatches on the GPU hardware. Each dispatch shows its duration, pipeline state object (PSO) name, and threadgroup configuration. |
| Command Buffer Track | When each MTLCommandBuffer was committed, scheduled, and completed. Gaps between command buffers indicate CPU-side stalls. |
| Encoder Track | The compute command encoder’s encode phase. If encoding takes longer than GPU execution, you are CPU-bound. |
The key thing to look for in the Metal System Trace is GPU idle gaps. In a well-tuned chain decode:
CPU: [encode CB1] [encode CB2] [encode CB3]
GPU: [execute CB1][execute CB2] [execute CB3]
^-- no gap here: GPU stays busy
If you see gaps where the GPU is idle between command buffers, the CPU is not encoding fast enough. Akunu’s chain decode design specifically addresses this by encoding chain_decode_chunk tokens (64-128, depending on chip) into a single command buffer, ensuring the GPU has enough work to stay saturated.
GPU Counters
Instruments also provides GPU hardware counters (on supported devices) that show:
| Counter Group | Key Metrics |
|---|---|
| Occupancy | How many threadgroups are resident on the GPU simultaneously. Low occupancy means the GPU has idle ALUs. |
| Memory | Read/write bandwidth, cache hit rates. Critical for understanding whether your GEMV kernels are memory-bound (they almost always are). |
| ALU | Arithmetic utilization. For quantized GEMV, this is typically low because you are waiting on memory, not compute. |
| Shader | Per-pipeline-state breakdown. Shows which PSOs consume the most GPU time. |
Roofline Analysis for Apple Silicon
The roofline model is the single most useful framework for understanding LLM inference performance on Apple Silicon.3 The core idea is simple: every computation has an arithmetic intensity (operations per byte of memory accessed), and the hardware has a memory bandwidth ceiling and a compute ceiling. Your kernel’s throughput is limited by whichever ceiling it hits first.
Apple Silicon Memory Bandwidth
| Chip | Memory BW (GB/s) | GPU FP16 TFLOPS | Roofline Knee (ops/byte) |
|---|---|---|---|
| M1 | 68.25 | 2.6 | 38 |
| M1 Pro | 200 | 5.2 | 26 |
| M1 Max | 400 | 10.4 | 26 |
| M2 | 100 | 3.6 | 36 |
| M2 Pro | 200 | 7.0 | 35 |
| M2 Max | 400 | 13.6 | 34 |
| M3 | 100 | 4.1 | 41 |
| M3 Pro | 150 | 7.0 | 47 |
| M3 Max | 400 | 14.2 | 36 |
| M4 | 120 | 4.3 | 36 |
| M4 Pro | 273 | 9.2 | 34 |
| M4 Max | 546 | 18.0 | 33 |
The “roofline knee” is the arithmetic intensity where you transition from memory-bound to compute-bound. For LLM decode, the arithmetic intensity is almost always well below this knee.
Why Decode Is Memory-Bound
During single-token decode, each GEMV reads the entire weight matrix and multiplies it by a single vector. For a Q4_0 weight matrix of shape [N, K]:
- Bytes read:
N * K / 2bytes (4 bits per weight, packed) +N * K / 32 * 2bytes (one FP16 scale per block of 32) - FLOPs:
2 * N * K(multiply-accumulate) - Arithmetic intensity: roughly
2 * N * K / (N * K * 0.5625)= ~3.6 ops/byte
That is far below the roofline knee of 26-47 ops/byte. The GEMV is firmly memory-bound. This means:
Decode throughput is determined almost entirely by memory bandwidth.
The theoretical maximum decode tok/s for a model of total weight size W bytes on a chip with bandwidth B bytes/s is:
max_decode_tok_s = B / W
For a 4B parameter Q4_0 model (~2.3 GB weights):
| Chip | BW (GB/s) | Theoretical Max (tok/s) |
|---|---|---|
| M1 | 68.25 | 29.7 |
| M2 Pro | 200 | 87.0 |
| M3 Max | 400 | 174.0 |
| M4 Max | 546 | 237.4 |
In practice, akunu achieves 70-85% of theoretical bandwidth utilization for decode, which is quite good for a real-world system with cache management, RoPE computation, attention, and norm overhead on top of the raw GEMVs.
Why Prefill Is Compute-Bound (for Large Batches)
During prefill, the projections become GEMMs (matrix-matrix multiply) because you are processing seq_len tokens simultaneously. The arithmetic intensity scales with the batch dimension:
- Arithmetic intensity: ~
2 * Mops/byte (where M = batch/seq_len)
For M >= 20 or so, you cross the roofline knee and become compute-bound. This is why prefill throughput is typically 10-50x higher than decode throughput – you are actually using the GPU’s ALUs instead of just waiting on memory.
Bandwidth Utilization: The Real Performance Metric
Raw tok/s numbers are useful for user-facing comparisons, but for engineering purposes, bandwidth utilization is the metric that tells you how close you are to optimal:
bandwidth_utilization = (model_weight_bytes / decode_time_per_token) / peak_memory_bandwidth
Here is how to compute this from akunu_bench output:
- Get model weight bytes from
akunu_model_memory()(reported as “size” in bench output) - Compute decode time per token:
1.0 / decode_tok_s - Divide effective bandwidth by peak bandwidth
For example, if akunu_bench reports 85 tok/s on a 2341 MiB model on M2 Pro (200 GB/s):
effective_bw = 2341 * 1024 * 1024 / (1/85) = 2341 * 1.0485e6 * 85 = 208.7 GB/s
utilization = 208.7 / 200 = 104.3%
Wait, over 100%? This happens because the System Level Cache (SLC) provides additional effective bandwidth for data that fits or partially fits in the cache hierarchy. The SLC on Apple Silicon can add 20-40% of effective bandwidth for workloads with good temporal locality.4 akunu’s chain decode exploits this: when processing 64-128 tokens sequentially through each layer, the weight data loaded for token N is still in cache for token N+1.
Identifying Common Bottlenecks
Here is a diagnostic flowchart based on what the profiling tools reveal:
Bottleneck: Low Decode tok/s
Is bandwidth utilization > 70%?
├── YES: You are near optimal for this chip/model combo.
│ Only way to go faster: smaller model or faster chip.
│
└── NO: Something is leaving bandwidth on the table.
│
├── Are there GPU idle gaps in Metal System Trace?
│ ├── YES: CPU encoding is too slow.
│ │ Check: is chain_decode_chunk large enough?
│ │ Check: are you using profiled decode by mistake?
│ │
│ └── NO: Kernels are suboptimal.
│ Use akunu_profile to find the slowest kernel.
│ Common culprits:
│ - Attention kernel with very long KV cache
│ - Logit projection (dim -> vocab_size GEMV, large N)
│ - Unoptimized dtype (Q5_K, Q3_K lack wide variants)
│
└── Is memory usage near system limits?
├── YES: Memory pressure causes swapping. Reduce max_context
│ or use a smaller quantization.
└── NO: Check thermal state (sysctl machdep.xcpm.cpu_thermal_level)
Bottleneck: High TTFT
TTFT is prefill time plus one decode step. If TTFT is high:
Is the prompt very long (>1000 tokens)?
├── YES: Prefill is doing large GEMMs. Check:
│ - Is prefill chunked? (akunu chunks at max_prefill_chunk = 4096)
│ - Are GEMM kernels using simd_matrix operations?
│ - For Q4_0/Q8_0, are the GEMM kernels the quantized variants?
│
└── NO: Short prompt but still slow?
Check if model loading is included in the measurement.
akunu_load_model() compiles PSOs and builds the dispatch table
on first call. Subsequent calls reuse cached state.
Bottleneck: Attention Dominating at Long Context
As context grows, the attention kernel’s cost scales linearly with KV cache length. At some point it overtakes the GEMVs:
| Context Length | Attention % of Decode (typical 4B model) |
|---|---|
| 128 | 3-5% |
| 512 | 8-12% |
| 2048 | 20-30% |
| 4096 | 35-50% |
If attention is your bottleneck, the options are:
- Reduce
max_contextto avoid over-allocating KV cache - Use a model with GQA (fewer KV heads = less memory traffic in attention)
- Wait for akunu to implement paged attention or sliding window eviction
Comparing Against llama.cpp and MLX
Benchmarking against other frameworks is valuable both for validating your measurements and for identifying optimization opportunities. Here is how to set up fair comparisons:
llama.cpp Comparison
Use llama-bench with matching parameters:
# llama.cpp
./llama-bench -m model.gguf -p 512 -n 128 -r 5
# akunu
./akunu_bench model.gguf -p 512 -n 128 -r 5
Key differences to account for:
| Factor | llama.cpp | akunu |
|---|---|---|
| Backend | Metal (via ggml-metal) | Metal (direct MSL) |
| Decode strategy | Single token per GPU submission | Chain decode (64-128 tokens per submission) |
| KV cache layout | Per-layer, row-major | Per-layer, head-major [n_kv_heads, max_seq, head_dim] |
| Weight fusion | None | Gate+Up fused on Pro+ chips (SLC > 16MB) |
| GEMV kernels | ggml generic + Metal shaders | Custom per-dtype Metal shaders with chip-specific tuning |
In practice, akunu’s decode throughput is typically 1.1-1.5x llama.cpp’s on the same hardware, primarily due to chain decode reducing GPU idle time and chip-specific GEMV tuning.5
MLX Comparison
MLX (Apple’s machine learning framework) uses a different approach:
# MLX benchmark
import mlx.core as mx
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Qwen3-4B-4bit")
# ... time the generation
Key differences:
| Factor | MLX | akunu |
|---|---|---|
| Language | Python + C++ + Metal | C++ + Metal |
| Weight format | SafeTensors with MLX quantization | GGUF or MLX SafeTensors |
| Graph compilation | JIT traced graphs | Pre-compiled dispatch table |
| Quantization | Group quantized (group_size=64) | GGUF block quant or MLX group quant |
| Overhead | Python dispatch + JIT | Near-zero (POD struct iteration) |
MLX’s Python overhead is minimal for long generations but can be significant for TTFT on short prompts. akunu’s pre-compiled dispatch table avoids any per-token overhead beyond the raw GPU dispatch cost.
What Fair Comparison Looks Like
For a fair comparison, ensure:
- Same model weights – or at least the same effective bits-per-weight. Q4_0 GGUF (4.5 effective bpw) is roughly comparable to MLX 4-bit with group_size=64.
- Same prompt and generation length – especially for prefill comparison, since prefill scales nonlinearly with prompt length.
- Same sampling – use greedy (temperature=0) to eliminate sampling variance.
- Warm start – run at least one throwaway generation before timing to ensure Metal shader compilation is complete and caches are warm.
- Same hardware – obvious, but worth stating. The M3 Pro and M2 Pro have the same 200 GB/s bandwidth but different GPU architectures, which affects compute-bound workloads like prefill.
Profiling Checklist
When you sit down to profile an akunu deployment, here is the sequence:
- Baseline: Run
akunu_benchto establish prefill tok/s, decode tok/s, and TTFT - Bandwidth check: Compute bandwidth utilization from the bench numbers. If >70%, you are in good shape.
- Kernel breakdown: Run
akunu_profileto identify which kernels dominate. The top 3-5 kernels by GPU time are your optimization targets. - System-level: If you suspect CPU-GPU sync issues, use Metal System Trace in Instruments to check for GPU idle gaps.
- Compare: Run the same model on llama.cpp and/or MLX to validate your numbers and identify framework-level differences.
- Thermal: For sustained workloads, monitor thermal throttling. Apple Silicon aggressively throttles GPU frequency under thermal pressure, which can reduce throughput by 20-40% on fanless MacBooks.
Advanced: Custom Profiling with the C API
The akunu_profile_decode_step() C API function is available for integration into your own profiling harness:
// Allocate timing buffer: n_layers + 3 entries
// [embedding, norm, layer0, layer1, ..., layerN-1, logit, argmax]
float timing[512];
int n = akunu_profile_decode_step(model, token_id, position, timing, 512);
for (int i = 0; i < n; i++) {
printf("%s: %.3f ms\n", akunu_profile_label(model, i), timing[i]);
}
Each entry corresponds to a dispatch command in the DispatchTable. The labels are stored in a parallel DispatchLabel array (cold data, separate from the hot command array) so that profiling metadata does not pollute the cache lines used by the decode inner loop.
The profiling works by running each dispatch command in its own MTLCommandBuffer and reading back GPUStartTime / GPUEndTime. This gives microsecond-accurate per-kernel GPU timing, but at the cost of massive overhead from the per-kernel command buffer synchronization. You would never use this in production – it is purely a diagnostic tool.
Summary
| Tool | When to Use | Output |
|---|---|---|
akunu_bench | Quick throughput comparison | Prefill tok/s, decode tok/s (markdown table) |
akunu_benchmark | End-to-end with real prompts | TTFT, prefill/decode speed at multiple prompt lengths |
akunu_profile | Identifying kernel bottlenecks | Per-kernel GPU time breakdown, sorted by cost |
| Metal System Trace | CPU-GPU sync analysis | Timeline of command buffer submissions and GPU execution |
| GPU Counters | Hardware utilization | Occupancy, bandwidth, ALU utilization |
| Roofline analysis | Understanding theoretical limits | Whether you are memory-bound or compute-bound |
The fundamental insight for Apple Silicon LLM inference is that decode is memory-bound and will remain so for the foreseeable future. The job of the profiler is not to find ways to make the GPU compute faster – it is to find the places where you are wasting bandwidth or leaving the GPU idle. Chain decode, weight fusion, and chip-specific GEMV tuning are all strategies that akunu uses to close the gap between measured and theoretical bandwidth, and the profiling tools described in this chapter are how you verify that those strategies are working.
-
On Apple Silicon, each
MTLCommandBuffercommit-and-wait cycle costs approximately 30-80 microseconds of CPU overhead. At 80+ tok/s, a 50us overhead per token adds up to 4ms per second – roughly 5% throughput loss just from synchronization. ↩ -
Metal’s GPU timing (
GPUStartTime/GPUEndTimeonMTLCommandBuffer) measures the time the command buffer was executing on the GPU. For a single kernel this is accurate, but for a command buffer containing hundreds of dispatches, you only get the total. Apple’s GPU Timeline in Instruments provides per-dispatch timing, but requires running inside Xcode. ↩ -
Williams, S., Waterman, A., & Patterson, D. (2009). “Roofline: an insightful visual performance model for multicore architectures.” Communications of the ACM, 52(4), 65-76. See https://doi.org/10.1145/1498765.1498785. ↩
-
Actual SLC sizes estimated in akunu’s
ChipConfig: 8 MB (M1/M2/M3 base), 16 MB (M4 base), 24 MB (M1/M2/M3 Pro), 32 MB (M4 Pro), 48 MB (Max), 96 MB (Ultra). These are not published by Apple but inferred from performance measurements and die analysis. ↩ -
This comparison is for the Metal backend specifically. llama.cpp supports many backends (CUDA, Vulkan, CPU) and architectures; akunu targets Apple Silicon exclusively, which allows tighter optimization. ↩
Architectural Decision Records
Every codebase is the sum of its decisions. Some of those decisions are obvious in retrospect; others look arbitrary unless you know the alternatives that were considered and rejected. This chapter documents ten key architectural decisions in akunu using the ADR (Architectural Decision Record) format.1 For each decision, we state the problem, enumerate the options considered, record the decision, and explain the rationale.
If you are contributing to akunu, modifying it for a different platform, or designing your own inference engine, these records tell you not just what was chosen but why – and more importantly, what would need to change if the underlying assumptions shifted.
ADR-1: Dispatch Table vs. Dynamic Dispatch
Problem
An LLM forward pass consists of dozens of GPU kernel dispatches per layer: embedding lookup, normalization, projections (GEMV/GEMM), RoPE, attention, activation, and residual adds. The engine needs a way to describe and execute this sequence. The two broad approaches are:
- Dynamic dispatch: At each step, the engine code decides which kernel to call, sets up buffer bindings, and dispatches. This is the “interpreter” approach.
- Static dispatch table: Pre-compile the entire forward pass into a flat array of dispatch commands at model load time. At inference time, just iterate the array.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Virtual method per layer | Each layer is a C++ object with a forward() method that encodes its own kernels | Clean OOP design, easy to understand | Virtual call overhead per dispatch, cache-unfriendly vtable chasing, hard to batch across tokens |
| B. Interpreter loop | A function that walks the model config and emits dispatch commands at each step | No up-front cost, flexible | Per-token overhead from branching on architecture, dtype, chip config; hard to batch |
| C. Pre-compiled dispatch table | Build a std::vector<DispatchCmd> once during init; replay it for each token | Zero per-token decision overhead, trivially batchable, cache-dense | Higher init cost, patching needed for dynamic fields (position, KV length) |
Decision
Option C: Pre-compiled dispatch table.
Rationale
The dispatch table approach was chosen because the forward pass structure is identical for every token. The only things that change between tokens are:
- The token embedding index (patched via
PATCH_TOKEN_OFFSET) - The position for RoPE and KV cache writes (patched via
PATCH_POSITION) - The KV sequence length for attention (patched via
PATCH_KV_SEQ_LEN)
Everything else – pipeline state objects, buffer bindings, threadgroup sizes, parameter structs – is invariant. Building all of this once and replaying it N times is the obvious optimization.
The DispatchCmd struct is a POD (Plain Old Data) type with no pointers to chase, no virtual calls, and no heap allocations beyond the vector itself. At 64 bytes for inline parameters plus fixed-size buffer arrays, it fits neatly in cache lines. The encode_chain() function in dispatch_table.h is a tight double loop:
for each token in [0, count):
for each command in table.commands:
set pipeline, set buffers, patch params, dispatch
This is the hot path. It runs once per chain decode chunk (64-128 tokens). The inner loop body is branch-free except for the patch type switch, which the compiler can lower to a jump table.
Consequences
- Pro: Chain decode became trivial to implement. Batching N tokens is just calling
encode_chain()withcount=N. - Pro: Profiling labels are stored in a parallel
DispatchLabelvector (cold data), keeping the hot command array dense. - Con: Adding a new architecture requires a new
build_dispatch_table()path intable_builder.h. The ArchDescriptor (ADR-3) mitigates this by making most architecture differences data-driven rather than code-driven. - Con: Dynamic control flow (e.g., early exit, mixture-of-experts routing) is harder to express. If akunu ever supports MoE models, the dispatch table design would need extension.
ADR-2: C API vs. C++
Problem
Akunu needs a public API for applications (CLI tools, Swift apps, servers) to load models, tokenize text, and run inference. The API design affects language binding ergonomics, ABI stability, and the mental model for users.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. C++ class API | class AkunuModel { ... } with methods | Natural for C++ users, can use RAII, templates | C++ ABI is fragile across compilers/versions, hard to bind from Swift/Python/Rust |
| B. C API with opaque handles | akunu_model_t as void*, free functions | Stable ABI, trivial to bind from any language, no name mangling | Verbose, no RAII, error handling via return codes or thread-local strings |
| C. C API wrapping C++ internals | Same as B, but the implementation uses C++ internally | Best of both: stable API surface, modern implementation | Thin translation layer between C and C++ |
Decision
Option C: C API wrapping C++ internals.
Rationale
The primary consumer of akunu’s API is the Swift binding layer (CAkunu/shim.c), which needs a C-compatible interface. Swift can import C headers directly via the Clang importer, but C++ interop (even with Swift 5.9+) is limited and fragile. A pure C API with opaque void* handles is the safest choice.
The actual API surface, visible in include/akunu/akunu.h, follows a consistent pattern:
- Lifecycle:
akunu_load_model()/akunu_free_model() - Info:
akunu_get_config(),akunu_model_memory() - Tokenization:
akunu_encode(),akunu_decode_token() - Generation:
akunu_generate(),akunu_chain_decode(),akunu_generate_continue() - Profiling:
akunu_profile_decode_step(),akunu_profile_label() - Error:
akunu_get_error()(thread-local)
All structs passed across the API boundary (AkunuModelConfig, AkunuGenerationStats, AkunuSamplingConfig) are defined in types.h as C-compatible POD types with fixed-width integer fields.
Consequences
- Pro: The Swift package (
Sources/CAkunu) imports the C header directly with zero bridging code beyond a thin shim. - Pro: The API is trivially bindable from Python (ctypes/cffi), Rust (bindgen), and any other language with C FFI.
- Pro: ABI stability – the library can be updated without recompiling consumers, as long as the C function signatures do not change.
- Con: No RAII for model handles. Forgetting
akunu_free_model()leaks GPU memory. The Swift binding wraps this in a class withdeinit. - Con: Error messages are thread-local strings, which is less ergonomic than exceptions or Result types.
ADR-3: Data-Driven Architecture Descriptors
Problem
Akunu supports multiple model architectures: LLaMA, Qwen3, Gemma, Gemma3, Whisper, and BERT. Each architecture has differences in activation functions, normalization placement, RoPE style, embedding scaling, and more. The question is how to handle these differences without littering the codebase with if (arch == "llama") ... else if (arch == "gemma") ... branches.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Architecture-specific subclasses | class LlamaModel : public Model, class GemmaModel : public Model | Clean separation of concerns | Virtual dispatch overhead, code duplication across architectures, adding a new arch means a new class |
| B. If/else chains on architecture name | Check cfg.architecture string throughout the codebase | Simple, no abstraction overhead | Scatters architecture logic everywhere, easy to miss a branch, N*M combinatorial explosion |
| C. Data-driven descriptor struct | An ArchDescriptor POD struct that captures all arch-specific behavior as data fields | Single source of truth, no branching in hot path, trivial to add new architectures | Must anticipate all possible variation dimensions up front |
Decision
Option C: Data-driven ArchDescriptor.
Rationale
The key insight is that the differences between architectures are almost entirely parametric, not structural. LLaMA and Gemma have the same transformer skeleton; they differ in:
- Activation function: SiLU (LLaMA) vs GELU (Gemma)
- Embedding scaling: none (LLaMA) vs
sqrt(dim)(Gemma) - QK norm: no (LLaMA) vs yes (Qwen3, Gemma)
- Post-attention/FFN norms: no (LLaMA) vs yes (Gemma)
- RoPE style: interleaved (LLaMA) vs NeoX/split-half (Qwen3, Gemma)
- Tied embeddings: no (LLaMA) vs yes (Qwen3, Gemma)
All of these can be captured as fields in a struct. The ArchDescriptor in src/core/arch_descriptor.h has fields like activation_kernel, embedding_scale, has_qk_norm, rope_kernel, and tie_embeddings. Factory functions (arch_llama(), arch_qwen3(), arch_gemma(), etc.) return pre-filled descriptors. The arch_from_config() function maps GGUF metadata strings to the right factory.
The build_dispatch_table() function in table_builder.h reads from the ArchDescriptor and never branches on architecture name. Adding support for a new LLaMA variant (say, Mistral with sliding window attention) is typically a one-line change: modify an existing factory or add a new one.
Consequences
- Pro: Adding Qwen3 support required writing
arch_qwen3()(4 lines that override 3 fields from the LLaMA defaults) and zero changes to the table builder or decode path. - Pro: The hot path (dispatch table replay) is completely architecture-agnostic. The architecture was “compiled away” during init.
- Con: Truly novel architectures (e.g., mixture of experts, state-space models) may not fit the descriptor model and would require structural changes.
- Con: Encoder-decoder models (Whisper) stretch the descriptor with fields like
is_encoder_decoder,has_cross_attention,has_conv_frontendthat are irrelevant for decoder-only models.
ADR-4: Precomputed RoPE via Fused Kernel
Problem
Rotary Position Embeddings (RoPE) apply a rotation to the Q and K vectors based on their position in the sequence. The rotation frequencies are computed as theta^(-2i/d) for each dimension pair i. This computation involves transcendental functions (sin, cos) which are expensive even on GPU hardware.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Compute RoPE frequencies per-token | Each decode step computes sin/cos from theta and position | Simple, no precomputation needed | Redundant transcendental function calls for every token |
| B. Precompute frequency table on CPU | Build a [max_seq_len, head_dim] table of sin/cos values at init | Amortizes transcendental cost | Large table (max_seqhead_dim4 bytes), wastes memory for short contexts |
| C. Precompute frequency divisors | Store [head_dim/2] frequency divisors; compute position*freq in kernel | Tiny table (head_dim/2 floats), position multiply is cheap | Slightly more complex kernel |
| D. Fuse RoPE with QKV projection + KV cache write | Single kernel: GEMV -> RoPE rotate Q,K -> write K,V to cache | Eliminates 2-3 separate kernel dispatches per layer | Complex kernel, harder to debug |
Decision
Option D: Fused RoPE+QKV+KV-write kernel, with precomputed frequency divisors (Option C) as the frequency source.
Rationale
In a non-fused approach, each transformer layer during decode requires:
- GEMV for Q projection
- GEMV for K projection
- GEMV for V projection
- RoPE on Q
- RoPE on K
- KV cache write for K
- KV cache write for V
That is 7 kernel dispatches. The fused kernel rope_qkv_write_f16 (or rope_neox_qkv_write_f16 for NeoX-style) combines steps 4-7 into a single dispatch. Combined with QKV fusion (fusing the three GEMV projections into a single GEMV that writes to a contiguous [q_dim + 2*kv_dim] buffer), the total drops from 7 dispatches to 2 (one fused GEMV, one fused RoPE+KV-write).
The RoPEQKVWriteParams struct captures all the parameters this fused kernel needs:
| Field | Purpose |
|---|---|
n_kv_heads | Number of KV heads (for GQA) |
head_dim | Elements per head |
max_seq_len | KV cache dimension for stride computation |
pos | Current position (patched per token in chain decode) |
theta | RoPE base frequency |
n_heads | Number of Q heads |
k_elem_offset | Byte offset to K section in QKV buffer |
v_elem_offset | Byte offset to V section in QKV buffer |
freq_scale | Linear RoPE scaling factor (1.0 = no scaling) |
The rope_freqs field in ArchDescriptor stores precomputed frequency divisors when the model provides them (some GGUF files include rope_freqs metadata). Otherwise, the kernel computes frequencies from theta directly using the standard formula.
Consequences
- Pro: Reduces per-layer dispatch count from 7 to 2, saving ~5 kernel launch overheads per layer per token. For a 32-layer model in chain decode (128 tokens), this eliminates 5 * 32 * 128 = 20,480 dispatch commands per chunk.
- Pro: Better memory access pattern. The fused kernel reads QKV once and writes K/V to cache in the same pass, improving cache utilization.
- Con: Two RoPE kernel variants (interleaved and NeoX) must be maintained, each with fused and standalone versions.
- Con: The fused kernel has more parameters (9 fields in
RoPEQKVWriteParams) and more complex dispatch geometry.
ADR-5: Chain Decode
Problem
The naive approach to autoregressive decoding is: for each token, encode one forward pass into a Metal command buffer, commit it, wait for completion, read back the result, and feed it to the next step. This creates a CPU-GPU synchronization point per token, and each sync costs 30-80 microseconds of overhead.2
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. One token per command buffer | Standard approach: encode, commit, wait, repeat | Simple, argmax result available immediately | ~50us sync overhead per token, GPU idle during CPU readback |
| B. Speculative multi-command buffer | Encode N tokens speculatively, verify in batch | Amortizes sync cost | Requires draft model or prediction, complex verification logic |
| C. Chain decode in single command buffer | Encode N identical forward passes back-to-back, patching only position and token offset. Argmax output of token i feeds as input to token i+1 via GPU buffer. | Single sync per chunk, GPU stays 100% busy | Must use greedy decoding (argmax). Non-greedy requires CPU-side sampling between tokens. |
Decision
Option C: Chain decode with single command buffer.
Rationale
The key observation is that for greedy decoding (temperature=0), the next token is determined entirely on the GPU (argmax of logits). There is no need to read the result back to the CPU between tokens. The encode_chain() function simply repeats the dispatch table N times, patching:
buffers[0].offset = tok * 4for embedding lookup (reads from token_ids buffer)param.pos = start_position + tokfor RoPE and KV cacheparam.kv_seq_len = start_position + tok + 1for attention
The argmax kernel writes its result to token_ids[tok + 1], and the next iteration’s embedding lookup reads from that same buffer. The entire chain runs as one GPU submission with zero CPU intervention.
The chunk size (chain_decode_chunk in ChipConfig) is tuned per chip:
| Chip Class | Chunk Size | Rationale |
|---|---|---|
| M1/M2/M3 base | 64 | Smaller GPU, less command buffer memory |
| M3 Pro | 96 | More GPU cores, but older command processor |
| M4 family, Max/Ultra | 128 | Improved command processor, higher bandwidth |
Consequences
- Pro: Eliminates ~50us * N sync overhead per chunk. For 128 tokens at 80 tok/s, this saves ~6.4ms – an 8% throughput improvement.
- Pro: GPU utilization approaches 100% within a chunk. Metal System Trace shows a continuous block of GPU activity with no idle gaps.
- Con: Only works for greedy (argmax) decoding. Non-greedy decoding (temperature > 0, top-k, top-p) requires
akunu_decode_step()with per-token CPU-GPU synchronization. - Con: The KV cache must be pre-sized for the full chunk. If the context window fills mid-chunk, the chain must terminate early.
- Con: Error recovery is harder – if a token generates an EOS mid-chain, the remaining tokens are wasted work.
ADR-6: GPU Gumbel-Max vs. CPU Sampling
Problem
When temperature > 0, the model needs to sample from the logit distribution rather than take the argmax. Sampling involves: (1) applying temperature scaling, (2) optionally applying repetition penalty, (3) computing probabilities (softmax), and (4) drawing a random sample. Where should this happen – CPU or GPU?
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. CPU sampling | Read logits back from GPU, sample on CPU | Full flexibility (top-k, top-p, min-p, grammar constraints), easy to implement | Requires GPU-to-CPU data transfer of entire logit vector (vocab_size * 2 bytes), breaks chain decode |
| B. GPU argmax only | Keep greedy on GPU, fall back to CPU for non-greedy | Simple, no sampling complexity on GPU | Non-greedy is slower due to sync overhead |
| C. GPU Gumbel-max trick | Add Gumbel noise to logits on GPU, then take argmax. Mathematically equivalent to sampling from the softmax distribution.3 | Keeps everything on GPU, compatible with chain decode | Limited to temperature sampling; top-k/top-p require sorting which is expensive on GPU |
Decision
Option A (CPU sampling) for non-greedy, Option B (GPU argmax) for greedy. GPU temperature scaling and repetition penalty are provided as optional GPU-side preprocessing (via akunu_gpu_temperature_scale() and akunu_gpu_repetition_penalty()), but the actual sampling decision happens on the CPU.
Rationale
The Gumbel-max trick (Option C) was prototyped but ultimately not adopted as the default for several reasons:
-
Grammar-constrained decoding requires masking invalid tokens before sampling. The xgrammar integration runs on the CPU and produces a token bitmask. Applying this mask on the GPU would require uploading it per token, which adds overhead that negates the chain decode benefit.
-
Top-k and top-p sampling require sorting or partial sorting of the logit vector, which is not efficient on Apple’s GPU for the vocabulary sizes used by modern LLMs (32K-128K). A GPU radix sort for 128K elements would consume more time than just reading the logits back to the CPU.
-
Sampling with temperature=0 (greedy) accounts for the majority of use cases in benchmarks and many production deployments. Chain decode works perfectly for greedy mode.
-
The logit readback cost is bounded: for a 128K vocabulary at FP16, the transfer is 256 KB – well within the UMA zero-copy window on Apple Silicon. The “transfer” is really just a cache flush, not a DMA copy.
Consequences
- Pro: Full sampling flexibility (temperature, top-k, top-p, min-p, repetition penalty, grammar constraints) without GPU-side complexity.
- Pro: Greedy mode gets the full chain decode benefit with zero overhead.
- Con: Non-greedy decoding cannot use chain decode and pays the per-token sync cost (~50us per token).
- Con: The GPU-side temperature and repetition penalty kernels (
TemperatureScaleParams,RepetitionPenaltyParams) are currently only used when the user explicitly calls the low-level API; the high-levelakunu_generate()does sampling on the CPU.
ADR-7: Dual Format Support (GGUF + MLX SafeTensors)
Problem
The Apple Silicon LLM ecosystem has two dominant weight formats:
- GGUF: The format from llama.cpp. Block-quantized weights with rich metadata. Supported by virtually every open-source LLM tool.
- MLX SafeTensors: The format from Apple’s MLX framework. Group-quantized weights in SafeTensors container with JSON config. Growing ecosystem, especially for Apple-optimized models.
Should akunu support one or both?
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. GGUF only | Rely on the GGUF ecosystem | Largest model library, well-understood format | Misses MLX-specific models, some newer models ship MLX-first |
| B. MLX only | Align with Apple’s own framework | Native Apple ecosystem, simpler quantization | Misses the vast GGUF library, less community tooling |
| C. Both via WeightProvider abstraction | Unified interface that wraps either backend | Access to both ecosystems | Two code paths to maintain, name mapping complexity |
Decision
Option C: Both formats via WeightProvider.
Rationale
Users should not have to choose a format. If they have a GGUF file from Hugging Face, it should work. If they have an MLX model from mlx-community, it should also work. The WeightProvider class in src/weight/weight_provider.h detects the format at load time:
- If the path is a directory or ends in
.safetensors-> MLX format - Otherwise -> GGUF format
Both backends expose the same interface: get_tensor(), get_dtype(), has_tensor(), get_config(), plus metadata accessors. The GGUF backend (WeightStore) wraps the GGUF parser; the MLX backend (MLXWeightStore) wraps the SafeTensors parser plus a name mapping layer (MLX uses HuggingFace naming conventions like model.layers.{n}.self_attn.q_proj.weight, which akunu maps to canonical names like layers.{n}.attention.q.weight).
The dtype system uses internal codes: GGUF dtypes 0-30 for standard GGUF types, plus synthetic codes 99-102 for MLX quantized formats (MLX Q3, Q4, Q6, Q8). The DTypeDescriptor table in dtype_descriptor.h maps each code to the appropriate GEMV, GEMM, and embedding kernels.
Consequences
- Pro: Users can load any model from either ecosystem with the same
akunu_load_model()call. - Pro: The same Metal kernels are used for both formats where the quantization is compatible (e.g., FP16 weights from either format use the same
gemv_f16kernel). - Con: MLX quantization is group-based (group_size=64 typically) while GGUF is block-based (block_size=32 for Q4_0). Different dequantization logic in the Metal kernels.
- Con: The MLX name mapping table (
kMLXRulesinmlx_weight_store.h) must be updated when new architectures use different naming conventions.
ADR-8: Fused Kernels
Problem
A transformer layer involves many small operations that are individually simple but collectively expensive due to kernel launch overhead and redundant memory traffic. For example, the FFN block in a SwiGLU model does:
- GEMV:
gate = W_gate @ x - GEMV:
up = W_up @ x - Elementwise:
act = silu(gate) * up - GEMV:
down = W_down @ act
Each of those first three steps reads and writes intermediate buffers. Can we fuse them?
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. No fusion | Separate kernels for every operation | Simple, composable, easy to debug | High kernel launch count, redundant memory traffic for intermediates |
| B. Full FFN fusion | Single kernel: gate+up GEMV -> SiLU -> down GEMV | Minimal memory traffic | Extremely complex kernel, hard to tune, loses flexibility |
| C. Selective fusion | Fuse operations where the benefit is clear: gate+up GEMV+activation, QKV+RoPE+KV-write | Good balance of performance and complexity | Must decide which fusions are worth the implementation cost |
Decision
Option C: Selective fusion.
Rationale
Akunu implements several fused kernels where profiling showed clear benefits:
| Fused Kernel | What It Fuses | Benefit |
|---|---|---|
gemv_q4_0_silu | Gate GEMV + SiLU activation + Up GEMV element multiply | Eliminates 2 intermediate buffer writes/reads, 1 fewer dispatch |
rope_qkv_write_f16 | RoPE rotation of Q,K + KV cache write for K,V | Eliminates 3-4 separate dispatches per layer (see ADR-4) |
gemv_kv_* (GEMVKVParams) | K or V projection GEMV + direct KV cache write | Eliminates separate KV cache copy kernel |
gemv_head_norm_* (GEMVHeadNormParams) | GEMV output + per-head RMSNorm | Eliminates separate norm dispatch for QK-norm models |
The fused SiLU kernels (gemv_q4_0_silu, gemv_mlx_q4_silu, etc.) are particularly valuable because the activation is applied during the GEMV accumulation, before the result is written to device memory. Each thread computes silu(gate_partial) * up_partial in registers, avoiding a round-trip through device memory for the intermediate gate and up buffers.
Weight fusion (gate+up weights concatenated into a single buffer) is a separate but related optimization. The ChipConfig::should_fuse_weights flag enables this on Pro+ chips where the SLC is large enough to benefit from reading the fused weight buffer sequentially.
Consequences
- Pro: 15-25% reduction in per-layer kernel count for decode, directly translating to lower dispatch overhead.
- Pro: Reduced memory traffic for intermediates, which matters on bandwidth-constrained base chips.
- Con: Each fused kernel variant must be written and tested for every supported dtype. The
gemv_*_silukernels exist for Q4_0, MLX Q3, MLX Q4, MLX Q6, and MLX Q8 – five implementations of the same fusion. - Con: The
DTypeDescriptortable must track both fused and unfused kernel names, adding complexity to the dtype lookup.
ADR-9: Ping-Pong Scratch Buffers
Problem
The transformer’s residual connection pattern means that each layer’s output is added to its input. The straightforward implementation allocates a new buffer for each intermediate result, but this wastes memory and forces unnecessary allocations in the hot path.
Options Considered
| Option | Description | Pros | Cons |
|---|---|---|---|
| A. Dynamic allocation | device.allocate() per forward pass for intermediates | Simple, no buffer management | Allocation in hot path, memory fragmentation, Metal allocator overhead |
| B. Single scratch buffer with offsets | One large buffer, sub-allocate with manual offset management | Minimal total memory | Complex offset bookkeeping, risk of aliasing bugs, hard to reason about lifetimes |
| C. Named ping-pong buffers | Pre-allocate fixed named buffers (h0, h1, residual, qkv, etc.) at model load. Alternate between h0 and h1 for residual accumulation. | Zero allocation in hot path, clear ownership, easy to debug | Fixed memory footprint regardless of actual usage |
Decision
Option C: Named ping-pong buffers via ScratchBuffers.
Rationale
The ScratchBuffers struct in src/cache/scratch.h pre-allocates all intermediate buffers during akunu_load_model(). The decode buffers are:
| Buffer | Size | Purpose |
|---|---|---|
h0 | dim * 2 bytes | Residual stream “ping” (FP16) |
h1 | dim * 2 bytes | Residual stream “pong” (FP16) |
residual | dim * 2 bytes | Norm output |
qkv | (q_dim + 2*kv_dim) * 2 bytes | Contiguous Q, K, V projections |
attn_out | max(q_dim, dim) * 2 bytes | Attention output |
post_norm | dim * 2 bytes | Post-norm temp (Gemma-style architectures) |
ffn_gate | ffn_dim * 4 bytes | Gate projection (2x for fused gate+up) |
ffn_up | ffn_dim * 2 bytes | Up projection |
ffn_act | ffn_dim * 2 bytes | Activation output |
logits | vocab_size * 2 bytes | Final logits |
token_ids | max_chain * 4 bytes | Chain decode token buffer |
The residual connection works by alternating between h0 and h1:
Layer input: h0
-> norm: residual = rmsnorm(h0)
-> attn: attn_out = attention(qkv_proj(residual))
-> add: h1 = h0 + attn_out (residual add)
-> norm: residual = rmsnorm(h1)
-> ffn: ffn_out = down(silu(gate) * up)
-> add: h0 = h1 + ffn_out (residual add)
Layer output: h0 (same buffer as input -- full circle)
Odd-numbered layers swap the roles: input from h1, output to h0. This “ping-pong” pattern means we never need more than two hidden-state-sized buffers for the entire forward pass, regardless of model depth.
A separate set of batch_* buffers handles prefill, where each buffer is scaled by prefill_chunk * dim to handle batched operations.
Consequences
- Pro: Absolutely zero memory allocation during inference. Every buffer is pre-allocated and reused.
- Pro: The
ScratchBuffersstruct is POD-like: a flat collection ofBufferhandles. The dispatch table references these buffers directly by name, making the data flow through the model explicit and debuggable. - Pro: Memory footprint is predictable and constant.
akunu_model_memory()can report the exact GPU memory usage at load time. - Con: The memory footprint is the maximum needed for any forward pass, even if most operations use less. For example,
ffn_gateis allocated atffn_dim * 4to support fused gate+up, even when weight fusion is disabled. - Con: Adding a new buffer for a new operation (e.g., a second attention output for cross-attention in Whisper) requires modifying the
ScratchBuffersstruct and thecreate()/destroy()methods.
ADR-10: Head-Major KV Cache Layout
Problem
The KV cache stores the key and value vectors for all previously processed tokens. During attention, the kernel needs to read all K vectors for a given head to compute dot products with the query, then read all V vectors for the same head to compute the weighted sum. The memory layout of the KV cache determines the access pattern and thus the memory bandwidth efficiency.
Options Considered
| Option | Description | Memory Layout | Access Pattern |
|---|---|---|---|
| A. Sequence-major | [max_seq_len, n_kv_heads, head_dim] | All heads for position 0, then all heads for position 1, … | Attention reads for one head are strided across positions |
| B. Head-major | [n_kv_heads, max_seq_len, head_dim] | All positions for head 0 (contiguous), then all positions for head 1, … | Attention reads for one head are contiguous |
| C. Paged | Small fixed-size pages, indirection table | Pages allocated on demand, pages for one head may be non-contiguous | Flexible memory management, but indirection overhead |
Decision
Option B: Head-major layout [n_kv_heads, max_seq_len, head_dim].
Rationale
During decode attention, each query head computes:
scores[t] = dot(Q[head], K[head][t]) for t in 0..kv_seq_len-1
output = sum(scores[t] * V[head][t])
With head-major layout, all K vectors for a given head are contiguous in memory: K[head] is a contiguous block of max_seq_len * head_dim FP16 values. The attention kernel can read this block sequentially, which maximizes memory bandwidth utilization on Apple Silicon’s memory controller.
With sequence-major layout (Option A), reading K vectors for one head would require striding by n_kv_heads * head_dim elements between positions, resulting in poor cache line utilization – you load a full cache line but only use head_dim / (n_kv_heads * head_dim) of it.
The KV stride is precomputed as kv_stride = max_seq_len * head_dim (elements between consecutive KV heads) and stored in the KVCache struct. The AttentionParams struct passes this to the attention kernel via the kv_stride field. A value of 0 means “use kv_seq_len * head_dim” (the dense case), which is useful when the KV cache is exactly filled.
The KVCacheWriteParams struct handles writing new K/V vectors to the correct position:
offset_in_buffer = head * kv_stride + pos * head_dim
This is a simple multiply-add, computed in the fused RoPE+KV-write kernel.
Consequences
- Pro: Contiguous memory access for attention reads, maximizing Apple Silicon’s memory bandwidth. This is the most important access pattern to optimize because attention cost grows linearly with context length.
- Pro: The GQA (Grouped Query Attention) pattern falls out naturally: Q heads 0..3 all read from KV head 0, which is a single contiguous block. No gather/scatter needed.
- Pro: KV cache shifting (for sliding window) is a simple
memmovewithin each head’s contiguous block. TheKVCacheShiftParamsstruct supports this. - Con: Memory is allocated for the full
max_seq_lenper head, even if the actual sequence is shorter. For a model withn_kv_heads=8, max_seq_len=4096, head_dim=128in FP16, each layer’s K cache is8 * 4096 * 128 * 2 = 8 MB. Over 32 layers, that is 512 MB for K alone (plus 512 MB for V). - Con: Paged attention (Option C) would use less memory for short sequences, but the indirection overhead and implementation complexity were deemed not worth it for the target use case (single-user inference on Apple Silicon with sufficient memory).
Summary: How the Decisions Fit Together
These ten decisions are not independent. They form an interlocking system:
+--------------------+
| ArchDescriptor(3) |---> drives table_builder
+--------------------+
|
v
+--------------------+ +-------------------+
| DispatchTable(1) |<--->| ScratchBuffers(9) |
+--------------------+ +-------------------+
| |
v v
+--------------------+ +-------------------+
| Chain Decode(5) | | KV Cache(10) |
+--------------------+ +-------------------+
|
v
+--------------------+
| Fused Kernels(8) |
| - RoPE+QKV(4) |
| - GEMV+SiLU |
+--------------------+
|
v
+--------------------+ +-------------------+
| C API(2) | | WeightProvider(7) |
+--------------------+ +-------------------+
| |
v v
+--------------------+ +-------------------+
| Sampling(6) | | GGUF + MLX dtypes |
+--------------------+ +-------------------+
The ArchDescriptor (3) feeds into the dispatch table builder. The dispatch table (1) references pre-allocated scratch buffers (9) and KV cache buffers (10), and embeds fused kernels (8, 4). Chain decode (5) replays the dispatch table with minimal patching. The C API (2) wraps all of this behind opaque handles. The WeightProvider (7) supplies weights in either GGUF or MLX format. And the sampling strategy (6) determines whether chain decode can be used (greedy) or falls back to per-token decode.
If you are extending akunu, this dependency graph tells you what you need to touch. Adding a new quantization format? Modify DTypeDescriptor (8/7) and add kernels. Adding a new architecture? Add an ArchDescriptor factory (3). Implementing paged attention? That affects KV cache (10), the dispatch table (1), and the attention kernel (8).
-
The ADR format was popularized by Michael Nygard. See “Documenting Architecture Decisions” (2011). The format used here is a simplified version: Problem, Options, Decision, Rationale, Consequences. See https://cognitect.com/blog/2011/11/15/documenting-architecture-decisions. ↩
-
Measured on M2 Pro:
MTLCommandBuffercommit + waitUntilCompleted averages 45us when the command buffer contains a trivial kernel. The overhead is in the driver and command processor, not the GPU itself. ↩ -
The Gumbel-max trick: if
g_i ~ Gumbel(0,1), thenargmax(log(p_i) + g_i) ~ Categorical(p). Sincelog(p_i) = logit_i / temperature - log(Z), and thelog(Z)term is constant across categories,argmax(logit_i / temperature + g_i)samples from the temperature-scaled distribution. ↩
Appendix A: Metal Shader Parameter Reference
This appendix is a complete reference for every parameter struct defined in backend/metal/kernels/ShaderTypes.h. These structs are shared between Metal shader code (MSL) and the C++/Swift host code. They are the contract between CPU and GPU: the host fills in the fields, binds the struct as a setBytes argument, and the shader reads from it.
Every struct in this file is padded to 16-byte alignment boundaries for Metal argument buffer compatibility. The comment at the top of ShaderTypes.h puts it plainly:
All structs are padded to 16-byte boundaries for Metal argument buffer alignment.
Any change to these structs must be mirrored in Sources/KernelStore/MetalTypes.swift to keep the Swift binding in sync.
How to Read the Tables
Each struct is documented with:
- Total size: The
sizeofthe struct in bytes - Alignment: The alignment requirement (always 16 bytes for Metal compatibility)
- Field table: Name, C type, byte offset, size, and notes
Byte offsets are calculated from the struct layout assuming standard C packing rules with the explicit padding fields (_pad0, _pad1, etc.) that akunu includes. The padding fields exist to ensure the struct is a multiple of 16 bytes and that fields after padding land on natural alignment boundaries.
Fields marked with “patched per-token” are dynamically modified during chain decode – the DispatchCmd::patch_type mechanism overwrites these fields at specific byte offsets for each token in the batch.
GEMMParams
General matrix-matrix multiplication parameters. Used by all simd_gemm_* and simd_gemm_small_* kernels during prefill.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
M | uint32_t | 0 | 4 | Rows of A / rows of output C |
N | uint32_t | 4 | 4 | Columns of B / columns of output C |
K | uint32_t | 8 | 4 | Columns of A / rows of B (contraction dimension) |
lda | uint32_t | 12 | 4 | Leading dimension of A (typically K for row-major) |
ldb | uint32_t | 16 | 4 | Leading dimension of B (typically N for row-major) |
ldc | uint32_t | 20 | 4 | Leading dimension of C (typically N for row-major) |
alpha | float | 24 | 4 | Scale factor: C = alpha * A @ B + beta * C |
beta | float | 28 | 4 | Accumulation factor: C = alpha * A @ B + beta * C |
Notes: The leading dimension fields (lda, ldb, ldc) allow non-contiguous matrix views. When matrices are contiguous row-major, lda = K, ldb = N, ldc = N. The alpha/beta fields support BLAS-style C = alpha*A@B + beta*C but in practice akunu always uses alpha=1.0, beta=0.0 (pure multiply, no accumulation).
ElementwiseParams
Parameters for element-wise kernels (add, multiply, activation functions applied to flat buffers).
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
count | uint32_t | 0 | 4 | Total number of elements to process |
_pad0 | uint32_t | 4 | 4 | Padding (unused) |
_pad1 | uint32_t | 8 | 4 | Padding (unused) |
_pad2 | uint32_t | 12 | 4 | Padding (unused) |
Notes: 12 bytes of explicit padding to reach 16-byte alignment. The kernel dispatches ceil(count / threadgroup_size) threadgroups. Each thread processes one element at index thread_position_in_grid.
AttentionParams
Parameters for the attention kernel. Handles both prefill (multi-token) and decode (single token) attention, including GQA (Grouped Query Attention) where n_kv_heads < n_heads.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
seq_len | uint32_t | 0 | 4 | Query sequence length (1 for decode, N for prefill) |
kv_seq_len | uint32_t | 4 | 4 | KV cache length (may differ from seq_len during decode). Patched per-token in chain decode. |
head_dim | uint32_t | 8 | 4 | Dimension per attention head |
n_heads | uint32_t | 12 | 4 | Number of query heads |
n_kv_heads | uint32_t | 16 | 4 | Number of key/value heads (GQA: n_kv_heads <= n_heads) |
scale | float | 20 | 4 | Attention scale factor: 1.0 / sqrt(head_dim) |
kv_stride | uint32_t | 24 | 4 | Elements between KV heads in cache: max_seq_len * head_dim. 0 = use kv_seq_len * head_dim |
q_stride | uint32_t | 28 | 4 | Elements between Q/O rows. 0 = n_heads * head_dim (contiguous) |
Notes: The kv_stride field encodes the head-major KV cache layout. For a cache shaped [n_kv_heads, max_seq_len, head_dim], the stride between heads is max_seq_len * head_dim elements. When kv_stride = 0, the kernel computes the stride from kv_seq_len * head_dim, which is the dense (no padding) case. The kv_seq_len field is patched per-token during chain decode using PATCH_KV_SEQ_LEN or PATCH_POS_AND_KV.
RMSNormParams
Parameters for RMSNorm (Root Mean Square Layer Normalization).1
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
dim | uint32_t | 0 | 4 | Vector dimension to normalize |
eps | float | 4 | 4 | Epsilon for numerical stability (typically 1e-5 or 1e-6) |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: The kernel computes rms = sqrt(sum(x[i]^2) / dim + eps) then out[i] = x[i] / rms * weight[i]. The dim field must match the weight buffer length. The threadgroup reduces to compute the sum-of-squares, then each thread normalizes its element.
LayerNormParams
Parameters for standard LayerNorm (used by Whisper).
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
dim | uint32_t | 0 | 4 | Vector dimension to normalize |
eps | float | 4 | 4 | Epsilon for numerical stability |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: Identical layout to RMSNormParams. The kernel is different: it computes mean and variance, then normalizes as (x - mean) / sqrt(var + eps) * weight + bias. The bias buffer is an additional binding not captured in the params struct.
SoftmaxParams
Parameters for the standalone softmax kernel.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
rows | uint32_t | 0 | 4 | Number of rows (independent softmax operations) |
cols | uint32_t | 4 | 4 | Number of columns (softmax dimension per row) |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: Each row is an independent softmax: out[r][c] = exp(x[r][c] - max_r) / sum_r(exp(x[r][c] - max_r)). Used for the final softmax in attention during prefill. During decode, the softmax is typically fused into the attention kernel.
RoPEParams
Parameters for the standalone RoPE (Rotary Position Embedding) kernel. Used during prefill when the fused RoPE+QKV+KV-write kernel is not applicable.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
seq_len | uint32_t | 0 | 4 | Number of positions to rotate |
head_dim | uint32_t | 4 | 4 | Dimension per head (rotation applies to pairs) |
n_heads | uint32_t | 8 | 4 | Number of heads in the input tensor |
pos_offset | uint32_t | 12 | 4 | Global position offset for decode step. Patched per-token. |
theta | float | 16 | 4 | RoPE base frequency (default 10000.0) |
row_stride | uint32_t | 20 | 4 | Elements between rows. 0 = n_heads * head_dim (contiguous) |
_pad0 | uint32_t | 24 | 4 | Padding (unused) |
_pad1 | uint32_t | 28 | 4 | Padding (unused) |
Notes: RoPE rotates dimension pairs (2i, 2i+1) by angle pos * theta^(-2i/head_dim). Two kernel variants exist: rope_f16 (interleaved, LLaMA-style) and rope_neox_f16 (split-half, NeoX-style where the first head_dim/2 elements are the “real” part and the second half is the “imaginary” part). The ArchDescriptor::rope_standalone field selects which variant to use.
EmbeddingParams
Parameters for the token embedding lookup kernel.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
num_tokens | uint32_t | 0 | 4 | Number of tokens to look up (1 for decode, N for prefill) |
dim | uint32_t | 4 | 4 | Embedding dimension per token |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: The kernel reads token IDs from a uint32 buffer and writes the corresponding embedding rows to the output buffer. For quantized embeddings (Q4_0, Q4_K, Q6_K, Q8_0, MLX formats), a specialized dequantizing embedding kernel is used that dequantizes on the fly and outputs FP16. The kernel name is selected by dtype_descriptor.h::embedding_kernel_for().
KVCacheWriteParams
Parameters for writing new key/value vectors into the KV cache.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
n_kv_heads | uint32_t | 0 | 4 | Number of KV heads |
head_dim | uint32_t | 4 | 4 | Dimension per head |
max_seq_len | uint32_t | 8 | 4 | Maximum sequence length (cache dimension) |
pos | uint32_t | 12 | 4 | Write position (single-token) or batch offset. Patched per-token. |
src_stride | uint32_t | 16 | 4 | Elements between rows in source. 0 = n_kv_heads * head_dim |
seq_len | uint32_t | 20 | 4 | Batch sequence length (1 for decode, N for prefill batch) |
_pad0 | uint32_t | 24 | 4 | Padding (unused) |
_pad1 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Writes into the head-major KV cache at cache[head][pos][0..head_dim-1]. The destination offset is computed as head * max_seq_len * head_dim + pos * head_dim. For batch writes during prefill, seq_len > 1 and the kernel writes seq_len consecutive positions starting at pos.
RoPEQKVWriteParams
Parameters for the fused kernel that applies RoPE rotation to Q and K, then writes K and V into the KV cache. This is the most complex parameter struct and the workhorse of the decode path.
Total size: 36 bytes | Alignment: 16 bytes (padded to 48 bytes in practice due to Metal alignment)
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
n_kv_heads | uint32_t | 0 | 4 | Number of KV heads |
head_dim | uint32_t | 4 | 4 | Dimension per head |
max_seq_len | uint32_t | 8 | 4 | KV cache max sequence length |
pos | uint32_t | 12 | 4 | Current sequence position. Patched per-token. |
theta | float | 16 | 4 | RoPE base frequency |
n_heads | uint32_t | 20 | 4 | Number of Q heads |
k_elem_offset | uint32_t | 24 | 4 | Element offset to K section in QKV buffer |
v_elem_offset | uint32_t | 28 | 4 | Element offset to V section in QKV buffer |
freq_scale | float | 32 | 4 | Linear RoPE scaling: 1/factor (1.0 = no scaling) |
Notes: This struct drives the fused rope_qkv_write_f16 and rope_neox_qkv_write_f16 kernels. The input is the contiguous QKV buffer [q_dim + 2*kv_dim] output by the QKV GEMV. The kernel:
- Reads Q elements, applies RoPE rotation, writes back to Q section (in-place)
- Reads K elements, applies RoPE rotation, writes to KV cache K buffer at
pos - Reads V elements (no rotation), writes to KV cache V buffer at
pos
The k_elem_offset and v_elem_offset fields tell the kernel where K and V start in the QKV buffer. For a model with q_dim=4096, kv_dim=1024: k_elem_offset = 4096, v_elem_offset = 4096 + 1024 = 5120. The freq_scale field supports extended context via linear RoPE scaling (e.g., freq_scale = 0.25 for 4x context extension).
KVCacheShiftParams
Parameters for shifting the KV cache contents (sliding window eviction).
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
n_kv_heads | uint32_t | 0 | 4 | Number of KV heads |
head_dim | uint32_t | 4 | 4 | Dimension per head |
max_seq_len | uint32_t | 8 | 4 | Cache max sequence length |
shift | uint32_t | 12 | 4 | Number of positions to shift left (evict oldest) |
new_len | uint32_t | 16 | 4 | New sequence length after shift |
_pad0 | uint32_t | 20 | 4 | Padding (unused) |
_pad1 | uint32_t | 24 | 4 | Padding (unused) |
_pad2 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Shifts cache contents by shift positions. Entries at positions [shift, shift+new_len) are moved to [0, new_len). This is a per-head memmove operation. Used when the KV cache fills up and the oldest tokens need to be evicted.
HeadNormParams
Parameters for per-head RMSNorm, used by architectures with QK normalization (Qwen3, Gemma).
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
head_dim | uint32_t | 0 | 4 | Elements per head |
n_heads | uint32_t | 4 | 4 | Number of heads to normalize |
seq_len | uint32_t | 8 | 4 | Sequence length (1 for decode, N for prefill) |
eps | float | 12 | 4 | Norm epsilon |
_pad0 | uint32_t | 16 | 4 | Padding (unused) |
_pad1 | uint32_t | 20 | 4 | Padding (unused) |
_pad2 | uint32_t | 24 | 4 | Padding (unused) |
_pad3 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Unlike standard RMSNorm which normalizes the entire hidden vector, per-head norm applies RMSNorm independently to each head_dim-sized slice. For Q normalization: Q[h] = rmsnorm(Q[h], q_norm_weight) for each head h. Each threadgroup handles one head.
GEMVHeadNormParams
Parameters for the fused GEMV + per-head RMSNorm kernel. Combines a matrix-vector multiply with per-head normalization in a single dispatch.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
N | uint32_t | 0 | 4 | Total output dimension (n_heads * head_dim) |
K | uint32_t | 4 | 4 | Input dimension |
head_dim | uint32_t | 8 | 4 | Elements per head |
n_heads | uint32_t | 12 | 4 | Number of heads |
eps | float | 16 | 4 | Norm epsilon |
_pad0 | uint32_t | 20 | 4 | Padding (unused) |
_pad1 | uint32_t | 24 | 4 | Padding (unused) |
_pad2 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Each threadgroup computes one output row (one head’s worth of elements) via GEMV, then applies RMSNorm to the result. This fuses Q = W_q @ x; Q[h] = rmsnorm(Q[h]) into a single kernel, eliminating the intermediate write of the un-normalized Q vector.
TemperatureScaleParams
Parameters for applying temperature scaling to the logits buffer.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
inv_temperature | float | 0 | 4 | Inverse temperature: 1.0 / temperature. Logits are multiplied by this value. |
count | uint32_t | 4 | 4 | Number of logit elements (vocabulary size) |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: Applies logits[i] *= inv_temperature in-place. Using inverse temperature (multiply instead of divide) avoids a per-element division in the shader. Called via akunu_gpu_temperature_scale() in the C API.
RepetitionPenaltyParams
Parameters for applying repetition penalty to logits.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
penalty | float | 0 | 4 | Repetition penalty factor (>1.0 penalizes, <1.0 encourages) |
n_tokens | uint32_t | 4 | 4 | Number of token IDs in the penalty list |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: For each token in the penalty list: if logit > 0, divide by penalty; if logit < 0, multiply by penalty. The token ID list is passed as a separate buffer binding. Called via akunu_gpu_repetition_penalty() in the C API.
MLXParams
Parameters for MLX-format quantized GEMV and embedding kernels. MLX uses group quantization: weights are packed with bits-per-value in groups of group_size, with FP16 scale and bias per group.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
M | uint32_t | 0 | 4 | Batch size (1 for GEMV, num_tokens for batch embedding) |
N | uint32_t | 4 | 4 | Output dimension (weight rows / vocab_size) |
K | uint32_t | 8 | 4 | Input dimension (unpacked element count) |
group_size | uint32_t | 12 | 4 | Quantization group size (typically 64) |
bits | uint32_t | 16 | 4 | Bits per quantized value (3, 4, 6, or 8) |
weight_bytes | uint32_t | 20 | 4 | Byte offset to scales section within weight buffer |
_pad0 | uint32_t | 24 | 4 | Padding (unused) |
_pad1 | uint32_t | 28 | 4 | Padding (unused) |
Notes: The MLX weight buffer layout is [packed_weights | scales | biases]. The weight_bytes field gives the byte offset where scales begin. Biases follow immediately after scales. The packed weight format differs by bit-width: 4-bit packs 8 values per uint32, 3-bit packs values with a more complex scheme, 8-bit uses one byte per value. The dequantization formula is: value = scale * (packed_int - bias) per group.
Conv1DParams
Parameters for 1D convolution, used by Whisper’s audio frontend.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
in_channels | uint32_t | 0 | 4 | Number of input channels |
out_channels | uint32_t | 4 | 4 | Number of output channels (filters) |
kernel_size | uint32_t | 8 | 4 | Convolution kernel width |
stride | uint32_t | 12 | 4 | Stride between convolution windows |
in_length | uint32_t | 16 | 4 | Input sequence length |
out_length | uint32_t | 20 | 4 | Output sequence length: (in_length + 2*padding - kernel_size) / stride + 1 |
padding | uint32_t | 24 | 4 | Zero-padding on each side of input |
_pad0 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Used for Whisper’s two-layer Conv1D frontend that processes mel spectrograms. The first conv has kernel_size=3, stride=1, padding=1; the second has kernel_size=3, stride=2, padding=1, downsampling the time dimension by 2x.
GEMVKVParams
Parameters for the fused GEMV + KV cache write kernel. Combines a projection GEMV with direct KV cache insertion.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
N | uint32_t | 0 | 4 | Output dimension (n_kv_heads * head_dim) |
K | uint32_t | 4 | 4 | Input dimension |
head_dim | uint32_t | 8 | 4 | Dimension per head |
max_seq_len | uint32_t | 12 | 4 | KV cache max sequence length |
pos | uint32_t | 16 | 4 | Write position in cache. Patched per-token. |
_pad0 | uint32_t | 20 | 4 | Padding (unused) |
_pad1 | uint32_t | 24 | 4 | Padding (unused) |
_pad2 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Each threadgroup computes a slice of the GEMV output and writes it directly to the KV cache buffer at the correct position, bypassing the intermediate scratch buffer. The cache write offset is (row / head_dim) * max_seq_len * head_dim + pos * head_dim + (row % head_dim), implementing the head-major layout.
Quick Reference Summary
| Struct | Size (bytes) | Used By | Patched Fields |
|---|---|---|---|
GEMMParams | 32 | simd_gemm_* (prefill) | None |
ElementwiseParams | 16 | Elementwise ops (add, residual) | None |
AttentionParams | 32 | Attention kernels | kv_seq_len |
RMSNormParams | 16 | rmsnorm_* | None |
LayerNormParams | 16 | layernorm_* (Whisper) | None |
SoftmaxParams | 16 | softmax_* | None |
RoPEParams | 32 | rope_f16, rope_neox_f16 | pos_offset |
EmbeddingParams | 16 | embedding_lookup_* | None |
KVCacheWriteParams | 32 | KV cache write kernels | pos |
RoPEQKVWriteParams | 36 | rope_qkv_write_f16, rope_neox_qkv_write_f16 | pos |
KVCacheShiftParams | 32 | KV cache shift kernel | None |
HeadNormParams | 32 | Per-head RMSNorm | None |
GEMVHeadNormParams | 32 | Fused GEMV + head norm | None |
TemperatureScaleParams | 16 | Temperature scaling | None |
RepetitionPenaltyParams | 16 | Repetition penalty | None |
MLXParams | 32 | MLX quantized GEMV/embedding | None |
Conv1DParams | 32 | Conv1D (Whisper frontend) | None |
GEMVKVParams | 32 | Fused GEMV + KV write | pos |
Alignment and Padding Rules
All structs follow these conventions:
-
16-byte total size minimum: Every struct is at least 16 bytes. Metal’s
setBytesrequires 16-byte aligned data for argument buffer compatibility. -
Explicit padding fields: Rather than relying on compiler-inserted padding (which varies across compilers and platforms), akunu includes explicit
_pad0,_pad1, etc. fields. This makes the layout identical whether compiled as C, C++, Objective-C++, or MSL. -
Natural alignment: All
uint32_tfields are at 4-byte-aligned offsets. Allfloatfields are at 4-byte-aligned offsets. No field crosses a natural alignment boundary. -
Cross-language compatibility: The same header (
ShaderTypes.h) is included by both Metal shaders (via#ifdef __METAL_VERSION__) and host code. The#ifdefguard switches between<metal_stdlib>and<simd/simd.h>for type compatibility, but the struct layouts are identical in both compilation contexts.
When adding a new parameter struct, follow the pattern: use uint32_t and float fields only, pad to a multiple of 16 bytes with explicit _padN fields, and add a // MARK: section comment with the struct size.
-
Zhang, B. & Sennrich, R. (2019). “Root Mean Square Layer Normalization.” NeurIPS 2019. RMSNorm computes
x / sqrt(mean(x^2) + eps) * weight, omitting the mean subtraction of standard LayerNorm. See https://arxiv.org/abs/1910.07467. ↩
Appendix B: Quantization Format Reference
This appendix is a quick-reference for every quantization format supported by akunu. It covers the GGUF block-quantized formats (from the ggml/llama.cpp ecosystem) and the MLX group-quantized formats (from Apple’s MLX framework). For each format, you get the block/group size, bytes per block, effective bits per weight, and the dequantization formula.
If you want the why behind these formats, see Chapter 7 (Quantization). This appendix is the what – a lookup table you can keep open while reading kernel code or debugging weight loading.
How to Read the Tables
- Block size: Number of weights packed together as a unit. GGUF formats use fixed block sizes (32 or 256). MLX formats use configurable group sizes (typically 64).
- Bytes per block: Total storage for one block, including quantized values, scales, and any auxiliary data.
- Bits per weight (bpw): Effective bits per weight element, computed as
8 * bytes_per_block / block_size. This is the number that determines model file size. - Dequant formula: How to reconstruct the floating-point value from the quantized representation.
GGUF Formats: Basic Quantization
These formats use a simple scheme: each block of weights shares one or two floating-point parameters (scale and optional minimum/zero-point).
| Format | GGUF Code | Block Size | Bytes/Block | bpw | Dequant Formula |
|---|---|---|---|---|---|
| F32 | 0 | 1 | 4 | 32.0 | value = raw_f32 |
| F16 | 1 | 1 | 2 | 16.0 | value = raw_f16 |
| Q4_0 | 2 | 32 | 18 | 4.5 | value = d * (q[i] - 8) where q[i] is a 4-bit unsigned int, d is FP16 scale |
| Q4_1 | 3 | 32 | 20 | 5.0 | value = d * q[i] + m where d is FP16 scale, m is FP16 minimum |
| Q5_0 | 6 | 32 | 22 | 5.5 | value = d * (q[i] - 16) where q[i] is a 5-bit unsigned int (4 low bits packed + 1 high bit), d is FP16 scale |
| Q8_0 | 8 | 32 | 34 | 8.5 | value = d * q[i] where q[i] is a signed 8-bit int, d is FP16 scale |
| BF16 | 30 | 1 | 2 | 16.0 | value = raw_bf16 (Brain Float 16: 8-bit exponent, 7-bit mantissa) |
Block Layout Details
Q4_0 (most common quantization in the GGUF ecosystem):
struct block_q4_0 { // 18 bytes total
half d; // 2 bytes: scale factor
uint8_t qs[16]; // 16 bytes: 32 x 4-bit values, packed in pairs
}; // bpw = 18*8/32 = 4.5
Each byte in qs holds two 4-bit values: the low nibble is element 2i, the high nibble is element 2i+1. Dequantization extracts the nibble, subtracts 8 (to center around zero), and multiplies by the scale d.
Q4_1 (asymmetric variant of Q4_0):
struct block_q4_1 { // 20 bytes total
half d; // 2 bytes: delta (scale)
half m; // 2 bytes: minimum
uint8_t qs[16]; // 16 bytes: 32 x 4-bit values, packed in pairs
}; // bpw = 20*8/32 = 5.0
The extra m (minimum) parameter means values are dequantized as d * q + m instead of d * (q - 8). This gives better accuracy when the weight distribution is not symmetric around zero.
Q5_0 (5-bit with high-bit extension):
struct block_q5_0 { // 22 bytes total
half d; // 2 bytes: scale
uint8_t qh[4]; // 4 bytes: 5th bit for each of 32 elements
uint8_t qs[16]; // 16 bytes: lower 4 bits, packed in pairs
}; // bpw = 22*8/32 = 5.5
The 5th bit for each element is stored separately in qh (packed as a uint32). To dequantize element i: extract the 4-bit value from qs, extract bit i from qh, combine to get a 5-bit unsigned int, subtract 16, multiply by d.
Q8_0 (8-bit, highest quality block quant):
struct block_q8_0 { // 34 bytes total
half d; // 2 bytes: scale
int8_t qs[32]; // 32 bytes: signed 8-bit values
}; // bpw = 34*8/32 = 8.5
Simple and fast to dequantize: value = d * qs[i]. The 0.5 extra bpw overhead comes from the FP16 scale shared across 32 elements.
GGUF Formats: K-Quantization
K-quant formats use a two-level quantization scheme with super-blocks of 256 elements. Each super-block contains sub-blocks with their own scales, plus a super-block-level scale that controls the magnitude of the sub-block scales. This hierarchical approach gives better accuracy at the same bit width compared to basic formats.1
| Format | GGUF Code | Block Size | Bytes/Block | bpw | Description |
|---|---|---|---|---|---|
| Q2_K | 10 | 256 | 84 | 2.625 | 2-bit values + 4-bit scale/min per 16-element sub-block + super-block scale |
| Q3_K | 11 | 256 | 110 | 3.4375 | 2-bit base + 1 high bit + 6-bit packed scales + super-block scale |
| Q4_K | 12 | 256 | 144 | 4.5 | 4-bit values + 6-bit scales/mins + super-block scale |
| Q5_K | 13 | 256 | 176 | 5.5 | 4-bit base + 1 high bit + 6-bit scales/mins + super-block scale |
| Q6_K | 14 | 256 | 210 | 6.5625 | 4-bit low + 2-bit high (6-bit total) + 8-bit scales + super-block scale |
K-Quant Block Layouts
Q4_K (the most popular K-quant for production use):
struct block_q4_K { // 144 bytes total
half d; // 2 bytes: super-block scale for quants
half dmin; // 2 bytes: super-block scale for mins
uint8_t scales[12]; // 12 bytes: 8 x 6-bit scales + 8 x 6-bit mins
uint8_t qs[128]; // 128 bytes: 256 x 4-bit values, nibble-packed
}; // bpw = 144*8/256 = 4.5
The 256-element super-block is divided into 8 sub-blocks of 32 elements each. Each sub-block has a 6-bit scale and a 6-bit minimum, packed into 12 bytes. Dequantization for element i in sub-block j:
value = d * scale_j * (q[i] - 8) + dmin * min_j
The get_scale_min_k4() helper in KernelCommon.h unpacks the 6-bit scale and minimum from the packed 12-byte scales array.
Q3_K (aggressive 3-bit quantization):
struct block_q3_K { // 110 bytes total
uint8_t hmask[32]; // 32 bytes: high bit for each of 256 elements
uint8_t qs[64]; // 64 bytes: lower 2 bits packed (4 per byte)
uint8_t scales[12]; // 12 bytes: 16 x signed 6-bit scales
half d; // 2 bytes: super-block scale
}; // bpw = 110*8/256 = 3.4375
Each element has 3 bits: 2 bits from qs and 1 bit from hmask. The 16 sub-blocks (16 elements each) have signed 6-bit scales packed into 12 bytes. The get_scale_q3_k() helper unpacks these.
Q2_K (extreme 2-bit quantization):
struct block_q2_K { // 84 bytes total
uint8_t scales[16]; // 16 bytes: 4-bit scale + 4-bit min per sub-block
uint8_t qs[64]; // 64 bytes: 2-bit values (4 per byte)
half d; // 2 bytes: super-block scale
half dmin; // 2 bytes: super-block min scale
}; // bpw = 84*8/256 = 2.625
Q6_K (high-quality 6-bit):
struct block_q6_K { // 210 bytes total
uint8_t ql[128]; // 128 bytes: lower 4 bits of 6-bit quants
uint8_t qh[64]; // 64 bytes: upper 2 bits of 6-bit quants
int8_t scales[16]; // 16 bytes: signed 8-bit sub-block scales
half d; // 2 bytes: super-block scale
}; // bpw = 210*8/256 = 6.5625
MLX Formats: Group Quantization
MLX uses a simpler group quantization scheme. Weights are divided into groups (typically 64 elements), and each group has an FP16 scale and FP16 bias (zero-point). The packed weight buffer layout is:
[packed_weights | scales | biases]
The MLXParams.weight_bytes field gives the byte offset where scales begin. Biases follow immediately.
| Format | Internal Code | Group Size | Bits | bpw | Dequant Formula |
|---|---|---|---|---|---|
| MLX Q3 | 99 | 64 | 3 | ~3.5 | value = scale * (packed_3bit_int) + bias |
| MLX Q4 | 100 | 64 | 4 | ~4.5 | value = scale * (packed_4bit_int) + bias |
| MLX Q6 | 102 | 64 | 6 | ~6.5 | value = scale * (packed_6bit_int) + bias |
| MLX Q8 | 101 | 64 | 8 | ~8.5 | value = scale * (packed_8bit_int) + bias |
Notes on bpw for MLX: The effective bpw includes the overhead of the FP16 scale and bias per group. For group_size=64 with 4-bit values: (64 * 4 + 16 + 16) / 64 = 4.5 bpw. The exact overhead is 32 / group_size bits per weight for the scale+bias pair.
MLX Packing Details
MLX Q4: Each uint32 holds 8 x 4-bit values. The low 4 bits are element 0, bits 4-7 are element 1, and so on. The group_size determines how many packed uint32s share a single scale/bias pair: for group_size=64, that is 8 uint32s per group.
MLX Q3: Packing is more complex. Three bits per value means values do not align neatly to byte boundaries. MLX packs 32 x 3-bit values into 3 uint32s (96 bits for 32 values). The remaining 32 values in a 64-element group use another 3 uint32s.
MLX Q8: The simplest MLX format. Each byte holds one 8-bit quantized value. Dequantization is a simple multiply-add: value = scale * q[i] + bias.
Internal Dtype Codes
Akunu uses uint32_t dtype codes internally. GGUF dtypes 0-30 map directly to the GGUF specification. MLX formats use synthetic codes 99-102 that are assigned during weight loading by MLXWeightStore. The full mapping in dtype_descriptor.h:
| Code | Format | Origin |
|---|---|---|
| 0 | F32 | GGUF |
| 1 | F16 | GGUF |
| 2 | Q4_0 | GGUF |
| 3 | Q4_1 | GGUF |
| 6 | Q5_0 | GGUF |
| 8 | Q8_0 | GGUF |
| 10 | Q2_K | GGUF |
| 11 | Q3_K | GGUF |
| 12 | Q4_K | GGUF |
| 13 | Q5_K | GGUF |
| 14 | Q6_K | GGUF |
| 30 | BF16 | GGUF |
| 31 | BF16 (native) | GGUF, M4+ only |
| 99 | MLX Q3 | MLX SafeTensors |
| 100 | MLX Q4 | MLX SafeTensors |
| 101 | MLX Q8 | MLX SafeTensors |
| 102 | MLX Q6 | MLX SafeTensors |
Note that codes 4-5, 7, 9, 15-29 are defined in the GGUF specification (for types like Q5_1, Q8_1, IQ2_XXS, etc.) but are not currently supported by akunu’s Metal kernels. If you attempt to load a GGUF file using an unsupported dtype, the dtype_lookup() function falls back to the F16 descriptor, which will produce incorrect results. Check the dtype before loading.
Kernel Support Matrix
Not every format has every kernel variant. This table shows which Metal kernel types are available for each supported dtype:
| Format | GEMV | GEMV Wide | GEMM | GEMM Small | Embedding | Fused SiLU |
|---|---|---|---|---|---|---|
| F16 | yes | yes | yes | yes | yes (generic) | no |
| Q4_0 | yes | yes | yes | yes | yes | yes |
| Q4_1 | yes | yes | yes | yes | yes | no |
| Q5_0 | yes | yes | yes | yes | no | no |
| Q8_0 | yes | yes | yes | yes | yes | no |
| Q2_K | yes | no | yes | yes | no | no |
| Q3_K | yes | no | yes | yes | no | no |
| Q4_K | yes | yes | yes | yes | yes | no |
| Q5_K | yes | no | yes | yes | no | no |
| Q6_K | yes | no | yes | yes | yes | no |
| BF16 | yes | no | yes | yes | yes | no |
| MLX Q3 | yes | no | yes | yes | yes | yes |
| MLX Q4 | yes | yes | yes | yes | yes | yes |
| MLX Q6 | yes | no | yes | yes | yes | yes |
| MLX Q8 | yes | yes | yes | yes | yes | yes |
Key observations:
- GEMV Wide kernels exist only for formats with wide enough adoption to justify the implementation effort. Q4_0 and Q4_K are the most common GGUF formats; MLX Q4 and Q8 are the most common MLX formats.
- Fused SiLU kernels exist for Q4_0 and all MLX formats. These fuse the gate+up GEMV with the SiLU activation to eliminate an intermediate buffer write. Other GGUF K-quant formats do not have fused SiLU variants.
- Embedding kernels that dequantize on the fly exist for the most common formats. Formats without a specialized embedding kernel use the generic F16 embedding lookup, which requires the embedding weights to be stored in FP16 (or converted during loading).
Model Size Estimation
To estimate the file size of a model in a given format:
file_size_bytes = n_parameters * bpw / 8 + metadata_overhead
The metadata overhead (GGUF header, tensor info, tokenizer data) is typically 1-10 MB for GGUF files, negligible for large models.
| Parameters | Q4_0 (4.5 bpw) | Q4_K (4.5 bpw) | Q8_0 (8.5 bpw) | MLX Q4 (~4.5 bpw) | F16 (16 bpw) |
|---|---|---|---|---|---|
| 1B | 0.56 GB | 0.56 GB | 1.06 GB | 0.56 GB | 2.0 GB |
| 4B | 2.25 GB | 2.25 GB | 4.25 GB | 2.25 GB | 8.0 GB |
| 8B | 4.50 GB | 4.50 GB | 8.50 GB | 4.50 GB | 16.0 GB |
| 14B | 7.88 GB | 7.88 GB | 14.88 GB | 7.88 GB | 28.0 GB |
| 32B | 18.0 GB | 18.0 GB | 34.0 GB | 18.0 GB | 64.0 GB |
| 70B | 39.4 GB | 39.4 GB | 74.4 GB | 39.4 GB | 140.0 GB |
Note that Q4_0 and Q4_K have the same effective bpw (4.5) but Q4_K generally provides better accuracy due to the hierarchical scale structure. The file sizes are identical; the quality difference is in how those bits are allocated.
Memory Bandwidth and Decode Throughput
Since single-token decode is memory-bound, the quantization format directly determines the maximum achievable decode throughput. The relationship is:
theoretical_max_tok_s = memory_bandwidth_bytes_per_sec / model_weight_bytes
This means halving the bits per weight (e.g., Q8_0 to Q4_0) roughly doubles the theoretical decode speed. Here is a reference table for a 4B parameter model on various chips:
| Chip | BW (GB/s) | Q4_0 (2.25 GB) | Q8_0 (4.25 GB) | F16 (8.0 GB) |
|---|---|---|---|---|
| M1 | 68.25 | 30.3 tok/s | 16.1 tok/s | 8.5 tok/s |
| M2 | 100 | 44.4 tok/s | 23.5 tok/s | 12.5 tok/s |
| M2 Pro | 200 | 88.9 tok/s | 47.1 tok/s | 25.0 tok/s |
| M3 Max | 400 | 177.8 tok/s | 94.1 tok/s | 50.0 tok/s |
| M4 Pro | 273 | 121.3 tok/s | 64.2 tok/s | 34.1 tok/s |
| M4 Max | 546 | 242.7 tok/s | 128.5 tok/s | 68.3 tok/s |
These are theoretical maximums assuming 100% bandwidth utilization. In practice, akunu achieves 70-85% of these numbers due to overhead from attention, normalization, RoPE, and kernel dispatch. The SLC can push effective bandwidth above the raw DRAM bandwidth for chain decode workloads, sometimes exceeding the theoretical maximum.
Choosing a Quantization Format
Here is a decision guide based on common use cases:
| Use Case | Recommended Format | Rationale |
|---|---|---|
| Maximum speed, acceptable quality | Q4_0 (GGUF) or MLX Q4 | Lowest bpw with good quality. Best decode throughput. |
| Best quality/speed tradeoff | Q4_K (GGUF) | Same bpw as Q4_0 but better accuracy from hierarchical scales. |
| Quality-sensitive applications | Q6_K (GGUF) or MLX Q6 | 6.5 bpw gives near-FP16 quality with 2.5x less memory. |
| Near-lossless | Q8_0 (GGUF) or MLX Q8 | 8.5 bpw is essentially indistinguishable from FP16 for most tasks. |
| Research / debugging | F16 or BF16 | Full precision. Useful as a reference for measuring quantization error. |
| Tiny models (<1B params) | Q8_0 or F16 | Small models are already fast; use higher precision to preserve quality. |
| Large models on limited RAM | Q2_K or Q3_K | Aggressive quantization to fit models that would otherwise not fit in memory. Quality degrades noticeably. |
Unsupported GGUF Types
The GGUF specification defines several additional quantization types that akunu does not currently support with Metal kernels. These are listed in gguf_parser.h but will fall back to the F16 dtype descriptor (producing incorrect results) if encountered:
| GGUF Code | Name | Reason Not Supported |
|---|---|---|
| 7 | Q5_1 | Asymmetric 5-bit; rare in practice, superseded by Q5_K |
| 9 | Q8_1 | Asymmetric 8-bit; rarely used for distribution |
| 15 | Q8_K | 8-bit K-quant; used internally by llama.cpp during quantization, not for inference |
| 16 | IQ2_XXS | Importance-matrix quantized 2-bit; complex lookup table dequantization |
| 17 | IQ2_XS | Importance-matrix 2-bit variant |
| 18 | IQ3_XXS | Importance-matrix 3-bit |
| 19 | IQ1_S | Importance-matrix 1-bit |
| 20 | IQ4_NL | Non-linear 4-bit with lookup table |
| 21 | IQ3_S | Importance-matrix 3-bit variant |
| 22 | IQ2_S | Importance-matrix 2-bit variant |
| 23 | IQ4_XS | Importance-matrix 4-bit |
| 29 | IQ1_M | Importance-matrix 1-bit variant |
The IQ (importance-matrix quantized) formats use lookup tables for dequantization, which adds implementation complexity. They offer slightly better perplexity than the standard formats at the same bit width, but their adoption in the ecosystem is limited compared to Q4_0, Q4_K, and Q6_K. If there is demand, these could be added as Metal kernels in the future.
Comparing GGUF and MLX Quantization
Even at the same nominal bit width, GGUF and MLX quantization are not identical:
| Aspect | GGUF (e.g., Q4_0) | MLX (e.g., Q4) |
|---|---|---|
| Block/group structure | Fixed 32-element blocks | Configurable groups (typically 64) |
| Scale storage | FP16 scale per block (Q4_0) or hierarchical 6-bit scales (Q4_K) | FP16 scale + FP16 bias per group |
| Dequant | d * (q - zero_point) | scale * q + bias |
| Symmetry | Symmetric (Q4_0) or asymmetric (Q4_1) | Always asymmetric (has bias) |
| Overhead | 2 bytes per 32 elements (Q4_0) = 0.5 bpw | 4 bytes per 64 elements = 0.5 bpw |
| Quantization method | Post-training quantization by llama.cpp | Post-training quantization by MLX |
The practical quality difference between GGUF Q4_0 and MLX Q4 is small for most models. The larger group size in MLX (64 vs 32) means each scale covers more elements, which can be slightly worse for weight distributions with high local variance, but the asymmetric bias term partially compensates.
For weights that are uniformly distributed around zero (which is common after training), symmetric quantization (GGUF Q4_0) is slightly more efficient because it does not waste bits on the bias term. For weights with non-zero mean per group, asymmetric quantization (MLX Q4, GGUF Q4_1) is more accurate.
-
K-quantization was introduced by ikawrakow in llama.cpp (2023). The “K” originally stood for “k-quant” with no specific expansion. The key insight is that using more bits for scale parameters (6-bit sub-block scales + FP16 super-block scale) reduces the quantization error budget allocated to the scale itself. ↩
Appendix C: Glossary
This glossary covers approximately 70 terms used throughout this book and in the akunu source code. Each entry gives a definition pitched at a CS audience and notes where the concept appears in akunu’s implementation. Terms are listed alphabetically.
ALU (Arithmetic Logic Unit) The functional unit within a GPU core that performs integer and floating-point arithmetic. Apple Silicon GPUs have ALUs organized into SIMD groups of 32 threads. In akunu, ALU utilization is typically low during decode (memory-bound) and high during prefill (compute-bound). See Chapter 55 for roofline analysis.
Apple GPU Family
Apple’s versioning scheme for GPU feature sets. Family 7 = M1, Family 8 = M2/M3, Family 9 = M4. In akunu, ChipConfig::gpu_family stores this value and uses it to select kernel variants and tuning parameters (e.g., native BF16 support requires Family 9+).
Argmax
The operation that returns the index of the maximum value in a vector. In greedy decoding, argmax(logits) selects the next token. In akunu, the argmax kernel runs on the GPU as the final step of the dispatch table, writing the result to the token_ids buffer for chain decode.
ARM (Advanced RISC Machines) The CPU instruction set architecture used by Apple Silicon. All M-series chips use ARM’s AArch64 (64-bit) ISA. Relevant to akunu only for CPU-side operations (tokenization, weight loading, sampling); the inference hot path runs entirely on the GPU.
Attention
The core mechanism of transformer models. Given queries Q, keys K, and values V, computes softmax(Q @ K^T / sqrt(d)) @ V. In akunu, the attention kernel reads Q from the scratch buffer and K/V from the head-major KV cache. See AttentionParams in Appendix A.
BF16 (Brain Float 16)
A 16-bit floating-point format with 8-bit exponent and 7-bit mantissa, matching FP32’s exponent range at the cost of precision. Native hardware support on M4 (GPU Family 9). In akunu, BF16 weights use dtype code 30 (converted to FP16 at load) or 31 (native BF16 on M4+). See dtype_descriptor.h.
BOS (Beginning of Sequence)
A special token (typically ID 1) that marks the start of a sequence. In akunu, akunu_bench uses BOS-filled synthetic prompts for reproducible benchmarking. See tools/akunu_bench.cpp.
BPE (Byte-Pair Encoding)
A subword tokenization algorithm that iteratively merges the most frequent adjacent pairs of characters/tokens. Used by most modern LLMs (GPT, LLaMA, Qwen). In akunu, the tokenizer implementation in src/tokenizer/tokenizer.h handles BPE encoding and decoding.
Causal Masking
The constraint in autoregressive language models that position i can only attend to positions 0..i (not future positions). During prefill, akunu’s GEMM-based attention applies a causal mask to the attention scores. During single-token decode, causal masking is implicit because the query is always the latest position.
Chain Decode
Akunu’s technique for batching multiple greedy decode steps into a single GPU command buffer submission. Instead of committing one command buffer per token (incurring ~50us sync overhead each time), akunu encodes 64-128 forward passes back-to-back, with the argmax output of token N feeding as input to token N+1 via a shared GPU buffer. See encode_chain() in dispatch_table.h and ADR-5 in Chapter 56.
ChipConfig
A struct in src/core/chip_config.h that captures hardware-derived tuning parameters for Apple Silicon GPU families. Includes SLC size estimates, GEMV kernel thresholds, chain decode chunk sizes, and norm dispatch geometry. Created via ChipConfig::from_gpu(cores, family).
Command Buffer
A Metal API object (MTLCommandBuffer) that holds a sequence of encoded GPU commands. In akunu, each chain decode chunk is encoded into one command buffer. The command buffer is committed to the GPU queue and either waited on synchronously (end_encoding_sync) or monitored asynchronously.
Compute Command Encoder
A Metal API object (MTLComputeCommandEncoder) used to encode compute dispatches (set pipeline, set buffers, dispatch threads) into a command buffer. In akunu, MetalDevice::begin_encoding() creates a new encoder, and the dispatch table is encoded through it.
Decode The autoregressive token generation phase where the model processes one token at a time, appending each to the KV cache. Decode is memory-bound on Apple Silicon because each step reads the entire weight matrix for a single vector multiplication. Contrast with Prefill.
Dispatch Table
A pre-compiled sequence of DispatchCmd structs representing one token’s complete forward pass. Built once during model initialization by build_dispatch_table(). Replayed N times by encode_chain() during inference. See src/core/dispatch_table.h and ADR-1 in Chapter 56.
DispatchCmd
A POD struct containing everything needed for a single GPU kernel dispatch: pipeline state object, buffer bindings (up to 8), inline parameters (up to 64 bytes), threadgroup memory, dispatch geometry, and per-token patching instructions. Defined in dispatch_table.h.
DType Descriptor
A struct in src/core/dtype_descriptor.h that maps a GGUF dtype code to the appropriate kernel names and dispatch geometry. Contains fields for GEMV, GEMV-wide, GEMM, embedding, and fused SiLU kernel names, plus threadgroup sizes for each. The kDTypes[] array is the single source of truth for all dtype-dependent behavior.
Embedding
The process of converting a discrete token ID into a dense floating-point vector. The embedding table is a matrix of shape [vocab_size, dim] where each row is the learned representation of one token. In akunu, the embedding lookup is the first kernel in the dispatch table, reading from the token_ids buffer and writing to the h0 scratch buffer. Quantized embedding kernels (e.g., embedding_lookup_q4_0) dequantize on the fly.
Encoder-Decoder
A transformer architecture with separate encoder and decoder stacks connected by cross-attention. The encoder processes the input (e.g., mel spectrograms for Whisper) in parallel; the decoder generates output tokens autoregressively, attending to both its own previous outputs and the encoder’s representations. In akunu, enabled by ArchDescriptor::is_encoder_decoder = true. See arch_whisper() in arch_descriptor.h.
EOS (End of Sequence)
A special token that signals the model wants to stop generating. When the model outputs EOS, akunu’s akunu_generate() terminates the decode loop and returns the generation statistics.
FFN (Feed-Forward Network)
The position-wise fully-connected sub-layer in each transformer block. Modern LLMs use a gated variant (SwiGLU or GEGLU) with three weight matrices: gate, up, and down projections. In akunu, the FFN intermediate dimension is stored in AkunuModelConfig::ffn_dim and is typically ~2.7x the model dimension for SwiGLU architectures.
FNV-1a
A non-cryptographic hash function (Fowler-Noll-Vo) used in akunu’s N-gram predictor for hashing token contexts. The 64-bit variant uses offset basis 14695981039346656037 and prime 1099511628211. See NGramPredictor::context_hash() in ngram_predictor.h.
FP16 (Half-Precision Float) IEEE 754 half-precision: 5-bit exponent, 10-bit mantissa. The native compute precision for Apple Silicon GPUs. In akunu, all intermediate activations (hidden states, attention outputs, FFN intermediates) are FP16. Weight matrices may be quantized to lower precision but are dequantized to FP16 during computation.
FlashAttention
An efficient attention algorithm that tiles the softmax computation to avoid materializing the full [seq_len, seq_len] attention matrix in memory.1 In akunu, the prefill attention kernel uses a tiled approach inspired by FlashAttention, computing attention in chunks that fit in threadgroup memory.
GBNF (GGML BNF)
A grammar specification format based on BNF (Backus-Naur Form), used for constrained decoding in llama.cpp and akunu. In akunu, akunu_grammar_create() parses a GBNF string and creates a grammar constraint that masks invalid tokens at each generation step. See src/grammar/json_schema_to_grammar.h.
GELU (Gaussian Error Linear Unit)
An activation function: GELU(x) = x * Phi(x) where Phi is the standard Gaussian CDF. Used by Gemma (with gate: GEGLU) and Whisper (plain GELU). In akunu, implemented as act_gelu() in KernelCommon.h using the tanh approximation.
GEMM (General Matrix-Matrix Multiply)
The C = alpha * A @ B + beta * C operation. Used during prefill when processing multiple tokens simultaneously. The arithmetic intensity scales with the batch dimension M, making prefill compute-bound for moderate batch sizes. In akunu, GEMM kernels use simdgroup_matrix hardware intrinsics and are selected via gemm_kernel_for() in dtype_descriptor.h. See GEMMParams in Appendix A.
GEMV (General Matrix-Vector Multiply)
The y = A @ x operation (M=1 case of GEMM). The dominant operation during single-token decode. Memory-bound on Apple Silicon because the entire weight matrix must be read for each multiplication. In akunu, GEMV kernels are specialized per dtype and chip configuration, with standard, large-K, and wide-N variants.
GGUF (GGML Universal File)
A binary file format for storing quantized LLM weights and metadata. Successor to GGML format, used by llama.cpp and supported by most open-source LLM tools. In akunu, parsed by src/weight/gguf_parser.h. Contains tensor data, model architecture metadata, tokenizer vocabulary, and quantization parameters in a single file.
GQA (Grouped Query Attention)
An attention variant where multiple query heads share a single key/value head, reducing KV cache memory and attention compute.2 For example, LLaMA 3 uses 32 query heads but only 8 KV heads (ratio 4:1). In akunu, GQA is handled by the attention kernel via the n_heads / n_kv_heads fields in AttentionParams.
Gumbel-Max Trick A method for sampling from a categorical distribution by adding Gumbel-distributed noise to log-probabilities and taking the argmax. Considered but not adopted as the default sampling strategy in akunu (see ADR-6 in Chapter 56). The main barrier is incompatibility with grammar-constrained decoding and top-k/top-p filtering.
Head-Major Layout
A memory layout for the KV cache where all positions for a given head are contiguous: [n_kv_heads, max_seq_len, head_dim]. Chosen by akunu (ADR-10) because the attention kernel reads all K/V vectors for one head sequentially, and contiguous layout maximizes memory bandwidth utilization.
K-Quant (K-Quantization) A family of GGUF quantization formats (Q2_K through Q6_K) that use a two-level hierarchical scheme with 256-element super-blocks containing smaller sub-blocks with their own 6-bit scales. Provides better accuracy than basic block quantization at the same bit width. See Appendix B for format details.
KV Cache
A buffer that stores the key and value vectors for all previously processed tokens, avoiding recomputation during autoregressive decode. In akunu, defined in src/cache/kv_cache.h as a KVCache struct with per-layer K and V buffers in head-major FP16 layout. Memory cost scales as 2 * n_layers * n_kv_heads * max_seq_len * head_dim * 2 bytes.
LayerNorm (Layer Normalization)
A normalization technique: LN(x) = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias. Used by Whisper. In akunu, the LayerNormParams struct (Appendix A) drives the kernel. Most modern LLMs use RMSNorm instead.
LLM (Large Language Model) A neural network with billions of parameters trained on large text corpora to generate text autoregressively. Akunu is an inference engine for LLMs on Apple Silicon, supporting architectures like LLaMA, Qwen, and Gemma.
Logits
The raw (unnormalized) output scores from the model’s final linear projection. A vector of size vocab_size where each element represents the model’s confidence that the corresponding token should come next. In akunu, logits are stored in the scratch.logits buffer (FP16, vocab_size * 2 bytes).
Mel Spectrogram
A time-frequency representation of audio, computed by applying the mel-scale filterbank to the Short-Time Fourier Transform (STFT) of a waveform. Whisper models expect 80-bin or 128-bin mel spectrograms at 16kHz sample rate. In akunu, mel computation is handled by src/audio/mel.h as a preprocessing step before encoder inference. The bin count is stored in AkunuModelConfig::n_mels.
Memory Mapping (mmap)
An operating system facility that maps a file’s contents into virtual memory, allowing the file to be read as if it were in RAM without explicit read calls. In akunu, both GGUF and SafeTensors files are opened via mmap(), enabling zero-copy access to tensor data. The OS manages paging from disk as needed, which means model loading appears near-instantaneous for files already in the page cache.
make_uniform()
A Metal shader helper function defined in KernelCommon.h that wraps simd_broadcast_first(). Tells the Metal compiler that a value is the same across all threads in a SIMD group, enabling better predication and vectorization. Used for loop bounds and uniform conditionals in GEMV/GEMM kernels.
Metal
Apple’s low-level GPU programming framework, analogous to Vulkan or Direct3D 12. Provides direct access to GPU compute via compute pipelines, command buffers, and buffers. Akunu uses Metal exclusively as its GPU backend via MetalDevice in backend/metal/metal_device.h.
Metallib
A pre-compiled Metal shader library (.metallib file). Contains compiled pipeline state objects for all of akunu’s GPU kernels. Loaded at model init time via MetalDevice::load_library(). Using a pre-compiled metallib avoids runtime shader compilation, which can take seconds.
MHA (Multi-Head Attention)
The standard attention mechanism where Q, K, and V are split into multiple heads, each attending independently, then concatenated. In akunu, the number of heads is specified by AkunuModelConfig::n_heads (for Q) and n_kv_heads (for K/V in GQA).
MLX
Apple’s array computation framework for machine learning, implemented in C++ and Metal with Python bindings. Uses SafeTensors format with group quantization. In akunu, MLX-format models are loaded via MLXWeightStore in src/weight/mlx_weight_store.h, with weight name mapping from HuggingFace conventions to akunu’s canonical names.
MSL (Metal Shading Language)
The programming language for Metal GPU shaders, based on C++14 with extensions for GPU-specific types (half, simdgroup, threadgroup memory). All of akunu’s GPU kernels are written in MSL. The source lives in backend/metal/kernels/.
NeoX RoPE
A variant of Rotary Position Embeddings where the rotation dimensions are arranged in a split-half pattern: the first head_dim/2 elements are one component, the second half is the other. Used by Qwen, Gemma, and GPT-NeoX. In akunu, selected via ArchDescriptor::rope_kernel = "rope_neox_qkv_write_f16".
Neural Engine A dedicated machine learning accelerator on Apple Silicon SoCs, optimized for dense matrix operations on fixed-size tensors via Core ML. Akunu does not use the Neural Engine because it requires models in Core ML format and does not support the dynamic shapes needed for autoregressive decoding with variable-length KV caches. The GPU provides more flexibility for custom kernel implementations.
N-Gram Predictor
Akunu’s lightweight speculative decoding module that predicts future tokens based on frequency tables of recently observed n-gram patterns (up to 4-grams). Does not require a draft model. Defined in src/speculative/ngram_predictor.h. Enabled via akunu_set_speculation(model, true).
NPDA (Neural Processing and Data Acceleration) Apple’s term for the collection of hardware blocks on their SoCs that accelerate ML workloads, including the GPU, Neural Engine, and AMX (Apple Matrix eXtensions). Akunu uses only the GPU via Metal; it does not target the Neural Engine or AMX.
Ping-Pong Buffers
The technique of alternating between two buffers (h0 and h1) for the transformer’s residual stream. Each layer reads from one buffer, writes intermediate results, then adds the residual back to the other buffer. This avoids allocating a new buffer per layer. See ScratchBuffers in src/cache/scratch.h and ADR-9 in Chapter 56.
Pipeline State Object (PSO)
A Metal API object (MTLComputePipelineState) representing a compiled GPU kernel ready for dispatch. In akunu, PSOs are cached in MetalDevice::pso_cache_ (keyed by kernel name) and looked up by get_pipeline(). The dispatch table stores PSO handles directly to avoid per-dispatch lookups.
Prefill
The phase of LLM inference where the entire prompt is processed in one batch to populate the KV cache. Unlike decode (which processes one token at a time), prefill uses GEMM (matrix-matrix) operations and can be compute-bound for longer prompts. In akunu, prefill is triggered by akunu_prefill() and uses the batch_* scratch buffers.
Q4_0 The most common GGUF quantization format. 32-element blocks with one FP16 scale each; 4 bits per weight value; 4.5 effective bits per weight. See Appendix B for the full block layout and dequantization formula.
QKV Fusion
The optimization of fusing the Q, K, and V linear projections into a single GEMV that writes to a contiguous output buffer [q_dim + 2*kv_dim]. Reduces three GEMV dispatches to one. In akunu, the QKV buffer is scratch.qkv with sub-offsets qkv_q_offset, qkv_k_offset, qkv_v_offset.
Residual Connection
A shortcut that adds a layer’s input to its output: output = layer(x) + x. Prevents vanishing gradients in deep networks and is used in every transformer layer (both after attention and after FFN). In akunu, residual additions alternate between the h0 and h1 ping-pong buffers.
Repetition Penalty
A technique to discourage the model from repeating tokens by modifying logits for recently generated tokens. Positive logits are divided by the penalty factor; negative logits are multiplied. In akunu, configurable via AkunuSamplingConfig::repeat_penalty and can optionally be applied on the GPU via the RepetitionPenaltyParams kernel.
Roofline Model A visual performance model that plots a kernel’s achievable throughput (FLOPS) against its arithmetic intensity (FLOPS/byte), bounded by the hardware’s peak compute and peak memory bandwidth. For LLM decode on Apple Silicon, most kernels (GEMV, attention, norms) fall in the memory-bound region. See Chapter 55 for a detailed roofline analysis with Apple Silicon numbers.
RMSNorm (Root Mean Square Normalization)
A simplified normalization: RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight. Omits the mean subtraction of LayerNorm. Used by LLaMA, Qwen, Gemma, and most modern LLMs. In akunu, driven by RMSNormParams (Appendix A) and selected via ArchDescriptor::norm_type = "rmsnorm".
RoPE (Rotary Position Embeddings)
A position encoding method that rotates query and key vectors by position-dependent angles, allowing the model to learn relative position relationships.3 In akunu, RoPE is applied by the fused rope_qkv_write_f16 kernel during decode and the standalone rope_f16 kernel during prefill. See RoPEParams and RoPEQKVWriteParams in Appendix A.
SafeTensors
A simple binary format for storing tensors, developed by Hugging Face. The header is a JSON object mapping tensor names to their dtype, shape, and byte offsets; the rest of the file is raw tensor data. In akunu, parsed by SafeTensorsParser in src/weight/safetensors_parser.h. MLX models use SafeTensors as their container format.
Sampling
The process of selecting the next token from the logit distribution. Options include greedy (argmax), temperature scaling, top-k (keep only top K logits), top-p/nucleus (keep logits whose cumulative probability exceeds p), and min-p (keep logits above a minimum probability threshold). In akunu, configured via AkunuSamplingConfig in types.h.
Scratch Buffers
Pre-allocated GPU buffers for all intermediate computations during inference. Created once at model load time. Includes h0/h1 (residual ping-pong), qkv, attn_out, ffn_gate/up/act, logits, and batch variants for prefill. See ScratchBuffers in src/cache/scratch.h.
SIMD Group
A group of 32 threads that execute in lockstep on Apple Silicon GPUs (equivalent to a “warp” on NVIDIA GPUs or “wavefront” on AMD). SIMD group operations (simd_sum, simd_max, simd_broadcast_first) are used extensively in akunu’s reduction kernels. The width is defined as SIMD_WIDTH = 32 in KernelCommon.h.
simdgroup_matrix
A Metal intrinsic type that maps to Apple Silicon’s hardware matrix multiply unit. Supports 8x8 FP16 matrix tiles. Used by akunu’s GEMM kernels (simd_gemm_*) for prefill operations with tiling constants TILE_M=64, TILE_N=64, TILE_K=32 defined in KernelCommon.h.
SiLU (Sigmoid Linear Unit)
An activation function: SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)). Used by LLaMA and Qwen in the FFN’s SwiGLU block. In akunu, implemented as act_silu() in KernelCommon.h. Fused SiLU GEMV kernels (gemv_q4_0_silu, gemv_mlx_q4_silu, etc.) apply this during the GEMV accumulation.
SLC (System Level Cache)
A large shared cache on Apple Silicon that sits between the GPU/CPU cores and main memory. Size ranges from 8 MB (M1 base) to 96 MB (Ultra). Not directly programmable, but its presence means that data recently read by the GPU may still be in cache for subsequent reads. In akunu, ChipConfig::slc_bytes estimates the SLC size and should_fuse_weights is enabled when the SLC is large enough to benefit from weight fusion.
SoC (System on Chip) An integrated circuit that combines CPU, GPU, Neural Engine, memory controller, and other components on a single die. Apple’s M-series chips are SoCs with unified memory architecture. Relevant to akunu because UMA eliminates the PCIe bottleneck found in discrete GPU systems.
Softmax
The function softmax(x)_i = exp(x_i) / sum(exp(x_j)) that converts logits to a probability distribution. Used in attention (over the attention scores) and optionally for final token sampling. In akunu, the standalone softmax kernel is driven by SoftmaxParams (Appendix A); during decode attention, softmax is fused into the attention kernel.
Speculative Decoding
A technique to accelerate autoregressive generation by using a fast predictor to draft multiple tokens, then verifying them in parallel with the full model.4 In akunu, implemented via the N-gram predictor (src/speculative/ngram_predictor.h) which does not require a separate draft model. Enabled via akunu_set_speculation(model, true).
SwiGLU
A gated FFN variant: FFN(x) = (SiLU(W_gate @ x) * (W_up @ x)) @ W_down. Combines SiLU activation with a gating mechanism. Used by LLaMA, Qwen, and most modern LLMs. In akunu, the SwiGLU pattern is encoded in the dispatch table as: (1) fused gate+up GEMV, (2) SiLU-gate activation kernel (or fused SiLU GEMV), (3) down projection GEMV.
Threadgroup A group of threads that share threadgroup memory and can synchronize via barriers. In Metal, threadgroups are the unit of dispatch: you specify (grid_size, threadgroup_size) when dispatching a kernel. In akunu, threadgroup sizes are tuned per kernel type and chip: GEMV typically uses 128 or 256 threads, GEMM uses 128 (4 SIMD groups).
Threadgroup Memory
Fast on-chip memory shared among threads in a threadgroup (equivalent to “shared memory” in CUDA). Limited to 32 KB per threadgroup on Apple Silicon (MAX_TG_MEMORY in KernelCommon.h). Used in akunu for GEMM tile buffers, reduction scratch space, and attention score accumulation.
TTFT (Time To First Token)
The wall-clock time from submitting a prompt to receiving the first generated token. Equals prefill time plus one decode step. The most perceptually important latency metric for interactive applications. Measured by akunu_benchmark (Chapter 55).
Tokenizer
The component that converts text to token IDs (encoding) and token IDs back to text (decoding). In akunu, the tokenizer is loaded from GGUF metadata or MLX tokenizer.json and exposes akunu_encode() and akunu_decode_token() via the C API. Implementation in src/tokenizer/tokenizer.h.
UMA (Unified Memory Architecture)
Apple Silicon’s memory architecture where CPU and GPU share the same physical memory pool. Eliminates the need for explicit data transfers between CPU and GPU. In akunu, all buffers (weights, KV cache, scratch) are allocated once and accessed by both CPU and GPU without copying. Metal buffers allocated via MTLDevice.makeBuffer(bytesNoCopy:...) enable true zero-copy access.
Tiling
A technique for breaking large matrix operations into smaller blocks (tiles) that fit in fast on-chip memory. In akunu’s GEMM kernels, tiling constants are TILE_M=64, TILE_N=64, TILE_K=32 (defined in KernelCommon.h). Each threadgroup processes one output tile, loading A and B sub-tiles into threadgroup memory cooperatively before computing the tile product using simdgroup_matrix intrinsics.
Top-k Sampling
A sampling strategy that restricts the candidate set to the k tokens with the highest logits before applying softmax and drawing a random sample. Reduces the probability of low-quality long-tail tokens. In akunu, configured via AkunuSamplingConfig::top_k.
Top-p Sampling (Nucleus Sampling)
A sampling strategy that sorts tokens by probability and includes tokens until the cumulative probability exceeds p.5 More adaptive than top-k because the number of candidates varies with the distribution’s entropy. In akunu, configured via AkunuSamplingConfig::top_p.
Min-p Sampling
A sampling strategy that keeps all tokens whose probability is at least min_p * max_probability. Unlike top-k (fixed count) or top-p (fixed cumulative threshold), min-p scales naturally with the model’s confidence: when the model is confident, fewer tokens pass the threshold; when uncertain, more pass. In akunu, configured via AkunuSamplingConfig::min_p.
Transformer
The neural network architecture underlying modern LLMs, based on self-attention and position-wise feed-forward networks.6 A decoder-only transformer (LLaMA, GPT) processes tokens autoregressively; an encoder-decoder transformer (Whisper) has separate encoder and decoder stacks. Akunu supports both via the ArchDescriptor::is_encoder_decoder flag.
Vocabulary Size
The number of distinct tokens the model can produce, typically 32K to 128K for modern LLMs. Stored in AkunuModelConfig::vocab_size. Determines the size of the final logit projection GEMV (dim -> vocab_size) and the logits scratch buffer (vocab_size * 2 bytes FP16).
Warp / Wave Terms used by NVIDIA (“warp”, 32 threads) and AMD (“wavefront”, 32 or 64 threads) for the SIMD execution unit equivalent to Apple’s “SIMD group.” All three refer to the same concept: a group of threads executing the same instruction in lockstep. Apple Silicon uses a fixed SIMD width of 32.
Weight Fusion
The optimization of concatenating two or more weight matrices into a single contiguous buffer so they can be loaded by a single GEMV dispatch. In akunu, gate and up projection weights are fused (WeightProvider::fuse_weights()) on Pro+ chips where the SLC is large enough (>16 MB) to benefit from sequential access to the larger combined buffer. QKV weights can also be fused.
WeightProvider
The abstraction layer in src/weight/weight_provider.h that wraps either a GGUF WeightStore or an MLX MLXWeightStore, providing a uniform interface for tensor access, metadata queries, and weight fusion regardless of the underlying file format. Format detection is automatic based on file path (directory or .safetensors = MLX, otherwise GGUF).
WeightStore
The GGUF-specific weight loading backend. Opens a GGUF file via gguf_open(), extracts model configuration from metadata, and provides zero-copy GPU buffer access to tensor data via memory mapping. Defined alongside the GGUF parser in src/weight/weight_store.h.
Whisper
OpenAI’s speech recognition model, an encoder-decoder transformer that processes mel spectrograms to produce text transcriptions.7 In akunu, Whisper is supported via the arch_whisper() descriptor, which enables encoder-decoder mode, cross-attention, Conv1D frontend, LayerNorm, and bias terms. The C API exposes akunu_transcribe() and related functions.
xgrammar
A third-party library (vendored in 3rdparty/xgrammar/) that implements grammar-constrained decoding. Compiles GBNF grammars and JSON schemas into efficient token masks that can be applied at each generation step to guarantee structurally valid output. Integrated into akunu via akunu_grammar_create() and akunu_generate_grammar().
Zero-Copy
The ability to share data between CPU and GPU without physically copying bytes. On Apple Silicon with UMA, Metal buffers are backed by physical pages that both the CPU and GPU can access. In akunu, GGUF tensor data is memory-mapped (mmap) from the file and the GPU buffer is created over the same pages, achieving true zero-copy weight loading. The SafeTensorsParser similarly uses mmap for MLX format files.
Index of Terms by Category
For quick navigation, here are the glossary terms grouped by topic:
Hardware and Platform: ALU, Apple GPU Family, ARM, ChipConfig, Metal, Metallib, MSL, NPDA, SIMD Group, simdgroup_matrix, SLC, SoC, Threadgroup, Threadgroup Memory, UMA, Warp/Wave
Quantization and Data Formats: BF16, FP16, GGUF, K-Quant, MLX, Q4_0, SafeTensors, Zero-Copy
Model Architecture: Attention, Causal Masking, GELU, GQA, LayerNorm, MHA, NeoX RoPE, RMSNorm, RoPE, SiLU, Softmax, SwiGLU, Tiling, Transformer
Inference Engine: Argmax, Chain Decode, Command Buffer, Compute Command Encoder, Decode, Dispatch Table, DispatchCmd, DType Descriptor, EOS, KV Cache, Logits, Head-Major Layout, Ping-Pong Buffers, Pipeline State Object, Prefill, QKV Fusion, Scratch Buffers, Speculative Decoding, TTFT, Weight Fusion, WeightProvider, WeightStore
Operations: BOS, BPE, GEMM, GEMV, Sampling, Top-k, Top-p, Min-p, Tokenizer, Vocabulary Size
External Libraries and Tools: FlashAttention, GBNF, Gumbel-Max Trick, LLM, N-Gram Predictor, Whisper, xgrammar
-
Dao, T. et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022. See https://arxiv.org/abs/2205.14135. ↩
-
Ainslie, J. et al. (2023). “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” EMNLP 2023. See https://arxiv.org/abs/2305.13245. ↩
-
Su, J. et al. (2021). “RoFormer: Enhanced Transformer with Rotary Position Embedding.” arXiv:2104.09864. See https://arxiv.org/abs/2104.09864. ↩
-
Leviathan, Y. et al. (2023). “Fast Inference from Transformers via Speculative Decoding.” ICML 2023. See https://arxiv.org/abs/2211.17192. ↩
-
Holtzman, A. et al. (2020). “The Curious Case of Neural Text Degeneration.” ICLR 2020. See https://arxiv.org/abs/1904.09751. ↩
-
Vaswani, A. et al. (2017). “Attention Is All You Need.” NeurIPS 2017. See https://arxiv.org/abs/1706.03762. ↩
-
Radford, A. et al. (2022). “Robust Speech Recognition via Large-Scale Weak Supervision.” arXiv:2212.04356. See https://arxiv.org/abs/2212.04356. ↩