1use tensorlogic_ir::TLExpr;
7
8#[derive(Debug, Default, Clone)]
10pub struct ConstantFoldingStats {
11 pub binary_ops_folded: usize,
13 pub unary_ops_folded: usize,
15 pub total_processed: usize,
17}
18
19pub fn fold_constants(expr: &TLExpr) -> (TLExpr, ConstantFoldingStats) {
42 let mut stats = ConstantFoldingStats::default();
43 let result = fold_constants_impl(expr, &mut stats);
44 (result, stats)
45}
46
47fn fold_constants_impl(expr: &TLExpr, stats: &mut ConstantFoldingStats) -> TLExpr {
48 stats.total_processed += 1;
49
50 match expr {
51 #[allow(unreachable_patterns)] TLExpr::Add(left, right) => fold_binary_op(
53 left,
54 right,
55 stats,
56 |a, b| a + b,
57 |l, r| TLExpr::Add(Box::new(l), Box::new(r)),
58 ),
59 TLExpr::Sub(left, right) => fold_binary_op(
60 left,
61 right,
62 stats,
63 |a, b| a - b,
64 |l, r| TLExpr::Sub(Box::new(l), Box::new(r)),
65 ),
66 TLExpr::Mul(left, right) => fold_binary_op(
67 left,
68 right,
69 stats,
70 |a, b| a * b,
71 |l, r| TLExpr::Mul(Box::new(l), Box::new(r)),
72 ),
73 TLExpr::Div(left, right) => fold_binary_op(
74 left,
75 right,
76 stats,
77 |a, b| {
78 if b.abs() < f64::EPSILON {
79 f64::NAN } else {
81 a / b
82 }
83 },
84 |l, r| TLExpr::Div(Box::new(l), Box::new(r)),
85 ),
86 TLExpr::Pow(left, right) => fold_binary_op(
87 left,
88 right,
89 stats,
90 |a, b| a.powf(b),
91 |l, r| TLExpr::Pow(Box::new(l), Box::new(r)),
92 ),
93 TLExpr::Mod(left, right) => fold_binary_op(
94 left,
95 right,
96 stats,
97 |a, b| a % b,
98 |l, r| TLExpr::Mod(Box::new(l), Box::new(r)),
99 ),
100 TLExpr::Min(left, right) => fold_binary_op(
101 left,
102 right,
103 stats,
104 |a, b| a.min(b),
105 |l, r| TLExpr::Min(Box::new(l), Box::new(r)),
106 ),
107 TLExpr::Max(left, right) => fold_binary_op(
108 left,
109 right,
110 stats,
111 |a, b| a.max(b),
112 |l, r| TLExpr::Max(Box::new(l), Box::new(r)),
113 ),
114
115 TLExpr::Abs(inner) => {
117 fold_unary_op(inner, stats, |x| x.abs(), |i| TLExpr::Abs(Box::new(i)))
118 }
119 TLExpr::Floor(inner) => {
120 fold_unary_op(inner, stats, |x| x.floor(), |i| TLExpr::Floor(Box::new(i)))
121 }
122 TLExpr::Ceil(inner) => {
123 fold_unary_op(inner, stats, |x| x.ceil(), |i| TLExpr::Ceil(Box::new(i)))
124 }
125 TLExpr::Round(inner) => {
126 fold_unary_op(inner, stats, |x| x.round(), |i| TLExpr::Round(Box::new(i)))
127 }
128 TLExpr::Sqrt(inner) => {
129 fold_unary_op(inner, stats, |x| x.sqrt(), |i| TLExpr::Sqrt(Box::new(i)))
130 }
131 TLExpr::Exp(inner) => {
132 fold_unary_op(inner, stats, |x| x.exp(), |i| TLExpr::Exp(Box::new(i)))
133 }
134 TLExpr::Log(inner) => fold_unary_op(inner, stats, |x| x.ln(), |i| TLExpr::Log(Box::new(i))),
135 TLExpr::Sin(inner) => {
136 fold_unary_op(inner, stats, |x| x.sin(), |i| TLExpr::Sin(Box::new(i)))
137 }
138 TLExpr::Cos(inner) => {
139 fold_unary_op(inner, stats, |x| x.cos(), |i| TLExpr::Cos(Box::new(i)))
140 }
141 TLExpr::Tan(inner) => {
142 fold_unary_op(inner, stats, |x| x.tan(), |i| TLExpr::Tan(Box::new(i)))
143 }
144
145 TLExpr::And(left, right) => {
147 let left_opt = fold_constants_impl(left, stats);
148 let right_opt = fold_constants_impl(right, stats);
149 TLExpr::And(Box::new(left_opt), Box::new(right_opt))
150 }
151 TLExpr::Or(left, right) => {
152 let left_opt = fold_constants_impl(left, stats);
153 let right_opt = fold_constants_impl(right, stats);
154 TLExpr::Or(Box::new(left_opt), Box::new(right_opt))
155 }
156 TLExpr::Not(inner) => {
157 let inner_opt = fold_constants_impl(inner, stats);
158 TLExpr::Not(Box::new(inner_opt))
159 }
160 TLExpr::Imply(left, right) => {
161 let left_opt = fold_constants_impl(left, stats);
162 let right_opt = fold_constants_impl(right, stats);
163 TLExpr::Imply(Box::new(left_opt), Box::new(right_opt))
164 }
165
166 TLExpr::Eq(left, right) => {
168 let left_opt = fold_constants_impl(left, stats);
169 let right_opt = fold_constants_impl(right, stats);
170 TLExpr::Eq(Box::new(left_opt), Box::new(right_opt))
171 }
172 TLExpr::Lt(left, right) => {
173 let left_opt = fold_constants_impl(left, stats);
174 let right_opt = fold_constants_impl(right, stats);
175 TLExpr::Lt(Box::new(left_opt), Box::new(right_opt))
176 }
177 TLExpr::Gt(left, right) => {
178 let left_opt = fold_constants_impl(left, stats);
179 let right_opt = fold_constants_impl(right, stats);
180 TLExpr::Gt(Box::new(left_opt), Box::new(right_opt))
181 }
182 TLExpr::Lte(left, right) => {
183 let left_opt = fold_constants_impl(left, stats);
184 let right_opt = fold_constants_impl(right, stats);
185 TLExpr::Lte(Box::new(left_opt), Box::new(right_opt))
186 }
187 TLExpr::Gte(left, right) => {
188 let left_opt = fold_constants_impl(left, stats);
189 let right_opt = fold_constants_impl(right, stats);
190 TLExpr::Gte(Box::new(left_opt), Box::new(right_opt))
191 }
192
193 TLExpr::Exists { var, domain, body } => {
195 let body_opt = fold_constants_impl(body, stats);
196 TLExpr::Exists {
197 var: var.clone(),
198 domain: domain.clone(),
199 body: Box::new(body_opt),
200 }
201 }
202 TLExpr::ForAll { var, domain, body } => {
203 let body_opt = fold_constants_impl(body, stats);
204 TLExpr::ForAll {
205 var: var.clone(),
206 domain: domain.clone(),
207 body: Box::new(body_opt),
208 }
209 }
210 TLExpr::Aggregate {
211 op,
212 var,
213 domain,
214 body,
215 group_by,
216 } => {
217 let body_opt = fold_constants_impl(body, stats);
218 TLExpr::Aggregate {
219 op: op.clone(),
220 var: var.clone(),
221 domain: domain.clone(),
222 body: Box::new(body_opt),
223 group_by: group_by.clone(),
224 }
225 }
226 TLExpr::IfThenElse {
227 condition,
228 then_branch,
229 else_branch,
230 } => {
231 let cond_opt = fold_constants_impl(condition, stats);
232 let then_opt = fold_constants_impl(then_branch, stats);
233 let else_opt = fold_constants_impl(else_branch, stats);
234 TLExpr::IfThenElse {
235 condition: Box::new(cond_opt),
236 then_branch: Box::new(then_opt),
237 else_branch: Box::new(else_opt),
238 }
239 }
240 TLExpr::Let { var, value, body } => {
241 let value_opt = fold_constants_impl(value, stats);
242 let body_opt = fold_constants_impl(body, stats);
243 TLExpr::Let {
244 var: var.clone(),
245 value: Box::new(value_opt),
246 body: Box::new(body_opt),
247 }
248 }
249
250 TLExpr::TNorm { kind, left, right } => {
252 let left_opt = fold_constants_impl(left, stats);
253 let right_opt = fold_constants_impl(right, stats);
254 TLExpr::TNorm {
255 kind: *kind,
256 left: Box::new(left_opt),
257 right: Box::new(right_opt),
258 }
259 }
260 TLExpr::TCoNorm { kind, left, right } => {
261 let left_opt = fold_constants_impl(left, stats);
262 let right_opt = fold_constants_impl(right, stats);
263 TLExpr::TCoNorm {
264 kind: *kind,
265 left: Box::new(left_opt),
266 right: Box::new(right_opt),
267 }
268 }
269 TLExpr::FuzzyNot { kind, expr: inner } => {
270 let inner_opt = fold_constants_impl(inner, stats);
271 TLExpr::FuzzyNot {
272 kind: *kind,
273 expr: Box::new(inner_opt),
274 }
275 }
276 TLExpr::FuzzyImplication {
277 kind,
278 premise,
279 conclusion,
280 } => {
281 let premise_opt = fold_constants_impl(premise, stats);
282 let conclusion_opt = fold_constants_impl(conclusion, stats);
283 TLExpr::FuzzyImplication {
284 kind: *kind,
285 premise: Box::new(premise_opt),
286 conclusion: Box::new(conclusion_opt),
287 }
288 }
289 TLExpr::SoftExists {
290 var,
291 domain,
292 body,
293 temperature,
294 } => {
295 let body_opt = fold_constants_impl(body, stats);
296 TLExpr::SoftExists {
297 var: var.clone(),
298 domain: domain.clone(),
299 body: Box::new(body_opt),
300 temperature: *temperature,
301 }
302 }
303 TLExpr::SoftForAll {
304 var,
305 domain,
306 body,
307 temperature,
308 } => {
309 let body_opt = fold_constants_impl(body, stats);
310 TLExpr::SoftForAll {
311 var: var.clone(),
312 domain: domain.clone(),
313 body: Box::new(body_opt),
314 temperature: *temperature,
315 }
316 }
317 TLExpr::WeightedRule { weight, rule } => {
318 let rule_opt = fold_constants_impl(rule, stats);
319 TLExpr::WeightedRule {
320 weight: *weight,
321 rule: Box::new(rule_opt),
322 }
323 }
324 TLExpr::ProbabilisticChoice { alternatives } => {
325 let alts_opt: Vec<_> = alternatives
326 .iter()
327 .map(|(w, e)| (*w, fold_constants_impl(e, stats)))
328 .collect();
329 TLExpr::ProbabilisticChoice {
330 alternatives: alts_opt,
331 }
332 }
333
334 TLExpr::Box(inner) => TLExpr::Box(Box::new(fold_constants_impl(inner, stats))),
336 TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(fold_constants_impl(inner, stats))),
337 TLExpr::Next(inner) => TLExpr::Next(Box::new(fold_constants_impl(inner, stats))),
338 TLExpr::Eventually(inner) => {
339 TLExpr::Eventually(Box::new(fold_constants_impl(inner, stats)))
340 }
341 TLExpr::Always(inner) => TLExpr::Always(Box::new(fold_constants_impl(inner, stats))),
342 TLExpr::Until { before, after } => TLExpr::Until {
343 before: Box::new(fold_constants_impl(before, stats)),
344 after: Box::new(fold_constants_impl(after, stats)),
345 },
346 TLExpr::Release { released, releaser } => TLExpr::Release {
347 released: Box::new(fold_constants_impl(released, stats)),
348 releaser: Box::new(fold_constants_impl(releaser, stats)),
349 },
350 TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
351 before: Box::new(fold_constants_impl(before, stats)),
352 after: Box::new(fold_constants_impl(after, stats)),
353 },
354 TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
355 released: Box::new(fold_constants_impl(released, stats)),
356 releaser: Box::new(fold_constants_impl(releaser, stats)),
357 },
358
359 TLExpr::Pred { .. } | TLExpr::Constant(_) | TLExpr::Score(_) => expr.clone(),
361 _ => expr.clone(),
363 }
364}
365
366fn fold_binary_op<F, C>(
368 left: &TLExpr,
369 right: &TLExpr,
370 stats: &mut ConstantFoldingStats,
371 op: F,
372 constructor: C,
373) -> TLExpr
374where
375 F: Fn(f64, f64) -> f64,
376 C: Fn(TLExpr, TLExpr) -> TLExpr,
377{
378 let left_opt = fold_constants_impl(left, stats);
379 let right_opt = fold_constants_impl(right, stats);
380
381 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_opt, &right_opt) {
382 stats.binary_ops_folded += 1;
383 TLExpr::Constant(op(*a, *b))
384 } else {
385 constructor(left_opt, right_opt)
386 }
387}
388
389fn fold_unary_op<F, C>(
391 inner: &TLExpr,
392 stats: &mut ConstantFoldingStats,
393 op: F,
394 constructor: C,
395) -> TLExpr
396where
397 F: Fn(f64) -> f64,
398 C: Fn(TLExpr) -> TLExpr,
399{
400 let inner_opt = fold_constants_impl(inner, stats);
401
402 if let TLExpr::Constant(x) = inner_opt {
403 stats.unary_ops_folded += 1;
404 TLExpr::Constant(op(x))
405 } else {
406 constructor(inner_opt)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_fold_binary_arithmetic() {
416 let expr = TLExpr::Add(
418 Box::new(TLExpr::Constant(2.0)),
419 Box::new(TLExpr::Constant(3.0)),
420 );
421 let (result, stats) = fold_constants(&expr);
422 assert!(matches!(result, TLExpr::Constant(x) if (x - 5.0).abs() < f64::EPSILON));
423 assert_eq!(stats.binary_ops_folded, 1);
424 }
425
426 #[test]
427 fn test_fold_nested_arithmetic() {
428 let expr = TLExpr::Mul(
430 Box::new(TLExpr::Add(
431 Box::new(TLExpr::Constant(2.0)),
432 Box::new(TLExpr::Constant(3.0)),
433 )),
434 Box::new(TLExpr::Constant(4.0)),
435 );
436 let (result, stats) = fold_constants(&expr);
437 assert!(matches!(result, TLExpr::Constant(x) if (x - 20.0).abs() < f64::EPSILON));
438 assert_eq!(stats.binary_ops_folded, 2);
439 }
440
441 #[test]
442 fn test_fold_unary_operations() {
443 let expr = TLExpr::Sqrt(Box::new(TLExpr::Constant(16.0)));
445 let (result, stats) = fold_constants(&expr);
446 assert!(matches!(result, TLExpr::Constant(x) if (x - 4.0).abs() < f64::EPSILON));
447 assert_eq!(stats.unary_ops_folded, 1);
448 }
449
450 #[test]
451 fn test_fold_trigonometry() {
452 let expr = TLExpr::Sin(Box::new(TLExpr::Constant(0.0)));
454 let (result, stats) = fold_constants(&expr);
455 assert!(matches!(result, TLExpr::Constant(x) if x.abs() < f64::EPSILON));
456 assert_eq!(stats.unary_ops_folded, 1);
457 }
458
459 #[test]
460 fn test_no_fold_with_variables() {
461 use tensorlogic_ir::Term;
462
463 let expr = TLExpr::Add(
465 Box::new(TLExpr::pred("x", vec![Term::var("i")])),
466 Box::new(TLExpr::Constant(2.0)),
467 );
468 let (result, stats) = fold_constants(&expr);
469 assert!(matches!(result, TLExpr::Add(..)));
470 assert_eq!(stats.binary_ops_folded, 0);
471 }
472}