1use crate::{compile_to_einsum_with_context, CompilerContext};
51use std::collections::{HashMap, HashSet};
52use std::sync::{Arc, Mutex};
53use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct ExpressionDependencies {
58 pub predicates: HashSet<String>,
60 pub variables: HashSet<String>,
62 pub domains: HashSet<String>,
64 pub config_hash: u64,
66}
67
68impl ExpressionDependencies {
69 pub fn new() -> Self {
71 Self {
72 predicates: HashSet::new(),
73 variables: HashSet::new(),
74 domains: HashSet::new(),
75 config_hash: 0,
76 }
77 }
78
79 pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
81 let mut deps = Self::new();
82 deps.analyze_recursive(expr);
83 deps.config_hash = Self::hash_config(ctx);
84 deps
85 }
86
87 fn analyze_recursive(&mut self, expr: &TLExpr) {
88 match expr {
89 TLExpr::Pred { name, args } => {
90 self.predicates.insert(name.clone());
91 for arg in args {
92 self.analyze_term(arg);
93 }
94 }
95 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
96 self.analyze_recursive(left);
97 self.analyze_recursive(right);
98 }
99 TLExpr::Not(inner) => {
100 self.analyze_recursive(inner);
101 }
102 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
103 self.variables.insert(var.clone());
104 self.domains.insert(domain.clone());
105 self.analyze_recursive(body);
106 }
107 TLExpr::Score(inner) => {
108 self.analyze_recursive(inner);
109 }
110 TLExpr::Add(left, right)
111 | TLExpr::Sub(left, right)
112 | TLExpr::Mul(left, right)
113 | TLExpr::Div(left, right) => {
114 self.analyze_recursive(left);
115 self.analyze_recursive(right);
116 }
117 TLExpr::Eq(left, right)
118 | TLExpr::Lt(left, right)
119 | TLExpr::Gt(left, right)
120 | TLExpr::Lte(left, right)
121 | TLExpr::Gte(left, right) => {
122 self.analyze_recursive(left);
123 self.analyze_recursive(right);
124 }
125 TLExpr::IfThenElse {
126 condition,
127 then_branch,
128 else_branch,
129 } => {
130 self.analyze_recursive(condition);
131 self.analyze_recursive(then_branch);
132 self.analyze_recursive(else_branch);
133 }
134 TLExpr::Aggregate {
135 op: _,
136 var,
137 domain,
138 body,
139 group_by,
140 } => {
141 self.variables.insert(var.clone());
142 self.domains.insert(domain.clone());
143 self.analyze_recursive(body);
144 if let Some(gb_vars) = group_by {
145 for var_name in gb_vars {
146 self.variables.insert(var_name.clone());
147 }
148 }
149 }
150 TLExpr::TNorm {
151 kind: _,
152 left,
153 right,
154 }
155 | TLExpr::TCoNorm {
156 kind: _,
157 left,
158 right,
159 } => {
160 self.analyze_recursive(left);
161 self.analyze_recursive(right);
162 }
163 TLExpr::FuzzyNot {
164 kind: _,
165 expr: inner,
166 } => {
167 self.analyze_recursive(inner);
168 }
169 TLExpr::FuzzyImplication {
170 kind: _,
171 premise,
172 conclusion,
173 } => {
174 self.analyze_recursive(premise);
175 self.analyze_recursive(conclusion);
176 }
177 TLExpr::SoftExists {
178 var,
179 domain,
180 body,
181 temperature: _,
182 }
183 | TLExpr::SoftForAll {
184 var,
185 domain,
186 body,
187 temperature: _,
188 } => {
189 self.variables.insert(var.clone());
190 self.domains.insert(domain.clone());
191 self.analyze_recursive(body);
192 }
193 TLExpr::WeightedRule { weight: _, rule } => {
194 self.analyze_recursive(rule);
195 }
196 TLExpr::ProbabilisticChoice { alternatives } => {
197 for (_, alt) in alternatives {
198 self.analyze_recursive(alt);
199 }
200 }
201 TLExpr::Let { var, value, body } => {
202 self.variables.insert(var.clone());
203 self.analyze_recursive(value);
204 self.analyze_recursive(body);
205 }
206 TLExpr::Box(inner)
207 | TLExpr::Diamond(inner)
208 | TLExpr::Next(inner)
209 | TLExpr::Eventually(inner)
210 | TLExpr::Always(inner) => {
211 self.analyze_recursive(inner);
212 }
213 TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
214 self.analyze_recursive(before);
215 self.analyze_recursive(after);
216 }
217 TLExpr::Release { released, releaser }
218 | TLExpr::StrongRelease { released, releaser } => {
219 self.analyze_recursive(released);
220 self.analyze_recursive(releaser);
221 }
222 TLExpr::Abs(inner)
224 | TLExpr::Sqrt(inner)
225 | TLExpr::Exp(inner)
226 | TLExpr::Log(inner)
227 | TLExpr::Sin(inner)
228 | TLExpr::Cos(inner)
229 | TLExpr::Tan(inner)
230 | TLExpr::Floor(inner)
231 | TLExpr::Ceil(inner)
232 | TLExpr::Round(inner) => {
233 self.analyze_recursive(inner);
234 }
235 TLExpr::Pow(left, right)
236 | TLExpr::Min(left, right)
237 | TLExpr::Max(left, right)
238 | TLExpr::Mod(left, right) => {
239 self.analyze_recursive(left);
240 self.analyze_recursive(right);
241 }
242 TLExpr::Constant(_) => {
243 }
245 _ => {
247 }
250 }
251 }
252
253 fn analyze_term(&mut self, term: &Term) {
254 if let Term::Var(name) = term {
255 self.variables.insert(name.clone());
256 }
257 }
258
259 fn hash_config(ctx: &CompilerContext) -> u64 {
260 use std::collections::hash_map::DefaultHasher;
261 use std::hash::{Hash, Hasher};
262
263 let mut hasher = DefaultHasher::new();
264 format!("{:?}", ctx.config).hash(&mut hasher);
266 hasher.finish()
267 }
268}
269
270impl Default for ExpressionDependencies {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct ChangeDetector {
279 previous_predicates: HashMap<String, (usize, Vec<String>)>,
281 previous_domains: HashMap<String, usize>,
283 previous_config_hash: u64,
285}
286
287impl ChangeDetector {
288 pub fn new() -> Self {
290 Self {
291 previous_predicates: HashMap::new(),
292 previous_domains: HashMap::new(),
293 previous_config_hash: 0,
294 }
295 }
296
297 pub fn update(&mut self, ctx: &CompilerContext) {
299 self.previous_predicates.clear();
300 self.previous_domains.clear();
301
302 for (name, info) in &ctx.domains {
304 self.previous_domains.insert(name.clone(), info.cardinality);
305 }
306
307 self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
308 }
309
310 pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
312 let mut changes = ChangeSet::new();
313
314 for (name, info) in &ctx.domains {
316 if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
317 if prev_size != info.cardinality {
318 changes.changed_domains.insert(name.clone());
319 }
320 } else {
321 changes.new_domains.insert(name.clone());
322 }
323 }
324
325 for name in self.previous_domains.keys() {
327 if !ctx.domains.contains_key(name) {
328 changes.removed_domains.insert(name.clone());
329 }
330 }
331
332 let current_hash = ExpressionDependencies::hash_config(ctx);
334 if current_hash != self.previous_config_hash {
335 changes.config_changed = true;
336 }
337
338 changes
339 }
340}
341
342impl Default for ChangeDetector {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[derive(Debug, Clone, Default)]
350pub struct ChangeSet {
351 pub new_predicates: HashSet<String>,
353 pub changed_predicates: HashSet<String>,
355 pub removed_predicates: HashSet<String>,
357 pub new_domains: HashSet<String>,
359 pub changed_domains: HashSet<String>,
361 pub removed_domains: HashSet<String>,
363 pub config_changed: bool,
365}
366
367impl ChangeSet {
368 fn new() -> Self {
369 Self::default()
370 }
371
372 pub fn has_changes(&self) -> bool {
374 !self.new_predicates.is_empty()
375 || !self.changed_predicates.is_empty()
376 || !self.removed_predicates.is_empty()
377 || !self.new_domains.is_empty()
378 || !self.changed_domains.is_empty()
379 || !self.removed_domains.is_empty()
380 || self.config_changed
381 }
382
383 pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
385 if self.config_changed {
387 return true;
388 }
389
390 for pred in &deps.predicates {
392 if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
393 return true;
394 }
395 }
396
397 for domain in &deps.domains {
399 if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
400 return true;
401 }
402 }
403
404 false
405 }
406}
407
408#[derive(Debug, Clone)]
410struct CacheEntry {
411 graph: EinsumGraph,
413 dependencies: ExpressionDependencies,
415 #[allow(dead_code)]
417 timestamp: u64,
418}
419
420pub struct IncrementalCompiler {
422 context: CompilerContext,
424 cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
426 change_detector: ChangeDetector,
428 stats: Arc<Mutex<IncrementalStats>>,
430 next_timestamp: Arc<Mutex<u64>>,
432}
433
434impl IncrementalCompiler {
435 pub fn new(context: CompilerContext) -> Self {
437 let mut change_detector = ChangeDetector::new();
438 change_detector.update(&context);
439
440 Self {
441 context,
442 cache: Arc::new(Mutex::new(HashMap::new())),
443 change_detector,
444 stats: Arc::new(Mutex::new(IncrementalStats::default())),
445 next_timestamp: Arc::new(Mutex::new(0)),
446 }
447 }
448
449 pub fn context(&self) -> &CompilerContext {
451 &self.context
452 }
453
454 pub fn context_mut(&mut self) -> &mut CompilerContext {
456 &mut self.context
457 }
458
459 pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
461 let changes = self.change_detector.detect_changes(&self.context);
463
464 if changes.has_changes() {
466 self.invalidate_affected(&changes);
467 self.change_detector.update(&self.context);
468 }
469
470 let expr_key = format!("{:?}", expr);
472 let cache = self.cache.lock().unwrap();
473
474 if let Some(entry) = cache.get(&expr_key) {
475 let mut stats = self.stats.lock().unwrap();
477 stats.cache_hits += 1;
478 stats.nodes_reused += entry.graph.nodes.len();
479 drop(stats);
480
481 return Ok(entry.graph.clone());
482 }
483
484 drop(cache);
486
487 let deps = ExpressionDependencies::analyze(expr, &self.context);
488 let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
491 IrError::InvalidEinsumSpec {
492 spec: format!("{:?}", expr),
493 reason: format!("Compilation failed: {}", e),
494 }
495 })?;
496
497 let mut stats = self.stats.lock().unwrap();
499 stats.cache_misses += 1;
500 stats.nodes_compiled += graph.nodes.len();
501 drop(stats);
502
503 let mut timestamp_guard = self.next_timestamp.lock().unwrap();
505 let timestamp = *timestamp_guard;
506 *timestamp_guard += 1;
507 drop(timestamp_guard);
508
509 let mut cache = self.cache.lock().unwrap();
510 cache.insert(
511 expr_key,
512 CacheEntry {
513 graph: graph.clone(),
514 dependencies: deps,
515 timestamp,
516 },
517 );
518
519 Ok(graph)
520 }
521
522 fn invalidate_affected(&mut self, changes: &ChangeSet) {
524 let mut cache = self.cache.lock().unwrap();
525 cache.retain(|_, entry| !changes.affects(&entry.dependencies));
526
527 let mut stats = self.stats.lock().unwrap();
528 stats.invalidations += 1;
529 }
530
531 pub fn clear_cache(&mut self) {
533 let mut cache = self.cache.lock().unwrap();
534 cache.clear();
535 }
536
537 pub fn stats(&self) -> IncrementalStats {
539 self.stats.lock().unwrap().clone()
540 }
541
542 pub fn reset_stats(&mut self) {
544 let mut stats = self.stats.lock().unwrap();
545 *stats = IncrementalStats::default();
546 }
547}
548
549#[derive(Debug, Clone, Default)]
551pub struct IncrementalStats {
552 pub cache_hits: usize,
554 pub cache_misses: usize,
556 pub invalidations: usize,
558 pub nodes_reused: usize,
560 pub nodes_compiled: usize,
562}
563
564impl IncrementalStats {
565 pub fn hit_rate(&self) -> f64 {
567 let total = self.cache_hits + self.cache_misses;
568 if total == 0 {
569 0.0
570 } else {
571 self.cache_hits as f64 / total as f64
572 }
573 }
574
575 pub fn reuse_rate(&self) -> f64 {
577 let total = self.nodes_reused + self.nodes_compiled;
578 if total == 0 {
579 0.0
580 } else {
581 self.nodes_reused as f64 / total as f64
582 }
583 }
584
585 pub fn total_compilations(&self) -> usize {
587 self.cache_hits + self.cache_misses
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_dependency_tracking() {
597 let mut ctx = CompilerContext::new();
598 ctx.add_domain("Person", 100);
599
600 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
601 let deps = ExpressionDependencies::analyze(&expr, &ctx);
602
603 assert!(deps.predicates.contains("knows"));
604 assert!(deps.variables.contains("x"));
605 assert!(deps.variables.contains("y"));
606 }
607
608 #[test]
609 fn test_incremental_compilation_reuse() {
610 let mut ctx = CompilerContext::new();
611 ctx.add_domain("Person", 100);
612
613 let mut compiler = IncrementalCompiler::new(ctx);
614
615 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
616
617 let _graph1 = compiler.compile(&expr).unwrap();
619 assert_eq!(compiler.stats().cache_misses, 1);
620 assert_eq!(compiler.stats().cache_hits, 0);
621
622 let _graph2 = compiler.compile(&expr).unwrap();
624 assert_eq!(compiler.stats().cache_misses, 1);
625 assert_eq!(compiler.stats().cache_hits, 1);
626 assert_eq!(compiler.stats().hit_rate(), 0.5);
627 }
628
629 #[test]
630 fn test_change_detection_domain() {
631 let mut ctx = CompilerContext::new();
632 ctx.add_domain("Person", 100);
633
634 let mut detector = ChangeDetector::new();
635 detector.update(&ctx);
636
637 let changes = detector.detect_changes(&ctx);
639 assert!(!changes.has_changes());
640
641 ctx.add_domain("Person", 200);
643 let changes = detector.detect_changes(&ctx);
644 assert!(changes.has_changes());
645 assert!(changes.changed_domains.contains("Person"));
646 }
647
648 #[test]
649 fn test_invalidation_on_domain_change() {
650 let mut ctx = CompilerContext::new();
651 ctx.add_domain("Person", 100);
652
653 let mut compiler = IncrementalCompiler::new(ctx);
654
655 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
656
657 let _graph1 = compiler.compile(&expr).unwrap();
659 assert_eq!(compiler.stats().cache_misses, 1);
660
661 compiler.context_mut().add_domain("Person", 200);
663
664 let _graph2 = compiler.compile(&expr).unwrap();
666 assert!(compiler.stats().cache_misses >= 1);
668 assert!(compiler.stats().invalidations >= 1);
669 }
670
671 #[test]
672 fn test_incremental_stats() {
673 let mut ctx = CompilerContext::new();
674 ctx.add_domain("Person", 100);
675
676 let mut compiler = IncrementalCompiler::new(ctx);
677
678 let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
679 let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
680
681 compiler.compile(&expr1).unwrap();
682 compiler.compile(&expr1).unwrap(); compiler.compile(&expr2).unwrap();
684
685 let stats = compiler.stats();
686 assert_eq!(stats.total_compilations(), 3);
687 assert!(
689 stats.cache_hits >= 1,
690 "Expected at least 1 cache hit, got {}",
691 stats.cache_hits
692 );
693 assert!(
695 stats.hit_rate() > 0.0,
696 "Expected positive hit rate, got {}",
697 stats.hit_rate()
698 );
699 }
700
701 #[test]
702 fn test_complex_expression_dependencies() {
703 let mut ctx = CompilerContext::new();
704 ctx.add_domain("Person", 100);
705
706 let expr = TLExpr::exists(
707 "x",
708 "Person",
709 TLExpr::and(
710 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
711 TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
712 ),
713 );
714
715 let deps = ExpressionDependencies::analyze(&expr, &ctx);
716
717 assert!(deps.predicates.contains("knows"));
718 assert!(deps.predicates.contains("likes"));
719 assert!(deps.variables.contains("x"));
720 assert!(deps.variables.contains("y"));
721 assert!(deps.variables.contains("z"));
722 assert!(deps.domains.contains("Person"));
723 }
724}