Skip to main content

tensorlogic_trustformers/kv_cache/
stats.rs

1use super::simple_cache::KvCache;
2
3/// Tracks statistics for autoregressive token generation with KV-cache.
4#[derive(Debug, Clone)]
5pub struct InferenceStats {
6    /// Total tokens generated so far.
7    pub tokens_generated: usize,
8    /// Number of cache reads that avoided recomputation.
9    pub cache_hits: usize,
10    /// Total number of attention operations executed.
11    pub total_attention_ops: usize,
12    /// Running average of cache sequence length across recorded steps.
13    pub avg_cache_len: f64,
14    /// Peak memory usage in bytes observed across all steps.
15    pub peak_memory_bytes: usize,
16}
17
18impl InferenceStats {
19    /// Create a zeroed `InferenceStats`.
20    pub fn new() -> Self {
21        Self {
22            tokens_generated: 0,
23            cache_hits: 0,
24            total_attention_ops: 0,
25            avg_cache_len: 0.0,
26            peak_memory_bytes: 0,
27        }
28    }
29
30    /// Record a single generation step given the current cache state.
31    pub fn record_step(&mut self, cache: &KvCache) {
32        self.tokens_generated += 1;
33        let cache_len = cache.current_len();
34        if cache_len > 0 {
35            self.cache_hits += 1;
36        }
37        self.total_attention_ops += 1;
38
39        // Exponential moving average of cache length.
40        let n = self.total_attention_ops as f64;
41        self.avg_cache_len = ((n - 1.0) * self.avg_cache_len + cache_len as f64) / n;
42
43        let mem = cache.memory_usage_bytes();
44        if mem > self.peak_memory_bytes {
45            self.peak_memory_bytes = mem;
46        }
47    }
48
49    /// Return a human-readable summary string.
50    pub fn summary(&self) -> String {
51        format!(
52            "InferenceStats {{ tokens_generated: {}, cache_hits: {}, total_attention_ops: {}, avg_cache_len: {:.1}, peak_memory_bytes: {} }}",
53            self.tokens_generated,
54            self.cache_hits,
55            self.total_attention_ops,
56            self.avg_cache_len,
57            self.peak_memory_bytes
58        )
59    }
60}
61
62impl Default for InferenceStats {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn test_inference_stats_record_step() {
74        let mut stats = InferenceStats::new();
75        let cache = KvCache::new(1, 2, 4, 16);
76        stats.record_step(&cache);
77        assert_eq!(
78            stats.tokens_generated, 1,
79            "tokens_generated should increment"
80        );
81    }
82
83    #[test]
84    fn test_inference_stats_summary_non_empty() {
85        let stats = InferenceStats::new();
86        let s = stats.summary();
87        assert!(!s.is_empty(), "summary must return non-empty string");
88        assert!(
89            s.contains("tokens_generated"),
90            "summary should contain field names"
91        );
92    }
93}