Whisper: Audio Transcription
Whisper is an encoder-decoder transformer trained by OpenAI for automatic speech recognition.1 It is the odd one out in Akunu’s model lineup – every other model is a decoder-only LLM, but Whisper has a full encoder (for audio) and a decoder (for text) connected by cross-attention. This chapter covers how Akunu implements Whisper from mel spectrogram computation all the way through beam search decoding, including the fused Metal kernels that make single-token decode fast enough for real-time transcription on Apple Silicon.
Architecture Overview
Whisper follows the classic encoder-decoder transformer design, with a Conv1D audio frontend replacing the standard token embedding:
Audio (PCM 16kHz)
|
v
Mel Spectrogram (80 x 3000)
|
v
+------------------------+
| Conv1D (kernel=3, s=1) | n_mels -> enc_dim, same length
| GELU |
+------------------------+
|
v
+------------------------+
| Conv1D (kernel=3, s=2) | enc_dim -> enc_dim, length / 2
| GELU |
+------------------------+
|
v
Transpose + Positional Embedding
|
v
+----------------------------------+
| Encoder Transformer x enc_layers |
| LayerNorm -> Self-Attention |
| Residual |
| LayerNorm -> FFN (GELU) |
| Residual |
+----------------------------------+
|
v
Final LayerNorm -> encoder_output [1500, enc_dim]
|
| precompute cross K/V for all decoder layers
|
v
+----------------------------------+
| Decoder Transformer x dec_layers |
| LayerNorm -> Self-Attention | (causal, with KV cache)
| Residual |
| LayerNorm -> Cross-Attention | (static K/V from encoder)
| Residual |
| LayerNorm -> FFN (GELU) |
| Residual |
+----------------------------------+
|
v
Final LayerNorm -> Logit Projection -> argmax/beam search
Key differences from a decoder-only LLM:
- LayerNorm instead of RMSNorm, with both weight and bias
- Bias terms on all linear projections (except key projections)
- GELU activation (not SiLU/SwiGLU) with no gate
- Sinusoidal positional embeddings (not RoPE)
- Cross-attention in every decoder layer
- Tied embeddings – the output projection reuses the embedding matrix
Audio Preprocessing: The Mel Spectrogram
Before any neural network computation happens, raw audio must be converted to a log-mel spectrogram. This is Akunu’s MelSpectrogram class in src/audio/mel.h.
Parameters
| Parameter | Whisper value | Description |
|---|---|---|
| Sample rate | 16000 Hz | Input audio must be 16kHz mono |
| n_fft | 400 | FFT window size (25ms at 16kHz) |
| hop_length | 160 | Stride between windows (10ms) |
| n_mels | 80 | Number of mel frequency bands |
| n_frames | 3000 | Output frames (30 seconds of audio) |
The pipeline:
PCM float samples (480,000 for 30s @ 16kHz)
|
v
+----------------------------+
| Hann window (n_fft=400) |
| Zero-pad to 512 (power of 2)|
| vDSP FFT (radix-2) |
| Power spectrum |X(f)|^2 |
+----------------------------+
| repeat for each frame (hop=160)
v
Spectrogram: [3000 frames, 201 freq bins]
|
v
+----------------------------+
| Mel filterbank (cblas_sgemm)|
| [80, 201] x [3000, 201]^T |
+----------------------------+
|
v
Log-mel: [80, 3000]
|
v
+----------------------------+
| Clamp to 1e-10, log10 |
| Dynamic range: max - 8.0 |
| Scale: (val + 4.0) / 4.0 |
+----------------------------+
|
v
Normalized mel spectrogram [80, 3000]
FFT via Accelerate
Akunu uses Apple’s Accelerate framework for the FFT, specifically vDSP_fft_zrip (in-place radix-2 FFT on split-complex data). The n_fft of 400 is zero-padded to 512 (next power of 2) for the FFT:
ms.n_fft_padded = 512; // next power of 2 >= 400
ms.log2n = 9; // log2(512)
ms.fft_setup = vDSP_create_fftsetup(ms.log2n, FFT_RADIX2);
One subtle detail: vDSP_fft_zrip returns output scaled by 2x compared to the standard DFT definition. The code compensates with a scale factor of 0.25 when computing the power spectrum:
float scale = 0.25f; // compensate vDSP 2x scaling
mag_row[0] = split.realp[0] * split.realp[0] * scale;
Mel Filterbank
The mel filterbank is a [80, 201] matrix that maps the 201 FFT frequency bins to 80 mel-spaced bands. It is constructed using the standard HTK mel scale:2
hz_to_mel(f) = 2595 * log10(1 + f/700)
mel_to_hz(m) = 700 * (10^(m/2595) - 1)
The filterbank application is a single matrix multiply using cblas_sgemm:
// mel_filters: [80, 201] magnitudes: [3000, 201]
// output: [80, 3000] = mel_filters @ magnitudes^T
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
n_mels, n_frames, n_freq,
1.0f, mel_filters.data(), n_freq,
magnitudes.data(), n_freq,
0.0f, mel.data(), n_frames);
This is one of the rare places in Akunu where CPU-side BLAS is used. The mel spectrogram is small enough that GPU dispatch overhead would exceed the computation time.
Log-Mel Normalization
The final normalization matches OpenAI’s Python implementation exactly:
1. log10(max(mel, 1e-10)) -- log scale with floor
2. clamp to (max_value - 8.0) -- 80 dB dynamic range
3. (value + 4.0) / 4.0 -- normalize to ~[0, 1]
This normalization is critical – using a different scheme (e.g., log2 instead of log10, or different clipping) will produce garbled output because the model was trained with exactly this preprocessing.
The Encoder
The encoder processes the mel spectrogram through a Conv1D frontend followed by a standard transformer.
Conv1D Frontend
Two 1D convolution layers reduce the temporal resolution by 2x:
Conv1: (80, 3000) -> (enc_dim, 3000) kernel=3, stride=1, pad=1
GELU activation
Conv2: (enc_dim, 3000) -> (enc_dim, 1500) kernel=3, stride=2, pad=1
GELU activation
These run as custom Metal kernels (conv1d_gelu_f32in_f16 and conv1d_gelu_f16). The first conv takes F32 input (from the mel spectrogram) and outputs F16; the second is F16 throughout. Both fuse the GELU activation into the convolution kernel to save a dispatch.
After convolution, the output is transposed from channel-first [enc_dim, 1500] to sequence-first [1500, enc_dim] for the transformer. A sinusoidal positional embedding is added:
// Transpose [enc_dim, enc_seq] -> [enc_seq, enc_dim]
dispatch("transpose_f16", ...);
// Add positional embedding
enc_add(dev, wb.enc_h1, wb.enc_pos_embed, wb.enc_h0, enc_seq * enc_dim);
Encoder Transformer Layers
Each encoder layer follows the standard pre-norm transformer pattern, but with LayerNorm (not RMSNorm) and GELU activation (not SiLU):
input
|
+-----> LayerNorm(attn_ln) -> Q,K,V projections
| |
| Self-Attention (non-causal)
| |
| O projection
| |
+---- residual add <-----------+
|
+-----> LayerNorm(mlp_ln) -> FFN up (GELU) -> FFN down
| |
+---- residual add <-----------------------------+
|
v
output
The self-attention is non-causal – every position attends to every other position, since the entire audio is available at once. Akunu dispatches this using the prefill attention kernel with non-causal masking:
uint32_t fc_values[] = {(uint32_t)head_dim, NQ, 1}; // 1 = non-causal
Pipeline attn_pso = dev.get_pipeline(
"flash_attention_prefill_f16",
"attn_enc_prefill_nc_hd" + std::to_string(head_dim),
fc_indices, fc_values, 3, fc_types);
The FFN uses plain GELU without a gate:
FFN(x) = down(GELU(up(x)))
This is simpler than the SwiGLU used in LLaMA/Qwen:
SwiGLU(x) = down(SiLU(gate(x)) * up(x))
No gate projection means the FFN has two weight matrices instead of three.
Head Rearrangement on GPU
A notable optimization: Akunu performs the Q/K/V head rearrangement entirely on GPU. The GEMM outputs are [seq, n_heads * head_dim] (all heads concatenated), but the attention kernel expects [n_heads, seq, head_dim] (head-major). Rather than CPU-side transpose, dedicated kernels handle this:
// [seq, n_heads*head_dim] -> [n_heads, seq, head_dim]
enc_head_rearrange_forward(dev, wb.enc_q, wb.enc_attn_out,
enc_seq, n_heads, head_dim);
// After attention: [n_heads, seq, head_dim] -> [seq, n_heads*head_dim]
enc_head_rearrange_inverse(dev, wb.enc_q, wb.enc_attn_out,
enc_seq, n_heads, head_dim);
This avoids a GPU-CPU sync that would kill pipeline parallelism.
Cross-Attention Precomputation
This is the single most important optimization for Whisper performance. The encoder output is fixed for the entire decoding process – it does not change between tokens. That means the cross-attention K and V projections can be computed once and reused for every decoder step.
static void precompute_cross_kv(Device& dev, WhisperModel& wm,
const AkunuModelConfig& cfg, WhisperBuffers& wb, int enc_seq)
{
dev.begin_encoding();
for (int l = 0; l < n_dec; l++) {
// K projection: encoder_output @ cross_attn.key.weight^T
enc_gemm(dev, wb.encoder_output, wm.get_tensor(
"decoder.blocks.%d.cross_attn.key.weight", l),
wb.cross_k[l], enc_seq, dec_dim, enc_dim);
// V projection: encoder_output @ cross_attn.value.weight^T + bias
enc_gemm(dev, wb.encoder_output, wm.get_tensor(
"decoder.blocks.%d.cross_attn.value.weight", l),
wb.cross_v[l], enc_seq, dec_dim, enc_dim);
enc_bias_add(dev, wb.cross_v[l], vb, enc_seq, dec_dim);
}
dev.end_encoding_sync();
// Rearrange to head-major [n_heads, enc_seq, head_dim]
for (int l = 0; l < n_dec; l++) {
// CPU rearrange on UMA is fast after single GPU sync
for (int h = 0; h < n_heads; h++)
for (int p = 0; p < enc_seq; p++)
memcpy(&d[(h*enc_seq+p)*head_dim],
&s[p*dec_dim + h*head_dim],
head_dim * sizeof(__fp16));
}
}
For Whisper Large (enc_seq=1500, dec_dim=1280, 32 decoder layers), this precomputation involves:
GEMM per layer: 2 x (1500 x 1280 x 1280) = 4.9 GFLOPS
Total (32 layers): 32 x 4.9 = 157 GFLOPS
Time on M2 Pro: ~15ms (one-time cost)
Without precomputation:
Per token: 157 GFLOPS / 1500 * 1 = 0.1 GFLOPS (tiny but...)
Overhead: 2 extra GEMV dispatches per layer per token
= 64 extra kernel dispatches per token
With precomputation:
Per token: 0 GEMM, cross-K and cross-V are static buffers
= just the cross-attention dot product
The savings are not just in FLOPS – eliminating 64 kernel dispatches per token significantly reduces command buffer overhead.
Note that the K/V rearrangement to head-major layout is done on CPU after a single GPU sync. On Apple Silicon’s unified memory, this memcpy-based rearrangement completes in microseconds because the data is already in coherent memory.
The Decoder
The decoder uses a dispatch table, just like the LLM decoder. But it has additional complexity: cross-attention, positional embeddings (instead of RoPE), and Whisper-specific token suppression.
Dispatch Table Structure
The Whisper decode table has the following command sequence per token:
1. embedding_lookup_f16 (token -> hidden state)
2. pos_embed_add_f16 (add positional encoding)
For each layer (16 commands/layer with fused kernels):
3. layernorm_f16 (self-attention norm)
4. whisper_gemv_bias_f16 (Q projection, fused with bias)
5. gemv_f16 (K projection, no bias)
6. whisper_gemv_bias_f16 (V projection, fused with bias)
7. kv_cache_write_f16 x2 (write K,V to cache, no RoPE)
8. flash_attention_decode_f16 (self-attention, causal)
9. whisper_gemv_bias_res_f16 (O + bias + residual, fused)
10. layernorm_f16 (cross-attention norm)
11. whisper_gemv_bias_f16 (cross-Q + bias)
12. flash_attention_decode_f16 (cross-attention, static K/V)
13. whisper_gemv_bias_res_f16 (cross-O + bias + residual)
14. layernorm_f16 (FFN norm)
15. whisper_gemv_bias_gelu_f16 (FFN up + bias + GELU, fused)
16. whisper_gemv_bias_res_f16 (FFN down + bias + residual)
17. layernorm_f16 (output norm)
18. gemv_f16 (logit projection)
19. whisper_suppress_f16 (suppress special tokens)
20. argmax_f16 (greedy decode)
Fused Whisper Kernels
The decoder benefits enormously from fused kernels that combine operations that would otherwise require separate dispatches:
| Fused kernel | Operations combined | Dispatches saved |
|---|---|---|
whisper_gemv_bias_f16 | GEMV + bias add | 2 -> 1 |
whisper_gemv_bias_gelu_f16 | GEMV + bias + GELU | 3 -> 1 |
whisper_gemv_bias_residual_f16 | GEMV + bias + residual add | 3 -> 1 |
Without fusion, each decoder layer would need ~26 dispatches. With fusion, it is ~16. For Whisper Large with 32 layers, that is:
Unfused: 2 + 32*26 + 4 = 838 dispatches/token
Fused: 2 + 32*16 + 4 = 518 dispatches/token
Savings: 320 dispatches/token (38% reduction)
On Apple Silicon, each dispatch has a fixed overhead of roughly 1-3 microseconds for command encoding. At 518 dispatches, that is ~0.5-1.5ms of pure overhead – significant when the target is real-time transcription.
The fused kernels gracefully degrade if the specialized Metal function is not available (e.g., on an older metallib):
static void w_emit_gemv_bias(DispatchTable& tbl, Device& dev, ...) {
Pipeline pso = dev.get_pipeline("whisper_gemv_bias_f16");
if (!pso.handle) {
// Fallback: separate GEMV + bias add
w_emit_gemv(tbl, dev, in, weight, out, out_off, N, K);
w_emit_bias(tbl, dev, out, out_off, bias, N);
return;
}
// Fused dispatch
// ...
}
KV Cache Without RoPE
Unlike LLMs, Whisper’s decoder uses learned sinusoidal positional embeddings added to the token embedding, not RoPE applied to Q/K at each layer. This means the KV cache write kernel is simpler – it just stores the projected K and V without any positional rotation:
// KV write: no RoPE (function constant = false)
uint32_t kv_fc_v[] = {0}; // 0 = no positional encoding
uint32_t kv_fc_t[] = {1}; // bool type
Pipeline kv_pso = dev.get_pipeline("kv_cache_write_f16",
"kv_write_nopos", kv_fc_i, kv_fc_v, 1, kv_fc_t);
Cross-Attention
Cross-attention in the decoder uses the precomputed K/V buffers from the encoder. These are static – they do not change between tokens – so there is no KV cache write for cross-attention:
// Cross-attention: static K/V, fixed kv_seq_len = enc_seq
struct attn_params = {
.seq_len = 1,
.kv_seq_len = (uint32_t)enc_seq, // always 1500
// ...
};
cmd.add_buffer(wb.cross_k[l], 0, 1); // precomputed
cmd.add_buffer(wb.cross_v[l], 0, 2); // precomputed
This is why cross-attention precomputation matters – the GEMV for projecting K and V would otherwise happen at every decoder step.
Token Suppression
Whisper has special tokens (timestamps, language tags, task tokens) that should be suppressed during normal text generation. Akunu handles this with a GPU-side suppression kernel:
// Suppress special tokens [first_special, first_timestamp)
// This allows timestamp tokens through for timestamps mode
struct {
uint32_t first_special, vocab_size, eot, suppress_blank;
} sp = {wdp.first_special, suppress_end, wdp.eot, 0};
dispatch("whisper_suppress_f16", n_suppress, 256);
The kernel sets logits for suppressed tokens to negative infinity, effectively removing them from consideration by the argmax or sampling step. Timestamps are not suppressed, which allows the model to output timestamp tokens when running in timestamps mode.
The ArchDescriptor for Whisper
Whisper’s unique properties are captured in the ArchDescriptor:
inline ArchDescriptor arch_whisper() {
ArchDescriptor d = {};
d.activation_kernel = "gelu_f16"; // plain GELU, no gate
d.embedding_scale = 0.0f;
d.has_qk_norm = false;
d.rope_kernel = nullptr; // no RoPE
d.tie_embeddings = true; // output = embedding^T
d.is_encoder_decoder = true;
d.has_cross_attention = true;
d.has_conv_frontend = true;
d.has_bias = true; // all linears have bias
d.norm_type = "layernorm"; // not rmsnorm
d.encoder_activation = "gelu_f16";
return d;
}
The is_encoder_decoder and has_cross_attention flags tell the initialization code to allocate cross-attention buffers, run the encoder, and build the decoder dispatch table with cross-attention commands.
Model Loading
Whisper models in Akunu are loaded from whisper.cpp’s custom binary format (magic: "lmgg" or "ggjt"), not GGUF. The WhisperModel struct holds all tensor data on GPU:
File structure:
4 bytes: magic ("lmgg" or "ggjt")
44 bytes: hyperparameters (11 x int32)
mel filters (precomputed filterbank from file)
vocabulary (token strings)
tensor data (name + dims + dtype + raw data)
The hyperparameters encode both encoder and decoder dimensions:
| Index | Field | Example (Large) |
|---|---|---|
| 0 | vocab_size | 51865 |
| 1 | n_audio_ctx (enc_seq) | 1500 |
| 2 | n_audio_state (enc_dim) | 1280 |
| 3 | n_audio_head | 20 |
| 4 | n_audio_layer | 32 |
| 5 | n_text_ctx (dec_seq) | 448 |
| 6 | n_text_state (dec_dim) | 1280 |
| 7 | n_text_head | 20 |
| 8 | n_text_layer | 32 |
| 9 | n_mels | 80 |
| 10 | ftype (quant format) | 1 (F16) |
End-to-End Flow
Putting it all together, here is the complete flow for transcribing 30 seconds of audio:
1. Audio input (480,000 float samples at 16kHz)
|
2. MelSpectrogram::compute() [CPU, ~5ms]
| vDSP FFT + cblas_sgemm + log normalization
|
3. Upload mel to GPU buffer [UMA, ~0.1ms]
|
4. encode_whisper() [GPU, ~30ms]
| Conv1D x2 + 32 encoder layers
| Non-causal self-attention (1500 seq len)
|
5. precompute_cross_kv() [GPU, ~15ms]
| 64 GEMMs (K+V for each of 32 layers)
| CPU rearrange to head-major
|
6. Decode loop (greedy or beam search):
| For each output token:
| dispatch_table.execute() [GPU, ~3ms/token]
| Read argmax token ID
| Check for EOT
| Until EOT or max_tokens
|
7. Detokenize output tokens -> text string
For Whisper Large on M2 Pro, typical performance is:
Encoder: ~30ms (one-time)
Cross-KV: ~15ms (one-time)
Decode per token: ~3ms
Typical output: ~50 tokens for 30s of speech
Total decode: ~150ms
Total latency: ~200ms for 30 seconds of audio
Real-time factor: 0.007x (150x faster than real-time)
This makes Whisper on Apple Silicon more than fast enough for real-time streaming transcription, where audio arrives in chunks and the encoder/decoder pipeline overlaps with audio capture.
-
Radford et al., “Robust Speech Recognition via Large-Scale Weak Supervision,” OpenAI, 2022. Whisper was trained on 680,000 hours of weakly supervised audio-text pairs. See https://arxiv.org/abs/2212.04356. ↩
-
The HTK mel scale (named after the Hidden Markov Model Toolkit) defines mel(f) = 2595 * log10(1 + f/700). An alternative “Slaney” definition uses a piecewise linear/log formula. Whisper uses the HTK definition with Slaney normalization of the filterbank triangles. ↩