1use ndarray::{s, Array2, Array3, ArrayD, IxDyn};
2use std::fmt;
3
4use super::position::{PositionError, RotaryPositionEmbedding};
5use super::simple_cache::{KvCache, KvCacheError};
6
7#[derive(Debug, Clone)]
9pub enum CachedAttentionError {
10 KvCacheError(KvCacheError),
12 PositionError(PositionError),
14 InvalidShape(String),
16}
17
18impl fmt::Display for CachedAttentionError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 Self::KvCacheError(e) => write!(f, "KV-cache error: {}", e),
22 Self::PositionError(e) => write!(f, "Position encoding error: {}", e),
23 Self::InvalidShape(msg) => write!(f, "Invalid shape: {}", msg),
24 }
25 }
26}
27
28impl std::error::Error for CachedAttentionError {}
29
30impl From<KvCacheError> for CachedAttentionError {
31 fn from(e: KvCacheError) -> Self {
32 Self::KvCacheError(e)
33 }
34}
35
36impl From<PositionError> for CachedAttentionError {
37 fn from(e: PositionError) -> Self {
38 Self::PositionError(e)
39 }
40}
41
42#[derive(Debug, Clone)]
47pub struct CachedAttention {
48 pub num_heads: usize,
50 pub head_dim: usize,
52 pub scale: f64,
54 pub rope: Option<RotaryPositionEmbedding>,
56 pub use_causal_mask: bool,
58}
59
60impl CachedAttention {
61 pub fn new(
66 num_heads: usize,
67 head_dim: usize,
68 use_rope: bool,
69 max_seq_len: usize,
70 ) -> std::result::Result<Self, CachedAttentionError> {
71 let scale = 1.0 / (head_dim as f64).sqrt();
72 let rope = if use_rope {
73 Some(
74 RotaryPositionEmbedding::new(head_dim, max_seq_len, 10000.0)
75 .map_err(CachedAttentionError::PositionError)?,
76 )
77 } else {
78 None
79 };
80 Ok(Self {
81 num_heads,
82 head_dim,
83 scale,
84 rope,
85 use_causal_mask: true,
86 })
87 }
88
89 pub fn forward(
98 &self,
99 query: &ArrayD<f64>,
100 key: &ArrayD<f64>,
101 value: &ArrayD<f64>,
102 cache: Option<&mut KvCache>,
103 layer_idx: usize,
104 ) -> std::result::Result<ArrayD<f64>, CachedAttentionError> {
105 let q_shape = query.shape();
106 if q_shape.len() != 3 {
107 return Err(CachedAttentionError::InvalidShape(format!(
108 "query must be 3-D [batch, seq, d], got {} dims",
109 q_shape.len()
110 )));
111 }
112 let batch = q_shape[0];
113 let seq_len = q_shape[1];
114 let d = q_shape[2];
115
116 if d != self.num_heads * self.head_dim {
117 return Err(CachedAttentionError::InvalidShape(format!(
118 "last dim {} != num_heads * head_dim = {}",
119 d,
120 self.num_heads * self.head_dim
121 )));
122 }
123
124 let q = query
126 .view()
127 .into_shape_with_order(IxDyn(&[batch * seq_len, self.num_heads, self.head_dim]))
128 .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?
129 .to_owned();
130
131 let mut k = key
132 .view()
133 .into_shape_with_order(IxDyn(&[batch * seq_len, self.num_heads, self.head_dim]))
134 .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?
135 .to_owned();
136
137 let v = value
138 .view()
139 .into_shape_with_order(IxDyn(&[batch * seq_len, self.num_heads, self.head_dim]))
140 .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?
141 .to_owned();
142
143 let seq_offset = cache.as_ref().map(|c| c.seq_len).unwrap_or(0);
145
146 let (q_rope, k_rope) = if let Some(rope) = &self.rope {
147 let q_r = rope
148 .apply(&q, seq_offset)
149 .map_err(CachedAttentionError::PositionError)?;
150 let k_r = rope
151 .apply(&k, seq_offset)
152 .map_err(CachedAttentionError::PositionError)?;
153 (q_r, k_r)
154 } else {
155 (q, k.clone())
156 };
157
158 let (full_k, full_v) = if let Some(cache_ref) = cache {
160 cache_ref
161 .append_kv(layer_idx, k_rope.clone(), v.clone())
162 .map_err(CachedAttentionError::KvCacheError)?;
163 let (ck, cv) = cache_ref.get_kv(layer_idx).ok_or({
164 CachedAttentionError::KvCacheError(KvCacheError::LayerOutOfBounds {
165 layer: layer_idx,
166 num_layers: cache_ref.num_layers,
167 })
168 })?;
169 (ck.to_owned(), cv.to_owned())
170 } else {
171 k = k_rope;
172 (k, v)
173 };
174
175 let cache_len = full_k.shape()[0] / batch.max(1);
176
177 let mask = if self.use_causal_mask {
179 Some(Self::causal_mask(seq_len, cache_len))
180 } else {
181 None
182 };
183
184 self.scaled_dot_product(&q_rope, &full_k, &full_v, mask.as_ref())
187 .map(|out| {
188 out.into_shape_with_order(IxDyn(&[batch, seq_len, self.num_heads * self.head_dim]))
190 .unwrap_or_else(|_| {
191 ArrayD::zeros(IxDyn(&[batch, seq_len, self.num_heads * self.head_dim]))
192 })
193 })
194 }
195
196 pub fn causal_mask(seq_len: usize, cache_len: usize) -> Array2<f64> {
201 let total_k = cache_len + seq_len;
202 let mut mask = Array2::<f64>::zeros((seq_len, total_k));
203 for q in 0..seq_len {
204 for k in 0..total_k {
207 if k > cache_len + q {
208 mask[[q, k]] = -1.0e9;
209 }
210 }
211 }
212 mask
213 }
214
215 pub fn scaled_dot_product(
223 &self,
224 q: &ArrayD<f64>,
225 k: &ArrayD<f64>,
226 v: &ArrayD<f64>,
227 mask: Option<&Array2<f64>>,
228 ) -> std::result::Result<ArrayD<f64>, CachedAttentionError> {
229 let q_shape = q.shape();
230 let k_shape = k.shape();
231
232 if q_shape.len() != 3 || k_shape.len() != 3 {
233 return Err(CachedAttentionError::InvalidShape(
234 "q, k, v must be 3-D [tokens, heads, head_dim]".to_string(),
235 ));
236 }
237
238 let total_q = q_shape[0];
239 let total_k = k_shape[0];
240 let num_heads = q_shape[1];
241 let head_dim = q_shape[2];
242
243 if head_dim == 0 || num_heads == 0 {
244 return Err(CachedAttentionError::InvalidShape(
245 "head_dim and num_heads must be > 0".to_string(),
246 ));
247 }
248
249 let mut scores = Array3::<f64>::zeros((total_q, num_heads, total_k));
252
253 let q3 = q
254 .view()
255 .into_shape_with_order((total_q, num_heads, head_dim))
256 .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?;
257
258 let k3 = k
259 .view()
260 .into_shape_with_order((total_k, num_heads, head_dim))
261 .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?;
262
263 for i in 0..total_q {
264 for h in 0..num_heads {
265 for j in 0..total_k {
266 let mut dot = 0.0_f64;
267 for d in 0..head_dim {
268 dot += q3[[i, h, d]] * k3[[j, h, d]];
269 }
270 scores[[i, h, j]] = dot * self.scale;
271 }
272 }
273 }
274
275 if let Some(m) = mask {
277 let mask_q = m.shape()[0];
278 let mask_k = m.shape()[1];
279 for i in 0..total_q.min(mask_q) {
280 for h in 0..num_heads {
281 for j in 0..total_k.min(mask_k) {
282 scores[[i, h, j]] += m[[i, j]];
283 }
284 }
285 }
286 }
287
288 for i in 0..total_q {
290 for h in 0..num_heads {
291 let row_max = scores
292 .slice(s![i, h, ..])
293 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
294 let mut sum = 0.0_f64;
295 for j in 0..total_k {
296 scores[[i, h, j]] = (scores[[i, h, j]] - row_max).exp();
297 sum += scores[[i, h, j]];
298 }
299 let safe_sum = if sum == 0.0 { 1.0 } else { sum };
300 for j in 0..total_k {
301 scores[[i, h, j]] /= safe_sum;
302 }
303 }
304 }
305
306 let v_shape = v.shape();
308 let v3 = v
309 .view()
310 .into_shape_with_order((v_shape[0], num_heads, head_dim))
311 .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?;
312
313 let mut output = Array3::<f64>::zeros((total_q, num_heads, head_dim));
314
315 for i in 0..total_q {
316 for h in 0..num_heads {
317 for d in 0..head_dim {
318 let mut val = 0.0_f64;
319 for j in 0..total_k {
320 val += scores[[i, h, j]] * v3[[j, h, d]];
321 }
322 output[[i, h, d]] = val;
323 }
324 }
325 }
326
327 Ok(output.into_dyn())
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
336 ArrayD::from_elem(IxDyn(shape), fill)
337 }
338
339 #[test]
340 fn test_cached_attention_forward_no_cache() {
341 let attn = CachedAttention::new(2, 4, false, 32).expect("valid config");
342 let q = make_tensor(&[1, 3, 8], 0.5);
344 let k = make_tensor(&[1, 3, 8], 0.5);
345 let v = make_tensor(&[1, 3, 8], 0.5);
346 let out = attn
347 .forward(&q, &k, &v, None, 0)
348 .expect("forward should succeed");
349 assert_eq!(
350 out.shape(),
351 &[1, 3, 8],
352 "output shape must be [batch, seq, d]"
353 );
354 }
355
356 #[test]
357 fn test_cached_attention_causal_mask_shape() {
358 let mask = CachedAttention::causal_mask(4, 0);
359 assert_eq!(mask.shape(), &[4, 4], "causal mask must be [seq, seq]");
360 assert!(mask[[0, 1]] < -1e8, "future positions should be masked");
362 assert!(
364 (mask[[1, 0]]).abs() < 1e-9,
365 "past positions should not be masked"
366 );
367 }
368
369 #[test]
370 fn test_cached_attention_with_cache_extends_seq() {
371 let attn = CachedAttention::new(2, 4, false, 64).expect("valid");
372 let mut cache = KvCache::new(1, 2, 4, 64);
373 let q = make_tensor(&[1, 2, 8], 0.1);
374 let k = make_tensor(&[1, 2, 8], 0.1);
375 let v = make_tensor(&[1, 2, 8], 0.1);
376 attn.forward(&q, &k, &v, Some(&mut cache), 0)
377 .expect("forward with cache");
378 assert!(cache.seq_len > 0, "cache seq_len should grow after forward");
379 }
380
381 #[test]
382 fn test_cached_attention_error_display() {
383 let err = CachedAttentionError::InvalidShape("bad shape".to_string());
384 let s = err.to_string();
385 assert!(
386 s.contains("bad shape"),
387 "Display impl should include the message"
388 );
389 }
390}