1use std::collections::VecDeque;
40
41#[derive(Debug, Clone, PartialEq)]
47pub enum EvictionPolicy {
48 H2O,
51 SlidingWindow {
54 window: usize,
56 sink: usize,
58 },
59 PyramidKV {
63 total_layers: usize,
65 },
66}
67
68#[derive(Debug, Clone)]
70pub struct KVCacheConfig {
71 pub max_seq_len: usize,
73 pub num_heads: usize,
75 pub head_dim: usize,
77 pub quantization_bits: u8,
79 pub eviction_policy: EvictionPolicy,
81}
82
83#[derive(Debug, Clone)]
90pub struct QuantizedTensor {
91 pub data: Vec<u8>,
94 pub scales: Vec<f32>,
96 pub zero_points: Vec<f32>,
98 pub bits: u8,
100}
101
102#[inline]
104pub fn round_to_nearest_even(x: f32) -> f32 {
105 let rounded = x.round();
106 let frac = (x - x.floor()).abs();
108 if (frac - 0.5).abs() < f32::EPSILON {
109 let r = rounded as i64;
110 if r % 2 != 0 {
111 if x > 0.0 {
113 rounded - 1.0
114 } else {
115 rounded + 1.0
116 }
117 } else {
118 rounded
119 }
120 } else {
121 rounded
122 }
123}
124
125pub fn quantize_asymmetric(tensor: &[f32], num_heads: usize, bits: u8) -> QuantizedTensor {
131 let head_dim = tensor.len() / num_heads;
132 let qmax = ((1u32 << bits) - 1) as f32;
133
134 let mut data = Vec::with_capacity(tensor.len());
135 let mut scales = Vec::with_capacity(num_heads);
136 let mut zero_points = Vec::with_capacity(num_heads);
137
138 for h in 0..num_heads {
139 let start = h * head_dim;
140 let end = start + head_dim;
141 let channel = &tensor[start..end];
142
143 let min_val = channel.iter().copied().fold(f32::INFINITY, f32::min);
144 let max_val = channel.iter().copied().fold(f32::NEG_INFINITY, f32::max);
145
146 let range = max_val - min_val;
147 let scale = if range.abs() < f32::EPSILON {
148 1.0
149 } else {
150 range / qmax
151 };
152 let zp = if range.abs() < f32::EPSILON {
153 0.0
154 } else {
155 -min_val / scale
156 };
157
158 scales.push(scale);
159 zero_points.push(zp);
160
161 for &v in channel {
162 let q = round_to_nearest_even(v / scale + zp).clamp(0.0, qmax);
163 data.push(q as u8);
164 }
165 }
166
167 QuantizedTensor {
168 data,
169 scales,
170 zero_points,
171 bits,
172 }
173}
174
175pub fn quantize_symmetric(tensor: &[f32], bits: u8) -> (Vec<u8>, f32) {
183 assert!(
184 bits >= 2 && bits <= 8,
185 "quantize_symmetric: bits must be in [2, 8], got {}",
186 bits
187 );
188 let qmax = ((1u32 << (bits - 1)) - 1) as f32;
189 let abs_max = tensor.iter().copied().map(f32::abs).fold(0.0_f32, f32::max);
190 let scale = if abs_max < f32::EPSILON {
191 1.0
192 } else {
193 abs_max / qmax
194 };
195 let offset = (1u32 << (bits - 1)) as f32; let data: Vec<u8> = tensor
198 .iter()
199 .map(|&v| {
200 let q =
201 round_to_nearest_even(v / scale + offset).clamp(0.0, (1u32 << bits) as f32 - 1.0);
202 q as u8
203 })
204 .collect();
205 (data, scale)
206}
207
208pub fn dequantize_symmetric(data: &[u8], scale: f32, bits: u8) -> Vec<f32> {
210 let offset = (1u32 << (bits - 1)) as f32;
211 data.iter().map(|&q| (q as f32 - offset) * scale).collect()
212}
213
214pub fn dequantize(qt: &QuantizedTensor, num_heads: usize) -> Vec<f32> {
216 let head_dim = qt.data.len() / num_heads;
217 let mut out = Vec::with_capacity(qt.data.len());
218 for h in 0..num_heads {
219 let start = h * head_dim;
220 let end = start + head_dim;
221 let scale = qt.scales[h];
222 let zp = qt.zero_points[h];
223 for &q in &qt.data[start..end] {
224 out.push(scale * (q as f32 - zp));
225 }
226 }
227 out
228}
229
230#[derive(Debug, Clone)]
236struct CacheEntry {
237 key: QuantizedTensor,
238 value: QuantizedTensor,
239 attention_score: f64,
241 seq_idx: usize,
243}
244
245pub struct CacheManager {
254 config: KVCacheConfig,
255 entries: VecDeque<CacheEntry>,
256 next_seq: usize,
257}
258
259impl CacheManager {
260 pub fn new(config: KVCacheConfig) -> Self {
262 Self {
263 config,
264 entries: VecDeque::new(),
265 next_seq: 0,
266 }
267 }
268
269 pub fn len(&self) -> usize {
271 self.entries.len()
272 }
273
274 pub fn is_empty(&self) -> bool {
276 self.entries.is_empty()
277 }
278
279 pub fn append(&mut self, key: &[f32], value: &[f32], _layer_idx: usize) {
285 let bits = self.config.quantization_bits;
286 let heads = self.config.num_heads;
287
288 let qk = quantize_asymmetric(key, heads, bits);
289 let qv = quantize_asymmetric(value, heads, bits);
290
291 self.entries.push_back(CacheEntry {
292 key: qk,
293 value: qv,
294 attention_score: 0.0,
295 seq_idx: self.next_seq,
296 });
297 self.next_seq += 1;
298
299 if self.entries.len() > self.config.max_seq_len {
301 self.evict(self.config.max_seq_len);
302 }
303 }
304
305 pub fn get(&self, positions: &[usize]) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
310 let heads = self.config.num_heads;
311 let mut keys = Vec::with_capacity(positions.len());
312 let mut values = Vec::with_capacity(positions.len());
313
314 for &pos in positions {
315 if pos < self.entries.len() {
316 let entry = &self.entries[pos];
317 keys.push(dequantize(&entry.key, heads));
318 values.push(dequantize(&entry.value, heads));
319 }
320 }
321 (keys, values)
322 }
323
324 pub fn evict(&mut self, budget: usize) {
326 if self.entries.len() <= budget {
327 return;
328 }
329
330 match &self.config.eviction_policy {
331 EvictionPolicy::H2O => self.evict_h2o(budget),
332 EvictionPolicy::SlidingWindow { window, sink } => {
333 self.evict_sliding_window(budget, *window, *sink);
334 }
335 EvictionPolicy::PyramidKV { .. } => {
336 self.evict_h2o(budget);
339 }
340 }
341 }
342
343 fn evict_h2o(&mut self, budget: usize) {
345 while self.entries.len() > budget {
346 let min_idx = self
348 .entries
349 .iter()
350 .enumerate()
351 .min_by(|(_, a), (_, b)| {
352 a.attention_score
353 .partial_cmp(&b.attention_score)
354 .unwrap_or(std::cmp::Ordering::Equal)
355 })
356 .map(|(i, _)| i)
357 .unwrap();
358 self.entries.remove(min_idx);
359 }
360 }
361
362 fn evict_sliding_window(&mut self, budget: usize, window: usize, sink: usize) {
364 let effective_budget = budget.min(sink + window);
365 if self.entries.len() <= effective_budget {
366 return;
367 }
368
369 let len = self.entries.len();
371 let keep_end = window.min(len);
372 let keep_start = sink.min(len.saturating_sub(keep_end));
373
374 let mut kept: VecDeque<CacheEntry> = VecDeque::with_capacity(keep_start + keep_end);
375 for i in 0..keep_start {
376 kept.push_back(self.entries[i].clone());
377 }
378 for i in (len - keep_end)..len {
379 if i >= keep_start {
380 kept.push_back(self.entries[i].clone());
381 }
382 }
383 self.entries = kept;
384 }
385
386 pub fn update_attention_scores(&mut self, scores: &[f64]) {
390 for (entry, &s) in self.entries.iter_mut().zip(scores.iter()) {
391 entry.attention_score += s;
392 }
393 }
394
395 pub fn pyramid_budget(&self, layer_idx: usize, total_layers: usize) -> usize {
399 if total_layers == 0 {
400 return self.config.max_seq_len;
401 }
402 let weight = (total_layers - layer_idx) as f64 / total_layers as f64;
403 let sum_weights: f64 = (1..=total_layers)
404 .map(|i| i as f64 / total_layers as f64)
405 .sum();
406 let budget = (weight / sum_weights) * self.config.max_seq_len as f64;
407 (budget.ceil() as usize).max(1)
408 }
409
410 pub fn compression_ratio(&self) -> f64 {
415 let total_elements = self.config.num_heads * self.config.head_dim;
416 let f32_bytes = (total_elements * 4 * 2) as f64; let q_bytes = self.entry_quantized_bytes() as f64;
418 if q_bytes < f64::EPSILON {
419 return 0.0;
420 }
421 f32_bytes / q_bytes
422 }
423
424 fn entry_quantized_bytes(&self) -> usize {
426 let elements = self.config.num_heads * self.config.head_dim;
427 let per_tensor = elements + self.config.num_heads * 4 * 2; per_tensor * 2
430 }
431
432 pub fn memory_bytes(&self) -> usize {
434 self.entries.len() * self.entry_quantized_bytes()
435 }
436}
437
438#[cfg(test)]
443mod tests {
444 use super::*;
445
446 fn make_config(bits: u8, policy: EvictionPolicy) -> KVCacheConfig {
447 KVCacheConfig {
448 max_seq_len: 8,
449 num_heads: 2,
450 head_dim: 4,
451 quantization_bits: bits,
452 eviction_policy: policy,
453 }
454 }
455
456 #[test]
459 fn test_quantize_roundtrip_4bit() {
460 let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.25, -0.5, 0.75, -0.25];
461 let qt = quantize_asymmetric(&data, 2, 4);
462 let restored = dequantize(&qt, 2);
463 for (orig, rest) in data.iter().zip(restored.iter()) {
464 assert!(
465 (orig - rest).abs() < 0.15,
466 "4-bit error too large: {orig} vs {rest}"
467 );
468 }
469 }
470
471 #[test]
472 fn test_quantize_roundtrip_3bit() {
473 let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.3, -0.7, 0.8, -0.2];
474 let qt = quantize_asymmetric(&data, 2, 3);
475 let restored = dequantize(&qt, 2);
476 for (orig, rest) in data.iter().zip(restored.iter()) {
478 assert!(
479 (orig - rest).abs() < 0.35,
480 "3-bit error too large: {orig} vs {rest}"
481 );
482 }
483 }
484
485 #[test]
486 fn test_symmetric_quantize_roundtrip() {
487 let data: Vec<f32> = vec![0.0, 0.5, -0.5, 1.0, -1.0];
488 let (qdata, scale) = quantize_symmetric(&data, 4);
489 let restored = dequantize_symmetric(&qdata, scale, 4);
490 for (orig, rest) in data.iter().zip(restored.iter()) {
491 assert!((orig - rest).abs() < 0.2, "sym roundtrip: {orig} vs {rest}");
492 }
493 }
494
495 #[test]
496 fn test_bankers_rounding() {
497 assert_eq!(round_to_nearest_even(2.5), 2.0);
498 assert_eq!(round_to_nearest_even(3.5), 4.0);
499 assert_eq!(round_to_nearest_even(4.5), 4.0);
500 assert_eq!(round_to_nearest_even(1.3), 1.0);
501 assert_eq!(round_to_nearest_even(1.7), 2.0);
502 }
503
504 #[test]
507 fn test_cache_append_and_get() {
508 let cfg = make_config(4, EvictionPolicy::H2O);
509 let mut mgr = CacheManager::new(cfg);
510 let k = vec![1.0_f32; 8];
511 let v = vec![-1.0_f32; 8];
512 mgr.append(&k, &v, 0);
513 assert_eq!(mgr.len(), 1);
514
515 let (keys, vals) = mgr.get(&[0]);
516 assert_eq!(keys.len(), 1);
517 assert_eq!(vals.len(), 1);
518 assert_eq!(keys[0].len(), 8);
519 }
520
521 #[test]
522 fn test_cache_empty() {
523 let cfg = make_config(4, EvictionPolicy::H2O);
524 let mgr = CacheManager::new(cfg);
525 assert!(mgr.is_empty());
526 assert_eq!(mgr.len(), 0);
527 let (k, v) = mgr.get(&[0]);
528 assert!(k.is_empty());
529 assert!(v.is_empty());
530 }
531
532 #[test]
533 fn test_h2o_eviction() {
534 let cfg = make_config(4, EvictionPolicy::H2O);
535 let mut mgr = CacheManager::new(cfg);
536
537 for i in 0..4 {
539 let k = vec![i as f32; 8];
540 let v = vec![i as f32; 8];
541 mgr.append(&k, &v, 0);
542 }
543 mgr.update_attention_scores(&[5.0, 1.0, 3.0, 4.0]);
545
546 mgr.evict(3);
548 assert_eq!(mgr.len(), 3);
549
550 let scores: Vec<f64> = mgr.entries.iter().map(|e| e.attention_score).collect();
553 assert!(!scores.contains(&1.0));
554 }
555
556 #[test]
557 fn test_sliding_window_eviction() {
558 let mut cfg = make_config(4, EvictionPolicy::SlidingWindow { window: 3, sink: 2 });
559 cfg.max_seq_len = 100; let mut mgr = CacheManager::new(cfg);
561
562 for i in 0..10 {
564 let k = vec![i as f32; 8];
565 let v = vec![i as f32; 8];
566 mgr.append(&k, &v, 0);
567 }
568 assert_eq!(mgr.len(), 10);
569
570 mgr.evict(5);
572 assert_eq!(mgr.len(), 5);
573
574 let seq_idxs: Vec<usize> = mgr.entries.iter().map(|e| e.seq_idx).collect();
576 assert_eq!(seq_idxs[0], 0);
577 assert_eq!(seq_idxs[1], 1);
578 assert!(seq_idxs.contains(&7));
579 assert!(seq_idxs.contains(&8));
580 assert!(seq_idxs.contains(&9));
581 }
582
583 #[test]
584 fn test_compression_ratio() {
585 let cfg = make_config(4, EvictionPolicy::H2O);
586 let mgr = CacheManager::new(cfg);
587 let ratio = mgr.compression_ratio();
588 assert!(
591 ratio > 1.0,
592 "compression ratio should be > 1.0, got {ratio}"
593 );
594 }
595
596 #[test]
597 fn test_memory_bytes() {
598 let cfg = make_config(4, EvictionPolicy::H2O);
599 let mut mgr = CacheManager::new(cfg);
600 assert_eq!(mgr.memory_bytes(), 0);
601
602 let k = vec![0.5_f32; 8];
603 let v = vec![-0.5_f32; 8];
604 mgr.append(&k, &v, 0);
605 assert!(mgr.memory_bytes() > 0);
606
607 let bytes_one = mgr.memory_bytes();
608 mgr.append(&k, &v, 0);
609 assert_eq!(mgr.memory_bytes(), bytes_one * 2);
610 }
611
612 #[test]
613 fn test_auto_eviction_on_append() {
614 let cfg = make_config(4, EvictionPolicy::H2O);
615 let mut mgr = CacheManager::new(cfg);
617 for i in 0..12 {
618 let k = vec![i as f32; 8];
619 let v = vec![i as f32; 8];
620 mgr.append(&k, &v, 0);
621 }
622 assert!(mgr.len() <= 8);
624 }
625
626 #[test]
627 fn test_pyramid_budget() {
628 let cfg = make_config(4, EvictionPolicy::PyramidKV { total_layers: 4 });
629 let mgr = CacheManager::new(cfg);
630 let b0 = mgr.pyramid_budget(0, 4);
631 let b3 = mgr.pyramid_budget(3, 4);
632 assert!(
634 b0 > b3,
635 "layer 0 budget ({b0}) should exceed layer 3 ({b3})"
636 );
637 }
638
639 #[test]
640 fn test_single_entry_operations() {
641 let cfg = make_config(3, EvictionPolicy::H2O);
642 let mut mgr = CacheManager::new(cfg);
643 let k = vec![0.42_f32; 8];
644 let v = vec![-0.42_f32; 8];
645 mgr.append(&k, &v, 0);
646
647 mgr.update_attention_scores(&[1.0]);
648 mgr.evict(1);
649 assert_eq!(mgr.len(), 1);
650
651 let (keys, vals) = mgr.get(&[0]);
652 assert_eq!(keys.len(), 1);
653 assert_eq!(vals.len(), 1);
654 }
655}