1use tensorlogic_ir::TLExpr;
27
28#[derive(Debug, Clone, Default)]
30pub struct StrengthReductionStats {
31 pub power_reductions: usize,
33 pub operations_eliminated: usize,
35 pub special_function_optimizations: usize,
37 pub total_processed: usize,
39}
40
41impl StrengthReductionStats {
42 pub fn total_optimizations(&self) -> usize {
44 self.power_reductions + self.operations_eliminated + self.special_function_optimizations
45 }
46}
47
48pub fn reduce_strength(expr: &TLExpr) -> (TLExpr, StrengthReductionStats) {
60 let mut stats = StrengthReductionStats::default();
61 let result = reduce_strength_impl(expr, &mut stats);
62 (result, stats)
63}
64
65fn reduce_strength_impl(expr: &TLExpr, stats: &mut StrengthReductionStats) -> TLExpr {
66 stats.total_processed += 1;
67
68 match expr {
69 TLExpr::Pow(base, exp) => {
71 let base_opt = reduce_strength_impl(base, stats);
72 let exp_opt = reduce_strength_impl(exp, stats);
73
74 if let TLExpr::Constant(n) = &exp_opt {
76 if *n == 0.0 {
78 stats.operations_eliminated += 1;
79 return TLExpr::Constant(1.0);
80 }
81 if *n == 1.0 {
83 stats.operations_eliminated += 1;
84 return base_opt;
85 }
86 if *n == 2.0 {
88 stats.power_reductions += 1;
89 return TLExpr::mul(base_opt.clone(), base_opt);
90 }
91 if *n == 3.0 {
93 stats.power_reductions += 1;
94 return TLExpr::mul(base_opt.clone(), TLExpr::mul(base_opt.clone(), base_opt));
95 }
96 if *n == -1.0 {
98 stats.power_reductions += 1;
99 return TLExpr::div(TLExpr::Constant(1.0), base_opt);
100 }
101 if *n == 0.5 {
103 stats.power_reductions += 1;
104 return TLExpr::sqrt(base_opt);
105 }
106 }
107
108 TLExpr::Pow(Box::new(base_opt), Box::new(exp_opt))
109 }
110
111 TLExpr::Exp(inner) => {
113 let inner_opt = reduce_strength_impl(inner, stats);
114
115 if let TLExpr::Constant(n) = &inner_opt {
117 if *n == 0.0 {
118 stats.special_function_optimizations += 1;
119 return TLExpr::Constant(1.0);
120 }
121 if *n == 1.0 {
123 stats.special_function_optimizations += 1;
124 return TLExpr::Constant(std::f64::consts::E);
125 }
126 }
127
128 if let TLExpr::Log(log_inner) = &inner_opt {
130 stats.special_function_optimizations += 1;
131 return (**log_inner).clone();
132 }
133
134 TLExpr::Exp(Box::new(inner_opt))
135 }
136
137 TLExpr::Log(inner) => {
139 let inner_opt = reduce_strength_impl(inner, stats);
140
141 if let TLExpr::Constant(n) = &inner_opt {
143 if *n == 1.0 {
144 stats.special_function_optimizations += 1;
145 return TLExpr::Constant(0.0);
146 }
147 if (*n - std::f64::consts::E).abs() < 1e-10 {
149 stats.special_function_optimizations += 1;
150 return TLExpr::Constant(1.0);
151 }
152 }
153
154 if let TLExpr::Exp(exp_inner) = &inner_opt {
156 stats.special_function_optimizations += 1;
157 return (**exp_inner).clone();
158 }
159
160 if let TLExpr::Pow(base, exp) = &inner_opt {
162 if let TLExpr::Constant(_) = exp.as_ref() {
163 stats.special_function_optimizations += 1;
164 return TLExpr::mul((**exp).clone(), TLExpr::log((**base).clone()));
165 }
166 }
167
168 TLExpr::Log(Box::new(inner_opt))
169 }
170
171 TLExpr::Sqrt(inner) => {
173 let inner_opt = reduce_strength_impl(inner, stats);
174
175 if let TLExpr::Constant(n) = &inner_opt {
177 if *n == 0.0 {
178 stats.special_function_optimizations += 1;
179 return TLExpr::Constant(0.0);
180 }
181 if *n == 1.0 {
183 stats.special_function_optimizations += 1;
184 return TLExpr::Constant(1.0);
185 }
186 if *n == 4.0 {
188 stats.special_function_optimizations += 1;
189 return TLExpr::Constant(2.0);
190 }
191 }
192
193 if let TLExpr::Pow(base, exp) = &inner_opt {
195 if let TLExpr::Constant(n) = exp.as_ref() {
196 if *n == 2.0 {
197 stats.special_function_optimizations += 1;
198 return TLExpr::abs((**base).clone());
199 }
200 }
201 }
202
203 if let TLExpr::Mul(lhs, rhs) = &inner_opt {
205 if lhs == rhs {
206 stats.special_function_optimizations += 1;
207 return TLExpr::abs((**lhs).clone());
208 }
209 }
210
211 TLExpr::Sqrt(Box::new(inner_opt))
212 }
213
214 TLExpr::Abs(inner) => {
216 let inner_opt = reduce_strength_impl(inner, stats);
217
218 if let TLExpr::Constant(n) = &inner_opt {
220 stats.special_function_optimizations += 1;
221 return TLExpr::Constant(n.abs());
222 }
223
224 if let TLExpr::Abs(_) = &inner_opt {
226 stats.special_function_optimizations += 1;
227 return inner_opt;
228 }
229
230 TLExpr::Abs(Box::new(inner_opt))
231 }
232
233 TLExpr::Div(lhs, rhs) => {
235 let lhs_opt = reduce_strength_impl(lhs, stats);
236 let rhs_opt = reduce_strength_impl(rhs, stats);
237
238 if let TLExpr::Constant(n) = &rhs_opt {
240 if *n == 1.0 {
241 stats.operations_eliminated += 1;
242 return lhs_opt;
243 }
244 if let TLExpr::Constant(m) = &lhs_opt {
246 if *m == 0.0 {
247 stats.operations_eliminated += 1;
248 return TLExpr::Constant(0.0);
249 }
250 }
251 if *n == 2.0 {
253 stats.power_reductions += 1;
254 return TLExpr::mul(lhs_opt, TLExpr::Constant(0.5));
255 }
256 if *n == 4.0 {
258 stats.power_reductions += 1;
259 return TLExpr::mul(lhs_opt, TLExpr::Constant(0.25));
260 }
261 }
262
263 TLExpr::Div(Box::new(lhs_opt), Box::new(rhs_opt))
264 }
265
266 TLExpr::Mul(lhs, rhs) => {
268 let lhs_opt = reduce_strength_impl(lhs, stats);
269 let rhs_opt = reduce_strength_impl(rhs, stats);
270
271 if let (TLExpr::Exp(a), TLExpr::Exp(b)) = (&lhs_opt, &rhs_opt) {
273 stats.special_function_optimizations += 1;
274 return TLExpr::exp(TLExpr::add((**a).clone(), (**b).clone()));
275 }
276
277 TLExpr::Mul(Box::new(lhs_opt), Box::new(rhs_opt))
278 }
279
280 TLExpr::Add(lhs, rhs) => {
282 let lhs_opt = reduce_strength_impl(lhs, stats);
283 let rhs_opt = reduce_strength_impl(rhs, stats);
284
285 if let (TLExpr::Log(a), TLExpr::Log(b)) = (&lhs_opt, &rhs_opt) {
287 stats.special_function_optimizations += 1;
288 return TLExpr::log(TLExpr::mul((**a).clone(), (**b).clone()));
289 }
290
291 TLExpr::Add(Box::new(lhs_opt), Box::new(rhs_opt))
292 }
293
294 TLExpr::Sub(lhs, rhs) => {
296 let lhs_opt = reduce_strength_impl(lhs, stats);
297 let rhs_opt = reduce_strength_impl(rhs, stats);
298
299 if let (TLExpr::Log(a), TLExpr::Log(b)) = (&lhs_opt, &rhs_opt) {
301 stats.special_function_optimizations += 1;
302 return TLExpr::log(TLExpr::div((**a).clone(), (**b).clone()));
303 }
304
305 TLExpr::Sub(Box::new(lhs_opt), Box::new(rhs_opt))
306 }
307
308 TLExpr::And(lhs, rhs) => {
310 let lhs_opt = reduce_strength_impl(lhs, stats);
311 let rhs_opt = reduce_strength_impl(rhs, stats);
312 TLExpr::And(Box::new(lhs_opt), Box::new(rhs_opt))
313 }
314
315 TLExpr::Or(lhs, rhs) => {
316 let lhs_opt = reduce_strength_impl(lhs, stats);
317 let rhs_opt = reduce_strength_impl(rhs, stats);
318 TLExpr::Or(Box::new(lhs_opt), Box::new(rhs_opt))
319 }
320
321 TLExpr::Not(inner) => {
322 let inner_opt = reduce_strength_impl(inner, stats);
323 TLExpr::Not(Box::new(inner_opt))
324 }
325
326 TLExpr::Imply(lhs, rhs) => {
327 let lhs_opt = reduce_strength_impl(lhs, stats);
328 let rhs_opt = reduce_strength_impl(rhs, stats);
329 TLExpr::Imply(Box::new(lhs_opt), Box::new(rhs_opt))
330 }
331
332 TLExpr::Exists { var, domain, body } => {
333 let body_opt = reduce_strength_impl(body, stats);
334 TLExpr::Exists {
335 var: var.clone(),
336 domain: domain.clone(),
337 body: Box::new(body_opt),
338 }
339 }
340
341 TLExpr::ForAll { var, domain, body } => {
342 let body_opt = reduce_strength_impl(body, stats);
343 TLExpr::ForAll {
344 var: var.clone(),
345 domain: domain.clone(),
346 body: Box::new(body_opt),
347 }
348 }
349
350 TLExpr::Let { var, value, body } => {
351 let value_opt = reduce_strength_impl(value, stats);
352 let body_opt = reduce_strength_impl(body, stats);
353 TLExpr::Let {
354 var: var.clone(),
355 value: Box::new(value_opt),
356 body: Box::new(body_opt),
357 }
358 }
359
360 TLExpr::IfThenElse {
361 condition,
362 then_branch,
363 else_branch,
364 } => {
365 let cond_opt = reduce_strength_impl(condition, stats);
366 let then_opt = reduce_strength_impl(then_branch, stats);
367 let else_opt = reduce_strength_impl(else_branch, stats);
368 TLExpr::IfThenElse {
369 condition: Box::new(cond_opt),
370 then_branch: Box::new(then_opt),
371 else_branch: Box::new(else_opt),
372 }
373 }
374
375 TLExpr::Eq(lhs, rhs) => {
377 let lhs_opt = reduce_strength_impl(lhs, stats);
378 let rhs_opt = reduce_strength_impl(rhs, stats);
379 TLExpr::Eq(Box::new(lhs_opt), Box::new(rhs_opt))
380 }
381
382 TLExpr::Lt(lhs, rhs) => {
383 let lhs_opt = reduce_strength_impl(lhs, stats);
384 let rhs_opt = reduce_strength_impl(rhs, stats);
385 TLExpr::Lt(Box::new(lhs_opt), Box::new(rhs_opt))
386 }
387
388 TLExpr::Lte(lhs, rhs) => {
389 let lhs_opt = reduce_strength_impl(lhs, stats);
390 let rhs_opt = reduce_strength_impl(rhs, stats);
391 TLExpr::Lte(Box::new(lhs_opt), Box::new(rhs_opt))
392 }
393
394 TLExpr::Gt(lhs, rhs) => {
395 let lhs_opt = reduce_strength_impl(lhs, stats);
396 let rhs_opt = reduce_strength_impl(rhs, stats);
397 TLExpr::Gt(Box::new(lhs_opt), Box::new(rhs_opt))
398 }
399
400 TLExpr::Gte(lhs, rhs) => {
401 let lhs_opt = reduce_strength_impl(lhs, stats);
402 let rhs_opt = reduce_strength_impl(rhs, stats);
403 TLExpr::Gte(Box::new(lhs_opt), Box::new(rhs_opt))
404 }
405
406 TLExpr::Min(lhs, rhs) => {
408 let lhs_opt = reduce_strength_impl(lhs, stats);
409 let rhs_opt = reduce_strength_impl(rhs, stats);
410 TLExpr::Min(Box::new(lhs_opt), Box::new(rhs_opt))
411 }
412
413 TLExpr::Max(lhs, rhs) => {
414 let lhs_opt = reduce_strength_impl(lhs, stats);
415 let rhs_opt = reduce_strength_impl(rhs, stats);
416 TLExpr::Max(Box::new(lhs_opt), Box::new(rhs_opt))
417 }
418
419 TLExpr::Box(inner) => {
421 let inner_opt = reduce_strength_impl(inner, stats);
422 TLExpr::Box(Box::new(inner_opt))
423 }
424
425 TLExpr::Diamond(inner) => {
426 let inner_opt = reduce_strength_impl(inner, stats);
427 TLExpr::Diamond(Box::new(inner_opt))
428 }
429
430 TLExpr::Next(inner) => {
432 let inner_opt = reduce_strength_impl(inner, stats);
433 TLExpr::Next(Box::new(inner_opt))
434 }
435
436 TLExpr::Eventually(inner) => {
437 let inner_opt = reduce_strength_impl(inner, stats);
438 TLExpr::Eventually(Box::new(inner_opt))
439 }
440
441 TLExpr::Always(inner) => {
442 let inner_opt = reduce_strength_impl(inner, stats);
443 TLExpr::Always(Box::new(inner_opt))
444 }
445
446 TLExpr::Until { before, after } => {
447 let before_opt = reduce_strength_impl(before, stats);
448 let after_opt = reduce_strength_impl(after, stats);
449 TLExpr::Until {
450 before: Box::new(before_opt),
451 after: Box::new(after_opt),
452 }
453 }
454
455 TLExpr::Pred { .. }
457 | TLExpr::Constant(_)
458 | TLExpr::Score(_)
459 | TLExpr::Mod(_, _)
460 | TLExpr::Floor(_)
461 | TLExpr::Ceil(_)
462 | TLExpr::Round(_)
463 | TLExpr::Sin(_)
464 | TLExpr::Cos(_)
465 | TLExpr::Tan(_)
466 | TLExpr::Aggregate { .. }
467 | TLExpr::TNorm { .. }
468 | TLExpr::TCoNorm { .. }
469 | TLExpr::FuzzyNot { .. }
470 | TLExpr::FuzzyImplication { .. }
471 | TLExpr::SoftExists { .. }
472 | TLExpr::SoftForAll { .. }
473 | TLExpr::WeightedRule { .. }
474 | TLExpr::ProbabilisticChoice { .. }
475 | TLExpr::Release { .. }
476 | TLExpr::WeakUntil { .. }
477 | TLExpr::StrongRelease { .. } => expr.clone(),
478
479 _ => expr.clone(),
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use tensorlogic_ir::Term;
488
489 #[test]
490 fn test_power_reduction_x_squared() {
491 let x = TLExpr::pred("x", vec![Term::var("i")]);
492 let expr = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
493 let (optimized, stats) = reduce_strength(&expr);
494
495 assert_eq!(stats.power_reductions, 1);
496 if let TLExpr::Mul(lhs, rhs) = optimized {
498 assert_eq!(*lhs, x);
499 assert_eq!(*rhs, x);
500 } else {
501 panic!("Expected Mul expression");
502 }
503 }
504
505 #[test]
506 fn test_power_reduction_x_zero() {
507 let x = TLExpr::pred("x", vec![Term::var("i")]);
508 let expr = TLExpr::pow(x, TLExpr::Constant(0.0));
509 let (optimized, stats) = reduce_strength(&expr);
510
511 assert_eq!(stats.operations_eliminated, 1);
512 assert_eq!(optimized, TLExpr::Constant(1.0));
513 }
514
515 #[test]
516 fn test_power_reduction_x_one() {
517 let x = TLExpr::pred("x", vec![Term::var("i")]);
518 let expr = TLExpr::pow(x.clone(), TLExpr::Constant(1.0));
519 let (optimized, stats) = reduce_strength(&expr);
520
521 assert_eq!(stats.operations_eliminated, 1);
522 assert_eq!(optimized, x);
523 }
524
525 #[test]
526 fn test_power_reduction_x_half() {
527 let x = TLExpr::pred("x", vec![Term::var("i")]);
528 let expr = TLExpr::pow(x.clone(), TLExpr::Constant(0.5));
529 let (optimized, stats) = reduce_strength(&expr);
530
531 assert_eq!(stats.power_reductions, 1);
532 assert!(matches!(optimized, TLExpr::Sqrt(_)));
533 }
534
535 #[test]
536 fn test_exp_zero() {
537 let expr = TLExpr::exp(TLExpr::Constant(0.0));
538 let (optimized, stats) = reduce_strength(&expr);
539
540 assert_eq!(stats.special_function_optimizations, 1);
541 assert_eq!(optimized, TLExpr::Constant(1.0));
542 }
543
544 #[test]
545 fn test_log_one() {
546 let expr = TLExpr::log(TLExpr::Constant(1.0));
547 let (optimized, stats) = reduce_strength(&expr);
548
549 assert_eq!(stats.special_function_optimizations, 1);
550 assert_eq!(optimized, TLExpr::Constant(0.0));
551 }
552
553 #[test]
554 fn test_exp_log_inverse() {
555 let x = TLExpr::pred("x", vec![Term::var("i")]);
556 let expr = TLExpr::exp(TLExpr::log(x.clone()));
557 let (optimized, stats) = reduce_strength(&expr);
558
559 assert_eq!(stats.special_function_optimizations, 1);
560 assert_eq!(optimized, x);
561 }
562
563 #[test]
564 fn test_log_exp_inverse() {
565 let x = TLExpr::pred("x", vec![Term::var("i")]);
566 let expr = TLExpr::log(TLExpr::exp(x.clone()));
567 let (optimized, stats) = reduce_strength(&expr);
568
569 assert_eq!(stats.special_function_optimizations, 1);
570 assert_eq!(optimized, x);
571 }
572
573 #[test]
574 fn test_sqrt_x_squared() {
575 let x = TLExpr::pred("x", vec![Term::var("i")]);
576 let expr = TLExpr::sqrt(TLExpr::pow(x.clone(), TLExpr::Constant(2.0)));
577 let (optimized, stats) = reduce_strength(&expr);
578
579 assert!(stats.special_function_optimizations > 0 || stats.power_reductions > 0);
581 assert!(matches!(optimized, TLExpr::Abs(_)));
582 }
583
584 #[test]
585 fn test_sqrt_x_times_x() {
586 let x = TLExpr::pred("x", vec![Term::var("i")]);
587 let expr = TLExpr::sqrt(TLExpr::mul(x.clone(), x.clone()));
588 let (optimized, stats) = reduce_strength(&expr);
589
590 assert_eq!(stats.special_function_optimizations, 1);
591 assert!(matches!(optimized, TLExpr::Abs(_)));
592 }
593
594 #[test]
595 fn test_abs_abs() {
596 let x = TLExpr::pred("x", vec![Term::var("i")]);
597 let expr = TLExpr::abs(TLExpr::abs(x.clone()));
598 let (optimized, stats) = reduce_strength(&expr);
599
600 assert_eq!(stats.special_function_optimizations, 1);
601 if let TLExpr::Abs(inner) = optimized {
603 assert_eq!(*inner, x);
604 } else {
605 panic!("Expected Abs expression");
606 }
607 }
608
609 #[test]
610 fn test_division_by_two() {
611 let x = TLExpr::pred("x", vec![Term::var("i")]);
612 let expr = TLExpr::div(x.clone(), TLExpr::Constant(2.0));
613 let (optimized, stats) = reduce_strength(&expr);
614
615 assert_eq!(stats.power_reductions, 1);
616 if let TLExpr::Mul(lhs, rhs) = optimized {
618 assert_eq!(*lhs, x);
619 assert_eq!(*rhs, TLExpr::Constant(0.5));
620 } else {
621 panic!("Expected Mul expression");
622 }
623 }
624
625 #[test]
626 fn test_exp_product() {
627 let a = TLExpr::pred("a", vec![Term::var("i")]);
628 let b = TLExpr::pred("b", vec![Term::var("j")]);
629 let expr = TLExpr::mul(TLExpr::exp(a.clone()), TLExpr::exp(b.clone()));
630 let (optimized, stats) = reduce_strength(&expr);
631
632 assert_eq!(stats.special_function_optimizations, 1);
633 if let TLExpr::Exp(inner) = optimized {
635 if let TLExpr::Add(lhs, rhs) = *inner {
636 assert_eq!(*lhs, a);
637 assert_eq!(*rhs, b);
638 } else {
639 panic!("Expected Add inside Exp");
640 }
641 } else {
642 panic!("Expected Exp expression");
643 }
644 }
645
646 #[test]
647 fn test_log_sum() {
648 let a = TLExpr::pred("a", vec![Term::var("i")]);
649 let b = TLExpr::pred("b", vec![Term::var("j")]);
650 let expr = TLExpr::add(TLExpr::log(a.clone()), TLExpr::log(b.clone()));
651 let (optimized, stats) = reduce_strength(&expr);
652
653 assert_eq!(stats.special_function_optimizations, 1);
654 if let TLExpr::Log(inner) = optimized {
656 if let TLExpr::Mul(lhs, rhs) = *inner {
657 assert_eq!(*lhs, a);
658 assert_eq!(*rhs, b);
659 } else {
660 panic!("Expected Mul inside Log");
661 }
662 } else {
663 panic!("Expected Log expression");
664 }
665 }
666
667 #[test]
668 fn test_log_difference() {
669 let a = TLExpr::pred("a", vec![Term::var("i")]);
670 let b = TLExpr::pred("b", vec![Term::var("j")]);
671 let expr = TLExpr::sub(TLExpr::log(a.clone()), TLExpr::log(b.clone()));
672 let (optimized, stats) = reduce_strength(&expr);
673
674 assert_eq!(stats.special_function_optimizations, 1);
675 if let TLExpr::Log(inner) = optimized {
677 if let TLExpr::Div(lhs, rhs) = *inner {
678 assert_eq!(*lhs, a);
679 assert_eq!(*rhs, b);
680 } else {
681 panic!("Expected Div inside Log");
682 }
683 } else {
684 panic!("Expected Log expression");
685 }
686 }
687
688 #[test]
689 fn test_nested_optimization() {
690 let x = TLExpr::pred("x", vec![Term::var("i")]);
692 let expr = TLExpr::exp(TLExpr::log(TLExpr::pow(x.clone(), TLExpr::Constant(2.0))));
693 let (optimized, stats) = reduce_strength(&expr);
694
695 assert!(stats.total_optimizations() >= 2);
697 if let TLExpr::Mul(lhs, rhs) = optimized {
699 assert_eq!(*lhs, x);
700 assert_eq!(*rhs, x);
701 } else {
702 panic!("Expected Mul expression, got {:?}", optimized);
703 }
704 }
705
706 #[test]
707 fn test_quantifier_body_optimization() {
708 let x = TLExpr::pred("x", vec![Term::var("y")]);
709 let body = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
710 let expr = TLExpr::exists("y", "D", body);
711 let (optimized, stats) = reduce_strength(&expr);
712
713 assert_eq!(stats.power_reductions, 1);
714 if let TLExpr::Exists { body, .. } = optimized {
715 assert!(matches!(*body, TLExpr::Mul(_, _)));
716 } else {
717 panic!("Expected Exists expression");
718 }
719 }
720
721 #[test]
722 fn test_stats_total_optimizations() {
723 let stats = StrengthReductionStats {
724 power_reductions: 3,
725 operations_eliminated: 2,
726 special_function_optimizations: 5,
727 total_processed: 100,
728 };
729 assert_eq!(stats.total_optimizations(), 10);
730 }
731}