1use crate::frame::CoreFrame;
2use crate::types::Alt;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct RecursiveTree<F> {
8 pub nodes: Vec<F>,
9}
10
11impl<F> RecursiveTree<F>
12where
13 F: MapLayer<usize, usize, Output = F> + Clone,
14{
15 pub fn extract_subtree(&self, idx: usize) -> Self {
17 let mut new_nodes = Vec::new();
18 let mut old_to_new = HashMap::new();
19
20 fn collect<F>(
21 idx: usize,
22 tree: &RecursiveTree<F>,
23 new_nodes: &mut Vec<F>,
24 old_to_new: &mut HashMap<usize, usize>,
25 ) -> usize
26 where
27 F: MapLayer<usize, usize, Output = F> + Clone,
28 {
29 if let Some(&new_idx) = old_to_new.get(&idx) {
30 return new_idx;
31 }
32
33 let frame = &tree.nodes[idx];
34 let mapped = frame
35 .clone()
36 .map_layer(|child| collect(child, tree, new_nodes, old_to_new));
37 let new_idx = new_nodes.len();
38 new_nodes.push(mapped);
39 old_to_new.insert(idx, new_idx);
40 new_idx
41 }
42
43 collect(idx, self, &mut new_nodes, &mut old_to_new);
44 RecursiveTree { nodes: new_nodes }
45 }
46}
47
48pub fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
50 match frame {
51 CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
52 CoreFrame::App { fun, arg } => vec![*fun, *arg],
53 CoreFrame::Lam { body, .. } => vec![*body],
54 CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
55 CoreFrame::LetRec { bindings, body } => {
56 let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
57 c.push(*body);
58 c
59 }
60 CoreFrame::Case {
61 scrutinee,
62 alts,
63 binder: _,
64 } => {
65 let mut c = vec![*scrutinee];
66 for alt in alts {
67 c.push(alt.body);
68 }
69 c
70 }
71 CoreFrame::Con { fields, .. } => fields.clone(),
72 CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
73 CoreFrame::Jump { args, .. } => args.clone(),
74 CoreFrame::PrimOp { args, .. } => args.clone(),
75 }
76}
77
78pub fn replace_subtree(
80 expr: &RecursiveTree<CoreFrame<usize>>,
81 target_idx: usize,
82 replacement: &RecursiveTree<CoreFrame<usize>>,
83) -> RecursiveTree<CoreFrame<usize>> {
84 if expr.nodes.is_empty() {
85 return expr.clone();
86 }
87 if replacement.nodes.is_empty() {
88 return expr.clone();
90 }
91 assert!(
92 target_idx < expr.nodes.len(),
93 "target_idx {} out of bounds (len {})",
94 target_idx,
95 expr.nodes.len()
96 );
97
98 let mut new_nodes = Vec::new();
99 let mut old_to_new = HashMap::new();
100 rebuild(
101 expr,
102 expr.nodes.len() - 1,
103 target_idx,
104 replacement,
105 &mut new_nodes,
106 &mut old_to_new,
107 );
108 RecursiveTree { nodes: new_nodes }
109}
110
111fn rebuild(
112 expr: &RecursiveTree<CoreFrame<usize>>,
113 idx: usize,
114 target: usize,
115 replacement: &RecursiveTree<CoreFrame<usize>>,
116 new_nodes: &mut Vec<CoreFrame<usize>>,
117 old_to_new: &mut HashMap<usize, usize>,
118) -> usize {
119 if let Some(&ni) = old_to_new.get(&idx) {
120 return ni;
121 }
122 if idx == target {
123 let offset = new_nodes.len();
124 for node in &replacement.nodes {
125 new_nodes.push(node.clone().map_layer(|i| i + offset));
126 }
127 let root = new_nodes
128 .len()
129 .checked_sub(1)
130 .expect("replacement tree must not be empty");
131 old_to_new.insert(idx, root);
132 return root;
133 }
134 let mapped = expr.nodes[idx]
135 .clone()
136 .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
137 let new_idx = new_nodes.len();
138 new_nodes.push(mapped);
139 old_to_new.insert(idx, new_idx);
140 new_idx
141}
142
143pub trait MapLayer<A, B> {
145 type Output;
146 fn map_layer(self, f: impl FnMut(A) -> B) -> Self::Output;
147}
148
149impl<A, B> MapLayer<A, B> for CoreFrame<A> {
150 type Output = CoreFrame<B>;
151 fn map_layer(self, mut f: impl FnMut(A) -> B) -> CoreFrame<B> {
152 match self {
153 CoreFrame::Var(v) => CoreFrame::Var(v),
154 CoreFrame::Lit(l) => CoreFrame::Lit(l),
155 CoreFrame::App { fun, arg } => CoreFrame::App {
156 fun: f(fun),
157 arg: f(arg),
158 },
159 CoreFrame::Lam { binder, body } => CoreFrame::Lam {
160 binder,
161 body: f(body),
162 },
163 CoreFrame::LetNonRec { binder, rhs, body } => CoreFrame::LetNonRec {
164 binder,
165 rhs: f(rhs),
166 body: f(body),
167 },
168 CoreFrame::LetRec { bindings, body } => CoreFrame::LetRec {
169 bindings: bindings.into_iter().map(|(id, rhs)| (id, f(rhs))).collect(),
170 body: f(body),
171 },
172 CoreFrame::Case {
173 scrutinee,
174 binder,
175 alts,
176 } => CoreFrame::Case {
177 scrutinee: f(scrutinee),
178 binder,
179 alts: alts
180 .into_iter()
181 .map(|alt| Alt {
182 con: alt.con,
183 binders: alt.binders,
184 body: f(alt.body),
185 })
186 .collect(),
187 },
188 CoreFrame::Con { tag, fields } => CoreFrame::Con {
189 tag,
190 fields: fields.into_iter().map(f).collect(),
191 },
192 CoreFrame::Join {
193 label,
194 params,
195 rhs,
196 body,
197 } => CoreFrame::Join {
198 label,
199 params,
200 rhs: f(rhs),
201 body: f(body),
202 },
203 CoreFrame::Jump { label, args } => CoreFrame::Jump {
204 label,
205 args: args.into_iter().map(f).collect(),
206 },
207 CoreFrame::PrimOp { op, args } => CoreFrame::PrimOp {
208 op,
209 args: args.into_iter().map(f).collect(),
210 },
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::types::*;
219
220 fn sample_frames() -> Vec<CoreFrame<usize>> {
221 vec![
222 CoreFrame::Var(VarId(1)), CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::App { fun: 0, arg: 1 }, CoreFrame::Lam {
226 binder: VarId(2),
227 body: 0,
228 }, CoreFrame::LetNonRec {
230 binder: VarId(3),
231 rhs: 1,
232 body: 2,
233 }, CoreFrame::LetRec {
235 bindings: vec![(VarId(4), 0), (VarId(5), 1)],
236 body: 2,
237 }, CoreFrame::Case {
239 scrutinee: 0,
240 binder: VarId(6),
241 alts: vec![Alt {
242 con: AltCon::Default,
243 binders: vec![],
244 body: 1,
245 }],
246 }, CoreFrame::Con {
248 tag: DataConId(7),
249 fields: vec![0, 1],
250 }, CoreFrame::Join {
252 label: JoinId(8),
253 params: vec![VarId(9)],
254 rhs: 0,
255 body: 1,
256 }, CoreFrame::Jump {
258 label: JoinId(10),
259 args: vec![0, 1],
260 }, CoreFrame::PrimOp {
262 op: PrimOpKind::IntAdd,
263 args: vec![0, 1],
264 }, ]
266 }
267
268 #[test]
269 fn test_get_children() {
270 let frames = sample_frames();
271 assert_eq!(get_children(&frames[0]), Vec::<usize>::new()); assert_eq!(get_children(&frames[1]), Vec::<usize>::new()); assert_eq!(get_children(&frames[2]), vec![0, 1]); assert_eq!(get_children(&frames[3]), vec![0]); assert_eq!(get_children(&frames[4]), vec![1, 2]); assert_eq!(get_children(&frames[5]), vec![0, 1, 2]); assert_eq!(get_children(&frames[6]), vec![0, 1]); assert_eq!(get_children(&frames[7]), vec![0, 1]); assert_eq!(get_children(&frames[8]), vec![0, 1]); assert_eq!(get_children(&frames[9]), vec![0, 1]); assert_eq!(get_children(&frames[10]), vec![0, 1]); }
283
284 #[test]
285 fn test_replace_subtree_root() {
286 let nodes = vec![
287 CoreFrame::Lit(Literal::LitInt(1)), ];
289 let expr = RecursiveTree { nodes };
290 let replacement = RecursiveTree {
291 nodes: vec![CoreFrame::Lit(Literal::LitInt(2))],
292 };
293 let result = replace_subtree(&expr, 0, &replacement);
294 assert_eq!(result.nodes.len(), 1);
295 assert_eq!(result.nodes[0], CoreFrame::Lit(Literal::LitInt(2)));
296 }
297
298 #[test]
299 fn test_replace_subtree_nested() {
300 let nodes = vec![
302 CoreFrame::Var(VarId(1)), CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::App { fun: 0, arg: 1 }, ];
306 let expr = RecursiveTree { nodes };
307
308 let replacement = RecursiveTree {
310 nodes: vec![CoreFrame::Lit(Literal::LitInt(2))],
311 };
312 let result = replace_subtree(&expr, 1, &replacement);
313
314 let root_idx = result.nodes.len() - 1;
317 if let CoreFrame::App { fun, arg } = &result.nodes[root_idx] {
318 assert_eq!(result.nodes[*fun], CoreFrame::Var(VarId(1)));
319 assert_eq!(result.nodes[*arg], CoreFrame::Lit(Literal::LitInt(2)));
320 } else {
321 panic!("Root should be App");
322 }
323 }
324
325 #[test]
326 fn test_map_layer_identity() {
327 for frame in sample_frames() {
328 let mapped = frame.clone().map_layer(|x| x);
329 assert_eq!(frame, mapped);
330 }
331 }
332
333 #[test]
334 fn test_map_layer_composition() {
335 let f = |x: usize| x + 10;
336 let g = |x: usize| x * 2;
337
338 for frame in sample_frames() {
339 let direct = frame.clone().map_layer(|x| g(f(x)));
340 let composed = frame.map_layer(f).map_layer(g);
341 assert_eq!(direct, composed);
342 }
343 }
344
345 #[test]
346 fn test_recursive_tree_construction() {
347 let nodes = vec![
349 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Var(VarId(1)), CoreFrame::App { fun: 0, arg: 1 }, ];
353 let tree = RecursiveTree { nodes };
354
355 assert_eq!(tree.nodes.len(), 3);
356 if let CoreFrame::App { fun, arg } = &tree.nodes[2] {
357 assert_eq!(*fun, 0);
358 assert_eq!(*arg, 1);
359 } else {
360 panic!("Root should be an App");
361 }
362 }
363}