Skip to main content

tensorlogic_compiler/
incremental.rs

1//! Incremental compilation for efficient recompilation when expressions change.
2//!
3//! This module provides an incremental compilation system that tracks dependencies
4//! and recompiles only the parts of expressions that have changed. This is crucial
5//! for interactive environments like REPLs, notebooks, and IDEs where expressions
6//! are frequently modified.
7//!
8//! # Architecture
9//!
10//! The incremental compilation system consists of three main components:
11//!
12//! 1. **DependencyTracker**: Tracks what each expression depends on (predicates,
13//!    variables, domains, configurations).
14//!
15//! 2. **ChangeDetector**: Detects changes to the compilation context (predicate
16//!    signatures, domains, configurations) and determines what needs recompilation.
17//!
18//! 3. **IncrementalCompiler**: Manages the compilation state, computes minimal
19//!    invalidation sets, and recompiles only affected sub-expressions.
20//!
21//! # Example
22//!
23//! ```rust
24//! use tensorlogic_compiler::{CompilerContext, incremental::IncrementalCompiler};
25//! use tensorlogic_ir::{TLExpr, Term};
26//!
27//! let mut ctx = CompilerContext::new();
28//! ctx.add_domain("Person", 100);
29//!
30//! let mut compiler = IncrementalCompiler::new(ctx);
31//!
32//! // Initial compilation
33//! let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
34//! let graph1 = compiler.compile(&expr1).unwrap();
35//!
36//! // Compile similar expression - some parts will be reused
37//! let expr2 = TLExpr::and(
38//!     TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
39//!     TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
40//! );
41//! let graph2 = compiler.compile(&expr2).unwrap();
42//!
43//! // Check incremental compilation stats
44//! let stats = compiler.stats();
45//! println!("Nodes reused: {}", stats.nodes_reused);
46//! println!("Nodes compiled: {}", stats.nodes_compiled);
47//! println!("Reuse rate: {:.1}%", stats.reuse_rate() * 100.0);
48//! ```
49
50use crate::{compile_to_einsum_with_context, CompilerContext};
51use std::collections::{HashMap, HashSet};
52use std::sync::{Arc, Mutex};
53use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
54
55/// Tracks dependencies of compiled expressions.
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct ExpressionDependencies {
58    /// Predicates referenced in the expression
59    pub predicates: HashSet<String>,
60    /// Variables referenced in the expression
61    pub variables: HashSet<String>,
62    /// Domains used in the expression
63    pub domains: HashSet<String>,
64    /// Configuration hash (to detect strategy changes)
65    pub config_hash: u64,
66}
67
68impl ExpressionDependencies {
69    /// Create a new empty dependency set.
70    pub fn new() -> Self {
71        Self {
72            predicates: HashSet::new(),
73            variables: HashSet::new(),
74            domains: HashSet::new(),
75            config_hash: 0,
76        }
77    }
78
79    /// Analyze an expression and extract its dependencies.
80    pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
81        let mut deps = Self::new();
82        deps.analyze_recursive(expr);
83        deps.config_hash = Self::hash_config(ctx);
84        deps
85    }
86
87    fn analyze_recursive(&mut self, expr: &TLExpr) {
88        match expr {
89            TLExpr::Pred { name, args } => {
90                self.predicates.insert(name.clone());
91                for arg in args {
92                    self.analyze_term(arg);
93                }
94            }
95            TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
96                self.analyze_recursive(left);
97                self.analyze_recursive(right);
98            }
99            TLExpr::Not(inner) => {
100                self.analyze_recursive(inner);
101            }
102            TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
103                self.variables.insert(var.clone());
104                self.domains.insert(domain.clone());
105                self.analyze_recursive(body);
106            }
107            TLExpr::Score(inner) => {
108                self.analyze_recursive(inner);
109            }
110            TLExpr::Add(left, right)
111            | TLExpr::Sub(left, right)
112            | TLExpr::Mul(left, right)
113            | TLExpr::Div(left, right) => {
114                self.analyze_recursive(left);
115                self.analyze_recursive(right);
116            }
117            TLExpr::Eq(left, right)
118            | TLExpr::Lt(left, right)
119            | TLExpr::Gt(left, right)
120            | TLExpr::Lte(left, right)
121            | TLExpr::Gte(left, right) => {
122                self.analyze_recursive(left);
123                self.analyze_recursive(right);
124            }
125            TLExpr::IfThenElse {
126                condition,
127                then_branch,
128                else_branch,
129            } => {
130                self.analyze_recursive(condition);
131                self.analyze_recursive(then_branch);
132                self.analyze_recursive(else_branch);
133            }
134            TLExpr::Aggregate {
135                op: _,
136                var,
137                domain,
138                body,
139                group_by,
140            } => {
141                self.variables.insert(var.clone());
142                self.domains.insert(domain.clone());
143                self.analyze_recursive(body);
144                if let Some(gb_vars) = group_by {
145                    for var_name in gb_vars {
146                        self.variables.insert(var_name.clone());
147                    }
148                }
149            }
150            TLExpr::TNorm {
151                kind: _,
152                left,
153                right,
154            }
155            | TLExpr::TCoNorm {
156                kind: _,
157                left,
158                right,
159            } => {
160                self.analyze_recursive(left);
161                self.analyze_recursive(right);
162            }
163            TLExpr::FuzzyNot {
164                kind: _,
165                expr: inner,
166            } => {
167                self.analyze_recursive(inner);
168            }
169            TLExpr::FuzzyImplication {
170                kind: _,
171                premise,
172                conclusion,
173            } => {
174                self.analyze_recursive(premise);
175                self.analyze_recursive(conclusion);
176            }
177            TLExpr::SoftExists {
178                var,
179                domain,
180                body,
181                temperature: _,
182            }
183            | TLExpr::SoftForAll {
184                var,
185                domain,
186                body,
187                temperature: _,
188            } => {
189                self.variables.insert(var.clone());
190                self.domains.insert(domain.clone());
191                self.analyze_recursive(body);
192            }
193            TLExpr::WeightedRule { weight: _, rule } => {
194                self.analyze_recursive(rule);
195            }
196            TLExpr::ProbabilisticChoice { alternatives } => {
197                for (_, alt) in alternatives {
198                    self.analyze_recursive(alt);
199                }
200            }
201            TLExpr::Let { var, value, body } => {
202                self.variables.insert(var.clone());
203                self.analyze_recursive(value);
204                self.analyze_recursive(body);
205            }
206            TLExpr::Box(inner)
207            | TLExpr::Diamond(inner)
208            | TLExpr::Next(inner)
209            | TLExpr::Eventually(inner)
210            | TLExpr::Always(inner) => {
211                self.analyze_recursive(inner);
212            }
213            TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
214                self.analyze_recursive(before);
215                self.analyze_recursive(after);
216            }
217            TLExpr::Release { released, releaser }
218            | TLExpr::StrongRelease { released, releaser } => {
219                self.analyze_recursive(released);
220                self.analyze_recursive(releaser);
221            }
222            // Math operations
223            TLExpr::Abs(inner)
224            | TLExpr::Sqrt(inner)
225            | TLExpr::Exp(inner)
226            | TLExpr::Log(inner)
227            | TLExpr::Sin(inner)
228            | TLExpr::Cos(inner)
229            | TLExpr::Tan(inner)
230            | TLExpr::Floor(inner)
231            | TLExpr::Ceil(inner)
232            | TLExpr::Round(inner) => {
233                self.analyze_recursive(inner);
234            }
235            TLExpr::Pow(left, right)
236            | TLExpr::Min(left, right)
237            | TLExpr::Max(left, right)
238            | TLExpr::Mod(left, right) => {
239                self.analyze_recursive(left);
240                self.analyze_recursive(right);
241            }
242            TLExpr::Constant(_) => {
243                // No dependencies
244            }
245            // All other expression types (enhancements)
246            _ => {
247                // For unhandled variants, recurse on any child expressions if needed
248                // This is a catch-all for future-proof compilation
249            }
250        }
251    }
252
253    fn analyze_term(&mut self, term: &Term) {
254        if let Term::Var(name) = term {
255            self.variables.insert(name.clone());
256        }
257    }
258
259    fn hash_config(ctx: &CompilerContext) -> u64 {
260        use std::collections::hash_map::DefaultHasher;
261        use std::hash::{Hash, Hasher};
262
263        let mut hasher = DefaultHasher::new();
264        // Hash the config strategies
265        format!("{:?}", ctx.config).hash(&mut hasher);
266        hasher.finish()
267    }
268}
269
270impl Default for ExpressionDependencies {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276/// Detects changes to the compilation context.
277#[derive(Debug, Clone)]
278pub struct ChangeDetector {
279    /// Previous predicate signatures
280    previous_predicates: HashMap<String, (usize, Vec<String>)>,
281    /// Previous domain sizes
282    previous_domains: HashMap<String, usize>,
283    /// Previous configuration hash
284    previous_config_hash: u64,
285}
286
287impl ChangeDetector {
288    /// Create a new change detector.
289    pub fn new() -> Self {
290        Self {
291            previous_predicates: HashMap::new(),
292            previous_domains: HashMap::new(),
293            previous_config_hash: 0,
294        }
295    }
296
297    /// Update the snapshot from the current context.
298    pub fn update(&mut self, ctx: &CompilerContext) {
299        self.previous_predicates.clear();
300        self.previous_domains.clear();
301
302        // Snapshot domains
303        for (name, info) in &ctx.domains {
304            self.previous_domains.insert(name.clone(), info.cardinality);
305        }
306
307        self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
308    }
309
310    /// Detect changes and return affected predicates and domains.
311    pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
312        let mut changes = ChangeSet::new();
313
314        // Check domain changes
315        for (name, info) in &ctx.domains {
316            if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
317                if prev_size != info.cardinality {
318                    changes.changed_domains.insert(name.clone());
319                }
320            } else {
321                changes.new_domains.insert(name.clone());
322            }
323        }
324
325        // Check for removed domains
326        for name in self.previous_domains.keys() {
327            if !ctx.domains.contains_key(name) {
328                changes.removed_domains.insert(name.clone());
329            }
330        }
331
332        // Check configuration changes
333        let current_hash = ExpressionDependencies::hash_config(ctx);
334        if current_hash != self.previous_config_hash {
335            changes.config_changed = true;
336        }
337
338        changes
339    }
340}
341
342impl Default for ChangeDetector {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348/// Describes what has changed in the compilation context.
349#[derive(Debug, Clone, Default)]
350pub struct ChangeSet {
351    /// Predicates that were added
352    pub new_predicates: HashSet<String>,
353    /// Predicates that were modified
354    pub changed_predicates: HashSet<String>,
355    /// Predicates that were removed
356    pub removed_predicates: HashSet<String>,
357    /// Domains that were added
358    pub new_domains: HashSet<String>,
359    /// Domains that were modified
360    pub changed_domains: HashSet<String>,
361    /// Domains that were removed
362    pub removed_domains: HashSet<String>,
363    /// Whether the configuration changed
364    pub config_changed: bool,
365}
366
367impl ChangeSet {
368    fn new() -> Self {
369        Self::default()
370    }
371
372    /// Check if there are any changes.
373    pub fn has_changes(&self) -> bool {
374        !self.new_predicates.is_empty()
375            || !self.changed_predicates.is_empty()
376            || !self.removed_predicates.is_empty()
377            || !self.new_domains.is_empty()
378            || !self.changed_domains.is_empty()
379            || !self.removed_domains.is_empty()
380            || self.config_changed
381    }
382
383    /// Check if a dependency set is affected by these changes.
384    pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
385        // Config changes affect everything
386        if self.config_changed {
387            return true;
388        }
389
390        // Check if any used predicate changed
391        for pred in &deps.predicates {
392            if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
393                return true;
394            }
395        }
396
397        // Check if any used domain changed
398        for domain in &deps.domains {
399            if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
400                return true;
401            }
402        }
403
404        false
405    }
406}
407
408/// Entry in the incremental compilation cache.
409#[derive(Debug, Clone)]
410struct CacheEntry {
411    /// The compiled graph
412    graph: EinsumGraph,
413    /// Dependencies of this expression
414    dependencies: ExpressionDependencies,
415    /// When this was compiled (for LRU eviction)
416    #[allow(dead_code)]
417    timestamp: u64,
418}
419
420/// Incremental compiler that reuses previously compiled expressions.
421pub struct IncrementalCompiler {
422    /// Compilation context
423    context: CompilerContext,
424    /// Cache of compiled expressions
425    cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
426    /// Change detector
427    change_detector: ChangeDetector,
428    /// Statistics
429    stats: Arc<Mutex<IncrementalStats>>,
430    /// Next timestamp for LRU
431    next_timestamp: Arc<Mutex<u64>>,
432}
433
434impl IncrementalCompiler {
435    /// Create a new incremental compiler with the given context.
436    pub fn new(context: CompilerContext) -> Self {
437        let mut change_detector = ChangeDetector::new();
438        change_detector.update(&context);
439
440        Self {
441            context,
442            cache: Arc::new(Mutex::new(HashMap::new())),
443            change_detector,
444            stats: Arc::new(Mutex::new(IncrementalStats::default())),
445            next_timestamp: Arc::new(Mutex::new(0)),
446        }
447    }
448
449    /// Get the compilation context.
450    pub fn context(&self) -> &CompilerContext {
451        &self.context
452    }
453
454    /// Get a mutable reference to the compilation context.
455    pub fn context_mut(&mut self) -> &mut CompilerContext {
456        &mut self.context
457    }
458
459    /// Compile an expression incrementally.
460    pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
461        // Detect changes since last compilation
462        let changes = self.change_detector.detect_changes(&self.context);
463
464        // Invalidate affected entries if there are changes
465        if changes.has_changes() {
466            self.invalidate_affected(&changes);
467            self.change_detector.update(&self.context);
468        }
469
470        // Try to get from cache
471        let expr_key = format!("{:?}", expr);
472        let cache = self.cache.lock().unwrap();
473
474        if let Some(entry) = cache.get(&expr_key) {
475            // Cache hit!
476            let mut stats = self.stats.lock().unwrap();
477            stats.cache_hits += 1;
478            stats.nodes_reused += entry.graph.nodes.len();
479            drop(stats);
480
481            return Ok(entry.graph.clone());
482        }
483
484        // Cache miss - compile from scratch
485        drop(cache);
486
487        let deps = ExpressionDependencies::analyze(expr, &self.context);
488        // Compile from scratch - we can't use ? here because anyhow::Error doesn't convert to IrError
489        // So we'll just propagate the error as InvalidEinsumSpec
490        let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
491            IrError::InvalidEinsumSpec {
492                spec: format!("{:?}", expr),
493                reason: format!("Compilation failed: {}", e),
494            }
495        })?;
496
497        // Update stats
498        let mut stats = self.stats.lock().unwrap();
499        stats.cache_misses += 1;
500        stats.nodes_compiled += graph.nodes.len();
501        drop(stats);
502
503        // Store in cache
504        let mut timestamp_guard = self.next_timestamp.lock().unwrap();
505        let timestamp = *timestamp_guard;
506        *timestamp_guard += 1;
507        drop(timestamp_guard);
508
509        let mut cache = self.cache.lock().unwrap();
510        cache.insert(
511            expr_key,
512            CacheEntry {
513                graph: graph.clone(),
514                dependencies: deps,
515                timestamp,
516            },
517        );
518
519        Ok(graph)
520    }
521
522    /// Invalidate cache entries affected by changes.
523    fn invalidate_affected(&mut self, changes: &ChangeSet) {
524        let mut cache = self.cache.lock().unwrap();
525        cache.retain(|_, entry| !changes.affects(&entry.dependencies));
526
527        let mut stats = self.stats.lock().unwrap();
528        stats.invalidations += 1;
529    }
530
531    /// Clear the cache.
532    pub fn clear_cache(&mut self) {
533        let mut cache = self.cache.lock().unwrap();
534        cache.clear();
535    }
536
537    /// Get incremental compilation statistics.
538    pub fn stats(&self) -> IncrementalStats {
539        self.stats.lock().unwrap().clone()
540    }
541
542    /// Reset statistics.
543    pub fn reset_stats(&mut self) {
544        let mut stats = self.stats.lock().unwrap();
545        *stats = IncrementalStats::default();
546    }
547}
548
549/// Statistics for incremental compilation.
550#[derive(Debug, Clone, Default)]
551pub struct IncrementalStats {
552    /// Number of cache hits
553    pub cache_hits: usize,
554    /// Number of cache misses
555    pub cache_misses: usize,
556    /// Number of invalidations
557    pub invalidations: usize,
558    /// Number of nodes reused from cache
559    pub nodes_reused: usize,
560    /// Number of nodes freshly compiled
561    pub nodes_compiled: usize,
562}
563
564impl IncrementalStats {
565    /// Get the cache hit rate (0.0 to 1.0).
566    pub fn hit_rate(&self) -> f64 {
567        let total = self.cache_hits + self.cache_misses;
568        if total == 0 {
569            0.0
570        } else {
571            self.cache_hits as f64 / total as f64
572        }
573    }
574
575    /// Get the reuse rate for nodes (0.0 to 1.0).
576    pub fn reuse_rate(&self) -> f64 {
577        let total = self.nodes_reused + self.nodes_compiled;
578        if total == 0 {
579            0.0
580        } else {
581            self.nodes_reused as f64 / total as f64
582        }
583    }
584
585    /// Get the total number of compilations.
586    pub fn total_compilations(&self) -> usize {
587        self.cache_hits + self.cache_misses
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_dependency_tracking() {
597        let mut ctx = CompilerContext::new();
598        ctx.add_domain("Person", 100);
599
600        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
601        let deps = ExpressionDependencies::analyze(&expr, &ctx);
602
603        assert!(deps.predicates.contains("knows"));
604        assert!(deps.variables.contains("x"));
605        assert!(deps.variables.contains("y"));
606    }
607
608    #[test]
609    fn test_incremental_compilation_reuse() {
610        let mut ctx = CompilerContext::new();
611        ctx.add_domain("Person", 100);
612
613        let mut compiler = IncrementalCompiler::new(ctx);
614
615        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
616
617        // First compilation
618        let _graph1 = compiler.compile(&expr).unwrap();
619        assert_eq!(compiler.stats().cache_misses, 1);
620        assert_eq!(compiler.stats().cache_hits, 0);
621
622        // Second compilation - should hit cache
623        let _graph2 = compiler.compile(&expr).unwrap();
624        assert_eq!(compiler.stats().cache_misses, 1);
625        assert_eq!(compiler.stats().cache_hits, 1);
626        assert_eq!(compiler.stats().hit_rate(), 0.5);
627    }
628
629    #[test]
630    fn test_change_detection_domain() {
631        let mut ctx = CompilerContext::new();
632        ctx.add_domain("Person", 100);
633
634        let mut detector = ChangeDetector::new();
635        detector.update(&ctx);
636
637        // No changes initially
638        let changes = detector.detect_changes(&ctx);
639        assert!(!changes.has_changes());
640
641        // Change domain size
642        ctx.add_domain("Person", 200);
643        let changes = detector.detect_changes(&ctx);
644        assert!(changes.has_changes());
645        assert!(changes.changed_domains.contains("Person"));
646    }
647
648    #[test]
649    fn test_invalidation_on_domain_change() {
650        let mut ctx = CompilerContext::new();
651        ctx.add_domain("Person", 100);
652
653        let mut compiler = IncrementalCompiler::new(ctx);
654
655        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
656
657        // First compilation
658        let _graph1 = compiler.compile(&expr).unwrap();
659        assert_eq!(compiler.stats().cache_misses, 1);
660
661        // Change domain
662        compiler.context_mut().add_domain("Person", 200);
663
664        // Should recompile due to domain change and invalidate cache
665        let _graph2 = compiler.compile(&expr).unwrap();
666        // After invalidation, this is another cache miss
667        assert!(compiler.stats().cache_misses >= 1);
668        assert!(compiler.stats().invalidations >= 1);
669    }
670
671    #[test]
672    fn test_incremental_stats() {
673        let mut ctx = CompilerContext::new();
674        ctx.add_domain("Person", 100);
675
676        let mut compiler = IncrementalCompiler::new(ctx);
677
678        let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
679        let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
680
681        compiler.compile(&expr1).unwrap();
682        compiler.compile(&expr1).unwrap(); // Should be cache hit
683        compiler.compile(&expr2).unwrap();
684
685        let stats = compiler.stats();
686        assert_eq!(stats.total_compilations(), 3);
687        // At least one cache hit from the second expr1 compilation
688        assert!(
689            stats.cache_hits >= 1,
690            "Expected at least 1 cache hit, got {}",
691            stats.cache_hits
692        );
693        // Hit rate should be positive if we have cache hits
694        assert!(
695            stats.hit_rate() > 0.0,
696            "Expected positive hit rate, got {}",
697            stats.hit_rate()
698        );
699    }
700
701    #[test]
702    fn test_complex_expression_dependencies() {
703        let mut ctx = CompilerContext::new();
704        ctx.add_domain("Person", 100);
705
706        let expr = TLExpr::exists(
707            "x",
708            "Person",
709            TLExpr::and(
710                TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
711                TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
712            ),
713        );
714
715        let deps = ExpressionDependencies::analyze(&expr, &ctx);
716
717        assert!(deps.predicates.contains("knows"));
718        assert!(deps.predicates.contains("likes"));
719        assert!(deps.variables.contains("x"));
720        assert!(deps.variables.contains("y"));
721        assert!(deps.variables.contains("z"));
722        assert!(deps.domains.contains("Person"));
723    }
724}