Skip to main content

tensorlogic_compiler/optimize/
pipeline.rs

1//! Multi-pass optimization pipeline for TLExpr expressions.
2//!
3//! This module provides a unified optimization pipeline that combines multiple
4//! optimization passes and applies them iteratively until a fixed point is reached.
5//!
6//! # Architecture
7//!
8//! The pipeline applies 7 optimization passes in this order:
9//! 1. **Negation optimization**: Push negations inward using De Morgan's laws
10//! 2. **Constant folding**: Evaluate constant expressions at compile time
11//! 3. **Algebraic simplification**: Apply mathematical identities (x+0=x, x*1=x, etc.)
12//! 4. **Strength reduction**: Replace expensive operations with cheaper equivalents (x^2→x*x)
13//! 5. **Distributivity**: Factor common subexpressions (a*b + a*c → a*(b+c))
14//! 6. **Quantifier optimization**: Loop-invariant code motion (∃x.(a+p(x)) → a + ∃x.p(x))
15//! 7. **Dead code elimination**: Remove unreachable code and constant branches
16//!
17//! This order is chosen because:
18//! - Negation optimization can expose more opportunities for other passes
19//! - Constant folding creates simpler expressions for subsequent passes
20//! - Algebraic simplification can create new constants and identity patterns
21//! - Strength reduction makes operations more efficient
22//! - Distributivity reduces redundant computation
23//! - Quantifier optimization hoists loop-invariant code
24//! - Dead code elimination removes unreachable branches created by earlier passes
25//!
26//! # Examples
27//!
28//! ```
29//! use tensorlogic_compiler::optimize::{OptimizationPipeline, PipelineConfig};
30//! use tensorlogic_ir::{TLExpr, Term};
31//!
32//! // Create a pipeline with default configuration
33//! let pipeline = OptimizationPipeline::new();
34//!
35//! // Optimize an expression: NOT(AND(x + 0, 2.0 * 3.0))
36//! let x = TLExpr::pred("x", vec![Term::var("i")]);
37//! let expr = TLExpr::negate(TLExpr::and(
38//!     TLExpr::add(x, TLExpr::Constant(0.0)),
39//!     TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
40//! ));
41//!
42//! let (optimized, stats) = pipeline.optimize(&expr);
43//!
44//! // Pipeline applies multiple passes and reports statistics
45//! assert!(stats.total_iterations > 0);
46//! assert!(stats.constant_folding.binary_ops_folded > 0);
47//! assert!(stats.algebraic.identities_eliminated > 0);
48//! ```
49
50use super::{
51    algebraic::{simplify_algebraic, AlgebraicSimplificationStats},
52    constant_folding::{fold_constants, ConstantFoldingStats},
53    dead_code::{eliminate_dead_code, DeadCodeStats},
54    distributivity::{optimize_distributivity, DistributivityStats},
55    negation::{optimize_negations, NegationOptStats},
56    quantifier_opt::{optimize_quantifiers, QuantifierOptStats},
57    strength_reduction::{reduce_strength, StrengthReductionStats},
58};
59use tensorlogic_ir::TLExpr;
60
61/// Configuration for the optimization pipeline.
62#[derive(Debug, Clone)]
63pub struct PipelineConfig {
64    /// Enable negation optimization pass
65    pub enable_negation_opt: bool,
66    /// Enable constant folding pass
67    pub enable_constant_folding: bool,
68    /// Enable algebraic simplification pass
69    pub enable_algebraic_simplification: bool,
70    /// Enable strength reduction pass
71    pub enable_strength_reduction: bool,
72    /// Enable distributivity optimization pass
73    pub enable_distributivity: bool,
74    /// Enable quantifier optimization pass
75    pub enable_quantifier_opt: bool,
76    /// Enable dead code elimination pass
77    pub enable_dead_code_elimination: bool,
78    /// Maximum number of iterations before stopping
79    pub max_iterations: usize,
80    /// Stop early if an iteration makes no changes
81    pub stop_on_fixed_point: bool,
82}
83
84impl Default for PipelineConfig {
85    fn default() -> Self {
86        Self {
87            enable_negation_opt: true,
88            enable_constant_folding: true,
89            enable_algebraic_simplification: true,
90            enable_strength_reduction: true,
91            enable_distributivity: true,
92            enable_quantifier_opt: true,
93            enable_dead_code_elimination: true,
94            max_iterations: 10,
95            stop_on_fixed_point: true,
96        }
97    }
98}
99
100impl PipelineConfig {
101    /// Create a configuration with all optimizations enabled.
102    pub fn all() -> Self {
103        Self::default()
104    }
105
106    /// Create a configuration with all optimizations disabled.
107    pub fn none() -> Self {
108        Self {
109            enable_negation_opt: false,
110            enable_constant_folding: false,
111            enable_algebraic_simplification: false,
112            enable_strength_reduction: false,
113            enable_distributivity: false,
114            enable_quantifier_opt: false,
115            enable_dead_code_elimination: false,
116            max_iterations: 1,
117            stop_on_fixed_point: true,
118        }
119    }
120
121    /// Create a configuration with only constant folding enabled.
122    pub fn constant_folding_only() -> Self {
123        Self {
124            enable_negation_opt: false,
125            enable_constant_folding: true,
126            enable_algebraic_simplification: false,
127            enable_strength_reduction: false,
128            enable_distributivity: false,
129            enable_quantifier_opt: false,
130            enable_dead_code_elimination: false,
131            max_iterations: 1,
132            stop_on_fixed_point: true,
133        }
134    }
135
136    /// Create a configuration with only algebraic simplification enabled.
137    pub fn algebraic_only() -> Self {
138        Self {
139            enable_negation_opt: false,
140            enable_constant_folding: false,
141            enable_algebraic_simplification: true,
142            enable_strength_reduction: false,
143            enable_distributivity: false,
144            enable_quantifier_opt: false,
145            enable_dead_code_elimination: false,
146            max_iterations: 1,
147            stop_on_fixed_point: true,
148        }
149    }
150
151    /// Create a configuration for aggressive optimization (more iterations).
152    pub fn aggressive() -> Self {
153        Self {
154            enable_negation_opt: true,
155            enable_constant_folding: true,
156            enable_algebraic_simplification: true,
157            enable_strength_reduction: true,
158            enable_distributivity: true,
159            enable_quantifier_opt: true,
160            enable_dead_code_elimination: true,
161            max_iterations: 20,
162            stop_on_fixed_point: true,
163        }
164    }
165
166    /// Builder method to enable/disable negation optimization.
167    pub fn with_negation_opt(mut self, enable: bool) -> Self {
168        self.enable_negation_opt = enable;
169        self
170    }
171
172    /// Builder method to enable/disable constant folding.
173    pub fn with_constant_folding(mut self, enable: bool) -> Self {
174        self.enable_constant_folding = enable;
175        self
176    }
177
178    /// Builder method to enable/disable algebraic simplification.
179    pub fn with_algebraic_simplification(mut self, enable: bool) -> Self {
180        self.enable_algebraic_simplification = enable;
181        self
182    }
183
184    /// Builder method to set maximum iterations.
185    pub fn with_max_iterations(mut self, max: usize) -> Self {
186        self.max_iterations = max;
187        self
188    }
189
190    /// Builder method to enable/disable fixed-point detection.
191    pub fn with_stop_on_fixed_point(mut self, stop: bool) -> Self {
192        self.stop_on_fixed_point = stop;
193        self
194    }
195
196    /// Builder method to enable/disable strength reduction.
197    pub fn with_strength_reduction(mut self, enable: bool) -> Self {
198        self.enable_strength_reduction = enable;
199        self
200    }
201
202    /// Builder method to enable/disable distributivity optimization.
203    pub fn with_distributivity(mut self, enable: bool) -> Self {
204        self.enable_distributivity = enable;
205        self
206    }
207
208    /// Builder method to enable/disable quantifier optimization.
209    pub fn with_quantifier_opt(mut self, enable: bool) -> Self {
210        self.enable_quantifier_opt = enable;
211        self
212    }
213
214    /// Builder method to enable/disable dead code elimination.
215    pub fn with_dead_code_elimination(mut self, enable: bool) -> Self {
216        self.enable_dead_code_elimination = enable;
217        self
218    }
219}
220
221/// Statistics from a single pipeline iteration.
222#[derive(Debug, Clone, Default)]
223pub struct IterationStats {
224    /// Negation optimization statistics
225    pub negation: NegationOptStats,
226    /// Constant folding statistics
227    pub constant_folding: ConstantFoldingStats,
228    /// Algebraic simplification statistics
229    pub algebraic: AlgebraicSimplificationStats,
230    /// Strength reduction statistics
231    pub strength_reduction: StrengthReductionStats,
232    /// Distributivity optimization statistics
233    pub distributivity: DistributivityStats,
234    /// Quantifier optimization statistics
235    pub quantifier_opt: QuantifierOptStats,
236    /// Dead code elimination statistics
237    pub dead_code: DeadCodeStats,
238}
239
240impl IterationStats {
241    /// Check if this iteration made any changes.
242    pub fn made_changes(&self) -> bool {
243        self.negation.double_negations_eliminated > 0
244            || self.negation.demorgans_applied > 0
245            || self.negation.quantifier_negations_pushed > 0
246            || self.constant_folding.binary_ops_folded > 0
247            || self.constant_folding.unary_ops_folded > 0
248            || self.algebraic.identities_eliminated > 0
249            || self.algebraic.annihilations_applied > 0
250            || self.algebraic.idempotent_simplified > 0
251            || self.strength_reduction.total_optimizations() > 0
252            || self.distributivity.total_optimizations() > 0
253            || self.quantifier_opt.total_optimizations() > 0
254            || self.dead_code.total_optimizations() > 0
255    }
256
257    /// Get total number of optimizations applied in this iteration.
258    pub fn total_optimizations(&self) -> usize {
259        self.negation.double_negations_eliminated
260            + self.negation.demorgans_applied
261            + self.negation.quantifier_negations_pushed
262            + self.constant_folding.binary_ops_folded
263            + self.constant_folding.unary_ops_folded
264            + self.algebraic.identities_eliminated
265            + self.algebraic.annihilations_applied
266            + self.algebraic.idempotent_simplified
267            + self.strength_reduction.total_optimizations()
268            + self.distributivity.total_optimizations()
269            + self.quantifier_opt.total_optimizations()
270            + self.dead_code.total_optimizations()
271    }
272}
273
274/// Cumulative statistics from all pipeline iterations.
275#[derive(Debug, Clone, Default)]
276pub struct PipelineStats {
277    /// Total number of iterations performed
278    pub total_iterations: usize,
279    /// Negation optimization statistics (accumulated)
280    pub negation: NegationOptStats,
281    /// Constant folding statistics (accumulated)
282    pub constant_folding: ConstantFoldingStats,
283    /// Algebraic simplification statistics (accumulated)
284    pub algebraic: AlgebraicSimplificationStats,
285    /// Strength reduction statistics (accumulated)
286    pub strength_reduction: StrengthReductionStats,
287    /// Distributivity optimization statistics (accumulated)
288    pub distributivity: DistributivityStats,
289    /// Quantifier optimization statistics (accumulated)
290    pub quantifier_opt: QuantifierOptStats,
291    /// Dead code elimination statistics (accumulated)
292    pub dead_code: DeadCodeStats,
293    /// Statistics per iteration
294    pub iterations: Vec<IterationStats>,
295    /// Whether the pipeline reached a fixed point
296    pub reached_fixed_point: bool,
297    /// Whether the pipeline was stopped due to max iterations
298    pub stopped_at_max_iterations: bool,
299}
300
301impl PipelineStats {
302    /// Get total number of optimizations applied across all iterations.
303    pub fn total_optimizations(&self) -> usize {
304        self.negation.double_negations_eliminated
305            + self.negation.demorgans_applied
306            + self.negation.quantifier_negations_pushed
307            + self.constant_folding.binary_ops_folded
308            + self.constant_folding.unary_ops_folded
309            + self.algebraic.identities_eliminated
310            + self.algebraic.annihilations_applied
311            + self.algebraic.idempotent_simplified
312            + self.strength_reduction.total_optimizations()
313            + self.distributivity.total_optimizations()
314            + self.quantifier_opt.total_optimizations()
315            + self.dead_code.total_optimizations()
316    }
317
318    /// Get the most productive iteration (one with most optimizations).
319    pub fn most_productive_iteration(&self) -> Option<(usize, &IterationStats)> {
320        self.iterations
321            .iter()
322            .enumerate()
323            .max_by_key(|(_, stats)| stats.total_optimizations())
324    }
325}
326
327impl std::fmt::Display for PipelineStats {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        writeln!(f, "Pipeline Statistics:")?;
330        writeln!(f, "  Iterations: {}", self.total_iterations)?;
331        writeln!(f, "  Reached fixed point: {}", self.reached_fixed_point)?;
332        writeln!(f, "  Total optimizations: {}", self.total_optimizations())?;
333        writeln!(f, "\nNegation Optimization:")?;
334        writeln!(
335            f,
336            "  Double negations eliminated: {}",
337            self.negation.double_negations_eliminated
338        )?;
339        writeln!(
340            f,
341            "  De Morgan's laws applied: {}",
342            self.negation.demorgans_applied
343        )?;
344        writeln!(
345            f,
346            "  Quantifier negations pushed: {}",
347            self.negation.quantifier_negations_pushed
348        )?;
349        writeln!(f, "\nConstant Folding:")?;
350        writeln!(
351            f,
352            "  Binary ops folded: {}",
353            self.constant_folding.binary_ops_folded
354        )?;
355        writeln!(
356            f,
357            "  Unary ops folded: {}",
358            self.constant_folding.unary_ops_folded
359        )?;
360        writeln!(f, "\nAlgebraic Simplification:")?;
361        writeln!(
362            f,
363            "  Identities eliminated: {}",
364            self.algebraic.identities_eliminated
365        )?;
366        writeln!(
367            f,
368            "  Annihilations applied: {}",
369            self.algebraic.annihilations_applied
370        )?;
371        writeln!(
372            f,
373            "  Idempotent simplified: {}",
374            self.algebraic.idempotent_simplified
375        )?;
376        writeln!(f, "\nStrength Reduction:")?;
377        writeln!(
378            f,
379            "  Power reductions: {}",
380            self.strength_reduction.power_reductions
381        )?;
382        writeln!(
383            f,
384            "  Operations eliminated: {}",
385            self.strength_reduction.operations_eliminated
386        )?;
387        writeln!(
388            f,
389            "  Special function optimizations: {}",
390            self.strength_reduction.special_function_optimizations
391        )?;
392        writeln!(f, "\nDistributivity:")?;
393        writeln!(
394            f,
395            "  Expressions factored: {}",
396            self.distributivity.expressions_factored
397        )?;
398        writeln!(
399            f,
400            "  Expressions expanded: {}",
401            self.distributivity.expressions_expanded
402        )?;
403        writeln!(f, "\nQuantifier Optimization:")?;
404        writeln!(
405            f,
406            "  Invariants hoisted: {}",
407            self.quantifier_opt.invariants_hoisted
408        )?;
409        writeln!(
410            f,
411            "  Quantifiers reordered: {}",
412            self.quantifier_opt.quantifiers_reordered
413        )?;
414        writeln!(f, "\nDead Code Elimination:")?;
415        writeln!(
416            f,
417            "  Branches eliminated: {}",
418            self.dead_code.branches_eliminated
419        )?;
420        writeln!(f, "  Short circuits: {}", self.dead_code.short_circuits)?;
421        writeln!(
422            f,
423            "  Unused quantifiers removed: {}",
424            self.dead_code.unused_quantifiers_removed
425        )?;
426        Ok(())
427    }
428}
429
430/// Multi-pass optimization pipeline for TLExpr expressions.
431///
432/// The pipeline applies 7 optimization passes in sequence, iterating
433/// until a fixed point is reached or the maximum number of iterations is hit.
434///
435/// # Pass Order
436///
437/// 1. **Negation optimization**: Applies De Morgan's laws and eliminates
438///    double negations. This exposes more opportunities for subsequent passes.
439///
440/// 2. **Constant folding**: Evaluates constant expressions at compile time.
441///    This creates simpler expressions with fewer operations.
442///
443/// 3. **Algebraic simplification**: Applies mathematical identities like
444///    x + 0 = x, x * 1 = x, etc. This can create new constants for folding.
445///
446/// 4. **Strength reduction**: Replaces expensive operations with cheaper
447///    equivalents (e.g., x^2 → x*x, exp(log(x)) → x).
448///
449/// 5. **Distributivity**: Factors common subexpressions to reduce redundant
450///    computation (e.g., a*b + a*c → a*(b+c)).
451///
452/// 6. **Quantifier optimization**: Hoists loop-invariant expressions out of
453///    quantifiers (e.g., ∃x.(a + p(x)) → a + ∃x.p(x)).
454///
455/// 7. **Dead code elimination**: Removes unreachable code and eliminates
456///    branches with constant conditions (e.g., if true then A else B → A).
457///
458/// # Fixed Point Detection
459///
460/// The pipeline tracks whether each pass makes changes. If an entire iteration
461/// produces no changes (i.e., the expression is unchanged), a fixed point has
462/// been reached and optimization stops early.
463///
464/// # Examples
465///
466/// ```
467/// use tensorlogic_compiler::optimize::{OptimizationPipeline, PipelineConfig};
468/// use tensorlogic_ir::{TLExpr, Term};
469///
470/// // Default pipeline
471/// let pipeline = OptimizationPipeline::new();
472///
473/// // Custom configuration
474/// let config = PipelineConfig::default()
475///     .with_max_iterations(5)
476///     .with_constant_folding(true);
477/// let pipeline = OptimizationPipeline::with_config(config);
478///
479/// // Optimize an expression
480/// let expr = TLExpr::add(
481///     TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
482///     TLExpr::Constant(0.0)
483/// );
484/// let (optimized, stats) = pipeline.optimize(&expr);
485/// ```
486pub struct OptimizationPipeline {
487    config: PipelineConfig,
488}
489
490impl OptimizationPipeline {
491    /// Create a new optimization pipeline with default configuration.
492    pub fn new() -> Self {
493        Self {
494            config: PipelineConfig::default(),
495        }
496    }
497
498    /// Create a new optimization pipeline with custom configuration.
499    pub fn with_config(config: PipelineConfig) -> Self {
500        Self { config }
501    }
502
503    /// Optimize an expression using the configured pipeline.
504    ///
505    /// Returns the optimized expression and statistics about the optimizations applied.
506    pub fn optimize(&self, expr: &TLExpr) -> (TLExpr, PipelineStats) {
507        let mut current = expr.clone();
508        let mut stats = PipelineStats::default();
509
510        for iteration in 0..self.config.max_iterations {
511            let mut iter_stats = IterationStats::default();
512            let mut changed = false;
513
514            // Pass 1: Negation optimization
515            if self.config.enable_negation_opt {
516                let (optimized, neg_stats) = optimize_negations(&current);
517                iter_stats.negation = neg_stats;
518
519                if optimized != current {
520                    current = optimized;
521                    changed = true;
522                }
523            }
524
525            // Pass 2: Constant folding
526            if self.config.enable_constant_folding {
527                let (optimized, fold_stats) = fold_constants(&current);
528                iter_stats.constant_folding = fold_stats;
529
530                if optimized != current {
531                    current = optimized;
532                    changed = true;
533                }
534            }
535
536            // Pass 3: Algebraic simplification
537            if self.config.enable_algebraic_simplification {
538                let (optimized, alg_stats) = simplify_algebraic(&current);
539                iter_stats.algebraic = alg_stats;
540
541                if optimized != current {
542                    current = optimized;
543                    changed = true;
544                }
545            }
546
547            // Pass 4: Strength reduction
548            if self.config.enable_strength_reduction {
549                let (optimized, sr_stats) = reduce_strength(&current);
550                iter_stats.strength_reduction = sr_stats;
551
552                if optimized != current {
553                    current = optimized;
554                    changed = true;
555                }
556            }
557
558            // Pass 5: Distributivity optimization
559            if self.config.enable_distributivity {
560                let (optimized, dist_stats) = optimize_distributivity(&current);
561                iter_stats.distributivity = dist_stats;
562
563                if optimized != current {
564                    current = optimized;
565                    changed = true;
566                }
567            }
568
569            // Pass 6: Quantifier optimization
570            if self.config.enable_quantifier_opt {
571                let (optimized, quant_stats) = optimize_quantifiers(&current);
572                iter_stats.quantifier_opt = quant_stats;
573
574                if optimized != current {
575                    current = optimized;
576                    changed = true;
577                }
578            }
579
580            // Pass 7: Dead code elimination
581            if self.config.enable_dead_code_elimination {
582                let (optimized, dead_stats) = eliminate_dead_code(&current);
583                iter_stats.dead_code = dead_stats;
584
585                if optimized != current {
586                    current = optimized;
587                    changed = true;
588                }
589            }
590
591            // Accumulate statistics
592            stats.total_iterations = iteration + 1;
593            stats.negation.double_negations_eliminated +=
594                iter_stats.negation.double_negations_eliminated;
595            stats.negation.demorgans_applied += iter_stats.negation.demorgans_applied;
596            stats.negation.quantifier_negations_pushed +=
597                iter_stats.negation.quantifier_negations_pushed;
598            stats.constant_folding.binary_ops_folded +=
599                iter_stats.constant_folding.binary_ops_folded;
600            stats.constant_folding.unary_ops_folded += iter_stats.constant_folding.unary_ops_folded;
601            stats.constant_folding.total_processed += iter_stats.constant_folding.total_processed;
602            stats.algebraic.identities_eliminated += iter_stats.algebraic.identities_eliminated;
603            stats.algebraic.annihilations_applied += iter_stats.algebraic.annihilations_applied;
604            stats.algebraic.idempotent_simplified += iter_stats.algebraic.idempotent_simplified;
605            stats.algebraic.total_processed += iter_stats.algebraic.total_processed;
606            stats.strength_reduction.power_reductions +=
607                iter_stats.strength_reduction.power_reductions;
608            stats.strength_reduction.operations_eliminated +=
609                iter_stats.strength_reduction.operations_eliminated;
610            stats.strength_reduction.special_function_optimizations +=
611                iter_stats.strength_reduction.special_function_optimizations;
612            stats.strength_reduction.total_processed +=
613                iter_stats.strength_reduction.total_processed;
614            stats.distributivity.expressions_factored +=
615                iter_stats.distributivity.expressions_factored;
616            stats.distributivity.expressions_expanded +=
617                iter_stats.distributivity.expressions_expanded;
618            stats.distributivity.common_terms_extracted +=
619                iter_stats.distributivity.common_terms_extracted;
620            stats.distributivity.total_processed += iter_stats.distributivity.total_processed;
621            stats.quantifier_opt.invariants_hoisted += iter_stats.quantifier_opt.invariants_hoisted;
622            stats.quantifier_opt.quantifiers_reordered +=
623                iter_stats.quantifier_opt.quantifiers_reordered;
624            stats.quantifier_opt.quantifiers_fused += iter_stats.quantifier_opt.quantifiers_fused;
625            stats.quantifier_opt.total_processed += iter_stats.quantifier_opt.total_processed;
626            stats.dead_code.branches_eliminated += iter_stats.dead_code.branches_eliminated;
627            stats.dead_code.short_circuits += iter_stats.dead_code.short_circuits;
628            stats.dead_code.unused_quantifiers_removed +=
629                iter_stats.dead_code.unused_quantifiers_removed;
630            stats.dead_code.identity_simplifications +=
631                iter_stats.dead_code.identity_simplifications;
632            stats.dead_code.total_processed += iter_stats.dead_code.total_processed;
633            stats.iterations.push(iter_stats);
634
635            // Check for fixed point
636            if self.config.stop_on_fixed_point && !changed {
637                stats.reached_fixed_point = true;
638                break;
639            }
640
641            // Check if we've hit max iterations
642            if iteration + 1 >= self.config.max_iterations {
643                stats.stopped_at_max_iterations = true;
644            }
645        }
646
647        (current, stats)
648    }
649
650    /// Get the current configuration.
651    pub fn config(&self) -> &PipelineConfig {
652        &self.config
653    }
654}
655
656impl Default for OptimizationPipeline {
657    fn default() -> Self {
658        Self::new()
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use tensorlogic_ir::Term;
666
667    #[test]
668    fn test_pipeline_with_all_passes() {
669        // Expression: NOT(AND(x + 0, 2.0 * 3.0))
670        let x = TLExpr::pred("x", vec![Term::var("i")]);
671        let expr = TLExpr::negate(TLExpr::and(
672            TLExpr::add(x, TLExpr::Constant(0.0)),
673            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
674        ));
675
676        let pipeline = OptimizationPipeline::new();
677        let (optimized, stats) = pipeline.optimize(&expr);
678
679        // Should apply multiple optimizations
680        assert!(stats.total_iterations > 0);
681        assert!(stats.constant_folding.binary_ops_folded > 0);
682        assert!(stats.algebraic.identities_eliminated > 0);
683        assert!(stats.negation.demorgans_applied > 0);
684
685        // Should not be the same as original
686        assert!(optimized != expr);
687    }
688
689    #[test]
690    fn test_constant_folding_only() {
691        let expr = TLExpr::add(
692            TLExpr::Constant(2.0),
693            TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
694        );
695
696        let config = PipelineConfig::constant_folding_only();
697        let pipeline = OptimizationPipeline::with_config(config);
698        let (optimized, stats) = pipeline.optimize(&expr);
699
700        // Should fold to 2.0 + 12.0 = 14.0
701        assert!(matches!(optimized, TLExpr::Constant(_)));
702        assert_eq!(stats.constant_folding.binary_ops_folded, 2);
703        assert_eq!(stats.algebraic.identities_eliminated, 0);
704        assert_eq!(stats.negation.demorgans_applied, 0);
705    }
706
707    #[test]
708    fn test_algebraic_only() {
709        let x = TLExpr::pred("x", vec![Term::var("i")]);
710        let expr = TLExpr::mul(TLExpr::add(x, TLExpr::Constant(0.0)), TLExpr::Constant(1.0));
711
712        let config = PipelineConfig::algebraic_only();
713        let pipeline = OptimizationPipeline::with_config(config);
714        let (_optimized, stats) = pipeline.optimize(&expr);
715
716        // Should eliminate both identities: x + 0 = x, x * 1 = x
717        assert_eq!(stats.algebraic.identities_eliminated, 2);
718        assert_eq!(stats.constant_folding.binary_ops_folded, 0);
719    }
720
721    #[test]
722    fn test_fixed_point_detection() {
723        // Expression that's already optimal
724        let x = TLExpr::pred("x", vec![Term::var("i")]);
725
726        let config = PipelineConfig::default().with_max_iterations(10);
727        let pipeline = OptimizationPipeline::with_config(config);
728        let (optimized, stats) = pipeline.optimize(&x);
729
730        // Should stop after first iteration (no changes)
731        assert_eq!(stats.total_iterations, 1);
732        assert!(stats.reached_fixed_point);
733        assert!(!stats.stopped_at_max_iterations);
734        assert_eq!(optimized, x);
735    }
736
737    #[test]
738    fn test_max_iterations_limit() {
739        // Create an expression that could benefit from more iterations
740        let x = TLExpr::pred("x", vec![Term::var("i")]);
741        let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(x, TLExpr::Constant(0.0))));
742
743        let config = PipelineConfig::default().with_max_iterations(1);
744        let pipeline = OptimizationPipeline::with_config(config);
745        let (_, stats) = pipeline.optimize(&expr);
746
747        assert_eq!(stats.total_iterations, 1);
748        assert!(stats.stopped_at_max_iterations);
749    }
750
751    #[test]
752    fn test_aggressive_optimization() {
753        // Complex nested expression that requires multiple optimization passes
754        let x = TLExpr::pred("x", vec![Term::var("i")]);
755        // Expression: NOT(AND(NOT(x + 0), NOT((2.0 * 3.0) * x))) + (1.0 * 1.0)
756        let expr = TLExpr::add(
757            TLExpr::negate(TLExpr::and(
758                TLExpr::negate(TLExpr::add(x.clone(), TLExpr::Constant(0.0))),
759                TLExpr::negate(TLExpr::mul(
760                    TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
761                    x,
762                )),
763            )),
764            TLExpr::mul(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
765        );
766
767        let config = PipelineConfig::aggressive();
768        let pipeline = OptimizationPipeline::with_config(config);
769        let (_, stats) = pipeline.optimize(&expr);
770
771        // Should apply multiple optimizations (negation, folding, algebraic)
772        // At least: De Morgan's law, double negations, constant folding, identity elimination
773        assert!(
774            stats.total_optimizations() >= 4,
775            "Expected at least 4 optimizations, got {}",
776            stats.total_optimizations()
777        );
778        assert!(stats.total_iterations >= 1);
779    }
780
781    #[test]
782    fn test_no_optimization() {
783        let x = TLExpr::pred("x", vec![Term::var("i")]);
784        let expr = TLExpr::add(x.clone(), TLExpr::Constant(1.0));
785
786        let config = PipelineConfig::none();
787        let pipeline = OptimizationPipeline::with_config(config);
788        let (optimized, stats) = pipeline.optimize(&expr);
789
790        // Should make no changes
791        assert_eq!(optimized, expr);
792        assert_eq!(stats.total_optimizations(), 0);
793    }
794
795    #[test]
796    fn test_iteration_stats() {
797        let expr = TLExpr::add(
798            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
799            TLExpr::Constant(0.0),
800        );
801
802        let pipeline = OptimizationPipeline::new();
803        let (_, stats) = pipeline.optimize(&expr);
804
805        // Check per-iteration statistics
806        assert!(!stats.iterations.is_empty());
807        assert!(stats.iterations[0].made_changes());
808        assert!(stats.iterations[0].total_optimizations() > 0);
809    }
810
811    #[test]
812    fn test_most_productive_iteration() {
813        let x = TLExpr::pred("x", vec![Term::var("i")]);
814        let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(
815            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
816            TLExpr::mul(x, TLExpr::Constant(1.0)),
817        )));
818
819        let pipeline = OptimizationPipeline::new();
820        let (_, stats) = pipeline.optimize(&expr);
821
822        // Should identify the most productive iteration
823        let (iter_idx, iter_stats) = stats.most_productive_iteration().unwrap();
824        assert!(iter_stats.total_optimizations() > 0);
825        assert!(iter_idx < stats.total_iterations);
826    }
827
828    #[test]
829    fn test_pipeline_display() {
830        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
831        let pipeline = OptimizationPipeline::new();
832        let (_, stats) = pipeline.optimize(&expr);
833
834        // Test Display implementation
835        let output = format!("{}", stats);
836        assert!(output.contains("Pipeline Statistics"));
837        assert!(output.contains("Iterations:"));
838        assert!(output.contains("Total optimizations:"));
839    }
840
841    #[test]
842    fn test_builder_pattern() {
843        let config = PipelineConfig::default()
844            .with_negation_opt(false)
845            .with_constant_folding(true)
846            .with_algebraic_simplification(false)
847            .with_max_iterations(5)
848            .with_stop_on_fixed_point(false);
849
850        assert!(!config.enable_negation_opt);
851        assert!(config.enable_constant_folding);
852        assert!(!config.enable_algebraic_simplification);
853        assert_eq!(config.max_iterations, 5);
854        assert!(!config.stop_on_fixed_point);
855    }
856
857    #[test]
858    fn test_complex_real_world_expression() {
859        // Softmax-like expression: exp((x - max) / 1.0) where temperature = 1.0
860        let x = TLExpr::pred("x", vec![Term::var("i")]);
861        let max_val = TLExpr::pred("max", vec![]);
862        let temp = TLExpr::Constant(1.0);
863
864        let expr = TLExpr::exp(TLExpr::div(TLExpr::sub(x, max_val), temp));
865
866        let pipeline = OptimizationPipeline::new();
867        let (optimized, stats) = pipeline.optimize(&expr);
868
869        // Should eliminate division by 1.0
870        assert!(stats.algebraic.identities_eliminated > 0);
871        assert!(optimized != expr);
872    }
873
874    #[test]
875    fn test_dead_code_elimination_integration() {
876        // Expression with dead branches: if true then A else B → A
877        let a = TLExpr::pred("a", vec![Term::var("i")]);
878        let b = TLExpr::pred("b", vec![Term::var("i")]);
879        let expr = TLExpr::IfThenElse {
880            condition: Box::new(TLExpr::Constant(1.0)), // Always true
881            then_branch: Box::new(a.clone()),
882            else_branch: Box::new(b),
883        };
884
885        let pipeline = OptimizationPipeline::new();
886        let (optimized, stats) = pipeline.optimize(&expr);
887
888        // Should eliminate the dead else branch
889        assert!(stats.dead_code.branches_eliminated > 0);
890        // Should be simplified to just 'a'
891        assert!(matches!(optimized, TLExpr::Pred { .. }));
892    }
893
894    #[test]
895    fn test_all_passes_together() {
896        // Complex expression benefiting from all passes
897        // if true then (NOT(NOT(x^2 + 0)) AND (a*b + a*c)) else FALSE
898        let x = TLExpr::pred("x", vec![Term::var("i")]);
899        let a = TLExpr::pred("a", vec![Term::var("i")]);
900        let b = TLExpr::pred("b", vec![Term::var("i")]);
901        let c = TLExpr::pred("c", vec![Term::var("i")]);
902
903        let expr = TLExpr::IfThenElse {
904            condition: Box::new(TLExpr::Constant(1.0)),
905            then_branch: Box::new(TLExpr::and(
906                TLExpr::negate(TLExpr::negate(TLExpr::add(
907                    TLExpr::pow(x, TLExpr::Constant(2.0)),
908                    TLExpr::Constant(0.0),
909                ))),
910                TLExpr::add(
911                    TLExpr::mul(a.clone(), b.clone()),
912                    TLExpr::mul(a.clone(), c.clone()),
913                ),
914            )),
915            else_branch: Box::new(TLExpr::Constant(0.0)),
916        };
917
918        let pipeline = OptimizationPipeline::new();
919        let (_, stats) = pipeline.optimize(&expr);
920
921        // Should apply multiple passes:
922        // - Dead code elimination (remove else branch)
923        // - Negation optimization (double negation)
924        // - Algebraic simplification (x + 0 → x)
925        // - Strength reduction (x^2 → x*x)
926        // - Distributivity (a*b + a*c → a*(b+c))
927        assert!(
928            stats.dead_code.branches_eliminated > 0,
929            "Dead code elimination should apply"
930        );
931        assert!(
932            stats.negation.double_negations_eliminated > 0,
933            "Negation optimization should apply"
934        );
935        assert!(
936            stats.algebraic.identities_eliminated > 0,
937            "Algebraic simplification should apply"
938        );
939        assert!(
940            stats.strength_reduction.power_reductions > 0,
941            "Strength reduction should apply"
942        );
943        assert!(
944            stats.distributivity.expressions_factored > 0,
945            "Distributivity should apply"
946        );
947
948        // Total should be significant
949        assert!(
950            stats.total_optimizations() >= 5,
951            "Should apply at least 5 optimizations"
952        );
953    }
954}