1use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46
47use crate::error::{Result, TrustformerError};
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct KVCacheConfig {
52 pub num_layers: usize,
54 pub num_heads: usize,
56 pub head_dim: usize,
58 pub max_seq_len: usize,
60 pub max_batch_size: usize,
62 pub enabled: bool,
64}
65
66impl KVCacheConfig {
67 pub fn new(num_layers: usize, num_heads: usize, head_dim: usize) -> Self {
69 Self {
70 num_layers,
71 num_heads,
72 head_dim,
73 max_seq_len: 2048,
74 max_batch_size: 32,
75 enabled: true,
76 }
77 }
78
79 pub fn with_max_seq_len(mut self, max_seq_len: usize) -> Self {
81 self.max_seq_len = max_seq_len;
82 self
83 }
84
85 pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
87 self.max_batch_size = max_batch_size;
88 self
89 }
90
91 pub fn with_enabled(mut self, enabled: bool) -> Self {
93 self.enabled = enabled;
94 self
95 }
96
97 pub fn validate(&self) -> Result<()> {
99 if self.num_layers == 0 {
100 return Err(TrustformerError::InvalidDimension {
101 expected: 1,
102 got: 0,
103 context: "num_layers must be > 0".to_string(),
104 });
105 }
106
107 if self.num_heads == 0 {
108 return Err(TrustformerError::InvalidDimension {
109 expected: 1,
110 got: 0,
111 context: "num_heads must be > 0".to_string(),
112 });
113 }
114
115 if self.head_dim == 0 {
116 return Err(TrustformerError::InvalidDimension {
117 expected: 1,
118 got: 0,
119 context: "head_dim must be > 0".to_string(),
120 });
121 }
122
123 if self.max_seq_len == 0 {
124 return Err(TrustformerError::InvalidDimension {
125 expected: 1,
126 got: 0,
127 context: "max_seq_len must be > 0".to_string(),
128 });
129 }
130
131 if self.max_batch_size == 0 {
132 return Err(TrustformerError::InvalidDimension {
133 expected: 1,
134 got: 0,
135 context: "max_batch_size must be > 0".to_string(),
136 });
137 }
138
139 Ok(())
140 }
141
142 pub fn memory_usage(&self) -> usize {
144 let bytes_per_element = 4;
148 let elements_per_layer =
149 self.max_batch_size * self.num_heads * self.max_seq_len * self.head_dim * 2; elements_per_layer * self.num_layers * bytes_per_element
152 }
153
154 pub fn memory_usage_mb(&self) -> f64 {
156 self.memory_usage() as f64 / (1024.0 * 1024.0)
157 }
158}
159
160#[derive(Clone, Debug)]
162pub struct CacheEntry {
163 pub keys: Vec<f32>,
165 pub values: Vec<f32>,
167 pub seq_len: usize,
169 pub batch_size: usize,
171}
172
173impl CacheEntry {
174 pub fn new(batch_size: usize, num_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
176 let capacity = batch_size * num_heads * max_seq_len * head_dim;
177 Self {
178 keys: Vec::with_capacity(capacity),
179 values: Vec::with_capacity(capacity),
180 seq_len: 0,
181 batch_size,
182 }
183 }
184
185 pub fn is_empty(&self) -> bool {
187 self.seq_len == 0
188 }
189
190 pub fn len(&self) -> usize {
192 self.seq_len
193 }
194
195 pub fn clear(&mut self) {
197 self.keys.clear();
198 self.values.clear();
199 self.seq_len = 0;
200 }
201}
202
203#[derive(Clone, Debug)]
205pub struct KVCache {
206 config: KVCacheConfig,
208 cache: HashMap<usize, CacheEntry>,
210 step: usize,
212}
213
214impl KVCache {
215 pub fn new(num_layers: usize, num_heads: usize, head_dim: usize) -> Self {
217 let config = KVCacheConfig::new(num_layers, num_heads, head_dim);
218 Self {
219 config,
220 cache: HashMap::new(),
221 step: 0,
222 }
223 }
224
225 pub fn from_config(config: KVCacheConfig) -> Result<Self> {
227 config.validate()?;
228 Ok(Self {
229 config,
230 cache: HashMap::new(),
231 step: 0,
232 })
233 }
234
235 pub fn config(&self) -> &KVCacheConfig {
237 &self.config
238 }
239
240 pub fn is_enabled(&self) -> bool {
242 self.config.enabled
243 }
244
245 pub fn step(&self) -> usize {
247 self.step
248 }
249
250 pub fn init_layer(&mut self, layer_idx: usize, batch_size: usize) -> Result<()> {
252 if layer_idx >= self.config.num_layers {
253 return Err(TrustformerError::InvalidDimension {
254 expected: self.config.num_layers,
255 got: layer_idx,
256 context: format!(
257 "layer_idx {} >= num_layers {}",
258 layer_idx, self.config.num_layers
259 ),
260 });
261 }
262
263 if batch_size > self.config.max_batch_size {
264 return Err(TrustformerError::InvalidDimension {
265 expected: self.config.max_batch_size,
266 got: batch_size,
267 context: format!(
268 "batch_size {} > max_batch_size {}",
269 batch_size, self.config.max_batch_size
270 ),
271 });
272 }
273
274 let entry = CacheEntry::new(
275 batch_size,
276 self.config.num_heads,
277 self.config.head_dim,
278 self.config.max_seq_len,
279 );
280 self.cache.insert(layer_idx, entry);
281 Ok(())
282 }
283
284 pub fn update_layer(
286 &mut self,
287 layer_idx: usize,
288 new_keys: Vec<f32>,
289 new_values: Vec<f32>,
290 ) -> Result<()> {
291 if !self.config.enabled {
292 return Ok(());
293 }
294
295 if !self.cache.contains_key(&layer_idx) {
297 let expected_size_per_token = self.config.num_heads * self.config.head_dim;
300
301 if !new_keys.len().is_multiple_of(expected_size_per_token) {
302 return Err(TrustformerError::InvalidDimension {
303 expected: expected_size_per_token,
304 got: new_keys.len(),
305 context: "keys size must be divisible by num_heads * head_dim".to_string(),
306 });
307 }
308
309 let batch_size = new_keys.len() / expected_size_per_token;
310 self.init_layer(layer_idx, batch_size)?;
311 }
312
313 let entry = self.cache.get_mut(&layer_idx).unwrap();
314
315 if new_keys.len() != new_values.len() {
317 return Err(TrustformerError::InvalidDimension {
318 expected: new_keys.len(),
319 got: new_values.len(),
320 context: "keys and values must have same size".to_string(),
321 });
322 }
323
324 entry.keys.extend_from_slice(&new_keys);
326 entry.values.extend_from_slice(&new_values);
327
328 let new_tokens =
330 new_keys.len() / (entry.batch_size * self.config.num_heads * self.config.head_dim);
331 entry.seq_len += new_tokens;
332
333 if entry.seq_len > self.config.max_seq_len {
335 return Err(TrustformerError::InvalidDimension {
336 expected: self.config.max_seq_len,
337 got: entry.seq_len,
338 context: format!(
339 "cache exceeded max_seq_len {} (current: {})",
340 self.config.max_seq_len, entry.seq_len
341 ),
342 });
343 }
344
345 Ok(())
346 }
347
348 pub fn get_layer(&self, layer_idx: usize) -> Result<(&[f32], &[f32])> {
350 let entry =
351 self.cache
352 .get(&layer_idx)
353 .ok_or_else(|| TrustformerError::InvalidDimension {
354 expected: 1,
355 got: 0,
356 context: format!("layer {} not found in cache", layer_idx),
357 })?;
358
359 Ok((&entry.keys, &entry.values))
360 }
361
362 pub fn get_seq_len(&self, layer_idx: usize) -> Result<usize> {
364 let entry =
365 self.cache
366 .get(&layer_idx)
367 .ok_or_else(|| TrustformerError::InvalidDimension {
368 expected: 1,
369 got: 0,
370 context: format!("layer {} not found in cache", layer_idx),
371 })?;
372
373 Ok(entry.seq_len)
374 }
375
376 pub fn clear_layer(&mut self, layer_idx: usize) {
378 if let Some(entry) = self.cache.get_mut(&layer_idx) {
379 entry.clear();
380 }
381 }
382
383 pub fn clear_all(&mut self) {
385 for entry in self.cache.values_mut() {
386 entry.clear();
387 }
388 self.step = 0;
389 }
390
391 pub fn next_step(&mut self) {
393 self.step += 1;
394 }
395
396 pub fn reset(&mut self) {
398 self.cache.clear();
399 self.step = 0;
400 }
401
402 pub fn num_cached_layers(&self) -> usize {
404 self.cache.len()
405 }
406
407 pub fn current_memory_usage(&self) -> usize {
409 let bytes_per_element = 4; self.cache
411 .values()
412 .map(|entry| (entry.keys.len() + entry.values.len()) * bytes_per_element)
413 .sum()
414 }
415
416 pub fn current_memory_usage_mb(&self) -> f64 {
418 self.current_memory_usage() as f64 / (1024.0 * 1024.0)
419 }
420
421 pub fn stats(&self) -> CacheStats {
423 CacheStats {
424 num_layers: self.cache.len(),
425 total_seq_len: self
426 .cache
427 .values()
428 .map(|entry| entry.seq_len)
429 .max()
430 .unwrap_or(0),
431 memory_usage_mb: self.current_memory_usage_mb(),
432 max_memory_mb: self.config.memory_usage_mb(),
433 step: self.step,
434 enabled: self.config.enabled,
435 }
436 }
437}
438
439#[derive(Clone, Debug)]
441pub struct CacheStats {
442 pub num_layers: usize,
444 pub total_seq_len: usize,
446 pub memory_usage_mb: f64,
448 pub max_memory_mb: f64,
450 pub step: usize,
452 pub enabled: bool,
454}
455
456impl CacheStats {
457 pub fn summary(&self) -> String {
459 format!(
460 "CacheStats:\n Layers: {}\n Seq len: {}\n Memory: {:.1}/{:.1} MB ({:.1}%)\n Step: {}\n Enabled: {}",
461 self.num_layers,
462 self.total_seq_len,
463 self.memory_usage_mb,
464 self.max_memory_mb,
465 if self.max_memory_mb > 0.0 {
466 (self.memory_usage_mb / self.max_memory_mb) * 100.0
467 } else {
468 0.0
469 },
470 self.step,
471 self.enabled
472 )
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_kv_cache_config_creation() {
482 let config = KVCacheConfig::new(12, 8, 64);
483 assert_eq!(config.num_layers, 12);
484 assert_eq!(config.num_heads, 8);
485 assert_eq!(config.head_dim, 64);
486 assert!(config.enabled);
487 assert!(config.validate().is_ok());
488 }
489
490 #[test]
491 fn test_config_builder() {
492 let config = KVCacheConfig::new(12, 8, 64)
493 .with_max_seq_len(4096)
494 .with_max_batch_size(16)
495 .with_enabled(false);
496
497 assert_eq!(config.max_seq_len, 4096);
498 assert_eq!(config.max_batch_size, 16);
499 assert!(!config.enabled);
500 }
501
502 #[test]
503 fn test_config_validation() {
504 let config = KVCacheConfig::new(0, 8, 64);
505 assert!(config.validate().is_err());
506
507 let config = KVCacheConfig::new(12, 0, 64);
508 assert!(config.validate().is_err());
509
510 let config = KVCacheConfig::new(12, 8, 0);
511 assert!(config.validate().is_err());
512 }
513
514 #[test]
515 fn test_memory_usage_calculation() {
516 let config = KVCacheConfig::new(12, 8, 64);
517 let memory = config.memory_usage();
518 assert!(memory > 0);
519
520 let memory_mb = config.memory_usage_mb();
521 assert!(memory_mb > 0.0);
522 }
523
524 #[test]
525 fn test_kv_cache_creation() {
526 let cache = KVCache::new(12, 8, 64);
527 assert_eq!(cache.config().num_layers, 12);
528 assert_eq!(cache.step(), 0);
529 assert!(cache.is_enabled());
530 }
531
532 #[test]
533 fn test_cache_from_config() {
534 let config = KVCacheConfig::new(12, 8, 64);
535 let cache = KVCache::from_config(config).unwrap();
536 assert_eq!(cache.config().num_layers, 12);
537 }
538
539 #[test]
540 fn test_init_layer() {
541 let mut cache = KVCache::new(12, 8, 64);
542 assert!(cache.init_layer(0, 1).is_ok());
543 assert_eq!(cache.num_cached_layers(), 1);
544 }
545
546 #[test]
547 fn test_init_layer_invalid_index() {
548 let mut cache = KVCache::new(12, 8, 64);
549 assert!(cache.init_layer(20, 1).is_err());
550 }
551
552 #[test]
553 fn test_update_and_get_layer() {
554 let mut cache = KVCache::new(12, 8, 64);
555
556 let keys = vec![0.1f32; 8 * 64];
558 let values = vec![0.2f32; 8 * 64];
559
560 cache.update_layer(0, keys.clone(), values.clone()).unwrap();
561
562 let (cached_keys, cached_values) = cache.get_layer(0).unwrap();
563 assert_eq!(cached_keys.len(), keys.len());
564 assert_eq!(cached_values.len(), values.len());
565 }
566
567 #[test]
568 fn test_update_multiple_steps() {
569 let mut cache = KVCache::new(12, 8, 64);
570
571 let keys1 = vec![0.1f32; 8 * 64];
573 let values1 = vec![0.2f32; 8 * 64];
574 cache.update_layer(0, keys1, values1).unwrap();
575 assert_eq!(cache.get_seq_len(0).unwrap(), 1);
576
577 let keys2 = vec![0.3f32; 8 * 64];
579 let values2 = vec![0.4f32; 8 * 64];
580 cache.update_layer(0, keys2, values2).unwrap();
581 assert_eq!(cache.get_seq_len(0).unwrap(), 2);
582
583 let (cached_keys, _) = cache.get_layer(0).unwrap();
585 assert_eq!(cached_keys.len(), 2 * 8 * 64);
586 }
587
588 #[test]
589 fn test_clear_layer() {
590 let mut cache = KVCache::new(12, 8, 64);
591 let keys = vec![0.1f32; 8 * 64];
592 let values = vec![0.2f32; 8 * 64];
593
594 cache.update_layer(0, keys, values).unwrap();
595 assert_eq!(cache.get_seq_len(0).unwrap(), 1);
596
597 cache.clear_layer(0);
598 assert_eq!(cache.get_seq_len(0).unwrap(), 0);
599 }
600
601 #[test]
602 fn test_clear_all() {
603 let mut cache = KVCache::new(12, 8, 64);
604 let keys = vec![0.1f32; 8 * 64];
605 let values = vec![0.2f32; 8 * 64];
606
607 cache.update_layer(0, keys.clone(), values.clone()).unwrap();
608 cache.update_layer(1, keys, values).unwrap();
609 assert_eq!(cache.num_cached_layers(), 2);
610
611 cache.clear_all();
612 assert_eq!(cache.get_seq_len(0).unwrap(), 0);
613 assert_eq!(cache.get_seq_len(1).unwrap(), 0);
614 assert_eq!(cache.step(), 0);
615 }
616
617 #[test]
618 fn test_reset() {
619 let mut cache = KVCache::new(12, 8, 64);
620 let keys = vec![0.1f32; 8 * 64];
621 let values = vec![0.2f32; 8 * 64];
622
623 cache.update_layer(0, keys, values).unwrap();
624 cache.next_step();
625 assert_eq!(cache.step(), 1);
626
627 cache.reset();
628 assert_eq!(cache.num_cached_layers(), 0);
629 assert_eq!(cache.step(), 0);
630 }
631
632 #[test]
633 fn test_next_step() {
634 let mut cache = KVCache::new(12, 8, 64);
635 assert_eq!(cache.step(), 0);
636
637 cache.next_step();
638 assert_eq!(cache.step(), 1);
639
640 cache.next_step();
641 assert_eq!(cache.step(), 2);
642 }
643
644 #[test]
645 fn test_memory_tracking() {
646 let mut cache = KVCache::new(12, 8, 64);
647 assert_eq!(cache.current_memory_usage(), 0);
648
649 let keys = vec![0.1f32; 8 * 64];
650 let values = vec![0.2f32; 8 * 64];
651 cache.update_layer(0, keys, values).unwrap();
652
653 assert!(cache.current_memory_usage() > 0);
654 assert!(cache.current_memory_usage_mb() > 0.0);
655 }
656
657 #[test]
658 fn test_cache_stats() {
659 let mut cache = KVCache::new(12, 8, 64);
660 let keys = vec![0.1f32; 8 * 64];
661 let values = vec![0.2f32; 8 * 64];
662
663 cache.update_layer(0, keys, values).unwrap();
664 cache.next_step();
665
666 let stats = cache.stats();
667 assert_eq!(stats.num_layers, 1);
668 assert_eq!(stats.total_seq_len, 1);
669 assert!(stats.memory_usage_mb > 0.0);
670 assert_eq!(stats.step, 1);
671 assert!(stats.enabled);
672 }
673
674 #[test]
675 fn test_stats_summary() {
676 let mut cache = KVCache::new(12, 8, 64);
677 let keys = vec![0.1f32; 8 * 64];
678 let values = vec![0.2f32; 8 * 64];
679
680 cache.update_layer(0, keys, values).unwrap();
681
682 let stats = cache.stats();
683 let summary = stats.summary();
684 assert!(summary.contains("Layers: 1"));
685 assert!(summary.contains("Seq len: 1"));
686 }
687
688 #[test]
689 fn test_disabled_cache() {
690 let config = KVCacheConfig::new(12, 8, 64).with_enabled(false);
691 let mut cache = KVCache::from_config(config).unwrap();
692 assert!(!cache.is_enabled());
693
694 let keys = vec![0.1f32; 8 * 64];
695 let values = vec![0.2f32; 8 * 64];
696
697 cache.update_layer(0, keys, values).unwrap();
699 assert_eq!(cache.num_cached_layers(), 0);
700 }
701
702 #[test]
703 fn test_mismatched_key_value_sizes() {
704 let mut cache = KVCache::new(12, 8, 64);
705 let keys = vec![0.1f32; 8 * 64];
706 let values = vec![0.2f32; 4 * 64]; assert!(cache.update_layer(0, keys, values).is_err());
709 }
710
711 #[test]
712 fn test_cache_entry_is_empty() {
713 let entry = CacheEntry::new(1, 8, 64, 2048);
714 assert!(entry.is_empty());
715 assert_eq!(entry.len(), 0);
716 }
717
718 #[test]
719 fn test_get_nonexistent_layer() {
720 let cache = KVCache::new(12, 8, 64);
721 assert!(cache.get_layer(0).is_err());
722 assert!(cache.get_seq_len(0).is_err());
723 }
724}