Skip to main content

tensorlogic_compiler/
pipeline.rs

1//! Full compiler optimization pipeline for TLExpr expressions.
2//!
3//! This module provides a top-level [`CompilerPipeline`] that chains **all** compiler
4//! passes in configurable order:
5//!
6//! `ConstProp → DeadCode → Inline → [Algebraic / OptimizationPipeline] → Rewrite`
7//!
8//! Unlike the algebraic-only [`crate::optimize::OptimizationPipeline`], the
9//! `CompilerPipeline` integrates the newer passes added in v0.1.11–v0.1.16
10//! (constant propagation, dead-code elimination, let-inlining, pattern-rewriting)
11//! and supports outer fixed-point iteration across all passes.
12//!
13//! # Quick Start
14//!
15//! ```rust
16//! use tensorlogic_compiler::pipeline::{CompilerPipeline, CompilerPipelineConfig};
17//! use tensorlogic_ir::TLExpr;
18//!
19//! let pipeline = CompilerPipeline::with_default();
20//! let expr = TLExpr::add(
21//!     TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
22//!     TLExpr::Constant(0.0),
23//! );
24//! let result = pipeline.run(expr);
25//! println!("{}", result.stats.summary());
26//! ```
27
28use std::collections::HashMap;
29use std::fmt;
30use std::time::{Duration, Instant};
31
32use tensorlogic_ir::TLExpr;
33
34use crate::const_prop::{ConstPropConfig, ConstantPropagator};
35use crate::dead_code::{DceConfig, DeadCodeEliminator};
36use crate::inline::{InlineConfig, LetInliner};
37use crate::optimize::OptimizationPipeline;
38use crate::rewrite::RewriteEngine;
39
40// ────────────────────────────────────────────────────────────────────────────
41// CompilerPassId
42// ────────────────────────────────────────────────────────────────────────────
43
44/// Identifies a single pass in the compiler pipeline.
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum CompilerPassId {
47    /// Constant propagation — evaluates constant sub-expressions at compile time.
48    ConstProp,
49    /// Dead code elimination — removes unreachable branches and unused bindings.
50    DeadCode,
51    /// Let-inlining — substitutes `Let`-bound variables into their use sites.
52    Inline,
53    /// Algebraic optimization pipeline (negation, folding, strength-reduction, …).
54    Algebraic,
55    /// Pattern-rewriting engine — applies structural rewrite rules to fixed point.
56    Rewrite,
57}
58
59impl fmt::Display for CompilerPassId {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            CompilerPassId::ConstProp => write!(f, "ConstProp"),
63            CompilerPassId::DeadCode => write!(f, "DeadCode"),
64            CompilerPassId::Inline => write!(f, "Inline"),
65            CompilerPassId::Algebraic => write!(f, "Algebraic"),
66            CompilerPassId::Rewrite => write!(f, "Rewrite"),
67        }
68    }
69}
70
71// ────────────────────────────────────────────────────────────────────────────
72// CompilerPassOrder
73// ────────────────────────────────────────────────────────────────────────────
74
75/// Canonical ordering in which passes are applied during one outer iteration.
76#[derive(Debug, Clone)]
77pub enum CompilerPassOrder {
78    /// `ConstProp → DCE → Inline → Algebraic → Rewrite`
79    ///
80    /// The default, well-balanced ordering.  Constant folding and DCE simplify
81    /// the tree before inlining; algebraic passes follow; pattern rewrites last.
82    CanonicalOrder,
83
84    /// `Inline → ConstProp → DCE → Algebraic → Rewrite`
85    ///
86    /// Inlining first exposes more constant sub-expressions for subsequent
87    /// folding and elimination passes.
88    InlineFirst,
89
90    /// `ConstProp → ConstProp → DCE → Inline → ConstProp → DCE → Rewrite`
91    ///
92    /// Runs constant propagation twice before elimination, then once more after
93    /// inlining to catch any newly-exposed constants.  Useful for deep algebraic
94    /// expressions with many nested constants.
95    AggressiveFold,
96
97    /// User-supplied pass sequence.
98    Custom(Vec<CompilerPassId>),
99}
100
101impl CompilerPassOrder {
102    /// Convert this ordering variant to the concrete list of passes to execute.
103    pub fn to_pass_list(&self) -> Vec<CompilerPassId> {
104        match self {
105            CompilerPassOrder::CanonicalOrder => vec![
106                CompilerPassId::ConstProp,
107                CompilerPassId::DeadCode,
108                CompilerPassId::Inline,
109                CompilerPassId::Algebraic,
110                CompilerPassId::Rewrite,
111            ],
112            CompilerPassOrder::InlineFirst => vec![
113                CompilerPassId::Inline,
114                CompilerPassId::ConstProp,
115                CompilerPassId::DeadCode,
116                CompilerPassId::Algebraic,
117                CompilerPassId::Rewrite,
118            ],
119            CompilerPassOrder::AggressiveFold => vec![
120                CompilerPassId::ConstProp,
121                CompilerPassId::ConstProp,
122                CompilerPassId::DeadCode,
123                CompilerPassId::Inline,
124                CompilerPassId::ConstProp,
125                CompilerPassId::DeadCode,
126                CompilerPassId::Rewrite,
127            ],
128            CompilerPassOrder::Custom(order) => order.clone(),
129        }
130    }
131}
132
133// ────────────────────────────────────────────────────────────────────────────
134// CompilerPipelineConfig
135// ────────────────────────────────────────────────────────────────────────────
136
137/// Controls which passes are enabled and how the full pipeline is configured.
138#[derive(Debug, Clone)]
139pub struct CompilerPipelineConfig {
140    /// Enable the constant-propagation pass.
141    pub enable_const_prop: bool,
142    /// Enable the dead-code-elimination pass.
143    pub enable_dead_code: bool,
144    /// Enable the let-inlining pass.
145    pub enable_inline: bool,
146    /// Enable the algebraic optimization sub-pipeline.
147    pub enable_algebraic: bool,
148    /// Enable the pattern-rewrite engine.
149    pub enable_rewrite: bool,
150    /// Order in which passes are applied within a single outer iteration.
151    pub pass_order: CompilerPassOrder,
152    /// Maximum number of outer fixed-point iterations over the full pass sequence.
153    pub max_outer_iterations: u32,
154    /// Configuration forwarded to the constant-propagation pass.
155    pub const_prop_config: ConstPropConfig,
156    /// Configuration forwarded to the dead-code-elimination pass.
157    pub dce_config: DceConfig,
158    /// Configuration forwarded to the let-inlining pass.
159    pub inline_config: InlineConfig,
160}
161
162impl Default for CompilerPipelineConfig {
163    fn default() -> Self {
164        Self {
165            enable_const_prop: true,
166            enable_dead_code: true,
167            enable_inline: true,
168            enable_algebraic: true,
169            enable_rewrite: true,
170            pass_order: CompilerPassOrder::CanonicalOrder,
171            max_outer_iterations: 3,
172            const_prop_config: ConstPropConfig::default(),
173            dce_config: DceConfig::default(),
174            inline_config: InlineConfig::default(),
175        }
176    }
177}
178
179// ────────────────────────────────────────────────────────────────────────────
180// CompilerPassStats
181// ────────────────────────────────────────────────────────────────────────────
182
183/// Timing and reduction statistics for a single pass execution.
184#[derive(Debug, Clone)]
185pub struct CompilerPassStats {
186    /// Which pass produced these stats.
187    pub pass_id: CompilerPassId,
188    /// Wall-clock time spent in this pass.
189    pub wall_time: Duration,
190    /// Node count immediately before the pass ran.
191    pub nodes_before: u64,
192    /// Node count immediately after the pass completed.
193    pub nodes_after: u64,
194    /// Number of reductions (folds, eliminations, inlines, rewrites, …).
195    pub reductions: u64,
196}
197
198impl CompilerPassStats {
199    /// Fraction of nodes eliminated by this pass: `(before − after) / before * 100`.
200    ///
201    /// Returns `0.0` when `nodes_before == 0`.
202    pub fn reduction_pct(&self) -> f64 {
203        if self.nodes_before == 0 {
204            return 0.0;
205        }
206        let before = self.nodes_before as f64;
207        let after = self.nodes_after as f64;
208        ((before - after) / before * 100.0).max(0.0)
209    }
210
211    /// Human-readable one-line summary of this pass execution.
212    pub fn summary(&self) -> String {
213        format!(
214            "{:<12} {:>8.3}ms  nodes: {:>6} → {:>6} ({:>5.1}%)  reductions: {}",
215            self.pass_id.to_string(),
216            self.wall_time.as_secs_f64() * 1_000.0,
217            self.nodes_before,
218            self.nodes_after,
219            self.reduction_pct(),
220            self.reductions,
221        )
222    }
223}
224
225// ────────────────────────────────────────────────────────────────────────────
226// CompilerPipelineStats
227// ────────────────────────────────────────────────────────────────────────────
228
229/// Aggregate statistics for an entire pipeline run.
230#[derive(Debug, Clone)]
231pub struct CompilerPipelineStats {
232    /// Per-pass statistics, in execution order.
233    pub pass_stats: Vec<CompilerPassStats>,
234    /// Total wall-clock time for the full pipeline run.
235    pub total_wall_time: Duration,
236    /// Number of outer fixed-point iterations executed.
237    pub outer_iterations: u32,
238    /// Total node reduction across all passes: `initial − final` (may be negative
239    /// if a pass somehow increases node count, which should not happen in practice).
240    pub total_node_reduction: i64,
241    /// Node count before the very first pass.
242    pub initial_node_count: u64,
243    /// Node count after the very last pass.
244    pub final_node_count: u64,
245}
246
247impl CompilerPipelineStats {
248    /// Fraction of nodes eliminated across the entire pipeline run.
249    ///
250    /// Returns `0.0` when `initial_node_count == 0`.
251    pub fn overall_reduction_pct(&self) -> f64 {
252        if self.initial_node_count == 0 {
253            return 0.0;
254        }
255        let before = self.initial_node_count as f64;
256        let after = self.final_node_count as f64;
257        ((before - after) / before * 100.0).max(0.0)
258    }
259
260    /// Returns the pass that consumed the most wall-clock time, or `None` if no
261    /// passes were executed.
262    pub fn slowest_pass(&self) -> Option<&CompilerPassStats> {
263        self.pass_stats.iter().max_by_key(|s| s.wall_time)
264    }
265
266    /// Render a formatted table of per-pass timing and reduction statistics.
267    pub fn format_table(&self) -> String {
268        let mut out = String::new();
269        out.push_str("┌──────────────────────────────────────────────────────────────────┐\n");
270        out.push_str("│  Pass          Time(ms)   Nodes Before → After    Pct   Reductions│\n");
271        out.push_str("├──────────────────────────────────────────────────────────────────┤\n");
272        for s in &self.pass_stats {
273            out.push_str(&format!("│  {}\n", s.summary()));
274        }
275        out.push_str("├──────────────────────────────────────────────────────────────────┤\n");
276        out.push_str(&format!(
277            "│  TOTAL         {:>8.3}ms  {:>6} nodes → {:>6} ({:>5.1}% overall)      │\n",
278            self.total_wall_time.as_secs_f64() * 1_000.0,
279            self.initial_node_count,
280            self.final_node_count,
281            self.overall_reduction_pct(),
282        ));
283        out.push_str("└──────────────────────────────────────────────────────────────────┘\n");
284        out
285    }
286
287    /// Human-readable one-line summary of the full pipeline run.
288    pub fn summary(&self) -> String {
289        format!(
290            "Pipeline: {} outer iterations, {:.3}ms total, {} → {} nodes ({:.1}% reduction)",
291            self.outer_iterations,
292            self.total_wall_time.as_secs_f64() * 1_000.0,
293            self.initial_node_count,
294            self.final_node_count,
295            self.overall_reduction_pct(),
296        )
297    }
298}
299
300// ────────────────────────────────────────────────────────────────────────────
301// CompilerPipelineResult
302// ────────────────────────────────────────────────────────────────────────────
303
304/// Output of a full pipeline run: the transformed expression plus collected statistics.
305#[derive(Debug, Clone)]
306pub struct CompilerPipelineResult {
307    /// The optimized expression.
308    pub expr: TLExpr,
309    /// Statistics covering all passes and outer iterations.
310    pub stats: CompilerPipelineStats,
311}
312
313// ────────────────────────────────────────────────────────────────────────────
314// PassBenchmark
315// ────────────────────────────────────────────────────────────────────────────
316
317/// Benchmarking statistics across multiple repeated runs of the same pass.
318#[derive(Debug, Clone)]
319pub struct PassBenchmark {
320    /// Which pass was benchmarked.
321    pub pass_id: CompilerPassId,
322    /// Number of runs included in these statistics.
323    pub runs: usize,
324    /// Minimum observed wall-clock time in nanoseconds.
325    pub min_ns: u64,
326    /// Maximum observed wall-clock time in nanoseconds.
327    pub max_ns: u64,
328    /// Arithmetic mean wall-clock time in nanoseconds.
329    pub mean_ns: u64,
330    /// Sum of all `reductions` values across all runs.
331    pub total_reductions: u64,
332}
333
334impl PassBenchmark {
335    /// Human-readable one-line benchmark summary.
336    pub fn summary(&self) -> String {
337        format!(
338            "{:<12}  runs={:>4}  min={:.3}ms  mean={:.3}ms  max={:.3}ms  reductions={}",
339            self.pass_id.to_string(),
340            self.runs,
341            self.min_ns as f64 / 1_000_000.0,
342            self.mean_ns as f64 / 1_000_000.0,
343            self.max_ns as f64 / 1_000_000.0,
344            self.total_reductions,
345        )
346    }
347}
348
349// ────────────────────────────────────────────────────────────────────────────
350// CompilerPipeline
351// ────────────────────────────────────────────────────────────────────────────
352
353/// Full compiler optimization pipeline.
354///
355/// Chains all compiler passes in configurable order with per-pass timing and
356/// aggregate statistics.  The outer fixed-point loop repeats the entire pass
357/// sequence until no further nodes are reduced or `config.max_outer_iterations`
358/// is reached.
359pub struct CompilerPipeline {
360    config: CompilerPipelineConfig,
361}
362
363impl Default for CompilerPipeline {
364    fn default() -> Self {
365        Self::with_default()
366    }
367}
368
369impl CompilerPipeline {
370    /// Create a pipeline with the given configuration.
371    pub fn new(config: CompilerPipelineConfig) -> Self {
372        Self { config }
373    }
374
375    /// Create a pipeline with the default configuration (all passes enabled).
376    pub fn with_default() -> Self {
377        Self::new(CompilerPipelineConfig::default())
378    }
379
380    /// Alias for [`Self::with_default`].
381    pub fn all_passes() -> Self {
382        Self::with_default()
383    }
384
385    /// Create a pipeline with all passes disabled.
386    ///
387    /// Expressions passed through this pipeline are returned unchanged (other
388    /// than recording the initial node count in statistics).
389    pub fn no_passes() -> Self {
390        Self::new(CompilerPipelineConfig {
391            enable_const_prop: false,
392            enable_dead_code: false,
393            enable_inline: false,
394            enable_algebraic: false,
395            enable_rewrite: false,
396            ..CompilerPipelineConfig::default()
397        })
398    }
399
400    // ── Public API ───────────────────────────────────────────────────────────
401
402    /// Run the full pipeline on `expr` and return the result with statistics.
403    pub fn run(&self, expr: TLExpr) -> CompilerPipelineResult {
404        let pipeline_start = Instant::now();
405        let initial_node_count = Self::count_nodes(&expr);
406
407        let mut stats = CompilerPipelineStats {
408            pass_stats: Vec::new(),
409            total_wall_time: Duration::ZERO,
410            outer_iterations: 0,
411            total_node_reduction: 0,
412            initial_node_count,
413            final_node_count: initial_node_count,
414        };
415
416        let order = self.config.pass_order.to_pass_list();
417        let mut current = expr;
418        let max_iters = self.config.max_outer_iterations.max(1);
419
420        for _ in 0..max_iters {
421            let nodes_before_iter = Self::count_nodes(&current);
422            current = self.run_sequence(current, &order, &mut stats);
423            stats.outer_iterations += 1;
424
425            let nodes_after_iter = Self::count_nodes(&current);
426            // Stop early if no progress was made in this outer iteration.
427            if nodes_after_iter >= nodes_before_iter {
428                break;
429            }
430        }
431
432        let final_node_count = Self::count_nodes(&current);
433        stats.final_node_count = final_node_count;
434        stats.total_node_reduction = initial_node_count as i64 - final_node_count as i64;
435        stats.total_wall_time = pipeline_start.elapsed();
436
437        CompilerPipelineResult {
438            expr: current,
439            stats,
440        }
441    }
442
443    /// Run the pipeline `runs` times on (a clone of) `expr` and return per-pass
444    /// benchmarking statistics aggregated across all runs.
445    pub fn benchmark(&self, expr: TLExpr, runs: usize) -> Vec<PassBenchmark> {
446        // Accumulate timing per pass_id across all runs.
447        let mut timings: HashMap<String, (u64, u64, u64, u64, u64)> = HashMap::new(); // key → (count, min_ns, max_ns, sum_ns, total_reductions)
448
449        let effective_runs = runs.max(1);
450
451        for _ in 0..effective_runs {
452            let result = self.run(expr.clone());
453            for ps in &result.stats.pass_stats {
454                let ns = ps.wall_time.as_nanos() as u64;
455                let key = ps.pass_id.to_string();
456                let entry = timings.entry(key).or_insert((0, u64::MAX, 0, 0, 0));
457                entry.0 += 1;
458                entry.1 = entry.1.min(ns);
459                entry.2 = entry.2.max(ns);
460                entry.3 = entry.3.saturating_add(ns);
461                entry.4 = entry.4.saturating_add(ps.reductions);
462            }
463        }
464
465        // Build a PassBenchmark per distinct pass id in the order they appear in
466        // the pass list (deduplicating for Custom / AggressiveFold repeats).
467        let order = self.config.pass_order.to_pass_list();
468        let mut seen: Vec<String> = Vec::new();
469        let mut benchmarks: Vec<PassBenchmark> = Vec::new();
470
471        for pass_id in &order {
472            let key = pass_id.to_string();
473            if seen.contains(&key) {
474                continue;
475            }
476            seen.push(key.clone());
477            if let Some(&(count, min_ns, max_ns, sum_ns, total_reductions)) = timings.get(&key) {
478                let mean_ns = sum_ns.checked_div(count).unwrap_or(0);
479                benchmarks.push(PassBenchmark {
480                    pass_id: pass_id.clone(),
481                    runs: count as usize,
482                    min_ns,
483                    max_ns,
484                    mean_ns,
485                    total_reductions,
486                });
487            }
488        }
489
490        benchmarks
491    }
492
493    // ── Private helpers ──────────────────────────────────────────────────────
494
495    /// Execute the given pass sequence once, recording stats for each invocation.
496    fn run_sequence(
497        &self,
498        mut expr: TLExpr,
499        order: &[CompilerPassId],
500        stats: &mut CompilerPipelineStats,
501    ) -> TLExpr {
502        for pass_id in order {
503            expr = self.run_single_pass(pass_id, expr, stats);
504        }
505        expr
506    }
507
508    /// Execute a single pass, recording timing and reduction statistics.
509    fn run_single_pass(
510        &self,
511        pass_id: &CompilerPassId,
512        expr: TLExpr,
513        stats: &mut CompilerPipelineStats,
514    ) -> TLExpr {
515        let nodes_before = Self::count_nodes(&expr);
516        let t0 = Instant::now();
517
518        let (new_expr, reductions) = match pass_id {
519            CompilerPassId::ConstProp => {
520                if !self.config.enable_const_prop {
521                    return expr;
522                }
523                let propagator = ConstantPropagator::new(self.config.const_prop_config.clone());
524                let (out, s) = propagator.run(expr);
525                let r = s.total_folds();
526                (out, r)
527            }
528
529            CompilerPassId::DeadCode => {
530                if !self.config.enable_dead_code {
531                    return expr;
532                }
533                let eliminator = DeadCodeEliminator::new(self.config.dce_config.clone());
534                let (out, s) = eliminator.run(expr);
535                let r = s.total_eliminations();
536                (out, r)
537            }
538
539            CompilerPassId::Inline => {
540                if !self.config.enable_inline {
541                    return expr;
542                }
543                let inliner = LetInliner::new(self.config.inline_config.clone());
544                let (out, s) = inliner.run(expr);
545                let r = s.total();
546                (out, r)
547            }
548
549            CompilerPassId::Algebraic => {
550                if !self.config.enable_algebraic {
551                    return expr;
552                }
553                let alg_pipeline = OptimizationPipeline::new();
554                let (out, s) = alg_pipeline.optimize(&expr);
555                let r = s.total_optimizations() as u64;
556                (out, r)
557            }
558
559            CompilerPassId::Rewrite => {
560                if !self.config.enable_rewrite {
561                    return expr;
562                }
563                let engine = RewriteEngine::new().add_all_builtin_rules();
564                let (out, s) = engine.rewrite(expr);
565                let r = s.total_rewrites;
566                (out, r)
567            }
568        };
569
570        let wall_time = t0.elapsed();
571        let nodes_after = Self::count_nodes(&new_expr);
572
573        stats.pass_stats.push(CompilerPassStats {
574            pass_id: pass_id.clone(),
575            wall_time,
576            nodes_before,
577            nodes_after,
578            reductions,
579        });
580
581        new_expr
582    }
583
584    /// Count the total number of AST nodes in `expr` via the DCE helper.
585    fn count_nodes(expr: &TLExpr) -> u64 {
586        DeadCodeEliminator::count_nodes(expr)
587    }
588}
589
590// ────────────────────────────────────────────────────────────────────────────
591// Tests
592// ────────────────────────────────────────────────────────────────────────────
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use tensorlogic_ir::{TLExpr, Term};
598
599    // ── Helpers ──────────────────────────────────────────────────────────────
600
601    fn simple_constant_expr() -> TLExpr {
602        // Add(Mul(2, 3), 4)  →  10
603        TLExpr::add(
604            TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
605            TLExpr::Constant(4.0),
606        )
607    }
608
609    fn dead_branch_expr() -> TLExpr {
610        // And(True, p(x))  →  p(x)
611        TLExpr::and(
612            TLExpr::Constant(1.0),
613            TLExpr::pred("p", vec![Term::var("x")]),
614        )
615    }
616
617    fn let_binding_expr() -> TLExpr {
618        // Let y = 5.0 in y  →  5.0
619        TLExpr::let_binding("y", TLExpr::Constant(5.0), TLExpr::pred("y", vec![]))
620    }
621
622    fn non_trivial_expr() -> TLExpr {
623        // Not(Not(p(x))) wrapped in And(True, _)
624        TLExpr::and(
625            TLExpr::Constant(1.0),
626            TLExpr::negate(TLExpr::negate(TLExpr::pred("p", vec![Term::var("x")]))),
627        )
628    }
629
630    // ── Config tests ─────────────────────────────────────────────────────────
631
632    #[test]
633    fn test_compiler_pipeline_config_default() {
634        let cfg = CompilerPipelineConfig::default();
635        assert!(cfg.enable_const_prop);
636        assert!(cfg.enable_dead_code);
637        assert!(cfg.enable_inline);
638        assert!(cfg.enable_algebraic);
639        assert!(cfg.enable_rewrite);
640        assert_eq!(cfg.max_outer_iterations, 3);
641    }
642
643    // ── No-pass pipeline ─────────────────────────────────────────────────────
644
645    #[test]
646    fn test_compiler_pipeline_no_passes() {
647        let pipeline = CompilerPipeline::no_passes();
648        let expr = simple_constant_expr();
649        let result = pipeline.run(expr.clone());
650        // With all passes disabled, the expression should be structurally identical.
651        assert_eq!(format!("{:?}", result.expr), format!("{:?}", expr),);
652    }
653
654    // ── Single-pass tests ────────────────────────────────────────────────────
655
656    #[test]
657    fn test_compiler_pipeline_const_prop_only() {
658        let cfg = CompilerPipelineConfig {
659            enable_const_prop: true,
660            enable_dead_code: false,
661            enable_inline: false,
662            enable_algebraic: false,
663            enable_rewrite: false,
664            max_outer_iterations: 1,
665            ..CompilerPipelineConfig::default()
666        };
667        let pipeline = CompilerPipeline::new(cfg);
668        let expr = simple_constant_expr();
669        let result = pipeline.run(expr);
670        // The expression should have been folded to a constant.
671        assert!(matches!(result.expr, TLExpr::Constant(_)));
672    }
673
674    #[test]
675    fn test_compiler_pipeline_dead_code_only() {
676        let cfg = CompilerPipelineConfig {
677            enable_const_prop: false,
678            enable_dead_code: true,
679            enable_inline: false,
680            enable_algebraic: false,
681            enable_rewrite: false,
682            max_outer_iterations: 1,
683            ..CompilerPipelineConfig::default()
684        };
685        let pipeline = CompilerPipeline::new(cfg);
686        let expr = dead_branch_expr();
687        let result = pipeline.run(expr);
688        // And(True, p(x))  →  p(x)
689        assert!(matches!(result.expr, TLExpr::Pred { .. }));
690    }
691
692    #[test]
693    fn test_compiler_pipeline_inline_only() {
694        let cfg = CompilerPipelineConfig {
695            enable_const_prop: false,
696            enable_dead_code: false,
697            enable_inline: true,
698            enable_algebraic: false,
699            enable_rewrite: false,
700            max_outer_iterations: 1,
701            ..CompilerPipelineConfig::default()
702        };
703        let pipeline = CompilerPipeline::new(cfg);
704        let expr = let_binding_expr();
705        let result = pipeline.run(expr);
706        // Let y = 5.0 in y  →  5.0
707        assert!(matches!(result.expr, TLExpr::Constant(v) if (v - 5.0).abs() < 1e-12));
708    }
709
710    #[test]
711    fn test_compiler_pipeline_all_passes() {
712        let pipeline = CompilerPipeline::all_passes();
713        let expr = non_trivial_expr();
714        // Should run without panicking.
715        let result = pipeline.run(expr);
716        assert!(result.stats.outer_iterations > 0);
717    }
718
719    // ── Stats tests ──────────────────────────────────────────────────────────
720
721    #[test]
722    fn test_compiler_pipeline_result_has_stats() {
723        let pipeline = CompilerPipeline::with_default();
724        let expr = simple_constant_expr();
725        let result = pipeline.run(expr);
726        assert!(result.stats.initial_node_count > 0);
727        assert!(!result.stats.pass_stats.is_empty());
728    }
729
730    #[test]
731    fn test_pass_stats_reduction_pct() {
732        let s = CompilerPassStats {
733            pass_id: CompilerPassId::ConstProp,
734            wall_time: Duration::from_millis(1),
735            nodes_before: 100,
736            nodes_after: 80,
737            reductions: 5,
738        };
739        let pct = s.reduction_pct();
740        assert!((pct - 20.0).abs() < 1e-6, "expected 20%, got {pct}");
741    }
742
743    #[test]
744    fn test_pass_stats_reduction_pct_zero_before() {
745        let s = CompilerPassStats {
746            pass_id: CompilerPassId::DeadCode,
747            wall_time: Duration::ZERO,
748            nodes_before: 0,
749            nodes_after: 0,
750            reductions: 0,
751        };
752        assert_eq!(s.reduction_pct(), 0.0);
753    }
754
755    #[test]
756    fn test_pass_stats_summary_nonempty() {
757        let s = CompilerPassStats {
758            pass_id: CompilerPassId::Inline,
759            wall_time: Duration::from_micros(500),
760            nodes_before: 10,
761            nodes_after: 8,
762            reductions: 2,
763        };
764        let summary = s.summary();
765        assert!(!summary.is_empty());
766        assert!(summary.contains("Inline"));
767    }
768
769    #[test]
770    fn test_pipeline_stats_overall_reduction() {
771        let pipeline = CompilerPipeline::with_default();
772        let expr = simple_constant_expr();
773        let result = pipeline.run(expr);
774        let initial = result.stats.initial_node_count;
775        let final_count = result.stats.final_node_count;
776        assert!(
777            initial >= final_count,
778            "pipeline should not increase node count"
779        );
780        let pct = result.stats.overall_reduction_pct();
781        assert!(pct >= 0.0);
782    }
783
784    #[test]
785    fn test_pipeline_stats_format_table() {
786        let pipeline = CompilerPipeline::with_default();
787        let expr = simple_constant_expr();
788        let result = pipeline.run(expr);
789        let table = result.stats.format_table();
790        assert!(
791            table.contains("Pass") || table.contains("TOTAL"),
792            "table should contain headers, got: {table}"
793        );
794    }
795
796    #[test]
797    fn test_pipeline_stats_summary_nonempty() {
798        let pipeline = CompilerPipeline::with_default();
799        let expr = simple_constant_expr();
800        let result = pipeline.run(expr);
801        let summary = result.stats.summary();
802        assert!(!summary.is_empty());
803        assert!(summary.contains("Pipeline"));
804    }
805
806    #[test]
807    fn test_pipeline_stats_slowest_pass() {
808        let pipeline = CompilerPipeline::with_default();
809        let expr = simple_constant_expr();
810        let result = pipeline.run(expr);
811        // There should be at least one pass executed, so slowest_pass must be Some.
812        assert!(result.stats.slowest_pass().is_some());
813    }
814
815    // ── Order tests ──────────────────────────────────────────────────────────
816
817    #[test]
818    fn test_compiler_pipeline_canonical_order() {
819        let cfg = CompilerPipelineConfig {
820            pass_order: CompilerPassOrder::CanonicalOrder,
821            ..CompilerPipelineConfig::default()
822        };
823        let pipeline = CompilerPipeline::new(cfg);
824        let result = pipeline.run(non_trivial_expr());
825        assert!(result.stats.outer_iterations >= 1);
826    }
827
828    #[test]
829    fn test_compiler_pipeline_inline_first() {
830        let cfg = CompilerPipelineConfig {
831            pass_order: CompilerPassOrder::InlineFirst,
832            ..CompilerPipelineConfig::default()
833        };
834        let pipeline = CompilerPipeline::new(cfg);
835        let result = pipeline.run(let_binding_expr());
836        assert!(result.stats.outer_iterations >= 1);
837    }
838
839    #[test]
840    fn test_compiler_pipeline_custom_order() {
841        let cfg = CompilerPipelineConfig {
842            pass_order: CompilerPassOrder::Custom(vec![
843                CompilerPassId::ConstProp,
844                CompilerPassId::DeadCode,
845            ]),
846            max_outer_iterations: 1,
847            ..CompilerPipelineConfig::default()
848        };
849        let pipeline = CompilerPipeline::new(cfg);
850        let result = pipeline.run(simple_constant_expr());
851        // Custom order with 2 passes → exactly 2 pass_stats entries per outer iter.
852        assert_eq!(result.stats.pass_stats.len(), 2);
853    }
854
855    #[test]
856    fn test_compiler_pipeline_outer_iterations() {
857        let cfg = CompilerPipelineConfig {
858            max_outer_iterations: 5,
859            ..CompilerPipelineConfig::default()
860        };
861        let pipeline = CompilerPipeline::new(cfg);
862        let result = pipeline.run(simple_constant_expr());
863        // Must terminate before or at the max.
864        assert!(result.stats.outer_iterations <= 5);
865        assert!(result.stats.outer_iterations >= 1);
866    }
867
868    // ── Benchmark tests ──────────────────────────────────────────────────────
869
870    #[test]
871    fn test_benchmark_runs_n_times() {
872        let pipeline = CompilerPipeline::with_default();
873        let expr = simple_constant_expr();
874        let benchmarks = pipeline.benchmark(expr, 4);
875        // There should be one entry per distinct pass in the canonical order.
876        let order_len = CompilerPassOrder::CanonicalOrder.to_pass_list().len();
877        assert_eq!(benchmarks.len(), order_len);
878        // Each benchmark should have been executed at least `runs` times
879        // (outer iterations may cause more executions per pipeline run).
880        for b in &benchmarks {
881            assert!(
882                b.runs >= 4,
883                "expected >=4 runs for {}, got {}",
884                b.pass_id,
885                b.runs
886            );
887        }
888    }
889
890    #[test]
891    fn test_pass_benchmark_summary_nonempty() {
892        let pipeline = CompilerPipeline::with_default();
893        let benchmarks = pipeline.benchmark(simple_constant_expr(), 2);
894        for b in &benchmarks {
895            let summary = b.summary();
896            assert!(!summary.is_empty());
897        }
898    }
899
900    // ── Idempotency test ─────────────────────────────────────────────────────
901
902    #[test]
903    fn test_pipeline_idempotent() {
904        let pipeline = CompilerPipeline::with_default();
905        let expr = non_trivial_expr();
906        let first = pipeline.run(expr);
907        let second = pipeline.run(first.expr.clone());
908        // Running a second time on an already-optimised expression should yield
909        // the same (or smaller) node count.
910        assert!(
911            second.stats.final_node_count <= first.stats.final_node_count,
912            "second run produced more nodes than first"
913        );
914    }
915
916    // ── AggressiveFold order test ────────────────────────────────────────────
917
918    #[test]
919    fn test_compiler_pipeline_aggressive_fold() {
920        let cfg = CompilerPipelineConfig {
921            pass_order: CompilerPassOrder::AggressiveFold,
922            max_outer_iterations: 2,
923            ..CompilerPipelineConfig::default()
924        };
925        let pipeline = CompilerPipeline::new(cfg);
926        let result = pipeline.run(simple_constant_expr());
927        assert!(result.stats.outer_iterations >= 1);
928    }
929
930    // ── PassBenchmark min ≤ mean ≤ max invariant ────────────────────────────
931
932    #[test]
933    fn test_benchmark_timing_invariants() {
934        let pipeline = CompilerPipeline::with_default();
935        let benchmarks = pipeline.benchmark(simple_constant_expr(), 3);
936        for b in &benchmarks {
937            assert!(b.min_ns <= b.mean_ns, "min_ns > mean_ns for {}", b.pass_id);
938            assert!(b.mean_ns <= b.max_ns, "mean_ns > max_ns for {}", b.pass_id);
939        }
940    }
941}