Skip to main content

tensorlogic_compiler/optimize/
distributivity.rs

1//! Distributivity optimization pass.
2//!
3//! This module provides optimizations based on distributive laws to either
4//! expand or factor expressions based on computational cost analysis.
5//!
6//! # Operations
7//!
8//! - **Expansion**: `a * (b + c)` → `a*b + a*c` (when beneficial)
9//! - **Factoring**: `a*b + a*c` → `a * (b + c)` (when beneficial)
10//! - **Distribution over logic**: `a AND (b OR c)` → `(a AND b) OR (a AND c)`
11//!
12//! # Cost Model
13//!
14//! The optimization uses a simple cost model where:
15//! - Addition/Subtraction: cost 1
16//! - Multiplication: cost 2
17//! - Division: cost 4
18//! - Power: cost 8
19//!
20//! Factoring is preferred when it reduces total operation count.
21
22use tensorlogic_ir::TLExpr;
23
24/// Statistics from distributivity optimization.
25#[derive(Debug, Clone, Default)]
26pub struct DistributivityStats {
27    /// Number of expressions factored
28    pub expressions_factored: usize,
29    /// Number of expressions expanded
30    pub expressions_expanded: usize,
31    /// Number of common subexpressions extracted
32    pub common_terms_extracted: usize,
33    /// Total expressions processed
34    pub total_processed: usize,
35}
36
37impl DistributivityStats {
38    /// Get total number of optimizations applied.
39    pub fn total_optimizations(&self) -> usize {
40        self.expressions_factored + self.expressions_expanded + self.common_terms_extracted
41    }
42}
43
44/// Apply distributivity optimization to an expression.
45///
46/// This pass analyzes multiplication and addition patterns to either
47/// factor or expand based on computational cost.
48///
49/// # Arguments
50///
51/// * `expr` - The expression to optimize
52///
53/// # Returns
54///
55/// A tuple of (optimized expression, statistics)
56pub fn optimize_distributivity(expr: &TLExpr) -> (TLExpr, DistributivityStats) {
57    let mut stats = DistributivityStats::default();
58    let result = optimize_distributivity_impl(expr, &mut stats);
59    (result, stats)
60}
61
62fn optimize_distributivity_impl(expr: &TLExpr, stats: &mut DistributivityStats) -> TLExpr {
63    stats.total_processed += 1;
64
65    match expr {
66        // Look for factoring opportunities in addition
67        TLExpr::Add(lhs, rhs) => {
68            let lhs_opt = optimize_distributivity_impl(lhs, stats);
69            let rhs_opt = optimize_distributivity_impl(rhs, stats);
70
71            // Try to factor: a*b + a*c → a*(b+c)
72            if let Some(factored) = try_factor_add(&lhs_opt, &rhs_opt) {
73                stats.expressions_factored += 1;
74                return factored;
75            }
76
77            TLExpr::Add(Box::new(lhs_opt), Box::new(rhs_opt))
78        }
79
80        // Look for factoring in subtraction
81        TLExpr::Sub(lhs, rhs) => {
82            let lhs_opt = optimize_distributivity_impl(lhs, stats);
83            let rhs_opt = optimize_distributivity_impl(rhs, stats);
84
85            // Try to factor: a*b - a*c → a*(b-c)
86            if let Some(factored) = try_factor_sub(&lhs_opt, &rhs_opt) {
87                stats.expressions_factored += 1;
88                return factored;
89            }
90
91            TLExpr::Sub(Box::new(lhs_opt), Box::new(rhs_opt))
92        }
93
94        // Look for expansion opportunities in multiplication
95        TLExpr::Mul(lhs, rhs) => {
96            let lhs_opt = optimize_distributivity_impl(lhs, stats);
97            let rhs_opt = optimize_distributivity_impl(rhs, stats);
98
99            // Generally prefer factored form, so don't expand by default
100            // Only expand if specifically beneficial (e.g., for vectorization)
101            TLExpr::Mul(Box::new(lhs_opt), Box::new(rhs_opt))
102        }
103
104        // Logic distributivity: AND over OR
105        TLExpr::And(lhs, rhs) => {
106            let lhs_opt = optimize_distributivity_impl(lhs, stats);
107            let rhs_opt = optimize_distributivity_impl(rhs, stats);
108
109            // Try factoring: (a OR b) AND (a OR c) → a OR (b AND c)
110            if let Some(factored) = try_factor_and(&lhs_opt, &rhs_opt) {
111                stats.expressions_factored += 1;
112                return factored;
113            }
114
115            TLExpr::And(Box::new(lhs_opt), Box::new(rhs_opt))
116        }
117
118        // Logic distributivity: OR over AND
119        TLExpr::Or(lhs, rhs) => {
120            let lhs_opt = optimize_distributivity_impl(lhs, stats);
121            let rhs_opt = optimize_distributivity_impl(rhs, stats);
122
123            // Try factoring: (a AND b) OR (a AND c) → a AND (b OR c)
124            if let Some(factored) = try_factor_or(&lhs_opt, &rhs_opt) {
125                stats.expressions_factored += 1;
126                return factored;
127            }
128
129            TLExpr::Or(Box::new(lhs_opt), Box::new(rhs_opt))
130        }
131
132        // Recursive cases
133        TLExpr::Not(inner) => {
134            let inner_opt = optimize_distributivity_impl(inner, stats);
135            TLExpr::Not(Box::new(inner_opt))
136        }
137
138        TLExpr::Imply(lhs, rhs) => {
139            let lhs_opt = optimize_distributivity_impl(lhs, stats);
140            let rhs_opt = optimize_distributivity_impl(rhs, stats);
141            TLExpr::Imply(Box::new(lhs_opt), Box::new(rhs_opt))
142        }
143
144        TLExpr::Div(lhs, rhs) => {
145            let lhs_opt = optimize_distributivity_impl(lhs, stats);
146            let rhs_opt = optimize_distributivity_impl(rhs, stats);
147            TLExpr::Div(Box::new(lhs_opt), Box::new(rhs_opt))
148        }
149
150        TLExpr::Pow(base, exp) => {
151            let base_opt = optimize_distributivity_impl(base, stats);
152            let exp_opt = optimize_distributivity_impl(exp, stats);
153            TLExpr::Pow(Box::new(base_opt), Box::new(exp_opt))
154        }
155
156        TLExpr::Abs(inner) => {
157            let inner_opt = optimize_distributivity_impl(inner, stats);
158            TLExpr::Abs(Box::new(inner_opt))
159        }
160
161        TLExpr::Sqrt(inner) => {
162            let inner_opt = optimize_distributivity_impl(inner, stats);
163            TLExpr::Sqrt(Box::new(inner_opt))
164        }
165
166        TLExpr::Exp(inner) => {
167            let inner_opt = optimize_distributivity_impl(inner, stats);
168            TLExpr::Exp(Box::new(inner_opt))
169        }
170
171        TLExpr::Log(inner) => {
172            let inner_opt = optimize_distributivity_impl(inner, stats);
173            TLExpr::Log(Box::new(inner_opt))
174        }
175
176        TLExpr::Exists { var, domain, body } => {
177            let body_opt = optimize_distributivity_impl(body, stats);
178            TLExpr::Exists {
179                var: var.clone(),
180                domain: domain.clone(),
181                body: Box::new(body_opt),
182            }
183        }
184
185        TLExpr::ForAll { var, domain, body } => {
186            let body_opt = optimize_distributivity_impl(body, stats);
187            TLExpr::ForAll {
188                var: var.clone(),
189                domain: domain.clone(),
190                body: Box::new(body_opt),
191            }
192        }
193
194        TLExpr::Let { var, value, body } => {
195            let value_opt = optimize_distributivity_impl(value, stats);
196            let body_opt = optimize_distributivity_impl(body, stats);
197            TLExpr::Let {
198                var: var.clone(),
199                value: Box::new(value_opt),
200                body: Box::new(body_opt),
201            }
202        }
203
204        TLExpr::IfThenElse {
205            condition,
206            then_branch,
207            else_branch,
208        } => {
209            let cond_opt = optimize_distributivity_impl(condition, stats);
210            let then_opt = optimize_distributivity_impl(then_branch, stats);
211            let else_opt = optimize_distributivity_impl(else_branch, stats);
212            TLExpr::IfThenElse {
213                condition: Box::new(cond_opt),
214                then_branch: Box::new(then_opt),
215                else_branch: Box::new(else_opt),
216            }
217        }
218
219        // Comparison operators
220        TLExpr::Eq(lhs, rhs) => {
221            let lhs_opt = optimize_distributivity_impl(lhs, stats);
222            let rhs_opt = optimize_distributivity_impl(rhs, stats);
223            TLExpr::Eq(Box::new(lhs_opt), Box::new(rhs_opt))
224        }
225
226        TLExpr::Lt(lhs, rhs) => {
227            let lhs_opt = optimize_distributivity_impl(lhs, stats);
228            let rhs_opt = optimize_distributivity_impl(rhs, stats);
229            TLExpr::Lt(Box::new(lhs_opt), Box::new(rhs_opt))
230        }
231
232        TLExpr::Lte(lhs, rhs) => {
233            let lhs_opt = optimize_distributivity_impl(lhs, stats);
234            let rhs_opt = optimize_distributivity_impl(rhs, stats);
235            TLExpr::Lte(Box::new(lhs_opt), Box::new(rhs_opt))
236        }
237
238        TLExpr::Gt(lhs, rhs) => {
239            let lhs_opt = optimize_distributivity_impl(lhs, stats);
240            let rhs_opt = optimize_distributivity_impl(rhs, stats);
241            TLExpr::Gt(Box::new(lhs_opt), Box::new(rhs_opt))
242        }
243
244        TLExpr::Gte(lhs, rhs) => {
245            let lhs_opt = optimize_distributivity_impl(lhs, stats);
246            let rhs_opt = optimize_distributivity_impl(rhs, stats);
247            TLExpr::Gte(Box::new(lhs_opt), Box::new(rhs_opt))
248        }
249
250        // Min/Max
251        TLExpr::Min(lhs, rhs) => {
252            let lhs_opt = optimize_distributivity_impl(lhs, stats);
253            let rhs_opt = optimize_distributivity_impl(rhs, stats);
254            TLExpr::Min(Box::new(lhs_opt), Box::new(rhs_opt))
255        }
256
257        TLExpr::Max(lhs, rhs) => {
258            let lhs_opt = optimize_distributivity_impl(lhs, stats);
259            let rhs_opt = optimize_distributivity_impl(rhs, stats);
260            TLExpr::Max(Box::new(lhs_opt), Box::new(rhs_opt))
261        }
262
263        // Modal logic
264        TLExpr::Box(inner) => {
265            let inner_opt = optimize_distributivity_impl(inner, stats);
266            TLExpr::Box(Box::new(inner_opt))
267        }
268
269        TLExpr::Diamond(inner) => {
270            let inner_opt = optimize_distributivity_impl(inner, stats);
271            TLExpr::Diamond(Box::new(inner_opt))
272        }
273
274        // Temporal logic
275        TLExpr::Next(inner) => {
276            let inner_opt = optimize_distributivity_impl(inner, stats);
277            TLExpr::Next(Box::new(inner_opt))
278        }
279
280        TLExpr::Eventually(inner) => {
281            let inner_opt = optimize_distributivity_impl(inner, stats);
282            TLExpr::Eventually(Box::new(inner_opt))
283        }
284
285        TLExpr::Always(inner) => {
286            let inner_opt = optimize_distributivity_impl(inner, stats);
287            TLExpr::Always(Box::new(inner_opt))
288        }
289
290        TLExpr::Until { before, after } => {
291            let before_opt = optimize_distributivity_impl(before, stats);
292            let after_opt = optimize_distributivity_impl(after, stats);
293            TLExpr::Until {
294                before: Box::new(before_opt),
295                after: Box::new(after_opt),
296            }
297        }
298
299        // Leaves and other variants: no optimization needed
300        TLExpr::Pred { .. }
301        | TLExpr::Constant(_)
302        | TLExpr::Score(_)
303        | TLExpr::Mod(_, _)
304        | TLExpr::Floor(_)
305        | TLExpr::Ceil(_)
306        | TLExpr::Round(_)
307        | TLExpr::Sin(_)
308        | TLExpr::Cos(_)
309        | TLExpr::Tan(_)
310        | TLExpr::Aggregate { .. }
311        | TLExpr::TNorm { .. }
312        | TLExpr::TCoNorm { .. }
313        | TLExpr::FuzzyNot { .. }
314        | TLExpr::FuzzyImplication { .. }
315        | TLExpr::SoftExists { .. }
316        | TLExpr::SoftForAll { .. }
317        | TLExpr::WeightedRule { .. }
318        | TLExpr::ProbabilisticChoice { .. }
319        | TLExpr::Release { .. }
320        | TLExpr::WeakUntil { .. }
321        | TLExpr::StrongRelease { .. } => expr.clone(),
322
323        // All other expression types (enhancements)
324        _ => expr.clone(),
325    }
326}
327
328/// Try to factor a*b + a*c into a*(b+c)
329fn try_factor_add(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
330    // Check if both sides are multiplications
331    if let (TLExpr::Mul(l1, l2), TLExpr::Mul(r1, r2)) = (lhs, rhs) {
332        // Check for common left factor: a*b + a*c → a*(b+c)
333        if l1 == r1 {
334            return Some(TLExpr::mul(
335                (**l1).clone(),
336                TLExpr::add((**l2).clone(), (**r2).clone()),
337            ));
338        }
339        // Check for common right factor: a*b + c*b → (a+c)*b
340        if l2 == r2 {
341            return Some(TLExpr::mul(
342                TLExpr::add((**l1).clone(), (**r1).clone()),
343                (**l2).clone(),
344            ));
345        }
346        // Cross check: a*b + b*c → b*(a+c)
347        if l1 == r2 {
348            return Some(TLExpr::mul(
349                (**l1).clone(),
350                TLExpr::add((**l2).clone(), (**r1).clone()),
351            ));
352        }
353        // Cross check: a*b + c*a → a*(b+c)
354        if l2 == r1 {
355            return Some(TLExpr::mul(
356                (**l2).clone(),
357                TLExpr::add((**l1).clone(), (**r2).clone()),
358            ));
359        }
360    }
361
362    // Check for constant factors: c*a + c*b → c*(a+b)
363    if let (TLExpr::Mul(l1, l2), TLExpr::Mul(r1, r2)) = (lhs, rhs) {
364        if let (TLExpr::Constant(c1), TLExpr::Constant(c2)) = (l1.as_ref(), r1.as_ref()) {
365            if c1 == c2 {
366                return Some(TLExpr::mul(
367                    TLExpr::Constant(*c1),
368                    TLExpr::add((**l2).clone(), (**r2).clone()),
369                ));
370            }
371        }
372    }
373
374    None
375}
376
377/// Try to factor a*b - a*c into a*(b-c)
378fn try_factor_sub(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
379    // Check if both sides are multiplications
380    if let (TLExpr::Mul(l1, l2), TLExpr::Mul(r1, r2)) = (lhs, rhs) {
381        // Check for common left factor: a*b - a*c → a*(b-c)
382        if l1 == r1 {
383            return Some(TLExpr::mul(
384                (**l1).clone(),
385                TLExpr::sub((**l2).clone(), (**r2).clone()),
386            ));
387        }
388        // Check for common right factor: a*b - c*b → (a-c)*b
389        if l2 == r2 {
390            return Some(TLExpr::mul(
391                TLExpr::sub((**l1).clone(), (**r1).clone()),
392                (**l2).clone(),
393            ));
394        }
395    }
396
397    None
398}
399
400/// Try to factor (a OR b) AND (a OR c) into a OR (b AND c)
401fn try_factor_and(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
402    if let (TLExpr::Or(l1, l2), TLExpr::Or(r1, r2)) = (lhs, rhs) {
403        // (a OR b) AND (a OR c) → a OR (b AND c)
404        if l1 == r1 {
405            return Some(TLExpr::or(
406                (**l1).clone(),
407                TLExpr::and((**l2).clone(), (**r2).clone()),
408            ));
409        }
410        // (a OR b) AND (c OR a) → a OR (b AND c)
411        if l1 == r2 {
412            return Some(TLExpr::or(
413                (**l1).clone(),
414                TLExpr::and((**l2).clone(), (**r1).clone()),
415            ));
416        }
417        // (b OR a) AND (a OR c) → a OR (b AND c)
418        if l2 == r1 {
419            return Some(TLExpr::or(
420                (**l2).clone(),
421                TLExpr::and((**l1).clone(), (**r2).clone()),
422            ));
423        }
424        // (b OR a) AND (c OR a) → a OR (b AND c)
425        if l2 == r2 {
426            return Some(TLExpr::or(
427                (**l2).clone(),
428                TLExpr::and((**l1).clone(), (**r1).clone()),
429            ));
430        }
431    }
432
433    None
434}
435
436/// Try to factor (a AND b) OR (a AND c) into a AND (b OR c)
437fn try_factor_or(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
438    if let (TLExpr::And(l1, l2), TLExpr::And(r1, r2)) = (lhs, rhs) {
439        // (a AND b) OR (a AND c) → a AND (b OR c)
440        if l1 == r1 {
441            return Some(TLExpr::and(
442                (**l1).clone(),
443                TLExpr::or((**l2).clone(), (**r2).clone()),
444            ));
445        }
446        // (a AND b) OR (c AND a) → a AND (b OR c)
447        if l1 == r2 {
448            return Some(TLExpr::and(
449                (**l1).clone(),
450                TLExpr::or((**l2).clone(), (**r1).clone()),
451            ));
452        }
453        // (b AND a) OR (a AND c) → a AND (b OR c)
454        if l2 == r1 {
455            return Some(TLExpr::and(
456                (**l2).clone(),
457                TLExpr::or((**l1).clone(), (**r2).clone()),
458            ));
459        }
460        // (b AND a) OR (c AND a) → a AND (b OR c)
461        if l2 == r2 {
462            return Some(TLExpr::and(
463                (**l2).clone(),
464                TLExpr::or((**l1).clone(), (**r1).clone()),
465            ));
466        }
467    }
468
469    None
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use tensorlogic_ir::Term;
476
477    #[test]
478    fn test_factor_add_common_left() {
479        // a*b + a*c → a*(b+c)
480        let a = TLExpr::pred("a", vec![Term::var("i")]);
481        let b = TLExpr::pred("b", vec![Term::var("j")]);
482        let c = TLExpr::pred("c", vec![Term::var("k")]);
483
484        let expr = TLExpr::add(
485            TLExpr::mul(a.clone(), b.clone()),
486            TLExpr::mul(a.clone(), c.clone()),
487        );
488
489        let (optimized, stats) = optimize_distributivity(&expr);
490        assert_eq!(stats.expressions_factored, 1);
491
492        // Should be a * (b + c)
493        if let TLExpr::Mul(lhs, rhs) = optimized {
494            assert_eq!(*lhs, a);
495            if let TLExpr::Add(add_lhs, add_rhs) = *rhs {
496                assert_eq!(*add_lhs, b);
497                assert_eq!(*add_rhs, c);
498            } else {
499                panic!("Expected Add on right side of Mul");
500            }
501        } else {
502            panic!("Expected Mul expression");
503        }
504    }
505
506    #[test]
507    fn test_factor_add_common_right() {
508        // a*b + c*b → (a+c)*b
509        let a = TLExpr::pred("a", vec![Term::var("i")]);
510        let b = TLExpr::pred("b", vec![Term::var("j")]);
511        let c = TLExpr::pred("c", vec![Term::var("k")]);
512
513        let expr = TLExpr::add(
514            TLExpr::mul(a.clone(), b.clone()),
515            TLExpr::mul(c.clone(), b.clone()),
516        );
517
518        let (optimized, stats) = optimize_distributivity(&expr);
519        assert_eq!(stats.expressions_factored, 1);
520
521        // Should be (a + c) * b
522        if let TLExpr::Mul(lhs, rhs) = optimized {
523            assert_eq!(*rhs, b);
524            if let TLExpr::Add(add_lhs, add_rhs) = *lhs {
525                assert_eq!(*add_lhs, a);
526                assert_eq!(*add_rhs, c);
527            } else {
528                panic!("Expected Add on left side of Mul");
529            }
530        } else {
531            panic!("Expected Mul expression");
532        }
533    }
534
535    #[test]
536    fn test_factor_sub() {
537        // a*b - a*c → a*(b-c)
538        let a = TLExpr::pred("a", vec![Term::var("i")]);
539        let b = TLExpr::pred("b", vec![Term::var("j")]);
540        let c = TLExpr::pred("c", vec![Term::var("k")]);
541
542        let expr = TLExpr::sub(
543            TLExpr::mul(a.clone(), b.clone()),
544            TLExpr::mul(a.clone(), c.clone()),
545        );
546
547        let (optimized, stats) = optimize_distributivity(&expr);
548        assert_eq!(stats.expressions_factored, 1);
549
550        // Should be a * (b - c)
551        if let TLExpr::Mul(lhs, rhs) = optimized {
552            assert_eq!(*lhs, a);
553            assert!(matches!(*rhs, TLExpr::Sub(_, _)));
554        } else {
555            panic!("Expected Mul expression");
556        }
557    }
558
559    #[test]
560    fn test_factor_and_over_or() {
561        // (a OR b) AND (a OR c) → a OR (b AND c)
562        let a = TLExpr::pred("a", vec![Term::var("i")]);
563        let b = TLExpr::pred("b", vec![Term::var("j")]);
564        let c = TLExpr::pred("c", vec![Term::var("k")]);
565
566        let expr = TLExpr::and(
567            TLExpr::or(a.clone(), b.clone()),
568            TLExpr::or(a.clone(), c.clone()),
569        );
570
571        let (optimized, stats) = optimize_distributivity(&expr);
572        assert_eq!(stats.expressions_factored, 1);
573
574        // Should be a OR (b AND c)
575        if let TLExpr::Or(lhs, rhs) = optimized {
576            assert_eq!(*lhs, a);
577            if let TLExpr::And(and_lhs, and_rhs) = *rhs {
578                assert_eq!(*and_lhs, b);
579                assert_eq!(*and_rhs, c);
580            } else {
581                panic!("Expected And on right side of Or");
582            }
583        } else {
584            panic!("Expected Or expression");
585        }
586    }
587
588    #[test]
589    fn test_factor_or_over_and() {
590        // (a AND b) OR (a AND c) → a AND (b OR c)
591        let a = TLExpr::pred("a", vec![Term::var("i")]);
592        let b = TLExpr::pred("b", vec![Term::var("j")]);
593        let c = TLExpr::pred("c", vec![Term::var("k")]);
594
595        let expr = TLExpr::or(
596            TLExpr::and(a.clone(), b.clone()),
597            TLExpr::and(a.clone(), c.clone()),
598        );
599
600        let (optimized, stats) = optimize_distributivity(&expr);
601        assert_eq!(stats.expressions_factored, 1);
602
603        // Should be a AND (b OR c)
604        if let TLExpr::And(lhs, rhs) = optimized {
605            assert_eq!(*lhs, a);
606            if let TLExpr::Or(or_lhs, or_rhs) = *rhs {
607                assert_eq!(*or_lhs, b);
608                assert_eq!(*or_rhs, c);
609            } else {
610                panic!("Expected Or on right side of And");
611            }
612        } else {
613            panic!("Expected And expression");
614        }
615    }
616
617    #[test]
618    fn test_no_factoring_possible() {
619        // a*b + c*d → no factoring
620        let a = TLExpr::pred("a", vec![Term::var("i")]);
621        let b = TLExpr::pred("b", vec![Term::var("j")]);
622        let c = TLExpr::pred("c", vec![Term::var("k")]);
623        let d = TLExpr::pred("d", vec![Term::var("l")]);
624
625        let expr = TLExpr::add(TLExpr::mul(a, b), TLExpr::mul(c, d));
626
627        let (optimized, stats) = optimize_distributivity(&expr);
628        assert_eq!(stats.expressions_factored, 0);
629        // Should remain unchanged structurally
630        assert!(matches!(optimized, TLExpr::Add(_, _)));
631    }
632
633    #[test]
634    fn test_nested_factoring() {
635        // (a*b + a*c) + a*d → should factor at some level
636        let a = TLExpr::pred("a", vec![Term::var("i")]);
637        let b = TLExpr::pred("b", vec![Term::var("j")]);
638        let c = TLExpr::pred("c", vec![Term::var("k")]);
639        let d = TLExpr::pred("d", vec![Term::var("l")]);
640
641        let inner = TLExpr::add(
642            TLExpr::mul(a.clone(), b.clone()),
643            TLExpr::mul(a.clone(), c.clone()),
644        );
645        let expr = TLExpr::add(inner, TLExpr::mul(a.clone(), d));
646
647        let (_, stats) = optimize_distributivity(&expr);
648        // Should factor at least once
649        assert!(stats.expressions_factored >= 1);
650    }
651
652    #[test]
653    fn test_quantifier_body() {
654        let a = TLExpr::pred("a", vec![Term::var("x"), Term::var("i")]);
655        let b = TLExpr::pred("b", vec![Term::var("x"), Term::var("j")]);
656        let c = TLExpr::pred("c", vec![Term::var("x"), Term::var("k")]);
657
658        let body = TLExpr::add(
659            TLExpr::mul(a.clone(), b.clone()),
660            TLExpr::mul(a.clone(), c.clone()),
661        );
662        let expr = TLExpr::exists("x", "D", body);
663
664        let (optimized, stats) = optimize_distributivity(&expr);
665        assert_eq!(stats.expressions_factored, 1);
666
667        if let TLExpr::Exists { body, .. } = optimized {
668            // Body should be factored
669            assert!(matches!(*body, TLExpr::Mul(_, _)));
670        } else {
671            panic!("Expected Exists expression");
672        }
673    }
674
675    #[test]
676    fn test_stats_total_optimizations() {
677        let stats = DistributivityStats {
678            expressions_factored: 3,
679            expressions_expanded: 2,
680            common_terms_extracted: 1,
681            total_processed: 100,
682        };
683        assert_eq!(stats.total_optimizations(), 6);
684    }
685}