Skip to main content

tidepool_optimize/
pipeline.rs

1use crate::beta::BetaReduce;
2use crate::case_reduce::CaseReduce;
3use crate::dce::Dce;
4use crate::inline::Inline;
5use crate::partial::PartialEval;
6use tidepool_eval::pass::{Changed, Pass};
7use tidepool_repr::CoreExpr;
8
9/// Maximum number of iterations for the pipeline to avoid infinite loops.
10pub const MAX_PIPELINE_ITERATIONS: usize = 1000;
11
12/// Statistics from a pipeline run.
13#[derive(Debug, Clone, Default)]
14pub struct PipelineStats {
15    /// Total number of pipeline iterations.
16    /// Includes the final iteration where no changes were reported.
17    pub iterations: usize,
18    /// Total number of times each pass was invoked.
19    pub pass_invocations: Vec<(String, usize)>,
20}
21
22/// Run a sequence of passes to fixed point.
23/// Keeps iterating until no pass reports a change or MAX_PIPELINE_ITERATIONS is reached.
24/// Returns stats about how many iterations and per-pass invocations.
25///
26/// # Panics
27/// Panics if the number of iterations exceeds MAX_PIPELINE_ITERATIONS.
28pub fn run_pipeline(passes: &[Box<dyn Pass>], expr: &mut CoreExpr) -> PipelineStats {
29    let mut stats = PipelineStats {
30        iterations: 0,
31        pass_invocations: passes.iter().map(|p| (p.name().to_string(), 0)).collect(),
32    };
33
34    if passes.is_empty() {
35        return stats;
36    }
37
38    loop {
39        stats.iterations += 1;
40        if stats.iterations > MAX_PIPELINE_ITERATIONS {
41            panic!(
42                "Optimization pipeline exceeded maximum iterations ({}). Potential infinite loop in passes: {:?}",
43                MAX_PIPELINE_ITERATIONS,
44                passes.iter().map(|p| p.name()).collect::<Vec<_>>()
45            );
46        }
47
48        let mut changed: Changed = false;
49        for (i, pass) in passes.iter().enumerate() {
50            if pass.run(expr) {
51                changed = true;
52            }
53            stats.pass_invocations[i].1 += 1;
54        }
55
56        if !changed {
57            break;
58        }
59    }
60
61    stats
62}
63
64/// Returns the default optimization pass sequence.
65/// Order: BetaReduce → Inline → CaseReduce → Dce → PartialEval.
66pub fn default_passes() -> Vec<Box<dyn Pass>> {
67    vec![
68        Box::new(BetaReduce),
69        Box::new(Inline),
70        Box::new(CaseReduce),
71        Box::new(Dce),
72        Box::new(PartialEval),
73    ]
74}
75
76/// Run the default optimization pipeline to fixed point.
77pub fn optimize(expr: &mut CoreExpr) -> PipelineStats {
78    run_pipeline(&default_passes(), expr)
79}
80
81/// Run a single pass to fixed point (convenience).
82/// Returns the number of times the pass reported a change.
83pub fn run_pass_to_fixpoint(pass: &dyn Pass, expr: &mut CoreExpr) -> usize {
84    let mut changes = 0;
85    loop {
86        if !pass.run(expr) {
87            break;
88        }
89        changes += 1;
90        if changes >= MAX_PIPELINE_ITERATIONS {
91            panic!(
92                "Pass '{}' exceeded maximum iterations ({}) in run_pass_to_fixpoint.",
93                pass.name(),
94                MAX_PIPELINE_ITERATIONS
95            );
96        }
97    }
98    changes
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use tidepool_repr::{CoreFrame, RecursiveTree, VarId};
105    use std::cell::Cell;
106
107    struct TestPass {
108        name: String,
109        changes_remaining: Cell<usize>,
110    }
111
112    impl Pass for TestPass {
113        fn run(&self, _expr: &mut CoreExpr) -> Changed {
114            let rem = self.changes_remaining.get();
115            if rem > 0 {
116                self.changes_remaining.set(rem - 1);
117                true
118            } else {
119                false
120            }
121        }
122
123        fn name(&self) -> &str {
124            &self.name
125        }
126    }
127
128    fn dummy_expr() -> CoreExpr {
129        RecursiveTree {
130            nodes: vec![CoreFrame::Var(VarId(0))],
131        }
132    }
133
134    #[test]
135    fn test_empty_pipeline() {
136        let mut expr = dummy_expr();
137        let stats = run_pipeline(&[], &mut expr);
138        assert_eq!(stats.iterations, 0);
139        assert!(stats.pass_invocations.is_empty());
140    }
141
142    #[test]
143    fn test_single_noop_pass() {
144        let mut expr = dummy_expr();
145        let pass = Box::new(TestPass {
146            name: "NoOp".to_string(),
147            changes_remaining: Cell::new(0),
148        });
149        let stats = run_pipeline(&[pass], &mut expr);
150        assert_eq!(stats.iterations, 1);
151        assert_eq!(stats.pass_invocations[0], ("NoOp".to_string(), 1));
152    }
153
154    #[test]
155    fn test_single_changing_pass() {
156        let mut expr = dummy_expr();
157        let pass = Box::new(TestPass {
158            name: "Changing".to_string(),
159            changes_remaining: Cell::new(1),
160        });
161        let stats = run_pipeline(&[pass], &mut expr);
162        assert_eq!(stats.iterations, 2);
163        assert_eq!(stats.pass_invocations[0], ("Changing".to_string(), 2));
164    }
165
166    #[test]
167    fn test_fixed_point_terminates() {
168        let mut expr = dummy_expr();
169        let n = 5;
170        let pass = Box::new(TestPass {
171            name: "N-Times".to_string(),
172            changes_remaining: Cell::new(n),
173        });
174        let stats = run_pipeline(&[pass], &mut expr);
175        assert_eq!(stats.iterations, n + 1);
176        assert_eq!(stats.pass_invocations[0], ("N-Times".to_string(), n + 1));
177    }
178
179    #[test]
180    fn test_pipeline_stats() {
181        let mut expr = dummy_expr();
182        let pass1 = Box::new(TestPass {
183            name: "P1".to_string(),
184            changes_remaining: Cell::new(2),
185        });
186        let pass2 = Box::new(TestPass {
187            name: "P2".to_string(),
188            changes_remaining: Cell::new(1),
189        });
190        let stats = run_pipeline(&[pass1, pass2], &mut expr);
191        // Iteration 1: P1 changes (2->1), P2 changes (1->0). Changed = true.
192        // Iteration 2: P1 changes (1->0), P2 no change. Changed = true.
193        // Iteration 3: P1 no change, P2 no change. Changed = false. Break.
194        assert_eq!(stats.iterations, 3);
195        assert_eq!(stats.pass_invocations[0], ("P1".to_string(), 3));
196        assert_eq!(stats.pass_invocations[1], ("P2".to_string(), 3));
197    }
198
199    #[test]
200    fn test_run_pass_to_fixpoint() {
201        let mut expr = dummy_expr();
202        let n = 3;
203        let pass = TestPass {
204            name: "N-Times".to_string(),
205            changes_remaining: Cell::new(n),
206        };
207        let changes = run_pass_to_fixpoint(&pass, &mut expr);
208        assert_eq!(changes, n);
209    }
210
211    #[test]
212    #[should_panic(expected = "Optimization pipeline exceeded maximum iterations")]
213    fn test_infinite_loop_panic() {
214        struct InfinitePass;
215        impl Pass for InfinitePass {
216            fn run(&self, _expr: &mut CoreExpr) -> Changed {
217                true
218            }
219            fn name(&self) -> &str {
220                "Infinite"
221            }
222        }
223        let mut expr = dummy_expr();
224        run_pipeline(&[Box::new(InfinitePass)], &mut expr);
225    }
226}