1use std::collections::{HashMap, VecDeque};
4use std::hash::{Hash, Hasher};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct CacheKey {
9 pub graph_id: Option<String>,
10 pub node_id: usize,
11 pub input_hash: u64,
12}
13
14impl CacheKey {
15 pub fn new(node_id: usize) -> Self {
16 CacheKey {
17 graph_id: None,
18 node_id,
19 input_hash: 0,
20 }
21 }
22
23 pub fn with_graph(mut self, graph_id: impl Into<String>) -> Self {
24 self.graph_id = Some(graph_id.into());
25 self
26 }
27
28 pub fn with_inputs<T: Hash>(mut self, inputs: &[T]) -> Self {
29 let mut hasher = std::collections::hash_map::DefaultHasher::new();
30 for input in inputs {
31 input.hash(&mut hasher);
32 }
33 self.input_hash = hasher.finish();
34 self
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum EvictionPolicy {
41 LRU,
43 FIFO,
45 LFU,
47 None,
49}
50
51#[derive(Debug, Clone, Default)]
53pub struct CacheStats {
54 pub hits: usize,
55 pub misses: usize,
56 pub evictions: usize,
57 pub current_size: usize,
58 pub peak_size: usize,
59 pub total_bytes: usize,
60}
61
62impl CacheStats {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn hit_rate(&self) -> f64 {
68 let total = self.hits + self.misses;
69 if total == 0 {
70 0.0
71 } else {
72 self.hits as f64 / total as f64
73 }
74 }
75
76 pub fn summary(&self) -> String {
77 format!(
78 "Cache Stats:\n\
79 - Hits: {} ({:.1}%)\n\
80 - Misses: {}\n\
81 - Evictions: {}\n\
82 - Current size: {} entries\n\
83 - Peak size: {} entries\n\
84 - Total bytes: {} ({:.2} MB)",
85 self.hits,
86 self.hit_rate() * 100.0,
87 self.misses,
88 self.evictions,
89 self.current_size,
90 self.peak_size,
91 self.total_bytes,
92 self.total_bytes as f64 / (1024.0 * 1024.0)
93 )
94 }
95}
96
97#[derive(Debug, Clone)]
99struct CacheEntry<T> {
100 value: T,
101 size_bytes: usize,
102 access_count: usize,
103 last_access: usize, }
105
106pub struct TensorCache<T> {
108 cache: HashMap<CacheKey, CacheEntry<T>>,
109 eviction_policy: EvictionPolicy,
110 max_size: Option<usize>,
111 max_bytes: Option<usize>,
112 stats: CacheStats,
113 access_counter: usize,
114 access_order: VecDeque<CacheKey>,
115}
116
117impl<T: Clone> TensorCache<T> {
118 pub fn new(eviction_policy: EvictionPolicy) -> Self {
119 TensorCache {
120 cache: HashMap::new(),
121 eviction_policy,
122 max_size: None,
123 max_bytes: None,
124 stats: CacheStats::new(),
125 access_counter: 0,
126 access_order: VecDeque::new(),
127 }
128 }
129
130 pub fn with_max_size(mut self, max_entries: usize) -> Self {
131 self.max_size = Some(max_entries);
132 self
133 }
134
135 pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
136 self.max_bytes = Some(max_bytes);
137 self
138 }
139
140 pub fn insert(&mut self, key: CacheKey, value: T, size_bytes: usize) {
142 while self.should_evict(size_bytes) {
144 self.evict_one();
145 }
146
147 if self.cache.contains_key(&key) {
149 if let Some(entry) = self.cache.get_mut(&key) {
151 self.stats.total_bytes -= entry.size_bytes;
152 entry.value = value;
153 entry.size_bytes = size_bytes;
154 entry.access_count += 1;
155 entry.last_access = self.access_counter;
156 self.stats.total_bytes += size_bytes;
157 }
158 } else {
159 let entry = CacheEntry {
161 value,
162 size_bytes,
163 access_count: 1,
164 last_access: self.access_counter,
165 };
166
167 self.cache.insert(key.clone(), entry);
168 self.stats.current_size += 1;
169 self.stats.peak_size = self.stats.peak_size.max(self.stats.current_size);
170 self.stats.total_bytes += size_bytes;
171
172 self.access_order.push_back(key);
174 }
175
176 self.access_counter += 1;
177 }
178
179 pub fn get(&mut self, key: &CacheKey) -> Option<T> {
181 if let Some(entry) = self.cache.get_mut(key) {
182 self.stats.hits += 1;
183 entry.access_count += 1;
184 entry.last_access = self.access_counter;
185 self.access_counter += 1;
186
187 if self.eviction_policy == EvictionPolicy::LRU {
189 self.access_order.retain(|k| k != key);
190 self.access_order.push_back(key.clone());
191 }
192
193 Some(entry.value.clone())
194 } else {
195 self.stats.misses += 1;
196 None
197 }
198 }
199
200 pub fn contains(&self, key: &CacheKey) -> bool {
202 self.cache.contains_key(key)
203 }
204
205 pub fn remove(&mut self, key: &CacheKey) -> Option<T> {
207 if let Some(entry) = self.cache.remove(key) {
208 self.stats.current_size -= 1;
209 self.stats.total_bytes -= entry.size_bytes;
210 self.access_order.retain(|k| k != key);
211 Some(entry.value)
212 } else {
213 None
214 }
215 }
216
217 pub fn clear(&mut self) {
219 self.cache.clear();
220 self.access_order.clear();
221 self.stats.current_size = 0;
222 self.stats.total_bytes = 0;
223 }
224
225 pub fn stats(&self) -> &CacheStats {
227 &self.stats
228 }
229
230 pub fn reset_stats(&mut self) {
232 self.stats.hits = 0;
233 self.stats.misses = 0;
234 self.stats.evictions = 0;
235 }
236
237 fn should_evict(&self, new_size_bytes: usize) -> bool {
238 if self.eviction_policy == EvictionPolicy::None {
239 return false;
240 }
241
242 let size_exceeded = self
243 .max_size
244 .map(|max| self.stats.current_size >= max)
245 .unwrap_or(false);
246
247 let bytes_exceeded = self
248 .max_bytes
249 .map(|max| self.stats.total_bytes + new_size_bytes > max)
250 .unwrap_or(false);
251
252 size_exceeded || bytes_exceeded
253 }
254
255 fn evict_one(&mut self) {
256 let key_to_evict = match self.eviction_policy {
257 EvictionPolicy::LRU => self.find_lru_key(),
258 EvictionPolicy::FIFO => self.find_fifo_key(),
259 EvictionPolicy::LFU => self.find_lfu_key(),
260 EvictionPolicy::None => return,
261 };
262
263 if let Some(key) = key_to_evict {
264 self.remove(&key);
265 self.stats.evictions += 1;
266 }
267 }
268
269 fn find_lru_key(&self) -> Option<CacheKey> {
270 self.cache
271 .iter()
272 .min_by_key(|(_, entry)| entry.last_access)
273 .map(|(key, _)| key.clone())
274 }
275
276 fn find_fifo_key(&self) -> Option<CacheKey> {
277 self.access_order.front().cloned()
278 }
279
280 fn find_lfu_key(&self) -> Option<CacheKey> {
281 self.cache
282 .iter()
283 .min_by_key(|(_, entry)| entry.access_count)
284 .map(|(key, _)| key.clone())
285 }
286
287 pub fn len(&self) -> usize {
288 self.stats.current_size
289 }
290
291 pub fn is_empty(&self) -> bool {
292 self.cache.is_empty()
293 }
294}
295
296impl<T: Clone> Default for TensorCache<T> {
297 fn default() -> Self {
298 Self::new(EvictionPolicy::LRU)
299 }
300}
301
302pub struct MemoryPool<T> {
304 pools: HashMap<usize, Vec<T>>,
305 stats: PoolStats,
306 max_pool_size: Option<usize>,
307}
308
309#[derive(Debug, Clone, Default)]
311pub struct PoolStats {
312 pub allocations: usize,
313 pub reuses: usize,
314 pub releases: usize,
315 pub peak_allocations: usize,
316}
317
318impl PoolStats {
319 pub fn reuse_rate(&self) -> f64 {
320 let total = self.allocations + self.reuses;
321 if total == 0 {
322 0.0
323 } else {
324 self.reuses as f64 / total as f64
325 }
326 }
327
328 pub fn summary(&self) -> String {
329 format!(
330 "Memory Pool Stats:\n\
331 - Allocations: {}\n\
332 - Reuses: {} ({:.1}%)\n\
333 - Releases: {}\n\
334 - Peak allocations: {}",
335 self.allocations,
336 self.reuses,
337 self.reuse_rate() * 100.0,
338 self.releases,
339 self.peak_allocations
340 )
341 }
342}
343
344impl<T> MemoryPool<T> {
345 pub fn new() -> Self {
346 MemoryPool {
347 pools: HashMap::new(),
348 stats: PoolStats::default(),
349 max_pool_size: Some(100), }
351 }
352
353 pub fn with_max_pool_size(mut self, max_size: usize) -> Self {
354 self.max_pool_size = Some(max_size);
355 self
356 }
357
358 pub fn acquire<F>(&mut self, size_class: usize, allocator: F) -> T
360 where
361 F: FnOnce() -> T,
362 {
363 if let Some(pool) = self.pools.get_mut(&size_class) {
364 if let Some(tensor) = pool.pop() {
365 self.stats.reuses += 1;
366 return tensor;
367 }
368 }
369
370 self.stats.allocations += 1;
371 self.stats.peak_allocations = self
372 .stats
373 .peak_allocations
374 .max(self.stats.allocations - self.stats.releases);
375
376 allocator()
377 }
378
379 pub fn release(&mut self, size_class: usize, tensor: T) {
381 let pool = self.pools.entry(size_class).or_default();
382
383 if let Some(max_size) = self.max_pool_size {
385 if pool.len() >= max_size {
386 self.stats.releases += 1;
388 return;
389 }
390 }
391
392 pool.push(tensor);
393 self.stats.releases += 1;
394 }
395
396 pub fn clear(&mut self) {
398 self.pools.clear();
399 self.stats = PoolStats::default();
400 }
401
402 pub fn stats(&self) -> &PoolStats {
404 &self.stats
405 }
406
407 pub fn total_pooled(&self) -> usize {
409 self.pools.values().map(|v| v.len()).sum()
410 }
411}
412
413impl<T> Default for MemoryPool<T> {
414 fn default() -> Self {
415 Self::new()
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_cache_key_creation() {
425 let key1 = CacheKey::new(0);
426 assert_eq!(key1.node_id, 0);
427 assert_eq!(key1.input_hash, 0);
428
429 let key2 = CacheKey::new(1).with_graph("graph1");
430 assert_eq!(key2.graph_id, Some("graph1".to_string()));
431
432 let inputs = vec![1, 2, 3];
433 let key3 = CacheKey::new(2).with_inputs(&inputs);
434 assert!(key3.input_hash != 0);
435 }
436
437 #[test]
438 fn test_cache_basic_operations() {
439 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
440
441 let key = CacheKey::new(0);
442 cache.insert(key.clone(), 42, 4);
443
444 assert_eq!(cache.get(&key), Some(42));
445 assert_eq!(cache.stats().hits, 1);
446 assert_eq!(cache.stats().misses, 0);
447
448 let missing_key = CacheKey::new(1);
449 assert_eq!(cache.get(&missing_key), None);
450 assert_eq!(cache.stats().misses, 1);
451 }
452
453 #[test]
454 fn test_cache_lru_eviction() {
455 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU).with_max_size(2);
456
457 cache.insert(CacheKey::new(0), 1, 4);
458 cache.insert(CacheKey::new(1), 2, 4);
459 cache.insert(CacheKey::new(2), 3, 4); assert!(!cache.contains(&CacheKey::new(0)));
462 assert!(cache.contains(&CacheKey::new(1)));
463 assert!(cache.contains(&CacheKey::new(2)));
464 assert_eq!(cache.stats().evictions, 1);
465 }
466
467 #[test]
468 fn test_cache_fifo_eviction() {
469 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::FIFO).with_max_size(2);
470
471 cache.insert(CacheKey::new(0), 1, 4);
472 cache.insert(CacheKey::new(1), 2, 4);
473 cache.insert(CacheKey::new(2), 3, 4); assert!(!cache.contains(&CacheKey::new(0)));
476 assert!(cache.contains(&CacheKey::new(1)));
477 assert!(cache.contains(&CacheKey::new(2)));
478 }
479
480 #[test]
481 fn test_cache_lfu_eviction() {
482 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LFU).with_max_size(2);
483
484 cache.insert(CacheKey::new(0), 1, 4);
485 cache.insert(CacheKey::new(1), 2, 4);
486
487 cache.get(&CacheKey::new(0));
489 cache.get(&CacheKey::new(0));
490
491 cache.insert(CacheKey::new(2), 3, 4); assert!(cache.contains(&CacheKey::new(0)));
494 assert!(!cache.contains(&CacheKey::new(1)));
495 assert!(cache.contains(&CacheKey::new(2)));
496 }
497
498 #[test]
499 fn test_cache_byte_limit() {
500 let mut cache: TensorCache<Vec<u8>> =
501 TensorCache::new(EvictionPolicy::LRU).with_max_bytes(20);
502
503 cache.insert(CacheKey::new(0), vec![0; 8], 8);
504 cache.insert(CacheKey::new(1), vec![0; 8], 8);
505 cache.insert(CacheKey::new(2), vec![0; 8], 8); assert!(cache.len() <= 2);
509 assert!(cache.stats().total_bytes <= 20);
510 }
511
512 #[test]
513 fn test_cache_stats() {
514 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
515
516 cache.insert(CacheKey::new(0), 42, 4);
517 cache.get(&CacheKey::new(0));
518 cache.get(&CacheKey::new(1));
519
520 let stats = cache.stats();
521 assert_eq!(stats.hits, 1);
522 assert_eq!(stats.misses, 1);
523 assert_eq!(stats.hit_rate(), 0.5);
524 }
525
526 #[test]
527 fn test_cache_remove() {
528 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
529
530 cache.insert(CacheKey::new(0), 42, 4);
531 assert_eq!(cache.len(), 1);
532
533 let removed = cache.remove(&CacheKey::new(0));
534 assert_eq!(removed, Some(42));
535 assert_eq!(cache.len(), 0);
536 }
537
538 #[test]
539 fn test_cache_clear() {
540 let mut cache: TensorCache<i32> = TensorCache::new(EvictionPolicy::LRU);
541
542 cache.insert(CacheKey::new(0), 1, 4);
543 cache.insert(CacheKey::new(1), 2, 4);
544 assert_eq!(cache.len(), 2);
545
546 cache.clear();
547 assert_eq!(cache.len(), 0);
548 assert_eq!(cache.stats().total_bytes, 0);
549 }
550
551 #[test]
552 fn test_memory_pool_basic() {
553 let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new();
554
555 let vec1 = pool.acquire(100, || vec![0u8; 100]);
557 assert_eq!(vec1.len(), 100);
558 assert_eq!(pool.stats().allocations, 1);
559
560 pool.release(100, vec1);
562 assert_eq!(pool.stats().releases, 1);
563
564 let vec2 = pool.acquire(100, || vec![0u8; 100]);
566 assert_eq!(vec2.len(), 100);
567 assert_eq!(pool.stats().reuses, 1);
568 }
569
570 #[test]
571 fn test_memory_pool_size_classes() {
572 let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new();
573
574 let vec1 = pool.acquire(100, || vec![0u8; 100]);
576 let vec2 = pool.acquire(200, || vec![0u8; 200]);
577
578 pool.release(100, vec1);
579 pool.release(200, vec2);
580
581 assert_eq!(pool.total_pooled(), 2);
582 }
583
584 #[test]
585 fn test_memory_pool_max_size() {
586 let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new().with_max_pool_size(2);
587
588 pool.release(100, vec![0u8; 100]);
590 pool.release(100, vec![0u8; 100]);
591 pool.release(100, vec![0u8; 100]); assert_eq!(pool.total_pooled(), 2);
594 }
595
596 #[test]
597 fn test_pool_stats() {
598 let mut pool: MemoryPool<Vec<u8>> = MemoryPool::new();
599
600 pool.acquire(100, || vec![0u8; 100]);
601 pool.acquire(100, || vec![0u8; 100]);
602
603 let stats = pool.stats();
604 assert_eq!(stats.allocations, 2);
605 assert!(stats.reuse_rate() == 0.0);
606 }
607}