Skip to main content

tensorlogic_trustformers/kv_cache/
simple_cache.rs

1use ndarray::{ArrayD, Axis, IxDyn};
2use std::fmt;
3
4/// Error for ndarray-based KV-cache operations.
5#[derive(Debug, Clone)]
6pub enum KvCacheError {
7    /// Layer index is out of bounds.
8    LayerOutOfBounds { layer: usize, num_layers: usize },
9    /// Cache has reached maximum sequence length.
10    CacheFull { max_seq_len: usize },
11    /// Tensor shape does not match expected shape.
12    ShapeMismatch {
13        expected: Vec<usize>,
14        got: Vec<usize>,
15    },
16}
17
18impl fmt::Display for KvCacheError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            Self::LayerOutOfBounds { layer, num_layers } => write!(
22                f,
23                "Layer index {} is out of bounds (num_layers = {})",
24                layer, num_layers
25            ),
26            Self::CacheFull { max_seq_len } => {
27                write!(f, "KV cache is full (max_seq_len = {})", max_seq_len)
28            }
29            Self::ShapeMismatch { expected, got } => {
30                write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
31            }
32        }
33    }
34}
35
36impl std::error::Error for KvCacheError {}
37
38/// Cached key-value pairs for autoregressive inference using ndarray tensors.
39///
40/// Stores per-layer key and value tensors as dynamic-rank `ArrayD<f64>`.
41/// Tensors are concatenated along the sequence dimension on each `append_kv` call.
42#[derive(Debug, Clone)]
43pub struct KvCache {
44    /// Cached K tensors per layer (shape: `[seq_len, num_heads, head_dim]` after appends).
45    pub keys: Vec<ArrayD<f64>>,
46    /// Cached V tensors per layer (shape: `[seq_len, num_heads, head_dim]` after appends).
47    pub values: Vec<ArrayD<f64>>,
48    /// Current cached sequence length.
49    pub seq_len: usize,
50    /// Maximum allowed sequence length.
51    pub max_seq_len: usize,
52    /// Number of transformer layers.
53    pub num_layers: usize,
54    /// Number of attention heads.
55    pub num_heads: usize,
56}
57
58impl KvCache {
59    /// Create a new, empty KV-cache.
60    ///
61    /// Initially all per-layer tensors are zero-sized along the sequence dimension.
62    pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
63        let empty = ArrayD::<f64>::zeros(IxDyn(&[0, num_heads, head_dim]));
64        Self {
65            keys: vec![empty.clone(); num_layers],
66            values: vec![empty; num_layers],
67            seq_len: 0,
68            max_seq_len,
69            num_layers,
70            num_heads,
71        }
72    }
73
74    /// Append new key and value tensors for the given layer.
75    ///
76    /// `new_k` and `new_v` must have shape `[new_tokens, num_heads, head_dim]`.
77    pub fn append_kv(
78        &mut self,
79        layer: usize,
80        new_k: ArrayD<f64>,
81        new_v: ArrayD<f64>,
82    ) -> std::result::Result<(), KvCacheError> {
83        if layer >= self.num_layers {
84            return Err(KvCacheError::LayerOutOfBounds {
85                layer,
86                num_layers: self.num_layers,
87            });
88        }
89
90        if self.seq_len >= self.max_seq_len {
91            return Err(KvCacheError::CacheFull {
92                max_seq_len: self.max_seq_len,
93            });
94        }
95
96        // Validate shapes match existing cache shape (except axis 0 which is seq).
97        let expected_tail = &self.keys[layer].shape()[1..];
98        let got_tail = &new_k.shape()[1..];
99        if expected_tail != got_tail && !self.keys[layer].shape()[0] == 0 {
100            return Err(KvCacheError::ShapeMismatch {
101                expected: expected_tail.to_vec(),
102                got: got_tail.to_vec(),
103            });
104        }
105
106        let new_tokens = new_k.shape()[0];
107        if self.seq_len + new_tokens > self.max_seq_len {
108            return Err(KvCacheError::CacheFull {
109                max_seq_len: self.max_seq_len,
110            });
111        }
112
113        // Concatenate along axis 0 (sequence dimension).
114        let concat_k = if self.keys[layer].shape()[0] == 0 {
115            new_k
116        } else {
117            let views_k = vec![self.keys[layer].view(), new_k.view()];
118            ndarray::concatenate(Axis(0), &views_k).map_err(|e| KvCacheError::ShapeMismatch {
119                expected: self.keys[layer].shape().to_vec(),
120                got: vec![e.to_string().len()], // encode error into shape slot
121            })?
122        };
123
124        let concat_v = if self.values[layer].shape()[0] == 0 {
125            new_v
126        } else {
127            let views_v = vec![self.values[layer].view(), new_v.view()];
128            ndarray::concatenate(Axis(0), &views_v).map_err(|e| KvCacheError::ShapeMismatch {
129                expected: self.values[layer].shape().to_vec(),
130                got: vec![e.to_string().len()],
131            })?
132        };
133
134        // Only update seq_len for layer 0 to keep a single global counter.
135        // For multi-layer caches, seq_len tracks the common sequence length.
136        if layer == 0 {
137            self.seq_len += new_tokens;
138        } else {
139            // Update seq_len from the actual key length of layer 0.
140            self.seq_len = self.keys[0].shape()[0];
141        }
142
143        self.keys[layer] = concat_k;
144        self.values[layer] = concat_v;
145
146        // Recompute seq_len as the max over all layers.
147        self.seq_len = self.keys.iter().map(|k| k.shape()[0]).max().unwrap_or(0);
148
149        Ok(())
150    }
151
152    /// Retrieve cached keys and values for the given layer.
153    ///
154    /// Returns `None` if the layer index is out of bounds.
155    pub fn get_kv(&self, layer: usize) -> Option<(&ArrayD<f64>, &ArrayD<f64>)> {
156        if layer >= self.num_layers {
157            return None;
158        }
159        Some((&self.keys[layer], &self.values[layer]))
160    }
161
162    /// Reset the cache to empty (seq_len == 0).
163    pub fn reset(&mut self) {
164        let head_dim = if self.num_layers > 0 && !self.keys[0].shape().is_empty() {
165            *self.keys[0].shape().last().unwrap_or(&0)
166        } else {
167            0
168        };
169        let empty = ArrayD::<f64>::zeros(IxDyn(&[0, self.num_heads, head_dim]));
170        for k in &mut self.keys {
171            *k = empty.clone();
172        }
173        for v in &mut self.values {
174            *v = empty.clone();
175        }
176        self.seq_len = 0;
177    }
178
179    /// Current cached sequence length.
180    pub fn current_len(&self) -> usize {
181        self.seq_len
182    }
183
184    /// Returns `true` if the cache is at maximum capacity.
185    pub fn is_full(&self) -> bool {
186        self.seq_len >= self.max_seq_len
187    }
188
189    /// Approximate memory usage in bytes (f64 = 8 bytes per element).
190    pub fn memory_usage_bytes(&self) -> usize {
191        let key_bytes: usize = self.keys.iter().map(|k| k.len() * 8).sum();
192        let val_bytes: usize = self.values.iter().map(|v| v.len() * 8).sum();
193        key_bytes + val_bytes
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
202        ArrayD::from_elem(IxDyn(shape), fill)
203    }
204
205    #[test]
206    fn test_kv_cache_new_and_append() {
207        let mut cache = KvCache::new(2, 4, 8, 16);
208        let new_k = make_tensor(&[3, 4, 8], 1.0);
209        let new_v = make_tensor(&[3, 4, 8], 2.0);
210        cache
211            .append_kv(0, new_k, new_v)
212            .expect("append should succeed");
213        assert_eq!(cache.seq_len, 3, "seq_len should increment");
214    }
215
216    #[test]
217    fn test_kv_cache_full_returns_error() {
218        let mut cache = KvCache::new(1, 2, 4, 3);
219        // Fill up the cache completely (3 tokens).
220        let k = make_tensor(&[3, 2, 4], 1.0);
221        let v = make_tensor(&[3, 2, 4], 1.0);
222        cache.append_kv(0, k, v).expect("initial fill");
223        // Next append should fail.
224        let k2 = make_tensor(&[1, 2, 4], 1.0);
225        let v2 = make_tensor(&[1, 2, 4], 1.0);
226        let result = cache.append_kv(0, k2, v2);
227        assert!(
228            matches!(result, Err(KvCacheError::CacheFull { .. })),
229            "expected CacheFull error"
230        );
231    }
232
233    #[test]
234    fn test_kv_cache_reset() {
235        let mut cache = KvCache::new(1, 2, 4, 16);
236        let k = make_tensor(&[4, 2, 4], 1.0);
237        let v = make_tensor(&[4, 2, 4], 1.0);
238        cache.append_kv(0, k, v).expect("append");
239        assert!(cache.seq_len > 0);
240        cache.reset();
241        assert_eq!(cache.seq_len, 0, "seq_len must be 0 after reset");
242    }
243
244    #[test]
245    fn test_kv_cache_memory_usage() {
246        let mut cache = KvCache::new(1, 2, 4, 16);
247        let k = make_tensor(&[2, 2, 4], 1.0);
248        let v = make_tensor(&[2, 2, 4], 1.0);
249        cache.append_kv(0, k, v).expect("append");
250        assert!(
251            cache.memory_usage_bytes() > 0,
252            "memory should be non-zero after append"
253        );
254    }
255
256    #[test]
257    fn test_kv_cache_get_kv_valid_layer() {
258        let mut cache = KvCache::new(2, 2, 4, 16);
259        let k = make_tensor(&[2, 2, 4], 1.0);
260        let v = make_tensor(&[2, 2, 4], 2.0);
261        cache.append_kv(0, k, v).expect("append");
262        let result = cache.get_kv(0);
263        assert!(result.is_some(), "should return Some for valid layer");
264    }
265
266    #[test]
267    fn test_kv_cache_get_kv_invalid_layer() {
268        let cache = KvCache::new(2, 2, 4, 16);
269        let result = cache.get_kv(99);
270        assert!(
271            result.is_none(),
272            "should return None for out-of-range layer"
273        );
274    }
275
276    #[test]
277    fn test_kv_cache_layer_out_of_bounds_error() {
278        let mut cache = KvCache::new(2, 2, 4, 16);
279        let k = make_tensor(&[1, 2, 4], 1.0);
280        let v = make_tensor(&[1, 2, 4], 1.0);
281        let result = cache.append_kv(5, k, v);
282        assert!(
283            matches!(result, Err(KvCacheError::LayerOutOfBounds { .. })),
284            "layer >= num_layers should return LayerOutOfBounds"
285        );
286    }
287}