1use std::collections::{HashMap, HashSet};
38use std::hash::{Hash, Hasher};
39
40use super::TLExpr;
41use crate::util::ExprStats;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
48pub enum RulePriority {
49 Critical = 100,
51 High = 75,
53 #[default]
55 Normal = 50,
56 Low = 25,
58 Minimal = 0,
60}
61
62pub type GuardPredicate = fn(&HashMap<String, TLExpr>) -> bool;
66
67pub type TransformFn = fn(&TLExpr) -> Option<TLExpr>;
71
72#[derive(Clone)]
79pub struct ConditionalRule {
80 pub name: String,
82 pub transform: TransformFn,
84 pub guard: GuardPredicate,
86 pub priority: RulePriority,
88 pub description: Option<String>,
90 applications: usize,
92}
93
94impl ConditionalRule {
95 pub fn new(name: impl Into<String>, transform: TransformFn, guard: GuardPredicate) -> Self {
97 Self {
98 name: name.into(),
99 transform,
100 guard,
101 priority: RulePriority::default(),
102 description: None,
103 applications: 0,
104 }
105 }
106
107 pub fn with_priority(mut self, priority: RulePriority) -> Self {
109 self.priority = priority;
110 self
111 }
112
113 pub fn with_description(mut self, description: impl Into<String>) -> Self {
115 self.description = Some(description.into());
116 self
117 }
118
119 pub fn apply(&mut self, expr: &TLExpr) -> Option<TLExpr> {
123 let bindings = HashMap::new(); if (self.guard)(&bindings) {
125 if let Some(result) = (self.transform)(expr) {
126 self.applications += 1;
127 return Some(result);
128 }
129 }
130 None
131 }
132
133 pub fn application_count(&self) -> usize {
135 self.applications
136 }
137
138 pub fn reset_counter(&mut self) {
140 self.applications = 0;
141 }
142}
143
144impl std::fmt::Debug for ConditionalRule {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("ConditionalRule")
147 .field("name", &self.name)
148 .field("priority", &self.priority)
149 .field("description", &self.description)
150 .field("applications", &self.applications)
151 .finish()
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
157pub enum RewriteStrategy {
158 Innermost,
160 Outermost,
162 #[default]
164 BottomUp,
165 TopDown,
167 FixpointPerNode,
169 GlobalFixpoint,
171}
172
173#[derive(Debug, Clone)]
175pub struct RewriteConfig {
176 pub max_steps: usize,
178 pub strategy: RewriteStrategy,
180 pub detect_cycles: bool,
182 pub trace: bool,
184 pub max_expr_size: Option<usize>,
186}
187
188impl Default for RewriteConfig {
189 fn default() -> Self {
190 Self {
191 max_steps: 10000,
192 strategy: RewriteStrategy::default(),
193 detect_cycles: true,
194 trace: false,
195 max_expr_size: Some(100000), }
197 }
198}
199
200#[derive(Debug, Clone, Default)]
202pub struct RewriteStats {
203 pub steps: usize,
205 pub rule_applications: usize,
207 pub rule_counts: HashMap<String, usize>,
209 pub reached_fixpoint: bool,
211 pub cycle_detected: bool,
213 pub size_limit_exceeded: bool,
215 pub initial_size: usize,
217 pub final_size: usize,
219}
220
221impl RewriteStats {
222 pub fn reduction_percentage(&self) -> f64 {
224 if self.initial_size == 0 {
225 return 0.0;
226 }
227 100.0 * (1.0 - (self.final_size as f64 / self.initial_size as f64))
228 }
229
230 pub fn is_successful(&self) -> bool {
232 self.reached_fixpoint && !self.cycle_detected && !self.size_limit_exceeded
233 }
234}
235
236fn expr_hash(expr: &TLExpr) -> u64 {
238 let mut hasher = std::collections::hash_map::DefaultHasher::new();
239 format!("{:?}", expr).hash(&mut hasher);
241 hasher.finish()
242}
243
244pub struct AdvancedRewriteSystem {
246 rules: Vec<ConditionalRule>,
248 config: RewriteConfig,
250 seen_hashes: HashSet<u64>,
252}
253
254impl AdvancedRewriteSystem {
255 pub fn new() -> Self {
257 Self {
258 rules: Vec::new(),
259 config: RewriteConfig::default(),
260 seen_hashes: HashSet::new(),
261 }
262 }
263
264 pub fn with_config(config: RewriteConfig) -> Self {
266 Self {
267 rules: Vec::new(),
268 config,
269 seen_hashes: HashSet::new(),
270 }
271 }
272
273 pub fn add_rule(mut self, rule: ConditionalRule) -> Self {
275 self.rules.push(rule);
276 self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
278 self
279 }
280
281 pub fn apply(&mut self, expr: &TLExpr) -> (TLExpr, RewriteStats) {
283 let initial_stats = ExprStats::compute(expr);
284 let mut stats = RewriteStats {
285 initial_size: initial_stats.node_count,
286 ..Default::default()
287 };
288
289 self.seen_hashes.clear();
290
291 let result = match self.config.strategy {
292 RewriteStrategy::Innermost => self.apply_innermost(expr, &mut stats),
293 RewriteStrategy::Outermost => self.apply_outermost(expr, &mut stats),
294 RewriteStrategy::BottomUp => self.apply_bottom_up(expr, &mut stats),
295 RewriteStrategy::TopDown => self.apply_top_down(expr, &mut stats),
296 RewriteStrategy::FixpointPerNode => self.apply_fixpoint_per_node(expr, &mut stats),
297 RewriteStrategy::GlobalFixpoint => self.apply_global_fixpoint(expr, &mut stats),
298 };
299
300 let final_stats = ExprStats::compute(&result);
301 stats.final_size = final_stats.node_count;
302
303 if stats.steps < self.config.max_steps && !stats.cycle_detected {
305 stats.reached_fixpoint = true;
306 }
307
308 (result, stats)
309 }
310
311 fn try_apply_at_node(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
313 for rule in &mut self.rules {
314 if let Some(result) = rule.apply(expr) {
315 stats.rule_applications += 1;
316 *stats.rule_counts.entry(rule.name.clone()).or_insert(0) += 1;
317
318 if self.config.trace {
319 eprintln!("Applied rule '{}' at step {}", rule.name, stats.steps);
320 }
321
322 return Some(result);
323 }
324 }
325 None
326 }
327
328 fn check_constraints(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> bool {
330 if self.config.detect_cycles {
332 let hash = expr_hash(expr);
333 if self.seen_hashes.contains(&hash) {
334 stats.cycle_detected = true;
335 return false;
336 }
337 self.seen_hashes.insert(hash);
338 }
339
340 if let Some(max_size) = self.config.max_expr_size {
342 let current_stats = ExprStats::compute(expr);
343 if current_stats.node_count > max_size {
344 stats.size_limit_exceeded = true;
345 return false;
346 }
347 }
348
349 if stats.steps >= self.config.max_steps {
351 return false;
352 }
353
354 true
355 }
356
357 fn apply_innermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
359 let mut current = expr.clone();
360
361 while stats.steps < self.config.max_steps {
362 stats.steps += 1;
363
364 if !self.check_constraints(¤t, stats) {
365 break;
366 }
367
368 if let Some(rewritten) = self.rewrite_innermost(¤t, stats) {
370 current = rewritten;
371 } else {
372 break; }
374 }
375
376 current
377 }
378
379 fn rewrite_innermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
381 let children_rewritten = self.rewrite_children(expr, stats);
383 if let Some(new_expr) = children_rewritten {
384 return Some(new_expr);
385 }
386
387 self.try_apply_at_node(expr, stats)
389 }
390
391 fn apply_outermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
393 let mut current = expr.clone();
394
395 while stats.steps < self.config.max_steps {
396 stats.steps += 1;
397
398 if !self.check_constraints(¤t, stats) {
399 break;
400 }
401
402 if let Some(rewritten) = self.try_apply_at_node(¤t, stats) {
404 current = rewritten;
405 continue;
406 }
407
408 if let Some(rewritten) = self.rewrite_children(¤t, stats) {
410 current = rewritten;
411 } else {
412 break;
413 }
414 }
415
416 current
417 }
418
419 fn apply_bottom_up(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
421 stats.steps += 1;
422
423 if !self.check_constraints(expr, stats) {
424 return expr.clone();
425 }
426
427 let with_transformed_children = self.transform_children_bottom_up(expr, stats);
429
430 if let Some(result) = self.try_apply_at_node(&with_transformed_children, stats) {
432 result
433 } else {
434 with_transformed_children
435 }
436 }
437
438 fn apply_top_down(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
440 stats.steps += 1;
441
442 if !self.check_constraints(expr, stats) {
443 return expr.clone();
444 }
445
446 let current = if let Some(result) = self.try_apply_at_node(expr, stats) {
448 result
449 } else {
450 expr.clone()
451 };
452
453 self.transform_children_top_down(¤t, stats)
455 }
456
457 fn apply_fixpoint_per_node(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
459 let mut current = expr.clone();
460
461 while let Some(rewritten) = self.try_apply_at_node(¤t, stats) {
463 current = rewritten;
464 stats.steps += 1;
465 if !self.check_constraints(¤t, stats) {
466 return current;
467 }
468 }
469
470 self.transform_children_fixpoint_per_node(¤t, stats)
472 }
473
474 fn apply_global_fixpoint(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
476 let mut current = expr.clone();
477
478 loop {
479 stats.steps += 1;
480
481 if !self.check_constraints(¤t, stats) {
482 break;
483 }
484
485 let next = self.apply_bottom_up(¤t, stats);
486 if next == current {
487 break; }
489 current = next;
490 }
491
492 current
493 }
494
495 fn rewrite_children(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
497 match expr {
498 TLExpr::And(l, r) => {
499 let l_new = self.rewrite_innermost(l, stats);
500 let r_new = self.rewrite_innermost(r, stats);
501 if l_new.is_some() || r_new.is_some() {
502 Some(TLExpr::and(
503 l_new.unwrap_or_else(|| (**l).clone()),
504 r_new.unwrap_or_else(|| (**r).clone()),
505 ))
506 } else {
507 None
508 }
509 }
510 TLExpr::Or(l, r) => {
511 let l_new = self.rewrite_innermost(l, stats);
512 let r_new = self.rewrite_innermost(r, stats);
513 if l_new.is_some() || r_new.is_some() {
514 Some(TLExpr::or(
515 l_new.unwrap_or_else(|| (**l).clone()),
516 r_new.unwrap_or_else(|| (**r).clone()),
517 ))
518 } else {
519 None
520 }
521 }
522 TLExpr::Not(e) => self.rewrite_innermost(e, stats).map(TLExpr::negate),
523 _ => None,
525 }
526 }
527
528 fn transform_children_bottom_up(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
530 match expr {
531 TLExpr::And(l, r) => TLExpr::and(
532 self.apply_bottom_up(l, stats),
533 self.apply_bottom_up(r, stats),
534 ),
535 TLExpr::Or(l, r) => TLExpr::or(
536 self.apply_bottom_up(l, stats),
537 self.apply_bottom_up(r, stats),
538 ),
539 TLExpr::Not(e) => TLExpr::negate(self.apply_bottom_up(e, stats)),
540 TLExpr::Imply(l, r) => TLExpr::imply(
541 self.apply_bottom_up(l, stats),
542 self.apply_bottom_up(r, stats),
543 ),
544 _ => expr.clone(),
546 }
547 }
548
549 fn transform_children_top_down(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
551 match expr {
552 TLExpr::And(l, r) => {
553 TLExpr::and(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
554 }
555 TLExpr::Or(l, r) => {
556 TLExpr::or(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
557 }
558 TLExpr::Not(e) => TLExpr::negate(self.apply_top_down(e, stats)),
559 TLExpr::Imply(l, r) => {
560 TLExpr::imply(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
561 }
562 _ => expr.clone(),
564 }
565 }
566
567 fn transform_children_fixpoint_per_node(
569 &mut self,
570 expr: &TLExpr,
571 stats: &mut RewriteStats,
572 ) -> TLExpr {
573 match expr {
574 TLExpr::And(l, r) => TLExpr::and(
575 self.apply_fixpoint_per_node(l, stats),
576 self.apply_fixpoint_per_node(r, stats),
577 ),
578 TLExpr::Or(l, r) => TLExpr::or(
579 self.apply_fixpoint_per_node(l, stats),
580 self.apply_fixpoint_per_node(r, stats),
581 ),
582 TLExpr::Not(e) => TLExpr::negate(self.apply_fixpoint_per_node(e, stats)),
583 TLExpr::Imply(l, r) => TLExpr::imply(
584 self.apply_fixpoint_per_node(l, stats),
585 self.apply_fixpoint_per_node(r, stats),
586 ),
587 _ => expr.clone(),
589 }
590 }
591
592 pub fn rule_statistics(&self) -> Vec<(&str, usize)> {
594 self.rules
595 .iter()
596 .map(|r| (r.name.as_str(), r.application_count()))
597 .collect()
598 }
599
600 pub fn reset_statistics(&mut self) {
602 for rule in &mut self.rules {
603 rule.reset_counter();
604 }
605 }
606}
607
608impl Default for AdvancedRewriteSystem {
609 fn default() -> Self {
610 Self::new()
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use crate::{TLExpr, Term};
618
619 #[test]
620 fn test_conditional_rule_basic() {
621 let mut rule = ConditionalRule::new(
622 "remove_double_neg",
623 |expr| {
624 if let TLExpr::Not(inner) = expr {
625 if let TLExpr::Not(inner_inner) = &**inner {
626 return Some((**inner_inner).clone());
627 }
628 }
629 None
630 },
631 |_| true,
632 );
633
634 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
635 let result = rule.apply(&expr).unwrap();
636
637 assert!(matches!(result, TLExpr::Pred { .. }));
638 assert_eq!(rule.application_count(), 1);
639 }
640
641 #[test]
642 fn test_priority_ordering() {
643 let mut system = AdvancedRewriteSystem::new();
644
645 system = system.add_rule(
647 ConditionalRule::new("low", |_| None, |_| true).with_priority(RulePriority::Low),
648 );
649 system = system.add_rule(
650 ConditionalRule::new("high", |_| None, |_| true).with_priority(RulePriority::High),
651 );
652 system = system.add_rule(
653 ConditionalRule::new("critical", |_| None, |_| true)
654 .with_priority(RulePriority::Critical),
655 );
656
657 assert_eq!(system.rules[0].priority, RulePriority::Critical);
659 assert_eq!(system.rules[1].priority, RulePriority::High);
660 assert_eq!(system.rules[2].priority, RulePriority::Low);
661 }
662
663 #[test]
664 fn test_bottom_up_strategy() {
665 let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
666 strategy: RewriteStrategy::BottomUp,
667 max_steps: 100,
668 ..Default::default()
669 });
670
671 system = system.add_rule(ConditionalRule::new(
673 "double_neg",
674 |expr| {
675 if let TLExpr::Not(inner) = expr {
676 if let TLExpr::Not(inner_inner) = &**inner {
677 return Some((**inner_inner).clone());
678 }
679 }
680 None
681 },
682 |_| true,
683 ));
684
685 let expr = TLExpr::negate(TLExpr::negate(TLExpr::negate(TLExpr::negate(
687 TLExpr::pred("P", vec![Term::var("x")]),
688 ))));
689
690 let (result, stats) = system.apply(&expr);
691
692 assert!(matches!(result, TLExpr::Pred { .. }));
694 assert_eq!(stats.rule_applications, 2); }
696
697 #[test]
698 fn test_cycle_detection() {
699 let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
700 strategy: RewriteStrategy::GlobalFixpoint,
701 detect_cycles: true,
702 max_steps: 1000,
703 ..Default::default()
704 });
705
706 system = system.add_rule(ConditionalRule::new(
708 "add_double_neg",
709 |expr| {
710 if let TLExpr::Pred { .. } = expr {
711 return Some(TLExpr::negate(TLExpr::negate(expr.clone())));
712 }
713 None
714 },
715 |_| true,
716 ));
717
718 system = system.add_rule(ConditionalRule::new(
719 "remove_double_neg",
720 |expr| {
721 if let TLExpr::Not(inner) = expr {
722 if let TLExpr::Not(inner_inner) = &**inner {
723 return Some((**inner_inner).clone());
724 }
725 }
726 None
727 },
728 |_| true,
729 ));
730
731 let expr = TLExpr::pred("P", vec![Term::var("x")]);
732 let (_result, stats) = system.apply(&expr);
733
734 assert!(stats.cycle_detected || stats.steps >= 1000);
736 }
737
738 #[test]
739 fn test_size_limit() {
740 let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
741 strategy: RewriteStrategy::Innermost, max_expr_size: Some(10),
743 detect_cycles: false, ..Default::default()
745 });
746
747 system = system.add_rule(ConditionalRule::new(
749 "duplicate",
750 |expr| {
751 if let TLExpr::Pred { .. } = expr {
752 return Some(TLExpr::and(expr.clone(), expr.clone()));
753 }
754 None
755 },
756 |_| true,
757 ));
758
759 let expr = TLExpr::pred("P", vec![Term::var("x")]);
760 let (_result, stats) = system.apply(&expr);
761
762 assert!(stats.size_limit_exceeded || stats.steps >= system.config.max_steps);
764 }
765
766 #[test]
767 fn test_rewrite_stats() {
768 let mut system = AdvancedRewriteSystem::new();
769
770 system = system.add_rule(ConditionalRule::new(
771 "test_rule",
772 |expr| {
773 if let TLExpr::Not(inner) = expr {
774 if let TLExpr::Not(inner_inner) = &**inner {
775 return Some((**inner_inner).clone());
776 }
777 }
778 None
779 },
780 |_| true,
781 ));
782
783 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
784 let (_result, stats) = system.apply(&expr);
785
786 assert!(stats.is_successful());
787 assert!(stats.reduction_percentage() > 0.0);
788 assert_eq!(stats.rule_applications, 1);
789 }
790}