1use std::collections::HashMap;
7
8use super::TLExpr;
9
10#[derive(Clone, Debug, PartialEq)]
12pub enum Pattern {
13 Var(String),
15 Constant(f64),
17 Pred { name: String, args: Vec<Pattern> },
19 And(Box<Pattern>, Box<Pattern>),
21 Or(Box<Pattern>, Box<Pattern>),
23 Not(Box<Pattern>),
25 Imply(Box<Pattern>, Box<Pattern>),
27 Any,
29 Add(Box<Pattern>, Box<Pattern>),
31 Sub(Box<Pattern>, Box<Pattern>),
33 Mul(Box<Pattern>, Box<Pattern>),
35 Div(Box<Pattern>, Box<Pattern>),
37 Pow(Box<Pattern>, Box<Pattern>),
39 Neg(Box<Pattern>),
41 Exp(Box<Pattern>),
43 Log(Box<Pattern>),
45 Sin(Box<Pattern>),
47 Cos(Box<Pattern>),
49 Tan(Box<Pattern>),
51}
52
53impl Pattern {
54 pub fn var(name: impl Into<String>) -> Self {
56 Pattern::Var(name.into())
57 }
58
59 pub fn constant(value: f64) -> Self {
61 Pattern::Constant(value)
62 }
63
64 pub fn any() -> Self {
66 Pattern::Any
67 }
68
69 pub fn pred(name: impl Into<String>, args: Vec<Pattern>) -> Self {
71 Pattern::Pred {
72 name: name.into(),
73 args,
74 }
75 }
76
77 pub fn and(left: Pattern, right: Pattern) -> Self {
79 Pattern::And(Box::new(left), Box::new(right))
80 }
81
82 pub fn or(left: Pattern, right: Pattern) -> Self {
84 Pattern::Or(Box::new(left), Box::new(right))
85 }
86
87 pub fn negation(pattern: Pattern) -> Self {
89 Pattern::Not(Box::new(pattern))
90 }
91
92 pub fn imply(left: Pattern, right: Pattern) -> Self {
94 Pattern::Imply(Box::new(left), Box::new(right))
95 }
96
97 #[allow(clippy::should_implement_trait)]
99 pub fn add(left: Pattern, right: Pattern) -> Self {
100 Pattern::Add(Box::new(left), Box::new(right))
101 }
102
103 #[allow(clippy::should_implement_trait)]
105 pub fn sub(left: Pattern, right: Pattern) -> Self {
106 Pattern::Sub(Box::new(left), Box::new(right))
107 }
108
109 #[allow(clippy::should_implement_trait)]
111 pub fn mul(left: Pattern, right: Pattern) -> Self {
112 Pattern::Mul(Box::new(left), Box::new(right))
113 }
114
115 #[allow(clippy::should_implement_trait)]
117 pub fn div(left: Pattern, right: Pattern) -> Self {
118 Pattern::Div(Box::new(left), Box::new(right))
119 }
120
121 pub fn pow(left: Pattern, right: Pattern) -> Self {
123 Pattern::Pow(Box::new(left), Box::new(right))
124 }
125
126 #[allow(clippy::should_implement_trait)]
128 pub fn neg(inner: Pattern) -> Self {
129 Pattern::Neg(Box::new(inner))
130 }
131
132 pub fn exp(inner: Pattern) -> Self {
134 Pattern::Exp(Box::new(inner))
135 }
136
137 pub fn log(inner: Pattern) -> Self {
139 Pattern::Log(Box::new(inner))
140 }
141
142 pub fn sin(inner: Pattern) -> Self {
144 Pattern::Sin(Box::new(inner))
145 }
146
147 pub fn cos(inner: Pattern) -> Self {
149 Pattern::Cos(Box::new(inner))
150 }
151
152 pub fn tan(inner: Pattern) -> Self {
154 Pattern::Tan(Box::new(inner))
155 }
156
157 pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, TLExpr>> {
159 let mut bindings = HashMap::new();
160 if self.matches_recursive(expr, &mut bindings) {
161 Some(bindings)
162 } else {
163 None
164 }
165 }
166
167 fn matches_recursive(&self, expr: &TLExpr, bindings: &mut HashMap<String, TLExpr>) -> bool {
168 match (self, expr) {
169 (Pattern::Any, _) => true,
171
172 (Pattern::Var(var_name), _) => {
174 if let Some(bound_expr) = bindings.get(var_name) {
175 bound_expr == expr
176 } else {
177 bindings.insert(var_name.clone(), expr.clone());
178 true
179 }
180 }
181
182 (Pattern::Constant(pv), TLExpr::Constant(ev)) => (pv - ev).abs() < f64::EPSILON,
184
185 (
187 Pattern::Pred {
188 name: pname,
189 args: pargs,
190 },
191 TLExpr::Pred {
192 name: ename,
193 args: eargs,
194 },
195 ) => {
196 if pname != ename || pargs.len() != eargs.len() {
197 return false;
198 }
199 pargs.len() == eargs.len()
202 }
203
204 (Pattern::And(pl, pr), TLExpr::And(el, er))
206 | (Pattern::Or(pl, pr), TLExpr::Or(el, er))
207 | (Pattern::Imply(pl, pr), TLExpr::Imply(el, er)) => {
208 pl.matches_recursive(el, bindings) && pr.matches_recursive(er, bindings)
209 }
210
211 (Pattern::Add(pl, pr), TLExpr::Add(el, er))
213 | (Pattern::Sub(pl, pr), TLExpr::Sub(el, er))
214 | (Pattern::Mul(pl, pr), TLExpr::Mul(el, er))
215 | (Pattern::Div(pl, pr), TLExpr::Div(el, er))
216 | (Pattern::Pow(pl, pr), TLExpr::Pow(el, er)) => {
217 pl.matches_recursive(el, bindings) && pr.matches_recursive(er, bindings)
218 }
219
220 (Pattern::Not(p), TLExpr::Not(e)) => p.matches_recursive(e, bindings),
222
223 (Pattern::Neg(p), TLExpr::Sub(zero_expr, e)) => {
225 if let TLExpr::Constant(v) = zero_expr.as_ref() {
226 v.abs() < 1e-15 && p.matches_recursive(e, bindings)
227 } else {
228 false
229 }
230 }
231
232 (Pattern::Exp(p), TLExpr::Exp(e))
234 | (Pattern::Log(p), TLExpr::Log(e))
235 | (Pattern::Sin(p), TLExpr::Sin(e))
236 | (Pattern::Cos(p), TLExpr::Cos(e))
237 | (Pattern::Tan(p), TLExpr::Tan(e)) => p.matches_recursive(e, bindings),
238
239 _ => false,
240 }
241 }
242}
243
244#[derive(Clone, Debug)]
246pub struct RewriteRule {
247 pub pattern: Pattern,
249 pub template: fn(&HashMap<String, TLExpr>) -> TLExpr,
251 pub name: Option<String>,
253}
254
255impl RewriteRule {
256 pub fn new(pattern: Pattern, template: fn(&HashMap<String, TLExpr>) -> TLExpr) -> Self {
258 Self {
259 pattern,
260 template,
261 name: None,
262 }
263 }
264
265 pub fn named(
267 name: impl Into<String>,
268 pattern: Pattern,
269 template: fn(&HashMap<String, TLExpr>) -> TLExpr,
270 ) -> Self {
271 Self {
272 pattern,
273 template,
274 name: Some(name.into()),
275 }
276 }
277
278 pub fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
280 self.pattern
281 .matches(expr)
282 .map(|bindings| (self.template)(&bindings))
283 }
284}
285
286#[derive(Clone, Debug, Default)]
288pub struct RewriteSystem {
289 rules: Vec<RewriteRule>,
290}
291
292impl RewriteSystem {
293 pub fn new() -> Self {
295 Self::default()
296 }
297
298 pub fn add_rule(mut self, rule: RewriteRule) -> Self {
300 self.rules.push(rule);
301 self
302 }
303
304 pub fn with_logic_equivalences() -> Self {
306 let mut system = Self::new();
307
308 system = system.add_rule(RewriteRule::named(
310 "double_negation",
311 Pattern::negation(Pattern::negation(Pattern::var("A"))),
312 |bindings| {
313 bindings
314 .get("A")
315 .expect("binding 'A' must exist when pattern matched")
316 .clone()
317 },
318 ));
319
320 system = system.add_rule(RewriteRule::named(
322 "demorgan_and",
323 Pattern::negation(Pattern::and(Pattern::var("A"), Pattern::var("B"))),
324 |bindings| {
325 TLExpr::or(
326 TLExpr::negate(
327 bindings
328 .get("A")
329 .expect("binding 'A' must exist when pattern matched")
330 .clone(),
331 ),
332 TLExpr::negate(
333 bindings
334 .get("B")
335 .expect("binding 'B' must exist when pattern matched")
336 .clone(),
337 ),
338 )
339 },
340 ));
341
342 system = system.add_rule(RewriteRule::named(
344 "demorgan_or",
345 Pattern::negation(Pattern::or(Pattern::var("A"), Pattern::var("B"))),
346 |bindings| {
347 TLExpr::and(
348 TLExpr::negate(
349 bindings
350 .get("A")
351 .expect("binding 'A' must exist when pattern matched")
352 .clone(),
353 ),
354 TLExpr::negate(
355 bindings
356 .get("B")
357 .expect("binding 'B' must exist when pattern matched")
358 .clone(),
359 ),
360 )
361 },
362 ));
363
364 system = system.add_rule(RewriteRule::named(
366 "implication_expansion",
367 Pattern::imply(Pattern::var("A"), Pattern::var("B")),
368 |bindings| {
369 TLExpr::or(
370 TLExpr::negate(
371 bindings
372 .get("A")
373 .expect("binding 'A' must exist when pattern matched")
374 .clone(),
375 ),
376 bindings
377 .get("B")
378 .expect("binding 'B' must exist when pattern matched")
379 .clone(),
380 )
381 },
382 ));
383
384 system
385 }
386
387 pub fn apply_once(&self, expr: &TLExpr) -> Option<TLExpr> {
389 for rule in &self.rules {
390 if let Some(result) = rule.apply(expr) {
391 return Some(result);
392 }
393 }
394 None
395 }
396
397 pub fn apply_recursive(&self, expr: &TLExpr) -> TLExpr {
399 if let Some(rewritten) = self.apply_once(expr) {
401 return self.apply_recursive(&rewritten);
402 }
403
404 match expr {
406 TLExpr::And(l, r) => TLExpr::and(self.apply_recursive(l), self.apply_recursive(r)),
407 TLExpr::Or(l, r) => TLExpr::or(self.apply_recursive(l), self.apply_recursive(r)),
408 TLExpr::Not(e) => TLExpr::negate(self.apply_recursive(e)),
409 TLExpr::Imply(l, r) => TLExpr::imply(self.apply_recursive(l), self.apply_recursive(r)),
410 TLExpr::Score(e) => TLExpr::score(self.apply_recursive(e)),
411
412 TLExpr::Add(l, r) => TLExpr::add(self.apply_recursive(l), self.apply_recursive(r)),
414 TLExpr::Sub(l, r) => TLExpr::sub(self.apply_recursive(l), self.apply_recursive(r)),
415 TLExpr::Mul(l, r) => TLExpr::mul(self.apply_recursive(l), self.apply_recursive(r)),
416 TLExpr::Div(l, r) => TLExpr::div(self.apply_recursive(l), self.apply_recursive(r)),
417 TLExpr::Pow(l, r) => TLExpr::pow(self.apply_recursive(l), self.apply_recursive(r)),
418 TLExpr::Mod(l, r) => TLExpr::modulo(self.apply_recursive(l), self.apply_recursive(r)),
419 TLExpr::Min(l, r) => TLExpr::min(self.apply_recursive(l), self.apply_recursive(r)),
420 TLExpr::Max(l, r) => TLExpr::max(self.apply_recursive(l), self.apply_recursive(r)),
421
422 TLExpr::Eq(l, r) => TLExpr::eq(self.apply_recursive(l), self.apply_recursive(r)),
424 TLExpr::Lt(l, r) => TLExpr::lt(self.apply_recursive(l), self.apply_recursive(r)),
425 TLExpr::Gt(l, r) => TLExpr::gt(self.apply_recursive(l), self.apply_recursive(r)),
426 TLExpr::Lte(l, r) => TLExpr::lte(self.apply_recursive(l), self.apply_recursive(r)),
427 TLExpr::Gte(l, r) => TLExpr::gte(self.apply_recursive(l), self.apply_recursive(r)),
428
429 TLExpr::Abs(e) => TLExpr::abs(self.apply_recursive(e)),
431 TLExpr::Floor(e) => TLExpr::floor(self.apply_recursive(e)),
432 TLExpr::Ceil(e) => TLExpr::ceil(self.apply_recursive(e)),
433 TLExpr::Round(e) => TLExpr::round(self.apply_recursive(e)),
434 TLExpr::Sqrt(e) => TLExpr::sqrt(self.apply_recursive(e)),
435 TLExpr::Exp(e) => TLExpr::exp(self.apply_recursive(e)),
436 TLExpr::Log(e) => TLExpr::log(self.apply_recursive(e)),
437 TLExpr::Sin(e) => TLExpr::sin(self.apply_recursive(e)),
438 TLExpr::Cos(e) => TLExpr::cos(self.apply_recursive(e)),
439 TLExpr::Tan(e) => TLExpr::tan(self.apply_recursive(e)),
440
441 TLExpr::Box(e) => TLExpr::modal_box(self.apply_recursive(e)),
443 TLExpr::Diamond(e) => TLExpr::modal_diamond(self.apply_recursive(e)),
444 TLExpr::Next(e) => TLExpr::next(self.apply_recursive(e)),
445 TLExpr::Eventually(e) => TLExpr::eventually(self.apply_recursive(e)),
446 TLExpr::Always(e) => TLExpr::always(self.apply_recursive(e)),
447 TLExpr::Until { before, after } => {
448 TLExpr::until(self.apply_recursive(before), self.apply_recursive(after))
449 }
450 TLExpr::Release { released, releaser } => TLExpr::release(
451 self.apply_recursive(released),
452 self.apply_recursive(releaser),
453 ),
454 TLExpr::WeakUntil { before, after } => {
455 TLExpr::weak_until(self.apply_recursive(before), self.apply_recursive(after))
456 }
457 TLExpr::StrongRelease { released, releaser } => TLExpr::strong_release(
458 self.apply_recursive(released),
459 self.apply_recursive(releaser),
460 ),
461
462 TLExpr::Exists { var, domain, body } => {
464 TLExpr::exists(var.clone(), domain.clone(), self.apply_recursive(body))
465 }
466 TLExpr::ForAll { var, domain, body } => {
467 TLExpr::forall(var.clone(), domain.clone(), self.apply_recursive(body))
468 }
469 TLExpr::SoftExists {
470 var,
471 domain,
472 body,
473 temperature,
474 } => TLExpr::soft_exists(
475 var.clone(),
476 domain.clone(),
477 self.apply_recursive(body),
478 *temperature,
479 ),
480 TLExpr::SoftForAll {
481 var,
482 domain,
483 body,
484 temperature,
485 } => TLExpr::soft_forall(
486 var.clone(),
487 domain.clone(),
488 self.apply_recursive(body),
489 *temperature,
490 ),
491
492 TLExpr::Aggregate {
494 op,
495 var,
496 domain,
497 body,
498 group_by,
499 } => {
500 if let Some(group_vars) = group_by {
501 TLExpr::aggregate_with_group_by(
502 op.clone(),
503 var.clone(),
504 domain.clone(),
505 self.apply_recursive(body),
506 group_vars.clone(),
507 )
508 } else {
509 TLExpr::aggregate(
510 op.clone(),
511 var.clone(),
512 domain.clone(),
513 self.apply_recursive(body),
514 )
515 }
516 }
517
518 TLExpr::IfThenElse {
520 condition,
521 then_branch,
522 else_branch,
523 } => TLExpr::if_then_else(
524 self.apply_recursive(condition),
525 self.apply_recursive(then_branch),
526 self.apply_recursive(else_branch),
527 ),
528 TLExpr::Let { var, value, body } => TLExpr::let_binding(
529 var.clone(),
530 self.apply_recursive(value),
531 self.apply_recursive(body),
532 ),
533
534 TLExpr::TNorm { kind, left, right } => TLExpr::tnorm(
536 *kind,
537 self.apply_recursive(left),
538 self.apply_recursive(right),
539 ),
540 TLExpr::TCoNorm { kind, left, right } => TLExpr::tconorm(
541 *kind,
542 self.apply_recursive(left),
543 self.apply_recursive(right),
544 ),
545 TLExpr::FuzzyNot { kind, expr } => TLExpr::fuzzy_not(*kind, self.apply_recursive(expr)),
546 TLExpr::FuzzyImplication {
547 kind,
548 premise,
549 conclusion,
550 } => TLExpr::fuzzy_imply(
551 *kind,
552 self.apply_recursive(premise),
553 self.apply_recursive(conclusion),
554 ),
555
556 TLExpr::WeightedRule { weight, rule } => {
558 TLExpr::weighted_rule(*weight, self.apply_recursive(rule))
559 }
560 TLExpr::ProbabilisticChoice { alternatives } => TLExpr::probabilistic_choice(
561 alternatives
562 .iter()
563 .map(|(p, e)| (*p, self.apply_recursive(e)))
564 .collect(),
565 ),
566
567 TLExpr::Lambda {
569 var,
570 var_type,
571 body,
572 } => TLExpr::lambda(var.clone(), var_type.clone(), self.apply_recursive(body)),
573 TLExpr::Apply { function, argument } => TLExpr::apply(
574 self.apply_recursive(function),
575 self.apply_recursive(argument),
576 ),
577 TLExpr::SetMembership { element, set } => {
578 TLExpr::set_membership(self.apply_recursive(element), self.apply_recursive(set))
579 }
580 TLExpr::SetUnion { left, right } => {
581 TLExpr::set_union(self.apply_recursive(left), self.apply_recursive(right))
582 }
583 TLExpr::SetIntersection { left, right } => {
584 TLExpr::set_intersection(self.apply_recursive(left), self.apply_recursive(right))
585 }
586 TLExpr::SetDifference { left, right } => {
587 TLExpr::set_difference(self.apply_recursive(left), self.apply_recursive(right))
588 }
589 TLExpr::SetCardinality { set } => TLExpr::set_cardinality(self.apply_recursive(set)),
590 TLExpr::EmptySet => expr.clone(),
591 TLExpr::SetComprehension {
592 var,
593 domain,
594 condition,
595 } => TLExpr::set_comprehension(
596 var.clone(),
597 domain.clone(),
598 self.apply_recursive(condition),
599 ),
600 TLExpr::CountingExists {
601 var,
602 domain,
603 body,
604 min_count,
605 } => TLExpr::counting_exists(
606 var.clone(),
607 domain.clone(),
608 self.apply_recursive(body),
609 *min_count,
610 ),
611 TLExpr::CountingForAll {
612 var,
613 domain,
614 body,
615 min_count,
616 } => TLExpr::counting_forall(
617 var.clone(),
618 domain.clone(),
619 self.apply_recursive(body),
620 *min_count,
621 ),
622 TLExpr::ExactCount {
623 var,
624 domain,
625 body,
626 count,
627 } => TLExpr::exact_count(
628 var.clone(),
629 domain.clone(),
630 self.apply_recursive(body),
631 *count,
632 ),
633 TLExpr::Majority { var, domain, body } => {
634 TLExpr::majority(var.clone(), domain.clone(), self.apply_recursive(body))
635 }
636 TLExpr::LeastFixpoint { var, body } => {
637 TLExpr::least_fixpoint(var.clone(), self.apply_recursive(body))
638 }
639 TLExpr::GreatestFixpoint { var, body } => {
640 TLExpr::greatest_fixpoint(var.clone(), self.apply_recursive(body))
641 }
642 TLExpr::Nominal { .. } => expr.clone(),
643 TLExpr::At { nominal, formula } => {
644 TLExpr::at(nominal.clone(), self.apply_recursive(formula))
645 }
646 TLExpr::Somewhere { formula } => TLExpr::somewhere(self.apply_recursive(formula)),
647 TLExpr::Everywhere { formula } => TLExpr::everywhere(self.apply_recursive(formula)),
648 TLExpr::AllDifferent { .. } => expr.clone(),
649 TLExpr::GlobalCardinality {
650 variables,
651 values,
652 min_occurrences,
653 max_occurrences,
654 } => TLExpr::global_cardinality(
655 variables.clone(),
656 values.iter().map(|v| self.apply_recursive(v)).collect(),
657 min_occurrences.clone(),
658 max_occurrences.clone(),
659 ),
660 TLExpr::Abducible { .. } => expr.clone(),
661 TLExpr::Explain { formula } => TLExpr::explain(self.apply_recursive(formula)),
662 TLExpr::SymbolLiteral(_) => expr.clone(),
663 TLExpr::Match { scrutinee, arms } => TLExpr::Match {
664 scrutinee: Box::new(self.apply_recursive(scrutinee)),
665 arms: arms
666 .iter()
667 .map(|(p, b)| (p.clone(), Box::new(self.apply_recursive(b))))
668 .collect(),
669 },
670
671 TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
673 }
674 }
675
676 pub fn apply_until_fixpoint(&self, expr: &TLExpr) -> TLExpr {
678 let mut current = expr.clone();
679 loop {
680 let next = self.apply_recursive(¤t);
681 if next == current {
682 return current;
683 }
684 current = next;
685 }
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692 use crate::Term;
693
694 #[test]
695 fn test_pattern_var_match() {
696 let pattern = Pattern::var("x");
697 let expr = TLExpr::pred("P", vec![Term::var("a")]);
698
699 let bindings = pattern.matches(&expr).expect("unwrap");
700 assert_eq!(bindings.get("x"), Some(&expr));
701 }
702
703 #[test]
704 fn test_pattern_constant_match() {
705 let pattern = Pattern::constant(42.0);
706 let expr = TLExpr::constant(42.0);
707
708 assert!(pattern.matches(&expr).is_some());
709 }
710
711 #[test]
712 fn test_pattern_and_match() {
713 let pattern = Pattern::and(Pattern::var("A"), Pattern::var("B"));
714 let expr = TLExpr::and(
715 TLExpr::pred("P", vec![Term::var("x")]),
716 TLExpr::pred("Q", vec![Term::var("y")]),
717 );
718
719 let bindings = pattern.matches(&expr).expect("unwrap");
720 assert!(bindings.contains_key("A"));
721 assert!(bindings.contains_key("B"));
722 }
723
724 #[test]
725 fn test_pattern_not_match() {
726 let pattern = Pattern::negation(Pattern::var("A"));
727 let expr = TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]));
728
729 let bindings = pattern.matches(&expr).expect("unwrap");
730 assert!(bindings.contains_key("A"));
731 }
732
733 #[test]
734 fn test_double_negation_rule() {
735 let rule = RewriteRule::new(
736 Pattern::negation(Pattern::negation(Pattern::var("A"))),
737 |bindings| {
738 bindings
739 .get("A")
740 .expect("binding 'A' must exist when pattern matched")
741 .clone()
742 },
743 );
744
745 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
746 let result = rule.apply(&expr).expect("unwrap");
747
748 assert!(matches!(result, TLExpr::Pred { .. }));
749 }
750
751 #[test]
752 fn test_rewrite_system_double_negation() {
753 let system = RewriteSystem::new().add_rule(RewriteRule::new(
754 Pattern::negation(Pattern::negation(Pattern::var("A"))),
755 |bindings| {
756 bindings
757 .get("A")
758 .expect("binding 'A' must exist when pattern matched")
759 .clone()
760 },
761 ));
762
763 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
764 let result = system.apply_recursive(&expr);
765
766 assert!(matches!(result, TLExpr::Pred { .. }));
767 }
768
769 #[test]
770 fn test_logic_equivalences_system() {
771 let system = RewriteSystem::with_logic_equivalences();
772
773 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
775 let result = system.apply_recursive(&expr);
776 assert!(matches!(result, TLExpr::Pred { .. }));
777
778 let expr = TLExpr::negate(TLExpr::and(
780 TLExpr::pred("P", vec![Term::var("x")]),
781 TLExpr::pred("Q", vec![Term::var("y")]),
782 ));
783 let result = system.apply_recursive(&expr);
784 assert!(matches!(result, TLExpr::Or(_, _)));
785 }
786
787 #[test]
788 fn test_nested_rewriting() {
789 let system = RewriteSystem::with_logic_equivalences();
790
791 let expr = TLExpr::negate(TLExpr::and(
793 TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]))),
794 TLExpr::pred("Q", vec![Term::var("y")]),
795 ));
796
797 let result = system.apply_until_fixpoint(&expr);
798 assert!(matches!(result, TLExpr::Or(_, _)));
800 }
801
802 #[test]
803 fn test_implication_expansion() {
804 let system = RewriteSystem::with_logic_equivalences();
805
806 let expr = TLExpr::imply(
808 TLExpr::pred("P", vec![Term::var("x")]),
809 TLExpr::pred("Q", vec![Term::var("y")]),
810 );
811
812 let result = system.apply_recursive(&expr);
813 assert!(matches!(result, TLExpr::Or(_, _)));
814 }
815
816 #[test]
817 fn test_pattern_add_match() {
818 let pattern = Pattern::add(Pattern::var("x"), Pattern::var("y"));
819 let expr = TLExpr::add(TLExpr::constant(1.0), TLExpr::constant(2.0));
820
821 let bindings = pattern
822 .matches(&expr)
823 .expect("Pattern::Add should match TLExpr::Add");
824 assert_eq!(bindings.get("x"), Some(&TLExpr::constant(1.0)));
825 assert_eq!(bindings.get("y"), Some(&TLExpr::constant(2.0)));
826 }
827
828 #[test]
829 fn test_pattern_exp_match() {
830 let pattern = Pattern::exp(Pattern::var("x"));
831 let expr = TLExpr::exp(TLExpr::constant(1.0));
832
833 let bindings = pattern
834 .matches(&expr)
835 .expect("Pattern::Exp should match TLExpr::Exp");
836 assert_eq!(bindings.get("x"), Some(&TLExpr::constant(1.0)));
837 }
838
839 #[test]
840 fn test_pattern_neg_match() {
841 let pattern = Pattern::neg(Pattern::var("x"));
842 let expr = TLExpr::sub(TLExpr::constant(0.0), TLExpr::constant(5.0));
844
845 let bindings = pattern
846 .matches(&expr)
847 .expect("Pattern::Neg should match TLExpr::Sub(0, x)");
848 assert_eq!(bindings.get("x"), Some(&TLExpr::constant(5.0)));
849 }
850
851 #[test]
852 fn test_pattern_add_does_not_match_mul() {
853 let pattern = Pattern::add(Pattern::var("x"), Pattern::var("y"));
854 let expr = TLExpr::mul(TLExpr::constant(1.0), TLExpr::constant(2.0));
855
856 assert!(pattern.matches(&expr).is_none());
857 }
858
859 #[test]
860 fn test_pattern_sin_cos_tan_match() {
861 let sin_pat = Pattern::sin(Pattern::var("a"));
862 let cos_pat = Pattern::cos(Pattern::var("a"));
863 let tan_pat = Pattern::tan(Pattern::var("a"));
864
865 let sin_expr = TLExpr::sin(TLExpr::constant(0.5));
866 let cos_expr = TLExpr::cos(TLExpr::constant(0.5));
867 let tan_expr = TLExpr::tan(TLExpr::constant(0.5));
868
869 assert!(sin_pat.matches(&sin_expr).is_some());
870 assert!(cos_pat.matches(&cos_expr).is_some());
871 assert!(tan_pat.matches(&tan_expr).is_some());
872
873 assert!(sin_pat.matches(&cos_expr).is_none());
875 assert!(cos_pat.matches(&tan_expr).is_none());
876 }
877
878 #[test]
879 fn test_pattern_neg_nonzero_constant_no_match() {
880 let pattern = Pattern::neg(Pattern::var("x"));
881 let expr = TLExpr::sub(TLExpr::constant(1.0), TLExpr::constant(5.0));
883
884 assert!(pattern.matches(&expr).is_none());
885 }
886}