Skip to main content

tensorlogic_compiler/
cache.rs

1//! Compilation cache for TensorLogic expressions.
2//!
3//! Provides two complementary caching mechanisms:
4//!
5//! 1. **[`CompilationCache`]** — A thread-safe, key-based cache that stores compiled
6//!    `EinsumGraph` instances keyed by a composite hash of expression structure,
7//!    compilation configuration, and domain information. Designed for concurrent use.
8//!
9//! 2. **[`LruCompilationCache`]** — A single-threaded LRU cache keyed by
10//!    [`ExprFingerprint`] (a structural content-address of an expression). Evicts the
11//!    least-recently-used entry when capacity is exceeded. Designed for use inside a
12//!    [`CachingCompiler`] wrapper.
13//!
14//! # Choosing the right cache
15//!
16//! | Scenario | Recommended type |
17//! |----------|-----------------|
18//! | Single-threaded compilation loop | [`LruCompilationCache`] / [`CachingCompiler`] |
19//! | Multi-threaded compilation (shared) | [`CompilationCache`] |
20//! | Batch compilation of related exprs | [`CachingCompiler::compile_batch`] |
21//!
22//! # Example — LRU cache via `CachingCompiler`
23//!
24//! ```rust
25//! use tensorlogic_compiler::cache::{CachingCompiler, CacheStats};
26//! use tensorlogic_compiler::compile_to_einsum;
27//! use tensorlogic_ir::{TLExpr, Term};
28//!
29//! let mut compiler = CachingCompiler::new(64, |expr| {
30//!     compile_to_einsum(expr).map_err(|e| e.to_string())
31//! });
32//!
33//! let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
34//!
35//! let _g1 = compiler.compile(&expr).expect("first compile");
36//! let _g2 = compiler.compile(&expr).expect("second compile (cache hit)");
37//!
38//! assert_eq!(compiler.cache_stats().hits, 1);
39//! assert_eq!(compiler.cache_stats().misses, 1);
40//! ```
41
42use std::collections::HashMap;
43use std::hash::{Hash, Hasher};
44use std::sync::{Arc, Mutex};
45
46use anyhow::Result;
47use tensorlogic_ir::{EinsumGraph, TLExpr};
48
49use crate::config::CompilationConfig;
50use crate::CompilerContext;
51
52// ──────────────────────────────────────────────────────────────────────────────
53// ExprFingerprint
54// ──────────────────────────────────────────────────────────────────────────────
55
56/// A compact fingerprint of a `TLExpr` structure (not values).
57///
58/// Two expressions with identical structure produce the same fingerprint.
59/// Used as a content-addressable cache key in [`LruCompilationCache`] and
60/// [`CachingCompiler`].
61///
62/// The fingerprint is derived from the `Debug` representation of the expression,
63/// which is deterministic for the same expression tree.
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65pub struct ExprFingerprint {
66    /// Serialised structural representation.
67    pub(crate) data: String,
68}
69
70impl ExprFingerprint {
71    /// Compute a fingerprint from an arbitrary string representation.
72    ///
73    /// In practice this is called with `format!("{:?}", expr)` so that the
74    /// fingerprint captures the full recursive structure of the expression.
75    pub fn compute(expr_repr: &str) -> Self {
76        ExprFingerprint {
77            data: expr_repr.to_string(),
78        }
79    }
80}
81
82impl std::fmt::Display for ExprFingerprint {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        let preview_len = self.data.len().min(32);
85        write!(f, "fp:{}", &self.data[..preview_len])
86    }
87}
88
89// ──────────────────────────────────────────────────────────────────────────────
90// CachedResult  (public, used by LruCompilationCache / CachingCompiler)
91// ──────────────────────────────────────────────────────────────────────────────
92
93/// A cached compilation result stored in an [`LruCompilationCache`].
94#[derive(Debug, Clone)]
95pub struct CachedResult {
96    /// The compiled graph.
97    pub graph: EinsumGraph,
98    /// Number of times this entry was accessed (read) via [`LruCompilationCache::get`].
99    pub hit_count: u64,
100    /// Approximate memory used by the graph (estimated as `nodes.len() * 256` bytes).
101    pub memory_bytes: usize,
102}
103
104// ──────────────────────────────────────────────────────────────────────────────
105// CacheStats  (shared by both cache types)
106// ──────────────────────────────────────────────────────────────────────────────
107
108/// Aggregate statistics for any compilation cache.
109#[derive(Debug, Clone, Default)]
110pub struct CacheStats {
111    /// Number of successful cache lookups.
112    pub hits: u64,
113    /// Number of cache lookups that resulted in a fresh compilation.
114    pub misses: u64,
115    /// Number of entries that were evicted to make room for new entries.
116    pub evictions: u64,
117    /// Current number of entries (updated after each insert/evict/clear).
118    pub current_entries: usize,
119    /// Approximate total memory occupied by all cached graphs (bytes).
120    pub total_memory_bytes: usize,
121}
122
123impl CacheStats {
124    /// Cache hit rate in the range `[0.0, 1.0]`.
125    ///
126    /// Returns `0.0` when no lookups have been performed yet.
127    pub fn hit_rate(&self) -> f64 {
128        let total = self.hits + self.misses;
129        if total == 0 {
130            0.0
131        } else {
132            self.hits as f64 / total as f64
133        }
134    }
135
136    /// Total number of cache lookups (hits + misses).
137    pub fn total_lookups(&self) -> u64 {
138        self.hits + self.misses
139    }
140}
141
142// ──────────────────────────────────────────────────────────────────────────────
143// LruCompilationCache
144// ──────────────────────────────────────────────────────────────────────────────
145
146/// LRU compilation cache with configurable capacity.
147///
148/// Stores compiled `EinsumGraph` instances keyed by [`ExprFingerprint`].
149/// When capacity is exceeded the least-recently-used entry is evicted.
150///
151/// This cache is **not** thread-safe — wrap it in `Arc<Mutex<_>>` or use
152/// [`CompilationCache`] if you need concurrent access.
153///
154/// # Example
155///
156/// ```rust
157/// use tensorlogic_compiler::cache::{LruCompilationCache, ExprFingerprint};
158/// use tensorlogic_ir::EinsumGraph;
159///
160/// let mut cache = LruCompilationCache::new(4);
161/// let fp = ExprFingerprint::compute("pred(x)");
162/// cache.insert(fp.clone(), EinsumGraph::new());
163/// assert!(cache.get(&fp).is_some());
164/// ```
165pub struct LruCompilationCache {
166    /// Maximum number of entries.
167    capacity: usize,
168    /// The cache storage.
169    entries: HashMap<ExprFingerprint, CachedResult>,
170    /// LRU order: oldest at the **front**, newest at the **back**.
171    lru_order: std::collections::VecDeque<ExprFingerprint>,
172    /// Accumulated statistics.
173    stats: CacheStats,
174}
175
176impl LruCompilationCache {
177    /// Create a new LRU cache with the given capacity (minimum 1).
178    pub fn new(capacity: usize) -> Self {
179        LruCompilationCache {
180            capacity: capacity.max(1),
181            entries: HashMap::new(),
182            lru_order: std::collections::VecDeque::new(),
183            stats: CacheStats::default(),
184        }
185    }
186
187    /// Insert a compiled result for the given fingerprint.
188    ///
189    /// If the fingerprint already exists the stored graph is updated and the
190    /// entry is promoted to the most-recently-used position.
191    ///
192    /// If the cache is at capacity the least-recently-used entry is evicted
193    /// before the new entry is inserted.
194    pub fn insert(&mut self, fp: ExprFingerprint, graph: EinsumGraph) {
195        // Estimate memory: proportional to node count.
196        let memory_bytes = graph.nodes.len() * 256;
197
198        if self.entries.contains_key(&fp) {
199            // Update the existing entry in-place.
200            if let Some(entry) = self.entries.get_mut(&fp) {
201                self.stats.total_memory_bytes = self
202                    .stats
203                    .total_memory_bytes
204                    .saturating_sub(entry.memory_bytes);
205                entry.graph = graph;
206                entry.memory_bytes = memory_bytes;
207                self.stats.total_memory_bytes += memory_bytes;
208            }
209            // Promote to most-recently-used.
210            if let Some(pos) = self.lru_order.iter().position(|x| x == &fp) {
211                self.lru_order.remove(pos);
212            }
213            self.lru_order.push_back(fp);
214        } else {
215            // Evict the LRU entry when at capacity.
216            if self.entries.len() >= self.capacity {
217                if let Some(oldest) = self.lru_order.pop_front() {
218                    if let Some(evicted) = self.entries.remove(&oldest) {
219                        self.stats.total_memory_bytes = self
220                            .stats
221                            .total_memory_bytes
222                            .saturating_sub(evicted.memory_bytes);
223                    }
224                    self.stats.evictions += 1;
225                }
226            }
227            self.stats.total_memory_bytes += memory_bytes;
228            self.lru_order.push_back(fp.clone());
229            self.entries.insert(
230                fp,
231                CachedResult {
232                    graph,
233                    hit_count: 0,
234                    memory_bytes,
235                },
236            );
237        }
238        self.stats.current_entries = self.entries.len();
239    }
240
241    /// Look up a fingerprint.
242    ///
243    /// On a hit the entry is promoted to the most-recently-used position,
244    /// its `hit_count` is incremented, and a reference to it is returned.
245    /// On a miss `None` is returned.
246    pub fn get(&mut self, fp: &ExprFingerprint) -> Option<&CachedResult> {
247        if self.entries.contains_key(fp) {
248            // Promote to most-recently-used.
249            if let Some(pos) = self.lru_order.iter().position(|x| x == fp) {
250                self.lru_order.remove(pos);
251            }
252            self.lru_order.push_back(fp.clone());
253            // Increment hit counter.
254            if let Some(entry) = self.entries.get_mut(fp) {
255                entry.hit_count += 1;
256            }
257            self.stats.hits += 1;
258            self.entries.get(fp)
259        } else {
260            self.stats.misses += 1;
261            None
262        }
263    }
264
265    /// Check if a fingerprint is present **without** updating LRU order or stats.
266    pub fn contains(&self, fp: &ExprFingerprint) -> bool {
267        self.entries.contains_key(fp)
268    }
269
270    /// Remove a specific entry by fingerprint.
271    ///
272    /// Returns `true` if the entry existed and was removed, `false` otherwise.
273    pub fn invalidate(&mut self, fp: &ExprFingerprint) -> bool {
274        if let Some(evicted) = self.entries.remove(fp) {
275            self.stats.total_memory_bytes = self
276                .stats
277                .total_memory_bytes
278                .saturating_sub(evicted.memory_bytes);
279            if let Some(pos) = self.lru_order.iter().position(|x| x == fp) {
280                self.lru_order.remove(pos);
281            }
282            self.stats.current_entries = self.entries.len();
283            true
284        } else {
285            false
286        }
287    }
288
289    /// Clear all cached entries, resetting memory accounting.
290    ///
291    /// Statistics counters (hits, misses, evictions) are **not** reset.
292    pub fn clear(&mut self) {
293        self.entries.clear();
294        self.lru_order.clear();
295        self.stats.current_entries = 0;
296        self.stats.total_memory_bytes = 0;
297    }
298
299    /// Reference to the current statistics snapshot.
300    pub fn stats(&self) -> &CacheStats {
301        &self.stats
302    }
303
304    /// Number of cached entries.
305    pub fn len(&self) -> usize {
306        self.entries.len()
307    }
308
309    /// Returns `true` when the cache contains no entries.
310    pub fn is_empty(&self) -> bool {
311        self.entries.is_empty()
312    }
313
314    /// The maximum number of entries this cache can hold before eviction.
315    pub fn capacity(&self) -> usize {
316        self.capacity
317    }
318}
319
320impl Default for LruCompilationCache {
321    fn default() -> Self {
322        Self::new(256)
323    }
324}
325
326// ──────────────────────────────────────────────────────────────────────────────
327// CachingCompiler
328// ──────────────────────────────────────────────────────────────────────────────
329
330/// A compiler wrapper that caches results keyed by expression fingerprint.
331///
332/// Uses structural fingerprinting of [`TLExpr`] to detect identical expressions.
333/// Falls back to fresh compilation on a cache miss and stores the result for
334/// subsequent calls.
335///
336/// # Example
337///
338/// ```rust
339/// use tensorlogic_compiler::cache::CachingCompiler;
340/// use tensorlogic_compiler::compile_to_einsum;
341/// use tensorlogic_ir::{TLExpr, Term};
342///
343/// let mut cc = CachingCompiler::new(32, |expr| {
344///     compile_to_einsum(expr).map_err(|e| e.to_string())
345/// });
346///
347/// let e = TLExpr::pred("p", vec![Term::var("x")]);
348/// let g1 = cc.compile(&e).unwrap();
349/// let g2 = cc.compile(&e).unwrap(); // cache hit
350///
351/// assert_eq!(cc.cache_stats().hits, 1);
352/// assert_eq!(g1, g2);
353/// ```
354/// Type alias for the compile function stored in a [`CachingCompiler`].
355type CompileFn =
356    Box<dyn Fn(&TLExpr) -> std::result::Result<EinsumGraph, String> + Send + Sync + 'static>;
357
358pub struct CachingCompiler {
359    cache: LruCompilationCache,
360    compile_fn: CompileFn,
361}
362
363impl CachingCompiler {
364    /// Create a `CachingCompiler` with a custom compile function and cache capacity.
365    ///
366    /// # Arguments
367    ///
368    /// * `capacity` – Maximum number of entries held in the LRU cache.
369    /// * `compile_fn` – A closure (or function) that compiles a [`TLExpr`] into an
370    ///   [`EinsumGraph`], returning `Err(String)` on failure.
371    pub fn new<F>(capacity: usize, compile_fn: F) -> Self
372    where
373        F: Fn(&TLExpr) -> std::result::Result<EinsumGraph, String> + Send + Sync + 'static,
374    {
375        CachingCompiler {
376            cache: LruCompilationCache::new(capacity),
377            compile_fn: Box::new(compile_fn),
378        }
379    }
380
381    /// Compile an expression, returning the cached result when available.
382    ///
383    /// # Errors
384    ///
385    /// Propagates any error produced by the underlying compile function on a cache miss.
386    pub fn compile(&mut self, expr: &TLExpr) -> std::result::Result<EinsumGraph, String> {
387        let fp = Self::fingerprint(expr);
388
389        if let Some(cached) = self.cache.get(&fp) {
390            return Ok(cached.graph.clone());
391        }
392
393        let result = (self.compile_fn)(expr)?;
394        self.cache.insert(fp, result.clone());
395        Ok(result)
396    }
397
398    /// Compile multiple expressions in order, sharing the cache across all of them.
399    ///
400    /// Returns one `Result` per input expression in the same order.
401    pub fn compile_batch(
402        &mut self,
403        exprs: &[TLExpr],
404    ) -> Vec<std::result::Result<EinsumGraph, String>> {
405        exprs.iter().map(|e| self.compile(e)).collect()
406    }
407
408    /// Returns a reference to the current cache statistics.
409    pub fn cache_stats(&self) -> &CacheStats {
410        self.cache.stats()
411    }
412
413    /// Invalidate the cached result for a specific expression.
414    ///
415    /// Returns `true` if an entry was present and removed.
416    pub fn invalidate(&mut self, expr: &TLExpr) -> bool {
417        let fp = Self::fingerprint(expr);
418        self.cache.invalidate(&fp)
419    }
420
421    /// Compute a structural [`ExprFingerprint`] for an expression.
422    ///
423    /// Two structurally identical expressions will produce equal fingerprints.
424    pub fn fingerprint(expr: &TLExpr) -> ExprFingerprint {
425        ExprFingerprint::compute(&Self::structural_repr(expr))
426    }
427
428    /// Produce a deterministic string representation of an expression's structure.
429    ///
430    /// This uses the `Debug` implementation of [`TLExpr`] which is deterministic
431    /// for the same expression tree. Future enhancements may switch to a custom
432    /// canonical serialisation if `Debug` output format changes.
433    fn structural_repr(expr: &TLExpr) -> String {
434        // `Debug` for TLExpr is stable within a single build and deterministic
435        // for identical expression trees, making it a reliable fingerprint source.
436        format!("{:?}", expr)
437    }
438}
439
440// ──────────────────────────────────────────────────────────────────────────────
441// Legacy thread-safe CompilationCache  (original implementation, retained)
442// ──────────────────────────────────────────────────────────────────────────────
443
444/// A hash key for the thread-safe compilation cache.
445#[derive(Debug, Clone, PartialEq, Eq, Hash)]
446struct CacheKey {
447    expr_hash: u64,
448    config_hash: u64,
449    domain_hash: u64,
450}
451
452impl CacheKey {
453    fn new(expr: &TLExpr, config: &CompilationConfig, ctx: &CompilerContext) -> Self {
454        use std::collections::hash_map::DefaultHasher;
455
456        let mut expr_hasher = DefaultHasher::new();
457        format!("{:?}", expr).hash(&mut expr_hasher);
458        let expr_hash = expr_hasher.finish();
459
460        let mut config_hasher = DefaultHasher::new();
461        format!("{:?}", config).hash(&mut config_hasher);
462        let config_hash = config_hasher.finish();
463
464        let mut domain_hasher = DefaultHasher::new();
465        for (name, domain) in &ctx.domains {
466            name.hash(&mut domain_hasher);
467            domain.cardinality.hash(&mut domain_hasher);
468        }
469        let domain_hash = domain_hasher.finish();
470
471        CacheKey {
472            expr_hash,
473            config_hash,
474            domain_hash,
475        }
476    }
477}
478
479/// Internal cached result for the thread-safe cache.
480#[derive(Clone)]
481struct ThreadSafeCachedResult {
482    graph: EinsumGraph,
483    hit_count: usize,
484}
485
486/// Thread-safe compilation cache for storing and retrieving compiled expressions.
487///
488/// Stores compiled `EinsumGraph` instances keyed by a composite hash that includes
489/// the expression structure, compilation configuration, and domain information.
490/// This cache **is** thread-safe and can be shared across compilation threads.
491///
492/// When capacity is exceeded the cache evicts the least-frequently-used entry
493/// (lowest `hit_count`). For strict LRU eviction use [`LruCompilationCache`] or
494/// [`CachingCompiler`] instead.
495///
496/// # Example
497///
498/// ```rust
499/// use tensorlogic_compiler::{CompilationCache, compile_to_einsum_with_context, CompilerContext};
500/// use tensorlogic_ir::{TLExpr, Term};
501///
502/// let cache = CompilationCache::new(100);
503/// let mut ctx = CompilerContext::new();
504/// ctx.add_domain("Person", 100);
505///
506/// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
507///
508/// // First compilation: miss (not in cache)
509/// let graph1 = cache.get_or_compile(&expr, &mut ctx, |expr, ctx| {
510///     compile_to_einsum_with_context(expr, ctx)
511/// }).expect("compile");
512///
513/// // Second compilation: hit (cached)
514/// let graph2 = cache.get_or_compile(&expr, &mut ctx, |expr, ctx| {
515///     compile_to_einsum_with_context(expr, ctx)
516/// }).expect("compile");
517///
518/// assert_eq!(graph1, graph2);
519/// assert_eq!(cache.stats().hits, 1);
520/// ```
521pub struct CompilationCache {
522    cache: Arc<Mutex<HashMap<CacheKey, ThreadSafeCachedResult>>>,
523    max_size: usize,
524    stats: Arc<Mutex<CacheStats>>,
525}
526
527impl CompilationCache {
528    /// Create a new compilation cache with the specified maximum size.
529    ///
530    /// # Arguments
531    ///
532    /// * `max_size` – Maximum number of entries to cache.
533    ///
534    /// # Example
535    ///
536    /// ```rust
537    /// use tensorlogic_compiler::CompilationCache;
538    ///
539    /// let cache = CompilationCache::new(100);
540    /// assert_eq!(cache.max_size(), 100);
541    /// ```
542    pub fn new(max_size: usize) -> Self {
543        Self {
544            cache: Arc::new(Mutex::new(HashMap::new())),
545            max_size,
546            stats: Arc::new(Mutex::new(CacheStats::default())),
547        }
548    }
549
550    /// Create a cache with the default size of 1 000 entries.
551    pub fn default_size() -> Self {
552        Self::new(1000)
553    }
554
555    /// Maximum number of entries the cache can hold.
556    pub fn max_size(&self) -> usize {
557        self.max_size
558    }
559
560    /// Get or compile an expression.
561    ///
562    /// On a cache hit the stored result is returned immediately.
563    /// On a miss `compile_fn` is called and the result is stored before returning.
564    ///
565    /// # Example
566    ///
567    /// ```rust
568    /// use tensorlogic_compiler::{CompilationCache, compile_to_einsum_with_context, CompilerContext};
569    /// use tensorlogic_ir::{TLExpr, Term};
570    ///
571    /// let cache = CompilationCache::new(100);
572    /// let mut ctx = CompilerContext::new();
573    /// ctx.add_domain("Person", 100);
574    ///
575    /// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
576    ///
577    /// let graph = cache.get_or_compile(&expr, &mut ctx, |expr, ctx| {
578    ///     compile_to_einsum_with_context(expr, ctx)
579    /// }).expect("compile");
580    /// ```
581    pub fn get_or_compile<F>(
582        &self,
583        expr: &TLExpr,
584        ctx: &mut CompilerContext,
585        compile_fn: F,
586    ) -> Result<EinsumGraph>
587    where
588        F: FnOnce(&TLExpr, &mut CompilerContext) -> Result<EinsumGraph>,
589    {
590        let key = CacheKey::new(expr, &ctx.config, ctx);
591
592        // Try cache first.
593        {
594            let mut cache = self
595                .cache
596                .lock()
597                .map_err(|e| anyhow::anyhow!("cache lock poisoned: {}", e))?;
598            if let Some(cached) = cache.get_mut(&key) {
599                cached.hit_count += 1;
600                let mut stats = self
601                    .stats
602                    .lock()
603                    .map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
604                stats.hits += 1;
605                return Ok(cached.graph.clone());
606            }
607        }
608
609        // Cache miss — compile.
610        {
611            let mut stats = self
612                .stats
613                .lock()
614                .map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
615            stats.misses += 1;
616        }
617
618        let graph = compile_fn(expr, ctx)?;
619
620        // Store result (evict if necessary).
621        {
622            let mut cache = self
623                .cache
624                .lock()
625                .map_err(|e| anyhow::anyhow!("cache lock poisoned: {}", e))?;
626
627            if cache.len() >= self.max_size {
628                // Evict least-frequently used entry.
629                let min_key = cache
630                    .iter()
631                    .min_by_key(|(_, v)| v.hit_count)
632                    .map(|(k, _)| k.clone());
633
634                if let Some(key_to_evict) = min_key {
635                    cache.remove(&key_to_evict);
636                    let mut stats = self
637                        .stats
638                        .lock()
639                        .map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
640                    stats.evictions += 1;
641                }
642            }
643
644            cache.insert(
645                key,
646                ThreadSafeCachedResult {
647                    graph: graph.clone(),
648                    hit_count: 0,
649                },
650            );
651
652            let mut stats = self
653                .stats
654                .lock()
655                .map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
656            stats.current_entries = cache.len();
657        }
658
659        Ok(graph)
660    }
661
662    /// Current cache statistics snapshot.
663    ///
664    /// # Example
665    ///
666    /// ```rust
667    /// use tensorlogic_compiler::CompilationCache;
668    ///
669    /// let cache = CompilationCache::new(100);
670    /// let stats = cache.stats();
671    /// assert_eq!(stats.hits, 0);
672    /// ```
673    pub fn stats(&self) -> CacheStats {
674        self.stats.lock().map(|g| g.clone()).unwrap_or_default()
675    }
676
677    /// Clear all cached entries.
678    ///
679    /// # Example
680    ///
681    /// ```rust
682    /// use tensorlogic_compiler::CompilationCache;
683    ///
684    /// let cache = CompilationCache::new(100);
685    /// cache.clear();
686    /// assert_eq!(cache.stats().current_entries, 0);
687    /// ```
688    pub fn clear(&self) {
689        if let Ok(mut cache) = self.cache.lock() {
690            cache.clear();
691        }
692        if let Ok(mut stats) = self.stats.lock() {
693            stats.current_entries = 0;
694            stats.total_memory_bytes = 0;
695        }
696    }
697
698    /// Current number of entries in the cache.
699    pub fn len(&self) -> usize {
700        self.cache.lock().map(|g| g.len()).unwrap_or(0)
701    }
702
703    /// Returns `true` when the cache is empty.
704    pub fn is_empty(&self) -> bool {
705        self.len() == 0
706    }
707}
708
709impl Default for CompilationCache {
710    fn default() -> Self {
711        Self::default_size()
712    }
713}
714
715// ──────────────────────────────────────────────────────────────────────────────
716// Tests
717// ──────────────────────────────────────────────────────────────────────────────
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use crate::compile_to_einsum_with_context;
723    use tensorlogic_ir::Term;
724
725    // ── helpers ────────────────────────────────────────────────────────────────
726
727    fn make_graph(node_count: usize) -> EinsumGraph {
728        use tensorlogic_ir::EinsumNode;
729        let mut g = EinsumGraph::new();
730        for i in 0..node_count {
731            let a = g.add_tensor(format!("t{}", i));
732            let b = g.add_tensor(format!("u{}", i));
733            let c = g.add_tensor(format!("v{}", i));
734            g.add_node(EinsumNode::einsum("i,i->i", vec![a, b], vec![c]))
735                .ok();
736        }
737        g
738    }
739
740    fn simple_fp(s: &str) -> ExprFingerprint {
741        ExprFingerprint::compute(s)
742    }
743
744    // ── LruCompilationCache tests ──────────────────────────────────────────────
745
746    /// insert then get returns Some
747    #[test]
748    fn test_cache_basic_insert_get() {
749        let mut cache = LruCompilationCache::new(8);
750        let fp = simple_fp("pred(x)");
751        cache.insert(fp.clone(), EinsumGraph::new());
752        assert!(
753            cache.get(&fp).is_some(),
754            "entry should be present after insert"
755        );
756    }
757
758    /// get on empty cache returns None
759    #[test]
760    fn test_cache_miss() {
761        let mut cache = LruCompilationCache::new(8);
762        let fp = simple_fp("pred(x)");
763        assert!(cache.get(&fp).is_none(), "empty cache must return None");
764    }
765
766    /// hit_count increments on each successful get
767    #[test]
768    fn test_cache_hit_increments_hit_count() {
769        let mut cache = LruCompilationCache::new(8);
770        let fp = simple_fp("pred(x)");
771        cache.insert(fp.clone(), EinsumGraph::new());
772
773        cache.get(&fp);
774        cache.get(&fp);
775
776        // hit_count inside the entry should reflect two reads.
777        assert!(cache.contains(&fp), "entry must still exist after reads");
778        // Obtain hit_count via a final get.
779        let entry = cache.get(&fp).expect("entry must be present");
780        // Three gets were performed (two above + this one) → hit_count == 3.
781        assert_eq!(entry.hit_count, 3, "hit_count should be 3 after three gets");
782    }
783
784    /// 1 hit + 1 miss → hit_rate == 0.5
785    #[test]
786    fn test_cache_stats_hit_rate() {
787        let mut cache = LruCompilationCache::new(8);
788        let fp = simple_fp("pred(x)");
789        cache.insert(fp.clone(), EinsumGraph::new());
790
791        cache.get(&fp); // hit
792        cache.get(&simple_fp("missing")); // miss
793
794        let stats = cache.stats();
795        assert_eq!(stats.hits, 1);
796        assert_eq!(stats.misses, 1);
797        assert!(
798            (stats.hit_rate() - 0.5).abs() < f64::EPSILON,
799            "hit rate must be 0.5"
800        );
801    }
802
803    /// capacity=2, three inserts → oldest evicted
804    #[test]
805    fn test_cache_lru_eviction() {
806        let mut cache = LruCompilationCache::new(2);
807        let fp1 = simple_fp("a");
808        let fp2 = simple_fp("b");
809        let fp3 = simple_fp("c");
810
811        cache.insert(fp1.clone(), EinsumGraph::new());
812        cache.insert(fp2.clone(), EinsumGraph::new());
813        cache.insert(fp3.clone(), EinsumGraph::new()); // should evict fp1
814
815        assert!(
816            !cache.contains(&fp1),
817            "oldest entry (fp1) must have been evicted"
818        );
819        assert!(cache.contains(&fp2), "fp2 must still be present");
820        assert!(cache.contains(&fp3), "fp3 must be present");
821        assert_eq!(cache.len(), 2);
822    }
823
824    /// Access the oldest entry so it becomes newest; the next eviction removes the other one
825    #[test]
826    fn test_cache_lru_access_updates_order() {
827        let mut cache = LruCompilationCache::new(2);
828        let fp1 = simple_fp("a");
829        let fp2 = simple_fp("b");
830        let fp3 = simple_fp("c");
831
832        cache.insert(fp1.clone(), EinsumGraph::new());
833        cache.insert(fp2.clone(), EinsumGraph::new());
834
835        // Access fp1 → it becomes MRU; fp2 is now LRU.
836        cache.get(&fp1);
837
838        // Insert fp3 → fp2 should be evicted (LRU), not fp1.
839        cache.insert(fp3.clone(), EinsumGraph::new());
840
841        assert!(cache.contains(&fp1), "fp1 was accessed so it must survive");
842        assert!(
843            !cache.contains(&fp2),
844            "fp2 is LRU after fp1 was accessed; it must be evicted"
845        );
846        assert!(cache.contains(&fp3), "fp3 must be present");
847    }
848
849    /// invalidate removes an entry
850    #[test]
851    fn test_cache_invalidate() {
852        let mut cache = LruCompilationCache::new(8);
853        let fp = simple_fp("pred(x)");
854        cache.insert(fp.clone(), EinsumGraph::new());
855
856        let removed = cache.invalidate(&fp);
857        assert!(removed, "invalidate must return true when entry existed");
858        assert!(
859            !cache.contains(&fp),
860            "entry must be gone after invalidation"
861        );
862    }
863
864    /// clear empties the cache
865    #[test]
866    fn test_cache_clear() {
867        let mut cache = LruCompilationCache::new(8);
868        cache.insert(simple_fp("a"), EinsumGraph::new());
869        cache.insert(simple_fp("b"), EinsumGraph::new());
870
871        cache.clear();
872
873        assert!(cache.is_empty(), "cache must be empty after clear");
874        assert_eq!(cache.len(), 0);
875        assert_eq!(cache.stats().total_memory_bytes, 0);
876    }
877
878    /// len / is_empty reflect the actual entry count
879    #[test]
880    fn test_cache_len_and_is_empty() {
881        let mut cache = LruCompilationCache::new(8);
882        assert!(cache.is_empty());
883        assert_eq!(cache.len(), 0);
884
885        cache.insert(simple_fp("x"), EinsumGraph::new());
886        assert!(!cache.is_empty());
887        assert_eq!(cache.len(), 1);
888    }
889
890    /// capacity() returns the configured value
891    #[test]
892    fn test_cache_capacity() {
893        let cache = LruCompilationCache::new(42);
894        assert_eq!(cache.capacity(), 42);
895    }
896
897    /// evictions counter is updated correctly
898    #[test]
899    fn test_cache_eviction_stat() {
900        let mut cache = LruCompilationCache::new(2);
901        cache.insert(simple_fp("a"), EinsumGraph::new());
902        cache.insert(simple_fp("b"), EinsumGraph::new());
903        cache.insert(simple_fp("c"), EinsumGraph::new()); // one eviction
904        cache.insert(simple_fp("d"), EinsumGraph::new()); // second eviction
905
906        assert_eq!(
907            cache.stats().evictions,
908            2,
909            "two evictions must have occurred"
910        );
911    }
912
913    /// total_memory_bytes is positive after inserting a non-empty graph
914    #[test]
915    fn test_cache_memory_estimate() {
916        let mut cache = LruCompilationCache::new(8);
917        // Graph with 4 nodes → 4 * 256 = 1024 bytes estimated.
918        let graph = make_graph(4);
919        cache.insert(simple_fp("g"), graph);
920
921        assert!(
922            cache.stats().total_memory_bytes > 0,
923            "memory estimate must be > 0 for a non-empty graph"
924        );
925    }
926
927    // ── ExprFingerprint tests ──────────────────────────────────────────────────
928
929    /// Same expression structure → same fingerprint
930    #[test]
931    fn test_fingerprint_same_for_same_expr() {
932        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
933        let fp1 = CachingCompiler::fingerprint(&expr);
934        let fp2 = CachingCompiler::fingerprint(&expr);
935        assert_eq!(
936            fp1, fp2,
937            "identical expressions must produce identical fingerprints"
938        );
939    }
940
941    /// Display format starts with "fp:"
942    #[test]
943    fn test_fingerprint_display() {
944        let fp = ExprFingerprint::compute("pred(x, y)");
945        let display = format!("{}", fp);
946        assert!(display.starts_with("fp:"), "Display must start with 'fp:'");
947    }
948
949    // ── CachingCompiler tests ─────────────────────────────────────────────────
950
951    fn make_caching_compiler(capacity: usize) -> CachingCompiler {
952        CachingCompiler::new(capacity, |expr| {
953            let mut ctx = CompilerContext::new();
954            compile_to_einsum_with_context(expr, &mut ctx).map_err(|e| e.to_string())
955        })
956    }
957
958    /// Second compile of the same expression should use the cache
959    #[test]
960    fn test_caching_compiler_cache_hit() {
961        let mut cc = make_caching_compiler(32);
962        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
963
964        cc.compile(&expr).expect("first compile");
965        cc.compile(&expr).expect("second compile");
966
967        assert_eq!(
968            cc.cache_stats().hits,
969            1,
970            "second compile must be a cache hit"
971        );
972    }
973
974    /// First compile counts as a miss
975    #[test]
976    fn test_caching_compiler_cache_miss_count() {
977        let mut cc = make_caching_compiler(32);
978        let expr = TLExpr::pred("likes", vec![Term::var("a"), Term::var("b")]);
979
980        cc.compile(&expr).expect("compile");
981
982        assert_eq!(
983            cc.cache_stats().misses,
984            1,
985            "first compile must be a cache miss"
986        );
987        assert_eq!(cc.cache_stats().hits, 0);
988    }
989
990    /// compile_batch processes all expressions
991    #[test]
992    fn test_caching_compiler_batch() {
993        let mut cc = make_caching_compiler(32);
994        let exprs = vec![
995            TLExpr::pred("p", vec![Term::var("x")]),
996            TLExpr::pred("q", vec![Term::var("y")]),
997            TLExpr::pred("r", vec![Term::var("z")]),
998        ];
999
1000        let results = cc.compile_batch(&exprs);
1001        assert_eq!(results.len(), 3, "batch must return one result per input");
1002        for (i, r) in results.iter().enumerate() {
1003            assert!(r.is_ok(), "result[{}] must be Ok", i);
1004        }
1005    }
1006
1007    /// invalidate clears the entry for a specific expression
1008    #[test]
1009    fn test_caching_compiler_invalidate() {
1010        let mut cc = make_caching_compiler(32);
1011        let expr = TLExpr::pred("p", vec![Term::var("x")]);
1012
1013        cc.compile(&expr).expect("compile");
1014        let removed = cc.invalidate(&expr);
1015        assert!(removed, "invalidate must return true when entry existed");
1016
1017        // Re-compiling should be a miss again.
1018        cc.compile(&expr).expect("re-compile");
1019        assert_eq!(
1020            cc.cache_stats().misses,
1021            2,
1022            "re-compile after invalidation must be another miss"
1023        );
1024    }
1025
1026    // ── Default / misc tests ──────────────────────────────────────────────────
1027
1028    /// Default LRU cache capacity is 256
1029    #[test]
1030    fn test_cache_default_capacity() {
1031        let cache = LruCompilationCache::default();
1032        assert_eq!(cache.capacity(), 256, "default capacity must be 256");
1033    }
1034
1035    /// ExprFingerprint implements Hash and can be used as a HashMap key
1036    #[test]
1037    fn test_expr_fingerprint_hash() {
1038        let mut map: HashMap<ExprFingerprint, u32> = HashMap::new();
1039        let fp = ExprFingerprint::compute("some_expr");
1040        map.insert(fp.clone(), 42);
1041        assert_eq!(
1042            map.get(&fp),
1043            Some(&42),
1044            "fingerprint must work as HashMap key"
1045        );
1046    }
1047
1048    // ── Legacy CompilationCache tests ─────────────────────────────────────────
1049
1050    #[test]
1051    fn test_ts_cache_new() {
1052        let cache = CompilationCache::new(100);
1053        assert_eq!(cache.max_size(), 100);
1054        assert_eq!(cache.len(), 0);
1055        assert!(cache.is_empty());
1056    }
1057
1058    #[test]
1059    fn test_ts_cache_hit() {
1060        let cache = CompilationCache::new(100);
1061        let mut ctx = CompilerContext::new();
1062        ctx.add_domain("Person", 100);
1063
1064        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
1065
1066        let graph1 = cache
1067            .get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
1068            .expect("compile");
1069
1070        let stats = cache.stats();
1071        assert_eq!(stats.misses, 1);
1072        assert_eq!(stats.hits, 0);
1073
1074        let graph2 = cache
1075            .get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
1076            .expect("compile");
1077
1078        let stats = cache.stats();
1079        assert_eq!(stats.misses, 1);
1080        assert_eq!(stats.hits, 1);
1081        assert!(
1082            (stats.hit_rate() - 0.5).abs() < f64::EPSILON,
1083            "hit rate must be 0.5"
1084        );
1085
1086        assert_eq!(graph1, graph2);
1087    }
1088
1089    #[test]
1090    fn test_ts_cache_different_expressions() {
1091        let cache = CompilationCache::new(100);
1092        let mut ctx = CompilerContext::new();
1093        ctx.add_domain("Person", 100);
1094
1095        let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
1096        let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("y")]);
1097
1098        let _ = cache
1099            .get_or_compile(&expr1, &mut ctx, compile_to_einsum_with_context)
1100            .expect("compile");
1101        let _ = cache
1102            .get_or_compile(&expr2, &mut ctx, compile_to_einsum_with_context)
1103            .expect("compile");
1104
1105        let stats = cache.stats();
1106        assert_eq!(stats.misses, 2);
1107        assert_eq!(stats.hits, 0);
1108        assert_eq!(cache.len(), 2);
1109    }
1110
1111    #[test]
1112    fn test_ts_cache_eviction() {
1113        let cache = CompilationCache::new(2);
1114        let mut ctx = CompilerContext::new();
1115        ctx.add_domain("Person", 100);
1116
1117        let _ = cache.get_or_compile(
1118            &TLExpr::pred("p1", vec![Term::var("x")]),
1119            &mut ctx,
1120            compile_to_einsum_with_context,
1121        );
1122        let _ = cache.get_or_compile(
1123            &TLExpr::pred("p2", vec![Term::var("x")]),
1124            &mut ctx,
1125            compile_to_einsum_with_context,
1126        );
1127        let _ = cache.get_or_compile(
1128            &TLExpr::pred("p3", vec![Term::var("x")]),
1129            &mut ctx,
1130            compile_to_einsum_with_context,
1131        );
1132
1133        let stats = cache.stats();
1134        assert_eq!(stats.evictions, 1);
1135        assert_eq!(cache.len(), 2);
1136    }
1137
1138    #[test]
1139    fn test_ts_cache_clear() {
1140        let cache = CompilationCache::new(100);
1141        let mut ctx = CompilerContext::new();
1142        ctx.add_domain("Person", 100);
1143
1144        let _ = cache.get_or_compile(
1145            &TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
1146            &mut ctx,
1147            compile_to_einsum_with_context,
1148        );
1149
1150        assert_eq!(cache.len(), 1);
1151        cache.clear();
1152        assert_eq!(cache.len(), 0);
1153        assert!(cache.is_empty());
1154    }
1155
1156    #[test]
1157    fn test_ts_cache_stats() {
1158        let cache = CompilationCache::new(100);
1159        let stats = cache.stats();
1160
1161        assert_eq!(stats.hits, 0);
1162        assert_eq!(stats.misses, 0);
1163        assert_eq!(stats.evictions, 0);
1164        assert_eq!(stats.current_entries, 0);
1165        assert_eq!(stats.hit_rate(), 0.0);
1166        assert_eq!(stats.total_lookups(), 0);
1167    }
1168}