tensorlogic_compiler/
incremental.rs

1//! Incremental compilation for efficient recompilation when expressions change.
2//!
3//! This module provides an incremental compilation system that tracks dependencies
4//! and recompiles only the parts of expressions that have changed. This is crucial
5//! for interactive environments like REPLs, notebooks, and IDEs where expressions
6//! are frequently modified.
7//!
8//! # Architecture
9//!
10//! The incremental compilation system consists of three main components:
11//!
12//! 1. **DependencyTracker**: Tracks what each expression depends on (predicates,
13//!    variables, domains, configurations).
14//!
15//! 2. **ChangeDetector**: Detects changes to the compilation context (predicate
16//!    signatures, domains, configurations) and determines what needs recompilation.
17//!
18//! 3. **IncrementalCompiler**: Manages the compilation state, computes minimal
19//!    invalidation sets, and recompiles only affected sub-expressions.
20//!
21//! # Example
22//!
23//! ```rust
24//! use tensorlogic_compiler::{CompilerContext, incremental::IncrementalCompiler};
25//! use tensorlogic_ir::{TLExpr, Term};
26//!
27//! let mut ctx = CompilerContext::new();
28//! ctx.add_domain("Person", 100);
29//!
30//! let mut compiler = IncrementalCompiler::new(ctx);
31//!
32//! // Initial compilation
33//! let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
34//! let graph1 = compiler.compile(&expr1).unwrap();
35//!
36//! // Compile similar expression - some parts will be reused
37//! let expr2 = TLExpr::and(
38//!     TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
39//!     TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
40//! );
41//! let graph2 = compiler.compile(&expr2).unwrap();
42//!
43//! // Check incremental compilation stats
44//! let stats = compiler.stats();
45//! println!("Nodes reused: {}", stats.nodes_reused);
46//! println!("Nodes compiled: {}", stats.nodes_compiled);
47//! println!("Reuse rate: {:.1}%", stats.reuse_rate() * 100.0);
48//! ```
49
50use crate::{compile_to_einsum_with_context, CompilerContext};
51use std::collections::{HashMap, HashSet};
52use std::sync::{Arc, Mutex};
53use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
54
55/// Tracks dependencies of compiled expressions.
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct ExpressionDependencies {
58    /// Predicates referenced in the expression
59    pub predicates: HashSet<String>,
60    /// Variables referenced in the expression
61    pub variables: HashSet<String>,
62    /// Domains used in the expression
63    pub domains: HashSet<String>,
64    /// Configuration hash (to detect strategy changes)
65    pub config_hash: u64,
66}
67
68impl ExpressionDependencies {
69    /// Create a new empty dependency set.
70    pub fn new() -> Self {
71        Self {
72            predicates: HashSet::new(),
73            variables: HashSet::new(),
74            domains: HashSet::new(),
75            config_hash: 0,
76        }
77    }
78
79    /// Analyze an expression and extract its dependencies.
80    pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
81        let mut deps = Self::new();
82        deps.analyze_recursive(expr);
83        deps.config_hash = Self::hash_config(ctx);
84        deps
85    }
86
87    fn analyze_recursive(&mut self, expr: &TLExpr) {
88        match expr {
89            TLExpr::Pred { name, args } => {
90                self.predicates.insert(name.clone());
91                for arg in args {
92                    self.analyze_term(arg);
93                }
94            }
95            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
96                self.analyze_recursive(left);
97                self.analyze_recursive(right);
98            }
99            TLExpr::Not(inner) => {
100                self.analyze_recursive(inner);
101            }
102            TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
103                self.variables.insert(var.clone());
104                self.domains.insert(domain.clone());
105                self.analyze_recursive(body);
106            }
107            TLExpr::Score(inner) => {
108                self.analyze_recursive(inner);
109            }
110            TLExpr::Add(left, right)
111            | TLExpr::Sub(left, right)
112            | TLExpr::Mul(left, right)
113            | TLExpr::Div(left, right) => {
114                self.analyze_recursive(left);
115                self.analyze_recursive(right);
116            }
117            TLExpr::Eq(left, right)
118            | TLExpr::Lt(left, right)
119            | TLExpr::Gt(left, right)
120            | TLExpr::Lte(left, right)
121            | TLExpr::Gte(left, right) => {
122                self.analyze_recursive(left);
123                self.analyze_recursive(right);
124            }
125            TLExpr::IfThenElse {
126                condition,
127                then_branch,
128                else_branch,
129            } => {
130                self.analyze_recursive(condition);
131                self.analyze_recursive(then_branch);
132                self.analyze_recursive(else_branch);
133            }
134            TLExpr::Aggregate {
135                op: _,
136                var,
137                domain,
138                body,
139                group_by,
140            } => {
141                self.variables.insert(var.clone());
142                self.domains.insert(domain.clone());
143                self.analyze_recursive(body);
144                if let Some(gb_vars) = group_by {
145                    for var_name in gb_vars {
146                        self.variables.insert(var_name.clone());
147                    }
148                }
149            }
150            TLExpr::TNorm {
151                kind: _,
152                left,
153                right,
154            }
155            | TLExpr::TCoNorm {
156                kind: _,
157                left,
158                right,
159            } => {
160                self.analyze_recursive(left);
161                self.analyze_recursive(right);
162            }
163            TLExpr::FuzzyNot {
164                kind: _,
165                expr: inner,
166            } => {
167                self.analyze_recursive(inner);
168            }
169            TLExpr::FuzzyImplication {
170                kind: _,
171                premise,
172                conclusion,
173            } => {
174                self.analyze_recursive(premise);
175                self.analyze_recursive(conclusion);
176            }
177            TLExpr::SoftExists {
178                var,
179                domain,
180                body,
181                temperature: _,
182            }
183            | TLExpr::SoftForAll {
184                var,
185                domain,
186                body,
187                temperature: _,
188            } => {
189                self.variables.insert(var.clone());
190                self.domains.insert(domain.clone());
191                self.analyze_recursive(body);
192            }
193            TLExpr::WeightedRule { weight: _, rule } => {
194                self.analyze_recursive(rule);
195            }
196            TLExpr::ProbabilisticChoice { alternatives } => {
197                for (_, alt) in alternatives {
198                    self.analyze_recursive(alt);
199                }
200            }
201            TLExpr::Let { var, value, body } => {
202                self.variables.insert(var.clone());
203                self.analyze_recursive(value);
204                self.analyze_recursive(body);
205            }
206            TLExpr::Box(inner)
207            | TLExpr::Diamond(inner)
208            | TLExpr::Next(inner)
209            | TLExpr::Eventually(inner)
210            | TLExpr::Always(inner) => {
211                self.analyze_recursive(inner);
212            }
213            TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
214                self.analyze_recursive(before);
215                self.analyze_recursive(after);
216            }
217            TLExpr::Release { released, releaser }
218            | TLExpr::StrongRelease { released, releaser } => {
219                self.analyze_recursive(released);
220                self.analyze_recursive(releaser);
221            }
222            // Math operations
223            TLExpr::Abs(inner)
224            | TLExpr::Sqrt(inner)
225            | TLExpr::Exp(inner)
226            | TLExpr::Log(inner)
227            | TLExpr::Sin(inner)
228            | TLExpr::Cos(inner)
229            | TLExpr::Tan(inner)
230            | TLExpr::Floor(inner)
231            | TLExpr::Ceil(inner)
232            | TLExpr::Round(inner) => {
233                self.analyze_recursive(inner);
234            }
235            TLExpr::Pow(left, right)
236            | TLExpr::Min(left, right)
237            | TLExpr::Max(left, right)
238            | TLExpr::Mod(left, right) => {
239                self.analyze_recursive(left);
240                self.analyze_recursive(right);
241            }
242            TLExpr::Constant(_) => {
243                // No dependencies
244            }
245        }
246    }
247
248    fn analyze_term(&mut self, term: &Term) {
249        if let Term::Var(name) = term {
250            self.variables.insert(name.clone());
251        }
252    }
253
254    fn hash_config(ctx: &CompilerContext) -> u64 {
255        use std::collections::hash_map::DefaultHasher;
256        use std::hash::{Hash, Hasher};
257
258        let mut hasher = DefaultHasher::new();
259        // Hash the config strategies
260        format!("{:?}", ctx.config).hash(&mut hasher);
261        hasher.finish()
262    }
263}
264
265impl Default for ExpressionDependencies {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271/// Detects changes to the compilation context.
272#[derive(Debug, Clone)]
273pub struct ChangeDetector {
274    /// Previous predicate signatures
275    previous_predicates: HashMap<String, (usize, Vec<String>)>,
276    /// Previous domain sizes
277    previous_domains: HashMap<String, usize>,
278    /// Previous configuration hash
279    previous_config_hash: u64,
280}
281
282impl ChangeDetector {
283    /// Create a new change detector.
284    pub fn new() -> Self {
285        Self {
286            previous_predicates: HashMap::new(),
287            previous_domains: HashMap::new(),
288            previous_config_hash: 0,
289        }
290    }
291
292    /// Update the snapshot from the current context.
293    pub fn update(&mut self, ctx: &CompilerContext) {
294        self.previous_predicates.clear();
295        self.previous_domains.clear();
296
297        // Snapshot domains
298        for (name, info) in &ctx.domains {
299            self.previous_domains.insert(name.clone(), info.cardinality);
300        }
301
302        self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
303    }
304
305    /// Detect changes and return affected predicates and domains.
306    pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
307        let mut changes = ChangeSet::new();
308
309        // Check domain changes
310        for (name, info) in &ctx.domains {
311            if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
312                if prev_size != info.cardinality {
313                    changes.changed_domains.insert(name.clone());
314                }
315            } else {
316                changes.new_domains.insert(name.clone());
317            }
318        }
319
320        // Check for removed domains
321        for name in self.previous_domains.keys() {
322            if !ctx.domains.contains_key(name) {
323                changes.removed_domains.insert(name.clone());
324            }
325        }
326
327        // Check configuration changes
328        let current_hash = ExpressionDependencies::hash_config(ctx);
329        if current_hash != self.previous_config_hash {
330            changes.config_changed = true;
331        }
332
333        changes
334    }
335}
336
337impl Default for ChangeDetector {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343/// Describes what has changed in the compilation context.
344#[derive(Debug, Clone, Default)]
345pub struct ChangeSet {
346    /// Predicates that were added
347    pub new_predicates: HashSet<String>,
348    /// Predicates that were modified
349    pub changed_predicates: HashSet<String>,
350    /// Predicates that were removed
351    pub removed_predicates: HashSet<String>,
352    /// Domains that were added
353    pub new_domains: HashSet<String>,
354    /// Domains that were modified
355    pub changed_domains: HashSet<String>,
356    /// Domains that were removed
357    pub removed_domains: HashSet<String>,
358    /// Whether the configuration changed
359    pub config_changed: bool,
360}
361
362impl ChangeSet {
363    fn new() -> Self {
364        Self::default()
365    }
366
367    /// Check if there are any changes.
368    pub fn has_changes(&self) -> bool {
369        !self.new_predicates.is_empty()
370            || !self.changed_predicates.is_empty()
371            || !self.removed_predicates.is_empty()
372            || !self.new_domains.is_empty()
373            || !self.changed_domains.is_empty()
374            || !self.removed_domains.is_empty()
375            || self.config_changed
376    }
377
378    /// Check if a dependency set is affected by these changes.
379    pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
380        // Config changes affect everything
381        if self.config_changed {
382            return true;
383        }
384
385        // Check if any used predicate changed
386        for pred in &deps.predicates {
387            if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
388                return true;
389            }
390        }
391
392        // Check if any used domain changed
393        for domain in &deps.domains {
394            if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
395                return true;
396            }
397        }
398
399        false
400    }
401}
402
403/// Entry in the incremental compilation cache.
404#[derive(Debug, Clone)]
405struct CacheEntry {
406    /// The compiled graph
407    graph: EinsumGraph,
408    /// Dependencies of this expression
409    dependencies: ExpressionDependencies,
410    /// When this was compiled (for LRU eviction)
411    #[allow(dead_code)]
412    timestamp: u64,
413}
414
415/// Incremental compiler that reuses previously compiled expressions.
416pub struct IncrementalCompiler {
417    /// Compilation context
418    context: CompilerContext,
419    /// Cache of compiled expressions
420    cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
421    /// Change detector
422    change_detector: ChangeDetector,
423    /// Statistics
424    stats: Arc<Mutex<IncrementalStats>>,
425    /// Next timestamp for LRU
426    next_timestamp: Arc<Mutex<u64>>,
427}
428
429impl IncrementalCompiler {
430    /// Create a new incremental compiler with the given context.
431    pub fn new(context: CompilerContext) -> Self {
432        let mut change_detector = ChangeDetector::new();
433        change_detector.update(&context);
434
435        Self {
436            context,
437            cache: Arc::new(Mutex::new(HashMap::new())),
438            change_detector,
439            stats: Arc::new(Mutex::new(IncrementalStats::default())),
440            next_timestamp: Arc::new(Mutex::new(0)),
441        }
442    }
443
444    /// Get the compilation context.
445    pub fn context(&self) -> &CompilerContext {
446        &self.context
447    }
448
449    /// Get a mutable reference to the compilation context.
450    pub fn context_mut(&mut self) -> &mut CompilerContext {
451        &mut self.context
452    }
453
454    /// Compile an expression incrementally.
455    pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
456        // Detect changes since last compilation
457        let changes = self.change_detector.detect_changes(&self.context);
458
459        // Invalidate affected entries if there are changes
460        if changes.has_changes() {
461            self.invalidate_affected(&changes);
462            self.change_detector.update(&self.context);
463        }
464
465        // Try to get from cache
466        let expr_key = format!("{:?}", expr);
467        let cache = self.cache.lock().unwrap();
468
469        if let Some(entry) = cache.get(&expr_key) {
470            // Cache hit!
471            let mut stats = self.stats.lock().unwrap();
472            stats.cache_hits += 1;
473            stats.nodes_reused += entry.graph.nodes.len();
474            drop(stats);
475
476            return Ok(entry.graph.clone());
477        }
478
479        // Cache miss - compile from scratch
480        drop(cache);
481
482        let deps = ExpressionDependencies::analyze(expr, &self.context);
483        // Compile from scratch - we can't use ? here because anyhow::Error doesn't convert to IrError
484        // So we'll just propagate the error as InvalidEinsumSpec
485        let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
486            IrError::InvalidEinsumSpec {
487                spec: format!("{:?}", expr),
488                reason: format!("Compilation failed: {}", e),
489            }
490        })?;
491
492        // Update stats
493        let mut stats = self.stats.lock().unwrap();
494        stats.cache_misses += 1;
495        stats.nodes_compiled += graph.nodes.len();
496        drop(stats);
497
498        // Store in cache
499        let mut timestamp_guard = self.next_timestamp.lock().unwrap();
500        let timestamp = *timestamp_guard;
501        *timestamp_guard += 1;
502        drop(timestamp_guard);
503
504        let mut cache = self.cache.lock().unwrap();
505        cache.insert(
506            expr_key,
507            CacheEntry {
508                graph: graph.clone(),
509                dependencies: deps,
510                timestamp,
511            },
512        );
513
514        Ok(graph)
515    }
516
517    /// Invalidate cache entries affected by changes.
518    fn invalidate_affected(&mut self, changes: &ChangeSet) {
519        let mut cache = self.cache.lock().unwrap();
520        cache.retain(|_, entry| !changes.affects(&entry.dependencies));
521
522        let mut stats = self.stats.lock().unwrap();
523        stats.invalidations += 1;
524    }
525
526    /// Clear the cache.
527    pub fn clear_cache(&mut self) {
528        let mut cache = self.cache.lock().unwrap();
529        cache.clear();
530    }
531
532    /// Get incremental compilation statistics.
533    pub fn stats(&self) -> IncrementalStats {
534        self.stats.lock().unwrap().clone()
535    }
536
537    /// Reset statistics.
538    pub fn reset_stats(&mut self) {
539        let mut stats = self.stats.lock().unwrap();
540        *stats = IncrementalStats::default();
541    }
542}
543
544/// Statistics for incremental compilation.
545#[derive(Debug, Clone, Default)]
546pub struct IncrementalStats {
547    /// Number of cache hits
548    pub cache_hits: usize,
549    /// Number of cache misses
550    pub cache_misses: usize,
551    /// Number of invalidations
552    pub invalidations: usize,
553    /// Number of nodes reused from cache
554    pub nodes_reused: usize,
555    /// Number of nodes freshly compiled
556    pub nodes_compiled: usize,
557}
558
559impl IncrementalStats {
560    /// Get the cache hit rate (0.0 to 1.0).
561    pub fn hit_rate(&self) -> f64 {
562        let total = self.cache_hits + self.cache_misses;
563        if total == 0 {
564            0.0
565        } else {
566            self.cache_hits as f64 / total as f64
567        }
568    }
569
570    /// Get the reuse rate for nodes (0.0 to 1.0).
571    pub fn reuse_rate(&self) -> f64 {
572        let total = self.nodes_reused + self.nodes_compiled;
573        if total == 0 {
574            0.0
575        } else {
576            self.nodes_reused as f64 / total as f64
577        }
578    }
579
580    /// Get the total number of compilations.
581    pub fn total_compilations(&self) -> usize {
582        self.cache_hits + self.cache_misses
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_dependency_tracking() {
592        let mut ctx = CompilerContext::new();
593        ctx.add_domain("Person", 100);
594
595        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
596        let deps = ExpressionDependencies::analyze(&expr, &ctx);
597
598        assert!(deps.predicates.contains("knows"));
599        assert!(deps.variables.contains("x"));
600        assert!(deps.variables.contains("y"));
601    }
602
603    #[test]
604    fn test_incremental_compilation_reuse() {
605        let mut ctx = CompilerContext::new();
606        ctx.add_domain("Person", 100);
607
608        let mut compiler = IncrementalCompiler::new(ctx);
609
610        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
611
612        // First compilation
613        let _graph1 = compiler.compile(&expr).unwrap();
614        assert_eq!(compiler.stats().cache_misses, 1);
615        assert_eq!(compiler.stats().cache_hits, 0);
616
617        // Second compilation - should hit cache
618        let _graph2 = compiler.compile(&expr).unwrap();
619        assert_eq!(compiler.stats().cache_misses, 1);
620        assert_eq!(compiler.stats().cache_hits, 1);
621        assert_eq!(compiler.stats().hit_rate(), 0.5);
622    }
623
624    #[test]
625    fn test_change_detection_domain() {
626        let mut ctx = CompilerContext::new();
627        ctx.add_domain("Person", 100);
628
629        let mut detector = ChangeDetector::new();
630        detector.update(&ctx);
631
632        // No changes initially
633        let changes = detector.detect_changes(&ctx);
634        assert!(!changes.has_changes());
635
636        // Change domain size
637        ctx.add_domain("Person", 200);
638        let changes = detector.detect_changes(&ctx);
639        assert!(changes.has_changes());
640        assert!(changes.changed_domains.contains("Person"));
641    }
642
643    #[test]
644    fn test_invalidation_on_domain_change() {
645        let mut ctx = CompilerContext::new();
646        ctx.add_domain("Person", 100);
647
648        let mut compiler = IncrementalCompiler::new(ctx);
649
650        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
651
652        // First compilation
653        let _graph1 = compiler.compile(&expr).unwrap();
654        assert_eq!(compiler.stats().cache_misses, 1);
655
656        // Change domain
657        compiler.context_mut().add_domain("Person", 200);
658
659        // Should recompile due to domain change and invalidate cache
660        let _graph2 = compiler.compile(&expr).unwrap();
661        // After invalidation, this is another cache miss
662        assert!(compiler.stats().cache_misses >= 1);
663        assert!(compiler.stats().invalidations >= 1);
664    }
665
666    #[test]
667    fn test_incremental_stats() {
668        let mut ctx = CompilerContext::new();
669        ctx.add_domain("Person", 100);
670
671        let mut compiler = IncrementalCompiler::new(ctx);
672
673        let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
674        let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
675
676        compiler.compile(&expr1).unwrap();
677        compiler.compile(&expr1).unwrap(); // Should be cache hit
678        compiler.compile(&expr2).unwrap();
679
680        let stats = compiler.stats();
681        assert_eq!(stats.total_compilations(), 3);
682        // At least one cache hit from the second expr1 compilation
683        assert!(
684            stats.cache_hits >= 1,
685            "Expected at least 1 cache hit, got {}",
686            stats.cache_hits
687        );
688        // Hit rate should be positive if we have cache hits
689        assert!(
690            stats.hit_rate() > 0.0,
691            "Expected positive hit rate, got {}",
692            stats.hit_rate()
693        );
694    }
695
696    #[test]
697    fn test_complex_expression_dependencies() {
698        let mut ctx = CompilerContext::new();
699        ctx.add_domain("Person", 100);
700
701        let expr = TLExpr::exists(
702            "x",
703            "Person",
704            TLExpr::and(
705                TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
706                TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
707            ),
708        );
709
710        let deps = ExpressionDependencies::analyze(&expr, &ctx);
711
712        assert!(deps.predicates.contains("knows"));
713        assert!(deps.predicates.contains("likes"));
714        assert!(deps.variables.contains("x"));
715        assert!(deps.variables.contains("y"));
716        assert!(deps.variables.contains("z"));
717        assert!(deps.domains.contains("Person"));
718    }
719}