1use crate::{compile_to_einsum_with_context, CompilerContext};
51use std::collections::{HashMap, HashSet};
52use std::sync::{Arc, Mutex};
53use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct ExpressionDependencies {
58 pub predicates: HashSet<String>,
60 pub variables: HashSet<String>,
62 pub domains: HashSet<String>,
64 pub config_hash: u64,
66}
67
68impl ExpressionDependencies {
69 pub fn new() -> Self {
71 Self {
72 predicates: HashSet::new(),
73 variables: HashSet::new(),
74 domains: HashSet::new(),
75 config_hash: 0,
76 }
77 }
78
79 pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
81 let mut deps = Self::new();
82 deps.analyze_recursive(expr);
83 deps.config_hash = Self::hash_config(ctx);
84 deps
85 }
86
87 fn analyze_recursive(&mut self, expr: &TLExpr) {
88 match expr {
89 TLExpr::Pred { name, args } => {
90 self.predicates.insert(name.clone());
91 for arg in args {
92 self.analyze_term(arg);
93 }
94 }
95 TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
96 self.analyze_recursive(left);
97 self.analyze_recursive(right);
98 }
99 TLExpr::Not(inner) => {
100 self.analyze_recursive(inner);
101 }
102 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
103 self.variables.insert(var.clone());
104 self.domains.insert(domain.clone());
105 self.analyze_recursive(body);
106 }
107 TLExpr::Score(inner) => {
108 self.analyze_recursive(inner);
109 }
110 TLExpr::Add(left, right)
111 | TLExpr::Sub(left, right)
112 | TLExpr::Mul(left, right)
113 | TLExpr::Div(left, right) => {
114 self.analyze_recursive(left);
115 self.analyze_recursive(right);
116 }
117 TLExpr::Eq(left, right)
118 | TLExpr::Lt(left, right)
119 | TLExpr::Gt(left, right)
120 | TLExpr::Lte(left, right)
121 | TLExpr::Gte(left, right) => {
122 self.analyze_recursive(left);
123 self.analyze_recursive(right);
124 }
125 TLExpr::IfThenElse {
126 condition,
127 then_branch,
128 else_branch,
129 } => {
130 self.analyze_recursive(condition);
131 self.analyze_recursive(then_branch);
132 self.analyze_recursive(else_branch);
133 }
134 TLExpr::Aggregate {
135 op: _,
136 var,
137 domain,
138 body,
139 group_by,
140 } => {
141 self.variables.insert(var.clone());
142 self.domains.insert(domain.clone());
143 self.analyze_recursive(body);
144 if let Some(gb_vars) = group_by {
145 for var_name in gb_vars {
146 self.variables.insert(var_name.clone());
147 }
148 }
149 }
150 TLExpr::TNorm {
151 kind: _,
152 left,
153 right,
154 }
155 | TLExpr::TCoNorm {
156 kind: _,
157 left,
158 right,
159 } => {
160 self.analyze_recursive(left);
161 self.analyze_recursive(right);
162 }
163 TLExpr::FuzzyNot {
164 kind: _,
165 expr: inner,
166 } => {
167 self.analyze_recursive(inner);
168 }
169 TLExpr::FuzzyImplication {
170 kind: _,
171 premise,
172 conclusion,
173 } => {
174 self.analyze_recursive(premise);
175 self.analyze_recursive(conclusion);
176 }
177 TLExpr::SoftExists {
178 var,
179 domain,
180 body,
181 temperature: _,
182 }
183 | TLExpr::SoftForAll {
184 var,
185 domain,
186 body,
187 temperature: _,
188 } => {
189 self.variables.insert(var.clone());
190 self.domains.insert(domain.clone());
191 self.analyze_recursive(body);
192 }
193 TLExpr::WeightedRule { weight: _, rule } => {
194 self.analyze_recursive(rule);
195 }
196 TLExpr::ProbabilisticChoice { alternatives } => {
197 for (_, alt) in alternatives {
198 self.analyze_recursive(alt);
199 }
200 }
201 TLExpr::Let { var, value, body } => {
202 self.variables.insert(var.clone());
203 self.analyze_recursive(value);
204 self.analyze_recursive(body);
205 }
206 TLExpr::Box(inner)
207 | TLExpr::Diamond(inner)
208 | TLExpr::Next(inner)
209 | TLExpr::Eventually(inner)
210 | TLExpr::Always(inner) => {
211 self.analyze_recursive(inner);
212 }
213 TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
214 self.analyze_recursive(before);
215 self.analyze_recursive(after);
216 }
217 TLExpr::Release { released, releaser }
218 | TLExpr::StrongRelease { released, releaser } => {
219 self.analyze_recursive(released);
220 self.analyze_recursive(releaser);
221 }
222 TLExpr::Abs(inner)
224 | TLExpr::Sqrt(inner)
225 | TLExpr::Exp(inner)
226 | TLExpr::Log(inner)
227 | TLExpr::Sin(inner)
228 | TLExpr::Cos(inner)
229 | TLExpr::Tan(inner)
230 | TLExpr::Floor(inner)
231 | TLExpr::Ceil(inner)
232 | TLExpr::Round(inner) => {
233 self.analyze_recursive(inner);
234 }
235 TLExpr::Pow(left, right)
236 | TLExpr::Min(left, right)
237 | TLExpr::Max(left, right)
238 | TLExpr::Mod(left, right) => {
239 self.analyze_recursive(left);
240 self.analyze_recursive(right);
241 }
242 TLExpr::Constant(_) => {
243 }
245 }
246 }
247
248 fn analyze_term(&mut self, term: &Term) {
249 if let Term::Var(name) = term {
250 self.variables.insert(name.clone());
251 }
252 }
253
254 fn hash_config(ctx: &CompilerContext) -> u64 {
255 use std::collections::hash_map::DefaultHasher;
256 use std::hash::{Hash, Hasher};
257
258 let mut hasher = DefaultHasher::new();
259 format!("{:?}", ctx.config).hash(&mut hasher);
261 hasher.finish()
262 }
263}
264
265impl Default for ExpressionDependencies {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct ChangeDetector {
274 previous_predicates: HashMap<String, (usize, Vec<String>)>,
276 previous_domains: HashMap<String, usize>,
278 previous_config_hash: u64,
280}
281
282impl ChangeDetector {
283 pub fn new() -> Self {
285 Self {
286 previous_predicates: HashMap::new(),
287 previous_domains: HashMap::new(),
288 previous_config_hash: 0,
289 }
290 }
291
292 pub fn update(&mut self, ctx: &CompilerContext) {
294 self.previous_predicates.clear();
295 self.previous_domains.clear();
296
297 for (name, info) in &ctx.domains {
299 self.previous_domains.insert(name.clone(), info.cardinality);
300 }
301
302 self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
303 }
304
305 pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
307 let mut changes = ChangeSet::new();
308
309 for (name, info) in &ctx.domains {
311 if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
312 if prev_size != info.cardinality {
313 changes.changed_domains.insert(name.clone());
314 }
315 } else {
316 changes.new_domains.insert(name.clone());
317 }
318 }
319
320 for name in self.previous_domains.keys() {
322 if !ctx.domains.contains_key(name) {
323 changes.removed_domains.insert(name.clone());
324 }
325 }
326
327 let current_hash = ExpressionDependencies::hash_config(ctx);
329 if current_hash != self.previous_config_hash {
330 changes.config_changed = true;
331 }
332
333 changes
334 }
335}
336
337impl Default for ChangeDetector {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[derive(Debug, Clone, Default)]
345pub struct ChangeSet {
346 pub new_predicates: HashSet<String>,
348 pub changed_predicates: HashSet<String>,
350 pub removed_predicates: HashSet<String>,
352 pub new_domains: HashSet<String>,
354 pub changed_domains: HashSet<String>,
356 pub removed_domains: HashSet<String>,
358 pub config_changed: bool,
360}
361
362impl ChangeSet {
363 fn new() -> Self {
364 Self::default()
365 }
366
367 pub fn has_changes(&self) -> bool {
369 !self.new_predicates.is_empty()
370 || !self.changed_predicates.is_empty()
371 || !self.removed_predicates.is_empty()
372 || !self.new_domains.is_empty()
373 || !self.changed_domains.is_empty()
374 || !self.removed_domains.is_empty()
375 || self.config_changed
376 }
377
378 pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
380 if self.config_changed {
382 return true;
383 }
384
385 for pred in &deps.predicates {
387 if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
388 return true;
389 }
390 }
391
392 for domain in &deps.domains {
394 if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
395 return true;
396 }
397 }
398
399 false
400 }
401}
402
403#[derive(Debug, Clone)]
405struct CacheEntry {
406 graph: EinsumGraph,
408 dependencies: ExpressionDependencies,
410 #[allow(dead_code)]
412 timestamp: u64,
413}
414
415pub struct IncrementalCompiler {
417 context: CompilerContext,
419 cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
421 change_detector: ChangeDetector,
423 stats: Arc<Mutex<IncrementalStats>>,
425 next_timestamp: Arc<Mutex<u64>>,
427}
428
429impl IncrementalCompiler {
430 pub fn new(context: CompilerContext) -> Self {
432 let mut change_detector = ChangeDetector::new();
433 change_detector.update(&context);
434
435 Self {
436 context,
437 cache: Arc::new(Mutex::new(HashMap::new())),
438 change_detector,
439 stats: Arc::new(Mutex::new(IncrementalStats::default())),
440 next_timestamp: Arc::new(Mutex::new(0)),
441 }
442 }
443
444 pub fn context(&self) -> &CompilerContext {
446 &self.context
447 }
448
449 pub fn context_mut(&mut self) -> &mut CompilerContext {
451 &mut self.context
452 }
453
454 pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
456 let changes = self.change_detector.detect_changes(&self.context);
458
459 if changes.has_changes() {
461 self.invalidate_affected(&changes);
462 self.change_detector.update(&self.context);
463 }
464
465 let expr_key = format!("{:?}", expr);
467 let cache = self.cache.lock().unwrap();
468
469 if let Some(entry) = cache.get(&expr_key) {
470 let mut stats = self.stats.lock().unwrap();
472 stats.cache_hits += 1;
473 stats.nodes_reused += entry.graph.nodes.len();
474 drop(stats);
475
476 return Ok(entry.graph.clone());
477 }
478
479 drop(cache);
481
482 let deps = ExpressionDependencies::analyze(expr, &self.context);
483 let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
486 IrError::InvalidEinsumSpec {
487 spec: format!("{:?}", expr),
488 reason: format!("Compilation failed: {}", e),
489 }
490 })?;
491
492 let mut stats = self.stats.lock().unwrap();
494 stats.cache_misses += 1;
495 stats.nodes_compiled += graph.nodes.len();
496 drop(stats);
497
498 let mut timestamp_guard = self.next_timestamp.lock().unwrap();
500 let timestamp = *timestamp_guard;
501 *timestamp_guard += 1;
502 drop(timestamp_guard);
503
504 let mut cache = self.cache.lock().unwrap();
505 cache.insert(
506 expr_key,
507 CacheEntry {
508 graph: graph.clone(),
509 dependencies: deps,
510 timestamp,
511 },
512 );
513
514 Ok(graph)
515 }
516
517 fn invalidate_affected(&mut self, changes: &ChangeSet) {
519 let mut cache = self.cache.lock().unwrap();
520 cache.retain(|_, entry| !changes.affects(&entry.dependencies));
521
522 let mut stats = self.stats.lock().unwrap();
523 stats.invalidations += 1;
524 }
525
526 pub fn clear_cache(&mut self) {
528 let mut cache = self.cache.lock().unwrap();
529 cache.clear();
530 }
531
532 pub fn stats(&self) -> IncrementalStats {
534 self.stats.lock().unwrap().clone()
535 }
536
537 pub fn reset_stats(&mut self) {
539 let mut stats = self.stats.lock().unwrap();
540 *stats = IncrementalStats::default();
541 }
542}
543
544#[derive(Debug, Clone, Default)]
546pub struct IncrementalStats {
547 pub cache_hits: usize,
549 pub cache_misses: usize,
551 pub invalidations: usize,
553 pub nodes_reused: usize,
555 pub nodes_compiled: usize,
557}
558
559impl IncrementalStats {
560 pub fn hit_rate(&self) -> f64 {
562 let total = self.cache_hits + self.cache_misses;
563 if total == 0 {
564 0.0
565 } else {
566 self.cache_hits as f64 / total as f64
567 }
568 }
569
570 pub fn reuse_rate(&self) -> f64 {
572 let total = self.nodes_reused + self.nodes_compiled;
573 if total == 0 {
574 0.0
575 } else {
576 self.nodes_reused as f64 / total as f64
577 }
578 }
579
580 pub fn total_compilations(&self) -> usize {
582 self.cache_hits + self.cache_misses
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_dependency_tracking() {
592 let mut ctx = CompilerContext::new();
593 ctx.add_domain("Person", 100);
594
595 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
596 let deps = ExpressionDependencies::analyze(&expr, &ctx);
597
598 assert!(deps.predicates.contains("knows"));
599 assert!(deps.variables.contains("x"));
600 assert!(deps.variables.contains("y"));
601 }
602
603 #[test]
604 fn test_incremental_compilation_reuse() {
605 let mut ctx = CompilerContext::new();
606 ctx.add_domain("Person", 100);
607
608 let mut compiler = IncrementalCompiler::new(ctx);
609
610 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
611
612 let _graph1 = compiler.compile(&expr).unwrap();
614 assert_eq!(compiler.stats().cache_misses, 1);
615 assert_eq!(compiler.stats().cache_hits, 0);
616
617 let _graph2 = compiler.compile(&expr).unwrap();
619 assert_eq!(compiler.stats().cache_misses, 1);
620 assert_eq!(compiler.stats().cache_hits, 1);
621 assert_eq!(compiler.stats().hit_rate(), 0.5);
622 }
623
624 #[test]
625 fn test_change_detection_domain() {
626 let mut ctx = CompilerContext::new();
627 ctx.add_domain("Person", 100);
628
629 let mut detector = ChangeDetector::new();
630 detector.update(&ctx);
631
632 let changes = detector.detect_changes(&ctx);
634 assert!(!changes.has_changes());
635
636 ctx.add_domain("Person", 200);
638 let changes = detector.detect_changes(&ctx);
639 assert!(changes.has_changes());
640 assert!(changes.changed_domains.contains("Person"));
641 }
642
643 #[test]
644 fn test_invalidation_on_domain_change() {
645 let mut ctx = CompilerContext::new();
646 ctx.add_domain("Person", 100);
647
648 let mut compiler = IncrementalCompiler::new(ctx);
649
650 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
651
652 let _graph1 = compiler.compile(&expr).unwrap();
654 assert_eq!(compiler.stats().cache_misses, 1);
655
656 compiler.context_mut().add_domain("Person", 200);
658
659 let _graph2 = compiler.compile(&expr).unwrap();
661 assert!(compiler.stats().cache_misses >= 1);
663 assert!(compiler.stats().invalidations >= 1);
664 }
665
666 #[test]
667 fn test_incremental_stats() {
668 let mut ctx = CompilerContext::new();
669 ctx.add_domain("Person", 100);
670
671 let mut compiler = IncrementalCompiler::new(ctx);
672
673 let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
674 let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
675
676 compiler.compile(&expr1).unwrap();
677 compiler.compile(&expr1).unwrap(); compiler.compile(&expr2).unwrap();
679
680 let stats = compiler.stats();
681 assert_eq!(stats.total_compilations(), 3);
682 assert!(
684 stats.cache_hits >= 1,
685 "Expected at least 1 cache hit, got {}",
686 stats.cache_hits
687 );
688 assert!(
690 stats.hit_rate() > 0.0,
691 "Expected positive hit rate, got {}",
692 stats.hit_rate()
693 );
694 }
695
696 #[test]
697 fn test_complex_expression_dependencies() {
698 let mut ctx = CompilerContext::new();
699 ctx.add_domain("Person", 100);
700
701 let expr = TLExpr::exists(
702 "x",
703 "Person",
704 TLExpr::and(
705 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
706 TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
707 ),
708 );
709
710 let deps = ExpressionDependencies::analyze(&expr, &ctx);
711
712 assert!(deps.predicates.contains("knows"));
713 assert!(deps.predicates.contains("likes"));
714 assert!(deps.variables.contains("x"));
715 assert!(deps.variables.contains("y"));
716 assert!(deps.variables.contains("z"));
717 assert!(deps.domains.contains("Person"));
718 }
719}