tensorlogic_trustformers/kv_cache/
simple_cache.rs1use ndarray::{ArrayD, Axis, IxDyn};
2use std::fmt;
3
4#[derive(Debug, Clone)]
6pub enum KvCacheError {
7 LayerOutOfBounds { layer: usize, num_layers: usize },
9 CacheFull { max_seq_len: usize },
11 ShapeMismatch {
13 expected: Vec<usize>,
14 got: Vec<usize>,
15 },
16}
17
18impl fmt::Display for KvCacheError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 Self::LayerOutOfBounds { layer, num_layers } => write!(
22 f,
23 "Layer index {} is out of bounds (num_layers = {})",
24 layer, num_layers
25 ),
26 Self::CacheFull { max_seq_len } => {
27 write!(f, "KV cache is full (max_seq_len = {})", max_seq_len)
28 }
29 Self::ShapeMismatch { expected, got } => {
30 write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
31 }
32 }
33 }
34}
35
36impl std::error::Error for KvCacheError {}
37
38#[derive(Debug, Clone)]
43pub struct KvCache {
44 pub keys: Vec<ArrayD<f64>>,
46 pub values: Vec<ArrayD<f64>>,
48 pub seq_len: usize,
50 pub max_seq_len: usize,
52 pub num_layers: usize,
54 pub num_heads: usize,
56}
57
58impl KvCache {
59 pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
63 let empty = ArrayD::<f64>::zeros(IxDyn(&[0, num_heads, head_dim]));
64 Self {
65 keys: vec![empty.clone(); num_layers],
66 values: vec![empty; num_layers],
67 seq_len: 0,
68 max_seq_len,
69 num_layers,
70 num_heads,
71 }
72 }
73
74 pub fn append_kv(
78 &mut self,
79 layer: usize,
80 new_k: ArrayD<f64>,
81 new_v: ArrayD<f64>,
82 ) -> std::result::Result<(), KvCacheError> {
83 if layer >= self.num_layers {
84 return Err(KvCacheError::LayerOutOfBounds {
85 layer,
86 num_layers: self.num_layers,
87 });
88 }
89
90 if self.seq_len >= self.max_seq_len {
91 return Err(KvCacheError::CacheFull {
92 max_seq_len: self.max_seq_len,
93 });
94 }
95
96 let expected_tail = &self.keys[layer].shape()[1..];
98 let got_tail = &new_k.shape()[1..];
99 if expected_tail != got_tail && !self.keys[layer].shape()[0] == 0 {
100 return Err(KvCacheError::ShapeMismatch {
101 expected: expected_tail.to_vec(),
102 got: got_tail.to_vec(),
103 });
104 }
105
106 let new_tokens = new_k.shape()[0];
107 if self.seq_len + new_tokens > self.max_seq_len {
108 return Err(KvCacheError::CacheFull {
109 max_seq_len: self.max_seq_len,
110 });
111 }
112
113 let concat_k = if self.keys[layer].shape()[0] == 0 {
115 new_k
116 } else {
117 let views_k = vec![self.keys[layer].view(), new_k.view()];
118 ndarray::concatenate(Axis(0), &views_k).map_err(|e| KvCacheError::ShapeMismatch {
119 expected: self.keys[layer].shape().to_vec(),
120 got: vec![e.to_string().len()], })?
122 };
123
124 let concat_v = if self.values[layer].shape()[0] == 0 {
125 new_v
126 } else {
127 let views_v = vec![self.values[layer].view(), new_v.view()];
128 ndarray::concatenate(Axis(0), &views_v).map_err(|e| KvCacheError::ShapeMismatch {
129 expected: self.values[layer].shape().to_vec(),
130 got: vec![e.to_string().len()],
131 })?
132 };
133
134 if layer == 0 {
137 self.seq_len += new_tokens;
138 } else {
139 self.seq_len = self.keys[0].shape()[0];
141 }
142
143 self.keys[layer] = concat_k;
144 self.values[layer] = concat_v;
145
146 self.seq_len = self.keys.iter().map(|k| k.shape()[0]).max().unwrap_or(0);
148
149 Ok(())
150 }
151
152 pub fn get_kv(&self, layer: usize) -> Option<(&ArrayD<f64>, &ArrayD<f64>)> {
156 if layer >= self.num_layers {
157 return None;
158 }
159 Some((&self.keys[layer], &self.values[layer]))
160 }
161
162 pub fn reset(&mut self) {
164 let head_dim = if self.num_layers > 0 && !self.keys[0].shape().is_empty() {
165 *self.keys[0].shape().last().unwrap_or(&0)
166 } else {
167 0
168 };
169 let empty = ArrayD::<f64>::zeros(IxDyn(&[0, self.num_heads, head_dim]));
170 for k in &mut self.keys {
171 *k = empty.clone();
172 }
173 for v in &mut self.values {
174 *v = empty.clone();
175 }
176 self.seq_len = 0;
177 }
178
179 pub fn current_len(&self) -> usize {
181 self.seq_len
182 }
183
184 pub fn is_full(&self) -> bool {
186 self.seq_len >= self.max_seq_len
187 }
188
189 pub fn memory_usage_bytes(&self) -> usize {
191 let key_bytes: usize = self.keys.iter().map(|k| k.len() * 8).sum();
192 let val_bytes: usize = self.values.iter().map(|v| v.len() * 8).sum();
193 key_bytes + val_bytes
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
202 ArrayD::from_elem(IxDyn(shape), fill)
203 }
204
205 #[test]
206 fn test_kv_cache_new_and_append() {
207 let mut cache = KvCache::new(2, 4, 8, 16);
208 let new_k = make_tensor(&[3, 4, 8], 1.0);
209 let new_v = make_tensor(&[3, 4, 8], 2.0);
210 cache
211 .append_kv(0, new_k, new_v)
212 .expect("append should succeed");
213 assert_eq!(cache.seq_len, 3, "seq_len should increment");
214 }
215
216 #[test]
217 fn test_kv_cache_full_returns_error() {
218 let mut cache = KvCache::new(1, 2, 4, 3);
219 let k = make_tensor(&[3, 2, 4], 1.0);
221 let v = make_tensor(&[3, 2, 4], 1.0);
222 cache.append_kv(0, k, v).expect("initial fill");
223 let k2 = make_tensor(&[1, 2, 4], 1.0);
225 let v2 = make_tensor(&[1, 2, 4], 1.0);
226 let result = cache.append_kv(0, k2, v2);
227 assert!(
228 matches!(result, Err(KvCacheError::CacheFull { .. })),
229 "expected CacheFull error"
230 );
231 }
232
233 #[test]
234 fn test_kv_cache_reset() {
235 let mut cache = KvCache::new(1, 2, 4, 16);
236 let k = make_tensor(&[4, 2, 4], 1.0);
237 let v = make_tensor(&[4, 2, 4], 1.0);
238 cache.append_kv(0, k, v).expect("append");
239 assert!(cache.seq_len > 0);
240 cache.reset();
241 assert_eq!(cache.seq_len, 0, "seq_len must be 0 after reset");
242 }
243
244 #[test]
245 fn test_kv_cache_memory_usage() {
246 let mut cache = KvCache::new(1, 2, 4, 16);
247 let k = make_tensor(&[2, 2, 4], 1.0);
248 let v = make_tensor(&[2, 2, 4], 1.0);
249 cache.append_kv(0, k, v).expect("append");
250 assert!(
251 cache.memory_usage_bytes() > 0,
252 "memory should be non-zero after append"
253 );
254 }
255
256 #[test]
257 fn test_kv_cache_get_kv_valid_layer() {
258 let mut cache = KvCache::new(2, 2, 4, 16);
259 let k = make_tensor(&[2, 2, 4], 1.0);
260 let v = make_tensor(&[2, 2, 4], 2.0);
261 cache.append_kv(0, k, v).expect("append");
262 let result = cache.get_kv(0);
263 assert!(result.is_some(), "should return Some for valid layer");
264 }
265
266 #[test]
267 fn test_kv_cache_get_kv_invalid_layer() {
268 let cache = KvCache::new(2, 2, 4, 16);
269 let result = cache.get_kv(99);
270 assert!(
271 result.is_none(),
272 "should return None for out-of-range layer"
273 );
274 }
275
276 #[test]
277 fn test_kv_cache_layer_out_of_bounds_error() {
278 let mut cache = KvCache::new(2, 2, 4, 16);
279 let k = make_tensor(&[1, 2, 4], 1.0);
280 let v = make_tensor(&[1, 2, 4], 1.0);
281 let result = cache.append_kv(5, k, v);
282 assert!(
283 matches!(result, Err(KvCacheError::LayerOutOfBounds { .. })),
284 "layer >= num_layers should return LayerOutOfBounds"
285 );
286 }
287}