Skip to main content

tensorlogic_ir/expr/optimization/
constant_folding.rs

1//! Constant folding: evaluate constant expressions at compile time.
2//!
3//! This module implements constant folding optimizations that evaluate
4//! expressions with constant operands at compile time, reducing runtime overhead.
5
6use crate::expr::TLExpr;
7
8/// Constant folding: evaluate constant expressions at compile time
9pub fn constant_fold(expr: &TLExpr) -> TLExpr {
10    match expr {
11        // Binary arithmetic operations on constants
12        TLExpr::Add(l, r) => {
13            let left = constant_fold(l);
14            let right = constant_fold(r);
15            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
16                return TLExpr::Constant(lv + rv);
17            }
18            TLExpr::Add(Box::new(left), Box::new(right))
19        }
20        TLExpr::Sub(l, r) => {
21            let left = constant_fold(l);
22            let right = constant_fold(r);
23            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
24                return TLExpr::Constant(lv - rv);
25            }
26            TLExpr::Sub(Box::new(left), Box::new(right))
27        }
28        TLExpr::Mul(l, r) => {
29            let left = constant_fold(l);
30            let right = constant_fold(r);
31            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
32                return TLExpr::Constant(lv * rv);
33            }
34            TLExpr::Mul(Box::new(left), Box::new(right))
35        }
36        TLExpr::Div(l, r) => {
37            let left = constant_fold(l);
38            let right = constant_fold(r);
39            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
40                if *rv != 0.0 {
41                    return TLExpr::Constant(lv / rv);
42                }
43            }
44            TLExpr::Div(Box::new(left), Box::new(right))
45        }
46        TLExpr::Pow(l, r) => {
47            let left = constant_fold(l);
48            let right = constant_fold(r);
49            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
50                return TLExpr::Constant(lv.powf(*rv));
51            }
52            TLExpr::Pow(Box::new(left), Box::new(right))
53        }
54        TLExpr::Mod(l, r) => {
55            let left = constant_fold(l);
56            let right = constant_fold(r);
57            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
58                return TLExpr::Constant(lv % rv);
59            }
60            TLExpr::Mod(Box::new(left), Box::new(right))
61        }
62        TLExpr::Min(l, r) => {
63            let left = constant_fold(l);
64            let right = constant_fold(r);
65            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
66                return TLExpr::Constant(lv.min(*rv));
67            }
68            TLExpr::Min(Box::new(left), Box::new(right))
69        }
70        TLExpr::Max(l, r) => {
71            let left = constant_fold(l);
72            let right = constant_fold(r);
73            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
74                return TLExpr::Constant(lv.max(*rv));
75            }
76            TLExpr::Max(Box::new(left), Box::new(right))
77        }
78
79        // Unary mathematical operations on constants
80        TLExpr::Abs(e) => {
81            let inner = constant_fold(e);
82            if let TLExpr::Constant(v) = &inner {
83                return TLExpr::Constant(v.abs());
84            }
85            TLExpr::Abs(Box::new(inner))
86        }
87        TLExpr::Floor(e) => {
88            let inner = constant_fold(e);
89            if let TLExpr::Constant(v) = &inner {
90                return TLExpr::Constant(v.floor());
91            }
92            TLExpr::Floor(Box::new(inner))
93        }
94        TLExpr::Ceil(e) => {
95            let inner = constant_fold(e);
96            if let TLExpr::Constant(v) = &inner {
97                return TLExpr::Constant(v.ceil());
98            }
99            TLExpr::Ceil(Box::new(inner))
100        }
101        TLExpr::Round(e) => {
102            let inner = constant_fold(e);
103            if let TLExpr::Constant(v) = &inner {
104                return TLExpr::Constant(v.round());
105            }
106            TLExpr::Round(Box::new(inner))
107        }
108        TLExpr::Sqrt(e) => {
109            let inner = constant_fold(e);
110            if let TLExpr::Constant(v) = &inner {
111                if *v >= 0.0 {
112                    return TLExpr::Constant(v.sqrt());
113                }
114            }
115            TLExpr::Sqrt(Box::new(inner))
116        }
117        TLExpr::Exp(e) => {
118            let inner = constant_fold(e);
119            if let TLExpr::Constant(v) = &inner {
120                return TLExpr::Constant(v.exp());
121            }
122            TLExpr::Exp(Box::new(inner))
123        }
124        TLExpr::Log(e) => {
125            let inner = constant_fold(e);
126            if let TLExpr::Constant(v) = &inner {
127                if *v > 0.0 {
128                    return TLExpr::Constant(v.ln());
129                }
130            }
131            TLExpr::Log(Box::new(inner))
132        }
133        TLExpr::Sin(e) => {
134            let inner = constant_fold(e);
135            if let TLExpr::Constant(v) = &inner {
136                return TLExpr::Constant(v.sin());
137            }
138            TLExpr::Sin(Box::new(inner))
139        }
140        TLExpr::Cos(e) => {
141            let inner = constant_fold(e);
142            if let TLExpr::Constant(v) = &inner {
143                return TLExpr::Constant(v.cos());
144            }
145            TLExpr::Cos(Box::new(inner))
146        }
147        TLExpr::Tan(e) => {
148            let inner = constant_fold(e);
149            if let TLExpr::Constant(v) = &inner {
150                return TLExpr::Constant(v.tan());
151            }
152            TLExpr::Tan(Box::new(inner))
153        }
154
155        // Comparison operations on constants
156        TLExpr::Eq(l, r) => {
157            let left = constant_fold(l);
158            let right = constant_fold(r);
159            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
160                return TLExpr::Constant(if (lv - rv).abs() < f64::EPSILON {
161                    1.0
162                } else {
163                    0.0
164                });
165            }
166            TLExpr::Eq(Box::new(left), Box::new(right))
167        }
168        TLExpr::Lt(l, r) => {
169            let left = constant_fold(l);
170            let right = constant_fold(r);
171            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
172                return TLExpr::Constant(if lv < rv { 1.0 } else { 0.0 });
173            }
174            TLExpr::Lt(Box::new(left), Box::new(right))
175        }
176        TLExpr::Gt(l, r) => {
177            let left = constant_fold(l);
178            let right = constant_fold(r);
179            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
180                return TLExpr::Constant(if lv > rv { 1.0 } else { 0.0 });
181            }
182            TLExpr::Gt(Box::new(left), Box::new(right))
183        }
184        TLExpr::Lte(l, r) => {
185            let left = constant_fold(l);
186            let right = constant_fold(r);
187            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
188                return TLExpr::Constant(if lv <= rv { 1.0 } else { 0.0 });
189            }
190            TLExpr::Lte(Box::new(left), Box::new(right))
191        }
192        TLExpr::Gte(l, r) => {
193            let left = constant_fold(l);
194            let right = constant_fold(r);
195            if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
196                return TLExpr::Constant(if lv >= rv { 1.0 } else { 0.0 });
197            }
198            TLExpr::Gte(Box::new(left), Box::new(right))
199        }
200
201        // Logical connectives - recursively fold subexpressions
202        TLExpr::And(l, r) => TLExpr::And(Box::new(constant_fold(l)), Box::new(constant_fold(r))),
203        TLExpr::Or(l, r) => TLExpr::Or(Box::new(constant_fold(l)), Box::new(constant_fold(r))),
204        TLExpr::Not(e) => TLExpr::Not(Box::new(constant_fold(e))),
205        TLExpr::Imply(l, r) => {
206            TLExpr::Imply(Box::new(constant_fold(l)), Box::new(constant_fold(r)))
207        }
208
209        // Quantifiers - fold the body
210        TLExpr::Exists { var, domain, body } => TLExpr::Exists {
211            var: var.clone(),
212            domain: domain.clone(),
213            body: Box::new(constant_fold(body)),
214        },
215        TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
216            var: var.clone(),
217            domain: domain.clone(),
218            body: Box::new(constant_fold(body)),
219        },
220
221        // Score operator
222        TLExpr::Score(e) => TLExpr::Score(Box::new(constant_fold(e))),
223
224        // Aggregation
225        TLExpr::Aggregate {
226            op,
227            var,
228            domain,
229            body,
230            group_by,
231        } => TLExpr::Aggregate {
232            op: op.clone(),
233            var: var.clone(),
234            domain: domain.clone(),
235            body: Box::new(constant_fold(body)),
236            group_by: group_by.clone(),
237        },
238
239        // Modal logic operators
240        TLExpr::Box(e) => TLExpr::Box(Box::new(constant_fold(e))),
241        TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(constant_fold(e))),
242
243        // Temporal logic operators
244        TLExpr::Next(e) => TLExpr::Next(Box::new(constant_fold(e))),
245        TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(constant_fold(e))),
246        TLExpr::Always(e) => TLExpr::Always(Box::new(constant_fold(e))),
247        TLExpr::Until { before, after } => TLExpr::Until {
248            before: Box::new(constant_fold(before)),
249            after: Box::new(constant_fold(after)),
250        },
251
252        // Fuzzy logic operators
253        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
254            kind: *kind,
255            left: Box::new(constant_fold(left)),
256            right: Box::new(constant_fold(right)),
257        },
258        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
259            kind: *kind,
260            left: Box::new(constant_fold(left)),
261            right: Box::new(constant_fold(right)),
262        },
263        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
264            kind: *kind,
265            expr: Box::new(constant_fold(expr)),
266        },
267        TLExpr::FuzzyImplication {
268            kind,
269            premise,
270            conclusion,
271        } => TLExpr::FuzzyImplication {
272            kind: *kind,
273            premise: Box::new(constant_fold(premise)),
274            conclusion: Box::new(constant_fold(conclusion)),
275        },
276
277        // Probabilistic operators
278        TLExpr::SoftExists {
279            var,
280            domain,
281            body,
282            temperature,
283        } => TLExpr::SoftExists {
284            var: var.clone(),
285            domain: domain.clone(),
286            body: Box::new(constant_fold(body)),
287            temperature: *temperature,
288        },
289        TLExpr::SoftForAll {
290            var,
291            domain,
292            body,
293            temperature,
294        } => TLExpr::SoftForAll {
295            var: var.clone(),
296            domain: domain.clone(),
297            body: Box::new(constant_fold(body)),
298            temperature: *temperature,
299        },
300        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
301            weight: *weight,
302            rule: Box::new(constant_fold(rule)),
303        },
304        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
305            alternatives: alternatives
306                .iter()
307                .map(|(p, e)| (*p, constant_fold(e)))
308                .collect(),
309        },
310
311        // Extended temporal logic
312        TLExpr::Release { released, releaser } => TLExpr::Release {
313            released: Box::new(constant_fold(released)),
314            releaser: Box::new(constant_fold(releaser)),
315        },
316        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
317            before: Box::new(constant_fold(before)),
318            after: Box::new(constant_fold(after)),
319        },
320        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
321            released: Box::new(constant_fold(released)),
322            releaser: Box::new(constant_fold(releaser)),
323        },
324
325        // Conditional expressions
326        TLExpr::IfThenElse {
327            condition,
328            then_branch,
329            else_branch,
330        } => TLExpr::IfThenElse {
331            condition: Box::new(constant_fold(condition)),
332            then_branch: Box::new(constant_fold(then_branch)),
333            else_branch: Box::new(constant_fold(else_branch)),
334        },
335        TLExpr::Let { var, value, body } => TLExpr::Let {
336            var: var.clone(),
337            value: Box::new(constant_fold(value)),
338            body: Box::new(constant_fold(body)),
339        },
340
341        // Alpha.3 enhancements: recurse into subexpressions (minimal optimization for now)
342        TLExpr::Lambda {
343            var,
344            var_type,
345            body,
346        } => TLExpr::lambda(var.clone(), var_type.clone(), constant_fold(body)),
347        TLExpr::Apply { function, argument } => {
348            TLExpr::apply(constant_fold(function), constant_fold(argument))
349        }
350        TLExpr::SetMembership { element, set } => {
351            TLExpr::set_membership(constant_fold(element), constant_fold(set))
352        }
353        TLExpr::SetUnion { left, right } => {
354            TLExpr::set_union(constant_fold(left), constant_fold(right))
355        }
356        TLExpr::SetIntersection { left, right } => {
357            TLExpr::set_intersection(constant_fold(left), constant_fold(right))
358        }
359        TLExpr::SetDifference { left, right } => {
360            TLExpr::set_difference(constant_fold(left), constant_fold(right))
361        }
362        TLExpr::SetCardinality { set } => TLExpr::set_cardinality(constant_fold(set)),
363        TLExpr::EmptySet => expr.clone(),
364        TLExpr::SetComprehension {
365            var,
366            domain,
367            condition,
368        } => TLExpr::set_comprehension(var.clone(), domain.clone(), constant_fold(condition)),
369        TLExpr::CountingExists {
370            var,
371            domain,
372            body,
373            min_count,
374        } => TLExpr::counting_exists(var.clone(), domain.clone(), constant_fold(body), *min_count),
375        TLExpr::CountingForAll {
376            var,
377            domain,
378            body,
379            min_count,
380        } => TLExpr::counting_forall(var.clone(), domain.clone(), constant_fold(body), *min_count),
381        TLExpr::ExactCount {
382            var,
383            domain,
384            body,
385            count,
386        } => TLExpr::exact_count(var.clone(), domain.clone(), constant_fold(body), *count),
387        TLExpr::Majority { var, domain, body } => {
388            TLExpr::majority(var.clone(), domain.clone(), constant_fold(body))
389        }
390        TLExpr::LeastFixpoint { var, body } => {
391            TLExpr::least_fixpoint(var.clone(), constant_fold(body))
392        }
393        TLExpr::GreatestFixpoint { var, body } => {
394            TLExpr::greatest_fixpoint(var.clone(), constant_fold(body))
395        }
396        TLExpr::Nominal { .. } => expr.clone(),
397        TLExpr::At { nominal, formula } => TLExpr::at(nominal.clone(), constant_fold(formula)),
398        TLExpr::Somewhere { formula } => TLExpr::somewhere(constant_fold(formula)),
399        TLExpr::Everywhere { formula } => TLExpr::everywhere(constant_fold(formula)),
400        TLExpr::AllDifferent { .. } => expr.clone(),
401        TLExpr::GlobalCardinality {
402            variables,
403            values,
404            min_occurrences,
405            max_occurrences,
406        } => TLExpr::global_cardinality(
407            variables.clone(),
408            values.iter().map(constant_fold).collect(),
409            min_occurrences.clone(),
410            max_occurrences.clone(),
411        ),
412        TLExpr::Abducible { .. } => expr.clone(),
413        TLExpr::Explain { formula } => TLExpr::explain(constant_fold(formula)),
414
415        // Leaves - no folding needed
416        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_constant_fold_addition() {
426        let expr = TLExpr::Add(
427            Box::new(TLExpr::Constant(2.0)),
428            Box::new(TLExpr::Constant(3.0)),
429        );
430        let folded = constant_fold(&expr);
431        assert_eq!(folded, TLExpr::Constant(5.0));
432    }
433
434    #[test]
435    fn test_constant_fold_multiplication() {
436        let expr = TLExpr::Mul(
437            Box::new(TLExpr::Constant(4.0)),
438            Box::new(TLExpr::Constant(5.0)),
439        );
440        let folded = constant_fold(&expr);
441        assert_eq!(folded, TLExpr::Constant(20.0));
442    }
443
444    #[test]
445    fn test_constant_fold_nested() {
446        // (2 + 3) * 4 = 20
447        let expr = TLExpr::Mul(
448            Box::new(TLExpr::Add(
449                Box::new(TLExpr::Constant(2.0)),
450                Box::new(TLExpr::Constant(3.0)),
451            )),
452            Box::new(TLExpr::Constant(4.0)),
453        );
454        let folded = constant_fold(&expr);
455        assert_eq!(folded, TLExpr::Constant(20.0));
456    }
457
458    #[test]
459    fn test_constant_fold_division_zero() {
460        let expr = TLExpr::Div(
461            Box::new(TLExpr::Constant(5.0)),
462            Box::new(TLExpr::Constant(0.0)),
463        );
464        let folded = constant_fold(&expr);
465        // Should not fold division by zero
466        matches!(folded, TLExpr::Div(_, _));
467    }
468
469    #[test]
470    fn test_constant_fold_sqrt_negative() {
471        let expr = TLExpr::Sqrt(Box::new(TLExpr::Constant(-4.0)));
472        let folded = constant_fold(&expr);
473        // Should not fold sqrt of negative
474        matches!(folded, TLExpr::Sqrt(_));
475    }
476}