Skip to main content

tensorlogic_compiler/optimize/
strength_reduction.rs

1//! Strength reduction optimization pass.
2//!
3//! This module provides optimizations that replace expensive operations with
4//! cheaper equivalents. Examples include:
5//!
6//! - `x^2` → `x * x` (avoid power function overhead)
7//! - `x^0` → `1` (eliminate unnecessary computation)
8//! - `x^1` → `x` (eliminate identity operation)
9//! - `exp(0)` → `1` (constant evaluation)
10//! - `log(1)` → `0` (constant evaluation)
11//! - `sqrt(x*x)` → `abs(x)` (eliminate redundant sqrt)
12//!
13//! # Examples
14//!
15//! ```
16//! use tensorlogic_compiler::optimize::reduce_strength;
17//! use tensorlogic_ir::{TLExpr, Term};
18//!
19//! // x^2 → x * x
20//! let x = TLExpr::pred("x", vec![Term::var("i")]);
21//! let expr = TLExpr::pow(x, TLExpr::Constant(2.0));
22//! let (optimized, stats) = reduce_strength(&expr);
23//! assert!(stats.power_reductions > 0);
24//! ```
25
26use tensorlogic_ir::TLExpr;
27
28/// Statistics from strength reduction optimization.
29#[derive(Debug, Clone, Default)]
30pub struct StrengthReductionStats {
31    /// Number of power operations reduced (e.g., x^2 → x*x)
32    pub power_reductions: usize,
33    /// Number of operations eliminated (e.g., x^0 → 1)
34    pub operations_eliminated: usize,
35    /// Number of special function optimizations (e.g., exp(0) → 1)
36    pub special_function_optimizations: usize,
37    /// Total expressions processed
38    pub total_processed: usize,
39}
40
41impl StrengthReductionStats {
42    /// Get total number of optimizations applied.
43    pub fn total_optimizations(&self) -> usize {
44        self.power_reductions + self.operations_eliminated + self.special_function_optimizations
45    }
46}
47
48/// Apply strength reduction optimization to an expression.
49///
50/// This pass replaces expensive operations with cheaper equivalents.
51///
52/// # Arguments
53///
54/// * `expr` - The expression to optimize
55///
56/// # Returns
57///
58/// A tuple of (optimized expression, statistics)
59pub fn reduce_strength(expr: &TLExpr) -> (TLExpr, StrengthReductionStats) {
60    let mut stats = StrengthReductionStats::default();
61    let result = reduce_strength_impl(expr, &mut stats);
62    (result, stats)
63}
64
65fn reduce_strength_impl(expr: &TLExpr, stats: &mut StrengthReductionStats) -> TLExpr {
66    stats.total_processed += 1;
67
68    match expr {
69        // Power optimizations
70        TLExpr::Pow(base, exp) => {
71            let base_opt = reduce_strength_impl(base, stats);
72            let exp_opt = reduce_strength_impl(exp, stats);
73
74            // Check for constant exponents
75            if let TLExpr::Constant(n) = &exp_opt {
76                // x^0 → 1
77                if *n == 0.0 {
78                    stats.operations_eliminated += 1;
79                    return TLExpr::Constant(1.0);
80                }
81                // x^1 → x
82                if *n == 1.0 {
83                    stats.operations_eliminated += 1;
84                    return base_opt;
85                }
86                // x^2 → x * x (avoid power function overhead)
87                if *n == 2.0 {
88                    stats.power_reductions += 1;
89                    return TLExpr::mul(base_opt.clone(), base_opt);
90                }
91                // x^3 → x * x * x
92                if *n == 3.0 {
93                    stats.power_reductions += 1;
94                    return TLExpr::mul(base_opt.clone(), TLExpr::mul(base_opt.clone(), base_opt));
95                }
96                // x^(-1) → 1 / x
97                if *n == -1.0 {
98                    stats.power_reductions += 1;
99                    return TLExpr::div(TLExpr::Constant(1.0), base_opt);
100                }
101                // x^0.5 → sqrt(x)
102                if *n == 0.5 {
103                    stats.power_reductions += 1;
104                    return TLExpr::sqrt(base_opt);
105                }
106            }
107
108            TLExpr::Pow(Box::new(base_opt), Box::new(exp_opt))
109        }
110
111        // Exponential optimizations
112        TLExpr::Exp(inner) => {
113            let inner_opt = reduce_strength_impl(inner, stats);
114
115            // exp(0) → 1
116            if let TLExpr::Constant(n) = &inner_opt {
117                if *n == 0.0 {
118                    stats.special_function_optimizations += 1;
119                    return TLExpr::Constant(1.0);
120                }
121                // exp(1) → e (approximate)
122                if *n == 1.0 {
123                    stats.special_function_optimizations += 1;
124                    return TLExpr::Constant(std::f64::consts::E);
125                }
126            }
127
128            // exp(log(x)) → x
129            if let TLExpr::Log(log_inner) = &inner_opt {
130                stats.special_function_optimizations += 1;
131                return (**log_inner).clone();
132            }
133
134            TLExpr::Exp(Box::new(inner_opt))
135        }
136
137        // Logarithm optimizations
138        TLExpr::Log(inner) => {
139            let inner_opt = reduce_strength_impl(inner, stats);
140
141            // log(1) → 0
142            if let TLExpr::Constant(n) = &inner_opt {
143                if *n == 1.0 {
144                    stats.special_function_optimizations += 1;
145                    return TLExpr::Constant(0.0);
146                }
147                // log(e) → 1
148                if (*n - std::f64::consts::E).abs() < 1e-10 {
149                    stats.special_function_optimizations += 1;
150                    return TLExpr::Constant(1.0);
151                }
152            }
153
154            // log(exp(x)) → x
155            if let TLExpr::Exp(exp_inner) = &inner_opt {
156                stats.special_function_optimizations += 1;
157                return (**exp_inner).clone();
158            }
159
160            // log(x^n) → n * log(x)
161            if let TLExpr::Pow(base, exp) = &inner_opt {
162                if let TLExpr::Constant(_) = exp.as_ref() {
163                    stats.special_function_optimizations += 1;
164                    return TLExpr::mul((**exp).clone(), TLExpr::log((**base).clone()));
165                }
166            }
167
168            TLExpr::Log(Box::new(inner_opt))
169        }
170
171        // Square root optimizations
172        TLExpr::Sqrt(inner) => {
173            let inner_opt = reduce_strength_impl(inner, stats);
174
175            // sqrt(0) → 0
176            if let TLExpr::Constant(n) = &inner_opt {
177                if *n == 0.0 {
178                    stats.special_function_optimizations += 1;
179                    return TLExpr::Constant(0.0);
180                }
181                // sqrt(1) → 1
182                if *n == 1.0 {
183                    stats.special_function_optimizations += 1;
184                    return TLExpr::Constant(1.0);
185                }
186                // sqrt(4) → 2
187                if *n == 4.0 {
188                    stats.special_function_optimizations += 1;
189                    return TLExpr::Constant(2.0);
190                }
191            }
192
193            // sqrt(x^2) → abs(x) (conceptually; we use x for now)
194            if let TLExpr::Pow(base, exp) = &inner_opt {
195                if let TLExpr::Constant(n) = exp.as_ref() {
196                    if *n == 2.0 {
197                        stats.special_function_optimizations += 1;
198                        return TLExpr::abs((**base).clone());
199                    }
200                }
201            }
202
203            // sqrt(x * x) → abs(x)
204            if let TLExpr::Mul(lhs, rhs) = &inner_opt {
205                if lhs == rhs {
206                    stats.special_function_optimizations += 1;
207                    return TLExpr::abs((**lhs).clone());
208                }
209            }
210
211            TLExpr::Sqrt(Box::new(inner_opt))
212        }
213
214        // Absolute value optimizations
215        TLExpr::Abs(inner) => {
216            let inner_opt = reduce_strength_impl(inner, stats);
217
218            // abs(constant) → |constant|
219            if let TLExpr::Constant(n) = &inner_opt {
220                stats.special_function_optimizations += 1;
221                return TLExpr::Constant(n.abs());
222            }
223
224            // abs(abs(x)) → abs(x)
225            if let TLExpr::Abs(_) = &inner_opt {
226                stats.special_function_optimizations += 1;
227                return inner_opt;
228            }
229
230            TLExpr::Abs(Box::new(inner_opt))
231        }
232
233        // Division optimizations
234        TLExpr::Div(lhs, rhs) => {
235            let lhs_opt = reduce_strength_impl(lhs, stats);
236            let rhs_opt = reduce_strength_impl(rhs, stats);
237
238            // x / 1 → x (already handled in algebraic, but good to have here)
239            if let TLExpr::Constant(n) = &rhs_opt {
240                if *n == 1.0 {
241                    stats.operations_eliminated += 1;
242                    return lhs_opt;
243                }
244                // 0 / x → 0
245                if let TLExpr::Constant(m) = &lhs_opt {
246                    if *m == 0.0 {
247                        stats.operations_eliminated += 1;
248                        return TLExpr::Constant(0.0);
249                    }
250                }
251                // x / 2 → x * 0.5 (multiplication is often faster)
252                if *n == 2.0 {
253                    stats.power_reductions += 1;
254                    return TLExpr::mul(lhs_opt, TLExpr::Constant(0.5));
255                }
256                // x / 4 → x * 0.25
257                if *n == 4.0 {
258                    stats.power_reductions += 1;
259                    return TLExpr::mul(lhs_opt, TLExpr::Constant(0.25));
260                }
261            }
262
263            TLExpr::Div(Box::new(lhs_opt), Box::new(rhs_opt))
264        }
265
266        // Multiplication optimizations for powers
267        TLExpr::Mul(lhs, rhs) => {
268            let lhs_opt = reduce_strength_impl(lhs, stats);
269            let rhs_opt = reduce_strength_impl(rhs, stats);
270
271            // exp(a) * exp(b) → exp(a + b)
272            if let (TLExpr::Exp(a), TLExpr::Exp(b)) = (&lhs_opt, &rhs_opt) {
273                stats.special_function_optimizations += 1;
274                return TLExpr::exp(TLExpr::add((**a).clone(), (**b).clone()));
275            }
276
277            TLExpr::Mul(Box::new(lhs_opt), Box::new(rhs_opt))
278        }
279
280        // Addition for exp/log patterns
281        TLExpr::Add(lhs, rhs) => {
282            let lhs_opt = reduce_strength_impl(lhs, stats);
283            let rhs_opt = reduce_strength_impl(rhs, stats);
284
285            // log(a) + log(b) → log(a * b)
286            if let (TLExpr::Log(a), TLExpr::Log(b)) = (&lhs_opt, &rhs_opt) {
287                stats.special_function_optimizations += 1;
288                return TLExpr::log(TLExpr::mul((**a).clone(), (**b).clone()));
289            }
290
291            TLExpr::Add(Box::new(lhs_opt), Box::new(rhs_opt))
292        }
293
294        // Subtraction
295        TLExpr::Sub(lhs, rhs) => {
296            let lhs_opt = reduce_strength_impl(lhs, stats);
297            let rhs_opt = reduce_strength_impl(rhs, stats);
298
299            // log(a) - log(b) → log(a / b)
300            if let (TLExpr::Log(a), TLExpr::Log(b)) = (&lhs_opt, &rhs_opt) {
301                stats.special_function_optimizations += 1;
302                return TLExpr::log(TLExpr::div((**a).clone(), (**b).clone()));
303            }
304
305            TLExpr::Sub(Box::new(lhs_opt), Box::new(rhs_opt))
306        }
307
308        // Recursive cases for compound expressions
309        TLExpr::And(lhs, rhs) => {
310            let lhs_opt = reduce_strength_impl(lhs, stats);
311            let rhs_opt = reduce_strength_impl(rhs, stats);
312            TLExpr::And(Box::new(lhs_opt), Box::new(rhs_opt))
313        }
314
315        TLExpr::Or(lhs, rhs) => {
316            let lhs_opt = reduce_strength_impl(lhs, stats);
317            let rhs_opt = reduce_strength_impl(rhs, stats);
318            TLExpr::Or(Box::new(lhs_opt), Box::new(rhs_opt))
319        }
320
321        TLExpr::Not(inner) => {
322            let inner_opt = reduce_strength_impl(inner, stats);
323            TLExpr::Not(Box::new(inner_opt))
324        }
325
326        TLExpr::Imply(lhs, rhs) => {
327            let lhs_opt = reduce_strength_impl(lhs, stats);
328            let rhs_opt = reduce_strength_impl(rhs, stats);
329            TLExpr::Imply(Box::new(lhs_opt), Box::new(rhs_opt))
330        }
331
332        TLExpr::Exists { var, domain, body } => {
333            let body_opt = reduce_strength_impl(body, stats);
334            TLExpr::Exists {
335                var: var.clone(),
336                domain: domain.clone(),
337                body: Box::new(body_opt),
338            }
339        }
340
341        TLExpr::ForAll { var, domain, body } => {
342            let body_opt = reduce_strength_impl(body, stats);
343            TLExpr::ForAll {
344                var: var.clone(),
345                domain: domain.clone(),
346                body: Box::new(body_opt),
347            }
348        }
349
350        TLExpr::Let { var, value, body } => {
351            let value_opt = reduce_strength_impl(value, stats);
352            let body_opt = reduce_strength_impl(body, stats);
353            TLExpr::Let {
354                var: var.clone(),
355                value: Box::new(value_opt),
356                body: Box::new(body_opt),
357            }
358        }
359
360        TLExpr::IfThenElse {
361            condition,
362            then_branch,
363            else_branch,
364        } => {
365            let cond_opt = reduce_strength_impl(condition, stats);
366            let then_opt = reduce_strength_impl(then_branch, stats);
367            let else_opt = reduce_strength_impl(else_branch, stats);
368            TLExpr::IfThenElse {
369                condition: Box::new(cond_opt),
370                then_branch: Box::new(then_opt),
371                else_branch: Box::new(else_opt),
372            }
373        }
374
375        // Comparison operators
376        TLExpr::Eq(lhs, rhs) => {
377            let lhs_opt = reduce_strength_impl(lhs, stats);
378            let rhs_opt = reduce_strength_impl(rhs, stats);
379            TLExpr::Eq(Box::new(lhs_opt), Box::new(rhs_opt))
380        }
381
382        TLExpr::Lt(lhs, rhs) => {
383            let lhs_opt = reduce_strength_impl(lhs, stats);
384            let rhs_opt = reduce_strength_impl(rhs, stats);
385            TLExpr::Lt(Box::new(lhs_opt), Box::new(rhs_opt))
386        }
387
388        TLExpr::Lte(lhs, rhs) => {
389            let lhs_opt = reduce_strength_impl(lhs, stats);
390            let rhs_opt = reduce_strength_impl(rhs, stats);
391            TLExpr::Lte(Box::new(lhs_opt), Box::new(rhs_opt))
392        }
393
394        TLExpr::Gt(lhs, rhs) => {
395            let lhs_opt = reduce_strength_impl(lhs, stats);
396            let rhs_opt = reduce_strength_impl(rhs, stats);
397            TLExpr::Gt(Box::new(lhs_opt), Box::new(rhs_opt))
398        }
399
400        TLExpr::Gte(lhs, rhs) => {
401            let lhs_opt = reduce_strength_impl(lhs, stats);
402            let rhs_opt = reduce_strength_impl(rhs, stats);
403            TLExpr::Gte(Box::new(lhs_opt), Box::new(rhs_opt))
404        }
405
406        // Min/Max
407        TLExpr::Min(lhs, rhs) => {
408            let lhs_opt = reduce_strength_impl(lhs, stats);
409            let rhs_opt = reduce_strength_impl(rhs, stats);
410            TLExpr::Min(Box::new(lhs_opt), Box::new(rhs_opt))
411        }
412
413        TLExpr::Max(lhs, rhs) => {
414            let lhs_opt = reduce_strength_impl(lhs, stats);
415            let rhs_opt = reduce_strength_impl(rhs, stats);
416            TLExpr::Max(Box::new(lhs_opt), Box::new(rhs_opt))
417        }
418
419        // Modal logic
420        TLExpr::Box(inner) => {
421            let inner_opt = reduce_strength_impl(inner, stats);
422            TLExpr::Box(Box::new(inner_opt))
423        }
424
425        TLExpr::Diamond(inner) => {
426            let inner_opt = reduce_strength_impl(inner, stats);
427            TLExpr::Diamond(Box::new(inner_opt))
428        }
429
430        // Temporal logic
431        TLExpr::Next(inner) => {
432            let inner_opt = reduce_strength_impl(inner, stats);
433            TLExpr::Next(Box::new(inner_opt))
434        }
435
436        TLExpr::Eventually(inner) => {
437            let inner_opt = reduce_strength_impl(inner, stats);
438            TLExpr::Eventually(Box::new(inner_opt))
439        }
440
441        TLExpr::Always(inner) => {
442            let inner_opt = reduce_strength_impl(inner, stats);
443            TLExpr::Always(Box::new(inner_opt))
444        }
445
446        TLExpr::Until { before, after } => {
447            let before_opt = reduce_strength_impl(before, stats);
448            let after_opt = reduce_strength_impl(after, stats);
449            TLExpr::Until {
450                before: Box::new(before_opt),
451                after: Box::new(after_opt),
452            }
453        }
454
455        // Leaves and other variants: no optimization needed
456        TLExpr::Pred { .. }
457        | TLExpr::Constant(_)
458        | TLExpr::Score(_)
459        | TLExpr::Mod(_, _)
460        | TLExpr::Floor(_)
461        | TLExpr::Ceil(_)
462        | TLExpr::Round(_)
463        | TLExpr::Sin(_)
464        | TLExpr::Cos(_)
465        | TLExpr::Tan(_)
466        | TLExpr::Aggregate { .. }
467        | TLExpr::TNorm { .. }
468        | TLExpr::TCoNorm { .. }
469        | TLExpr::FuzzyNot { .. }
470        | TLExpr::FuzzyImplication { .. }
471        | TLExpr::SoftExists { .. }
472        | TLExpr::SoftForAll { .. }
473        | TLExpr::WeightedRule { .. }
474        | TLExpr::ProbabilisticChoice { .. }
475        | TLExpr::Release { .. }
476        | TLExpr::WeakUntil { .. }
477        | TLExpr::StrongRelease { .. } => expr.clone(),
478
479        // All other expression types (enhancements)
480        _ => expr.clone(),
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use tensorlogic_ir::Term;
488
489    #[test]
490    fn test_power_reduction_x_squared() {
491        let x = TLExpr::pred("x", vec![Term::var("i")]);
492        let expr = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
493        let (optimized, stats) = reduce_strength(&expr);
494
495        assert_eq!(stats.power_reductions, 1);
496        // Should be x * x
497        if let TLExpr::Mul(lhs, rhs) = optimized {
498            assert_eq!(*lhs, x);
499            assert_eq!(*rhs, x);
500        } else {
501            panic!("Expected Mul expression");
502        }
503    }
504
505    #[test]
506    fn test_power_reduction_x_zero() {
507        let x = TLExpr::pred("x", vec![Term::var("i")]);
508        let expr = TLExpr::pow(x, TLExpr::Constant(0.0));
509        let (optimized, stats) = reduce_strength(&expr);
510
511        assert_eq!(stats.operations_eliminated, 1);
512        assert_eq!(optimized, TLExpr::Constant(1.0));
513    }
514
515    #[test]
516    fn test_power_reduction_x_one() {
517        let x = TLExpr::pred("x", vec![Term::var("i")]);
518        let expr = TLExpr::pow(x.clone(), TLExpr::Constant(1.0));
519        let (optimized, stats) = reduce_strength(&expr);
520
521        assert_eq!(stats.operations_eliminated, 1);
522        assert_eq!(optimized, x);
523    }
524
525    #[test]
526    fn test_power_reduction_x_half() {
527        let x = TLExpr::pred("x", vec![Term::var("i")]);
528        let expr = TLExpr::pow(x.clone(), TLExpr::Constant(0.5));
529        let (optimized, stats) = reduce_strength(&expr);
530
531        assert_eq!(stats.power_reductions, 1);
532        assert!(matches!(optimized, TLExpr::Sqrt(_)));
533    }
534
535    #[test]
536    fn test_exp_zero() {
537        let expr = TLExpr::exp(TLExpr::Constant(0.0));
538        let (optimized, stats) = reduce_strength(&expr);
539
540        assert_eq!(stats.special_function_optimizations, 1);
541        assert_eq!(optimized, TLExpr::Constant(1.0));
542    }
543
544    #[test]
545    fn test_log_one() {
546        let expr = TLExpr::log(TLExpr::Constant(1.0));
547        let (optimized, stats) = reduce_strength(&expr);
548
549        assert_eq!(stats.special_function_optimizations, 1);
550        assert_eq!(optimized, TLExpr::Constant(0.0));
551    }
552
553    #[test]
554    fn test_exp_log_inverse() {
555        let x = TLExpr::pred("x", vec![Term::var("i")]);
556        let expr = TLExpr::exp(TLExpr::log(x.clone()));
557        let (optimized, stats) = reduce_strength(&expr);
558
559        assert_eq!(stats.special_function_optimizations, 1);
560        assert_eq!(optimized, x);
561    }
562
563    #[test]
564    fn test_log_exp_inverse() {
565        let x = TLExpr::pred("x", vec![Term::var("i")]);
566        let expr = TLExpr::log(TLExpr::exp(x.clone()));
567        let (optimized, stats) = reduce_strength(&expr);
568
569        assert_eq!(stats.special_function_optimizations, 1);
570        assert_eq!(optimized, x);
571    }
572
573    #[test]
574    fn test_sqrt_x_squared() {
575        let x = TLExpr::pred("x", vec![Term::var("i")]);
576        let expr = TLExpr::sqrt(TLExpr::pow(x.clone(), TLExpr::Constant(2.0)));
577        let (optimized, stats) = reduce_strength(&expr);
578
579        // sqrt(x^2) should become abs(x)
580        assert!(stats.special_function_optimizations > 0 || stats.power_reductions > 0);
581        assert!(matches!(optimized, TLExpr::Abs(_)));
582    }
583
584    #[test]
585    fn test_sqrt_x_times_x() {
586        let x = TLExpr::pred("x", vec![Term::var("i")]);
587        let expr = TLExpr::sqrt(TLExpr::mul(x.clone(), x.clone()));
588        let (optimized, stats) = reduce_strength(&expr);
589
590        assert_eq!(stats.special_function_optimizations, 1);
591        assert!(matches!(optimized, TLExpr::Abs(_)));
592    }
593
594    #[test]
595    fn test_abs_abs() {
596        let x = TLExpr::pred("x", vec![Term::var("i")]);
597        let expr = TLExpr::abs(TLExpr::abs(x.clone()));
598        let (optimized, stats) = reduce_strength(&expr);
599
600        assert_eq!(stats.special_function_optimizations, 1);
601        // Should be abs(x), not abs(abs(x))
602        if let TLExpr::Abs(inner) = optimized {
603            assert_eq!(*inner, x);
604        } else {
605            panic!("Expected Abs expression");
606        }
607    }
608
609    #[test]
610    fn test_division_by_two() {
611        let x = TLExpr::pred("x", vec![Term::var("i")]);
612        let expr = TLExpr::div(x.clone(), TLExpr::Constant(2.0));
613        let (optimized, stats) = reduce_strength(&expr);
614
615        assert_eq!(stats.power_reductions, 1);
616        // Should be x * 0.5
617        if let TLExpr::Mul(lhs, rhs) = optimized {
618            assert_eq!(*lhs, x);
619            assert_eq!(*rhs, TLExpr::Constant(0.5));
620        } else {
621            panic!("Expected Mul expression");
622        }
623    }
624
625    #[test]
626    fn test_exp_product() {
627        let a = TLExpr::pred("a", vec![Term::var("i")]);
628        let b = TLExpr::pred("b", vec![Term::var("j")]);
629        let expr = TLExpr::mul(TLExpr::exp(a.clone()), TLExpr::exp(b.clone()));
630        let (optimized, stats) = reduce_strength(&expr);
631
632        assert_eq!(stats.special_function_optimizations, 1);
633        // Should be exp(a + b)
634        if let TLExpr::Exp(inner) = optimized {
635            if let TLExpr::Add(lhs, rhs) = *inner {
636                assert_eq!(*lhs, a);
637                assert_eq!(*rhs, b);
638            } else {
639                panic!("Expected Add inside Exp");
640            }
641        } else {
642            panic!("Expected Exp expression");
643        }
644    }
645
646    #[test]
647    fn test_log_sum() {
648        let a = TLExpr::pred("a", vec![Term::var("i")]);
649        let b = TLExpr::pred("b", vec![Term::var("j")]);
650        let expr = TLExpr::add(TLExpr::log(a.clone()), TLExpr::log(b.clone()));
651        let (optimized, stats) = reduce_strength(&expr);
652
653        assert_eq!(stats.special_function_optimizations, 1);
654        // Should be log(a * b)
655        if let TLExpr::Log(inner) = optimized {
656            if let TLExpr::Mul(lhs, rhs) = *inner {
657                assert_eq!(*lhs, a);
658                assert_eq!(*rhs, b);
659            } else {
660                panic!("Expected Mul inside Log");
661            }
662        } else {
663            panic!("Expected Log expression");
664        }
665    }
666
667    #[test]
668    fn test_log_difference() {
669        let a = TLExpr::pred("a", vec![Term::var("i")]);
670        let b = TLExpr::pred("b", vec![Term::var("j")]);
671        let expr = TLExpr::sub(TLExpr::log(a.clone()), TLExpr::log(b.clone()));
672        let (optimized, stats) = reduce_strength(&expr);
673
674        assert_eq!(stats.special_function_optimizations, 1);
675        // Should be log(a / b)
676        if let TLExpr::Log(inner) = optimized {
677            if let TLExpr::Div(lhs, rhs) = *inner {
678                assert_eq!(*lhs, a);
679                assert_eq!(*rhs, b);
680            } else {
681                panic!("Expected Div inside Log");
682            }
683        } else {
684            panic!("Expected Log expression");
685        }
686    }
687
688    #[test]
689    fn test_nested_optimization() {
690        // exp(log(x^2)) should reduce to x^2 → then x * x
691        let x = TLExpr::pred("x", vec![Term::var("i")]);
692        let expr = TLExpr::exp(TLExpr::log(TLExpr::pow(x.clone(), TLExpr::Constant(2.0))));
693        let (optimized, stats) = reduce_strength(&expr);
694
695        // Multiple optimizations: exp(log(..)) → .., x^2 → x*x
696        assert!(stats.total_optimizations() >= 2);
697        // Final result should be x * x
698        if let TLExpr::Mul(lhs, rhs) = optimized {
699            assert_eq!(*lhs, x);
700            assert_eq!(*rhs, x);
701        } else {
702            panic!("Expected Mul expression, got {:?}", optimized);
703        }
704    }
705
706    #[test]
707    fn test_quantifier_body_optimization() {
708        let x = TLExpr::pred("x", vec![Term::var("y")]);
709        let body = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
710        let expr = TLExpr::exists("y", "D", body);
711        let (optimized, stats) = reduce_strength(&expr);
712
713        assert_eq!(stats.power_reductions, 1);
714        if let TLExpr::Exists { body, .. } = optimized {
715            assert!(matches!(*body, TLExpr::Mul(_, _)));
716        } else {
717            panic!("Expected Exists expression");
718        }
719    }
720
721    #[test]
722    fn test_stats_total_optimizations() {
723        let stats = StrengthReductionStats {
724            power_reductions: 3,
725            operations_eliminated: 2,
726            special_function_optimizations: 5,
727            total_processed: 100,
728        };
729        assert_eq!(stats.total_optimizations(), 10);
730    }
731}