1use std::cmp::Reverse;
11use std::collections::{HashMap, HashSet};
12
13use super::{EinsumGraph, EinsumNode, OpType};
14use crate::error::IrError;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum GraphPattern {
19 AnyNode,
21 OpType(OpType),
23 Sequence(Vec<GraphPattern>),
25 Choice(Vec<GraphPattern>),
27 WithInputs(usize),
29 WithOutputs(usize),
31 Capture(String, Box<GraphPattern>),
33 ZeroOrMore(Box<GraphPattern>),
35 OneOrMore(Box<GraphPattern>),
37}
38
39#[derive(Debug, Clone)]
41pub struct PatternMatch {
42 pub matched_nodes: Vec<usize>,
44 pub captures: HashMap<String, Vec<usize>>,
46 pub matched_tensors: HashSet<usize>,
48}
49
50impl PatternMatch {
51 pub fn new() -> Self {
53 Self {
54 matched_nodes: Vec::new(),
55 captures: HashMap::new(),
56 matched_tensors: HashSet::new(),
57 }
58 }
59
60 pub fn add_node(&mut self, node_idx: usize) {
62 self.matched_nodes.push(node_idx);
63 }
64
65 pub fn add_capture(&mut self, name: String, node_idx: usize) {
67 self.captures.entry(name).or_default().push(node_idx);
68 }
69
70 pub fn get_capture(&self, name: &str) -> Option<&[usize]> {
72 self.captures.get(name).map(|v| v.as_slice())
73 }
74}
75
76impl Default for PatternMatch {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct GraphRewriteRule {
85 pub name: String,
87 pub pattern: GraphPattern,
89 pub rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
91 pub priority: i32,
93}
94
95impl GraphRewriteRule {
96 pub fn new(
98 name: impl Into<String>,
99 pattern: GraphPattern,
100 rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
101 ) -> Self {
102 Self {
103 name: name.into(),
104 pattern,
105 rewriter,
106 priority: 0,
107 }
108 }
109
110 pub fn with_priority(mut self, priority: i32) -> Self {
112 self.priority = priority;
113 self
114 }
115}
116
117#[derive(Debug, Clone, Default)]
119pub struct RewriteStats {
120 pub patterns_matched: usize,
122 pub rewrites_applied: usize,
124 pub nodes_before: usize,
126 pub nodes_after: usize,
128 pub nodes_eliminated: usize,
130}
131
132impl RewriteStats {
133 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn reduction_percentage(&self) -> f64 {
140 if self.nodes_before == 0 {
141 return 0.0;
142 }
143 (self.nodes_eliminated as f64 / self.nodes_before as f64) * 100.0
144 }
145}
146
147pub struct PatternMatcher {
149 rules: Vec<GraphRewriteRule>,
151}
152
153impl PatternMatcher {
154 pub fn new() -> Self {
156 Self { rules: Vec::new() }
157 }
158
159 pub fn add_rule(&mut self, rule: GraphRewriteRule) {
161 self.rules.push(rule);
162 self.rules.sort_by_key(|r| Reverse(r.priority));
164 }
165
166 pub fn find_matches(&self, graph: &EinsumGraph, pattern: &GraphPattern) -> Vec<PatternMatch> {
168 let mut matches = Vec::new();
169
170 for start_idx in 0..graph.nodes.len() {
172 if let Some(m) = self.try_match_from(graph, pattern, start_idx, &HashSet::new()) {
173 matches.push(m);
174 }
175 }
176
177 matches
178 }
179
180 fn try_match_from(
182 &self,
183 graph: &EinsumGraph,
184 pattern: &GraphPattern,
185 start_idx: usize,
186 visited: &HashSet<usize>,
187 ) -> Option<PatternMatch> {
188 if start_idx >= graph.nodes.len() || visited.contains(&start_idx) {
189 return None;
190 }
191
192 match pattern {
193 GraphPattern::AnyNode => {
194 let mut m = PatternMatch::new();
195 m.add_node(start_idx);
196 Some(m)
197 }
198
199 GraphPattern::OpType(expected_op) => {
200 let node = &graph.nodes[start_idx];
201 if Self::op_matches(&node.op, expected_op) {
202 let mut m = PatternMatch::new();
203 m.add_node(start_idx);
204 Some(m)
205 } else {
206 None
207 }
208 }
209
210 GraphPattern::WithInputs(count) => {
211 let node = &graph.nodes[start_idx];
212 if node.inputs.len() == *count {
213 let mut m = PatternMatch::new();
214 m.add_node(start_idx);
215 Some(m)
216 } else {
217 None
218 }
219 }
220
221 GraphPattern::WithOutputs(count) => {
222 let node = &graph.nodes[start_idx];
223 if node.outputs.len() == *count {
224 let mut m = PatternMatch::new();
225 m.add_node(start_idx);
226 Some(m)
227 } else {
228 None
229 }
230 }
231
232 GraphPattern::Capture(name, sub_pattern) => {
233 if let Some(mut m) = self.try_match_from(graph, sub_pattern, start_idx, visited) {
234 m.add_capture(name.clone(), start_idx);
235 Some(m)
236 } else {
237 None
238 }
239 }
240
241 GraphPattern::Sequence(patterns) => {
242 self.match_sequence(graph, patterns, start_idx, visited)
243 }
244
245 GraphPattern::Choice(patterns) => {
246 for pat in patterns {
247 if let Some(m) = self.try_match_from(graph, pat, start_idx, visited) {
248 return Some(m);
249 }
250 }
251 None
252 }
253
254 GraphPattern::OneOrMore(sub_pattern) => {
255 self.match_one_or_more(graph, sub_pattern, start_idx, visited)
256 }
257
258 GraphPattern::ZeroOrMore(sub_pattern) => {
259 if let Some(m) = self.match_one_or_more(graph, sub_pattern, start_idx, visited) {
260 Some(m)
261 } else {
262 Some(PatternMatch::new())
264 }
265 }
266 }
267 }
268
269 fn match_sequence(
271 &self,
272 graph: &EinsumGraph,
273 patterns: &[GraphPattern],
274 start_idx: usize,
275 visited: &HashSet<usize>,
276 ) -> Option<PatternMatch> {
277 if patterns.is_empty() {
278 return Some(PatternMatch::new());
279 }
280
281 let mut result = PatternMatch::new();
282 let mut current_visited = visited.clone();
283 let mut current_idx = start_idx;
284
285 for pattern in patterns {
286 if let Some(m) = self.try_match_from(graph, pattern, current_idx, ¤t_visited) {
287 for &node in &m.matched_nodes {
289 result.add_node(node);
290 current_visited.insert(node);
291 }
292 for (name, nodes) in m.captures {
293 for node in nodes {
294 result.add_capture(name.clone(), node);
295 }
296 }
297
298 if let Some(&last_node) = m.matched_nodes.last() {
300 if let Some(next) = self.find_successor(graph, last_node) {
301 current_idx = next;
302 } else {
303 return None; }
305 }
306 } else {
307 return None;
308 }
309 }
310
311 Some(result)
312 }
313
314 fn match_one_or_more(
316 &self,
317 graph: &EinsumGraph,
318 pattern: &GraphPattern,
319 start_idx: usize,
320 visited: &HashSet<usize>,
321 ) -> Option<PatternMatch> {
322 let mut result = PatternMatch::new();
323 let mut current_visited = visited.clone();
324 let mut current_idx = start_idx;
325 let mut matched_any = false;
326
327 loop {
328 if let Some(m) = self.try_match_from(graph, pattern, current_idx, ¤t_visited) {
329 matched_any = true;
330
331 for &node in &m.matched_nodes {
333 result.add_node(node);
334 current_visited.insert(node);
335 }
336
337 if let Some(&last_node) = m.matched_nodes.last() {
339 if let Some(next) = self.find_successor(graph, last_node) {
340 current_idx = next;
341 continue;
342 }
343 }
344 }
345 break;
346 }
347
348 if matched_any {
349 Some(result)
350 } else {
351 None
352 }
353 }
354
355 fn find_successor(&self, graph: &EinsumGraph, node_idx: usize) -> Option<usize> {
357 let node = &graph.nodes[node_idx];
358
359 for &output_tensor in &node.outputs {
361 for (idx, other_node) in graph.nodes.iter().enumerate() {
362 if other_node.inputs.contains(&output_tensor) {
363 return Some(idx);
364 }
365 }
366 }
367
368 None
369 }
370
371 fn op_matches(actual: &OpType, expected: &OpType) -> bool {
373 match (actual, expected) {
374 (OpType::Einsum { .. }, OpType::Einsum { .. }) => true,
375 (OpType::ElemUnary { op: a }, OpType::ElemUnary { op: b }) => a == b,
376 (OpType::ElemBinary { op: a }, OpType::ElemBinary { op: b }) => a == b,
377 (OpType::Reduce { op: a, .. }, OpType::Reduce { op: b, .. }) => a == b,
378 _ => false,
379 }
380 }
381
382 pub fn apply_rules(&self, graph: &mut EinsumGraph) -> Result<RewriteStats, IrError> {
384 let mut stats = RewriteStats::new();
385 stats.nodes_before = graph.nodes.len();
386
387 let mut modified = true;
388 let mut iterations = 0;
389 const MAX_ITERATIONS: usize = 100;
390
391 while modified && iterations < MAX_ITERATIONS {
392 modified = false;
393 iterations += 1;
394
395 for rule in &self.rules {
396 let matches = self.find_matches(graph, &rule.pattern);
397
398 for m in matches {
399 stats.patterns_matched += 1;
400
401 if let Ok(new_nodes) = (rule.rewriter)(graph, &m) {
403 if self.apply_rewrite(graph, &m, new_nodes)? {
405 stats.rewrites_applied += 1;
406 modified = true;
407 }
408 }
409 }
410 }
411 }
412
413 stats.nodes_after = graph.nodes.len();
414 stats.nodes_eliminated = stats.nodes_before.saturating_sub(stats.nodes_after);
415
416 Ok(stats)
417 }
418
419 fn apply_rewrite(
421 &self,
422 _graph: &mut EinsumGraph,
423 _pattern_match: &PatternMatch,
424 _new_nodes: Vec<EinsumNode>,
425 ) -> Result<bool, IrError> {
426 Ok(false)
435 }
436}
437
438impl Default for PatternMatcher {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444pub mod patterns {
446 use super::*;
447
448 #[allow(dead_code)]
450 pub fn elementwise_chain(min_length: usize) -> GraphPattern {
451 let elem_op = GraphPattern::Choice(vec![
452 GraphPattern::OpType(OpType::ElemUnary { op: String::new() }),
453 GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
454 ]);
455
456 if min_length == 1 {
457 GraphPattern::OneOrMore(Box::new(elem_op))
458 } else {
459 let mut sequence = Vec::new();
460 for _ in 0..min_length {
461 sequence.push(elem_op.clone());
462 }
463 GraphPattern::Sequence(sequence)
464 }
465 }
466
467 #[allow(dead_code)]
469 pub fn einsum_reduce() -> GraphPattern {
470 GraphPattern::Sequence(vec![
471 GraphPattern::OpType(OpType::Einsum {
472 spec: String::new(),
473 }),
474 GraphPattern::OpType(OpType::Reduce {
475 op: String::new(),
476 axes: Vec::new(),
477 }),
478 ])
479 }
480
481 #[allow(dead_code)]
483 pub fn map_reduce() -> GraphPattern {
484 GraphPattern::Sequence(vec![
485 GraphPattern::Capture(
486 "map".to_string(),
487 Box::new(GraphPattern::OpType(OpType::ElemUnary {
488 op: String::new(),
489 })),
490 ),
491 GraphPattern::Capture(
492 "reduce".to_string(),
493 Box::new(GraphPattern::OpType(OpType::Reduce {
494 op: String::new(),
495 axes: Vec::new(),
496 })),
497 ),
498 ])
499 }
500
501 #[allow(dead_code)]
503 pub fn broadcast_elementwise() -> GraphPattern {
504 GraphPattern::Sequence(vec![
505 GraphPattern::OpType(OpType::ElemBinary {
506 op: "broadcast".to_string(),
507 }),
508 GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
509 ])
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_pattern_match_creation() {
519 let m = PatternMatch::new();
520 assert!(m.matched_nodes.is_empty());
521 assert!(m.captures.is_empty());
522 }
523
524 #[test]
525 fn test_pattern_match_add_node() {
526 let mut m = PatternMatch::new();
527 m.add_node(0);
528 m.add_node(1);
529 assert_eq!(m.matched_nodes, vec![0, 1]);
530 }
531
532 #[test]
533 fn test_pattern_match_capture() {
534 let mut m = PatternMatch::new();
535 m.add_capture("test".to_string(), 5);
536 assert_eq!(m.get_capture("test"), Some(&[5][..]));
537 assert_eq!(m.get_capture("nonexistent"), None);
538 }
539
540 #[test]
541 fn test_rewrite_stats_default() {
542 let stats = RewriteStats::default();
543 assert_eq!(stats.patterns_matched, 0);
544 assert_eq!(stats.rewrites_applied, 0);
545 }
546
547 #[test]
548 fn test_rewrite_stats_reduction() {
549 let stats = RewriteStats {
550 nodes_before: 100,
551 nodes_after: 80,
552 nodes_eliminated: 20,
553 ..Default::default()
554 };
555 assert_eq!(stats.reduction_percentage(), 20.0);
556 }
557
558 #[test]
559 fn test_pattern_matcher_creation() {
560 let matcher = PatternMatcher::new();
561 assert_eq!(matcher.rules.len(), 0);
562 }
563
564 #[test]
565 fn test_pattern_matcher_add_rule() {
566 let mut matcher = PatternMatcher::new();
567
568 fn dummy_rewriter(
569 _graph: &EinsumGraph,
570 _m: &PatternMatch,
571 ) -> Result<Vec<EinsumNode>, IrError> {
572 Ok(Vec::new())
573 }
574
575 let rule = GraphRewriteRule::new("test", GraphPattern::AnyNode, dummy_rewriter);
576 matcher.add_rule(rule);
577 assert_eq!(matcher.rules.len(), 1);
578 }
579
580 #[test]
581 fn test_rule_priority_ordering() {
582 let mut matcher = PatternMatcher::new();
583
584 fn dummy_rewriter(
585 _graph: &EinsumGraph,
586 _m: &PatternMatch,
587 ) -> Result<Vec<EinsumNode>, IrError> {
588 Ok(Vec::new())
589 }
590
591 let rule1 =
592 GraphRewriteRule::new("low", GraphPattern::AnyNode, dummy_rewriter).with_priority(1);
593 let rule2 =
594 GraphRewriteRule::new("high", GraphPattern::AnyNode, dummy_rewriter).with_priority(10);
595
596 matcher.add_rule(rule1);
597 matcher.add_rule(rule2);
598
599 assert_eq!(matcher.rules[0].name, "high");
601 assert_eq!(matcher.rules[1].name, "low");
602 }
603
604 #[test]
605 fn test_op_matches_einsum() {
606 let op1 = OpType::Einsum {
607 spec: "ij,jk->ik".to_string(),
608 };
609 let op2 = OpType::Einsum {
610 spec: "ik,kl->il".to_string(),
611 };
612 assert!(PatternMatcher::op_matches(&op1, &op2));
613 }
614
615 #[test]
616 fn test_op_matches_elem_unary() {
617 let op1 = OpType::ElemUnary {
618 op: "relu".to_string(),
619 };
620 let op2 = OpType::ElemUnary {
621 op: "relu".to_string(),
622 };
623 assert!(PatternMatcher::op_matches(&op1, &op2));
624 }
625
626 #[test]
627 fn test_op_not_matches_different_types() {
628 let op1 = OpType::ElemUnary {
629 op: "relu".to_string(),
630 };
631 let op2 = OpType::ElemBinary {
632 op: "add".to_string(),
633 };
634 assert!(!PatternMatcher::op_matches(&op1, &op2));
635 }
636
637 #[test]
638 fn test_patterns_elementwise_chain() {
639 let pattern = patterns::elementwise_chain(1);
640 match pattern {
641 GraphPattern::OneOrMore(_) => (),
642 _ => panic!("Expected OneOrMore pattern"),
643 }
644 }
645
646 #[test]
647 fn test_patterns_map_reduce() {
648 let pattern = patterns::map_reduce();
649 match pattern {
650 GraphPattern::Sequence(seq) => {
651 assert_eq!(seq.len(), 2);
652 }
653 _ => panic!("Expected Sequence pattern"),
654 }
655 }
656}