1use std::collections::{HashMap, VecDeque};
8use std::marker::PhantomData;
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct MemoKey {
22 pub expr_fingerprint: u64,
24 pub input_hash: u64,
26}
27
28impl MemoKey {
29 pub fn new(expr_fingerprint: u64, input_hash: u64) -> Self {
31 Self {
32 expr_fingerprint,
33 input_hash,
34 }
35 }
36
37 pub fn from_expr(expr: &tensorlogic_ir::TLExpr) -> Self {
39 let fp = tensorlogic_ir::expr_fingerprint(expr);
40 Self::new(fp, 0)
41 }
42
43 pub fn from_expr_and_hash(expr: &tensorlogic_ir::TLExpr, input_hash: u64) -> Self {
45 let fp = tensorlogic_ir::expr_fingerprint(expr);
46 Self::new(fp, input_hash)
47 }
48
49 pub fn hash_inputs(inputs: &[f64]) -> u64 {
55 let mut state: u64 = 14_695_981_039_346_656_037;
57 for &v in inputs {
58 let bits = v.to_bits();
60 for byte_idx in 0..8u64 {
61 let byte = (bits >> (byte_idx * 8)) & 0xFF;
62 state ^= byte;
63 state = state.wrapping_mul(1_099_511_628_211);
64 }
65 }
66 state
67 }
68}
69
70#[derive(Debug, Clone)]
80pub enum MemoEvictionPolicy {
81 Lru,
83 Fifo,
85 Ttl(Duration),
88}
89
90#[derive(Debug, Clone)]
96pub struct MemoConfig {
97 pub max_entries: usize,
99 pub ttl: Option<Duration>,
105 pub eviction: MemoEvictionPolicy,
107}
108
109impl Default for MemoConfig {
110 fn default() -> Self {
111 Self {
112 max_entries: 1024,
113 ttl: None,
114 eviction: MemoEvictionPolicy::Lru,
115 }
116 }
117}
118
119#[derive(Debug, Clone, Default)]
129pub struct MemoStats {
130 pub hits: u64,
132 pub misses: u64,
134 pub evictions: u64,
136 pub expired_on_access: u64,
138 pub current_entries: usize,
140}
141
142impl MemoStats {
143 pub fn hit_rate(&self) -> f64 {
147 let total = self.total_lookups();
148 if total == 0 {
149 0.0
150 } else {
151 self.hits as f64 / total as f64
152 }
153 }
154
155 pub fn total_lookups(&self) -> u64 {
157 self.hits + self.misses + self.expired_on_access
158 }
159
160 pub fn summary(&self) -> String {
162 format!(
163 "MemoCache: entries={} hits={} misses={} expired={} evictions={} hit_rate={:.1}%",
164 self.current_entries,
165 self.hits,
166 self.misses,
167 self.expired_on_access,
168 self.evictions,
169 self.hit_rate() * 100.0,
170 )
171 }
172}
173
174#[derive(Debug, Clone)]
180pub enum MemoLookupResult<V> {
181 Hit(V),
183 Miss,
185 Expired,
187}
188
189#[derive(Debug, Clone)]
195struct MemoEntry<V> {
196 value: V,
197 inserted_at: Instant,
198 last_accessed: Instant,
199 access_count: u64,
200}
201
202pub struct MemoCache<V: Clone> {
224 entries: HashMap<MemoKey, MemoEntry<V>>,
225 insertion_order: VecDeque<MemoKey>,
227 config: MemoConfig,
228 stats: MemoStats,
229}
230
231impl<V: Clone + std::fmt::Debug> MemoCache<V> {
232 pub fn new(config: MemoConfig) -> Self {
236 let max = config.max_entries;
237 Self {
238 entries: HashMap::with_capacity(max.min(1024)),
239 insertion_order: VecDeque::with_capacity(max.min(1024)),
240 config,
241 stats: MemoStats::default(),
242 }
243 }
244
245 pub fn with_default() -> Self {
247 Self::new(MemoConfig::default())
248 }
249
250 pub fn with_max_entries(max: usize) -> Self {
252 Self::new(MemoConfig {
253 max_entries: max,
254 ..MemoConfig::default()
255 })
256 }
257
258 pub fn get(&mut self, key: &MemoKey) -> MemoLookupResult<V> {
268 if !self.entries.contains_key(key) {
270 self.stats.misses += 1;
271 return MemoLookupResult::Miss;
272 }
273
274 if self.is_expired_by_key(key) {
276 self.entries.remove(key);
277 self.insertion_order.retain(|k| k != key);
278 self.stats.current_entries = self.entries.len();
279 self.stats.expired_on_access += 1;
280 return MemoLookupResult::Expired;
281 }
282
283 if let Some(entry) = self.entries.get_mut(key) {
285 entry.last_accessed = Instant::now();
286 entry.access_count += 1;
287 let value = entry.value.clone();
288 self.update_lru(key);
290 self.stats.hits += 1;
291 MemoLookupResult::Hit(value)
292 } else {
293 self.stats.misses += 1;
295 MemoLookupResult::Miss
296 }
297 }
298
299 pub fn insert(&mut self, key: MemoKey, value: V) {
304 if self.entries.contains_key(&key) {
306 if let Some(entry) = self.entries.get_mut(&key) {
307 entry.value = value;
308 entry.last_accessed = Instant::now();
309 entry.access_count += 1;
310 }
311 return;
312 }
313
314 if self.entries.len() >= self.config.max_entries {
316 self.evict_one();
317 }
318
319 let now = Instant::now();
320 let entry = MemoEntry {
321 value,
322 inserted_at: now,
323 last_accessed: now,
324 access_count: 1,
325 };
326
327 self.entries.insert(key.clone(), entry);
328 self.insertion_order.push_back(key);
329 self.stats.current_entries = self.entries.len();
330 }
331
332 pub fn invalidate(&mut self, key: &MemoKey) -> bool {
334 let removed = self.entries.remove(key).is_some();
335 if removed {
336 self.insertion_order.retain(|k| k != key);
337 self.stats.current_entries = self.entries.len();
338 }
339 removed
340 }
341
342 pub fn clear(&mut self) {
344 self.entries.clear();
345 self.insertion_order.clear();
346 self.stats.current_entries = 0;
347 }
348
349 pub fn stats(&self) -> &MemoStats {
351 &self.stats
352 }
353
354 pub fn len(&self) -> usize {
356 self.entries.len()
357 }
358
359 pub fn is_empty(&self) -> bool {
361 self.entries.is_empty()
362 }
363
364 fn is_expired_by_key(&self, key: &MemoKey) -> bool {
372 let ttl = match (&self.config.ttl, &self.config.eviction) {
373 (Some(d), _) => Some(*d),
374 (None, MemoEvictionPolicy::Ttl(d)) => Some(*d),
375 _ => None,
376 };
377 if let Some(duration) = ttl {
378 if let Some(entry) = self.entries.get(key) {
379 return entry.inserted_at.elapsed() > duration;
380 }
381 }
382 false
383 }
384
385 fn is_expired(&self, entry: &MemoEntry<V>) -> bool {
387 let ttl = match (&self.config.ttl, &self.config.eviction) {
388 (Some(d), _) => Some(*d),
389 (None, MemoEvictionPolicy::Ttl(d)) => Some(*d),
390 _ => None,
391 };
392 ttl.map(|d| entry.inserted_at.elapsed() > d)
393 .unwrap_or(false)
394 }
395
396 fn evict_one(&mut self) {
398 let key_to_remove = match &self.config.eviction {
399 MemoEvictionPolicy::Lru => self.find_lru_key(),
400 MemoEvictionPolicy::Fifo => self.find_fifo_key(),
401 MemoEvictionPolicy::Ttl(_) => {
402 self.find_expired_key().or_else(|| self.find_fifo_key())
404 }
405 };
406
407 if let Some(key) = key_to_remove {
408 self.entries.remove(&key);
409 self.insertion_order.retain(|k| k != &key);
410 self.stats.evictions += 1;
411 self.stats.current_entries = self.entries.len();
412 }
413 }
414
415 fn find_lru_key(&self) -> Option<MemoKey> {
422 self.insertion_order.front().cloned()
423 }
424
425 fn find_fifo_key(&self) -> Option<MemoKey> {
427 self.insertion_order.front().cloned()
428 }
429
430 fn find_expired_key(&self) -> Option<MemoKey> {
432 self.entries
433 .iter()
434 .find(|(_, e)| self.is_expired(e))
435 .map(|(k, _)| k.clone())
436 }
437
438 fn update_lru(&mut self, key: &MemoKey) {
440 if matches!(self.config.eviction, MemoEvictionPolicy::Lru) {
442 if let Some(pos) = self.insertion_order.iter().position(|k| k == key) {
443 self.insertion_order.remove(pos);
444 self.insertion_order.push_back(key.clone());
445 }
446 }
447 }
448}
449
450pub type ExprMemoCache = MemoCache<ndarray::ArrayD<f64>>;
456
457pub struct MemoCacheBuilder<V: Clone + std::fmt::Debug> {
463 config: MemoConfig,
464 _phantom: PhantomData<V>,
465}
466
467impl<V: Clone + std::fmt::Debug> MemoCacheBuilder<V> {
468 pub fn new() -> Self {
470 Self {
471 config: MemoConfig::default(),
472 _phantom: PhantomData,
473 }
474 }
475
476 pub fn max_entries(mut self, max: usize) -> Self {
478 self.config.max_entries = max;
479 self
480 }
481
482 pub fn ttl(mut self, duration: Duration) -> Self {
484 self.config.ttl = Some(duration);
485 self
486 }
487
488 pub fn eviction(mut self, policy: MemoEvictionPolicy) -> Self {
490 self.config.eviction = policy;
491 self
492 }
493
494 pub fn build(self) -> MemoCache<V> {
496 MemoCache::new(self.config)
497 }
498}
499
500impl<V: Clone + std::fmt::Debug> Default for MemoCacheBuilder<V> {
501 fn default() -> Self {
502 Self::new()
503 }
504}
505
506#[cfg(test)]
511mod tests {
512 use super::*;
513 use std::thread;
514 use tensorlogic_ir::{TLExpr, Term};
515
516 fn make_expr_a() -> TLExpr {
519 TLExpr::pred("foo", vec![Term::var("x")])
520 }
521
522 fn make_expr_b() -> TLExpr {
523 TLExpr::pred("bar", vec![Term::var("y")])
524 }
525
526 #[test]
529 fn test_memo_key_equality() {
530 let k1 = MemoKey::new(42, 99);
531 let k2 = MemoKey::new(42, 99);
532 let k3 = MemoKey::new(42, 100);
533 assert_eq!(k1, k2);
534 assert_ne!(k1, k3);
535 }
536
537 #[test]
538 fn test_memo_key_from_expr() {
539 let expr = make_expr_a();
540 let key = MemoKey::from_expr(&expr);
541 assert_eq!(key.input_hash, 0);
542 let key2 = MemoKey::from_expr(&expr);
544 assert_eq!(key.expr_fingerprint, key2.expr_fingerprint);
545 }
546
547 #[test]
548 fn test_memo_key_hash_inputs_consistent() {
549 let inputs = vec![1.0_f64, 2.0, 3.0];
550 let h1 = MemoKey::hash_inputs(&inputs);
551 let h2 = MemoKey::hash_inputs(&inputs);
552 assert_eq!(h1, h2);
553 }
554
555 #[test]
556 fn test_memo_key_hash_inputs_different() {
557 let h1 = MemoKey::hash_inputs(&[1.0, 2.0, 3.0]);
558 let h2 = MemoKey::hash_inputs(&[1.0, 2.0, 4.0]);
559 assert_ne!(h1, h2);
560 }
561
562 #[test]
565 fn test_memo_cache_miss_on_empty() {
566 let mut cache: MemoCache<i32> = MemoCache::with_default();
567 let key = MemoKey::new(1, 0);
568 assert!(matches!(cache.get(&key), MemoLookupResult::Miss));
569 assert_eq!(cache.stats().misses, 1);
570 }
571
572 #[test]
573 fn test_memo_cache_hit_after_insert() {
574 let mut cache: MemoCache<i32> = MemoCache::with_default();
575 let key = MemoKey::new(7, 0);
576 cache.insert(key.clone(), 42);
577 assert!(matches!(cache.get(&key), MemoLookupResult::Hit(42)));
578 assert_eq!(cache.stats().hits, 1);
579 }
580
581 #[test]
582 fn test_memo_cache_hit_rate_zero_initially() {
583 let cache: MemoCache<i32> = MemoCache::with_default();
584 assert_eq!(cache.stats().hit_rate(), 0.0);
585 }
586
587 #[test]
588 fn test_memo_cache_hit_rate_after_hit() {
589 let mut cache: MemoCache<i32> = MemoCache::with_default();
590 let key = MemoKey::new(1, 0);
591 cache.insert(key.clone(), 10);
592 cache.get(&key); cache.get(&MemoKey::new(2, 0)); let rate = cache.stats().hit_rate();
595 assert!((rate - 0.5).abs() < 1e-9, "expected 0.5, got {rate}");
596 }
597
598 #[test]
601 fn test_memo_cache_lru_evicts_oldest_access() {
602 let mut cache: MemoCache<i32> = MemoCache::new(MemoConfig {
604 max_entries: 2,
605 ttl: None,
606 eviction: MemoEvictionPolicy::Lru,
607 });
608
609 let k1 = MemoKey::new(1, 0);
610 let k2 = MemoKey::new(2, 0);
611 let k3 = MemoKey::new(3, 0);
612
613 cache.insert(k1.clone(), 1);
614 cache.insert(k2.clone(), 2);
615 cache.get(&k1);
617 cache.insert(k3.clone(), 3);
619
620 assert!(matches!(cache.get(&k1), MemoLookupResult::Hit(1)));
621 assert!(matches!(cache.get(&k2), MemoLookupResult::Miss));
622 assert!(matches!(cache.get(&k3), MemoLookupResult::Hit(3)));
623 assert!(cache.stats().evictions >= 1);
624 }
625
626 #[test]
627 fn test_memo_cache_fifo_evicts_first_inserted() {
628 let mut cache: MemoCache<i32> = MemoCache::new(MemoConfig {
629 max_entries: 2,
630 ttl: None,
631 eviction: MemoEvictionPolicy::Fifo,
632 });
633
634 let k1 = MemoKey::new(1, 0);
635 let k2 = MemoKey::new(2, 0);
636 let k3 = MemoKey::new(3, 0);
637
638 cache.insert(k1.clone(), 10);
639 cache.insert(k2.clone(), 20);
640 cache.get(&k1);
642 cache.insert(k3.clone(), 30); assert!(matches!(cache.get(&k1), MemoLookupResult::Miss));
645 assert!(matches!(cache.get(&k2), MemoLookupResult::Hit(20)));
646 assert!(matches!(cache.get(&k3), MemoLookupResult::Hit(30)));
647 }
648
649 #[test]
650 fn test_memo_cache_ttl_expires_entry() {
651 let ttl = Duration::from_millis(10);
652 let mut cache: MemoCache<i32> = MemoCache::new(MemoConfig {
653 max_entries: 16,
654 ttl: Some(ttl),
655 eviction: MemoEvictionPolicy::Ttl(ttl),
656 });
657
658 let key = MemoKey::new(99, 0);
659 cache.insert(key.clone(), 55);
660
661 assert!(matches!(cache.get(&key), MemoLookupResult::Hit(55)));
663
664 thread::sleep(Duration::from_millis(20));
666
667 assert!(matches!(cache.get(&key), MemoLookupResult::Expired));
668 assert_eq!(cache.stats().expired_on_access, 1);
669 }
670
671 #[test]
674 fn test_memo_cache_invalidate_key() {
675 let mut cache: MemoCache<i32> = MemoCache::with_default();
676 let key = MemoKey::new(5, 0);
677 cache.insert(key.clone(), 77);
678 assert!(cache.invalidate(&key));
679 assert!(!cache.invalidate(&key)); assert!(matches!(cache.get(&key), MemoLookupResult::Miss));
681 }
682
683 #[test]
684 fn test_memo_cache_clear() {
685 let mut cache: MemoCache<i32> = MemoCache::with_default();
686 cache.insert(MemoKey::new(1, 0), 1);
687 cache.insert(MemoKey::new(2, 0), 2);
688 assert_eq!(cache.len(), 2);
689 cache.clear();
690 assert!(cache.is_empty());
691 assert_eq!(cache.stats().current_entries, 0);
692 }
693
694 #[test]
697 fn test_memo_cache_len() {
698 let mut cache: MemoCache<i32> = MemoCache::with_default();
699 assert_eq!(cache.len(), 0);
700 cache.insert(MemoKey::new(1, 0), 10);
701 assert_eq!(cache.len(), 1);
702 cache.insert(MemoKey::new(2, 0), 20);
703 assert_eq!(cache.len(), 2);
704 }
705
706 #[test]
709 fn test_memo_stats_total_lookups() {
710 let mut cache: MemoCache<i32> = MemoCache::with_default();
711 let key = MemoKey::new(1, 0);
712 cache.insert(key.clone(), 1);
713 cache.get(&key); cache.get(&MemoKey::new(99, 0)); assert_eq!(cache.stats().total_lookups(), 2);
716 }
717
718 #[test]
719 fn test_memo_stats_summary_nonempty() {
720 let mut cache: MemoCache<i32> = MemoCache::with_default();
721 cache.insert(MemoKey::new(1, 0), 1);
722 cache.get(&MemoKey::new(1, 0));
723 let summary = cache.stats().summary();
724 assert!(summary.contains("MemoCache"));
725 assert!(summary.contains("hits=1"));
726 }
727
728 #[test]
731 fn test_memo_lookup_result_variants() {
732 let hit: MemoLookupResult<i32> = MemoLookupResult::Hit(42);
734 let miss: MemoLookupResult<i32> = MemoLookupResult::Miss;
735 let expired: MemoLookupResult<i32> = MemoLookupResult::Expired;
736
737 assert!(matches!(hit, MemoLookupResult::Hit(42)));
738 assert!(matches!(miss, MemoLookupResult::Miss));
739 assert!(matches!(expired, MemoLookupResult::Expired));
740 }
741
742 #[test]
745 fn test_memo_cache_builder_default() {
746 let cache: MemoCache<i32> = MemoCacheBuilder::new().build();
747 assert!(cache.is_empty());
748 }
749
750 #[test]
751 fn test_memo_cache_builder_custom_config() {
752 let cache: MemoCache<i32> = MemoCacheBuilder::new()
753 .max_entries(8)
754 .ttl(Duration::from_secs(60))
755 .eviction(MemoEvictionPolicy::Fifo)
756 .build();
757 assert!(cache.is_empty());
758 assert_eq!(cache.len(), 0);
759 }
760
761 #[test]
764 fn test_expr_memo_cache_type_alias() {
765 use ndarray::ArrayD;
766 let mut cache: ExprMemoCache = MemoCache::with_default();
767 let key = MemoKey::from_expr(&make_expr_a());
768 let arr = ArrayD::<f64>::zeros(ndarray::IxDyn(&[2, 3]));
769 cache.insert(key.clone(), arr.clone());
770 assert!(matches!(cache.get(&key), MemoLookupResult::Hit(_)));
771 }
772
773 #[test]
776 fn test_memo_key_from_expr_different_exprs() {
777 let ka = MemoKey::from_expr(&make_expr_a());
778 let kb = MemoKey::from_expr(&make_expr_b());
779 assert_ne!(ka.expr_fingerprint, kb.expr_fingerprint);
781 }
782
783 #[test]
784 fn test_memo_key_from_expr_and_hash() {
785 let expr = make_expr_a();
786 let h = MemoKey::hash_inputs(&[1.0, 2.0]);
787 let key = MemoKey::from_expr_and_hash(&expr, h);
788 assert_eq!(key.input_hash, h);
789 assert_ne!(key.input_hash, 0);
790 }
791}