Skip to main content

tensorlogic_ir/graph/
pattern.rs

1//! Graph pattern matching and rewriting system.
2//!
3//! This module provides a sophisticated pattern matching and rewriting framework
4//! for tensor computation graphs. It enables:
5//! - Pattern-based graph optimization
6//! - Backend-specific transformations
7//! - Custom fusion strategies
8//! - Automatic graph simplification
9
10use std::collections::{HashMap, HashSet};
11
12use super::{EinsumGraph, EinsumNode, OpType};
13use crate::error::IrError;
14
15/// A pattern that can match against graph structures
16#[derive(Debug, Clone, PartialEq)]
17pub enum GraphPattern {
18    /// Match any single node
19    AnyNode,
20    /// Match a specific operation type
21    OpType(OpType),
22    /// Match a sequence of operations
23    Sequence(Vec<GraphPattern>),
24    /// Match any of the given patterns
25    Choice(Vec<GraphPattern>),
26    /// Match a node with specific input count
27    WithInputs(usize),
28    /// Match a node with specific output count
29    WithOutputs(usize),
30    /// Match a named subpattern for capture
31    Capture(String, Box<GraphPattern>),
32    /// Match a pattern zero or more times
33    ZeroOrMore(Box<GraphPattern>),
34    /// Match a pattern one or more times
35    OneOrMore(Box<GraphPattern>),
36}
37
38/// Result of a successful pattern match
39#[derive(Debug, Clone)]
40pub struct PatternMatch {
41    /// Indices of matched nodes in order
42    pub matched_nodes: Vec<usize>,
43    /// Captured subpatterns by name
44    pub captures: HashMap<String, Vec<usize>>,
45    /// Matched tensors (inputs and outputs)
46    pub matched_tensors: HashSet<usize>,
47}
48
49impl PatternMatch {
50    /// Create a new empty pattern match
51    pub fn new() -> Self {
52        Self {
53            matched_nodes: Vec::new(),
54            captures: HashMap::new(),
55            matched_tensors: HashSet::new(),
56        }
57    }
58
59    /// Add a matched node
60    pub fn add_node(&mut self, node_idx: usize) {
61        self.matched_nodes.push(node_idx);
62    }
63
64    /// Add a capture
65    pub fn add_capture(&mut self, name: String, node_idx: usize) {
66        self.captures.entry(name).or_default().push(node_idx);
67    }
68
69    /// Get nodes for a capture by name
70    pub fn get_capture(&self, name: &str) -> Option<&[usize]> {
71        self.captures.get(name).map(|v| v.as_slice())
72    }
73}
74
75impl Default for PatternMatch {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81/// A rewrite rule that transforms matched patterns
82#[derive(Debug, Clone)]
83pub struct GraphRewriteRule {
84    /// Name of this rule for debugging
85    pub name: String,
86    /// Pattern to match
87    pub pattern: GraphPattern,
88    /// Function to apply the rewrite
89    pub rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
90    /// Priority (higher = applied first)
91    pub priority: i32,
92}
93
94impl GraphRewriteRule {
95    /// Create a new rewrite rule
96    pub fn new(
97        name: impl Into<String>,
98        pattern: GraphPattern,
99        rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
100    ) -> Self {
101        Self {
102            name: name.into(),
103            pattern,
104            rewriter,
105            priority: 0,
106        }
107    }
108
109    /// Set the priority of this rule
110    pub fn with_priority(mut self, priority: i32) -> Self {
111        self.priority = priority;
112        self
113    }
114}
115
116/// Statistics about pattern matching and rewriting
117#[derive(Debug, Clone, Default)]
118pub struct RewriteStats {
119    /// Number of patterns matched
120    pub patterns_matched: usize,
121    /// Number of rewrites applied
122    pub rewrites_applied: usize,
123    /// Number of nodes before rewriting
124    pub nodes_before: usize,
125    /// Number of nodes after rewriting
126    pub nodes_after: usize,
127    /// Nodes eliminated by rewriting
128    pub nodes_eliminated: usize,
129}
130
131impl RewriteStats {
132    /// Create new empty stats
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Calculate reduction percentage
138    pub fn reduction_percentage(&self) -> f64 {
139        if self.nodes_before == 0 {
140            return 0.0;
141        }
142        (self.nodes_eliminated as f64 / self.nodes_before as f64) * 100.0
143    }
144}
145
146/// Pattern matcher for graphs
147pub struct PatternMatcher {
148    /// Rules to apply
149    rules: Vec<GraphRewriteRule>,
150}
151
152impl PatternMatcher {
153    /// Create a new pattern matcher
154    pub fn new() -> Self {
155        Self { rules: Vec::new() }
156    }
157
158    /// Add a rewrite rule
159    pub fn add_rule(&mut self, rule: GraphRewriteRule) {
160        self.rules.push(rule);
161        // Sort by priority (descending)
162        self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
163    }
164
165    /// Find all matches for a pattern in the graph
166    pub fn find_matches(&self, graph: &EinsumGraph, pattern: &GraphPattern) -> Vec<PatternMatch> {
167        let mut matches = Vec::new();
168
169        // Try to match starting from each node
170        for start_idx in 0..graph.nodes.len() {
171            if let Some(m) = self.try_match_from(graph, pattern, start_idx, &HashSet::new()) {
172                matches.push(m);
173            }
174        }
175
176        matches
177    }
178
179    /// Try to match a pattern starting from a specific node
180    fn try_match_from(
181        &self,
182        graph: &EinsumGraph,
183        pattern: &GraphPattern,
184        start_idx: usize,
185        visited: &HashSet<usize>,
186    ) -> Option<PatternMatch> {
187        if start_idx >= graph.nodes.len() || visited.contains(&start_idx) {
188            return None;
189        }
190
191        match pattern {
192            GraphPattern::AnyNode => {
193                let mut m = PatternMatch::new();
194                m.add_node(start_idx);
195                Some(m)
196            }
197
198            GraphPattern::OpType(expected_op) => {
199                let node = &graph.nodes[start_idx];
200                if Self::op_matches(&node.op, expected_op) {
201                    let mut m = PatternMatch::new();
202                    m.add_node(start_idx);
203                    Some(m)
204                } else {
205                    None
206                }
207            }
208
209            GraphPattern::WithInputs(count) => {
210                let node = &graph.nodes[start_idx];
211                if node.inputs.len() == *count {
212                    let mut m = PatternMatch::new();
213                    m.add_node(start_idx);
214                    Some(m)
215                } else {
216                    None
217                }
218            }
219
220            GraphPattern::WithOutputs(count) => {
221                let node = &graph.nodes[start_idx];
222                if node.outputs.len() == *count {
223                    let mut m = PatternMatch::new();
224                    m.add_node(start_idx);
225                    Some(m)
226                } else {
227                    None
228                }
229            }
230
231            GraphPattern::Capture(name, sub_pattern) => {
232                if let Some(mut m) = self.try_match_from(graph, sub_pattern, start_idx, visited) {
233                    m.add_capture(name.clone(), start_idx);
234                    Some(m)
235                } else {
236                    None
237                }
238            }
239
240            GraphPattern::Sequence(patterns) => {
241                self.match_sequence(graph, patterns, start_idx, visited)
242            }
243
244            GraphPattern::Choice(patterns) => {
245                for pat in patterns {
246                    if let Some(m) = self.try_match_from(graph, pat, start_idx, visited) {
247                        return Some(m);
248                    }
249                }
250                None
251            }
252
253            GraphPattern::OneOrMore(sub_pattern) => {
254                self.match_one_or_more(graph, sub_pattern, start_idx, visited)
255            }
256
257            GraphPattern::ZeroOrMore(sub_pattern) => {
258                if let Some(m) = self.match_one_or_more(graph, sub_pattern, start_idx, visited) {
259                    Some(m)
260                } else {
261                    // Zero matches is valid
262                    Some(PatternMatch::new())
263                }
264            }
265        }
266    }
267
268    /// Match a sequence of patterns
269    fn match_sequence(
270        &self,
271        graph: &EinsumGraph,
272        patterns: &[GraphPattern],
273        start_idx: usize,
274        visited: &HashSet<usize>,
275    ) -> Option<PatternMatch> {
276        if patterns.is_empty() {
277            return Some(PatternMatch::new());
278        }
279
280        let mut result = PatternMatch::new();
281        let mut current_visited = visited.clone();
282        let mut current_idx = start_idx;
283
284        for pattern in patterns {
285            if let Some(m) = self.try_match_from(graph, pattern, current_idx, &current_visited) {
286                // Merge matches
287                for &node in &m.matched_nodes {
288                    result.add_node(node);
289                    current_visited.insert(node);
290                }
291                for (name, nodes) in m.captures {
292                    for node in nodes {
293                        result.add_capture(name.clone(), node);
294                    }
295                }
296
297                // Move to next node (follow data flow)
298                if let Some(&last_node) = m.matched_nodes.last() {
299                    if let Some(next) = self.find_successor(graph, last_node) {
300                        current_idx = next;
301                    } else {
302                        return None; // No successor, can't continue sequence
303                    }
304                }
305            } else {
306                return None;
307            }
308        }
309
310        Some(result)
311    }
312
313    /// Match one or more occurrences of a pattern
314    fn match_one_or_more(
315        &self,
316        graph: &EinsumGraph,
317        pattern: &GraphPattern,
318        start_idx: usize,
319        visited: &HashSet<usize>,
320    ) -> Option<PatternMatch> {
321        let mut result = PatternMatch::new();
322        let mut current_visited = visited.clone();
323        let mut current_idx = start_idx;
324        let mut matched_any = false;
325
326        loop {
327            if let Some(m) = self.try_match_from(graph, pattern, current_idx, &current_visited) {
328                matched_any = true;
329
330                // Merge matches
331                for &node in &m.matched_nodes {
332                    result.add_node(node);
333                    current_visited.insert(node);
334                }
335
336                // Try to continue matching
337                if let Some(&last_node) = m.matched_nodes.last() {
338                    if let Some(next) = self.find_successor(graph, last_node) {
339                        current_idx = next;
340                        continue;
341                    }
342                }
343            }
344            break;
345        }
346
347        if matched_any {
348            Some(result)
349        } else {
350            None
351        }
352    }
353
354    /// Find the successor of a node in the dataflow
355    fn find_successor(&self, graph: &EinsumGraph, node_idx: usize) -> Option<usize> {
356        let node = &graph.nodes[node_idx];
357
358        // Find a node that uses the output of this node
359        for &output_tensor in &node.outputs {
360            for (idx, other_node) in graph.nodes.iter().enumerate() {
361                if other_node.inputs.contains(&output_tensor) {
362                    return Some(idx);
363                }
364            }
365        }
366
367        None
368    }
369
370    /// Check if two operation types match
371    fn op_matches(actual: &OpType, expected: &OpType) -> bool {
372        match (actual, expected) {
373            (OpType::Einsum { .. }, OpType::Einsum { .. }) => true,
374            (OpType::ElemUnary { op: a }, OpType::ElemUnary { op: b }) => a == b,
375            (OpType::ElemBinary { op: a }, OpType::ElemBinary { op: b }) => a == b,
376            (OpType::Reduce { op: a, .. }, OpType::Reduce { op: b, .. }) => a == b,
377            _ => false,
378        }
379    }
380
381    /// Apply all rules to a graph and return rewrite statistics
382    pub fn apply_rules(&self, graph: &mut EinsumGraph) -> Result<RewriteStats, IrError> {
383        let mut stats = RewriteStats::new();
384        stats.nodes_before = graph.nodes.len();
385
386        let mut modified = true;
387        let mut iterations = 0;
388        const MAX_ITERATIONS: usize = 100;
389
390        while modified && iterations < MAX_ITERATIONS {
391            modified = false;
392            iterations += 1;
393
394            for rule in &self.rules {
395                let matches = self.find_matches(graph, &rule.pattern);
396
397                for m in matches {
398                    stats.patterns_matched += 1;
399
400                    // Apply the rewrite
401                    if let Ok(new_nodes) = (rule.rewriter)(graph, &m) {
402                        // Replace matched nodes with new nodes
403                        if self.apply_rewrite(graph, &m, new_nodes)? {
404                            stats.rewrites_applied += 1;
405                            modified = true;
406                        }
407                    }
408                }
409            }
410        }
411
412        stats.nodes_after = graph.nodes.len();
413        stats.nodes_eliminated = stats.nodes_before.saturating_sub(stats.nodes_after);
414
415        Ok(stats)
416    }
417
418    /// Apply a rewrite by replacing matched nodes with new nodes
419    fn apply_rewrite(
420        &self,
421        _graph: &mut EinsumGraph,
422        _pattern_match: &PatternMatch,
423        _new_nodes: Vec<EinsumNode>,
424    ) -> Result<bool, IrError> {
425        // This is a simplified implementation
426        // In a full implementation, we would:
427        // 1. Remove the matched nodes
428        // 2. Insert the new nodes
429        // 3. Rewire connections
430        // 4. Update tensor indices
431
432        // For now, just indicate success without modification
433        Ok(false)
434    }
435}
436
437impl Default for PatternMatcher {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443/// Common graph rewrite patterns
444pub mod patterns {
445    use super::*;
446
447    /// Pattern for consecutive element-wise operations
448    #[allow(dead_code)]
449    pub fn elementwise_chain(min_length: usize) -> GraphPattern {
450        let elem_op = GraphPattern::Choice(vec![
451            GraphPattern::OpType(OpType::ElemUnary { op: String::new() }),
452            GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
453        ]);
454
455        if min_length == 1 {
456            GraphPattern::OneOrMore(Box::new(elem_op))
457        } else {
458            let mut sequence = Vec::new();
459            for _ in 0..min_length {
460                sequence.push(elem_op.clone());
461            }
462            GraphPattern::Sequence(sequence)
463        }
464    }
465
466    /// Pattern for einsum followed by reduction
467    #[allow(dead_code)]
468    pub fn einsum_reduce() -> GraphPattern {
469        GraphPattern::Sequence(vec![
470            GraphPattern::OpType(OpType::Einsum {
471                spec: String::new(),
472            }),
473            GraphPattern::OpType(OpType::Reduce {
474                op: String::new(),
475                axes: Vec::new(),
476            }),
477        ])
478    }
479
480    /// Pattern for map-reduce idiom
481    #[allow(dead_code)]
482    pub fn map_reduce() -> GraphPattern {
483        GraphPattern::Sequence(vec![
484            GraphPattern::Capture(
485                "map".to_string(),
486                Box::new(GraphPattern::OpType(OpType::ElemUnary {
487                    op: String::new(),
488                })),
489            ),
490            GraphPattern::Capture(
491                "reduce".to_string(),
492                Box::new(GraphPattern::OpType(OpType::Reduce {
493                    op: String::new(),
494                    axes: Vec::new(),
495                })),
496            ),
497        ])
498    }
499
500    /// Pattern for broadcast followed by element-wise op
501    #[allow(dead_code)]
502    pub fn broadcast_elementwise() -> GraphPattern {
503        GraphPattern::Sequence(vec![
504            GraphPattern::OpType(OpType::ElemBinary {
505                op: "broadcast".to_string(),
506            }),
507            GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
508        ])
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_pattern_match_creation() {
518        let m = PatternMatch::new();
519        assert!(m.matched_nodes.is_empty());
520        assert!(m.captures.is_empty());
521    }
522
523    #[test]
524    fn test_pattern_match_add_node() {
525        let mut m = PatternMatch::new();
526        m.add_node(0);
527        m.add_node(1);
528        assert_eq!(m.matched_nodes, vec![0, 1]);
529    }
530
531    #[test]
532    fn test_pattern_match_capture() {
533        let mut m = PatternMatch::new();
534        m.add_capture("test".to_string(), 5);
535        assert_eq!(m.get_capture("test"), Some(&[5][..]));
536        assert_eq!(m.get_capture("nonexistent"), None);
537    }
538
539    #[test]
540    fn test_rewrite_stats_default() {
541        let stats = RewriteStats::default();
542        assert_eq!(stats.patterns_matched, 0);
543        assert_eq!(stats.rewrites_applied, 0);
544    }
545
546    #[test]
547    fn test_rewrite_stats_reduction() {
548        let stats = RewriteStats {
549            nodes_before: 100,
550            nodes_after: 80,
551            nodes_eliminated: 20,
552            ..Default::default()
553        };
554        assert_eq!(stats.reduction_percentage(), 20.0);
555    }
556
557    #[test]
558    fn test_pattern_matcher_creation() {
559        let matcher = PatternMatcher::new();
560        assert_eq!(matcher.rules.len(), 0);
561    }
562
563    #[test]
564    fn test_pattern_matcher_add_rule() {
565        let mut matcher = PatternMatcher::new();
566
567        fn dummy_rewriter(
568            _graph: &EinsumGraph,
569            _m: &PatternMatch,
570        ) -> Result<Vec<EinsumNode>, IrError> {
571            Ok(Vec::new())
572        }
573
574        let rule = GraphRewriteRule::new("test", GraphPattern::AnyNode, dummy_rewriter);
575        matcher.add_rule(rule);
576        assert_eq!(matcher.rules.len(), 1);
577    }
578
579    #[test]
580    fn test_rule_priority_ordering() {
581        let mut matcher = PatternMatcher::new();
582
583        fn dummy_rewriter(
584            _graph: &EinsumGraph,
585            _m: &PatternMatch,
586        ) -> Result<Vec<EinsumNode>, IrError> {
587            Ok(Vec::new())
588        }
589
590        let rule1 =
591            GraphRewriteRule::new("low", GraphPattern::AnyNode, dummy_rewriter).with_priority(1);
592        let rule2 =
593            GraphRewriteRule::new("high", GraphPattern::AnyNode, dummy_rewriter).with_priority(10);
594
595        matcher.add_rule(rule1);
596        matcher.add_rule(rule2);
597
598        // Should be sorted by priority (descending)
599        assert_eq!(matcher.rules[0].name, "high");
600        assert_eq!(matcher.rules[1].name, "low");
601    }
602
603    #[test]
604    fn test_op_matches_einsum() {
605        let op1 = OpType::Einsum {
606            spec: "ij,jk->ik".to_string(),
607        };
608        let op2 = OpType::Einsum {
609            spec: "ik,kl->il".to_string(),
610        };
611        assert!(PatternMatcher::op_matches(&op1, &op2));
612    }
613
614    #[test]
615    fn test_op_matches_elem_unary() {
616        let op1 = OpType::ElemUnary {
617            op: "relu".to_string(),
618        };
619        let op2 = OpType::ElemUnary {
620            op: "relu".to_string(),
621        };
622        assert!(PatternMatcher::op_matches(&op1, &op2));
623    }
624
625    #[test]
626    fn test_op_not_matches_different_types() {
627        let op1 = OpType::ElemUnary {
628            op: "relu".to_string(),
629        };
630        let op2 = OpType::ElemBinary {
631            op: "add".to_string(),
632        };
633        assert!(!PatternMatcher::op_matches(&op1, &op2));
634    }
635
636    #[test]
637    fn test_patterns_elementwise_chain() {
638        let pattern = patterns::elementwise_chain(1);
639        match pattern {
640            GraphPattern::OneOrMore(_) => (),
641            _ => panic!("Expected OneOrMore pattern"),
642        }
643    }
644
645    #[test]
646    fn test_patterns_map_reduce() {
647        let pattern = patterns::map_reduce();
648        match pattern {
649            GraphPattern::Sequence(seq) => {
650                assert_eq!(seq.len(), 2);
651            }
652            _ => panic!("Expected Sequence pattern"),
653        }
654    }
655}