1use crate::{compile_to_einsum_with_context, CompilerContext};
51use std::collections::{HashMap, HashSet};
52use std::sync::{Arc, Mutex};
53use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct ExpressionDependencies {
58 pub predicates: HashSet<String>,
60 pub variables: HashSet<String>,
62 pub domains: HashSet<String>,
64 pub config_hash: u64,
66}
67
68impl ExpressionDependencies {
69 pub fn new() -> Self {
71 Self {
72 predicates: HashSet::new(),
73 variables: HashSet::new(),
74 domains: HashSet::new(),
75 config_hash: 0,
76 }
77 }
78
79 pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
81 let mut deps = Self::new();
82 deps.analyze_recursive(expr);
83 deps.config_hash = Self::hash_config(ctx);
84 deps
85 }
86
87 fn analyze_recursive(&mut self, expr: &TLExpr) {
88 match expr {
89 TLExpr::Pred { name, args } => {
90 self.predicates.insert(name.clone());
91 for arg in args {
92 self.analyze_term(arg);
93 }
94 }
95 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
96 self.analyze_recursive(left);
97 self.analyze_recursive(right);
98 }
99 TLExpr::Not(inner) => {
100 self.analyze_recursive(inner);
101 }
102 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
103 self.variables.insert(var.clone());
104 self.domains.insert(domain.clone());
105 self.analyze_recursive(body);
106 }
107 TLExpr::Score(inner) => {
108 self.analyze_recursive(inner);
109 }
110 TLExpr::Add(left, right)
111 | TLExpr::Sub(left, right)
112 | TLExpr::Mul(left, right)
113 | TLExpr::Div(left, right) => {
114 self.analyze_recursive(left);
115 self.analyze_recursive(right);
116 }
117 TLExpr::Eq(left, right)
118 | TLExpr::Lt(left, right)
119 | TLExpr::Gt(left, right)
120 | TLExpr::Lte(left, right)
121 | TLExpr::Gte(left, right) => {
122 self.analyze_recursive(left);
123 self.analyze_recursive(right);
124 }
125 TLExpr::IfThenElse {
126 condition,
127 then_branch,
128 else_branch,
129 } => {
130 self.analyze_recursive(condition);
131 self.analyze_recursive(then_branch);
132 self.analyze_recursive(else_branch);
133 }
134 TLExpr::Aggregate {
135 op: _,
136 var,
137 domain,
138 body,
139 group_by,
140 } => {
141 self.variables.insert(var.clone());
142 self.domains.insert(domain.clone());
143 self.analyze_recursive(body);
144 if let Some(gb_vars) = group_by {
145 for var_name in gb_vars {
146 self.variables.insert(var_name.clone());
147 }
148 }
149 }
150 TLExpr::TNorm {
151 kind: _,
152 left,
153 right,
154 }
155 | TLExpr::TCoNorm {
156 kind: _,
157 left,
158 right,
159 } => {
160 self.analyze_recursive(left);
161 self.analyze_recursive(right);
162 }
163 TLExpr::FuzzyNot {
164 kind: _,
165 expr: inner,
166 } => {
167 self.analyze_recursive(inner);
168 }
169 TLExpr::FuzzyImplication {
170 kind: _,
171 premise,
172 conclusion,
173 } => {
174 self.analyze_recursive(premise);
175 self.analyze_recursive(conclusion);
176 }
177 TLExpr::SoftExists {
178 var,
179 domain,
180 body,
181 temperature: _,
182 }
183 | TLExpr::SoftForAll {
184 var,
185 domain,
186 body,
187 temperature: _,
188 } => {
189 self.variables.insert(var.clone());
190 self.domains.insert(domain.clone());
191 self.analyze_recursive(body);
192 }
193 TLExpr::WeightedRule { weight: _, rule } => {
194 self.analyze_recursive(rule);
195 }
196 TLExpr::ProbabilisticChoice { alternatives } => {
197 for (_, alt) in alternatives {
198 self.analyze_recursive(alt);
199 }
200 }
201 TLExpr::Let { var, value, body } => {
202 self.variables.insert(var.clone());
203 self.analyze_recursive(value);
204 self.analyze_recursive(body);
205 }
206 TLExpr::Box(inner)
207 | TLExpr::Diamond(inner)
208 | TLExpr::Next(inner)
209 | TLExpr::Eventually(inner)
210 | TLExpr::Always(inner) => {
211 self.analyze_recursive(inner);
212 }
213 TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
214 self.analyze_recursive(before);
215 self.analyze_recursive(after);
216 }
217 TLExpr::Release { released, releaser }
218 | TLExpr::StrongRelease { released, releaser } => {
219 self.analyze_recursive(released);
220 self.analyze_recursive(releaser);
221 }
222 TLExpr::Abs(inner)
224 | TLExpr::Sqrt(inner)
225 | TLExpr::Exp(inner)
226 | TLExpr::Log(inner)
227 | TLExpr::Sin(inner)
228 | TLExpr::Cos(inner)
229 | TLExpr::Tan(inner)
230 | TLExpr::Floor(inner)
231 | TLExpr::Ceil(inner)
232 | TLExpr::Round(inner) => {
233 self.analyze_recursive(inner);
234 }
235 TLExpr::Pow(left, right)
236 | TLExpr::Min(left, right)
237 | TLExpr::Max(left, right)
238 | TLExpr::Mod(left, right) => {
239 self.analyze_recursive(left);
240 self.analyze_recursive(right);
241 }
242 TLExpr::Constant(_) => {
243 }
245 _ => {
247 }
250 }
251 }
252
253 fn analyze_term(&mut self, term: &Term) {
254 if let Term::Var(name) = term {
255 self.variables.insert(name.clone());
256 }
257 }
258
259 fn hash_config(ctx: &CompilerContext) -> u64 {
260 use std::collections::hash_map::DefaultHasher;
261 use std::hash::{Hash, Hasher};
262
263 let mut hasher = DefaultHasher::new();
264 format!("{:?}", ctx.config).hash(&mut hasher);
266 hasher.finish()
267 }
268}
269
270impl Default for ExpressionDependencies {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct ChangeDetector {
279 previous_predicates: HashMap<String, (usize, Vec<String>)>,
281 previous_domains: HashMap<String, usize>,
283 previous_config_hash: u64,
285}
286
287impl ChangeDetector {
288 pub fn new() -> Self {
290 Self {
291 previous_predicates: HashMap::new(),
292 previous_domains: HashMap::new(),
293 previous_config_hash: 0,
294 }
295 }
296
297 pub fn update(&mut self, ctx: &CompilerContext) {
299 self.previous_predicates.clear();
300 self.previous_domains.clear();
301
302 for (name, info) in &ctx.domains {
304 self.previous_domains.insert(name.clone(), info.cardinality);
305 }
306
307 self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
308 }
309
310 pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
312 let mut changes = ChangeSet::new();
313
314 for (name, info) in &ctx.domains {
316 if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
317 if prev_size != info.cardinality {
318 changes.changed_domains.insert(name.clone());
319 }
320 } else {
321 changes.new_domains.insert(name.clone());
322 }
323 }
324
325 for name in self.previous_domains.keys() {
327 if !ctx.domains.contains_key(name) {
328 changes.removed_domains.insert(name.clone());
329 }
330 }
331
332 let current_hash = ExpressionDependencies::hash_config(ctx);
334 if current_hash != self.previous_config_hash {
335 changes.config_changed = true;
336 }
337
338 changes
339 }
340}
341
342impl Default for ChangeDetector {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[derive(Debug, Clone, Default)]
350pub struct ChangeSet {
351 pub new_predicates: HashSet<String>,
353 pub changed_predicates: HashSet<String>,
355 pub removed_predicates: HashSet<String>,
357 pub new_domains: HashSet<String>,
359 pub changed_domains: HashSet<String>,
361 pub removed_domains: HashSet<String>,
363 pub config_changed: bool,
365}
366
367impl ChangeSet {
368 fn new() -> Self {
369 Self::default()
370 }
371
372 pub fn has_changes(&self) -> bool {
374 !self.new_predicates.is_empty()
375 || !self.changed_predicates.is_empty()
376 || !self.removed_predicates.is_empty()
377 || !self.new_domains.is_empty()
378 || !self.changed_domains.is_empty()
379 || !self.removed_domains.is_empty()
380 || self.config_changed
381 }
382
383 pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
385 if self.config_changed {
387 return true;
388 }
389
390 for pred in &deps.predicates {
392 if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
393 return true;
394 }
395 }
396
397 for domain in &deps.domains {
399 if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
400 return true;
401 }
402 }
403
404 false
405 }
406}
407
408#[derive(Debug, Clone)]
410struct CacheEntry {
411 graph: EinsumGraph,
413 dependencies: ExpressionDependencies,
415 #[allow(dead_code)]
417 timestamp: u64,
418}
419
420pub struct IncrementalCompiler {
422 context: CompilerContext,
424 cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
426 change_detector: ChangeDetector,
428 stats: Arc<Mutex<IncrementalStats>>,
430 next_timestamp: Arc<Mutex<u64>>,
432}
433
434impl IncrementalCompiler {
435 pub fn new(context: CompilerContext) -> Self {
437 let mut change_detector = ChangeDetector::new();
438 change_detector.update(&context);
439
440 Self {
441 context,
442 cache: Arc::new(Mutex::new(HashMap::new())),
443 change_detector,
444 stats: Arc::new(Mutex::new(IncrementalStats::default())),
445 next_timestamp: Arc::new(Mutex::new(0)),
446 }
447 }
448
449 pub fn context(&self) -> &CompilerContext {
451 &self.context
452 }
453
454 pub fn context_mut(&mut self) -> &mut CompilerContext {
456 &mut self.context
457 }
458
459 pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
461 let changes = self.change_detector.detect_changes(&self.context);
463
464 if changes.has_changes() {
466 self.invalidate_affected(&changes);
467 self.change_detector.update(&self.context);
468 }
469
470 let expr_key = format!("{:?}", expr);
472 let cache = self.cache.lock().expect("lock should not be poisoned");
473
474 if let Some(entry) = cache.get(&expr_key) {
475 let mut stats = self.stats.lock().expect("lock should not be poisoned");
477 stats.cache_hits += 1;
478 stats.nodes_reused += entry.graph.nodes.len();
479 drop(stats);
480
481 return Ok(entry.graph.clone());
482 }
483
484 drop(cache);
486
487 let deps = ExpressionDependencies::analyze(expr, &self.context);
488 let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
491 IrError::InvalidEinsumSpec {
492 spec: format!("{:?}", expr),
493 reason: format!("Compilation failed: {}", e),
494 }
495 })?;
496
497 let mut stats = self.stats.lock().expect("lock should not be poisoned");
499 stats.cache_misses += 1;
500 stats.nodes_compiled += graph.nodes.len();
501 drop(stats);
502
503 let mut timestamp_guard = self
505 .next_timestamp
506 .lock()
507 .expect("lock should not be poisoned");
508 let timestamp = *timestamp_guard;
509 *timestamp_guard += 1;
510 drop(timestamp_guard);
511
512 let mut cache = self.cache.lock().expect("lock should not be poisoned");
513 cache.insert(
514 expr_key,
515 CacheEntry {
516 graph: graph.clone(),
517 dependencies: deps,
518 timestamp,
519 },
520 );
521
522 Ok(graph)
523 }
524
525 fn invalidate_affected(&mut self, changes: &ChangeSet) {
527 let mut cache = self.cache.lock().expect("lock should not be poisoned");
528 cache.retain(|_, entry| !changes.affects(&entry.dependencies));
529
530 let mut stats = self.stats.lock().expect("lock should not be poisoned");
531 stats.invalidations += 1;
532 }
533
534 pub fn clear_cache(&mut self) {
536 let mut cache = self.cache.lock().expect("lock should not be poisoned");
537 cache.clear();
538 }
539
540 pub fn stats(&self) -> IncrementalStats {
542 self.stats
543 .lock()
544 .expect("lock should not be poisoned")
545 .clone()
546 }
547
548 pub fn reset_stats(&mut self) {
550 let mut stats = self.stats.lock().expect("lock should not be poisoned");
551 *stats = IncrementalStats::default();
552 }
553}
554
555#[derive(Debug, Clone, Default)]
557pub struct IncrementalStats {
558 pub cache_hits: usize,
560 pub cache_misses: usize,
562 pub invalidations: usize,
564 pub nodes_reused: usize,
566 pub nodes_compiled: usize,
568}
569
570impl IncrementalStats {
571 pub fn hit_rate(&self) -> f64 {
573 let total = self.cache_hits + self.cache_misses;
574 if total == 0 {
575 0.0
576 } else {
577 self.cache_hits as f64 / total as f64
578 }
579 }
580
581 pub fn reuse_rate(&self) -> f64 {
583 let total = self.nodes_reused + self.nodes_compiled;
584 if total == 0 {
585 0.0
586 } else {
587 self.nodes_reused as f64 / total as f64
588 }
589 }
590
591 pub fn total_compilations(&self) -> usize {
593 self.cache_hits + self.cache_misses
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_dependency_tracking() {
603 let mut ctx = CompilerContext::new();
604 ctx.add_domain("Person", 100);
605
606 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
607 let deps = ExpressionDependencies::analyze(&expr, &ctx);
608
609 assert!(deps.predicates.contains("knows"));
610 assert!(deps.variables.contains("x"));
611 assert!(deps.variables.contains("y"));
612 }
613
614 #[test]
615 fn test_incremental_compilation_reuse() {
616 let mut ctx = CompilerContext::new();
617 ctx.add_domain("Person", 100);
618
619 let mut compiler = IncrementalCompiler::new(ctx);
620
621 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
622
623 let _graph1 = compiler.compile(&expr).expect("unwrap");
625 assert_eq!(compiler.stats().cache_misses, 1);
626 assert_eq!(compiler.stats().cache_hits, 0);
627
628 let _graph2 = compiler.compile(&expr).expect("unwrap");
630 assert_eq!(compiler.stats().cache_misses, 1);
631 assert_eq!(compiler.stats().cache_hits, 1);
632 assert_eq!(compiler.stats().hit_rate(), 0.5);
633 }
634
635 #[test]
636 fn test_change_detection_domain() {
637 let mut ctx = CompilerContext::new();
638 ctx.add_domain("Person", 100);
639
640 let mut detector = ChangeDetector::new();
641 detector.update(&ctx);
642
643 let changes = detector.detect_changes(&ctx);
645 assert!(!changes.has_changes());
646
647 ctx.add_domain("Person", 200);
649 let changes = detector.detect_changes(&ctx);
650 assert!(changes.has_changes());
651 assert!(changes.changed_domains.contains("Person"));
652 }
653
654 #[test]
655 fn test_invalidation_on_domain_change() {
656 let mut ctx = CompilerContext::new();
657 ctx.add_domain("Person", 100);
658
659 let mut compiler = IncrementalCompiler::new(ctx);
660
661 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
662
663 let _graph1 = compiler.compile(&expr).expect("unwrap");
665 assert_eq!(compiler.stats().cache_misses, 1);
666
667 compiler.context_mut().add_domain("Person", 200);
669
670 let _graph2 = compiler.compile(&expr).expect("unwrap");
672 assert!(compiler.stats().cache_misses >= 1);
674 assert!(compiler.stats().invalidations >= 1);
675 }
676
677 #[test]
678 fn test_incremental_stats() {
679 let mut ctx = CompilerContext::new();
680 ctx.add_domain("Person", 100);
681
682 let mut compiler = IncrementalCompiler::new(ctx);
683
684 let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
685 let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
686
687 compiler.compile(&expr1).expect("unwrap");
688 compiler.compile(&expr1).expect("unwrap"); compiler.compile(&expr2).expect("unwrap");
690
691 let stats = compiler.stats();
692 assert_eq!(stats.total_compilations(), 3);
693 assert!(
695 stats.cache_hits >= 1,
696 "Expected at least 1 cache hit, got {}",
697 stats.cache_hits
698 );
699 assert!(
701 stats.hit_rate() > 0.0,
702 "Expected positive hit rate, got {}",
703 stats.hit_rate()
704 );
705 }
706
707 #[test]
708 fn test_complex_expression_dependencies() {
709 let mut ctx = CompilerContext::new();
710 ctx.add_domain("Person", 100);
711
712 let expr = TLExpr::exists(
713 "x",
714 "Person",
715 TLExpr::and(
716 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
717 TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
718 ),
719 );
720
721 let deps = ExpressionDependencies::analyze(&expr, &ctx);
722
723 assert!(deps.predicates.contains("knows"));
724 assert!(deps.predicates.contains("likes"));
725 assert!(deps.variables.contains("x"));
726 assert!(deps.variables.contains("y"));
727 assert!(deps.variables.contains("z"));
728 assert!(deps.domains.contains("Person"));
729 }
730}