Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Model Loading: From File to GPU

If you have ever wondered what actually happens between the moment you point an inference engine at a model file and the moment the first token appears on your screen, this chapter is for you. In akunu, that transition is orchestrated by a single function – akunu_load_model – defined in src/inference/model_loader.cpp. The function is roughly 250 lines of C++, and it touches almost every subsystem in the engine: format detection, weight I/O, configuration parsing, architecture selection, tokenizer construction, GPU memory allocation, weight fusion, KV cache creation, scratch buffer setup, RoPE precomputation, dispatch table building, shader compilation, and a warmup forward pass. That is a lot of ground to cover, so let us take it one stage at a time.

The Bird’s-Eye View

Here is the full pipeline from “user calls the C API” to “model is ready to generate text.”

  akunu_load_model(path, metallib_path, max_context)
  |
  |  1. Create GPU device
  |     Device::create_default()
  |
  |  2. Load metallib (shader library)
  |     load_metallib(device, metallib_path)
  |
  |  3. Detect model format (GGUF vs MLX/SafeTensors vs Whisper GGML)
  |     is_whisper_ggml(path) / WeightProvider::detect_format(path)
  |
  |  4. Open weights
  |     WeightProvider::open(path)
  |
  |  5. Extract model config
  |     weights.get_config() --> AkunuModelConfig
  |
  |  6. Infer architecture descriptor
  |     arch_from_config(config.architecture, config.dim) --> ArchDescriptor
  |
  |  7. Detect chip capabilities
  |     device.chip_config() --> ChipConfig
  |
  |  8. Resolve weight quirks (MLX RoPE style, tied embeddings, quant params)
  |
  |  9. Precompute RoPE frequencies
  |     init_rope_freqs(state, path)
  |
  | 10. Validate config (dim, layers, vocab, heads, head_dim all nonzero)
  |
  | 11. Load tokenizer
  |     load_tokenizer(state, path)
  |
  | 12. Allocate KV cache
  |     KVCache::create(device, n_layers, n_kv_heads, head_dim, ctx)
  |
  | 13. Allocate scratch buffers
  |     ScratchBuffers::create(device, config, ctx, max_prefill_chunk)
  |
  | 14. Build greedy dispatch table
  |     build_dispatch_table(device, weights, config, arch, chip, kv, scratch)
  |
  | 15. Build sampled dispatch table (greedy + Gumbel top-k + argmax)
  |
  | 16. Pre-allocate GPU param buffers for dispatch commands
  |
  | 17. Warmup forward pass (compile all PSOs)
  |     encode_dispatch_table(&table, 0, 1); end_encoding_sync();
  |
  | 18. Return opaque model handle
  v

That is 18 discrete stages. Some of them are trivial (one function call). Others – like the dispatch table build – are complex enough to deserve an entire chapter of their own (see Chapter 25). But they all execute sequentially, one time, at model load. Nothing in this list happens again during inference. That is the whole point: pay the cost once up front so the hot path is zero-allocation and zero-branching.

Stage 1: Creating the Metal Device

The very first thing akunu_load_model does is create a Metal device:

state->device = Device::create_default();
printf("GPU: %s\n", state->device->name());

Device is akunu’s hardware abstraction layer. On Apple Silicon, create_default() calls MTLCreateSystemDefaultDevice(), wraps the resulting id<MTLDevice> in an internal implementation class (MetalDeviceImpl), and queries the hardware for its GPU core count and Apple GPU family number. Those two integers feed into ChipConfig::from_gpu() later, which tunes every kernel dispatch in the engine to the specific chip you are running on.

Why does the device come first? Because almost everything else – the weight provider, the KV cache, the scratch buffers – needs a Device& reference to allocate GPU memory. Device construction is cheap (microseconds), and it anchors the entire object graph.

Stage 2: Loading the Metallib

if (!load_metallib(*state->device, metallib_path)) {
    set_error("Failed to load metallib. Build with: make shaders");
    delete state;
    return nullptr;
}

A .metallib is Apple’s precompiled shader archive – the GPU equivalent of a .a static library. Akunu’s Metal kernels (GEMV, GEMM, RoPE, attention, normalization, activation, argmax, etc.) are compiled offline into a single akunu.metallib file. load_metallib tries a user-provided path first, then falls back to a handful of well-known build output locations:

engine/build/akunu.metallib
.build/metallib/akunu.metallib
build/akunu.metallib
akunu.metallib

If none of these exist, the model load fails immediately. You cannot do inference without GPU kernels. The function is deliberately simple – no complex search logic, no environment variables, just a flat priority list.

Why load shaders before weights? Because there is no point spending time and memory on a multi-gigabyte weight file if we cannot even dispatch a kernel. Fail fast.

Stage 3: Format Detection

Akunu supports three model formats:

  1. GGUF – the standard quantized format from llama.cpp. Most common.
  2. MLX SafeTensors – Apple’s MLX framework format. Directory of .safetensors files plus a config.json.
  3. Whisper GGML – legacy binary format for OpenAI Whisper models. Uses a "lmgg" or "ggjt" magic number.

The detection logic is refreshingly unsophisticated:

bool is_whisper_ggml(const char *path) {
    FILE *f = fopen(path, "rb");
    char magic[4];
    fread(magic, 1, 4, f);
    fclose(f);
    return memcmp(magic, "lmgg", 4) == 0
        || memcmp(magic, "ggjt", 4) == 0;
}

If the path is a directory or ends in .safetensors, it is MLX. Otherwise, GGUF. No content sniffing beyond the four-byte magic for Whisper. This is a case where “dumb and fast” beats “clever and fragile.”

Stage 4: Opening Weights

For standard LLM models (not Whisper), the next step is opening the weight provider:

state->weights = new WeightProvider(*state->device);
if (!state->weights->open(model_path)) {
    set_error("Failed to open model: %s", model_path);
    delete state;
    return nullptr;
}

WeightProvider is a unified facade that wraps either a WeightStore (GGUF parser) or an MLXWeightStore (SafeTensors parser). The open() call parses file headers and metadata but does NOT eagerly load all tensors into GPU memory. Weight tensors are memory-mapped and lazily materialized on first access through get_tensor(). For a 7B Q4_0 model, the GGUF file is about 4 GB – you do not want to copy all of that before you even know the model’s architecture.

The WeightProvider exposes a uniform interface regardless of format:

+-----------------+-----------+-----------+
| Method          | GGUF      | MLX       |
+-----------------+-----------+-----------+
| get_config()    | metadata  | JSON      |
| get_tensor(name)| mmap+GPU  | mmap+GPU  |
| get_dtype(name) | per-tensor| per-tensor|
| has_tensor(name)| lookup    | lookup    |
| fuse_weights()  | concat    | concat    |
+-----------------+-----------+-----------+

This abstraction is crucial because the rest of the engine – table builder, prefill, everything – never knows or cares what format the weights came from.

Stage 5: Extracting the Model Config

state->config = state->weights->get_config();

For GGUF, this reads well-known metadata keys (llama.embedding_length, llama.block_count, llama.attention.head_count, etc.) from the GGUF header. For MLX, it parses config.json in the model directory.

The result is a plain-old-data struct, AkunuModelConfig:

AkunuModelConfig
+----------------------------------+
| dim            (e.g., 4096)      |  embedding dimension
| n_layers       (e.g., 32)       |  transformer layers
| n_heads        (e.g., 32)       |  query heads
| n_kv_heads     (e.g., 8)        |  key/value heads (GQA)
| head_dim       (e.g., 128)      |  dim per head
| q_dim          (e.g., 4096)     |  n_heads * head_dim
| kv_dim         (e.g., 1024)     |  n_kv_heads * head_dim
| ffn_dim        (e.g., 14336)    |  feed-forward intermediate
| vocab_size     (e.g., 32000)    |  vocabulary size
| max_seq_len    (e.g., 8192)     |  maximum context length
| norm_eps       (e.g., 1e-5)     |  RMSNorm epsilon
| rope_theta     (e.g., 500000.0) |  RoPE base frequency
| architecture   "llama"          |  arch name string
+----------------------------------+

Every downstream allocation and dispatch geometry is driven entirely by these numbers. There are no “if model is 7B then…” branches anywhere in akunu. The config is the single source of truth.

Stage 6: Architecture Inference

state->arch = arch_from_config(state->config.architecture, state->config.dim);

This is one of akunu’s cleanest design patterns. The ArchDescriptor struct captures every architecture-specific behavior as data, not code:

ArchDescriptor
+------------------------------+------------------+
| Field                        | Example (LLaMA)  |
+------------------------------+------------------+
| activation_kernel            | "silu_gate_f16"  |
| embedding_scale              | 0.0 (no scale)   |
| has_qk_norm                  | false            |
| has_post_attn_norm           | false            |
| has_post_ffn_norm            | false            |
| rope_kernel                  | "rope_qkv_write" |
| tie_embeddings               | false            |
| quant_bits                   | 0 (GGUF native)  |
+------------------------------+------------------+

Different architectures get different descriptors:

arch_from_config("llama", ...)  --> arch_llama()
arch_from_config("qwen3", ...)  --> arch_qwen3()    // QK-norm, tied embeds
arch_from_config("gemma", ...)  --> arch_gemma(dim)  // GELU, post-norms, scale
arch_from_config("whisper",...) --> arch_whisper()   // LayerNorm, cross-attn
arch_from_config("bert", ...)   --> arch_bert()      // encoder-only, SwiGLU

The beauty of this approach is that adding a new architecture requires exactly one factory function and one case in arch_from_config. The table builder, the prefill engine, and the decode loop never branch on the architecture name. They read fields from the descriptor. This is the “data-driven polymorphism” pattern, and it produces code that is both simpler and faster than the traditional virtual-method approach.

Stage 7: Chip Configuration

state->chip = state->device->chip_config();

ChipConfig captures hardware-specific tuning parameters:

ChipConfig
+-----------------------------+-------+----------+----------+
| Field                       | M1    | M3 Pro   | M4 Ultra |
+-----------------------------+-------+----------+----------+
| gpu_cores                   | 8     | 18       | 64       |
| slc_bytes                   | 8 MB  | 24 MB    | 96 MB    |
| should_fuse_weights         | false | true     | true     |
| chain_decode_chunk          | 64    | 96       | 128      |
| max_prefill_chunk           | 4096  | 4096     | 4096     |
| q4_small_k_threshold        | 2048  | 1024     | 512      |
| wide_gemv_threshold         | 32768 | 32768    | 32768    |
+-----------------------------+-------+----------+----------+

Notice should_fuse_weights: on chips with a large System Level Cache (16+ MB, meaning Pro and above, plus M4 base), akunu will fuse Q/K/V weight matrices into a single contiguous buffer and dispatch one GEMV instead of three. The fused weights fit in SLC, so the second and third “GEMVs” hit cache instead of DRAM. On base M1/M2/M3, the SLC is too small for this to help, so fusion is disabled.

Also notice chain_decode_chunk: this controls how many tokens are chained into a single GPU command buffer submission. Larger chunks amortize the overhead of Metal command encoding. M4 can handle 128 tokens per chunk; M1 is limited to 64 due to its smaller command processor and narrower memory bus.

Stage 8: Weight Quirks

After extracting the architecture descriptor, model_loader patches it with format-specific adjustments:

// MLX Llama uses NeoX (split-half) RoPE, not interleaved
if (weights->format() == MLX_SAFETENSORS && strstr(config.architecture, "llama"))
    arch.rope_kernel = "rope_neox_qkv_write_f16";

// Tie embeddings if the architecture says so, OR if output.weight is missing
if (!arch.tie_embeddings)
    arch.tie_embeddings = !weights->has_tensor("output.weight");

// Copy MLX quantization info
arch.quant_bits = weights->quant_bits();
arch.quant_group_size = weights->quant_group_size();

This is where the messy real-world details of model formats get resolved. MLX uses a different RoPE convention than GGUF for the same architecture. Some GGUF exports include output.weight and some do not. Rather than scattering format checks throughout the engine, they are all concentrated here in the loader.

Stage 9: RoPE Frequency Precomputation

init_rope_freqs(state, model_path);

Most models use standard RoPE with a base frequency theta. But LLaMA 3 introduced a complex wavelen-based frequency scaling scheme, and some models use simple linear scaling. akunu handles both by precomputing the frequency divisors for each dimension of the rotary embedding at load time:

For each dimension i in [0, head_dim/2):
    base_freq = theta^(2i / head_dim)

    LLaMA 3 wavelen scaling:
        wavelen = 2 * pi * base_freq
        if wavelen > low_wavelen:
            freq[i] = base_freq * factor         (long wavelengths scaled)
        else if wavelen > high_wavelen:
            smooth = (orig_max_pos/wavelen - low_freq_factor)
                   / (high_freq_factor - low_freq_factor)
            freq[i] = base_freq / ((1-smooth)/factor + smooth)
        else:
            freq[i] = base_freq                  (short wavelengths unchanged)

    Linear scaling:
        freq[i] = base_freq * factor

The resulting frequency vector is uploaded to a GPU buffer (arch.rope_freqs). If no scaling is needed, the buffer stays null and the RoPE kernel computes frequencies on the fly from theta.

This precomputation avoids two problems. First, the wavelen formulas involve floating-point operations (pow, division, conditional branches) that would add latency if repeated on every token. Second, it keeps the GPU kernel simpler – it either reads precomputed frequencies from a buffer or computes the standard geometric series, never the complex LLaMA 3 formula.

Stage 10: Config Validation

if (config.dim == 0 || config.n_layers == 0 || config.vocab_size == 0 ||
    config.n_heads == 0 || config.head_dim == 0) {
    set_error("Invalid model config: ...");
    delete state;
    return nullptr;
}

A simple sanity check. If any critical dimension is zero, something went wrong during metadata parsing (corrupt file, unsupported format version, missing keys). Fail immediately rather than producing cryptic GPU errors downstream.

Stage 11: Tokenizer Loading

The load_tokenizer function is more involved than you might expect. It needs to handle two completely different tokenizer sources:

GGUF path:
    vocab  <-- weights.get_string_array("tokenizer.ggml.tokens")
    scores <-- weights.get_float_array("tokenizer.ggml.scores")
    merges <-- weights.get_string_array("tokenizer.ggml.merges")
    type   <-- weights.get_metadata_string("tokenizer.ggml.model")
    bos_id <-- weights.get_metadata_int("tokenizer.ggml.bos_token_id")
    eos_id <-- weights.get_metadata_int("tokenizer.ggml.eos_token_id")

HuggingFace path (MLX models):
    load_hf_tokenizer(model_dir, hf_data)
    --> parses tokenizer.json + tokenizer_config.json

After loading the raw vocabulary, akunu also scans for implicit stop tokens that are not the “official” EOS token but should still terminate generation:

const char *stop_tokens[] = {
    "<|im_end|>",     // ChatML
    "<|endoftext|>",  // GPT-2 style
    "<|eot_id|>",     // LLaMA 3
    "<end_of_turn>",  // Gemma
    "</s>",           // legacy
    nullptr
};

For each of these, if the string exists in the vocabulary with a different ID than the primary EOS, it is registered as an additional EOS token. This means the decode loop does not need to know about chat templates – it just checks tokenizer.is_eos(tok) and the tokenizer handles the multi-EOS logic.

Stage 12: KV Cache Allocation

int ctx = max_context > 0
    ? max_context
    : std::min((int)config.max_seq_len, chip.max_prefill_chunk);

state->kv_cache = KVCache::create(
    *state->device, config.n_layers, config.n_kv_heads, config.head_dim, ctx);

The KV cache is the largest single allocation in the system. For a 32-layer model with 8 KV heads, 128-dim heads, and a 4096-token context, the total is:

Per-layer buffer size:
    n_kv_heads * max_seq_len * head_dim * sizeof(FP16)
    = 8 * 4096 * 128 * 2
    = 8 MB

Total:
    n_layers * 2 (K + V) * 8 MB
    = 32 * 2 * 8 MB
    = 512 MB

The KVCache struct is a flat POD container – no virtual calls, no reference counting, no linked lists:

KVCache
+------------------+
| n_layers    = 32 |
| n_kv_heads  = 8  |
| head_dim    = 128|
| max_length  = 4096|
| current_length = 0|
| kv_stride   = 524288  (max_length * head_dim)
|                  |
| k_buffers[32]    |  <-- one GPU buffer per layer
| v_buffers[32]    |  <-- one GPU buffer per layer
+------------------+

All buffers are zero-filled at creation. The current_length field tracks how many positions have been written (by prefill or decode). The kv_stride is precomputed to avoid a multiplication in the attention kernel’s inner loop.

Stage 13: Scratch Buffer Allocation

state->scratch = ScratchBuffers::create(
    *state->device, state->config, ctx, chip.max_prefill_chunk);

Scratch buffers are the working memory for a single forward pass. They are allocated once and reused every time. No dynamic allocation ever happens in the hot path.

ScratchBuffers (decode -- single token)
+----------------------------------------+
| h0         [dim]         FP16   residual ping  |
| h1         [dim]         FP16   residual pong  |
| residual   [dim]         FP16   norm output    |
| qkv        [q+2*kv_dim] FP16   Q|K|V concat   |
| attn_out   [q_dim]       FP16   attention out  |
| post_norm  [dim]         FP16   Gemma temp     |
| ffn_gate   [2*ffn_dim]  FP16   gate|up fused   |
| ffn_up     [ffn_dim]    FP16   up projection   |
| ffn_act    [ffn_dim]    FP16   activation out  |
| logits     [vocab]       FP16   final logits   |
| token_ids  [max_chain]  U32    token buffer    |
+----------------------------------------+

ScratchBuffers (prefill -- batch)
+----------------------------------------+
| batch_h0        [chunk * dim]    FP16  |
| batch_h1        [chunk * dim]    FP16  |
| batch_residual  [chunk * dim]    FP16  |
| batch_q         [chunk * q_dim]  FP16  |
| batch_k         [chunk * kv_dim] FP16  |
| batch_v         [chunk * kv_dim] FP16  |
| batch_attn_out  [chunk * q_dim]  FP16  |
| batch_gate      [chunk * ffn]    FP16  |
| batch_up        [chunk * ffn]    FP16  |
| batch_act       [chunk * ffn]    FP16  |
| batch_post_norm [chunk * dim]    FP16  |
+----------------------------------------+

The dual sets of buffers – one for single-token decode, one for batched prefill – are a deliberate design choice. Decode uses GEMV (matrix-vector); prefill uses GEMM (matrix-matrix). They have completely different memory access patterns, and sharing buffers between them would either waste memory or require dynamic resizing.

Stage 14: Building the Greedy Dispatch Table

state->dispatch_table = build_dispatch_table(
    *state->device, *state->weights, state->config,
    state->arch, state->chip, state->kv_cache, state->scratch);

This is the most complex and most important step of model loading. The dispatch table is a flat array of DispatchCmd structs – one for every GPU kernel dispatch needed to process a single token through the entire transformer. Chapter 25 covers this in detail, but the high-level structure is:

Command sequence for one token:

  [0]   Embedding lookup
  [1]   Embedding scale (Gemma only)
  [2]   Initial RMSNorm

  For each layer:
    [3+N*k]   Fused QKV GEMV (or separate Q, K, V)
    [...]     Fused QK-norm + RoPE + KV write (or separate)
    [...]     Flash attention decode
    [...]     O projection GEMV
    [...]     Post-attn norm (Gemma only)
    [...]     Fused residual + FFN norm
    [...]     Fused gate+up GEMV (or separate)
    [...]     Fused SiLU+down GEMV (or separate activation + down)
    [...]     Post-FFN norm (Gemma only)
    [...]     Fused next attn norm

  [N-2] Output norm
  [N-1] Logit projection GEMV
  [N]   Argmax

For a 32-layer LLaMA model with all fusions enabled, this is typically around 160-200 commands.

Stage 15: Building the Sampled Dispatch Table

After the greedy table is built, model_loader constructs a second table for sampled (temperature > 0) decoding:

auto& greedy_cmds = state->dispatch_table.commands;
auto& sampled_cmds = state->sampled_dispatch_table.commands;

// Copy all commands except the final argmax
for (size_t i = 0; i + 1 < greedy_cmds.size(); i++)
    sampled_cmds.push_back(greedy_cmds[i]);

// Insert a Gumbel top-k noise kernel
DispatchCmd gumbel_cmd = DispatchCmd::make(gumbel_pso, Dim3(1), Dim3(1024));
gumbel_cmd.add_buffer(scratch.logits, 0, 0);
// ... set up GumbelTopKParams ...
gumbel_cmd.patch_type = DispatchCmd::PATCH_POSITION;

sampled_cmds.push_back(gumbel_cmd);
sampled_cmds.push_back(greedy_cmds.back());  // argmax still works!

The key insight: the Gumbel-max trick turns sampling into argmax. By adding temperature * Gumbel_noise to each logit before taking the argmax, you get a sample from the categorical distribution softmax(logits / temperature). This means the sampled table is identical to the greedy table except for one extra kernel inserted before the argmax. No CPU round-trip. No softmax. No probability computation. Just noise + argmax, entirely on the GPU.

Stage 16: Pre-Allocating Parameter Buffers

for (auto& cmd : cmds) {
    if (cmd.param_size > 0) {
        cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
    }
}

Each DispatchCmd can carry up to 64 bytes of inline parameter data. Rather than using Metal’s setBytes (which copies the data on every dispatch), akunu pre-allocates a small GPU buffer for each command’s parameters and uses setBuffer instead. On Unified Memory Architecture (UMA) systems like Apple Silicon, these buffers are CPU-accessible, so updating a parameter (like the position for RoPE) is a simple pointer write – no copy, no upload.

Stage 17: Warmup Forward Pass

printf("Warming up...\n");
uint32_t warmup_token = state->tokenizer.bos_id();
state->device->write_buffer(state->scratch.token_ids, &warmup_token, 4);
state->device->begin_encoding();
state->device->encode_dispatch_table(&state->dispatch_table, 0, 1);
state->device->end_encoding_sync();
state->kv_cache.reset();

Metal compiles Pipeline State Objects (PSOs) lazily – the first time a kernel is dispatched, the GPU driver compiles it from the metallib. This compilation can take tens of milliseconds per kernel. If we did not warm up, the first real inference would stutter terribly.

The warmup runs a complete forward pass for one token (the BOS token), forcing every PSO in the dispatch table to compile. Then it resets the KV cache to zero length, erasing any state written during warmup. The result: when the user sends their first prompt, every kernel is already compiled and ready to go.

Stage 18: Returning the Model Handle

printf("Ready.\n\n");
return state;

The ModelState pointer is returned as an opaque akunu_model_t handle. From here on, the caller uses this handle with akunu_generate, akunu_prefill, akunu_embed, etc. All the complexity of model loading is hidden behind a single pointer.

The Complete Flow Diagram

Let us put it all together in a single visual:

   akunu_load_model("model.gguf", "akunu.metallib", 4096)
   |
   +-- Device::create_default()
   |     |
   |     +-- MTLCreateSystemDefaultDevice()
   |     +-- query GPU cores, family
   |
   +-- load_metallib(device, path)
   |     |
   |     +-- try user path, then 4 fallback paths
   |     +-- device.load_library(path) --> MTLLibrary
   |
   +-- is_whisper_ggml(path)?
   |     |
   |     +-- NO (standard LLM path)
   |
   +-- WeightProvider::open(path)
   |     |
   |     +-- detect_format(path)  --> GGUF or MLX
   |     +-- GGUF: WeightStore::open() --> parse header, mmap tensors
   |     +-- MLX:  MLXWeightStore::open() --> parse JSON, mmap safetensors
   |
   +-- weights.get_config()
   |     |
   |     +-- AkunuModelConfig { dim=4096, layers=32, heads=32, ... }
   |
   +-- arch_from_config("llama", 4096)
   |     |
   |     +-- ArchDescriptor { silu_gate, no-qk-norm, rope_qkv_write, ... }
   |
   +-- device.chip_config()
   |     |
   |     +-- ChipConfig { cores=10, slc=8MB, chunk=64, ... }
   |
   +-- Resolve MLX RoPE, tied embeds, quant params
   |
   +-- init_rope_freqs(state, path)
   |     |
   |     +-- LLaMA3 scaling? compute wavelen-based freqs
   |     +-- Linear scaling? compute simple scaled freqs
   |     +-- Neither? leave rope_freqs null (use theta at runtime)
   |
   +-- Validate config (all dims nonzero)
   |
   +-- load_tokenizer(state, path)
   |     |
   |     +-- GGUF: read vocab/scores/merges from metadata
   |     +-- MLX: load_hf_tokenizer(dir)
   |     +-- Register extra stop tokens
   |
   +-- KVCache::create(device, 32, 8, 128, 4096)
   |     |
   |     +-- 32 layers x 2 (K+V) x 8MB = 512 MB GPU memory
   |
   +-- ScratchBuffers::create(device, config, 4096, 4096)
   |     |
   |     +-- decode: h0, h1, residual, qkv, attn_out, ffn, logits
   |     +-- prefill: batch versions of all the above
   |
   +-- build_dispatch_table(device, weights, config, arch, chip, kv, scratch)
   |     |
   |     +-- ~180 DispatchCmd structs
   |     +-- all PSOs resolved, all buffers bound, all params set
   |
   +-- Build sampled_dispatch_table (greedy + Gumbel + argmax)
   |
   +-- Pre-allocate param_buf for every command
   |
   +-- Warmup: run 1 token forward pass to compile all PSOs
   |     +-- begin_encoding()
   |     +-- encode_dispatch_table(&table, 0, 1)
   |     +-- end_encoding_sync()  <-- blocks until GPU done
   |     +-- kv_cache.reset()     <-- erase warmup state
   |
   +-- return state  (as opaque akunu_model_t)

Memory Budget

Here is a concrete example for Llama 3.1 8B (Q4_0, 4096 context) on M3 Pro:

Component               Size
-------------------------------------
Model weights (mmap)    ~4.3 GB
KV cache (32 layers)    ~512 MB
Scratch (decode)        ~2.5 MB
Scratch (prefill)       ~550 MB
Metallib (shaders)      ~2 MB
Tokenizer               ~3 MB
Dispatch table           ~30 KB
-------------------------------------
Total GPU-resident      ~5.4 GB

The weights dominate, as expected. The KV cache is the second largest allocation. Everything else is a rounding error.

Error Handling Philosophy

You may have noticed that akunu_load_model uses a simple pattern:

if (something_failed) {
    set_error("...");
    delete state;
    return nullptr;
}

There are no exceptions. There are no error codes. The function either returns a valid model handle or returns nullptr and sets a thread-local error string that the caller can retrieve with akunu_last_error(). This is the standard C API pattern – it works across language boundaries (Swift, Python, Rust FFI) and avoids the overhead and complexity of C++ exception handling in a performance-critical codebase.

The Whisper Path

For completeness, akunu_load_model also handles Whisper models. The flow is similar but with encoder-decoder-specific additions:

Whisper-specific stages:
  - load_whisper_model() instead of WeightProvider
  - arch_whisper() descriptor (LayerNorm, bias, cross-attention)
  - MelSpectrogram processor (with model's precomputed mel filters)
  - WhisperBuffers (encoder/decoder scratch)
  - Copy learned positional embeddings (encoder + decoder)
  - build_whisper_decode_table() (includes GPU suppress params)
  - Beam search buffers (5 beams x KV caches + intermediate buffers)

We will not dive deep into the Whisper path in this book, but it is worth knowing that the same akunu_load_model entry point handles both LLMs and audio models. The architecture descriptor pattern makes this possible without an explosion of conditional logic.

Summary

Model loading in akunu is a carefully ordered sequence of 18 stages that transforms a file path into a ready-to-run GPU inference pipeline. Every decision – architecture-specific behavior, chip-specific tuning, format-specific quirks – is resolved during loading and encoded into data structures (the ArchDescriptor, ChipConfig, DispatchTable) that the hot path reads but never modifies.

The result: zero allocation, zero branching, and zero format-awareness during inference. The next four chapters explore the data structures this loading process creates and how they are used to actually generate text at hundreds of tokens per second.