Skip to main content

tensorlogic_ir/expr/
optimization_pipeline.rs

1//! Optimization pipeline orchestration for expressions.
2//!
3//! This module provides a high-level optimization pipeline that automatically
4//! orders and applies multiple optimization passes to expressions, tracking
5//! metrics and ensuring convergence.
6//!
7//! # Architecture
8//!
9//! The pipeline consists of:
10//! - **Optimization passes**: Individual transformation functions
11//! - **Pass ordering**: Automatic determination of pass order
12//! - **Convergence detection**: Stopping when no more changes occur
13//! - **Metrics tracking**: Recording improvements and performance
14//!
15//! # Example
16//!
17//! ```rust
18//! use tensorlogic_ir::{TLExpr, Term, OptimizationPipeline, OptimizationLevel};
19//!
20//! let expr = TLExpr::and(
21//!     TLExpr::constant(1.0),
22//!     TLExpr::pred("P", vec![Term::var("x")])
23//! );
24//!
25//! let pipeline = OptimizationPipeline::default();
26//! let (optimized, metrics) = pipeline.optimize(expr);
27//! println!("Applied {} passes", metrics.passes_applied);
28//! ```
29
30use std::collections::HashMap;
31
32use super::{
33    distributive_laws::{apply_distributive_laws, DistributiveStrategy},
34    modal_equivalences::apply_modal_equivalences,
35    normal_forms::to_nnf,
36    optimization::{algebraic_simplify, constant_fold, propagate_constants},
37    temporal_equivalences::apply_temporal_equivalences,
38    TLExpr,
39};
40
41/// Optimization level controlling aggressiveness of optimizations.
42#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default)]
43pub enum OptimizationLevel {
44    /// No optimizations (O0)
45    None,
46    /// Basic optimizations (O1): constant folding, simple simplifications
47    Basic,
48    /// Standard optimizations (O2): includes algebraic laws, normal forms
49    #[default]
50    Standard,
51    /// Aggressive optimizations (O3): all transformations, multiple passes
52    Aggressive,
53}
54
55/// A single optimization pass.
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum OptimizationPass {
58    /// Constant folding
59    ConstantFolding,
60    /// Constant propagation
61    ConstantPropagation,
62    /// Algebraic simplification
63    AlgebraicSimplification,
64    /// Convert to negation normal form
65    NegationNormalForm,
66    /// Apply modal logic equivalences
67    ModalEquivalences,
68    /// Apply temporal logic equivalences
69    TemporalEquivalences,
70    /// Apply distributive laws (AND over OR)
71    DistributiveAndOverOr,
72    /// Apply distributive laws (OR over AND)
73    DistributiveOrOverAnd,
74    /// Apply distributive laws for quantifiers
75    DistributiveQuantifiers,
76    /// Apply distributive laws for modal operators
77    DistributiveModal,
78}
79
80impl OptimizationPass {
81    /// Get the name of this pass.
82    pub fn name(&self) -> &'static str {
83        match self {
84            OptimizationPass::ConstantFolding => "constant_folding",
85            OptimizationPass::ConstantPropagation => "constant_propagation",
86            OptimizationPass::AlgebraicSimplification => "algebraic_simplification",
87            OptimizationPass::NegationNormalForm => "negation_normal_form",
88            OptimizationPass::ModalEquivalences => "modal_equivalences",
89            OptimizationPass::TemporalEquivalences => "temporal_equivalences",
90            OptimizationPass::DistributiveAndOverOr => "distributive_and_over_or",
91            OptimizationPass::DistributiveOrOverAnd => "distributive_or_over_and",
92            OptimizationPass::DistributiveQuantifiers => "distributive_quantifiers",
93            OptimizationPass::DistributiveModal => "distributive_modal",
94        }
95    }
96
97    /// Apply this pass to an expression.
98    pub fn apply(&self, expr: TLExpr) -> TLExpr {
99        match self {
100            OptimizationPass::ConstantFolding => constant_fold(&expr),
101            OptimizationPass::ConstantPropagation => propagate_constants(&expr),
102            OptimizationPass::AlgebraicSimplification => algebraic_simplify(&expr),
103            OptimizationPass::NegationNormalForm => to_nnf(&expr),
104            OptimizationPass::ModalEquivalences => apply_modal_equivalences(&expr),
105            OptimizationPass::TemporalEquivalences => apply_temporal_equivalences(&expr),
106            OptimizationPass::DistributiveAndOverOr => {
107                apply_distributive_laws(&expr, DistributiveStrategy::AndOverOr)
108            }
109            OptimizationPass::DistributiveOrOverAnd => {
110                apply_distributive_laws(&expr, DistributiveStrategy::OrOverAnd)
111            }
112            OptimizationPass::DistributiveQuantifiers => {
113                apply_distributive_laws(&expr, DistributiveStrategy::Quantifiers)
114            }
115            OptimizationPass::DistributiveModal => {
116                apply_distributive_laws(&expr, DistributiveStrategy::Modal)
117            }
118        }
119    }
120
121    /// Get the priority of this pass (lower = earlier in pipeline).
122    pub fn priority(&self) -> u32 {
123        match self {
124            // Early passes: normalize and fold constants
125            OptimizationPass::ConstantFolding => 10,
126            OptimizationPass::ConstantPropagation => 20,
127            OptimizationPass::NegationNormalForm => 30,
128            // Middle passes: apply equivalences and simplifications
129            OptimizationPass::AlgebraicSimplification => 40,
130            OptimizationPass::ModalEquivalences => 50,
131            OptimizationPass::TemporalEquivalences => 60,
132            // Late passes: distributive laws (can expand expressions)
133            OptimizationPass::DistributiveQuantifiers => 70,
134            OptimizationPass::DistributiveModal => 80,
135            OptimizationPass::DistributiveAndOverOr => 90,
136            OptimizationPass::DistributiveOrOverAnd => 100,
137        }
138    }
139
140    /// Get all passes for a given optimization level.
141    pub fn for_level(level: OptimizationLevel) -> Vec<OptimizationPass> {
142        match level {
143            OptimizationLevel::None => vec![],
144            OptimizationLevel::Basic => vec![
145                OptimizationPass::ConstantFolding,
146                OptimizationPass::ConstantPropagation,
147                OptimizationPass::AlgebraicSimplification,
148            ],
149            OptimizationLevel::Standard => vec![
150                OptimizationPass::ConstantFolding,
151                OptimizationPass::ConstantPropagation,
152                OptimizationPass::NegationNormalForm,
153                OptimizationPass::AlgebraicSimplification,
154                OptimizationPass::ModalEquivalences,
155                OptimizationPass::TemporalEquivalences,
156            ],
157            OptimizationLevel::Aggressive => vec![
158                OptimizationPass::ConstantFolding,
159                OptimizationPass::ConstantPropagation,
160                OptimizationPass::NegationNormalForm,
161                OptimizationPass::AlgebraicSimplification,
162                OptimizationPass::ModalEquivalences,
163                OptimizationPass::TemporalEquivalences,
164                OptimizationPass::DistributiveQuantifiers,
165                OptimizationPass::DistributiveModal,
166                OptimizationPass::DistributiveAndOverOr,
167            ],
168        }
169    }
170}
171
172/// Metrics collected during optimization.
173#[derive(Clone, Debug, Default, PartialEq)]
174pub struct OptimizationMetrics {
175    /// Number of passes applied
176    pub passes_applied: usize,
177    /// Number of iterations until convergence
178    pub iterations: usize,
179    /// Whether the pipeline converged
180    pub converged: bool,
181    /// Per-pass application counts
182    pub pass_counts: HashMap<String, usize>,
183    /// Initial expression size (node count)
184    pub initial_size: usize,
185    /// Final expression size (node count)
186    pub final_size: usize,
187    /// Size reduction ratio
188    pub reduction_ratio: f64,
189}
190
191impl OptimizationMetrics {
192    /// Create new empty metrics.
193    pub fn new() -> Self {
194        Self::default()
195    }
196
197    /// Record that a pass was applied.
198    pub fn record_pass(&mut self, pass: OptimizationPass) {
199        self.passes_applied += 1;
200        *self.pass_counts.entry(pass.name().to_string()).or_insert(0) += 1;
201    }
202
203    /// Compute final metrics.
204    pub fn finalize(&mut self, initial_size: usize, final_size: usize) {
205        self.initial_size = initial_size;
206        self.final_size = final_size;
207        self.reduction_ratio = if initial_size > 0 {
208            1.0 - (final_size as f64 / initial_size as f64)
209        } else {
210            0.0
211        };
212    }
213}
214
215/// Configuration for the optimization pipeline.
216#[derive(Clone, Debug)]
217pub struct PipelineConfig {
218    /// Optimization level
219    pub level: OptimizationLevel,
220    /// Maximum number of iterations
221    pub max_iterations: usize,
222    /// Custom pass ordering (if None, uses default ordering by priority)
223    pub custom_passes: Option<Vec<OptimizationPass>>,
224    /// Whether to enable convergence detection
225    pub enable_convergence: bool,
226}
227
228impl Default for PipelineConfig {
229    fn default() -> Self {
230        Self {
231            level: OptimizationLevel::Standard,
232            max_iterations: 10,
233            custom_passes: None,
234            enable_convergence: true,
235        }
236    }
237}
238
239impl PipelineConfig {
240    /// Create a new configuration with the given optimization level.
241    pub fn with_level(level: OptimizationLevel) -> Self {
242        Self {
243            level,
244            ..Default::default()
245        }
246    }
247
248    /// Set custom passes.
249    pub fn with_custom_passes(mut self, passes: Vec<OptimizationPass>) -> Self {
250        self.custom_passes = Some(passes);
251        self
252    }
253
254    /// Set maximum iterations.
255    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
256        self.max_iterations = max_iterations;
257        self
258    }
259
260    /// Disable convergence detection.
261    pub fn without_convergence(mut self) -> Self {
262        self.enable_convergence = false;
263        self
264    }
265}
266
267/// Optimization pipeline that orchestrates multiple passes.
268#[derive(Default)]
269pub struct OptimizationPipeline {
270    config: PipelineConfig,
271}
272
273impl OptimizationPipeline {
274    /// Create a new pipeline with the given configuration.
275    pub fn new(config: PipelineConfig) -> Self {
276        Self { config }
277    }
278
279    /// Create a pipeline with a specific optimization level.
280    pub fn with_level(level: OptimizationLevel) -> Self {
281        Self::new(PipelineConfig::with_level(level))
282    }
283
284    /// Optimize an expression using the configured pipeline.
285    ///
286    /// Returns the optimized expression and metrics about the optimization process.
287    pub fn optimize(&self, expr: TLExpr) -> (TLExpr, OptimizationMetrics) {
288        let mut current = expr;
289        let mut metrics = OptimizationMetrics::new();
290        let initial_size = count_nodes(&current);
291
292        // Get passes to apply (either custom or default for level)
293        let passes = self
294            .config
295            .custom_passes
296            .clone()
297            .unwrap_or_else(|| OptimizationPass::for_level(self.config.level));
298
299        // Sort passes by priority
300        let mut sorted_passes = passes.clone();
301        sorted_passes.sort_by_key(|p| p.priority());
302
303        // Apply passes iteratively until convergence or max iterations
304        for iteration in 0..self.config.max_iterations {
305            metrics.iterations = iteration + 1;
306            let previous = current.clone();
307
308            // Apply each pass in order
309            for pass in &sorted_passes {
310                let before = current.clone();
311                current = pass.apply(current);
312
313                // Record if the pass made changes
314                if before != current {
315                    metrics.record_pass(*pass);
316                }
317            }
318
319            // Check for convergence
320            if self.config.enable_convergence && current == previous {
321                metrics.converged = true;
322                break;
323            }
324        }
325
326        let final_size = count_nodes(&current);
327        metrics.finalize(initial_size, final_size);
328
329        (current, metrics)
330    }
331
332    /// Apply a single pass to an expression.
333    pub fn apply_pass(&self, expr: TLExpr, pass: OptimizationPass) -> TLExpr {
334        pass.apply(expr)
335    }
336
337    /// Get the configuration of this pipeline.
338    pub fn config(&self) -> &PipelineConfig {
339        &self.config
340    }
341}
342
343/// Count the number of nodes in an expression (for metrics).
344fn count_nodes(expr: &TLExpr) -> usize {
345    match expr {
346        TLExpr::Pred { .. } | TLExpr::Constant(_) => 1,
347        TLExpr::And(l, r)
348        | TLExpr::Or(l, r)
349        | TLExpr::Imply(l, r)
350        | TLExpr::Add(l, r)
351        | TLExpr::Sub(l, r)
352        | TLExpr::Mul(l, r)
353        | TLExpr::Div(l, r)
354        | TLExpr::Pow(l, r)
355        | TLExpr::Mod(l, r)
356        | TLExpr::Min(l, r)
357        | TLExpr::Max(l, r)
358        | TLExpr::Eq(l, r)
359        | TLExpr::Lt(l, r)
360        | TLExpr::Gt(l, r)
361        | TLExpr::Lte(l, r)
362        | TLExpr::Gte(l, r) => 1 + count_nodes(l) + count_nodes(r),
363        TLExpr::Not(e)
364        | TLExpr::Score(e)
365        | TLExpr::Abs(e)
366        | TLExpr::Floor(e)
367        | TLExpr::Ceil(e)
368        | TLExpr::Round(e)
369        | TLExpr::Sqrt(e)
370        | TLExpr::Exp(e)
371        | TLExpr::Log(e)
372        | TLExpr::Sin(e)
373        | TLExpr::Cos(e)
374        | TLExpr::Tan(e)
375        | TLExpr::Box(e)
376        | TLExpr::Diamond(e)
377        | TLExpr::Next(e)
378        | TLExpr::Eventually(e)
379        | TLExpr::Always(e) => 1 + count_nodes(e),
380        TLExpr::Until { before, after }
381        | TLExpr::Release {
382            released: before,
383            releaser: after,
384        }
385        | TLExpr::WeakUntil { before, after }
386        | TLExpr::StrongRelease {
387            released: before,
388            releaser: after,
389        } => 1 + count_nodes(before) + count_nodes(after),
390        TLExpr::Exists { body, .. }
391        | TLExpr::ForAll { body, .. }
392        | TLExpr::SoftExists { body, .. }
393        | TLExpr::SoftForAll { body, .. }
394        | TLExpr::Aggregate { body, .. }
395        | TLExpr::WeightedRule { rule: body, .. }
396        | TLExpr::FuzzyNot { expr: body, .. } => 1 + count_nodes(body),
397        TLExpr::TNorm { left, right, .. }
398        | TLExpr::TCoNorm { left, right, .. }
399        | TLExpr::FuzzyImplication {
400            premise: left,
401            conclusion: right,
402            ..
403        } => 1 + count_nodes(left) + count_nodes(right),
404        TLExpr::ProbabilisticChoice { alternatives } => {
405            1 + alternatives
406                .iter()
407                .map(|(_, e)| count_nodes(e))
408                .sum::<usize>()
409        }
410        TLExpr::IfThenElse {
411            condition,
412            then_branch,
413            else_branch,
414        } => 1 + count_nodes(condition) + count_nodes(then_branch) + count_nodes(else_branch),
415        TLExpr::Let { value, body, .. } => 1 + count_nodes(value) + count_nodes(body),
416
417        // Beta.1 enhancements
418        TLExpr::Lambda { body, .. } => 1 + count_nodes(body),
419        TLExpr::Apply { function, argument } => 1 + count_nodes(function) + count_nodes(argument),
420        TLExpr::SetMembership { element, set }
421        | TLExpr::SetUnion {
422            left: element,
423            right: set,
424        }
425        | TLExpr::SetIntersection {
426            left: element,
427            right: set,
428        }
429        | TLExpr::SetDifference {
430            left: element,
431            right: set,
432        } => 1 + count_nodes(element) + count_nodes(set),
433        TLExpr::SetCardinality { set } => 1 + count_nodes(set),
434        TLExpr::EmptySet => 1,
435        TLExpr::SetComprehension { condition, .. } => 1 + count_nodes(condition),
436        TLExpr::CountingExists { body, .. }
437        | TLExpr::CountingForAll { body, .. }
438        | TLExpr::ExactCount { body, .. }
439        | TLExpr::Majority { body, .. } => 1 + count_nodes(body),
440        TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
441            1 + count_nodes(body)
442        }
443        TLExpr::Nominal { .. } => 1,
444        TLExpr::At { formula, .. } => 1 + count_nodes(formula),
445        TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => 1 + count_nodes(formula),
446        TLExpr::AllDifferent { .. } => 1,
447        TLExpr::GlobalCardinality { values, .. } => {
448            1 + values.iter().map(count_nodes).sum::<usize>()
449        }
450        TLExpr::Abducible { .. } => 1,
451        TLExpr::Explain { formula } => 1 + count_nodes(formula),
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use crate::Term;
459
460    #[test]
461    fn test_optimization_level_ordering() {
462        assert!(OptimizationLevel::None < OptimizationLevel::Basic);
463        assert!(OptimizationLevel::Basic < OptimizationLevel::Standard);
464        assert!(OptimizationLevel::Standard < OptimizationLevel::Aggressive);
465    }
466
467    #[test]
468    fn test_pass_priority_ordering() {
469        let passes = OptimizationPass::for_level(OptimizationLevel::Aggressive);
470        let priorities: Vec<u32> = passes.iter().map(|p| p.priority()).collect();
471
472        // Verify that constant folding comes first
473        assert_eq!(passes[0], OptimizationPass::ConstantFolding);
474        assert_eq!(priorities[0], 10);
475    }
476
477    #[test]
478    fn test_pipeline_basic_optimization() {
479        // (1.0 AND P(x)) should simplify to P(x)
480        let expr = TLExpr::and(
481            TLExpr::constant(1.0),
482            TLExpr::pred("P", vec![Term::var("x")]),
483        );
484
485        let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Basic);
486        let (_optimized, metrics) = pipeline.optimize(expr);
487
488        // Should have simplified
489        assert!(metrics.passes_applied > 0);
490        assert!(metrics.reduction_ratio > 0.0);
491        assert!(metrics.converged);
492    }
493
494    #[test]
495    fn test_pipeline_no_optimization() {
496        let expr = TLExpr::pred("P", vec![Term::var("x")]);
497
498        let pipeline = OptimizationPipeline::with_level(OptimizationLevel::None);
499        let (optimized, metrics) = pipeline.optimize(expr.clone());
500
501        // Should not change
502        assert_eq!(optimized, expr);
503        assert_eq!(metrics.passes_applied, 0);
504    }
505
506    #[test]
507    fn test_pipeline_convergence() {
508        // Expression that requires multiple passes
509        let expr = TLExpr::and(
510            TLExpr::or(TLExpr::constant(1.0), TLExpr::constant(0.0)),
511            TLExpr::pred("P", vec![Term::var("x")]),
512        );
513
514        let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Standard);
515        let (_, metrics) = pipeline.optimize(expr);
516
517        // Should converge
518        assert!(metrics.converged);
519        assert!(metrics.iterations > 0);
520    }
521
522    #[test]
523    fn test_pipeline_max_iterations() {
524        let expr = TLExpr::pred("P", vec![Term::var("x")]);
525
526        let config = PipelineConfig::default().with_max_iterations(5);
527        let pipeline = OptimizationPipeline::new(config);
528        let (_, metrics) = pipeline.optimize(expr);
529
530        // Should not exceed max iterations
531        assert!(metrics.iterations <= 5);
532    }
533
534    #[test]
535    fn test_custom_passes() {
536        let expr = TLExpr::constant(42.0);
537
538        let custom_passes = vec![
539            OptimizationPass::ConstantFolding,
540            OptimizationPass::AlgebraicSimplification,
541        ];
542
543        let config = PipelineConfig::default().with_custom_passes(custom_passes);
544        let pipeline = OptimizationPipeline::new(config);
545        let (_, metrics) = pipeline.optimize(expr);
546
547        // Verify only specified passes were used
548        assert!(metrics.pass_counts.len() <= 2);
549    }
550
551    #[test]
552    fn test_metrics_tracking() {
553        let expr = TLExpr::and(
554            TLExpr::constant(1.0),
555            TLExpr::pred("P", vec![Term::var("x")]),
556        );
557
558        let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Standard);
559        let (_, metrics) = pipeline.optimize(expr);
560
561        assert!(metrics.initial_size > metrics.final_size);
562        assert!(metrics.reduction_ratio > 0.0);
563        assert!(metrics.reduction_ratio <= 1.0);
564    }
565
566    #[test]
567    fn test_count_nodes_simple() {
568        let expr = TLExpr::pred("P", vec![Term::var("x")]);
569        assert_eq!(count_nodes(&expr), 1);
570    }
571
572    #[test]
573    fn test_count_nodes_complex() {
574        let expr = TLExpr::and(
575            TLExpr::pred("P", vec![Term::var("x")]),
576            TLExpr::or(
577                TLExpr::pred("Q", vec![Term::var("y")]),
578                TLExpr::pred("R", vec![Term::var("z")]),
579            ),
580        );
581        // 1 (AND) + 1 (P) + 1 (OR) + 1 (Q) + 1 (R) = 5
582        assert_eq!(count_nodes(&expr), 5);
583    }
584
585    #[test]
586    fn test_pipeline_aggressive_level() {
587        let expr = TLExpr::and(
588            TLExpr::or(
589                TLExpr::pred("P", vec![Term::var("x")]),
590                TLExpr::pred("Q", vec![Term::var("x")]),
591            ),
592            TLExpr::pred("R", vec![Term::var("x")]),
593        );
594
595        let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Aggressive);
596        let (_, metrics) = pipeline.optimize(expr);
597
598        // Aggressive level should apply more passes
599        assert!(metrics.passes_applied > 0);
600    }
601
602    #[test]
603    fn test_pass_application() {
604        let expr = TLExpr::constant(1.0);
605        let pipeline = OptimizationPipeline::default();
606
607        let result = pipeline.apply_pass(expr.clone(), OptimizationPass::ConstantFolding);
608        assert_eq!(result, expr); // Constants don't change
609    }
610}