Skip to main content

tensorlogic_infer/
rewrite.rs

1//! Graph rewriting engine for pattern-based optimizations.
2//!
3//! This module provides a powerful graph rewriting system:
4//! - **Pattern matching**: Find subgraphs matching specific patterns
5//! - **Rewrite rules**: Transform matched patterns into optimized equivalents
6//! - **Rule application**: Apply rules systematically with strategies
7//! - **Correctness checking**: Validate rewrites preserve semantics
8//! - **Performance tracking**: Measure impact of rewrites
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{RewriteEngine, RewriteRule, Pattern, RewriteStrategy};
14//!
15//! // Define a rewrite rule: A + 0 -> A
16//! let rule = RewriteRule::new("eliminate_add_zero")
17//!     .with_pattern(Pattern::binary_op("add", Pattern::any(), Pattern::zero()))
18//!     .with_replacement(|matched| matched.get_operand(0));
19//!
20//! // Create rewrite engine
21//! let mut engine = RewriteEngine::new()
22//!     .add_rule(rule)
23//!     .with_strategy(RewriteStrategy::Exhaustive);
24//!
25//! // Apply rewrites to graph
26//! let optimized = engine.rewrite(&graph)?;
27//! println!("Eliminated {} operations", engine.stats().rewrites_applied);
28//! ```
29
30use serde::{Deserialize, Serialize};
31use std::collections::{HashMap, HashSet};
32use thiserror::Error;
33
34/// Graph rewriting errors.
35#[derive(Error, Debug, Clone, PartialEq)]
36pub enum RewriteError {
37    #[error("Pattern matching failed: {0}")]
38    PatternMatchFailed(String),
39
40    #[error("Invalid rewrite rule: {0}")]
41    InvalidRule(String),
42
43    #[error("Rewrite application failed: {0}")]
44    ApplicationFailed(String),
45
46    #[error("Cycle detected in rewrite application")]
47    CycleDetected,
48
49    #[error("Semantics verification failed: {0}")]
50    SemanticsViolation(String),
51}
52
53/// Node identifier in the computation graph.
54pub type NodeId = usize;
55
56/// Graph pattern for matching.
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58pub enum Pattern {
59    /// Match any node
60    Any,
61
62    /// Match a specific operation
63    Op(String),
64
65    /// Match a binary operation with two subpatterns
66    BinaryOp {
67        op: String,
68        left: Box<Pattern>,
69        right: Box<Pattern>,
70    },
71
72    /// Match a unary operation with one subpattern
73    UnaryOp { op: String, operand: Box<Pattern> },
74
75    /// Match a constant value
76    Constant(f64),
77
78    /// Match zero
79    Zero,
80
81    /// Match one
82    One,
83
84    /// Match a variable (captures the matched node)
85    Variable(String),
86
87    /// Match a sequence of operations
88    Sequence(Vec<Pattern>),
89
90    /// Match any of the given patterns
91    Alternative(Vec<Pattern>),
92}
93
94impl Pattern {
95    /// Create a pattern matching any node.
96    pub fn any() -> Self {
97        Pattern::Any
98    }
99
100    /// Create a pattern matching a specific operation.
101    pub fn op(name: impl Into<String>) -> Self {
102        Pattern::Op(name.into())
103    }
104
105    /// Create a pattern matching a binary operation.
106    pub fn binary_op(op: impl Into<String>, left: Pattern, right: Pattern) -> Self {
107        Pattern::BinaryOp {
108            op: op.into(),
109            left: Box::new(left),
110            right: Box::new(right),
111        }
112    }
113
114    /// Create a pattern matching a unary operation.
115    pub fn unary_op(op: impl Into<String>, operand: Pattern) -> Self {
116        Pattern::UnaryOp {
117            op: op.into(),
118            operand: Box::new(operand),
119        }
120    }
121
122    /// Create a pattern matching a constant.
123    pub fn constant(value: f64) -> Self {
124        Pattern::Constant(value)
125    }
126
127    /// Create a pattern matching zero.
128    pub fn zero() -> Self {
129        Pattern::Zero
130    }
131
132    /// Create a pattern matching one.
133    pub fn one() -> Self {
134        Pattern::One
135    }
136
137    /// Create a pattern matching a variable.
138    pub fn variable(name: impl Into<String>) -> Self {
139        Pattern::Variable(name.into())
140    }
141}
142
143/// A matched pattern instance.
144#[derive(Debug, Clone, PartialEq)]
145pub struct Match {
146    /// Root node of the match
147    pub root: NodeId,
148
149    /// Captured variables
150    pub captures: HashMap<String, NodeId>,
151
152    /// Matched nodes
153    pub matched_nodes: HashSet<NodeId>,
154}
155
156impl Match {
157    /// Create a new match.
158    pub fn new(root: NodeId) -> Self {
159        let mut matched_nodes = HashSet::new();
160        matched_nodes.insert(root);
161
162        Self {
163            root,
164            captures: HashMap::new(),
165            matched_nodes,
166        }
167    }
168
169    /// Get a captured node by variable name.
170    pub fn get_capture(&self, name: &str) -> Option<NodeId> {
171        self.captures.get(name).copied()
172    }
173
174    /// Add a capture.
175    pub fn with_capture(mut self, name: String, node: NodeId) -> Self {
176        self.captures.insert(name, node);
177        self.matched_nodes.insert(node);
178        self
179    }
180
181    /// Get all matched nodes.
182    pub fn nodes(&self) -> &HashSet<NodeId> {
183        &self.matched_nodes
184    }
185}
186
187/// Rewrite rule replacement function.
188pub type ReplacementFn = Box<dyn Fn(&Match) -> Result<NodeId, RewriteError>>;
189
190/// A graph rewrite rule.
191pub struct RewriteRule {
192    /// Rule name
193    pub name: String,
194
195    /// Pattern to match
196    pub pattern: Pattern,
197
198    /// Replacement function
199    pub replacement: ReplacementFn,
200
201    /// Rule priority (higher = applied first)
202    pub priority: i32,
203
204    /// Whether this rule preserves semantics
205    pub preserves_semantics: bool,
206}
207
208impl RewriteRule {
209    /// Create a new rewrite rule.
210    pub fn new(name: impl Into<String>) -> Self {
211        Self {
212            name: name.into(),
213            pattern: Pattern::Any,
214            replacement: Box::new(|m| Ok(m.root)),
215            priority: 0,
216            preserves_semantics: true,
217        }
218    }
219
220    /// Set the pattern to match.
221    pub fn with_pattern(mut self, pattern: Pattern) -> Self {
222        self.pattern = pattern;
223        self
224    }
225
226    /// Set the replacement function.
227    pub fn with_replacement<F>(mut self, f: F) -> Self
228    where
229        F: Fn(&Match) -> Result<NodeId, RewriteError> + 'static,
230    {
231        self.replacement = Box::new(f);
232        self
233    }
234
235    /// Set the rule priority.
236    pub fn with_priority(mut self, priority: i32) -> Self {
237        self.priority = priority;
238        self
239    }
240
241    /// Mark whether this rule preserves semantics.
242    pub fn with_semantics_preservation(mut self, preserves: bool) -> Self {
243        self.preserves_semantics = preserves;
244        self
245    }
246}
247
248/// Strategy for applying rewrite rules.
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
250pub enum RewriteStrategy {
251    /// Apply each rule once to each node
252    OnePass,
253
254    /// Apply rules until no more matches found
255    Exhaustive,
256
257    /// Apply rules in a fixed-point manner (until convergence)
258    FixedPoint { max_iterations: usize },
259
260    /// Apply rules in order of priority
261    Prioritized,
262
263    /// Apply rules bottom-up (from leaves to root)
264    BottomUp,
265
266    /// Apply rules top-down (from root to leaves)
267    TopDown,
268}
269
270impl Default for RewriteStrategy {
271    fn default() -> Self {
272        RewriteStrategy::Exhaustive
273    }
274}
275
276/// Graph rewriting engine.
277pub struct RewriteEngine {
278    /// Rewrite rules
279    rules: Vec<RewriteRule>,
280
281    /// Application strategy
282    strategy: RewriteStrategy,
283
284    /// Statistics
285    stats: RewriteStats,
286
287    /// Enable verification
288    verify_semantics: bool,
289}
290
291impl RewriteEngine {
292    /// Create a new rewrite engine.
293    pub fn new() -> Self {
294        Self {
295            rules: Vec::new(),
296            strategy: RewriteStrategy::default(),
297            stats: RewriteStats::default(),
298            verify_semantics: false,
299        }
300    }
301
302    /// Add a rewrite rule.
303    pub fn add_rule(mut self, rule: RewriteRule) -> Self {
304        self.rules.push(rule);
305        self
306    }
307
308    /// Set the rewrite strategy.
309    pub fn with_strategy(mut self, strategy: RewriteStrategy) -> Self {
310        self.strategy = strategy;
311        self
312    }
313
314    /// Enable or disable semantics verification.
315    pub fn with_verification(mut self, enabled: bool) -> Self {
316        self.verify_semantics = enabled;
317        self
318    }
319
320    /// Get rewrite statistics.
321    pub fn stats(&self) -> &RewriteStats {
322        &self.stats
323    }
324
325    /// Reset statistics.
326    pub fn reset_stats(&mut self) {
327        self.stats = RewriteStats::default();
328    }
329
330    /// Sort rules by priority.
331    fn sort_rules_by_priority(&mut self) {
332        self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
333    }
334
335    /// Apply rewrites to a simplified graph representation.
336    /// In a real implementation, this would work with the actual EinsumGraph.
337    pub fn rewrite_simple(&mut self, node_count: usize) -> Result<usize, RewriteError> {
338        self.stats.graphs_processed += 1;
339
340        match self.strategy {
341            RewriteStrategy::OnePass => self.apply_one_pass(node_count),
342            RewriteStrategy::Exhaustive => self.apply_exhaustive(node_count),
343            RewriteStrategy::FixedPoint { max_iterations } => {
344                self.apply_fixed_point(node_count, max_iterations)
345            }
346            RewriteStrategy::Prioritized => {
347                self.sort_rules_by_priority();
348                self.apply_one_pass(node_count)
349            }
350            RewriteStrategy::BottomUp | RewriteStrategy::TopDown => self.apply_one_pass(node_count),
351        }
352    }
353
354    fn apply_one_pass(&mut self, node_count: usize) -> Result<usize, RewriteError> {
355        let mut rewrites = 0;
356
357        // Simplified: just count how many rules could apply
358        for rule in &self.rules {
359            // In real implementation, would match pattern and apply replacement
360            if self.can_apply_rule(rule, node_count) {
361                rewrites += 1;
362                self.stats.rewrites_applied += 1;
363                self.stats
364                    .rule_applications
365                    .entry(rule.name.clone())
366                    .and_modify(|c| *c += 1)
367                    .or_insert(1);
368            }
369        }
370
371        Ok(node_count.saturating_sub(rewrites))
372    }
373
374    fn apply_exhaustive(&mut self, mut node_count: usize) -> Result<usize, RewriteError> {
375        let mut iteration = 0;
376        let max_iterations = 100; // Safety limit
377
378        loop {
379            iteration += 1;
380            if iteration > max_iterations {
381                return Err(RewriteError::CycleDetected);
382            }
383
384            let before = node_count;
385            node_count = self.apply_one_pass(node_count)?;
386
387            if node_count == before {
388                // Converged
389                break;
390            }
391        }
392
393        Ok(node_count)
394    }
395
396    fn apply_fixed_point(
397        &mut self,
398        mut node_count: usize,
399        max_iterations: usize,
400    ) -> Result<usize, RewriteError> {
401        for iteration in 0..max_iterations {
402            let before = node_count;
403            node_count = self.apply_one_pass(node_count)?;
404
405            if node_count == before {
406                self.stats.fixed_point_iterations = iteration + 1;
407                break;
408            }
409        }
410
411        Ok(node_count)
412    }
413
414    fn can_apply_rule(&self, _rule: &RewriteRule, _node_count: usize) -> bool {
415        // Simplified: in real implementation, would match pattern
416        true
417    }
418}
419
420impl Default for RewriteEngine {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426/// Rewrite statistics.
427#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
428pub struct RewriteStats {
429    /// Number of graphs processed
430    pub graphs_processed: usize,
431
432    /// Total rewrites applied
433    pub rewrites_applied: usize,
434
435    /// Applications per rule
436    pub rule_applications: HashMap<String, usize>,
437
438    /// Fixed-point iterations
439    pub fixed_point_iterations: usize,
440
441    /// Verification failures
442    pub verification_failures: usize,
443}
444
445impl std::fmt::Display for RewriteStats {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        writeln!(f, "Rewrite Statistics")?;
448        writeln!(f, "==================")?;
449        writeln!(f, "Graphs processed:     {}", self.graphs_processed)?;
450        writeln!(f, "Rewrites applied:     {}", self.rewrites_applied)?;
451        writeln!(f, "Fixed-point iters:    {}", self.fixed_point_iterations)?;
452        writeln!(f, "Verification fails:   {}", self.verification_failures)?;
453
454        if !self.rule_applications.is_empty() {
455            writeln!(f, "\nRule Applications:")?;
456            let mut rules: Vec<_> = self.rule_applications.iter().collect();
457            rules.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
458            for (rule, count) in rules {
459                writeln!(f, "  {}: {}", rule, count)?;
460            }
461        }
462
463        Ok(())
464    }
465}
466
467/// Common rewrite rules for optimization.
468pub struct CommonRules;
469
470impl CommonRules {
471    /// Eliminate addition with zero: A + 0 -> A
472    pub fn eliminate_add_zero() -> RewriteRule {
473        RewriteRule::new("eliminate_add_zero")
474            .with_pattern(Pattern::binary_op("add", Pattern::any(), Pattern::zero()))
475            .with_replacement(|m| Ok(m.root))
476            .with_priority(10)
477    }
478
479    /// Eliminate multiplication by one: A * 1 -> A
480    pub fn eliminate_mul_one() -> RewriteRule {
481        RewriteRule::new("eliminate_mul_one")
482            .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::one()))
483            .with_replacement(|m| Ok(m.root))
484            .with_priority(10)
485    }
486
487    /// Eliminate multiplication by zero: A * 0 -> 0
488    pub fn eliminate_mul_zero() -> RewriteRule {
489        RewriteRule::new("eliminate_mul_zero")
490            .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::zero()))
491            .with_replacement(|_m| Ok(0)) // Return zero node
492            .with_priority(10)
493    }
494
495    /// Constant folding: C1 + C2 -> C3
496    pub fn constant_folding() -> RewriteRule {
497        RewriteRule::new("constant_folding")
498            .with_pattern(Pattern::binary_op(
499                "add",
500                Pattern::constant(0.0), // Placeholder
501                Pattern::constant(0.0),
502            ))
503            .with_replacement(|_m| Ok(0)) // Would compute result
504            .with_priority(20)
505    }
506
507    /// Associativity: (A + B) + C -> A + (B + C)
508    pub fn associativity_add() -> RewriteRule {
509        RewriteRule::new("associativity_add")
510            .with_pattern(Pattern::binary_op(
511                "add",
512                Pattern::binary_op("add", Pattern::any(), Pattern::any()),
513                Pattern::any(),
514            ))
515            .with_replacement(|m| Ok(m.root))
516            .with_priority(5)
517    }
518
519    /// Get all common optimization rules.
520    pub fn all() -> Vec<RewriteRule> {
521        vec![
522            Self::eliminate_add_zero(),
523            Self::eliminate_mul_one(),
524            Self::eliminate_mul_zero(),
525            Self::constant_folding(),
526            Self::associativity_add(),
527        ]
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_pattern_creation() {
537        let pattern = Pattern::binary_op("add", Pattern::any(), Pattern::zero());
538        assert!(matches!(pattern, Pattern::BinaryOp { .. }));
539    }
540
541    #[test]
542    fn test_pattern_helpers() {
543        let _ = Pattern::any();
544        let _ = Pattern::op("matmul");
545        let _ = Pattern::zero();
546        let _ = Pattern::one();
547        let _ = Pattern::constant(42.0);
548        let _ = Pattern::variable("x");
549    }
550
551    #[test]
552    fn test_match_creation() {
553        let m = Match::new(5);
554        assert_eq!(m.root, 5);
555        assert!(m.matched_nodes.contains(&5));
556    }
557
558    #[test]
559    fn test_match_captures() {
560        let m = Match::new(5).with_capture("x".to_string(), 10);
561        assert_eq!(m.get_capture("x"), Some(10));
562        assert!(m.matched_nodes.contains(&10));
563    }
564
565    #[test]
566    fn test_rewrite_rule_creation() {
567        let rule = RewriteRule::new("test_rule")
568            .with_pattern(Pattern::any())
569            .with_priority(10);
570
571        assert_eq!(rule.name, "test_rule");
572        assert_eq!(rule.priority, 10);
573    }
574
575    #[test]
576    fn test_rewrite_engine_creation() {
577        let engine = RewriteEngine::new();
578        assert_eq!(engine.rules.len(), 0);
579        assert_eq!(engine.strategy, RewriteStrategy::Exhaustive);
580    }
581
582    #[test]
583    fn test_rewrite_engine_add_rule() {
584        let rule = RewriteRule::new("test");
585        let engine = RewriteEngine::new().add_rule(rule);
586        assert_eq!(engine.rules.len(), 1);
587    }
588
589    #[test]
590    fn test_rewrite_strategy() {
591        let engine = RewriteEngine::new().with_strategy(RewriteStrategy::OnePass);
592        assert_eq!(engine.strategy, RewriteStrategy::OnePass);
593    }
594
595    #[test]
596    fn test_rewrite_stats() {
597        let stats = RewriteStats::default();
598        assert_eq!(stats.graphs_processed, 0);
599        assert_eq!(stats.rewrites_applied, 0);
600    }
601
602    #[test]
603    fn test_rewrite_stats_display() {
604        let mut stats = RewriteStats::default();
605        stats.graphs_processed = 5;
606        stats.rewrites_applied = 10;
607        stats.rule_applications.insert("rule1".to_string(), 7);
608
609        let display = format!("{}", stats);
610        assert!(display.contains("Graphs processed:     5"));
611        assert!(display.contains("Rewrites applied:     10"));
612    }
613
614    #[test]
615    fn test_common_rules() {
616        let rules = CommonRules::all();
617        assert!(!rules.is_empty());
618        assert_eq!(rules.len(), 5);
619    }
620
621    #[test]
622    fn test_eliminate_add_zero_rule() {
623        let rule = CommonRules::eliminate_add_zero();
624        assert_eq!(rule.name, "eliminate_add_zero");
625        assert_eq!(rule.priority, 10);
626    }
627
628    #[test]
629    fn test_rewrite_one_pass() {
630        let rule = RewriteRule::new("test");
631        let mut engine = RewriteEngine::new()
632            .add_rule(rule)
633            .with_strategy(RewriteStrategy::OnePass);
634
635        let result = engine.rewrite_simple(10).unwrap();
636        assert!(result <= 10);
637        assert!(engine.stats().graphs_processed > 0);
638    }
639
640    #[test]
641    fn test_rewrite_exhaustive() {
642        let rule = RewriteRule::new("test");
643        let mut engine = RewriteEngine::new()
644            .add_rule(rule)
645            .with_strategy(RewriteStrategy::Exhaustive);
646
647        let result = engine.rewrite_simple(10).unwrap();
648        assert!(result <= 10);
649    }
650
651    #[test]
652    fn test_rewrite_fixed_point() {
653        let rule = RewriteRule::new("test");
654        let mut engine = RewriteEngine::new()
655            .add_rule(rule)
656            .with_strategy(RewriteStrategy::FixedPoint { max_iterations: 10 });
657
658        let result = engine.rewrite_simple(10).unwrap();
659        assert!(result <= 10);
660    }
661
662    #[test]
663    fn test_rewrite_prioritized() {
664        let rule1 = RewriteRule::new("low").with_priority(1);
665        let rule2 = RewriteRule::new("high").with_priority(10);
666
667        let mut engine = RewriteEngine::new()
668            .add_rule(rule1)
669            .add_rule(rule2)
670            .with_strategy(RewriteStrategy::Prioritized);
671
672        engine.rewrite_simple(10).unwrap();
673        // After sorting, high priority rule should be first
674        assert_eq!(engine.rules[0].name, "high");
675    }
676
677    #[test]
678    fn test_reset_stats() {
679        let rule = RewriteRule::new("test");
680        let mut engine = RewriteEngine::new().add_rule(rule);
681
682        engine.rewrite_simple(10).unwrap();
683        assert!(engine.stats().graphs_processed > 0);
684
685        engine.reset_stats();
686        assert_eq!(engine.stats().graphs_processed, 0);
687    }
688
689    #[test]
690    fn test_verification_flag() {
691        let engine = RewriteEngine::new().with_verification(true);
692        assert!(engine.verify_semantics);
693    }
694}