1use std::cmp::Reverse;
38use std::collections::{HashMap, HashSet};
39use std::hash::{Hash, Hasher};
40
41use super::TLExpr;
42use crate::util::ExprStats;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
49pub enum RulePriority {
50 Critical = 100,
52 High = 75,
54 #[default]
56 Normal = 50,
57 Low = 25,
59 Minimal = 0,
61}
62
63pub type GuardPredicate = fn(&HashMap<String, TLExpr>) -> bool;
67
68pub type TransformFn = fn(&TLExpr) -> Option<TLExpr>;
72
73#[derive(Clone)]
80pub struct ConditionalRule {
81 pub name: String,
83 pub transform: TransformFn,
85 pub guard: GuardPredicate,
87 pub priority: RulePriority,
89 pub description: Option<String>,
91 applications: usize,
93}
94
95impl ConditionalRule {
96 pub fn new(name: impl Into<String>, transform: TransformFn, guard: GuardPredicate) -> Self {
98 Self {
99 name: name.into(),
100 transform,
101 guard,
102 priority: RulePriority::default(),
103 description: None,
104 applications: 0,
105 }
106 }
107
108 pub fn with_priority(mut self, priority: RulePriority) -> Self {
110 self.priority = priority;
111 self
112 }
113
114 pub fn with_description(mut self, description: impl Into<String>) -> Self {
116 self.description = Some(description.into());
117 self
118 }
119
120 pub fn apply(&mut self, expr: &TLExpr) -> Option<TLExpr> {
124 let bindings = HashMap::new(); if (self.guard)(&bindings) {
126 if let Some(result) = (self.transform)(expr) {
127 self.applications += 1;
128 return Some(result);
129 }
130 }
131 None
132 }
133
134 pub fn application_count(&self) -> usize {
136 self.applications
137 }
138
139 pub fn reset_counter(&mut self) {
141 self.applications = 0;
142 }
143}
144
145impl std::fmt::Debug for ConditionalRule {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("ConditionalRule")
148 .field("name", &self.name)
149 .field("priority", &self.priority)
150 .field("description", &self.description)
151 .field("applications", &self.applications)
152 .finish()
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
158pub enum RewriteStrategy {
159 Innermost,
161 Outermost,
163 #[default]
165 BottomUp,
166 TopDown,
168 FixpointPerNode,
170 GlobalFixpoint,
172}
173
174#[derive(Debug, Clone)]
176pub struct RewriteConfig {
177 pub max_steps: usize,
179 pub strategy: RewriteStrategy,
181 pub detect_cycles: bool,
183 pub trace: bool,
185 pub max_expr_size: Option<usize>,
187}
188
189impl Default for RewriteConfig {
190 fn default() -> Self {
191 Self {
192 max_steps: 10000,
193 strategy: RewriteStrategy::default(),
194 detect_cycles: true,
195 trace: false,
196 max_expr_size: Some(100000), }
198 }
199}
200
201#[derive(Debug, Clone, Default)]
203pub struct RewriteStats {
204 pub steps: usize,
206 pub rule_applications: usize,
208 pub rule_counts: HashMap<String, usize>,
210 pub reached_fixpoint: bool,
212 pub cycle_detected: bool,
214 pub size_limit_exceeded: bool,
216 pub initial_size: usize,
218 pub final_size: usize,
220}
221
222impl RewriteStats {
223 pub fn reduction_percentage(&self) -> f64 {
225 if self.initial_size == 0 {
226 return 0.0;
227 }
228 100.0 * (1.0 - (self.final_size as f64 / self.initial_size as f64))
229 }
230
231 pub fn is_successful(&self) -> bool {
233 self.reached_fixpoint && !self.cycle_detected && !self.size_limit_exceeded
234 }
235}
236
237fn expr_hash(expr: &TLExpr) -> u64 {
239 let mut hasher = std::collections::hash_map::DefaultHasher::new();
240 format!("{:?}", expr).hash(&mut hasher);
242 hasher.finish()
243}
244
245pub struct AdvancedRewriteSystem {
247 rules: Vec<ConditionalRule>,
249 config: RewriteConfig,
251 seen_hashes: HashSet<u64>,
253}
254
255impl AdvancedRewriteSystem {
256 pub fn new() -> Self {
258 Self {
259 rules: Vec::new(),
260 config: RewriteConfig::default(),
261 seen_hashes: HashSet::new(),
262 }
263 }
264
265 pub fn with_config(config: RewriteConfig) -> Self {
267 Self {
268 rules: Vec::new(),
269 config,
270 seen_hashes: HashSet::new(),
271 }
272 }
273
274 pub fn add_rule(mut self, rule: ConditionalRule) -> Self {
276 self.rules.push(rule);
277 self.rules.sort_by_key(|r| Reverse(r.priority));
279 self
280 }
281
282 pub fn apply(&mut self, expr: &TLExpr) -> (TLExpr, RewriteStats) {
284 let initial_stats = ExprStats::compute(expr);
285 let mut stats = RewriteStats {
286 initial_size: initial_stats.node_count,
287 ..Default::default()
288 };
289
290 self.seen_hashes.clear();
291
292 let result = match self.config.strategy {
293 RewriteStrategy::Innermost => self.apply_innermost(expr, &mut stats),
294 RewriteStrategy::Outermost => self.apply_outermost(expr, &mut stats),
295 RewriteStrategy::BottomUp => self.apply_bottom_up(expr, &mut stats),
296 RewriteStrategy::TopDown => self.apply_top_down(expr, &mut stats),
297 RewriteStrategy::FixpointPerNode => self.apply_fixpoint_per_node(expr, &mut stats),
298 RewriteStrategy::GlobalFixpoint => self.apply_global_fixpoint(expr, &mut stats),
299 };
300
301 let final_stats = ExprStats::compute(&result);
302 stats.final_size = final_stats.node_count;
303
304 if stats.steps < self.config.max_steps && !stats.cycle_detected {
306 stats.reached_fixpoint = true;
307 }
308
309 (result, stats)
310 }
311
312 fn try_apply_at_node(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
314 for rule in &mut self.rules {
315 if let Some(result) = rule.apply(expr) {
316 stats.rule_applications += 1;
317 *stats.rule_counts.entry(rule.name.clone()).or_insert(0) += 1;
318
319 if self.config.trace {
320 eprintln!("Applied rule '{}' at step {}", rule.name, stats.steps);
321 }
322
323 return Some(result);
324 }
325 }
326 None
327 }
328
329 fn check_constraints(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> bool {
331 if self.config.detect_cycles {
333 let hash = expr_hash(expr);
334 if self.seen_hashes.contains(&hash) {
335 stats.cycle_detected = true;
336 return false;
337 }
338 self.seen_hashes.insert(hash);
339 }
340
341 if let Some(max_size) = self.config.max_expr_size {
343 let current_stats = ExprStats::compute(expr);
344 if current_stats.node_count > max_size {
345 stats.size_limit_exceeded = true;
346 return false;
347 }
348 }
349
350 if stats.steps >= self.config.max_steps {
352 return false;
353 }
354
355 true
356 }
357
358 fn apply_innermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
360 let mut current = expr.clone();
361
362 while stats.steps < self.config.max_steps {
363 stats.steps += 1;
364
365 if !self.check_constraints(¤t, stats) {
366 break;
367 }
368
369 if let Some(rewritten) = self.rewrite_innermost(¤t, stats) {
371 current = rewritten;
372 } else {
373 break; }
375 }
376
377 current
378 }
379
380 fn rewrite_innermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
382 let children_rewritten = self.rewrite_children(expr, stats);
384 if let Some(new_expr) = children_rewritten {
385 return Some(new_expr);
386 }
387
388 self.try_apply_at_node(expr, stats)
390 }
391
392 fn apply_outermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
394 let mut current = expr.clone();
395
396 while stats.steps < self.config.max_steps {
397 stats.steps += 1;
398
399 if !self.check_constraints(¤t, stats) {
400 break;
401 }
402
403 if let Some(rewritten) = self.try_apply_at_node(¤t, stats) {
405 current = rewritten;
406 continue;
407 }
408
409 if let Some(rewritten) = self.rewrite_children(¤t, stats) {
411 current = rewritten;
412 } else {
413 break;
414 }
415 }
416
417 current
418 }
419
420 fn apply_bottom_up(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
422 stats.steps += 1;
423
424 if !self.check_constraints(expr, stats) {
425 return expr.clone();
426 }
427
428 let with_transformed_children = self.transform_children_bottom_up(expr, stats);
430
431 if let Some(result) = self.try_apply_at_node(&with_transformed_children, stats) {
433 result
434 } else {
435 with_transformed_children
436 }
437 }
438
439 fn apply_top_down(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
441 stats.steps += 1;
442
443 if !self.check_constraints(expr, stats) {
444 return expr.clone();
445 }
446
447 let current = if let Some(result) = self.try_apply_at_node(expr, stats) {
449 result
450 } else {
451 expr.clone()
452 };
453
454 self.transform_children_top_down(¤t, stats)
456 }
457
458 fn apply_fixpoint_per_node(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
460 let mut current = expr.clone();
461
462 while let Some(rewritten) = self.try_apply_at_node(¤t, stats) {
464 current = rewritten;
465 stats.steps += 1;
466 if !self.check_constraints(¤t, stats) {
467 return current;
468 }
469 }
470
471 self.transform_children_fixpoint_per_node(¤t, stats)
473 }
474
475 fn apply_global_fixpoint(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
477 let mut current = expr.clone();
478
479 loop {
480 stats.steps += 1;
481
482 if !self.check_constraints(¤t, stats) {
483 break;
484 }
485
486 let next = self.apply_bottom_up(¤t, stats);
487 if next == current {
488 break; }
490 current = next;
491 }
492
493 current
494 }
495
496 fn rewrite_children(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
498 match expr {
499 TLExpr::And(l, r) => {
500 let l_new = self.rewrite_innermost(l, stats);
501 let r_new = self.rewrite_innermost(r, stats);
502 if l_new.is_some() || r_new.is_some() {
503 Some(TLExpr::and(
504 l_new.unwrap_or_else(|| (**l).clone()),
505 r_new.unwrap_or_else(|| (**r).clone()),
506 ))
507 } else {
508 None
509 }
510 }
511 TLExpr::Or(l, r) => {
512 let l_new = self.rewrite_innermost(l, stats);
513 let r_new = self.rewrite_innermost(r, stats);
514 if l_new.is_some() || r_new.is_some() {
515 Some(TLExpr::or(
516 l_new.unwrap_or_else(|| (**l).clone()),
517 r_new.unwrap_or_else(|| (**r).clone()),
518 ))
519 } else {
520 None
521 }
522 }
523 TLExpr::Not(e) => self.rewrite_innermost(e, stats).map(TLExpr::negate),
524 _ => None,
526 }
527 }
528
529 fn transform_children_bottom_up(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
531 match expr {
532 TLExpr::And(l, r) => TLExpr::and(
533 self.apply_bottom_up(l, stats),
534 self.apply_bottom_up(r, stats),
535 ),
536 TLExpr::Or(l, r) => TLExpr::or(
537 self.apply_bottom_up(l, stats),
538 self.apply_bottom_up(r, stats),
539 ),
540 TLExpr::Not(e) => TLExpr::negate(self.apply_bottom_up(e, stats)),
541 TLExpr::Imply(l, r) => TLExpr::imply(
542 self.apply_bottom_up(l, stats),
543 self.apply_bottom_up(r, stats),
544 ),
545 _ => expr.clone(),
547 }
548 }
549
550 fn transform_children_top_down(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
552 match expr {
553 TLExpr::And(l, r) => {
554 TLExpr::and(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
555 }
556 TLExpr::Or(l, r) => {
557 TLExpr::or(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
558 }
559 TLExpr::Not(e) => TLExpr::negate(self.apply_top_down(e, stats)),
560 TLExpr::Imply(l, r) => {
561 TLExpr::imply(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
562 }
563 _ => expr.clone(),
565 }
566 }
567
568 fn transform_children_fixpoint_per_node(
570 &mut self,
571 expr: &TLExpr,
572 stats: &mut RewriteStats,
573 ) -> TLExpr {
574 match expr {
575 TLExpr::And(l, r) => TLExpr::and(
576 self.apply_fixpoint_per_node(l, stats),
577 self.apply_fixpoint_per_node(r, stats),
578 ),
579 TLExpr::Or(l, r) => TLExpr::or(
580 self.apply_fixpoint_per_node(l, stats),
581 self.apply_fixpoint_per_node(r, stats),
582 ),
583 TLExpr::Not(e) => TLExpr::negate(self.apply_fixpoint_per_node(e, stats)),
584 TLExpr::Imply(l, r) => TLExpr::imply(
585 self.apply_fixpoint_per_node(l, stats),
586 self.apply_fixpoint_per_node(r, stats),
587 ),
588 _ => expr.clone(),
590 }
591 }
592
593 pub fn rule_statistics(&self) -> Vec<(&str, usize)> {
595 self.rules
596 .iter()
597 .map(|r| (r.name.as_str(), r.application_count()))
598 .collect()
599 }
600
601 pub fn reset_statistics(&mut self) {
603 for rule in &mut self.rules {
604 rule.reset_counter();
605 }
606 }
607}
608
609impl Default for AdvancedRewriteSystem {
610 fn default() -> Self {
611 Self::new()
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::{TLExpr, Term};
619
620 #[test]
621 fn test_conditional_rule_basic() {
622 let mut rule = ConditionalRule::new(
623 "remove_double_neg",
624 |expr| {
625 if let TLExpr::Not(inner) = expr {
626 if let TLExpr::Not(inner_inner) = &**inner {
627 return Some((**inner_inner).clone());
628 }
629 }
630 None
631 },
632 |_| true,
633 );
634
635 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
636 let result = rule.apply(&expr).expect("unwrap");
637
638 assert!(matches!(result, TLExpr::Pred { .. }));
639 assert_eq!(rule.application_count(), 1);
640 }
641
642 #[test]
643 fn test_priority_ordering() {
644 let mut system = AdvancedRewriteSystem::new();
645
646 system = system.add_rule(
648 ConditionalRule::new("low", |_| None, |_| true).with_priority(RulePriority::Low),
649 );
650 system = system.add_rule(
651 ConditionalRule::new("high", |_| None, |_| true).with_priority(RulePriority::High),
652 );
653 system = system.add_rule(
654 ConditionalRule::new("critical", |_| None, |_| true)
655 .with_priority(RulePriority::Critical),
656 );
657
658 assert_eq!(system.rules[0].priority, RulePriority::Critical);
660 assert_eq!(system.rules[1].priority, RulePriority::High);
661 assert_eq!(system.rules[2].priority, RulePriority::Low);
662 }
663
664 #[test]
665 fn test_bottom_up_strategy() {
666 let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
667 strategy: RewriteStrategy::BottomUp,
668 max_steps: 100,
669 ..Default::default()
670 });
671
672 system = system.add_rule(ConditionalRule::new(
674 "double_neg",
675 |expr| {
676 if let TLExpr::Not(inner) = expr {
677 if let TLExpr::Not(inner_inner) = &**inner {
678 return Some((**inner_inner).clone());
679 }
680 }
681 None
682 },
683 |_| true,
684 ));
685
686 let expr = TLExpr::negate(TLExpr::negate(TLExpr::negate(TLExpr::negate(
688 TLExpr::pred("P", vec![Term::var("x")]),
689 ))));
690
691 let (result, stats) = system.apply(&expr);
692
693 assert!(matches!(result, TLExpr::Pred { .. }));
695 assert_eq!(stats.rule_applications, 2); }
697
698 #[test]
699 fn test_cycle_detection() {
700 let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
701 strategy: RewriteStrategy::GlobalFixpoint,
702 detect_cycles: true,
703 max_steps: 1000,
704 ..Default::default()
705 });
706
707 system = system.add_rule(ConditionalRule::new(
709 "add_double_neg",
710 |expr| {
711 if let TLExpr::Pred { .. } = expr {
712 return Some(TLExpr::negate(TLExpr::negate(expr.clone())));
713 }
714 None
715 },
716 |_| true,
717 ));
718
719 system = system.add_rule(ConditionalRule::new(
720 "remove_double_neg",
721 |expr| {
722 if let TLExpr::Not(inner) = expr {
723 if let TLExpr::Not(inner_inner) = &**inner {
724 return Some((**inner_inner).clone());
725 }
726 }
727 None
728 },
729 |_| true,
730 ));
731
732 let expr = TLExpr::pred("P", vec![Term::var("x")]);
733 let (_result, stats) = system.apply(&expr);
734
735 assert!(stats.cycle_detected || stats.steps >= 1000);
737 }
738
739 #[test]
740 fn test_size_limit() {
741 let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
742 strategy: RewriteStrategy::Innermost, max_expr_size: Some(10),
744 detect_cycles: false, ..Default::default()
746 });
747
748 system = system.add_rule(ConditionalRule::new(
750 "duplicate",
751 |expr| {
752 if let TLExpr::Pred { .. } = expr {
753 return Some(TLExpr::and(expr.clone(), expr.clone()));
754 }
755 None
756 },
757 |_| true,
758 ));
759
760 let expr = TLExpr::pred("P", vec![Term::var("x")]);
761 let (_result, stats) = system.apply(&expr);
762
763 assert!(stats.size_limit_exceeded || stats.steps >= system.config.max_steps);
765 }
766
767 #[test]
768 fn test_rewrite_stats() {
769 let mut system = AdvancedRewriteSystem::new();
770
771 system = system.add_rule(ConditionalRule::new(
772 "test_rule",
773 |expr| {
774 if let TLExpr::Not(inner) = expr {
775 if let TLExpr::Not(inner_inner) = &**inner {
776 return Some((**inner_inner).clone());
777 }
778 }
779 None
780 },
781 |_| true,
782 ));
783
784 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
785 let (_result, stats) = system.apply(&expr);
786
787 assert!(stats.is_successful());
788 assert!(stats.reduction_percentage() > 0.0);
789 assert_eq!(stats.rule_applications, 1);
790 }
791}