tensorlogic_ir/expr/
optimization.rs

1//! Expression-level optimizations (constant folding, algebraic simplification).
2
3use super::TLExpr;
4
5/// Constant folding: evaluate constant expressions at compile time
6pub fn constant_fold(expr: &TLExpr) -> TLExpr {
7    match expr {
8        // Binary arithmetic operations on constants
9        TLExpr::Add(l, r) => {
10            let left = constant_fold(l);
11            let right = constant_fold(r);
12            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
13                return TLExpr::Constant(lv + rv);
14            }
15            TLExpr::Add(Box::new(left), Box::new(right))
16        }
17        TLExpr::Sub(l, r) => {
18            let left = constant_fold(l);
19            let right = constant_fold(r);
20            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
21                return TLExpr::Constant(lv - rv);
22            }
23            TLExpr::Sub(Box::new(left), Box::new(right))
24        }
25        TLExpr::Mul(l, r) => {
26            let left = constant_fold(l);
27            let right = constant_fold(r);
28            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
29                return TLExpr::Constant(lv * rv);
30            }
31            TLExpr::Mul(Box::new(left), Box::new(right))
32        }
33        TLExpr::Div(l, r) => {
34            let left = constant_fold(l);
35            let right = constant_fold(r);
36            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
37                if *rv != 0.0 {
38                    return TLExpr::Constant(lv / rv);
39                }
40            }
41            TLExpr::Div(Box::new(left), Box::new(right))
42        }
43        TLExpr::Pow(l, r) => {
44            let left = constant_fold(l);
45            let right = constant_fold(r);
46            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
47                return TLExpr::Constant(lv.powf(*rv));
48            }
49            TLExpr::Pow(Box::new(left), Box::new(right))
50        }
51        TLExpr::Mod(l, r) => {
52            let left = constant_fold(l);
53            let right = constant_fold(r);
54            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
55                return TLExpr::Constant(lv % rv);
56            }
57            TLExpr::Mod(Box::new(left), Box::new(right))
58        }
59        TLExpr::Min(l, r) => {
60            let left = constant_fold(l);
61            let right = constant_fold(r);
62            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
63                return TLExpr::Constant(lv.min(*rv));
64            }
65            TLExpr::Min(Box::new(left), Box::new(right))
66        }
67        TLExpr::Max(l, r) => {
68            let left = constant_fold(l);
69            let right = constant_fold(r);
70            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
71                return TLExpr::Constant(lv.max(*rv));
72            }
73            TLExpr::Max(Box::new(left), Box::new(right))
74        }
75
76        // Unary mathematical operations on constants
77        TLExpr::Abs(e) => {
78            let inner = constant_fold(e);
79            if let TLExpr::Constant(v) = &inner {
80                return TLExpr::Constant(v.abs());
81            }
82            TLExpr::Abs(Box::new(inner))
83        }
84        TLExpr::Floor(e) => {
85            let inner = constant_fold(e);
86            if let TLExpr::Constant(v) = &inner {
87                return TLExpr::Constant(v.floor());
88            }
89            TLExpr::Floor(Box::new(inner))
90        }
91        TLExpr::Ceil(e) => {
92            let inner = constant_fold(e);
93            if let TLExpr::Constant(v) = &inner {
94                return TLExpr::Constant(v.ceil());
95            }
96            TLExpr::Ceil(Box::new(inner))
97        }
98        TLExpr::Round(e) => {
99            let inner = constant_fold(e);
100            if let TLExpr::Constant(v) = &inner {
101                return TLExpr::Constant(v.round());
102            }
103            TLExpr::Round(Box::new(inner))
104        }
105        TLExpr::Sqrt(e) => {
106            let inner = constant_fold(e);
107            if let TLExpr::Constant(v) = &inner {
108                if *v >= 0.0 {
109                    return TLExpr::Constant(v.sqrt());
110                }
111            }
112            TLExpr::Sqrt(Box::new(inner))
113        }
114        TLExpr::Exp(e) => {
115            let inner = constant_fold(e);
116            if let TLExpr::Constant(v) = &inner {
117                return TLExpr::Constant(v.exp());
118            }
119            TLExpr::Exp(Box::new(inner))
120        }
121        TLExpr::Log(e) => {
122            let inner = constant_fold(e);
123            if let TLExpr::Constant(v) = &inner {
124                if *v > 0.0 {
125                    return TLExpr::Constant(v.ln());
126                }
127            }
128            TLExpr::Log(Box::new(inner))
129        }
130        TLExpr::Sin(e) => {
131            let inner = constant_fold(e);
132            if let TLExpr::Constant(v) = &inner {
133                return TLExpr::Constant(v.sin());
134            }
135            TLExpr::Sin(Box::new(inner))
136        }
137        TLExpr::Cos(e) => {
138            let inner = constant_fold(e);
139            if let TLExpr::Constant(v) = &inner {
140                return TLExpr::Constant(v.cos());
141            }
142            TLExpr::Cos(Box::new(inner))
143        }
144        TLExpr::Tan(e) => {
145            let inner = constant_fold(e);
146            if let TLExpr::Constant(v) = &inner {
147                return TLExpr::Constant(v.tan());
148            }
149            TLExpr::Tan(Box::new(inner))
150        }
151
152        TLExpr::Box(e) => TLExpr::Box(Box::new(constant_fold(e))),
153        TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(constant_fold(e))),
154        TLExpr::Next(e) => TLExpr::Next(Box::new(constant_fold(e))),
155        TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(constant_fold(e))),
156        TLExpr::Always(e) => TLExpr::Always(Box::new(constant_fold(e))),
157        TLExpr::Until { before, after } => TLExpr::Until {
158            before: Box::new(constant_fold(before)),
159            after: Box::new(constant_fold(after)),
160        },
161
162        // Fuzzy logic operators
163        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
164            kind: *kind,
165            left: Box::new(constant_fold(left)),
166            right: Box::new(constant_fold(right)),
167        },
168        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
169            kind: *kind,
170            left: Box::new(constant_fold(left)),
171            right: Box::new(constant_fold(right)),
172        },
173        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
174            kind: *kind,
175            expr: Box::new(constant_fold(expr)),
176        },
177        TLExpr::FuzzyImplication {
178            kind,
179            premise,
180            conclusion,
181        } => TLExpr::FuzzyImplication {
182            kind: *kind,
183            premise: Box::new(constant_fold(premise)),
184            conclusion: Box::new(constant_fold(conclusion)),
185        },
186
187        // Probabilistic operators
188        TLExpr::SoftExists {
189            var,
190            domain,
191            body,
192            temperature,
193        } => TLExpr::SoftExists {
194            var: var.clone(),
195            domain: domain.clone(),
196            body: Box::new(constant_fold(body)),
197            temperature: *temperature,
198        },
199        TLExpr::SoftForAll {
200            var,
201            domain,
202            body,
203            temperature,
204        } => TLExpr::SoftForAll {
205            var: var.clone(),
206            domain: domain.clone(),
207            body: Box::new(constant_fold(body)),
208            temperature: *temperature,
209        },
210        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
211            weight: *weight,
212            rule: Box::new(constant_fold(rule)),
213        },
214        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
215            alternatives: alternatives
216                .iter()
217                .map(|(p, e)| (*p, constant_fold(e)))
218                .collect(),
219        },
220
221        // Extended temporal logic
222        TLExpr::Release { released, releaser } => TLExpr::Release {
223            released: Box::new(constant_fold(released)),
224            releaser: Box::new(constant_fold(releaser)),
225        },
226        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
227            before: Box::new(constant_fold(before)),
228            after: Box::new(constant_fold(after)),
229        },
230        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
231            released: Box::new(constant_fold(released)),
232            releaser: Box::new(constant_fold(releaser)),
233        },
234
235        // Comparison operations on constants
236        TLExpr::Eq(l, r) => {
237            let left = constant_fold(l);
238            let right = constant_fold(r);
239            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
240                return TLExpr::Constant(if (lv - rv).abs() < f64::EPSILON {
241                    1.0
242                } else {
243                    0.0
244                });
245            }
246            TLExpr::Eq(Box::new(left), Box::new(right))
247        }
248        TLExpr::Lt(l, r) => {
249            let left = constant_fold(l);
250            let right = constant_fold(r);
251            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
252                return TLExpr::Constant(if lv < rv { 1.0 } else { 0.0 });
253            }
254            TLExpr::Lt(Box::new(left), Box::new(right))
255        }
256        TLExpr::Gt(l, r) => {
257            let left = constant_fold(l);
258            let right = constant_fold(r);
259            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
260                return TLExpr::Constant(if lv > rv { 1.0 } else { 0.0 });
261            }
262            TLExpr::Gt(Box::new(left), Box::new(right))
263        }
264        TLExpr::Lte(l, r) => {
265            let left = constant_fold(l);
266            let right = constant_fold(r);
267            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
268                return TLExpr::Constant(if lv <= rv { 1.0 } else { 0.0 });
269            }
270            TLExpr::Lte(Box::new(left), Box::new(right))
271        }
272        TLExpr::Gte(l, r) => {
273            let left = constant_fold(l);
274            let right = constant_fold(r);
275            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
276                return TLExpr::Constant(if lv >= rv { 1.0 } else { 0.0 });
277            }
278            TLExpr::Gte(Box::new(left), Box::new(right))
279        }
280
281        // Logical operations
282        TLExpr::Not(e) => {
283            let inner = constant_fold(e);
284            if let TLExpr::Constant(v) = &inner {
285                return TLExpr::Constant(1.0 - v);
286            }
287            TLExpr::Not(Box::new(inner))
288        }
289        TLExpr::And(l, r) => {
290            let left = constant_fold(l);
291            let right = constant_fold(r);
292            TLExpr::And(Box::new(left), Box::new(right))
293        }
294        TLExpr::Or(l, r) => {
295            let left = constant_fold(l);
296            let right = constant_fold(r);
297            TLExpr::Or(Box::new(left), Box::new(right))
298        }
299        TLExpr::Imply(l, r) => {
300            let left = constant_fold(l);
301            let right = constant_fold(r);
302            TLExpr::Imply(Box::new(left), Box::new(right))
303        }
304
305        // Recursive folding for other operations
306        TLExpr::Score(e) => TLExpr::Score(Box::new(constant_fold(e))),
307        TLExpr::Exists { var, domain, body } => TLExpr::Exists {
308            var: var.clone(),
309            domain: domain.clone(),
310            body: Box::new(constant_fold(body)),
311        },
312        TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
313            var: var.clone(),
314            domain: domain.clone(),
315            body: Box::new(constant_fold(body)),
316        },
317        TLExpr::Aggregate {
318            op,
319            var,
320            domain,
321            body,
322            group_by,
323        } => TLExpr::Aggregate {
324            op: op.clone(),
325            var: var.clone(),
326            domain: domain.clone(),
327            body: Box::new(constant_fold(body)),
328            group_by: group_by.clone(),
329        },
330        TLExpr::IfThenElse {
331            condition,
332            then_branch,
333            else_branch,
334        } => TLExpr::IfThenElse {
335            condition: Box::new(constant_fold(condition)),
336            then_branch: Box::new(constant_fold(then_branch)),
337            else_branch: Box::new(constant_fold(else_branch)),
338        },
339        TLExpr::Let { var, value, body } => TLExpr::Let {
340            var: var.clone(),
341            value: Box::new(constant_fold(value)),
342            body: Box::new(constant_fold(body)),
343        },
344
345        // Leaves - no folding needed
346        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
347    }
348}
349
350/// Algebraic simplification: apply algebraic identities and simplification rules
351pub fn algebraic_simplify(expr: &TLExpr) -> TLExpr {
352    match expr {
353        // Addition identities
354        TLExpr::Add(l, r) => {
355            let left = algebraic_simplify(l);
356            let right = algebraic_simplify(r);
357
358            // x + 0 = x
359            if let TLExpr::Constant(0.0) = right {
360                return left;
361            }
362            // 0 + x = x
363            if let TLExpr::Constant(0.0) = left {
364                return right;
365            }
366
367            TLExpr::Add(Box::new(left), Box::new(right))
368        }
369
370        // Subtraction identities
371        TLExpr::Sub(l, r) => {
372            let left = algebraic_simplify(l);
373            let right = algebraic_simplify(r);
374
375            // x - 0 = x
376            if let TLExpr::Constant(0.0) = right {
377                return left;
378            }
379            // x - x = 0 (simplified form comparison)
380            if left == right {
381                return TLExpr::Constant(0.0);
382            }
383
384            TLExpr::Sub(Box::new(left), Box::new(right))
385        }
386
387        // Multiplication identities
388        TLExpr::Mul(l, r) => {
389            let left = algebraic_simplify(l);
390            let right = algebraic_simplify(r);
391
392            // x * 0 = 0
393            if let TLExpr::Constant(0.0) = right {
394                return TLExpr::Constant(0.0);
395            }
396            if let TLExpr::Constant(0.0) = left {
397                return TLExpr::Constant(0.0);
398            }
399
400            // x * 1 = x
401            if let TLExpr::Constant(1.0) = right {
402                return left;
403            }
404            // 1 * x = x
405            if let TLExpr::Constant(1.0) = left {
406                return right;
407            }
408
409            TLExpr::Mul(Box::new(left), Box::new(right))
410        }
411
412        // Division identities
413        TLExpr::Div(l, r) => {
414            let left = algebraic_simplify(l);
415            let right = algebraic_simplify(r);
416
417            // x / 1 = x
418            if let TLExpr::Constant(1.0) = right {
419                return left;
420            }
421
422            // 0 / x = 0 (assuming x != 0)
423            if let TLExpr::Constant(0.0) = left {
424                if let TLExpr::Constant(rv) = right {
425                    if rv != 0.0 {
426                        return TLExpr::Constant(0.0);
427                    }
428                }
429            }
430
431            // x / x = 1 (assuming x != 0)
432            // Only apply for constants to avoid division by zero issues
433            if left == right {
434                if let TLExpr::Constant(v) = left {
435                    if v != 0.0 {
436                        return TLExpr::Constant(1.0);
437                    }
438                }
439            }
440
441            TLExpr::Div(Box::new(left), Box::new(right))
442        }
443
444        // Power identities
445        TLExpr::Pow(l, r) => {
446            let left = algebraic_simplify(l);
447            let right = algebraic_simplify(r);
448
449            // x ^ 0 = 1
450            if let TLExpr::Constant(0.0) = right {
451                return TLExpr::Constant(1.0);
452            }
453            // x ^ 1 = x
454            if let TLExpr::Constant(1.0) = right {
455                return left;
456            }
457            // 0 ^ x = 0 (for x > 0)
458            if let TLExpr::Constant(0.0) = left {
459                if let TLExpr::Constant(rv) = right {
460                    if rv > 0.0 {
461                        return TLExpr::Constant(0.0);
462                    }
463                }
464            }
465            // 1 ^ x = 1
466            if let TLExpr::Constant(1.0) = left {
467                return TLExpr::Constant(1.0);
468            }
469
470            TLExpr::Pow(Box::new(left), Box::new(right))
471        }
472
473        // Double negation: NOT(NOT(x)) = x
474        TLExpr::Not(e) => {
475            let inner = algebraic_simplify(e);
476            if let TLExpr::Not(inner_inner) = &inner {
477                return *inner_inner.clone();
478            }
479            TLExpr::Not(Box::new(inner))
480        }
481
482        // Recursively simplify other operations
483        TLExpr::Mod(l, r) => {
484            let left = algebraic_simplify(l);
485            let right = algebraic_simplify(r);
486            TLExpr::Mod(Box::new(left), Box::new(right))
487        }
488        TLExpr::Min(l, r) => {
489            let left = algebraic_simplify(l);
490            let right = algebraic_simplify(r);
491            TLExpr::Min(Box::new(left), Box::new(right))
492        }
493        TLExpr::Max(l, r) => {
494            let left = algebraic_simplify(l);
495            let right = algebraic_simplify(r);
496            TLExpr::Max(Box::new(left), Box::new(right))
497        }
498        TLExpr::Abs(e) => TLExpr::Abs(Box::new(algebraic_simplify(e))),
499        TLExpr::Floor(e) => TLExpr::Floor(Box::new(algebraic_simplify(e))),
500        TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(algebraic_simplify(e))),
501        TLExpr::Round(e) => TLExpr::Round(Box::new(algebraic_simplify(e))),
502        TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(algebraic_simplify(e))),
503        // Modal logic simplifications
504        TLExpr::Box(e) => {
505            let inner = algebraic_simplify(e);
506
507            // □(TRUE) = TRUE, □(FALSE) = FALSE
508            if let TLExpr::Constant(v) = inner {
509                return TLExpr::Constant(v);
510            }
511
512            TLExpr::Box(Box::new(inner))
513        }
514        TLExpr::Diamond(e) => {
515            let inner = algebraic_simplify(e);
516
517            // ◇(TRUE) = TRUE, ◇(FALSE) = FALSE
518            if let TLExpr::Constant(v) = inner {
519                return TLExpr::Constant(v);
520            }
521
522            TLExpr::Diamond(Box::new(inner))
523        }
524
525        // Temporal logic simplifications
526        TLExpr::Next(e) => {
527            let inner = algebraic_simplify(e);
528
529            // X(TRUE) = TRUE, X(FALSE) = FALSE
530            if let TLExpr::Constant(v) = inner {
531                return TLExpr::Constant(v);
532            }
533
534            TLExpr::Next(Box::new(inner))
535        }
536        TLExpr::Eventually(e) => {
537            let inner = algebraic_simplify(e);
538
539            // F(TRUE) = TRUE, F(FALSE) = FALSE
540            if let TLExpr::Constant(v) = inner {
541                return TLExpr::Constant(v);
542            }
543
544            // Idempotence: F(F(P)) = F(P)
545            if let TLExpr::Eventually(inner_inner) = &inner {
546                return TLExpr::Eventually(inner_inner.clone());
547            }
548
549            TLExpr::Eventually(Box::new(inner))
550        }
551        TLExpr::Always(e) => {
552            let inner = algebraic_simplify(e);
553
554            // G(TRUE) = TRUE, G(FALSE) = FALSE
555            if let TLExpr::Constant(v) = inner {
556                return TLExpr::Constant(v);
557            }
558
559            // Idempotence: G(G(P)) = G(P)
560            if let TLExpr::Always(inner_inner) = &inner {
561                return TLExpr::Always(inner_inner.clone());
562            }
563
564            TLExpr::Always(Box::new(inner))
565        }
566        TLExpr::Until { before, after } => {
567            let before_simplified = algebraic_simplify(before);
568            let after_simplified = algebraic_simplify(after);
569
570            // P U TRUE = TRUE (after becomes immediately true)
571            if let TLExpr::Constant(1.0) = after_simplified {
572                return TLExpr::Constant(1.0);
573            }
574
575            // FALSE U P = F(P) (before is never true, so we just wait for after)
576            if let TLExpr::Constant(0.0) = before_simplified {
577                return TLExpr::Eventually(Box::new(after_simplified));
578            }
579
580            TLExpr::Until {
581                before: Box::new(before_simplified),
582                after: Box::new(after_simplified),
583            }
584        }
585
586        // Fuzzy logic operators - pass through with recursive simplification
587        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
588            kind: *kind,
589            left: Box::new(algebraic_simplify(left)),
590            right: Box::new(algebraic_simplify(right)),
591        },
592        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
593            kind: *kind,
594            left: Box::new(algebraic_simplify(left)),
595            right: Box::new(algebraic_simplify(right)),
596        },
597        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
598            kind: *kind,
599            expr: Box::new(algebraic_simplify(expr)),
600        },
601        TLExpr::FuzzyImplication {
602            kind,
603            premise,
604            conclusion,
605        } => TLExpr::FuzzyImplication {
606            kind: *kind,
607            premise: Box::new(algebraic_simplify(premise)),
608            conclusion: Box::new(algebraic_simplify(conclusion)),
609        },
610
611        // Probabilistic operators - pass through
612        TLExpr::SoftExists {
613            var,
614            domain,
615            body,
616            temperature,
617        } => TLExpr::SoftExists {
618            var: var.clone(),
619            domain: domain.clone(),
620            body: Box::new(algebraic_simplify(body)),
621            temperature: *temperature,
622        },
623        TLExpr::SoftForAll {
624            var,
625            domain,
626            body,
627            temperature,
628        } => TLExpr::SoftForAll {
629            var: var.clone(),
630            domain: domain.clone(),
631            body: Box::new(algebraic_simplify(body)),
632            temperature: *temperature,
633        },
634        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
635            weight: *weight,
636            rule: Box::new(algebraic_simplify(rule)),
637        },
638        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
639            alternatives: alternatives
640                .iter()
641                .map(|(p, e)| (*p, algebraic_simplify(e)))
642                .collect(),
643        },
644
645        // Extended temporal logic - pass through
646        TLExpr::Release { released, releaser } => TLExpr::Release {
647            released: Box::new(algebraic_simplify(released)),
648            releaser: Box::new(algebraic_simplify(releaser)),
649        },
650        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
651            before: Box::new(algebraic_simplify(before)),
652            after: Box::new(algebraic_simplify(after)),
653        },
654        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
655            released: Box::new(algebraic_simplify(released)),
656            releaser: Box::new(algebraic_simplify(releaser)),
657        },
658
659        TLExpr::Exp(e) => TLExpr::Exp(Box::new(algebraic_simplify(e))),
660        TLExpr::Log(e) => TLExpr::Log(Box::new(algebraic_simplify(e))),
661        TLExpr::Sin(e) => TLExpr::Sin(Box::new(algebraic_simplify(e))),
662        TLExpr::Cos(e) => TLExpr::Cos(Box::new(algebraic_simplify(e))),
663        TLExpr::Tan(e) => TLExpr::Tan(Box::new(algebraic_simplify(e))),
664        // EQ simplifications
665        TLExpr::Eq(l, r) => {
666            let left = algebraic_simplify(l);
667            let right = algebraic_simplify(r);
668
669            // x = x → TRUE
670            if left == right {
671                return TLExpr::Constant(1.0);
672            }
673
674            TLExpr::Eq(Box::new(left), Box::new(right))
675        }
676
677        // LT simplifications
678        TLExpr::Lt(l, r) => {
679            let left = algebraic_simplify(l);
680            let right = algebraic_simplify(r);
681
682            // x < x → FALSE
683            if left == right {
684                return TLExpr::Constant(0.0);
685            }
686
687            TLExpr::Lt(Box::new(left), Box::new(right))
688        }
689
690        // GT simplifications
691        TLExpr::Gt(l, r) => {
692            let left = algebraic_simplify(l);
693            let right = algebraic_simplify(r);
694
695            // x > x → FALSE
696            if left == right {
697                return TLExpr::Constant(0.0);
698            }
699
700            TLExpr::Gt(Box::new(left), Box::new(right))
701        }
702
703        // LTE simplifications
704        TLExpr::Lte(l, r) => {
705            let left = algebraic_simplify(l);
706            let right = algebraic_simplify(r);
707
708            // x <= x → TRUE
709            if left == right {
710                return TLExpr::Constant(1.0);
711            }
712
713            TLExpr::Lte(Box::new(left), Box::new(right))
714        }
715
716        // GTE simplifications
717        TLExpr::Gte(l, r) => {
718            let left = algebraic_simplify(l);
719            let right = algebraic_simplify(r);
720
721            // x >= x → TRUE
722            if left == right {
723                return TLExpr::Constant(1.0);
724            }
725
726            TLExpr::Gte(Box::new(left), Box::new(right))
727        }
728        // AND logical laws
729        TLExpr::And(l, r) => {
730            let left = algebraic_simplify(l);
731            let right = algebraic_simplify(r);
732
733            // Idempotence: A ∧ A = A
734            if left == right {
735                return left;
736            }
737
738            // Identity: A ∧ TRUE = A, TRUE ∧ A = A
739            if let TLExpr::Constant(1.0) = right {
740                return left;
741            }
742            if let TLExpr::Constant(1.0) = left {
743                return right;
744            }
745
746            // Annihilation: A ∧ FALSE = FALSE, FALSE ∧ A = FALSE
747            if let TLExpr::Constant(0.0) = right {
748                return TLExpr::Constant(0.0);
749            }
750            if let TLExpr::Constant(0.0) = left {
751                return TLExpr::Constant(0.0);
752            }
753
754            // Complement: A ∧ ¬A = FALSE
755            if let TLExpr::Not(inner) = &right {
756                if **inner == left {
757                    return TLExpr::Constant(0.0);
758                }
759            }
760            if let TLExpr::Not(inner) = &left {
761                if **inner == right {
762                    return TLExpr::Constant(0.0);
763                }
764            }
765
766            // Absorption: A ∧ (A ∨ B) = A
767            if let TLExpr::Or(or_left, _or_right) = &right {
768                if **or_left == left {
769                    return left;
770                }
771            }
772            if let TLExpr::Or(or_left, _or_right) = &left {
773                if **or_left == right {
774                    return right;
775                }
776            }
777
778            TLExpr::And(Box::new(left), Box::new(right))
779        }
780
781        // OR logical laws
782        TLExpr::Or(l, r) => {
783            let left = algebraic_simplify(l);
784            let right = algebraic_simplify(r);
785
786            // Idempotence: A ∨ A = A
787            if left == right {
788                return left;
789            }
790
791            // Annihilation: A ∨ TRUE = TRUE, TRUE ∨ A = TRUE
792            if let TLExpr::Constant(1.0) = right {
793                return TLExpr::Constant(1.0);
794            }
795            if let TLExpr::Constant(1.0) = left {
796                return TLExpr::Constant(1.0);
797            }
798
799            // Identity: A ∨ FALSE = A, FALSE ∨ A = A
800            if let TLExpr::Constant(0.0) = right {
801                return left;
802            }
803            if let TLExpr::Constant(0.0) = left {
804                return right;
805            }
806
807            // Complement: A ∨ ¬A = TRUE
808            if let TLExpr::Not(inner) = &right {
809                if **inner == left {
810                    return TLExpr::Constant(1.0);
811                }
812            }
813            if let TLExpr::Not(inner) = &left {
814                if **inner == right {
815                    return TLExpr::Constant(1.0);
816                }
817            }
818
819            // Absorption: A ∨ (A ∧ B) = A
820            if let TLExpr::And(and_left, _and_right) = &right {
821                if **and_left == left {
822                    return left;
823                }
824            }
825            if let TLExpr::And(and_left, _and_right) = &left {
826                if **and_left == right {
827                    return right;
828                }
829            }
830
831            TLExpr::Or(Box::new(left), Box::new(right))
832        }
833
834        // IMPLY simplifications
835        TLExpr::Imply(l, r) => {
836            let left = algebraic_simplify(l);
837            let right = algebraic_simplify(r);
838
839            // TRUE → P = P
840            if let TLExpr::Constant(1.0) = left {
841                return right;
842            }
843
844            // FALSE → P = TRUE
845            if let TLExpr::Constant(0.0) = left {
846                return TLExpr::Constant(1.0);
847            }
848
849            // P → TRUE = TRUE
850            if let TLExpr::Constant(1.0) = right {
851                return TLExpr::Constant(1.0);
852            }
853
854            // P → FALSE = ¬P
855            if let TLExpr::Constant(0.0) = right {
856                return TLExpr::negate(left);
857            }
858
859            // P → P = TRUE
860            if left == right {
861                return TLExpr::Constant(1.0);
862            }
863
864            TLExpr::Imply(Box::new(left), Box::new(right))
865        }
866        TLExpr::Score(e) => TLExpr::Score(Box::new(algebraic_simplify(e))),
867        TLExpr::Exists { var, domain, body } => TLExpr::Exists {
868            var: var.clone(),
869            domain: domain.clone(),
870            body: Box::new(algebraic_simplify(body)),
871        },
872        TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
873            var: var.clone(),
874            domain: domain.clone(),
875            body: Box::new(algebraic_simplify(body)),
876        },
877        TLExpr::Aggregate {
878            op,
879            var,
880            domain,
881            body,
882            group_by,
883        } => TLExpr::Aggregate {
884            op: op.clone(),
885            var: var.clone(),
886            domain: domain.clone(),
887            body: Box::new(algebraic_simplify(body)),
888            group_by: group_by.clone(),
889        },
890        TLExpr::IfThenElse {
891            condition,
892            then_branch,
893            else_branch,
894        } => TLExpr::IfThenElse {
895            condition: Box::new(algebraic_simplify(condition)),
896            then_branch: Box::new(algebraic_simplify(then_branch)),
897            else_branch: Box::new(algebraic_simplify(else_branch)),
898        },
899        TLExpr::Let { var, value, body } => TLExpr::Let {
900            var: var.clone(),
901            value: Box::new(algebraic_simplify(value)),
902            body: Box::new(algebraic_simplify(body)),
903        },
904
905        // Leaves
906        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
907    }
908}
909
910/// Substitute a variable with a value in an expression (for Let binding propagation)
911fn substitute(expr: &TLExpr, var: &str, value: &TLExpr) -> TLExpr {
912    match expr {
913        // If we find a predicate matching the variable name with no args, substitute
914        TLExpr::Pred { name, args } if name == var && args.is_empty() => value.clone(),
915
916        // For predicates with args or different names, keep them
917        TLExpr::Pred { .. } => expr.clone(),
918
919        // Recursively substitute in binary operations
920        TLExpr::And(l, r) => TLExpr::And(
921            Box::new(substitute(l, var, value)),
922            Box::new(substitute(r, var, value)),
923        ),
924        TLExpr::Or(l, r) => TLExpr::Or(
925            Box::new(substitute(l, var, value)),
926            Box::new(substitute(r, var, value)),
927        ),
928        TLExpr::Imply(l, r) => TLExpr::Imply(
929            Box::new(substitute(l, var, value)),
930            Box::new(substitute(r, var, value)),
931        ),
932        TLExpr::Add(l, r) => TLExpr::Add(
933            Box::new(substitute(l, var, value)),
934            Box::new(substitute(r, var, value)),
935        ),
936        TLExpr::Sub(l, r) => TLExpr::Sub(
937            Box::new(substitute(l, var, value)),
938            Box::new(substitute(r, var, value)),
939        ),
940        TLExpr::Mul(l, r) => TLExpr::Mul(
941            Box::new(substitute(l, var, value)),
942            Box::new(substitute(r, var, value)),
943        ),
944        TLExpr::Div(l, r) => TLExpr::Div(
945            Box::new(substitute(l, var, value)),
946            Box::new(substitute(r, var, value)),
947        ),
948        TLExpr::Pow(l, r) => TLExpr::Pow(
949            Box::new(substitute(l, var, value)),
950            Box::new(substitute(r, var, value)),
951        ),
952        TLExpr::Mod(l, r) => TLExpr::Mod(
953            Box::new(substitute(l, var, value)),
954            Box::new(substitute(r, var, value)),
955        ),
956        TLExpr::Min(l, r) => TLExpr::Min(
957            Box::new(substitute(l, var, value)),
958            Box::new(substitute(r, var, value)),
959        ),
960        TLExpr::Max(l, r) => TLExpr::Max(
961            Box::new(substitute(l, var, value)),
962            Box::new(substitute(r, var, value)),
963        ),
964        TLExpr::Eq(l, r) => TLExpr::Eq(
965            Box::new(substitute(l, var, value)),
966            Box::new(substitute(r, var, value)),
967        ),
968        TLExpr::Lt(l, r) => TLExpr::Lt(
969            Box::new(substitute(l, var, value)),
970            Box::new(substitute(r, var, value)),
971        ),
972        TLExpr::Gt(l, r) => TLExpr::Gt(
973            Box::new(substitute(l, var, value)),
974            Box::new(substitute(r, var, value)),
975        ),
976        TLExpr::Lte(l, r) => TLExpr::Lte(
977            Box::new(substitute(l, var, value)),
978            Box::new(substitute(r, var, value)),
979        ),
980        TLExpr::Gte(l, r) => TLExpr::Gte(
981            Box::new(substitute(l, var, value)),
982            Box::new(substitute(r, var, value)),
983        ),
984
985        // Recursively substitute in unary operations
986        TLExpr::Not(e) => TLExpr::Not(Box::new(substitute(e, var, value))),
987        TLExpr::Box(e) => TLExpr::Box(Box::new(substitute(e, var, value))),
988        TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(substitute(e, var, value))),
989        TLExpr::Next(e) => TLExpr::Next(Box::new(substitute(e, var, value))),
990        TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(substitute(e, var, value))),
991        TLExpr::Always(e) => TLExpr::Always(Box::new(substitute(e, var, value))),
992        TLExpr::Until { before, after } => TLExpr::Until {
993            before: Box::new(substitute(before, var, value)),
994            after: Box::new(substitute(after, var, value)),
995        },
996
997        // Fuzzy logic operators
998        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
999            kind: *kind,
1000            left: Box::new(substitute(left, var, value)),
1001            right: Box::new(substitute(right, var, value)),
1002        },
1003        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
1004            kind: *kind,
1005            left: Box::new(substitute(left, var, value)),
1006            right: Box::new(substitute(right, var, value)),
1007        },
1008        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
1009            kind: *kind,
1010            expr: Box::new(substitute(expr, var, value)),
1011        },
1012        TLExpr::FuzzyImplication {
1013            kind,
1014            premise,
1015            conclusion,
1016        } => TLExpr::FuzzyImplication {
1017            kind: *kind,
1018            premise: Box::new(substitute(premise, var, value)),
1019            conclusion: Box::new(substitute(conclusion, var, value)),
1020        },
1021
1022        // Probabilistic operators
1023        TLExpr::SoftExists {
1024            var: v,
1025            domain,
1026            body,
1027            temperature,
1028        } => TLExpr::SoftExists {
1029            var: v.clone(),
1030            domain: domain.clone(),
1031            body: Box::new(if v == var {
1032                (**body).clone()
1033            } else {
1034                substitute(body, var, value)
1035            }),
1036            temperature: *temperature,
1037        },
1038        TLExpr::SoftForAll {
1039            var: v,
1040            domain,
1041            body,
1042            temperature,
1043        } => TLExpr::SoftForAll {
1044            var: v.clone(),
1045            domain: domain.clone(),
1046            body: Box::new(if v == var {
1047                (**body).clone()
1048            } else {
1049                substitute(body, var, value)
1050            }),
1051            temperature: *temperature,
1052        },
1053        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
1054            weight: *weight,
1055            rule: Box::new(substitute(rule, var, value)),
1056        },
1057        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
1058            alternatives: alternatives
1059                .iter()
1060                .map(|(p, e)| (*p, substitute(e, var, value)))
1061                .collect(),
1062        },
1063
1064        // Extended temporal logic
1065        TLExpr::Release { released, releaser } => TLExpr::Release {
1066            released: Box::new(substitute(released, var, value)),
1067            releaser: Box::new(substitute(releaser, var, value)),
1068        },
1069        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
1070            before: Box::new(substitute(before, var, value)),
1071            after: Box::new(substitute(after, var, value)),
1072        },
1073        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
1074            released: Box::new(substitute(released, var, value)),
1075            releaser: Box::new(substitute(releaser, var, value)),
1076        },
1077
1078        TLExpr::Score(e) => TLExpr::Score(Box::new(substitute(e, var, value))),
1079        TLExpr::Abs(e) => TLExpr::Abs(Box::new(substitute(e, var, value))),
1080        TLExpr::Floor(e) => TLExpr::Floor(Box::new(substitute(e, var, value))),
1081        TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(substitute(e, var, value))),
1082        TLExpr::Round(e) => TLExpr::Round(Box::new(substitute(e, var, value))),
1083        TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(substitute(e, var, value))),
1084        TLExpr::Exp(e) => TLExpr::Exp(Box::new(substitute(e, var, value))),
1085        TLExpr::Log(e) => TLExpr::Log(Box::new(substitute(e, var, value))),
1086        TLExpr::Sin(e) => TLExpr::Sin(Box::new(substitute(e, var, value))),
1087        TLExpr::Cos(e) => TLExpr::Cos(Box::new(substitute(e, var, value))),
1088        TLExpr::Tan(e) => TLExpr::Tan(Box::new(substitute(e, var, value))),
1089
1090        // For quantifiers and aggregates, don't substitute if the variable shadows
1091        TLExpr::Exists {
1092            var: qvar,
1093            domain,
1094            body,
1095        } => {
1096            if qvar == var {
1097                expr.clone() // Variable is shadowed, don't substitute
1098            } else {
1099                TLExpr::Exists {
1100                    var: qvar.clone(),
1101                    domain: domain.clone(),
1102                    body: Box::new(substitute(body, var, value)),
1103                }
1104            }
1105        }
1106        TLExpr::ForAll {
1107            var: qvar,
1108            domain,
1109            body,
1110        } => {
1111            if qvar == var {
1112                expr.clone() // Variable is shadowed, don't substitute
1113            } else {
1114                TLExpr::ForAll {
1115                    var: qvar.clone(),
1116                    domain: domain.clone(),
1117                    body: Box::new(substitute(body, var, value)),
1118                }
1119            }
1120        }
1121        TLExpr::Aggregate {
1122            op,
1123            var: avar,
1124            domain,
1125            body,
1126            group_by,
1127        } => {
1128            if avar == var {
1129                expr.clone() // Variable is shadowed, don't substitute
1130            } else {
1131                TLExpr::Aggregate {
1132                    op: op.clone(),
1133                    var: avar.clone(),
1134                    domain: domain.clone(),
1135                    body: Box::new(substitute(body, var, value)),
1136                    group_by: group_by.clone(),
1137                }
1138            }
1139        }
1140
1141        // For Let bindings, handle shadowing and substitute recursively
1142        TLExpr::Let {
1143            var: lvar,
1144            value: lvalue,
1145            body,
1146        } => {
1147            let new_value = substitute(lvalue, var, value);
1148            if lvar == var {
1149                // Variable is shadowed in body, don't substitute there
1150                TLExpr::Let {
1151                    var: lvar.clone(),
1152                    value: Box::new(new_value),
1153                    body: body.clone(),
1154                }
1155            } else {
1156                TLExpr::Let {
1157                    var: lvar.clone(),
1158                    value: Box::new(new_value),
1159                    body: Box::new(substitute(body, var, value)),
1160                }
1161            }
1162        }
1163
1164        // For if-then-else, substitute in all branches
1165        TLExpr::IfThenElse {
1166            condition,
1167            then_branch,
1168            else_branch,
1169        } => TLExpr::IfThenElse {
1170            condition: Box::new(substitute(condition, var, value)),
1171            then_branch: Box::new(substitute(then_branch, var, value)),
1172            else_branch: Box::new(substitute(else_branch, var, value)),
1173        },
1174
1175        // Constants remain unchanged
1176        TLExpr::Constant(_) => expr.clone(),
1177    }
1178}
1179
1180/// Propagate constants through Let bindings
1181pub fn propagate_constants(expr: &TLExpr) -> TLExpr {
1182    match expr {
1183        // If the Let binding value is a constant, substitute it into the body
1184        TLExpr::Let { var, value, body } => {
1185            let optimized_value = propagate_constants(value);
1186            let optimized_body = propagate_constants(body);
1187
1188            // If the value is constant, substitute it
1189            if matches!(optimized_value, TLExpr::Constant(_)) {
1190                substitute(&optimized_body, var, &optimized_value)
1191            } else {
1192                TLExpr::Let {
1193                    var: var.clone(),
1194                    value: Box::new(optimized_value),
1195                    body: Box::new(optimized_body),
1196                }
1197            }
1198        }
1199
1200        // Recursively propagate in other expressions
1201        TLExpr::And(l, r) => TLExpr::And(
1202            Box::new(propagate_constants(l)),
1203            Box::new(propagate_constants(r)),
1204        ),
1205        TLExpr::Or(l, r) => TLExpr::Or(
1206            Box::new(propagate_constants(l)),
1207            Box::new(propagate_constants(r)),
1208        ),
1209        TLExpr::Imply(l, r) => TLExpr::Imply(
1210            Box::new(propagate_constants(l)),
1211            Box::new(propagate_constants(r)),
1212        ),
1213        TLExpr::Add(l, r) => TLExpr::Add(
1214            Box::new(propagate_constants(l)),
1215            Box::new(propagate_constants(r)),
1216        ),
1217        TLExpr::Sub(l, r) => TLExpr::Sub(
1218            Box::new(propagate_constants(l)),
1219            Box::new(propagate_constants(r)),
1220        ),
1221        TLExpr::Mul(l, r) => TLExpr::Mul(
1222            Box::new(propagate_constants(l)),
1223            Box::new(propagate_constants(r)),
1224        ),
1225        TLExpr::Div(l, r) => TLExpr::Div(
1226            Box::new(propagate_constants(l)),
1227            Box::new(propagate_constants(r)),
1228        ),
1229        TLExpr::Pow(l, r) => TLExpr::Pow(
1230            Box::new(propagate_constants(l)),
1231            Box::new(propagate_constants(r)),
1232        ),
1233        TLExpr::Mod(l, r) => TLExpr::Mod(
1234            Box::new(propagate_constants(l)),
1235            Box::new(propagate_constants(r)),
1236        ),
1237        TLExpr::Min(l, r) => TLExpr::Min(
1238            Box::new(propagate_constants(l)),
1239            Box::new(propagate_constants(r)),
1240        ),
1241        TLExpr::Max(l, r) => TLExpr::Max(
1242            Box::new(propagate_constants(l)),
1243            Box::new(propagate_constants(r)),
1244        ),
1245        TLExpr::Eq(l, r) => TLExpr::Eq(
1246            Box::new(propagate_constants(l)),
1247            Box::new(propagate_constants(r)),
1248        ),
1249        TLExpr::Lt(l, r) => TLExpr::Lt(
1250            Box::new(propagate_constants(l)),
1251            Box::new(propagate_constants(r)),
1252        ),
1253        TLExpr::Gt(l, r) => TLExpr::Gt(
1254            Box::new(propagate_constants(l)),
1255            Box::new(propagate_constants(r)),
1256        ),
1257        TLExpr::Lte(l, r) => TLExpr::Lte(
1258            Box::new(propagate_constants(l)),
1259            Box::new(propagate_constants(r)),
1260        ),
1261        TLExpr::Gte(l, r) => TLExpr::Gte(
1262            Box::new(propagate_constants(l)),
1263            Box::new(propagate_constants(r)),
1264        ),
1265        TLExpr::Not(e) => TLExpr::Not(Box::new(propagate_constants(e))),
1266        TLExpr::Score(e) => TLExpr::Score(Box::new(propagate_constants(e))),
1267        TLExpr::Abs(e) => TLExpr::Abs(Box::new(propagate_constants(e))),
1268        TLExpr::Floor(e) => TLExpr::Floor(Box::new(propagate_constants(e))),
1269        TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(propagate_constants(e))),
1270        TLExpr::Round(e) => TLExpr::Round(Box::new(propagate_constants(e))),
1271        TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(propagate_constants(e))),
1272        TLExpr::Exp(e) => TLExpr::Exp(Box::new(propagate_constants(e))),
1273        TLExpr::Log(e) => TLExpr::Log(Box::new(propagate_constants(e))),
1274        TLExpr::Sin(e) => TLExpr::Sin(Box::new(propagate_constants(e))),
1275        TLExpr::Cos(e) => TLExpr::Cos(Box::new(propagate_constants(e))),
1276        TLExpr::Tan(e) => TLExpr::Tan(Box::new(propagate_constants(e))),
1277        TLExpr::Box(e) => TLExpr::Box(Box::new(propagate_constants(e))),
1278        TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(propagate_constants(e))),
1279        TLExpr::Next(e) => TLExpr::Next(Box::new(propagate_constants(e))),
1280        TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(propagate_constants(e))),
1281        TLExpr::Always(e) => TLExpr::Always(Box::new(propagate_constants(e))),
1282        TLExpr::Until { before, after } => TLExpr::Until {
1283            before: Box::new(propagate_constants(before)),
1284            after: Box::new(propagate_constants(after)),
1285        },
1286
1287        // Fuzzy logic operators
1288        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
1289            kind: *kind,
1290            left: Box::new(propagate_constants(left)),
1291            right: Box::new(propagate_constants(right)),
1292        },
1293        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
1294            kind: *kind,
1295            left: Box::new(propagate_constants(left)),
1296            right: Box::new(propagate_constants(right)),
1297        },
1298        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
1299            kind: *kind,
1300            expr: Box::new(propagate_constants(expr)),
1301        },
1302        TLExpr::FuzzyImplication {
1303            kind,
1304            premise,
1305            conclusion,
1306        } => TLExpr::FuzzyImplication {
1307            kind: *kind,
1308            premise: Box::new(propagate_constants(premise)),
1309            conclusion: Box::new(propagate_constants(conclusion)),
1310        },
1311
1312        // Probabilistic operators
1313        TLExpr::SoftExists {
1314            var,
1315            domain,
1316            body,
1317            temperature,
1318        } => TLExpr::SoftExists {
1319            var: var.clone(),
1320            domain: domain.clone(),
1321            body: Box::new(propagate_constants(body)),
1322            temperature: *temperature,
1323        },
1324        TLExpr::SoftForAll {
1325            var,
1326            domain,
1327            body,
1328            temperature,
1329        } => TLExpr::SoftForAll {
1330            var: var.clone(),
1331            domain: domain.clone(),
1332            body: Box::new(propagate_constants(body)),
1333            temperature: *temperature,
1334        },
1335        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
1336            weight: *weight,
1337            rule: Box::new(propagate_constants(rule)),
1338        },
1339        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
1340            alternatives: alternatives
1341                .iter()
1342                .map(|(p, e)| (*p, propagate_constants(e)))
1343                .collect(),
1344        },
1345
1346        // Extended temporal logic
1347        TLExpr::Release { released, releaser } => TLExpr::Release {
1348            released: Box::new(propagate_constants(released)),
1349            releaser: Box::new(propagate_constants(releaser)),
1350        },
1351        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
1352            before: Box::new(propagate_constants(before)),
1353            after: Box::new(propagate_constants(after)),
1354        },
1355        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
1356            released: Box::new(propagate_constants(released)),
1357            releaser: Box::new(propagate_constants(releaser)),
1358        },
1359
1360        TLExpr::Exists { var, domain, body } => TLExpr::Exists {
1361            var: var.clone(),
1362            domain: domain.clone(),
1363            body: Box::new(propagate_constants(body)),
1364        },
1365        TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
1366            var: var.clone(),
1367            domain: domain.clone(),
1368            body: Box::new(propagate_constants(body)),
1369        },
1370        TLExpr::Aggregate {
1371            op,
1372            var,
1373            domain,
1374            body,
1375            group_by,
1376        } => TLExpr::Aggregate {
1377            op: op.clone(),
1378            var: var.clone(),
1379            domain: domain.clone(),
1380            body: Box::new(propagate_constants(body)),
1381            group_by: group_by.clone(),
1382        },
1383        TLExpr::IfThenElse {
1384            condition,
1385            then_branch,
1386            else_branch,
1387        } => TLExpr::IfThenElse {
1388            condition: Box::new(propagate_constants(condition)),
1389            then_branch: Box::new(propagate_constants(then_branch)),
1390            else_branch: Box::new(propagate_constants(else_branch)),
1391        },
1392        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
1393    }
1394}
1395
1396/// Apply multiple optimization passes in sequence
1397pub fn optimize_expr(expr: &TLExpr) -> TLExpr {
1398    // Apply optimizations iteratively until no more changes occur
1399    // This handles nested Let bindings and cascading optimizations
1400    let mut current = expr.clone();
1401    let mut iterations = 0;
1402    const MAX_ITERATIONS: usize = 10; // Prevent infinite loops
1403
1404    loop {
1405        let propagated = propagate_constants(&current);
1406        let folded = constant_fold(&propagated);
1407        let simplified = algebraic_simplify(&folded);
1408
1409        // If no change occurred, we're done
1410        if simplified == current || iterations >= MAX_ITERATIONS {
1411            return simplified;
1412        }
1413
1414        current = simplified;
1415        iterations += 1;
1416    }
1417}
1418
1419#[cfg(test)]
1420mod tests {
1421    use super::*;
1422
1423    #[test]
1424    fn test_constant_fold_addition() {
1425        let expr = TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0));
1426        let folded = constant_fold(&expr);
1427        assert_eq!(folded, TLExpr::Constant(5.0));
1428    }
1429
1430    #[test]
1431    fn test_constant_fold_multiplication() {
1432        let expr = TLExpr::mul(TLExpr::constant(4.0), TLExpr::constant(5.0));
1433        let folded = constant_fold(&expr);
1434        assert_eq!(folded, TLExpr::Constant(20.0));
1435    }
1436
1437    #[test]
1438    fn test_constant_fold_nested() {
1439        // (2 + 3) * 4 = 20
1440        let expr = TLExpr::mul(
1441            TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
1442            TLExpr::constant(4.0),
1443        );
1444        let folded = constant_fold(&expr);
1445        assert_eq!(folded, TLExpr::Constant(20.0));
1446    }
1447
1448    #[test]
1449    fn test_algebraic_simplify_add_zero() {
1450        let expr = TLExpr::add(TLExpr::constant(5.0), TLExpr::constant(0.0));
1451        let simplified = algebraic_simplify(&expr);
1452        assert_eq!(simplified, TLExpr::Constant(5.0));
1453    }
1454
1455    #[test]
1456    fn test_algebraic_simplify_mul_one() {
1457        let expr = TLExpr::mul(TLExpr::constant(7.0), TLExpr::constant(1.0));
1458        let simplified = algebraic_simplify(&expr);
1459        assert_eq!(simplified, TLExpr::Constant(7.0));
1460    }
1461
1462    #[test]
1463    fn test_algebraic_simplify_mul_zero() {
1464        let expr = TLExpr::mul(TLExpr::constant(7.0), TLExpr::constant(0.0));
1465        let simplified = algebraic_simplify(&expr);
1466        assert_eq!(simplified, TLExpr::Constant(0.0));
1467    }
1468
1469    #[test]
1470    fn test_algebraic_simplify_double_negation() {
1471        let expr = TLExpr::negate(TLExpr::negate(TLExpr::constant(5.0)));
1472        let simplified = algebraic_simplify(&expr);
1473        assert_eq!(simplified, TLExpr::Constant(5.0));
1474    }
1475
1476    #[test]
1477    fn test_optimize_expr_combined() {
1478        // (2 + 3) * 1 should become 5
1479        let expr = TLExpr::mul(
1480            TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
1481            TLExpr::constant(1.0),
1482        );
1483        let optimized = optimize_expr(&expr);
1484        assert_eq!(optimized, TLExpr::Constant(5.0));
1485    }
1486
1487    #[test]
1488    fn test_constant_fold_trig() {
1489        let expr = TLExpr::sin(TLExpr::constant(0.0));
1490        let folded = constant_fold(&expr);
1491        assert_eq!(folded, TLExpr::Constant(0.0));
1492    }
1493
1494    #[test]
1495    fn test_constant_fold_sqrt() {
1496        let expr = TLExpr::sqrt(TLExpr::constant(4.0));
1497        let folded = constant_fold(&expr);
1498        assert_eq!(folded, TLExpr::Constant(2.0));
1499    }
1500
1501    #[test]
1502    fn test_algebraic_simplify_power_identities() {
1503        // x^0 = 1
1504        let expr = TLExpr::pow(TLExpr::constant(42.0), TLExpr::constant(0.0));
1505        let simplified = algebraic_simplify(&expr);
1506        assert_eq!(simplified, TLExpr::Constant(1.0));
1507
1508        // x^1 = x
1509        let expr2 = TLExpr::pow(TLExpr::constant(42.0), TLExpr::constant(1.0));
1510        let simplified2 = algebraic_simplify(&expr2);
1511        assert_eq!(simplified2, TLExpr::Constant(42.0));
1512    }
1513
1514    #[test]
1515    fn test_let_binding_constant_propagation() {
1516        // let x = 5 in x + x should become 10
1517        let expr = TLExpr::let_binding(
1518            "x",
1519            TLExpr::constant(5.0),
1520            TLExpr::add(TLExpr::pred("x", vec![]), TLExpr::pred("x", vec![])),
1521        );
1522        let optimized = optimize_expr(&expr);
1523        assert_eq!(optimized, TLExpr::Constant(10.0));
1524    }
1525
1526    #[test]
1527    fn test_let_binding_nested_propagation() {
1528        // let x = 3 in (let y = x + 2 in x * y) should become 15
1529        let expr = TLExpr::let_binding(
1530            "x",
1531            TLExpr::constant(3.0),
1532            TLExpr::let_binding(
1533                "y",
1534                TLExpr::add(TLExpr::pred("x", vec![]), TLExpr::constant(2.0)),
1535                TLExpr::mul(TLExpr::pred("x", vec![]), TLExpr::pred("y", vec![])),
1536            ),
1537        );
1538        let optimized = optimize_expr(&expr);
1539        assert_eq!(optimized, TLExpr::Constant(15.0));
1540    }
1541
1542    #[test]
1543    fn test_let_binding_shadowing() {
1544        // let x = 5 in (let x = 3 in x) should become 3
1545        let expr = TLExpr::let_binding(
1546            "x",
1547            TLExpr::constant(5.0),
1548            TLExpr::let_binding("x", TLExpr::constant(3.0), TLExpr::pred("x", vec![])),
1549        );
1550        let optimized = optimize_expr(&expr);
1551        assert_eq!(optimized, TLExpr::Constant(3.0));
1552    }
1553
1554    #[test]
1555    fn test_let_binding_no_propagation_for_expressions() {
1556        // let x = y + 1 in x * x should not be fully evaluated
1557        let expr = TLExpr::let_binding(
1558            "x",
1559            TLExpr::add(TLExpr::pred("y", vec![]), TLExpr::constant(1.0)),
1560            TLExpr::mul(TLExpr::pred("x", vec![]), TLExpr::pred("x", vec![])),
1561        );
1562        let optimized = optimize_expr(&expr);
1563        // Should keep the Let binding since x is not a constant
1564        assert!(matches!(optimized, TLExpr::Let { .. }));
1565    }
1566
1567    #[test]
1568    fn test_substitute_respects_shadowing_in_quantifiers() {
1569        // Substitution should not affect shadowed variables in quantifiers
1570        let expr = TLExpr::exists("x", "Domain", TLExpr::pred("x", vec![]));
1571        let substituted = substitute(&expr, "x", &TLExpr::constant(5.0));
1572        // x is shadowed, so it should remain unchanged
1573        assert_eq!(substituted, expr);
1574    }
1575
1576    #[test]
1577    fn test_propagate_constants_complex() {
1578        // let a = 2 in (let b = a * 3 in b + a) should become 8
1579        let expr = TLExpr::let_binding(
1580            "a",
1581            TLExpr::constant(2.0),
1582            TLExpr::let_binding(
1583                "b",
1584                TLExpr::mul(TLExpr::pred("a", vec![]), TLExpr::constant(3.0)),
1585                TLExpr::add(TLExpr::pred("b", vec![]), TLExpr::pred("a", vec![])),
1586            ),
1587        );
1588        let optimized = optimize_expr(&expr);
1589        assert_eq!(optimized, TLExpr::Constant(8.0));
1590    }
1591
1592    #[test]
1593    fn test_constant_fold_min_max() {
1594        let expr1 = TLExpr::min(TLExpr::constant(3.0), TLExpr::constant(7.0));
1595        let folded1 = constant_fold(&expr1);
1596        assert_eq!(folded1, TLExpr::Constant(3.0));
1597
1598        let expr2 = TLExpr::max(TLExpr::constant(3.0), TLExpr::constant(7.0));
1599        let folded2 = constant_fold(&expr2);
1600        assert_eq!(folded2, TLExpr::Constant(7.0));
1601    }
1602
1603    #[test]
1604    fn test_constant_fold_modulo() {
1605        let expr = TLExpr::modulo(TLExpr::constant(10.0), TLExpr::constant(3.0));
1606        let folded = constant_fold(&expr);
1607        assert_eq!(folded, TLExpr::Constant(1.0));
1608    }
1609
1610    // ===== Advanced Algebraic Simplification Tests =====
1611
1612    // Logical Laws - AND
1613    #[test]
1614    fn test_and_idempotence() {
1615        // A ∧ A = A
1616        let p = TLExpr::pred("P", vec![]);
1617        let expr = TLExpr::and(p.clone(), p.clone());
1618        let simplified = algebraic_simplify(&expr);
1619        assert_eq!(simplified, p);
1620    }
1621
1622    #[test]
1623    fn test_and_identity() {
1624        // A ∧ TRUE = A
1625        let p = TLExpr::pred("P", vec![]);
1626        let expr = TLExpr::and(p.clone(), TLExpr::constant(1.0));
1627        let simplified = algebraic_simplify(&expr);
1628        assert_eq!(simplified, p);
1629    }
1630
1631    #[test]
1632    fn test_and_annihilation() {
1633        // A ∧ FALSE = FALSE
1634        let p = TLExpr::pred("P", vec![]);
1635        let expr = TLExpr::and(p, TLExpr::constant(0.0));
1636        let simplified = algebraic_simplify(&expr);
1637        assert_eq!(simplified, TLExpr::Constant(0.0));
1638    }
1639
1640    #[test]
1641    fn test_and_complement() {
1642        // A ∧ ¬A = FALSE
1643        let p = TLExpr::pred("P", vec![]);
1644        let expr = TLExpr::and(p.clone(), TLExpr::negate(p));
1645        let simplified = algebraic_simplify(&expr);
1646        assert_eq!(simplified, TLExpr::Constant(0.0));
1647    }
1648
1649    #[test]
1650    fn test_and_absorption() {
1651        // A ∧ (A ∨ B) = A
1652        let p = TLExpr::pred("P", vec![]);
1653        let q = TLExpr::pred("Q", vec![]);
1654        let expr = TLExpr::and(p.clone(), TLExpr::or(p.clone(), q));
1655        let simplified = algebraic_simplify(&expr);
1656        assert_eq!(simplified, p);
1657    }
1658
1659    // Logical Laws - OR
1660    #[test]
1661    fn test_or_idempotence() {
1662        // A ∨ A = A
1663        let p = TLExpr::pred("P", vec![]);
1664        let expr = TLExpr::or(p.clone(), p.clone());
1665        let simplified = algebraic_simplify(&expr);
1666        assert_eq!(simplified, p);
1667    }
1668
1669    #[test]
1670    fn test_or_identity() {
1671        // A ∨ FALSE = A
1672        let p = TLExpr::pred("P", vec![]);
1673        let expr = TLExpr::or(p.clone(), TLExpr::constant(0.0));
1674        let simplified = algebraic_simplify(&expr);
1675        assert_eq!(simplified, p);
1676    }
1677
1678    #[test]
1679    fn test_or_annihilation() {
1680        // A ∨ TRUE = TRUE
1681        let p = TLExpr::pred("P", vec![]);
1682        let expr = TLExpr::or(p, TLExpr::constant(1.0));
1683        let simplified = algebraic_simplify(&expr);
1684        assert_eq!(simplified, TLExpr::Constant(1.0));
1685    }
1686
1687    #[test]
1688    fn test_or_complement() {
1689        // A ∨ ¬A = TRUE
1690        let p = TLExpr::pred("P", vec![]);
1691        let expr = TLExpr::or(p.clone(), TLExpr::negate(p));
1692        let simplified = algebraic_simplify(&expr);
1693        assert_eq!(simplified, TLExpr::Constant(1.0));
1694    }
1695
1696    #[test]
1697    fn test_or_absorption() {
1698        // A ∨ (A ∧ B) = A
1699        let p = TLExpr::pred("P", vec![]);
1700        let q = TLExpr::pred("Q", vec![]);
1701        let expr = TLExpr::or(p.clone(), TLExpr::and(p.clone(), q));
1702        let simplified = algebraic_simplify(&expr);
1703        assert_eq!(simplified, p);
1704    }
1705
1706    // Implication Laws
1707    #[test]
1708    fn test_imply_true_antecedent() {
1709        // TRUE → P = P
1710        let p = TLExpr::pred("P", vec![]);
1711        let expr = TLExpr::imply(TLExpr::constant(1.0), p.clone());
1712        let simplified = algebraic_simplify(&expr);
1713        assert_eq!(simplified, p);
1714    }
1715
1716    #[test]
1717    fn test_imply_false_antecedent() {
1718        // FALSE → P = TRUE
1719        let p = TLExpr::pred("P", vec![]);
1720        let expr = TLExpr::imply(TLExpr::constant(0.0), p);
1721        let simplified = algebraic_simplify(&expr);
1722        assert_eq!(simplified, TLExpr::Constant(1.0));
1723    }
1724
1725    #[test]
1726    fn test_imply_true_consequent() {
1727        // P → TRUE = TRUE
1728        let p = TLExpr::pred("P", vec![]);
1729        let expr = TLExpr::imply(p, TLExpr::constant(1.0));
1730        let simplified = algebraic_simplify(&expr);
1731        assert_eq!(simplified, TLExpr::Constant(1.0));
1732    }
1733
1734    #[test]
1735    fn test_imply_false_consequent() {
1736        // P → FALSE = ¬P
1737        let p = TLExpr::pred("P", vec![]);
1738        let expr = TLExpr::imply(p.clone(), TLExpr::constant(0.0));
1739        let simplified = algebraic_simplify(&expr);
1740        assert_eq!(simplified, TLExpr::negate(p));
1741    }
1742
1743    #[test]
1744    fn test_imply_reflexive() {
1745        // P → P = TRUE
1746        let p = TLExpr::pred("P", vec![]);
1747        let expr = TLExpr::imply(p.clone(), p);
1748        let simplified = algebraic_simplify(&expr);
1749        assert_eq!(simplified, TLExpr::Constant(1.0));
1750    }
1751
1752    // Comparison Simplifications
1753    #[test]
1754    fn test_eq_reflexive() {
1755        // x = x → TRUE
1756        let p = TLExpr::pred("P", vec![]);
1757        let expr = TLExpr::eq(p.clone(), p);
1758        let simplified = algebraic_simplify(&expr);
1759        assert_eq!(simplified, TLExpr::Constant(1.0));
1760    }
1761
1762    #[test]
1763    fn test_lt_irreflexive() {
1764        // x < x → FALSE
1765        let p = TLExpr::pred("P", vec![]);
1766        let expr = TLExpr::lt(p.clone(), p);
1767        let simplified = algebraic_simplify(&expr);
1768        assert_eq!(simplified, TLExpr::Constant(0.0));
1769    }
1770
1771    #[test]
1772    fn test_gt_irreflexive() {
1773        // x > x → FALSE
1774        let p = TLExpr::pred("P", vec![]);
1775        let expr = TLExpr::gt(p.clone(), p);
1776        let simplified = algebraic_simplify(&expr);
1777        assert_eq!(simplified, TLExpr::Constant(0.0));
1778    }
1779
1780    #[test]
1781    fn test_lte_reflexive() {
1782        // x <= x → TRUE
1783        let p = TLExpr::pred("P", vec![]);
1784        let expr = TLExpr::lte(p.clone(), p);
1785        let simplified = algebraic_simplify(&expr);
1786        assert_eq!(simplified, TLExpr::Constant(1.0));
1787    }
1788
1789    #[test]
1790    fn test_gte_reflexive() {
1791        // x >= x → TRUE
1792        let p = TLExpr::pred("P", vec![]);
1793        let expr = TLExpr::gte(p.clone(), p);
1794        let simplified = algebraic_simplify(&expr);
1795        assert_eq!(simplified, TLExpr::Constant(1.0));
1796    }
1797
1798    // Arithmetic Simplifications
1799    #[test]
1800    fn test_div_self() {
1801        // x / x = 1 (for constant x != 0)
1802        let expr = TLExpr::div(TLExpr::constant(5.0), TLExpr::constant(5.0));
1803        let simplified = algebraic_simplify(&expr);
1804        assert_eq!(simplified, TLExpr::Constant(1.0));
1805    }
1806
1807    // Modal Logic Simplifications
1808    #[test]
1809    fn test_box_constant() {
1810        // □(TRUE) = TRUE, □(FALSE) = FALSE
1811        let expr1 = TLExpr::modal_box(TLExpr::constant(1.0));
1812        let simplified1 = algebraic_simplify(&expr1);
1813        assert_eq!(simplified1, TLExpr::Constant(1.0));
1814
1815        let expr2 = TLExpr::modal_box(TLExpr::constant(0.0));
1816        let simplified2 = algebraic_simplify(&expr2);
1817        assert_eq!(simplified2, TLExpr::Constant(0.0));
1818    }
1819
1820    #[test]
1821    fn test_diamond_constant() {
1822        // ◇(TRUE) = TRUE, ◇(FALSE) = FALSE
1823        let expr1 = TLExpr::modal_diamond(TLExpr::constant(1.0));
1824        let simplified1 = algebraic_simplify(&expr1);
1825        assert_eq!(simplified1, TLExpr::Constant(1.0));
1826
1827        let expr2 = TLExpr::modal_diamond(TLExpr::constant(0.0));
1828        let simplified2 = algebraic_simplify(&expr2);
1829        assert_eq!(simplified2, TLExpr::Constant(0.0));
1830    }
1831
1832    // Temporal Logic Simplifications
1833    #[test]
1834    fn test_next_constant() {
1835        // X(TRUE) = TRUE, X(FALSE) = FALSE
1836        let expr1 = TLExpr::next(TLExpr::constant(1.0));
1837        let simplified1 = algebraic_simplify(&expr1);
1838        assert_eq!(simplified1, TLExpr::Constant(1.0));
1839
1840        let expr2 = TLExpr::next(TLExpr::constant(0.0));
1841        let simplified2 = algebraic_simplify(&expr2);
1842        assert_eq!(simplified2, TLExpr::Constant(0.0));
1843    }
1844
1845    #[test]
1846    fn test_eventually_constant() {
1847        // F(TRUE) = TRUE, F(FALSE) = FALSE
1848        let expr1 = TLExpr::eventually(TLExpr::constant(1.0));
1849        let simplified1 = algebraic_simplify(&expr1);
1850        assert_eq!(simplified1, TLExpr::Constant(1.0));
1851
1852        let expr2 = TLExpr::eventually(TLExpr::constant(0.0));
1853        let simplified2 = algebraic_simplify(&expr2);
1854        assert_eq!(simplified2, TLExpr::Constant(0.0));
1855    }
1856
1857    #[test]
1858    fn test_eventually_idempotence() {
1859        // F(F(P)) = F(P)
1860        let p = TLExpr::pred("P", vec![]);
1861        let expr = TLExpr::eventually(TLExpr::eventually(p.clone()));
1862        let simplified = algebraic_simplify(&expr);
1863        assert_eq!(simplified, TLExpr::eventually(p));
1864    }
1865
1866    #[test]
1867    fn test_always_constant() {
1868        // G(TRUE) = TRUE, G(FALSE) = FALSE
1869        let expr1 = TLExpr::always(TLExpr::constant(1.0));
1870        let simplified1 = algebraic_simplify(&expr1);
1871        assert_eq!(simplified1, TLExpr::Constant(1.0));
1872
1873        let expr2 = TLExpr::always(TLExpr::constant(0.0));
1874        let simplified2 = algebraic_simplify(&expr2);
1875        assert_eq!(simplified2, TLExpr::Constant(0.0));
1876    }
1877
1878    #[test]
1879    fn test_always_idempotence() {
1880        // G(G(P)) = G(P)
1881        let p = TLExpr::pred("P", vec![]);
1882        let expr = TLExpr::always(TLExpr::always(p.clone()));
1883        let simplified = algebraic_simplify(&expr);
1884        assert_eq!(simplified, TLExpr::always(p));
1885    }
1886
1887    #[test]
1888    fn test_until_true_consequent() {
1889        // P U TRUE = TRUE
1890        let p = TLExpr::pred("P", vec![]);
1891        let expr = TLExpr::until(p, TLExpr::constant(1.0));
1892        let simplified = algebraic_simplify(&expr);
1893        assert_eq!(simplified, TLExpr::Constant(1.0));
1894    }
1895
1896    #[test]
1897    fn test_until_false_antecedent() {
1898        // FALSE U P = F(P)
1899        let p = TLExpr::pred("P", vec![]);
1900        let expr = TLExpr::until(TLExpr::constant(0.0), p.clone());
1901        let simplified = algebraic_simplify(&expr);
1902        assert_eq!(simplified, TLExpr::eventually(p));
1903    }
1904
1905    // Combined Optimization Tests
1906    #[test]
1907    fn test_combined_logical_simplification() {
1908        // (P ∧ TRUE) ∨ FALSE should become P
1909        let p = TLExpr::pred("P", vec![]);
1910        let expr = TLExpr::or(
1911            TLExpr::and(p.clone(), TLExpr::constant(1.0)),
1912            TLExpr::constant(0.0),
1913        );
1914        let optimized = optimize_expr(&expr);
1915        assert_eq!(optimized, p);
1916    }
1917
1918    #[test]
1919    fn test_combined_implication_simplification() {
1920        // (P → TRUE) ∧ (FALSE → Q) should become TRUE
1921        let p = TLExpr::pred("P", vec![]);
1922        let q = TLExpr::pred("Q", vec![]);
1923        let expr = TLExpr::and(
1924            TLExpr::imply(p, TLExpr::constant(1.0)),
1925            TLExpr::imply(TLExpr::constant(0.0), q),
1926        );
1927        let optimized = optimize_expr(&expr);
1928        assert_eq!(optimized, TLExpr::Constant(1.0));
1929    }
1930}