Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The Dispatch Table: Precompiled GPU Command Sequences

This is the chapter about akunu’s central innovation. If you had to distill the entire engine design into a single idea, it would be this: build the GPU command sequence once at model load time, then replay it for every token.

Most inference engines construct GPU commands on-the-fly during inference. For each token, they iterate through the model layers, look up weight tensors by name, create kernel parameter structs, resolve pipeline state objects, and emit dispatch commands. This per-token overhead is small in isolation – maybe a few hundred microseconds – but it adds up. At 70 tokens per second, each token has a 14ms budget. Spending 0.5ms on command construction is 3.5% overhead, and on smaller models it can be much worse.

Akunu eliminates this overhead entirely. The DispatchTable is a flat array of DispatchCmd structs – plain old data, no pointers to chase, no virtual calls, no hash table lookups. Built once, replayed thousands of times. The only per-token work is patching a few position-dependent fields.

The Core Data Structures

DispatchCmd: A Single GPU Command

Every GPU operation in a forward pass – every GEMV, every norm, every attention, every RoPE – is represented as a single DispatchCmd:

struct DispatchCmd {
    Pipeline pso;                   // Compiled compute pipeline

    // Buffer bindings
    static constexpr int MAX_BUFFERS = 8;
    Buffer buffers[MAX_BUFFERS];
    uint32_t offsets[MAX_BUFFERS];
    int buffer_count;

    // Inline params (up to 64 bytes)
    uint8_t param_bytes[64];
    int param_size;
    int param_index;

    // Pre-allocated GPU buffer for static params
    Buffer param_buf;

    // Secondary params
    uint8_t param2_bytes[16];
    int param2_size;
    int param2_index;

    // Threadgroup memory
    int tg_mem_bytes;
    int tg_mem_index;

    // Dispatch geometry
    Dim3 grid;
    Dim3 threadgroup;
    bool use_dispatch_threads;

    // Per-token patching
    enum PatchType : uint8_t {
        PATCH_NONE = 0,
        PATCH_TOKEN_OFFSET,
        PATCH_POSITION,
        PATCH_KV_SEQ_LEN,
        PATCH_POS_AND_KV,
        PATCH_ARGMAX_OUTPUT,
    };
    PatchType patch_type;
    int patch_offset_1;
    int patch_offset_2;
};

Let’s examine each section.

Pipeline (pso): The pre-compiled compute pipeline state object. Resolved once during table building via device.get_pipeline(). No string lookup at dispatch time.

Buffer bindings (buffers[], offsets[], buffer_count): Fixed-size array of up to 8 buffer bindings. Each entry has a Buffer handle and a byte offset. These are the weight buffers, scratch buffers, and KV cache buffers that the kernel reads from and writes to. All resolved at build time.

Inline parameters (param_bytes[64], param_size, param_index): The kernel’s parameter struct (dimensions, epsilon, strides, etc.) stored inline as raw bytes. 64 bytes is enough for every kernel parameter struct in akunu. The param_index is the argument buffer index for setBytes or setBuffer.

Pre-allocated param buffer (param_buf): For commands with static (non-position-dependent) parameters, a GPU buffer is pre-allocated containing the parameter data. At replay time, setBuffer is used instead of setBytes, avoiding the per-dispatch copy entirely.

Secondary params (param2_bytes[16]): Some kernels need two separate setBytes calls at different argument indices. The secondary param slot handles this.

Threadgroup memory (tg_mem_bytes, tg_mem_index): Some kernels require threadgroup memory allocation (e.g., GEMM tile staging). This is pre-computed.

Dispatch geometry (grid, threadgroup, use_dispatch_threads): Pre-computed grid and threadgroup dimensions. use_dispatch_threads selects between dispatchThreadgroups and dispatchThreads (the latter auto-computes grid size from total threads).

Per-token patching (patch_type, patch_offset_1, patch_offset_2): This is the key mechanism that makes replay possible despite per-token variation. More on this below.

Size and Alignment

A single DispatchCmd is approximately 280 bytes. For a 32-layer Llama model, the table contains roughly 200-250 commands, totaling ~56-70 KB. This fits comfortably in L2 cache, meaning the replay loop operates almost entirely from cache.1

DispatchTable: The Complete Sequence

struct DispatchTable {
    std::vector<DispatchCmd> commands;      // Hot path: dense command array
    std::vector<DispatchLabel> labels;      // Cold path: profiling labels
    int tokens_per_tg;
};

The hot/cold split is deliberate. Profiling labels (48-byte strings like “layer.0.attention”) are stored in a parallel vector, separate from the command array. During inference, the labels are never accessed – the inner loop iterates only the dense commands vector. During profiling, labels are accessed by index to annotate timing data.

Per-Token Patching

Here is the fundamental challenge: most of the forward pass is identical for every token – same weights, same kernels, same dispatch geometry. But a few things change:

  1. Position: RoPE needs the current sequence position. Attention needs the KV sequence length.
  2. Token offset: The embedding lookup reads from token_ids[token_index]. The argmax writes to token_ids[token_index + 1].

Akunu solves this with a small enum of patch types:

PatchTypeWhat ChangesWhere
PATCH_NONENothingMost commands (GEMV, norms, etc.)
PATCH_TOKEN_OFFSETbuffers[0].offset = token_index * 4Embedding lookup
PATCH_POSITIONparam_bytes[offset1] = positionRoPE
PATCH_KV_SEQ_LENparam_bytes[offset1] = position + 1Flash attention
PATCH_POS_AND_KVBoth position and KV lengthCombined RoPE+attention params
PATCH_ARGMAX_OUTPUTbuffers[1].offset = (token_index+1) * 4Argmax output

The patch_offset_1 and patch_offset_2 fields specify the byte offsets within param_bytes where the position/KV-length values should be written. These offsets are computed at build time using offsetof():

// From table_builder.cpp:
cmd.patch_type = DispatchCmd::PATCH_POSITION;
cmd.patch_offset_1 = offsetof(decltype(rope_params), pos);

At replay time, the patching is a simple memcpy + write:

uint8_t patched[64];
memcpy(patched, cmd.param_bytes, cmd.param_size);
*(uint32_t *)(patched + cmd.patch_offset_1) = (uint32_t)pos;
[enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];

For the ~200 commands in a typical table, only about 100 need patching (RoPE, attention, embedding, argmax). The other ~100 use PATCH_NONE and bind their pre-allocated param buffer with zero per-token work.

Table Building: The build_dispatch_table Function

The table builder in src/core/table_builder.cpp constructs the dispatch table by walking through the transformer architecture once:

DispatchTable build_dispatch_table(Device& device,
                                   WeightProvider& weights,
                                   const AkunuModelConfig& cfg,
                                   const ArchDescriptor& arch,
                                   const ChipConfig& chip,
                                   const KVCache& kv_cache,
                                   const ScratchBuffers& scratch);

The function takes everything it needs as parameters: the device (for pipeline creation), weights (for buffer handles), model config (for dimensions), arch descriptor (for architecture-specific behavior), chip config (for hardware tuning), KV cache (for cache buffer handles), and scratch buffers (for intermediate buffer handles).

It returns a fully-constructed DispatchTable with all commands, all buffers resolved, all pipelines compiled and cached, and all parameter structs serialized.

The CmdBuilder Helper

To reduce boilerplate, the table builder uses a fluent builder:

struct CmdBuilder {
    DispatchCmd cmd;
    DispatchTable& table;

    CmdBuilder& buf(Buffer b, int offset, int index) { ... }
    CmdBuilder& params(const T& p, int index) { ... }
    CmdBuilder& label(const char *l) { ... }
    CmdBuilder& threads() { cmd.use_dispatch_threads = true; return *this; }
    CmdBuilder& patch(PatchType type, int offset1, int offset2 = 0) { ... }
    CmdBuilder& tg_mem(int bytes, int index) { ... }
    void emit() { table.commands.push_back(cmd); }
};

Usage:

CmdBuilder(table, attn_pso, Dim3(1, n_heads), Dim3(1024))
    .buf(scratch.qkv, scratch.qkv_q_offset, 0)
    .buf(kv_cache.k_buffers[layer], 1)
    .buf(kv_cache.v_buffers[layer], 2)
    .buf(scratch.attn_out, 3)
    .params(attn_params, 4)
    .patch(DispatchCmd::PATCH_KV_SEQ_LEN,
           offsetof(decltype(attn_params), kv_seq_len))
    .label("layer.%d.attention", layer)
    .emit();

The Complete Forward Pass

The table builder emits commands in this order:

Dispatch Table Command Sequence (one token forward pass):

  1. EMBEDDING:  token_ids[i] -> h0           [PATCH_TOKEN_OFFSET]
     (embedding_scale — Gemma only)
  2. INITIAL RMSNORM: h0 -> residual

  3. LAYER LOOP (x n_layers):
     ├─ 3a. QKV GEMV (fused or 3 separate)
     ├─ 3b. QK-Norm + RoPE + KV Write        [PATCH_POSITION]
     ├─ 3c. Flash Attention                   [PATCH_KV_SEQ_LEN]
     ├─ 3d. O Projection GEMV
     ├─ (post_attn_norm — Gemma only)
     ├─ 3e. Fused Residual + FFN Norm
     ├─ 3f. Gate+Up GEMV (fused or 2 separate)
     ├─ 3g. Fused SiLU+Down GEMV (or separate)
     ├─ (post_ffn_norm — Gemma only)
     └─ 3h. Fused Next Attention Norm

  4. OUTPUT NORM: fused residual + RMSNorm
  5. LOGIT PROJECTION: residual -> vocab logits
  6. ARGMAX: logits -> token_ids[i+1]         [PATCH_ARGMAX_OUTPUT]

One Layer’s Commands in Detail

For a 32-layer Llama 3 8B model with Q4_0 quantization, fused QKV, and fused SiLU+down, one layer produces approximately 7 commands:

#CommandPipelinePatch TypeBuffers
1QKV GEMV (fused)gemv_q4_0NONEresidual, fused_qkv_w, qkv_buf
2RoPE + KV Writerope_qkv_write_f16POSITIONqkv_buf, k_cache[L], v_cache[L]
3Flash Attentionflash_attn_decode_par_f16KV_SEQ_LENqkv_buf, k_cache[L], v_cache[L], attn_out
4O Projectiongemv_q4_0NONEattn_out, o_weight, residual
5Fused Residual+FFN Normresidual_rmsnorm_f16NONEresidual, h0, ffn_norm_w, h1, attn_out
6Gate+Up GEMV (fused)gemv_q4_0NONEattn_out, fused_gu_w, ffn_gate
7Fused SiLU+Downgemv_q4_0_siluNONEffn_gate, ffn_gate+off, down_w, residual

Plus the fused-next-attention-norm for the transition to the next layer (1 more command). So roughly 8 commands per layer.

For 32 layers: 8 * 32 = 256 per-layer commands. Plus embedding (1), initial norm (1), output norm (1), logit projection (1), argmax (1) = 261 total commands.

The actual count varies by architecture and fusion decisions:

  • Llama (fused): ~7-8 commands/layer, ~230-260 total
  • Llama (unfused): ~10-12 commands/layer, ~330-390 total
  • Gemma (fused): ~9-10 commands/layer (extra post-norms), ~300-330 total
  • Qwen3 (fused): ~7-8 commands/layer, ~230-260 total

The table builder prints the count at the end:

printf("Dispatch table built: %zu commands per token\n", cmds.size());

The Replay Loop: encode_chain

The generic replay function in dispatch_table.h is straightforward:

inline void encode_chain(Device& device, const DispatchTable& table,
                         int start_position, int count) {
    const auto& cmds = table.commands;
    const int n_cmds = (int)cmds.size();

    for (int tok = 0; tok < count; tok++) {
        int pos = start_position + tok;

        for (int c = 0; c < n_cmds; c++) {
            const auto& cmd = cmds[c];

            device.set_pipeline(cmd.pso);

            // Set buffers (with per-token offset patching)
            for (int b = 0; b < cmd.buffer_count; b++) {
                int offset = cmd.offsets[b];
                if (cmd.patch_type == PATCH_TOKEN_OFFSET && b == 0)
                    offset = tok * 4;
                if (cmd.patch_type == PATCH_ARGMAX_OUTPUT && b == 1)
                    offset = (tok + 1) * 4;
                device.set_buffer(cmd.buffers[b], offset, b);
            }

            // Set params (with position patching)
            if (cmd.param_size > 0) {
                if (needs_patching(cmd.patch_type)) {
                    uint8_t patched[64];
                    memcpy(patched, cmd.param_bytes, cmd.param_size);
                    patch_position(patched, cmd, pos);
                    device.set_bytes(patched, cmd.param_size, cmd.param_index);
                } else {
                    device.set_bytes(cmd.param_bytes, cmd.param_size, cmd.param_index);
                }
            }

            // Secondary params, threadgroup memory
            if (cmd.param2_size > 0)
                device.set_bytes(cmd.param2_bytes, cmd.param2_size, cmd.param2_index);
            if (cmd.tg_mem_bytes > 0)
                device.set_threadgroup_memory(cmd.tg_mem_bytes, cmd.tg_mem_index);

            // Dispatch
            if (cmd.use_dispatch_threads)
                device.dispatch_threads(cmd.grid, cmd.threadgroup);
            else
                device.dispatch(cmd.grid, cmd.threadgroup);
        }
    }
}

This is the generic version that works through the Device virtual interface. It is correct but slow – each device.set_pipeline() call goes through a virtual function dispatch.

The Metal Fast Path

MetalDevice overrides encode_dispatch_table to eliminate virtual calls:

void MetalDevice::encode_dispatch_table(const void *table_ptr,
                                        int start_position, int count) {
    const DispatchCmd *__restrict cmds = table.commands.data();
    const int n_cmds = (int)table.commands.size();
    id<MTLComputeCommandEncoder> enc = STATE.encoder;
    uint8_t patched[64];  // reused stack buffer

    for (int tok = 0; tok < count; tok++) {
        const uint32_t pos = (uint32_t)(start_position + tok);
        const uint32_t kv_len = pos + 1;

        for (int c = 0; c < n_cmds; c++) {
            const DispatchCmd &__restrict cmd = cmds[c];
            const auto pt = cmd.patch_type;

            [enc setComputePipelineState:cmd.pso.handle];

            // Buffers: switch on patch type for offset patching
            switch (pt) {
            case PATCH_TOKEN_OFFSET:
                [enc setBuffer:cmd.buffers[0].handle offset:tok*4 atIndex:0];
                for (int b = 1; b < cmd.buffer_count; b++)
                    [enc setBuffer:cmd.buffers[b].handle offset:cmd.offsets[b] atIndex:b];
                break;
            case PATCH_ARGMAX_OUTPUT:
                [enc setBuffer:cmd.buffers[0].handle offset:cmd.offsets[0] atIndex:0];
                [enc setBuffer:cmd.buffers[1].handle offset:(tok+1)*4 atIndex:1];
                for (int b = 2; b < cmd.buffer_count; b++)
                    [enc setBuffer:cmd.buffers[b].handle offset:cmd.offsets[b] atIndex:b];
                break;
            default:
                for (int b = 0; b < cmd.buffer_count; b++)
                    [enc setBuffer:cmd.buffers[b].handle offset:cmd.offsets[b] atIndex:b];
                break;
            }

            // Params: setBytes for patched, setBuffer for static
            if (cmd.param_size > 0) {
                switch (pt) {
                case PATCH_POSITION:
                    memcpy(patched, cmd.param_bytes, cmd.param_size);
                    *(uint32_t*)(patched + cmd.patch_offset_1) = pos;
                    [enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
                    break;
                case PATCH_POS_AND_KV:
                    memcpy(patched, cmd.param_bytes, cmd.param_size);
                    *(uint32_t*)(patched + cmd.patch_offset_1) = pos;
                    *(uint32_t*)(patched + cmd.patch_offset_2) = kv_len;
                    [enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
                    break;
                case PATCH_KV_SEQ_LEN:
                    memcpy(patched, cmd.param_bytes, cmd.param_size);
                    *(uint32_t*)(patched + cmd.patch_offset_1) = kv_len;
                    [enc setBytes:patched length:cmd.param_size atIndex:cmd.param_index];
                    break;
                default:
                    if (cmd.param_buf.handle)
                        [enc setBuffer:cmd.param_buf.handle offset:0 atIndex:cmd.param_index];
                    else
                        [enc setBytes:cmd.param_bytes length:cmd.param_size atIndex:cmd.param_index];
                    break;
                }
            }

            // Secondary params + threadgroup memory + dispatch
            // ...
        }
    }
}

Key optimizations in the fast path:

  1. No virtual calls. Direct ObjC message sends to the encoder.
  2. __restrict pointers. Hints to the compiler that cmds does not alias anything else.
  3. Stack-allocated patched buffer. Reused across all commands, no heap allocation.
  4. Switch on patch_type. The common case (PATCH_NONE) falls through to the simple buffer-binding loop.
  5. Pre-allocated param buffers. Static params use setBuffer (zero per-token work) instead of setBytes (which copies data into the command buffer).

The setBuffer/setBytes Split

This deserves special attention because it is the most subtle optimization in the dispatch table.

Most commands (~60%) have PATCH_NONE – their parameters do not change between tokens. For these, the table builder pre-allocates a GPU buffer containing the parameter data:

// At the end of build_dispatch_table():
for (auto& cmd : cmds) {
    if (cmd.param_size > 0) {
        cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
    }
}

At replay time, these commands use setBuffer to bind the pre-allocated buffer. On UMA, this is essentially free – Metal just records a pointer in the command buffer. No data copy, no coherency cost.

For position-patched commands (~40%), setBytes is used. This copies the patched parameter data (up to 64 bytes) into the command buffer inline. This is also fast (<4KB threshold for Metal inline data), but it does involve a memcpy and a few bytes of command buffer space per dispatch.

The net effect: for a 260-command table, approximately 160 commands use zero-cost setBuffer and 100 commands use low-cost setBytes. The total per-token CPU overhead for parameter binding is roughly:

$$100 \text{ commands} \times 64 \text{ bytes/memcpy} = 6.4 \text{ KB of memcpy}$$

At memcpy throughput of ~50 GB/s on Apple Silicon, this takes approximately 0.1 microseconds. Negligible.

Chain Decode: The Complete Picture

Let me trace through a complete chain decode of 128 tokens to show how everything fits together:

Chain Decode: 128 Tokens in One Submission

  [1] CPU: Write first token ID to token_ids[0]
  [2] CPU: device.begin_encoding()
  [3] CPU: encode_dispatch_table(table, start_pos, 128)
      ─── 260 commands x 128 tokens = 33,280 Metal API calls ───
  [4] CPU: device.end_encoding_sync()
  [5] GPU: executes all 33,280 dispatches sequentially
      Token 0:   embed -> 32 layers -> logit -> argmax -> token_ids[1]
      Token 1:   embed(token_ids[1]) -> 32 layers -> ... -> token_ids[2]
      ...
      Token 127: ... -> token_ids[128]
  [6] CPU: Read 128 tokens from token_ids[1..128]

The crucial insight: tokens are chained on the GPU. Token 0’s argmax writes its output to token_ids[1]. Token 1’s embedding reads from token_ids[1]. This data dependency is resolved by the GPU’s sequential execution within a single command buffer – no CPU round-trip between tokens.2

The CPU’s only job is:

  1. Write the first token (4 bytes)
  2. Encode the command sequence (~2-5ms of Metal API calls for 128 tokens)
  3. Submit and wait (~14ms * 128 = ~1.8 seconds of GPU time)
  4. Read the results (128 * 4 = 512 bytes)

Pre-Allocated Parameter Buffers

After building all commands, the table builder pre-allocates GPU buffers for every command’s parameters:

for (auto& cmd : cmds) {
    if (cmd.param_size > 0) {
        cmd.param_buf = device.allocate(cmd.param_bytes, cmd.param_size);
    }
}

This creates one small Metal buffer (16-64 bytes) per command. For 260 commands, that is about 16 KB of additional GPU memory. The benefit is that setBuffer at replay time is a simple pointer bind, while setBytes would copy data into the command buffer.

For position-patched commands, param_buf is still allocated but is not used during replay (the patched version is passed via setBytes instead). This is a minor waste (~6 KB) but simplifies the code.

Profiling Labels

Labels are stored separately to keep the hot-path array dense:

struct DispatchLabel {
    char text[48];
};

struct DispatchTable {
    std::vector<DispatchCmd> commands;   // HOT: iterated every token
    std::vector<DispatchLabel> labels;   // COLD: accessed only during profiling
};

The label for command i is labels[i].text. Labels are set during table building:

table.set_last_label("layer.0.attention");

And accessed during profiling:

const char *label = table.label_at(cmd_index);

The finalize_labels() call at the end of table building pads the labels vector to match the commands vector size, ensuring safe index access.

Comparison to Other Engines

To appreciate the dispatch table approach, compare it to how other inference engines handle command construction:

EngineCommand ConstructionPer-Token Overhead
llama.cppBuild ggml graph -> schedule -> Metal encodeGraph traversal + scheduling per token
MLXLazy evaluation DAG -> compile -> Metal encodeDAG construction + compilation per sequence
vLLMPyTorch eager/compiled -> CUDA kernelsPython overhead + CUDA launch per kernel
AkunuPre-built dispatch table -> flat array replayPatch ~100 uint32 values + Metal encode

Akunu’s approach trades flexibility (you cannot dynamically change the model graph) for speed (zero per-token overhead beyond patching and Metal API calls). This is the right trade for inference, where the computation graph is static.

Limitations

The dispatch table approach has some constraints:

  1. Static graph only. The forward pass must be identical for every token (modulo position patching). Dynamic architectures with variable-length layers or conditional computation would not fit this model.

  2. Memory overhead. Each command stores buffer handles and param bytes inline. For a large model with many layers, this is ~70 KB – trivial compared to the gigabytes of weights.

  3. Single-device. The table is built for one device and assumes all buffers are on that device. Multi-device (tensor parallel) would require multiple tables with cross-device synchronization.

  4. No dynamic batching during decode. The table processes one token at a time (repeated N times). Batching multiple independent sequences would require separate tables or a more complex patching scheme.

Summary

The dispatch table is akunu’s key innovation. It pre-compiles the entire forward pass into a flat array of ~200-260 DispatchCmd structs, each containing a pre-resolved pipeline, pre-bound buffers, pre-serialized parameters, and pre-computed dispatch geometry. At inference time, the replay loop iterates this array, patching only position-dependent fields. The Metal fast path eliminates virtual call overhead, and the setBuffer/setBytes split minimizes per-token parameter binding cost. The result: the CPU spends almost all of its time doing useful Metal API calls, not constructing or resolving commands.



  1. The L2 cache on Apple Silicon GPU cores is approximately 256 KB per core cluster (estimated from die analysis). A 70 KB dispatch table fits entirely within a single core cluster’s L2. See Dougall Johnson, “Apple GPU Architecture,” https://dougallj.github.io/applegpu/.

  2. Metal guarantees that compute dispatches within a single compute command encoder execute in order. The data dependency between token N’s argmax output and token N+1’s embedding input is automatically satisfied by this ordering guarantee. No explicit barriers are needed. See Apple, “Metal Programming Guide: Command Organization,” https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu.