Skip to main content

ruvector_attention/attention/
kv_cache.rs

1//! KV-Cache Compression for inference-time memory efficiency.
2//!
3//! Inspired by Google's TurboQuant (ICLR 2026), this module implements low-bit
4//! quantization of Key-Value caches to reduce memory pressure during autoregressive
5//! inference. TurboQuant demonstrates that 3-bit asymmetric per-channel quantization
6//! of KV caches achieves up to 6x memory reduction and 8x attention computation
7//! speedup with negligible quality loss (<0.5% perplexity degradation).
8//!
9//! # Design
10//!
11//! - **Per-channel asymmetric quantization**: Each attention head gets its own
12//!   scale and zero-point, preserving head-specific value distributions.
13//! - **Banker's rounding**: Round-to-nearest-even reduces systematic bias in
14//!   low-bit regimes, critical at 3-bit where every quantum matters.
15//! - **Eviction policies**: When the cache exceeds a budget, entries are pruned
16//!   using one of three strategies: H2O (attention-score based), Sliding Window
17//!   (recency-biased with sink tokens), or PyramidKV (layer-aware budgets).
18//!
19//! # Example
20//!
21//! ```rust
22//! use ruvector_attention::attention::kv_cache::*;
23//!
24//! let config = KVCacheConfig {
25//!     max_seq_len: 128,
26//!     num_heads: 4,
27//!     head_dim: 16,
28//!     quantization_bits: 4,
29//!     eviction_policy: EvictionPolicy::SlidingWindow { window: 64, sink: 4 },
30//! };
31//! let mut manager = CacheManager::new(config);
32//! let key = vec![0.5_f32; 64];
33//! let value = vec![-0.3_f32; 64];
34//! manager.append(&key, &value, 0);
35//! let (k, v) = manager.get(&[0]);
36//! assert_eq!(k.len(), 1);
37//! ```
38
39use std::collections::VecDeque;
40
41// ---------------------------------------------------------------------------
42// Configuration
43// ---------------------------------------------------------------------------
44
45/// Eviction policy for pruning the KV-cache when it exceeds its budget.
46#[derive(Debug, Clone, PartialEq)]
47pub enum EvictionPolicy {
48    /// Heavy Hitter Oracle: retains tokens with the highest cumulative
49    /// attention scores, discarding those rarely attended to.
50    H2O,
51    /// Sliding Window with sink tokens (StreamingLLM). Keeps the first
52    /// `sink` tokens and the most recent `window` tokens.
53    SlidingWindow {
54        /// Number of recent tokens to retain.
55        window: usize,
56        /// Number of initial "sink" tokens to always keep.
57        sink: usize,
58    },
59    /// PyramidKV: assigns larger cache budgets to lower (earlier) layers
60    /// and smaller budgets to upper layers, reflecting the observation
61    /// that lower layers capture broader context.
62    PyramidKV {
63        /// Total number of layers in the model.
64        total_layers: usize,
65    },
66}
67
68/// Configuration for the quantized KV-cache.
69#[derive(Debug, Clone)]
70pub struct KVCacheConfig {
71    /// Maximum sequence length the cache can hold before eviction is required.
72    pub max_seq_len: usize,
73    /// Number of attention heads.
74    pub num_heads: usize,
75    /// Dimension per attention head.
76    pub head_dim: usize,
77    /// Bit-width for quantization. Supported: 2, 3, 4, 8.
78    pub quantization_bits: u8,
79    /// Policy used when the cache exceeds its budget.
80    pub eviction_policy: EvictionPolicy,
81}
82
83// ---------------------------------------------------------------------------
84// Quantization primitives
85// ---------------------------------------------------------------------------
86
87/// A quantized tensor with per-channel scale and zero-point for asymmetric
88/// dequantization: `value = scale * (quantized - zero_point)`.
89#[derive(Debug, Clone)]
90pub struct QuantizedTensor {
91    /// Packed quantized values stored as u8. For sub-byte widths the values
92    /// are stored one-per-byte for simplicity (packing is a future optimisation).
93    pub data: Vec<u8>,
94    /// Per-channel (per-head) scale factors.
95    pub scales: Vec<f32>,
96    /// Per-channel (per-head) zero-points in quantized domain.
97    pub zero_points: Vec<f32>,
98    /// Bit-width used during quantization.
99    pub bits: u8,
100}
101
102/// Banker's rounding (round half to even) to reduce systematic bias.
103#[inline]
104pub fn round_to_nearest_even(x: f32) -> f32 {
105    let rounded = x.round();
106    // When exactly halfway, round to even.
107    let frac = (x - x.floor()).abs();
108    if (frac - 0.5).abs() < f32::EPSILON {
109        let r = rounded as i64;
110        if r % 2 != 0 {
111            // Nudge toward even.
112            if x > 0.0 { rounded - 1.0 } else { rounded + 1.0 }
113        } else {
114            rounded
115        }
116    } else {
117        rounded
118    }
119}
120
121/// Asymmetric per-channel quantization.
122///
123/// `tensor` is shaped `[num_heads * head_dim]` (one KV vector across all heads).
124/// Quantisation is performed per-head (channel), each getting its own scale and
125/// zero-point. Returns a [`QuantizedTensor`].
126pub fn quantize_asymmetric(tensor: &[f32], num_heads: usize, bits: u8) -> QuantizedTensor {
127    let head_dim = tensor.len() / num_heads;
128    let qmax = ((1u32 << bits) - 1) as f32;
129
130    let mut data = Vec::with_capacity(tensor.len());
131    let mut scales = Vec::with_capacity(num_heads);
132    let mut zero_points = Vec::with_capacity(num_heads);
133
134    for h in 0..num_heads {
135        let start = h * head_dim;
136        let end = start + head_dim;
137        let channel = &tensor[start..end];
138
139        let min_val = channel.iter().copied().fold(f32::INFINITY, f32::min);
140        let max_val = channel.iter().copied().fold(f32::NEG_INFINITY, f32::max);
141
142        let range = max_val - min_val;
143        let scale = if range.abs() < f32::EPSILON { 1.0 } else { range / qmax };
144        let zp = if range.abs() < f32::EPSILON { 0.0 } else { -min_val / scale };
145
146        scales.push(scale);
147        zero_points.push(zp);
148
149        for &v in channel {
150            let q = round_to_nearest_even(v / scale + zp).clamp(0.0, qmax);
151            data.push(q as u8);
152        }
153    }
154
155    QuantizedTensor { data, scales, zero_points, bits }
156}
157
158/// Symmetric quantization (simpler, useful for comparison).
159///
160/// `value = scale * quantized` with zero-point fixed at the midpoint.
161///
162/// # Panics
163///
164/// Panics if `bits` is less than 2 or greater than 8.
165pub fn quantize_symmetric(tensor: &[f32], bits: u8) -> (Vec<u8>, f32) {
166    assert!(bits >= 2 && bits <= 8, "quantize_symmetric: bits must be in [2, 8], got {}", bits);
167    let qmax = ((1u32 << (bits - 1)) - 1) as f32;
168    let abs_max = tensor.iter().copied().map(f32::abs).fold(0.0_f32, f32::max);
169    let scale = if abs_max < f32::EPSILON { 1.0 } else { abs_max / qmax };
170    let offset = (1u32 << (bits - 1)) as f32; // unsigned offset
171
172    let data: Vec<u8> = tensor
173        .iter()
174        .map(|&v| {
175            let q = round_to_nearest_even(v / scale + offset).clamp(0.0, (1u32 << bits) as f32 - 1.0);
176            q as u8
177        })
178        .collect();
179    (data, scale)
180}
181
182/// Dequantize symmetric quantized data back to f32.
183pub fn dequantize_symmetric(data: &[u8], scale: f32, bits: u8) -> Vec<f32> {
184    let offset = (1u32 << (bits - 1)) as f32;
185    data.iter().map(|&q| (q as f32 - offset) * scale).collect()
186}
187
188/// Dequantize an asymmetrically quantized tensor back to f32.
189pub fn dequantize(qt: &QuantizedTensor, num_heads: usize) -> Vec<f32> {
190    let head_dim = qt.data.len() / num_heads;
191    let mut out = Vec::with_capacity(qt.data.len());
192    for h in 0..num_heads {
193        let start = h * head_dim;
194        let end = start + head_dim;
195        let scale = qt.scales[h];
196        let zp = qt.zero_points[h];
197        for &q in &qt.data[start..end] {
198            out.push(scale * (q as f32 - zp));
199        }
200    }
201    out
202}
203
204// ---------------------------------------------------------------------------
205// Cache entry
206// ---------------------------------------------------------------------------
207
208/// A single cached key-value pair (quantized).
209#[derive(Debug, Clone)]
210struct CacheEntry {
211    key: QuantizedTensor,
212    value: QuantizedTensor,
213    /// Cumulative attention score for H2O eviction.
214    attention_score: f64,
215    /// Insertion order (monotonically increasing).
216    seq_idx: usize,
217}
218
219// ---------------------------------------------------------------------------
220// CacheManager
221// ---------------------------------------------------------------------------
222
223/// Manages a quantized KV-cache with configurable eviction.
224///
225/// Provides `append`, `get`, `evict`, and diagnostic methods such as
226/// `compression_ratio` and `memory_bytes`.
227pub struct CacheManager {
228    config: KVCacheConfig,
229    entries: VecDeque<CacheEntry>,
230    next_seq: usize,
231}
232
233impl CacheManager {
234    /// Create a new cache manager with the given configuration.
235    pub fn new(config: KVCacheConfig) -> Self {
236        Self {
237            config,
238            entries: VecDeque::new(),
239            next_seq: 0,
240        }
241    }
242
243    /// Number of entries currently in the cache.
244    pub fn len(&self) -> usize {
245        self.entries.len()
246    }
247
248    /// Whether the cache is empty.
249    pub fn is_empty(&self) -> bool {
250        self.entries.is_empty()
251    }
252
253    /// Append a new key-value pair to the cache.
254    ///
255    /// `key` and `value` must each have length `num_heads * head_dim`.
256    /// `_layer_idx` is used by the PyramidKV eviction policy to determine
257    /// the per-layer budget.
258    pub fn append(&mut self, key: &[f32], value: &[f32], _layer_idx: usize) {
259        let bits = self.config.quantization_bits;
260        let heads = self.config.num_heads;
261
262        let qk = quantize_asymmetric(key, heads, bits);
263        let qv = quantize_asymmetric(value, heads, bits);
264
265        self.entries.push_back(CacheEntry {
266            key: qk,
267            value: qv,
268            attention_score: 0.0,
269            seq_idx: self.next_seq,
270        });
271        self.next_seq += 1;
272
273        // Auto-evict if over budget.
274        if self.entries.len() > self.config.max_seq_len {
275            self.evict(self.config.max_seq_len);
276        }
277    }
278
279    /// Retrieve dequantized key-value pairs at the given logical positions.
280    ///
281    /// Returns `(keys, values)` where each inner `Vec<f32>` has length
282    /// `num_heads * head_dim`.
283    pub fn get(&self, positions: &[usize]) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
284        let heads = self.config.num_heads;
285        let mut keys = Vec::with_capacity(positions.len());
286        let mut values = Vec::with_capacity(positions.len());
287
288        for &pos in positions {
289            if pos < self.entries.len() {
290                let entry = &self.entries[pos];
291                keys.push(dequantize(&entry.key, heads));
292                values.push(dequantize(&entry.value, heads));
293            }
294        }
295        (keys, values)
296    }
297
298    /// Evict entries until the cache contains at most `budget` entries.
299    pub fn evict(&mut self, budget: usize) {
300        if self.entries.len() <= budget {
301            return;
302        }
303
304        match &self.config.eviction_policy {
305            EvictionPolicy::H2O => self.evict_h2o(budget),
306            EvictionPolicy::SlidingWindow { window, sink } => {
307                self.evict_sliding_window(budget, *window, *sink);
308            }
309            EvictionPolicy::PyramidKV { .. } => {
310                // PyramidKV adjusts budget externally per layer; here we just
311                // fall back to H2O-style eviction within the given budget.
312                self.evict_h2o(budget);
313            }
314        }
315    }
316
317    /// H2O eviction: remove entries with the lowest cumulative attention score.
318    fn evict_h2o(&mut self, budget: usize) {
319        while self.entries.len() > budget {
320            // Find index of entry with the lowest attention score.
321            let min_idx = self
322                .entries
323                .iter()
324                .enumerate()
325                .min_by(|(_, a), (_, b)| {
326                    a.attention_score
327                        .partial_cmp(&b.attention_score)
328                        .unwrap_or(std::cmp::Ordering::Equal)
329                })
330                .map(|(i, _)| i)
331                .unwrap();
332            self.entries.remove(min_idx);
333        }
334    }
335
336    /// Sliding window eviction: keep first `sink` tokens and last `window` tokens.
337    fn evict_sliding_window(&mut self, budget: usize, window: usize, sink: usize) {
338        let effective_budget = budget.min(sink + window);
339        if self.entries.len() <= effective_budget {
340            return;
341        }
342
343        // Identify indices to keep: first `sink` and last `window`.
344        let len = self.entries.len();
345        let keep_end = window.min(len);
346        let keep_start = sink.min(len.saturating_sub(keep_end));
347
348        let mut kept: VecDeque<CacheEntry> = VecDeque::with_capacity(keep_start + keep_end);
349        for i in 0..keep_start {
350            kept.push_back(self.entries[i].clone());
351        }
352        for i in (len - keep_end)..len {
353            if i >= keep_start {
354                kept.push_back(self.entries[i].clone());
355            }
356        }
357        self.entries = kept;
358    }
359
360    /// Update cumulative attention scores for the H2O eviction policy.
361    ///
362    /// `scores` should have one value per current cache entry.
363    pub fn update_attention_scores(&mut self, scores: &[f64]) {
364        for (entry, &s) in self.entries.iter_mut().zip(scores.iter()) {
365            entry.attention_score += s;
366        }
367    }
368
369    /// Compute the budget for a given layer under PyramidKV.
370    ///
371    /// Lower layers get a proportionally larger share of `max_seq_len`.
372    pub fn pyramid_budget(&self, layer_idx: usize, total_layers: usize) -> usize {
373        if total_layers == 0 {
374            return self.config.max_seq_len;
375        }
376        let weight = (total_layers - layer_idx) as f64 / total_layers as f64;
377        let sum_weights: f64 = (1..=total_layers).map(|i| i as f64 / total_layers as f64).sum();
378        let budget = (weight / sum_weights) * self.config.max_seq_len as f64;
379        (budget.ceil() as usize).max(1)
380    }
381
382    /// Compression ratio: `f32 bytes / quantized bytes` for a single entry.
383    ///
384    /// A 4-bit cache over f32 baseline yields roughly 8x compression
385    /// (before accounting for scale/zero-point overhead).
386    pub fn compression_ratio(&self) -> f64 {
387        let total_elements = self.config.num_heads * self.config.head_dim;
388        let f32_bytes = (total_elements * 4 * 2) as f64; // K + V
389        let q_bytes = self.entry_quantized_bytes() as f64;
390        if q_bytes < f64::EPSILON {
391            return 0.0;
392        }
393        f32_bytes / q_bytes
394    }
395
396    /// Bytes consumed by the quantized data of a single KV entry (approximate).
397    fn entry_quantized_bytes(&self) -> usize {
398        let elements = self.config.num_heads * self.config.head_dim;
399        // 1 byte per element (unpacked) + scales + zero_points per head, times 2 (K+V).
400        let per_tensor = elements + self.config.num_heads * 4 * 2; // scale + zp as f32
401        per_tensor * 2
402    }
403
404    /// Approximate total memory usage of the cache in bytes.
405    pub fn memory_bytes(&self) -> usize {
406        self.entries.len() * self.entry_quantized_bytes()
407    }
408}
409
410// ---------------------------------------------------------------------------
411// Tests
412// ---------------------------------------------------------------------------
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    fn make_config(bits: u8, policy: EvictionPolicy) -> KVCacheConfig {
419        KVCacheConfig {
420            max_seq_len: 8,
421            num_heads: 2,
422            head_dim: 4,
423            quantization_bits: bits,
424            eviction_policy: policy,
425        }
426    }
427
428    // -- Quantization roundtrip tests --
429
430    #[test]
431    fn test_quantize_roundtrip_4bit() {
432        let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.25, -0.5, 0.75, -0.25];
433        let qt = quantize_asymmetric(&data, 2, 4);
434        let restored = dequantize(&qt, 2);
435        for (orig, rest) in data.iter().zip(restored.iter()) {
436            assert!((orig - rest).abs() < 0.15, "4-bit error too large: {orig} vs {rest}");
437        }
438    }
439
440    #[test]
441    fn test_quantize_roundtrip_3bit() {
442        let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.3, -0.7, 0.8, -0.2];
443        let qt = quantize_asymmetric(&data, 2, 3);
444        let restored = dequantize(&qt, 2);
445        // 3-bit has only 8 levels so error is larger.
446        for (orig, rest) in data.iter().zip(restored.iter()) {
447            assert!((orig - rest).abs() < 0.35, "3-bit error too large: {orig} vs {rest}");
448        }
449    }
450
451    #[test]
452    fn test_symmetric_quantize_roundtrip() {
453        let data: Vec<f32> = vec![0.0, 0.5, -0.5, 1.0, -1.0];
454        let (qdata, scale) = quantize_symmetric(&data, 4);
455        let restored = dequantize_symmetric(&qdata, scale, 4);
456        for (orig, rest) in data.iter().zip(restored.iter()) {
457            assert!((orig - rest).abs() < 0.2, "sym roundtrip: {orig} vs {rest}");
458        }
459    }
460
461    #[test]
462    fn test_bankers_rounding() {
463        assert_eq!(round_to_nearest_even(2.5), 2.0);
464        assert_eq!(round_to_nearest_even(3.5), 4.0);
465        assert_eq!(round_to_nearest_even(4.5), 4.0);
466        assert_eq!(round_to_nearest_even(1.3), 1.0);
467        assert_eq!(round_to_nearest_even(1.7), 2.0);
468    }
469
470    // -- Cache operations --
471
472    #[test]
473    fn test_cache_append_and_get() {
474        let cfg = make_config(4, EvictionPolicy::H2O);
475        let mut mgr = CacheManager::new(cfg);
476        let k = vec![1.0_f32; 8];
477        let v = vec![-1.0_f32; 8];
478        mgr.append(&k, &v, 0);
479        assert_eq!(mgr.len(), 1);
480
481        let (keys, vals) = mgr.get(&[0]);
482        assert_eq!(keys.len(), 1);
483        assert_eq!(vals.len(), 1);
484        assert_eq!(keys[0].len(), 8);
485    }
486
487    #[test]
488    fn test_cache_empty() {
489        let cfg = make_config(4, EvictionPolicy::H2O);
490        let mgr = CacheManager::new(cfg);
491        assert!(mgr.is_empty());
492        assert_eq!(mgr.len(), 0);
493        let (k, v) = mgr.get(&[0]);
494        assert!(k.is_empty());
495        assert!(v.is_empty());
496    }
497
498    #[test]
499    fn test_h2o_eviction() {
500        let cfg = make_config(4, EvictionPolicy::H2O);
501        let mut mgr = CacheManager::new(cfg);
502
503        // Insert 4 entries.
504        for i in 0..4 {
505            let k = vec![i as f32; 8];
506            let v = vec![i as f32; 8];
507            mgr.append(&k, &v, 0);
508        }
509        // Give them different attention scores: entry 1 gets the lowest.
510        mgr.update_attention_scores(&[5.0, 1.0, 3.0, 4.0]);
511
512        // Evict down to 3.
513        mgr.evict(3);
514        assert_eq!(mgr.len(), 3);
515
516        // The entry with score 1.0 (index 1) should have been removed.
517        // Remaining scores should be 5.0, 3.0, 4.0.
518        let scores: Vec<f64> = mgr.entries.iter().map(|e| e.attention_score).collect();
519        assert!(!scores.contains(&1.0));
520    }
521
522    #[test]
523    fn test_sliding_window_eviction() {
524        let mut cfg = make_config(4, EvictionPolicy::SlidingWindow { window: 3, sink: 2 });
525        cfg.max_seq_len = 100; // large so auto-evict doesn't trigger
526        let mut mgr = CacheManager::new(cfg);
527
528        // Insert 10 entries with sequential values.
529        for i in 0..10 {
530            let k = vec![i as f32; 8];
531            let v = vec![i as f32; 8];
532            mgr.append(&k, &v, 0);
533        }
534        assert_eq!(mgr.len(), 10);
535
536        // Evict down to 5 (keep sink=2 and window=3).
537        mgr.evict(5);
538        assert_eq!(mgr.len(), 5);
539
540        // First 2 entries (sink) and last 3 entries should remain.
541        let seq_idxs: Vec<usize> = mgr.entries.iter().map(|e| e.seq_idx).collect();
542        assert_eq!(seq_idxs[0], 0);
543        assert_eq!(seq_idxs[1], 1);
544        assert!(seq_idxs.contains(&7));
545        assert!(seq_idxs.contains(&8));
546        assert!(seq_idxs.contains(&9));
547    }
548
549    #[test]
550    fn test_compression_ratio() {
551        let cfg = make_config(4, EvictionPolicy::H2O);
552        let mgr = CacheManager::new(cfg);
553        let ratio = mgr.compression_ratio();
554        // 4-bit in our unpacked scheme: each element uses 1 byte vs 4 bytes in f32,
555        // but we also store scales/zero-points. Should still be > 1.0.
556        assert!(ratio > 1.0, "compression ratio should be > 1.0, got {ratio}");
557    }
558
559    #[test]
560    fn test_memory_bytes() {
561        let cfg = make_config(4, EvictionPolicy::H2O);
562        let mut mgr = CacheManager::new(cfg);
563        assert_eq!(mgr.memory_bytes(), 0);
564
565        let k = vec![0.5_f32; 8];
566        let v = vec![-0.5_f32; 8];
567        mgr.append(&k, &v, 0);
568        assert!(mgr.memory_bytes() > 0);
569
570        let bytes_one = mgr.memory_bytes();
571        mgr.append(&k, &v, 0);
572        assert_eq!(mgr.memory_bytes(), bytes_one * 2);
573    }
574
575    #[test]
576    fn test_auto_eviction_on_append() {
577        let cfg = make_config(4, EvictionPolicy::H2O);
578        // max_seq_len = 8
579        let mut mgr = CacheManager::new(cfg);
580        for i in 0..12 {
581            let k = vec![i as f32; 8];
582            let v = vec![i as f32; 8];
583            mgr.append(&k, &v, 0);
584        }
585        // Should never exceed max_seq_len.
586        assert!(mgr.len() <= 8);
587    }
588
589    #[test]
590    fn test_pyramid_budget() {
591        let cfg = make_config(4, EvictionPolicy::PyramidKV { total_layers: 4 });
592        let mgr = CacheManager::new(cfg);
593        let b0 = mgr.pyramid_budget(0, 4);
594        let b3 = mgr.pyramid_budget(3, 4);
595        // Lower layers should get a larger budget.
596        assert!(b0 > b3, "layer 0 budget ({b0}) should exceed layer 3 ({b3})");
597    }
598
599    #[test]
600    fn test_single_entry_operations() {
601        let cfg = make_config(3, EvictionPolicy::H2O);
602        let mut mgr = CacheManager::new(cfg);
603        let k = vec![0.42_f32; 8];
604        let v = vec![-0.42_f32; 8];
605        mgr.append(&k, &v, 0);
606
607        mgr.update_attention_scores(&[1.0]);
608        mgr.evict(1);
609        assert_eq!(mgr.len(), 1);
610
611        let (keys, vals) = mgr.get(&[0]);
612        assert_eq!(keys.len(), 1);
613        assert_eq!(vals.len(), 1);
614    }
615}