Skip to main content

tensorlogic_ir/expr/
rewriting.rs

1//! Expression rewriting system with pattern matching.
2//!
3//! This module provides a powerful rewriting system that allows users to define
4//! custom transformation rules for TensorLogic expressions.
5
6use std::collections::HashMap;
7
8use super::TLExpr;
9
10/// A pattern that can match against expressions.
11#[derive(Clone, Debug, PartialEq)]
12pub enum Pattern {
13    /// Match any expression and bind it to a variable
14    Var(String),
15    /// Match a specific constant value
16    Constant(f64),
17    /// Match a predicate with a specific name
18    Pred { name: String, args: Vec<Pattern> },
19    /// Match an AND expression
20    And(Box<Pattern>, Box<Pattern>),
21    /// Match an OR expression
22    Or(Box<Pattern>, Box<Pattern>),
23    /// Match a NOT expression
24    Not(Box<Pattern>),
25    /// Match an implication
26    Imply(Box<Pattern>, Box<Pattern>),
27    /// Match any expression (wildcard)
28    Any,
29}
30
31impl Pattern {
32    /// Create a variable pattern.
33    pub fn var(name: impl Into<String>) -> Self {
34        Pattern::Var(name.into())
35    }
36
37    /// Create a constant pattern.
38    pub fn constant(value: f64) -> Self {
39        Pattern::Constant(value)
40    }
41
42    /// Create a wildcard pattern.
43    pub fn any() -> Self {
44        Pattern::Any
45    }
46
47    /// Create a predicate pattern.
48    pub fn pred(name: impl Into<String>, args: Vec<Pattern>) -> Self {
49        Pattern::Pred {
50            name: name.into(),
51            args,
52        }
53    }
54
55    /// Create an AND pattern.
56    pub fn and(left: Pattern, right: Pattern) -> Self {
57        Pattern::And(Box::new(left), Box::new(right))
58    }
59
60    /// Create an OR pattern.
61    pub fn or(left: Pattern, right: Pattern) -> Self {
62        Pattern::Or(Box::new(left), Box::new(right))
63    }
64
65    /// Create a NOT pattern.
66    pub fn negation(pattern: Pattern) -> Self {
67        Pattern::Not(Box::new(pattern))
68    }
69
70    /// Create an implication pattern.
71    pub fn imply(left: Pattern, right: Pattern) -> Self {
72        Pattern::Imply(Box::new(left), Box::new(right))
73    }
74
75    /// Try to match this pattern against an expression, returning bindings if successful.
76    pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, TLExpr>> {
77        let mut bindings = HashMap::new();
78        if self.matches_recursive(expr, &mut bindings) {
79            Some(bindings)
80        } else {
81            None
82        }
83    }
84
85    fn matches_recursive(&self, expr: &TLExpr, bindings: &mut HashMap<String, TLExpr>) -> bool {
86        match (self, expr) {
87            // Wildcard matches anything
88            (Pattern::Any, _) => true,
89
90            // Variable pattern: bind if not already bound, or check if bound value matches
91            (Pattern::Var(var_name), _) => {
92                if let Some(bound_expr) = bindings.get(var_name) {
93                    bound_expr == expr
94                } else {
95                    bindings.insert(var_name.clone(), expr.clone());
96                    true
97                }
98            }
99
100            // Constant pattern
101            (Pattern::Constant(pv), TLExpr::Constant(ev)) => (pv - ev).abs() < f64::EPSILON,
102
103            // Predicate pattern
104            (
105                Pattern::Pred {
106                    name: pname,
107                    args: pargs,
108                },
109                TLExpr::Pred {
110                    name: ename,
111                    args: eargs,
112                },
113            ) => {
114                if pname != ename || pargs.len() != eargs.len() {
115                    return false;
116                }
117                // Note: We're matching predicate arguments structurally here
118                // For a more sophisticated system, we'd need term patterns
119                pargs.len() == eargs.len()
120            }
121
122            // Binary operators
123            (Pattern::And(pl, pr), TLExpr::And(el, er))
124            | (Pattern::Or(pl, pr), TLExpr::Or(el, er))
125            | (Pattern::Imply(pl, pr), TLExpr::Imply(el, er)) => {
126                pl.matches_recursive(el, bindings) && pr.matches_recursive(er, bindings)
127            }
128
129            // Unary operators
130            (Pattern::Not(p), TLExpr::Not(e)) => p.matches_recursive(e, bindings),
131
132            _ => false,
133        }
134    }
135}
136
137/// A rewrite rule that transforms expressions matching a pattern into a template.
138#[derive(Clone, Debug)]
139pub struct RewriteRule {
140    /// The pattern to match
141    pub pattern: Pattern,
142    /// Function to generate the replacement expression from bindings
143    pub template: fn(&HashMap<String, TLExpr>) -> TLExpr,
144    /// Optional name for debugging
145    pub name: Option<String>,
146}
147
148impl RewriteRule {
149    /// Create a new rewrite rule.
150    pub fn new(pattern: Pattern, template: fn(&HashMap<String, TLExpr>) -> TLExpr) -> Self {
151        Self {
152            pattern,
153            template,
154            name: None,
155        }
156    }
157
158    /// Create a named rewrite rule.
159    pub fn named(
160        name: impl Into<String>,
161        pattern: Pattern,
162        template: fn(&HashMap<String, TLExpr>) -> TLExpr,
163    ) -> Self {
164        Self {
165            pattern,
166            template,
167            name: Some(name.into()),
168        }
169    }
170
171    /// Try to apply this rule to an expression.
172    pub fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
173        self.pattern
174            .matches(expr)
175            .map(|bindings| (self.template)(&bindings))
176    }
177}
178
179/// A collection of rewrite rules that can be applied to expressions.
180#[derive(Clone, Debug, Default)]
181pub struct RewriteSystem {
182    rules: Vec<RewriteRule>,
183}
184
185impl RewriteSystem {
186    /// Create a new empty rewrite system.
187    pub fn new() -> Self {
188        Self::default()
189    }
190
191    /// Add a rule to the system.
192    pub fn add_rule(mut self, rule: RewriteRule) -> Self {
193        self.rules.push(rule);
194        self
195    }
196
197    /// Create a system with common logical equivalences.
198    pub fn with_logic_equivalences() -> Self {
199        let mut system = Self::new();
200
201        // Double negation elimination: ¬¬A → A
202        system = system.add_rule(RewriteRule::named(
203            "double_negation",
204            Pattern::negation(Pattern::negation(Pattern::var("A"))),
205            |bindings| bindings.get("A").unwrap().clone(),
206        ));
207
208        // De Morgan's laws: ¬(A ∧ B) → ¬A ∨ ¬B
209        system = system.add_rule(RewriteRule::named(
210            "demorgan_and",
211            Pattern::negation(Pattern::and(Pattern::var("A"), Pattern::var("B"))),
212            |bindings| {
213                TLExpr::or(
214                    TLExpr::negate(bindings.get("A").unwrap().clone()),
215                    TLExpr::negate(bindings.get("B").unwrap().clone()),
216                )
217            },
218        ));
219
220        // De Morgan's laws: ¬(A ∨ B) → ¬A ∧ ¬B
221        system = system.add_rule(RewriteRule::named(
222            "demorgan_or",
223            Pattern::negation(Pattern::or(Pattern::var("A"), Pattern::var("B"))),
224            |bindings| {
225                TLExpr::and(
226                    TLExpr::negate(bindings.get("A").unwrap().clone()),
227                    TLExpr::negate(bindings.get("B").unwrap().clone()),
228                )
229            },
230        ));
231
232        // Implication expansion: A → B ≡ ¬A ∨ B
233        system = system.add_rule(RewriteRule::named(
234            "implication_expansion",
235            Pattern::imply(Pattern::var("A"), Pattern::var("B")),
236            |bindings| {
237                TLExpr::or(
238                    TLExpr::negate(bindings.get("A").unwrap().clone()),
239                    bindings.get("B").unwrap().clone(),
240                )
241            },
242        ));
243
244        system
245    }
246
247    /// Try to apply the first matching rule to an expression.
248    pub fn apply_once(&self, expr: &TLExpr) -> Option<TLExpr> {
249        for rule in &self.rules {
250            if let Some(result) = rule.apply(expr) {
251                return Some(result);
252            }
253        }
254        None
255    }
256
257    /// Apply rules recursively to an expression and all its subexpressions.
258    pub fn apply_recursive(&self, expr: &TLExpr) -> TLExpr {
259        // First, try to apply a rule at the top level
260        if let Some(rewritten) = self.apply_once(expr) {
261            return self.apply_recursive(&rewritten);
262        }
263
264        // If no rule applies, recurse into subexpressions
265        match expr {
266            TLExpr::And(l, r) => TLExpr::and(self.apply_recursive(l), self.apply_recursive(r)),
267            TLExpr::Or(l, r) => TLExpr::or(self.apply_recursive(l), self.apply_recursive(r)),
268            TLExpr::Not(e) => TLExpr::negate(self.apply_recursive(e)),
269            TLExpr::Imply(l, r) => TLExpr::imply(self.apply_recursive(l), self.apply_recursive(r)),
270            TLExpr::Score(e) => TLExpr::score(self.apply_recursive(e)),
271
272            // Arithmetic
273            TLExpr::Add(l, r) => TLExpr::add(self.apply_recursive(l), self.apply_recursive(r)),
274            TLExpr::Sub(l, r) => TLExpr::sub(self.apply_recursive(l), self.apply_recursive(r)),
275            TLExpr::Mul(l, r) => TLExpr::mul(self.apply_recursive(l), self.apply_recursive(r)),
276            TLExpr::Div(l, r) => TLExpr::div(self.apply_recursive(l), self.apply_recursive(r)),
277            TLExpr::Pow(l, r) => TLExpr::pow(self.apply_recursive(l), self.apply_recursive(r)),
278            TLExpr::Mod(l, r) => TLExpr::modulo(self.apply_recursive(l), self.apply_recursive(r)),
279            TLExpr::Min(l, r) => TLExpr::min(self.apply_recursive(l), self.apply_recursive(r)),
280            TLExpr::Max(l, r) => TLExpr::max(self.apply_recursive(l), self.apply_recursive(r)),
281
282            // Comparison
283            TLExpr::Eq(l, r) => TLExpr::eq(self.apply_recursive(l), self.apply_recursive(r)),
284            TLExpr::Lt(l, r) => TLExpr::lt(self.apply_recursive(l), self.apply_recursive(r)),
285            TLExpr::Gt(l, r) => TLExpr::gt(self.apply_recursive(l), self.apply_recursive(r)),
286            TLExpr::Lte(l, r) => TLExpr::lte(self.apply_recursive(l), self.apply_recursive(r)),
287            TLExpr::Gte(l, r) => TLExpr::gte(self.apply_recursive(l), self.apply_recursive(r)),
288
289            // Mathematical functions
290            TLExpr::Abs(e) => TLExpr::abs(self.apply_recursive(e)),
291            TLExpr::Floor(e) => TLExpr::floor(self.apply_recursive(e)),
292            TLExpr::Ceil(e) => TLExpr::ceil(self.apply_recursive(e)),
293            TLExpr::Round(e) => TLExpr::round(self.apply_recursive(e)),
294            TLExpr::Sqrt(e) => TLExpr::sqrt(self.apply_recursive(e)),
295            TLExpr::Exp(e) => TLExpr::exp(self.apply_recursive(e)),
296            TLExpr::Log(e) => TLExpr::log(self.apply_recursive(e)),
297            TLExpr::Sin(e) => TLExpr::sin(self.apply_recursive(e)),
298            TLExpr::Cos(e) => TLExpr::cos(self.apply_recursive(e)),
299            TLExpr::Tan(e) => TLExpr::tan(self.apply_recursive(e)),
300
301            // Modal/Temporal
302            TLExpr::Box(e) => TLExpr::modal_box(self.apply_recursive(e)),
303            TLExpr::Diamond(e) => TLExpr::modal_diamond(self.apply_recursive(e)),
304            TLExpr::Next(e) => TLExpr::next(self.apply_recursive(e)),
305            TLExpr::Eventually(e) => TLExpr::eventually(self.apply_recursive(e)),
306            TLExpr::Always(e) => TLExpr::always(self.apply_recursive(e)),
307            TLExpr::Until { before, after } => {
308                TLExpr::until(self.apply_recursive(before), self.apply_recursive(after))
309            }
310            TLExpr::Release { released, releaser } => TLExpr::release(
311                self.apply_recursive(released),
312                self.apply_recursive(releaser),
313            ),
314            TLExpr::WeakUntil { before, after } => {
315                TLExpr::weak_until(self.apply_recursive(before), self.apply_recursive(after))
316            }
317            TLExpr::StrongRelease { released, releaser } => TLExpr::strong_release(
318                self.apply_recursive(released),
319                self.apply_recursive(releaser),
320            ),
321
322            // Quantifiers
323            TLExpr::Exists { var, domain, body } => {
324                TLExpr::exists(var.clone(), domain.clone(), self.apply_recursive(body))
325            }
326            TLExpr::ForAll { var, domain, body } => {
327                TLExpr::forall(var.clone(), domain.clone(), self.apply_recursive(body))
328            }
329            TLExpr::SoftExists {
330                var,
331                domain,
332                body,
333                temperature,
334            } => TLExpr::soft_exists(
335                var.clone(),
336                domain.clone(),
337                self.apply_recursive(body),
338                *temperature,
339            ),
340            TLExpr::SoftForAll {
341                var,
342                domain,
343                body,
344                temperature,
345            } => TLExpr::soft_forall(
346                var.clone(),
347                domain.clone(),
348                self.apply_recursive(body),
349                *temperature,
350            ),
351
352            // Aggregation
353            TLExpr::Aggregate {
354                op,
355                var,
356                domain,
357                body,
358                group_by,
359            } => {
360                if let Some(group_vars) = group_by {
361                    TLExpr::aggregate_with_group_by(
362                        op.clone(),
363                        var.clone(),
364                        domain.clone(),
365                        self.apply_recursive(body),
366                        group_vars.clone(),
367                    )
368                } else {
369                    TLExpr::aggregate(
370                        op.clone(),
371                        var.clone(),
372                        domain.clone(),
373                        self.apply_recursive(body),
374                    )
375                }
376            }
377
378            // Control flow
379            TLExpr::IfThenElse {
380                condition,
381                then_branch,
382                else_branch,
383            } => TLExpr::if_then_else(
384                self.apply_recursive(condition),
385                self.apply_recursive(then_branch),
386                self.apply_recursive(else_branch),
387            ),
388            TLExpr::Let { var, value, body } => TLExpr::let_binding(
389                var.clone(),
390                self.apply_recursive(value),
391                self.apply_recursive(body),
392            ),
393
394            // Fuzzy logic
395            TLExpr::TNorm { kind, left, right } => TLExpr::tnorm(
396                *kind,
397                self.apply_recursive(left),
398                self.apply_recursive(right),
399            ),
400            TLExpr::TCoNorm { kind, left, right } => TLExpr::tconorm(
401                *kind,
402                self.apply_recursive(left),
403                self.apply_recursive(right),
404            ),
405            TLExpr::FuzzyNot { kind, expr } => TLExpr::fuzzy_not(*kind, self.apply_recursive(expr)),
406            TLExpr::FuzzyImplication {
407                kind,
408                premise,
409                conclusion,
410            } => TLExpr::fuzzy_imply(
411                *kind,
412                self.apply_recursive(premise),
413                self.apply_recursive(conclusion),
414            ),
415
416            // Probabilistic
417            TLExpr::WeightedRule { weight, rule } => {
418                TLExpr::weighted_rule(*weight, self.apply_recursive(rule))
419            }
420            TLExpr::ProbabilisticChoice { alternatives } => TLExpr::probabilistic_choice(
421                alternatives
422                    .iter()
423                    .map(|(p, e)| (*p, self.apply_recursive(e)))
424                    .collect(),
425            ),
426
427            // Beta.1 enhancements: recurse into subexpressions
428            TLExpr::Lambda {
429                var,
430                var_type,
431                body,
432            } => TLExpr::lambda(var.clone(), var_type.clone(), self.apply_recursive(body)),
433            TLExpr::Apply { function, argument } => TLExpr::apply(
434                self.apply_recursive(function),
435                self.apply_recursive(argument),
436            ),
437            TLExpr::SetMembership { element, set } => {
438                TLExpr::set_membership(self.apply_recursive(element), self.apply_recursive(set))
439            }
440            TLExpr::SetUnion { left, right } => {
441                TLExpr::set_union(self.apply_recursive(left), self.apply_recursive(right))
442            }
443            TLExpr::SetIntersection { left, right } => {
444                TLExpr::set_intersection(self.apply_recursive(left), self.apply_recursive(right))
445            }
446            TLExpr::SetDifference { left, right } => {
447                TLExpr::set_difference(self.apply_recursive(left), self.apply_recursive(right))
448            }
449            TLExpr::SetCardinality { set } => TLExpr::set_cardinality(self.apply_recursive(set)),
450            TLExpr::EmptySet => expr.clone(),
451            TLExpr::SetComprehension {
452                var,
453                domain,
454                condition,
455            } => TLExpr::set_comprehension(
456                var.clone(),
457                domain.clone(),
458                self.apply_recursive(condition),
459            ),
460            TLExpr::CountingExists {
461                var,
462                domain,
463                body,
464                min_count,
465            } => TLExpr::counting_exists(
466                var.clone(),
467                domain.clone(),
468                self.apply_recursive(body),
469                *min_count,
470            ),
471            TLExpr::CountingForAll {
472                var,
473                domain,
474                body,
475                min_count,
476            } => TLExpr::counting_forall(
477                var.clone(),
478                domain.clone(),
479                self.apply_recursive(body),
480                *min_count,
481            ),
482            TLExpr::ExactCount {
483                var,
484                domain,
485                body,
486                count,
487            } => TLExpr::exact_count(
488                var.clone(),
489                domain.clone(),
490                self.apply_recursive(body),
491                *count,
492            ),
493            TLExpr::Majority { var, domain, body } => {
494                TLExpr::majority(var.clone(), domain.clone(), self.apply_recursive(body))
495            }
496            TLExpr::LeastFixpoint { var, body } => {
497                TLExpr::least_fixpoint(var.clone(), self.apply_recursive(body))
498            }
499            TLExpr::GreatestFixpoint { var, body } => {
500                TLExpr::greatest_fixpoint(var.clone(), self.apply_recursive(body))
501            }
502            TLExpr::Nominal { .. } => expr.clone(),
503            TLExpr::At { nominal, formula } => {
504                TLExpr::at(nominal.clone(), self.apply_recursive(formula))
505            }
506            TLExpr::Somewhere { formula } => TLExpr::somewhere(self.apply_recursive(formula)),
507            TLExpr::Everywhere { formula } => TLExpr::everywhere(self.apply_recursive(formula)),
508            TLExpr::AllDifferent { .. } => expr.clone(),
509            TLExpr::GlobalCardinality {
510                variables,
511                values,
512                min_occurrences,
513                max_occurrences,
514            } => TLExpr::global_cardinality(
515                variables.clone(),
516                values.iter().map(|v| self.apply_recursive(v)).collect(),
517                min_occurrences.clone(),
518                max_occurrences.clone(),
519            ),
520            TLExpr::Abducible { .. } => expr.clone(),
521            TLExpr::Explain { formula } => TLExpr::explain(self.apply_recursive(formula)),
522
523            // Leaves - no recursion needed
524            TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
525        }
526    }
527
528    /// Apply rules until no more changes occur (fixed point).
529    pub fn apply_until_fixpoint(&self, expr: &TLExpr) -> TLExpr {
530        let mut current = expr.clone();
531        loop {
532            let next = self.apply_recursive(&current);
533            if next == current {
534                return current;
535            }
536            current = next;
537        }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use crate::Term;
545
546    #[test]
547    fn test_pattern_var_match() {
548        let pattern = Pattern::var("x");
549        let expr = TLExpr::pred("P", vec![Term::var("a")]);
550
551        let bindings = pattern.matches(&expr).unwrap();
552        assert_eq!(bindings.get("x"), Some(&expr));
553    }
554
555    #[test]
556    fn test_pattern_constant_match() {
557        let pattern = Pattern::constant(42.0);
558        let expr = TLExpr::constant(42.0);
559
560        assert!(pattern.matches(&expr).is_some());
561    }
562
563    #[test]
564    fn test_pattern_and_match() {
565        let pattern = Pattern::and(Pattern::var("A"), Pattern::var("B"));
566        let expr = TLExpr::and(
567            TLExpr::pred("P", vec![Term::var("x")]),
568            TLExpr::pred("Q", vec![Term::var("y")]),
569        );
570
571        let bindings = pattern.matches(&expr).unwrap();
572        assert!(bindings.contains_key("A"));
573        assert!(bindings.contains_key("B"));
574    }
575
576    #[test]
577    fn test_pattern_not_match() {
578        let pattern = Pattern::negation(Pattern::var("A"));
579        let expr = TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]));
580
581        let bindings = pattern.matches(&expr).unwrap();
582        assert!(bindings.contains_key("A"));
583    }
584
585    #[test]
586    fn test_double_negation_rule() {
587        let rule = RewriteRule::new(
588            Pattern::negation(Pattern::negation(Pattern::var("A"))),
589            |bindings| bindings.get("A").unwrap().clone(),
590        );
591
592        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
593        let result = rule.apply(&expr).unwrap();
594
595        assert!(matches!(result, TLExpr::Pred { .. }));
596    }
597
598    #[test]
599    fn test_rewrite_system_double_negation() {
600        let system = RewriteSystem::new().add_rule(RewriteRule::new(
601            Pattern::negation(Pattern::negation(Pattern::var("A"))),
602            |bindings| bindings.get("A").unwrap().clone(),
603        ));
604
605        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
606        let result = system.apply_recursive(&expr);
607
608        assert!(matches!(result, TLExpr::Pred { .. }));
609    }
610
611    #[test]
612    fn test_logic_equivalences_system() {
613        let system = RewriteSystem::with_logic_equivalences();
614
615        // Test double negation
616        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
617        let result = system.apply_recursive(&expr);
618        assert!(matches!(result, TLExpr::Pred { .. }));
619
620        // Test De Morgan's law: ¬(A ∧ B) → ¬A ∨ ¬B
621        let expr = TLExpr::negate(TLExpr::and(
622            TLExpr::pred("P", vec![Term::var("x")]),
623            TLExpr::pred("Q", vec![Term::var("y")]),
624        ));
625        let result = system.apply_recursive(&expr);
626        assert!(matches!(result, TLExpr::Or(_, _)));
627    }
628
629    #[test]
630    fn test_nested_rewriting() {
631        let system = RewriteSystem::with_logic_equivalences();
632
633        // ¬(¬¬P ∧ Q) should be rewritten to ¬(P ∧ Q) then to ¬P ∨ ¬Q
634        let expr = TLExpr::negate(TLExpr::and(
635            TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]))),
636            TLExpr::pred("Q", vec![Term::var("y")]),
637        ));
638
639        let result = system.apply_until_fixpoint(&expr);
640        // Should be ¬P ∨ ¬Q after full rewriting
641        assert!(matches!(result, TLExpr::Or(_, _)));
642    }
643
644    #[test]
645    fn test_implication_expansion() {
646        let system = RewriteSystem::with_logic_equivalences();
647
648        // P → Q should expand to ¬P ∨ Q
649        let expr = TLExpr::imply(
650            TLExpr::pred("P", vec![Term::var("x")]),
651            TLExpr::pred("Q", vec![Term::var("y")]),
652        );
653
654        let result = system.apply_recursive(&expr);
655        assert!(matches!(result, TLExpr::Or(_, _)));
656    }
657}