Skip to main content

whisperforge_core/
kv_cache.rs

1use anyhow::{Context, Result};
2use burn::tensor::{Int, Tensor, TensorData, backend::Backend, module::embedding};
3
4use crate::model::{Whisper, qkv_attention};
5
6/// Precomputed and growing KV cache for O(n) per-step decoder inference.
7///
8/// `cross_kv` holds the encoder cross-attention K,V projections for every decoder
9/// layer. These are constant for the entire decoding of one audio chunk and are
10/// computed once by [`KvCache::new`] before the decode loop.
11///
12/// `self_kv` holds the growing decoder self-attention K,V cache; one new row is
13/// appended per [`forward_decoder_cached`] call.
14pub struct KvCache<B: Backend> {
15    /// Per-layer static cross-attention K,V from the encoder output.
16    cross_kv: Vec<(Tensor<B, 3>, Tensor<B, 3>)>,
17    /// Per-layer growing self-attention K,V; `None` until the first decode step.
18    self_kv: Vec<Option<(Tensor<B, 3>, Tensor<B, 3>)>>,
19    /// Number of tokens decoded so far (indexes into positional embedding).
20    pub step: usize,
21}
22
23impl<B: Backend> KvCache<B> {
24    /// Build a cache pre-populated with cross-attention K,V from `encoder_output`.
25    ///
26    /// The encoder K,V projections (`n_layers × 2` linear ops) are computed once
27    /// here and reused at every subsequent decode step instead of being recomputed
28    /// each time `forward_decoder` is called from scratch.
29    ///
30    /// # Performance
31    ///
32    /// Call this once per audio chunk. The cost is `n_decoder_layers × 2`
33    /// matrix multiplications of shape `[n_encoder_frames × n_text_state]`.
34    pub fn new(model: &Whisper<B>, encoder_output: Tensor<B, 3>) -> Self {
35        let cross_kv = model
36            .decoder
37            .blocks
38            .iter()
39            .map(|block| {
40                let ca = &block.cross_attn;
41                let k = ca.key.forward(encoder_output.clone());
42                let v = ca.value.forward(encoder_output.clone());
43                (k, v)
44            })
45            .collect::<Vec<_>>();
46
47        let n_layers = cross_kv.len();
48        Self {
49            cross_kv,
50            self_kv: vec![None; n_layers],
51            step: 0,
52        }
53    }
54}
55
56/// Decode one new token using the KV cache, updating the cache in-place.
57///
58/// Returns vocabulary logits for the new position as a flat `Vec<f32>` of
59/// length `n_vocab`.
60///
61/// # Why this is faster than `forward_decoder`
62///
63/// `forward_decoder` processes the **full growing token sequence** at every step
64/// (O(n) work per step → O(n²) total). This function processes only the **single
65/// new token** and reads cached K,V for past positions (O(1) per step → O(n)
66/// total). The cross-attention K,V are also cached across steps (constant).
67///
68/// # Correctness
69///
70/// The decoder causal mask is not needed when Q has sequence length 1 — a single
71/// query position can attend to all cached K,V without violating causality.
72pub fn forward_decoder_cached<B: Backend>(
73    model: &Whisper<B>,
74    token: u32,
75    cache: &mut KvCache<B>,
76    device: &B::Device,
77) -> Result<Vec<f32>> {
78    let decoder = &model.decoder;
79    let step = cache.step;
80
81    // Embed the single new token and add its positional encoding. [1, 1, n_text_state]
82    let token_tensor: Tensor<B, 2, Int> =
83        Tensor::from_data(TensorData::new(vec![token as i32], [1, 1]), device);
84    let mut x = embedding(decoder.token_embedding.val(), token_tensor)
85        + decoder
86            .positional_embedding
87            .val()
88            .slice([step..(step + 1)])
89            .unsqueeze::<3>();
90
91    for (layer_idx, block) in decoder.blocks.iter().enumerate() {
92        // --- Self-attention with growing KV cache ---
93        let x_norm = block.attn_ln.forward(x.clone());
94        let sa = &block.attn;
95
96        // Project only the new token into Q, K, V. [1, 1, n_text_state]
97        let q = sa.query.forward(x_norm.clone());
98        let k_new = sa.key.forward(x_norm.clone());
99        let v_new = sa.value.forward(x_norm);
100
101        // Concatenate new K,V with accumulated cache. [1, step+1, n_text_state]
102        let (k_full, v_full) = match cache.self_kv[layer_idx].take() {
103            Some((k_prev, v_prev)) => (
104                Tensor::cat(vec![k_prev, k_new], 1),
105                Tensor::cat(vec![v_prev, v_new], 1),
106            ),
107            None => (k_new, v_new),
108        };
109        cache.self_kv[layer_idx] = Some((k_full.clone(), v_full.clone()));
110
111        // Attention: Q[1,1,D] × K[1,step+1,D] — no mask needed (Q len = 1).
112        let sa_out = qkv_attention(q, k_full, v_full, None, sa.n_head);
113        x = x + sa.out.forward(sa_out);
114
115        // --- Cross-attention: reuse static K,V precomputed from encoder output ---
116        let x_norm = block.cross_attn_ln.forward(x.clone());
117        let ca = &block.cross_attn;
118        let q = ca.query.forward(x_norm);
119        let (k_cross, v_cross) = &cache.cross_kv[layer_idx];
120        let ca_out = qkv_attention(q, k_cross.clone(), v_cross.clone(), None, ca.n_head);
121        x = x + ca.out.forward(ca_out);
122
123        // --- Feed-forward ---
124        x = x.clone() + block.mlp.forward(block.mlp_ln.forward(x));
125    }
126
127    cache.step += 1;
128
129    // Final layer norm + project to vocabulary. [1, 1, n_vocab] → [n_vocab]
130    let x = decoder.ln.forward(x);
131    let logits = x.matmul(decoder.token_embedding.val().transpose().unsqueeze::<3>());
132
133    let [_, _, vocab_size] = logits.dims();
134    logits
135        .squeeze::<1>()
136        .into_data()
137        .to_vec::<f32>()
138        .map_err(|e| anyhow::anyhow!("logit extraction failed: {:?}", e))
139        .with_context(|| format!("forward_decoder_cached step {step}, vocab_size={vocab_size}"))
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use burn::tensor::{Distribution, Int, TensorData};
146    use burn_flex::Flex;
147    use burn_flex::FlexDevice;
148
149    fn tiny_en_random() -> (crate::model::Whisper<Flex<f32>>, FlexDevice) {
150        let device = FlexDevice;
151        let config = crate::model::WhisperConfig::tiny_en();
152        let model = config.init::<Flex<f32>>(&device);
153        (model, device)
154    }
155
156    #[test]
157    fn test_kv_cache_step_counter() {
158        let (model, device) = tiny_en_random();
159        let encoder_out = Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
160        let mut cache = KvCache::new(&model, encoder_out);
161        assert_eq!(cache.step, 0);
162        forward_decoder_cached(&model, 50258u32, &mut cache, &device).unwrap();
163        assert_eq!(cache.step, 1);
164        forward_decoder_cached(&model, 50259u32, &mut cache, &device).unwrap();
165        assert_eq!(cache.step, 2);
166    }
167
168    #[test]
169    fn test_kv_cache_logit_shape() {
170        let (model, device) = tiny_en_random();
171        let encoder_out = Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
172        let mut cache = KvCache::new(&model, encoder_out);
173        let logits = forward_decoder_cached(&model, 50258u32, &mut cache, &device).unwrap();
174        assert_eq!(logits.len(), 51864);
175    }
176
177    /// Verify that the KV-cached decoder produces numerically identical logits to
178    /// `forward_decoder` for the same token sequence and encoder output.
179    ///
180    /// Uses random weights so no model files are needed. Tolerance 1e-4 covers
181    /// float32 rounding from different operation order.
182    #[test]
183    fn test_kv_cache_matches_forward_decoder() {
184        let (model, device) = tiny_en_random();
185        let encoder_out =
186            Tensor::<Flex<f32>, 3>::random([1, 1500, 384], Distribution::Normal(0.0, 0.1), &device);
187
188        // Typical initial context: sot, en, transcribe, no_timestamps
189        let init: [u32; 4] = [50258, 50259, 50359, 50363];
190
191        // --- Original forward_decoder: full sequence in one call ---
192        let token_tensor: Tensor<Flex<f32>, 2, Int> = Tensor::from_data(
193            TensorData::new(init.iter().map(|&t| t as i32).collect::<Vec<_>>(), [1, 4]),
194            &device,
195        );
196        let logits_full = model.forward_decoder(token_tensor, encoder_out.clone());
197        let [b, seq, vocab] = logits_full.dims();
198        let orig: Vec<f32> = logits_full
199            .slice([0..b, (seq - 1)..seq, 0..vocab])
200            .squeeze::<1>()
201            .into_data()
202            .to_vec::<f32>()
203            .unwrap();
204
205        // --- KV-cached path: one token at a time ---
206        let mut cache = KvCache::new(&model, encoder_out);
207        let mut cached = Vec::new();
208        for &tok in &init {
209            cached = forward_decoder_cached(&model, tok, &mut cache, &device).unwrap();
210        }
211
212        assert_eq!(orig.len(), cached.len());
213        let max_diff = orig
214            .iter()
215            .zip(cached.iter())
216            .map(|(a, b)| (a - b).abs())
217            .fold(0.0f32, f32::max);
218        assert!(
219            max_diff < 1e-4,
220            "KV-cached logits diverge from forward_decoder by {max_diff:.2e} (expected < 1e-4)"
221        );
222    }
223}