1use tensorlogic_ir::TLExpr;
7
8#[derive(Debug, Default, Clone)]
10pub struct AlgebraicSimplificationStats {
11 pub identities_eliminated: usize,
13 pub annihilations_applied: usize,
15 pub idempotent_simplified: usize,
17 pub total_processed: usize,
19}
20
21pub fn simplify_algebraic(expr: &TLExpr) -> (TLExpr, AlgebraicSimplificationStats) {
44 let mut stats = AlgebraicSimplificationStats::default();
45 let result = simplify_algebraic_impl(expr, &mut stats);
46 (result, stats)
47}
48
49fn simplify_algebraic_impl(expr: &TLExpr, stats: &mut AlgebraicSimplificationStats) -> TLExpr {
50 stats.total_processed += 1;
51
52 match expr {
53 TLExpr::Add(left, right) => {
55 let left_simp = simplify_algebraic_impl(left, stats);
56 let right_simp = simplify_algebraic_impl(right, stats);
57
58 if is_zero(&right_simp) {
59 stats.identities_eliminated += 1;
60 left_simp
61 } else if is_zero(&left_simp) {
62 stats.identities_eliminated += 1;
63 right_simp
64 } else {
65 TLExpr::Add(Box::new(left_simp), Box::new(right_simp))
66 }
67 }
68
69 TLExpr::Sub(left, right) => {
71 let left_simp = simplify_algebraic_impl(left, stats);
72 let right_simp = simplify_algebraic_impl(right, stats);
73
74 if is_zero(&right_simp) {
75 stats.identities_eliminated += 1;
76 left_simp
77 } else {
78 TLExpr::Sub(Box::new(left_simp), Box::new(right_simp))
79 }
80 }
81
82 TLExpr::Mul(left, right) => {
84 let left_simp = simplify_algebraic_impl(left, stats);
85 let right_simp = simplify_algebraic_impl(right, stats);
86
87 if is_zero(&left_simp) || is_zero(&right_simp) {
88 stats.annihilations_applied += 1;
89 TLExpr::Constant(0.0)
90 } else if is_one(&right_simp) {
91 stats.identities_eliminated += 1;
92 left_simp
93 } else if is_one(&left_simp) {
94 stats.identities_eliminated += 1;
95 right_simp
96 } else {
97 TLExpr::Mul(Box::new(left_simp), Box::new(right_simp))
98 }
99 }
100
101 TLExpr::Div(left, right) => {
103 let left_simp = simplify_algebraic_impl(left, stats);
104 let right_simp = simplify_algebraic_impl(right, stats);
105
106 if is_one(&right_simp) {
107 stats.identities_eliminated += 1;
108 left_simp
109 } else if is_zero(&left_simp) {
110 stats.annihilations_applied += 1;
111 TLExpr::Constant(0.0)
112 } else {
113 TLExpr::Div(Box::new(left_simp), Box::new(right_simp))
114 }
115 }
116
117 TLExpr::Pow(base, exponent) => {
119 let base_simp = simplify_algebraic_impl(base, stats);
120 let exp_simp = simplify_algebraic_impl(exponent, stats);
121
122 if is_zero(&exp_simp) {
123 stats.identities_eliminated += 1;
124 TLExpr::Constant(1.0)
125 } else if is_one(&exp_simp) {
126 stats.identities_eliminated += 1;
127 base_simp
128 } else if is_one(&base_simp) {
129 stats.annihilations_applied += 1;
130 TLExpr::Constant(1.0)
131 } else if is_zero(&base_simp) {
132 stats.annihilations_applied += 1;
133 TLExpr::Constant(0.0)
134 } else {
135 TLExpr::Pow(Box::new(base_simp), Box::new(exp_simp))
136 }
137 }
138
139 TLExpr::Min(left, right) => {
141 let left_simp = simplify_algebraic_impl(left, stats);
142 let right_simp = simplify_algebraic_impl(right, stats);
143
144 if expressions_equal(&left_simp, &right_simp) {
145 stats.idempotent_simplified += 1;
146 left_simp
147 } else {
148 TLExpr::Min(Box::new(left_simp), Box::new(right_simp))
149 }
150 }
151 TLExpr::Max(left, right) => {
152 let left_simp = simplify_algebraic_impl(left, stats);
153 let right_simp = simplify_algebraic_impl(right, stats);
154
155 if expressions_equal(&left_simp, &right_simp) {
156 stats.idempotent_simplified += 1;
157 left_simp
158 } else {
159 TLExpr::Max(Box::new(left_simp), Box::new(right_simp))
160 }
161 }
162
163 TLExpr::Abs(inner) => {
165 let inner_simp = simplify_algebraic_impl(inner, stats);
166 if matches!(&inner_simp, TLExpr::Abs(_)) {
168 stats.idempotent_simplified += 1;
169 inner_simp
170 } else {
171 TLExpr::Abs(Box::new(inner_simp))
172 }
173 }
174
175 TLExpr::Floor(inner) => {
177 let inner_simp = simplify_algebraic_impl(inner, stats);
178 TLExpr::Floor(Box::new(inner_simp))
179 }
180 TLExpr::Ceil(inner) => {
181 let inner_simp = simplify_algebraic_impl(inner, stats);
182 TLExpr::Ceil(Box::new(inner_simp))
183 }
184 TLExpr::Round(inner) => {
185 let inner_simp = simplify_algebraic_impl(inner, stats);
186 TLExpr::Round(Box::new(inner_simp))
187 }
188 TLExpr::Sqrt(inner) => {
189 let inner_simp = simplify_algebraic_impl(inner, stats);
190 TLExpr::Sqrt(Box::new(inner_simp))
191 }
192 TLExpr::Exp(inner) => {
193 let inner_simp = simplify_algebraic_impl(inner, stats);
194 TLExpr::Exp(Box::new(inner_simp))
195 }
196 TLExpr::Log(inner) => {
197 let inner_simp = simplify_algebraic_impl(inner, stats);
198 TLExpr::Log(Box::new(inner_simp))
199 }
200 TLExpr::Sin(inner) => {
201 let inner_simp = simplify_algebraic_impl(inner, stats);
202 TLExpr::Sin(Box::new(inner_simp))
203 }
204 TLExpr::Cos(inner) => {
205 let inner_simp = simplify_algebraic_impl(inner, stats);
206 TLExpr::Cos(Box::new(inner_simp))
207 }
208 TLExpr::Tan(inner) => {
209 let inner_simp = simplify_algebraic_impl(inner, stats);
210 TLExpr::Tan(Box::new(inner_simp))
211 }
212
213 TLExpr::Mod(left, right) => {
215 let left_simp = simplify_algebraic_impl(left, stats);
216 let right_simp = simplify_algebraic_impl(right, stats);
217 TLExpr::Mod(Box::new(left_simp), Box::new(right_simp))
218 }
219
220 TLExpr::And(left, right) => {
222 let left_simp = simplify_algebraic_impl(left, stats);
223 let right_simp = simplify_algebraic_impl(right, stats);
224 TLExpr::And(Box::new(left_simp), Box::new(right_simp))
225 }
226 TLExpr::Or(left, right) => {
227 let left_simp = simplify_algebraic_impl(left, stats);
228 let right_simp = simplify_algebraic_impl(right, stats);
229 TLExpr::Or(Box::new(left_simp), Box::new(right_simp))
230 }
231 TLExpr::Not(inner) => {
232 let inner_simp = simplify_algebraic_impl(inner, stats);
233 TLExpr::Not(Box::new(inner_simp))
234 }
235 TLExpr::Imply(left, right) => {
236 let left_simp = simplify_algebraic_impl(left, stats);
237 let right_simp = simplify_algebraic_impl(right, stats);
238 TLExpr::Imply(Box::new(left_simp), Box::new(right_simp))
239 }
240
241 TLExpr::Eq(left, right) => {
243 let left_simp = simplify_algebraic_impl(left, stats);
244 let right_simp = simplify_algebraic_impl(right, stats);
245 TLExpr::Eq(Box::new(left_simp), Box::new(right_simp))
246 }
247 TLExpr::Lt(left, right) => {
248 let left_simp = simplify_algebraic_impl(left, stats);
249 let right_simp = simplify_algebraic_impl(right, stats);
250 TLExpr::Lt(Box::new(left_simp), Box::new(right_simp))
251 }
252 TLExpr::Gt(left, right) => {
253 let left_simp = simplify_algebraic_impl(left, stats);
254 let right_simp = simplify_algebraic_impl(right, stats);
255 TLExpr::Gt(Box::new(left_simp), Box::new(right_simp))
256 }
257 TLExpr::Lte(left, right) => {
258 let left_simp = simplify_algebraic_impl(left, stats);
259 let right_simp = simplify_algebraic_impl(right, stats);
260 TLExpr::Lte(Box::new(left_simp), Box::new(right_simp))
261 }
262 TLExpr::Gte(left, right) => {
263 let left_simp = simplify_algebraic_impl(left, stats);
264 let right_simp = simplify_algebraic_impl(right, stats);
265 TLExpr::Gte(Box::new(left_simp), Box::new(right_simp))
266 }
267
268 TLExpr::Exists { var, domain, body } => {
270 let body_simp = simplify_algebraic_impl(body, stats);
271 TLExpr::Exists {
272 var: var.clone(),
273 domain: domain.clone(),
274 body: Box::new(body_simp),
275 }
276 }
277 TLExpr::ForAll { var, domain, body } => {
278 let body_simp = simplify_algebraic_impl(body, stats);
279 TLExpr::ForAll {
280 var: var.clone(),
281 domain: domain.clone(),
282 body: Box::new(body_simp),
283 }
284 }
285 TLExpr::Aggregate {
286 op,
287 var,
288 domain,
289 body,
290 group_by,
291 } => {
292 let body_simp = simplify_algebraic_impl(body, stats);
293 TLExpr::Aggregate {
294 op: op.clone(),
295 var: var.clone(),
296 domain: domain.clone(),
297 body: Box::new(body_simp),
298 group_by: group_by.clone(),
299 }
300 }
301 TLExpr::IfThenElse {
302 condition,
303 then_branch,
304 else_branch,
305 } => {
306 let cond_simp = simplify_algebraic_impl(condition, stats);
307 let then_simp = simplify_algebraic_impl(then_branch, stats);
308 let else_simp = simplify_algebraic_impl(else_branch, stats);
309 TLExpr::IfThenElse {
310 condition: Box::new(cond_simp),
311 then_branch: Box::new(then_simp),
312 else_branch: Box::new(else_simp),
313 }
314 }
315 TLExpr::Let { var, value, body } => {
316 let value_simp = simplify_algebraic_impl(value, stats);
317 let body_simp = simplify_algebraic_impl(body, stats);
318 TLExpr::Let {
319 var: var.clone(),
320 value: Box::new(value_simp),
321 body: Box::new(body_simp),
322 }
323 }
324
325 TLExpr::TNorm { kind, left, right } => {
327 let left_simp = simplify_algebraic_impl(left, stats);
328 let right_simp = simplify_algebraic_impl(right, stats);
329 TLExpr::TNorm {
330 kind: *kind,
331 left: Box::new(left_simp),
332 right: Box::new(right_simp),
333 }
334 }
335 TLExpr::TCoNorm { kind, left, right } => {
336 let left_simp = simplify_algebraic_impl(left, stats);
337 let right_simp = simplify_algebraic_impl(right, stats);
338 TLExpr::TCoNorm {
339 kind: *kind,
340 left: Box::new(left_simp),
341 right: Box::new(right_simp),
342 }
343 }
344 TLExpr::FuzzyNot { kind, expr: inner } => {
345 let inner_simp = simplify_algebraic_impl(inner, stats);
346 TLExpr::FuzzyNot {
347 kind: *kind,
348 expr: Box::new(inner_simp),
349 }
350 }
351 TLExpr::FuzzyImplication {
352 kind,
353 premise,
354 conclusion,
355 } => {
356 let premise_simp = simplify_algebraic_impl(premise, stats);
357 let conclusion_simp = simplify_algebraic_impl(conclusion, stats);
358 TLExpr::FuzzyImplication {
359 kind: *kind,
360 premise: Box::new(premise_simp),
361 conclusion: Box::new(conclusion_simp),
362 }
363 }
364 TLExpr::SoftExists {
365 var,
366 domain,
367 body,
368 temperature,
369 } => {
370 let body_simp = simplify_algebraic_impl(body, stats);
371 TLExpr::SoftExists {
372 var: var.clone(),
373 domain: domain.clone(),
374 body: Box::new(body_simp),
375 temperature: *temperature,
376 }
377 }
378 TLExpr::SoftForAll {
379 var,
380 domain,
381 body,
382 temperature,
383 } => {
384 let body_simp = simplify_algebraic_impl(body, stats);
385 TLExpr::SoftForAll {
386 var: var.clone(),
387 domain: domain.clone(),
388 body: Box::new(body_simp),
389 temperature: *temperature,
390 }
391 }
392 TLExpr::WeightedRule { weight, rule } => {
393 let rule_simp = simplify_algebraic_impl(rule, stats);
394 TLExpr::WeightedRule {
395 weight: *weight,
396 rule: Box::new(rule_simp),
397 }
398 }
399 TLExpr::ProbabilisticChoice { alternatives } => {
400 let alts_simp: Vec<_> = alternatives
401 .iter()
402 .map(|(w, e)| (*w, simplify_algebraic_impl(e, stats)))
403 .collect();
404 TLExpr::ProbabilisticChoice {
405 alternatives: alts_simp,
406 }
407 }
408
409 TLExpr::Box(inner) => TLExpr::Box(Box::new(simplify_algebraic_impl(inner, stats))),
411 TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(simplify_algebraic_impl(inner, stats))),
412 TLExpr::Next(inner) => TLExpr::Next(Box::new(simplify_algebraic_impl(inner, stats))),
413 TLExpr::Eventually(inner) => {
414 TLExpr::Eventually(Box::new(simplify_algebraic_impl(inner, stats)))
415 }
416 TLExpr::Always(inner) => TLExpr::Always(Box::new(simplify_algebraic_impl(inner, stats))),
417 TLExpr::Until { before, after } => TLExpr::Until {
418 before: Box::new(simplify_algebraic_impl(before, stats)),
419 after: Box::new(simplify_algebraic_impl(after, stats)),
420 },
421 TLExpr::Release { released, releaser } => TLExpr::Release {
422 released: Box::new(simplify_algebraic_impl(released, stats)),
423 releaser: Box::new(simplify_algebraic_impl(releaser, stats)),
424 },
425 TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
426 before: Box::new(simplify_algebraic_impl(before, stats)),
427 after: Box::new(simplify_algebraic_impl(after, stats)),
428 },
429 TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
430 released: Box::new(simplify_algebraic_impl(released, stats)),
431 releaser: Box::new(simplify_algebraic_impl(releaser, stats)),
432 },
433
434 TLExpr::Pred { .. } | TLExpr::Constant(_) | TLExpr::Score(_) => expr.clone(),
436 _ => expr.clone(),
438 }
439}
440
441fn is_zero(expr: &TLExpr) -> bool {
443 matches!(expr, TLExpr::Constant(x) if x.abs() < f64::EPSILON)
444}
445
446fn is_one(expr: &TLExpr) -> bool {
448 matches!(expr, TLExpr::Constant(x) if (x - 1.0).abs() < f64::EPSILON)
449}
450
451fn expressions_equal(a: &TLExpr, b: &TLExpr) -> bool {
453 match (a, b) {
454 (TLExpr::Constant(x), TLExpr::Constant(y)) => (x - y).abs() < f64::EPSILON,
455 (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
456 n1 == n2 && a1 == a2
457 }
458 _ => false, }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use tensorlogic_ir::Term;
466
467 #[test]
468 fn test_addition_identity() {
469 let x = TLExpr::pred("x", vec![Term::var("i")]);
471 let expr = TLExpr::Add(Box::new(x.clone()), Box::new(TLExpr::Constant(0.0)));
472
473 let (result, stats) = simplify_algebraic(&expr);
474 assert!(matches!(result, TLExpr::Pred { .. }));
475 assert_eq!(stats.identities_eliminated, 1);
476 }
477
478 #[test]
479 fn test_multiplication_identity() {
480 let x = TLExpr::pred("x", vec![Term::var("i")]);
482 let expr = TLExpr::Mul(Box::new(x.clone()), Box::new(TLExpr::Constant(1.0)));
483
484 let (result, stats) = simplify_algebraic(&expr);
485 assert!(matches!(result, TLExpr::Pred { .. }));
486 assert_eq!(stats.identities_eliminated, 1);
487 }
488
489 #[test]
490 fn test_multiplication_annihilation() {
491 let x = TLExpr::pred("x", vec![Term::var("i")]);
493 let expr = TLExpr::Mul(Box::new(x), Box::new(TLExpr::Constant(0.0)));
494
495 let (result, stats) = simplify_algebraic(&expr);
496 assert!(matches!(result, TLExpr::Constant(0.0)));
497 assert_eq!(stats.annihilations_applied, 1);
498 }
499
500 #[test]
501 fn test_power_identities() {
502 let x = TLExpr::pred("x", vec![Term::var("i")]);
503
504 let expr1 = TLExpr::Pow(Box::new(x.clone()), Box::new(TLExpr::Constant(0.0)));
506 let (result1, stats1) = simplify_algebraic(&expr1);
507 assert!(matches!(result1, TLExpr::Constant(1.0)));
508 assert_eq!(stats1.identities_eliminated, 1);
509
510 let expr2 = TLExpr::Pow(Box::new(x), Box::new(TLExpr::Constant(1.0)));
512 let (result2, stats2) = simplify_algebraic(&expr2);
513 assert!(matches!(result2, TLExpr::Pred { .. }));
514 assert_eq!(stats2.identities_eliminated, 1);
515 }
516
517 #[test]
518 fn test_idempotent_min_max() {
519 let x = TLExpr::pred("x", vec![Term::var("i")]);
520
521 let expr1 = TLExpr::Min(Box::new(x.clone()), Box::new(x.clone()));
523 let (result1, stats1) = simplify_algebraic(&expr1);
524 assert!(matches!(result1, TLExpr::Pred { .. }));
525 assert_eq!(stats1.idempotent_simplified, 1);
526
527 let expr2 = TLExpr::Max(Box::new(x.clone()), Box::new(x));
529 let (result2, stats2) = simplify_algebraic(&expr2);
530 assert!(matches!(result2, TLExpr::Pred { .. }));
531 assert_eq!(stats2.idempotent_simplified, 1);
532 }
533
534 #[test]
535 fn test_nested_simplification() {
536 let x = TLExpr::pred("x", vec![Term::var("i")]);
538 let add = TLExpr::Add(Box::new(x), Box::new(TLExpr::Constant(0.0)));
539 let expr = TLExpr::Mul(Box::new(add), Box::new(TLExpr::Constant(1.0)));
540
541 let (result, stats) = simplify_algebraic(&expr);
542 assert!(matches!(result, TLExpr::Pred { .. }));
543 assert_eq!(stats.identities_eliminated, 2);
544 }
545}