1use anyhow::{Context, Result};
2use burn::tensor::{Int, Tensor, TensorData, backend::Backend, module::embedding};
3
4use crate::model::{Whisper, qkv_attention};
5
6pub struct KvCache<B: Backend> {
15 cross_kv: Vec<(Tensor<B, 3>, Tensor<B, 3>)>,
17 self_kv: Vec<Option<(Tensor<B, 3>, Tensor<B, 3>)>>,
19 pub step: usize,
21}
22
23impl<B: Backend> KvCache<B> {
24 pub fn new(model: &Whisper<B>, encoder_output: Tensor<B, 3>) -> Self {
35 let cross_kv = model
36 .decoder
37 .blocks
38 .iter()
39 .map(|block| {
40 let ca = &block.cross_attn;
41 let k = ca.key.forward(encoder_output.clone());
42 let v = ca.value.forward(encoder_output.clone());
43 (k, v)
44 })
45 .collect::<Vec<_>>();
46
47 let n_layers = cross_kv.len();
48 Self {
49 cross_kv,
50 self_kv: vec![None; n_layers],
51 step: 0,
52 }
53 }
54}
55
56pub fn forward_decoder_cached<B: Backend>(
73 model: &Whisper<B>,
74 token: u32,
75 cache: &mut KvCache<B>,
76 device: &B::Device,
77) -> Result<Vec<f32>> {
78 let decoder = &model.decoder;
79 let step = cache.step;
80
81 let token_tensor: Tensor<B, 2, Int> =
83 Tensor::from_data(TensorData::new(vec![token as i32], [1, 1]), device);
84 let mut x = embedding(decoder.token_embedding.val(), token_tensor)
85 + decoder
86 .positional_embedding
87 .val()
88 .slice([step..(step + 1)])
89 .unsqueeze::<3>();
90
91 for (layer_idx, block) in decoder.blocks.iter().enumerate() {
92 let x_norm = block.attn_ln.forward(x.clone());
94 let sa = &block.attn;
95
96 let q = sa.query.forward(x_norm.clone());
98 let k_new = sa.key.forward(x_norm.clone());
99 let v_new = sa.value.forward(x_norm);
100
101 let (k_full, v_full) = match cache.self_kv[layer_idx].take() {
103 Some((k_prev, v_prev)) => (
104 Tensor::cat(vec![k_prev, k_new], 1),
105 Tensor::cat(vec![v_prev, v_new], 1),
106 ),
107 None => (k_new, v_new),
108 };
109 cache.self_kv[layer_idx] = Some((k_full.clone(), v_full.clone()));
110
111 let sa_out = qkv_attention(q, k_full, v_full, None, sa.n_head);
113 x = x + sa.out.forward(sa_out);
114
115 let x_norm = block.cross_attn_ln.forward(x.clone());
117 let ca = &block.cross_attn;
118 let q = ca.query.forward(x_norm);
119 let (k_cross, v_cross) = &cache.cross_kv[layer_idx];
120 let ca_out = qkv_attention(q, k_cross.clone(), v_cross.clone(), None, ca.n_head);
121 x = x + ca.out.forward(ca_out);
122
123 x = x.clone() + block.mlp.forward(block.mlp_ln.forward(x));
125 }
126
127 cache.step += 1;
128
129 let x = decoder.ln.forward(x);
131 let logits = x.matmul(decoder.token_embedding.val().transpose().unsqueeze::<3>());
132
133 let [_, _, vocab_size] = logits.dims();
134 logits
135 .squeeze::<1>()
136 .into_data()
137 .to_vec::<f32>()
138 .map_err(|e| anyhow::anyhow!("logit extraction failed: {:?}", e))
139 .with_context(|| format!("forward_decoder_cached step {step}, vocab_size={vocab_size}"))
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use burn::tensor::{Distribution, Int, TensorData};
146 use burn_flex::Flex;
147 use burn_flex::FlexDevice;
148
149 fn tiny_en_random() -> (crate::model::Whisper<Flex<f32>>, FlexDevice) {
150 let device = FlexDevice;
151 let config = crate::model::WhisperConfig::tiny_en();
152 let model = config.init::<Flex<f32>>(&device);
153 (model, device)
154 }
155
156 #[test]
157 fn test_kv_cache_step_counter() {
158 let (model, device) = tiny_en_random();
159 let encoder_out = Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
160 let mut cache = KvCache::new(&model, encoder_out);
161 assert_eq!(cache.step, 0);
162 forward_decoder_cached(&model, 50258u32, &mut cache, &device).unwrap();
163 assert_eq!(cache.step, 1);
164 forward_decoder_cached(&model, 50259u32, &mut cache, &device).unwrap();
165 assert_eq!(cache.step, 2);
166 }
167
168 #[test]
169 fn test_kv_cache_logit_shape() {
170 let (model, device) = tiny_en_random();
171 let encoder_out = Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
172 let mut cache = KvCache::new(&model, encoder_out);
173 let logits = forward_decoder_cached(&model, 50258u32, &mut cache, &device).unwrap();
174 assert_eq!(logits.len(), 51864);
175 }
176
177 #[test]
183 fn test_kv_cache_matches_forward_decoder() {
184 let (model, device) = tiny_en_random();
185 let encoder_out =
186 Tensor::<Flex<f32>, 3>::random([1, 1500, 384], Distribution::Normal(0.0, 0.1), &device);
187
188 let init: [u32; 4] = [50258, 50259, 50359, 50363];
190
191 let token_tensor: Tensor<Flex<f32>, 2, Int> = Tensor::from_data(
193 TensorData::new(init.iter().map(|&t| t as i32).collect::<Vec<_>>(), [1, 4]),
194 &device,
195 );
196 let logits_full = model.forward_decoder(token_tensor, encoder_out.clone());
197 let [b, seq, vocab] = logits_full.dims();
198 let orig: Vec<f32> = logits_full
199 .slice([0..b, (seq - 1)..seq, 0..vocab])
200 .squeeze::<1>()
201 .into_data()
202 .to_vec::<f32>()
203 .unwrap();
204
205 let mut cache = KvCache::new(&model, encoder_out);
207 let mut cached = Vec::new();
208 for &tok in &init {
209 cached = forward_decoder_cached(&model, tok, &mut cache, &device).unwrap();
210 }
211
212 assert_eq!(orig.len(), cached.len());
213 let max_diff = orig
214 .iter()
215 .zip(cached.iter())
216 .map(|(a, b)| (a - b).abs())
217 .fold(0.0f32, f32::max);
218 assert!(
219 max_diff < 1e-4,
220 "KV-cached logits diverge from forward_decoder by {max_diff:.2e} (expected < 1e-4)"
221 );
222 }
223}