Skip to main content

tensorlogic_compiler/
rewrite.rs

1//! Pattern-matching rewrite rules engine for TLExpr transformation.
2//!
3//! This module provides a composable, extensible rule engine that applies
4//! structural rewrites to [`tensorlogic_ir::TLExpr`] trees until a fixed point
5//! is reached (no further rules apply).
6//!
7//! # Design
8//!
9//! - Rules implement the [`RewriteRule`] trait.
10//! - The [`RewriteEngine`] holds a collection of rules and drives iteration.
11//! - Rewrites are applied **bottom-up**: children are transformed first, then
12//!   the root node is offered to each rule in order.
13//! - Iteration continues until no rule fires in a full pass (fixed point) or
14//!   `max_iterations` is reached.
15//!
16//! # Built-in rules
17//!
18//! | Rule struct | Transformation |
19//! |---|---|
20//! | [`EliminateDoubleNeg`] | `Not(Not(x))` → `x` |
21//! | [`FlattenNestedAnd`] | `And(And(a,b),c)` → `And(a,And(b,c))` |
22//! | [`FlattenNestedOr`] | `Or(Or(a,b),c)` → `Or(a,Or(b,c))` |
23//! | [`EliminateAndTrue`] | `And(True,x)` / `And(x,True)` → `x` |
24//! | [`EliminateOrFalse`] | `Or(False,x)` / `Or(x,False)` → `x` |
25//!
26//! # Example
27//!
28//! ```rust
29//! use tensorlogic_compiler::rewrite::{RewriteEngine, EliminateDoubleNeg};
30//! use tensorlogic_ir::{TLExpr, Term};
31//!
32//! let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("p", vec![Term::var("x")])));
33//!
34//! let (result, stats) = RewriteEngine::new()
35//!     .add_rule(Box::new(EliminateDoubleNeg))
36//!     .rewrite(expr);
37//!
38//! // Not(Not(p(x))) → p(x)
39//! assert_eq!(stats.total_rewrites, 1);
40//! println!("{}", stats.summary());
41//! ```
42
43use std::collections::HashMap;
44use std::fmt;
45
46use tensorlogic_ir::TLExpr;
47
48// ---------------------------------------------------------------------------
49// Helper macros for DRY binary / unary child rewriting
50// (must be declared before use)
51// ---------------------------------------------------------------------------
52
53/// Rewrite both children of a binary node and reconstruct with the same variant.
54macro_rules! rewrite_binary {
55    ($self:expr, $stats:expr, $ctor:path, $left:expr, $right:expr) => {{
56        let (nl, cl) = $self.rewrite_expr(*$left, $stats);
57        let (nr, cr) = $self.rewrite_expr(*$right, $stats);
58        ($ctor(Box::new(nl), Box::new(nr)), cl || cr)
59    }};
60}
61
62/// Rewrite the single child of a unary node and reconstruct.
63macro_rules! rewrite_unary {
64    ($self:expr, $stats:expr, $ctor:path, $inner:expr) => {{
65        let (ni, changed) = $self.rewrite_expr(*$inner, $stats);
66        ($ctor(Box::new(ni)), changed)
67    }};
68}
69
70// ---------------------------------------------------------------------------
71// RewriteRule trait
72// ---------------------------------------------------------------------------
73
74/// A single pattern-matching rewrite rule.
75///
76/// Implementors inspect an expression and, if the rule matches, return a
77/// transformed replacement expression.  If the rule does not apply, `None`
78/// is returned and the engine tries the next rule.
79pub trait RewriteRule: Send + Sync {
80    /// Unique human-readable name used in statistics output.
81    fn name(&self) -> &'static str;
82
83    /// Try to apply this rule to `expr`.
84    ///
85    /// Returns `Some(new_expr)` when the rule fires, `None` otherwise.
86    fn apply(&self, expr: &TLExpr) -> Option<TLExpr>;
87
88    /// Whether the engine should recurse into children of `expr` before
89    /// trying this rule.  Defaults to `true` (standard bottom-up traversal).
90    fn is_recursive(&self) -> bool {
91        true
92    }
93}
94
95// ---------------------------------------------------------------------------
96// RewriteStats
97// ---------------------------------------------------------------------------
98
99/// Accumulated statistics for a rewrite pass (or the full fixed-point loop).
100#[derive(Debug, Clone, Default)]
101pub struct RewriteStats {
102    /// How many times each named rule fired.
103    pub rules_applied: HashMap<String, u64>,
104    /// Total number of individual rewrites across all rules and iterations.
105    pub total_rewrites: u64,
106    /// Number of full-tree passes performed.
107    pub iterations: u32,
108    /// Total expression nodes visited (across all passes).
109    pub nodes_visited: u64,
110    /// `true` when the engine stopped because no rule fired (fixed point),
111    /// `false` when it stopped because `max_iterations` was reached.
112    pub fixed_point_reached: bool,
113}
114
115impl RewriteStats {
116    /// Record one application of `rule_name`, incrementing all counters.
117    pub fn record_rule(&mut self, rule_name: &str) {
118        *self.rules_applied.entry(rule_name.to_owned()).or_insert(0) += 1;
119        self.total_rewrites += 1;
120    }
121
122    /// Human-readable one-line summary suitable for logging.
123    pub fn summary(&self) -> String {
124        if self.rules_applied.is_empty() {
125            return format!(
126                "RewriteStats: 0 rewrites, {} iteration(s), fixed_point={}",
127                self.iterations, self.fixed_point_reached
128            );
129        }
130
131        let mut rule_parts: Vec<String> = self
132            .rules_applied
133            .iter()
134            .map(|(name, count)| format!("{}×{}", name, count))
135            .collect();
136        rule_parts.sort(); // deterministic output
137
138        format!(
139            "RewriteStats: {} rewrite(s) in {} iteration(s), fixed_point={} [{}]",
140            self.total_rewrites,
141            self.iterations,
142            self.fixed_point_reached,
143            rule_parts.join(", ")
144        )
145    }
146}
147
148impl fmt::Display for RewriteStats {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        write!(f, "{}", self.summary())
151    }
152}
153
154// ---------------------------------------------------------------------------
155// RewriteEngine
156// ---------------------------------------------------------------------------
157
158/// Applies a set of [`RewriteRule`]s to a [`TLExpr`] tree until a fixed point.
159///
160/// Rules are attempted in the order they were added.  Within a single node
161/// visit the first firing rule wins and the engine moves on; no other rules
162/// are tried for the same node in the same pass.
163///
164/// The engine iterates until either:
165/// - a full pass completes with zero rule firings (fixed point), or
166/// - [`max_iterations`][RewriteEngine::max_iterations] is reached.
167pub struct RewriteEngine {
168    rules: Vec<Box<dyn RewriteRule>>,
169    /// Maximum number of full-tree passes before stopping (default: 64).
170    pub max_iterations: u32,
171    /// Soft limit on node visits per pass (default: 1_000_000).  If exceeded,
172    /// the engine finishes the current pass but does not start a new one.
173    pub max_nodes_per_pass: u64,
174}
175
176impl RewriteEngine {
177    /// Create a new engine with no rules and default limits.
178    pub fn new() -> Self {
179        Self {
180            rules: Vec::new(),
181            max_iterations: 64,
182            max_nodes_per_pass: 1_000_000,
183        }
184    }
185
186    /// Set the maximum number of fixed-point iterations (builder pattern).
187    pub fn with_max_iterations(mut self, n: u32) -> Self {
188        self.max_iterations = n;
189        self
190    }
191
192    /// Add one rule to the engine (builder pattern).
193    pub fn add_rule(mut self, rule: Box<dyn RewriteRule>) -> Self {
194        self.rules.push(rule);
195        self
196    }
197
198    /// Register all five built-in rules (builder pattern).
199    pub fn add_all_builtin_rules(self) -> Self {
200        builtin_rules()
201            .into_iter()
202            .fold(self, |engine, rule| engine.add_rule(rule))
203    }
204
205    /// Apply all rules to `expr` until a fixed point is reached.
206    ///
207    /// Returns the final expression together with accumulated [`RewriteStats`].
208    pub fn rewrite(&self, expr: TLExpr) -> (TLExpr, RewriteStats) {
209        let mut stats = RewriteStats::default();
210        let mut current = expr;
211
212        for _iteration in 0..self.max_iterations {
213            stats.iterations += 1;
214            let (next, changed) = self.rewrite_once(current, &mut stats);
215            current = next;
216
217            if !changed {
218                stats.fixed_point_reached = true;
219                break;
220            }
221
222            if stats.nodes_visited >= self.max_nodes_per_pass {
223                break;
224            }
225        }
226
227        (current, stats)
228    }
229
230    /// Perform exactly one full pass over the tree.
231    ///
232    /// Returns the (possibly modified) expression and whether any rule fired.
233    fn rewrite_once(&self, expr: TLExpr, stats: &mut RewriteStats) -> (TLExpr, bool) {
234        self.rewrite_expr(expr, stats)
235    }
236
237    /// Bottom-up recursive traversal: rewrite children first, then try rules
238    /// at the current node.
239    fn rewrite_expr(&self, expr: TLExpr, stats: &mut RewriteStats) -> (TLExpr, bool) {
240        stats.nodes_visited += 1;
241
242        // --- Step 1: descend into children to get a rewritten child tree ---
243        let (expr_after_children, children_changed) = self.rewrite_children(expr, stats);
244
245        // --- Step 2: try rules at the current (post-children) node ---
246        for rule in &self.rules {
247            if let Some(replacement) = rule.apply(&expr_after_children) {
248                stats.record_rule(rule.name());
249                // Do NOT recurse again here; the outer fixed-point loop handles it.
250                return (replacement, true);
251            }
252        }
253
254        (expr_after_children, children_changed)
255    }
256
257    /// Reconstruct the node after recursively rewriting children.
258    ///
259    /// Only the "interesting" structural variants (`Not`, `And`, `Or`, and a
260    /// selection of other compound forms) are descended into; leaf nodes are
261    /// returned unchanged.
262    fn rewrite_children(&self, expr: TLExpr, stats: &mut RewriteStats) -> (TLExpr, bool) {
263        match expr {
264            // ---- Structural connectives (primary targets) ----
265            TLExpr::Not(inner) => {
266                let (new_inner, changed) = self.rewrite_expr(*inner, stats);
267                (TLExpr::Not(Box::new(new_inner)), changed)
268            }
269            TLExpr::And(left, right) => {
270                let (new_left, cl) = self.rewrite_expr(*left, stats);
271                let (new_right, cr) = self.rewrite_expr(*right, stats);
272                (
273                    TLExpr::And(Box::new(new_left), Box::new(new_right)),
274                    cl || cr,
275                )
276            }
277            TLExpr::Or(left, right) => {
278                let (new_left, cl) = self.rewrite_expr(*left, stats);
279                let (new_right, cr) = self.rewrite_expr(*right, stats);
280                (
281                    TLExpr::Or(Box::new(new_left), Box::new(new_right)),
282                    cl || cr,
283                )
284            }
285            TLExpr::Imply(ante, cons) => {
286                let (new_ante, ca) = self.rewrite_expr(*ante, stats);
287                let (new_cons, cc) = self.rewrite_expr(*cons, stats);
288                (
289                    TLExpr::Imply(Box::new(new_ante), Box::new(new_cons)),
290                    ca || cc,
291                )
292            }
293            TLExpr::Score(inner) => {
294                let (new_inner, changed) = self.rewrite_expr(*inner, stats);
295                (TLExpr::Score(Box::new(new_inner)), changed)
296            }
297
298            // ---- Arithmetic binary ----
299            TLExpr::Add(l, r) => rewrite_binary!(self, stats, TLExpr::Add, l, r),
300            TLExpr::Sub(l, r) => rewrite_binary!(self, stats, TLExpr::Sub, l, r),
301            TLExpr::Mul(l, r) => rewrite_binary!(self, stats, TLExpr::Mul, l, r),
302            TLExpr::Div(l, r) => rewrite_binary!(self, stats, TLExpr::Div, l, r),
303            TLExpr::Pow(l, r) => rewrite_binary!(self, stats, TLExpr::Pow, l, r),
304            TLExpr::Mod(l, r) => rewrite_binary!(self, stats, TLExpr::Mod, l, r),
305            TLExpr::Min(l, r) => rewrite_binary!(self, stats, TLExpr::Min, l, r),
306            TLExpr::Max(l, r) => rewrite_binary!(self, stats, TLExpr::Max, l, r),
307
308            // ---- Arithmetic unary ----
309            TLExpr::Abs(inner) => rewrite_unary!(self, stats, TLExpr::Abs, inner),
310            TLExpr::Floor(inner) => rewrite_unary!(self, stats, TLExpr::Floor, inner),
311            TLExpr::Ceil(inner) => rewrite_unary!(self, stats, TLExpr::Ceil, inner),
312            TLExpr::Round(inner) => rewrite_unary!(self, stats, TLExpr::Round, inner),
313            TLExpr::Sqrt(inner) => rewrite_unary!(self, stats, TLExpr::Sqrt, inner),
314            TLExpr::Exp(inner) => rewrite_unary!(self, stats, TLExpr::Exp, inner),
315            TLExpr::Log(inner) => rewrite_unary!(self, stats, TLExpr::Log, inner),
316            TLExpr::Sin(inner) => rewrite_unary!(self, stats, TLExpr::Sin, inner),
317            TLExpr::Cos(inner) => rewrite_unary!(self, stats, TLExpr::Cos, inner),
318            TLExpr::Tan(inner) => rewrite_unary!(self, stats, TLExpr::Tan, inner),
319
320            // ---- Comparison binary ----
321            TLExpr::Eq(l, r) => rewrite_binary!(self, stats, TLExpr::Eq, l, r),
322            TLExpr::Lt(l, r) => rewrite_binary!(self, stats, TLExpr::Lt, l, r),
323            TLExpr::Gt(l, r) => rewrite_binary!(self, stats, TLExpr::Gt, l, r),
324            TLExpr::Lte(l, r) => rewrite_binary!(self, stats, TLExpr::Lte, l, r),
325            TLExpr::Gte(l, r) => rewrite_binary!(self, stats, TLExpr::Gte, l, r),
326
327            // ---- Conditional ----
328            TLExpr::IfThenElse {
329                condition,
330                then_branch,
331                else_branch,
332            } => {
333                let (new_cond, cc) = self.rewrite_expr(*condition, stats);
334                let (new_then, ct) = self.rewrite_expr(*then_branch, stats);
335                let (new_else, ce) = self.rewrite_expr(*else_branch, stats);
336                (
337                    TLExpr::IfThenElse {
338                        condition: Box::new(new_cond),
339                        then_branch: Box::new(new_then),
340                        else_branch: Box::new(new_else),
341                    },
342                    cc || ct || ce,
343                )
344            }
345
346            // ---- Quantifiers ----
347            TLExpr::Exists { var, domain, body } => {
348                let (new_body, changed) = self.rewrite_expr(*body, stats);
349                (
350                    TLExpr::Exists {
351                        var,
352                        domain,
353                        body: Box::new(new_body),
354                    },
355                    changed,
356                )
357            }
358            TLExpr::ForAll { var, domain, body } => {
359                let (new_body, changed) = self.rewrite_expr(*body, stats);
360                (
361                    TLExpr::ForAll {
362                        var,
363                        domain,
364                        body: Box::new(new_body),
365                    },
366                    changed,
367                )
368            }
369
370            // ---- Modal logic ----
371            TLExpr::Box(inner) => rewrite_unary!(self, stats, TLExpr::Box, inner),
372            TLExpr::Diamond(inner) => rewrite_unary!(self, stats, TLExpr::Diamond, inner),
373
374            // ---- Temporal logic ----
375            TLExpr::Next(inner) => rewrite_unary!(self, stats, TLExpr::Next, inner),
376            TLExpr::Eventually(inner) => {
377                rewrite_unary!(self, stats, TLExpr::Eventually, inner)
378            }
379            TLExpr::Always(inner) => rewrite_unary!(self, stats, TLExpr::Always, inner),
380            TLExpr::Until { before, after } => {
381                let (nb, cb) = self.rewrite_expr(*before, stats);
382                let (na, ca) = self.rewrite_expr(*after, stats);
383                (
384                    TLExpr::Until {
385                        before: Box::new(nb),
386                        after: Box::new(na),
387                    },
388                    cb || ca,
389                )
390            }
391            TLExpr::Release { released, releaser } => {
392                let (nr, cr) = self.rewrite_expr(*released, stats);
393                let (nl, cl) = self.rewrite_expr(*releaser, stats);
394                (
395                    TLExpr::Release {
396                        released: Box::new(nr),
397                        releaser: Box::new(nl),
398                    },
399                    cr || cl,
400                )
401            }
402            TLExpr::WeakUntil { before, after } => {
403                let (nb, cb) = self.rewrite_expr(*before, stats);
404                let (na, ca) = self.rewrite_expr(*after, stats);
405                (
406                    TLExpr::WeakUntil {
407                        before: Box::new(nb),
408                        after: Box::new(na),
409                    },
410                    cb || ca,
411                )
412            }
413            TLExpr::StrongRelease { released, releaser } => {
414                let (nr, cr) = self.rewrite_expr(*released, stats);
415                let (nl, cl) = self.rewrite_expr(*releaser, stats);
416                (
417                    TLExpr::StrongRelease {
418                        released: Box::new(nr),
419                        releaser: Box::new(nl),
420                    },
421                    cr || cl,
422                )
423            }
424
425            // ---- Fuzzy logic ----
426            TLExpr::TNorm { kind, left, right } => {
427                let (nl, cl) = self.rewrite_expr(*left, stats);
428                let (nr, cr) = self.rewrite_expr(*right, stats);
429                (
430                    TLExpr::TNorm {
431                        kind,
432                        left: Box::new(nl),
433                        right: Box::new(nr),
434                    },
435                    cl || cr,
436                )
437            }
438            TLExpr::TCoNorm { kind, left, right } => {
439                let (nl, cl) = self.rewrite_expr(*left, stats);
440                let (nr, cr) = self.rewrite_expr(*right, stats);
441                (
442                    TLExpr::TCoNorm {
443                        kind,
444                        left: Box::new(nl),
445                        right: Box::new(nr),
446                    },
447                    cl || cr,
448                )
449            }
450            TLExpr::FuzzyNot { kind, expr: inner } => {
451                let (ni, changed) = self.rewrite_expr(*inner, stats);
452                (
453                    TLExpr::FuzzyNot {
454                        kind,
455                        expr: Box::new(ni),
456                    },
457                    changed,
458                )
459            }
460            TLExpr::FuzzyImplication {
461                kind,
462                premise,
463                conclusion,
464            } => {
465                let (np, cp) = self.rewrite_expr(*premise, stats);
466                let (nc, cc) = self.rewrite_expr(*conclusion, stats);
467                (
468                    TLExpr::FuzzyImplication {
469                        kind,
470                        premise: Box::new(np),
471                        conclusion: Box::new(nc),
472                    },
473                    cp || cc,
474                )
475            }
476
477            // ---- Probabilistic ----
478            TLExpr::SoftExists {
479                var,
480                domain,
481                body,
482                temperature,
483            } => {
484                let (nb, changed) = self.rewrite_expr(*body, stats);
485                (
486                    TLExpr::SoftExists {
487                        var,
488                        domain,
489                        body: Box::new(nb),
490                        temperature,
491                    },
492                    changed,
493                )
494            }
495            TLExpr::SoftForAll {
496                var,
497                domain,
498                body,
499                temperature,
500            } => {
501                let (nb, changed) = self.rewrite_expr(*body, stats);
502                (
503                    TLExpr::SoftForAll {
504                        var,
505                        domain,
506                        body: Box::new(nb),
507                        temperature,
508                    },
509                    changed,
510                )
511            }
512            TLExpr::WeightedRule { weight, rule } => {
513                let (nr, changed) = self.rewrite_expr(*rule, stats);
514                (
515                    TLExpr::WeightedRule {
516                        weight,
517                        rule: Box::new(nr),
518                    },
519                    changed,
520                )
521            }
522            TLExpr::ProbabilisticChoice { alternatives } => {
523                let mut changed = false;
524                let new_alts = alternatives
525                    .into_iter()
526                    .map(|(prob, alt_expr)| {
527                        let (ne, c) = self.rewrite_expr(alt_expr, stats);
528                        changed |= c;
529                        (prob, ne)
530                    })
531                    .collect();
532                (
533                    TLExpr::ProbabilisticChoice {
534                        alternatives: new_alts,
535                    },
536                    changed,
537                )
538            }
539
540            // ---- Higher-order ----
541            TLExpr::Lambda {
542                var,
543                var_type,
544                body,
545            } => {
546                let (nb, changed) = self.rewrite_expr(*body, stats);
547                (
548                    TLExpr::Lambda {
549                        var,
550                        var_type,
551                        body: Box::new(nb),
552                    },
553                    changed,
554                )
555            }
556            TLExpr::Apply { function, argument } => {
557                let (nf, cf) = self.rewrite_expr(*function, stats);
558                let (na, ca) = self.rewrite_expr(*argument, stats);
559                (
560                    TLExpr::Apply {
561                        function: Box::new(nf),
562                        argument: Box::new(na),
563                    },
564                    cf || ca,
565                )
566            }
567
568            // ---- Set operations ----
569            TLExpr::SetMembership { element, set } => {
570                let (ne, ce) = self.rewrite_expr(*element, stats);
571                let (ns, cs) = self.rewrite_expr(*set, stats);
572                (
573                    TLExpr::SetMembership {
574                        element: Box::new(ne),
575                        set: Box::new(ns),
576                    },
577                    ce || cs,
578                )
579            }
580            TLExpr::SetUnion { left, right } => {
581                let (nl, cl) = self.rewrite_expr(*left, stats);
582                let (nr, cr) = self.rewrite_expr(*right, stats);
583                (
584                    TLExpr::SetUnion {
585                        left: Box::new(nl),
586                        right: Box::new(nr),
587                    },
588                    cl || cr,
589                )
590            }
591            TLExpr::SetIntersection { left, right } => {
592                let (nl, cl) = self.rewrite_expr(*left, stats);
593                let (nr, cr) = self.rewrite_expr(*right, stats);
594                (
595                    TLExpr::SetIntersection {
596                        left: Box::new(nl),
597                        right: Box::new(nr),
598                    },
599                    cl || cr,
600                )
601            }
602            TLExpr::SetDifference { left, right } => {
603                let (nl, cl) = self.rewrite_expr(*left, stats);
604                let (nr, cr) = self.rewrite_expr(*right, stats);
605                (
606                    TLExpr::SetDifference {
607                        left: Box::new(nl),
608                        right: Box::new(nr),
609                    },
610                    cl || cr,
611                )
612            }
613            TLExpr::SetCardinality { set } => {
614                let (ns, changed) = self.rewrite_expr(*set, stats);
615                (TLExpr::SetCardinality { set: Box::new(ns) }, changed)
616            }
617            TLExpr::SetComprehension {
618                var,
619                domain,
620                condition,
621            } => {
622                let (nc, changed) = self.rewrite_expr(*condition, stats);
623                (
624                    TLExpr::SetComprehension {
625                        var,
626                        domain,
627                        condition: Box::new(nc),
628                    },
629                    changed,
630                )
631            }
632
633            // ---- Let binding ----
634            TLExpr::Let { var, value, body } => {
635                let (nv, cv) = self.rewrite_expr(*value, stats);
636                let (nb, cb) = self.rewrite_expr(*body, stats);
637                (
638                    TLExpr::Let {
639                        var,
640                        value: Box::new(nv),
641                        body: Box::new(nb),
642                    },
643                    cv || cb,
644                )
645            }
646
647            // ---- Aggregate ----
648            TLExpr::Aggregate {
649                op,
650                var,
651                domain,
652                body,
653                group_by,
654            } => {
655                let (nb, changed) = self.rewrite_expr(*body, stats);
656                (
657                    TLExpr::Aggregate {
658                        op,
659                        var,
660                        domain,
661                        body: Box::new(nb),
662                        group_by,
663                    },
664                    changed,
665                )
666            }
667
668            // ---- Leaf nodes (no children to recurse into) ----
669            leaf => (leaf, false),
670        }
671    }
672}
673
674impl Default for RewriteEngine {
675    fn default() -> Self {
676        Self::new()
677    }
678}
679
680// ---------------------------------------------------------------------------
681// Built-in rules
682// ---------------------------------------------------------------------------
683
684// ---- EliminateDoubleNeg ----
685
686/// Eliminate double negation: `Not(Not(x))` → `x`.
687#[derive(Debug, Clone, Default)]
688pub struct EliminateDoubleNeg;
689
690impl RewriteRule for EliminateDoubleNeg {
691    fn name(&self) -> &'static str {
692        "eliminate_double_neg"
693    }
694
695    fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
696        if let TLExpr::Not(inner) = expr {
697            if let TLExpr::Not(inner_inner) = inner.as_ref() {
698                return Some(*inner_inner.clone());
699            }
700        }
701        None
702    }
703}
704
705// ---- FlattenNestedAnd ----
706
707/// Right-associate a left-nested `And`: `And(And(a,b),c)` → `And(a,And(b,c))`.
708#[derive(Debug, Clone, Default)]
709pub struct FlattenNestedAnd;
710
711impl RewriteRule for FlattenNestedAnd {
712    fn name(&self) -> &'static str {
713        "flatten_nested_and"
714    }
715
716    fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
717        if let TLExpr::And(left, right) = expr {
718            if let TLExpr::And(a, b) = left.as_ref() {
719                // And(And(a,b), c) → And(a, And(b,c))
720                let new_right = TLExpr::And(b.clone(), right.clone());
721                return Some(TLExpr::And(a.clone(), Box::new(new_right)));
722            }
723        }
724        None
725    }
726}
727
728// ---- FlattenNestedOr ----
729
730/// Right-associate a left-nested `Or`: `Or(Or(a,b),c)` → `Or(a,Or(b,c))`.
731#[derive(Debug, Clone, Default)]
732pub struct FlattenNestedOr;
733
734impl RewriteRule for FlattenNestedOr {
735    fn name(&self) -> &'static str {
736        "flatten_nested_or"
737    }
738
739    fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
740        if let TLExpr::Or(left, right) = expr {
741            if let TLExpr::Or(a, b) = left.as_ref() {
742                // Or(Or(a,b), c) → Or(a, Or(b,c))
743                let new_right = TLExpr::Or(b.clone(), right.clone());
744                return Some(TLExpr::Or(a.clone(), Box::new(new_right)));
745            }
746        }
747        None
748    }
749}
750
751// ---- EliminateAndTrue ----
752
753/// Identity for conjunction: `And(True, x)` or `And(x, True)` → `x`.
754///
755/// "True" is represented as `TLExpr::Constant(c)` where `c ≈ 1.0`.
756#[derive(Debug, Clone, Default)]
757pub struct EliminateAndTrue;
758
759impl RewriteRule for EliminateAndTrue {
760    fn name(&self) -> &'static str {
761        "eliminate_and_true"
762    }
763
764    fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
765        if let TLExpr::And(left, right) = expr {
766            if is_true_constant(left) {
767                return Some(*right.clone());
768            }
769            if is_true_constant(right) {
770                return Some(*left.clone());
771            }
772        }
773        None
774    }
775}
776
777// ---- EliminateOrFalse ----
778
779/// Identity for disjunction: `Or(False, x)` or `Or(x, False)` → `x`.
780///
781/// "False" is represented as `TLExpr::Constant(c)` where `c ≈ 0.0`.
782#[derive(Debug, Clone, Default)]
783pub struct EliminateOrFalse;
784
785impl RewriteRule for EliminateOrFalse {
786    fn name(&self) -> &'static str {
787        "eliminate_or_false"
788    }
789
790    fn apply(&self, expr: &TLExpr) -> Option<TLExpr> {
791        if let TLExpr::Or(left, right) = expr {
792            if is_false_constant(left) {
793                return Some(*right.clone());
794            }
795            if is_false_constant(right) {
796                return Some(*left.clone());
797            }
798        }
799        None
800    }
801}
802
803// ---------------------------------------------------------------------------
804// Helpers for constant detection
805// ---------------------------------------------------------------------------
806
807/// Returns `true` if `expr` is a numeric constant close to `1.0` (logical True).
808#[inline]
809fn is_true_constant(expr: &TLExpr) -> bool {
810    if let TLExpr::Constant(v) = expr {
811        (v - 1.0_f64).abs() < f64::EPSILON
812    } else {
813        false
814    }
815}
816
817/// Returns `true` if `expr` is a numeric constant close to `0.0` (logical False).
818#[inline]
819fn is_false_constant(expr: &TLExpr) -> bool {
820    if let TLExpr::Constant(v) = expr {
821        v.abs() < f64::EPSILON
822    } else {
823        false
824    }
825}
826
827// ---------------------------------------------------------------------------
828// Convenience constructor
829// ---------------------------------------------------------------------------
830
831/// Return one boxed instance of each built-in rule.
832///
833/// The order matches the recommended application order:
834/// 1. [`EliminateDoubleNeg`]
835/// 2. [`FlattenNestedAnd`]
836/// 3. [`FlattenNestedOr`]
837/// 4. [`EliminateAndTrue`]
838/// 5. [`EliminateOrFalse`]
839pub fn builtin_rules() -> Vec<Box<dyn RewriteRule>> {
840    vec![
841        Box::new(EliminateDoubleNeg) as Box<dyn RewriteRule>,
842        Box::new(FlattenNestedAnd),
843        Box::new(FlattenNestedOr),
844        Box::new(EliminateAndTrue),
845        Box::new(EliminateOrFalse),
846    ]
847}
848
849// ---------------------------------------------------------------------------
850// Tests
851// ---------------------------------------------------------------------------
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856    use tensorlogic_ir::{TLExpr, Term};
857
858    // ------------------------------------------------------------------
859    // Helper constructors
860    // ------------------------------------------------------------------
861
862    fn pred(name: &str) -> TLExpr {
863        TLExpr::pred(name, vec![Term::var("x")])
864    }
865
866    fn tru() -> TLExpr {
867        TLExpr::Constant(1.0)
868    }
869
870    fn fal() -> TLExpr {
871        TLExpr::Constant(0.0)
872    }
873
874    // ------------------------------------------------------------------
875    // EliminateDoubleNeg
876    // ------------------------------------------------------------------
877
878    #[test]
879    fn test_eliminate_double_neg_fires() {
880        let inner = pred("p");
881        let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(inner.clone()))));
882        let result = EliminateDoubleNeg.apply(&expr);
883        assert_eq!(result, Some(inner));
884    }
885
886    #[test]
887    fn test_eliminate_double_neg_no_fire() {
888        let expr = TLExpr::Not(Box::new(pred("p")));
889        assert_eq!(EliminateDoubleNeg.apply(&expr), None);
890    }
891
892    #[test]
893    fn test_eliminate_double_neg_nested() {
894        // Not(Not(Not(x))) — the rule fires at the outer level and removes two
895        // layers, leaving Not(x). (The engine applies one rule per node per pass.)
896        let x = pred("p");
897        let not_x = TLExpr::Not(Box::new(x.clone()));
898        let not_not_x = TLExpr::Not(Box::new(not_x.clone()));
899        let not_not_not_x = TLExpr::Not(Box::new(not_not_x));
900        let result = EliminateDoubleNeg.apply(&not_not_not_x);
901        assert_eq!(result, Some(not_x));
902    }
903
904    // ------------------------------------------------------------------
905    // FlattenNestedAnd
906    // ------------------------------------------------------------------
907
908    #[test]
909    fn test_flatten_nested_and_fires() {
910        let a = pred("a");
911        let b = pred("b");
912        let c = pred("c");
913        let and_ab = TLExpr::And(Box::new(a.clone()), Box::new(b.clone()));
914        let and_and_ab_c = TLExpr::And(Box::new(and_ab), Box::new(c.clone()));
915
916        let result = FlattenNestedAnd.apply(&and_and_ab_c);
917        let expected = TLExpr::And(Box::new(a), Box::new(TLExpr::And(Box::new(b), Box::new(c))));
918        assert_eq!(result, Some(expected));
919    }
920
921    #[test]
922    fn test_flatten_nested_and_no_fire() {
923        let expr = TLExpr::And(Box::new(pred("a")), Box::new(pred("b")));
924        assert_eq!(FlattenNestedAnd.apply(&expr), None);
925    }
926
927    // ------------------------------------------------------------------
928    // FlattenNestedOr
929    // ------------------------------------------------------------------
930
931    #[test]
932    fn test_flatten_nested_or_fires() {
933        let a = pred("a");
934        let b = pred("b");
935        let c = pred("c");
936        let or_ab = TLExpr::Or(Box::new(a.clone()), Box::new(b.clone()));
937        let or_or_ab_c = TLExpr::Or(Box::new(or_ab), Box::new(c.clone()));
938
939        let result = FlattenNestedOr.apply(&or_or_ab_c);
940        let expected = TLExpr::Or(Box::new(a), Box::new(TLExpr::Or(Box::new(b), Box::new(c))));
941        assert_eq!(result, Some(expected));
942    }
943
944    // ------------------------------------------------------------------
945    // EliminateAndTrue
946    // ------------------------------------------------------------------
947
948    #[test]
949    fn test_eliminate_and_true_left() {
950        let x = pred("x");
951        let expr = TLExpr::And(Box::new(tru()), Box::new(x.clone()));
952        let result = EliminateAndTrue.apply(&expr);
953        assert_eq!(result, Some(x));
954    }
955
956    #[test]
957    fn test_eliminate_and_true_right() {
958        let x = pred("x");
959        let expr = TLExpr::And(Box::new(x.clone()), Box::new(tru()));
960        let result = EliminateAndTrue.apply(&expr);
961        assert_eq!(result, Some(x));
962    }
963
964    // ------------------------------------------------------------------
965    // EliminateOrFalse
966    // ------------------------------------------------------------------
967
968    #[test]
969    fn test_eliminate_or_false_left() {
970        let x = pred("x");
971        let expr = TLExpr::Or(Box::new(fal()), Box::new(x.clone()));
972        let result = EliminateOrFalse.apply(&expr);
973        assert_eq!(result, Some(x));
974    }
975
976    #[test]
977    fn test_eliminate_or_false_right() {
978        let x = pred("x");
979        let expr = TLExpr::Or(Box::new(x.clone()), Box::new(fal()));
980        let result = EliminateOrFalse.apply(&expr);
981        assert_eq!(result, Some(x));
982    }
983
984    // ------------------------------------------------------------------
985    // RewriteEngine — basic behaviour
986    // ------------------------------------------------------------------
987
988    #[test]
989    fn test_rewrite_engine_empty_rules() {
990        let expr = TLExpr::Not(Box::new(pred("p")));
991        let engine = RewriteEngine::new();
992        let (result, stats) = engine.rewrite(expr.clone());
993        assert_eq!(result, expr);
994        assert_eq!(stats.total_rewrites, 0);
995    }
996
997    #[test]
998    fn test_rewrite_engine_fixed_point() {
999        // Not(Not(x)) should reach fixed point after one application
1000        let x = pred("p");
1001        let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(x.clone()))));
1002        let engine = RewriteEngine::new().add_rule(Box::new(EliminateDoubleNeg));
1003        let (result, stats) = engine.rewrite(expr);
1004        assert_eq!(result, x);
1005        assert!(stats.fixed_point_reached);
1006    }
1007
1008    #[test]
1009    fn test_rewrite_engine_stats_record() {
1010        let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(pred("p")))));
1011        let engine = RewriteEngine::new().add_rule(Box::new(EliminateDoubleNeg));
1012        let (_result, stats) = engine.rewrite(expr);
1013        assert!(stats.total_rewrites > 0);
1014    }
1015
1016    #[test]
1017    fn test_rewrite_engine_stats_summary_nonempty() {
1018        let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(pred("p")))));
1019        let engine = RewriteEngine::new().add_rule(Box::new(EliminateDoubleNeg));
1020        let (_result, stats) = engine.rewrite(expr);
1021        let summary = stats.summary();
1022        assert!(!summary.is_empty());
1023    }
1024
1025    // ------------------------------------------------------------------
1026    // builtin_rules / add_all_builtin_rules
1027    // ------------------------------------------------------------------
1028
1029    #[test]
1030    fn test_builtin_rules_count() {
1031        assert_eq!(builtin_rules().len(), 5);
1032    }
1033
1034    #[test]
1035    fn test_add_all_builtin_rules() {
1036        let engine = RewriteEngine::new().add_all_builtin_rules();
1037        assert_eq!(engine.rules.len(), 5);
1038    }
1039
1040    // ------------------------------------------------------------------
1041    // Edge cases / limits
1042    // ------------------------------------------------------------------
1043
1044    #[test]
1045    fn test_rewrite_engine_iterations_limit() {
1046        // With max_iterations=1, the engine must not loop infinitely.
1047        let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(pred("p")))));
1048        let engine = RewriteEngine::new()
1049            .with_max_iterations(1)
1050            .add_rule(Box::new(EliminateDoubleNeg));
1051        let (_result, stats) = engine.rewrite(expr);
1052        assert!(stats.iterations <= 1);
1053    }
1054
1055    // ------------------------------------------------------------------
1056    // RewriteStats record / tracking
1057    // ------------------------------------------------------------------
1058
1059    #[test]
1060    fn test_rewrite_stats_record_rule() {
1061        let mut stats = RewriteStats::default();
1062        stats.record_rule("my_rule");
1063        assert_eq!(*stats.rules_applied.get("my_rule").unwrap_or(&0), 1);
1064        assert_eq!(stats.total_rewrites, 1);
1065    }
1066
1067    #[test]
1068    fn test_rewrite_stats_multiple_rules() {
1069        let mut stats = RewriteStats::default();
1070        stats.record_rule("rule_a");
1071        stats.record_rule("rule_b");
1072        stats.record_rule("rule_a");
1073        assert_eq!(*stats.rules_applied.get("rule_a").unwrap_or(&0), 2);
1074        assert_eq!(*stats.rules_applied.get("rule_b").unwrap_or(&0), 1);
1075        assert_eq!(stats.total_rewrites, 3);
1076    }
1077
1078    // ------------------------------------------------------------------
1079    // Complex cooperation between rules
1080    // ------------------------------------------------------------------
1081
1082    #[test]
1083    fn test_complex_rewrite() {
1084        // Not(Not(And(x, True))) → x
1085        // Step 1 (bottom-up): And(x, True) → x    (EliminateAndTrue)
1086        // Step 2 (outer):     Not(Not(x))  → x    (EliminateDoubleNeg)
1087        let x = pred("p");
1088        let and_x_true = TLExpr::And(Box::new(x.clone()), Box::new(tru()));
1089        let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(and_x_true))));
1090
1091        let engine = RewriteEngine::new().add_all_builtin_rules();
1092        let (result, stats) = engine.rewrite(expr);
1093        assert_eq!(result, x);
1094        assert!(stats.total_rewrites >= 2, "expected at least 2 rewrites");
1095    }
1096}