Skip to main content

tensorlogic_compiler/
cache.rs

1//! Compilation caching for improved performance.
2//!
3//! This module provides a caching layer for compiled expressions, reducing
4//! redundant compilation when the same expressions are compiled multiple times.
5
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::sync::{Arc, Mutex};
9
10use anyhow::Result;
11use tensorlogic_ir::{EinsumGraph, TLExpr};
12
13use crate::config::CompilationConfig;
14use crate::CompilerContext;
15
16/// A hash key for caching compiled expressions.
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct CacheKey {
19    /// Expression hash
20    expr_hash: u64,
21    /// Configuration hash
22    config_hash: u64,
23    /// Domain information hash
24    domain_hash: u64,
25}
26
27impl CacheKey {
28    /// Create a new cache key from an expression, configuration, and context.
29    fn new(expr: &TLExpr, config: &CompilationConfig, ctx: &CompilerContext) -> Self {
30        use std::collections::hash_map::DefaultHasher;
31
32        // Hash the expression
33        let mut expr_hasher = DefaultHasher::new();
34        format!("{:?}", expr).hash(&mut expr_hasher);
35        let expr_hash = expr_hasher.finish();
36
37        // Hash the configuration
38        let mut config_hasher = DefaultHasher::new();
39        format!("{:?}", config).hash(&mut config_hasher);
40        let config_hash = config_hasher.finish();
41
42        // Hash the domains
43        let mut domain_hasher = DefaultHasher::new();
44        for (name, domain) in &ctx.domains {
45            name.hash(&mut domain_hasher);
46            domain.cardinality.hash(&mut domain_hasher);
47        }
48        let domain_hash = domain_hasher.finish();
49
50        CacheKey {
51            expr_hash,
52            config_hash,
53            domain_hash,
54        }
55    }
56}
57
58/// Cached compilation result.
59#[derive(Clone)]
60struct CachedResult {
61    /// The compiled graph
62    graph: EinsumGraph,
63    /// Number of times this entry has been hit
64    hit_count: usize,
65}
66
67/// Compilation cache for storing and retrieving compiled expressions.
68///
69/// The cache is thread-safe and can be shared across multiple compilation
70/// operations. It automatically evicts least-recently-used entries when
71/// the cache reaches its maximum size.
72///
73/// # Example
74///
75/// ```
76/// use tensorlogic_compiler::{CompilationCache, compile_to_einsum_with_context, CompilerContext};
77/// use tensorlogic_ir::{TLExpr, Term};
78///
79/// let cache = CompilationCache::new(100); // Cache up to 100 entries
80/// let mut ctx = CompilerContext::new();
81/// ctx.add_domain("Person", 100);
82///
83/// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
84///
85/// // First compilation: miss (not in cache)
86/// let graph1 = cache.get_or_compile(&expr, &mut ctx, |expr, ctx| {
87///     compile_to_einsum_with_context(expr, ctx)
88/// }).unwrap();
89///
90/// // Second compilation: hit (cached)
91/// let graph2 = cache.get_or_compile(&expr, &mut ctx, |expr, ctx| {
92///     compile_to_einsum_with_context(expr, ctx)
93/// }).unwrap();
94///
95/// assert_eq!(graph1, graph2);
96/// assert_eq!(cache.stats().hits, 1);
97/// ```
98pub struct CompilationCache {
99    /// Cache storage
100    cache: Arc<Mutex<HashMap<CacheKey, CachedResult>>>,
101    /// Maximum cache size
102    max_size: usize,
103    /// Cache statistics
104    stats: Arc<Mutex<CacheStats>>,
105}
106
107/// Statistics about cache performance.
108#[derive(Debug, Clone, Default)]
109pub struct CacheStats {
110    /// Number of cache hits
111    pub hits: u64,
112    /// Number of cache misses
113    pub misses: u64,
114    /// Number of entries evicted
115    pub evictions: u64,
116    /// Current cache size
117    pub current_size: usize,
118}
119
120impl CacheStats {
121    /// Calculate the hit rate (0.0 to 1.0).
122    pub fn hit_rate(&self) -> f64 {
123        let total = self.hits + self.misses;
124        if total == 0 {
125            0.0
126        } else {
127            self.hits as f64 / total as f64
128        }
129    }
130
131    /// Calculate the total number of lookups.
132    pub fn total_lookups(&self) -> u64 {
133        self.hits + self.misses
134    }
135}
136
137impl CompilationCache {
138    /// Create a new compilation cache with the specified maximum size.
139    ///
140    /// # Arguments
141    ///
142    /// * `max_size` - Maximum number of entries to cache (default: 1000)
143    ///
144    /// # Example
145    ///
146    /// ```
147    /// use tensorlogic_compiler::CompilationCache;
148    ///
149    /// let cache = CompilationCache::new(100);
150    /// assert_eq!(cache.max_size(), 100);
151    /// ```
152    pub fn new(max_size: usize) -> Self {
153        Self {
154            cache: Arc::new(Mutex::new(HashMap::new())),
155            max_size,
156            stats: Arc::new(Mutex::new(CacheStats::default())),
157        }
158    }
159
160    /// Create a new cache with default settings (1000 entries).
161    pub fn default_size() -> Self {
162        Self::new(1000)
163    }
164
165    /// Get the maximum cache size.
166    pub fn max_size(&self) -> usize {
167        self.max_size
168    }
169
170    /// Get or compile an expression.
171    ///
172    /// If the expression is in the cache, returns the cached result.
173    /// Otherwise, compiles the expression using the provided function
174    /// and caches the result.
175    ///
176    /// # Arguments
177    ///
178    /// * `expr` - The expression to compile
179    /// * `ctx` - The compiler context
180    /// * `compile_fn` - Function to compile the expression if not cached
181    ///
182    /// # Example
183    ///
184    /// ```
185    /// use tensorlogic_compiler::{CompilationCache, compile_to_einsum_with_context, CompilerContext};
186    /// use tensorlogic_ir::{TLExpr, Term};
187    ///
188    /// let cache = CompilationCache::new(100);
189    /// let mut ctx = CompilerContext::new();
190    /// ctx.add_domain("Person", 100);
191    ///
192    /// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
193    ///
194    /// let graph = cache.get_or_compile(&expr, &mut ctx, |expr, ctx| {
195    ///     compile_to_einsum_with_context(expr, ctx)
196    /// }).unwrap();
197    /// ```
198    pub fn get_or_compile<F>(
199        &self,
200        expr: &TLExpr,
201        ctx: &mut CompilerContext,
202        compile_fn: F,
203    ) -> Result<EinsumGraph>
204    where
205        F: FnOnce(&TLExpr, &mut CompilerContext) -> Result<EinsumGraph>,
206    {
207        let key = CacheKey::new(expr, &ctx.config, ctx);
208
209        // Try to get from cache
210        {
211            let mut cache = self.cache.lock().unwrap();
212            if let Some(cached) = cache.get_mut(&key) {
213                // Cache hit
214                cached.hit_count += 1;
215                let mut stats = self.stats.lock().unwrap();
216                stats.hits += 1;
217                return Ok(cached.graph.clone());
218            }
219        }
220
221        // Cache miss - compile
222        let mut stats = self.stats.lock().unwrap();
223        stats.misses += 1;
224        drop(stats);
225
226        let graph = compile_fn(expr, ctx)?;
227
228        // Store in cache
229        {
230            let mut cache = self.cache.lock().unwrap();
231
232            // Evict if necessary
233            if cache.len() >= self.max_size {
234                // Find least-used entry
235                let min_key = cache
236                    .iter()
237                    .min_by_key(|(_, v)| v.hit_count)
238                    .map(|(k, _)| k.clone());
239
240                if let Some(key_to_evict) = min_key {
241                    cache.remove(&key_to_evict);
242                    let mut stats = self.stats.lock().unwrap();
243                    stats.evictions += 1;
244                }
245            }
246
247            cache.insert(
248                key,
249                CachedResult {
250                    graph: graph.clone(),
251                    hit_count: 0,
252                },
253            );
254
255            let mut stats = self.stats.lock().unwrap();
256            stats.current_size = cache.len();
257        }
258
259        Ok(graph)
260    }
261
262    /// Get current cache statistics.
263    ///
264    /// # Example
265    ///
266    /// ```
267    /// use tensorlogic_compiler::CompilationCache;
268    ///
269    /// let cache = CompilationCache::new(100);
270    /// let stats = cache.stats();
271    /// assert_eq!(stats.hits, 0);
272    /// assert_eq!(stats.misses, 0);
273    /// ```
274    pub fn stats(&self) -> CacheStats {
275        self.stats.lock().unwrap().clone()
276    }
277
278    /// Clear the cache.
279    ///
280    /// # Example
281    ///
282    /// ```
283    /// use tensorlogic_compiler::CompilationCache;
284    ///
285    /// let cache = CompilationCache::new(100);
286    /// cache.clear();
287    /// assert_eq!(cache.stats().current_size, 0);
288    /// ```
289    pub fn clear(&self) {
290        let mut cache = self.cache.lock().unwrap();
291        cache.clear();
292        let mut stats = self.stats.lock().unwrap();
293        stats.current_size = 0;
294    }
295
296    /// Get the current number of entries in the cache.
297    pub fn len(&self) -> usize {
298        self.cache.lock().unwrap().len()
299    }
300
301    /// Check if the cache is empty.
302    pub fn is_empty(&self) -> bool {
303        self.len() == 0
304    }
305}
306
307impl Default for CompilationCache {
308    fn default() -> Self {
309        Self::default_size()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::compile_to_einsum_with_context;
317    use tensorlogic_ir::Term;
318
319    #[test]
320    fn test_cache_new() {
321        let cache = CompilationCache::new(100);
322        assert_eq!(cache.max_size(), 100);
323        assert_eq!(cache.len(), 0);
324        assert!(cache.is_empty());
325    }
326
327    #[test]
328    fn test_cache_hit() {
329        let cache = CompilationCache::new(100);
330        let mut ctx = CompilerContext::new();
331        ctx.add_domain("Person", 100);
332
333        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
334
335        // First compilation: miss
336        let graph1 = cache
337            .get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
338            .unwrap();
339
340        let stats = cache.stats();
341        assert_eq!(stats.misses, 1);
342        assert_eq!(stats.hits, 0);
343
344        // Second compilation: hit
345        let graph2 = cache
346            .get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
347            .unwrap();
348
349        let stats = cache.stats();
350        assert_eq!(stats.misses, 1);
351        assert_eq!(stats.hits, 1);
352        assert_eq!(stats.hit_rate(), 0.5);
353
354        // Graphs should be identical
355        assert_eq!(graph1, graph2);
356    }
357
358    #[test]
359    fn test_cache_different_expressions() {
360        let cache = CompilationCache::new(100);
361        let mut ctx = CompilerContext::new();
362        ctx.add_domain("Person", 100);
363
364        let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
365        let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("y")]);
366
367        // Compile both
368        let _graph1 = cache
369            .get_or_compile(&expr1, &mut ctx, |e, c| {
370                compile_to_einsum_with_context(e, c)
371            })
372            .unwrap();
373        let _graph2 = cache
374            .get_or_compile(&expr2, &mut ctx, |e, c| {
375                compile_to_einsum_with_context(e, c)
376            })
377            .unwrap();
378
379        // Both should be misses
380        let stats = cache.stats();
381        assert_eq!(stats.misses, 2);
382        assert_eq!(stats.hits, 0);
383        assert_eq!(cache.len(), 2);
384    }
385
386    #[test]
387    fn test_cache_eviction() {
388        let cache = CompilationCache::new(2); // Small cache
389        let mut ctx = CompilerContext::new();
390        ctx.add_domain("Person", 100);
391
392        let expr1 = TLExpr::pred("p1", vec![Term::var("x")]);
393        let expr2 = TLExpr::pred("p2", vec![Term::var("x")]);
394        let expr3 = TLExpr::pred("p3", vec![Term::var("x")]);
395
396        // Compile three expressions (should evict one)
397        let _ = cache.get_or_compile(&expr1, &mut ctx, |e, c| {
398            compile_to_einsum_with_context(e, c)
399        });
400        let _ = cache.get_or_compile(&expr2, &mut ctx, |e, c| {
401            compile_to_einsum_with_context(e, c)
402        });
403        let _ = cache.get_or_compile(&expr3, &mut ctx, |e, c| {
404            compile_to_einsum_with_context(e, c)
405        });
406
407        // Should have evicted one entry
408        let stats = cache.stats();
409        assert_eq!(stats.evictions, 1);
410        assert_eq!(cache.len(), 2);
411    }
412
413    #[test]
414    fn test_cache_clear() {
415        let cache = CompilationCache::new(100);
416        let mut ctx = CompilerContext::new();
417        ctx.add_domain("Person", 100);
418
419        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
420
421        // Compile and cache
422        let _ = cache.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context);
423
424        assert_eq!(cache.len(), 1);
425
426        // Clear
427        cache.clear();
428
429        assert_eq!(cache.len(), 0);
430        assert!(cache.is_empty());
431    }
432
433    #[test]
434    fn test_cache_stats() {
435        let cache = CompilationCache::new(100);
436        let stats = cache.stats();
437
438        assert_eq!(stats.hits, 0);
439        assert_eq!(stats.misses, 0);
440        assert_eq!(stats.evictions, 0);
441        assert_eq!(stats.current_size, 0);
442        assert_eq!(stats.hit_rate(), 0.0);
443        assert_eq!(stats.total_lookups(), 0);
444    }
445}