1use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::sync::{Arc, Mutex};
9
10use anyhow::Result;
11use tensorlogic_ir::{EinsumGraph, TLExpr};
12
13use crate::config::CompilationConfig;
14use crate::CompilerContext;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct CacheKey {
19 expr_hash: u64,
21 config_hash: u64,
23 domain_hash: u64,
25}
26
27impl CacheKey {
28 fn new(expr: &TLExpr, config: &CompilationConfig, ctx: &CompilerContext) -> Self {
30 use std::collections::hash_map::DefaultHasher;
31
32 let mut expr_hasher = DefaultHasher::new();
34 format!("{:?}", expr).hash(&mut expr_hasher);
35 let expr_hash = expr_hasher.finish();
36
37 let mut config_hasher = DefaultHasher::new();
39 format!("{:?}", config).hash(&mut config_hasher);
40 let config_hash = config_hasher.finish();
41
42 let mut domain_hasher = DefaultHasher::new();
44 for (name, domain) in &ctx.domains {
45 name.hash(&mut domain_hasher);
46 domain.cardinality.hash(&mut domain_hasher);
47 }
48 let domain_hash = domain_hasher.finish();
49
50 CacheKey {
51 expr_hash,
52 config_hash,
53 domain_hash,
54 }
55 }
56}
57
58#[derive(Clone)]
60struct CachedResult {
61 graph: EinsumGraph,
63 hit_count: usize,
65}
66
67pub struct CompilationCache {
99 cache: Arc<Mutex<HashMap<CacheKey, CachedResult>>>,
101 max_size: usize,
103 stats: Arc<Mutex<CacheStats>>,
105}
106
107#[derive(Debug, Clone, Default)]
109pub struct CacheStats {
110 pub hits: u64,
112 pub misses: u64,
114 pub evictions: u64,
116 pub current_size: usize,
118}
119
120impl CacheStats {
121 pub fn hit_rate(&self) -> f64 {
123 let total = self.hits + self.misses;
124 if total == 0 {
125 0.0
126 } else {
127 self.hits as f64 / total as f64
128 }
129 }
130
131 pub fn total_lookups(&self) -> u64 {
133 self.hits + self.misses
134 }
135}
136
137impl CompilationCache {
138 pub fn new(max_size: usize) -> Self {
153 Self {
154 cache: Arc::new(Mutex::new(HashMap::new())),
155 max_size,
156 stats: Arc::new(Mutex::new(CacheStats::default())),
157 }
158 }
159
160 pub fn default_size() -> Self {
162 Self::new(1000)
163 }
164
165 pub fn max_size(&self) -> usize {
167 self.max_size
168 }
169
170 pub fn get_or_compile<F>(
199 &self,
200 expr: &TLExpr,
201 ctx: &mut CompilerContext,
202 compile_fn: F,
203 ) -> Result<EinsumGraph>
204 where
205 F: FnOnce(&TLExpr, &mut CompilerContext) -> Result<EinsumGraph>,
206 {
207 let key = CacheKey::new(expr, &ctx.config, ctx);
208
209 {
211 let mut cache = self.cache.lock().unwrap();
212 if let Some(cached) = cache.get_mut(&key) {
213 cached.hit_count += 1;
215 let mut stats = self.stats.lock().unwrap();
216 stats.hits += 1;
217 return Ok(cached.graph.clone());
218 }
219 }
220
221 let mut stats = self.stats.lock().unwrap();
223 stats.misses += 1;
224 drop(stats);
225
226 let graph = compile_fn(expr, ctx)?;
227
228 {
230 let mut cache = self.cache.lock().unwrap();
231
232 if cache.len() >= self.max_size {
234 let min_key = cache
236 .iter()
237 .min_by_key(|(_, v)| v.hit_count)
238 .map(|(k, _)| k.clone());
239
240 if let Some(key_to_evict) = min_key {
241 cache.remove(&key_to_evict);
242 let mut stats = self.stats.lock().unwrap();
243 stats.evictions += 1;
244 }
245 }
246
247 cache.insert(
248 key,
249 CachedResult {
250 graph: graph.clone(),
251 hit_count: 0,
252 },
253 );
254
255 let mut stats = self.stats.lock().unwrap();
256 stats.current_size = cache.len();
257 }
258
259 Ok(graph)
260 }
261
262 pub fn stats(&self) -> CacheStats {
275 self.stats.lock().unwrap().clone()
276 }
277
278 pub fn clear(&self) {
290 let mut cache = self.cache.lock().unwrap();
291 cache.clear();
292 let mut stats = self.stats.lock().unwrap();
293 stats.current_size = 0;
294 }
295
296 pub fn len(&self) -> usize {
298 self.cache.lock().unwrap().len()
299 }
300
301 pub fn is_empty(&self) -> bool {
303 self.len() == 0
304 }
305}
306
307impl Default for CompilationCache {
308 fn default() -> Self {
309 Self::default_size()
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::compile_to_einsum_with_context;
317 use tensorlogic_ir::Term;
318
319 #[test]
320 fn test_cache_new() {
321 let cache = CompilationCache::new(100);
322 assert_eq!(cache.max_size(), 100);
323 assert_eq!(cache.len(), 0);
324 assert!(cache.is_empty());
325 }
326
327 #[test]
328 fn test_cache_hit() {
329 let cache = CompilationCache::new(100);
330 let mut ctx = CompilerContext::new();
331 ctx.add_domain("Person", 100);
332
333 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
334
335 let graph1 = cache
337 .get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
338 .unwrap();
339
340 let stats = cache.stats();
341 assert_eq!(stats.misses, 1);
342 assert_eq!(stats.hits, 0);
343
344 let graph2 = cache
346 .get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
347 .unwrap();
348
349 let stats = cache.stats();
350 assert_eq!(stats.misses, 1);
351 assert_eq!(stats.hits, 1);
352 assert_eq!(stats.hit_rate(), 0.5);
353
354 assert_eq!(graph1, graph2);
356 }
357
358 #[test]
359 fn test_cache_different_expressions() {
360 let cache = CompilationCache::new(100);
361 let mut ctx = CompilerContext::new();
362 ctx.add_domain("Person", 100);
363
364 let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
365 let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("y")]);
366
367 let _graph1 = cache
369 .get_or_compile(&expr1, &mut ctx, |e, c| {
370 compile_to_einsum_with_context(e, c)
371 })
372 .unwrap();
373 let _graph2 = cache
374 .get_or_compile(&expr2, &mut ctx, |e, c| {
375 compile_to_einsum_with_context(e, c)
376 })
377 .unwrap();
378
379 let stats = cache.stats();
381 assert_eq!(stats.misses, 2);
382 assert_eq!(stats.hits, 0);
383 assert_eq!(cache.len(), 2);
384 }
385
386 #[test]
387 fn test_cache_eviction() {
388 let cache = CompilationCache::new(2); let mut ctx = CompilerContext::new();
390 ctx.add_domain("Person", 100);
391
392 let expr1 = TLExpr::pred("p1", vec![Term::var("x")]);
393 let expr2 = TLExpr::pred("p2", vec![Term::var("x")]);
394 let expr3 = TLExpr::pred("p3", vec![Term::var("x")]);
395
396 let _ = cache.get_or_compile(&expr1, &mut ctx, |e, c| {
398 compile_to_einsum_with_context(e, c)
399 });
400 let _ = cache.get_or_compile(&expr2, &mut ctx, |e, c| {
401 compile_to_einsum_with_context(e, c)
402 });
403 let _ = cache.get_or_compile(&expr3, &mut ctx, |e, c| {
404 compile_to_einsum_with_context(e, c)
405 });
406
407 let stats = cache.stats();
409 assert_eq!(stats.evictions, 1);
410 assert_eq!(cache.len(), 2);
411 }
412
413 #[test]
414 fn test_cache_clear() {
415 let cache = CompilationCache::new(100);
416 let mut ctx = CompilerContext::new();
417 ctx.add_domain("Person", 100);
418
419 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
420
421 let _ = cache.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context);
423
424 assert_eq!(cache.len(), 1);
425
426 cache.clear();
428
429 assert_eq!(cache.len(), 0);
430 assert!(cache.is_empty());
431 }
432
433 #[test]
434 fn test_cache_stats() {
435 let cache = CompilationCache::new(100);
436 let stats = cache.stats();
437
438 assert_eq!(stats.hits, 0);
439 assert_eq!(stats.misses, 0);
440 assert_eq!(stats.evictions, 0);
441 assert_eq!(stats.current_size, 0);
442 assert_eq!(stats.hit_rate(), 0.0);
443 assert_eq!(stats.total_lookups(), 0);
444 }
445}