1use std::collections::{HashMap, HashSet};
11
12use super::{EinsumGraph, EinsumNode, OpType};
13use crate::error::IrError;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum GraphPattern {
18 AnyNode,
20 OpType(OpType),
22 Sequence(Vec<GraphPattern>),
24 Choice(Vec<GraphPattern>),
26 WithInputs(usize),
28 WithOutputs(usize),
30 Capture(String, Box<GraphPattern>),
32 ZeroOrMore(Box<GraphPattern>),
34 OneOrMore(Box<GraphPattern>),
36}
37
38#[derive(Debug, Clone)]
40pub struct PatternMatch {
41 pub matched_nodes: Vec<usize>,
43 pub captures: HashMap<String, Vec<usize>>,
45 pub matched_tensors: HashSet<usize>,
47}
48
49impl PatternMatch {
50 pub fn new() -> Self {
52 Self {
53 matched_nodes: Vec::new(),
54 captures: HashMap::new(),
55 matched_tensors: HashSet::new(),
56 }
57 }
58
59 pub fn add_node(&mut self, node_idx: usize) {
61 self.matched_nodes.push(node_idx);
62 }
63
64 pub fn add_capture(&mut self, name: String, node_idx: usize) {
66 self.captures.entry(name).or_default().push(node_idx);
67 }
68
69 pub fn get_capture(&self, name: &str) -> Option<&[usize]> {
71 self.captures.get(name).map(|v| v.as_slice())
72 }
73}
74
75impl Default for PatternMatch {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct GraphRewriteRule {
84 pub name: String,
86 pub pattern: GraphPattern,
88 pub rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
90 pub priority: i32,
92}
93
94impl GraphRewriteRule {
95 pub fn new(
97 name: impl Into<String>,
98 pattern: GraphPattern,
99 rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
100 ) -> Self {
101 Self {
102 name: name.into(),
103 pattern,
104 rewriter,
105 priority: 0,
106 }
107 }
108
109 pub fn with_priority(mut self, priority: i32) -> Self {
111 self.priority = priority;
112 self
113 }
114}
115
116#[derive(Debug, Clone, Default)]
118pub struct RewriteStats {
119 pub patterns_matched: usize,
121 pub rewrites_applied: usize,
123 pub nodes_before: usize,
125 pub nodes_after: usize,
127 pub nodes_eliminated: usize,
129}
130
131impl RewriteStats {
132 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn reduction_percentage(&self) -> f64 {
139 if self.nodes_before == 0 {
140 return 0.0;
141 }
142 (self.nodes_eliminated as f64 / self.nodes_before as f64) * 100.0
143 }
144}
145
146pub struct PatternMatcher {
148 rules: Vec<GraphRewriteRule>,
150}
151
152impl PatternMatcher {
153 pub fn new() -> Self {
155 Self { rules: Vec::new() }
156 }
157
158 pub fn add_rule(&mut self, rule: GraphRewriteRule) {
160 self.rules.push(rule);
161 self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
163 }
164
165 pub fn find_matches(&self, graph: &EinsumGraph, pattern: &GraphPattern) -> Vec<PatternMatch> {
167 let mut matches = Vec::new();
168
169 for start_idx in 0..graph.nodes.len() {
171 if let Some(m) = self.try_match_from(graph, pattern, start_idx, &HashSet::new()) {
172 matches.push(m);
173 }
174 }
175
176 matches
177 }
178
179 fn try_match_from(
181 &self,
182 graph: &EinsumGraph,
183 pattern: &GraphPattern,
184 start_idx: usize,
185 visited: &HashSet<usize>,
186 ) -> Option<PatternMatch> {
187 if start_idx >= graph.nodes.len() || visited.contains(&start_idx) {
188 return None;
189 }
190
191 match pattern {
192 GraphPattern::AnyNode => {
193 let mut m = PatternMatch::new();
194 m.add_node(start_idx);
195 Some(m)
196 }
197
198 GraphPattern::OpType(expected_op) => {
199 let node = &graph.nodes[start_idx];
200 if Self::op_matches(&node.op, expected_op) {
201 let mut m = PatternMatch::new();
202 m.add_node(start_idx);
203 Some(m)
204 } else {
205 None
206 }
207 }
208
209 GraphPattern::WithInputs(count) => {
210 let node = &graph.nodes[start_idx];
211 if node.inputs.len() == *count {
212 let mut m = PatternMatch::new();
213 m.add_node(start_idx);
214 Some(m)
215 } else {
216 None
217 }
218 }
219
220 GraphPattern::WithOutputs(count) => {
221 let node = &graph.nodes[start_idx];
222 if node.outputs.len() == *count {
223 let mut m = PatternMatch::new();
224 m.add_node(start_idx);
225 Some(m)
226 } else {
227 None
228 }
229 }
230
231 GraphPattern::Capture(name, sub_pattern) => {
232 if let Some(mut m) = self.try_match_from(graph, sub_pattern, start_idx, visited) {
233 m.add_capture(name.clone(), start_idx);
234 Some(m)
235 } else {
236 None
237 }
238 }
239
240 GraphPattern::Sequence(patterns) => {
241 self.match_sequence(graph, patterns, start_idx, visited)
242 }
243
244 GraphPattern::Choice(patterns) => {
245 for pat in patterns {
246 if let Some(m) = self.try_match_from(graph, pat, start_idx, visited) {
247 return Some(m);
248 }
249 }
250 None
251 }
252
253 GraphPattern::OneOrMore(sub_pattern) => {
254 self.match_one_or_more(graph, sub_pattern, start_idx, visited)
255 }
256
257 GraphPattern::ZeroOrMore(sub_pattern) => {
258 if let Some(m) = self.match_one_or_more(graph, sub_pattern, start_idx, visited) {
259 Some(m)
260 } else {
261 Some(PatternMatch::new())
263 }
264 }
265 }
266 }
267
268 fn match_sequence(
270 &self,
271 graph: &EinsumGraph,
272 patterns: &[GraphPattern],
273 start_idx: usize,
274 visited: &HashSet<usize>,
275 ) -> Option<PatternMatch> {
276 if patterns.is_empty() {
277 return Some(PatternMatch::new());
278 }
279
280 let mut result = PatternMatch::new();
281 let mut current_visited = visited.clone();
282 let mut current_idx = start_idx;
283
284 for pattern in patterns {
285 if let Some(m) = self.try_match_from(graph, pattern, current_idx, ¤t_visited) {
286 for &node in &m.matched_nodes {
288 result.add_node(node);
289 current_visited.insert(node);
290 }
291 for (name, nodes) in m.captures {
292 for node in nodes {
293 result.add_capture(name.clone(), node);
294 }
295 }
296
297 if let Some(&last_node) = m.matched_nodes.last() {
299 if let Some(next) = self.find_successor(graph, last_node) {
300 current_idx = next;
301 } else {
302 return None; }
304 }
305 } else {
306 return None;
307 }
308 }
309
310 Some(result)
311 }
312
313 fn match_one_or_more(
315 &self,
316 graph: &EinsumGraph,
317 pattern: &GraphPattern,
318 start_idx: usize,
319 visited: &HashSet<usize>,
320 ) -> Option<PatternMatch> {
321 let mut result = PatternMatch::new();
322 let mut current_visited = visited.clone();
323 let mut current_idx = start_idx;
324 let mut matched_any = false;
325
326 loop {
327 if let Some(m) = self.try_match_from(graph, pattern, current_idx, ¤t_visited) {
328 matched_any = true;
329
330 for &node in &m.matched_nodes {
332 result.add_node(node);
333 current_visited.insert(node);
334 }
335
336 if let Some(&last_node) = m.matched_nodes.last() {
338 if let Some(next) = self.find_successor(graph, last_node) {
339 current_idx = next;
340 continue;
341 }
342 }
343 }
344 break;
345 }
346
347 if matched_any {
348 Some(result)
349 } else {
350 None
351 }
352 }
353
354 fn find_successor(&self, graph: &EinsumGraph, node_idx: usize) -> Option<usize> {
356 let node = &graph.nodes[node_idx];
357
358 for &output_tensor in &node.outputs {
360 for (idx, other_node) in graph.nodes.iter().enumerate() {
361 if other_node.inputs.contains(&output_tensor) {
362 return Some(idx);
363 }
364 }
365 }
366
367 None
368 }
369
370 fn op_matches(actual: &OpType, expected: &OpType) -> bool {
372 match (actual, expected) {
373 (OpType::Einsum { .. }, OpType::Einsum { .. }) => true,
374 (OpType::ElemUnary { op: a }, OpType::ElemUnary { op: b }) => a == b,
375 (OpType::ElemBinary { op: a }, OpType::ElemBinary { op: b }) => a == b,
376 (OpType::Reduce { op: a, .. }, OpType::Reduce { op: b, .. }) => a == b,
377 _ => false,
378 }
379 }
380
381 pub fn apply_rules(&self, graph: &mut EinsumGraph) -> Result<RewriteStats, IrError> {
383 let mut stats = RewriteStats::new();
384 stats.nodes_before = graph.nodes.len();
385
386 let mut modified = true;
387 let mut iterations = 0;
388 const MAX_ITERATIONS: usize = 100;
389
390 while modified && iterations < MAX_ITERATIONS {
391 modified = false;
392 iterations += 1;
393
394 for rule in &self.rules {
395 let matches = self.find_matches(graph, &rule.pattern);
396
397 for m in matches {
398 stats.patterns_matched += 1;
399
400 if let Ok(new_nodes) = (rule.rewriter)(graph, &m) {
402 if self.apply_rewrite(graph, &m, new_nodes)? {
404 stats.rewrites_applied += 1;
405 modified = true;
406 }
407 }
408 }
409 }
410 }
411
412 stats.nodes_after = graph.nodes.len();
413 stats.nodes_eliminated = stats.nodes_before.saturating_sub(stats.nodes_after);
414
415 Ok(stats)
416 }
417
418 fn apply_rewrite(
420 &self,
421 _graph: &mut EinsumGraph,
422 _pattern_match: &PatternMatch,
423 _new_nodes: Vec<EinsumNode>,
424 ) -> Result<bool, IrError> {
425 Ok(false)
434 }
435}
436
437impl Default for PatternMatcher {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443pub mod patterns {
445 use super::*;
446
447 #[allow(dead_code)]
449 pub fn elementwise_chain(min_length: usize) -> GraphPattern {
450 let elem_op = GraphPattern::Choice(vec![
451 GraphPattern::OpType(OpType::ElemUnary { op: String::new() }),
452 GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
453 ]);
454
455 if min_length == 1 {
456 GraphPattern::OneOrMore(Box::new(elem_op))
457 } else {
458 let mut sequence = Vec::new();
459 for _ in 0..min_length {
460 sequence.push(elem_op.clone());
461 }
462 GraphPattern::Sequence(sequence)
463 }
464 }
465
466 #[allow(dead_code)]
468 pub fn einsum_reduce() -> GraphPattern {
469 GraphPattern::Sequence(vec![
470 GraphPattern::OpType(OpType::Einsum {
471 spec: String::new(),
472 }),
473 GraphPattern::OpType(OpType::Reduce {
474 op: String::new(),
475 axes: Vec::new(),
476 }),
477 ])
478 }
479
480 #[allow(dead_code)]
482 pub fn map_reduce() -> GraphPattern {
483 GraphPattern::Sequence(vec![
484 GraphPattern::Capture(
485 "map".to_string(),
486 Box::new(GraphPattern::OpType(OpType::ElemUnary {
487 op: String::new(),
488 })),
489 ),
490 GraphPattern::Capture(
491 "reduce".to_string(),
492 Box::new(GraphPattern::OpType(OpType::Reduce {
493 op: String::new(),
494 axes: Vec::new(),
495 })),
496 ),
497 ])
498 }
499
500 #[allow(dead_code)]
502 pub fn broadcast_elementwise() -> GraphPattern {
503 GraphPattern::Sequence(vec![
504 GraphPattern::OpType(OpType::ElemBinary {
505 op: "broadcast".to_string(),
506 }),
507 GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
508 ])
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_pattern_match_creation() {
518 let m = PatternMatch::new();
519 assert!(m.matched_nodes.is_empty());
520 assert!(m.captures.is_empty());
521 }
522
523 #[test]
524 fn test_pattern_match_add_node() {
525 let mut m = PatternMatch::new();
526 m.add_node(0);
527 m.add_node(1);
528 assert_eq!(m.matched_nodes, vec![0, 1]);
529 }
530
531 #[test]
532 fn test_pattern_match_capture() {
533 let mut m = PatternMatch::new();
534 m.add_capture("test".to_string(), 5);
535 assert_eq!(m.get_capture("test"), Some(&[5][..]));
536 assert_eq!(m.get_capture("nonexistent"), None);
537 }
538
539 #[test]
540 fn test_rewrite_stats_default() {
541 let stats = RewriteStats::default();
542 assert_eq!(stats.patterns_matched, 0);
543 assert_eq!(stats.rewrites_applied, 0);
544 }
545
546 #[test]
547 fn test_rewrite_stats_reduction() {
548 let stats = RewriteStats {
549 nodes_before: 100,
550 nodes_after: 80,
551 nodes_eliminated: 20,
552 ..Default::default()
553 };
554 assert_eq!(stats.reduction_percentage(), 20.0);
555 }
556
557 #[test]
558 fn test_pattern_matcher_creation() {
559 let matcher = PatternMatcher::new();
560 assert_eq!(matcher.rules.len(), 0);
561 }
562
563 #[test]
564 fn test_pattern_matcher_add_rule() {
565 let mut matcher = PatternMatcher::new();
566
567 fn dummy_rewriter(
568 _graph: &EinsumGraph,
569 _m: &PatternMatch,
570 ) -> Result<Vec<EinsumNode>, IrError> {
571 Ok(Vec::new())
572 }
573
574 let rule = GraphRewriteRule::new("test", GraphPattern::AnyNode, dummy_rewriter);
575 matcher.add_rule(rule);
576 assert_eq!(matcher.rules.len(), 1);
577 }
578
579 #[test]
580 fn test_rule_priority_ordering() {
581 let mut matcher = PatternMatcher::new();
582
583 fn dummy_rewriter(
584 _graph: &EinsumGraph,
585 _m: &PatternMatch,
586 ) -> Result<Vec<EinsumNode>, IrError> {
587 Ok(Vec::new())
588 }
589
590 let rule1 =
591 GraphRewriteRule::new("low", GraphPattern::AnyNode, dummy_rewriter).with_priority(1);
592 let rule2 =
593 GraphRewriteRule::new("high", GraphPattern::AnyNode, dummy_rewriter).with_priority(10);
594
595 matcher.add_rule(rule1);
596 matcher.add_rule(rule2);
597
598 assert_eq!(matcher.rules[0].name, "high");
600 assert_eq!(matcher.rules[1].name, "low");
601 }
602
603 #[test]
604 fn test_op_matches_einsum() {
605 let op1 = OpType::Einsum {
606 spec: "ij,jk->ik".to_string(),
607 };
608 let op2 = OpType::Einsum {
609 spec: "ik,kl->il".to_string(),
610 };
611 assert!(PatternMatcher::op_matches(&op1, &op2));
612 }
613
614 #[test]
615 fn test_op_matches_elem_unary() {
616 let op1 = OpType::ElemUnary {
617 op: "relu".to_string(),
618 };
619 let op2 = OpType::ElemUnary {
620 op: "relu".to_string(),
621 };
622 assert!(PatternMatcher::op_matches(&op1, &op2));
623 }
624
625 #[test]
626 fn test_op_not_matches_different_types() {
627 let op1 = OpType::ElemUnary {
628 op: "relu".to_string(),
629 };
630 let op2 = OpType::ElemBinary {
631 op: "add".to_string(),
632 };
633 assert!(!PatternMatcher::op_matches(&op1, &op2));
634 }
635
636 #[test]
637 fn test_patterns_elementwise_chain() {
638 let pattern = patterns::elementwise_chain(1);
639 match pattern {
640 GraphPattern::OneOrMore(_) => (),
641 _ => panic!("Expected OneOrMore pattern"),
642 }
643 }
644
645 #[test]
646 fn test_patterns_map_reduce() {
647 let pattern = patterns::map_reduce();
648 match pattern {
649 GraphPattern::Sequence(seq) => {
650 assert_eq!(seq.len(), 2);
651 }
652 _ => panic!("Expected Sequence pattern"),
653 }
654 }
655}