Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Compute Pipelines and Shaders

If you have ever worked with graphics APIs, the word “pipeline” probably conjures images of vertex shaders, rasterizers, fragment shaders, and blend states all wired together into a monolithic object. Compute pipelines are refreshingly simpler. A compute pipeline is, at its core, just one thing: a compiled GPU function plus the metadata the driver needs to launch it. That is it. No rasterizer, no depth test, no blend modes. One function, ready to run.

In this chapter we are going to trace the full lifecycle of a compute pipeline on Metal – from writing a kernel function in the Metal Shading Language, through the compilation stages that turn it into executable GPU code, all the way to encoding a dispatch command that actually runs the thing. Along the way we will see how akunu wraps this machinery to keep things fast and ergonomic from Rust.

What Is a Compute Pipeline State Object?

A MTLComputePipelineState (PSO) is an opaque, immutable object that represents a fully compiled and optimized GPU kernel. Once you have one, you can use it over and over again to dispatch work – the driver does not need to recompile anything. Creating a PSO is expensive (potentially tens of milliseconds), but using one is cheap (microseconds).

Think of it like compiling a C++ program. The compilation step is slow, but once you have the binary, running it is fast. You would never recompile the same source file every time you want to run the program. Same idea here.

+------------------+       +------------------+       +-------------------+
|  MSL Source Code | ----> |  MTLLibrary      | ----> | MTLFunction       |
|  (kernel void    |       |  (collection of  |       | (one specific     |
|   my_kernel...)  |       |   compiled       |       |  kernel function) |
+------------------+       |   functions)     |       +-------------------+
                           +------------------+              |
                                                             v
                                                    +-------------------+
                                                    | MTLComputePipeline|
                                                    | State (PSO)       |
                                                    | (ready to dispatch|
                                                    |  on GPU)          |
                                                    +-------------------+

Let us walk through each step.

The Compilation Flow: MSL to PSO

Metal has a multi-stage compilation pipeline. Understanding these stages matters because each one is a knob you can turn for performance.

Stage 1: MSL Source to AIR (Metal Intermediate Representation)

The Metal Shading Language is a dialect of C++14 with GPU-specific extensions (we will cover MSL in depth in the next chapter). When you compile MSL, the first output is AIR – Apple’s Intermediate Representation. AIR is analogous to LLVM IR (and in fact the Metal compiler is built on LLVM). It is a platform-independent representation of your shader.

+-----------+     metal compiler      +----------+
| .metal    |  ------------------->   | .air     |
| (MSL src) |    (clang frontend)     | (LLVM IR |
+-----------+                         |  variant)|
                                      +----------+

Stage 2: AIR to metallib (Metal Library)

Multiple AIR files are linked together into a .metallib file. This is a binary archive that contains all your compiled functions. You can think of it as the GPU equivalent of a .a static library or a .dylib.

+----------+
| .air     |--+
+----------+  |    metallib tool      +------------+
              +------------------->   | .metallib  |
+----------+  |                       | (binary    |
| .air     |--+                       |  archive)  |
+----------+                          +------------+

Stage 3: metallib to PSO (at runtime)

At runtime, you load the .metallib into a MTLLibrary, look up a function by name, and then create a pipeline state. This final step performs GPU-specific optimizations – register allocation, instruction scheduling, occupancy tuning – that depend on the specific GPU model the code will run on.

+------------+    MTLDevice           +-------------+     MTLDevice
| .metallib  | ------------------>   | MTLFunction  | ----------------->
| (binary)   |  newLibrary(data:)    | (looked up   |  newComputePipeline
+------------+                       |  by name)    |  State(function:)
                                     +-------------+
                                                          +----------+
                                                          | PSO      |
                                                          | (GPU-    |
                                                          |  ready)  |
                                                          +----------+

Putting It All Together

Here is the full pipeline:

MSL Source (.metal)
       |
       | metal compiler (offline or runtime)
       v
AIR (.air)
       |
       | metallib linker
       v
Metal Library (.metallib)
       |
       | Load at runtime: device.newLibrary(data:)
       v
MTLLibrary (in-memory)
       |
       | library.newFunction(name: "my_kernel")
       v
MTLFunction
       |
       | device.newComputePipelineState(function:)
       v
MTLComputePipelineState (PSO)  <-- this is what you dispatch with

Runtime vs Offline Compilation

You have two choices for when compilation happens.

Runtime Compilation (from source)

You embed MSL source code as a string and compile it at runtime:

// Runtime compilation from source string
let source = """
kernel void add_arrays(device float* a [[buffer(0)]],
                       device float* b [[buffer(1)]],
                       device float* c [[buffer(2)]],
                       uint id [[thread_position_in_grid]]) {
    c[id] = a[id] + b[id];
}
"""
let library = try device.makeLibrary(source: source, options: nil)
let function = library.makeFunction(name: "add_arrays")!
let pso = try device.makeComputePipelineState(function: function)

This is convenient for development and for cases where you need to generate shader code dynamically. The downside is that compilation happens at application startup, which can take tens to hundreds of milliseconds per kernel.

Offline Compilation (precompiled metallib)

You compile your shaders ahead of time using Apple’s command-line tools:

# Step 1: Compile MSL to AIR
xcrun -sdk macosx metal -c my_kernels.metal -o my_kernels.air

# Step 2: Link AIR into a metallib
xcrun -sdk macosx metallib my_kernels.air -o my_kernels.metallib

# Or do it in one step:
xcrun -sdk macosx metal my_kernels.metal -o my_kernels.metallib

Then at runtime, you just load the precompiled binary:

let libraryURL = Bundle.main.url(forResource: "my_kernels",
                                  withExtension: "metallib")!
let library = try device.makeLibrary(URL: libraryURL)

This is what production applications should do. The runtime cost drops from “compile everything” to “just load a binary and create PSOs.” Akunu takes the offline compilation approach – its Metal kernels are precompiled into metallib files that ship with the library.

When Would You Use Runtime Compilation?

Runtime compilation is not just for lazy developers. There are legitimate reasons to compile at runtime:

  1. Generated kernels: When the kernel code depends on runtime parameters like tensor shapes, data types, or hardware capabilities.
  2. JIT specialization: Generating code paths optimized for specific input sizes that are only known at inference time.
  3. Plugin systems: Loading user-provided shader code.

That said, for a production ML inference engine, offline compilation is almost always the right default.

Function Specialization Constants

Here is where things get interesting for ML workloads. Metal supports function constants – compile-time constants that let you create specialized variants of the same kernel without duplicating source code.

Consider a GEMV (matrix-vector multiply) kernel. The optimal implementation depends on the matrix dimensions:

// Define function constants
constant uint K [[function_constant(0)]];
constant bool USE_BIAS [[function_constant(1)]];

kernel void gemv(device const half* W [[buffer(0)]],
                 device const half* x [[buffer(1)]],
                 device half* y [[buffer(2)]],
                 device const half* bias [[buffer(3)]],
                 uint row [[thread_position_in_grid]]) {
    
    half sum = 0;
    // The compiler knows K at compile time -- it can fully
    // unroll this loop, choose optimal vector widths, etc.
    for (uint i = 0; i < K; i += 4) {
        half4 w = *((device const half4*)(W + row * K + i));
        half4 v = *((device const half4*)(x + i));
        sum += dot(w, v);
    }
    
    // Dead code elimination removes this branch entirely
    // when USE_BIAS is false
    if (USE_BIAS) {
        sum += bias[row];
    }
    
    y[row] = sum;
}

At PSO creation time, you provide the constant values:

let constants = MTLFunctionConstantValues()

var k: UInt32 = 128
constants.setConstantValue(&k, type: .uint, index: 0)

var useBias: Bool = true
constants.setConstantValue(&useBias, type: .bool, index: 1)

let function = try library.makeFunction(name: "gemv",
                                         constantValues: constants)
let pso = try device.makeComputePipelineState(function: function)

The compiler treats these constants exactly like #define values – it can unroll loops, eliminate dead branches, choose optimal vector sizes, and perform constant propagation. The result is a specialized binary for this specific combination of constant values.

Why does this matter for ML? Because ML models have fixed dimensions. A LLaMA-7B model always has a hidden dimension of 4096. An attention head is always 128-dimensional. Once you know the model architecture, you know exactly what constants to specialize on, and the compiler can produce code that is perfectly tuned for those dimensions.

Pipeline Caching Strategies

Creating a PSO involves GPU-specific compilation. On an M1 Max, creating a single pipeline takes somewhere between 5-50ms depending on kernel complexity. If you have 50 kernels, that is potentially seconds of startup time. And if you use function constants, each unique combination of constants is a separate PSO.

The solution is caching. Metal provides MTLBinaryArchive for persisting compiled pipelines to disk, but the more common approach in practice is to cache PSOs in memory and be smart about when you create them.

How Akunu Caches PSOs

Akunu maintains a PSO cache in its MetalDevice struct. The cache is a hash map keyed by a string that encodes the kernel name plus all specialization constant values:

PSO Cache (HashMap<String, MTLComputePipelineState>)
+-------------------------------------------+
| Key                    | Value (PSO)       |
|------------------------+-------------------|
| "gemv_k128"            | PSO_0x7f8a...     |
| "gemv_k4096"           | PSO_0x7f8b...     |
| "gemm_tm32_tn64_tk16"  | PSO_0x7f8c...     |
| "softmax_n128"         | PSO_0x7f8d...     |
| "rmsnorm_n4096"        | PSO_0x7f8e...     |
| "rope_dim128_base10000"| PSO_0x7f8f...     |
+-------------------------------------------+

The first time you need a PSO with specific constants, it gets compiled and stored. Every subsequent use is a hash table lookup – essentially free. In Rust, the cache lives inside the MetalDevice:

#![allow(unused)]
fn main() {
pub struct MetalDevice {
    device: metal::Device,
    command_queue: metal::CommandQueue,
    library: metal::Library,
    pso_cache: HashMap<String, metal::ComputePipelineState>,
    // ...
}

impl MetalDevice {
    fn get_or_create_pso(
        &mut self,
        kernel_name: &str,
        constants: &[(u32, FunctionConstantValue)],
    ) -> &metal::ComputePipelineState {
        // Build cache key: "kernel_name_const0val_const1val_..."
        let key = build_cache_key(kernel_name, constants);
        
        self.pso_cache.entry(key).or_insert_with(|| {
            let func_constants = metal::FunctionConstantValues::new();
            for (index, value) in constants {
                match value {
                    FunctionConstantValue::UInt(v) => {
                        func_constants.set_constant_value_at_index(
                            v as *const _ as *const _,
                            metal::MTLDataType::UInt,
                            *index as u64,
                        );
                    }
                    // ... other types
                }
            }
            
            let function = self.library
                .get_function(kernel_name, Some(func_constants))
                .expect("kernel not found");
            
            self.device
                .new_compute_pipeline_state_with_function(&function)
                .expect("failed to create PSO")
        })
    }
}
}

The key insight is that for ML inference, the set of PSOs you need is determined entirely by the model architecture. A LLaMA model uses the same kernels with the same dimensions for every token it generates. So in practice, all PSOs get created during model loading, and inference itself never hits a cache miss.

Model Loading Phase:             Inference Phase:
+------------------------+      +------------------------+
| For each layer:        |      | For each token:        |
|   - Create GEMM PSOs   |      |   - Look up PSO (hit)  |
|   - Create norm PSOs   |      |   - Encode dispatch     |
|   - Create attn PSOs   |      |   - Execute on GPU      |
|   (one-time cost)      |      |   (no compilation!)     |
+------------------------+      +------------------------+
     ~100-500ms total                ~0ms PSO overhead

Command Encoding Flow

Now that we have a PSO, let us actually use it. In Metal, you do not call GPU functions directly. Instead, you encode commands into a command buffer, and then submit the entire buffer to the GPU for execution.

The Full Encoding Stack

MTLCommandQueue
       |
       | makeCommandBuffer()
       v
MTLCommandBuffer
       |
       | makeComputeCommandEncoder()
       v
MTLComputeCommandEncoder
       |
       | setComputePipelineState()    -- which kernel to run
       | setBuffer(buffer, offset, index) -- bind data
       | setBytes(ptr, length, index)     -- inline small data
       | dispatchThreadgroups(...)        -- launch!
       |
       | endEncoding()
       v
MTLCommandBuffer
       |
       | commit()      -- send to GPU
       | waitUntilCompleted()  -- block until done
       v
[GPU executes the commands]

Let us walk through this with a concrete example. Say we want to run our add_arrays kernel on two arrays of 1024 floats.

Step 1: Create a Command Buffer

let commandBuffer = commandQueue.makeCommandBuffer()!

A command buffer is a container for GPU commands. You can think of it as a recording of work to be done. Nothing executes until you call commit().

Step 2: Create a Compute Command Encoder

let encoder = commandBuffer.makeComputeCommandEncoder()!

The encoder is how you record compute commands. Metal separates encoding (recording) from execution (running) so that you can build up a batch of work and submit it all at once.

Step 3: Set the Pipeline State

encoder.setComputePipelineState(addArraysPSO)

This tells the GPU which kernel function to run for subsequent dispatch calls.

Step 4: Bind Data

encoder.setBuffer(bufferA, offset: 0, index: 0)  // buffer(0) in MSL
encoder.setBuffer(bufferB, offset: 0, index: 1)  // buffer(1) in MSL
encoder.setBuffer(bufferC, offset: 0, index: 2)  // buffer(2) in MSL

Each setBuffer call binds a MTLBuffer to a buffer index. These indices correspond to the [[buffer(N)]] attributes in your MSL kernel.

Step 5: Dispatch

let gridSize = MTLSize(width: 1024, height: 1, depth: 1)
let threadgroupSize = MTLSize(width: 256, height: 1, depth: 1)
encoder.dispatchThreadgroups(
    MTLSize(width: 1024 / 256, height: 1, depth: 1),  // 4 threadgroups
    threadsPerThreadgroup: threadgroupSize
)

We will cover dispatch geometry in detail in Chapter 9. For now, just know that we are launching 1024 threads organized into groups of 256.

Step 6: End Encoding and Commit

encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()

endEncoding() finalizes the encoder – you cannot add more commands after this. commit() submits the command buffer to the GPU. waitUntilCompleted() blocks the CPU until the GPU finishes (in production you would typically use completion handlers or addCompletedHandler instead of blocking).

Encoding Multiple Dispatches

For ML inference, a single command buffer typically contains many dispatch calls – one per layer operation:

let encoder = commandBuffer.makeComputeCommandEncoder()!

// Layer 1: RMSNorm
encoder.setComputePipelineState(rmsNormPSO)
encoder.setBuffer(hiddenState, offset: 0, index: 0)
encoder.setBuffer(normWeights, offset: 0, index: 1)
encoder.setBuffer(normalizedOutput, offset: 0, index: 2)
encoder.dispatchThreadgroups(normGrid, threadsPerThreadgroup: normThreadgroup)

// Layer 1: Q,K,V projection (GEMM)
encoder.setComputePipelineState(gemmPSO)
encoder.setBuffer(normalizedOutput, offset: 0, index: 0)
encoder.setBuffer(qkvWeights, offset: 0, index: 1)
encoder.setBuffer(qkvOutput, offset: 0, index: 2)
encoder.dispatchThreadgroups(gemmGrid, threadsPerThreadgroup: gemmThreadgroup)

// Layer 1: Attention
encoder.setComputePipelineState(attentionPSO)
// ... more setBuffer calls ...
encoder.dispatchThreadgroups(attnGrid, threadsPerThreadgroup: attnThreadgroup)

// ... and so on for all layers ...

encoder.endEncoding()
commandBuffer.commit()

The GPU executes these dispatches in order (within the same encoder), which gives you implicit synchronization between dependent operations. The output of RMSNorm flows into GEMM which flows into attention – no explicit barriers needed within a single encoder.

The setBytes Optimization

For small amounts of data (under 4KB), Metal provides setBytes() as an alternative to setBuffer(). Instead of allocating a GPU buffer and copying data into it, setBytes() inlines the data directly into the command buffer:

struct GEMMParams {
    var M: UInt32
    var N: UInt32
    var K: UInt32
    var alpha: Float
}

var params = GEMMParams(M: 4096, N: 4096, K: 4096, alpha: 1.0)
encoder.setBytes(&params, length: MemoryLayout<GEMMParams>.size, index: 3)

This avoids the overhead of buffer allocation for small, frequently-changing parameter structs. Akunu uses this pattern extensively – kernel parameters like matrix dimensions, strides, and scalar values are all passed via setBytes().

setBuffer path (for large data):
  CPU Memory --> MTLBuffer (GPU-accessible) --> Kernel reads buffer(N)
  [allocate]    [copy or share]                 [dereference pointer]

setBytes path (for small params):
  CPU Memory --> Embedded in command buffer --> Kernel reads buffer(N)
  [no alloc]    [inline copy, < 4KB]            [same interface!]

From the kernel’s perspective, there is no difference – both appear as [[buffer(N)]] arguments. The optimization is entirely on the CPU side.

How Akunu Wraps the Metal Pipeline

Akunu’s Rust code wraps all of this Metal machinery behind a clean interface. The MetalDevice struct owns the Metal device, command queue, library, and PSO cache. Individual operations (GEMM, softmax, RMSNorm, etc.) are methods on MetalDevice that handle all the encoding details internally.

Here is a simplified view of how a GEMM dispatch looks in akunu:

#![allow(unused)]
fn main() {
impl MetalDevice {
    pub fn gemm(
        &mut self,
        a: &MetalBuffer,    // M x K matrix
        b: &MetalBuffer,    // K x N matrix
        c: &MetalBuffer,    // M x N output matrix
        m: u32,
        n: u32,
        k: u32,
    ) -> Result<()> {
        // 1. Get or create the PSO (cached)
        let cache_key = format!("gemm_tm32_tn64_tk{}", k);
        let pso = self.get_or_create_pso("gemm", &[
            (0, FunctionConstantValue::UInt(k)),
        ]);
        
        // 2. Create a command buffer and encoder
        let command_buffer = self.command_queue.new_command_buffer();
        let encoder = command_buffer.new_compute_command_encoder();
        
        // 3. Set pipeline and bind buffers
        encoder.set_compute_pipeline_state(&pso);
        encoder.set_buffer(0, Some(&a.buffer), 0);
        encoder.set_buffer(1, Some(&b.buffer), 0);
        encoder.set_buffer(2, Some(&c.buffer), 0);
        
        // 4. Set parameters via setBytes
        let params = GemmParams { m, n, k };
        encoder.set_bytes(
            3,
            std::mem::size_of::<GemmParams>() as u64,
            &params as *const _ as *const _,
        );
        
        // 5. Calculate dispatch geometry
        let threadgroup_size = MTLSize::new(128, 1, 1);  // 4 SIMD groups
        let grid_size = MTLSize::new(
            (n + 63) / 64,   // tiles along N
            (m + 31) / 32,   // tiles along M
            1,
        );
        
        // 6. Dispatch and finalize
        encoder.dispatch_threadgroups(grid_size, threadgroup_size);
        encoder.end_encoding();
        command_buffer.commit();
        
        Ok(())
    }
}
}

Notice the pattern: get PSO, create encoder, bind resources, dispatch, finalize. Every operation follows this same template. The differences are in which kernel is used, what buffers are bound, and how the dispatch geometry is calculated.

Batching Multiple Operations

For inference, akunu batches all the operations for a single forward pass into one command buffer:

#![allow(unused)]
fn main() {
impl MetalDevice {
    pub fn forward_pass(
        &mut self,
        model: &Model,
        input: &MetalBuffer,
        output: &MetalBuffer,
    ) -> Result<()> {
        let command_buffer = self.command_queue.new_command_buffer();
        let encoder = command_buffer.new_compute_command_encoder();
        
        for layer in &model.layers {
            // RMSNorm
            self.encode_rmsnorm(&encoder, &layer.norm_weights, ...);
            
            // QKV projection
            self.encode_gemm(&encoder, &layer.qkv_weights, ...);
            
            // Attention
            self.encode_attention(&encoder, ...);
            
            // Feed-forward network
            self.encode_gemm(&encoder, &layer.ff_weights, ...);
        }
        
        encoder.end_encoding();
        command_buffer.commit();
        command_buffer.wait_until_completed();
        
        Ok(())
    }
}
}

This is efficient because:

  1. One command buffer: Only one submission to the GPU driver, reducing CPU-side overhead.
  2. Sequential execution: Within a single encoder, dispatches execute in order, providing implicit synchronization.
  3. Cached PSOs: All pipeline states are already compiled from model loading.
  4. No CPU-GPU round trips: The CPU records all the work and submits it in one shot. The GPU executes everything without waiting for the CPU.

Advanced: Pipeline Reflection and Validation

When you create a PSO, you can ask Metal for reflection information – metadata about the kernel’s resource usage:

var reflection: MTLAutoreleasedComputePipelineReflection?
let pso = try device.makeComputePipelineState(
    function: function,
    options: .argumentInfo,
    reflection: &reflection
)

// Now you can inspect resource usage
if let reflection = reflection {
    for arg in reflection.arguments {
        print("Argument: \(arg.name), index: \(arg.index)")
        print("  Type: \(arg.type), access: \(arg.access)")
        print("  Size: \(arg.bufferDataSize)")
    }
}

This is useful for debugging and for building generic dispatch systems that can validate buffer bindings at runtime. If you pass a buffer that is too small or bind the wrong type, the validation layer will catch it.

In debug builds, Metal’s validation layer (enabled via the MTL_DEBUG_LAYER environment variable) performs extensive checking:

# Enable Metal validation
export MTL_DEBUG_LAYER=1

# Enable shader validation (slower but catches more bugs)
export MTL_SHADER_VALIDATION=1

These are invaluable during development. They catch things like:

  • Buffer overruns (reading/writing past the end of a buffer)
  • Missing buffer bindings
  • Threadgroup memory size mismatches
  • Resource hazards (reading a buffer that is still being written)

Async Pipeline Creation

PSO creation can be done asynchronously, which is important when you need to create many pipelines at startup:

let group = DispatchGroup()

for kernelConfig in allKernelConfigs {
    group.enter()
    device.makeComputePipelineState(function: function) { pso, error in
        if let pso = pso {
            cache[kernelConfig.key] = pso
        }
        group.leave()
    }
}

group.wait()  // All PSOs ready

This lets the Metal runtime parallelize compilation across CPU cores. For a model with dozens of kernel variants, this can significantly reduce startup time.

Summary

Let us recap the key points:

  1. A PSO is a compiled GPU kernel – expensive to create, cheap to use.
  2. Compilation is multi-stage: MSL source -> AIR -> metallib -> PSO.
  3. Offline compilation (precompiled metallib) eliminates runtime compilation cost.
  4. Function constants create specialized kernel variants without code duplication – critical for ML where dimensions are fixed per model.
  5. PSO caching (like akunu’s pso_cache_) ensures each kernel variant is compiled only once.
  6. Command encoding follows a simple pattern: begin, set pipeline, bind resources, dispatch, end.
  7. setBytes is an optimization for small parameter structs (< 4KB).
  8. Batch multiple operations into a single command buffer for efficiency.

In the next chapter, we will dive into the Metal Shading Language itself – the language you use to write those kernel functions.