1use tensorlogic_ir::TLExpr;
7
8#[derive(Debug, Default, Clone)]
10pub struct ConstantFoldingStats {
11 pub binary_ops_folded: usize,
13 pub unary_ops_folded: usize,
15 pub total_processed: usize,
17}
18
19pub fn fold_constants(expr: &TLExpr) -> (TLExpr, ConstantFoldingStats) {
42 let mut stats = ConstantFoldingStats::default();
43 let result = fold_constants_impl(expr, &mut stats);
44 (result, stats)
45}
46
47fn fold_constants_impl(expr: &TLExpr, stats: &mut ConstantFoldingStats) -> TLExpr {
48 stats.total_processed += 1;
49
50 match expr {
51 #[allow(unreachable_patterns)] TLExpr::Add(left, right) => fold_binary_op(
53 left,
54 right,
55 stats,
56 |a, b| a + b,
57 |l, r| TLExpr::Add(Box::new(l), Box::new(r)),
58 ),
59 TLExpr::Sub(left, right) => fold_binary_op(
60 left,
61 right,
62 stats,
63 |a, b| a - b,
64 |l, r| TLExpr::Sub(Box::new(l), Box::new(r)),
65 ),
66 TLExpr::Mul(left, right) => fold_binary_op(
67 left,
68 right,
69 stats,
70 |a, b| a * b,
71 |l, r| TLExpr::Mul(Box::new(l), Box::new(r)),
72 ),
73 TLExpr::Div(left, right) => fold_binary_op(
74 left,
75 right,
76 stats,
77 |a, b| {
78 if b.abs() < f64::EPSILON {
79 f64::NAN } else {
81 a / b
82 }
83 },
84 |l, r| TLExpr::Div(Box::new(l), Box::new(r)),
85 ),
86 TLExpr::Pow(left, right) => fold_binary_op(
87 left,
88 right,
89 stats,
90 |a, b| a.powf(b),
91 |l, r| TLExpr::Pow(Box::new(l), Box::new(r)),
92 ),
93 TLExpr::Mod(left, right) => fold_binary_op(
94 left,
95 right,
96 stats,
97 |a, b| a % b,
98 |l, r| TLExpr::Mod(Box::new(l), Box::new(r)),
99 ),
100 TLExpr::Min(left, right) => fold_binary_op(
101 left,
102 right,
103 stats,
104 |a, b| a.min(b),
105 |l, r| TLExpr::Min(Box::new(l), Box::new(r)),
106 ),
107 TLExpr::Max(left, right) => fold_binary_op(
108 left,
109 right,
110 stats,
111 |a, b| a.max(b),
112 |l, r| TLExpr::Max(Box::new(l), Box::new(r)),
113 ),
114
115 TLExpr::Abs(inner) => {
117 fold_unary_op(inner, stats, |x| x.abs(), |i| TLExpr::Abs(Box::new(i)))
118 }
119 TLExpr::Floor(inner) => {
120 fold_unary_op(inner, stats, |x| x.floor(), |i| TLExpr::Floor(Box::new(i)))
121 }
122 TLExpr::Ceil(inner) => {
123 fold_unary_op(inner, stats, |x| x.ceil(), |i| TLExpr::Ceil(Box::new(i)))
124 }
125 TLExpr::Round(inner) => {
126 fold_unary_op(inner, stats, |x| x.round(), |i| TLExpr::Round(Box::new(i)))
127 }
128 TLExpr::Sqrt(inner) => {
129 fold_unary_op(inner, stats, |x| x.sqrt(), |i| TLExpr::Sqrt(Box::new(i)))
130 }
131 TLExpr::Exp(inner) => {
132 fold_unary_op(inner, stats, |x| x.exp(), |i| TLExpr::Exp(Box::new(i)))
133 }
134 TLExpr::Log(inner) => fold_unary_op(inner, stats, |x| x.ln(), |i| TLExpr::Log(Box::new(i))),
135 TLExpr::Sin(inner) => {
136 fold_unary_op(inner, stats, |x| x.sin(), |i| TLExpr::Sin(Box::new(i)))
137 }
138 TLExpr::Cos(inner) => {
139 fold_unary_op(inner, stats, |x| x.cos(), |i| TLExpr::Cos(Box::new(i)))
140 }
141 TLExpr::Tan(inner) => {
142 fold_unary_op(inner, stats, |x| x.tan(), |i| TLExpr::Tan(Box::new(i)))
143 }
144
145 TLExpr::And(left, right) => {
147 let left_opt = fold_constants_impl(left, stats);
148 let right_opt = fold_constants_impl(right, stats);
149 TLExpr::And(Box::new(left_opt), Box::new(right_opt))
150 }
151 TLExpr::Or(left, right) => {
152 let left_opt = fold_constants_impl(left, stats);
153 let right_opt = fold_constants_impl(right, stats);
154 TLExpr::Or(Box::new(left_opt), Box::new(right_opt))
155 }
156 TLExpr::Not(inner) => {
157 let inner_opt = fold_constants_impl(inner, stats);
158 TLExpr::Not(Box::new(inner_opt))
159 }
160 TLExpr::Imply(left, right) => {
161 let left_opt = fold_constants_impl(left, stats);
162 let right_opt = fold_constants_impl(right, stats);
163 TLExpr::Imply(Box::new(left_opt), Box::new(right_opt))
164 }
165
166 TLExpr::Eq(left, right) => {
168 let left_opt = fold_constants_impl(left, stats);
169 let right_opt = fold_constants_impl(right, stats);
170 TLExpr::Eq(Box::new(left_opt), Box::new(right_opt))
171 }
172 TLExpr::Lt(left, right) => {
173 let left_opt = fold_constants_impl(left, stats);
174 let right_opt = fold_constants_impl(right, stats);
175 TLExpr::Lt(Box::new(left_opt), Box::new(right_opt))
176 }
177 TLExpr::Gt(left, right) => {
178 let left_opt = fold_constants_impl(left, stats);
179 let right_opt = fold_constants_impl(right, stats);
180 TLExpr::Gt(Box::new(left_opt), Box::new(right_opt))
181 }
182 TLExpr::Lte(left, right) => {
183 let left_opt = fold_constants_impl(left, stats);
184 let right_opt = fold_constants_impl(right, stats);
185 TLExpr::Lte(Box::new(left_opt), Box::new(right_opt))
186 }
187 TLExpr::Gte(left, right) => {
188 let left_opt = fold_constants_impl(left, stats);
189 let right_opt = fold_constants_impl(right, stats);
190 TLExpr::Gte(Box::new(left_opt), Box::new(right_opt))
191 }
192
193 TLExpr::Exists { var, domain, body } => {
195 let body_opt = fold_constants_impl(body, stats);
196 TLExpr::Exists {
197 var: var.clone(),
198 domain: domain.clone(),
199 body: Box::new(body_opt),
200 }
201 }
202 TLExpr::ForAll { var, domain, body } => {
203 let body_opt = fold_constants_impl(body, stats);
204 TLExpr::ForAll {
205 var: var.clone(),
206 domain: domain.clone(),
207 body: Box::new(body_opt),
208 }
209 }
210 TLExpr::Aggregate {
211 op,
212 var,
213 domain,
214 body,
215 group_by,
216 } => {
217 let body_opt = fold_constants_impl(body, stats);
218 TLExpr::Aggregate {
219 op: op.clone(),
220 var: var.clone(),
221 domain: domain.clone(),
222 body: Box::new(body_opt),
223 group_by: group_by.clone(),
224 }
225 }
226 TLExpr::IfThenElse {
227 condition,
228 then_branch,
229 else_branch,
230 } => {
231 let cond_opt = fold_constants_impl(condition, stats);
232 let then_opt = fold_constants_impl(then_branch, stats);
233 let else_opt = fold_constants_impl(else_branch, stats);
234 TLExpr::IfThenElse {
235 condition: Box::new(cond_opt),
236 then_branch: Box::new(then_opt),
237 else_branch: Box::new(else_opt),
238 }
239 }
240 TLExpr::Let { var, value, body } => {
241 let value_opt = fold_constants_impl(value, stats);
242 let body_opt = fold_constants_impl(body, stats);
243 TLExpr::Let {
244 var: var.clone(),
245 value: Box::new(value_opt),
246 body: Box::new(body_opt),
247 }
248 }
249
250 TLExpr::TNorm { kind, left, right } => {
252 let left_opt = fold_constants_impl(left, stats);
253 let right_opt = fold_constants_impl(right, stats);
254 TLExpr::TNorm {
255 kind: *kind,
256 left: Box::new(left_opt),
257 right: Box::new(right_opt),
258 }
259 }
260 TLExpr::TCoNorm { kind, left, right } => {
261 let left_opt = fold_constants_impl(left, stats);
262 let right_opt = fold_constants_impl(right, stats);
263 TLExpr::TCoNorm {
264 kind: *kind,
265 left: Box::new(left_opt),
266 right: Box::new(right_opt),
267 }
268 }
269 TLExpr::FuzzyNot { kind, expr: inner } => {
270 let inner_opt = fold_constants_impl(inner, stats);
271 TLExpr::FuzzyNot {
272 kind: *kind,
273 expr: Box::new(inner_opt),
274 }
275 }
276 TLExpr::FuzzyImplication {
277 kind,
278 premise,
279 conclusion,
280 } => {
281 let premise_opt = fold_constants_impl(premise, stats);
282 let conclusion_opt = fold_constants_impl(conclusion, stats);
283 TLExpr::FuzzyImplication {
284 kind: *kind,
285 premise: Box::new(premise_opt),
286 conclusion: Box::new(conclusion_opt),
287 }
288 }
289 TLExpr::SoftExists {
290 var,
291 domain,
292 body,
293 temperature,
294 } => {
295 let body_opt = fold_constants_impl(body, stats);
296 TLExpr::SoftExists {
297 var: var.clone(),
298 domain: domain.clone(),
299 body: Box::new(body_opt),
300 temperature: *temperature,
301 }
302 }
303 TLExpr::SoftForAll {
304 var,
305 domain,
306 body,
307 temperature,
308 } => {
309 let body_opt = fold_constants_impl(body, stats);
310 TLExpr::SoftForAll {
311 var: var.clone(),
312 domain: domain.clone(),
313 body: Box::new(body_opt),
314 temperature: *temperature,
315 }
316 }
317 TLExpr::WeightedRule { weight, rule } => {
318 let rule_opt = fold_constants_impl(rule, stats);
319 TLExpr::WeightedRule {
320 weight: *weight,
321 rule: Box::new(rule_opt),
322 }
323 }
324 TLExpr::ProbabilisticChoice { alternatives } => {
325 let alts_opt: Vec<_> = alternatives
326 .iter()
327 .map(|(w, e)| (*w, fold_constants_impl(e, stats)))
328 .collect();
329 TLExpr::ProbabilisticChoice {
330 alternatives: alts_opt,
331 }
332 }
333
334 TLExpr::Box(inner) => TLExpr::Box(Box::new(fold_constants_impl(inner, stats))),
336 TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(fold_constants_impl(inner, stats))),
337 TLExpr::Next(inner) => TLExpr::Next(Box::new(fold_constants_impl(inner, stats))),
338 TLExpr::Eventually(inner) => {
339 TLExpr::Eventually(Box::new(fold_constants_impl(inner, stats)))
340 }
341 TLExpr::Always(inner) => TLExpr::Always(Box::new(fold_constants_impl(inner, stats))),
342 TLExpr::Until { before, after } => TLExpr::Until {
343 before: Box::new(fold_constants_impl(before, stats)),
344 after: Box::new(fold_constants_impl(after, stats)),
345 },
346 TLExpr::Release { released, releaser } => TLExpr::Release {
347 released: Box::new(fold_constants_impl(released, stats)),
348 releaser: Box::new(fold_constants_impl(releaser, stats)),
349 },
350 TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
351 before: Box::new(fold_constants_impl(before, stats)),
352 after: Box::new(fold_constants_impl(after, stats)),
353 },
354 TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
355 released: Box::new(fold_constants_impl(released, stats)),
356 releaser: Box::new(fold_constants_impl(releaser, stats)),
357 },
358
359 TLExpr::Pred { .. } | TLExpr::Constant(_) | TLExpr::Score(_) => expr.clone(),
361 }
362}
363
364fn fold_binary_op<F, C>(
366 left: &TLExpr,
367 right: &TLExpr,
368 stats: &mut ConstantFoldingStats,
369 op: F,
370 constructor: C,
371) -> TLExpr
372where
373 F: Fn(f64, f64) -> f64,
374 C: Fn(TLExpr, TLExpr) -> TLExpr,
375{
376 let left_opt = fold_constants_impl(left, stats);
377 let right_opt = fold_constants_impl(right, stats);
378
379 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_opt, &right_opt) {
380 stats.binary_ops_folded += 1;
381 TLExpr::Constant(op(*a, *b))
382 } else {
383 constructor(left_opt, right_opt)
384 }
385}
386
387fn fold_unary_op<F, C>(
389 inner: &TLExpr,
390 stats: &mut ConstantFoldingStats,
391 op: F,
392 constructor: C,
393) -> TLExpr
394where
395 F: Fn(f64) -> f64,
396 C: Fn(TLExpr) -> TLExpr,
397{
398 let inner_opt = fold_constants_impl(inner, stats);
399
400 if let TLExpr::Constant(x) = inner_opt {
401 stats.unary_ops_folded += 1;
402 TLExpr::Constant(op(x))
403 } else {
404 constructor(inner_opt)
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_fold_binary_arithmetic() {
414 let expr = TLExpr::Add(
416 Box::new(TLExpr::Constant(2.0)),
417 Box::new(TLExpr::Constant(3.0)),
418 );
419 let (result, stats) = fold_constants(&expr);
420 assert!(matches!(result, TLExpr::Constant(x) if (x - 5.0).abs() < f64::EPSILON));
421 assert_eq!(stats.binary_ops_folded, 1);
422 }
423
424 #[test]
425 fn test_fold_nested_arithmetic() {
426 let expr = TLExpr::Mul(
428 Box::new(TLExpr::Add(
429 Box::new(TLExpr::Constant(2.0)),
430 Box::new(TLExpr::Constant(3.0)),
431 )),
432 Box::new(TLExpr::Constant(4.0)),
433 );
434 let (result, stats) = fold_constants(&expr);
435 assert!(matches!(result, TLExpr::Constant(x) if (x - 20.0).abs() < f64::EPSILON));
436 assert_eq!(stats.binary_ops_folded, 2);
437 }
438
439 #[test]
440 fn test_fold_unary_operations() {
441 let expr = TLExpr::Sqrt(Box::new(TLExpr::Constant(16.0)));
443 let (result, stats) = fold_constants(&expr);
444 assert!(matches!(result, TLExpr::Constant(x) if (x - 4.0).abs() < f64::EPSILON));
445 assert_eq!(stats.unary_ops_folded, 1);
446 }
447
448 #[test]
449 fn test_fold_trigonometry() {
450 let expr = TLExpr::Sin(Box::new(TLExpr::Constant(0.0)));
452 let (result, stats) = fold_constants(&expr);
453 assert!(matches!(result, TLExpr::Constant(x) if x.abs() < f64::EPSILON));
454 assert_eq!(stats.unary_ops_folded, 1);
455 }
456
457 #[test]
458 fn test_no_fold_with_variables() {
459 use tensorlogic_ir::Term;
460
461 let expr = TLExpr::Add(
463 Box::new(TLExpr::pred("x", vec![Term::var("i")])),
464 Box::new(TLExpr::Constant(2.0)),
465 );
466 let (result, stats) = fold_constants(&expr);
467 assert!(matches!(result, TLExpr::Add(..)));
468 assert_eq!(stats.binary_ops_folded, 0);
469 }
470}