Skip to main content

tensorlogic_trustformers/kv_cache/
cached_attention.rs

1use ndarray::{s, Array2, Array3, ArrayD, IxDyn};
2use std::fmt;
3
4use super::position::{PositionError, RotaryPositionEmbedding};
5use super::simple_cache::{KvCache, KvCacheError};
6
7/// Errors that can occur during cached attention forward passes.
8#[derive(Debug, Clone)]
9pub enum CachedAttentionError {
10    /// Wrapped KV-cache error.
11    KvCacheError(KvCacheError),
12    /// Wrapped position encoding error.
13    PositionError(PositionError),
14    /// General shape or configuration error.
15    InvalidShape(String),
16}
17
18impl fmt::Display for CachedAttentionError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            Self::KvCacheError(e) => write!(f, "KV-cache error: {}", e),
22            Self::PositionError(e) => write!(f, "Position encoding error: {}", e),
23            Self::InvalidShape(msg) => write!(f, "Invalid shape: {}", msg),
24        }
25    }
26}
27
28impl std::error::Error for CachedAttentionError {}
29
30impl From<KvCacheError> for CachedAttentionError {
31    fn from(e: KvCacheError) -> Self {
32        Self::KvCacheError(e)
33    }
34}
35
36impl From<PositionError> for CachedAttentionError {
37    fn from(e: PositionError) -> Self {
38        Self::PositionError(e)
39    }
40}
41
42/// Scaled dot-product multi-head attention with optional KV-cache and RoPE.
43///
44/// Inputs are assumed to have shape `[batch, seq_len, num_heads * head_dim]` and
45/// are internally reshaped to `[batch, seq_len, num_heads, head_dim]`.
46#[derive(Debug, Clone)]
47pub struct CachedAttention {
48    /// Number of attention heads.
49    pub num_heads: usize,
50    /// Dimension of each head.
51    pub head_dim: usize,
52    /// Attention scale factor (defaults to `1 / sqrt(head_dim)`).
53    pub scale: f64,
54    /// Optional Rotary Position Embedding applied to Q and K.
55    pub rope: Option<RotaryPositionEmbedding>,
56    /// If `true`, apply a causal mask to prevent attending to future positions.
57    pub use_causal_mask: bool,
58}
59
60impl CachedAttention {
61    /// Create a new `CachedAttention`.
62    ///
63    /// When `use_rope` is `true`, a `RotaryPositionEmbedding` is pre-built for
64    /// `max_seq_len` positions using the standard base of 10000.
65    pub fn new(
66        num_heads: usize,
67        head_dim: usize,
68        use_rope: bool,
69        max_seq_len: usize,
70    ) -> std::result::Result<Self, CachedAttentionError> {
71        let scale = 1.0 / (head_dim as f64).sqrt();
72        let rope = if use_rope {
73            Some(
74                RotaryPositionEmbedding::new(head_dim, max_seq_len, 10000.0)
75                    .map_err(CachedAttentionError::PositionError)?,
76            )
77        } else {
78            None
79        };
80        Ok(Self {
81            num_heads,
82            head_dim,
83            scale,
84            rope,
85            use_causal_mask: true,
86        })
87    }
88
89    /// Run the forward pass.
90    ///
91    /// * `query`, `key`, `value` — shape `[batch, seq_len, num_heads * head_dim]`
92    /// * `cache` — optional mutable KV-cache; keys/values from previous steps are
93    ///   prepended before computing attention.
94    /// * `layer_idx` — index used when reading/writing the cache.
95    ///
96    /// Returns output of shape `[batch, seq_len, num_heads * head_dim]`.
97    pub fn forward(
98        &self,
99        query: &ArrayD<f64>,
100        key: &ArrayD<f64>,
101        value: &ArrayD<f64>,
102        cache: Option<&mut KvCache>,
103        layer_idx: usize,
104    ) -> std::result::Result<ArrayD<f64>, CachedAttentionError> {
105        let q_shape = query.shape();
106        if q_shape.len() != 3 {
107            return Err(CachedAttentionError::InvalidShape(format!(
108                "query must be 3-D [batch, seq, d], got {} dims",
109                q_shape.len()
110            )));
111        }
112        let batch = q_shape[0];
113        let seq_len = q_shape[1];
114        let d = q_shape[2];
115
116        if d != self.num_heads * self.head_dim {
117            return Err(CachedAttentionError::InvalidShape(format!(
118                "last dim {} != num_heads * head_dim = {}",
119                d,
120                self.num_heads * self.head_dim
121            )));
122        }
123
124        // Reshape Q, K, V to [batch * seq, num_heads, head_dim] for easier ops.
125        let q = query
126            .view()
127            .into_shape_with_order(IxDyn(&[batch * seq_len, self.num_heads, self.head_dim]))
128            .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?
129            .to_owned();
130
131        let mut k = key
132            .view()
133            .into_shape_with_order(IxDyn(&[batch * seq_len, self.num_heads, self.head_dim]))
134            .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?
135            .to_owned();
136
137        let v = value
138            .view()
139            .into_shape_with_order(IxDyn(&[batch * seq_len, self.num_heads, self.head_dim]))
140            .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?
141            .to_owned();
142
143        // Apply RoPE to Q and K if configured.
144        let seq_offset = cache.as_ref().map(|c| c.seq_len).unwrap_or(0);
145
146        let (q_rope, k_rope) = if let Some(rope) = &self.rope {
147            let q_r = rope
148                .apply(&q, seq_offset)
149                .map_err(CachedAttentionError::PositionError)?;
150            let k_r = rope
151                .apply(&k, seq_offset)
152                .map_err(CachedAttentionError::PositionError)?;
153            (q_r, k_r)
154        } else {
155            (q, k.clone())
156        };
157
158        // Append current K, V to cache (if present), then read full K, V.
159        let (full_k, full_v) = if let Some(cache_ref) = cache {
160            cache_ref
161                .append_kv(layer_idx, k_rope.clone(), v.clone())
162                .map_err(CachedAttentionError::KvCacheError)?;
163            let (ck, cv) = cache_ref.get_kv(layer_idx).ok_or({
164                CachedAttentionError::KvCacheError(KvCacheError::LayerOutOfBounds {
165                    layer: layer_idx,
166                    num_layers: cache_ref.num_layers,
167                })
168            })?;
169            (ck.to_owned(), cv.to_owned())
170        } else {
171            k = k_rope;
172            (k, v)
173        };
174
175        let cache_len = full_k.shape()[0] / batch.max(1);
176
177        // Build optional causal mask.
178        let mask = if self.use_causal_mask {
179            Some(Self::causal_mask(seq_len, cache_len))
180        } else {
181            None
182        };
183
184        // Reshape Q to [seq_len, num_heads, head_dim] (single batch for simplicity).
185        // Full attention: Q [seq, heads, d], K [cache+seq, heads, d], V [cache+seq, heads, d].
186        self.scaled_dot_product(&q_rope, &full_k, &full_v, mask.as_ref())
187            .map(|out| {
188                // Reshape output back to [batch, seq_len, num_heads * head_dim].
189                out.into_shape_with_order(IxDyn(&[batch, seq_len, self.num_heads * self.head_dim]))
190                    .unwrap_or_else(|_| {
191                        ArrayD::zeros(IxDyn(&[batch, seq_len, self.num_heads * self.head_dim]))
192                    })
193            })
194    }
195
196    /// Build a lower-triangular causal mask of shape `[seq_len, cache_len + seq_len]`.
197    ///
198    /// Positions where attention is allowed have value `0.0`; masked positions
199    /// have value `-1e9` (a large negative additive bias).
200    pub fn causal_mask(seq_len: usize, cache_len: usize) -> Array2<f64> {
201        let total_k = cache_len + seq_len;
202        let mut mask = Array2::<f64>::zeros((seq_len, total_k));
203        for q in 0..seq_len {
204            // Query position relative to the full key sequence is (cache_len + q).
205            // Allow attention to positions <= cache_len + q.
206            for k in 0..total_k {
207                if k > cache_len + q {
208                    mask[[q, k]] = -1.0e9;
209                }
210            }
211        }
212        mask
213    }
214
215    /// Compute scaled dot-product attention.
216    ///
217    /// * `q` — shape `[total_q, num_heads, head_dim]`
218    /// * `k` — shape `[total_k, num_heads, head_dim]`
219    /// * `v` — shape `[total_k, num_heads, head_dim]`
220    /// * `mask` — optional additive mask of shape `[total_q / num_heads, total_k / num_heads]`
221    ///   or `[seq_q, seq_k]` that is broadcast across heads.
222    pub fn scaled_dot_product(
223        &self,
224        q: &ArrayD<f64>,
225        k: &ArrayD<f64>,
226        v: &ArrayD<f64>,
227        mask: Option<&Array2<f64>>,
228    ) -> std::result::Result<ArrayD<f64>, CachedAttentionError> {
229        let q_shape = q.shape();
230        let k_shape = k.shape();
231
232        if q_shape.len() != 3 || k_shape.len() != 3 {
233            return Err(CachedAttentionError::InvalidShape(
234                "q, k, v must be 3-D [tokens, heads, head_dim]".to_string(),
235            ));
236        }
237
238        let total_q = q_shape[0];
239        let total_k = k_shape[0];
240        let num_heads = q_shape[1];
241        let head_dim = q_shape[2];
242
243        if head_dim == 0 || num_heads == 0 {
244            return Err(CachedAttentionError::InvalidShape(
245                "head_dim and num_heads must be > 0".to_string(),
246            ));
247        }
248
249        // Compute attention scores: [total_q, num_heads, total_k]
250        // scores[i, h, j] = sum_d q[i, h, d] * k[j, h, d] * scale
251        let mut scores = Array3::<f64>::zeros((total_q, num_heads, total_k));
252
253        let q3 = q
254            .view()
255            .into_shape_with_order((total_q, num_heads, head_dim))
256            .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?;
257
258        let k3 = k
259            .view()
260            .into_shape_with_order((total_k, num_heads, head_dim))
261            .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?;
262
263        for i in 0..total_q {
264            for h in 0..num_heads {
265                for j in 0..total_k {
266                    let mut dot = 0.0_f64;
267                    for d in 0..head_dim {
268                        dot += q3[[i, h, d]] * k3[[j, h, d]];
269                    }
270                    scores[[i, h, j]] = dot * self.scale;
271                }
272            }
273        }
274
275        // Apply mask if provided.
276        if let Some(m) = mask {
277            let mask_q = m.shape()[0];
278            let mask_k = m.shape()[1];
279            for i in 0..total_q.min(mask_q) {
280                for h in 0..num_heads {
281                    for j in 0..total_k.min(mask_k) {
282                        scores[[i, h, j]] += m[[i, j]];
283                    }
284                }
285            }
286        }
287
288        // Softmax over key dimension (axis 2).
289        for i in 0..total_q {
290            for h in 0..num_heads {
291                let row_max = scores
292                    .slice(s![i, h, ..])
293                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
294                let mut sum = 0.0_f64;
295                for j in 0..total_k {
296                    scores[[i, h, j]] = (scores[[i, h, j]] - row_max).exp();
297                    sum += scores[[i, h, j]];
298                }
299                let safe_sum = if sum == 0.0 { 1.0 } else { sum };
300                for j in 0..total_k {
301                    scores[[i, h, j]] /= safe_sum;
302                }
303            }
304        }
305
306        // Weighted sum over values: output[i, h, d] = sum_j scores[i, h, j] * v[j, h, d]
307        let v_shape = v.shape();
308        let v3 = v
309            .view()
310            .into_shape_with_order((v_shape[0], num_heads, head_dim))
311            .map_err(|e| CachedAttentionError::InvalidShape(e.to_string()))?;
312
313        let mut output = Array3::<f64>::zeros((total_q, num_heads, head_dim));
314
315        for i in 0..total_q {
316            for h in 0..num_heads {
317                for d in 0..head_dim {
318                    let mut val = 0.0_f64;
319                    for j in 0..total_k {
320                        val += scores[[i, h, j]] * v3[[j, h, d]];
321                    }
322                    output[[i, h, d]] = val;
323                }
324            }
325        }
326
327        Ok(output.into_dyn())
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
336        ArrayD::from_elem(IxDyn(shape), fill)
337    }
338
339    #[test]
340    fn test_cached_attention_forward_no_cache() {
341        let attn = CachedAttention::new(2, 4, false, 32).expect("valid config");
342        // [batch=1, seq=3, d=8]
343        let q = make_tensor(&[1, 3, 8], 0.5);
344        let k = make_tensor(&[1, 3, 8], 0.5);
345        let v = make_tensor(&[1, 3, 8], 0.5);
346        let out = attn
347            .forward(&q, &k, &v, None, 0)
348            .expect("forward should succeed");
349        assert_eq!(
350            out.shape(),
351            &[1, 3, 8],
352            "output shape must be [batch, seq, d]"
353        );
354    }
355
356    #[test]
357    fn test_cached_attention_causal_mask_shape() {
358        let mask = CachedAttention::causal_mask(4, 0);
359        assert_eq!(mask.shape(), &[4, 4], "causal mask must be [seq, seq]");
360        // Lower triangular: mask[0,1] should be large negative.
361        assert!(mask[[0, 1]] < -1e8, "future positions should be masked");
362        // mask[1,0] should be zero (allowed to attend to past).
363        assert!(
364            (mask[[1, 0]]).abs() < 1e-9,
365            "past positions should not be masked"
366        );
367    }
368
369    #[test]
370    fn test_cached_attention_with_cache_extends_seq() {
371        let attn = CachedAttention::new(2, 4, false, 64).expect("valid");
372        let mut cache = KvCache::new(1, 2, 4, 64);
373        let q = make_tensor(&[1, 2, 8], 0.1);
374        let k = make_tensor(&[1, 2, 8], 0.1);
375        let v = make_tensor(&[1, 2, 8], 0.1);
376        attn.forward(&q, &k, &v, Some(&mut cache), 0)
377            .expect("forward with cache");
378        assert!(cache.seq_len > 0, "cache seq_len should grow after forward");
379    }
380
381    #[test]
382    fn test_cached_attention_error_display() {
383        let err = CachedAttentionError::InvalidShape("bad shape".to_string());
384        let s = err.to_string();
385        assert!(
386            s.contains("bad shape"),
387            "Display impl should include the message"
388        );
389    }
390}