Adding a New Metal Kernel
This chapter walks through adding a new Metal kernel to Akunu from scratch. We will build a complete example – a hypothetical vector_scale_f16 kernel that multiplies every element of an F16 buffer by a scalar – and trace every step from the .metal source file to the dispatch table integration and test.
The process has six steps:
1. Write the .metal shader
2. Define parameter structs (shared between CPU and GPU)
3. Add to the metallib build
4. Dispatch from Device
5. Integrate into table_builder (if it is part of inference)
6. Write a test
Let us go through each one.
Step 1: Write the .metal Shader
Metal shader files live in Akunu’s shader directory. Create a new file or add to an existing one. For our example, we will create the kernel inline:
// In an existing .metal file, or a new one
#include <metal_stdlib>
using namespace metal;
/// Multiply every element of an F16 buffer by a scalar.
///
/// Grid: 1D, total threads >= count
/// TG: 256 threads (or whatever fits)
///
/// Buffers:
/// [0] device half *data -- input/output (in-place)
///
/// Params:
/// [1] constant float &scale
/// [2] constant uint &count
kernel void vector_scale_f16(
device half *data [[buffer(0)]],
constant float &scale [[buffer(1)]],
constant uint &count [[buffer(2)]],
uint tid [[thread_position_in_grid]])
{
if (tid >= count) return;
data[tid] = half(float(data[tid]) * scale);
}
Key points for Metal kernel design in Akunu:
Buffer binding convention. Akunu binds buffers by index using device.set_buffer(buf, offset, index) or cmd.add_buffer(buf, offset, index). The index in the Metal signature ([[buffer(N)]]) must match. By convention, input buffers come first, output buffers next, and parameters last.
Parameter passing. Small parameters (< 4KB) are passed via setBytes (which Akunu wraps as device.set_bytes() or cmd.set_params()). For dispatch tables, params are pre-allocated as GPU buffers for zero-copy patching.
Thread indexing. Use [[thread_position_in_grid]] for 1D kernels, or [[threadgroup_position_in_grid]] + [[thread_position_in_threadgroup]] for kernels that need threadgroup-level coordination. Akunu supports both dispatch() (grid in threadgroups) and dispatch_threads() (grid in threads).
Bounds checking. Always check tid >= count at the top. Akunu’s dispatch grid is rounded up to the threadgroup size, so trailing threads may be out of bounds.
Data types. Use half for F16, bfloat for BF16 (M4+ only via metal_bfloat header), float for F32, uint for U32. For quantized formats, you typically read uint or uchar and manually unpack.
Step 2: Define Parameter Structs
If your kernel takes structured parameters (more than a single scalar), define a shared struct. Akunu’s convention is to use C structs that are layout-compatible between CPU and GPU. Metal’s constant buffer binding expects 16-byte aligned access for best performance, so pad structs to 16-byte boundaries:
// Shared between CPU (.cpp) and GPU (.metal)
struct VectorScaleParams {
uint32_t count;
float scale;
uint32_t _pad0; // pad to 16 bytes
uint32_t _pad1;
};
For simple cases like our example, you can skip the struct and pass individual scalars. But for anything with more than 2-3 parameters, a struct keeps things clean and avoids off-by-one buffer index mistakes.
The parameter structs used throughout Akunu follow a consistent pattern:
// GEMM parameters (used by all GEMV/GEMM kernels)
struct {
uint32_t M, N, K, lda, ldb, ldc;
float alpha, beta;
} params; // 32 bytes, naturally aligned
// MLX quantized GEMV parameters
struct {
uint32_t M, N, K, group_size, bits, weight_bytes, _p0, _p1;
} mlx_params; // 32 bytes
// Norm parameters
struct {
uint32_t dim;
float eps;
uint32_t _p0, _p1;
} norm_params; // 16 bytes
Step 3: Add to the Metallib Build
Akunu compiles all .metal files into a single .metallib using xcrun metal and xcrun metallib. The build process is:
*.metal files
|
v
xcrun metal -c -target air64-apple-macos14.0 -o file.air file.metal
|
v (repeat for each .metal file)
|
xcrun metallib -o akunu.metallib *.air
If you created a new .metal file, add it to the build script or Makefile. If you added to an existing file, it will be picked up automatically.
To verify your kernel compiled successfully:
# List all functions in the metallib
xcrun metal-objdump --disassemble-all akunu.metallib 2>&1 | grep "^_"
You should see _vector_scale_f16 in the output.
Step 4: Dispatch from Device
For one-off or encoder-phase dispatches (not in the per-token dispatch table), you call the kernel directly through Device:
void apply_scale(Device& dev, Buffer data, float scale, int count) {
// Get (or create and cache) the pipeline state object
Pipeline pso = dev.get_pipeline("vector_scale_f16");
// Set up the dispatch
dev.set_pipeline(pso);
dev.set_buffer(data, 0, 0); // buffer index 0
dev.set_bytes(&scale, sizeof(float), 1); // buffer index 1
uint32_t n = (uint32_t)count;
dev.set_bytes(&n, sizeof(uint32_t), 2); // buffer index 2
// Dispatch: total threads = count, threadgroup size = 256
dev.dispatch_threads(Dim3(count), Dim3(256));
}
The get_pipeline() call compiles the kernel function from the metallib on first use and caches the resulting MTLComputePipelineState. Subsequent calls with the same kernel name return the cached pipeline instantly.
Function constants. If your kernel needs compile-time specialization (e.g., head dimension, group size), use function constants:
// In Metal:
constant uint HEAD_DIM [[function_constant(0)]];
// In C++:
uint32_t fc_indices[] = {0};
uint32_t fc_values[] = {128}; // HEAD_DIM = 128
Pipeline pso = dev.get_pipeline("my_kernel",
"my_kernel_hd128", // cache key
fc_indices, fc_values, 1);
Function constants create specialized pipeline variants at runtime. Akunu caches each variant by its cache key string.
Step 5: Integrate into the Dispatch Table
If your kernel runs as part of the per-token forward pass, it needs to be added to the dispatch table built by build_dispatch_table() in table_builder.cpp. The dispatch table is built once during model initialization and then replayed for every token.
Using the CmdBuilder helper:
// Inside build_dispatch_table(), at the appropriate point
// in the layer loop:
{
Pipeline pso = device.get_pipeline("vector_scale_f16");
struct {
uint32_t count;
float scale;
uint32_t _p0, _p1;
} params = {(uint32_t)dim, some_scale_value, 0, 0};
CmdBuilder(table, pso, Dim3((dim + 255) / 256), Dim3(256))
.threads() // use dispatch_threads
.buf(scratch.h0, 0) // buffer binding 0
.params(params, 1) // params at buffer index 1
.label("layer.%d.scale", layer)
.emit();
}
Or using the raw DispatchCmd API:
{
Pipeline pso = device.get_pipeline("vector_scale_f16");
DispatchCmd cmd = DispatchCmd::make(pso,
Dim3((dim + 255) / 256), // grid (in threadgroups)
Dim3(256)); // threadgroup size
cmd.use_dispatch_threads = true;
cmd.add_buffer(scratch.h0, 0, 0); // (buffer, offset, index)
struct { uint32_t count; float scale; uint32_t _p0, _p1; }
params = {(uint32_t)dim, scale, 0, 0};
cmd.set_params(¶ms, sizeof(params), 1);
// Optional: patch a field at runtime (e.g., position)
// cmd.patch_type = DispatchCmd::PATCH_POSITION;
// cmd.patch_offset_1 = offsetof(decltype(params), some_field);
table.commands.push_back(cmd);
table.set_last_label("my_scale");
}
Dispatch table patching. Some parameters change every token (position, KV sequence length, argmax output offset). The dispatch table supports patching specific fields in pre-allocated parameter buffers:
| Patch type | What it patches | Used by |
|---|---|---|
PATCH_POSITION | Position field in RoPE/PE params | RoPE, positional embedding |
PATCH_KV_SEQ_LEN | KV sequence length in attention params | Flash attention |
PATCH_TOKEN_OFFSET | Token index in embedding params | Embedding lookup |
PATCH_ARGMAX_OUTPUT | Output offset in token chain buffer | Argmax |
If your kernel needs a patched field, set the patch_type and patch_offset_1 to indicate which field in the parameter struct should be updated by the chain decoder.
Step 6: Write a Test
Test your kernel with known inputs and expected outputs. A typical test pattern:
void test_vector_scale() {
Device dev;
dev.init("akunu.metallib");
// Allocate and fill input buffer
const int N = 1024;
std::vector<uint16_t> input_f16(N);
for (int i = 0; i < N; i++) {
__fp16 val = (__fp16)(i * 0.1f);
memcpy(&input_f16[i], &val, 2);
}
Buffer buf = dev.allocate(input_f16.data(), N * 2);
// Dispatch kernel
float scale = 2.0f;
Pipeline pso = dev.get_pipeline("vector_scale_f16");
dev.begin_encoding();
dev.set_pipeline(pso);
dev.set_buffer(buf, 0, 0);
dev.set_bytes(&scale, sizeof(float), 1);
uint32_t count = N;
dev.set_bytes(&count, sizeof(uint32_t), 2);
dev.dispatch_threads(Dim3(N), Dim3(256));
dev.end_encoding_sync();
// Verify results
const __fp16 *result = (const __fp16 *)buf.contents;
for (int i = 0; i < N; i++) {
float expected = i * 0.1f * 2.0f;
float actual = (float)result[i];
assert(fabsf(actual - expected) < 0.01f);
}
dev.free_buffer(buf);
printf("vector_scale_f16: PASSED\n");
}
Key testing tips:
- Always sync before reading. Call
dev.end_encoding_sync()to ensure the GPU has finished before readingbuf.contentson CPU. UMA makes the memory accessible, but coherence requires the command buffer to complete. - Test boundary conditions. Test with N not divisible by threadgroup size (e.g., N=1000 with TG=256) to verify your bounds check works.
- Test numerical accuracy. F16 has limited precision (about 3 decimal digits). Use tolerances appropriate for the data type.
- Test with dispatch tables too. If your kernel will run in the dispatch table, test it through the table replay path, not just direct dispatch.
Complete Walkthrough: Adding a Hypothetical Kernel
Let us trace the complete path for a more realistic example: a fused LayerNorm + linear projection kernel for Whisper. This would combine the FFN layernorm and the up-projection GEMV into a single dispatch.
1. Metal Shader
kernel void layernorm_gemv_f16(
device const half *input [[buffer(0)]],
device const half *norm_weight [[buffer(1)]],
device const half *norm_bias [[buffer(2)]],
device const half *proj_weight [[buffer(3)]],
device const half *proj_bias [[buffer(4)]],
device half *output [[buffer(5)]],
constant uint &dim [[buffer(6)]],
constant float &eps [[buffer(7)]],
constant uint &out_dim [[buffer(8)]],
uint2 gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]])
{
// Phase 1: LayerNorm (cooperative across threadgroup)
// ... compute mean, variance, normalize in shared memory ...
// Phase 2: GEMV (each threadgroup handles some output rows)
// ... dot product of normalized input with weight rows ...
}
2. Parameter Struct
struct LayerNormGEMVParams {
uint32_t dim; // input dimension
float eps; // norm epsilon
uint32_t out_dim; // output dimension (proj rows)
uint32_t _pad;
};
3. Build
Add the .metal file to the build, recompile metallib.
4. Dispatch Integration
static void w_emit_layernorm_gemv(DispatchTable& tbl, Device& dev,
Buffer in, Buffer norm_w, Buffer norm_b,
Buffer proj_w, Buffer proj_b, Buffer out,
int dim, float eps, int out_dim)
{
Pipeline pso = dev.get_pipeline("layernorm_gemv_f16");
if (!pso.handle) {
// Fallback: separate layernorm + GEMV
w_emit_layernorm(tbl, dev, in, norm_w, norm_b, out, dim, eps);
w_emit_gemv(tbl, dev, out, proj_w, out, 0, out_dim, dim);
return;
}
// ... fused dispatch ...
}
5. Use in Decoder Table
// In build_whisper_decode_table(), replace:
// w_emit_layernorm(...)
// w_emit_gemv_bias_gelu(...)
// With:
// w_emit_layernorm_gemv(...)
// w_emit_gelu(...) // GELU still separate
6. Test
Compare output of the fused kernel against separate layernorm + GEMV on the same input data.
Performance Considerations
Writing a correct kernel is the first challenge. Making it fast is the second. Here are the performance patterns that matter most on Apple Silicon:
Memory Bandwidth is King
For inference-sized problems (M=1 GEMV), the kernel is almost always memory-bandwidth-bound, not compute-bound. The Apple M2 Pro has ~200 GB/s of memory bandwidth and ~7 TFLOPS of F16 compute. For a 4096x4096 GEMV:
Data to read: 4096 * 4096 * 2 bytes = 32 MB (F16 weights)
Compute: 4096 * 4096 * 2 FLOPS = 33.5 MFLOPS
Time at bandwidth limit: 32 MB / 200 GB/s = 0.16 ms
Time at compute limit: 33.5 MFLOPS / 7 TFLOPS = 0.000005 ms
Arithmetic intensity: 33.5M / 32M = ~1 FLOP/byte
The arithmetic intensity of 1 FLOP/byte is well below the crossover point (~50 FLOP/byte for M2 Pro). This means your kernel’s speed is determined almost entirely by how efficiently it reads memory. Optimize for coalesced reads, minimize bank conflicts, and avoid reading the same data multiple times.
Threadgroup Size Selection
The optimal threadgroup size depends on the kernel’s register and threadgroup memory usage. Some guidelines:
| Kernel type | Recommended TG size | Rationale |
|---|---|---|
| Simple element-wise | 256 | Good occupancy, minimal overhead |
| GEMV (F16) | 128 | 16 rows x 8 threads per row |
| GEMV (quantized) | 256 | More ALU work for dequant, need threads |
| Reduction (norm) | min(dim, 1024) | One TG per row, use all SIMD groups |
| Attention | 32 (1 simdgroup) | SIMD-level matmul, shared memory |
Simdgroup Operations
Apple Silicon GPUs have 32-wide SIMD groups. Use simd_sum(), simd_shuffle(), and simdgroup matrix multiply (simdgroup_matrix<half, 8, 8>) when possible. These operations are dramatically faster than equivalent threadgroup-memory-based reductions because they operate within the register file.
Avoiding Threadgroup Memory Bottlenecks
Threadgroup memory on Apple Silicon is carved from the same SRAM as the L1 cache. Using too much threadgroup memory reduces cache capacity and can hurt performance. For simple kernels, prefer passing data through registers (via SIMD shuffles) over threadgroup memory.
Debugging Metal Kernels
When things go wrong (and they will), here are the tools:
Metal GPU Capture. Xcode’s GPU debugger lets you capture a frame of GPU work and inspect every buffer, every dispatch, and every thread’s execution. This is the single most useful debugging tool for Metal kernels.
Printf from Metal. Metal supports printf in shaders (with performance caveats). Add #include <metal_stdlib> and use printf("tid=%u val=%f\n", tid, float(data[tid])). Output appears in Xcode’s console. Be aware that printf from GPU threads is non-deterministic in ordering and can significantly slow down execution. Use it for debugging specific threads, not for bulk output.
Validation layers. Enable Metal validation (MTL_DEBUG_LAYER=1 environment variable) to catch out-of-bounds buffer access, uninitialized reads, and other errors. You can also set MTL_SHADER_VALIDATION=1 for even stricter checks, though this has a larger performance impact.
Numerical debugging. When you get wrong numerical results, dump the intermediate buffers to files and compare against a Python reference implementation. Many quantization bugs are off-by-one in bit extraction or wrong byte ordering. A useful technique is to write a Python script that reads the same weight file and computes the expected output for a known input vector.
Dispatch geometry. A very common bug is getting the grid or threadgroup dimensions wrong. If your kernel silently produces zeros or garbage, check:
- Is the grid large enough to cover all elements?
- Is the threadgroup size compatible with the kernel’s requirements?
- Are you using
dispatch()(grid = number of threadgroups) vsdispatch_threads()(grid = total threads)? - Does your kernel’s
[[threads_per_threadgroup]]attribute match the TG size you are dispatching?
Common pitfalls checklist:
| Symptom | Likely cause |
|---|---|
| All zeros output | Wrong buffer binding index, or grid too small |
| NaN/Inf output | Division by zero in normalization, or uninitialized buffer |
| Correct for small N, wrong for large N | Overflow in index calculation (use uint64_t for offsets) |
| Non-deterministic results | Race condition in threadgroup memory, or missing barrier |
| Slightly wrong values | F16 precision loss, or wrong quantization offset (e.g., -8 vs -16) |
| Kernel not found | Typo in kernel name, or .metal file not in metallib build |
Summary
Adding a kernel to Akunu follows a predictable path:
.metal file -> metallib build -> Device::get_pipeline()
|
+-------+-------+
| |
Direct dispatch DispatchCmd in table
(encoder, init) (per-token decode)
| |
+-------+-------+
|
Test
The data-driven design means the kernel itself is self-contained – it does not need to know about architectures, quantization formats, or model configs. Those concerns are handled by the dispatch table builder and the dtype descriptor table. Your kernel just needs to do one thing correctly: take input buffers, apply a computation, and write output buffers.