Skip to main content

tensorlogic_compiler/
incremental.rs

1//! Incremental compilation for efficient recompilation when expressions change.
2//!
3//! This module provides an incremental compilation system that tracks dependencies
4//! and recompiles only the parts of expressions that have changed. This is crucial
5//! for interactive environments like REPLs, notebooks, and IDEs where expressions
6//! are frequently modified.
7//!
8//! # Architecture
9//!
10//! The incremental compilation system consists of three main components:
11//!
12//! 1. **DependencyTracker**: Tracks what each expression depends on (predicates,
13//!    variables, domains, configurations).
14//!
15//! 2. **ChangeDetector**: Detects changes to the compilation context (predicate
16//!    signatures, domains, configurations) and determines what needs recompilation.
17//!
18//! 3. **IncrementalCompiler**: Manages the compilation state, computes minimal
19//!    invalidation sets, and recompiles only affected sub-expressions.
20//!
21//! # Example
22//!
23//! ```rust
24//! use tensorlogic_compiler::{CompilerContext, incremental::IncrementalCompiler};
25//! use tensorlogic_ir::{TLExpr, Term};
26//!
27//! let mut ctx = CompilerContext::new();
28//! ctx.add_domain("Person", 100);
29//!
30//! let mut compiler = IncrementalCompiler::new(ctx);
31//!
32//! // Initial compilation
33//! let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
34//! let graph1 = compiler.compile(&expr1).expect("unwrap");
35//!
36//! // Compile similar expression - some parts will be reused
37//! let expr2 = TLExpr::and(
38//!     TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
39//!     TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
40//! );
41//! let graph2 = compiler.compile(&expr2).expect("unwrap");
42//!
43//! // Check incremental compilation stats
44//! let stats = compiler.stats();
45//! println!("Nodes reused: {}", stats.nodes_reused);
46//! println!("Nodes compiled: {}", stats.nodes_compiled);
47//! println!("Reuse rate: {:.1}%", stats.reuse_rate() * 100.0);
48//! ```
49
50use crate::{compile_to_einsum_with_context, CompilerContext};
51use std::collections::{HashMap, HashSet};
52use std::sync::{Arc, Mutex};
53use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
54
55/// Tracks dependencies of compiled expressions.
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct ExpressionDependencies {
58    /// Predicates referenced in the expression
59    pub predicates: HashSet<String>,
60    /// Variables referenced in the expression
61    pub variables: HashSet<String>,
62    /// Domains used in the expression
63    pub domains: HashSet<String>,
64    /// Configuration hash (to detect strategy changes)
65    pub config_hash: u64,
66}
67
68impl ExpressionDependencies {
69    /// Create a new empty dependency set.
70    pub fn new() -> Self {
71        Self {
72            predicates: HashSet::new(),
73            variables: HashSet::new(),
74            domains: HashSet::new(),
75            config_hash: 0,
76        }
77    }
78
79    /// Analyze an expression and extract its dependencies.
80    pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
81        let mut deps = Self::new();
82        deps.analyze_recursive(expr);
83        deps.config_hash = Self::hash_config(ctx);
84        deps
85    }
86
87    fn analyze_recursive(&mut self, expr: &TLExpr) {
88        match expr {
89            TLExpr::Pred { name, args } => {
90                self.predicates.insert(name.clone());
91                for arg in args {
92                    self.analyze_term(arg);
93                }
94            }
95            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
96                self.analyze_recursive(left);
97                self.analyze_recursive(right);
98            }
99            TLExpr::Not(inner) => {
100                self.analyze_recursive(inner);
101            }
102            TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
103                self.variables.insert(var.clone());
104                self.domains.insert(domain.clone());
105                self.analyze_recursive(body);
106            }
107            TLExpr::Score(inner) => {
108                self.analyze_recursive(inner);
109            }
110            TLExpr::Add(left, right)
111            | TLExpr::Sub(left, right)
112            | TLExpr::Mul(left, right)
113            | TLExpr::Div(left, right) => {
114                self.analyze_recursive(left);
115                self.analyze_recursive(right);
116            }
117            TLExpr::Eq(left, right)
118            | TLExpr::Lt(left, right)
119            | TLExpr::Gt(left, right)
120            | TLExpr::Lte(left, right)
121            | TLExpr::Gte(left, right) => {
122                self.analyze_recursive(left);
123                self.analyze_recursive(right);
124            }
125            TLExpr::IfThenElse {
126                condition,
127                then_branch,
128                else_branch,
129            } => {
130                self.analyze_recursive(condition);
131                self.analyze_recursive(then_branch);
132                self.analyze_recursive(else_branch);
133            }
134            TLExpr::Aggregate {
135                op: _,
136                var,
137                domain,
138                body,
139                group_by,
140            } => {
141                self.variables.insert(var.clone());
142                self.domains.insert(domain.clone());
143                self.analyze_recursive(body);
144                if let Some(gb_vars) = group_by {
145                    for var_name in gb_vars {
146                        self.variables.insert(var_name.clone());
147                    }
148                }
149            }
150            TLExpr::TNorm {
151                kind: _,
152                left,
153                right,
154            }
155            | TLExpr::TCoNorm {
156                kind: _,
157                left,
158                right,
159            } => {
160                self.analyze_recursive(left);
161                self.analyze_recursive(right);
162            }
163            TLExpr::FuzzyNot {
164                kind: _,
165                expr: inner,
166            } => {
167                self.analyze_recursive(inner);
168            }
169            TLExpr::FuzzyImplication {
170                kind: _,
171                premise,
172                conclusion,
173            } => {
174                self.analyze_recursive(premise);
175                self.analyze_recursive(conclusion);
176            }
177            TLExpr::SoftExists {
178                var,
179                domain,
180                body,
181                temperature: _,
182            }
183            | TLExpr::SoftForAll {
184                var,
185                domain,
186                body,
187                temperature: _,
188            } => {
189                self.variables.insert(var.clone());
190                self.domains.insert(domain.clone());
191                self.analyze_recursive(body);
192            }
193            TLExpr::WeightedRule { weight: _, rule } => {
194                self.analyze_recursive(rule);
195            }
196            TLExpr::ProbabilisticChoice { alternatives } => {
197                for (_, alt) in alternatives {
198                    self.analyze_recursive(alt);
199                }
200            }
201            TLExpr::Let { var, value, body } => {
202                self.variables.insert(var.clone());
203                self.analyze_recursive(value);
204                self.analyze_recursive(body);
205            }
206            TLExpr::Box(inner)
207            | TLExpr::Diamond(inner)
208            | TLExpr::Next(inner)
209            | TLExpr::Eventually(inner)
210            | TLExpr::Always(inner) => {
211                self.analyze_recursive(inner);
212            }
213            TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
214                self.analyze_recursive(before);
215                self.analyze_recursive(after);
216            }
217            TLExpr::Release { released, releaser }
218            | TLExpr::StrongRelease { released, releaser } => {
219                self.analyze_recursive(released);
220                self.analyze_recursive(releaser);
221            }
222            // Math operations
223            TLExpr::Abs(inner)
224            | TLExpr::Sqrt(inner)
225            | TLExpr::Exp(inner)
226            | TLExpr::Log(inner)
227            | TLExpr::Sin(inner)
228            | TLExpr::Cos(inner)
229            | TLExpr::Tan(inner)
230            | TLExpr::Floor(inner)
231            | TLExpr::Ceil(inner)
232            | TLExpr::Round(inner) => {
233                self.analyze_recursive(inner);
234            }
235            TLExpr::Pow(left, right)
236            | TLExpr::Min(left, right)
237            | TLExpr::Max(left, right)
238            | TLExpr::Mod(left, right) => {
239                self.analyze_recursive(left);
240                self.analyze_recursive(right);
241            }
242            TLExpr::Constant(_) => {
243                // No dependencies
244            }
245            // All other expression types (enhancements)
246            _ => {
247                // For unhandled variants, recurse on any child expressions if needed
248                // This is a catch-all for future-proof compilation
249            }
250        }
251    }
252
253    fn analyze_term(&mut self, term: &Term) {
254        if let Term::Var(name) = term {
255            self.variables.insert(name.clone());
256        }
257    }
258
259    fn hash_config(ctx: &CompilerContext) -> u64 {
260        use std::collections::hash_map::DefaultHasher;
261        use std::hash::{Hash, Hasher};
262
263        let mut hasher = DefaultHasher::new();
264        // Hash the config strategies
265        format!("{:?}", ctx.config).hash(&mut hasher);
266        hasher.finish()
267    }
268}
269
270impl Default for ExpressionDependencies {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276/// Detects changes to the compilation context.
277#[derive(Debug, Clone)]
278pub struct ChangeDetector {
279    /// Previous predicate signatures
280    previous_predicates: HashMap<String, (usize, Vec<String>)>,
281    /// Previous domain sizes
282    previous_domains: HashMap<String, usize>,
283    /// Previous configuration hash
284    previous_config_hash: u64,
285}
286
287impl ChangeDetector {
288    /// Create a new change detector.
289    pub fn new() -> Self {
290        Self {
291            previous_predicates: HashMap::new(),
292            previous_domains: HashMap::new(),
293            previous_config_hash: 0,
294        }
295    }
296
297    /// Update the snapshot from the current context.
298    pub fn update(&mut self, ctx: &CompilerContext) {
299        self.previous_predicates.clear();
300        self.previous_domains.clear();
301
302        // Snapshot domains
303        for (name, info) in &ctx.domains {
304            self.previous_domains.insert(name.clone(), info.cardinality);
305        }
306
307        self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
308    }
309
310    /// Detect changes and return affected predicates and domains.
311    pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
312        let mut changes = ChangeSet::new();
313
314        // Check domain changes
315        for (name, info) in &ctx.domains {
316            if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
317                if prev_size != info.cardinality {
318                    changes.changed_domains.insert(name.clone());
319                }
320            } else {
321                changes.new_domains.insert(name.clone());
322            }
323        }
324
325        // Check for removed domains
326        for name in self.previous_domains.keys() {
327            if !ctx.domains.contains_key(name) {
328                changes.removed_domains.insert(name.clone());
329            }
330        }
331
332        // Check configuration changes
333        let current_hash = ExpressionDependencies::hash_config(ctx);
334        if current_hash != self.previous_config_hash {
335            changes.config_changed = true;
336        }
337
338        changes
339    }
340}
341
342impl Default for ChangeDetector {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348/// Describes what has changed in the compilation context.
349#[derive(Debug, Clone, Default)]
350pub struct ChangeSet {
351    /// Predicates that were added
352    pub new_predicates: HashSet<String>,
353    /// Predicates that were modified
354    pub changed_predicates: HashSet<String>,
355    /// Predicates that were removed
356    pub removed_predicates: HashSet<String>,
357    /// Domains that were added
358    pub new_domains: HashSet<String>,
359    /// Domains that were modified
360    pub changed_domains: HashSet<String>,
361    /// Domains that were removed
362    pub removed_domains: HashSet<String>,
363    /// Whether the configuration changed
364    pub config_changed: bool,
365}
366
367impl ChangeSet {
368    fn new() -> Self {
369        Self::default()
370    }
371
372    /// Check if there are any changes.
373    pub fn has_changes(&self) -> bool {
374        !self.new_predicates.is_empty()
375            || !self.changed_predicates.is_empty()
376            || !self.removed_predicates.is_empty()
377            || !self.new_domains.is_empty()
378            || !self.changed_domains.is_empty()
379            || !self.removed_domains.is_empty()
380            || self.config_changed
381    }
382
383    /// Check if a dependency set is affected by these changes.
384    pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
385        // Config changes affect everything
386        if self.config_changed {
387            return true;
388        }
389
390        // Check if any used predicate changed
391        for pred in &deps.predicates {
392            if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
393                return true;
394            }
395        }
396
397        // Check if any used domain changed
398        for domain in &deps.domains {
399            if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
400                return true;
401            }
402        }
403
404        false
405    }
406}
407
408/// Entry in the incremental compilation cache.
409#[derive(Debug, Clone)]
410struct CacheEntry {
411    /// The compiled graph
412    graph: EinsumGraph,
413    /// Dependencies of this expression
414    dependencies: ExpressionDependencies,
415    /// When this was compiled (for LRU eviction)
416    #[allow(dead_code)]
417    timestamp: u64,
418}
419
420/// Incremental compiler that reuses previously compiled expressions.
421pub struct IncrementalCompiler {
422    /// Compilation context
423    context: CompilerContext,
424    /// Cache of compiled expressions
425    cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
426    /// Change detector
427    change_detector: ChangeDetector,
428    /// Statistics
429    stats: Arc<Mutex<IncrementalStats>>,
430    /// Next timestamp for LRU
431    next_timestamp: Arc<Mutex<u64>>,
432}
433
434impl IncrementalCompiler {
435    /// Create a new incremental compiler with the given context.
436    pub fn new(context: CompilerContext) -> Self {
437        let mut change_detector = ChangeDetector::new();
438        change_detector.update(&context);
439
440        Self {
441            context,
442            cache: Arc::new(Mutex::new(HashMap::new())),
443            change_detector,
444            stats: Arc::new(Mutex::new(IncrementalStats::default())),
445            next_timestamp: Arc::new(Mutex::new(0)),
446        }
447    }
448
449    /// Get the compilation context.
450    pub fn context(&self) -> &CompilerContext {
451        &self.context
452    }
453
454    /// Get a mutable reference to the compilation context.
455    pub fn context_mut(&mut self) -> &mut CompilerContext {
456        &mut self.context
457    }
458
459    /// Compile an expression incrementally.
460    pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
461        // Detect changes since last compilation
462        let changes = self.change_detector.detect_changes(&self.context);
463
464        // Invalidate affected entries if there are changes
465        if changes.has_changes() {
466            self.invalidate_affected(&changes);
467            self.change_detector.update(&self.context);
468        }
469
470        // Try to get from cache
471        let expr_key = format!("{:?}", expr);
472        let cache = self.cache.lock().expect("lock should not be poisoned");
473
474        if let Some(entry) = cache.get(&expr_key) {
475            // Cache hit!
476            let mut stats = self.stats.lock().expect("lock should not be poisoned");
477            stats.cache_hits += 1;
478            stats.nodes_reused += entry.graph.nodes.len();
479            drop(stats);
480
481            return Ok(entry.graph.clone());
482        }
483
484        // Cache miss - compile from scratch
485        drop(cache);
486
487        let deps = ExpressionDependencies::analyze(expr, &self.context);
488        // Compile from scratch - we can't use ? here because anyhow::Error doesn't convert to IrError
489        // So we'll just propagate the error as InvalidEinsumSpec
490        let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
491            IrError::InvalidEinsumSpec {
492                spec: format!("{:?}", expr),
493                reason: format!("Compilation failed: {}", e),
494            }
495        })?;
496
497        // Update stats
498        let mut stats = self.stats.lock().expect("lock should not be poisoned");
499        stats.cache_misses += 1;
500        stats.nodes_compiled += graph.nodes.len();
501        drop(stats);
502
503        // Store in cache
504        let mut timestamp_guard = self
505            .next_timestamp
506            .lock()
507            .expect("lock should not be poisoned");
508        let timestamp = *timestamp_guard;
509        *timestamp_guard += 1;
510        drop(timestamp_guard);
511
512        let mut cache = self.cache.lock().expect("lock should not be poisoned");
513        cache.insert(
514            expr_key,
515            CacheEntry {
516                graph: graph.clone(),
517                dependencies: deps,
518                timestamp,
519            },
520        );
521
522        Ok(graph)
523    }
524
525    /// Invalidate cache entries affected by changes.
526    fn invalidate_affected(&mut self, changes: &ChangeSet) {
527        let mut cache = self.cache.lock().expect("lock should not be poisoned");
528        cache.retain(|_, entry| !changes.affects(&entry.dependencies));
529
530        let mut stats = self.stats.lock().expect("lock should not be poisoned");
531        stats.invalidations += 1;
532    }
533
534    /// Clear the cache.
535    pub fn clear_cache(&mut self) {
536        let mut cache = self.cache.lock().expect("lock should not be poisoned");
537        cache.clear();
538    }
539
540    /// Get incremental compilation statistics.
541    pub fn stats(&self) -> IncrementalStats {
542        self.stats
543            .lock()
544            .expect("lock should not be poisoned")
545            .clone()
546    }
547
548    /// Reset statistics.
549    pub fn reset_stats(&mut self) {
550        let mut stats = self.stats.lock().expect("lock should not be poisoned");
551        *stats = IncrementalStats::default();
552    }
553}
554
555/// Statistics for incremental compilation.
556#[derive(Debug, Clone, Default)]
557pub struct IncrementalStats {
558    /// Number of cache hits
559    pub cache_hits: usize,
560    /// Number of cache misses
561    pub cache_misses: usize,
562    /// Number of invalidations
563    pub invalidations: usize,
564    /// Number of nodes reused from cache
565    pub nodes_reused: usize,
566    /// Number of nodes freshly compiled
567    pub nodes_compiled: usize,
568}
569
570impl IncrementalStats {
571    /// Get the cache hit rate (0.0 to 1.0).
572    pub fn hit_rate(&self) -> f64 {
573        let total = self.cache_hits + self.cache_misses;
574        if total == 0 {
575            0.0
576        } else {
577            self.cache_hits as f64 / total as f64
578        }
579    }
580
581    /// Get the reuse rate for nodes (0.0 to 1.0).
582    pub fn reuse_rate(&self) -> f64 {
583        let total = self.nodes_reused + self.nodes_compiled;
584        if total == 0 {
585            0.0
586        } else {
587            self.nodes_reused as f64 / total as f64
588        }
589    }
590
591    /// Get the total number of compilations.
592    pub fn total_compilations(&self) -> usize {
593        self.cache_hits + self.cache_misses
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_dependency_tracking() {
603        let mut ctx = CompilerContext::new();
604        ctx.add_domain("Person", 100);
605
606        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
607        let deps = ExpressionDependencies::analyze(&expr, &ctx);
608
609        assert!(deps.predicates.contains("knows"));
610        assert!(deps.variables.contains("x"));
611        assert!(deps.variables.contains("y"));
612    }
613
614    #[test]
615    fn test_incremental_compilation_reuse() {
616        let mut ctx = CompilerContext::new();
617        ctx.add_domain("Person", 100);
618
619        let mut compiler = IncrementalCompiler::new(ctx);
620
621        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
622
623        // First compilation
624        let _graph1 = compiler.compile(&expr).expect("unwrap");
625        assert_eq!(compiler.stats().cache_misses, 1);
626        assert_eq!(compiler.stats().cache_hits, 0);
627
628        // Second compilation - should hit cache
629        let _graph2 = compiler.compile(&expr).expect("unwrap");
630        assert_eq!(compiler.stats().cache_misses, 1);
631        assert_eq!(compiler.stats().cache_hits, 1);
632        assert_eq!(compiler.stats().hit_rate(), 0.5);
633    }
634
635    #[test]
636    fn test_change_detection_domain() {
637        let mut ctx = CompilerContext::new();
638        ctx.add_domain("Person", 100);
639
640        let mut detector = ChangeDetector::new();
641        detector.update(&ctx);
642
643        // No changes initially
644        let changes = detector.detect_changes(&ctx);
645        assert!(!changes.has_changes());
646
647        // Change domain size
648        ctx.add_domain("Person", 200);
649        let changes = detector.detect_changes(&ctx);
650        assert!(changes.has_changes());
651        assert!(changes.changed_domains.contains("Person"));
652    }
653
654    #[test]
655    fn test_invalidation_on_domain_change() {
656        let mut ctx = CompilerContext::new();
657        ctx.add_domain("Person", 100);
658
659        let mut compiler = IncrementalCompiler::new(ctx);
660
661        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
662
663        // First compilation
664        let _graph1 = compiler.compile(&expr).expect("unwrap");
665        assert_eq!(compiler.stats().cache_misses, 1);
666
667        // Change domain
668        compiler.context_mut().add_domain("Person", 200);
669
670        // Should recompile due to domain change and invalidate cache
671        let _graph2 = compiler.compile(&expr).expect("unwrap");
672        // After invalidation, this is another cache miss
673        assert!(compiler.stats().cache_misses >= 1);
674        assert!(compiler.stats().invalidations >= 1);
675    }
676
677    #[test]
678    fn test_incremental_stats() {
679        let mut ctx = CompilerContext::new();
680        ctx.add_domain("Person", 100);
681
682        let mut compiler = IncrementalCompiler::new(ctx);
683
684        let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
685        let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
686
687        compiler.compile(&expr1).expect("unwrap");
688        compiler.compile(&expr1).expect("unwrap"); // Should be cache hit
689        compiler.compile(&expr2).expect("unwrap");
690
691        let stats = compiler.stats();
692        assert_eq!(stats.total_compilations(), 3);
693        // At least one cache hit from the second expr1 compilation
694        assert!(
695            stats.cache_hits >= 1,
696            "Expected at least 1 cache hit, got {}",
697            stats.cache_hits
698        );
699        // Hit rate should be positive if we have cache hits
700        assert!(
701            stats.hit_rate() > 0.0,
702            "Expected positive hit rate, got {}",
703            stats.hit_rate()
704        );
705    }
706
707    #[test]
708    fn test_complex_expression_dependencies() {
709        let mut ctx = CompilerContext::new();
710        ctx.add_domain("Person", 100);
711
712        let expr = TLExpr::exists(
713            "x",
714            "Person",
715            TLExpr::and(
716                TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
717                TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
718            ),
719        );
720
721        let deps = ExpressionDependencies::analyze(&expr, &ctx);
722
723        assert!(deps.predicates.contains("knows"));
724        assert!(deps.predicates.contains("likes"));
725        assert!(deps.variables.contains("x"));
726        assert!(deps.variables.contains("y"));
727        assert!(deps.variables.contains("z"));
728        assert!(deps.domains.contains("Person"));
729    }
730}