Skip to main content

tensorlogic_trustformers/
kv_cache.rs

1//! Key-Value cache for efficient autoregressive inference.
2//!
3//! During autoregressive generation (e.g., text generation), transformers repeatedly
4//! compute attention over the same prefix tokens. KV-caching stores the key and value
5//! projections from previous steps, avoiding redundant computation.
6//!
7//! ## Performance Impact
8//!
9//! Without KV-cache:
10//! ```text
11//! Step 1: Compute attention for token 1
12//! Step 2: Compute attention for tokens 1,2    (redundant!)
13//! Step 3: Compute attention for tokens 1,2,3  (redundant!)
14//! ```
15//!
16//! With KV-cache:
17//! ```text
18//! Step 1: Compute K,V for token 1, cache them
19//! Step 2: Compute K,V for token 2, append to cache
20//! Step 3: Compute K,V for token 3, append to cache
21//! ```
22//!
23//! **Speedup**: ~10-100x for long sequences!
24//!
25//! ## Usage
26//!
27//! ```rust,no_run
28//! use tensorlogic_trustformers::KVCache;
29//!
30//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
31//! // Create cache for 12-layer model with 12 heads
32//! let mut cache = KVCache::new(12, 12, 64);
33//!
34//! // During generation, update cache for each layer
35//! # let (new_keys, new_values) = (vec![], vec![]);
36//! cache.update_layer(0, new_keys, new_values)?;
37//!
38//! // Retrieve cached keys/values for attention
39//! let (cached_keys, cached_values) = cache.get_layer(0)?;
40//! # Ok(())
41//! # }
42//! ```
43
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46
47use crate::error::{Result, TrustformerError};
48
49/// Configuration for KV-cache
50#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct KVCacheConfig {
52    /// Number of layers in the model
53    pub num_layers: usize,
54    /// Number of attention heads per layer
55    pub num_heads: usize,
56    /// Dimension per attention head (d_k)
57    pub head_dim: usize,
58    /// Maximum sequence length to cache
59    pub max_seq_len: usize,
60    /// Maximum batch size
61    pub max_batch_size: usize,
62    /// Whether to enable cache
63    pub enabled: bool,
64}
65
66impl KVCacheConfig {
67    /// Create a new KV-cache configuration
68    pub fn new(num_layers: usize, num_heads: usize, head_dim: usize) -> Self {
69        Self {
70            num_layers,
71            num_heads,
72            head_dim,
73            max_seq_len: 2048,
74            max_batch_size: 32,
75            enabled: true,
76        }
77    }
78
79    /// Set maximum sequence length
80    pub fn with_max_seq_len(mut self, max_seq_len: usize) -> Self {
81        self.max_seq_len = max_seq_len;
82        self
83    }
84
85    /// Set maximum batch size
86    pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
87        self.max_batch_size = max_batch_size;
88        self
89    }
90
91    /// Enable or disable cache
92    pub fn with_enabled(mut self, enabled: bool) -> Self {
93        self.enabled = enabled;
94        self
95    }
96
97    /// Validate configuration
98    pub fn validate(&self) -> Result<()> {
99        if self.num_layers == 0 {
100            return Err(TrustformerError::InvalidDimension {
101                expected: 1,
102                got: 0,
103                context: "num_layers must be > 0".to_string(),
104            });
105        }
106
107        if self.num_heads == 0 {
108            return Err(TrustformerError::InvalidDimension {
109                expected: 1,
110                got: 0,
111                context: "num_heads must be > 0".to_string(),
112            });
113        }
114
115        if self.head_dim == 0 {
116            return Err(TrustformerError::InvalidDimension {
117                expected: 1,
118                got: 0,
119                context: "head_dim must be > 0".to_string(),
120            });
121        }
122
123        if self.max_seq_len == 0 {
124            return Err(TrustformerError::InvalidDimension {
125                expected: 1,
126                got: 0,
127                context: "max_seq_len must be > 0".to_string(),
128            });
129        }
130
131        if self.max_batch_size == 0 {
132            return Err(TrustformerError::InvalidDimension {
133                expected: 1,
134                got: 0,
135                context: "max_batch_size must be > 0".to_string(),
136            });
137        }
138
139        Ok(())
140    }
141
142    /// Calculate memory usage in bytes
143    pub fn memory_usage(&self) -> usize {
144        // Each cache entry: [batch, num_heads, seq_len, head_dim]
145        // We store both keys and values
146        // Assume f32 (4 bytes per element)
147        let bytes_per_element = 4;
148        let elements_per_layer =
149            self.max_batch_size * self.num_heads * self.max_seq_len * self.head_dim * 2; // keys + values
150
151        elements_per_layer * self.num_layers * bytes_per_element
152    }
153
154    /// Human-readable memory usage
155    pub fn memory_usage_mb(&self) -> f64 {
156        self.memory_usage() as f64 / (1024.0 * 1024.0)
157    }
158}
159
160/// Cache entry for a single layer
161#[derive(Clone, Debug)]
162pub struct CacheEntry {
163    /// Cached keys: [batch, num_heads, seq_len, head_dim]
164    pub keys: Vec<f32>,
165    /// Cached values: [batch, num_heads, seq_len, head_dim]
166    pub values: Vec<f32>,
167    /// Current sequence length in cache
168    pub seq_len: usize,
169    /// Batch size
170    pub batch_size: usize,
171}
172
173impl CacheEntry {
174    /// Create a new empty cache entry
175    pub fn new(batch_size: usize, num_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
176        let capacity = batch_size * num_heads * max_seq_len * head_dim;
177        Self {
178            keys: Vec::with_capacity(capacity),
179            values: Vec::with_capacity(capacity),
180            seq_len: 0,
181            batch_size,
182        }
183    }
184
185    /// Check if cache is empty
186    pub fn is_empty(&self) -> bool {
187        self.seq_len == 0
188    }
189
190    /// Get current sequence length
191    pub fn len(&self) -> usize {
192        self.seq_len
193    }
194
195    /// Clear the cache
196    pub fn clear(&mut self) {
197        self.keys.clear();
198        self.values.clear();
199        self.seq_len = 0;
200    }
201}
202
203/// Key-Value cache for efficient transformer inference
204#[derive(Clone, Debug)]
205pub struct KVCache {
206    /// Configuration
207    config: KVCacheConfig,
208    /// Cache entries per layer
209    cache: HashMap<usize, CacheEntry>,
210    /// Current generation step
211    step: usize,
212}
213
214impl KVCache {
215    /// Create a new KV-cache
216    pub fn new(num_layers: usize, num_heads: usize, head_dim: usize) -> Self {
217        let config = KVCacheConfig::new(num_layers, num_heads, head_dim);
218        Self {
219            config,
220            cache: HashMap::new(),
221            step: 0,
222        }
223    }
224
225    /// Create KV-cache from configuration
226    pub fn from_config(config: KVCacheConfig) -> Result<Self> {
227        config.validate()?;
228        Ok(Self {
229            config,
230            cache: HashMap::new(),
231            step: 0,
232        })
233    }
234
235    /// Get configuration
236    pub fn config(&self) -> &KVCacheConfig {
237        &self.config
238    }
239
240    /// Check if cache is enabled
241    pub fn is_enabled(&self) -> bool {
242        self.config.enabled
243    }
244
245    /// Get current generation step
246    pub fn step(&self) -> usize {
247        self.step
248    }
249
250    /// Initialize cache for a layer
251    pub fn init_layer(&mut self, layer_idx: usize, batch_size: usize) -> Result<()> {
252        if layer_idx >= self.config.num_layers {
253            return Err(TrustformerError::InvalidDimension {
254                expected: self.config.num_layers,
255                got: layer_idx,
256                context: format!(
257                    "layer_idx {} >= num_layers {}",
258                    layer_idx, self.config.num_layers
259                ),
260            });
261        }
262
263        if batch_size > self.config.max_batch_size {
264            return Err(TrustformerError::InvalidDimension {
265                expected: self.config.max_batch_size,
266                got: batch_size,
267                context: format!(
268                    "batch_size {} > max_batch_size {}",
269                    batch_size, self.config.max_batch_size
270                ),
271            });
272        }
273
274        let entry = CacheEntry::new(
275            batch_size,
276            self.config.num_heads,
277            self.config.head_dim,
278            self.config.max_seq_len,
279        );
280        self.cache.insert(layer_idx, entry);
281        Ok(())
282    }
283
284    /// Update cache for a layer with new keys and values
285    pub fn update_layer(
286        &mut self,
287        layer_idx: usize,
288        new_keys: Vec<f32>,
289        new_values: Vec<f32>,
290    ) -> Result<()> {
291        if !self.config.enabled {
292            return Ok(());
293        }
294
295        // Initialize layer if not present
296        if !self.cache.contains_key(&layer_idx) {
297            // Infer batch size from keys shape
298            // Assuming keys shape: [batch, num_heads, new_seq_len, head_dim]
299            let expected_size_per_token = self.config.num_heads * self.config.head_dim;
300
301            if !new_keys.len().is_multiple_of(expected_size_per_token) {
302                return Err(TrustformerError::InvalidDimension {
303                    expected: expected_size_per_token,
304                    got: new_keys.len(),
305                    context: "keys size must be divisible by num_heads * head_dim".to_string(),
306                });
307            }
308
309            let batch_size = new_keys.len() / expected_size_per_token;
310            self.init_layer(layer_idx, batch_size)?;
311        }
312
313        let entry = self.cache.get_mut(&layer_idx).unwrap();
314
315        // Validate sizes
316        if new_keys.len() != new_values.len() {
317            return Err(TrustformerError::InvalidDimension {
318                expected: new_keys.len(),
319                got: new_values.len(),
320                context: "keys and values must have same size".to_string(),
321            });
322        }
323
324        // Append new keys and values
325        entry.keys.extend_from_slice(&new_keys);
326        entry.values.extend_from_slice(&new_values);
327
328        // Update sequence length
329        let new_tokens =
330            new_keys.len() / (entry.batch_size * self.config.num_heads * self.config.head_dim);
331        entry.seq_len += new_tokens;
332
333        // Check if we exceeded max sequence length
334        if entry.seq_len > self.config.max_seq_len {
335            return Err(TrustformerError::InvalidDimension {
336                expected: self.config.max_seq_len,
337                got: entry.seq_len,
338                context: format!(
339                    "cache exceeded max_seq_len {} (current: {})",
340                    self.config.max_seq_len, entry.seq_len
341                ),
342            });
343        }
344
345        Ok(())
346    }
347
348    /// Get cached keys and values for a layer
349    pub fn get_layer(&self, layer_idx: usize) -> Result<(&[f32], &[f32])> {
350        let entry =
351            self.cache
352                .get(&layer_idx)
353                .ok_or_else(|| TrustformerError::InvalidDimension {
354                    expected: 1,
355                    got: 0,
356                    context: format!("layer {} not found in cache", layer_idx),
357                })?;
358
359        Ok((&entry.keys, &entry.values))
360    }
361
362    /// Get sequence length for a layer
363    pub fn get_seq_len(&self, layer_idx: usize) -> Result<usize> {
364        let entry =
365            self.cache
366                .get(&layer_idx)
367                .ok_or_else(|| TrustformerError::InvalidDimension {
368                    expected: 1,
369                    got: 0,
370                    context: format!("layer {} not found in cache", layer_idx),
371                })?;
372
373        Ok(entry.seq_len)
374    }
375
376    /// Clear cache for a specific layer
377    pub fn clear_layer(&mut self, layer_idx: usize) {
378        if let Some(entry) = self.cache.get_mut(&layer_idx) {
379            entry.clear();
380        }
381    }
382
383    /// Clear all cache entries
384    pub fn clear_all(&mut self) {
385        for entry in self.cache.values_mut() {
386            entry.clear();
387        }
388        self.step = 0;
389    }
390
391    /// Increment generation step
392    pub fn next_step(&mut self) {
393        self.step += 1;
394    }
395
396    /// Reset to initial state
397    pub fn reset(&mut self) {
398        self.cache.clear();
399        self.step = 0;
400    }
401
402    /// Get number of cached layers
403    pub fn num_cached_layers(&self) -> usize {
404        self.cache.len()
405    }
406
407    /// Calculate current memory usage
408    pub fn current_memory_usage(&self) -> usize {
409        let bytes_per_element = 4; // f32
410        self.cache
411            .values()
412            .map(|entry| (entry.keys.len() + entry.values.len()) * bytes_per_element)
413            .sum()
414    }
415
416    /// Calculate memory usage in MB
417    pub fn current_memory_usage_mb(&self) -> f64 {
418        self.current_memory_usage() as f64 / (1024.0 * 1024.0)
419    }
420
421    /// Get cache statistics
422    pub fn stats(&self) -> CacheStats {
423        CacheStats {
424            num_layers: self.cache.len(),
425            total_seq_len: self
426                .cache
427                .values()
428                .map(|entry| entry.seq_len)
429                .max()
430                .unwrap_or(0),
431            memory_usage_mb: self.current_memory_usage_mb(),
432            max_memory_mb: self.config.memory_usage_mb(),
433            step: self.step,
434            enabled: self.config.enabled,
435        }
436    }
437}
438
439/// Statistics about cache usage
440#[derive(Clone, Debug)]
441pub struct CacheStats {
442    /// Number of cached layers
443    pub num_layers: usize,
444    /// Maximum sequence length across all layers
445    pub total_seq_len: usize,
446    /// Current memory usage in MB
447    pub memory_usage_mb: f64,
448    /// Maximum allowed memory in MB
449    pub max_memory_mb: f64,
450    /// Current generation step
451    pub step: usize,
452    /// Whether cache is enabled
453    pub enabled: bool,
454}
455
456impl CacheStats {
457    /// Format statistics as human-readable string
458    pub fn summary(&self) -> String {
459        format!(
460            "CacheStats:\n  Layers: {}\n  Seq len: {}\n  Memory: {:.1}/{:.1} MB ({:.1}%)\n  Step: {}\n  Enabled: {}",
461            self.num_layers,
462            self.total_seq_len,
463            self.memory_usage_mb,
464            self.max_memory_mb,
465            if self.max_memory_mb > 0.0 {
466                (self.memory_usage_mb / self.max_memory_mb) * 100.0
467            } else {
468                0.0
469            },
470            self.step,
471            self.enabled
472        )
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_kv_cache_config_creation() {
482        let config = KVCacheConfig::new(12, 8, 64);
483        assert_eq!(config.num_layers, 12);
484        assert_eq!(config.num_heads, 8);
485        assert_eq!(config.head_dim, 64);
486        assert!(config.enabled);
487        assert!(config.validate().is_ok());
488    }
489
490    #[test]
491    fn test_config_builder() {
492        let config = KVCacheConfig::new(12, 8, 64)
493            .with_max_seq_len(4096)
494            .with_max_batch_size(16)
495            .with_enabled(false);
496
497        assert_eq!(config.max_seq_len, 4096);
498        assert_eq!(config.max_batch_size, 16);
499        assert!(!config.enabled);
500    }
501
502    #[test]
503    fn test_config_validation() {
504        let config = KVCacheConfig::new(0, 8, 64);
505        assert!(config.validate().is_err());
506
507        let config = KVCacheConfig::new(12, 0, 64);
508        assert!(config.validate().is_err());
509
510        let config = KVCacheConfig::new(12, 8, 0);
511        assert!(config.validate().is_err());
512    }
513
514    #[test]
515    fn test_memory_usage_calculation() {
516        let config = KVCacheConfig::new(12, 8, 64);
517        let memory = config.memory_usage();
518        assert!(memory > 0);
519
520        let memory_mb = config.memory_usage_mb();
521        assert!(memory_mb > 0.0);
522    }
523
524    #[test]
525    fn test_kv_cache_creation() {
526        let cache = KVCache::new(12, 8, 64);
527        assert_eq!(cache.config().num_layers, 12);
528        assert_eq!(cache.step(), 0);
529        assert!(cache.is_enabled());
530    }
531
532    #[test]
533    fn test_cache_from_config() {
534        let config = KVCacheConfig::new(12, 8, 64);
535        let cache = KVCache::from_config(config).unwrap();
536        assert_eq!(cache.config().num_layers, 12);
537    }
538
539    #[test]
540    fn test_init_layer() {
541        let mut cache = KVCache::new(12, 8, 64);
542        assert!(cache.init_layer(0, 1).is_ok());
543        assert_eq!(cache.num_cached_layers(), 1);
544    }
545
546    #[test]
547    fn test_init_layer_invalid_index() {
548        let mut cache = KVCache::new(12, 8, 64);
549        assert!(cache.init_layer(20, 1).is_err());
550    }
551
552    #[test]
553    fn test_update_and_get_layer() {
554        let mut cache = KVCache::new(12, 8, 64);
555
556        // batch=1, heads=8, tokens=1, dim=64
557        let keys = vec![0.1f32; 8 * 64];
558        let values = vec![0.2f32; 8 * 64];
559
560        cache.update_layer(0, keys.clone(), values.clone()).unwrap();
561
562        let (cached_keys, cached_values) = cache.get_layer(0).unwrap();
563        assert_eq!(cached_keys.len(), keys.len());
564        assert_eq!(cached_values.len(), values.len());
565    }
566
567    #[test]
568    fn test_update_multiple_steps() {
569        let mut cache = KVCache::new(12, 8, 64);
570
571        // Step 1: Add first token
572        let keys1 = vec![0.1f32; 8 * 64];
573        let values1 = vec![0.2f32; 8 * 64];
574        cache.update_layer(0, keys1, values1).unwrap();
575        assert_eq!(cache.get_seq_len(0).unwrap(), 1);
576
577        // Step 2: Add second token
578        let keys2 = vec![0.3f32; 8 * 64];
579        let values2 = vec![0.4f32; 8 * 64];
580        cache.update_layer(0, keys2, values2).unwrap();
581        assert_eq!(cache.get_seq_len(0).unwrap(), 2);
582
583        // Verify total cached size
584        let (cached_keys, _) = cache.get_layer(0).unwrap();
585        assert_eq!(cached_keys.len(), 2 * 8 * 64);
586    }
587
588    #[test]
589    fn test_clear_layer() {
590        let mut cache = KVCache::new(12, 8, 64);
591        let keys = vec![0.1f32; 8 * 64];
592        let values = vec![0.2f32; 8 * 64];
593
594        cache.update_layer(0, keys, values).unwrap();
595        assert_eq!(cache.get_seq_len(0).unwrap(), 1);
596
597        cache.clear_layer(0);
598        assert_eq!(cache.get_seq_len(0).unwrap(), 0);
599    }
600
601    #[test]
602    fn test_clear_all() {
603        let mut cache = KVCache::new(12, 8, 64);
604        let keys = vec![0.1f32; 8 * 64];
605        let values = vec![0.2f32; 8 * 64];
606
607        cache.update_layer(0, keys.clone(), values.clone()).unwrap();
608        cache.update_layer(1, keys, values).unwrap();
609        assert_eq!(cache.num_cached_layers(), 2);
610
611        cache.clear_all();
612        assert_eq!(cache.get_seq_len(0).unwrap(), 0);
613        assert_eq!(cache.get_seq_len(1).unwrap(), 0);
614        assert_eq!(cache.step(), 0);
615    }
616
617    #[test]
618    fn test_reset() {
619        let mut cache = KVCache::new(12, 8, 64);
620        let keys = vec![0.1f32; 8 * 64];
621        let values = vec![0.2f32; 8 * 64];
622
623        cache.update_layer(0, keys, values).unwrap();
624        cache.next_step();
625        assert_eq!(cache.step(), 1);
626
627        cache.reset();
628        assert_eq!(cache.num_cached_layers(), 0);
629        assert_eq!(cache.step(), 0);
630    }
631
632    #[test]
633    fn test_next_step() {
634        let mut cache = KVCache::new(12, 8, 64);
635        assert_eq!(cache.step(), 0);
636
637        cache.next_step();
638        assert_eq!(cache.step(), 1);
639
640        cache.next_step();
641        assert_eq!(cache.step(), 2);
642    }
643
644    #[test]
645    fn test_memory_tracking() {
646        let mut cache = KVCache::new(12, 8, 64);
647        assert_eq!(cache.current_memory_usage(), 0);
648
649        let keys = vec![0.1f32; 8 * 64];
650        let values = vec![0.2f32; 8 * 64];
651        cache.update_layer(0, keys, values).unwrap();
652
653        assert!(cache.current_memory_usage() > 0);
654        assert!(cache.current_memory_usage_mb() > 0.0);
655    }
656
657    #[test]
658    fn test_cache_stats() {
659        let mut cache = KVCache::new(12, 8, 64);
660        let keys = vec![0.1f32; 8 * 64];
661        let values = vec![0.2f32; 8 * 64];
662
663        cache.update_layer(0, keys, values).unwrap();
664        cache.next_step();
665
666        let stats = cache.stats();
667        assert_eq!(stats.num_layers, 1);
668        assert_eq!(stats.total_seq_len, 1);
669        assert!(stats.memory_usage_mb > 0.0);
670        assert_eq!(stats.step, 1);
671        assert!(stats.enabled);
672    }
673
674    #[test]
675    fn test_stats_summary() {
676        let mut cache = KVCache::new(12, 8, 64);
677        let keys = vec![0.1f32; 8 * 64];
678        let values = vec![0.2f32; 8 * 64];
679
680        cache.update_layer(0, keys, values).unwrap();
681
682        let stats = cache.stats();
683        let summary = stats.summary();
684        assert!(summary.contains("Layers: 1"));
685        assert!(summary.contains("Seq len: 1"));
686    }
687
688    #[test]
689    fn test_disabled_cache() {
690        let config = KVCacheConfig::new(12, 8, 64).with_enabled(false);
691        let mut cache = KVCache::from_config(config).unwrap();
692        assert!(!cache.is_enabled());
693
694        let keys = vec![0.1f32; 8 * 64];
695        let values = vec![0.2f32; 8 * 64];
696
697        // Should succeed but not actually cache
698        cache.update_layer(0, keys, values).unwrap();
699        assert_eq!(cache.num_cached_layers(), 0);
700    }
701
702    #[test]
703    fn test_mismatched_key_value_sizes() {
704        let mut cache = KVCache::new(12, 8, 64);
705        let keys = vec![0.1f32; 8 * 64];
706        let values = vec![0.2f32; 4 * 64]; // Wrong size
707
708        assert!(cache.update_layer(0, keys, values).is_err());
709    }
710
711    #[test]
712    fn test_cache_entry_is_empty() {
713        let entry = CacheEntry::new(1, 8, 64, 2048);
714        assert!(entry.is_empty());
715        assert_eq!(entry.len(), 0);
716    }
717
718    #[test]
719    fn test_get_nonexistent_layer() {
720        let cache = KVCache::new(12, 8, 64);
721        assert!(cache.get_layer(0).is_err());
722        assert!(cache.get_seq_len(0).is_err());
723    }
724}