tensorlogic_ir/expr/optimization/
mod.rs1mod algebraic;
12mod constant_folding;
13mod propagation;
14pub(crate) mod substitution;
15
16pub use algebraic::algebraic_simplify;
18pub use constant_folding::constant_fold;
19pub use propagation::propagate_constants;
20
21use crate::expr::TLExpr;
22
23pub fn optimize_expr(expr: &TLExpr) -> TLExpr {
44 let mut current = expr.clone();
47 let mut iterations = 0;
48 const MAX_ITERATIONS: usize = 10; loop {
51 let propagated = propagate_constants(¤t);
52 let folded = constant_fold(&propagated);
53 let simplified = algebraic_simplify(&folded);
54
55 if simplified == current || iterations >= MAX_ITERATIONS {
57 return simplified;
58 }
59
60 current = simplified;
61 iterations += 1;
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn test_constant_fold_addition() {
71 let expr = TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0));
72 let folded = constant_fold(&expr);
73 assert_eq!(folded, TLExpr::Constant(5.0));
74 }
75
76 #[test]
77 fn test_constant_fold_multiplication() {
78 let expr = TLExpr::mul(TLExpr::constant(4.0), TLExpr::constant(5.0));
79 let folded = constant_fold(&expr);
80 assert_eq!(folded, TLExpr::Constant(20.0));
81 }
82
83 #[test]
84 fn test_constant_fold_nested() {
85 let expr = TLExpr::mul(
87 TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
88 TLExpr::constant(4.0),
89 );
90 let folded = constant_fold(&expr);
91 assert_eq!(folded, TLExpr::Constant(20.0));
92 }
93
94 #[test]
95 fn test_algebraic_simplify_add_zero() {
96 let expr = TLExpr::add(TLExpr::constant(5.0), TLExpr::constant(0.0));
97 let simplified = algebraic_simplify(&expr);
98 assert_eq!(simplified, TLExpr::Constant(5.0));
99 }
100
101 #[test]
102 fn test_algebraic_simplify_mul_one() {
103 let expr = TLExpr::mul(TLExpr::constant(7.0), TLExpr::constant(1.0));
104 let simplified = algebraic_simplify(&expr);
105 assert_eq!(simplified, TLExpr::Constant(7.0));
106 }
107
108 #[test]
109 fn test_algebraic_simplify_mul_zero() {
110 let expr = TLExpr::mul(TLExpr::constant(7.0), TLExpr::constant(0.0));
111 let simplified = algebraic_simplify(&expr);
112 assert_eq!(simplified, TLExpr::Constant(0.0));
113 }
114
115 #[test]
116 fn test_algebraic_simplify_double_negation() {
117 let expr = TLExpr::negate(TLExpr::negate(TLExpr::constant(5.0)));
118 let simplified = algebraic_simplify(&expr);
119 assert_eq!(simplified, TLExpr::Constant(5.0));
120 }
121
122 #[test]
123 fn test_optimize_expr_combined() {
124 let expr = TLExpr::mul(
126 TLExpr::add(TLExpr::constant(2.0), TLExpr::constant(3.0)),
127 TLExpr::constant(1.0),
128 );
129 let optimized = optimize_expr(&expr);
130 assert_eq!(optimized, TLExpr::Constant(5.0));
131 }
132
133 #[test]
134 fn test_propagate_constants_let_binding() {
135 let expr = TLExpr::let_binding(
137 "x".to_string(),
138 TLExpr::constant(5.0),
139 TLExpr::add(TLExpr::pred("x", vec![]), TLExpr::constant(3.0)),
140 );
141 let propagated = propagate_constants(&expr);
142 let folded = constant_fold(&propagated);
144 assert_eq!(folded, TLExpr::Constant(8.0));
145 }
146
147 #[test]
148 fn test_propagate_constants_nested_let() {
149 let expr = TLExpr::let_binding(
151 "x".to_string(),
152 TLExpr::constant(2.0),
153 TLExpr::let_binding(
154 "y".to_string(),
155 TLExpr::add(TLExpr::pred("x", vec![]), TLExpr::constant(1.0)),
156 TLExpr::mul(TLExpr::pred("y", vec![]), TLExpr::constant(3.0)),
157 ),
158 );
159 let optimized = optimize_expr(&expr);
160 assert_eq!(optimized, TLExpr::Constant(9.0));
161 }
162
163 #[test]
164 fn test_algebraic_simplify_and_true() {
165 let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
166 let simplified = algebraic_simplify(&expr);
167 assert_eq!(simplified, TLExpr::pred("P", vec![]));
168 }
169
170 #[test]
171 fn test_algebraic_simplify_and_false() {
172 let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::constant(0.0));
173 let simplified = algebraic_simplify(&expr);
174 assert_eq!(simplified, TLExpr::Constant(0.0));
175 }
176
177 #[test]
178 fn test_algebraic_simplify_or_false() {
179 let expr = TLExpr::or(TLExpr::pred("P", vec![]), TLExpr::constant(0.0));
180 let simplified = algebraic_simplify(&expr);
181 assert_eq!(simplified, TLExpr::pred("P", vec![]));
182 }
183
184 #[test]
185 fn test_algebraic_simplify_or_true() {
186 let expr = TLExpr::or(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
187 let simplified = algebraic_simplify(&expr);
188 assert_eq!(simplified, TLExpr::Constant(1.0));
189 }
190
191 #[test]
192 fn test_algebraic_simplify_implies_true_antecedent() {
193 let expr = TLExpr::imply(TLExpr::constant(1.0), TLExpr::pred("Q", vec![]));
194 let simplified = algebraic_simplify(&expr);
195 assert_eq!(simplified, TLExpr::pred("Q", vec![]));
196 }
197
198 #[test]
199 fn test_algebraic_simplify_implies_false_antecedent() {
200 let expr = TLExpr::imply(TLExpr::constant(0.0), TLExpr::pred("Q", vec![]));
201 let simplified = algebraic_simplify(&expr);
202 assert_eq!(simplified, TLExpr::Constant(1.0));
203 }
204
205 #[test]
206 fn test_algebraic_simplify_implies_true_consequent() {
207 let expr = TLExpr::imply(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
208 let simplified = algebraic_simplify(&expr);
209 assert_eq!(simplified, TLExpr::Constant(1.0));
210 }
211
212 #[test]
213 fn test_algebraic_simplify_implies_false_consequent() {
214 let expr = TLExpr::imply(TLExpr::pred("P", vec![]), TLExpr::constant(0.0));
215 let simplified = algebraic_simplify(&expr);
216 matches!(simplified, TLExpr::Not(_));
218 }
219
220 #[test]
221 fn test_algebraic_simplify_same_comparison() {
222 let x = TLExpr::pred("x", vec![]);
224 let expr = TLExpr::eq(x.clone(), x);
225 let simplified = algebraic_simplify(&expr);
226 assert_eq!(simplified, TLExpr::Constant(1.0));
227 }
228
229 #[test]
230 fn test_algebraic_simplify_comparison_lt_same() {
231 let x = TLExpr::pred("x", vec![]);
233 let expr = TLExpr::lt(x.clone(), x);
234 let simplified = algebraic_simplify(&expr);
235 assert_eq!(simplified, TLExpr::Constant(0.0));
236 }
237
238 #[test]
239 fn test_algebraic_simplify_comparison_lte_same() {
240 let x = TLExpr::pred("x", vec![]);
242 let expr = TLExpr::lte(x.clone(), x);
243 let simplified = algebraic_simplify(&expr);
244 assert_eq!(simplified, TLExpr::Constant(1.0));
245 }
246
247 #[test]
248 fn test_algebraic_simplify_division_same_constant() {
249 let expr = TLExpr::div(TLExpr::constant(5.0), TLExpr::constant(5.0));
251 let simplified = algebraic_simplify(&expr);
252 assert_eq!(simplified, TLExpr::Constant(1.0));
253 }
254
255 #[test]
256 fn test_modal_simplify_box_true() {
257 let expr = TLExpr::modal_box(TLExpr::constant(1.0));
258 let simplified = algebraic_simplify(&expr);
259 assert_eq!(simplified, TLExpr::Constant(1.0));
260 }
261
262 #[test]
263 fn test_modal_simplify_box_false() {
264 let expr = TLExpr::modal_box(TLExpr::constant(0.0));
265 let simplified = algebraic_simplify(&expr);
266 assert_eq!(simplified, TLExpr::Constant(0.0));
267 }
268
269 #[test]
270 fn test_modal_simplify_diamond_true() {
271 let expr = TLExpr::modal_diamond(TLExpr::constant(1.0));
272 let simplified = algebraic_simplify(&expr);
273 assert_eq!(simplified, TLExpr::Constant(1.0));
274 }
275
276 #[test]
277 fn test_modal_simplify_diamond_false() {
278 let expr = TLExpr::modal_diamond(TLExpr::constant(0.0));
279 let simplified = algebraic_simplify(&expr);
280 assert_eq!(simplified, TLExpr::Constant(0.0));
281 }
282
283 #[test]
284 fn test_temporal_simplify_next_true() {
285 let expr = TLExpr::next(TLExpr::constant(1.0));
286 let simplified = algebraic_simplify(&expr);
287 assert_eq!(simplified, TLExpr::Constant(1.0));
288 }
289
290 #[test]
291 fn test_temporal_simplify_eventually_true() {
292 let expr = TLExpr::eventually(TLExpr::constant(1.0));
293 let simplified = algebraic_simplify(&expr);
294 assert_eq!(simplified, TLExpr::Constant(1.0));
295 }
296
297 #[test]
298 fn test_temporal_simplify_always_true() {
299 let expr = TLExpr::always(TLExpr::constant(1.0));
300 let simplified = algebraic_simplify(&expr);
301 assert_eq!(simplified, TLExpr::Constant(1.0));
302 }
303
304 #[test]
305 fn test_temporal_simplify_eventually_idempotent() {
306 let p = TLExpr::pred("P", vec![]);
308 let expr = TLExpr::eventually(TLExpr::eventually(p.clone()));
309 let simplified = algebraic_simplify(&expr);
310 assert_eq!(simplified, TLExpr::eventually(p));
311 }
312
313 #[test]
314 fn test_temporal_simplify_always_idempotent() {
315 let p = TLExpr::pred("P", vec![]);
317 let expr = TLExpr::always(TLExpr::always(p.clone()));
318 let simplified = algebraic_simplify(&expr);
319 assert_eq!(simplified, TLExpr::always(p));
320 }
321
322 #[test]
323 fn test_temporal_simplify_until_true() {
324 let expr = TLExpr::until(TLExpr::pred("P", vec![]), TLExpr::constant(1.0));
325 let simplified = algebraic_simplify(&expr);
326 assert_eq!(simplified, TLExpr::Constant(1.0));
327 }
328
329 #[test]
330 fn test_temporal_simplify_until_false_left() {
331 let p = TLExpr::pred("P", vec![]);
333 let expr = TLExpr::until(TLExpr::constant(0.0), p.clone());
334 let simplified = algebraic_simplify(&expr);
335 assert_eq!(simplified, TLExpr::eventually(p));
336 }
337
338 #[test]
339 fn test_algebraic_simplify_absorption_and_or() {
340 let a = TLExpr::pred("A", vec![]);
342 let b = TLExpr::pred("B", vec![]);
343 let expr = TLExpr::and(a.clone(), TLExpr::or(a.clone(), b));
344 let simplified = algebraic_simplify(&expr);
345 assert_eq!(simplified, a);
346 }
347
348 #[test]
349 fn test_algebraic_simplify_absorption_or_and() {
350 let a = TLExpr::pred("A", vec![]);
352 let b = TLExpr::pred("B", vec![]);
353 let expr = TLExpr::or(a.clone(), TLExpr::and(a.clone(), b));
354 let simplified = algebraic_simplify(&expr);
355 assert_eq!(simplified, a);
356 }
357
358 #[test]
359 fn test_algebraic_simplify_idempotence_and() {
360 let a = TLExpr::pred("A", vec![]);
362 let expr = TLExpr::and(a.clone(), a.clone());
363 let simplified = algebraic_simplify(&expr);
364 assert_eq!(simplified, a);
365 }
366
367 #[test]
368 fn test_algebraic_simplify_idempotence_or() {
369 let a = TLExpr::pred("A", vec![]);
371 let expr = TLExpr::or(a.clone(), a.clone());
372 let simplified = algebraic_simplify(&expr);
373 assert_eq!(simplified, a);
374 }
375
376 #[test]
377 fn test_algebraic_simplify_complement_and() {
378 let a = TLExpr::pred("A", vec![]);
380 let expr = TLExpr::and(a.clone(), TLExpr::negate(a));
381 let simplified = algebraic_simplify(&expr);
382 assert_eq!(simplified, TLExpr::Constant(0.0));
383 }
384
385 #[test]
386 fn test_algebraic_simplify_complement_or() {
387 let a = TLExpr::pred("A", vec![]);
389 let expr = TLExpr::or(a.clone(), TLExpr::negate(a));
390 let simplified = algebraic_simplify(&expr);
391 assert_eq!(simplified, TLExpr::Constant(1.0));
392 }
393}