1use serde::{Deserialize, Serialize};
31use std::cmp::Reverse;
32use std::collections::{HashMap, HashSet};
33use thiserror::Error;
34
35#[derive(Error, Debug, Clone, PartialEq)]
37pub enum RewriteError {
38 #[error("Pattern matching failed: {0}")]
39 PatternMatchFailed(String),
40
41 #[error("Invalid rewrite rule: {0}")]
42 InvalidRule(String),
43
44 #[error("Rewrite application failed: {0}")]
45 ApplicationFailed(String),
46
47 #[error("Cycle detected in rewrite application")]
48 CycleDetected,
49
50 #[error("Semantics verification failed: {0}")]
51 SemanticsViolation(String),
52}
53
54pub type NodeId = usize;
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
59pub enum Pattern {
60 Any,
62
63 Op(String),
65
66 BinaryOp {
68 op: String,
69 left: Box<Pattern>,
70 right: Box<Pattern>,
71 },
72
73 UnaryOp { op: String, operand: Box<Pattern> },
75
76 Constant(f64),
78
79 Zero,
81
82 One,
84
85 Variable(String),
87
88 Sequence(Vec<Pattern>),
90
91 Alternative(Vec<Pattern>),
93}
94
95impl Pattern {
96 pub fn any() -> Self {
98 Pattern::Any
99 }
100
101 pub fn op(name: impl Into<String>) -> Self {
103 Pattern::Op(name.into())
104 }
105
106 pub fn binary_op(op: impl Into<String>, left: Pattern, right: Pattern) -> Self {
108 Pattern::BinaryOp {
109 op: op.into(),
110 left: Box::new(left),
111 right: Box::new(right),
112 }
113 }
114
115 pub fn unary_op(op: impl Into<String>, operand: Pattern) -> Self {
117 Pattern::UnaryOp {
118 op: op.into(),
119 operand: Box::new(operand),
120 }
121 }
122
123 pub fn constant(value: f64) -> Self {
125 Pattern::Constant(value)
126 }
127
128 pub fn zero() -> Self {
130 Pattern::Zero
131 }
132
133 pub fn one() -> Self {
135 Pattern::One
136 }
137
138 pub fn variable(name: impl Into<String>) -> Self {
140 Pattern::Variable(name.into())
141 }
142}
143
144#[derive(Debug, Clone, PartialEq)]
146pub struct Match {
147 pub root: NodeId,
149
150 pub captures: HashMap<String, NodeId>,
152
153 pub matched_nodes: HashSet<NodeId>,
155}
156
157impl Match {
158 pub fn new(root: NodeId) -> Self {
160 let mut matched_nodes = HashSet::new();
161 matched_nodes.insert(root);
162
163 Self {
164 root,
165 captures: HashMap::new(),
166 matched_nodes,
167 }
168 }
169
170 pub fn get_capture(&self, name: &str) -> Option<NodeId> {
172 self.captures.get(name).copied()
173 }
174
175 pub fn with_capture(mut self, name: String, node: NodeId) -> Self {
177 self.captures.insert(name, node);
178 self.matched_nodes.insert(node);
179 self
180 }
181
182 pub fn nodes(&self) -> &HashSet<NodeId> {
184 &self.matched_nodes
185 }
186}
187
188pub type ReplacementFn = Box<dyn Fn(&Match) -> Result<NodeId, RewriteError>>;
190
191pub struct RewriteRule {
193 pub name: String,
195
196 pub pattern: Pattern,
198
199 pub replacement: ReplacementFn,
201
202 pub priority: i32,
204
205 pub preserves_semantics: bool,
207}
208
209impl RewriteRule {
210 pub fn new(name: impl Into<String>) -> Self {
212 Self {
213 name: name.into(),
214 pattern: Pattern::Any,
215 replacement: Box::new(|m| Ok(m.root)),
216 priority: 0,
217 preserves_semantics: true,
218 }
219 }
220
221 pub fn with_pattern(mut self, pattern: Pattern) -> Self {
223 self.pattern = pattern;
224 self
225 }
226
227 pub fn with_replacement<F>(mut self, f: F) -> Self
229 where
230 F: Fn(&Match) -> Result<NodeId, RewriteError> + 'static,
231 {
232 self.replacement = Box::new(f);
233 self
234 }
235
236 pub fn with_priority(mut self, priority: i32) -> Self {
238 self.priority = priority;
239 self
240 }
241
242 pub fn with_semantics_preservation(mut self, preserves: bool) -> Self {
244 self.preserves_semantics = preserves;
245 self
246 }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
251pub enum RewriteStrategy {
252 OnePass,
254
255 Exhaustive,
257
258 FixedPoint { max_iterations: usize },
260
261 Prioritized,
263
264 BottomUp,
266
267 TopDown,
269}
270
271impl Default for RewriteStrategy {
272 fn default() -> Self {
273 RewriteStrategy::Exhaustive
274 }
275}
276
277pub struct RewriteEngine {
279 rules: Vec<RewriteRule>,
281
282 strategy: RewriteStrategy,
284
285 stats: RewriteStats,
287
288 verify_semantics: bool,
290}
291
292impl RewriteEngine {
293 pub fn new() -> Self {
295 Self {
296 rules: Vec::new(),
297 strategy: RewriteStrategy::default(),
298 stats: RewriteStats::default(),
299 verify_semantics: false,
300 }
301 }
302
303 pub fn add_rule(mut self, rule: RewriteRule) -> Self {
305 self.rules.push(rule);
306 self
307 }
308
309 pub fn with_strategy(mut self, strategy: RewriteStrategy) -> Self {
311 self.strategy = strategy;
312 self
313 }
314
315 pub fn with_verification(mut self, enabled: bool) -> Self {
317 self.verify_semantics = enabled;
318 self
319 }
320
321 pub fn stats(&self) -> &RewriteStats {
323 &self.stats
324 }
325
326 pub fn reset_stats(&mut self) {
328 self.stats = RewriteStats::default();
329 }
330
331 fn sort_rules_by_priority(&mut self) {
333 self.rules.sort_by_key(|b| Reverse(b.priority));
334 }
335
336 pub fn rewrite_simple(&mut self, node_count: usize) -> Result<usize, RewriteError> {
339 self.stats.graphs_processed += 1;
340
341 match self.strategy {
342 RewriteStrategy::OnePass => self.apply_one_pass(node_count),
343 RewriteStrategy::Exhaustive => self.apply_exhaustive(node_count),
344 RewriteStrategy::FixedPoint { max_iterations } => {
345 self.apply_fixed_point(node_count, max_iterations)
346 }
347 RewriteStrategy::Prioritized => {
348 self.sort_rules_by_priority();
349 self.apply_one_pass(node_count)
350 }
351 RewriteStrategy::BottomUp | RewriteStrategy::TopDown => self.apply_one_pass(node_count),
352 }
353 }
354
355 fn apply_one_pass(&mut self, node_count: usize) -> Result<usize, RewriteError> {
356 let mut rewrites = 0;
357
358 for rule in &self.rules {
360 if self.can_apply_rule(rule, node_count) {
362 rewrites += 1;
363 self.stats.rewrites_applied += 1;
364 self.stats
365 .rule_applications
366 .entry(rule.name.clone())
367 .and_modify(|c| *c += 1)
368 .or_insert(1);
369 }
370 }
371
372 Ok(node_count.saturating_sub(rewrites))
373 }
374
375 fn apply_exhaustive(&mut self, mut node_count: usize) -> Result<usize, RewriteError> {
376 let mut iteration = 0;
377 let max_iterations = 100; loop {
380 iteration += 1;
381 if iteration > max_iterations {
382 return Err(RewriteError::CycleDetected);
383 }
384
385 let before = node_count;
386 node_count = self.apply_one_pass(node_count)?;
387
388 if node_count == before {
389 break;
391 }
392 }
393
394 Ok(node_count)
395 }
396
397 fn apply_fixed_point(
398 &mut self,
399 mut node_count: usize,
400 max_iterations: usize,
401 ) -> Result<usize, RewriteError> {
402 for iteration in 0..max_iterations {
403 let before = node_count;
404 node_count = self.apply_one_pass(node_count)?;
405
406 if node_count == before {
407 self.stats.fixed_point_iterations = iteration + 1;
408 break;
409 }
410 }
411
412 Ok(node_count)
413 }
414
415 fn can_apply_rule(&self, _rule: &RewriteRule, _node_count: usize) -> bool {
416 true
418 }
419}
420
421impl Default for RewriteEngine {
422 fn default() -> Self {
423 Self::new()
424 }
425}
426
427#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
429pub struct RewriteStats {
430 pub graphs_processed: usize,
432
433 pub rewrites_applied: usize,
435
436 pub rule_applications: HashMap<String, usize>,
438
439 pub fixed_point_iterations: usize,
441
442 pub verification_failures: usize,
444}
445
446impl std::fmt::Display for RewriteStats {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 writeln!(f, "Rewrite Statistics")?;
449 writeln!(f, "==================")?;
450 writeln!(f, "Graphs processed: {}", self.graphs_processed)?;
451 writeln!(f, "Rewrites applied: {}", self.rewrites_applied)?;
452 writeln!(f, "Fixed-point iters: {}", self.fixed_point_iterations)?;
453 writeln!(f, "Verification fails: {}", self.verification_failures)?;
454
455 if !self.rule_applications.is_empty() {
456 writeln!(f, "\nRule Applications:")?;
457 let mut rules: Vec<_> = self.rule_applications.iter().collect();
458 rules.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
459 for (rule, count) in rules {
460 writeln!(f, " {}: {}", rule, count)?;
461 }
462 }
463
464 Ok(())
465 }
466}
467
468pub struct CommonRules;
470
471impl CommonRules {
472 pub fn eliminate_add_zero() -> RewriteRule {
474 RewriteRule::new("eliminate_add_zero")
475 .with_pattern(Pattern::binary_op("add", Pattern::any(), Pattern::zero()))
476 .with_replacement(|m| Ok(m.root))
477 .with_priority(10)
478 }
479
480 pub fn eliminate_mul_one() -> RewriteRule {
482 RewriteRule::new("eliminate_mul_one")
483 .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::one()))
484 .with_replacement(|m| Ok(m.root))
485 .with_priority(10)
486 }
487
488 pub fn eliminate_mul_zero() -> RewriteRule {
490 RewriteRule::new("eliminate_mul_zero")
491 .with_pattern(Pattern::binary_op("mul", Pattern::any(), Pattern::zero()))
492 .with_replacement(|_m| Ok(0)) .with_priority(10)
494 }
495
496 pub fn constant_folding() -> RewriteRule {
498 RewriteRule::new("constant_folding")
499 .with_pattern(Pattern::binary_op(
500 "add",
501 Pattern::constant(0.0), Pattern::constant(0.0),
503 ))
504 .with_replacement(|_m| Ok(0)) .with_priority(20)
506 }
507
508 pub fn associativity_add() -> RewriteRule {
510 RewriteRule::new("associativity_add")
511 .with_pattern(Pattern::binary_op(
512 "add",
513 Pattern::binary_op("add", Pattern::any(), Pattern::any()),
514 Pattern::any(),
515 ))
516 .with_replacement(|m| Ok(m.root))
517 .with_priority(5)
518 }
519
520 pub fn all() -> Vec<RewriteRule> {
522 vec![
523 Self::eliminate_add_zero(),
524 Self::eliminate_mul_one(),
525 Self::eliminate_mul_zero(),
526 Self::constant_folding(),
527 Self::associativity_add(),
528 ]
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_pattern_creation() {
538 let pattern = Pattern::binary_op("add", Pattern::any(), Pattern::zero());
539 assert!(matches!(pattern, Pattern::BinaryOp { .. }));
540 }
541
542 #[test]
543 fn test_pattern_helpers() {
544 let _ = Pattern::any();
545 let _ = Pattern::op("matmul");
546 let _ = Pattern::zero();
547 let _ = Pattern::one();
548 let _ = Pattern::constant(42.0);
549 let _ = Pattern::variable("x");
550 }
551
552 #[test]
553 fn test_match_creation() {
554 let m = Match::new(5);
555 assert_eq!(m.root, 5);
556 assert!(m.matched_nodes.contains(&5));
557 }
558
559 #[test]
560 fn test_match_captures() {
561 let m = Match::new(5).with_capture("x".to_string(), 10);
562 assert_eq!(m.get_capture("x"), Some(10));
563 assert!(m.matched_nodes.contains(&10));
564 }
565
566 #[test]
567 fn test_rewrite_rule_creation() {
568 let rule = RewriteRule::new("test_rule")
569 .with_pattern(Pattern::any())
570 .with_priority(10);
571
572 assert_eq!(rule.name, "test_rule");
573 assert_eq!(rule.priority, 10);
574 }
575
576 #[test]
577 fn test_rewrite_engine_creation() {
578 let engine = RewriteEngine::new();
579 assert_eq!(engine.rules.len(), 0);
580 assert_eq!(engine.strategy, RewriteStrategy::Exhaustive);
581 }
582
583 #[test]
584 fn test_rewrite_engine_add_rule() {
585 let rule = RewriteRule::new("test");
586 let engine = RewriteEngine::new().add_rule(rule);
587 assert_eq!(engine.rules.len(), 1);
588 }
589
590 #[test]
591 fn test_rewrite_strategy() {
592 let engine = RewriteEngine::new().with_strategy(RewriteStrategy::OnePass);
593 assert_eq!(engine.strategy, RewriteStrategy::OnePass);
594 }
595
596 #[test]
597 fn test_rewrite_stats() {
598 let stats = RewriteStats::default();
599 assert_eq!(stats.graphs_processed, 0);
600 assert_eq!(stats.rewrites_applied, 0);
601 }
602
603 #[test]
604 fn test_rewrite_stats_display() {
605 let mut stats = RewriteStats::default();
606 stats.graphs_processed = 5;
607 stats.rewrites_applied = 10;
608 stats.rule_applications.insert("rule1".to_string(), 7);
609
610 let display = format!("{}", stats);
611 assert!(display.contains("Graphs processed: 5"));
612 assert!(display.contains("Rewrites applied: 10"));
613 }
614
615 #[test]
616 fn test_common_rules() {
617 let rules = CommonRules::all();
618 assert!(!rules.is_empty());
619 assert_eq!(rules.len(), 5);
620 }
621
622 #[test]
623 fn test_eliminate_add_zero_rule() {
624 let rule = CommonRules::eliminate_add_zero();
625 assert_eq!(rule.name, "eliminate_add_zero");
626 assert_eq!(rule.priority, 10);
627 }
628
629 #[test]
630 fn test_rewrite_one_pass() {
631 let rule = RewriteRule::new("test");
632 let mut engine = RewriteEngine::new()
633 .add_rule(rule)
634 .with_strategy(RewriteStrategy::OnePass);
635
636 let result = engine.rewrite_simple(10).expect("unwrap");
637 assert!(result <= 10);
638 assert!(engine.stats().graphs_processed > 0);
639 }
640
641 #[test]
642 fn test_rewrite_exhaustive() {
643 let rule = RewriteRule::new("test");
644 let mut engine = RewriteEngine::new()
645 .add_rule(rule)
646 .with_strategy(RewriteStrategy::Exhaustive);
647
648 let result = engine.rewrite_simple(10).expect("unwrap");
649 assert!(result <= 10);
650 }
651
652 #[test]
653 fn test_rewrite_fixed_point() {
654 let rule = RewriteRule::new("test");
655 let mut engine = RewriteEngine::new()
656 .add_rule(rule)
657 .with_strategy(RewriteStrategy::FixedPoint { max_iterations: 10 });
658
659 let result = engine.rewrite_simple(10).expect("unwrap");
660 assert!(result <= 10);
661 }
662
663 #[test]
664 fn test_rewrite_prioritized() {
665 let rule1 = RewriteRule::new("low").with_priority(1);
666 let rule2 = RewriteRule::new("high").with_priority(10);
667
668 let mut engine = RewriteEngine::new()
669 .add_rule(rule1)
670 .add_rule(rule2)
671 .with_strategy(RewriteStrategy::Prioritized);
672
673 engine.rewrite_simple(10).expect("unwrap");
674 assert_eq!(engine.rules[0].name, "high");
676 }
677
678 #[test]
679 fn test_reset_stats() {
680 let rule = RewriteRule::new("test");
681 let mut engine = RewriteEngine::new().add_rule(rule);
682
683 engine.rewrite_simple(10).expect("unwrap");
684 assert!(engine.stats().graphs_processed > 0);
685
686 engine.reset_stats();
687 assert_eq!(engine.stats().graphs_processed, 0);
688 }
689
690 #[test]
691 fn test_verification_flag() {
692 let engine = RewriteEngine::new().with_verification(true);
693 assert!(engine.verify_semantics);
694 }
695}