Skip to main content

tensorlogic_infer/
rewrite.rs

1//! Graph rewriting engine for pattern-based optimizations.
2//!
3//! This module provides a powerful graph rewriting system:
4//! - **Pattern matching**: Find subgraphs matching specific patterns
5//! - **Rewrite rules**: Transform matched patterns into optimized equivalents
6//! - **Rule application**: Apply rules systematically with strategies
7//! - **Correctness checking**: Validate rewrites preserve semantics
8//! - **Performance tracking**: Measure impact of rewrites
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{RewriteEngine, RewriteRule, Pattern, RewriteStrategy};
14//!
15//! // Define a rewrite rule: A + 0 -> A
16//! let rule = RewriteRule::new("eliminate_add_zero")
17//!     .with_pattern(Pattern::binary_op("add", Pattern::any(), Pattern::zero()))
18//!     .with_replacement(|matched| matched.get_operand(0));
19//!
20//! // Create rewrite engine
21//! let mut engine = RewriteEngine::new()
22//!     .add_rule(rule)
23//!     .with_strategy(RewriteStrategy::Exhaustive);
24//!
25//! // Apply rewrites to graph
26//! let optimized = engine.rewrite(&graph)?;
27//! println!("Eliminated {} operations", engine.stats().rewrites_applied);
28//! ```
29
30use serde::{Deserialize, Serialize};
31use std::cmp::Reverse;
32use std::collections::{HashMap, HashSet};
33use thiserror::Error;
34
35/// Graph rewriting errors.
36#[derive(Error, Debug, Clone, PartialEq)]
37pub enum RewriteError {
38    #[error("Pattern matching failed: {0}")]
39    PatternMatchFailed(String),
40
41    #[error("Invalid rewrite rule: {0}")]
42    InvalidRule(String),
43
44    #[error("Rewrite application failed: {0}")]
45    ApplicationFailed(String),
46
47    #[error("Cycle detected in rewrite application")]
48    CycleDetected,
49
50    #[error("Semantics verification failed: {0}")]
51    SemanticsViolation(String),
52}
53
54/// Node identifier in the computation graph.
55pub type NodeId = usize;
56
57/// Graph pattern for matching.
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
59pub enum Pattern {
60    /// Match any node
61    Any,
62
63    /// Match a specific operation
64    Op(String),
65
66    /// Match a binary operation with two subpatterns
67    BinaryOp {
68        op: String,
69        left: Box<Pattern>,
70        right: Box<Pattern>,
71    },
72
73    /// Match a unary operation with one subpattern
74    UnaryOp { op: String, operand: Box<Pattern> },
75
76    /// Match a constant value
77    Constant(f64),
78
79    /// Match zero
80    Zero,
81
82    /// Match one
83    One,
84
85    /// Match a variable (captures the matched node)
86    Variable(String),
87
88    /// Match a sequence of operations
89    Sequence(Vec<Pattern>),
90
91    /// Match any of the given patterns
92    Alternative(Vec<Pattern>),
93}
94
95impl Pattern {
96    /// Create a pattern matching any node.
97    pub fn any() -> Self {
98        Pattern::Any
99    }
100
101    /// Create a pattern matching a specific operation.
102    pub fn op(name: impl Into<String>) -> Self {
103        Pattern::Op(name.into())
104    }
105
106    /// Create a pattern matching a binary operation.
107    pub fn binary_op(op: impl Into<String>, left: Pattern, right: Pattern) -> Self {
108        Pattern::BinaryOp {
109            op: op.into(),
110            left: Box::new(left),
111            right: Box::new(right),
112        }
113    }
114
115    /// Create a pattern matching a unary operation.
116    pub fn unary_op(op: impl Into<String>, operand: Pattern) -> Self {
117        Pattern::UnaryOp {
118            op: op.into(),
119            operand: Box::new(operand),
120        }
121    }
122
123    /// Create a pattern matching a constant.
124    pub fn constant(value: f64) -> Self {
125        Pattern::Constant(value)
126    }
127
128    /// Create a pattern matching zero.
129    pub fn zero() -> Self {
130        Pattern::Zero
131    }
132
133    /// Create a pattern matching one.
134    pub fn one() -> Self {
135        Pattern::One
136    }
137
138    /// Create a pattern matching a variable.
139    pub fn variable(name: impl Into<String>) -> Self {
140        Pattern::Variable(name.into())
141    }
142}
143
144/// A matched pattern instance.
145#[derive(Debug, Clone, PartialEq)]
146pub struct Match {
147    /// Root node of the match
148    pub root: NodeId,
149
150    /// Captured variables
151    pub captures: HashMap<String, NodeId>,
152
153    /// Matched nodes
154    pub matched_nodes: HashSet<NodeId>,
155}
156
157impl Match {
158    /// Create a new match.
159    pub fn new(root: NodeId) -> Self {
160        let mut matched_nodes = HashSet::new();
161        matched_nodes.insert(root);
162
163        Self {
164            root,
165            captures: HashMap::new(),
166            matched_nodes,
167        }
168    }
169
170    /// Get a captured node by variable name.
171    pub fn get_capture(&self, name: &str) -> Option<NodeId> {
172        self.captures.get(name).copied()
173    }
174
175    /// Add a capture.
176    pub fn with_capture(mut self, name: String, node: NodeId) -> Self {
177        self.captures.insert(name, node);
178        self.matched_nodes.insert(node);
179        self
180    }
181
182    /// Get all matched nodes.
183    pub fn nodes(&self) -> &HashSet<NodeId> {
184        &self.matched_nodes
185    }
186}
187
188/// Rewrite rule replacement function.
189pub type ReplacementFn = Box<dyn Fn(&Match) -> Result<NodeId, RewriteError>>;
190
191/// A graph rewrite rule.
192pub struct RewriteRule {
193    /// Rule name
194    pub name: String,
195
196    /// Pattern to match
197    pub pattern: Pattern,
198
199    /// Replacement function
200    pub replacement: ReplacementFn,
201
202    /// Rule priority (higher = applied first)
203    pub priority: i32,
204
205    /// Whether this rule preserves semantics
206    pub preserves_semantics: bool,
207}
208
209impl RewriteRule {
210    /// Create a new rewrite rule.
211    pub fn new(name: impl Into<String>) -> Self {
212        Self {
213            name: name.into(),
214            pattern: Pattern::Any,
215            replacement: Box::new(|m| Ok(m.root)),
216            priority: 0,
217            preserves_semantics: true,
218        }
219    }
220
221    /// Set the pattern to match.
222    pub fn with_pattern(mut self, pattern: Pattern) -> Self {
223        self.pattern = pattern;
224        self
225    }
226
227    /// Set the replacement function.
228    pub fn with_replacement<F>(mut self, f: F) -> Self
229    where
230        F: Fn(&Match) -> Result<NodeId, RewriteError> + 'static,
231    {
232        self.replacement = Box::new(f);
233        self
234    }
235
236    /// Set the rule priority.
237    pub fn with_priority(mut self, priority: i32) -> Self {
238        self.priority = priority;
239        self
240    }
241
242    /// Mark whether this rule preserves semantics.
243    pub fn with_semantics_preservation(mut self, preserves: bool) -> Self {
244        self.preserves_semantics = preserves;
245        self
246    }
247}
248
249/// Strategy for applying rewrite rules.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
251pub enum RewriteStrategy {
252    /// Apply each rule once to each node
253    OnePass,
254
255    /// Apply rules until no more matches found
256    Exhaustive,
257
258    /// Apply rules in a fixed-point manner (until convergence)
259    FixedPoint { max_iterations: usize },
260
261    /// Apply rules in order of priority
262    Prioritized,
263
264    /// Apply rules bottom-up (from leaves to root)
265    BottomUp,
266
267    /// Apply rules top-down (from root to leaves)
268    TopDown,
269}
270
271impl Default for RewriteStrategy {
272    fn default() -> Self {
273        RewriteStrategy::Exhaustive
274    }
275}
276
277/// Graph rewriting engine.
278pub struct RewriteEngine {
279    /// Rewrite rules
280    rules: Vec<RewriteRule>,
281
282    /// Application strategy
283    strategy: RewriteStrategy,
284
285    /// Statistics
286    stats: RewriteStats,
287
288    /// Enable verification
289    verify_semantics: bool,
290}
291
292impl RewriteEngine {
293    /// Create a new rewrite engine.
294    pub fn new() -> Self {
295        Self {
296            rules: Vec::new(),
297            strategy: RewriteStrategy::default(),
298            stats: RewriteStats::default(),
299            verify_semantics: false,
300        }
301    }
302
303    /// Add a rewrite rule.
304    pub fn add_rule(mut self, rule: RewriteRule) -> Self {
305        self.rules.push(rule);
306        self
307    }
308
309    /// Set the rewrite strategy.
310    pub fn with_strategy(mut self, strategy: RewriteStrategy) -> Self {
311        self.strategy = strategy;
312        self
313    }
314
315    /// Enable or disable semantics verification.
316    pub fn with_verification(mut self, enabled: bool) -> Self {
317        self.verify_semantics = enabled;
318        self
319    }
320
321    /// Get rewrite statistics.
322    pub fn stats(&self) -> &RewriteStats {
323        &self.stats
324    }
325
326    /// Reset statistics.
327    pub fn reset_stats(&mut self) {
328        self.stats = RewriteStats::default();
329    }
330
331    /// Sort rules by priority.
332    fn sort_rules_by_priority(&mut self) {
333        self.rules.sort_by_key(|b| Reverse(b.priority));
334    }
335
336    /// Apply rewrites to a simplified graph representation.
337    /// In a real implementation, this would work with the actual EinsumGraph.
338    pub fn rewrite_simple(&mut self, node_count: usize) -> Result<usize, RewriteError> {
339        self.stats.graphs_processed += 1;
340
341        match self.strategy {
342            RewriteStrategy::OnePass => self.apply_one_pass(node_count),
343            RewriteStrategy::Exhaustive => self.apply_exhaustive(node_count),
344            RewriteStrategy::FixedPoint { max_iterations } => {
345                self.apply_fixed_point(node_count, max_iterations)
346            }
347            RewriteStrategy::Prioritized => {
348                self.sort_rules_by_priority();
349                self.apply_one_pass(node_count)
350            }
351            RewriteStrategy::BottomUp | RewriteStrategy::TopDown => self.apply_one_pass(node_count),
352        }
353    }
354
355    fn apply_one_pass(&mut self, node_count: usize) -> Result<usize, RewriteError> {
356        let mut rewrites = 0;
357
358        // Simplified: just count how many rules could apply
359        for rule in &self.rules {
360            // In real implementation, would match pattern and apply replacement
361            if self.can_apply_rule(rule, node_count) {
362                rewrites += 1;
363                self.stats.rewrites_applied += 1;
364                self.stats
365                    .rule_applications
366                    .entry(rule.name.clone())
367                    .and_modify(|c| *c += 1)
368                    .or_insert(1);
369            }
370        }
371
372        Ok(node_count.saturating_sub(rewrites))
373    }
374
375    fn apply_exhaustive(&mut self, mut node_count: usize) -> Result<usize, RewriteError> {
376        let mut iteration = 0;
377        let max_iterations = 100; // Safety limit
378
379        loop {
380            iteration += 1;
381            if iteration > max_iterations {
382                return Err(RewriteError::CycleDetected);
383            }
384
385            let before = node_count;
386            node_count = self.apply_one_pass(node_count)?;
387
388            if node_count == before {
389                // Converged
390                break;
391            }
392        }
393
394        Ok(node_count)
395    }
396
397    fn apply_fixed_point(
398        &mut self,
399        mut node_count: usize,
400        max_iterations: usize,
401    ) -> Result<usize, RewriteError> {
402        for iteration in 0..max_iterations {
403            let before = node_count;
404            node_count = self.apply_one_pass(node_count)?;
405
406            if node_count == before {
407                self.stats.fixed_point_iterations = iteration + 1;
408                break;
409            }
410        }
411
412        Ok(node_count)
413    }
414
415    fn can_apply_rule(&self, _rule: &RewriteRule, _node_count: usize) -> bool {
416        // Simplified: in real implementation, would match pattern
417        true
418    }
419}
420
421impl Default for RewriteEngine {
422    fn default() -> Self {
423        Self::new()
424    }
425}
426
427/// Rewrite statistics.
428#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
429pub struct RewriteStats {
430    /// Number of graphs processed
431    pub graphs_processed: usize,
432
433    /// Total rewrites applied
434    pub rewrites_applied: usize,
435
436    /// Applications per rule
437    pub rule_applications: HashMap<String, usize>,
438
439    /// Fixed-point iterations
440    pub fixed_point_iterations: usize,
441
442    /// Verification failures
443    pub verification_failures: usize,
444}
445
446impl std::fmt::Display for RewriteStats {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        writeln!(f, "Rewrite Statistics")?;
449        writeln!(f, "==================")?;
450        writeln!(f, "Graphs processed:     {}", self.graphs_processed)?;
451        writeln!(f, "Rewrites applied:     {}", self.rewrites_applied)?;
452        writeln!(f, "Fixed-point iters:    {}", self.fixed_point_iterations)?;
453        writeln!(f, "Verification fails:   {}", self.verification_failures)?;
454
455        if !self.rule_applications.is_empty() {
456            writeln!(f, "\nRule Applications:")?;
457            let mut rules: Vec<_> = self.rule_applications.iter().collect();
458            rules.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
459            for (rule, count) in rules {
460                writeln!(f, "  {}: {}", rule, count)?;
461            }
462        }
463
464        Ok(())
465    }
466}
467
468/// Common rewrite rules for optimization.
469pub struct CommonRules;
470
471impl CommonRules {
472    /// Eliminate addition with zero: A + 0 -> A
473    pub fn eliminate_add_zero() -> RewriteRule {
474        RewriteRule::new("eliminate_add_zero")
475            .with_pattern(Pattern::binary_op("add", Pattern::any(), Pattern::zero()))
476            .with_replacement(|m| Ok(m.root))
477            .with_priority(10)
478    }
479
480    /// Eliminate multiplication by one: A * 1 -> A
481    pub fn eliminate_mul_one() -> RewriteRule {
482        RewriteRule::new("eliminate_mul_one")
483            .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::one()))
484            .with_replacement(|m| Ok(m.root))
485            .with_priority(10)
486    }
487
488    /// Eliminate multiplication by zero: A * 0 -> 0
489    pub fn eliminate_mul_zero() -> RewriteRule {
490        RewriteRule::new("eliminate_mul_zero")
491            .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::zero()))
492            .with_replacement(|_m| Ok(0)) // Return zero node
493            .with_priority(10)
494    }
495
496    /// Constant folding: C1 + C2 -> C3
497    pub fn constant_folding() -> RewriteRule {
498        RewriteRule::new("constant_folding")
499            .with_pattern(Pattern::binary_op(
500                "add",
501                Pattern::constant(0.0), // Placeholder
502                Pattern::constant(0.0),
503            ))
504            .with_replacement(|_m| Ok(0)) // Would compute result
505            .with_priority(20)
506    }
507
508    /// Associativity: (A + B) + C -> A + (B + C)
509    pub fn associativity_add() -> RewriteRule {
510        RewriteRule::new("associativity_add")
511            .with_pattern(Pattern::binary_op(
512                "add",
513                Pattern::binary_op("add", Pattern::any(), Pattern::any()),
514                Pattern::any(),
515            ))
516            .with_replacement(|m| Ok(m.root))
517            .with_priority(5)
518    }
519
520    /// Get all common optimization rules.
521    pub fn all() -> Vec<RewriteRule> {
522        vec![
523            Self::eliminate_add_zero(),
524            Self::eliminate_mul_one(),
525            Self::eliminate_mul_zero(),
526            Self::constant_folding(),
527            Self::associativity_add(),
528        ]
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_pattern_creation() {
538        let pattern = Pattern::binary_op("add", Pattern::any(), Pattern::zero());
539        assert!(matches!(pattern, Pattern::BinaryOp { .. }));
540    }
541
542    #[test]
543    fn test_pattern_helpers() {
544        let _ = Pattern::any();
545        let _ = Pattern::op("matmul");
546        let _ = Pattern::zero();
547        let _ = Pattern::one();
548        let _ = Pattern::constant(42.0);
549        let _ = Pattern::variable("x");
550    }
551
552    #[test]
553    fn test_match_creation() {
554        let m = Match::new(5);
555        assert_eq!(m.root, 5);
556        assert!(m.matched_nodes.contains(&5));
557    }
558
559    #[test]
560    fn test_match_captures() {
561        let m = Match::new(5).with_capture("x".to_string(), 10);
562        assert_eq!(m.get_capture("x"), Some(10));
563        assert!(m.matched_nodes.contains(&10));
564    }
565
566    #[test]
567    fn test_rewrite_rule_creation() {
568        let rule = RewriteRule::new("test_rule")
569            .with_pattern(Pattern::any())
570            .with_priority(10);
571
572        assert_eq!(rule.name, "test_rule");
573        assert_eq!(rule.priority, 10);
574    }
575
576    #[test]
577    fn test_rewrite_engine_creation() {
578        let engine = RewriteEngine::new();
579        assert_eq!(engine.rules.len(), 0);
580        assert_eq!(engine.strategy, RewriteStrategy::Exhaustive);
581    }
582
583    #[test]
584    fn test_rewrite_engine_add_rule() {
585        let rule = RewriteRule::new("test");
586        let engine = RewriteEngine::new().add_rule(rule);
587        assert_eq!(engine.rules.len(), 1);
588    }
589
590    #[test]
591    fn test_rewrite_strategy() {
592        let engine = RewriteEngine::new().with_strategy(RewriteStrategy::OnePass);
593        assert_eq!(engine.strategy, RewriteStrategy::OnePass);
594    }
595
596    #[test]
597    fn test_rewrite_stats() {
598        let stats = RewriteStats::default();
599        assert_eq!(stats.graphs_processed, 0);
600        assert_eq!(stats.rewrites_applied, 0);
601    }
602
603    #[test]
604    fn test_rewrite_stats_display() {
605        let mut stats = RewriteStats::default();
606        stats.graphs_processed = 5;
607        stats.rewrites_applied = 10;
608        stats.rule_applications.insert("rule1".to_string(), 7);
609
610        let display = format!("{}", stats);
611        assert!(display.contains("Graphs processed:     5"));
612        assert!(display.contains("Rewrites applied:     10"));
613    }
614
615    #[test]
616    fn test_common_rules() {
617        let rules = CommonRules::all();
618        assert!(!rules.is_empty());
619        assert_eq!(rules.len(), 5);
620    }
621
622    #[test]
623    fn test_eliminate_add_zero_rule() {
624        let rule = CommonRules::eliminate_add_zero();
625        assert_eq!(rule.name, "eliminate_add_zero");
626        assert_eq!(rule.priority, 10);
627    }
628
629    #[test]
630    fn test_rewrite_one_pass() {
631        let rule = RewriteRule::new("test");
632        let mut engine = RewriteEngine::new()
633            .add_rule(rule)
634            .with_strategy(RewriteStrategy::OnePass);
635
636        let result = engine.rewrite_simple(10).expect("unwrap");
637        assert!(result <= 10);
638        assert!(engine.stats().graphs_processed > 0);
639    }
640
641    #[test]
642    fn test_rewrite_exhaustive() {
643        let rule = RewriteRule::new("test");
644        let mut engine = RewriteEngine::new()
645            .add_rule(rule)
646            .with_strategy(RewriteStrategy::Exhaustive);
647
648        let result = engine.rewrite_simple(10).expect("unwrap");
649        assert!(result <= 10);
650    }
651
652    #[test]
653    fn test_rewrite_fixed_point() {
654        let rule = RewriteRule::new("test");
655        let mut engine = RewriteEngine::new()
656            .add_rule(rule)
657            .with_strategy(RewriteStrategy::FixedPoint { max_iterations: 10 });
658
659        let result = engine.rewrite_simple(10).expect("unwrap");
660        assert!(result <= 10);
661    }
662
663    #[test]
664    fn test_rewrite_prioritized() {
665        let rule1 = RewriteRule::new("low").with_priority(1);
666        let rule2 = RewriteRule::new("high").with_priority(10);
667
668        let mut engine = RewriteEngine::new()
669            .add_rule(rule1)
670            .add_rule(rule2)
671            .with_strategy(RewriteStrategy::Prioritized);
672
673        engine.rewrite_simple(10).expect("unwrap");
674        // After sorting, high priority rule should be first
675        assert_eq!(engine.rules[0].name, "high");
676    }
677
678    #[test]
679    fn test_reset_stats() {
680        let rule = RewriteRule::new("test");
681        let mut engine = RewriteEngine::new().add_rule(rule);
682
683        engine.rewrite_simple(10).expect("unwrap");
684        assert!(engine.stats().graphs_processed > 0);
685
686        engine.reset_stats();
687        assert_eq!(engine.stats().graphs_processed, 0);
688    }
689
690    #[test]
691    fn test_verification_flag() {
692        let engine = RewriteEngine::new().with_verification(true);
693        assert!(engine.verify_semantics);
694    }
695}