1use crate::error::{Result, RuvLLMError};
25use crate::memory_pool::{BufferPool, BufferSize, PooledBuffer};
26use crate::types::Precision;
27use parking_lot::RwLock;
28use serde::{Deserialize, Serialize};
29use std::alloc::{alloc, dealloc, Layout};
30use std::collections::VecDeque;
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::Arc;
33
34const CACHE_LINE_SIZE: usize = 64;
36
37const NEON_ALIGNMENT: usize = 16;
39
40const POOL_BLOCK_SIZE: usize = 4096;
42
43#[derive(Debug)]
45pub struct AlignedBuffer {
46 ptr: *mut f32,
47 len: usize,
48 capacity: usize,
49 layout: Layout,
50}
51
52unsafe impl Send for AlignedBuffer {}
54unsafe impl Sync for AlignedBuffer {}
55
56impl AlignedBuffer {
57 pub fn new(capacity: usize) -> Self {
59 let size = capacity * std::mem::size_of::<f32>();
60 let layout = Layout::from_size_align(size.max(CACHE_LINE_SIZE), CACHE_LINE_SIZE)
61 .expect("Invalid layout");
62
63 let ptr = unsafe { alloc(layout) as *mut f32 };
65
66 if ptr.is_null() {
67 panic!("Failed to allocate aligned buffer");
68 }
69
70 Self {
71 ptr,
72 len: 0,
73 capacity,
74 layout,
75 }
76 }
77
78 #[inline(always)]
89 pub fn as_slice(&self) -> &[f32] {
90 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
93 }
94
95 #[inline(always)]
106 pub fn as_mut_slice(&mut self) -> &mut [f32] {
107 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
111 }
112
113 #[inline(always)]
115 pub fn extend_from_slice(&mut self, data: &[f32]) {
116 let new_len = self.len + data.len();
117 assert!(new_len <= self.capacity, "Buffer overflow");
118
119 unsafe {
121 std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr.add(self.len), data.len());
122 }
123 self.len = new_len;
124 }
125
126 #[inline(always)]
128 pub fn clear(&mut self) {
129 self.len = 0;
130 }
131
132 #[inline(always)]
134 pub fn as_ptr(&self) -> *const f32 {
135 self.ptr
136 }
137
138 #[inline(always)]
140 pub fn as_mut_ptr(&mut self) -> *mut f32 {
141 self.ptr
142 }
143
144 #[inline(always)]
146 pub fn len(&self) -> usize {
147 self.len
148 }
149
150 #[inline(always)]
152 pub fn is_empty(&self) -> bool {
153 self.len == 0
154 }
155
156 #[inline(always)]
158 pub fn capacity(&self) -> usize {
159 self.capacity
160 }
161
162 #[inline(always)]
173 pub(crate) unsafe fn set_len_unchecked(&mut self, new_len: usize) {
174 debug_assert!(
175 new_len <= self.capacity,
176 "set_len_unchecked: {} > {}",
177 new_len,
178 self.capacity
179 );
180 self.len = new_len;
181 }
182}
183
184impl Drop for AlignedBuffer {
185 fn drop(&mut self) {
186 unsafe {
188 dealloc(self.ptr as *mut u8, self.layout);
189 }
190 }
191}
192
193impl Clone for AlignedBuffer {
194 fn clone(&self) -> Self {
195 let mut new_buf = Self::new(self.capacity);
196 new_buf.extend_from_slice(self.as_slice());
197 new_buf
198 }
199}
200
201#[derive(Debug)]
203pub struct KvMemoryPool {
204 key_pool: RwLock<Vec<AlignedBuffer>>,
206 value_pool: RwLock<Vec<AlignedBuffer>>,
208 block_size: usize,
210 max_blocks: usize,
212 allocated_blocks: AtomicUsize,
214}
215
216impl KvMemoryPool {
217 pub fn new(block_size: usize, max_blocks: usize) -> Self {
219 Self {
220 key_pool: RwLock::new(Vec::with_capacity(max_blocks)),
221 value_pool: RwLock::new(Vec::with_capacity(max_blocks)),
222 block_size,
223 max_blocks,
224 allocated_blocks: AtomicUsize::new(0),
225 }
226 }
227
228 pub fn get_key_buffer(&self) -> AlignedBuffer {
230 let mut pool = self.key_pool.write();
231 if let Some(buf) = pool.pop() {
232 buf
233 } else {
234 self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
235 AlignedBuffer::new(self.block_size)
236 }
237 }
238
239 pub fn get_value_buffer(&self) -> AlignedBuffer {
241 let mut pool = self.value_pool.write();
242 if let Some(buf) = pool.pop() {
243 buf
244 } else {
245 self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
246 AlignedBuffer::new(self.block_size)
247 }
248 }
249
250 pub fn return_key_buffer(&self, mut buf: AlignedBuffer) {
252 buf.clear();
253 let mut pool = self.key_pool.write();
254 if pool.len() < self.max_blocks {
255 pool.push(buf);
256 }
257 }
259
260 pub fn return_value_buffer(&self, mut buf: AlignedBuffer) {
262 buf.clear();
263 let mut pool = self.value_pool.write();
264 if pool.len() < self.max_blocks {
265 pool.push(buf);
266 }
267 }
268
269 pub fn prewarm(&self, count: usize) {
271 let count = count.min(self.max_blocks);
272
273 let mut key_pool = self.key_pool.write();
274 let mut value_pool = self.value_pool.write();
275
276 for _ in 0..count {
277 if key_pool.len() < self.max_blocks {
278 key_pool.push(AlignedBuffer::new(self.block_size));
279 self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
280 }
281 if value_pool.len() < self.max_blocks {
282 value_pool.push(AlignedBuffer::new(self.block_size));
283 self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
284 }
285 }
286 }
287
288 pub fn stats(&self) -> PoolStats {
290 PoolStats {
291 key_pool_size: self.key_pool.read().len(),
292 value_pool_size: self.value_pool.read().len(),
293 total_allocated: self.allocated_blocks.load(Ordering::Relaxed),
294 block_size_bytes: self.block_size * std::mem::size_of::<f32>(),
295 }
296 }
297}
298
299#[derive(Debug, Clone, Default)]
301pub struct PoolStats {
302 pub key_pool_size: usize,
303 pub value_pool_size: usize,
304 pub total_allocated: usize,
305 pub block_size_bytes: usize,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct KvCacheConfig {
311 pub tail_length: usize,
313 pub tail_precision: Precision,
315 pub store_precision: Precision,
317 pub max_tokens: usize,
319 pub num_kv_heads: usize,
321 pub head_dim: usize,
323 pub migration_batch: usize,
325}
326
327impl Default for KvCacheConfig {
328 fn default() -> Self {
329 Self {
330 tail_length: 256,
331 tail_precision: Precision::FP16,
332 store_precision: Precision::Q4,
333 max_tokens: 4096,
334 num_kv_heads: 8,
335 head_dim: 128,
336 migration_batch: 64,
337 }
338 }
339}
340
341#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
343pub enum CacheTier {
344 Hot,
346 Warm,
348 Cold,
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
354pub enum CacheQuantization {
355 HighPrecisionTail {
357 tail_length: usize,
359 precision: Precision,
361 },
362 QuantizedStore {
364 precision: Precision,
366 compression_ratio: f32,
368 },
369 Hybrid {
371 tail_length: usize,
373 tail_precision: Precision,
375 store_precision: Precision,
377 },
378}
379
380impl Default for CacheQuantization {
381 fn default() -> Self {
382 Self::Hybrid {
383 tail_length: 256,
384 tail_precision: Precision::FP16,
385 store_precision: Precision::Q4,
386 }
387 }
388}
389
390#[derive(Debug, Clone)]
392struct KvPair {
393 keys: Vec<f32>,
395 values: Vec<f32>,
397 position: usize,
399}
400
401#[derive(Debug, Clone)]
403struct QuantizedKvPair {
404 keys: Vec<f32>,
406 values: Vec<f32>,
408 scale: f32,
410 zero_point: f32,
412 position: usize,
414}
415
416impl QuantizedKvPair {
417 fn from_kv_pair(pair: &KvPair, precision: Precision) -> Self {
421 let (scale, zero_point) = Self::compute_scale_and_zero(&pair.keys, precision);
423
424 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
425 let quantize = |vals: &[f32]| -> Vec<f32> { Self::quantize_neon(vals, scale, zero_point) };
426
427 #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
428 let quantize = |vals: &[f32]| -> Vec<f32> {
429 vals.iter()
430 .map(|v| ((v - zero_point) / scale).round())
431 .collect()
432 };
433
434 Self {
435 keys: quantize(&pair.keys),
436 values: quantize(&pair.values),
437 scale,
438 zero_point,
439 position: pair.position,
440 }
441 }
442
443 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
445 fn quantize_neon(values: &[f32], scale: f32, zero_point: f32) -> Vec<f32> {
446 use std::arch::aarch64::*;
447
448 let mut result = vec![0.0f32; values.len()];
449 let inv_scale = 1.0 / scale;
450
451 unsafe {
453 let inv_scale_vec = vdupq_n_f32(inv_scale);
454 let zero_vec = vdupq_n_f32(zero_point);
455
456 const UNROLL_8X: usize = 8;
457 let chunks = values.len() / UNROLL_8X;
458
459 for c in 0..chunks {
460 let base = c * UNROLL_8X;
461
462 let v0 = vld1q_f32(values.as_ptr().add(base));
464 let v1 = vld1q_f32(values.as_ptr().add(base + 4));
465
466 let sub0 = vsubq_f32(v0, zero_vec);
468 let sub1 = vsubq_f32(v1, zero_vec);
469
470 let scaled0 = vmulq_f32(sub0, inv_scale_vec);
472 let scaled1 = vmulq_f32(sub1, inv_scale_vec);
473
474 let rounded0 = vrndnq_f32(scaled0);
476 let rounded1 = vrndnq_f32(scaled1);
477
478 vst1q_f32(result.as_mut_ptr().add(base), rounded0);
480 vst1q_f32(result.as_mut_ptr().add(base + 4), rounded1);
481 }
482
483 for i in (chunks * UNROLL_8X)..values.len() {
485 result[i] = ((values[i] - zero_point) * inv_scale).round();
486 }
487 }
488
489 result
490 }
491
492 fn compute_scale_and_zero(values: &[f32], precision: Precision) -> (f32, f32) {
494 if values.is_empty() {
495 return (1.0, 0.0);
496 }
497
498 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
499 let (min_val, max_val) = unsafe { Self::minmax_neon(values) };
500
501 #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
502 let (min_val, max_val) = {
503 let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
504 let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
505 (min, max)
506 };
507
508 let range = match precision {
509 Precision::Q8 => 255.0,
510 Precision::Q4 | Precision::Q4K => 15.0,
511 _ => 255.0,
512 };
513
514 let scale = (max_val - min_val) / range;
515 let zero_point = min_val;
516
517 (scale.max(1e-8), zero_point)
518 }
519
520 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
522 unsafe fn minmax_neon(values: &[f32]) -> (f32, f32) {
523 use std::arch::aarch64::*;
524
525 let mut min_vec = vdupq_n_f32(f32::INFINITY);
526 let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
527
528 const UNROLL_8X: usize = 8;
529 let chunks = values.len() / UNROLL_8X;
530
531 for c in 0..chunks {
532 let base = c * UNROLL_8X;
533 let v0 = vld1q_f32(values.as_ptr().add(base));
534 let v1 = vld1q_f32(values.as_ptr().add(base + 4));
535
536 min_vec = vminq_f32(min_vec, vminq_f32(v0, v1));
537 max_vec = vmaxq_f32(max_vec, vmaxq_f32(v0, v1));
538 }
539
540 let min_val = vminvq_f32(min_vec);
542 let max_val = vmaxvq_f32(max_vec);
543
544 let mut final_min = min_val;
546 let mut final_max = max_val;
547 for i in (chunks * UNROLL_8X)..values.len() {
548 final_min = final_min.min(values[i]);
549 final_max = final_max.max(values[i]);
550 }
551
552 (final_min, final_max)
553 }
554
555 fn dequantize(&self) -> KvPair {
559 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
560 let dequant =
561 |vals: &[f32]| -> Vec<f32> { Self::dequantize_neon(vals, self.scale, self.zero_point) };
562
563 #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
564 let dequant = |vals: &[f32]| -> Vec<f32> {
565 vals.iter()
566 .map(|v| v * self.scale + self.zero_point)
567 .collect()
568 };
569
570 KvPair {
571 keys: dequant(&self.keys),
572 values: dequant(&self.values),
573 position: self.position,
574 }
575 }
576
577 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
581 fn dequantize_neon(quantized: &[f32], scale: f32, zero_point: f32) -> Vec<f32> {
582 use std::arch::aarch64::*;
583
584 let mut result = vec![0.0f32; quantized.len()];
585
586 unsafe {
588 let scale_vec = vdupq_n_f32(scale);
589 let zero_vec = vdupq_n_f32(zero_point);
590
591 const UNROLL_8X: usize = 8;
592 let chunks = quantized.len() / UNROLL_8X;
593
594 for c in 0..chunks {
595 let base = c * UNROLL_8X;
596
597 let q0 = vld1q_f32(quantized.as_ptr().add(base));
599 let q1 = vld1q_f32(quantized.as_ptr().add(base + 4));
600
601 let d0 = vfmaq_f32(zero_vec, q0, scale_vec);
603 let d1 = vfmaq_f32(zero_vec, q1, scale_vec);
604
605 vst1q_f32(result.as_mut_ptr().add(base), d0);
607 vst1q_f32(result.as_mut_ptr().add(base + 4), d1);
608 }
609
610 for i in (chunks * UNROLL_8X)..quantized.len() {
612 result[i] = quantized[i] * scale + zero_point;
613 }
614 }
615
616 result
617 }
618
619 #[inline(always)]
627 fn dequantize_into(&self, key_buf: &mut AlignedBuffer, value_buf: &mut AlignedBuffer) {
628 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
629 unsafe {
630 let key_new_len = key_buf.len() + self.keys.len();
632 let value_new_len = value_buf.len() + self.values.len();
633
634 assert!(
635 key_new_len <= key_buf.capacity(),
636 "Key buffer overflow: {} > {}",
637 key_new_len,
638 key_buf.capacity()
639 );
640 assert!(
641 value_new_len <= value_buf.capacity(),
642 "Value buffer overflow: {} > {}",
643 value_new_len,
644 value_buf.capacity()
645 );
646
647 Self::dequantize_neon_into(
648 &self.keys,
649 key_buf.as_mut_ptr().add(key_buf.len()),
650 self.scale,
651 self.zero_point,
652 );
653 Self::dequantize_neon_into(
654 &self.values,
655 value_buf.as_mut_ptr().add(value_buf.len()),
656 self.scale,
657 self.zero_point,
658 );
659
660 key_buf.set_len_unchecked(key_new_len);
663 value_buf.set_len_unchecked(value_new_len);
664 }
665
666 #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
667 {
668 let keys: Vec<f32> = self
669 .keys
670 .iter()
671 .map(|v| v * self.scale + self.zero_point)
672 .collect();
673 let values: Vec<f32> = self
674 .values
675 .iter()
676 .map(|v| v * self.scale + self.zero_point)
677 .collect();
678 key_buf.extend_from_slice(&keys);
679 value_buf.extend_from_slice(&values);
680 }
681 }
682
683 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
685 #[inline(always)]
686 unsafe fn dequantize_neon_into(
687 quantized: &[f32],
688 output: *mut f32,
689 scale: f32,
690 zero_point: f32,
691 ) {
692 use std::arch::aarch64::*;
693
694 let scale_vec = vdupq_n_f32(scale);
695 let zero_vec = vdupq_n_f32(zero_point);
696
697 const UNROLL_8X: usize = 8;
698 let chunks = quantized.len() / UNROLL_8X;
699
700 for c in 0..chunks {
701 let base = c * UNROLL_8X;
702
703 let q0 = vld1q_f32(quantized.as_ptr().add(base));
704 let q1 = vld1q_f32(quantized.as_ptr().add(base + 4));
705
706 let d0 = vfmaq_f32(zero_vec, q0, scale_vec);
707 let d1 = vfmaq_f32(zero_vec, q1, scale_vec);
708
709 vst1q_f32(output.add(base), d0);
710 vst1q_f32(output.add(base + 4), d1);
711 }
712
713 for i in (chunks * UNROLL_8X)..quantized.len() {
714 *output.add(i) = quantized[i] * scale + zero_point;
715 }
716 }
717}
718
719#[derive(Debug)]
726pub struct TwoTierKvCache {
727 config: KvCacheConfig,
729 tail: RwLock<VecDeque<KvPair>>,
731 store: RwLock<Vec<QuantizedKvPair>>,
733 total_tokens: AtomicUsize,
735 quantization_policy: Arc<RwLock<CacheQuantization>>,
737 memory_pool: Arc<KvMemoryPool>,
739}
740
741impl TwoTierKvCache {
742 pub fn new(config: KvCacheConfig) -> Self {
744 let quantization_policy = Arc::new(RwLock::new(CacheQuantization::Hybrid {
745 tail_length: config.tail_length,
746 tail_precision: config.tail_precision,
747 store_precision: config.store_precision,
748 }));
749
750 let stride = config.num_kv_heads * config.head_dim;
752 let block_size = stride * config.tail_length;
753
754 let max_blocks = (config.max_tokens / config.tail_length).max(4);
756 let memory_pool = Arc::new(KvMemoryPool::new(block_size, max_blocks));
757
758 memory_pool.prewarm(2);
760
761 Self {
762 config,
763 tail: RwLock::new(VecDeque::new()),
764 store: RwLock::new(Vec::new()),
765 total_tokens: AtomicUsize::new(0),
766 quantization_policy,
767 memory_pool,
768 }
769 }
770
771 pub fn with_pool(config: KvCacheConfig, pool: Arc<KvMemoryPool>) -> Self {
773 let quantization_policy = Arc::new(RwLock::new(CacheQuantization::Hybrid {
774 tail_length: config.tail_length,
775 tail_precision: config.tail_precision,
776 store_precision: config.store_precision,
777 }));
778
779 Self {
780 config,
781 tail: RwLock::new(VecDeque::new()),
782 store: RwLock::new(Vec::new()),
783 total_tokens: AtomicUsize::new(0),
784 quantization_policy,
785 memory_pool: pool,
786 }
787 }
788
789 pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<()> {
791 let stride = self.config.num_kv_heads * self.config.head_dim;
792 let num_tokens = keys.len() / stride;
793
794 if keys.len() != values.len() {
795 return Err(RuvLLMError::KvCache(
796 "Key and value lengths must match".to_string(),
797 ));
798 }
799
800 let current_tokens = self.total_tokens.load(Ordering::SeqCst);
801
802 let mut tail = self.tail.write();
804 for i in 0..num_tokens {
805 let offset = i * stride;
806 tail.push_back(KvPair {
807 keys: keys[offset..offset + stride].to_vec(),
808 values: values[offset..offset + stride].to_vec(),
809 position: current_tokens + i,
810 });
811 }
812
813 while tail.len() > self.config.tail_length {
815 let batch_size = self
816 .config
817 .migration_batch
818 .min(tail.len() - self.config.tail_length);
819
820 let to_migrate: Vec<_> = (0..batch_size).filter_map(|_| tail.pop_front()).collect();
821
822 let mut store = self.store.write();
823 for pair in to_migrate {
824 let quantized = QuantizedKvPair::from_kv_pair(&pair, self.config.store_precision);
825 store.push(quantized);
826 }
827 }
828
829 self.total_tokens.fetch_add(num_tokens, Ordering::SeqCst);
830
831 self.enforce_max_tokens()?;
833
834 Ok(())
835 }
836
837 fn enforce_max_tokens(&self) -> Result<()> {
839 let total = self.total_tokens.load(Ordering::SeqCst);
840
841 if total <= self.config.max_tokens {
842 return Ok(());
843 }
844
845 let to_evict = total - self.config.max_tokens;
846 let mut store = self.store.write();
847
848 let store_evict = to_evict.min(store.len());
850 store.drain(0..store_evict);
851
852 self.total_tokens.fetch_sub(store_evict, Ordering::SeqCst);
853
854 let remaining = to_evict - store_evict;
856 if remaining > 0 {
857 let mut tail = self.tail.write();
858 for _ in 0..remaining.min(tail.len()) {
859 tail.pop_front();
860 }
861 self.total_tokens
862 .fetch_sub(remaining.min(tail.len()), Ordering::SeqCst);
863 }
864
865 Ok(())
866 }
867
868 pub fn get_all_kv(&self) -> (Vec<f32>, Vec<f32>) {
870 let stride = self.config.num_kv_heads * self.config.head_dim;
871 let total = self.total_tokens.load(Ordering::SeqCst);
872
873 let mut all_keys = Vec::with_capacity(total * stride);
874 let mut all_values = Vec::with_capacity(total * stride);
875
876 let store = self.store.read();
878 for qpair in store.iter() {
879 let pair = qpair.dequantize();
880 all_keys.extend_from_slice(&pair.keys);
881 all_values.extend_from_slice(&pair.values);
882 }
883 drop(store);
884
885 let tail = self.tail.read();
887 for pair in tail.iter() {
888 all_keys.extend_from_slice(&pair.keys);
889 all_values.extend_from_slice(&pair.values);
890 }
891
892 (all_keys, all_values)
893 }
894
895 pub fn get_all_kv_aligned(&self) -> (AlignedBuffer, AlignedBuffer) {
900 let stride = self.config.num_kv_heads * self.config.head_dim;
901 let total = self.total_tokens.load(Ordering::SeqCst);
902
903 let mut key_buf = AlignedBuffer::new(total * stride);
905 let mut value_buf = AlignedBuffer::new(total * stride);
906
907 let store = self.store.read();
909 for qpair in store.iter() {
910 qpair.dequantize_into(&mut key_buf, &mut value_buf);
911 }
912 drop(store);
913
914 let tail = self.tail.read();
916 for pair in tail.iter() {
917 key_buf.extend_from_slice(&pair.keys);
918 value_buf.extend_from_slice(&pair.values);
919 }
920
921 (key_buf, value_buf)
922 }
923
924 pub fn memory_pool(&self) -> &Arc<KvMemoryPool> {
926 &self.memory_pool
927 }
928
929 pub fn pool_stats(&self) -> PoolStats {
931 self.memory_pool.stats()
932 }
933
934 pub fn attend(&self, query: &[f32], scale: f32) -> Result<Vec<f32>> {
938 let (keys, values) = self.get_all_kv();
939 let stride = self.config.num_kv_heads * self.config.head_dim;
940 let num_tokens = keys.len() / stride;
941
942 if num_tokens == 0 {
943 return Ok(vec![0.0; query.len()]);
944 }
945
946 let mut scores = Vec::with_capacity(num_tokens);
948
949 for t in 0..num_tokens {
950 let k_offset = t * stride;
951 let k_slice = &keys[k_offset..k_offset + stride];
952
953 let score: f32 = query
954 .iter()
955 .zip(k_slice.iter())
956 .map(|(q, k)| q * k * scale)
957 .sum();
958
959 scores.push(score);
960 }
961
962 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
964 let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
965 let sum_exp: f32 = exp_scores.iter().sum();
966 let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
967
968 let mut output = vec![0.0; stride];
970 for (t, weight) in attn_weights.iter().enumerate() {
971 let v_offset = t * stride;
972 for (i, v) in values[v_offset..v_offset + stride].iter().enumerate() {
973 output[i] += weight * v;
974 }
975 }
976
977 Ok(output)
978 }
979
980 pub fn stats(&self) -> KvCacheStats {
982 let tail = self.tail.read();
983 let store = self.store.read();
984 let stride = self.config.num_kv_heads * self.config.head_dim;
985
986 let tail_bytes = tail.len() * stride * 4 * 2; let store_bytes =
988 store.len() * stride * self.config.store_precision.bytes_per_element() as usize * 2;
989
990 KvCacheStats {
991 total_tokens: self.total_tokens.load(Ordering::SeqCst),
992 tail_tokens: tail.len(),
993 store_tokens: store.len(),
994 tail_bytes,
995 store_bytes,
996 compression_ratio: tail_bytes as f32 / store_bytes.max(1) as f32,
997 }
998 }
999
1000 pub fn clear(&self) {
1002 let mut tail = self.tail.write();
1003 let mut store = self.store.write();
1004 tail.clear();
1005 store.clear();
1006 self.total_tokens.store(0, Ordering::SeqCst);
1007 }
1008
1009 pub fn update_policy(&self, policy: CacheQuantization) {
1011 let mut current = self.quantization_policy.write();
1012 *current = policy;
1013 }
1014}
1015
1016#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1018pub struct KvCacheStats {
1019 pub total_tokens: usize,
1021 pub tail_tokens: usize,
1023 pub store_tokens: usize,
1025 pub tail_bytes: usize,
1027 pub store_bytes: usize,
1029 pub compression_ratio: f32,
1031}
1032
1033pub struct PooledKvBlock {
1042 keys: PooledBuffer,
1044 values: PooledBuffer,
1046 token_count: usize,
1048 stride: usize,
1050}
1051
1052impl PooledKvBlock {
1053 pub fn new(
1062 pool: &BufferPool,
1063 max_tokens: usize,
1064 num_heads: usize,
1065 head_dim: usize,
1066 ) -> Option<Self> {
1067 let stride = num_heads * head_dim;
1068 let bytes_needed = max_tokens * stride * std::mem::size_of::<f32>();
1069
1070 let keys = pool.acquire_for_size(bytes_needed).ok()??;
1075 let values = pool.acquire_for_size(bytes_needed).ok()??;
1076
1077 Some(Self {
1078 keys,
1079 values,
1080 token_count: 0,
1081 stride,
1082 })
1083 }
1084
1085 pub fn append(&mut self, keys: &[f32], values: &[f32]) -> usize {
1089 let capacity_tokens = self.keys.capacity() / (self.stride * std::mem::size_of::<f32>());
1090 let input_tokens = keys.len() / self.stride;
1091 let space_remaining = capacity_tokens.saturating_sub(self.token_count);
1092 let tokens_to_append = input_tokens.min(space_remaining);
1093
1094 if tokens_to_append == 0 {
1095 return 0;
1096 }
1097
1098 let elements = tokens_to_append * self.stride;
1099 let offset = self.token_count * self.stride;
1100
1101 let key_slice = self.keys.as_slice_mut::<f32>();
1103 key_slice[offset..offset + elements].copy_from_slice(&keys[..elements]);
1104
1105 let value_slice = self.values.as_slice_mut::<f32>();
1107 value_slice[offset..offset + elements].copy_from_slice(&values[..elements]);
1108
1109 self.token_count += tokens_to_append;
1110 tokens_to_append
1111 }
1112
1113 pub fn keys(&self) -> &[f32] {
1115 let elements = self.token_count * self.stride;
1116 &self.keys.as_slice::<f32>()[..elements]
1117 }
1118
1119 pub fn values(&self) -> &[f32] {
1121 let elements = self.token_count * self.stride;
1122 &self.values.as_slice::<f32>()[..elements]
1123 }
1124
1125 pub fn token_count(&self) -> usize {
1127 self.token_count
1128 }
1129
1130 pub fn is_full(&self) -> bool {
1132 let capacity_tokens = self.keys.capacity() / (self.stride * std::mem::size_of::<f32>());
1133 self.token_count >= capacity_tokens
1134 }
1135
1136 pub fn remaining_tokens(&self) -> usize {
1138 let capacity_tokens = self.keys.capacity() / (self.stride * std::mem::size_of::<f32>());
1139 capacity_tokens.saturating_sub(self.token_count)
1140 }
1141
1142 pub fn clear(&mut self) {
1144 self.token_count = 0;
1145 }
1146}
1147
1148impl std::fmt::Debug for PooledKvBlock {
1149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1150 f.debug_struct("PooledKvBlock")
1151 .field("token_count", &self.token_count)
1152 .field("stride", &self.stride)
1153 .field("key_capacity", &self.keys.capacity())
1154 .field("value_capacity", &self.values.capacity())
1155 .finish()
1156 }
1157}
1158
1159#[derive(Debug)]
1164pub struct PooledKvCache {
1165 config: KvCacheConfig,
1167 pool: BufferPool,
1169 blocks: RwLock<Vec<PooledKvBlock>>,
1171 tokens_per_block: usize,
1173 total_tokens: AtomicUsize,
1175}
1176
1177impl PooledKvCache {
1178 pub fn new(config: KvCacheConfig, pool: BufferPool, tokens_per_block: usize) -> Self {
1186 Self {
1187 config,
1188 pool,
1189 blocks: RwLock::new(Vec::new()),
1190 tokens_per_block,
1191 total_tokens: AtomicUsize::new(0),
1192 }
1193 }
1194
1195 pub fn with_new_pool(config: KvCacheConfig, tokens_per_block: usize) -> Self {
1197 let pool = BufferPool::new();
1198 Self::new(config, pool, tokens_per_block)
1199 }
1200
1201 pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<()> {
1203 let stride = self.config.num_kv_heads * self.config.head_dim;
1204 let input_tokens = keys.len() / stride;
1205
1206 if keys.len() != values.len() {
1207 return Err(RuvLLMError::KvCache(
1208 "Key and value lengths must match".to_string(),
1209 ));
1210 }
1211
1212 let mut blocks = self.blocks.write();
1213 let mut remaining_keys = keys;
1214 let mut remaining_values = values;
1215
1216 while !remaining_keys.is_empty() {
1217 let need_new_block = blocks.is_empty() || blocks.last().map_or(true, |b| b.is_full());
1219
1220 if need_new_block {
1221 let new_block = PooledKvBlock::new(
1222 &self.pool,
1223 self.tokens_per_block,
1224 self.config.num_kv_heads,
1225 self.config.head_dim,
1226 )
1227 .ok_or_else(|| {
1228 RuvLLMError::OutOfMemory("Failed to allocate KV block from pool".to_string())
1229 })?;
1230 blocks.push(new_block);
1231 }
1232
1233 let block = blocks
1236 .last_mut()
1237 .expect("blocks should be non-empty after allocation");
1238 let tokens_appended = block.append(remaining_keys, remaining_values);
1239
1240 if tokens_appended == 0 {
1241 break;
1242 }
1243
1244 let elements = tokens_appended * stride;
1245 remaining_keys = &remaining_keys[elements..];
1246 remaining_values = &remaining_values[elements..];
1247
1248 self.total_tokens
1249 .fetch_add(tokens_appended, Ordering::SeqCst);
1250 }
1251
1252 self.enforce_max_tokens(&mut blocks)?;
1254
1255 Ok(())
1256 }
1257
1258 fn enforce_max_tokens(&self, blocks: &mut Vec<PooledKvBlock>) -> Result<()> {
1260 let total = self.total_tokens.load(Ordering::SeqCst);
1261
1262 if total <= self.config.max_tokens {
1263 return Ok(());
1264 }
1265
1266 let mut to_evict = total - self.config.max_tokens;
1267
1268 while to_evict > 0 && !blocks.is_empty() {
1269 let first_block_tokens = blocks[0].token_count();
1270
1271 if first_block_tokens <= to_evict {
1272 blocks.remove(0);
1274 to_evict -= first_block_tokens;
1275 self.total_tokens
1276 .fetch_sub(first_block_tokens, Ordering::SeqCst);
1277 } else {
1278 let removed_tokens = blocks[0].token_count();
1281 blocks.remove(0);
1282 self.total_tokens
1283 .fetch_sub(removed_tokens, Ordering::SeqCst);
1284 break;
1285 }
1286 }
1287
1288 Ok(())
1289 }
1290
1291 pub fn get_all_kv(&self) -> (Vec<f32>, Vec<f32>) {
1293 let blocks = self.blocks.read();
1294 let total = self.total_tokens.load(Ordering::SeqCst);
1295 let stride = self.config.num_kv_heads * self.config.head_dim;
1296
1297 let mut all_keys = Vec::with_capacity(total * stride);
1298 let mut all_values = Vec::with_capacity(total * stride);
1299
1300 for block in blocks.iter() {
1301 all_keys.extend_from_slice(block.keys());
1302 all_values.extend_from_slice(block.values());
1303 }
1304
1305 (all_keys, all_values)
1306 }
1307
1308 pub fn stats(&self) -> PooledKvCacheStats {
1310 let blocks = self.blocks.read();
1311 let total_tokens = self.total_tokens.load(Ordering::SeqCst);
1312 let stride = self.config.num_kv_heads * self.config.head_dim;
1313
1314 PooledKvCacheStats {
1315 total_tokens,
1316 block_count: blocks.len(),
1317 tokens_per_block: self.tokens_per_block,
1318 total_bytes: total_tokens * stride * std::mem::size_of::<f32>() * 2,
1319 pool_stats: self.pool.stats(),
1320 }
1321 }
1322
1323 pub fn clear(&self) {
1325 let mut blocks = self.blocks.write();
1326 blocks.clear();
1327 self.total_tokens.store(0, Ordering::SeqCst);
1328 }
1329
1330 pub fn pool(&self) -> &BufferPool {
1332 &self.pool
1333 }
1334}
1335
1336#[derive(Debug, Clone)]
1338pub struct PooledKvCacheStats {
1339 pub total_tokens: usize,
1341 pub block_count: usize,
1343 pub tokens_per_block: usize,
1345 pub total_bytes: usize,
1347 pub pool_stats: crate::memory_pool::BufferPoolStats,
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353 use super::*;
1354
1355 #[test]
1356 fn test_kv_cache_append() {
1357 let config = KvCacheConfig {
1358 tail_length: 4,
1359 num_kv_heads: 2,
1360 head_dim: 4,
1361 migration_batch: 2,
1362 ..Default::default()
1363 };
1364
1365 let cache = TwoTierKvCache::new(config);
1366
1367 let keys = vec![1.0; 2 * 4]; let values = vec![1.0; 2 * 4];
1370 cache.append(&keys, &values).unwrap();
1371
1372 let stats = cache.stats();
1373 assert_eq!(stats.total_tokens, 1);
1374 assert_eq!(stats.tail_tokens, 1);
1375 assert_eq!(stats.store_tokens, 0);
1376 }
1377
1378 #[test]
1379 fn test_kv_cache_migration() {
1380 let config = KvCacheConfig {
1381 tail_length: 2,
1382 num_kv_heads: 2,
1383 head_dim: 4,
1384 migration_batch: 1,
1385 max_tokens: 100,
1386 ..Default::default()
1387 };
1388
1389 let cache = TwoTierKvCache::new(config);
1390
1391 for _ in 0..5 {
1393 let keys = vec![1.0; 2 * 4];
1394 let values = vec![1.0; 2 * 4];
1395 cache.append(&keys, &values).unwrap();
1396 }
1397
1398 let stats = cache.stats();
1399 assert_eq!(stats.total_tokens, 5);
1400 assert_eq!(stats.tail_tokens, 2);
1401 assert_eq!(stats.store_tokens, 3);
1402 }
1403
1404 #[test]
1405 fn test_kv_cache_attend() {
1406 let config = KvCacheConfig {
1407 tail_length: 4,
1408 num_kv_heads: 1,
1409 head_dim: 4,
1410 ..Default::default()
1411 };
1412
1413 let cache = TwoTierKvCache::new(config);
1414
1415 let keys = vec![1.0, 0.0, 0.0, 0.0];
1417 let values = vec![1.0, 2.0, 3.0, 4.0];
1418 cache.append(&keys, &values).unwrap();
1419
1420 let query = vec![1.0, 0.0, 0.0, 0.0];
1422 let output = cache.attend(&query, 1.0).unwrap();
1423
1424 assert_eq!(output.len(), 4);
1425 assert!((output[0] - 1.0).abs() < 0.1);
1427 }
1428
1429 #[test]
1430 fn test_pooled_kv_cache_basic() {
1431 let config = KvCacheConfig {
1432 tail_length: 4,
1433 num_kv_heads: 2,
1434 head_dim: 4,
1435 max_tokens: 100,
1436 ..Default::default()
1437 };
1438
1439 let cache = PooledKvCache::with_new_pool(config, 16);
1440
1441 let stride = 2 * 4; let keys = vec![1.0; stride]; let values = vec![2.0; stride];
1445 cache.append(&keys, &values).unwrap();
1446
1447 let stats = cache.stats();
1448 assert_eq!(stats.total_tokens, 1);
1449 assert_eq!(stats.block_count, 1);
1450 }
1451
1452 #[test]
1453 fn test_pooled_kv_cache_multiple_blocks() {
1454 let config = KvCacheConfig {
1455 tail_length: 4,
1456 num_kv_heads: 2,
1457 head_dim: 4,
1458 max_tokens: 100,
1459 ..Default::default()
1460 };
1461
1462 let cache = PooledKvCache::with_new_pool(config, 2);
1468
1469 let stride = 2 * 4;
1470
1471 for i in 0..5 {
1473 let keys = vec![i as f32; stride];
1474 let values = vec![(i * 2) as f32; stride];
1475 cache.append(&keys, &values).unwrap();
1476 }
1477
1478 let stats = cache.stats();
1479 assert_eq!(stats.total_tokens, 5);
1480 assert!(stats.block_count >= 1, "Should have at least 1 block");
1485 assert!(stats.block_count <= 5, "Should have at most 5 blocks");
1486
1487 let (all_keys, all_values) = cache.get_all_kv();
1489 assert_eq!(all_keys.len(), 5 * stride);
1490 assert_eq!(all_values.len(), 5 * stride);
1491
1492 assert_eq!(all_keys[0], 0.0);
1494 assert_eq!(all_keys[4 * stride], 4.0);
1496 }
1497
1498 #[test]
1499 fn test_pooled_kv_cache_pool_reuse() {
1500 let config = KvCacheConfig {
1501 tail_length: 4,
1502 num_kv_heads: 2,
1503 head_dim: 4,
1504 max_tokens: 100,
1505 ..Default::default()
1506 };
1507
1508 let pool = BufferPool::new();
1509 pool.prewarm(BufferSize::KB4, 4);
1510
1511 let cache = PooledKvCache::new(config, pool, 16);
1512
1513 let stride = 2 * 4;
1514 let keys = vec![1.0; stride];
1515 let values = vec![2.0; stride];
1516
1517 for _ in 0..3 {
1519 cache.append(&keys, &values).unwrap();
1520 cache.clear();
1521 }
1522
1523 let stats = cache.stats();
1524 assert_eq!(stats.total_tokens, 0);
1525 assert!(stats.pool_stats.returns > 0 || stats.pool_stats.hits > 0);
1526 }
1527}