1use std::collections::VecDeque;
40
41#[derive(Debug, Clone, PartialEq)]
47pub enum EvictionPolicy {
48 H2O,
51 SlidingWindow {
54 window: usize,
56 sink: usize,
58 },
59 PyramidKV {
63 total_layers: usize,
65 },
66}
67
68#[derive(Debug, Clone)]
70pub struct KVCacheConfig {
71 pub max_seq_len: usize,
73 pub num_heads: usize,
75 pub head_dim: usize,
77 pub quantization_bits: u8,
79 pub eviction_policy: EvictionPolicy,
81}
82
83#[derive(Debug, Clone)]
90pub struct QuantizedTensor {
91 pub data: Vec<u8>,
94 pub scales: Vec<f32>,
96 pub zero_points: Vec<f32>,
98 pub bits: u8,
100}
101
102#[inline]
104pub fn round_to_nearest_even(x: f32) -> f32 {
105 let rounded = x.round();
106 let frac = (x - x.floor()).abs();
108 if (frac - 0.5).abs() < f32::EPSILON {
109 let r = rounded as i64;
110 if r % 2 != 0 {
111 if x > 0.0 { rounded - 1.0 } else { rounded + 1.0 }
113 } else {
114 rounded
115 }
116 } else {
117 rounded
118 }
119}
120
121pub fn quantize_asymmetric(tensor: &[f32], num_heads: usize, bits: u8) -> QuantizedTensor {
127 let head_dim = tensor.len() / num_heads;
128 let qmax = ((1u32 << bits) - 1) as f32;
129
130 let mut data = Vec::with_capacity(tensor.len());
131 let mut scales = Vec::with_capacity(num_heads);
132 let mut zero_points = Vec::with_capacity(num_heads);
133
134 for h in 0..num_heads {
135 let start = h * head_dim;
136 let end = start + head_dim;
137 let channel = &tensor[start..end];
138
139 let min_val = channel.iter().copied().fold(f32::INFINITY, f32::min);
140 let max_val = channel.iter().copied().fold(f32::NEG_INFINITY, f32::max);
141
142 let range = max_val - min_val;
143 let scale = if range.abs() < f32::EPSILON { 1.0 } else { range / qmax };
144 let zp = if range.abs() < f32::EPSILON { 0.0 } else { -min_val / scale };
145
146 scales.push(scale);
147 zero_points.push(zp);
148
149 for &v in channel {
150 let q = round_to_nearest_even(v / scale + zp).clamp(0.0, qmax);
151 data.push(q as u8);
152 }
153 }
154
155 QuantizedTensor { data, scales, zero_points, bits }
156}
157
158pub fn quantize_symmetric(tensor: &[f32], bits: u8) -> (Vec<u8>, f32) {
166 assert!(bits >= 2 && bits <= 8, "quantize_symmetric: bits must be in [2, 8], got {}", bits);
167 let qmax = ((1u32 << (bits - 1)) - 1) as f32;
168 let abs_max = tensor.iter().copied().map(f32::abs).fold(0.0_f32, f32::max);
169 let scale = if abs_max < f32::EPSILON { 1.0 } else { abs_max / qmax };
170 let offset = (1u32 << (bits - 1)) as f32; let data: Vec<u8> = tensor
173 .iter()
174 .map(|&v| {
175 let q = round_to_nearest_even(v / scale + offset).clamp(0.0, (1u32 << bits) as f32 - 1.0);
176 q as u8
177 })
178 .collect();
179 (data, scale)
180}
181
182pub fn dequantize_symmetric(data: &[u8], scale: f32, bits: u8) -> Vec<f32> {
184 let offset = (1u32 << (bits - 1)) as f32;
185 data.iter().map(|&q| (q as f32 - offset) * scale).collect()
186}
187
188pub fn dequantize(qt: &QuantizedTensor, num_heads: usize) -> Vec<f32> {
190 let head_dim = qt.data.len() / num_heads;
191 let mut out = Vec::with_capacity(qt.data.len());
192 for h in 0..num_heads {
193 let start = h * head_dim;
194 let end = start + head_dim;
195 let scale = qt.scales[h];
196 let zp = qt.zero_points[h];
197 for &q in &qt.data[start..end] {
198 out.push(scale * (q as f32 - zp));
199 }
200 }
201 out
202}
203
204#[derive(Debug, Clone)]
210struct CacheEntry {
211 key: QuantizedTensor,
212 value: QuantizedTensor,
213 attention_score: f64,
215 seq_idx: usize,
217}
218
219pub struct CacheManager {
228 config: KVCacheConfig,
229 entries: VecDeque<CacheEntry>,
230 next_seq: usize,
231}
232
233impl CacheManager {
234 pub fn new(config: KVCacheConfig) -> Self {
236 Self {
237 config,
238 entries: VecDeque::new(),
239 next_seq: 0,
240 }
241 }
242
243 pub fn len(&self) -> usize {
245 self.entries.len()
246 }
247
248 pub fn is_empty(&self) -> bool {
250 self.entries.is_empty()
251 }
252
253 pub fn append(&mut self, key: &[f32], value: &[f32], _layer_idx: usize) {
259 let bits = self.config.quantization_bits;
260 let heads = self.config.num_heads;
261
262 let qk = quantize_asymmetric(key, heads, bits);
263 let qv = quantize_asymmetric(value, heads, bits);
264
265 self.entries.push_back(CacheEntry {
266 key: qk,
267 value: qv,
268 attention_score: 0.0,
269 seq_idx: self.next_seq,
270 });
271 self.next_seq += 1;
272
273 if self.entries.len() > self.config.max_seq_len {
275 self.evict(self.config.max_seq_len);
276 }
277 }
278
279 pub fn get(&self, positions: &[usize]) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
284 let heads = self.config.num_heads;
285 let mut keys = Vec::with_capacity(positions.len());
286 let mut values = Vec::with_capacity(positions.len());
287
288 for &pos in positions {
289 if pos < self.entries.len() {
290 let entry = &self.entries[pos];
291 keys.push(dequantize(&entry.key, heads));
292 values.push(dequantize(&entry.value, heads));
293 }
294 }
295 (keys, values)
296 }
297
298 pub fn evict(&mut self, budget: usize) {
300 if self.entries.len() <= budget {
301 return;
302 }
303
304 match &self.config.eviction_policy {
305 EvictionPolicy::H2O => self.evict_h2o(budget),
306 EvictionPolicy::SlidingWindow { window, sink } => {
307 self.evict_sliding_window(budget, *window, *sink);
308 }
309 EvictionPolicy::PyramidKV { .. } => {
310 self.evict_h2o(budget);
313 }
314 }
315 }
316
317 fn evict_h2o(&mut self, budget: usize) {
319 while self.entries.len() > budget {
320 let min_idx = self
322 .entries
323 .iter()
324 .enumerate()
325 .min_by(|(_, a), (_, b)| {
326 a.attention_score
327 .partial_cmp(&b.attention_score)
328 .unwrap_or(std::cmp::Ordering::Equal)
329 })
330 .map(|(i, _)| i)
331 .unwrap();
332 self.entries.remove(min_idx);
333 }
334 }
335
336 fn evict_sliding_window(&mut self, budget: usize, window: usize, sink: usize) {
338 let effective_budget = budget.min(sink + window);
339 if self.entries.len() <= effective_budget {
340 return;
341 }
342
343 let len = self.entries.len();
345 let keep_end = window.min(len);
346 let keep_start = sink.min(len.saturating_sub(keep_end));
347
348 let mut kept: VecDeque<CacheEntry> = VecDeque::with_capacity(keep_start + keep_end);
349 for i in 0..keep_start {
350 kept.push_back(self.entries[i].clone());
351 }
352 for i in (len - keep_end)..len {
353 if i >= keep_start {
354 kept.push_back(self.entries[i].clone());
355 }
356 }
357 self.entries = kept;
358 }
359
360 pub fn update_attention_scores(&mut self, scores: &[f64]) {
364 for (entry, &s) in self.entries.iter_mut().zip(scores.iter()) {
365 entry.attention_score += s;
366 }
367 }
368
369 pub fn pyramid_budget(&self, layer_idx: usize, total_layers: usize) -> usize {
373 if total_layers == 0 {
374 return self.config.max_seq_len;
375 }
376 let weight = (total_layers - layer_idx) as f64 / total_layers as f64;
377 let sum_weights: f64 = (1..=total_layers).map(|i| i as f64 / total_layers as f64).sum();
378 let budget = (weight / sum_weights) * self.config.max_seq_len as f64;
379 (budget.ceil() as usize).max(1)
380 }
381
382 pub fn compression_ratio(&self) -> f64 {
387 let total_elements = self.config.num_heads * self.config.head_dim;
388 let f32_bytes = (total_elements * 4 * 2) as f64; let q_bytes = self.entry_quantized_bytes() as f64;
390 if q_bytes < f64::EPSILON {
391 return 0.0;
392 }
393 f32_bytes / q_bytes
394 }
395
396 fn entry_quantized_bytes(&self) -> usize {
398 let elements = self.config.num_heads * self.config.head_dim;
399 let per_tensor = elements + self.config.num_heads * 4 * 2; per_tensor * 2
402 }
403
404 pub fn memory_bytes(&self) -> usize {
406 self.entries.len() * self.entry_quantized_bytes()
407 }
408}
409
410#[cfg(test)]
415mod tests {
416 use super::*;
417
418 fn make_config(bits: u8, policy: EvictionPolicy) -> KVCacheConfig {
419 KVCacheConfig {
420 max_seq_len: 8,
421 num_heads: 2,
422 head_dim: 4,
423 quantization_bits: bits,
424 eviction_policy: policy,
425 }
426 }
427
428 #[test]
431 fn test_quantize_roundtrip_4bit() {
432 let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.25, -0.5, 0.75, -0.25];
433 let qt = quantize_asymmetric(&data, 2, 4);
434 let restored = dequantize(&qt, 2);
435 for (orig, rest) in data.iter().zip(restored.iter()) {
436 assert!((orig - rest).abs() < 0.15, "4-bit error too large: {orig} vs {rest}");
437 }
438 }
439
440 #[test]
441 fn test_quantize_roundtrip_3bit() {
442 let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.3, -0.7, 0.8, -0.2];
443 let qt = quantize_asymmetric(&data, 2, 3);
444 let restored = dequantize(&qt, 2);
445 for (orig, rest) in data.iter().zip(restored.iter()) {
447 assert!((orig - rest).abs() < 0.35, "3-bit error too large: {orig} vs {rest}");
448 }
449 }
450
451 #[test]
452 fn test_symmetric_quantize_roundtrip() {
453 let data: Vec<f32> = vec![0.0, 0.5, -0.5, 1.0, -1.0];
454 let (qdata, scale) = quantize_symmetric(&data, 4);
455 let restored = dequantize_symmetric(&qdata, scale, 4);
456 for (orig, rest) in data.iter().zip(restored.iter()) {
457 assert!((orig - rest).abs() < 0.2, "sym roundtrip: {orig} vs {rest}");
458 }
459 }
460
461 #[test]
462 fn test_bankers_rounding() {
463 assert_eq!(round_to_nearest_even(2.5), 2.0);
464 assert_eq!(round_to_nearest_even(3.5), 4.0);
465 assert_eq!(round_to_nearest_even(4.5), 4.0);
466 assert_eq!(round_to_nearest_even(1.3), 1.0);
467 assert_eq!(round_to_nearest_even(1.7), 2.0);
468 }
469
470 #[test]
473 fn test_cache_append_and_get() {
474 let cfg = make_config(4, EvictionPolicy::H2O);
475 let mut mgr = CacheManager::new(cfg);
476 let k = vec![1.0_f32; 8];
477 let v = vec![-1.0_f32; 8];
478 mgr.append(&k, &v, 0);
479 assert_eq!(mgr.len(), 1);
480
481 let (keys, vals) = mgr.get(&[0]);
482 assert_eq!(keys.len(), 1);
483 assert_eq!(vals.len(), 1);
484 assert_eq!(keys[0].len(), 8);
485 }
486
487 #[test]
488 fn test_cache_empty() {
489 let cfg = make_config(4, EvictionPolicy::H2O);
490 let mgr = CacheManager::new(cfg);
491 assert!(mgr.is_empty());
492 assert_eq!(mgr.len(), 0);
493 let (k, v) = mgr.get(&[0]);
494 assert!(k.is_empty());
495 assert!(v.is_empty());
496 }
497
498 #[test]
499 fn test_h2o_eviction() {
500 let cfg = make_config(4, EvictionPolicy::H2O);
501 let mut mgr = CacheManager::new(cfg);
502
503 for i in 0..4 {
505 let k = vec![i as f32; 8];
506 let v = vec![i as f32; 8];
507 mgr.append(&k, &v, 0);
508 }
509 mgr.update_attention_scores(&[5.0, 1.0, 3.0, 4.0]);
511
512 mgr.evict(3);
514 assert_eq!(mgr.len(), 3);
515
516 let scores: Vec<f64> = mgr.entries.iter().map(|e| e.attention_score).collect();
519 assert!(!scores.contains(&1.0));
520 }
521
522 #[test]
523 fn test_sliding_window_eviction() {
524 let mut cfg = make_config(4, EvictionPolicy::SlidingWindow { window: 3, sink: 2 });
525 cfg.max_seq_len = 100; let mut mgr = CacheManager::new(cfg);
527
528 for i in 0..10 {
530 let k = vec![i as f32; 8];
531 let v = vec![i as f32; 8];
532 mgr.append(&k, &v, 0);
533 }
534 assert_eq!(mgr.len(), 10);
535
536 mgr.evict(5);
538 assert_eq!(mgr.len(), 5);
539
540 let seq_idxs: Vec<usize> = mgr.entries.iter().map(|e| e.seq_idx).collect();
542 assert_eq!(seq_idxs[0], 0);
543 assert_eq!(seq_idxs[1], 1);
544 assert!(seq_idxs.contains(&7));
545 assert!(seq_idxs.contains(&8));
546 assert!(seq_idxs.contains(&9));
547 }
548
549 #[test]
550 fn test_compression_ratio() {
551 let cfg = make_config(4, EvictionPolicy::H2O);
552 let mgr = CacheManager::new(cfg);
553 let ratio = mgr.compression_ratio();
554 assert!(ratio > 1.0, "compression ratio should be > 1.0, got {ratio}");
557 }
558
559 #[test]
560 fn test_memory_bytes() {
561 let cfg = make_config(4, EvictionPolicy::H2O);
562 let mut mgr = CacheManager::new(cfg);
563 assert_eq!(mgr.memory_bytes(), 0);
564
565 let k = vec![0.5_f32; 8];
566 let v = vec![-0.5_f32; 8];
567 mgr.append(&k, &v, 0);
568 assert!(mgr.memory_bytes() > 0);
569
570 let bytes_one = mgr.memory_bytes();
571 mgr.append(&k, &v, 0);
572 assert_eq!(mgr.memory_bytes(), bytes_one * 2);
573 }
574
575 #[test]
576 fn test_auto_eviction_on_append() {
577 let cfg = make_config(4, EvictionPolicy::H2O);
578 let mut mgr = CacheManager::new(cfg);
580 for i in 0..12 {
581 let k = vec![i as f32; 8];
582 let v = vec![i as f32; 8];
583 mgr.append(&k, &v, 0);
584 }
585 assert!(mgr.len() <= 8);
587 }
588
589 #[test]
590 fn test_pyramid_budget() {
591 let cfg = make_config(4, EvictionPolicy::PyramidKV { total_layers: 4 });
592 let mgr = CacheManager::new(cfg);
593 let b0 = mgr.pyramid_budget(0, 4);
594 let b3 = mgr.pyramid_budget(3, 4);
595 assert!(b0 > b3, "layer 0 budget ({b0}) should exceed layer 3 ({b3})");
597 }
598
599 #[test]
600 fn test_single_entry_operations() {
601 let cfg = make_config(3, EvictionPolicy::H2O);
602 let mut mgr = CacheManager::new(cfg);
603 let k = vec![0.42_f32; 8];
604 let v = vec![-0.42_f32; 8];
605 mgr.append(&k, &v, 0);
606
607 mgr.update_attention_scores(&[1.0]);
608 mgr.evict(1);
609 assert_eq!(mgr.len(), 1);
610
611 let (keys, vals) = mgr.get(&[0]);
612 assert_eq!(keys.len(), 1);
613 assert_eq!(vals.len(), 1);
614 }
615}