Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Adding a New Metal Kernel

This chapter walks through adding a new Metal kernel to Akunu from scratch. We will build a complete example – a hypothetical vector_scale_f16 kernel that multiplies every element of an F16 buffer by a scalar – and trace every step from the .metal source file to the dispatch table integration and test.

The process has six steps:

1. Write the .metal shader
2. Define parameter structs (shared between CPU and GPU)
3. Add to the metallib build
4. Dispatch from Device
5. Integrate into table_builder (if it is part of inference)
6. Write a test

Let us go through each one.

Step 1: Write the .metal Shader

Metal shader files live in Akunu’s shader directory. Create a new file or add to an existing one. For our example, we will create the kernel inline:

// In an existing .metal file, or a new one

#include <metal_stdlib>
using namespace metal;

/// Multiply every element of an F16 buffer by a scalar.
///
/// Grid:  1D, total threads >= count
/// TG:    256 threads (or whatever fits)
///
/// Buffers:
///   [0] device half *data    -- input/output (in-place)
///
/// Params:
///   [1] constant float &scale
///   [2] constant uint  &count
kernel void vector_scale_f16(
    device half *data         [[buffer(0)]],
    constant float &scale     [[buffer(1)]],
    constant uint &count      [[buffer(2)]],
    uint tid                  [[thread_position_in_grid]])
{
    if (tid >= count) return;
    data[tid] = half(float(data[tid]) * scale);
}

Key points for Metal kernel design in Akunu:

Buffer binding convention. Akunu binds buffers by index using device.set_buffer(buf, offset, index) or cmd.add_buffer(buf, offset, index). The index in the Metal signature ([[buffer(N)]]) must match. By convention, input buffers come first, output buffers next, and parameters last.

Parameter passing. Small parameters (< 4KB) are passed via setBytes (which Akunu wraps as device.set_bytes() or cmd.set_params()). For dispatch tables, params are pre-allocated as GPU buffers for zero-copy patching.

Thread indexing. Use [[thread_position_in_grid]] for 1D kernels, or [[threadgroup_position_in_grid]] + [[thread_position_in_threadgroup]] for kernels that need threadgroup-level coordination. Akunu supports both dispatch() (grid in threadgroups) and dispatch_threads() (grid in threads).

Bounds checking. Always check tid >= count at the top. Akunu’s dispatch grid is rounded up to the threadgroup size, so trailing threads may be out of bounds.

Data types. Use half for F16, bfloat for BF16 (M4+ only via metal_bfloat header), float for F32, uint for U32. For quantized formats, you typically read uint or uchar and manually unpack.

Step 2: Define Parameter Structs

If your kernel takes structured parameters (more than a single scalar), define a shared struct. Akunu’s convention is to use C structs that are layout-compatible between CPU and GPU. Metal’s constant buffer binding expects 16-byte aligned access for best performance, so pad structs to 16-byte boundaries:

// Shared between CPU (.cpp) and GPU (.metal)
struct VectorScaleParams {
    uint32_t count;
    float    scale;
    uint32_t _pad0;   // pad to 16 bytes
    uint32_t _pad1;
};

For simple cases like our example, you can skip the struct and pass individual scalars. But for anything with more than 2-3 parameters, a struct keeps things clean and avoids off-by-one buffer index mistakes.

The parameter structs used throughout Akunu follow a consistent pattern:

// GEMM parameters (used by all GEMV/GEMM kernels)
struct {
    uint32_t M, N, K, lda, ldb, ldc;
    float alpha, beta;
} params;  // 32 bytes, naturally aligned

// MLX quantized GEMV parameters
struct {
    uint32_t M, N, K, group_size, bits, weight_bytes, _p0, _p1;
} mlx_params;  // 32 bytes

// Norm parameters
struct {
    uint32_t dim;
    float eps;
    uint32_t _p0, _p1;
} norm_params;  // 16 bytes

Step 3: Add to the Metallib Build

Akunu compiles all .metal files into a single .metallib using xcrun metal and xcrun metallib. The build process is:

*.metal files
     |
     v
xcrun metal -c -target air64-apple-macos14.0 -o file.air file.metal
     |
     v (repeat for each .metal file)
     |
xcrun metallib -o akunu.metallib *.air

If you created a new .metal file, add it to the build script or Makefile. If you added to an existing file, it will be picked up automatically.

To verify your kernel compiled successfully:

# List all functions in the metallib
xcrun metal-objdump --disassemble-all akunu.metallib 2>&1 | grep "^_"

You should see _vector_scale_f16 in the output.

Step 4: Dispatch from Device

For one-off or encoder-phase dispatches (not in the per-token dispatch table), you call the kernel directly through Device:

void apply_scale(Device& dev, Buffer data, float scale, int count) {
    // Get (or create and cache) the pipeline state object
    Pipeline pso = dev.get_pipeline("vector_scale_f16");

    // Set up the dispatch
    dev.set_pipeline(pso);
    dev.set_buffer(data, 0, 0);       // buffer index 0
    dev.set_bytes(&scale, sizeof(float), 1);  // buffer index 1
    uint32_t n = (uint32_t)count;
    dev.set_bytes(&n, sizeof(uint32_t), 2);   // buffer index 2

    // Dispatch: total threads = count, threadgroup size = 256
    dev.dispatch_threads(Dim3(count), Dim3(256));
}

The get_pipeline() call compiles the kernel function from the metallib on first use and caches the resulting MTLComputePipelineState. Subsequent calls with the same kernel name return the cached pipeline instantly.

Function constants. If your kernel needs compile-time specialization (e.g., head dimension, group size), use function constants:

// In Metal:
constant uint HEAD_DIM [[function_constant(0)]];

// In C++:
uint32_t fc_indices[] = {0};
uint32_t fc_values[] = {128};  // HEAD_DIM = 128
Pipeline pso = dev.get_pipeline("my_kernel",
    "my_kernel_hd128",       // cache key
    fc_indices, fc_values, 1);

Function constants create specialized pipeline variants at runtime. Akunu caches each variant by its cache key string.

Step 5: Integrate into the Dispatch Table

If your kernel runs as part of the per-token forward pass, it needs to be added to the dispatch table built by build_dispatch_table() in table_builder.cpp. The dispatch table is built once during model initialization and then replayed for every token.

Using the CmdBuilder helper:

// Inside build_dispatch_table(), at the appropriate point
// in the layer loop:
{
    Pipeline pso = device.get_pipeline("vector_scale_f16");

    struct {
        uint32_t count;
        float scale;
        uint32_t _p0, _p1;
    } params = {(uint32_t)dim, some_scale_value, 0, 0};

    CmdBuilder(table, pso, Dim3((dim + 255) / 256), Dim3(256))
        .threads()                    // use dispatch_threads
        .buf(scratch.h0, 0)           // buffer binding 0
        .params(params, 1)            // params at buffer index 1
        .label("layer.%d.scale", layer)
        .emit();
}

Or using the raw DispatchCmd API:

{
    Pipeline pso = device.get_pipeline("vector_scale_f16");

    DispatchCmd cmd = DispatchCmd::make(pso,
        Dim3((dim + 255) / 256),  // grid (in threadgroups)
        Dim3(256));               // threadgroup size
    cmd.use_dispatch_threads = true;

    cmd.add_buffer(scratch.h0, 0, 0);  // (buffer, offset, index)

    struct { uint32_t count; float scale; uint32_t _p0, _p1; }
        params = {(uint32_t)dim, scale, 0, 0};
    cmd.set_params(&params, sizeof(params), 1);

    // Optional: patch a field at runtime (e.g., position)
    // cmd.patch_type = DispatchCmd::PATCH_POSITION;
    // cmd.patch_offset_1 = offsetof(decltype(params), some_field);

    table.commands.push_back(cmd);
    table.set_last_label("my_scale");
}

Dispatch table patching. Some parameters change every token (position, KV sequence length, argmax output offset). The dispatch table supports patching specific fields in pre-allocated parameter buffers:

Patch typeWhat it patchesUsed by
PATCH_POSITIONPosition field in RoPE/PE paramsRoPE, positional embedding
PATCH_KV_SEQ_LENKV sequence length in attention paramsFlash attention
PATCH_TOKEN_OFFSETToken index in embedding paramsEmbedding lookup
PATCH_ARGMAX_OUTPUTOutput offset in token chain bufferArgmax

If your kernel needs a patched field, set the patch_type and patch_offset_1 to indicate which field in the parameter struct should be updated by the chain decoder.

Step 6: Write a Test

Test your kernel with known inputs and expected outputs. A typical test pattern:

void test_vector_scale() {
    Device dev;
    dev.init("akunu.metallib");

    // Allocate and fill input buffer
    const int N = 1024;
    std::vector<uint16_t> input_f16(N);
    for (int i = 0; i < N; i++) {
        __fp16 val = (__fp16)(i * 0.1f);
        memcpy(&input_f16[i], &val, 2);
    }
    Buffer buf = dev.allocate(input_f16.data(), N * 2);

    // Dispatch kernel
    float scale = 2.0f;
    Pipeline pso = dev.get_pipeline("vector_scale_f16");
    dev.begin_encoding();
    dev.set_pipeline(pso);
    dev.set_buffer(buf, 0, 0);
    dev.set_bytes(&scale, sizeof(float), 1);
    uint32_t count = N;
    dev.set_bytes(&count, sizeof(uint32_t), 2);
    dev.dispatch_threads(Dim3(N), Dim3(256));
    dev.end_encoding_sync();

    // Verify results
    const __fp16 *result = (const __fp16 *)buf.contents;
    for (int i = 0; i < N; i++) {
        float expected = i * 0.1f * 2.0f;
        float actual = (float)result[i];
        assert(fabsf(actual - expected) < 0.01f);
    }

    dev.free_buffer(buf);
    printf("vector_scale_f16: PASSED\n");
}

Key testing tips:

  • Always sync before reading. Call dev.end_encoding_sync() to ensure the GPU has finished before reading buf.contents on CPU. UMA makes the memory accessible, but coherence requires the command buffer to complete.
  • Test boundary conditions. Test with N not divisible by threadgroup size (e.g., N=1000 with TG=256) to verify your bounds check works.
  • Test numerical accuracy. F16 has limited precision (about 3 decimal digits). Use tolerances appropriate for the data type.
  • Test with dispatch tables too. If your kernel will run in the dispatch table, test it through the table replay path, not just direct dispatch.

Complete Walkthrough: Adding a Hypothetical Kernel

Let us trace the complete path for a more realistic example: a fused LayerNorm + linear projection kernel for Whisper. This would combine the FFN layernorm and the up-projection GEMV into a single dispatch.

1. Metal Shader

kernel void layernorm_gemv_f16(
    device const half *input       [[buffer(0)]],
    device const half *norm_weight  [[buffer(1)]],
    device const half *norm_bias    [[buffer(2)]],
    device const half *proj_weight  [[buffer(3)]],
    device const half *proj_bias    [[buffer(4)]],
    device half *output            [[buffer(5)]],
    constant uint &dim             [[buffer(6)]],
    constant float &eps            [[buffer(7)]],
    constant uint &out_dim         [[buffer(8)]],
    uint2 gid  [[threadgroup_position_in_grid]],
    uint  tid  [[thread_position_in_threadgroup]])
{
    // Phase 1: LayerNorm (cooperative across threadgroup)
    // ... compute mean, variance, normalize in shared memory ...

    // Phase 2: GEMV (each threadgroup handles some output rows)
    // ... dot product of normalized input with weight rows ...
}

2. Parameter Struct

struct LayerNormGEMVParams {
    uint32_t dim;       // input dimension
    float    eps;       // norm epsilon
    uint32_t out_dim;   // output dimension (proj rows)
    uint32_t _pad;
};

3. Build

Add the .metal file to the build, recompile metallib.

4. Dispatch Integration

static void w_emit_layernorm_gemv(DispatchTable& tbl, Device& dev,
    Buffer in, Buffer norm_w, Buffer norm_b,
    Buffer proj_w, Buffer proj_b, Buffer out,
    int dim, float eps, int out_dim)
{
    Pipeline pso = dev.get_pipeline("layernorm_gemv_f16");
    if (!pso.handle) {
        // Fallback: separate layernorm + GEMV
        w_emit_layernorm(tbl, dev, in, norm_w, norm_b, out, dim, eps);
        w_emit_gemv(tbl, dev, out, proj_w, out, 0, out_dim, dim);
        return;
    }
    // ... fused dispatch ...
}

5. Use in Decoder Table

// In build_whisper_decode_table(), replace:
//   w_emit_layernorm(...)
//   w_emit_gemv_bias_gelu(...)
// With:
//   w_emit_layernorm_gemv(...)
//   w_emit_gelu(...)  // GELU still separate

6. Test

Compare output of the fused kernel against separate layernorm + GEMV on the same input data.

Performance Considerations

Writing a correct kernel is the first challenge. Making it fast is the second. Here are the performance patterns that matter most on Apple Silicon:

Memory Bandwidth is King

For inference-sized problems (M=1 GEMV), the kernel is almost always memory-bandwidth-bound, not compute-bound. The Apple M2 Pro has ~200 GB/s of memory bandwidth and ~7 TFLOPS of F16 compute. For a 4096x4096 GEMV:

Data to read:  4096 * 4096 * 2 bytes = 32 MB (F16 weights)
Compute:       4096 * 4096 * 2 FLOPS = 33.5 MFLOPS
Time at bandwidth limit: 32 MB / 200 GB/s = 0.16 ms
Time at compute limit:   33.5 MFLOPS / 7 TFLOPS = 0.000005 ms
Arithmetic intensity: 33.5M / 32M = ~1 FLOP/byte

The arithmetic intensity of 1 FLOP/byte is well below the crossover point (~50 FLOP/byte for M2 Pro). This means your kernel’s speed is determined almost entirely by how efficiently it reads memory. Optimize for coalesced reads, minimize bank conflicts, and avoid reading the same data multiple times.

Threadgroup Size Selection

The optimal threadgroup size depends on the kernel’s register and threadgroup memory usage. Some guidelines:

Kernel typeRecommended TG sizeRationale
Simple element-wise256Good occupancy, minimal overhead
GEMV (F16)12816 rows x 8 threads per row
GEMV (quantized)256More ALU work for dequant, need threads
Reduction (norm)min(dim, 1024)One TG per row, use all SIMD groups
Attention32 (1 simdgroup)SIMD-level matmul, shared memory

Simdgroup Operations

Apple Silicon GPUs have 32-wide SIMD groups. Use simd_sum(), simd_shuffle(), and simdgroup matrix multiply (simdgroup_matrix<half, 8, 8>) when possible. These operations are dramatically faster than equivalent threadgroup-memory-based reductions because they operate within the register file.

Avoiding Threadgroup Memory Bottlenecks

Threadgroup memory on Apple Silicon is carved from the same SRAM as the L1 cache. Using too much threadgroup memory reduces cache capacity and can hurt performance. For simple kernels, prefer passing data through registers (via SIMD shuffles) over threadgroup memory.

Debugging Metal Kernels

When things go wrong (and they will), here are the tools:

Metal GPU Capture. Xcode’s GPU debugger lets you capture a frame of GPU work and inspect every buffer, every dispatch, and every thread’s execution. This is the single most useful debugging tool for Metal kernels.

Printf from Metal. Metal supports printf in shaders (with performance caveats). Add #include <metal_stdlib> and use printf("tid=%u val=%f\n", tid, float(data[tid])). Output appears in Xcode’s console. Be aware that printf from GPU threads is non-deterministic in ordering and can significantly slow down execution. Use it for debugging specific threads, not for bulk output.

Validation layers. Enable Metal validation (MTL_DEBUG_LAYER=1 environment variable) to catch out-of-bounds buffer access, uninitialized reads, and other errors. You can also set MTL_SHADER_VALIDATION=1 for even stricter checks, though this has a larger performance impact.

Numerical debugging. When you get wrong numerical results, dump the intermediate buffers to files and compare against a Python reference implementation. Many quantization bugs are off-by-one in bit extraction or wrong byte ordering. A useful technique is to write a Python script that reads the same weight file and computes the expected output for a known input vector.

Dispatch geometry. A very common bug is getting the grid or threadgroup dimensions wrong. If your kernel silently produces zeros or garbage, check:

  • Is the grid large enough to cover all elements?
  • Is the threadgroup size compatible with the kernel’s requirements?
  • Are you using dispatch() (grid = number of threadgroups) vs dispatch_threads() (grid = total threads)?
  • Does your kernel’s [[threads_per_threadgroup]] attribute match the TG size you are dispatching?

Common pitfalls checklist:

SymptomLikely cause
All zeros outputWrong buffer binding index, or grid too small
NaN/Inf outputDivision by zero in normalization, or uninitialized buffer
Correct for small N, wrong for large NOverflow in index calculation (use uint64_t for offsets)
Non-deterministic resultsRace condition in threadgroup memory, or missing barrier
Slightly wrong valuesF16 precision loss, or wrong quantization offset (e.g., -8 vs -16)
Kernel not foundTypo in kernel name, or .metal file not in metallib build

Summary

Adding a kernel to Akunu follows a predictable path:

.metal file -> metallib build -> Device::get_pipeline()
                                      |
                              +-------+-------+
                              |               |
                     Direct dispatch    DispatchCmd in table
                     (encoder, init)    (per-token decode)
                              |               |
                              +-------+-------+
                                      |
                                   Test

The data-driven design means the kernel itself is self-contained – it does not need to know about architectures, quantization formats, or model configs. Those concerns are handled by the dispatch table builder and the dtype descriptor table. Your kernel just needs to do one thing correctly: take input buffers, apply a computation, and write output buffers.