Skip to main content

tensorlogic_compiler/optimize/
constant_folding.rs

1//! Constant folding optimization pass.
2//!
3//! This module implements compile-time evaluation of constant expressions,
4//! reducing them to single constant values where possible.
5
6use tensorlogic_ir::TLExpr;
7
8/// Statistics from constant folding optimization.
9#[derive(Debug, Default, Clone)]
10pub struct ConstantFoldingStats {
11    /// Number of binary operations folded
12    pub binary_ops_folded: usize,
13    /// Number of unary operations folded
14    pub unary_ops_folded: usize,
15    /// Total expressions processed
16    pub total_processed: usize,
17}
18
19/// Optimize an expression by folding constant subexpressions.
20///
21/// This pass evaluates constant expressions at compile time, replacing them
22/// with their computed values. This can significantly reduce runtime computation
23/// for expressions involving constants.
24///
25/// # Examples
26///
27/// ```
28/// use tensorlogic_compiler::optimize::fold_constants;
29/// use tensorlogic_ir::TLExpr;
30///
31/// // 2.0 + 3.0 => 5.0
32/// let expr = TLExpr::Add(
33///     Box::new(TLExpr::Constant(2.0)),
34///     Box::new(TLExpr::Constant(3.0)),
35/// );
36///
37/// let (optimized, stats) = fold_constants(&expr);
38/// assert!(matches!(optimized, TLExpr::Constant(5.0)));
39/// assert_eq!(stats.binary_ops_folded, 1);
40/// ```
41pub fn fold_constants(expr: &TLExpr) -> (TLExpr, ConstantFoldingStats) {
42    let mut stats = ConstantFoldingStats::default();
43    let result = fold_constants_impl(expr, &mut stats);
44    (result, stats)
45}
46
47fn fold_constants_impl(expr: &TLExpr, stats: &mut ConstantFoldingStats) -> TLExpr {
48    stats.total_processed += 1;
49
50    match expr {
51        #[allow(unreachable_patterns)] // Binary arithmetic operations
52        TLExpr::Add(left, right) => fold_binary_op(
53            left,
54            right,
55            stats,
56            |a, b| a + b,
57            |l, r| TLExpr::Add(Box::new(l), Box::new(r)),
58        ),
59        TLExpr::Sub(left, right) => fold_binary_op(
60            left,
61            right,
62            stats,
63            |a, b| a - b,
64            |l, r| TLExpr::Sub(Box::new(l), Box::new(r)),
65        ),
66        TLExpr::Mul(left, right) => fold_binary_op(
67            left,
68            right,
69            stats,
70            |a, b| a * b,
71            |l, r| TLExpr::Mul(Box::new(l), Box::new(r)),
72        ),
73        TLExpr::Div(left, right) => fold_binary_op(
74            left,
75            right,
76            stats,
77            |a, b| {
78                if b.abs() < f64::EPSILON {
79                    f64::NAN // Division by zero
80                } else {
81                    a / b
82                }
83            },
84            |l, r| TLExpr::Div(Box::new(l), Box::new(r)),
85        ),
86        TLExpr::Pow(left, right) => fold_binary_op(
87            left,
88            right,
89            stats,
90            |a, b| a.powf(b),
91            |l, r| TLExpr::Pow(Box::new(l), Box::new(r)),
92        ),
93        TLExpr::Mod(left, right) => fold_binary_op(
94            left,
95            right,
96            stats,
97            |a, b| a % b,
98            |l, r| TLExpr::Mod(Box::new(l), Box::new(r)),
99        ),
100        TLExpr::Min(left, right) => fold_binary_op(
101            left,
102            right,
103            stats,
104            |a, b| a.min(b),
105            |l, r| TLExpr::Min(Box::new(l), Box::new(r)),
106        ),
107        TLExpr::Max(left, right) => fold_binary_op(
108            left,
109            right,
110            stats,
111            |a, b| a.max(b),
112            |l, r| TLExpr::Max(Box::new(l), Box::new(r)),
113        ),
114
115        // Unary mathematical operations
116        TLExpr::Abs(inner) => {
117            fold_unary_op(inner, stats, |x| x.abs(), |i| TLExpr::Abs(Box::new(i)))
118        }
119        TLExpr::Floor(inner) => {
120            fold_unary_op(inner, stats, |x| x.floor(), |i| TLExpr::Floor(Box::new(i)))
121        }
122        TLExpr::Ceil(inner) => {
123            fold_unary_op(inner, stats, |x| x.ceil(), |i| TLExpr::Ceil(Box::new(i)))
124        }
125        TLExpr::Round(inner) => {
126            fold_unary_op(inner, stats, |x| x.round(), |i| TLExpr::Round(Box::new(i)))
127        }
128        TLExpr::Sqrt(inner) => {
129            fold_unary_op(inner, stats, |x| x.sqrt(), |i| TLExpr::Sqrt(Box::new(i)))
130        }
131        TLExpr::Exp(inner) => {
132            fold_unary_op(inner, stats, |x| x.exp(), |i| TLExpr::Exp(Box::new(i)))
133        }
134        TLExpr::Log(inner) => fold_unary_op(inner, stats, |x| x.ln(), |i| TLExpr::Log(Box::new(i))),
135        TLExpr::Sin(inner) => {
136            fold_unary_op(inner, stats, |x| x.sin(), |i| TLExpr::Sin(Box::new(i)))
137        }
138        TLExpr::Cos(inner) => {
139            fold_unary_op(inner, stats, |x| x.cos(), |i| TLExpr::Cos(Box::new(i)))
140        }
141        TLExpr::Tan(inner) => {
142            fold_unary_op(inner, stats, |x| x.tan(), |i| TLExpr::Tan(Box::new(i)))
143        }
144
145        // Logical operations (can't fold without knowing tensor values)
146        TLExpr::And(left, right) => {
147            let left_opt = fold_constants_impl(left, stats);
148            let right_opt = fold_constants_impl(right, stats);
149            TLExpr::And(Box::new(left_opt), Box::new(right_opt))
150        }
151        TLExpr::Or(left, right) => {
152            let left_opt = fold_constants_impl(left, stats);
153            let right_opt = fold_constants_impl(right, stats);
154            TLExpr::Or(Box::new(left_opt), Box::new(right_opt))
155        }
156        TLExpr::Not(inner) => {
157            let inner_opt = fold_constants_impl(inner, stats);
158            TLExpr::Not(Box::new(inner_opt))
159        }
160        TLExpr::Imply(left, right) => {
161            let left_opt = fold_constants_impl(left, stats);
162            let right_opt = fold_constants_impl(right, stats);
163            TLExpr::Imply(Box::new(left_opt), Box::new(right_opt))
164        }
165
166        // Comparison operations
167        TLExpr::Eq(left, right) => {
168            let left_opt = fold_constants_impl(left, stats);
169            let right_opt = fold_constants_impl(right, stats);
170            TLExpr::Eq(Box::new(left_opt), Box::new(right_opt))
171        }
172        TLExpr::Lt(left, right) => {
173            let left_opt = fold_constants_impl(left, stats);
174            let right_opt = fold_constants_impl(right, stats);
175            TLExpr::Lt(Box::new(left_opt), Box::new(right_opt))
176        }
177        TLExpr::Gt(left, right) => {
178            let left_opt = fold_constants_impl(left, stats);
179            let right_opt = fold_constants_impl(right, stats);
180            TLExpr::Gt(Box::new(left_opt), Box::new(right_opt))
181        }
182        TLExpr::Lte(left, right) => {
183            let left_opt = fold_constants_impl(left, stats);
184            let right_opt = fold_constants_impl(right, stats);
185            TLExpr::Lte(Box::new(left_opt), Box::new(right_opt))
186        }
187        TLExpr::Gte(left, right) => {
188            let left_opt = fold_constants_impl(left, stats);
189            let right_opt = fold_constants_impl(right, stats);
190            TLExpr::Gte(Box::new(left_opt), Box::new(right_opt))
191        }
192
193        // Quantifiers and other constructs
194        TLExpr::Exists { var, domain, body } => {
195            let body_opt = fold_constants_impl(body, stats);
196            TLExpr::Exists {
197                var: var.clone(),
198                domain: domain.clone(),
199                body: Box::new(body_opt),
200            }
201        }
202        TLExpr::ForAll { var, domain, body } => {
203            let body_opt = fold_constants_impl(body, stats);
204            TLExpr::ForAll {
205                var: var.clone(),
206                domain: domain.clone(),
207                body: Box::new(body_opt),
208            }
209        }
210        TLExpr::Aggregate {
211            op,
212            var,
213            domain,
214            body,
215            group_by,
216        } => {
217            let body_opt = fold_constants_impl(body, stats);
218            TLExpr::Aggregate {
219                op: op.clone(),
220                var: var.clone(),
221                domain: domain.clone(),
222                body: Box::new(body_opt),
223                group_by: group_by.clone(),
224            }
225        }
226        TLExpr::IfThenElse {
227            condition,
228            then_branch,
229            else_branch,
230        } => {
231            let cond_opt = fold_constants_impl(condition, stats);
232            let then_opt = fold_constants_impl(then_branch, stats);
233            let else_opt = fold_constants_impl(else_branch, stats);
234            TLExpr::IfThenElse {
235                condition: Box::new(cond_opt),
236                then_branch: Box::new(then_opt),
237                else_branch: Box::new(else_opt),
238            }
239        }
240        TLExpr::Let { var, value, body } => {
241            let value_opt = fold_constants_impl(value, stats);
242            let body_opt = fold_constants_impl(body, stats);
243            TLExpr::Let {
244                var: var.clone(),
245                value: Box::new(value_opt),
246                body: Box::new(body_opt),
247            }
248        }
249
250        // Fuzzy logic operators
251        TLExpr::TNorm { kind, left, right } => {
252            let left_opt = fold_constants_impl(left, stats);
253            let right_opt = fold_constants_impl(right, stats);
254            TLExpr::TNorm {
255                kind: *kind,
256                left: Box::new(left_opt),
257                right: Box::new(right_opt),
258            }
259        }
260        TLExpr::TCoNorm { kind, left, right } => {
261            let left_opt = fold_constants_impl(left, stats);
262            let right_opt = fold_constants_impl(right, stats);
263            TLExpr::TCoNorm {
264                kind: *kind,
265                left: Box::new(left_opt),
266                right: Box::new(right_opt),
267            }
268        }
269        TLExpr::FuzzyNot { kind, expr: inner } => {
270            let inner_opt = fold_constants_impl(inner, stats);
271            TLExpr::FuzzyNot {
272                kind: *kind,
273                expr: Box::new(inner_opt),
274            }
275        }
276        TLExpr::FuzzyImplication {
277            kind,
278            premise,
279            conclusion,
280        } => {
281            let premise_opt = fold_constants_impl(premise, stats);
282            let conclusion_opt = fold_constants_impl(conclusion, stats);
283            TLExpr::FuzzyImplication {
284                kind: *kind,
285                premise: Box::new(premise_opt),
286                conclusion: Box::new(conclusion_opt),
287            }
288        }
289        TLExpr::SoftExists {
290            var,
291            domain,
292            body,
293            temperature,
294        } => {
295            let body_opt = fold_constants_impl(body, stats);
296            TLExpr::SoftExists {
297                var: var.clone(),
298                domain: domain.clone(),
299                body: Box::new(body_opt),
300                temperature: *temperature,
301            }
302        }
303        TLExpr::SoftForAll {
304            var,
305            domain,
306            body,
307            temperature,
308        } => {
309            let body_opt = fold_constants_impl(body, stats);
310            TLExpr::SoftForAll {
311                var: var.clone(),
312                domain: domain.clone(),
313                body: Box::new(body_opt),
314                temperature: *temperature,
315            }
316        }
317        TLExpr::WeightedRule { weight, rule } => {
318            let rule_opt = fold_constants_impl(rule, stats);
319            TLExpr::WeightedRule {
320                weight: *weight,
321                rule: Box::new(rule_opt),
322            }
323        }
324        TLExpr::ProbabilisticChoice { alternatives } => {
325            let alts_opt: Vec<_> = alternatives
326                .iter()
327                .map(|(w, e)| (*w, fold_constants_impl(e, stats)))
328                .collect();
329            TLExpr::ProbabilisticChoice {
330                alternatives: alts_opt,
331            }
332        }
333
334        // Modal/temporal logic operators - not yet implemented, pass through with recursion
335        TLExpr::Box(inner) => TLExpr::Box(Box::new(fold_constants_impl(inner, stats))),
336        TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(fold_constants_impl(inner, stats))),
337        TLExpr::Next(inner) => TLExpr::Next(Box::new(fold_constants_impl(inner, stats))),
338        TLExpr::Eventually(inner) => {
339            TLExpr::Eventually(Box::new(fold_constants_impl(inner, stats)))
340        }
341        TLExpr::Always(inner) => TLExpr::Always(Box::new(fold_constants_impl(inner, stats))),
342        TLExpr::Until { before, after } => TLExpr::Until {
343            before: Box::new(fold_constants_impl(before, stats)),
344            after: Box::new(fold_constants_impl(after, stats)),
345        },
346        TLExpr::Release { released, releaser } => TLExpr::Release {
347            released: Box::new(fold_constants_impl(released, stats)),
348            releaser: Box::new(fold_constants_impl(releaser, stats)),
349        },
350        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
351            before: Box::new(fold_constants_impl(before, stats)),
352            after: Box::new(fold_constants_impl(after, stats)),
353        },
354        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
355            released: Box::new(fold_constants_impl(released, stats)),
356            releaser: Box::new(fold_constants_impl(releaser, stats)),
357        },
358
359        // Base cases
360        TLExpr::Pred { .. } | TLExpr::Constant(_) | TLExpr::Score(_) => expr.clone(),
361        // All other expression types (enhancements) - no constant folding
362        _ => expr.clone(),
363    }
364}
365
366/// Helper function to fold binary operations on constants
367fn fold_binary_op<F, C>(
368    left: &TLExpr,
369    right: &TLExpr,
370    stats: &mut ConstantFoldingStats,
371    op: F,
372    constructor: C,
373) -> TLExpr
374where
375    F: Fn(f64, f64) -> f64,
376    C: Fn(TLExpr, TLExpr) -> TLExpr,
377{
378    let left_opt = fold_constants_impl(left, stats);
379    let right_opt = fold_constants_impl(right, stats);
380
381    if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_opt, &right_opt) {
382        stats.binary_ops_folded += 1;
383        TLExpr::Constant(op(*a, *b))
384    } else {
385        constructor(left_opt, right_opt)
386    }
387}
388
389/// Helper function to fold unary operations on constants
390fn fold_unary_op<F, C>(
391    inner: &TLExpr,
392    stats: &mut ConstantFoldingStats,
393    op: F,
394    constructor: C,
395) -> TLExpr
396where
397    F: Fn(f64) -> f64,
398    C: Fn(TLExpr) -> TLExpr,
399{
400    let inner_opt = fold_constants_impl(inner, stats);
401
402    if let TLExpr::Constant(x) = inner_opt {
403        stats.unary_ops_folded += 1;
404        TLExpr::Constant(op(x))
405    } else {
406        constructor(inner_opt)
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_fold_binary_arithmetic() {
416        // 2.0 + 3.0 = 5.0
417        let expr = TLExpr::Add(
418            Box::new(TLExpr::Constant(2.0)),
419            Box::new(TLExpr::Constant(3.0)),
420        );
421        let (result, stats) = fold_constants(&expr);
422        assert!(matches!(result, TLExpr::Constant(x) if (x - 5.0).abs() < f64::EPSILON));
423        assert_eq!(stats.binary_ops_folded, 1);
424    }
425
426    #[test]
427    fn test_fold_nested_arithmetic() {
428        // (2.0 + 3.0) * 4.0 = 5.0 * 4.0 = 20.0
429        let expr = TLExpr::Mul(
430            Box::new(TLExpr::Add(
431                Box::new(TLExpr::Constant(2.0)),
432                Box::new(TLExpr::Constant(3.0)),
433            )),
434            Box::new(TLExpr::Constant(4.0)),
435        );
436        let (result, stats) = fold_constants(&expr);
437        assert!(matches!(result, TLExpr::Constant(x) if (x - 20.0).abs() < f64::EPSILON));
438        assert_eq!(stats.binary_ops_folded, 2);
439    }
440
441    #[test]
442    fn test_fold_unary_operations() {
443        // sqrt(16.0) = 4.0
444        let expr = TLExpr::Sqrt(Box::new(TLExpr::Constant(16.0)));
445        let (result, stats) = fold_constants(&expr);
446        assert!(matches!(result, TLExpr::Constant(x) if (x - 4.0).abs() < f64::EPSILON));
447        assert_eq!(stats.unary_ops_folded, 1);
448    }
449
450    #[test]
451    fn test_fold_trigonometry() {
452        // sin(0.0) = 0.0
453        let expr = TLExpr::Sin(Box::new(TLExpr::Constant(0.0)));
454        let (result, stats) = fold_constants(&expr);
455        assert!(matches!(result, TLExpr::Constant(x) if x.abs() < f64::EPSILON));
456        assert_eq!(stats.unary_ops_folded, 1);
457    }
458
459    #[test]
460    fn test_no_fold_with_variables() {
461        use tensorlogic_ir::Term;
462
463        // x + 2.0 (cannot fold because of variable)
464        let expr = TLExpr::Add(
465            Box::new(TLExpr::pred("x", vec![Term::var("i")])),
466            Box::new(TLExpr::Constant(2.0)),
467        );
468        let (result, stats) = fold_constants(&expr);
469        assert!(matches!(result, TLExpr::Add(..)));
470        assert_eq!(stats.binary_ops_folded, 0);
471    }
472}