1use crate::occ::{get_occ, occ_analysis, Occ};
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{get_children, replace_subtree, CoreExpr, CoreFrame};
4
5pub struct Inline;
7
8impl Pass for Inline {
9 fn run(&self, expr: &mut CoreExpr) -> Changed {
10 if expr.nodes.is_empty() {
11 return false;
12 }
13 let occ_map = occ_analysis(expr);
14 match try_inline(expr, &occ_map) {
15 Some(new_expr) => {
16 *expr = new_expr;
17 true
18 }
19 None => false,
20 }
21 }
22
23 fn name(&self) -> &str {
24 "Inline"
25 }
26}
27
28fn try_inline(expr: &CoreExpr, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
29 try_inline_at(expr, expr.nodes.len() - 1, occ_map)
30}
31
32fn try_inline_at(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
33 match &expr.nodes[idx] {
34 CoreFrame::LetNonRec { binder, rhs, body } => {
35 if get_occ(occ_map, *binder) == Occ::Once {
36 let body_tree = expr.extract_subtree(*body);
38 let rhs_tree = expr.extract_subtree(*rhs);
39 let inlined = tidepool_repr::subst::subst(&body_tree, *binder, &rhs_tree);
40 Some(replace_subtree(expr, idx, &inlined))
41 } else {
42 try_inline_at(expr, *rhs, occ_map).or_else(|| try_inline_at(expr, *body, occ_map))
44 }
45 }
46 _ => try_children(expr, idx, occ_map),
48 }
49}
50
51fn try_children(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
52 let children = get_children(&expr.nodes[idx]);
53 for child in children {
54 if let Some(result) = try_inline_at(expr, child, occ_map) {
55 return Some(result);
56 }
57 }
58 None
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use tidepool_eval::{eval, Env, VecHeap};
65 use tidepool_repr::{Literal, PrimOpKind, VarId};
66
67 fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
68 CoreExpr { nodes }
69 }
70
71 #[test]
73 fn test_inline_single_use() {
74 let x = VarId(1);
75 let mut expr = tree(vec![
76 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Var(x), CoreFrame::LetNonRec {
79 binder: x,
80 rhs: 0,
81 body: 1,
82 }, ]);
84 let pass = Inline;
85 let changed = pass.run(&mut expr);
86 assert!(changed);
87 assert_eq!(expr.nodes.len(), 1);
88 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
89 }
90
91 #[test]
93 fn test_inline_multi_use_preserved() {
94 let x = VarId(1);
95 let mut expr = tree(vec![
96 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Var(x), CoreFrame::Var(x), CoreFrame::PrimOp {
100 op: PrimOpKind::IntAdd,
101 args: vec![1, 2],
102 }, CoreFrame::LetNonRec {
104 binder: x,
105 rhs: 0,
106 body: 3,
107 }, ]);
109 let pass = Inline;
110 let changed = pass.run(&mut expr);
111 assert!(!changed);
112 }
113
114 #[test]
116 fn test_inline_dead_preserved() {
117 let x = VarId(1);
118 let mut expr = tree(vec![
119 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Lit(Literal::LitInt(0)), CoreFrame::LetNonRec {
122 binder: x,
123 rhs: 0,
124 body: 1,
125 }, ]);
127 let pass = Inline;
128 let changed = pass.run(&mut expr);
129 assert!(!changed);
130 }
131
132 #[test]
134 fn test_inline_nested() {
135 let x = VarId(1);
136 let y = VarId(2);
137 let mut expr = tree(vec![
138 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Var(x), CoreFrame::Var(y), CoreFrame::LetNonRec {
142 binder: y,
143 rhs: 1,
144 body: 2,
145 }, CoreFrame::LetNonRec {
147 binder: x,
148 rhs: 0,
149 body: 3,
150 }, ]);
152 let pass = Inline;
153
154 assert!(pass.run(&mut expr));
156 assert!(pass.run(&mut expr));
158 assert_eq!(expr.nodes.len(), 1);
160 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(1)));
161 }
162
163 #[test]
165 fn test_inline_letrec_not_inlined() {
166 let f = VarId(1);
167 let mut expr = tree(vec![
168 CoreFrame::Var(f), CoreFrame::Var(f), CoreFrame::LetRec {
171 bindings: vec![(f, 0)],
172 body: 1,
173 }, ]);
175 let pass = Inline;
176 let changed = pass.run(&mut expr);
177 assert!(!changed);
178 }
179
180 #[test]
182 fn test_inline_capture_avoiding() {
183 let x = VarId(1);
184 let y = VarId(2);
185 let mut expr = tree(vec![
186 CoreFrame::Var(y), CoreFrame::Var(x), CoreFrame::Lam { binder: y, body: 1 }, CoreFrame::LetNonRec {
190 binder: x,
191 rhs: 0,
192 body: 2,
193 }, ]);
195 let pass = Inline;
196 let changed = pass.run(&mut expr);
197 assert!(changed);
198
199 let root = expr.nodes.len() - 1;
201 if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
202 assert_ne!(*binder, y);
203 if let CoreFrame::Var(v) = &expr.nodes[*body] {
204 assert_eq!(*v, y);
205 } else {
206 panic!("Body should be Var(y)");
207 }
208 } else {
209 panic!("Result should be Lam");
210 }
211 }
212
213 #[test]
215 fn test_inline_preserves_eval() {
216 let x = VarId(1);
217
218 let expr_once = tree(vec![
220 CoreFrame::Lit(Literal::LitInt(21)),
221 CoreFrame::Var(x),
222 CoreFrame::LetNonRec {
223 binder: x,
224 rhs: 0,
225 body: 1,
226 },
227 ]);
228 let mut expr_once_reduced = expr_once.clone();
229 Inline.run(&mut expr_once_reduced);
230
231 let mut heap = VecHeap::new();
232 let env = Env::new();
233 let v1 = eval(&expr_once, &env, &mut heap).unwrap();
234 let v2 = eval(&expr_once_reduced, &env, &mut heap).unwrap();
235 match (v1, v2) {
236 (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) => assert_eq!(l1, l2),
237 _ => panic!("Expected literals"),
238 }
239
240 let mut expr_many = tree(vec![
242 CoreFrame::Lit(Literal::LitInt(21)),
243 CoreFrame::Var(x),
244 CoreFrame::Var(x),
245 CoreFrame::PrimOp {
246 op: PrimOpKind::IntAdd,
247 args: vec![1, 2],
248 },
249 CoreFrame::LetNonRec {
250 binder: x,
251 rhs: 0,
252 body: 3,
253 },
254 ]);
255 let expr_many_orig = expr_many.clone();
256 Inline.run(&mut expr_many);
257 assert_eq!(expr_many, expr_many_orig);
258 }
259}