1use crate::expr::TLExpr;
7
8pub fn constant_fold(expr: &TLExpr) -> TLExpr {
10 match expr {
11 TLExpr::Add(l, r) => {
13 let left = constant_fold(l);
14 let right = constant_fold(r);
15 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
16 return TLExpr::Constant(lv + rv);
17 }
18 TLExpr::Add(Box::new(left), Box::new(right))
19 }
20 TLExpr::Sub(l, r) => {
21 let left = constant_fold(l);
22 let right = constant_fold(r);
23 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
24 return TLExpr::Constant(lv - rv);
25 }
26 TLExpr::Sub(Box::new(left), Box::new(right))
27 }
28 TLExpr::Mul(l, r) => {
29 let left = constant_fold(l);
30 let right = constant_fold(r);
31 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
32 return TLExpr::Constant(lv * rv);
33 }
34 TLExpr::Mul(Box::new(left), Box::new(right))
35 }
36 TLExpr::Div(l, r) => {
37 let left = constant_fold(l);
38 let right = constant_fold(r);
39 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
40 if *rv != 0.0 {
41 return TLExpr::Constant(lv / rv);
42 }
43 }
44 TLExpr::Div(Box::new(left), Box::new(right))
45 }
46 TLExpr::Pow(l, r) => {
47 let left = constant_fold(l);
48 let right = constant_fold(r);
49 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
50 return TLExpr::Constant(lv.powf(*rv));
51 }
52 TLExpr::Pow(Box::new(left), Box::new(right))
53 }
54 TLExpr::Mod(l, r) => {
55 let left = constant_fold(l);
56 let right = constant_fold(r);
57 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
58 return TLExpr::Constant(lv % rv);
59 }
60 TLExpr::Mod(Box::new(left), Box::new(right))
61 }
62 TLExpr::Min(l, r) => {
63 let left = constant_fold(l);
64 let right = constant_fold(r);
65 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
66 return TLExpr::Constant(lv.min(*rv));
67 }
68 TLExpr::Min(Box::new(left), Box::new(right))
69 }
70 TLExpr::Max(l, r) => {
71 let left = constant_fold(l);
72 let right = constant_fold(r);
73 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
74 return TLExpr::Constant(lv.max(*rv));
75 }
76 TLExpr::Max(Box::new(left), Box::new(right))
77 }
78
79 TLExpr::Abs(e) => {
81 let inner = constant_fold(e);
82 if let TLExpr::Constant(v) = &inner {
83 return TLExpr::Constant(v.abs());
84 }
85 TLExpr::Abs(Box::new(inner))
86 }
87 TLExpr::Floor(e) => {
88 let inner = constant_fold(e);
89 if let TLExpr::Constant(v) = &inner {
90 return TLExpr::Constant(v.floor());
91 }
92 TLExpr::Floor(Box::new(inner))
93 }
94 TLExpr::Ceil(e) => {
95 let inner = constant_fold(e);
96 if let TLExpr::Constant(v) = &inner {
97 return TLExpr::Constant(v.ceil());
98 }
99 TLExpr::Ceil(Box::new(inner))
100 }
101 TLExpr::Round(e) => {
102 let inner = constant_fold(e);
103 if let TLExpr::Constant(v) = &inner {
104 return TLExpr::Constant(v.round());
105 }
106 TLExpr::Round(Box::new(inner))
107 }
108 TLExpr::Sqrt(e) => {
109 let inner = constant_fold(e);
110 if let TLExpr::Constant(v) = &inner {
111 if *v >= 0.0 {
112 return TLExpr::Constant(v.sqrt());
113 }
114 }
115 TLExpr::Sqrt(Box::new(inner))
116 }
117 TLExpr::Exp(e) => {
118 let inner = constant_fold(e);
119 if let TLExpr::Constant(v) = &inner {
120 return TLExpr::Constant(v.exp());
121 }
122 TLExpr::Exp(Box::new(inner))
123 }
124 TLExpr::Log(e) => {
125 let inner = constant_fold(e);
126 if let TLExpr::Constant(v) = &inner {
127 if *v > 0.0 {
128 return TLExpr::Constant(v.ln());
129 }
130 }
131 TLExpr::Log(Box::new(inner))
132 }
133 TLExpr::Sin(e) => {
134 let inner = constant_fold(e);
135 if let TLExpr::Constant(v) = &inner {
136 return TLExpr::Constant(v.sin());
137 }
138 TLExpr::Sin(Box::new(inner))
139 }
140 TLExpr::Cos(e) => {
141 let inner = constant_fold(e);
142 if let TLExpr::Constant(v) = &inner {
143 return TLExpr::Constant(v.cos());
144 }
145 TLExpr::Cos(Box::new(inner))
146 }
147 TLExpr::Tan(e) => {
148 let inner = constant_fold(e);
149 if let TLExpr::Constant(v) = &inner {
150 return TLExpr::Constant(v.tan());
151 }
152 TLExpr::Tan(Box::new(inner))
153 }
154
155 TLExpr::Eq(l, r) => {
157 let left = constant_fold(l);
158 let right = constant_fold(r);
159 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
160 return TLExpr::Constant(if (lv - rv).abs() < f64::EPSILON {
161 1.0
162 } else {
163 0.0
164 });
165 }
166 TLExpr::Eq(Box::new(left), Box::new(right))
167 }
168 TLExpr::Lt(l, r) => {
169 let left = constant_fold(l);
170 let right = constant_fold(r);
171 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
172 return TLExpr::Constant(if lv < rv { 1.0 } else { 0.0 });
173 }
174 TLExpr::Lt(Box::new(left), Box::new(right))
175 }
176 TLExpr::Gt(l, r) => {
177 let left = constant_fold(l);
178 let right = constant_fold(r);
179 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
180 return TLExpr::Constant(if lv > rv { 1.0 } else { 0.0 });
181 }
182 TLExpr::Gt(Box::new(left), Box::new(right))
183 }
184 TLExpr::Lte(l, r) => {
185 let left = constant_fold(l);
186 let right = constant_fold(r);
187 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
188 return TLExpr::Constant(if lv <= rv { 1.0 } else { 0.0 });
189 }
190 TLExpr::Lte(Box::new(left), Box::new(right))
191 }
192 TLExpr::Gte(l, r) => {
193 let left = constant_fold(l);
194 let right = constant_fold(r);
195 if let (TLExpr::Constant(lv), TLExpr::Constant(rv)) = (&left, &right) {
196 return TLExpr::Constant(if lv >= rv { 1.0 } else { 0.0 });
197 }
198 TLExpr::Gte(Box::new(left), Box::new(right))
199 }
200
201 TLExpr::And(l, r) => TLExpr::And(Box::new(constant_fold(l)), Box::new(constant_fold(r))),
203 TLExpr::Or(l, r) => TLExpr::Or(Box::new(constant_fold(l)), Box::new(constant_fold(r))),
204 TLExpr::Not(e) => TLExpr::Not(Box::new(constant_fold(e))),
205 TLExpr::Imply(l, r) => {
206 TLExpr::Imply(Box::new(constant_fold(l)), Box::new(constant_fold(r)))
207 }
208
209 TLExpr::Exists { var, domain, body } => TLExpr::Exists {
211 var: var.clone(),
212 domain: domain.clone(),
213 body: Box::new(constant_fold(body)),
214 },
215 TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
216 var: var.clone(),
217 domain: domain.clone(),
218 body: Box::new(constant_fold(body)),
219 },
220
221 TLExpr::Score(e) => TLExpr::Score(Box::new(constant_fold(e))),
223
224 TLExpr::Aggregate {
226 op,
227 var,
228 domain,
229 body,
230 group_by,
231 } => TLExpr::Aggregate {
232 op: op.clone(),
233 var: var.clone(),
234 domain: domain.clone(),
235 body: Box::new(constant_fold(body)),
236 group_by: group_by.clone(),
237 },
238
239 TLExpr::Box(e) => TLExpr::Box(Box::new(constant_fold(e))),
241 TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(constant_fold(e))),
242
243 TLExpr::Next(e) => TLExpr::Next(Box::new(constant_fold(e))),
245 TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(constant_fold(e))),
246 TLExpr::Always(e) => TLExpr::Always(Box::new(constant_fold(e))),
247 TLExpr::Until { before, after } => TLExpr::Until {
248 before: Box::new(constant_fold(before)),
249 after: Box::new(constant_fold(after)),
250 },
251
252 TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
254 kind: *kind,
255 left: Box::new(constant_fold(left)),
256 right: Box::new(constant_fold(right)),
257 },
258 TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
259 kind: *kind,
260 left: Box::new(constant_fold(left)),
261 right: Box::new(constant_fold(right)),
262 },
263 TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
264 kind: *kind,
265 expr: Box::new(constant_fold(expr)),
266 },
267 TLExpr::FuzzyImplication {
268 kind,
269 premise,
270 conclusion,
271 } => TLExpr::FuzzyImplication {
272 kind: *kind,
273 premise: Box::new(constant_fold(premise)),
274 conclusion: Box::new(constant_fold(conclusion)),
275 },
276
277 TLExpr::SoftExists {
279 var,
280 domain,
281 body,
282 temperature,
283 } => TLExpr::SoftExists {
284 var: var.clone(),
285 domain: domain.clone(),
286 body: Box::new(constant_fold(body)),
287 temperature: *temperature,
288 },
289 TLExpr::SoftForAll {
290 var,
291 domain,
292 body,
293 temperature,
294 } => TLExpr::SoftForAll {
295 var: var.clone(),
296 domain: domain.clone(),
297 body: Box::new(constant_fold(body)),
298 temperature: *temperature,
299 },
300 TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
301 weight: *weight,
302 rule: Box::new(constant_fold(rule)),
303 },
304 TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
305 alternatives: alternatives
306 .iter()
307 .map(|(p, e)| (*p, constant_fold(e)))
308 .collect(),
309 },
310
311 TLExpr::Release { released, releaser } => TLExpr::Release {
313 released: Box::new(constant_fold(released)),
314 releaser: Box::new(constant_fold(releaser)),
315 },
316 TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
317 before: Box::new(constant_fold(before)),
318 after: Box::new(constant_fold(after)),
319 },
320 TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
321 released: Box::new(constant_fold(released)),
322 releaser: Box::new(constant_fold(releaser)),
323 },
324
325 TLExpr::IfThenElse {
327 condition,
328 then_branch,
329 else_branch,
330 } => TLExpr::IfThenElse {
331 condition: Box::new(constant_fold(condition)),
332 then_branch: Box::new(constant_fold(then_branch)),
333 else_branch: Box::new(constant_fold(else_branch)),
334 },
335 TLExpr::Let { var, value, body } => TLExpr::Let {
336 var: var.clone(),
337 value: Box::new(constant_fold(value)),
338 body: Box::new(constant_fold(body)),
339 },
340
341 TLExpr::Lambda {
343 var,
344 var_type,
345 body,
346 } => TLExpr::lambda(var.clone(), var_type.clone(), constant_fold(body)),
347 TLExpr::Apply { function, argument } => {
348 TLExpr::apply(constant_fold(function), constant_fold(argument))
349 }
350 TLExpr::SetMembership { element, set } => {
351 TLExpr::set_membership(constant_fold(element), constant_fold(set))
352 }
353 TLExpr::SetUnion { left, right } => {
354 TLExpr::set_union(constant_fold(left), constant_fold(right))
355 }
356 TLExpr::SetIntersection { left, right } => {
357 TLExpr::set_intersection(constant_fold(left), constant_fold(right))
358 }
359 TLExpr::SetDifference { left, right } => {
360 TLExpr::set_difference(constant_fold(left), constant_fold(right))
361 }
362 TLExpr::SetCardinality { set } => TLExpr::set_cardinality(constant_fold(set)),
363 TLExpr::EmptySet => expr.clone(),
364 TLExpr::SetComprehension {
365 var,
366 domain,
367 condition,
368 } => TLExpr::set_comprehension(var.clone(), domain.clone(), constant_fold(condition)),
369 TLExpr::CountingExists {
370 var,
371 domain,
372 body,
373 min_count,
374 } => TLExpr::counting_exists(var.clone(), domain.clone(), constant_fold(body), *min_count),
375 TLExpr::CountingForAll {
376 var,
377 domain,
378 body,
379 min_count,
380 } => TLExpr::counting_forall(var.clone(), domain.clone(), constant_fold(body), *min_count),
381 TLExpr::ExactCount {
382 var,
383 domain,
384 body,
385 count,
386 } => TLExpr::exact_count(var.clone(), domain.clone(), constant_fold(body), *count),
387 TLExpr::Majority { var, domain, body } => {
388 TLExpr::majority(var.clone(), domain.clone(), constant_fold(body))
389 }
390 TLExpr::LeastFixpoint { var, body } => {
391 TLExpr::least_fixpoint(var.clone(), constant_fold(body))
392 }
393 TLExpr::GreatestFixpoint { var, body } => {
394 TLExpr::greatest_fixpoint(var.clone(), constant_fold(body))
395 }
396 TLExpr::Nominal { .. } => expr.clone(),
397 TLExpr::At { nominal, formula } => TLExpr::at(nominal.clone(), constant_fold(formula)),
398 TLExpr::Somewhere { formula } => TLExpr::somewhere(constant_fold(formula)),
399 TLExpr::Everywhere { formula } => TLExpr::everywhere(constant_fold(formula)),
400 TLExpr::AllDifferent { .. } => expr.clone(),
401 TLExpr::GlobalCardinality {
402 variables,
403 values,
404 min_occurrences,
405 max_occurrences,
406 } => TLExpr::global_cardinality(
407 variables.clone(),
408 values.iter().map(constant_fold).collect(),
409 min_occurrences.clone(),
410 max_occurrences.clone(),
411 ),
412 TLExpr::Abducible { .. } => expr.clone(),
413 TLExpr::Explain { formula } => TLExpr::explain(constant_fold(formula)),
414
415 TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_constant_fold_addition() {
426 let expr = TLExpr::Add(
427 Box::new(TLExpr::Constant(2.0)),
428 Box::new(TLExpr::Constant(3.0)),
429 );
430 let folded = constant_fold(&expr);
431 assert_eq!(folded, TLExpr::Constant(5.0));
432 }
433
434 #[test]
435 fn test_constant_fold_multiplication() {
436 let expr = TLExpr::Mul(
437 Box::new(TLExpr::Constant(4.0)),
438 Box::new(TLExpr::Constant(5.0)),
439 );
440 let folded = constant_fold(&expr);
441 assert_eq!(folded, TLExpr::Constant(20.0));
442 }
443
444 #[test]
445 fn test_constant_fold_nested() {
446 let expr = TLExpr::Mul(
448 Box::new(TLExpr::Add(
449 Box::new(TLExpr::Constant(2.0)),
450 Box::new(TLExpr::Constant(3.0)),
451 )),
452 Box::new(TLExpr::Constant(4.0)),
453 );
454 let folded = constant_fold(&expr);
455 assert_eq!(folded, TLExpr::Constant(20.0));
456 }
457
458 #[test]
459 fn test_constant_fold_division_zero() {
460 let expr = TLExpr::Div(
461 Box::new(TLExpr::Constant(5.0)),
462 Box::new(TLExpr::Constant(0.0)),
463 );
464 let folded = constant_fold(&expr);
465 matches!(folded, TLExpr::Div(_, _));
467 }
468
469 #[test]
470 fn test_constant_fold_sqrt_negative() {
471 let expr = TLExpr::Sqrt(Box::new(TLExpr::Constant(-4.0)));
472 let folded = constant_fold(&expr);
473 matches!(folded, TLExpr::Sqrt(_));
475 }
476}