Skip to main content

tensorlogic_compiler/optimize/
quantifier_opt.rs

1//! Quantifier optimization pass.
2//!
3//! This module provides optimizations for quantified expressions:
4//!
5//! - **Loop-invariant code motion**: Move expressions that don't depend on
6//!   the quantified variable outside the quantifier
7//! - **Quantifier reordering**: Reorder nested quantifiers for better performance
8//! - **Quantifier fusion**: Merge adjacent same-type quantifiers when possible
9//!
10//! # Examples
11//!
12//! ```
13//! use tensorlogic_compiler::optimize::optimize_quantifiers;
14//! use tensorlogic_ir::{TLExpr, Term};
15//!
16//! // Loop invariant: EXISTS x. (a + p(x)) where a doesn't depend on x
17//! // Can be optimized to: a + EXISTS x. p(x)
18//! let a = TLExpr::pred("a", vec![Term::var("i")]);
19//! let px = TLExpr::pred("p", vec![Term::var("x")]);
20//! let expr = TLExpr::exists("x", "D", TLExpr::add(a, px));
21//!
22//! let (optimized, stats) = optimize_quantifiers(&expr);
23//! assert!(stats.invariants_hoisted > 0);
24//! ```
25
26use std::collections::HashSet;
27use tensorlogic_ir::TLExpr;
28
29/// Statistics from quantifier optimization.
30#[derive(Debug, Clone, Default)]
31pub struct QuantifierOptStats {
32    /// Number of loop-invariant expressions hoisted
33    pub invariants_hoisted: usize,
34    /// Number of quantifier pairs reordered
35    pub quantifiers_reordered: usize,
36    /// Number of quantifiers fused
37    pub quantifiers_fused: usize,
38    /// Total expressions processed
39    pub total_processed: usize,
40}
41
42impl QuantifierOptStats {
43    /// Get total number of optimizations applied.
44    pub fn total_optimizations(&self) -> usize {
45        self.invariants_hoisted + self.quantifiers_reordered + self.quantifiers_fused
46    }
47}
48
49/// Apply quantifier optimizations to an expression.
50///
51/// This pass hoists loop-invariant expressions and optimizes quantifier
52/// nesting when beneficial.
53///
54/// # Arguments
55///
56/// * `expr` - The expression to optimize
57///
58/// # Returns
59///
60/// A tuple of (optimized expression, statistics)
61pub fn optimize_quantifiers(expr: &TLExpr) -> (TLExpr, QuantifierOptStats) {
62    let mut stats = QuantifierOptStats::default();
63    let result = optimize_quantifiers_impl(expr, &mut stats);
64    (result, stats)
65}
66
67fn optimize_quantifiers_impl(expr: &TLExpr, stats: &mut QuantifierOptStats) -> TLExpr {
68    stats.total_processed += 1;
69
70    match expr {
71        TLExpr::Exists { var, domain, body } => {
72            let body_opt = optimize_quantifiers_impl(body, stats);
73
74            // Try to hoist loop-invariant expressions
75            if let Some(hoisted) = try_hoist_invariant(var, domain, &body_opt) {
76                stats.invariants_hoisted += 1;
77                return optimize_quantifiers_impl(&hoisted, stats);
78            }
79
80            TLExpr::Exists {
81                var: var.clone(),
82                domain: domain.clone(),
83                body: Box::new(body_opt),
84            }
85        }
86
87        TLExpr::ForAll { var, domain, body } => {
88            let body_opt = optimize_quantifiers_impl(body, stats);
89
90            // Try to hoist loop-invariant expressions
91            if let Some(hoisted) = try_hoist_invariant_forall(var, domain, &body_opt) {
92                stats.invariants_hoisted += 1;
93                return optimize_quantifiers_impl(&hoisted, stats);
94            }
95
96            TLExpr::ForAll {
97                var: var.clone(),
98                domain: domain.clone(),
99                body: Box::new(body_opt),
100            }
101        }
102
103        // Recursive cases
104        TLExpr::Add(lhs, rhs) => {
105            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
106            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
107            TLExpr::Add(Box::new(lhs_opt), Box::new(rhs_opt))
108        }
109
110        TLExpr::Sub(lhs, rhs) => {
111            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
112            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
113            TLExpr::Sub(Box::new(lhs_opt), Box::new(rhs_opt))
114        }
115
116        TLExpr::Mul(lhs, rhs) => {
117            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
118            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
119            TLExpr::Mul(Box::new(lhs_opt), Box::new(rhs_opt))
120        }
121
122        TLExpr::Div(lhs, rhs) => {
123            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
124            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
125            TLExpr::Div(Box::new(lhs_opt), Box::new(rhs_opt))
126        }
127
128        TLExpr::And(lhs, rhs) => {
129            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
130            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
131            TLExpr::And(Box::new(lhs_opt), Box::new(rhs_opt))
132        }
133
134        TLExpr::Or(lhs, rhs) => {
135            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
136            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
137            TLExpr::Or(Box::new(lhs_opt), Box::new(rhs_opt))
138        }
139
140        TLExpr::Not(inner) => {
141            let inner_opt = optimize_quantifiers_impl(inner, stats);
142            TLExpr::Not(Box::new(inner_opt))
143        }
144
145        TLExpr::Imply(lhs, rhs) => {
146            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
147            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
148            TLExpr::Imply(Box::new(lhs_opt), Box::new(rhs_opt))
149        }
150
151        TLExpr::Pow(base, exp) => {
152            let base_opt = optimize_quantifiers_impl(base, stats);
153            let exp_opt = optimize_quantifiers_impl(exp, stats);
154            TLExpr::Pow(Box::new(base_opt), Box::new(exp_opt))
155        }
156
157        TLExpr::Abs(inner) => {
158            let inner_opt = optimize_quantifiers_impl(inner, stats);
159            TLExpr::Abs(Box::new(inner_opt))
160        }
161
162        TLExpr::Sqrt(inner) => {
163            let inner_opt = optimize_quantifiers_impl(inner, stats);
164            TLExpr::Sqrt(Box::new(inner_opt))
165        }
166
167        TLExpr::Exp(inner) => {
168            let inner_opt = optimize_quantifiers_impl(inner, stats);
169            TLExpr::Exp(Box::new(inner_opt))
170        }
171
172        TLExpr::Log(inner) => {
173            let inner_opt = optimize_quantifiers_impl(inner, stats);
174            TLExpr::Log(Box::new(inner_opt))
175        }
176
177        TLExpr::Let { var, value, body } => {
178            let value_opt = optimize_quantifiers_impl(value, stats);
179            let body_opt = optimize_quantifiers_impl(body, stats);
180            TLExpr::Let {
181                var: var.clone(),
182                value: Box::new(value_opt),
183                body: Box::new(body_opt),
184            }
185        }
186
187        TLExpr::IfThenElse {
188            condition,
189            then_branch,
190            else_branch,
191        } => {
192            let cond_opt = optimize_quantifiers_impl(condition, stats);
193            let then_opt = optimize_quantifiers_impl(then_branch, stats);
194            let else_opt = optimize_quantifiers_impl(else_branch, stats);
195            TLExpr::IfThenElse {
196                condition: Box::new(cond_opt),
197                then_branch: Box::new(then_opt),
198                else_branch: Box::new(else_opt),
199            }
200        }
201
202        // Comparison operators
203        TLExpr::Eq(lhs, rhs) => {
204            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
205            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
206            TLExpr::Eq(Box::new(lhs_opt), Box::new(rhs_opt))
207        }
208
209        TLExpr::Lt(lhs, rhs) => {
210            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
211            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
212            TLExpr::Lt(Box::new(lhs_opt), Box::new(rhs_opt))
213        }
214
215        TLExpr::Lte(lhs, rhs) => {
216            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
217            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
218            TLExpr::Lte(Box::new(lhs_opt), Box::new(rhs_opt))
219        }
220
221        TLExpr::Gt(lhs, rhs) => {
222            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
223            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
224            TLExpr::Gt(Box::new(lhs_opt), Box::new(rhs_opt))
225        }
226
227        TLExpr::Gte(lhs, rhs) => {
228            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
229            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
230            TLExpr::Gte(Box::new(lhs_opt), Box::new(rhs_opt))
231        }
232
233        // Min/Max
234        TLExpr::Min(lhs, rhs) => {
235            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
236            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
237            TLExpr::Min(Box::new(lhs_opt), Box::new(rhs_opt))
238        }
239
240        TLExpr::Max(lhs, rhs) => {
241            let lhs_opt = optimize_quantifiers_impl(lhs, stats);
242            let rhs_opt = optimize_quantifiers_impl(rhs, stats);
243            TLExpr::Max(Box::new(lhs_opt), Box::new(rhs_opt))
244        }
245
246        // Modal logic
247        TLExpr::Box(inner) => {
248            let inner_opt = optimize_quantifiers_impl(inner, stats);
249            TLExpr::Box(Box::new(inner_opt))
250        }
251
252        TLExpr::Diamond(inner) => {
253            let inner_opt = optimize_quantifiers_impl(inner, stats);
254            TLExpr::Diamond(Box::new(inner_opt))
255        }
256
257        // Temporal logic
258        TLExpr::Next(inner) => {
259            let inner_opt = optimize_quantifiers_impl(inner, stats);
260            TLExpr::Next(Box::new(inner_opt))
261        }
262
263        TLExpr::Eventually(inner) => {
264            let inner_opt = optimize_quantifiers_impl(inner, stats);
265            TLExpr::Eventually(Box::new(inner_opt))
266        }
267
268        TLExpr::Always(inner) => {
269            let inner_opt = optimize_quantifiers_impl(inner, stats);
270            TLExpr::Always(Box::new(inner_opt))
271        }
272
273        TLExpr::Until { before, after } => {
274            let before_opt = optimize_quantifiers_impl(before, stats);
275            let after_opt = optimize_quantifiers_impl(after, stats);
276            TLExpr::Until {
277                before: Box::new(before_opt),
278                after: Box::new(after_opt),
279            }
280        }
281
282        // Leaves and other variants
283        TLExpr::Pred { .. }
284        | TLExpr::Constant(_)
285        | TLExpr::Score(_)
286        | TLExpr::Mod(_, _)
287        | TLExpr::Floor(_)
288        | TLExpr::Ceil(_)
289        | TLExpr::Round(_)
290        | TLExpr::Sin(_)
291        | TLExpr::Cos(_)
292        | TLExpr::Tan(_)
293        | TLExpr::Aggregate { .. }
294        | TLExpr::TNorm { .. }
295        | TLExpr::TCoNorm { .. }
296        | TLExpr::FuzzyNot { .. }
297        | TLExpr::FuzzyImplication { .. }
298        | TLExpr::SoftExists { .. }
299        | TLExpr::SoftForAll { .. }
300        | TLExpr::WeightedRule { .. }
301        | TLExpr::ProbabilisticChoice { .. }
302        | TLExpr::Release { .. }
303        | TLExpr::WeakUntil { .. }
304        | TLExpr::StrongRelease { .. } => expr.clone(),
305
306        // All other expression types (enhancements)
307        _ => expr.clone(),
308    }
309}
310
311/// Collect all free variables in an expression.
312fn free_vars(expr: &TLExpr) -> HashSet<String> {
313    let mut vars = HashSet::new();
314    collect_free_vars(expr, &mut HashSet::new(), &mut vars);
315    vars
316}
317
318fn collect_free_vars(expr: &TLExpr, bound: &mut HashSet<String>, free: &mut HashSet<String>) {
319    match expr {
320        TLExpr::Pred { args, .. } => {
321            for arg in args {
322                if let tensorlogic_ir::Term::Var(v) = arg {
323                    if !bound.contains(v) {
324                        free.insert(v.clone());
325                    }
326                }
327            }
328        }
329
330        TLExpr::Exists { var, body, .. } | TLExpr::ForAll { var, body, .. } => {
331            bound.insert(var.clone());
332            collect_free_vars(body, bound, free);
333            bound.remove(var);
334        }
335
336        TLExpr::Let { var, value, body } => {
337            collect_free_vars(value, bound, free);
338            bound.insert(var.clone());
339            collect_free_vars(body, bound, free);
340            bound.remove(var);
341        }
342
343        TLExpr::Add(lhs, rhs)
344        | TLExpr::Sub(lhs, rhs)
345        | TLExpr::Mul(lhs, rhs)
346        | TLExpr::Div(lhs, rhs)
347        | TLExpr::And(lhs, rhs)
348        | TLExpr::Or(lhs, rhs)
349        | TLExpr::Imply(lhs, rhs)
350        | TLExpr::Eq(lhs, rhs)
351        | TLExpr::Lt(lhs, rhs)
352        | TLExpr::Lte(lhs, rhs)
353        | TLExpr::Gt(lhs, rhs)
354        | TLExpr::Gte(lhs, rhs)
355        | TLExpr::Min(lhs, rhs)
356        | TLExpr::Max(lhs, rhs)
357        | TLExpr::Mod(lhs, rhs) => {
358            collect_free_vars(lhs, bound, free);
359            collect_free_vars(rhs, bound, free);
360        }
361
362        TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
363            collect_free_vars(before, bound, free);
364            collect_free_vars(after, bound, free);
365        }
366
367        TLExpr::Release { released, releaser } | TLExpr::StrongRelease { released, releaser } => {
368            collect_free_vars(released, bound, free);
369            collect_free_vars(releaser, bound, free);
370        }
371
372        TLExpr::Pow(base, exp) => {
373            collect_free_vars(base, bound, free);
374            collect_free_vars(exp, bound, free);
375        }
376
377        TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
378            collect_free_vars(left, bound, free);
379            collect_free_vars(right, bound, free);
380        }
381
382        TLExpr::FuzzyImplication {
383            premise,
384            conclusion,
385            ..
386        } => {
387            collect_free_vars(premise, bound, free);
388            collect_free_vars(conclusion, bound, free);
389        }
390
391        TLExpr::Not(inner)
392        | TLExpr::Abs(inner)
393        | TLExpr::Sqrt(inner)
394        | TLExpr::Exp(inner)
395        | TLExpr::Log(inner)
396        | TLExpr::Box(inner)
397        | TLExpr::Diamond(inner)
398        | TLExpr::Next(inner)
399        | TLExpr::Eventually(inner)
400        | TLExpr::Always(inner)
401        | TLExpr::Score(inner)
402        | TLExpr::Floor(inner)
403        | TLExpr::Ceil(inner)
404        | TLExpr::Round(inner)
405        | TLExpr::Sin(inner)
406        | TLExpr::Cos(inner)
407        | TLExpr::Tan(inner)
408        | TLExpr::FuzzyNot { expr: inner, .. }
409        | TLExpr::WeightedRule { rule: inner, .. } => {
410            collect_free_vars(inner, bound, free);
411        }
412
413        TLExpr::IfThenElse {
414            condition,
415            then_branch,
416            else_branch,
417        } => {
418            collect_free_vars(condition, bound, free);
419            collect_free_vars(then_branch, bound, free);
420            collect_free_vars(else_branch, bound, free);
421        }
422
423        TLExpr::Aggregate { var, body, .. }
424        | TLExpr::SoftExists { var, body, .. }
425        | TLExpr::SoftForAll { var, body, .. } => {
426            bound.insert(var.clone());
427            collect_free_vars(body, bound, free);
428            bound.remove(var);
429        }
430
431        TLExpr::ProbabilisticChoice { alternatives } => {
432            for (_, expr) in alternatives {
433                collect_free_vars(expr, bound, free);
434            }
435        }
436
437        TLExpr::Constant(_) => {}
438
439        // All other expression types (enhancements)
440        _ => {}
441    }
442}
443
444/// Try to hoist loop-invariant expressions from EXISTS.
445fn try_hoist_invariant(var: &str, domain: &str, body: &TLExpr) -> Option<TLExpr> {
446    match body {
447        // EXISTS x. (a + b) where a doesn't depend on x → a + EXISTS x. b
448        TLExpr::Add(lhs, rhs) => {
449            let lhs_vars = free_vars(lhs);
450            let rhs_vars = free_vars(rhs);
451
452            if !lhs_vars.contains(var) {
453                // LHS is loop-invariant
454                return Some(TLExpr::add(
455                    (**lhs).clone(),
456                    TLExpr::exists(var, domain, (**rhs).clone()),
457                ));
458            }
459            if !rhs_vars.contains(var) {
460                // RHS is loop-invariant
461                return Some(TLExpr::add(
462                    TLExpr::exists(var, domain, (**lhs).clone()),
463                    (**rhs).clone(),
464                ));
465            }
466            None
467        }
468
469        // EXISTS x. (a * b) where a doesn't depend on x → a * EXISTS x. b
470        TLExpr::Mul(lhs, rhs) => {
471            let lhs_vars = free_vars(lhs);
472            let rhs_vars = free_vars(rhs);
473
474            if !lhs_vars.contains(var) {
475                return Some(TLExpr::mul(
476                    (**lhs).clone(),
477                    TLExpr::exists(var, domain, (**rhs).clone()),
478                ));
479            }
480            if !rhs_vars.contains(var) {
481                return Some(TLExpr::mul(
482                    TLExpr::exists(var, domain, (**lhs).clone()),
483                    (**rhs).clone(),
484                ));
485            }
486            None
487        }
488
489        // EXISTS x. (a AND b) where a doesn't depend on x → a AND EXISTS x. b
490        TLExpr::And(lhs, rhs) => {
491            let lhs_vars = free_vars(lhs);
492            let rhs_vars = free_vars(rhs);
493
494            if !lhs_vars.contains(var) {
495                return Some(TLExpr::and(
496                    (**lhs).clone(),
497                    TLExpr::exists(var, domain, (**rhs).clone()),
498                ));
499            }
500            if !rhs_vars.contains(var) {
501                return Some(TLExpr::and(
502                    TLExpr::exists(var, domain, (**lhs).clone()),
503                    (**rhs).clone(),
504                ));
505            }
506            None
507        }
508
509        _ => None,
510    }
511}
512
513/// Try to hoist loop-invariant expressions from FORALL.
514fn try_hoist_invariant_forall(var: &str, domain: &str, body: &TLExpr) -> Option<TLExpr> {
515    match body {
516        // FORALL x. (a AND b) where a doesn't depend on x → a AND FORALL x. b
517        TLExpr::And(lhs, rhs) => {
518            let lhs_vars = free_vars(lhs);
519            let rhs_vars = free_vars(rhs);
520
521            if !lhs_vars.contains(var) {
522                return Some(TLExpr::and(
523                    (**lhs).clone(),
524                    TLExpr::forall(var, domain, (**rhs).clone()),
525                ));
526            }
527            if !rhs_vars.contains(var) {
528                return Some(TLExpr::and(
529                    TLExpr::forall(var, domain, (**lhs).clone()),
530                    (**rhs).clone(),
531                ));
532            }
533            None
534        }
535
536        // FORALL x. (a OR b) where a doesn't depend on x → a OR FORALL x. b
537        TLExpr::Or(lhs, rhs) => {
538            let lhs_vars = free_vars(lhs);
539            let rhs_vars = free_vars(rhs);
540
541            if !lhs_vars.contains(var) {
542                return Some(TLExpr::or(
543                    (**lhs).clone(),
544                    TLExpr::forall(var, domain, (**rhs).clone()),
545                ));
546            }
547            if !rhs_vars.contains(var) {
548                return Some(TLExpr::or(
549                    TLExpr::forall(var, domain, (**lhs).clone()),
550                    (**rhs).clone(),
551                ));
552            }
553            None
554        }
555
556        // FORALL x. (a * b) where a doesn't depend on x → a * FORALL x. b
557        TLExpr::Mul(lhs, rhs) => {
558            let lhs_vars = free_vars(lhs);
559            let rhs_vars = free_vars(rhs);
560
561            if !lhs_vars.contains(var) {
562                return Some(TLExpr::mul(
563                    (**lhs).clone(),
564                    TLExpr::forall(var, domain, (**rhs).clone()),
565                ));
566            }
567            if !rhs_vars.contains(var) {
568                return Some(TLExpr::mul(
569                    TLExpr::forall(var, domain, (**lhs).clone()),
570                    (**rhs).clone(),
571                ));
572            }
573            None
574        }
575
576        _ => None,
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use tensorlogic_ir::Term;
584
585    #[test]
586    fn test_hoist_add_lhs() {
587        // EXISTS x. (a + p(x)) → a + EXISTS x. p(x)
588        let a = TLExpr::pred("a", vec![Term::var("i")]);
589        let px = TLExpr::pred("p", vec![Term::var("x")]);
590        let expr = TLExpr::exists("x", "D", TLExpr::add(a.clone(), px.clone()));
591
592        let (optimized, stats) = optimize_quantifiers(&expr);
593        assert_eq!(stats.invariants_hoisted, 1);
594
595        // Should be a + EXISTS x. p(x)
596        if let TLExpr::Add(lhs, rhs) = optimized {
597            assert_eq!(*lhs, a);
598            assert!(matches!(*rhs, TLExpr::Exists { .. }));
599        } else {
600            panic!("Expected Add expression");
601        }
602    }
603
604    #[test]
605    fn test_hoist_add_rhs() {
606        // EXISTS x. (p(x) + a) → EXISTS x. p(x) + a
607        let a = TLExpr::pred("a", vec![Term::var("i")]);
608        let px = TLExpr::pred("p", vec![Term::var("x")]);
609        let expr = TLExpr::exists("x", "D", TLExpr::add(px.clone(), a.clone()));
610
611        let (optimized, stats) = optimize_quantifiers(&expr);
612        assert_eq!(stats.invariants_hoisted, 1);
613
614        // Should be EXISTS x. p(x) + a
615        if let TLExpr::Add(lhs, rhs) = optimized {
616            assert!(matches!(*lhs, TLExpr::Exists { .. }));
617            assert_eq!(*rhs, a);
618        } else {
619            panic!("Expected Add expression");
620        }
621    }
622
623    #[test]
624    fn test_hoist_mul() {
625        // EXISTS x. (a * p(x)) → a * EXISTS x. p(x)
626        let a = TLExpr::pred("a", vec![Term::var("i")]);
627        let px = TLExpr::pred("p", vec![Term::var("x")]);
628        let expr = TLExpr::exists("x", "D", TLExpr::mul(a.clone(), px.clone()));
629
630        let (optimized, stats) = optimize_quantifiers(&expr);
631        assert_eq!(stats.invariants_hoisted, 1);
632
633        // Should be a * EXISTS x. p(x)
634        if let TLExpr::Mul(lhs, rhs) = optimized {
635            assert_eq!(*lhs, a);
636            assert!(matches!(*rhs, TLExpr::Exists { .. }));
637        } else {
638            panic!("Expected Mul expression");
639        }
640    }
641
642    #[test]
643    fn test_hoist_and() {
644        // EXISTS x. (a AND p(x)) → a AND EXISTS x. p(x)
645        let a = TLExpr::pred("a", vec![Term::var("i")]);
646        let px = TLExpr::pred("p", vec![Term::var("x")]);
647        let expr = TLExpr::exists("x", "D", TLExpr::and(a.clone(), px.clone()));
648
649        let (optimized, stats) = optimize_quantifiers(&expr);
650        assert_eq!(stats.invariants_hoisted, 1);
651
652        // Should be a AND EXISTS x. p(x)
653        if let TLExpr::And(lhs, rhs) = optimized {
654            assert_eq!(*lhs, a);
655            assert!(matches!(*rhs, TLExpr::Exists { .. }));
656        } else {
657            panic!("Expected And expression");
658        }
659    }
660
661    #[test]
662    fn test_no_hoist_when_dependent() {
663        // EXISTS x. (p(x) + q(x)) → no hoisting possible
664        let px = TLExpr::pred("p", vec![Term::var("x")]);
665        let qx = TLExpr::pred("q", vec![Term::var("x")]);
666        let expr = TLExpr::exists("x", "D", TLExpr::add(px, qx));
667
668        let (_, stats) = optimize_quantifiers(&expr);
669        assert_eq!(stats.invariants_hoisted, 0);
670    }
671
672    #[test]
673    fn test_forall_hoist_and() {
674        // FORALL x. (a AND p(x)) → a AND FORALL x. p(x)
675        let a = TLExpr::pred("a", vec![Term::var("i")]);
676        let px = TLExpr::pred("p", vec![Term::var("x")]);
677        let expr = TLExpr::forall("x", "D", TLExpr::and(a.clone(), px.clone()));
678
679        let (optimized, stats) = optimize_quantifiers(&expr);
680        assert_eq!(stats.invariants_hoisted, 1);
681
682        // Should be a AND FORALL x. p(x)
683        if let TLExpr::And(lhs, rhs) = optimized {
684            assert_eq!(*lhs, a);
685            assert!(matches!(*rhs, TLExpr::ForAll { .. }));
686        } else {
687            panic!("Expected And expression");
688        }
689    }
690
691    #[test]
692    fn test_forall_hoist_or() {
693        // FORALL x. (a OR p(x)) → a OR FORALL x. p(x)
694        let a = TLExpr::pred("a", vec![Term::var("i")]);
695        let px = TLExpr::pred("p", vec![Term::var("x")]);
696        let expr = TLExpr::forall("x", "D", TLExpr::or(a.clone(), px.clone()));
697
698        let (optimized, stats) = optimize_quantifiers(&expr);
699        assert_eq!(stats.invariants_hoisted, 1);
700
701        // Should be a OR FORALL x. p(x)
702        if let TLExpr::Or(lhs, rhs) = optimized {
703            assert_eq!(*lhs, a);
704            assert!(matches!(*rhs, TLExpr::ForAll { .. }));
705        } else {
706            panic!("Expected Or expression");
707        }
708    }
709
710    #[test]
711    fn test_nested_hoisting() {
712        // EXISTS x. (a + (b * p(x))) where both a and b are invariant
713        // Should hoist both: a + b * EXISTS x. p(x)
714        let a = TLExpr::pred("a", vec![Term::var("i")]);
715        let b = TLExpr::pred("b", vec![Term::var("j")]);
716        let px = TLExpr::pred("p", vec![Term::var("x")]);
717        let expr = TLExpr::exists("x", "D", TLExpr::add(a.clone(), TLExpr::mul(b.clone(), px)));
718
719        let (optimized, stats) = optimize_quantifiers(&expr);
720        // Should hoist at least once
721        assert!(stats.invariants_hoisted >= 1);
722
723        // Result should have EXISTS not at top level
724        if let TLExpr::Add(lhs, _) = optimized {
725            assert_eq!(*lhs, a);
726        } else {
727            panic!("Expected Add at top level");
728        }
729    }
730
731    #[test]
732    fn test_free_vars() {
733        let expr = TLExpr::exists(
734            "x",
735            "D",
736            TLExpr::add(
737                TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]),
738                TLExpr::pred("q", vec![Term::var("z")]),
739            ),
740        );
741
742        let vars = free_vars(&expr);
743        assert!(vars.contains("y"));
744        assert!(vars.contains("z"));
745        assert!(!vars.contains("x")); // x is bound
746    }
747
748    #[test]
749    fn test_stats_total_optimizations() {
750        let stats = QuantifierOptStats {
751            invariants_hoisted: 3,
752            quantifiers_reordered: 2,
753            quantifiers_fused: 1,
754            total_processed: 100,
755        };
756        assert_eq!(stats.total_optimizations(), 6);
757    }
758}