sapient_generate/
kv_cache.rs1use sapient_core::Tensor;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
14pub struct LayerKVCache {
15 pub keys: Vec<Tensor>,
17 pub values: Vec<Tensor>,
19}
20
21impl LayerKVCache {
22 pub fn empty() -> Self {
23 Self {
24 keys: Vec::new(),
25 values: Vec::new(),
26 }
27 }
28
29 pub fn append(&mut self, k: Tensor, v: Tensor) -> usize {
31 self.keys.push(k);
32 self.values.push(v);
33 self.keys.len()
34 }
35
36 pub fn seq_len(&self) -> usize {
38 self.keys.len()
39 }
40
41 pub fn clear(&mut self) {
43 self.keys.clear();
44 self.values.clear();
45 }
46}
47
48#[derive(Debug, Clone)]
52pub struct KVCache {
53 layers: Vec<LayerKVCache>,
54}
55
56impl KVCache {
57 pub fn new(n_layers: usize) -> Self {
59 Self {
60 layers: (0..n_layers).map(|_| LayerKVCache::empty()).collect(),
61 }
62 }
63
64 pub fn layer(&self, idx: usize) -> &LayerKVCache {
65 &self.layers[idx]
66 }
67 pub fn layer_mut(&mut self, idx: usize) -> &mut LayerKVCache {
68 &mut self.layers[idx]
69 }
70
71 pub fn seq_len(&self) -> usize {
73 self.layers.first().map(|l| l.seq_len()).unwrap_or(0)
74 }
75
76 pub fn clear(&mut self) {
78 for l in &mut self.layers {
79 l.clear();
80 }
81 }
82
83 pub fn n_layers(&self) -> usize {
85 self.layers.len()
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn kv_cache_grows() {
95 let mut cache = KVCache::new(4);
96 assert_eq!(cache.seq_len(), 0);
97 let k = Tensor::zeros(vec![1, 2, 1, 64], sapient_core::DType::F32).unwrap();
98 let v = Tensor::zeros(vec![1, 2, 1, 64], sapient_core::DType::F32).unwrap();
99 cache.layer_mut(0).append(k.clone(), v.clone());
100 cache.layer_mut(0).append(k, v);
101 assert_eq!(cache.layer(0).seq_len(), 2);
102 }
103
104 #[test]
105 fn kv_cache_clear() {
106 let mut cache = KVCache::new(2);
107 let t = Tensor::zeros(vec![1, 1, 1, 64], sapient_core::DType::F32).unwrap();
108 cache.layer_mut(0).append(t.clone(), t);
109 assert_eq!(cache.seq_len(), 1);
110 cache.clear();
111 assert_eq!(cache.seq_len(), 0);
112 }
113}