1use serde::{Deserialize, Serialize};
31use std::collections::{HashMap, HashSet};
32use thiserror::Error;
33
34#[derive(Error, Debug, Clone, PartialEq)]
36pub enum RewriteError {
37 #[error("Pattern matching failed: {0}")]
38 PatternMatchFailed(String),
39
40 #[error("Invalid rewrite rule: {0}")]
41 InvalidRule(String),
42
43 #[error("Rewrite application failed: {0}")]
44 ApplicationFailed(String),
45
46 #[error("Cycle detected in rewrite application")]
47 CycleDetected,
48
49 #[error("Semantics verification failed: {0}")]
50 SemanticsViolation(String),
51}
52
53pub type NodeId = usize;
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58pub enum Pattern {
59 Any,
61
62 Op(String),
64
65 BinaryOp {
67 op: String,
68 left: Box<Pattern>,
69 right: Box<Pattern>,
70 },
71
72 UnaryOp { op: String, operand: Box<Pattern> },
74
75 Constant(f64),
77
78 Zero,
80
81 One,
83
84 Variable(String),
86
87 Sequence(Vec<Pattern>),
89
90 Alternative(Vec<Pattern>),
92}
93
94impl Pattern {
95 pub fn any() -> Self {
97 Pattern::Any
98 }
99
100 pub fn op(name: impl Into<String>) -> Self {
102 Pattern::Op(name.into())
103 }
104
105 pub fn binary_op(op: impl Into<String>, left: Pattern, right: Pattern) -> Self {
107 Pattern::BinaryOp {
108 op: op.into(),
109 left: Box::new(left),
110 right: Box::new(right),
111 }
112 }
113
114 pub fn unary_op(op: impl Into<String>, operand: Pattern) -> Self {
116 Pattern::UnaryOp {
117 op: op.into(),
118 operand: Box::new(operand),
119 }
120 }
121
122 pub fn constant(value: f64) -> Self {
124 Pattern::Constant(value)
125 }
126
127 pub fn zero() -> Self {
129 Pattern::Zero
130 }
131
132 pub fn one() -> Self {
134 Pattern::One
135 }
136
137 pub fn variable(name: impl Into<String>) -> Self {
139 Pattern::Variable(name.into())
140 }
141}
142
143#[derive(Debug, Clone, PartialEq)]
145pub struct Match {
146 pub root: NodeId,
148
149 pub captures: HashMap<String, NodeId>,
151
152 pub matched_nodes: HashSet<NodeId>,
154}
155
156impl Match {
157 pub fn new(root: NodeId) -> Self {
159 let mut matched_nodes = HashSet::new();
160 matched_nodes.insert(root);
161
162 Self {
163 root,
164 captures: HashMap::new(),
165 matched_nodes,
166 }
167 }
168
169 pub fn get_capture(&self, name: &str) -> Option<NodeId> {
171 self.captures.get(name).copied()
172 }
173
174 pub fn with_capture(mut self, name: String, node: NodeId) -> Self {
176 self.captures.insert(name, node);
177 self.matched_nodes.insert(node);
178 self
179 }
180
181 pub fn nodes(&self) -> &HashSet<NodeId> {
183 &self.matched_nodes
184 }
185}
186
187pub type ReplacementFn = Box<dyn Fn(&Match) -> Result<NodeId, RewriteError>>;
189
190pub struct RewriteRule {
192 pub name: String,
194
195 pub pattern: Pattern,
197
198 pub replacement: ReplacementFn,
200
201 pub priority: i32,
203
204 pub preserves_semantics: bool,
206}
207
208impl RewriteRule {
209 pub fn new(name: impl Into<String>) -> Self {
211 Self {
212 name: name.into(),
213 pattern: Pattern::Any,
214 replacement: Box::new(|m| Ok(m.root)),
215 priority: 0,
216 preserves_semantics: true,
217 }
218 }
219
220 pub fn with_pattern(mut self, pattern: Pattern) -> Self {
222 self.pattern = pattern;
223 self
224 }
225
226 pub fn with_replacement<F>(mut self, f: F) -> Self
228 where
229 F: Fn(&Match) -> Result<NodeId, RewriteError> + 'static,
230 {
231 self.replacement = Box::new(f);
232 self
233 }
234
235 pub fn with_priority(mut self, priority: i32) -> Self {
237 self.priority = priority;
238 self
239 }
240
241 pub fn with_semantics_preservation(mut self, preserves: bool) -> Self {
243 self.preserves_semantics = preserves;
244 self
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
250pub enum RewriteStrategy {
251 OnePass,
253
254 Exhaustive,
256
257 FixedPoint { max_iterations: usize },
259
260 Prioritized,
262
263 BottomUp,
265
266 TopDown,
268}
269
270impl Default for RewriteStrategy {
271 fn default() -> Self {
272 RewriteStrategy::Exhaustive
273 }
274}
275
276pub struct RewriteEngine {
278 rules: Vec<RewriteRule>,
280
281 strategy: RewriteStrategy,
283
284 stats: RewriteStats,
286
287 verify_semantics: bool,
289}
290
291impl RewriteEngine {
292 pub fn new() -> Self {
294 Self {
295 rules: Vec::new(),
296 strategy: RewriteStrategy::default(),
297 stats: RewriteStats::default(),
298 verify_semantics: false,
299 }
300 }
301
302 pub fn add_rule(mut self, rule: RewriteRule) -> Self {
304 self.rules.push(rule);
305 self
306 }
307
308 pub fn with_strategy(mut self, strategy: RewriteStrategy) -> Self {
310 self.strategy = strategy;
311 self
312 }
313
314 pub fn with_verification(mut self, enabled: bool) -> Self {
316 self.verify_semantics = enabled;
317 self
318 }
319
320 pub fn stats(&self) -> &RewriteStats {
322 &self.stats
323 }
324
325 pub fn reset_stats(&mut self) {
327 self.stats = RewriteStats::default();
328 }
329
330 fn sort_rules_by_priority(&mut self) {
332 self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
333 }
334
335 pub fn rewrite_simple(&mut self, node_count: usize) -> Result<usize, RewriteError> {
338 self.stats.graphs_processed += 1;
339
340 match self.strategy {
341 RewriteStrategy::OnePass => self.apply_one_pass(node_count),
342 RewriteStrategy::Exhaustive => self.apply_exhaustive(node_count),
343 RewriteStrategy::FixedPoint { max_iterations } => {
344 self.apply_fixed_point(node_count, max_iterations)
345 }
346 RewriteStrategy::Prioritized => {
347 self.sort_rules_by_priority();
348 self.apply_one_pass(node_count)
349 }
350 RewriteStrategy::BottomUp | RewriteStrategy::TopDown => self.apply_one_pass(node_count),
351 }
352 }
353
354 fn apply_one_pass(&mut self, node_count: usize) -> Result<usize, RewriteError> {
355 let mut rewrites = 0;
356
357 for rule in &self.rules {
359 if self.can_apply_rule(rule, node_count) {
361 rewrites += 1;
362 self.stats.rewrites_applied += 1;
363 self.stats
364 .rule_applications
365 .entry(rule.name.clone())
366 .and_modify(|c| *c += 1)
367 .or_insert(1);
368 }
369 }
370
371 Ok(node_count.saturating_sub(rewrites))
372 }
373
374 fn apply_exhaustive(&mut self, mut node_count: usize) -> Result<usize, RewriteError> {
375 let mut iteration = 0;
376 let max_iterations = 100; loop {
379 iteration += 1;
380 if iteration > max_iterations {
381 return Err(RewriteError::CycleDetected);
382 }
383
384 let before = node_count;
385 node_count = self.apply_one_pass(node_count)?;
386
387 if node_count == before {
388 break;
390 }
391 }
392
393 Ok(node_count)
394 }
395
396 fn apply_fixed_point(
397 &mut self,
398 mut node_count: usize,
399 max_iterations: usize,
400 ) -> Result<usize, RewriteError> {
401 for iteration in 0..max_iterations {
402 let before = node_count;
403 node_count = self.apply_one_pass(node_count)?;
404
405 if node_count == before {
406 self.stats.fixed_point_iterations = iteration + 1;
407 break;
408 }
409 }
410
411 Ok(node_count)
412 }
413
414 fn can_apply_rule(&self, _rule: &RewriteRule, _node_count: usize) -> bool {
415 true
417 }
418}
419
420impl Default for RewriteEngine {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
428pub struct RewriteStats {
429 pub graphs_processed: usize,
431
432 pub rewrites_applied: usize,
434
435 pub rule_applications: HashMap<String, usize>,
437
438 pub fixed_point_iterations: usize,
440
441 pub verification_failures: usize,
443}
444
445impl std::fmt::Display for RewriteStats {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 writeln!(f, "Rewrite Statistics")?;
448 writeln!(f, "==================")?;
449 writeln!(f, "Graphs processed: {}", self.graphs_processed)?;
450 writeln!(f, "Rewrites applied: {}", self.rewrites_applied)?;
451 writeln!(f, "Fixed-point iters: {}", self.fixed_point_iterations)?;
452 writeln!(f, "Verification fails: {}", self.verification_failures)?;
453
454 if !self.rule_applications.is_empty() {
455 writeln!(f, "\nRule Applications:")?;
456 let mut rules: Vec<_> = self.rule_applications.iter().collect();
457 rules.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
458 for (rule, count) in rules {
459 writeln!(f, " {}: {}", rule, count)?;
460 }
461 }
462
463 Ok(())
464 }
465}
466
467pub struct CommonRules;
469
470impl CommonRules {
471 pub fn eliminate_add_zero() -> RewriteRule {
473 RewriteRule::new("eliminate_add_zero")
474 .with_pattern(Pattern::binary_op("add", Pattern::any(), Pattern::zero()))
475 .with_replacement(|m| Ok(m.root))
476 .with_priority(10)
477 }
478
479 pub fn eliminate_mul_one() -> RewriteRule {
481 RewriteRule::new("eliminate_mul_one")
482 .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::one()))
483 .with_replacement(|m| Ok(m.root))
484 .with_priority(10)
485 }
486
487 pub fn eliminate_mul_zero() -> RewriteRule {
489 RewriteRule::new("eliminate_mul_zero")
490 .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::zero()))
491 .with_replacement(|_m| Ok(0)) .with_priority(10)
493 }
494
495 pub fn constant_folding() -> RewriteRule {
497 RewriteRule::new("constant_folding")
498 .with_pattern(Pattern::binary_op(
499 "add",
500 Pattern::constant(0.0), Pattern::constant(0.0),
502 ))
503 .with_replacement(|_m| Ok(0)) .with_priority(20)
505 }
506
507 pub fn associativity_add() -> RewriteRule {
509 RewriteRule::new("associativity_add")
510 .with_pattern(Pattern::binary_op(
511 "add",
512 Pattern::binary_op("add", Pattern::any(), Pattern::any()),
513 Pattern::any(),
514 ))
515 .with_replacement(|m| Ok(m.root))
516 .with_priority(5)
517 }
518
519 pub fn all() -> Vec<RewriteRule> {
521 vec![
522 Self::eliminate_add_zero(),
523 Self::eliminate_mul_one(),
524 Self::eliminate_mul_zero(),
525 Self::constant_folding(),
526 Self::associativity_add(),
527 ]
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_pattern_creation() {
537 let pattern = Pattern::binary_op("add", Pattern::any(), Pattern::zero());
538 assert!(matches!(pattern, Pattern::BinaryOp { .. }));
539 }
540
541 #[test]
542 fn test_pattern_helpers() {
543 let _ = Pattern::any();
544 let _ = Pattern::op("matmul");
545 let _ = Pattern::zero();
546 let _ = Pattern::one();
547 let _ = Pattern::constant(42.0);
548 let _ = Pattern::variable("x");
549 }
550
551 #[test]
552 fn test_match_creation() {
553 let m = Match::new(5);
554 assert_eq!(m.root, 5);
555 assert!(m.matched_nodes.contains(&5));
556 }
557
558 #[test]
559 fn test_match_captures() {
560 let m = Match::new(5).with_capture("x".to_string(), 10);
561 assert_eq!(m.get_capture("x"), Some(10));
562 assert!(m.matched_nodes.contains(&10));
563 }
564
565 #[test]
566 fn test_rewrite_rule_creation() {
567 let rule = RewriteRule::new("test_rule")
568 .with_pattern(Pattern::any())
569 .with_priority(10);
570
571 assert_eq!(rule.name, "test_rule");
572 assert_eq!(rule.priority, 10);
573 }
574
575 #[test]
576 fn test_rewrite_engine_creation() {
577 let engine = RewriteEngine::new();
578 assert_eq!(engine.rules.len(), 0);
579 assert_eq!(engine.strategy, RewriteStrategy::Exhaustive);
580 }
581
582 #[test]
583 fn test_rewrite_engine_add_rule() {
584 let rule = RewriteRule::new("test");
585 let engine = RewriteEngine::new().add_rule(rule);
586 assert_eq!(engine.rules.len(), 1);
587 }
588
589 #[test]
590 fn test_rewrite_strategy() {
591 let engine = RewriteEngine::new().with_strategy(RewriteStrategy::OnePass);
592 assert_eq!(engine.strategy, RewriteStrategy::OnePass);
593 }
594
595 #[test]
596 fn test_rewrite_stats() {
597 let stats = RewriteStats::default();
598 assert_eq!(stats.graphs_processed, 0);
599 assert_eq!(stats.rewrites_applied, 0);
600 }
601
602 #[test]
603 fn test_rewrite_stats_display() {
604 let mut stats = RewriteStats::default();
605 stats.graphs_processed = 5;
606 stats.rewrites_applied = 10;
607 stats.rule_applications.insert("rule1".to_string(), 7);
608
609 let display = format!("{}", stats);
610 assert!(display.contains("Graphs processed: 5"));
611 assert!(display.contains("Rewrites applied: 10"));
612 }
613
614 #[test]
615 fn test_common_rules() {
616 let rules = CommonRules::all();
617 assert!(!rules.is_empty());
618 assert_eq!(rules.len(), 5);
619 }
620
621 #[test]
622 fn test_eliminate_add_zero_rule() {
623 let rule = CommonRules::eliminate_add_zero();
624 assert_eq!(rule.name, "eliminate_add_zero");
625 assert_eq!(rule.priority, 10);
626 }
627
628 #[test]
629 fn test_rewrite_one_pass() {
630 let rule = RewriteRule::new("test");
631 let mut engine = RewriteEngine::new()
632 .add_rule(rule)
633 .with_strategy(RewriteStrategy::OnePass);
634
635 let result = engine.rewrite_simple(10).unwrap();
636 assert!(result <= 10);
637 assert!(engine.stats().graphs_processed > 0);
638 }
639
640 #[test]
641 fn test_rewrite_exhaustive() {
642 let rule = RewriteRule::new("test");
643 let mut engine = RewriteEngine::new()
644 .add_rule(rule)
645 .with_strategy(RewriteStrategy::Exhaustive);
646
647 let result = engine.rewrite_simple(10).unwrap();
648 assert!(result <= 10);
649 }
650
651 #[test]
652 fn test_rewrite_fixed_point() {
653 let rule = RewriteRule::new("test");
654 let mut engine = RewriteEngine::new()
655 .add_rule(rule)
656 .with_strategy(RewriteStrategy::FixedPoint { max_iterations: 10 });
657
658 let result = engine.rewrite_simple(10).unwrap();
659 assert!(result <= 10);
660 }
661
662 #[test]
663 fn test_rewrite_prioritized() {
664 let rule1 = RewriteRule::new("low").with_priority(1);
665 let rule2 = RewriteRule::new("high").with_priority(10);
666
667 let mut engine = RewriteEngine::new()
668 .add_rule(rule1)
669 .add_rule(rule2)
670 .with_strategy(RewriteStrategy::Prioritized);
671
672 engine.rewrite_simple(10).unwrap();
673 assert_eq!(engine.rules[0].name, "high");
675 }
676
677 #[test]
678 fn test_reset_stats() {
679 let rule = RewriteRule::new("test");
680 let mut engine = RewriteEngine::new().add_rule(rule);
681
682 engine.rewrite_simple(10).unwrap();
683 assert!(engine.stats().graphs_processed > 0);
684
685 engine.reset_stats();
686 assert_eq!(engine.stats().graphs_processed, 0);
687 }
688
689 #[test]
690 fn test_verification_flag() {
691 let engine = RewriteEngine::new().with_verification(true);
692 assert!(engine.verify_semantics);
693 }
694}