Skip to main content

tensorlogic_infer/
cache.rs

1//! Tensor caching and memory pooling for efficient execution.
2
3use std::collections::{HashMap, VecDeque};
4use std::hash::{Hash, Hasher};
5
6/// Cache key for tensor identification
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct CacheKey {
9    pub graph_id: Option<String>,
10    pub node_id: usize,
11    pub input_hash: u64,
12}
13
14impl CacheKey {
15    pub fn new(node_id: usize) -> Self {
16        CacheKey {
17            graph_id: None,
18            node_id,
19            input_hash: 0,
20        }
21    }
22
23    pub fn with_graph(mut self, graph_id: impl Into<String>) -> Self {
24        self.graph_id = Some(graph_id.into());
25        self
26    }
27
28    pub fn with_inputs<T: Hash>(mut self, inputs: &[T]) -> Self {
29        let mut hasher = std::collections::hash_map::DefaultHasher::new();
30        for input in inputs {
31            input.hash(&mut hasher);
32        }
33        self.input_hash = hasher.finish();
34        self
35    }
36}
37
38/// Cache eviction policy
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum EvictionPolicy {
41    /// Least Recently Used
42    LRU,
43    /// First In First Out
44    FIFO,
45    /// Least Frequently Used
46    LFU,
47    /// No eviction (cache grows unbounded)
48    None,
49}
50
51/// Cache statistics
52#[derive(Debug, Clone, Default)]
53pub struct CacheStats {
54    pub hits: usize,
55    pub misses: usize,
56    pub evictions: usize,
57    pub current_size: usize,
58    pub peak_size: usize,
59    pub total_bytes: usize,
60}
61
62impl CacheStats {
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    pub fn hit_rate(&self) -> f64 {
68        let total = self.hits + self.misses;
69        if total == 0 {
70            0.0
71        } else {
72            self.hits as f64 / total as f64
73        }
74    }
75
76    pub fn summary(&self) -> String {
77        format!(
78            "Cache Stats:\n\
79             - Hits: {} ({:.1}%)\n\
80             - Misses: {}\n\
81             - Evictions: {}\n\
82             - Current size: {} entries\n\
83             - Peak size: {} entries\n\
84             - Total bytes: {} ({:.2} MB)",
85            self.hits,
86            self.hit_rate() * 100.0,
87            self.misses,
88            self.evictions,
89            self.current_size,
90            self.peak_size,
91            self.total_bytes,
92            self.total_bytes as f64 / (1024.0 * 1024.0)
93        )
94    }
95}
96
97/// Cached entry metadata
98#[derive(Debug, Clone)]
99struct CacheEntry<T> {
100    value: T,
101    size_bytes: usize,
102    access_count: usize,
103    last_access: usize, // Timestamp
104}
105
106/// Tensor cache with configurable eviction policy
107pub struct TensorCache<T> {
108    cache: HashMap<CacheKey, CacheEntry<T>>,
109    eviction_policy: EvictionPolicy,
110    max_size: Option<usize>,
111    max_bytes: Option<usize>,
112    stats: CacheStats,
113    access_counter: usize,
114    access_order: VecDeque<CacheKey>,
115}
116
117impl<T: Clone> TensorCache<T> {
118    pub fn new(eviction_policy: EvictionPolicy) -> Self {
119        TensorCache {
120            cache: HashMap::new(),
121            eviction_policy,
122            max_size: None,
123            max_bytes: None,
124            stats: CacheStats::new(),
125            access_counter: 0,
126            access_order: VecDeque::new(),
127        }
128    }
129
130    pub fn with_max_size(mut self, max_entries: usize) -> Self {
131        self.max_size = Some(max_entries);
132        self
133    }
134
135    pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
136        self.max_bytes = Some(max_bytes);
137        self
138    }
139
140    /// Insert a value into the cache
141    pub fn insert(&mut self, key: CacheKey, value: T, size_bytes: usize) {
142        // Check if eviction is needed
143        while self.should_evict(size_bytes) {
144            self.evict_one();
145        }
146
147        // Insert or update entry
148        if self.cache.contains_key(&key) {
149            // Update existing entry
150            if let Some(entry) = self.cache.get_mut(&key) {
151                self.stats.total_bytes -= entry.size_bytes;
152                entry.value = value;
153                entry.size_bytes = size_bytes;
154                entry.access_count += 1;
155                entry.last_access = self.access_counter;
156                self.stats.total_bytes += size_bytes;
157            }
158        } else {
159            // Insert new entry
160            let entry = CacheEntry {
161                value,
162                size_bytes,
163                access_count: 1,
164                last_access: self.access_counter,
165            };
166
167            self.cache.insert(key.clone(), entry);
168            self.stats.current_size += 1;
169            self.stats.peak_size = self.stats.peak_size.max(self.stats.current_size);
170            self.stats.total_bytes += size_bytes;
171
172            // Track access order for FIFO/LRU
173            self.access_order.push_back(key);
174        }
175
176        self.access_counter += 1;
177    }
178
179    /// Get a value from the cache
180    pub fn get(&mut self, key: &CacheKey) -> Option<T> {
181        if let Some(entry) = self.cache.get_mut(key) {
182            self.stats.hits += 1;
183            entry.access_count += 1;
184            entry.last_access = self.access_counter;
185            self.access_counter += 1;
186
187            // Update access order for LRU
188            if self.eviction_policy == EvictionPolicy::LRU {
189                self.access_order.retain(|k| k != key);
190                self.access_order.push_back(key.clone());
191            }
192
193            Some(entry.value.clone())
194        } else {
195            self.stats.misses += 1;
196            None
197        }
198    }
199
200    /// Check if a key exists in the cache without updating access stats
201    pub fn contains(&self, key: &CacheKey) -> bool {
202        self.cache.contains_key(key)
203    }
204
205    /// Remove a specific entry
206    pub fn remove(&mut self, key: &CacheKey) -> Option<T> {
207        if let Some(entry) = self.cache.remove(key) {
208            self.stats.current_size -= 1;
209            self.stats.total_bytes -= entry.size_bytes;
210            self.access_order.retain(|k| k != key);
211            Some(entry.value)
212        } else {
213            None
214        }
215    }
216
217    /// Clear all entries
218    pub fn clear(&mut self) {
219        self.cache.clear();
220        self.access_order.clear();
221        self.stats.current_size = 0;
222        self.stats.total_bytes = 0;
223    }
224
225    /// Get cache statistics
226    pub fn stats(&self) -> &CacheStats {
227        &self.stats
228    }
229
230    /// Reset statistics (keep cached entries)
231    pub fn reset_stats(&mut self) {
232        self.stats.hits = 0;
233        self.stats.misses = 0;
234        self.stats.evictions = 0;
235    }
236
237    fn should_evict(&self, new_size_bytes: usize) -> bool {
238        if self.eviction_policy == EvictionPolicy::None {
239            return false;
240        }
241
242        let size_exceeded = self
243            .max_size
244            .map(|max| self.stats.current_size >= max)
245            .unwrap_or(false);
246
247        let bytes_exceeded = self
248            .max_bytes
249            .map(|max| self.stats.total_bytes + new_size_bytes > max)
250            .unwrap_or(false);
251
252        size_exceeded || bytes_exceeded
253    }
254
255    fn evict_one(&mut self) {
256        let key_to_evict = match self.eviction_policy {
257            EvictionPolicy::LRU => self.find_lru_key(),
258            EvictionPolicy::FIFO => self.find_fifo_key(),
259            EvictionPolicy::LFU => self.find_lfu_key(),
260            EvictionPolicy::None => return,
261        };
262
263        if let Some(key) = key_to_evict {
264            self.remove(&key);
265            self.stats.evictions += 1;
266        }
267    }
268
269    fn find_lru_key(&self) -> Option<CacheKey> {
270        self.cache
271            .iter()
272            .min_by_key(|(_, entry)| entry.last_access)
273            .map(|(key, _)| key.clone())
274    }
275
276    fn find_fifo_key(&self) -> Option<CacheKey> {
277        self.access_order.front().cloned()
278    }
279
280    fn find_lfu_key(&self) -> Option<CacheKey> {
281        self.cache
282            .iter()
283            .min_by_key(|(_, entry)| entry.access_count)
284            .map(|(key, _)| key.clone())
285    }
286
287    pub fn len(&self) -> usize {
288        self.stats.current_size
289    }
290
291    pub fn is_empty(&self) -> bool {
292        self.cache.is_empty()
293    }
294}
295
296impl<T: Clone> Default for TensorCache<T> {
297    fn default() -> Self {
298        Self::new(EvictionPolicy::LRU)
299    }
300}
301
302/// Memory pool for tensor allocation reuse
303pub struct MemoryPool<T> {
304    pools: HashMap<usize, Vec<T>>,
305    stats: PoolStats,
306    max_pool_size: Option<usize>,
307}
308
309/// Memory pool statistics
310#[derive(Debug, Clone, Default)]
311pub struct PoolStats {
312    pub allocations: usize,
313    pub reuses: usize,
314    pub releases: usize,
315    pub peak_allocations: usize,
316}
317
318impl PoolStats {
319    pub fn reuse_rate(&self) -> f64 {
320        let total = self.allocations + self.reuses;
321        if total == 0 {
322            0.0
323        } else {
324            self.reuses as f64 / total as f64
325        }
326    }
327
328    pub fn summary(&self) -> String {
329        format!(
330            "Memory Pool Stats:\n\
331             - Allocations: {}\n\
332             - Reuses: {} ({:.1}%)\n\
333             - Releases: {}\n\
334             - Peak allocations: {}",
335            self.allocations,
336            self.reuses,
337            self.reuse_rate() * 100.0,
338            self.releases,
339            self.peak_allocations
340        )
341    }
342}
343
344impl<T> MemoryPool<T> {
345    pub fn new() -> Self {
346        MemoryPool {
347            pools: HashMap::new(),
348            stats: PoolStats::default(),
349            max_pool_size: Some(100), // Default max 100 per size class
350        }
351    }
352
353    pub fn with_max_pool_size(mut self, max_size: usize) -> Self {
354        self.max_pool_size = Some(max_size);
355        self
356    }
357
358    /// Acquire a tensor from the pool or allocate new
359    pub fn acquire<F>(&mut self, size_class: usize, allocator: F) -> T
360    where
361        F: FnOnce() -> T,
362    {
363        if let Some(pool) = self.pools.get_mut(&size_class) {
364            if let Some(tensor) = pool.pop() {
365                self.stats.reuses += 1;
366                return tensor;
367            }
368        }
369
370        self.stats.allocations += 1;
371        self.stats.peak_allocations = self
372            .stats
373            .peak_allocations
374            .max(self.stats.allocations - self.stats.releases);
375
376        allocator()
377    }
378
379    /// Release a tensor back to the pool
380    pub fn release(&mut self, size_class: usize, tensor: T) {
381        let pool = self.pools.entry(size_class).or_default();
382
383        // Check pool size limit
384        if let Some(max_size) = self.max_pool_size {
385            if pool.len() >= max_size {
386                // Pool is full, drop the tensor
387                self.stats.releases += 1;
388                return;
389            }
390        }
391
392        pool.push(tensor);
393        self.stats.releases += 1;
394    }
395
396    /// Clear all pools
397    pub fn clear(&mut self) {
398        self.pools.clear();
399        self.stats = PoolStats::default();
400    }
401
402    /// Get pool statistics
403    pub fn stats(&self) -> &PoolStats {
404        &self.stats
405    }
406
407    /// Get total number of pooled tensors
408    pub fn total_pooled(&self) -> usize {
409        self.pools.values().map(|v| v.len()).sum()
410    }
411}
412
413impl<T> Default for MemoryPool<T> {
414    fn default() -> Self {
415        Self::new()
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_cache_key_creation() {
425        let key1 = CacheKey::new(0);
426        assert_eq!(key1.node_id, 0);
427        assert_eq!(key1.input_hash, 0);
428
429        let key2 = CacheKey::new(1).with_graph("graph1");
430        assert_eq!(key2.graph_id, Some("graph1".to_string()));
431
432        let inputs = vec![1, 2, 3];
433        let key3 = CacheKey::new(2).with_inputs(&inputs);
434        assert!(key3.input_hash != 0);
435    }
436
437    #[test]
438    fn test_cache_basic_operations() {
439        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
440
441        let key = CacheKey::new(0);
442        cache.insert(key.clone(), 42, 4);
443
444        assert_eq!(cache.get(&key), Some(42));
445        assert_eq!(cache.stats().hits, 1);
446        assert_eq!(cache.stats().misses, 0);
447
448        let missing_key = CacheKey::new(1);
449        assert_eq!(cache.get(&missing_key), None);
450        assert_eq!(cache.stats().misses, 1);
451    }
452
453    #[test]
454    fn test_cache_lru_eviction() {
455        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU).with_max_size(2);
456
457        cache.insert(CacheKey::new(0), 1, 4);
458        cache.insert(CacheKey::new(1), 2, 4);
459        cache.insert(CacheKey::new(2), 3, 4); // Should evict key 0
460
461        assert!(!cache.contains(&CacheKey::new(0)));
462        assert!(cache.contains(&CacheKey::new(1)));
463        assert!(cache.contains(&CacheKey::new(2)));
464        assert_eq!(cache.stats().evictions, 1);
465    }
466
467    #[test]
468    fn test_cache_fifo_eviction() {
469        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::FIFO).with_max_size(2);
470
471        cache.insert(CacheKey::new(0), 1, 4);
472        cache.insert(CacheKey::new(1), 2, 4);
473        cache.insert(CacheKey::new(2), 3, 4); // Should evict key 0 (first in)
474
475        assert!(!cache.contains(&CacheKey::new(0)));
476        assert!(cache.contains(&CacheKey::new(1)));
477        assert!(cache.contains(&CacheKey::new(2)));
478    }
479
480    #[test]
481    fn test_cache_lfu_eviction() {
482        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LFU).with_max_size(2);
483
484        cache.insert(CacheKey::new(0), 1, 4);
485        cache.insert(CacheKey::new(1), 2, 4);
486
487        // Access key 0 multiple times
488        cache.get(&CacheKey::new(0));
489        cache.get(&CacheKey::new(0));
490
491        cache.insert(CacheKey::new(2), 3, 4); // Should evict key 1 (least frequently used)
492
493        assert!(cache.contains(&CacheKey::new(0)));
494        assert!(!cache.contains(&CacheKey::new(1)));
495        assert!(cache.contains(&CacheKey::new(2)));
496    }
497
498    #[test]
499    fn test_cache_byte_limit() {
500        let mut cache: TensorCache<Vec<u8>> =
501            TensorCache::new(EvictionPolicy::LRU).with_max_bytes(20);
502
503        cache.insert(CacheKey::new(0), vec![0; 8], 8);
504        cache.insert(CacheKey::new(1), vec![0; 8], 8);
505        cache.insert(CacheKey::new(2), vec![0; 8], 8); // Should evict to stay under 20 bytes
506
507        // Should have at most 2 entries to stay under byte limit
508        assert!(cache.len() <= 2);
509        assert!(cache.stats().total_bytes <= 20);
510    }
511
512    #[test]
513    fn test_cache_stats() {
514        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
515
516        cache.insert(CacheKey::new(0), 42, 4);
517        cache.get(&CacheKey::new(0));
518        cache.get(&CacheKey::new(1));
519
520        let stats = cache.stats();
521        assert_eq!(stats.hits, 1);
522        assert_eq!(stats.misses, 1);
523        assert_eq!(stats.hit_rate(), 0.5);
524    }
525
526    #[test]
527    fn test_cache_remove() {
528        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
529
530        cache.insert(CacheKey::new(0), 42, 4);
531        assert_eq!(cache.len(), 1);
532
533        let removed = cache.remove(&CacheKey::new(0));
534        assert_eq!(removed, Some(42));
535        assert_eq!(cache.len(), 0);
536    }
537
538    #[test]
539    fn test_cache_clear() {
540        let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
541
542        cache.insert(CacheKey::new(0), 1, 4);
543        cache.insert(CacheKey::new(1), 2, 4);
544        assert_eq!(cache.len(), 2);
545
546        cache.clear();
547        assert_eq!(cache.len(), 0);
548        assert_eq!(cache.stats().total_bytes, 0);
549    }
550
551    #[test]
552    fn test_memory_pool_basic() {
553        let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new();
554
555        // Acquire new allocation
556        let vec1 = pool.acquire(100, || vec![0u8; 100]);
557        assert_eq!(vec1.len(), 100);
558        assert_eq!(pool.stats().allocations, 1);
559
560        // Release back to pool
561        pool.release(100, vec1);
562        assert_eq!(pool.stats().releases, 1);
563
564        // Reuse from pool
565        let vec2 = pool.acquire(100, || vec![0u8; 100]);
566        assert_eq!(vec2.len(), 100);
567        assert_eq!(pool.stats().reuses, 1);
568    }
569
570    #[test]
571    fn test_memory_pool_size_classes() {
572        let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new();
573
574        // Different size classes
575        let vec1 = pool.acquire(100, || vec![0u8; 100]);
576        let vec2 = pool.acquire(200, || vec![0u8; 200]);
577
578        pool.release(100, vec1);
579        pool.release(200, vec2);
580
581        assert_eq!(pool.total_pooled(), 2);
582    }
583
584    #[test]
585    fn test_memory_pool_max_size() {
586        let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new().with_max_pool_size(2);
587
588        // Fill pool
589        pool.release(100, vec![0u8; 100]);
590        pool.release(100, vec![0u8; 100]);
591        pool.release(100, vec![0u8; 100]); // Should be dropped
592
593        assert_eq!(pool.total_pooled(), 2);
594    }
595
596    #[test]
597    fn test_pool_stats() {
598        let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new();
599
600        pool.acquire(100, || vec![0u8; 100]);
601        pool.acquire(100, || vec![0u8; 100]);
602
603        let stats = pool.stats();
604        assert_eq!(stats.allocations, 2);
605        assert!(stats.reuse_rate() == 0.0);
606    }
607}