Skip to main content

tensorlogic_ir/expr/optimization/
mod.rs

1//! Expression-level optimizations.
2//!
3//! This module provides various optimization passes for `TLExpr`:
4//! - **Constant folding**: Evaluate constant expressions at compile time
5//! - **Algebraic simplification**: Apply algebraic identities (e.g., x + 0 = x)
6//! - **Constant propagation**: Substitute variables bound in Let expressions
7//!
8//! The main entry point is [`optimize_expr`], which applies multiple passes
9//! iteratively until a fixed point is reached.
10
11mod algebraic;
12mod constant_folding;
13mod propagation;
14pub(crate) mod substitution;
15
16// Re-export public functions
17pub use algebraic::algebraic_simplify;
18pub use constant_folding::constant_fold;
19pub use propagation::propagate_constants;
20
21use crate::expr::TLExpr;
22
23/// Apply multiple optimization passes in sequence
24///
25/// This function applies constant propagation, constant folding, and algebraic
26/// simplification iteratively until no more changes occur or a maximum number
27/// of iterations is reached.
28///
29/// # Example
30///
31/// ```
32/// use tensorlogic_ir::TLExpr;
33/// use tensorlogic_ir::optimize_expr;
34///
35/// // (2 + 3) * 1 should become 5
36/// let expr = TLExpr::mul(
37///     TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
38///     TLExpr::constant(1.0),
39/// );
40/// let optimized = optimize_expr(&expr);
41/// assert_eq!(optimized, TLExpr::Constant(5.0));
42/// ```
43pub fn optimize_expr(expr: &TLExpr) -> TLExpr {
44    // Apply optimizations iteratively until no more changes occur
45    // This handles nested Let bindings and cascading optimizations
46    let mut current = expr.clone();
47    let mut iterations = 0;
48    const MAX_ITERATIONS: usize = 10; // Prevent infinite loops
49
50    loop {
51        let propagated = propagate_constants(&current);
52        let folded = constant_fold(&propagated);
53        let simplified = algebraic_simplify(&folded);
54
55        // If no change occurred, we're done
56        if simplified == current || iterations >= MAX_ITERATIONS {
57            return simplified;
58        }
59
60        current = simplified;
61        iterations += 1;
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_constant_fold_addition() {
71        let expr = TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0));
72        let folded = constant_fold(&expr);
73        assert_eq!(folded, TLExpr::Constant(5.0));
74    }
75
76    #[test]
77    fn test_constant_fold_multiplication() {
78        let expr = TLExpr::mul(TLExpr::constant(4.0), TLExpr::constant(5.0));
79        let folded = constant_fold(&expr);
80        assert_eq!(folded, TLExpr::Constant(20.0));
81    }
82
83    #[test]
84    fn test_constant_fold_nested() {
85        // (2 + 3) * 4 = 20
86        let expr = TLExpr::mul(
87            TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
88            TLExpr::constant(4.0),
89        );
90        let folded = constant_fold(&expr);
91        assert_eq!(folded, TLExpr::Constant(20.0));
92    }
93
94    #[test]
95    fn test_algebraic_simplify_add_zero() {
96        let expr = TLExpr::add(TLExpr::constant(5.0), TLExpr::constant(0.0));
97        let simplified = algebraic_simplify(&expr);
98        assert_eq!(simplified, TLExpr::Constant(5.0));
99    }
100
101    #[test]
102    fn test_algebraic_simplify_mul_one() {
103        let expr = TLExpr::mul(TLExpr::constant(7.0), TLExpr::constant(1.0));
104        let simplified = algebraic_simplify(&expr);
105        assert_eq!(simplified, TLExpr::Constant(7.0));
106    }
107
108    #[test]
109    fn test_algebraic_simplify_mul_zero() {
110        let expr = TLExpr::mul(TLExpr::constant(7.0), TLExpr::constant(0.0));
111        let simplified = algebraic_simplify(&expr);
112        assert_eq!(simplified, TLExpr::Constant(0.0));
113    }
114
115    #[test]
116    fn test_algebraic_simplify_double_negation() {
117        let expr = TLExpr::negate(TLExpr::negate(TLExpr::constant(5.0)));
118        let simplified = algebraic_simplify(&expr);
119        assert_eq!(simplified, TLExpr::Constant(5.0));
120    }
121
122    #[test]
123    fn test_optimize_expr_combined() {
124        // (2 + 3) * 1 should become 5
125        let expr = TLExpr::mul(
126            TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
127            TLExpr::constant(1.0),
128        );
129        let optimized = optimize_expr(&expr);
130        assert_eq!(optimized, TLExpr::Constant(5.0));
131    }
132
133    #[test]
134    fn test_propagate_constants_let_binding() {
135        // let x = 5 in x + 3 should become 8
136        let expr = TLExpr::let_binding(
137            "x".to_string(),
138            TLExpr::constant(5.0),
139            TLExpr::add(TLExpr::pred("x", vec![]), TLExpr::constant(3.0)),
140        );
141        let propagated = propagate_constants(&expr);
142        // After propagation: 5 + 3
143        let folded = constant_fold(&propagated);
144        assert_eq!(folded, TLExpr::Constant(8.0));
145    }
146
147    #[test]
148    fn test_propagate_constants_nested_let() {
149        // let x = 2 in let y = x + 1 in y * 3 should become 9
150        let expr = TLExpr::let_binding(
151            "x".to_string(),
152            TLExpr::constant(2.0),
153            TLExpr::let_binding(
154                "y".to_string(),
155                TLExpr::add(TLExpr::pred("x", vec![]), TLExpr::constant(1.0)),
156                TLExpr::mul(TLExpr::pred("y", vec![]), TLExpr::constant(3.0)),
157            ),
158        );
159        let optimized = optimize_expr(&expr);
160        assert_eq!(optimized, TLExpr::Constant(9.0));
161    }
162
163    #[test]
164    fn test_algebraic_simplify_and_true() {
165        let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
166        let simplified = algebraic_simplify(&expr);
167        assert_eq!(simplified, TLExpr::pred("P", vec![]));
168    }
169
170    #[test]
171    fn test_algebraic_simplify_and_false() {
172        let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::constant(0.0));
173        let simplified = algebraic_simplify(&expr);
174        assert_eq!(simplified, TLExpr::Constant(0.0));
175    }
176
177    #[test]
178    fn test_algebraic_simplify_or_false() {
179        let expr = TLExpr::or(TLExpr::pred("P", vec![]), TLExpr::constant(0.0));
180        let simplified = algebraic_simplify(&expr);
181        assert_eq!(simplified, TLExpr::pred("P", vec![]));
182    }
183
184    #[test]
185    fn test_algebraic_simplify_or_true() {
186        let expr = TLExpr::or(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
187        let simplified = algebraic_simplify(&expr);
188        assert_eq!(simplified, TLExpr::Constant(1.0));
189    }
190
191    #[test]
192    fn test_algebraic_simplify_implies_true_antecedent() {
193        let expr = TLExpr::imply(TLExpr::constant(1.0), TLExpr::pred("Q", vec![]));
194        let simplified = algebraic_simplify(&expr);
195        assert_eq!(simplified, TLExpr::pred("Q", vec![]));
196    }
197
198    #[test]
199    fn test_algebraic_simplify_implies_false_antecedent() {
200        let expr = TLExpr::imply(TLExpr::constant(0.0), TLExpr::pred("Q", vec![]));
201        let simplified = algebraic_simplify(&expr);
202        assert_eq!(simplified, TLExpr::Constant(1.0));
203    }
204
205    #[test]
206    fn test_algebraic_simplify_implies_true_consequent() {
207        let expr = TLExpr::imply(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
208        let simplified = algebraic_simplify(&expr);
209        assert_eq!(simplified, TLExpr::Constant(1.0));
210    }
211
212    #[test]
213    fn test_algebraic_simplify_implies_false_consequent() {
214        let expr = TLExpr::imply(TLExpr::pred("P", vec![]), TLExpr::constant(0.0));
215        let simplified = algebraic_simplify(&expr);
216        // P → FALSE = ¬P
217        matches!(simplified, TLExpr::Not(_));
218    }
219
220    #[test]
221    fn test_algebraic_simplify_same_comparison() {
222        // x = x should become TRUE (1.0)
223        let x = TLExpr::pred("x", vec![]);
224        let expr = TLExpr::eq(x.clone(), x);
225        let simplified = algebraic_simplify(&expr);
226        assert_eq!(simplified, TLExpr::Constant(1.0));
227    }
228
229    #[test]
230    fn test_algebraic_simplify_comparison_lt_same() {
231        // x < x should become FALSE (0.0)
232        let x = TLExpr::pred("x", vec![]);
233        let expr = TLExpr::lt(x.clone(), x);
234        let simplified = algebraic_simplify(&expr);
235        assert_eq!(simplified, TLExpr::Constant(0.0));
236    }
237
238    #[test]
239    fn test_algebraic_simplify_comparison_lte_same() {
240        // x <= x should become TRUE (1.0)
241        let x = TLExpr::pred("x", vec![]);
242        let expr = TLExpr::lte(x.clone(), x);
243        let simplified = algebraic_simplify(&expr);
244        assert_eq!(simplified, TLExpr::Constant(1.0));
245    }
246
247    #[test]
248    fn test_algebraic_simplify_division_same_constant() {
249        // 5.0 / 5.0 should become 1.0
250        let expr = TLExpr::div(TLExpr::constant(5.0), TLExpr::constant(5.0));
251        let simplified = algebraic_simplify(&expr);
252        assert_eq!(simplified, TLExpr::Constant(1.0));
253    }
254
255    #[test]
256    fn test_modal_simplify_box_true() {
257        let expr = TLExpr::modal_box(TLExpr::constant(1.0));
258        let simplified = algebraic_simplify(&expr);
259        assert_eq!(simplified, TLExpr::Constant(1.0));
260    }
261
262    #[test]
263    fn test_modal_simplify_box_false() {
264        let expr = TLExpr::modal_box(TLExpr::constant(0.0));
265        let simplified = algebraic_simplify(&expr);
266        assert_eq!(simplified, TLExpr::Constant(0.0));
267    }
268
269    #[test]
270    fn test_modal_simplify_diamond_true() {
271        let expr = TLExpr::modal_diamond(TLExpr::constant(1.0));
272        let simplified = algebraic_simplify(&expr);
273        assert_eq!(simplified, TLExpr::Constant(1.0));
274    }
275
276    #[test]
277    fn test_modal_simplify_diamond_false() {
278        let expr = TLExpr::modal_diamond(TLExpr::constant(0.0));
279        let simplified = algebraic_simplify(&expr);
280        assert_eq!(simplified, TLExpr::Constant(0.0));
281    }
282
283    #[test]
284    fn test_temporal_simplify_next_true() {
285        let expr = TLExpr::next(TLExpr::constant(1.0));
286        let simplified = algebraic_simplify(&expr);
287        assert_eq!(simplified, TLExpr::Constant(1.0));
288    }
289
290    #[test]
291    fn test_temporal_simplify_eventually_true() {
292        let expr = TLExpr::eventually(TLExpr::constant(1.0));
293        let simplified = algebraic_simplify(&expr);
294        assert_eq!(simplified, TLExpr::Constant(1.0));
295    }
296
297    #[test]
298    fn test_temporal_simplify_always_true() {
299        let expr = TLExpr::always(TLExpr::constant(1.0));
300        let simplified = algebraic_simplify(&expr);
301        assert_eq!(simplified, TLExpr::Constant(1.0));
302    }
303
304    #[test]
305    fn test_temporal_simplify_eventually_idempotent() {
306        // F(F(P)) = F(P)
307        let p = TLExpr::pred("P", vec![]);
308        let expr = TLExpr::eventually(TLExpr::eventually(p.clone()));
309        let simplified = algebraic_simplify(&expr);
310        assert_eq!(simplified, TLExpr::eventually(p));
311    }
312
313    #[test]
314    fn test_temporal_simplify_always_idempotent() {
315        // G(G(P)) = G(P)
316        let p = TLExpr::pred("P", vec![]);
317        let expr = TLExpr::always(TLExpr::always(p.clone()));
318        let simplified = algebraic_simplify(&expr);
319        assert_eq!(simplified, TLExpr::always(p));
320    }
321
322    #[test]
323    fn test_temporal_simplify_until_true() {
324        let expr = TLExpr::until(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
325        let simplified = algebraic_simplify(&expr);
326        assert_eq!(simplified, TLExpr::Constant(1.0));
327    }
328
329    #[test]
330    fn test_temporal_simplify_until_false_left() {
331        // FALSE U P = F(P)
332        let p = TLExpr::pred("P", vec![]);
333        let expr = TLExpr::until(TLExpr::constant(0.0), p.clone());
334        let simplified = algebraic_simplify(&expr);
335        assert_eq!(simplified, TLExpr::eventually(p));
336    }
337
338    #[test]
339    fn test_algebraic_simplify_absorption_and_or() {
340        // A ∧ (A ∨ B) = A
341        let a = TLExpr::pred("A", vec![]);
342        let b = TLExpr::pred("B", vec![]);
343        let expr = TLExpr::and(a.clone(), TLExpr::or(a.clone(), b));
344        let simplified = algebraic_simplify(&expr);
345        assert_eq!(simplified, a);
346    }
347
348    #[test]
349    fn test_algebraic_simplify_absorption_or_and() {
350        // A ∨ (A ∧ B) = A
351        let a = TLExpr::pred("A", vec![]);
352        let b = TLExpr::pred("B", vec![]);
353        let expr = TLExpr::or(a.clone(), TLExpr::and(a.clone(), b));
354        let simplified = algebraic_simplify(&expr);
355        assert_eq!(simplified, a);
356    }
357
358    #[test]
359    fn test_algebraic_simplify_idempotence_and() {
360        // A ∧ A = A
361        let a = TLExpr::pred("A", vec![]);
362        let expr = TLExpr::and(a.clone(), a.clone());
363        let simplified = algebraic_simplify(&expr);
364        assert_eq!(simplified, a);
365    }
366
367    #[test]
368    fn test_algebraic_simplify_idempotence_or() {
369        // A ∨ A = A
370        let a = TLExpr::pred("A", vec![]);
371        let expr = TLExpr::or(a.clone(), a.clone());
372        let simplified = algebraic_simplify(&expr);
373        assert_eq!(simplified, a);
374    }
375
376    #[test]
377    fn test_algebraic_simplify_complement_and() {
378        // A ∧ ¬A = FALSE
379        let a = TLExpr::pred("A", vec![]);
380        let expr = TLExpr::and(a.clone(), TLExpr::negate(a));
381        let simplified = algebraic_simplify(&expr);
382        assert_eq!(simplified, TLExpr::Constant(0.0));
383    }
384
385    #[test]
386    fn test_algebraic_simplify_complement_or() {
387        // A ∨ ¬A = TRUE
388        let a = TLExpr::pred("A", vec![]);
389        let expr = TLExpr::or(a.clone(), TLExpr::negate(a));
390        let simplified = algebraic_simplify(&expr);
391        assert_eq!(simplified, TLExpr::Constant(1.0));
392    }
393}