Skip to main content

tensorlogic_ir/graph/
pattern.rs

1//! Graph pattern matching and rewriting system.
2//!
3//! This module provides a sophisticated pattern matching and rewriting framework
4//! for tensor computation graphs. It enables:
5//! - Pattern-based graph optimization
6//! - Backend-specific transformations
7//! - Custom fusion strategies
8//! - Automatic graph simplification
9
10use std::cmp::Reverse;
11use std::collections::{HashMap, HashSet};
12
13use super::{EinsumGraph, EinsumNode, OpType};
14use crate::error::IrError;
15
16/// A pattern that can match against graph structures
17#[derive(Debug, Clone, PartialEq)]
18pub enum GraphPattern {
19    /// Match any single node
20    AnyNode,
21    /// Match a specific operation type
22    OpType(OpType),
23    /// Match a sequence of operations
24    Sequence(Vec<GraphPattern>),
25    /// Match any of the given patterns
26    Choice(Vec<GraphPattern>),
27    /// Match a node with specific input count
28    WithInputs(usize),
29    /// Match a node with specific output count
30    WithOutputs(usize),
31    /// Match a named subpattern for capture
32    Capture(String, Box<GraphPattern>),
33    /// Match a pattern zero or more times
34    ZeroOrMore(Box<GraphPattern>),
35    /// Match a pattern one or more times
36    OneOrMore(Box<GraphPattern>),
37}
38
39/// Result of a successful pattern match
40#[derive(Debug, Clone)]
41pub struct PatternMatch {
42    /// Indices of matched nodes in order
43    pub matched_nodes: Vec<usize>,
44    /// Captured subpatterns by name
45    pub captures: HashMap<String, Vec<usize>>,
46    /// Matched tensors (inputs and outputs)
47    pub matched_tensors: HashSet<usize>,
48}
49
50impl PatternMatch {
51    /// Create a new empty pattern match
52    pub fn new() -> Self {
53        Self {
54            matched_nodes: Vec::new(),
55            captures: HashMap::new(),
56            matched_tensors: HashSet::new(),
57        }
58    }
59
60    /// Add a matched node
61    pub fn add_node(&mut self, node_idx: usize) {
62        self.matched_nodes.push(node_idx);
63    }
64
65    /// Add a capture
66    pub fn add_capture(&mut self, name: String, node_idx: usize) {
67        self.captures.entry(name).or_default().push(node_idx);
68    }
69
70    /// Get nodes for a capture by name
71    pub fn get_capture(&self, name: &str) -> Option<&[usize]> {
72        self.captures.get(name).map(|v| v.as_slice())
73    }
74}
75
76impl Default for PatternMatch {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82/// A rewrite rule that transforms matched patterns
83#[derive(Debug, Clone)]
84pub struct GraphRewriteRule {
85    /// Name of this rule for debugging
86    pub name: String,
87    /// Pattern to match
88    pub pattern: GraphPattern,
89    /// Function to apply the rewrite
90    pub rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
91    /// Priority (higher = applied first)
92    pub priority: i32,
93}
94
95impl GraphRewriteRule {
96    /// Create a new rewrite rule
97    pub fn new(
98        name: impl Into<String>,
99        pattern: GraphPattern,
100        rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
101    ) -> Self {
102        Self {
103            name: name.into(),
104            pattern,
105            rewriter,
106            priority: 0,
107        }
108    }
109
110    /// Set the priority of this rule
111    pub fn with_priority(mut self, priority: i32) -> Self {
112        self.priority = priority;
113        self
114    }
115}
116
117/// Statistics about pattern matching and rewriting
118#[derive(Debug, Clone, Default)]
119pub struct RewriteStats {
120    /// Number of patterns matched
121    pub patterns_matched: usize,
122    /// Number of rewrites applied
123    pub rewrites_applied: usize,
124    /// Number of nodes before rewriting
125    pub nodes_before: usize,
126    /// Number of nodes after rewriting
127    pub nodes_after: usize,
128    /// Nodes eliminated by rewriting
129    pub nodes_eliminated: usize,
130}
131
132impl RewriteStats {
133    /// Create new empty stats
134    pub fn new() -> Self {
135        Self::default()
136    }
137
138    /// Calculate reduction percentage
139    pub fn reduction_percentage(&self) -> f64 {
140        if self.nodes_before == 0 {
141            return 0.0;
142        }
143        (self.nodes_eliminated as f64 / self.nodes_before as f64) * 100.0
144    }
145}
146
147/// Pattern matcher for graphs
148pub struct PatternMatcher {
149    /// Rules to apply
150    rules: Vec<GraphRewriteRule>,
151}
152
153impl PatternMatcher {
154    /// Create a new pattern matcher
155    pub fn new() -> Self {
156        Self { rules: Vec::new() }
157    }
158
159    /// Add a rewrite rule
160    pub fn add_rule(&mut self, rule: GraphRewriteRule) {
161        self.rules.push(rule);
162        // Sort by priority (descending)
163        self.rules.sort_by_key(|r| Reverse(r.priority));
164    }
165
166    /// Find all matches for a pattern in the graph
167    pub fn find_matches(&self, graph: &EinsumGraph, pattern: &GraphPattern) -> Vec<PatternMatch> {
168        let mut matches = Vec::new();
169
170        // Try to match starting from each node
171        for start_idx in 0..graph.nodes.len() {
172            if let Some(m) = self.try_match_from(graph, pattern, start_idx, &HashSet::new()) {
173                matches.push(m);
174            }
175        }
176
177        matches
178    }
179
180    /// Try to match a pattern starting from a specific node
181    fn try_match_from(
182        &self,
183        graph: &EinsumGraph,
184        pattern: &GraphPattern,
185        start_idx: usize,
186        visited: &HashSet<usize>,
187    ) -> Option<PatternMatch> {
188        if start_idx >= graph.nodes.len() || visited.contains(&start_idx) {
189            return None;
190        }
191
192        match pattern {
193            GraphPattern::AnyNode => {
194                let mut m = PatternMatch::new();
195                m.add_node(start_idx);
196                Some(m)
197            }
198
199            GraphPattern::OpType(expected_op) => {
200                let node = &graph.nodes[start_idx];
201                if Self::op_matches(&node.op, expected_op) {
202                    let mut m = PatternMatch::new();
203                    m.add_node(start_idx);
204                    Some(m)
205                } else {
206                    None
207                }
208            }
209
210            GraphPattern::WithInputs(count) => {
211                let node = &graph.nodes[start_idx];
212                if node.inputs.len() == *count {
213                    let mut m = PatternMatch::new();
214                    m.add_node(start_idx);
215                    Some(m)
216                } else {
217                    None
218                }
219            }
220
221            GraphPattern::WithOutputs(count) => {
222                let node = &graph.nodes[start_idx];
223                if node.outputs.len() == *count {
224                    let mut m = PatternMatch::new();
225                    m.add_node(start_idx);
226                    Some(m)
227                } else {
228                    None
229                }
230            }
231
232            GraphPattern::Capture(name, sub_pattern) => {
233                if let Some(mut m) = self.try_match_from(graph, sub_pattern, start_idx, visited) {
234                    m.add_capture(name.clone(), start_idx);
235                    Some(m)
236                } else {
237                    None
238                }
239            }
240
241            GraphPattern::Sequence(patterns) => {
242                self.match_sequence(graph, patterns, start_idx, visited)
243            }
244
245            GraphPattern::Choice(patterns) => {
246                for pat in patterns {
247                    if let Some(m) = self.try_match_from(graph, pat, start_idx, visited) {
248                        return Some(m);
249                    }
250                }
251                None
252            }
253
254            GraphPattern::OneOrMore(sub_pattern) => {
255                self.match_one_or_more(graph, sub_pattern, start_idx, visited)
256            }
257
258            GraphPattern::ZeroOrMore(sub_pattern) => {
259                if let Some(m) = self.match_one_or_more(graph, sub_pattern, start_idx, visited) {
260                    Some(m)
261                } else {
262                    // Zero matches is valid
263                    Some(PatternMatch::new())
264                }
265            }
266        }
267    }
268
269    /// Match a sequence of patterns
270    fn match_sequence(
271        &self,
272        graph: &EinsumGraph,
273        patterns: &[GraphPattern],
274        start_idx: usize,
275        visited: &HashSet<usize>,
276    ) -> Option<PatternMatch> {
277        if patterns.is_empty() {
278            return Some(PatternMatch::new());
279        }
280
281        let mut result = PatternMatch::new();
282        let mut current_visited = visited.clone();
283        let mut current_idx = start_idx;
284
285        for pattern in patterns {
286            if let Some(m) = self.try_match_from(graph, pattern, current_idx, &current_visited) {
287                // Merge matches
288                for &node in &m.matched_nodes {
289                    result.add_node(node);
290                    current_visited.insert(node);
291                }
292                for (name, nodes) in m.captures {
293                    for node in nodes {
294                        result.add_capture(name.clone(), node);
295                    }
296                }
297
298                // Move to next node (follow data flow)
299                if let Some(&last_node) = m.matched_nodes.last() {
300                    if let Some(next) = self.find_successor(graph, last_node) {
301                        current_idx = next;
302                    } else {
303                        return None; // No successor, can't continue sequence
304                    }
305                }
306            } else {
307                return None;
308            }
309        }
310
311        Some(result)
312    }
313
314    /// Match one or more occurrences of a pattern
315    fn match_one_or_more(
316        &self,
317        graph: &EinsumGraph,
318        pattern: &GraphPattern,
319        start_idx: usize,
320        visited: &HashSet<usize>,
321    ) -> Option<PatternMatch> {
322        let mut result = PatternMatch::new();
323        let mut current_visited = visited.clone();
324        let mut current_idx = start_idx;
325        let mut matched_any = false;
326
327        loop {
328            if let Some(m) = self.try_match_from(graph, pattern, current_idx, &current_visited) {
329                matched_any = true;
330
331                // Merge matches
332                for &node in &m.matched_nodes {
333                    result.add_node(node);
334                    current_visited.insert(node);
335                }
336
337                // Try to continue matching
338                if let Some(&last_node) = m.matched_nodes.last() {
339                    if let Some(next) = self.find_successor(graph, last_node) {
340                        current_idx = next;
341                        continue;
342                    }
343                }
344            }
345            break;
346        }
347
348        if matched_any {
349            Some(result)
350        } else {
351            None
352        }
353    }
354
355    /// Find the successor of a node in the dataflow
356    fn find_successor(&self, graph: &EinsumGraph, node_idx: usize) -> Option<usize> {
357        let node = &graph.nodes[node_idx];
358
359        // Find a node that uses the output of this node
360        for &output_tensor in &node.outputs {
361            for (idx, other_node) in graph.nodes.iter().enumerate() {
362                if other_node.inputs.contains(&output_tensor) {
363                    return Some(idx);
364                }
365            }
366        }
367
368        None
369    }
370
371    /// Check if two operation types match
372    fn op_matches(actual: &OpType, expected: &OpType) -> bool {
373        match (actual, expected) {
374            (OpType::Einsum { .. }, OpType::Einsum { .. }) => true,
375            (OpType::ElemUnary { op: a }, OpType::ElemUnary { op: b }) => a == b,
376            (OpType::ElemBinary { op: a }, OpType::ElemBinary { op: b }) => a == b,
377            (OpType::Reduce { op: a, .. }, OpType::Reduce { op: b, .. }) => a == b,
378            _ => false,
379        }
380    }
381
382    /// Apply all rules to a graph and return rewrite statistics
383    pub fn apply_rules(&self, graph: &mut EinsumGraph) -> Result<RewriteStats, IrError> {
384        let mut stats = RewriteStats::new();
385        stats.nodes_before = graph.nodes.len();
386
387        let mut modified = true;
388        let mut iterations = 0;
389        const MAX_ITERATIONS: usize = 100;
390
391        while modified && iterations < MAX_ITERATIONS {
392            modified = false;
393            iterations += 1;
394
395            for rule in &self.rules {
396                let matches = self.find_matches(graph, &rule.pattern);
397
398                for m in matches {
399                    stats.patterns_matched += 1;
400
401                    // Apply the rewrite
402                    if let Ok(new_nodes) = (rule.rewriter)(graph, &m) {
403                        // Replace matched nodes with new nodes
404                        if self.apply_rewrite(graph, &m, new_nodes)? {
405                            stats.rewrites_applied += 1;
406                            modified = true;
407                        }
408                    }
409                }
410            }
411        }
412
413        stats.nodes_after = graph.nodes.len();
414        stats.nodes_eliminated = stats.nodes_before.saturating_sub(stats.nodes_after);
415
416        Ok(stats)
417    }
418
419    /// Apply a rewrite by replacing matched nodes with new nodes
420    fn apply_rewrite(
421        &self,
422        _graph: &mut EinsumGraph,
423        _pattern_match: &PatternMatch,
424        _new_nodes: Vec<EinsumNode>,
425    ) -> Result<bool, IrError> {
426        // This is a simplified implementation
427        // In a full implementation, we would:
428        // 1. Remove the matched nodes
429        // 2. Insert the new nodes
430        // 3. Rewire connections
431        // 4. Update tensor indices
432
433        // For now, just indicate success without modification
434        Ok(false)
435    }
436}
437
438impl Default for PatternMatcher {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444/// Common graph rewrite patterns
445pub mod patterns {
446    use super::*;
447
448    /// Pattern for consecutive element-wise operations
449    #[allow(dead_code)]
450    pub fn elementwise_chain(min_length: usize) -> GraphPattern {
451        let elem_op = GraphPattern::Choice(vec![
452            GraphPattern::OpType(OpType::ElemUnary { op: String::new() }),
453            GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
454        ]);
455
456        if min_length == 1 {
457            GraphPattern::OneOrMore(Box::new(elem_op))
458        } else {
459            let mut sequence = Vec::new();
460            for _ in 0..min_length {
461                sequence.push(elem_op.clone());
462            }
463            GraphPattern::Sequence(sequence)
464        }
465    }
466
467    /// Pattern for einsum followed by reduction
468    #[allow(dead_code)]
469    pub fn einsum_reduce() -> GraphPattern {
470        GraphPattern::Sequence(vec![
471            GraphPattern::OpType(OpType::Einsum {
472                spec: String::new(),
473            }),
474            GraphPattern::OpType(OpType::Reduce {
475                op: String::new(),
476                axes: Vec::new(),
477            }),
478        ])
479    }
480
481    /// Pattern for map-reduce idiom
482    #[allow(dead_code)]
483    pub fn map_reduce() -> GraphPattern {
484        GraphPattern::Sequence(vec![
485            GraphPattern::Capture(
486                "map".to_string(),
487                Box::new(GraphPattern::OpType(OpType::ElemUnary {
488                    op: String::new(),
489                })),
490            ),
491            GraphPattern::Capture(
492                "reduce".to_string(),
493                Box::new(GraphPattern::OpType(OpType::Reduce {
494                    op: String::new(),
495                    axes: Vec::new(),
496                })),
497            ),
498        ])
499    }
500
501    /// Pattern for broadcast followed by element-wise op
502    #[allow(dead_code)]
503    pub fn broadcast_elementwise() -> GraphPattern {
504        GraphPattern::Sequence(vec![
505            GraphPattern::OpType(OpType::ElemBinary {
506                op: "broadcast".to_string(),
507            }),
508            GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
509        ])
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_pattern_match_creation() {
519        let m = PatternMatch::new();
520        assert!(m.matched_nodes.is_empty());
521        assert!(m.captures.is_empty());
522    }
523
524    #[test]
525    fn test_pattern_match_add_node() {
526        let mut m = PatternMatch::new();
527        m.add_node(0);
528        m.add_node(1);
529        assert_eq!(m.matched_nodes, vec![0, 1]);
530    }
531
532    #[test]
533    fn test_pattern_match_capture() {
534        let mut m = PatternMatch::new();
535        m.add_capture("test".to_string(), 5);
536        assert_eq!(m.get_capture("test"), Some(&[5][..]));
537        assert_eq!(m.get_capture("nonexistent"), None);
538    }
539
540    #[test]
541    fn test_rewrite_stats_default() {
542        let stats = RewriteStats::default();
543        assert_eq!(stats.patterns_matched, 0);
544        assert_eq!(stats.rewrites_applied, 0);
545    }
546
547    #[test]
548    fn test_rewrite_stats_reduction() {
549        let stats = RewriteStats {
550            nodes_before: 100,
551            nodes_after: 80,
552            nodes_eliminated: 20,
553            ..Default::default()
554        };
555        assert_eq!(stats.reduction_percentage(), 20.0);
556    }
557
558    #[test]
559    fn test_pattern_matcher_creation() {
560        let matcher = PatternMatcher::new();
561        assert_eq!(matcher.rules.len(), 0);
562    }
563
564    #[test]
565    fn test_pattern_matcher_add_rule() {
566        let mut matcher = PatternMatcher::new();
567
568        fn dummy_rewriter(
569            _graph: &EinsumGraph,
570            _m: &PatternMatch,
571        ) -> Result<Vec<EinsumNode>, IrError> {
572            Ok(Vec::new())
573        }
574
575        let rule = GraphRewriteRule::new("test", GraphPattern::AnyNode, dummy_rewriter);
576        matcher.add_rule(rule);
577        assert_eq!(matcher.rules.len(), 1);
578    }
579
580    #[test]
581    fn test_rule_priority_ordering() {
582        let mut matcher = PatternMatcher::new();
583
584        fn dummy_rewriter(
585            _graph: &EinsumGraph,
586            _m: &PatternMatch,
587        ) -> Result<Vec<EinsumNode>, IrError> {
588            Ok(Vec::new())
589        }
590
591        let rule1 =
592            GraphRewriteRule::new("low", GraphPattern::AnyNode, dummy_rewriter).with_priority(1);
593        let rule2 =
594            GraphRewriteRule::new("high", GraphPattern::AnyNode, dummy_rewriter).with_priority(10);
595
596        matcher.add_rule(rule1);
597        matcher.add_rule(rule2);
598
599        // Should be sorted by priority (descending)
600        assert_eq!(matcher.rules[0].name, "high");
601        assert_eq!(matcher.rules[1].name, "low");
602    }
603
604    #[test]
605    fn test_op_matches_einsum() {
606        let op1 = OpType::Einsum {
607            spec: "ij,jk->ik".to_string(),
608        };
609        let op2 = OpType::Einsum {
610            spec: "ik,kl->il".to_string(),
611        };
612        assert!(PatternMatcher::op_matches(&op1, &op2));
613    }
614
615    #[test]
616    fn test_op_matches_elem_unary() {
617        let op1 = OpType::ElemUnary {
618            op: "relu".to_string(),
619        };
620        let op2 = OpType::ElemUnary {
621            op: "relu".to_string(),
622        };
623        assert!(PatternMatcher::op_matches(&op1, &op2));
624    }
625
626    #[test]
627    fn test_op_not_matches_different_types() {
628        let op1 = OpType::ElemUnary {
629            op: "relu".to_string(),
630        };
631        let op2 = OpType::ElemBinary {
632            op: "add".to_string(),
633        };
634        assert!(!PatternMatcher::op_matches(&op1, &op2));
635    }
636
637    #[test]
638    fn test_patterns_elementwise_chain() {
639        let pattern = patterns::elementwise_chain(1);
640        match pattern {
641            GraphPattern::OneOrMore(_) => (),
642            _ => panic!("Expected OneOrMore pattern"),
643        }
644    }
645
646    #[test]
647    fn test_patterns_map_reduce() {
648        let pattern = patterns::map_reduce();
649        match pattern {
650            GraphPattern::Sequence(seq) => {
651                assert_eq!(seq.len(), 2);
652            }
653            _ => panic!("Expected Sequence pattern"),
654        }
655    }
656}