1use std::collections::HashMap;
7
8use super::TLExpr;
9
10#[derive(Clone, Debug, PartialEq)]
12pub enum Pattern {
13 Var(String),
15 Constant(f64),
17 Pred { name: String, args: Vec<Pattern> },
19 And(Box<Pattern>, Box<Pattern>),
21 Or(Box<Pattern>, Box<Pattern>),
23 Not(Box<Pattern>),
25 Imply(Box<Pattern>, Box<Pattern>),
27 Any,
29}
30
31impl Pattern {
32 pub fn var(name: impl Into<String>) -> Self {
34 Pattern::Var(name.into())
35 }
36
37 pub fn constant(value: f64) -> Self {
39 Pattern::Constant(value)
40 }
41
42 pub fn any() -> Self {
44 Pattern::Any
45 }
46
47 pub fn pred(name: impl Into<String>, args: Vec<Pattern>) -> Self {
49 Pattern::Pred {
50 name: name.into(),
51 args,
52 }
53 }
54
55 pub fn and(left: Pattern, right: Pattern) -> Self {
57 Pattern::And(Box::new(left), Box::new(right))
58 }
59
60 pub fn or(left: Pattern, right: Pattern) -> Self {
62 Pattern::Or(Box::new(left), Box::new(right))
63 }
64
65 pub fn negation(pattern: Pattern) -> Self {
67 Pattern::Not(Box::new(pattern))
68 }
69
70 pub fn imply(left: Pattern, right: Pattern) -> Self {
72 Pattern::Imply(Box::new(left), Box::new(right))
73 }
74
75 pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, TLExpr>> {
77 let mut bindings = HashMap::new();
78 if self.matches_recursive(expr, &mut bindings) {
79 Some(bindings)
80 } else {
81 None
82 }
83 }
84
85 fn matches_recursive(&self, expr: &TLExpr, bindings: &mut HashMap<String, TLExpr>) -> bool {
86 match (self, expr) {
87 (Pattern::Any, _) => true,
89
90 (Pattern::Var(var_name), _) => {
92 if let Some(bound_expr) = bindings.get(var_name) {
93 bound_expr == expr
94 } else {
95 bindings.insert(var_name.clone(), expr.clone());
96 true
97 }
98 }
99
100 (Pattern::Constant(pv), TLExpr::Constant(ev)) => (pv - ev).abs() < f64::EPSILON,
102
103 (
105 Pattern::Pred {
106 name: pname,
107 args: pargs,
108 },
109 TLExpr::Pred {
110 name: ename,
111 args: eargs,
112 },
113 ) => {
114 if pname != ename || pargs.len() != eargs.len() {
115 return false;
116 }
117 pargs.len() == eargs.len()
120 }
121
122 (Pattern::And(pl, pr), TLExpr::And(el, er))
124 | (Pattern::Or(pl, pr), TLExpr::Or(el, er))
125 | (Pattern::Imply(pl, pr), TLExpr::Imply(el, er)) => {
126 pl.matches_recursive(el, bindings) && pr.matches_recursive(er, bindings)
127 }
128
129 (Pattern::Not(p), TLExpr::Not(e)) => p.matches_recursive(e, bindings),
131
132 _ => false,
133 }
134 }
135}
136
137#[derive(Clone, Debug)]
139pub struct RewriteRule {
140 pub pattern: Pattern,
142 pub template: fn(&HashMap<String, TLExpr>) -> TLExpr,
144 pub name: Option<String>,
146}
147
148impl RewriteRule {
149 pub fn new(pattern: Pattern, template: fn(&HashMap<String, TLExpr>) -> TLExpr) -> Self {
151 Self {
152 pattern,
153 template,
154 name: None,
155 }
156 }
157
158 pub fn named(
160 name: impl Into<String>,
161 pattern: Pattern,
162 template: fn(&HashMap<String, TLExpr>) -> TLExpr,
163 ) -> Self {
164 Self {
165 pattern,
166 template,
167 name: Some(name.into()),
168 }
169 }
170
171 pub fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
173 self.pattern
174 .matches(expr)
175 .map(|bindings| (self.template)(&bindings))
176 }
177}
178
179#[derive(Clone, Debug, Default)]
181pub struct RewriteSystem {
182 rules: Vec<RewriteRule>,
183}
184
185impl RewriteSystem {
186 pub fn new() -> Self {
188 Self::default()
189 }
190
191 pub fn add_rule(mut self, rule: RewriteRule) -> Self {
193 self.rules.push(rule);
194 self
195 }
196
197 pub fn with_logic_equivalences() -> Self {
199 let mut system = Self::new();
200
201 system = system.add_rule(RewriteRule::named(
203 "double_negation",
204 Pattern::negation(Pattern::negation(Pattern::var("A"))),
205 |bindings| bindings.get("A").unwrap().clone(),
206 ));
207
208 system = system.add_rule(RewriteRule::named(
210 "demorgan_and",
211 Pattern::negation(Pattern::and(Pattern::var("A"), Pattern::var("B"))),
212 |bindings| {
213 TLExpr::or(
214 TLExpr::negate(bindings.get("A").unwrap().clone()),
215 TLExpr::negate(bindings.get("B").unwrap().clone()),
216 )
217 },
218 ));
219
220 system = system.add_rule(RewriteRule::named(
222 "demorgan_or",
223 Pattern::negation(Pattern::or(Pattern::var("A"), Pattern::var("B"))),
224 |bindings| {
225 TLExpr::and(
226 TLExpr::negate(bindings.get("A").unwrap().clone()),
227 TLExpr::negate(bindings.get("B").unwrap().clone()),
228 )
229 },
230 ));
231
232 system = system.add_rule(RewriteRule::named(
234 "implication_expansion",
235 Pattern::imply(Pattern::var("A"), Pattern::var("B")),
236 |bindings| {
237 TLExpr::or(
238 TLExpr::negate(bindings.get("A").unwrap().clone()),
239 bindings.get("B").unwrap().clone(),
240 )
241 },
242 ));
243
244 system
245 }
246
247 pub fn apply_once(&self, expr: &TLExpr) -> Option<TLExpr> {
249 for rule in &self.rules {
250 if let Some(result) = rule.apply(expr) {
251 return Some(result);
252 }
253 }
254 None
255 }
256
257 pub fn apply_recursive(&self, expr: &TLExpr) -> TLExpr {
259 if let Some(rewritten) = self.apply_once(expr) {
261 return self.apply_recursive(&rewritten);
262 }
263
264 match expr {
266 TLExpr::And(l, r) => TLExpr::and(self.apply_recursive(l), self.apply_recursive(r)),
267 TLExpr::Or(l, r) => TLExpr::or(self.apply_recursive(l), self.apply_recursive(r)),
268 TLExpr::Not(e) => TLExpr::negate(self.apply_recursive(e)),
269 TLExpr::Imply(l, r) => TLExpr::imply(self.apply_recursive(l), self.apply_recursive(r)),
270 TLExpr::Score(e) => TLExpr::score(self.apply_recursive(e)),
271
272 TLExpr::Add(l, r) => TLExpr::add(self.apply_recursive(l), self.apply_recursive(r)),
274 TLExpr::Sub(l, r) => TLExpr::sub(self.apply_recursive(l), self.apply_recursive(r)),
275 TLExpr::Mul(l, r) => TLExpr::mul(self.apply_recursive(l), self.apply_recursive(r)),
276 TLExpr::Div(l, r) => TLExpr::div(self.apply_recursive(l), self.apply_recursive(r)),
277 TLExpr::Pow(l, r) => TLExpr::pow(self.apply_recursive(l), self.apply_recursive(r)),
278 TLExpr::Mod(l, r) => TLExpr::modulo(self.apply_recursive(l), self.apply_recursive(r)),
279 TLExpr::Min(l, r) => TLExpr::min(self.apply_recursive(l), self.apply_recursive(r)),
280 TLExpr::Max(l, r) => TLExpr::max(self.apply_recursive(l), self.apply_recursive(r)),
281
282 TLExpr::Eq(l, r) => TLExpr::eq(self.apply_recursive(l), self.apply_recursive(r)),
284 TLExpr::Lt(l, r) => TLExpr::lt(self.apply_recursive(l), self.apply_recursive(r)),
285 TLExpr::Gt(l, r) => TLExpr::gt(self.apply_recursive(l), self.apply_recursive(r)),
286 TLExpr::Lte(l, r) => TLExpr::lte(self.apply_recursive(l), self.apply_recursive(r)),
287 TLExpr::Gte(l, r) => TLExpr::gte(self.apply_recursive(l), self.apply_recursive(r)),
288
289 TLExpr::Abs(e) => TLExpr::abs(self.apply_recursive(e)),
291 TLExpr::Floor(e) => TLExpr::floor(self.apply_recursive(e)),
292 TLExpr::Ceil(e) => TLExpr::ceil(self.apply_recursive(e)),
293 TLExpr::Round(e) => TLExpr::round(self.apply_recursive(e)),
294 TLExpr::Sqrt(e) => TLExpr::sqrt(self.apply_recursive(e)),
295 TLExpr::Exp(e) => TLExpr::exp(self.apply_recursive(e)),
296 TLExpr::Log(e) => TLExpr::log(self.apply_recursive(e)),
297 TLExpr::Sin(e) => TLExpr::sin(self.apply_recursive(e)),
298 TLExpr::Cos(e) => TLExpr::cos(self.apply_recursive(e)),
299 TLExpr::Tan(e) => TLExpr::tan(self.apply_recursive(e)),
300
301 TLExpr::Box(e) => TLExpr::modal_box(self.apply_recursive(e)),
303 TLExpr::Diamond(e) => TLExpr::modal_diamond(self.apply_recursive(e)),
304 TLExpr::Next(e) => TLExpr::next(self.apply_recursive(e)),
305 TLExpr::Eventually(e) => TLExpr::eventually(self.apply_recursive(e)),
306 TLExpr::Always(e) => TLExpr::always(self.apply_recursive(e)),
307 TLExpr::Until { before, after } => {
308 TLExpr::until(self.apply_recursive(before), self.apply_recursive(after))
309 }
310 TLExpr::Release { released, releaser } => TLExpr::release(
311 self.apply_recursive(released),
312 self.apply_recursive(releaser),
313 ),
314 TLExpr::WeakUntil { before, after } => {
315 TLExpr::weak_until(self.apply_recursive(before), self.apply_recursive(after))
316 }
317 TLExpr::StrongRelease { released, releaser } => TLExpr::strong_release(
318 self.apply_recursive(released),
319 self.apply_recursive(releaser),
320 ),
321
322 TLExpr::Exists { var, domain, body } => {
324 TLExpr::exists(var.clone(), domain.clone(), self.apply_recursive(body))
325 }
326 TLExpr::ForAll { var, domain, body } => {
327 TLExpr::forall(var.clone(), domain.clone(), self.apply_recursive(body))
328 }
329 TLExpr::SoftExists {
330 var,
331 domain,
332 body,
333 temperature,
334 } => TLExpr::soft_exists(
335 var.clone(),
336 domain.clone(),
337 self.apply_recursive(body),
338 *temperature,
339 ),
340 TLExpr::SoftForAll {
341 var,
342 domain,
343 body,
344 temperature,
345 } => TLExpr::soft_forall(
346 var.clone(),
347 domain.clone(),
348 self.apply_recursive(body),
349 *temperature,
350 ),
351
352 TLExpr::Aggregate {
354 op,
355 var,
356 domain,
357 body,
358 group_by,
359 } => {
360 if let Some(group_vars) = group_by {
361 TLExpr::aggregate_with_group_by(
362 op.clone(),
363 var.clone(),
364 domain.clone(),
365 self.apply_recursive(body),
366 group_vars.clone(),
367 )
368 } else {
369 TLExpr::aggregate(
370 op.clone(),
371 var.clone(),
372 domain.clone(),
373 self.apply_recursive(body),
374 )
375 }
376 }
377
378 TLExpr::IfThenElse {
380 condition,
381 then_branch,
382 else_branch,
383 } => TLExpr::if_then_else(
384 self.apply_recursive(condition),
385 self.apply_recursive(then_branch),
386 self.apply_recursive(else_branch),
387 ),
388 TLExpr::Let { var, value, body } => TLExpr::let_binding(
389 var.clone(),
390 self.apply_recursive(value),
391 self.apply_recursive(body),
392 ),
393
394 TLExpr::TNorm { kind, left, right } => TLExpr::tnorm(
396 *kind,
397 self.apply_recursive(left),
398 self.apply_recursive(right),
399 ),
400 TLExpr::TCoNorm { kind, left, right } => TLExpr::tconorm(
401 *kind,
402 self.apply_recursive(left),
403 self.apply_recursive(right),
404 ),
405 TLExpr::FuzzyNot { kind, expr } => TLExpr::fuzzy_not(*kind, self.apply_recursive(expr)),
406 TLExpr::FuzzyImplication {
407 kind,
408 premise,
409 conclusion,
410 } => TLExpr::fuzzy_imply(
411 *kind,
412 self.apply_recursive(premise),
413 self.apply_recursive(conclusion),
414 ),
415
416 TLExpr::WeightedRule { weight, rule } => {
418 TLExpr::weighted_rule(*weight, self.apply_recursive(rule))
419 }
420 TLExpr::ProbabilisticChoice { alternatives } => TLExpr::probabilistic_choice(
421 alternatives
422 .iter()
423 .map(|(p, e)| (*p, self.apply_recursive(e)))
424 .collect(),
425 ),
426
427 TLExpr::Lambda {
429 var,
430 var_type,
431 body,
432 } => TLExpr::lambda(var.clone(), var_type.clone(), self.apply_recursive(body)),
433 TLExpr::Apply { function, argument } => TLExpr::apply(
434 self.apply_recursive(function),
435 self.apply_recursive(argument),
436 ),
437 TLExpr::SetMembership { element, set } => {
438 TLExpr::set_membership(self.apply_recursive(element), self.apply_recursive(set))
439 }
440 TLExpr::SetUnion { left, right } => {
441 TLExpr::set_union(self.apply_recursive(left), self.apply_recursive(right))
442 }
443 TLExpr::SetIntersection { left, right } => {
444 TLExpr::set_intersection(self.apply_recursive(left), self.apply_recursive(right))
445 }
446 TLExpr::SetDifference { left, right } => {
447 TLExpr::set_difference(self.apply_recursive(left), self.apply_recursive(right))
448 }
449 TLExpr::SetCardinality { set } => TLExpr::set_cardinality(self.apply_recursive(set)),
450 TLExpr::EmptySet => expr.clone(),
451 TLExpr::SetComprehension {
452 var,
453 domain,
454 condition,
455 } => TLExpr::set_comprehension(
456 var.clone(),
457 domain.clone(),
458 self.apply_recursive(condition),
459 ),
460 TLExpr::CountingExists {
461 var,
462 domain,
463 body,
464 min_count,
465 } => TLExpr::counting_exists(
466 var.clone(),
467 domain.clone(),
468 self.apply_recursive(body),
469 *min_count,
470 ),
471 TLExpr::CountingForAll {
472 var,
473 domain,
474 body,
475 min_count,
476 } => TLExpr::counting_forall(
477 var.clone(),
478 domain.clone(),
479 self.apply_recursive(body),
480 *min_count,
481 ),
482 TLExpr::ExactCount {
483 var,
484 domain,
485 body,
486 count,
487 } => TLExpr::exact_count(
488 var.clone(),
489 domain.clone(),
490 self.apply_recursive(body),
491 *count,
492 ),
493 TLExpr::Majority { var, domain, body } => {
494 TLExpr::majority(var.clone(), domain.clone(), self.apply_recursive(body))
495 }
496 TLExpr::LeastFixpoint { var, body } => {
497 TLExpr::least_fixpoint(var.clone(), self.apply_recursive(body))
498 }
499 TLExpr::GreatestFixpoint { var, body } => {
500 TLExpr::greatest_fixpoint(var.clone(), self.apply_recursive(body))
501 }
502 TLExpr::Nominal { .. } => expr.clone(),
503 TLExpr::At { nominal, formula } => {
504 TLExpr::at(nominal.clone(), self.apply_recursive(formula))
505 }
506 TLExpr::Somewhere { formula } => TLExpr::somewhere(self.apply_recursive(formula)),
507 TLExpr::Everywhere { formula } => TLExpr::everywhere(self.apply_recursive(formula)),
508 TLExpr::AllDifferent { .. } => expr.clone(),
509 TLExpr::GlobalCardinality {
510 variables,
511 values,
512 min_occurrences,
513 max_occurrences,
514 } => TLExpr::global_cardinality(
515 variables.clone(),
516 values.iter().map(|v| self.apply_recursive(v)).collect(),
517 min_occurrences.clone(),
518 max_occurrences.clone(),
519 ),
520 TLExpr::Abducible { .. } => expr.clone(),
521 TLExpr::Explain { formula } => TLExpr::explain(self.apply_recursive(formula)),
522
523 TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
525 }
526 }
527
528 pub fn apply_until_fixpoint(&self, expr: &TLExpr) -> TLExpr {
530 let mut current = expr.clone();
531 loop {
532 let next = self.apply_recursive(¤t);
533 if next == current {
534 return current;
535 }
536 current = next;
537 }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::Term;
545
546 #[test]
547 fn test_pattern_var_match() {
548 let pattern = Pattern::var("x");
549 let expr = TLExpr::pred("P", vec![Term::var("a")]);
550
551 let bindings = pattern.matches(&expr).unwrap();
552 assert_eq!(bindings.get("x"), Some(&expr));
553 }
554
555 #[test]
556 fn test_pattern_constant_match() {
557 let pattern = Pattern::constant(42.0);
558 let expr = TLExpr::constant(42.0);
559
560 assert!(pattern.matches(&expr).is_some());
561 }
562
563 #[test]
564 fn test_pattern_and_match() {
565 let pattern = Pattern::and(Pattern::var("A"), Pattern::var("B"));
566 let expr = TLExpr::and(
567 TLExpr::pred("P", vec![Term::var("x")]),
568 TLExpr::pred("Q", vec![Term::var("y")]),
569 );
570
571 let bindings = pattern.matches(&expr).unwrap();
572 assert!(bindings.contains_key("A"));
573 assert!(bindings.contains_key("B"));
574 }
575
576 #[test]
577 fn test_pattern_not_match() {
578 let pattern = Pattern::negation(Pattern::var("A"));
579 let expr = TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]));
580
581 let bindings = pattern.matches(&expr).unwrap();
582 assert!(bindings.contains_key("A"));
583 }
584
585 #[test]
586 fn test_double_negation_rule() {
587 let rule = RewriteRule::new(
588 Pattern::negation(Pattern::negation(Pattern::var("A"))),
589 |bindings| bindings.get("A").unwrap().clone(),
590 );
591
592 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
593 let result = rule.apply(&expr).unwrap();
594
595 assert!(matches!(result, TLExpr::Pred { .. }));
596 }
597
598 #[test]
599 fn test_rewrite_system_double_negation() {
600 let system = RewriteSystem::new().add_rule(RewriteRule::new(
601 Pattern::negation(Pattern::negation(Pattern::var("A"))),
602 |bindings| bindings.get("A").unwrap().clone(),
603 ));
604
605 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
606 let result = system.apply_recursive(&expr);
607
608 assert!(matches!(result, TLExpr::Pred { .. }));
609 }
610
611 #[test]
612 fn test_logic_equivalences_system() {
613 let system = RewriteSystem::with_logic_equivalences();
614
615 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
617 let result = system.apply_recursive(&expr);
618 assert!(matches!(result, TLExpr::Pred { .. }));
619
620 let expr = TLExpr::negate(TLExpr::and(
622 TLExpr::pred("P", vec![Term::var("x")]),
623 TLExpr::pred("Q", vec![Term::var("y")]),
624 ));
625 let result = system.apply_recursive(&expr);
626 assert!(matches!(result, TLExpr::Or(_, _)));
627 }
628
629 #[test]
630 fn test_nested_rewriting() {
631 let system = RewriteSystem::with_logic_equivalences();
632
633 let expr = TLExpr::negate(TLExpr::and(
635 TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]))),
636 TLExpr::pred("Q", vec![Term::var("y")]),
637 ));
638
639 let result = system.apply_until_fixpoint(&expr);
640 assert!(matches!(result, TLExpr::Or(_, _)));
642 }
643
644 #[test]
645 fn test_implication_expansion() {
646 let system = RewriteSystem::with_logic_equivalences();
647
648 let expr = TLExpr::imply(
650 TLExpr::pred("P", vec![Term::var("x")]),
651 TLExpr::pred("Q", vec![Term::var("y")]),
652 );
653
654 let result = system.apply_recursive(&expr);
655 assert!(matches!(result, TLExpr::Or(_, _)));
656 }
657}