Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Embedding and Activation Kernels

Before a transformer can normalize or attend or do anything useful, it needs to convert integer token IDs into dense floating-point vectors. And between the big matrix multiplies, it needs to apply nonlinear activation functions to keep the network from collapsing into a single linear transformation. These are the “glue” kernels – individually simple, collectively essential.

In this chapter we will cover akunu’s embedding lookup kernels (FP16, BF16, and quantized variants), the SiLU and GELU activation kernels (including fused gated versions for FFN blocks), and the collection of utility kernels that handle the mundane but necessary work of moving data around: vector addition, residual connections, QKV splits, transposes, bias adds, and head rearrangement.

Embedding Lookup: From Token IDs to Vectors

The embedding table is conceptually a 2D array of shape [vocab_size, dim]. To look up a token, you just index into the row for that token ID and copy it out. On a GPU, you launch enough threads to copy all elements in parallel.

FP16 Embedding Lookup

The simplest case – the embedding table is stored in half-precision:

kernel void embedding_lookup_f16(
    device const uint32_t    *tokens  [[buffer(0)]],
    device const half        *table   [[buffer(1)]],
    device half              *output  [[buffer(2)]],
    constant EmbeddingParams &params  [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    const uint dim_idx   = tid.x;
    const uint token_idx = tid.y;

    if (token_idx >= params.num_tokens || dim_idx >= params.dim) return;

    uint token_id = tokens[token_idx];
    output[token_idx * params.dim + dim_idx] = table[token_id * params.dim + dim_idx];
}

That is it. A single memory read and a single memory write. The dispatch grid is 2D:

Grid: (dim, num_tokens)

For dim=4096, num_tokens=1 (decode):
  4096 threads, each copies one element

For dim=4096, num_tokens=512 (prefill):
  2,097,152 threads

The memory access pattern deserves attention. Threads with consecutive dim_idx values (same token_idx) read consecutive memory locations from the same table row. This is perfectly coalesced:

Token "hello" (id=15043):
  Thread (0, 0): table[15043 * 4096 + 0]    -->  output[0]
  Thread (1, 0): table[15043 * 4096 + 1]    -->  output[1]
  Thread (2, 0): table[15043 * 4096 + 2]    -->  output[2]
  ...
  Thread (4095, 0): table[15043 * 4096 + 4095] --> output[4095]

  128 bytes per memory transaction / 2 bytes per element = 64 elements per fetch
  SIMD width = 32 --> 2 fetches serve one full SIMD group

However, there is a catch. Different tokens will read from completely different rows of the table, potentially megabytes apart. For a vocab of 32K with dim=4096, the table is 32768 * 4096 * 2 = 256 MB. Only a tiny fraction fits in cache. The first token’s row load will be a cache miss for sure. But since we only read each row once, caching is irrelevant – this is a pure streaming kernel.

BF16 Embedding Lookup

For M4 and later chips that have native BF16 support, akunu provides a BF16 variant:

kernel void embedding_lookup_bf16(
    device const uint32_t   *token_ids  [[buffer(0)]],
    device const bfloat     *table      [[buffer(1)]],
    device half             *output     [[buffer(2)]],
    constant uint           &dim        [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    const uint d_idx     = tid.x;
    const uint token_idx = tid.y;
    if (d_idx >= dim) return;

    const uint token_id = token_ids[token_idx];
    output[token_idx * dim + d_idx] = half(table[token_id * dim + d_idx]);
}

Notice it reads BF16 but outputs FP16. BF16 has the same exponent range as FP32 (8 bits) but only 7 bits of mantissa versus FP16’s 10. The conversion happens inline – the half() cast truncates the mantissa. BF16 weights are increasingly common in HuggingFace models, and on M4 hardware this cast is a single-cycle operation.

Quantized Embedding Lookup: Q4_0 On-the-Fly Dequantization

Now for the interesting one. When the embedding table is quantized to Q4_0 format (4 bits per weight), the kernel must dequantize on the fly. Each Q4_0 block contains 32 elements packed into 18 bytes:

Q4_0 Block Structure (18 bytes total):
+------+------------------+
| d    | qs[16]           |
| (2B) | (16 bytes)       |
+------+------------------+
  |          |
  |          +-- 32 four-bit values, nibble-packed:
  |              qs[j] low nibble  = element j      (j = 0..15)
  |              qs[j] high nibble = element j + 16 (j = 0..15)
  |
  +-- FP16 scale factor

Dequantization: value = d * (nibble - 8)
  (Q4_0 stores unsigned 0-15, centered at 8)

Here is the kernel:

kernel void embedding_lookup_q4_0(
    device const uint32_t    *token_ids  [[buffer(0)]],
    device const block_q4_0  *table      [[buffer(1)]],
    device half              *output     [[buffer(2)]],
    constant EmbeddingParams &params     [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    const uint d_idx     = tid.x;
    const uint token_idx = tid.y;

    if (d_idx >= params.dim || token_idx >= params.num_tokens) return;

    const uint token_id = token_ids[token_idx];
    const uint blocks_per_row = params.dim / QK4_0;

    const uint block_idx  = d_idx / QK4_0;
    const uint elem_idx   = d_idx % QK4_0;

    device const block_q4_0 &blk = table[token_id * blocks_per_row + block_idx];
    half scale = blk.d;
    uint is_high = elem_idx / 16;
    uint j = elem_idx % 16;
    uint8_t nibble = is_high ? (blk.qs[j] >> 4) : (blk.qs[j] & 0xF);
    half val = scale * half(int(nibble) - 8);

    output[token_idx * params.dim + d_idx] = val;
}

Let us trace the dequantization for element 21 within a block:

elem_idx = 21
is_high  = 21 / 16 = 1   (it is in the upper half)
j        = 21 % 16 = 5   (byte index within qs[])

nibble   = qs[5] >> 4    (high nibble of byte 5)
value    = scale * (nibble - 8)

Example: qs[5] = 0xA3, scale = 0.125
  nibble = 0xA = 10
  value  = 0.125 * (10 - 8) = 0.125 * 2 = 0.25

The memory savings are significant:

FP16 table:  vocab_size * dim * 2 bytes
Q4_0 table:  vocab_size * (dim/32) * 18 bytes

For vocab=128256, dim=4096:
  FP16:  128256 * 4096 * 2    = 1,050,673,152 bytes (~1.0 GB)
  Q4_0:  128256 * 128 * 18    =  295,239,680 bytes (~282 MB)

Compression ratio: ~3.6x

The downside is slightly more ALU work per element (a shift, a mask, a subtract, a multiply). But since embedding lookup is purely memory-bound – you are reading from a huge table with essentially random access per token – the extra ALU is completely hidden behind memory latency.

Akunu provides similar quantized lookup kernels for Q4_1, Q5_0, Q5_K, Q8_0, Q2_K, Q3_K, Q4_K, and Q6_K formats. Each follows the same pattern: one thread per output element, read the relevant quantized block, extract and dequantize the value, write FP16 output.

Positional Embedding Addition

For models that use learned positional embeddings (like Whisper), there is a simple kernel that adds a position-dependent vector:

kernel void pos_embed_add_f16(
    device const half        *input     [[buffer(0)]],
    device const half        *pos_table [[buffer(1)]],
    device half              *output    [[buffer(2)]],
    constant ElementwiseParams &p       [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    uint pos = p._pad0;
    uint dim = p.count;
    output[tid] = input[tid] + pos_table[pos * dim + tid];
}

One thread per dimension, one addition per thread. The position index is passed via a repurposed padding field in ElementwiseParams (a common trick to avoid defining a new params struct for a one-off kernel).

Activation Functions: SiLU and GELU

After the FFN’s first linear projection, the output needs a nonlinearity. Modern LLMs use two main activation functions:

SiLU (Sigmoid Linear Unit):
  f(x) = x * sigmoid(x) = x / (1 + exp(-x))

GELU (Gaussian Error Linear Unit):
  f(x) = x * Phi(x) ~ 0.5 * x * (1 + tanh(sqrt(2/pi) * x * (1 + 0.044715 * x^2)))

Let us visualize these:

             SiLU                              GELU
    y |                               y |
    2 |               /               2 |               /
      |             /                   |             /
    1 |           /                   1 |           /
      |         /                       |         ./
    0 |---____/                       0 |---___./
      |  /                              |  /
   -1 |/                             -1 |/
      +--+--+--+--+--+--> x            +--+--+--+--+--+--> x
        -4 -2  0  2  4                   -4 -2  0  2  4

Both are smooth, monotonically increasing for positive x, and have a soft “gate” that suppresses negative values. The key difference: GELU has a steeper transition around x=0 (it goes slightly negative before recovering), while SiLU is smoother.

SiLU Kernel

The standalone SiLU kernel is trivial:

kernel void silu_f16(
    device const half        *input  [[buffer(0)]],
    device half              *output [[buffer(1)]],
    constant ElementwiseParams &p    [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    output[tid] = act_silu(input[tid]);
}

Where act_silu is defined in KernelCommon.h:

inline half act_silu(half x) {
    return x / (half(1) + exp(-x));
}

One thread, one element, one activation. The exp(-x) is the expensive part – it compiles to a hardware transcendental approximation on Apple GPUs. For FP16 inputs, the fast-math exp is accurate to about 3 ULPs, which is more than adequate.

Fused SiLU Gate: The SwiGLU Pattern

LLaMA and its descendants use SwiGLU (SiLU-Gated Linear Unit) in the FFN. The pattern is:

FFN(x) = (SiLU(W_gate * x)) * (W_up * x)

The gate and up projections are computed separately by GEMM, producing two vectors of size ff_dim. The fused kernel combines the activation and element-wise multiply:

kernel void silu_gate_f16(
    device const half        *gate   [[buffer(0)]],
    device const half        *up     [[buffer(1)]],
    device half              *output [[buffer(2)]],
    constant ElementwiseParams &p    [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    output[tid] = act_silu(gate[tid]) * up[tid];
}

This is twice the work of standalone SiLU but saves a full pass over the data:

Unfused:
  Kernel 1: silu(gate) --> temp[]     // read gate, write temp
  Kernel 2: temp * up  --> output[]   // read temp, read up, write output
  Total: 2 reads + 1 write + 1 read + 1 write = 5 memory ops per element

Fused:
  Kernel 1: silu(gate) * up --> output[]  // read gate, read up, write output
  Total: 2 reads + 1 write = 3 memory ops per element

The fused version eliminates the temporary buffer entirely – 40% less memory traffic and one fewer dispatch.

Strided SiLU Gate: Batch Mode

During prefill, the gate and up projections may be packed into a single buffer with a specific stride layout:

buf layout: [M rows, 2 * ff_dim columns]

Row structure:
+---------------------------+---------------------------+
| gate[0..ff_dim-1]        | up[0..ff_dim-1]           |
+---------------------------+---------------------------+
<-------- ff_dim ----------><-------- ff_dim ---------->
<------------------ 2 * ff_dim ----------------------->

The strided kernel handles this layout with a 2D grid:

kernel void silu_gate_strided_f16(
    device const half        *buf    [[buffer(0)]],
    device half              *output [[buffer(1)]],
    constant uint32_t        &ffDim  [[buffer(2)]],
    uint2 tid [[thread_position_in_grid]]
) {
    const uint j   = tid.x;  // [0, ffDim)
    const uint row = tid.y;  // [0, M)
    if (j >= ffDim) return;
    const uint gate_idx = row * 2 * ffDim + j;
    const uint up_idx   = row * 2 * ffDim + ffDim + j;
    output[row * ffDim + j] = act_silu(buf[gate_idx]) * buf[up_idx];
}

Each thread computes one output element from the gate and up values at the appropriate offsets within the packed buffer.

GELU and GELU Gate

For Gemma (which uses GeGLU rather than SwiGLU), there are corresponding GELU kernels. The standalone GELU:

kernel void gelu_f16(
    device const half        *input  [[buffer(0)]],
    device half              *output [[buffer(1)]],
    constant ElementwiseParams &p    [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    output[tid] = act_gelu(input[tid]);
}

Where act_gelu_f32 in KernelCommon.h uses the tanh approximation:

inline float act_gelu_f32(half x) {
    float xf = float(x);
    constexpr float SQRT_2_OVER_PI = 0.7978845608f;
    constexpr float GELU_COEF_A = 0.044715f;
    return 0.5f * xf * (1.0f + precise::tanh(SQRT_2_OVER_PI * xf
                                              * (1.0f + GELU_COEF_A * xf * xf)));
}

Notice the precise::tanh – this uses the precise (not fast-math) tanh to match the reference implementation exactly. The comment in the source explains why the fused GELU gate kernel does all computation in float32:

/// GELU-gate: output = gelu(gate) * up
/// All computation in float32 to match llama.cpp precision
/// (F16 GELU*up can overflow)
kernel void gelu_gate_f16(
    device const half        *gate   [[buffer(0)]],
    device const half        *up     [[buffer(1)]],
    device half              *output [[buffer(2)]],
    constant ElementwiseParams &p    [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    output[tid] = half(act_gelu_f32(gate[tid]) * float(up[tid]));
}

The GELU output can approach 1.0 for large inputs, and multiplying by up (which can also be large) risks FP16 overflow. Doing the multiply in float32 and casting back to FP16 at the end prevents this.

Utility Kernels: The Supporting Cast

Beyond embeddings and activations, akunu has a collection of utility kernels that handle common data manipulation operations. These are all simple – one thread per element, no reductions, no shared memory – but they are called frequently throughout inference.

Vector Add

kernel void vector_add_f16(
    device const half        *A      [[buffer(0)]],
    device const half        *B      [[buffer(1)]],
    device half              *C      [[buffer(2)]],
    constant ElementwiseParams &p    [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    C[tid] = A[tid] + B[tid];
}

As simple as it gets. There are both FP16 and FP32 variants. Used for general purpose addition where neither input is an “accumulator” (unlike residual add, which semantically represents a skip connection).

Residual Add

kernel void residual_add_f16(
    device const half        *a      [[buffer(0)]],
    device const half        *b      [[buffer(1)]],
    device half              *output [[buffer(2)]],
    constant ElementwiseParams &p    [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    output[tid] = a[tid] + b[tid];
}

Functionally identical to vector_add_f16. The separate kernel exists for semantic clarity in the pipeline (and potentially for profiling – you can see “residual_add” in GPU trace tools and know exactly what part of the transformer you are looking at).

Bias Add: Broadcast Addition

After a GEMM/GEMV produces output of shape [rows, dim], many models add a bias vector of shape [dim] to every row:

kernel void bias_add_f16(
    device half              *data   [[buffer(0)]],
    device const half        *bias   [[buffer(1)]],
    constant ElementwiseParams &p    [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= p.count) return;
    uint col = tid % p._pad0;  // _pad0 = dim
    data[tid] = data[tid] + bias[col];
}

The modulo operation tid % dim extracts the column index, which indexes into the 1D bias vector. The same bias value is added to every row. This is an in-place operation – no separate output buffer needed.

Before:                        After:
+--------+--------+--------+  +--------+--------+--------+
| r0,c0  | r0,c1  | r0,c2  |  | +b[0]  | +b[1]  | +b[2]  |
| r1,c0  | r1,c1  | r1,c2  |  | +b[0]  | +b[1]  | +b[2]  |
| r2,c0  | r2,c1  | r2,c2  |  | +b[0]  | +b[1]  | +b[2]  |
+--------+--------+--------+  +--------+--------+--------+

QKV Split

After the fused QKV linear projection, the output is a single buffer of shape [seq_len, q_dim + kv_dim + kv_dim]. The QKV split kernel separates it into three buffers:

kernel void qkv_split_f16(
    device const half       *src    [[buffer(0)]],
    device half             *dst_q  [[buffer(1)]],
    device half             *dst_k  [[buffer(2)]],
    device half             *dst_v  [[buffer(3)]],
    constant QKVSplitParams &params [[buffer(4)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint total = params.seq_len * params.qkv_dim;
    if (tid >= total) return;

    const uint t   = tid / params.qkv_dim;
    const uint col = tid % params.qkv_dim;

    half val = src[tid];

    if (col < params.q_dim) {
        dst_q[t * params.q_dim + col] = val;
    } else if (col < params.q_dim + params.kv_dim) {
        dst_k[t * params.kv_dim + (col - params.q_dim)] = val;
    } else {
        dst_v[t * params.kv_dim + (col - params.q_dim - params.kv_dim)] = val;
    }
}

Visually:

Source: [seq_len, q_dim + kv_dim + kv_dim]

Row t:
+-------------------------------+------------------+------------------+
|   Q columns (0..q_dim-1)     | K (q_dim..q_dim  | V (q_dim+kv_dim  |
|                               |   +kv_dim-1)     |   ..end)         |
+-------------------------------+------------------+------------------+
           |                            |                    |
           v                            v                    v
    dst_q[t, 0..q_dim-1]      dst_k[t, 0..kv_dim-1]  dst_v[t, 0..kv_dim-1]

The branch (if col < q_dim ... else if ... else) means threads in the same SIMD group may take different paths. This causes some warp divergence. However, since the branches are determined purely by col (which varies across threads in a predictable pattern), the divergence is minimal – large contiguous blocks of threads all take the same branch.

For GQA models where kv_dim < q_dim, the Q region is much larger. With n_heads=32, n_kv_heads=8, head_dim=128:

q_dim  = 32 * 128 = 4096
kv_dim =  8 * 128 = 1024
qkv_dim = 4096 + 1024 + 1024 = 6144

Threads 0-4095:    write to dst_q     (66.7%)
Threads 4096-5119: write to dst_k     (16.7%)
Threads 5120-6143: write to dst_v     (16.7%)

Transpose

A straightforward 2D transpose:

kernel void transpose_f16(
    device const half        *input  [[buffer(0)]],
    device half              *output [[buffer(1)]],
    constant ElementwiseParams &p    [[buffer(2)]],
    uint2 gid [[thread_position_in_grid]]
) {
    uint row = gid.y;
    uint col = gid.x;
    uint rows = p.count;
    uint cols = p._pad0;
    if (row >= rows || col >= cols) return;
    output[col * rows + row] = input[row * cols + col];
}

The classic output[col][row] = input[row][col] pattern. The dispatch grid is (ceil(cols/16), ceil(rows/16)) with threadgroup size (16, 16).

A production-quality transpose would typically use shared memory tiling to avoid uncoalesced writes (writing to output[col * rows + row] is strided when threads have consecutive col values). However, for the matrix sizes akunu typically transposes (attention scores, head rearrangement intermediates), the simple version is fast enough.

Head Rearrange

Attention requires reordering data between two layouts:

Position-major:  [seq_len, n_heads * head_dim]   (natural GEMM output)
Head-major:      [n_heads, seq_len, head_dim]    (what attention needs)

The forward kernel converts position-major to head-major:

kernel void head_rearrange_forward_f16(
    device const half      *src    [[buffer(0)]],
    device half            *dst    [[buffer(1)]],
    constant uint          &seq    [[buffer(2)]],
    constant uint          &n_heads [[buffer(3)]],
    constant uint          &head_dim [[buffer(4)]],
    uint tid [[thread_position_in_grid]]
) {
    uint dim = n_heads * head_dim;
    uint total = seq * dim;
    if (tid >= total) return;

    uint pos  = tid / dim;
    uint rem  = tid % dim;
    uint head = rem / head_dim;
    uint d    = rem % head_dim;

    dst[head * seq * head_dim + pos * head_dim + d] = src[tid];
}

Visualized for seq=3, n_heads=2, head_dim=4:

Source (position-major):
  pos 0: [h0d0 h0d1 h0d2 h0d3 | h1d0 h1d1 h1d2 h1d3]
  pos 1: [h0d0 h0d1 h0d2 h0d3 | h1d0 h1d1 h1d2 h1d3]
  pos 2: [h0d0 h0d1 h0d2 h0d3 | h1d0 h1d1 h1d2 h1d3]

Destination (head-major):
  head 0: [pos0: d0 d1 d2 d3 | pos1: d0 d1 d2 d3 | pos2: d0 d1 d2 d3]
  head 1: [pos0: d0 d1 d2 d3 | pos1: d0 d1 d2 d3 | pos2: d0 d1 d2 d3]

The inverse kernel does the opposite transformation. Both are simple scatter/gather operations with integer division and modulo to compute source and destination indices.

Performance Characteristics

Let us categorize these kernels by their computational profile:

+---------------------------+----------+--------+-----------+
| Kernel                    | Type     | AI     | Bound by  |
+---------------------------+----------+--------+-----------+
| embedding_lookup_f16      | Gather   | 0      | Memory    |
| embedding_lookup_q4_0     | Gather   | ~0.5   | Memory    |
| silu_f16                  | Map      | ~2     | Memory    |
| silu_gate_f16             | Map      | ~3     | Memory    |
| gelu_f16                  | Map      | ~5     | Memory    |
| gelu_gate_f16             | Map      | ~6     | Memory    |
| vector_add_f16            | Map      | ~0.3   | Memory    |
| residual_add_f16          | Map      | ~0.3   | Memory    |
| bias_add_f16              | Map      | ~0.5   | Memory    |
| qkv_split_f16             | Scatter  | 0      | Memory    |
| transpose_f16             | Scatter  | 0      | Memory    |
| head_rearrange_*_f16      | Scatter  | 0      | Memory    |
+---------------------------+----------+--------+-----------+

AI = Arithmetic Intensity (FLOPs per byte transferred)

Every single one is memory-bound. The most compute-intensive is GELU (which involves a tanh, several multiplies, and an add), but even that has an arithmetic intensity well below what Apple Silicon can sustain. The dominant cost is moving bytes to and from main memory.

This is why fusion matters so much. Fusing silu_gate saves an entire buffer round-trip. Fusing residual_rmsnorm eliminates a data pass. Each fusion does not speed up the math – it reduces the memory traffic.

The Data Flow Through Inference

Let us trace how these kernels fit together in a single decoder layer:

Token IDs
    |
    v
[embedding_lookup_f16 / q4_0]  --> hidden[seq_len, dim]
    |
    v
[residual_rmsnorm_f16]          --> norm_out[seq_len, dim]  (+ res_out for skip)
    |
    v
[GEMM: QKV projection]         --> qkv[seq_len, q_dim + 2*kv_dim]
    |
    v
[qkv_split_f16]                --> Q[seq, q_dim], K[seq, kv_dim], V[seq, kv_dim]
    |     (or fused into rope_*_kv_write for decode)
    v
[rope_neox_f16 / rope_f16]     --> Q', K' (rotated)
    |
    v
[head_rearrange_forward_f16]   --> Q'[n_heads, seq, hd], K'[n_kv, seq, hd]
    |
    v
[Attention GEMM + Softmax]     --> attn_out[n_heads, seq, hd]
    |
    v
[head_rearrange_inverse_f16]   --> attn_out[seq, dim]
    |
    v
[GEMM: output projection]      --> proj_out[seq, dim]
    |
    v
[bias_add_f16]                  --> proj_out += bias   (if model has bias)
    |
    v
[residual_add_f16]              --> hidden = proj_out + res_out
    |
    v
[residual_rmsnorm_f16]          --> norm_out   (for FFN)
    |
    v
[GEMM: gate + up projection]   --> gate[seq, ff_dim], up[seq, ff_dim]
    |
    v
[silu_gate_f16 / gelu_gate_f16]  --> activated[seq, ff_dim]
    |
    v
[GEMM: down projection]        --> ffn_out[seq, dim]
    |
    v
[residual_add_f16]              --> hidden = ffn_out + res_out
    |
    v
(next layer)

The embedding lookup runs once at the beginning. Everything else repeats for each layer. The GEMMs dominate runtime (they are the only compute-bound kernels), while everything else fills in the gaps. But collectively, these “gap” kernels add up – for a 32-layer model, you might have 200+ non-GEMM dispatches per token. Each one needs to be as lean as possible.

Summary

Embedding and activation kernels are the connective tissue of the inference pipeline:

  1. Embedding lookup – pure gather from a table. FP16 is a simple copy, Q4_0 dequantizes on-the-fly with nibble extraction. BF16 variant for M4+ hardware.
  2. SiLU/GELU activations – one thread per element, hardware transcendentals. Fused gate variants (silu_gate, gelu_gate) eliminate a temporary buffer. GELU gate uses FP32 intermediate to prevent overflow.
  3. Utility kernels – vector add, residual add, bias add, QKV split, transpose, head rearrange. All one-thread-per-element, all memory-bound, all essential for moving data between the big GEMMs in the right layout.

The recurring theme: these kernels are individually trivial but collectively critical. They are all memory-bound, so the optimization strategy is always the same – minimize the number of dispatches and the number of memory round-trips. Fusion is the primary weapon.