Skip to main content

tensorlogic_compiler/
const_prop.rs

1//! Constant Propagation pass for TLExpr trees.
2//!
3//! This pass performs **compile-time evaluation** of subexpressions whose operands
4//! are all `TLExpr::Constant(f64)` values. It complements the [`crate::dead_code`]
5//! pass (which handles structural boolean short-circuiting) by focusing on numeric
6//! arithmetic, comparison folding, and unary math folding.
7//!
8//! # Boolean Constants Convention
9//!
10//! Consistent with the rest of the codebase:
11//! - `TLExpr::Constant(1.0)` represents logical **True**
12//! - `TLExpr::Constant(0.0)` represents logical **False**
13//!
14//! Comparison operations that evaluate to a definite truth value produce one of
15//! these two constants.
16//!
17//! # Example
18//!
19//! ```rust
20//! use tensorlogic_compiler::const_prop::{ConstantPropagator, ConstPropConfig};
21//! use tensorlogic_ir::TLExpr;
22//!
23//! let propagator = ConstantPropagator::with_default();
24//! // Add(Mul(2, 3), 4) → 10
25//! let expr = TLExpr::add(
26//!     TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
27//!     TLExpr::Constant(4.0),
28//! );
29//! let (result, stats) = propagator.run(expr);
30//! assert!(matches!(result, TLExpr::Constant(v) if (v - 10.0).abs() < 1e-12));
31//! assert!(stats.arithmetic_folds >= 2);
32//! ```
33
34use tensorlogic_ir::TLExpr;
35
36// ────────────────────────────────────────────────────────────────
37// Statistics
38// ────────────────────────────────────────────────────────────────
39
40/// Statistics collected during a constant propagation run.
41#[derive(Debug, Clone, Default)]
42pub struct ConstPropStats {
43    /// Number of arithmetic binary folding operations
44    /// (e.g. `Add(1,2) → 3`, `Mul(2,3) → 6`).
45    pub arithmetic_folds: u64,
46    /// Number of comparison folding operations
47    /// (e.g. `Lt(1,2) → True`, `Eq(3,3) → True`).
48    pub comparison_folds: u64,
49    /// Number of boolean/unary constant folding operations
50    /// (e.g. `Not(True) → False`, `Abs(-3) → 3`).
51    pub boolean_folds: u64,
52    /// Total expression nodes counted before the first pass.
53    pub nodes_before: u64,
54    /// Total expression nodes counted after the last pass.
55    pub nodes_after: u64,
56    /// Number of passes executed.
57    pub passes: u32,
58}
59
60impl ConstPropStats {
61    /// Total folds across all categories.
62    pub fn total_folds(&self) -> u64 {
63        self.arithmetic_folds
64            .saturating_add(self.comparison_folds)
65            .saturating_add(self.boolean_folds)
66    }
67
68    /// Fraction of nodes removed: `(before − after) / before`.
69    ///
70    /// Returns `0.0` when `nodes_before == 0`.
71    pub fn reduction_pct(&self) -> f64 {
72        if self.nodes_before == 0 {
73            return 0.0;
74        }
75        let before = self.nodes_before as f64;
76        let after = self.nodes_after as f64;
77        (((before - after) / before) * 100.0).max(0.0)
78    }
79
80    /// Human-readable one-line summary.
81    pub fn summary(&self) -> String {
82        format!(
83            "ConstProp: {} passes, {}/{} nodes kept ({:.1}% reduction) — \
84             {} arith folds, {} cmp folds, {} bool folds",
85            self.passes,
86            self.nodes_after,
87            self.nodes_before,
88            self.reduction_pct(),
89            self.arithmetic_folds,
90            self.comparison_folds,
91            self.boolean_folds,
92        )
93    }
94}
95
96// ────────────────────────────────────────────────────────────────
97// Configuration
98// ────────────────────────────────────────────────────────────────
99
100/// Configuration for the constant propagation pass.
101#[derive(Debug, Clone)]
102pub struct ConstPropConfig {
103    /// Fold arithmetic binary operations when both operands are constants.
104    pub fold_arithmetic: bool,
105    /// Fold comparison operations when both operands are constants.
106    pub fold_comparisons: bool,
107    /// Fold unary operations (Abs, Floor, Ceil, Round, Sqrt, Exp, Log, Sin, Cos, Tan, Not) on constants.
108    pub fold_boolean: bool,
109    /// Maximum number of convergence passes to perform.
110    pub max_passes: u32,
111    /// Absolute tolerance used when comparing floats for equality (`Eq` comparison).
112    pub float_tolerance: f64,
113}
114
115impl Default for ConstPropConfig {
116    fn default() -> Self {
117        Self {
118            fold_arithmetic: true,
119            fold_comparisons: true,
120            fold_boolean: true,
121            max_passes: 20,
122            float_tolerance: 1e-12,
123        }
124    }
125}
126
127// ────────────────────────────────────────────────────────────────
128// ConstantPropagator
129// ────────────────────────────────────────────────────────────────
130
131/// The constant propagation compiler pass.
132///
133/// Performs a bottom-up sweep over a [`TLExpr`] tree, evaluating subexpressions
134/// at compile time when all operands are `TLExpr::Constant(f64)` values.
135/// Runs to a fixed point (i.e. repeated until no further changes occur) or
136/// until `config.max_passes` is reached.
137pub struct ConstantPropagator {
138    config: ConstPropConfig,
139}
140
141impl ConstantPropagator {
142    /// Create a new propagator with the supplied configuration.
143    pub fn new(config: ConstPropConfig) -> Self {
144        Self { config }
145    }
146
147    /// Create a new propagator with default configuration.
148    pub fn with_default() -> Self {
149        Self::new(ConstPropConfig::default())
150    }
151
152    /// Run constant propagation to a fixed point.
153    ///
154    /// Returns `(simplified_expr, stats)`.
155    pub fn run(&self, expr: TLExpr) -> (TLExpr, ConstPropStats) {
156        let mut stats = ConstPropStats {
157            nodes_before: Self::count_nodes(&expr),
158            ..Default::default()
159        };
160
161        let mut current = expr;
162        let mut pass_count = 0u32;
163
164        loop {
165            if pass_count >= self.config.max_passes {
166                break;
167            }
168            let (next, changed) = self.run_pass(current, &mut stats);
169            pass_count = pass_count.saturating_add(1);
170            current = next;
171            if !changed {
172                break;
173            }
174        }
175
176        stats.passes = pass_count;
177        stats.nodes_after = Self::count_nodes(&current);
178        (current, stats)
179    }
180
181    /// Execute one propagation pass over the entire tree.
182    ///
183    /// Returns `(new_expr, changed)` where `changed` indicates whether any
184    /// fold occurred during this pass.
185    fn run_pass(&self, expr: TLExpr, stats: &mut ConstPropStats) -> (TLExpr, bool) {
186        self.propagate(expr, stats)
187    }
188
189    /// Recursive bottom-up propagation.
190    ///
191    /// First recurse into children; then attempt to fold the current node.
192    /// Returns `(new_expr, changed)`.
193    fn propagate(&self, expr: TLExpr, stats: &mut ConstPropStats) -> (TLExpr, bool) {
194        match expr {
195            // ── Leaf nodes — nothing to fold ──────────────────────────────
196            TLExpr::Constant(_)
197            | TLExpr::Pred { .. }
198            | TLExpr::EmptySet
199            | TLExpr::AllDifferent { .. }
200            | TLExpr::Nominal { .. }
201            | TLExpr::Abducible { .. } => (expr, false),
202
203            // ── Arithmetic binary ops ─────────────────────────────────────
204            TLExpr::Add(lhs, rhs) => self.fold_binary_arith("Add", *lhs, *rhs, stats, TLExpr::Add),
205            TLExpr::Sub(lhs, rhs) => self.fold_binary_arith("Sub", *lhs, *rhs, stats, TLExpr::Sub),
206            TLExpr::Mul(lhs, rhs) => self.fold_binary_arith("Mul", *lhs, *rhs, stats, TLExpr::Mul),
207            TLExpr::Div(lhs, rhs) => self.fold_binary_arith("Div", *lhs, *rhs, stats, TLExpr::Div),
208            TLExpr::Pow(lhs, rhs) => self.fold_binary_arith("Pow", *lhs, *rhs, stats, TLExpr::Pow),
209            TLExpr::Mod(lhs, rhs) => self.fold_binary_arith("Mod", *lhs, *rhs, stats, TLExpr::Mod),
210            TLExpr::Min(lhs, rhs) => self.fold_binary_arith("Min", *lhs, *rhs, stats, TLExpr::Min),
211            TLExpr::Max(lhs, rhs) => self.fold_binary_arith("Max", *lhs, *rhs, stats, TLExpr::Max),
212
213            // ── Comparison ops ────────────────────────────────────────────
214            TLExpr::Eq(lhs, rhs) => self.fold_binary_cmp("Eq", *lhs, *rhs, stats, TLExpr::Eq),
215            TLExpr::Lt(lhs, rhs) => self.fold_binary_cmp("Lt", *lhs, *rhs, stats, TLExpr::Lt),
216            TLExpr::Gt(lhs, rhs) => self.fold_binary_cmp("Gt", *lhs, *rhs, stats, TLExpr::Gt),
217            TLExpr::Lte(lhs, rhs) => self.fold_binary_cmp("Lte", *lhs, *rhs, stats, TLExpr::Lte),
218            TLExpr::Gte(lhs, rhs) => self.fold_binary_cmp("Gte", *lhs, *rhs, stats, TLExpr::Gte),
219
220            // ── Unary math ops ────────────────────────────────────────────
221            TLExpr::Abs(inner) => self.fold_unary_math("Abs", *inner, stats, TLExpr::Abs),
222            TLExpr::Floor(inner) => self.fold_unary_math("Floor", *inner, stats, TLExpr::Floor),
223            TLExpr::Ceil(inner) => self.fold_unary_math("Ceil", *inner, stats, TLExpr::Ceil),
224            TLExpr::Round(inner) => self.fold_unary_math("Round", *inner, stats, TLExpr::Round),
225            TLExpr::Sqrt(inner) => self.fold_unary_math("Sqrt", *inner, stats, TLExpr::Sqrt),
226            TLExpr::Exp(inner) => self.fold_unary_math("Exp", *inner, stats, TLExpr::Exp),
227            TLExpr::Log(inner) => self.fold_unary_math("Log", *inner, stats, TLExpr::Log),
228            TLExpr::Sin(inner) => self.fold_unary_math("Sin", *inner, stats, TLExpr::Sin),
229            TLExpr::Cos(inner) => self.fold_unary_math("Cos", *inner, stats, TLExpr::Cos),
230            TLExpr::Tan(inner) => self.fold_unary_math("Tan", *inner, stats, TLExpr::Tan),
231
232            // ── Boolean / logical unary ───────────────────────────────────
233            TLExpr::Not(inner) => {
234                let (new_inner, child_changed) = self.propagate(*inner, stats);
235                if self.config.fold_boolean {
236                    if let Some(v) = Self::as_constant(&new_inner) {
237                        // Not(True) → False, Not(False) → True
238                        // More generally, Not(x) → Constant(1 - x) when x is a constant
239                        let result = TLExpr::Constant(1.0 - v);
240                        stats.boolean_folds = stats.boolean_folds.saturating_add(1);
241                        return (result, true);
242                    }
243                }
244                (TLExpr::Not(Box::new(new_inner)), child_changed)
245            }
246
247            // ── Boolean binary ops ────────────────────────────────────────
248            TLExpr::And(lhs, rhs) => {
249                let (new_lhs, cl) = self.propagate(*lhs, stats);
250                let (new_rhs, cr) = self.propagate(*rhs, stats);
251                if self.config.fold_boolean {
252                    if let (Some(a), Some(b)) =
253                        (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs))
254                    {
255                        // Treat both as booleans (non-zero = true)
256                        let result = if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 };
257                        stats.boolean_folds = stats.boolean_folds.saturating_add(1);
258                        return (TLExpr::Constant(result), true);
259                    }
260                }
261                (TLExpr::And(Box::new(new_lhs), Box::new(new_rhs)), cl || cr)
262            }
263            TLExpr::Or(lhs, rhs) => {
264                let (new_lhs, cl) = self.propagate(*lhs, stats);
265                let (new_rhs, cr) = self.propagate(*rhs, stats);
266                if self.config.fold_boolean {
267                    if let (Some(a), Some(b)) =
268                        (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs))
269                    {
270                        let result = if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 };
271                        stats.boolean_folds = stats.boolean_folds.saturating_add(1);
272                        return (TLExpr::Constant(result), true);
273                    }
274                }
275                (TLExpr::Or(Box::new(new_lhs), Box::new(new_rhs)), cl || cr)
276            }
277            TLExpr::Imply(premise, conclusion) => {
278                let (new_p, cp) = self.propagate(*premise, stats);
279                let (new_c, cc) = self.propagate(*conclusion, stats);
280                if self.config.fold_boolean {
281                    if let (Some(a), Some(b)) =
282                        (Self::as_constant(&new_p), Self::as_constant(&new_c))
283                    {
284                        // a → b  ≡  ¬a ∨ b
285                        let result = if a == 0.0 || b != 0.0 { 1.0 } else { 0.0 };
286                        stats.boolean_folds = stats.boolean_folds.saturating_add(1);
287                        return (TLExpr::Constant(result), true);
288                    }
289                }
290                (TLExpr::Imply(Box::new(new_p), Box::new(new_c)), cp || cc)
291            }
292
293            // ── If-then-else ──────────────────────────────────────────────
294            TLExpr::IfThenElse {
295                condition,
296                then_branch,
297                else_branch,
298            } => {
299                let (new_cond, cc) = self.propagate(*condition, stats);
300                let (new_then, ct) = self.propagate(*then_branch, stats);
301                let (new_else, ce) = self.propagate(*else_branch, stats);
302                if self.config.fold_boolean {
303                    if let Some(v) = Self::as_constant(&new_cond) {
304                        if v != 0.0 {
305                            // condition is truthy → take then branch
306                            stats.boolean_folds = stats.boolean_folds.saturating_add(1);
307                            return (new_then, true);
308                        } else {
309                            // condition is falsy → take else branch
310                            stats.boolean_folds = stats.boolean_folds.saturating_add(1);
311                            return (new_else, true);
312                        }
313                    }
314                }
315                let changed = cc || ct || ce;
316                (
317                    TLExpr::IfThenElse {
318                        condition: Box::new(new_cond),
319                        then_branch: Box::new(new_then),
320                        else_branch: Box::new(new_else),
321                    },
322                    changed,
323                )
324            }
325
326            // ── Score (unary passthrough) ─────────────────────────────────
327            TLExpr::Score(inner) => {
328                let (new_inner, changed) = self.propagate(*inner, stats);
329                (TLExpr::Score(Box::new(new_inner)), changed)
330            }
331
332            // ── Quantifiers — recurse into body ──────────────────────────
333            TLExpr::Exists { var, domain, body } => {
334                let (new_body, changed) = self.propagate(*body, stats);
335                (
336                    TLExpr::Exists {
337                        var,
338                        domain,
339                        body: Box::new(new_body),
340                    },
341                    changed,
342                )
343            }
344            TLExpr::ForAll { var, domain, body } => {
345                let (new_body, changed) = self.propagate(*body, stats);
346                (
347                    TLExpr::ForAll {
348                        var,
349                        domain,
350                        body: Box::new(new_body),
351                    },
352                    changed,
353                )
354            }
355
356            // ── Let binding ───────────────────────────────────────────────
357            TLExpr::Let { var, value, body } => {
358                let (new_value, cv) = self.propagate(*value, stats);
359                let (new_body, cb) = self.propagate(*body, stats);
360                (
361                    TLExpr::Let {
362                        var,
363                        value: Box::new(new_value),
364                        body: Box::new(new_body),
365                    },
366                    cv || cb,
367                )
368            }
369
370            // ── Aggregate ─────────────────────────────────────────────────
371            TLExpr::Aggregate {
372                op,
373                var,
374                domain,
375                body,
376                group_by,
377            } => {
378                let (new_body, changed) = self.propagate(*body, stats);
379                (
380                    TLExpr::Aggregate {
381                        op,
382                        var,
383                        domain,
384                        body: Box::new(new_body),
385                        group_by,
386                    },
387                    changed,
388                )
389            }
390
391            // ── Modal / temporal / fuzzy — recurse ────────────────────────
392            TLExpr::Box(inner) => {
393                let (n, c) = self.propagate(*inner, stats);
394                (TLExpr::Box(Box::new(n)), c)
395            }
396            TLExpr::Diamond(inner) => {
397                let (n, c) = self.propagate(*inner, stats);
398                (TLExpr::Diamond(Box::new(n)), c)
399            }
400            TLExpr::Next(inner) => {
401                let (n, c) = self.propagate(*inner, stats);
402                (TLExpr::Next(Box::new(n)), c)
403            }
404            TLExpr::Eventually(inner) => {
405                let (n, c) = self.propagate(*inner, stats);
406                (TLExpr::Eventually(Box::new(n)), c)
407            }
408            TLExpr::Always(inner) => {
409                let (n, c) = self.propagate(*inner, stats);
410                (TLExpr::Always(Box::new(n)), c)
411            }
412            TLExpr::Until { before, after } => {
413                let (nb, cb) = self.propagate(*before, stats);
414                let (na, ca) = self.propagate(*after, stats);
415                (
416                    TLExpr::Until {
417                        before: Box::new(nb),
418                        after: Box::new(na),
419                    },
420                    cb || ca,
421                )
422            }
423            TLExpr::Release { released, releaser } => {
424                let (nr, cr) = self.propagate(*released, stats);
425                let (nl, cl) = self.propagate(*releaser, stats);
426                (
427                    TLExpr::Release {
428                        released: Box::new(nr),
429                        releaser: Box::new(nl),
430                    },
431                    cr || cl,
432                )
433            }
434            TLExpr::WeakUntil { before, after } => {
435                let (nb, cb) = self.propagate(*before, stats);
436                let (na, ca) = self.propagate(*after, stats);
437                (
438                    TLExpr::WeakUntil {
439                        before: Box::new(nb),
440                        after: Box::new(na),
441                    },
442                    cb || ca,
443                )
444            }
445            TLExpr::StrongRelease { released, releaser } => {
446                let (nr, cr) = self.propagate(*released, stats);
447                let (nl, cl) = self.propagate(*releaser, stats);
448                (
449                    TLExpr::StrongRelease {
450                        released: Box::new(nr),
451                        releaser: Box::new(nl),
452                    },
453                    cr || cl,
454                )
455            }
456
457            TLExpr::TNorm { kind, left, right } => {
458                let (nl, cl) = self.propagate(*left, stats);
459                let (nr, cr) = self.propagate(*right, stats);
460                (
461                    TLExpr::TNorm {
462                        kind,
463                        left: Box::new(nl),
464                        right: Box::new(nr),
465                    },
466                    cl || cr,
467                )
468            }
469            TLExpr::TCoNorm { kind, left, right } => {
470                let (nl, cl) = self.propagate(*left, stats);
471                let (nr, cr) = self.propagate(*right, stats);
472                (
473                    TLExpr::TCoNorm {
474                        kind,
475                        left: Box::new(nl),
476                        right: Box::new(nr),
477                    },
478                    cl || cr,
479                )
480            }
481            TLExpr::FuzzyNot { kind, expr: inner } => {
482                let (n, c) = self.propagate(*inner, stats);
483                (
484                    TLExpr::FuzzyNot {
485                        kind,
486                        expr: Box::new(n),
487                    },
488                    c,
489                )
490            }
491            TLExpr::FuzzyImplication {
492                kind,
493                premise,
494                conclusion,
495            } => {
496                let (np, cp) = self.propagate(*premise, stats);
497                let (nc, cc) = self.propagate(*conclusion, stats);
498                (
499                    TLExpr::FuzzyImplication {
500                        kind,
501                        premise: Box::new(np),
502                        conclusion: Box::new(nc),
503                    },
504                    cp || cc,
505                )
506            }
507
508            // ── Probabilistic ─────────────────────────────────────────────
509            TLExpr::SoftExists {
510                var,
511                domain,
512                body,
513                temperature,
514            } => {
515                let (nb, changed) = self.propagate(*body, stats);
516                (
517                    TLExpr::SoftExists {
518                        var,
519                        domain,
520                        body: Box::new(nb),
521                        temperature,
522                    },
523                    changed,
524                )
525            }
526            TLExpr::SoftForAll {
527                var,
528                domain,
529                body,
530                temperature,
531            } => {
532                let (nb, changed) = self.propagate(*body, stats);
533                (
534                    TLExpr::SoftForAll {
535                        var,
536                        domain,
537                        body: Box::new(nb),
538                        temperature,
539                    },
540                    changed,
541                )
542            }
543            TLExpr::WeightedRule { weight, rule } => {
544                let (nr, changed) = self.propagate(*rule, stats);
545                (
546                    TLExpr::WeightedRule {
547                        weight,
548                        rule: Box::new(nr),
549                    },
550                    changed,
551                )
552            }
553            TLExpr::ProbabilisticChoice { alternatives } => {
554                let mut changed = false;
555                let new_alts: Vec<(f64, TLExpr)> = alternatives
556                    .into_iter()
557                    .map(|(p, e)| {
558                        let (ne, c) = self.propagate(e, stats);
559                        if c {
560                            changed = true;
561                        }
562                        (p, ne)
563                    })
564                    .collect();
565                (
566                    TLExpr::ProbabilisticChoice {
567                        alternatives: new_alts,
568                    },
569                    changed,
570                )
571            }
572
573            // ── Higher-order ──────────────────────────────────────────────
574            TLExpr::Lambda {
575                var,
576                var_type,
577                body,
578            } => {
579                let (nb, changed) = self.propagate(*body, stats);
580                (
581                    TLExpr::Lambda {
582                        var,
583                        var_type,
584                        body: Box::new(nb),
585                    },
586                    changed,
587                )
588            }
589            TLExpr::Apply { function, argument } => {
590                let (nf, cf) = self.propagate(*function, stats);
591                let (na, ca) = self.propagate(*argument, stats);
592                (
593                    TLExpr::Apply {
594                        function: Box::new(nf),
595                        argument: Box::new(na),
596                    },
597                    cf || ca,
598                )
599            }
600
601            // ── Set operations ────────────────────────────────────────────
602            TLExpr::SetMembership { element, set } => {
603                let (ne, ce) = self.propagate(*element, stats);
604                let (ns, cs) = self.propagate(*set, stats);
605                (
606                    TLExpr::SetMembership {
607                        element: Box::new(ne),
608                        set: Box::new(ns),
609                    },
610                    ce || cs,
611                )
612            }
613            TLExpr::SetUnion { left, right } => {
614                let (nl, cl) = self.propagate(*left, stats);
615                let (nr, cr) = self.propagate(*right, stats);
616                (
617                    TLExpr::SetUnion {
618                        left: Box::new(nl),
619                        right: Box::new(nr),
620                    },
621                    cl || cr,
622                )
623            }
624            TLExpr::SetIntersection { left, right } => {
625                let (nl, cl) = self.propagate(*left, stats);
626                let (nr, cr) = self.propagate(*right, stats);
627                (
628                    TLExpr::SetIntersection {
629                        left: Box::new(nl),
630                        right: Box::new(nr),
631                    },
632                    cl || cr,
633                )
634            }
635            TLExpr::SetDifference { left, right } => {
636                let (nl, cl) = self.propagate(*left, stats);
637                let (nr, cr) = self.propagate(*right, stats);
638                (
639                    TLExpr::SetDifference {
640                        left: Box::new(nl),
641                        right: Box::new(nr),
642                    },
643                    cl || cr,
644                )
645            }
646            TLExpr::SetCardinality { set } => {
647                let (ns, changed) = self.propagate(*set, stats);
648                (TLExpr::SetCardinality { set: Box::new(ns) }, changed)
649            }
650            TLExpr::SetComprehension {
651                var,
652                domain,
653                condition,
654            } => {
655                let (nc, changed) = self.propagate(*condition, stats);
656                (
657                    TLExpr::SetComprehension {
658                        var,
659                        domain,
660                        condition: Box::new(nc),
661                    },
662                    changed,
663                )
664            }
665
666            // ── Counting quantifiers ──────────────────────────────────────
667            TLExpr::CountingExists {
668                var,
669                domain,
670                body,
671                min_count,
672            } => {
673                let (nb, changed) = self.propagate(*body, stats);
674                (
675                    TLExpr::CountingExists {
676                        var,
677                        domain,
678                        body: Box::new(nb),
679                        min_count,
680                    },
681                    changed,
682                )
683            }
684            TLExpr::CountingForAll {
685                var,
686                domain,
687                body,
688                min_count,
689            } => {
690                let (nb, changed) = self.propagate(*body, stats);
691                (
692                    TLExpr::CountingForAll {
693                        var,
694                        domain,
695                        body: Box::new(nb),
696                        min_count,
697                    },
698                    changed,
699                )
700            }
701            TLExpr::ExactCount {
702                var,
703                domain,
704                body,
705                count,
706            } => {
707                let (nb, changed) = self.propagate(*body, stats);
708                (
709                    TLExpr::ExactCount {
710                        var,
711                        domain,
712                        body: Box::new(nb),
713                        count,
714                    },
715                    changed,
716                )
717            }
718            TLExpr::Majority { var, domain, body } => {
719                let (nb, changed) = self.propagate(*body, stats);
720                (
721                    TLExpr::Majority {
722                        var,
723                        domain,
724                        body: Box::new(nb),
725                    },
726                    changed,
727                )
728            }
729
730            // ── Fixed-point ───────────────────────────────────────────────
731            TLExpr::LeastFixpoint { var, body } => {
732                let (nb, changed) = self.propagate(*body, stats);
733                (
734                    TLExpr::LeastFixpoint {
735                        var,
736                        body: Box::new(nb),
737                    },
738                    changed,
739                )
740            }
741            TLExpr::GreatestFixpoint { var, body } => {
742                let (nb, changed) = self.propagate(*body, stats);
743                (
744                    TLExpr::GreatestFixpoint {
745                        var,
746                        body: Box::new(nb),
747                    },
748                    changed,
749                )
750            }
751
752            // ── Hybrid logic ──────────────────────────────────────────────
753            TLExpr::At { nominal, formula } => {
754                let (nf, changed) = self.propagate(*formula, stats);
755                (
756                    TLExpr::At {
757                        nominal,
758                        formula: Box::new(nf),
759                    },
760                    changed,
761                )
762            }
763            TLExpr::Somewhere { formula } => {
764                let (nf, changed) = self.propagate(*formula, stats);
765                (
766                    TLExpr::Somewhere {
767                        formula: Box::new(nf),
768                    },
769                    changed,
770                )
771            }
772            TLExpr::Everywhere { formula } => {
773                let (nf, changed) = self.propagate(*formula, stats);
774                (
775                    TLExpr::Everywhere {
776                        formula: Box::new(nf),
777                    },
778                    changed,
779                )
780            }
781
782            // ── Constraint programming ────────────────────────────────────
783            TLExpr::GlobalCardinality {
784                variables,
785                values,
786                min_occurrences,
787                max_occurrences,
788            } => {
789                let mut changed = false;
790                let new_values: Vec<TLExpr> = values
791                    .into_iter()
792                    .map(|e| {
793                        let (ne, c) = self.propagate(e, stats);
794                        if c {
795                            changed = true;
796                        }
797                        ne
798                    })
799                    .collect();
800                (
801                    TLExpr::GlobalCardinality {
802                        variables,
803                        values: new_values,
804                        min_occurrences,
805                        max_occurrences,
806                    },
807                    changed,
808                )
809            }
810
811            // ── Abductive reasoning ───────────────────────────────────────
812            TLExpr::Explain { formula } => {
813                let (nf, changed) = self.propagate(*formula, stats);
814                (
815                    TLExpr::Explain {
816                        formula: Box::new(nf),
817                    },
818                    changed,
819                )
820            }
821
822            // ── Pattern matching ──────────────────────────────────────────
823            TLExpr::SymbolLiteral(_) => (expr, false),
824
825            TLExpr::Match { scrutinee, arms } => {
826                let (new_scrutinee, sc) = self.propagate(*scrutinee, stats);
827                let mut any_changed = sc;
828                let new_arms = arms
829                    .into_iter()
830                    .map(|(pat, body)| {
831                        let (new_body, bc) = self.propagate(*body, stats);
832                        if bc {
833                            any_changed = true;
834                        }
835                        (pat, Box::new(new_body))
836                    })
837                    .collect();
838                (
839                    TLExpr::Match {
840                        scrutinee: Box::new(new_scrutinee),
841                        arms: new_arms,
842                    },
843                    any_changed,
844                )
845            }
846        }
847    }
848
849    // ── Helpers ───────────────────────────────────────────────────────────
850
851    /// Extract the numeric value if `expr` is a `Constant`, otherwise `None`.
852    pub fn as_constant(expr: &TLExpr) -> Option<f64> {
853        if let TLExpr::Constant(v) = expr {
854            Some(*v)
855        } else {
856            None
857        }
858    }
859
860    /// Recurse into both children of a binary arithmetic op, then try to fold.
861    ///
862    /// `ctor` is used to reconstruct the node when folding is not possible.
863    fn fold_binary_arith(
864        &self,
865        op_name: &str,
866        lhs: TLExpr,
867        rhs: TLExpr,
868        stats: &mut ConstPropStats,
869        ctor: fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
870    ) -> (TLExpr, bool) {
871        let (new_lhs, cl) = self.propagate(lhs, stats);
872        let (new_rhs, cr) = self.propagate(rhs, stats);
873        let child_changed = cl || cr;
874
875        if self.config.fold_arithmetic {
876            if let (Some(a), Some(b)) = (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs)) {
877                if let Some(folded) = self.fold_arith_binary(op_name, a, b, stats) {
878                    return (folded, true);
879                }
880            }
881        }
882        (ctor(Box::new(new_lhs), Box::new(new_rhs)), child_changed)
883    }
884
885    /// Recurse into both children of a comparison op, then try to fold.
886    fn fold_binary_cmp(
887        &self,
888        op_name: &str,
889        lhs: TLExpr,
890        rhs: TLExpr,
891        stats: &mut ConstPropStats,
892        ctor: fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
893    ) -> (TLExpr, bool) {
894        let (new_lhs, cl) = self.propagate(lhs, stats);
895        let (new_rhs, cr) = self.propagate(rhs, stats);
896        let child_changed = cl || cr;
897
898        if self.config.fold_comparisons {
899            if let (Some(a), Some(b)) = (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs)) {
900                if let Some(folded) = self.fold_comparison(op_name, a, b, stats) {
901                    return (folded, true);
902                }
903            }
904        }
905        (ctor(Box::new(new_lhs), Box::new(new_rhs)), child_changed)
906    }
907
908    /// Recurse into the child of a unary math op, then try to fold.
909    fn fold_unary_math(
910        &self,
911        op_name: &str,
912        inner: TLExpr,
913        stats: &mut ConstPropStats,
914        ctor: fn(Box<TLExpr>) -> TLExpr,
915    ) -> (TLExpr, bool) {
916        let (new_inner, child_changed) = self.propagate(inner, stats);
917
918        if self.config.fold_boolean {
919            if let Some(v) = Self::as_constant(&new_inner) {
920                let maybe_result = Self::fold_unary_math_value(op_name, v);
921                if let Some(result) = maybe_result {
922                    stats.boolean_folds = stats.boolean_folds.saturating_add(1);
923                    return (TLExpr::Constant(result), true);
924                }
925            }
926        }
927        (ctor(Box::new(new_inner)), child_changed)
928    }
929
930    /// Evaluate a unary math function on a constant, returning `None` on error
931    /// (e.g. `Log` of a negative number, `Sqrt` of negative).
932    fn fold_unary_math_value(op_name: &str, v: f64) -> Option<f64> {
933        match op_name {
934            "Abs" => Some(v.abs()),
935            "Floor" => Some(v.floor()),
936            "Ceil" => Some(v.ceil()),
937            "Round" => Some(v.round()),
938            "Sqrt" => {
939                if v < 0.0 {
940                    None
941                } else {
942                    Some(v.sqrt())
943                }
944            }
945            "Exp" => Some(v.exp()),
946            "Log" => {
947                if v <= 0.0 {
948                    None
949                } else {
950                    Some(v.ln())
951                }
952            }
953            "Sin" => Some(v.sin()),
954            "Cos" => Some(v.cos()),
955            "Tan" => Some(v.tan()),
956            _ => None,
957        }
958    }
959
960    /// Try to evaluate an arithmetic binary operation on two constant values.
961    ///
962    /// Returns `None` for division-by-zero, Mod-by-zero, and unknown ops.
963    fn fold_arith_binary(
964        &self,
965        op_name: &str,
966        lhs: f64,
967        rhs: f64,
968        stats: &mut ConstPropStats,
969    ) -> Option<TLExpr> {
970        let result = match op_name {
971            "Add" => lhs + rhs,
972            "Sub" => lhs - rhs,
973            "Mul" => lhs * rhs,
974            "Div" => {
975                if rhs.abs() < f64::EPSILON {
976                    return None; // division by zero — don't fold
977                }
978                lhs / rhs
979            }
980            "Pow" => lhs.powf(rhs),
981            "Mod" => {
982                if rhs.abs() < f64::EPSILON {
983                    return None; // mod by zero — don't fold
984                }
985                lhs % rhs
986            }
987            "Min" => lhs.min(rhs),
988            "Max" => lhs.max(rhs),
989            _ => return None,
990        };
991
992        if result.is_finite() || result.is_infinite() {
993            // We allow infinite results (e.g. 1/0 for large divisors, Pow overflow)
994            // but we already guard against div/mod by zero above.
995            stats.arithmetic_folds = stats.arithmetic_folds.saturating_add(1);
996            Some(TLExpr::Constant(result))
997        } else {
998            // NaN — don't fold
999            None
1000        }
1001    }
1002
1003    /// Try to evaluate a comparison operation on two constants, returning a
1004    /// boolean constant (`Constant(1.0)` = True, `Constant(0.0)` = False).
1005    fn fold_comparison(
1006        &self,
1007        op_name: &str,
1008        lhs: f64,
1009        rhs: f64,
1010        stats: &mut ConstPropStats,
1011    ) -> Option<TLExpr> {
1012        let bool_result: bool = match op_name {
1013            "Eq" => (lhs - rhs).abs() <= self.config.float_tolerance,
1014            "Lt" => lhs < rhs,
1015            "Gt" => lhs > rhs,
1016            "Lte" => lhs <= rhs || (lhs - rhs).abs() <= self.config.float_tolerance,
1017            "Gte" => lhs >= rhs || (lhs - rhs).abs() <= self.config.float_tolerance,
1018            _ => return None,
1019        };
1020        stats.comparison_folds = stats.comparison_folds.saturating_add(1);
1021        Some(TLExpr::Constant(if bool_result { 1.0 } else { 0.0 }))
1022    }
1023
1024    /// Count the total number of nodes in an expression tree.
1025    pub fn count_nodes(expr: &TLExpr) -> u64 {
1026        match expr {
1027            // Leaf nodes
1028            TLExpr::Constant(_)
1029            | TLExpr::EmptySet
1030            | TLExpr::AllDifferent { .. }
1031            | TLExpr::Nominal { .. }
1032            | TLExpr::Abducible { .. }
1033            | TLExpr::Pred { .. } => 1,
1034
1035            // Unary
1036            TLExpr::Not(e)
1037            | TLExpr::Score(e)
1038            | TLExpr::Abs(e)
1039            | TLExpr::Floor(e)
1040            | TLExpr::Ceil(e)
1041            | TLExpr::Round(e)
1042            | TLExpr::Sqrt(e)
1043            | TLExpr::Exp(e)
1044            | TLExpr::Log(e)
1045            | TLExpr::Sin(e)
1046            | TLExpr::Cos(e)
1047            | TLExpr::Tan(e)
1048            | TLExpr::Box(e)
1049            | TLExpr::Diamond(e)
1050            | TLExpr::Next(e)
1051            | TLExpr::Eventually(e)
1052            | TLExpr::Always(e)
1053            | TLExpr::FuzzyNot { expr: e, .. }
1054            | TLExpr::Somewhere { formula: e }
1055            | TLExpr::Everywhere { formula: e }
1056            | TLExpr::SetCardinality { set: e }
1057            | TLExpr::Explain { formula: e }
1058            | TLExpr::WeightedRule { rule: e, .. } => 1 + Self::count_nodes(e),
1059
1060            // Body quantifiers
1061            TLExpr::Exists { body: e, .. }
1062            | TLExpr::ForAll { body: e, .. }
1063            | TLExpr::SoftExists { body: e, .. }
1064            | TLExpr::SoftForAll { body: e, .. }
1065            | TLExpr::Aggregate { body: e, .. }
1066            | TLExpr::CountingExists { body: e, .. }
1067            | TLExpr::CountingForAll { body: e, .. }
1068            | TLExpr::ExactCount { body: e, .. }
1069            | TLExpr::Majority { body: e, .. }
1070            | TLExpr::LeastFixpoint { body: e, .. }
1071            | TLExpr::GreatestFixpoint { body: e, .. }
1072            | TLExpr::Lambda { body: e, .. }
1073            | TLExpr::SetComprehension { condition: e, .. }
1074            | TLExpr::At { formula: e, .. } => 1 + Self::count_nodes(e),
1075
1076            // Binary
1077            TLExpr::And(l, r)
1078            | TLExpr::Or(l, r)
1079            | TLExpr::Imply(l, r)
1080            | TLExpr::Add(l, r)
1081            | TLExpr::Sub(l, r)
1082            | TLExpr::Mul(l, r)
1083            | TLExpr::Div(l, r)
1084            | TLExpr::Pow(l, r)
1085            | TLExpr::Mod(l, r)
1086            | TLExpr::Min(l, r)
1087            | TLExpr::Max(l, r)
1088            | TLExpr::Eq(l, r)
1089            | TLExpr::Lt(l, r)
1090            | TLExpr::Gt(l, r)
1091            | TLExpr::Lte(l, r)
1092            | TLExpr::Gte(l, r)
1093            | TLExpr::Until {
1094                before: l,
1095                after: r,
1096            }
1097            | TLExpr::Release {
1098                released: l,
1099                releaser: r,
1100            }
1101            | TLExpr::WeakUntil {
1102                before: l,
1103                after: r,
1104            }
1105            | TLExpr::StrongRelease {
1106                released: l,
1107                releaser: r,
1108            }
1109            | TLExpr::SetMembership { element: l, set: r }
1110            | TLExpr::SetUnion { left: l, right: r }
1111            | TLExpr::SetIntersection { left: l, right: r }
1112            | TLExpr::SetDifference { left: l, right: r }
1113            | TLExpr::Apply {
1114                function: l,
1115                argument: r,
1116            } => 1 + Self::count_nodes(l) + Self::count_nodes(r),
1117
1118            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
1119                1 + Self::count_nodes(left) + Self::count_nodes(right)
1120            }
1121            TLExpr::FuzzyImplication {
1122                premise,
1123                conclusion,
1124                ..
1125            } => 1 + Self::count_nodes(premise) + Self::count_nodes(conclusion),
1126
1127            TLExpr::IfThenElse {
1128                condition,
1129                then_branch,
1130                else_branch,
1131            } => {
1132                1 + Self::count_nodes(condition)
1133                    + Self::count_nodes(then_branch)
1134                    + Self::count_nodes(else_branch)
1135            }
1136
1137            TLExpr::Let { value, body, .. } => {
1138                1 + Self::count_nodes(value) + Self::count_nodes(body)
1139            }
1140
1141            TLExpr::ProbabilisticChoice { alternatives } => {
1142                1 + alternatives
1143                    .iter()
1144                    .map(|(_, e)| Self::count_nodes(e))
1145                    .sum::<u64>()
1146            }
1147            TLExpr::GlobalCardinality { values, .. } => {
1148                1 + values.iter().map(Self::count_nodes).sum::<u64>()
1149            }
1150
1151            TLExpr::SymbolLiteral(_) => 1,
1152
1153            TLExpr::Match { scrutinee, arms } => {
1154                1 + Self::count_nodes(scrutinee)
1155                    + arms.iter().map(|(_, b)| Self::count_nodes(b)).sum::<u64>()
1156            }
1157        }
1158    }
1159}
1160
1161impl Default for ConstantPropagator {
1162    fn default() -> Self {
1163        Self::with_default()
1164    }
1165}
1166
1167// ────────────────────────────────────────────────────────────────
1168// Tests
1169// ────────────────────────────────────────────────────────────────
1170
1171#[cfg(test)]
1172mod tests {
1173    use super::*;
1174    use tensorlogic_ir::TLExpr;
1175
1176    fn propagator() -> ConstantPropagator {
1177        ConstantPropagator::with_default()
1178    }
1179
1180    fn assert_constant(expr: &TLExpr, expected: f64) {
1181        match expr {
1182            TLExpr::Constant(v) => {
1183                let diff = (v - expected).abs();
1184                assert!(diff < 1e-9, "Expected constant {}, got {}", expected, v);
1185            }
1186            other => panic!("Expected Constant({}), got {:?}", expected, other),
1187        }
1188    }
1189
1190    // ── 1. Constant returns itself ──────────────────────────────────────
1191    #[test]
1192    fn test_constant_returns_itself() {
1193        let (result, stats) = propagator().run(TLExpr::Constant(3.0));
1194        assert_constant(&result, 3.0);
1195        assert_eq!(stats.total_folds(), 0);
1196    }
1197
1198    // ── 2. Add two constants ─────────────────────────────────────────────
1199    #[test]
1200    fn test_add_two_constants() {
1201        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1202        let (result, stats) = propagator().run(expr);
1203        assert_constant(&result, 5.0);
1204        assert!(stats.arithmetic_folds >= 1);
1205    }
1206
1207    // ── 3. Sub two constants ─────────────────────────────────────────────
1208    #[test]
1209    fn test_sub_two_constants() {
1210        let expr = TLExpr::sub(TLExpr::Constant(5.0), TLExpr::Constant(3.0));
1211        let (result, _) = propagator().run(expr);
1212        assert_constant(&result, 2.0);
1213    }
1214
1215    // ── 4. Mul two constants ─────────────────────────────────────────────
1216    #[test]
1217    fn test_mul_two_constants() {
1218        let expr = TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(4.0));
1219        let (result, _) = propagator().run(expr);
1220        assert_constant(&result, 8.0);
1221    }
1222
1223    // ── 5. Div two constants ─────────────────────────────────────────────
1224    #[test]
1225    fn test_div_two_constants() {
1226        let expr = TLExpr::div(TLExpr::Constant(6.0), TLExpr::Constant(2.0));
1227        let (result, _) = propagator().run(expr);
1228        assert_constant(&result, 3.0);
1229    }
1230
1231    // ── 6. Div by zero is not folded ─────────────────────────────────────
1232    #[test]
1233    fn test_div_by_zero_no_fold() {
1234        let x = TLExpr::pred("x", vec![]);
1235        let expr = TLExpr::div(x, TLExpr::Constant(0.0));
1236        let (result, stats) = propagator().run(expr);
1237        // Should NOT be Constant; the Div node must be preserved
1238        assert!(!matches!(result, TLExpr::Constant(_)));
1239        assert_eq!(stats.arithmetic_folds, 0);
1240    }
1241
1242    // ── 7. Neg (Unary) constant — using Sub(0, x) pattern ─────────────────
1243    // TLExpr has no Neg variant; fold Sub(0, const) instead.
1244    #[test]
1245    fn test_neg_via_sub_constant() {
1246        let expr = TLExpr::sub(TLExpr::Constant(0.0), TLExpr::Constant(3.0));
1247        let (result, _) = propagator().run(expr);
1248        assert_constant(&result, -3.0);
1249    }
1250
1251    // ── 7b. Abs constant ──────────────────────────────────────────────────
1252    #[test]
1253    fn test_abs_constant() {
1254        let expr = TLExpr::abs(TLExpr::Constant(-5.0));
1255        let (result, stats) = propagator().run(expr);
1256        assert_constant(&result, 5.0);
1257        assert!(stats.boolean_folds >= 1);
1258    }
1259
1260    // ── 8. Nested arithmetic — two passes required ───────────────────────
1261    #[test]
1262    fn test_nested_arithmetic() {
1263        // Add(Mul(2,3), 4) → Add(6, 4) → 10
1264        let expr = TLExpr::add(
1265            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
1266            TLExpr::Constant(4.0),
1267        );
1268        let (result, stats) = propagator().run(expr);
1269        assert_constant(&result, 10.0);
1270        assert!(stats.arithmetic_folds >= 2);
1271    }
1272
1273    // ── 9. Comparison Lt → True ──────────────────────────────────────────
1274    #[test]
1275    fn test_comparison_lt_true() {
1276        let expr = TLExpr::Lt(
1277            Box::new(TLExpr::Constant(1.0)),
1278            Box::new(TLExpr::Constant(2.0)),
1279        );
1280        let (result, stats) = propagator().run(expr);
1281        assert_constant(&result, 1.0); // True = 1.0
1282        assert!(stats.comparison_folds >= 1);
1283    }
1284
1285    // ── 10. Comparison Gt → False ────────────────────────────────────────
1286    #[test]
1287    fn test_comparison_gt_false() {
1288        let expr = TLExpr::Gt(
1289            Box::new(TLExpr::Constant(1.0)),
1290            Box::new(TLExpr::Constant(2.0)),
1291        );
1292        let (result, stats) = propagator().run(expr);
1293        assert_constant(&result, 0.0); // False = 0.0
1294        assert!(stats.comparison_folds >= 1);
1295    }
1296
1297    // ── 11. Stats arithmetic_folds > 0 ───────────────────────────────────
1298    #[test]
1299    fn test_const_prop_stats_counts() {
1300        let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0));
1301        let (_, stats) = propagator().run(expr);
1302        assert!(stats.arithmetic_folds > 0, "Expected arithmetic_folds > 0");
1303    }
1304
1305    // ── 12. Stats summary non-empty ──────────────────────────────────────
1306    #[test]
1307    fn test_const_prop_stats_summary() {
1308        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1309        let (_, stats) = propagator().run(expr);
1310        let summary = stats.summary();
1311        assert!(!summary.is_empty(), "Expected non-empty summary");
1312        assert!(summary.contains("ConstProp"));
1313    }
1314
1315    // ── 13. Config default max_passes == 20 ──────────────────────────────
1316    #[test]
1317    fn test_const_prop_config_default() {
1318        let config = ConstPropConfig::default();
1319        assert_eq!(config.max_passes, 20);
1320        assert!(config.fold_arithmetic);
1321        assert!(config.fold_comparisons);
1322        assert!(config.fold_boolean);
1323        assert!((config.float_tolerance - 1e-12).abs() < 1e-20);
1324    }
1325
1326    // ── 14. Disabled fold — arithmetic off ───────────────────────────────
1327    #[test]
1328    fn test_disabled_fold() {
1329        let config = ConstPropConfig {
1330            fold_arithmetic: false,
1331            ..Default::default()
1332        };
1333        let prop = ConstantPropagator::new(config);
1334        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1335        let (result, stats) = prop.run(expr);
1336        // Should NOT be folded
1337        assert!(!matches!(result, TLExpr::Constant(_)));
1338        assert_eq!(stats.arithmetic_folds, 0);
1339    }
1340
1341    // ── 15. Fixed point — idempotent after first pass ─────────────────────
1342    #[test]
1343    fn test_fixed_point() {
1344        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1345        let (result1, _) = propagator().run(expr);
1346        let (result2, stats2) = propagator().run(result1.clone());
1347        // Second run should produce same result with 0 additional folds
1348        assert_eq!(stats2.total_folds(), 0);
1349        if let TLExpr::Constant(v1) = result1 {
1350            if let TLExpr::Constant(v2) = result2 {
1351                assert!((v1 - v2).abs() < 1e-12);
1352            } else {
1353                panic!("Expected Constant in second run");
1354            }
1355        } else {
1356            panic!("Expected Constant in first run");
1357        }
1358    }
1359
1360    // ── 16. Passes count >= 1 ────────────────────────────────────────────
1361    #[test]
1362    fn test_passes_count() {
1363        let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1364        let (_, stats) = propagator().run(expr);
1365        assert!(stats.passes >= 1, "Expected at least 1 pass");
1366    }
1367
1368    // ── 17. Reduction pct — nodes_after < nodes_before ───────────────────
1369    #[test]
1370    fn test_reduction_pct() {
1371        // Add(Mul(2,3), 4) has 5 nodes, result has 1
1372        let expr = TLExpr::add(
1373            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
1374            TLExpr::Constant(4.0),
1375        );
1376        let (_, stats) = propagator().run(expr);
1377        assert!(stats.nodes_before > stats.nodes_after);
1378        assert!(stats.reduction_pct() > 0.0);
1379    }
1380
1381    // ── 18. Non-constant unchanged ───────────────────────────────────────
1382    #[test]
1383    fn test_non_constant_unchanged() {
1384        let expr = TLExpr::pred("x", vec![]);
1385        let (result, stats) = propagator().run(expr.clone());
1386        assert_eq!(stats.total_folds(), 0);
1387        // Result should still be a Pred
1388        assert!(matches!(result, TLExpr::Pred { .. }));
1389    }
1390
1391    // ── 19. Mixed expr — can't fold ──────────────────────────────────────
1392    #[test]
1393    fn test_mixed_expr() {
1394        // Add(Constant(2), Pred("x")) — cannot fold because rhs is not a constant
1395        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::pred("x", vec![]));
1396        let (result, stats) = propagator().run(expr);
1397        assert!(matches!(result, TLExpr::Add(_, _)));
1398        assert_eq!(stats.arithmetic_folds, 0);
1399    }
1400
1401    // ── 20. Compose with DCE ─────────────────────────────────────────────
1402    #[test]
1403    fn test_const_prop_with_dead_code() {
1404        use crate::dead_code::{DceConfig, DeadCodeEliminator};
1405
1406        // And(True, Add(1, 2)):
1407        //   - const_prop folds Add(1,2) → 3, yielding And(Constant(1.0), Constant(3.0))
1408        //   - const_prop further folds the And (both operands are constants) → Constant(1.0)
1409        //     (AND of two non-zero constants = True = 1.0)
1410        // Running DCE on top should be a no-op but not fail.
1411        let inner = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1412        let expr = TLExpr::and(TLExpr::Constant(1.0), inner);
1413
1414        let (after_cp, cp_stats) = propagator().run(expr);
1415        // Const prop should have folded at least the Add
1416        assert!(cp_stats.total_folds() >= 1);
1417
1418        let dce = DeadCodeEliminator::new(DceConfig::default());
1419        let (after_dce, _dce_stats) = dce.run(after_cp);
1420        // The pipeline should converge to a constant (1.0 = True for And(True, 3))
1421        assert!(matches!(after_dce, TLExpr::Constant(_)));
1422    }
1423
1424    // ── Pow folding ───────────────────────────────────────────────────────
1425    #[test]
1426    fn test_pow_two_constants() {
1427        let expr = TLExpr::pow(TLExpr::Constant(2.0), TLExpr::Constant(10.0));
1428        let (result, _) = propagator().run(expr);
1429        assert_constant(&result, 1024.0);
1430    }
1431
1432    // ── Min / Max folding ─────────────────────────────────────────────────
1433    #[test]
1434    fn test_min_max_constants() {
1435        let min_expr = TLExpr::min(TLExpr::Constant(3.0), TLExpr::Constant(7.0));
1436        let (min_result, _) = propagator().run(min_expr);
1437        assert_constant(&min_result, 3.0);
1438
1439        let max_expr = TLExpr::max(TLExpr::Constant(3.0), TLExpr::Constant(7.0));
1440        let (max_result, _) = propagator().run(max_expr);
1441        assert_constant(&max_result, 7.0);
1442    }
1443
1444    // ── Eq comparison with tolerance ─────────────────────────────────────
1445    #[test]
1446    fn test_comparison_eq_true() {
1447        let a = 1.0_f64;
1448        let b = a + 1e-13; // within tolerance
1449        let expr = TLExpr::Eq(Box::new(TLExpr::Constant(a)), Box::new(TLExpr::Constant(b)));
1450        let (result, stats) = propagator().run(expr);
1451        assert_constant(&result, 1.0); // True
1452        assert!(stats.comparison_folds >= 1);
1453    }
1454
1455    // ── count_nodes smoke test ────────────────────────────────────────────
1456    #[test]
1457    fn test_count_nodes() {
1458        assert_eq!(ConstantPropagator::count_nodes(&TLExpr::Constant(1.0)), 1);
1459        let binary = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1460        assert_eq!(ConstantPropagator::count_nodes(&binary), 3);
1461    }
1462}