1use crate::occ::{get_occ, occ_analysis, Occ};
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{CoreExpr, CoreFrame, MapLayer};
4use std::collections::HashMap;
5
6pub struct Inline;
8
9impl Pass for Inline {
10 fn run(&self, expr: &mut CoreExpr) -> Changed {
11 if expr.nodes.is_empty() {
12 return false;
13 }
14 let occ_map = occ_analysis(expr);
15 match try_inline(expr, &occ_map) {
16 Some(new_expr) => {
17 *expr = new_expr;
18 true
19 }
20 None => false,
21 }
22 }
23
24 fn name(&self) -> &str {
25 "Inline"
26 }
27}
28
29fn try_inline(expr: &CoreExpr, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
30 try_inline_at(expr, expr.nodes.len() - 1, occ_map)
31}
32
33fn try_inline_at(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
34 match &expr.nodes[idx] {
35 CoreFrame::LetNonRec { binder, rhs, body } => {
36 if get_occ(occ_map, *binder) == Occ::Once {
37 let body_tree = expr.extract_subtree(*body);
39 let rhs_tree = expr.extract_subtree(*rhs);
40 let inlined = tidepool_repr::subst::subst(&body_tree, *binder, &rhs_tree);
41 Some(replace_subtree(expr, idx, &inlined))
42 } else {
43 try_inline_at(expr, *rhs, occ_map).or_else(|| try_inline_at(expr, *body, occ_map))
45 }
46 }
47 _ => try_children(expr, idx, occ_map),
49 }
50}
51
52fn try_children(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
53 let children = get_children(&expr.nodes[idx]);
54 for child in children {
55 if let Some(result) = try_inline_at(expr, child, occ_map) {
56 return Some(result);
57 }
58 }
59 None
60}
61
62fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
63 match frame {
64 CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
65 CoreFrame::App { fun, arg } => vec![*fun, *arg],
66 CoreFrame::Lam { body, .. } => vec![*body],
67 CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
68 CoreFrame::LetRec { bindings, body, .. } => {
69 let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
70 c.push(*body);
71 c
72 }
73 CoreFrame::Case {
74 scrutinee, alts, ..
75 } => {
76 let mut c = vec![*scrutinee];
77 for alt in alts {
78 c.push(alt.body);
79 }
80 c
81 }
82 CoreFrame::Con { fields, .. } => fields.clone(),
83 CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
84 CoreFrame::Jump { args, .. } => args.clone(),
85 CoreFrame::PrimOp { args, .. } => args.clone(),
86 }
87}
88
89fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
90 let mut new_nodes = Vec::new();
91 let mut old_to_new = HashMap::new();
92 rebuild(
93 expr,
94 expr.nodes.len() - 1,
95 target_idx,
96 replacement,
97 &mut new_nodes,
98 &mut old_to_new,
99 );
100 CoreExpr { nodes: new_nodes }
101}
102
103fn rebuild(
104 expr: &CoreExpr,
105 idx: usize,
106 target: usize,
107 replacement: &CoreExpr,
108 new_nodes: &mut Vec<CoreFrame<usize>>,
109 old_to_new: &mut HashMap<usize, usize>,
110) -> usize {
111 if let Some(&ni) = old_to_new.get(&idx) {
112 return ni;
113 }
114 if idx == target {
115 let offset = new_nodes.len();
116 for node in &replacement.nodes {
117 new_nodes.push(node.clone().map_layer(|i| i + offset));
118 }
119 let root = new_nodes.len() - 1;
120 old_to_new.insert(idx, root);
121 return root;
122 }
123 let mapped = expr.nodes[idx]
124 .clone()
125 .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
126 let new_idx = new_nodes.len();
127 new_nodes.push(mapped);
128 old_to_new.insert(idx, new_idx);
129 new_idx
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use tidepool_eval::{eval, Env, VecHeap};
136 use tidepool_repr::{Literal, PrimOpKind, VarId};
137
138 fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
139 CoreExpr { nodes }
140 }
141
142 #[test]
144 fn test_inline_single_use() {
145 let x = VarId(1);
146 let mut expr = tree(vec![
147 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Var(x), CoreFrame::LetNonRec {
150 binder: x,
151 rhs: 0,
152 body: 1,
153 }, ]);
155 let pass = Inline;
156 let changed = pass.run(&mut expr);
157 assert!(changed);
158 assert_eq!(expr.nodes.len(), 1);
159 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
160 }
161
162 #[test]
164 fn test_inline_multi_use_preserved() {
165 let x = VarId(1);
166 let mut expr = tree(vec![
167 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Var(x), CoreFrame::Var(x), CoreFrame::PrimOp {
171 op: PrimOpKind::IntAdd,
172 args: vec![1, 2],
173 }, CoreFrame::LetNonRec {
175 binder: x,
176 rhs: 0,
177 body: 3,
178 }, ]);
180 let pass = Inline;
181 let changed = pass.run(&mut expr);
182 assert!(!changed);
183 }
184
185 #[test]
187 fn test_inline_dead_preserved() {
188 let x = VarId(1);
189 let mut expr = tree(vec![
190 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Lit(Literal::LitInt(0)), CoreFrame::LetNonRec {
193 binder: x,
194 rhs: 0,
195 body: 1,
196 }, ]);
198 let pass = Inline;
199 let changed = pass.run(&mut expr);
200 assert!(!changed);
201 }
202
203 #[test]
205 fn test_inline_nested() {
206 let x = VarId(1);
207 let y = VarId(2);
208 let mut expr = tree(vec![
209 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Var(x), CoreFrame::Var(y), CoreFrame::LetNonRec {
213 binder: y,
214 rhs: 1,
215 body: 2,
216 }, CoreFrame::LetNonRec {
218 binder: x,
219 rhs: 0,
220 body: 3,
221 }, ]);
223 let pass = Inline;
224
225 assert!(pass.run(&mut expr));
227 assert!(pass.run(&mut expr));
229 assert_eq!(expr.nodes.len(), 1);
231 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(1)));
232 }
233
234 #[test]
236 fn test_inline_letrec_not_inlined() {
237 let f = VarId(1);
238 let mut expr = tree(vec![
239 CoreFrame::Var(f), CoreFrame::Var(f), CoreFrame::LetRec {
242 bindings: vec![(f, 0)],
243 body: 1,
244 }, ]);
246 let pass = Inline;
247 let changed = pass.run(&mut expr);
248 assert!(!changed);
249 }
250
251 #[test]
253 fn test_inline_capture_avoiding() {
254 let x = VarId(1);
255 let y = VarId(2);
256 let mut expr = tree(vec![
257 CoreFrame::Var(y), CoreFrame::Var(x), CoreFrame::Lam { binder: y, body: 1 }, CoreFrame::LetNonRec {
261 binder: x,
262 rhs: 0,
263 body: 2,
264 }, ]);
266 let pass = Inline;
267 let changed = pass.run(&mut expr);
268 assert!(changed);
269
270 let root = expr.nodes.len() - 1;
272 if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
273 assert_ne!(*binder, y);
274 if let CoreFrame::Var(v) = &expr.nodes[*body] {
275 assert_eq!(*v, y);
276 } else {
277 panic!("Body should be Var(y)");
278 }
279 } else {
280 panic!("Result should be Lam");
281 }
282 }
283
284 #[test]
286 fn test_inline_preserves_eval() {
287 let x = VarId(1);
288
289 let expr_once = tree(vec![
291 CoreFrame::Lit(Literal::LitInt(21)),
292 CoreFrame::Var(x),
293 CoreFrame::LetNonRec {
294 binder: x,
295 rhs: 0,
296 body: 1,
297 },
298 ]);
299 let mut expr_once_reduced = expr_once.clone();
300 Inline.run(&mut expr_once_reduced);
301
302 let mut heap = VecHeap::new();
303 let env = Env::new();
304 let v1 = eval(&expr_once, &env, &mut heap).unwrap();
305 let v2 = eval(&expr_once_reduced, &env, &mut heap).unwrap();
306 match (v1, v2) {
307 (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) => assert_eq!(l1, l2),
308 _ => panic!("Expected literals"),
309 }
310
311 let mut expr_many = tree(vec![
313 CoreFrame::Lit(Literal::LitInt(21)),
314 CoreFrame::Var(x),
315 CoreFrame::Var(x),
316 CoreFrame::PrimOp {
317 op: PrimOpKind::IntAdd,
318 args: vec![1, 2],
319 },
320 CoreFrame::LetNonRec {
321 binder: x,
322 rhs: 0,
323 body: 3,
324 },
325 ]);
326 let expr_many_orig = expr_many.clone();
327 Inline.run(&mut expr_many);
328 assert_eq!(expr_many, expr_many_orig);
329 }
330}