tensorlogic_compiler/optimize/
pipeline.rs

1//! Multi-pass optimization pipeline for TLExpr expressions.
2//!
3//! This module provides a unified optimization pipeline that combines multiple
4//! optimization passes and applies them iteratively until a fixed point is reached.
5//!
6//! # Architecture
7//!
8//! The pipeline applies optimizations in this order:
9//! 1. **Negation optimization**: Push negations inward using De Morgan's laws
10//! 2. **Constant folding**: Evaluate constant expressions at compile time
11//! 3. **Algebraic simplification**: Apply mathematical identities
12//!
13//! This order is chosen because:
14//! - Negation optimization can expose more opportunities for other passes
15//! - Constant folding creates simpler expressions for algebraic simplification
16//! - Algebraic simplification can create new constants for the next iteration
17//!
18//! # Examples
19//!
20//! ```
21//! use tensorlogic_compiler::optimize::{OptimizationPipeline, PipelineConfig};
22//! use tensorlogic_ir::{TLExpr, Term};
23//!
24//! // Create a pipeline with default configuration
25//! let pipeline = OptimizationPipeline::new();
26//!
27//! // Optimize an expression: NOT(AND(x + 0, 2.0 * 3.0))
28//! let x = TLExpr::pred("x", vec![Term::var("i")]);
29//! let expr = TLExpr::negate(TLExpr::and(
30//!     TLExpr::add(x, TLExpr::Constant(0.0)),
31//!     TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
32//! ));
33//!
34//! let (optimized, stats) = pipeline.optimize(&expr);
35//!
36//! // Pipeline applies multiple passes and reports statistics
37//! assert!(stats.total_iterations > 0);
38//! assert!(stats.constant_folding.binary_ops_folded > 0);
39//! assert!(stats.algebraic.identities_eliminated > 0);
40//! ```
41
42use super::{
43    algebraic::{simplify_algebraic, AlgebraicSimplificationStats},
44    constant_folding::{fold_constants, ConstantFoldingStats},
45    negation::{optimize_negations, NegationOptStats},
46};
47use tensorlogic_ir::TLExpr;
48
49/// Configuration for the optimization pipeline.
50#[derive(Debug, Clone)]
51pub struct PipelineConfig {
52    /// Enable negation optimization pass
53    pub enable_negation_opt: bool,
54    /// Enable constant folding pass
55    pub enable_constant_folding: bool,
56    /// Enable algebraic simplification pass
57    pub enable_algebraic_simplification: bool,
58    /// Maximum number of iterations before stopping
59    pub max_iterations: usize,
60    /// Stop early if an iteration makes no changes
61    pub stop_on_fixed_point: bool,
62}
63
64impl Default for PipelineConfig {
65    fn default() -> Self {
66        Self {
67            enable_negation_opt: true,
68            enable_constant_folding: true,
69            enable_algebraic_simplification: true,
70            max_iterations: 10,
71            stop_on_fixed_point: true,
72        }
73    }
74}
75
76impl PipelineConfig {
77    /// Create a configuration with all optimizations enabled.
78    pub fn all() -> Self {
79        Self::default()
80    }
81
82    /// Create a configuration with all optimizations disabled.
83    pub fn none() -> Self {
84        Self {
85            enable_negation_opt: false,
86            enable_constant_folding: false,
87            enable_algebraic_simplification: false,
88            max_iterations: 1,
89            stop_on_fixed_point: true,
90        }
91    }
92
93    /// Create a configuration with only constant folding enabled.
94    pub fn constant_folding_only() -> Self {
95        Self {
96            enable_negation_opt: false,
97            enable_constant_folding: true,
98            enable_algebraic_simplification: false,
99            max_iterations: 1,
100            stop_on_fixed_point: true,
101        }
102    }
103
104    /// Create a configuration with only algebraic simplification enabled.
105    pub fn algebraic_only() -> Self {
106        Self {
107            enable_negation_opt: false,
108            enable_constant_folding: false,
109            enable_algebraic_simplification: true,
110            max_iterations: 1,
111            stop_on_fixed_point: true,
112        }
113    }
114
115    /// Create a configuration for aggressive optimization (more iterations).
116    pub fn aggressive() -> Self {
117        Self {
118            enable_negation_opt: true,
119            enable_constant_folding: true,
120            enable_algebraic_simplification: true,
121            max_iterations: 20,
122            stop_on_fixed_point: true,
123        }
124    }
125
126    /// Builder method to enable/disable negation optimization.
127    pub fn with_negation_opt(mut self, enable: bool) -> Self {
128        self.enable_negation_opt = enable;
129        self
130    }
131
132    /// Builder method to enable/disable constant folding.
133    pub fn with_constant_folding(mut self, enable: bool) -> Self {
134        self.enable_constant_folding = enable;
135        self
136    }
137
138    /// Builder method to enable/disable algebraic simplification.
139    pub fn with_algebraic_simplification(mut self, enable: bool) -> Self {
140        self.enable_algebraic_simplification = enable;
141        self
142    }
143
144    /// Builder method to set maximum iterations.
145    pub fn with_max_iterations(mut self, max: usize) -> Self {
146        self.max_iterations = max;
147        self
148    }
149
150    /// Builder method to enable/disable fixed-point detection.
151    pub fn with_stop_on_fixed_point(mut self, stop: bool) -> Self {
152        self.stop_on_fixed_point = stop;
153        self
154    }
155}
156
157/// Statistics from a single pipeline iteration.
158#[derive(Debug, Clone, Default)]
159pub struct IterationStats {
160    /// Negation optimization statistics
161    pub negation: NegationOptStats,
162    /// Constant folding statistics
163    pub constant_folding: ConstantFoldingStats,
164    /// Algebraic simplification statistics
165    pub algebraic: AlgebraicSimplificationStats,
166}
167
168impl IterationStats {
169    /// Check if this iteration made any changes.
170    pub fn made_changes(&self) -> bool {
171        self.negation.double_negations_eliminated > 0
172            || self.negation.demorgans_applied > 0
173            || self.negation.quantifier_negations_pushed > 0
174            || self.constant_folding.binary_ops_folded > 0
175            || self.constant_folding.unary_ops_folded > 0
176            || self.algebraic.identities_eliminated > 0
177            || self.algebraic.annihilations_applied > 0
178            || self.algebraic.idempotent_simplified > 0
179    }
180
181    /// Get total number of optimizations applied in this iteration.
182    pub fn total_optimizations(&self) -> usize {
183        self.negation.double_negations_eliminated
184            + self.negation.demorgans_applied
185            + self.negation.quantifier_negations_pushed
186            + self.constant_folding.binary_ops_folded
187            + self.constant_folding.unary_ops_folded
188            + self.algebraic.identities_eliminated
189            + self.algebraic.annihilations_applied
190            + self.algebraic.idempotent_simplified
191    }
192}
193
194/// Cumulative statistics from all pipeline iterations.
195#[derive(Debug, Clone, Default)]
196pub struct PipelineStats {
197    /// Total number of iterations performed
198    pub total_iterations: usize,
199    /// Negation optimization statistics (accumulated)
200    pub negation: NegationOptStats,
201    /// Constant folding statistics (accumulated)
202    pub constant_folding: ConstantFoldingStats,
203    /// Algebraic simplification statistics (accumulated)
204    pub algebraic: AlgebraicSimplificationStats,
205    /// Statistics per iteration
206    pub iterations: Vec<IterationStats>,
207    /// Whether the pipeline reached a fixed point
208    pub reached_fixed_point: bool,
209    /// Whether the pipeline was stopped due to max iterations
210    pub stopped_at_max_iterations: bool,
211}
212
213impl PipelineStats {
214    /// Get total number of optimizations applied across all iterations.
215    pub fn total_optimizations(&self) -> usize {
216        self.negation.double_negations_eliminated
217            + self.negation.demorgans_applied
218            + self.negation.quantifier_negations_pushed
219            + self.constant_folding.binary_ops_folded
220            + self.constant_folding.unary_ops_folded
221            + self.algebraic.identities_eliminated
222            + self.algebraic.annihilations_applied
223            + self.algebraic.idempotent_simplified
224    }
225
226    /// Get the most productive iteration (one with most optimizations).
227    pub fn most_productive_iteration(&self) -> Option<(usize, &IterationStats)> {
228        self.iterations
229            .iter()
230            .enumerate()
231            .max_by_key(|(_, stats)| stats.total_optimizations())
232    }
233}
234
235impl std::fmt::Display for PipelineStats {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        writeln!(f, "Pipeline Statistics:")?;
238        writeln!(f, "  Iterations: {}", self.total_iterations)?;
239        writeln!(f, "  Reached fixed point: {}", self.reached_fixed_point)?;
240        writeln!(f, "  Total optimizations: {}", self.total_optimizations())?;
241        writeln!(f, "\nNegation Optimization:")?;
242        writeln!(
243            f,
244            "  Double negations eliminated: {}",
245            self.negation.double_negations_eliminated
246        )?;
247        writeln!(
248            f,
249            "  De Morgan's laws applied: {}",
250            self.negation.demorgans_applied
251        )?;
252        writeln!(
253            f,
254            "  Quantifier negations pushed: {}",
255            self.negation.quantifier_negations_pushed
256        )?;
257        writeln!(f, "\nConstant Folding:")?;
258        writeln!(
259            f,
260            "  Binary ops folded: {}",
261            self.constant_folding.binary_ops_folded
262        )?;
263        writeln!(
264            f,
265            "  Unary ops folded: {}",
266            self.constant_folding.unary_ops_folded
267        )?;
268        writeln!(f, "\nAlgebraic Simplification:")?;
269        writeln!(
270            f,
271            "  Identities eliminated: {}",
272            self.algebraic.identities_eliminated
273        )?;
274        writeln!(
275            f,
276            "  Annihilations applied: {}",
277            self.algebraic.annihilations_applied
278        )?;
279        writeln!(
280            f,
281            "  Idempotent simplified: {}",
282            self.algebraic.idempotent_simplified
283        )?;
284        Ok(())
285    }
286}
287
288/// Multi-pass optimization pipeline for TLExpr expressions.
289///
290/// The pipeline applies multiple optimization passes in sequence, iterating
291/// until a fixed point is reached or the maximum number of iterations is hit.
292///
293/// # Pass Order
294///
295/// 1. **Negation optimization**: Applies De Morgan's laws and eliminates
296///    double negations. This exposes more opportunities for subsequent passes.
297///
298/// 2. **Constant folding**: Evaluates constant expressions at compile time.
299///    This creates simpler expressions with fewer operations.
300///
301/// 3. **Algebraic simplification**: Applies mathematical identities like
302///    x + 0 = x, x * 1 = x, etc. This can create new constants for folding.
303///
304/// # Fixed Point Detection
305///
306/// The pipeline tracks whether each pass makes changes. If an entire iteration
307/// produces no changes (i.e., the expression is unchanged), a fixed point has
308/// been reached and optimization stops early.
309///
310/// # Examples
311///
312/// ```
313/// use tensorlogic_compiler::optimize::{OptimizationPipeline, PipelineConfig};
314/// use tensorlogic_ir::{TLExpr, Term};
315///
316/// // Default pipeline
317/// let pipeline = OptimizationPipeline::new();
318///
319/// // Custom configuration
320/// let config = PipelineConfig::default()
321///     .with_max_iterations(5)
322///     .with_constant_folding(true);
323/// let pipeline = OptimizationPipeline::with_config(config);
324///
325/// // Optimize an expression
326/// let expr = TLExpr::add(
327///     TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
328///     TLExpr::Constant(0.0)
329/// );
330/// let (optimized, stats) = pipeline.optimize(&expr);
331/// ```
332pub struct OptimizationPipeline {
333    config: PipelineConfig,
334}
335
336impl OptimizationPipeline {
337    /// Create a new optimization pipeline with default configuration.
338    pub fn new() -> Self {
339        Self {
340            config: PipelineConfig::default(),
341        }
342    }
343
344    /// Create a new optimization pipeline with custom configuration.
345    pub fn with_config(config: PipelineConfig) -> Self {
346        Self { config }
347    }
348
349    /// Optimize an expression using the configured pipeline.
350    ///
351    /// Returns the optimized expression and statistics about the optimizations applied.
352    pub fn optimize(&self, expr: &TLExpr) -> (TLExpr, PipelineStats) {
353        let mut current = expr.clone();
354        let mut stats = PipelineStats::default();
355
356        for iteration in 0..self.config.max_iterations {
357            let mut iter_stats = IterationStats::default();
358            let mut changed = false;
359
360            // Pass 1: Negation optimization
361            if self.config.enable_negation_opt {
362                let (optimized, neg_stats) = optimize_negations(&current);
363                iter_stats.negation = neg_stats;
364
365                if optimized != current {
366                    current = optimized;
367                    changed = true;
368                }
369            }
370
371            // Pass 2: Constant folding
372            if self.config.enable_constant_folding {
373                let (optimized, fold_stats) = fold_constants(&current);
374                iter_stats.constant_folding = fold_stats;
375
376                if optimized != current {
377                    current = optimized;
378                    changed = true;
379                }
380            }
381
382            // Pass 3: Algebraic simplification
383            if self.config.enable_algebraic_simplification {
384                let (optimized, alg_stats) = simplify_algebraic(&current);
385                iter_stats.algebraic = alg_stats;
386
387                if optimized != current {
388                    current = optimized;
389                    changed = true;
390                }
391            }
392
393            // Accumulate statistics
394            stats.total_iterations = iteration + 1;
395            stats.negation.double_negations_eliminated +=
396                iter_stats.negation.double_negations_eliminated;
397            stats.negation.demorgans_applied += iter_stats.negation.demorgans_applied;
398            stats.negation.quantifier_negations_pushed +=
399                iter_stats.negation.quantifier_negations_pushed;
400            stats.constant_folding.binary_ops_folded +=
401                iter_stats.constant_folding.binary_ops_folded;
402            stats.constant_folding.unary_ops_folded += iter_stats.constant_folding.unary_ops_folded;
403            stats.constant_folding.total_processed += iter_stats.constant_folding.total_processed;
404            stats.algebraic.identities_eliminated += iter_stats.algebraic.identities_eliminated;
405            stats.algebraic.annihilations_applied += iter_stats.algebraic.annihilations_applied;
406            stats.algebraic.idempotent_simplified += iter_stats.algebraic.idempotent_simplified;
407            stats.algebraic.total_processed += iter_stats.algebraic.total_processed;
408            stats.iterations.push(iter_stats);
409
410            // Check for fixed point
411            if self.config.stop_on_fixed_point && !changed {
412                stats.reached_fixed_point = true;
413                break;
414            }
415
416            // Check if we've hit max iterations
417            if iteration + 1 >= self.config.max_iterations {
418                stats.stopped_at_max_iterations = true;
419            }
420        }
421
422        (current, stats)
423    }
424
425    /// Get the current configuration.
426    pub fn config(&self) -> &PipelineConfig {
427        &self.config
428    }
429}
430
431impl Default for OptimizationPipeline {
432    fn default() -> Self {
433        Self::new()
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use tensorlogic_ir::Term;
441
442    #[test]
443    fn test_pipeline_with_all_passes() {
444        // Expression: NOT(AND(x + 0, 2.0 * 3.0))
445        let x = TLExpr::pred("x", vec![Term::var("i")]);
446        let expr = TLExpr::negate(TLExpr::and(
447            TLExpr::add(x, TLExpr::Constant(0.0)),
448            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
449        ));
450
451        let pipeline = OptimizationPipeline::new();
452        let (optimized, stats) = pipeline.optimize(&expr);
453
454        // Should apply multiple optimizations
455        assert!(stats.total_iterations > 0);
456        assert!(stats.constant_folding.binary_ops_folded > 0);
457        assert!(stats.algebraic.identities_eliminated > 0);
458        assert!(stats.negation.demorgans_applied > 0);
459
460        // Should not be the same as original
461        assert!(optimized != expr);
462    }
463
464    #[test]
465    fn test_constant_folding_only() {
466        let expr = TLExpr::add(
467            TLExpr::Constant(2.0),
468            TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
469        );
470
471        let config = PipelineConfig::constant_folding_only();
472        let pipeline = OptimizationPipeline::with_config(config);
473        let (optimized, stats) = pipeline.optimize(&expr);
474
475        // Should fold to 2.0 + 12.0 = 14.0
476        assert!(matches!(optimized, TLExpr::Constant(_)));
477        assert_eq!(stats.constant_folding.binary_ops_folded, 2);
478        assert_eq!(stats.algebraic.identities_eliminated, 0);
479        assert_eq!(stats.negation.demorgans_applied, 0);
480    }
481
482    #[test]
483    fn test_algebraic_only() {
484        let x = TLExpr::pred("x", vec![Term::var("i")]);
485        let expr = TLExpr::mul(TLExpr::add(x, TLExpr::Constant(0.0)), TLExpr::Constant(1.0));
486
487        let config = PipelineConfig::algebraic_only();
488        let pipeline = OptimizationPipeline::with_config(config);
489        let (_optimized, stats) = pipeline.optimize(&expr);
490
491        // Should eliminate both identities: x + 0 = x, x * 1 = x
492        assert_eq!(stats.algebraic.identities_eliminated, 2);
493        assert_eq!(stats.constant_folding.binary_ops_folded, 0);
494    }
495
496    #[test]
497    fn test_fixed_point_detection() {
498        // Expression that's already optimal
499        let x = TLExpr::pred("x", vec![Term::var("i")]);
500
501        let config = PipelineConfig::default().with_max_iterations(10);
502        let pipeline = OptimizationPipeline::with_config(config);
503        let (optimized, stats) = pipeline.optimize(&x);
504
505        // Should stop after first iteration (no changes)
506        assert_eq!(stats.total_iterations, 1);
507        assert!(stats.reached_fixed_point);
508        assert!(!stats.stopped_at_max_iterations);
509        assert_eq!(optimized, x);
510    }
511
512    #[test]
513    fn test_max_iterations_limit() {
514        // Create an expression that could benefit from more iterations
515        let x = TLExpr::pred("x", vec![Term::var("i")]);
516        let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(x, TLExpr::Constant(0.0))));
517
518        let config = PipelineConfig::default().with_max_iterations(1);
519        let pipeline = OptimizationPipeline::with_config(config);
520        let (_, stats) = pipeline.optimize(&expr);
521
522        assert_eq!(stats.total_iterations, 1);
523        assert!(stats.stopped_at_max_iterations);
524    }
525
526    #[test]
527    fn test_aggressive_optimization() {
528        // Complex nested expression that requires multiple optimization passes
529        let x = TLExpr::pred("x", vec![Term::var("i")]);
530        // Expression: NOT(AND(NOT(x + 0), NOT((2.0 * 3.0) * x))) + (1.0 * 1.0)
531        let expr = TLExpr::add(
532            TLExpr::negate(TLExpr::and(
533                TLExpr::negate(TLExpr::add(x.clone(), TLExpr::Constant(0.0))),
534                TLExpr::negate(TLExpr::mul(
535                    TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
536                    x,
537                )),
538            )),
539            TLExpr::mul(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
540        );
541
542        let config = PipelineConfig::aggressive();
543        let pipeline = OptimizationPipeline::with_config(config);
544        let (_, stats) = pipeline.optimize(&expr);
545
546        // Should apply multiple optimizations (negation, folding, algebraic)
547        // At least: De Morgan's law, double negations, constant folding, identity elimination
548        assert!(
549            stats.total_optimizations() >= 4,
550            "Expected at least 4 optimizations, got {}",
551            stats.total_optimizations()
552        );
553        assert!(stats.total_iterations >= 1);
554    }
555
556    #[test]
557    fn test_no_optimization() {
558        let x = TLExpr::pred("x", vec![Term::var("i")]);
559        let expr = TLExpr::add(x.clone(), TLExpr::Constant(1.0));
560
561        let config = PipelineConfig::none();
562        let pipeline = OptimizationPipeline::with_config(config);
563        let (optimized, stats) = pipeline.optimize(&expr);
564
565        // Should make no changes
566        assert_eq!(optimized, expr);
567        assert_eq!(stats.total_optimizations(), 0);
568    }
569
570    #[test]
571    fn test_iteration_stats() {
572        let expr = TLExpr::add(
573            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
574            TLExpr::Constant(0.0),
575        );
576
577        let pipeline = OptimizationPipeline::new();
578        let (_, stats) = pipeline.optimize(&expr);
579
580        // Check per-iteration statistics
581        assert!(!stats.iterations.is_empty());
582        assert!(stats.iterations[0].made_changes());
583        assert!(stats.iterations[0].total_optimizations() > 0);
584    }
585
586    #[test]
587    fn test_most_productive_iteration() {
588        let x = TLExpr::pred("x", vec![Term::var("i")]);
589        let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(
590            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
591            TLExpr::mul(x, TLExpr::Constant(1.0)),
592        )));
593
594        let pipeline = OptimizationPipeline::new();
595        let (_, stats) = pipeline.optimize(&expr);
596
597        // Should identify the most productive iteration
598        let (iter_idx, iter_stats) = stats.most_productive_iteration().unwrap();
599        assert!(iter_stats.total_optimizations() > 0);
600        assert!(iter_idx < stats.total_iterations);
601    }
602
603    #[test]
604    fn test_pipeline_display() {
605        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
606        let pipeline = OptimizationPipeline::new();
607        let (_, stats) = pipeline.optimize(&expr);
608
609        // Test Display implementation
610        let output = format!("{}", stats);
611        assert!(output.contains("Pipeline Statistics"));
612        assert!(output.contains("Iterations:"));
613        assert!(output.contains("Total optimizations:"));
614    }
615
616    #[test]
617    fn test_builder_pattern() {
618        let config = PipelineConfig::default()
619            .with_negation_opt(false)
620            .with_constant_folding(true)
621            .with_algebraic_simplification(false)
622            .with_max_iterations(5)
623            .with_stop_on_fixed_point(false);
624
625        assert!(!config.enable_negation_opt);
626        assert!(config.enable_constant_folding);
627        assert!(!config.enable_algebraic_simplification);
628        assert_eq!(config.max_iterations, 5);
629        assert!(!config.stop_on_fixed_point);
630    }
631
632    #[test]
633    fn test_complex_real_world_expression() {
634        // Softmax-like expression: exp((x - max) / 1.0) where temperature = 1.0
635        let x = TLExpr::pred("x", vec![Term::var("i")]);
636        let max_val = TLExpr::pred("max", vec![]);
637        let temp = TLExpr::Constant(1.0);
638
639        let expr = TLExpr::exp(TLExpr::div(TLExpr::sub(x, max_val), temp));
640
641        let pipeline = OptimizationPipeline::new();
642        let (optimized, stats) = pipeline.optimize(&expr);
643
644        // Should eliminate division by 1.0
645        assert!(stats.algebraic.identities_eliminated > 0);
646        assert!(optimized != expr);
647    }
648}