Skip to main content

sapient_generate/
kv_cache.rs

1//! KV-cache for incremental autoregressive decoding.
2//!
3//! Each decoder layer maintains a `LayerKVCache` that grows as tokens
4//! are generated. At decode step `t`, the cache holds K and V for
5//! positions [0, t-1], so only the new token's QKV needs to be computed.
6
7use sapient_core::Tensor;
8use std::collections::HashMap;
9
10// ── LayerKVCache ──────────────────────────────────────────────────────────────
11
12/// Cached key and value tensors for one decoder layer.
13#[derive(Debug, Clone)]
14pub struct LayerKVCache {
15    /// Accumulated keys — shape grows as (batch, n_kv_heads, seq_k, head_dim).
16    pub keys: Vec<Tensor>,
17    /// Accumulated values — same shape.
18    pub values: Vec<Tensor>,
19}
20
21impl LayerKVCache {
22    pub fn empty() -> Self {
23        Self {
24            keys: Vec::new(),
25            values: Vec::new(),
26        }
27    }
28
29    /// Append a new key/value slice and return the current sequence length.
30    pub fn append(&mut self, k: Tensor, v: Tensor) -> usize {
31        self.keys.push(k);
32        self.values.push(v);
33        self.keys.len()
34    }
35
36    /// Current cached sequence length.
37    pub fn seq_len(&self) -> usize {
38        self.keys.len()
39    }
40
41    /// Clear the cache (e.g., start a new conversation).
42    pub fn clear(&mut self) {
43        self.keys.clear();
44        self.values.clear();
45    }
46}
47
48// ── KVCache ───────────────────────────────────────────────────────────────────
49
50/// Full KV cache for all decoder layers.
51#[derive(Debug, Clone)]
52pub struct KVCache {
53    layers: Vec<LayerKVCache>,
54}
55
56impl KVCache {
57    /// Create an empty KV cache for `n_layers` decoder layers.
58    pub fn new(n_layers: usize) -> Self {
59        Self {
60            layers: (0..n_layers).map(|_| LayerKVCache::empty()).collect(),
61        }
62    }
63
64    pub fn layer(&self, idx: usize) -> &LayerKVCache {
65        &self.layers[idx]
66    }
67    pub fn layer_mut(&mut self, idx: usize) -> &mut LayerKVCache {
68        &mut self.layers[idx]
69    }
70
71    /// Sequence length of the first layer (all layers have the same length).
72    pub fn seq_len(&self) -> usize {
73        self.layers.first().map(|l| l.seq_len()).unwrap_or(0)
74    }
75
76    /// Clear the entire cache (new conversation / context reset).
77    pub fn clear(&mut self) {
78        for l in &mut self.layers {
79            l.clear();
80        }
81    }
82
83    /// Number of layers in the cache.
84    pub fn n_layers(&self) -> usize {
85        self.layers.len()
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn kv_cache_grows() {
95        let mut cache = KVCache::new(4);
96        assert_eq!(cache.seq_len(), 0);
97        let k = Tensor::zeros(vec![1, 2, 1, 64], sapient_core::DType::F32).unwrap();
98        let v = Tensor::zeros(vec![1, 2, 1, 64], sapient_core::DType::F32).unwrap();
99        cache.layer_mut(0).append(k.clone(), v.clone());
100        cache.layer_mut(0).append(k, v);
101        assert_eq!(cache.layer(0).seq_len(), 2);
102    }
103
104    #[test]
105    fn kv_cache_clear() {
106        let mut cache = KVCache::new(2);
107        let t = Tensor::zeros(vec![1, 1, 1, 64], sapient_core::DType::F32).unwrap();
108        cache.layer_mut(0).append(t.clone(), t);
109        assert_eq!(cache.seq_len(), 1);
110        cache.clear();
111        assert_eq!(cache.seq_len(), 0);
112    }
113}