Skip to main content

ruvector_attention/attention/
kv_cache.rs

1//! KV-Cache Compression for inference-time memory efficiency.
2//!
3//! Inspired by Google's TurboQuant (ICLR 2026), this module implements low-bit
4//! quantization of Key-Value caches to reduce memory pressure during autoregressive
5//! inference. TurboQuant demonstrates that 3-bit asymmetric per-channel quantization
6//! of KV caches achieves up to 6x memory reduction and 8x attention computation
7//! speedup with negligible quality loss (<0.5% perplexity degradation).
8//!
9//! # Design
10//!
11//! - **Per-channel asymmetric quantization**: Each attention head gets its own
12//!   scale and zero-point, preserving head-specific value distributions.
13//! - **Banker's rounding**: Round-to-nearest-even reduces systematic bias in
14//!   low-bit regimes, critical at 3-bit where every quantum matters.
15//! - **Eviction policies**: When the cache exceeds a budget, entries are pruned
16//!   using one of three strategies: H2O (attention-score based), Sliding Window
17//!   (recency-biased with sink tokens), or PyramidKV (layer-aware budgets).
18//!
19//! # Example
20//!
21//! ```rust
22//! use ruvector_attention::attention::kv_cache::*;
23//!
24//! let config = KVCacheConfig {
25//!     max_seq_len: 128,
26//!     num_heads: 4,
27//!     head_dim: 16,
28//!     quantization_bits: 4,
29//!     eviction_policy: EvictionPolicy::SlidingWindow { window: 64, sink: 4 },
30//! };
31//! let mut manager = CacheManager::new(config);
32//! let key = vec![0.5_f32; 64];
33//! let value = vec![-0.3_f32; 64];
34//! manager.append(&key, &value, 0);
35//! let (k, v) = manager.get(&[0]);
36//! assert_eq!(k.len(), 1);
37//! ```
38
39use std::collections::VecDeque;
40
41// ---------------------------------------------------------------------------
42// Configuration
43// ---------------------------------------------------------------------------
44
45/// Eviction policy for pruning the KV-cache when it exceeds its budget.
46#[derive(Debug, Clone, PartialEq)]
47pub enum EvictionPolicy {
48    /// Heavy Hitter Oracle: retains tokens with the highest cumulative
49    /// attention scores, discarding those rarely attended to.
50    H2O,
51    /// Sliding Window with sink tokens (StreamingLLM). Keeps the first
52    /// `sink` tokens and the most recent `window` tokens.
53    SlidingWindow {
54        /// Number of recent tokens to retain.
55        window: usize,
56        /// Number of initial "sink" tokens to always keep.
57        sink: usize,
58    },
59    /// PyramidKV: assigns larger cache budgets to lower (earlier) layers
60    /// and smaller budgets to upper layers, reflecting the observation
61    /// that lower layers capture broader context.
62    PyramidKV {
63        /// Total number of layers in the model.
64        total_layers: usize,
65    },
66}
67
68/// Configuration for the quantized KV-cache.
69#[derive(Debug, Clone)]
70pub struct KVCacheConfig {
71    /// Maximum sequence length the cache can hold before eviction is required.
72    pub max_seq_len: usize,
73    /// Number of attention heads.
74    pub num_heads: usize,
75    /// Dimension per attention head.
76    pub head_dim: usize,
77    /// Bit-width for quantization. Supported: 2, 3, 4, 8.
78    pub quantization_bits: u8,
79    /// Policy used when the cache exceeds its budget.
80    pub eviction_policy: EvictionPolicy,
81}
82
83// ---------------------------------------------------------------------------
84// Quantization primitives
85// ---------------------------------------------------------------------------
86
87/// A quantized tensor with per-channel scale and zero-point for asymmetric
88/// dequantization: `value = scale * (quantized - zero_point)`.
89#[derive(Debug, Clone)]
90pub struct QuantizedTensor {
91    /// Packed quantized values stored as u8. For sub-byte widths the values
92    /// are stored one-per-byte for simplicity (packing is a future optimisation).
93    pub data: Vec<u8>,
94    /// Per-channel (per-head) scale factors.
95    pub scales: Vec<f32>,
96    /// Per-channel (per-head) zero-points in quantized domain.
97    pub zero_points: Vec<f32>,
98    /// Bit-width used during quantization.
99    pub bits: u8,
100}
101
102/// Banker's rounding (round half to even) to reduce systematic bias.
103#[inline]
104pub fn round_to_nearest_even(x: f32) -> f32 {
105    let rounded = x.round();
106    // When exactly halfway, round to even.
107    let frac = (x - x.floor()).abs();
108    if (frac - 0.5).abs() < f32::EPSILON {
109        let r = rounded as i64;
110        if r % 2 != 0 {
111            // Nudge toward even.
112            if x > 0.0 {
113                rounded - 1.0
114            } else {
115                rounded + 1.0
116            }
117        } else {
118            rounded
119        }
120    } else {
121        rounded
122    }
123}
124
125/// Asymmetric per-channel quantization.
126///
127/// `tensor` is shaped `[num_heads * head_dim]` (one KV vector across all heads).
128/// Quantisation is performed per-head (channel), each getting its own scale and
129/// zero-point. Returns a [`QuantizedTensor`].
130pub fn quantize_asymmetric(tensor: &[f32], num_heads: usize, bits: u8) -> QuantizedTensor {
131    let head_dim = tensor.len() / num_heads;
132    let qmax = ((1u32 << bits) - 1) as f32;
133
134    let mut data = Vec::with_capacity(tensor.len());
135    let mut scales = Vec::with_capacity(num_heads);
136    let mut zero_points = Vec::with_capacity(num_heads);
137
138    for h in 0..num_heads {
139        let start = h * head_dim;
140        let end = start + head_dim;
141        let channel = &tensor[start..end];
142
143        let min_val = channel.iter().copied().fold(f32::INFINITY, f32::min);
144        let max_val = channel.iter().copied().fold(f32::NEG_INFINITY, f32::max);
145
146        let range = max_val - min_val;
147        let scale = if range.abs() < f32::EPSILON {
148            1.0
149        } else {
150            range / qmax
151        };
152        let zp = if range.abs() < f32::EPSILON {
153            0.0
154        } else {
155            -min_val / scale
156        };
157
158        scales.push(scale);
159        zero_points.push(zp);
160
161        for &v in channel {
162            let q = round_to_nearest_even(v / scale + zp).clamp(0.0, qmax);
163            data.push(q as u8);
164        }
165    }
166
167    QuantizedTensor {
168        data,
169        scales,
170        zero_points,
171        bits,
172    }
173}
174
175/// Symmetric quantization (simpler, useful for comparison).
176///
177/// `value = scale * quantized` with zero-point fixed at the midpoint.
178///
179/// # Panics
180///
181/// Panics if `bits` is less than 2 or greater than 8.
182pub fn quantize_symmetric(tensor: &[f32], bits: u8) -> (Vec<u8>, f32) {
183    assert!(
184        bits >= 2 && bits <= 8,
185        "quantize_symmetric: bits must be in [2, 8], got {}",
186        bits
187    );
188    let qmax = ((1u32 << (bits - 1)) - 1) as f32;
189    let abs_max = tensor.iter().copied().map(f32::abs).fold(0.0_f32, f32::max);
190    let scale = if abs_max < f32::EPSILON {
191        1.0
192    } else {
193        abs_max / qmax
194    };
195    let offset = (1u32 << (bits - 1)) as f32; // unsigned offset
196
197    let data: Vec<u8> = tensor
198        .iter()
199        .map(|&v| {
200            let q =
201                round_to_nearest_even(v / scale + offset).clamp(0.0, (1u32 << bits) as f32 - 1.0);
202            q as u8
203        })
204        .collect();
205    (data, scale)
206}
207
208/// Dequantize symmetric quantized data back to f32.
209pub fn dequantize_symmetric(data: &[u8], scale: f32, bits: u8) -> Vec<f32> {
210    let offset = (1u32 << (bits - 1)) as f32;
211    data.iter().map(|&q| (q as f32 - offset) * scale).collect()
212}
213
214/// Dequantize an asymmetrically quantized tensor back to f32.
215pub fn dequantize(qt: &QuantizedTensor, num_heads: usize) -> Vec<f32> {
216    let head_dim = qt.data.len() / num_heads;
217    let mut out = Vec::with_capacity(qt.data.len());
218    for h in 0..num_heads {
219        let start = h * head_dim;
220        let end = start + head_dim;
221        let scale = qt.scales[h];
222        let zp = qt.zero_points[h];
223        for &q in &qt.data[start..end] {
224            out.push(scale * (q as f32 - zp));
225        }
226    }
227    out
228}
229
230// ---------------------------------------------------------------------------
231// Cache entry
232// ---------------------------------------------------------------------------
233
234/// A single cached key-value pair (quantized).
235#[derive(Debug, Clone)]
236struct CacheEntry {
237    key: QuantizedTensor,
238    value: QuantizedTensor,
239    /// Cumulative attention score for H2O eviction.
240    attention_score: f64,
241    /// Insertion order (monotonically increasing).
242    seq_idx: usize,
243}
244
245// ---------------------------------------------------------------------------
246// CacheManager
247// ---------------------------------------------------------------------------
248
249/// Manages a quantized KV-cache with configurable eviction.
250///
251/// Provides `append`, `get`, `evict`, and diagnostic methods such as
252/// `compression_ratio` and `memory_bytes`.
253pub struct CacheManager {
254    config: KVCacheConfig,
255    entries: VecDeque<CacheEntry>,
256    next_seq: usize,
257}
258
259impl CacheManager {
260    /// Create a new cache manager with the given configuration.
261    pub fn new(config: KVCacheConfig) -> Self {
262        Self {
263            config,
264            entries: VecDeque::new(),
265            next_seq: 0,
266        }
267    }
268
269    /// Number of entries currently in the cache.
270    pub fn len(&self) -> usize {
271        self.entries.len()
272    }
273
274    /// Whether the cache is empty.
275    pub fn is_empty(&self) -> bool {
276        self.entries.is_empty()
277    }
278
279    /// Append a new key-value pair to the cache.
280    ///
281    /// `key` and `value` must each have length `num_heads * head_dim`.
282    /// `_layer_idx` is used by the PyramidKV eviction policy to determine
283    /// the per-layer budget.
284    pub fn append(&mut self, key: &[f32], value: &[f32], _layer_idx: usize) {
285        let bits = self.config.quantization_bits;
286        let heads = self.config.num_heads;
287
288        let qk = quantize_asymmetric(key, heads, bits);
289        let qv = quantize_asymmetric(value, heads, bits);
290
291        self.entries.push_back(CacheEntry {
292            key: qk,
293            value: qv,
294            attention_score: 0.0,
295            seq_idx: self.next_seq,
296        });
297        self.next_seq += 1;
298
299        // Auto-evict if over budget.
300        if self.entries.len() > self.config.max_seq_len {
301            self.evict(self.config.max_seq_len);
302        }
303    }
304
305    /// Retrieve dequantized key-value pairs at the given logical positions.
306    ///
307    /// Returns `(keys, values)` where each inner `Vec<f32>` has length
308    /// `num_heads * head_dim`.
309    pub fn get(&self, positions: &[usize]) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
310        let heads = self.config.num_heads;
311        let mut keys = Vec::with_capacity(positions.len());
312        let mut values = Vec::with_capacity(positions.len());
313
314        for &pos in positions {
315            if pos < self.entries.len() {
316                let entry = &self.entries[pos];
317                keys.push(dequantize(&entry.key, heads));
318                values.push(dequantize(&entry.value, heads));
319            }
320        }
321        (keys, values)
322    }
323
324    /// Evict entries until the cache contains at most `budget` entries.
325    pub fn evict(&mut self, budget: usize) {
326        if self.entries.len() <= budget {
327            return;
328        }
329
330        match &self.config.eviction_policy {
331            EvictionPolicy::H2O => self.evict_h2o(budget),
332            EvictionPolicy::SlidingWindow { window, sink } => {
333                self.evict_sliding_window(budget, *window, *sink);
334            }
335            EvictionPolicy::PyramidKV { .. } => {
336                // PyramidKV adjusts budget externally per layer; here we just
337                // fall back to H2O-style eviction within the given budget.
338                self.evict_h2o(budget);
339            }
340        }
341    }
342
343    /// H2O eviction: remove entries with the lowest cumulative attention score.
344    fn evict_h2o(&mut self, budget: usize) {
345        while self.entries.len() > budget {
346            // Find index of entry with the lowest attention score.
347            let min_idx = self
348                .entries
349                .iter()
350                .enumerate()
351                .min_by(|(_, a), (_, b)| {
352                    a.attention_score
353                        .partial_cmp(&b.attention_score)
354                        .unwrap_or(std::cmp::Ordering::Equal)
355                })
356                .map(|(i, _)| i)
357                .unwrap();
358            self.entries.remove(min_idx);
359        }
360    }
361
362    /// Sliding window eviction: keep first `sink` tokens and last `window` tokens.
363    fn evict_sliding_window(&mut self, budget: usize, window: usize, sink: usize) {
364        let effective_budget = budget.min(sink + window);
365        if self.entries.len() <= effective_budget {
366            return;
367        }
368
369        // Identify indices to keep: first `sink` and last `window`.
370        let len = self.entries.len();
371        let keep_end = window.min(len);
372        let keep_start = sink.min(len.saturating_sub(keep_end));
373
374        let mut kept: VecDeque<CacheEntry> = VecDeque::with_capacity(keep_start + keep_end);
375        for i in 0..keep_start {
376            kept.push_back(self.entries[i].clone());
377        }
378        for i in (len - keep_end)..len {
379            if i >= keep_start {
380                kept.push_back(self.entries[i].clone());
381            }
382        }
383        self.entries = kept;
384    }
385
386    /// Update cumulative attention scores for the H2O eviction policy.
387    ///
388    /// `scores` should have one value per current cache entry.
389    pub fn update_attention_scores(&mut self, scores: &[f64]) {
390        for (entry, &s) in self.entries.iter_mut().zip(scores.iter()) {
391            entry.attention_score += s;
392        }
393    }
394
395    /// Compute the budget for a given layer under PyramidKV.
396    ///
397    /// Lower layers get a proportionally larger share of `max_seq_len`.
398    pub fn pyramid_budget(&self, layer_idx: usize, total_layers: usize) -> usize {
399        if total_layers == 0 {
400            return self.config.max_seq_len;
401        }
402        let weight = (total_layers - layer_idx) as f64 / total_layers as f64;
403        let sum_weights: f64 = (1..=total_layers)
404            .map(|i| i as f64 / total_layers as f64)
405            .sum();
406        let budget = (weight / sum_weights) * self.config.max_seq_len as f64;
407        (budget.ceil() as usize).max(1)
408    }
409
410    /// Compression ratio: `f32 bytes / quantized bytes` for a single entry.
411    ///
412    /// A 4-bit cache over f32 baseline yields roughly 8x compression
413    /// (before accounting for scale/zero-point overhead).
414    pub fn compression_ratio(&self) -> f64 {
415        let total_elements = self.config.num_heads * self.config.head_dim;
416        let f32_bytes = (total_elements * 4 * 2) as f64; // K + V
417        let q_bytes = self.entry_quantized_bytes() as f64;
418        if q_bytes < f64::EPSILON {
419            return 0.0;
420        }
421        f32_bytes / q_bytes
422    }
423
424    /// Bytes consumed by the quantized data of a single KV entry (approximate).
425    fn entry_quantized_bytes(&self) -> usize {
426        let elements = self.config.num_heads * self.config.head_dim;
427        // 1 byte per element (unpacked) + scales + zero_points per head, times 2 (K+V).
428        let per_tensor = elements + self.config.num_heads * 4 * 2; // scale + zp as f32
429        per_tensor * 2
430    }
431
432    /// Approximate total memory usage of the cache in bytes.
433    pub fn memory_bytes(&self) -> usize {
434        self.entries.len() * self.entry_quantized_bytes()
435    }
436}
437
438// ---------------------------------------------------------------------------
439// Tests
440// ---------------------------------------------------------------------------
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    fn make_config(bits: u8, policy: EvictionPolicy) -> KVCacheConfig {
447        KVCacheConfig {
448            max_seq_len: 8,
449            num_heads: 2,
450            head_dim: 4,
451            quantization_bits: bits,
452            eviction_policy: policy,
453        }
454    }
455
456    // -- Quantization roundtrip tests --
457
458    #[test]
459    fn test_quantize_roundtrip_4bit() {
460        let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.25, -0.5, 0.75, -0.25];
461        let qt = quantize_asymmetric(&data, 2, 4);
462        let restored = dequantize(&qt, 2);
463        for (orig, rest) in data.iter().zip(restored.iter()) {
464            assert!(
465                (orig - rest).abs() < 0.15,
466                "4-bit error too large: {orig} vs {rest}"
467            );
468        }
469    }
470
471    #[test]
472    fn test_quantize_roundtrip_3bit() {
473        let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.3, -0.7, 0.8, -0.2];
474        let qt = quantize_asymmetric(&data, 2, 3);
475        let restored = dequantize(&qt, 2);
476        // 3-bit has only 8 levels so error is larger.
477        for (orig, rest) in data.iter().zip(restored.iter()) {
478            assert!(
479                (orig - rest).abs() < 0.35,
480                "3-bit error too large: {orig} vs {rest}"
481            );
482        }
483    }
484
485    #[test]
486    fn test_symmetric_quantize_roundtrip() {
487        let data: Vec<f32> = vec![0.0, 0.5, -0.5, 1.0, -1.0];
488        let (qdata, scale) = quantize_symmetric(&data, 4);
489        let restored = dequantize_symmetric(&qdata, scale, 4);
490        for (orig, rest) in data.iter().zip(restored.iter()) {
491            assert!((orig - rest).abs() < 0.2, "sym roundtrip: {orig} vs {rest}");
492        }
493    }
494
495    #[test]
496    fn test_bankers_rounding() {
497        assert_eq!(round_to_nearest_even(2.5), 2.0);
498        assert_eq!(round_to_nearest_even(3.5), 4.0);
499        assert_eq!(round_to_nearest_even(4.5), 4.0);
500        assert_eq!(round_to_nearest_even(1.3), 1.0);
501        assert_eq!(round_to_nearest_even(1.7), 2.0);
502    }
503
504    // -- Cache operations --
505
506    #[test]
507    fn test_cache_append_and_get() {
508        let cfg = make_config(4, EvictionPolicy::H2O);
509        let mut mgr = CacheManager::new(cfg);
510        let k = vec![1.0_f32; 8];
511        let v = vec![-1.0_f32; 8];
512        mgr.append(&k, &v, 0);
513        assert_eq!(mgr.len(), 1);
514
515        let (keys, vals) = mgr.get(&[0]);
516        assert_eq!(keys.len(), 1);
517        assert_eq!(vals.len(), 1);
518        assert_eq!(keys[0].len(), 8);
519    }
520
521    #[test]
522    fn test_cache_empty() {
523        let cfg = make_config(4, EvictionPolicy::H2O);
524        let mgr = CacheManager::new(cfg);
525        assert!(mgr.is_empty());
526        assert_eq!(mgr.len(), 0);
527        let (k, v) = mgr.get(&[0]);
528        assert!(k.is_empty());
529        assert!(v.is_empty());
530    }
531
532    #[test]
533    fn test_h2o_eviction() {
534        let cfg = make_config(4, EvictionPolicy::H2O);
535        let mut mgr = CacheManager::new(cfg);
536
537        // Insert 4 entries.
538        for i in 0..4 {
539            let k = vec![i as f32; 8];
540            let v = vec![i as f32; 8];
541            mgr.append(&k, &v, 0);
542        }
543        // Give them different attention scores: entry 1 gets the lowest.
544        mgr.update_attention_scores(&[5.0, 1.0, 3.0, 4.0]);
545
546        // Evict down to 3.
547        mgr.evict(3);
548        assert_eq!(mgr.len(), 3);
549
550        // The entry with score 1.0 (index 1) should have been removed.
551        // Remaining scores should be 5.0, 3.0, 4.0.
552        let scores: Vec<f64> = mgr.entries.iter().map(|e| e.attention_score).collect();
553        assert!(!scores.contains(&1.0));
554    }
555
556    #[test]
557    fn test_sliding_window_eviction() {
558        let mut cfg = make_config(4, EvictionPolicy::SlidingWindow { window: 3, sink: 2 });
559        cfg.max_seq_len = 100; // large so auto-evict doesn't trigger
560        let mut mgr = CacheManager::new(cfg);
561
562        // Insert 10 entries with sequential values.
563        for i in 0..10 {
564            let k = vec![i as f32; 8];
565            let v = vec![i as f32; 8];
566            mgr.append(&k, &v, 0);
567        }
568        assert_eq!(mgr.len(), 10);
569
570        // Evict down to 5 (keep sink=2 and window=3).
571        mgr.evict(5);
572        assert_eq!(mgr.len(), 5);
573
574        // First 2 entries (sink) and last 3 entries should remain.
575        let seq_idxs: Vec<usize> = mgr.entries.iter().map(|e| e.seq_idx).collect();
576        assert_eq!(seq_idxs[0], 0);
577        assert_eq!(seq_idxs[1], 1);
578        assert!(seq_idxs.contains(&7));
579        assert!(seq_idxs.contains(&8));
580        assert!(seq_idxs.contains(&9));
581    }
582
583    #[test]
584    fn test_compression_ratio() {
585        let cfg = make_config(4, EvictionPolicy::H2O);
586        let mgr = CacheManager::new(cfg);
587        let ratio = mgr.compression_ratio();
588        // 4-bit in our unpacked scheme: each element uses 1 byte vs 4 bytes in f32,
589        // but we also store scales/zero-points. Should still be > 1.0.
590        assert!(
591            ratio > 1.0,
592            "compression ratio should be > 1.0, got {ratio}"
593        );
594    }
595
596    #[test]
597    fn test_memory_bytes() {
598        let cfg = make_config(4, EvictionPolicy::H2O);
599        let mut mgr = CacheManager::new(cfg);
600        assert_eq!(mgr.memory_bytes(), 0);
601
602        let k = vec![0.5_f32; 8];
603        let v = vec![-0.5_f32; 8];
604        mgr.append(&k, &v, 0);
605        assert!(mgr.memory_bytes() > 0);
606
607        let bytes_one = mgr.memory_bytes();
608        mgr.append(&k, &v, 0);
609        assert_eq!(mgr.memory_bytes(), bytes_one * 2);
610    }
611
612    #[test]
613    fn test_auto_eviction_on_append() {
614        let cfg = make_config(4, EvictionPolicy::H2O);
615        // max_seq_len = 8
616        let mut mgr = CacheManager::new(cfg);
617        for i in 0..12 {
618            let k = vec![i as f32; 8];
619            let v = vec![i as f32; 8];
620            mgr.append(&k, &v, 0);
621        }
622        // Should never exceed max_seq_len.
623        assert!(mgr.len() <= 8);
624    }
625
626    #[test]
627    fn test_pyramid_budget() {
628        let cfg = make_config(4, EvictionPolicy::PyramidKV { total_layers: 4 });
629        let mgr = CacheManager::new(cfg);
630        let b0 = mgr.pyramid_budget(0, 4);
631        let b3 = mgr.pyramid_budget(3, 4);
632        // Lower layers should get a larger budget.
633        assert!(
634            b0 > b3,
635            "layer 0 budget ({b0}) should exceed layer 3 ({b3})"
636        );
637    }
638
639    #[test]
640    fn test_single_entry_operations() {
641        let cfg = make_config(3, EvictionPolicy::H2O);
642        let mut mgr = CacheManager::new(cfg);
643        let k = vec![0.42_f32; 8];
644        let v = vec![-0.42_f32; 8];
645        mgr.append(&k, &v, 0);
646
647        mgr.update_attention_scores(&[1.0]);
648        mgr.evict(1);
649        assert_eq!(mgr.len(), 1);
650
651        let (keys, vals) = mgr.get(&[0]);
652        assert_eq!(keys.len(), 1);
653        assert_eq!(vals.len(), 1);
654    }
655}