Skip to main content

tensorlogic_ir/expr/
rewriting.rs

1//! Expression rewriting system with pattern matching.
2//!
3//! This module provides a powerful rewriting system that allows users to define
4//! custom transformation rules for TensorLogic expressions.
5
6use std::collections::HashMap;
7
8use super::TLExpr;
9
10/// A pattern that can match against expressions.
11#[derive(Clone, Debug, PartialEq)]
12pub enum Pattern {
13    /// Match any expression and bind it to a variable
14    Var(String),
15    /// Match a specific constant value
16    Constant(f64),
17    /// Match a predicate with a specific name
18    Pred { name: String, args: Vec<Pattern> },
19    /// Match an AND expression
20    And(Box<Pattern>, Box<Pattern>),
21    /// Match an OR expression
22    Or(Box<Pattern>, Box<Pattern>),
23    /// Match a NOT expression
24    Not(Box<Pattern>),
25    /// Match an implication
26    Imply(Box<Pattern>, Box<Pattern>),
27    /// Match any expression (wildcard)
28    Any,
29    /// Match an addition
30    Add(Box<Pattern>, Box<Pattern>),
31    /// Match a subtraction
32    Sub(Box<Pattern>, Box<Pattern>),
33    /// Match a multiplication
34    Mul(Box<Pattern>, Box<Pattern>),
35    /// Match a division
36    Div(Box<Pattern>, Box<Pattern>),
37    /// Match a power
38    Pow(Box<Pattern>, Box<Pattern>),
39    /// Match negation (encoded as Sub(0, x) in TLExpr)
40    Neg(Box<Pattern>),
41    /// Match an exponential
42    Exp(Box<Pattern>),
43    /// Match a logarithm
44    Log(Box<Pattern>),
45    /// Match a sine
46    Sin(Box<Pattern>),
47    /// Match a cosine
48    Cos(Box<Pattern>),
49    /// Match a tangent
50    Tan(Box<Pattern>),
51}
52
53impl Pattern {
54    /// Create a variable pattern.
55    pub fn var(name: impl Into<String>) -> Self {
56        Pattern::Var(name.into())
57    }
58
59    /// Create a constant pattern.
60    pub fn constant(value: f64) -> Self {
61        Pattern::Constant(value)
62    }
63
64    /// Create a wildcard pattern.
65    pub fn any() -> Self {
66        Pattern::Any
67    }
68
69    /// Create a predicate pattern.
70    pub fn pred(name: impl Into<String>, args: Vec<Pattern>) -> Self {
71        Pattern::Pred {
72            name: name.into(),
73            args,
74        }
75    }
76
77    /// Create an AND pattern.
78    pub fn and(left: Pattern, right: Pattern) -> Self {
79        Pattern::And(Box::new(left), Box::new(right))
80    }
81
82    /// Create an OR pattern.
83    pub fn or(left: Pattern, right: Pattern) -> Self {
84        Pattern::Or(Box::new(left), Box::new(right))
85    }
86
87    /// Create a NOT pattern.
88    pub fn negation(pattern: Pattern) -> Self {
89        Pattern::Not(Box::new(pattern))
90    }
91
92    /// Create an implication pattern.
93    pub fn imply(left: Pattern, right: Pattern) -> Self {
94        Pattern::Imply(Box::new(left), Box::new(right))
95    }
96
97    /// Create an addition pattern.
98    #[allow(clippy::should_implement_trait)]
99    pub fn add(left: Pattern, right: Pattern) -> Self {
100        Pattern::Add(Box::new(left), Box::new(right))
101    }
102
103    /// Create a subtraction pattern.
104    #[allow(clippy::should_implement_trait)]
105    pub fn sub(left: Pattern, right: Pattern) -> Self {
106        Pattern::Sub(Box::new(left), Box::new(right))
107    }
108
109    /// Create a multiplication pattern.
110    #[allow(clippy::should_implement_trait)]
111    pub fn mul(left: Pattern, right: Pattern) -> Self {
112        Pattern::Mul(Box::new(left), Box::new(right))
113    }
114
115    /// Create a division pattern.
116    #[allow(clippy::should_implement_trait)]
117    pub fn div(left: Pattern, right: Pattern) -> Self {
118        Pattern::Div(Box::new(left), Box::new(right))
119    }
120
121    /// Create a power pattern.
122    pub fn pow(left: Pattern, right: Pattern) -> Self {
123        Pattern::Pow(Box::new(left), Box::new(right))
124    }
125
126    /// Create a negation pattern (matches Sub(0, x)).
127    #[allow(clippy::should_implement_trait)]
128    pub fn neg(inner: Pattern) -> Self {
129        Pattern::Neg(Box::new(inner))
130    }
131
132    /// Create an exponential pattern.
133    pub fn exp(inner: Pattern) -> Self {
134        Pattern::Exp(Box::new(inner))
135    }
136
137    /// Create a logarithm pattern.
138    pub fn log(inner: Pattern) -> Self {
139        Pattern::Log(Box::new(inner))
140    }
141
142    /// Create a sine pattern.
143    pub fn sin(inner: Pattern) -> Self {
144        Pattern::Sin(Box::new(inner))
145    }
146
147    /// Create a cosine pattern.
148    pub fn cos(inner: Pattern) -> Self {
149        Pattern::Cos(Box::new(inner))
150    }
151
152    /// Create a tangent pattern.
153    pub fn tan(inner: Pattern) -> Self {
154        Pattern::Tan(Box::new(inner))
155    }
156
157    /// Try to match this pattern against an expression, returning bindings if successful.
158    pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, TLExpr>> {
159        let mut bindings = HashMap::new();
160        if self.matches_recursive(expr, &mut bindings) {
161            Some(bindings)
162        } else {
163            None
164        }
165    }
166
167    fn matches_recursive(&self, expr: &TLExpr, bindings: &mut HashMap<String, TLExpr>) -> bool {
168        match (self, expr) {
169            // Wildcard matches anything
170            (Pattern::Any, _) => true,
171
172            // Variable pattern: bind if not already bound, or check if bound value matches
173            (Pattern::Var(var_name), _) => {
174                if let Some(bound_expr) = bindings.get(var_name) {
175                    bound_expr == expr
176                } else {
177                    bindings.insert(var_name.clone(), expr.clone());
178                    true
179                }
180            }
181
182            // Constant pattern
183            (Pattern::Constant(pv), TLExpr::Constant(ev)) => (pv - ev).abs() < f64::EPSILON,
184
185            // Predicate pattern
186            (
187                Pattern::Pred {
188                    name: pname,
189                    args: pargs,
190                },
191                TLExpr::Pred {
192                    name: ename,
193                    args: eargs,
194                },
195            ) => {
196                if pname != ename || pargs.len() != eargs.len() {
197                    return false;
198                }
199                // Note: We're matching predicate arguments structurally here
200                // For a more sophisticated system, we'd need term patterns
201                pargs.len() == eargs.len()
202            }
203
204            // Binary logical operators
205            (Pattern::And(pl, pr), TLExpr::And(el, er))
206            | (Pattern::Or(pl, pr), TLExpr::Or(el, er))
207            | (Pattern::Imply(pl, pr), TLExpr::Imply(el, er)) => {
208                pl.matches_recursive(el, bindings) && pr.matches_recursive(er, bindings)
209            }
210
211            // Binary arithmetic operators
212            (Pattern::Add(pl, pr), TLExpr::Add(el, er))
213            | (Pattern::Sub(pl, pr), TLExpr::Sub(el, er))
214            | (Pattern::Mul(pl, pr), TLExpr::Mul(el, er))
215            | (Pattern::Div(pl, pr), TLExpr::Div(el, er))
216            | (Pattern::Pow(pl, pr), TLExpr::Pow(el, er)) => {
217                pl.matches_recursive(el, bindings) && pr.matches_recursive(er, bindings)
218            }
219
220            // Unary logical operators
221            (Pattern::Not(p), TLExpr::Not(e)) => p.matches_recursive(e, bindings),
222
223            // Negation: Pattern::Neg(a) matches TLExpr::Sub(Constant(~0.0), ea)
224            (Pattern::Neg(p), TLExpr::Sub(zero_expr, e)) => {
225                if let TLExpr::Constant(v) = zero_expr.as_ref() {
226                    v.abs() < 1e-15 && p.matches_recursive(e, bindings)
227                } else {
228                    false
229                }
230            }
231
232            // Unary transcendental/math operators
233            (Pattern::Exp(p), TLExpr::Exp(e))
234            | (Pattern::Log(p), TLExpr::Log(e))
235            | (Pattern::Sin(p), TLExpr::Sin(e))
236            | (Pattern::Cos(p), TLExpr::Cos(e))
237            | (Pattern::Tan(p), TLExpr::Tan(e)) => p.matches_recursive(e, bindings),
238
239            _ => false,
240        }
241    }
242}
243
244/// A rewrite rule that transforms expressions matching a pattern into a template.
245#[derive(Clone, Debug)]
246pub struct RewriteRule {
247    /// The pattern to match
248    pub pattern: Pattern,
249    /// Function to generate the replacement expression from bindings
250    pub template: fn(&HashMap<String, TLExpr>) -> TLExpr,
251    /// Optional name for debugging
252    pub name: Option<String>,
253}
254
255impl RewriteRule {
256    /// Create a new rewrite rule.
257    pub fn new(pattern: Pattern, template: fn(&HashMap<String, TLExpr>) -> TLExpr) -> Self {
258        Self {
259            pattern,
260            template,
261            name: None,
262        }
263    }
264
265    /// Create a named rewrite rule.
266    pub fn named(
267        name: impl Into<String>,
268        pattern: Pattern,
269        template: fn(&HashMap<String, TLExpr>) -> TLExpr,
270    ) -> Self {
271        Self {
272            pattern,
273            template,
274            name: Some(name.into()),
275        }
276    }
277
278    /// Try to apply this rule to an expression.
279    pub fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
280        self.pattern
281            .matches(expr)
282            .map(|bindings| (self.template)(&bindings))
283    }
284}
285
286/// A collection of rewrite rules that can be applied to expressions.
287#[derive(Clone, Debug, Default)]
288pub struct RewriteSystem {
289    rules: Vec<RewriteRule>,
290}
291
292impl RewriteSystem {
293    /// Create a new empty rewrite system.
294    pub fn new() -> Self {
295        Self::default()
296    }
297
298    /// Add a rule to the system.
299    pub fn add_rule(mut self, rule: RewriteRule) -> Self {
300        self.rules.push(rule);
301        self
302    }
303
304    /// Create a system with common logical equivalences.
305    pub fn with_logic_equivalences() -> Self {
306        let mut system = Self::new();
307
308        // Double negation elimination: ¬¬A → A
309        system = system.add_rule(RewriteRule::named(
310            "double_negation",
311            Pattern::negation(Pattern::negation(Pattern::var("A"))),
312            |bindings| {
313                bindings
314                    .get("A")
315                    .expect("binding 'A' must exist when pattern matched")
316                    .clone()
317            },
318        ));
319
320        // De Morgan's laws: ¬(A ∧ B) → ¬A ∨ ¬B
321        system = system.add_rule(RewriteRule::named(
322            "demorgan_and",
323            Pattern::negation(Pattern::and(Pattern::var("A"), Pattern::var("B"))),
324            |bindings| {
325                TLExpr::or(
326                    TLExpr::negate(
327                        bindings
328                            .get("A")
329                            .expect("binding 'A' must exist when pattern matched")
330                            .clone(),
331                    ),
332                    TLExpr::negate(
333                        bindings
334                            .get("B")
335                            .expect("binding 'B' must exist when pattern matched")
336                            .clone(),
337                    ),
338                )
339            },
340        ));
341
342        // De Morgan's laws: ¬(A ∨ B) → ¬A ∧ ¬B
343        system = system.add_rule(RewriteRule::named(
344            "demorgan_or",
345            Pattern::negation(Pattern::or(Pattern::var("A"), Pattern::var("B"))),
346            |bindings| {
347                TLExpr::and(
348                    TLExpr::negate(
349                        bindings
350                            .get("A")
351                            .expect("binding 'A' must exist when pattern matched")
352                            .clone(),
353                    ),
354                    TLExpr::negate(
355                        bindings
356                            .get("B")
357                            .expect("binding 'B' must exist when pattern matched")
358                            .clone(),
359                    ),
360                )
361            },
362        ));
363
364        // Implication expansion: A → B ≡ ¬A ∨ B
365        system = system.add_rule(RewriteRule::named(
366            "implication_expansion",
367            Pattern::imply(Pattern::var("A"), Pattern::var("B")),
368            |bindings| {
369                TLExpr::or(
370                    TLExpr::negate(
371                        bindings
372                            .get("A")
373                            .expect("binding 'A' must exist when pattern matched")
374                            .clone(),
375                    ),
376                    bindings
377                        .get("B")
378                        .expect("binding 'B' must exist when pattern matched")
379                        .clone(),
380                )
381            },
382        ));
383
384        system
385    }
386
387    /// Try to apply the first matching rule to an expression.
388    pub fn apply_once(&self, expr: &TLExpr) -> Option<TLExpr> {
389        for rule in &self.rules {
390            if let Some(result) = rule.apply(expr) {
391                return Some(result);
392            }
393        }
394        None
395    }
396
397    /// Apply rules recursively to an expression and all its subexpressions.
398    pub fn apply_recursive(&self, expr: &TLExpr) -> TLExpr {
399        // First, try to apply a rule at the top level
400        if let Some(rewritten) = self.apply_once(expr) {
401            return self.apply_recursive(&rewritten);
402        }
403
404        // If no rule applies, recurse into subexpressions
405        match expr {
406            TLExpr::And(l, r) => TLExpr::and(self.apply_recursive(l), self.apply_recursive(r)),
407            TLExpr::Or(l, r) => TLExpr::or(self.apply_recursive(l), self.apply_recursive(r)),
408            TLExpr::Not(e) => TLExpr::negate(self.apply_recursive(e)),
409            TLExpr::Imply(l, r) => TLExpr::imply(self.apply_recursive(l), self.apply_recursive(r)),
410            TLExpr::Score(e) => TLExpr::score(self.apply_recursive(e)),
411
412            // Arithmetic
413            TLExpr::Add(l, r) => TLExpr::add(self.apply_recursive(l), self.apply_recursive(r)),
414            TLExpr::Sub(l, r) => TLExpr::sub(self.apply_recursive(l), self.apply_recursive(r)),
415            TLExpr::Mul(l, r) => TLExpr::mul(self.apply_recursive(l), self.apply_recursive(r)),
416            TLExpr::Div(l, r) => TLExpr::div(self.apply_recursive(l), self.apply_recursive(r)),
417            TLExpr::Pow(l, r) => TLExpr::pow(self.apply_recursive(l), self.apply_recursive(r)),
418            TLExpr::Mod(l, r) => TLExpr::modulo(self.apply_recursive(l), self.apply_recursive(r)),
419            TLExpr::Min(l, r) => TLExpr::min(self.apply_recursive(l), self.apply_recursive(r)),
420            TLExpr::Max(l, r) => TLExpr::max(self.apply_recursive(l), self.apply_recursive(r)),
421
422            // Comparison
423            TLExpr::Eq(l, r) => TLExpr::eq(self.apply_recursive(l), self.apply_recursive(r)),
424            TLExpr::Lt(l, r) => TLExpr::lt(self.apply_recursive(l), self.apply_recursive(r)),
425            TLExpr::Gt(l, r) => TLExpr::gt(self.apply_recursive(l), self.apply_recursive(r)),
426            TLExpr::Lte(l, r) => TLExpr::lte(self.apply_recursive(l), self.apply_recursive(r)),
427            TLExpr::Gte(l, r) => TLExpr::gte(self.apply_recursive(l), self.apply_recursive(r)),
428
429            // Mathematical functions
430            TLExpr::Abs(e) => TLExpr::abs(self.apply_recursive(e)),
431            TLExpr::Floor(e) => TLExpr::floor(self.apply_recursive(e)),
432            TLExpr::Ceil(e) => TLExpr::ceil(self.apply_recursive(e)),
433            TLExpr::Round(e) => TLExpr::round(self.apply_recursive(e)),
434            TLExpr::Sqrt(e) => TLExpr::sqrt(self.apply_recursive(e)),
435            TLExpr::Exp(e) => TLExpr::exp(self.apply_recursive(e)),
436            TLExpr::Log(e) => TLExpr::log(self.apply_recursive(e)),
437            TLExpr::Sin(e) => TLExpr::sin(self.apply_recursive(e)),
438            TLExpr::Cos(e) => TLExpr::cos(self.apply_recursive(e)),
439            TLExpr::Tan(e) => TLExpr::tan(self.apply_recursive(e)),
440
441            // Modal/Temporal
442            TLExpr::Box(e) => TLExpr::modal_box(self.apply_recursive(e)),
443            TLExpr::Diamond(e) => TLExpr::modal_diamond(self.apply_recursive(e)),
444            TLExpr::Next(e) => TLExpr::next(self.apply_recursive(e)),
445            TLExpr::Eventually(e) => TLExpr::eventually(self.apply_recursive(e)),
446            TLExpr::Always(e) => TLExpr::always(self.apply_recursive(e)),
447            TLExpr::Until { before, after } => {
448                TLExpr::until(self.apply_recursive(before), self.apply_recursive(after))
449            }
450            TLExpr::Release { released, releaser } => TLExpr::release(
451                self.apply_recursive(released),
452                self.apply_recursive(releaser),
453            ),
454            TLExpr::WeakUntil { before, after } => {
455                TLExpr::weak_until(self.apply_recursive(before), self.apply_recursive(after))
456            }
457            TLExpr::StrongRelease { released, releaser } => TLExpr::strong_release(
458                self.apply_recursive(released),
459                self.apply_recursive(releaser),
460            ),
461
462            // Quantifiers
463            TLExpr::Exists { var, domain, body } => {
464                TLExpr::exists(var.clone(), domain.clone(), self.apply_recursive(body))
465            }
466            TLExpr::ForAll { var, domain, body } => {
467                TLExpr::forall(var.clone(), domain.clone(), self.apply_recursive(body))
468            }
469            TLExpr::SoftExists {
470                var,
471                domain,
472                body,
473                temperature,
474            } => TLExpr::soft_exists(
475                var.clone(),
476                domain.clone(),
477                self.apply_recursive(body),
478                *temperature,
479            ),
480            TLExpr::SoftForAll {
481                var,
482                domain,
483                body,
484                temperature,
485            } => TLExpr::soft_forall(
486                var.clone(),
487                domain.clone(),
488                self.apply_recursive(body),
489                *temperature,
490            ),
491
492            // Aggregation
493            TLExpr::Aggregate {
494                op,
495                var,
496                domain,
497                body,
498                group_by,
499            } => {
500                if let Some(group_vars) = group_by {
501                    TLExpr::aggregate_with_group_by(
502                        op.clone(),
503                        var.clone(),
504                        domain.clone(),
505                        self.apply_recursive(body),
506                        group_vars.clone(),
507                    )
508                } else {
509                    TLExpr::aggregate(
510                        op.clone(),
511                        var.clone(),
512                        domain.clone(),
513                        self.apply_recursive(body),
514                    )
515                }
516            }
517
518            // Control flow
519            TLExpr::IfThenElse {
520                condition,
521                then_branch,
522                else_branch,
523            } => TLExpr::if_then_else(
524                self.apply_recursive(condition),
525                self.apply_recursive(then_branch),
526                self.apply_recursive(else_branch),
527            ),
528            TLExpr::Let { var, value, body } => TLExpr::let_binding(
529                var.clone(),
530                self.apply_recursive(value),
531                self.apply_recursive(body),
532            ),
533
534            // Fuzzy logic
535            TLExpr::TNorm { kind, left, right } => TLExpr::tnorm(
536                *kind,
537                self.apply_recursive(left),
538                self.apply_recursive(right),
539            ),
540            TLExpr::TCoNorm { kind, left, right } => TLExpr::tconorm(
541                *kind,
542                self.apply_recursive(left),
543                self.apply_recursive(right),
544            ),
545            TLExpr::FuzzyNot { kind, expr } => TLExpr::fuzzy_not(*kind, self.apply_recursive(expr)),
546            TLExpr::FuzzyImplication {
547                kind,
548                premise,
549                conclusion,
550            } => TLExpr::fuzzy_imply(
551                *kind,
552                self.apply_recursive(premise),
553                self.apply_recursive(conclusion),
554            ),
555
556            // Probabilistic
557            TLExpr::WeightedRule { weight, rule } => {
558                TLExpr::weighted_rule(*weight, self.apply_recursive(rule))
559            }
560            TLExpr::ProbabilisticChoice { alternatives } => TLExpr::probabilistic_choice(
561                alternatives
562                    .iter()
563                    .map(|(p, e)| (*p, self.apply_recursive(e)))
564                    .collect(),
565            ),
566
567            // Beta.1 enhancements: recurse into subexpressions
568            TLExpr::Lambda {
569                var,
570                var_type,
571                body,
572            } => TLExpr::lambda(var.clone(), var_type.clone(), self.apply_recursive(body)),
573            TLExpr::Apply { function, argument } => TLExpr::apply(
574                self.apply_recursive(function),
575                self.apply_recursive(argument),
576            ),
577            TLExpr::SetMembership { element, set } => {
578                TLExpr::set_membership(self.apply_recursive(element), self.apply_recursive(set))
579            }
580            TLExpr::SetUnion { left, right } => {
581                TLExpr::set_union(self.apply_recursive(left), self.apply_recursive(right))
582            }
583            TLExpr::SetIntersection { left, right } => {
584                TLExpr::set_intersection(self.apply_recursive(left), self.apply_recursive(right))
585            }
586            TLExpr::SetDifference { left, right } => {
587                TLExpr::set_difference(self.apply_recursive(left), self.apply_recursive(right))
588            }
589            TLExpr::SetCardinality { set } => TLExpr::set_cardinality(self.apply_recursive(set)),
590            TLExpr::EmptySet => expr.clone(),
591            TLExpr::SetComprehension {
592                var,
593                domain,
594                condition,
595            } => TLExpr::set_comprehension(
596                var.clone(),
597                domain.clone(),
598                self.apply_recursive(condition),
599            ),
600            TLExpr::CountingExists {
601                var,
602                domain,
603                body,
604                min_count,
605            } => TLExpr::counting_exists(
606                var.clone(),
607                domain.clone(),
608                self.apply_recursive(body),
609                *min_count,
610            ),
611            TLExpr::CountingForAll {
612                var,
613                domain,
614                body,
615                min_count,
616            } => TLExpr::counting_forall(
617                var.clone(),
618                domain.clone(),
619                self.apply_recursive(body),
620                *min_count,
621            ),
622            TLExpr::ExactCount {
623                var,
624                domain,
625                body,
626                count,
627            } => TLExpr::exact_count(
628                var.clone(),
629                domain.clone(),
630                self.apply_recursive(body),
631                *count,
632            ),
633            TLExpr::Majority { var, domain, body } => {
634                TLExpr::majority(var.clone(), domain.clone(), self.apply_recursive(body))
635            }
636            TLExpr::LeastFixpoint { var, body } => {
637                TLExpr::least_fixpoint(var.clone(), self.apply_recursive(body))
638            }
639            TLExpr::GreatestFixpoint { var, body } => {
640                TLExpr::greatest_fixpoint(var.clone(), self.apply_recursive(body))
641            }
642            TLExpr::Nominal { .. } => expr.clone(),
643            TLExpr::At { nominal, formula } => {
644                TLExpr::at(nominal.clone(), self.apply_recursive(formula))
645            }
646            TLExpr::Somewhere { formula } => TLExpr::somewhere(self.apply_recursive(formula)),
647            TLExpr::Everywhere { formula } => TLExpr::everywhere(self.apply_recursive(formula)),
648            TLExpr::AllDifferent { .. } => expr.clone(),
649            TLExpr::GlobalCardinality {
650                variables,
651                values,
652                min_occurrences,
653                max_occurrences,
654            } => TLExpr::global_cardinality(
655                variables.clone(),
656                values.iter().map(|v| self.apply_recursive(v)).collect(),
657                min_occurrences.clone(),
658                max_occurrences.clone(),
659            ),
660            TLExpr::Abducible { .. } => expr.clone(),
661            TLExpr::Explain { formula } => TLExpr::explain(self.apply_recursive(formula)),
662            TLExpr::SymbolLiteral(_) => expr.clone(),
663            TLExpr::Match { scrutinee, arms } => TLExpr::Match {
664                scrutinee: Box::new(self.apply_recursive(scrutinee)),
665                arms: arms
666                    .iter()
667                    .map(|(p, b)| (p.clone(), Box::new(self.apply_recursive(b))))
668                    .collect(),
669            },
670
671            // Leaves - no recursion needed
672            TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
673        }
674    }
675
676    /// Apply rules until no more changes occur (fixed point).
677    pub fn apply_until_fixpoint(&self, expr: &TLExpr) -> TLExpr {
678        let mut current = expr.clone();
679        loop {
680            let next = self.apply_recursive(&current);
681            if next == current {
682                return current;
683            }
684            current = next;
685        }
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692    use crate::Term;
693
694    #[test]
695    fn test_pattern_var_match() {
696        let pattern = Pattern::var("x");
697        let expr = TLExpr::pred("P", vec![Term::var("a")]);
698
699        let bindings = pattern.matches(&expr).expect("unwrap");
700        assert_eq!(bindings.get("x"), Some(&expr));
701    }
702
703    #[test]
704    fn test_pattern_constant_match() {
705        let pattern = Pattern::constant(42.0);
706        let expr = TLExpr::constant(42.0);
707
708        assert!(pattern.matches(&expr).is_some());
709    }
710
711    #[test]
712    fn test_pattern_and_match() {
713        let pattern = Pattern::and(Pattern::var("A"), Pattern::var("B"));
714        let expr = TLExpr::and(
715            TLExpr::pred("P", vec![Term::var("x")]),
716            TLExpr::pred("Q", vec![Term::var("y")]),
717        );
718
719        let bindings = pattern.matches(&expr).expect("unwrap");
720        assert!(bindings.contains_key("A"));
721        assert!(bindings.contains_key("B"));
722    }
723
724    #[test]
725    fn test_pattern_not_match() {
726        let pattern = Pattern::negation(Pattern::var("A"));
727        let expr = TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]));
728
729        let bindings = pattern.matches(&expr).expect("unwrap");
730        assert!(bindings.contains_key("A"));
731    }
732
733    #[test]
734    fn test_double_negation_rule() {
735        let rule = RewriteRule::new(
736            Pattern::negation(Pattern::negation(Pattern::var("A"))),
737            |bindings| {
738                bindings
739                    .get("A")
740                    .expect("binding 'A' must exist when pattern matched")
741                    .clone()
742            },
743        );
744
745        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
746        let result = rule.apply(&expr).expect("unwrap");
747
748        assert!(matches!(result, TLExpr::Pred { .. }));
749    }
750
751    #[test]
752    fn test_rewrite_system_double_negation() {
753        let system = RewriteSystem::new().add_rule(RewriteRule::new(
754            Pattern::negation(Pattern::negation(Pattern::var("A"))),
755            |bindings| {
756                bindings
757                    .get("A")
758                    .expect("binding 'A' must exist when pattern matched")
759                    .clone()
760            },
761        ));
762
763        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
764        let result = system.apply_recursive(&expr);
765
766        assert!(matches!(result, TLExpr::Pred { .. }));
767    }
768
769    #[test]
770    fn test_logic_equivalences_system() {
771        let system = RewriteSystem::with_logic_equivalences();
772
773        // Test double negation
774        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
775        let result = system.apply_recursive(&expr);
776        assert!(matches!(result, TLExpr::Pred { .. }));
777
778        // Test De Morgan's law: ¬(A ∧ B) → ¬A ∨ ¬B
779        let expr = TLExpr::negate(TLExpr::and(
780            TLExpr::pred("P", vec![Term::var("x")]),
781            TLExpr::pred("Q", vec![Term::var("y")]),
782        ));
783        let result = system.apply_recursive(&expr);
784        assert!(matches!(result, TLExpr::Or(_, _)));
785    }
786
787    #[test]
788    fn test_nested_rewriting() {
789        let system = RewriteSystem::with_logic_equivalences();
790
791        // ¬(¬¬P ∧ Q) should be rewritten to ¬(P ∧ Q) then to ¬P ∨ ¬Q
792        let expr = TLExpr::negate(TLExpr::and(
793            TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]))),
794            TLExpr::pred("Q", vec![Term::var("y")]),
795        ));
796
797        let result = system.apply_until_fixpoint(&expr);
798        // Should be ¬P ∨ ¬Q after full rewriting
799        assert!(matches!(result, TLExpr::Or(_, _)));
800    }
801
802    #[test]
803    fn test_implication_expansion() {
804        let system = RewriteSystem::with_logic_equivalences();
805
806        // P → Q should expand to ¬P ∨ Q
807        let expr = TLExpr::imply(
808            TLExpr::pred("P", vec![Term::var("x")]),
809            TLExpr::pred("Q", vec![Term::var("y")]),
810        );
811
812        let result = system.apply_recursive(&expr);
813        assert!(matches!(result, TLExpr::Or(_, _)));
814    }
815
816    #[test]
817    fn test_pattern_add_match() {
818        let pattern = Pattern::add(Pattern::var("x"), Pattern::var("y"));
819        let expr = TLExpr::add(TLExpr::constant(1.0), TLExpr::constant(2.0));
820
821        let bindings = pattern
822            .matches(&expr)
823            .expect("Pattern::Add should match TLExpr::Add");
824        assert_eq!(bindings.get("x"), Some(&TLExpr::constant(1.0)));
825        assert_eq!(bindings.get("y"), Some(&TLExpr::constant(2.0)));
826    }
827
828    #[test]
829    fn test_pattern_exp_match() {
830        let pattern = Pattern::exp(Pattern::var("x"));
831        let expr = TLExpr::exp(TLExpr::constant(1.0));
832
833        let bindings = pattern
834            .matches(&expr)
835            .expect("Pattern::Exp should match TLExpr::Exp");
836        assert_eq!(bindings.get("x"), Some(&TLExpr::constant(1.0)));
837    }
838
839    #[test]
840    fn test_pattern_neg_match() {
841        let pattern = Pattern::neg(Pattern::var("x"));
842        // Negation in TLExpr is Sub(0, x)
843        let expr = TLExpr::sub(TLExpr::constant(0.0), TLExpr::constant(5.0));
844
845        let bindings = pattern
846            .matches(&expr)
847            .expect("Pattern::Neg should match TLExpr::Sub(0, x)");
848        assert_eq!(bindings.get("x"), Some(&TLExpr::constant(5.0)));
849    }
850
851    #[test]
852    fn test_pattern_add_does_not_match_mul() {
853        let pattern = Pattern::add(Pattern::var("x"), Pattern::var("y"));
854        let expr = TLExpr::mul(TLExpr::constant(1.0), TLExpr::constant(2.0));
855
856        assert!(pattern.matches(&expr).is_none());
857    }
858
859    #[test]
860    fn test_pattern_sin_cos_tan_match() {
861        let sin_pat = Pattern::sin(Pattern::var("a"));
862        let cos_pat = Pattern::cos(Pattern::var("a"));
863        let tan_pat = Pattern::tan(Pattern::var("a"));
864
865        let sin_expr = TLExpr::sin(TLExpr::constant(0.5));
866        let cos_expr = TLExpr::cos(TLExpr::constant(0.5));
867        let tan_expr = TLExpr::tan(TLExpr::constant(0.5));
868
869        assert!(sin_pat.matches(&sin_expr).is_some());
870        assert!(cos_pat.matches(&cos_expr).is_some());
871        assert!(tan_pat.matches(&tan_expr).is_some());
872
873        // Cross-mismatches
874        assert!(sin_pat.matches(&cos_expr).is_none());
875        assert!(cos_pat.matches(&tan_expr).is_none());
876    }
877
878    #[test]
879    fn test_pattern_neg_nonzero_constant_no_match() {
880        let pattern = Pattern::neg(Pattern::var("x"));
881        // Sub(1.0, x) should NOT match Neg pattern (needs ~0.0 as first arg)
882        let expr = TLExpr::sub(TLExpr::constant(1.0), TLExpr::constant(5.0));
883
884        assert!(pattern.matches(&expr).is_none());
885    }
886}