vortex_expr/traversal/
fold.rs1use vortex_error::VortexResult;
5
6use crate::traversal::Node;
7
8#[derive(Debug)]
13pub enum FoldDown<R> {
14 Continue,
15 Stop(R),
16 Skip(R),
17}
18
19#[derive(Debug)]
23pub enum FoldDownContext<C, R> {
24 Continue(C),
25 Stop(R),
26 Skip(R),
27}
28
29#[derive(Debug)]
32pub enum FoldUp<R> {
33 Continue(R),
34 Stop(R),
35}
36
37impl<R> FoldUp<R> {
38 pub fn value(self) -> R {
39 match self {
40 Self::Continue(r) => r,
41 Self::Stop(r) => r,
42 }
43 }
44}
45
46pub trait NodeFolderContext {
59 type NodeTy: Node;
60 type Result;
61 type Context;
62
63 fn visit_down(
68 &mut self,
69 _ctx: &Self::Context,
70 _node: &Self::NodeTy,
71 ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>>;
72
73 fn visit_up(
77 &mut self,
78 _node: Self::NodeTy,
79 _context: &Self::Context,
80 _children: Vec<Self::Result>,
81 ) -> VortexResult<FoldUp<Self::Result>>;
82}
83
84pub trait NodeFolder {
86 type NodeTy: Node;
87 type Result;
88
89 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
94 Ok(FoldDown::Continue)
95 }
96
97 fn visit_up(
101 &mut self,
102 _node: Self::NodeTy,
103 _children: Vec<Self::Result>,
104 ) -> VortexResult<FoldUp<Self::Result>>;
105}
106
107pub(crate) struct NodeFolderContextWrapper<'a, T>
108where
109 T: NodeFolder,
110{
111 pub inner: &'a mut T,
112}
113
114impl<T: NodeFolder> NodeFolderContext for NodeFolderContextWrapper<'_, T> {
115 type NodeTy = T::NodeTy;
116 type Result = T::Result;
117 type Context = ();
118
119 fn visit_down(
120 &mut self,
121 _ctx: &Self::Context,
122 _node: &Self::NodeTy,
123 ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>> {
124 match self.inner.visit_down(_node)? {
125 FoldDown::Continue => Ok(FoldDownContext::Continue(())),
126 FoldDown::Stop(r) => Ok(FoldDownContext::Stop(r)),
127 FoldDown::Skip(r) => Ok(FoldDownContext::Skip(r)),
128 }
129 }
130
131 fn visit_up(
132 &mut self,
133 _node: Self::NodeTy,
134 _context: &Self::Context,
135 _children: Vec<Self::Result>,
136 ) -> VortexResult<FoldUp<Self::Result>> {
137 self.inner.visit_up(_node, _children)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use vortex_error::VortexExpect;
144
145 use super::*;
146 use crate::traversal::NodeExt;
147 use crate::{
148 BinaryVTable, ExprRef, LiteralVTable, Operator, checked_add, gt, lit, vortex_bail,
149 };
150
151 struct AddFold;
152 impl NodeFolder for AddFold {
153 type NodeTy = ExprRef;
154 type Result = i32;
155
156 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
157 if let Some(lit) = node.as_opt::<LiteralVTable>() {
158 let v = lit
159 .value()
160 .as_primitive()
161 .typed_value::<i32>()
162 .vortex_expect("i32");
163
164 if v == 5 {
165 return Ok(FoldDown::Stop(5));
166 }
167 }
168
169 if let Some(binary) = node.as_opt::<BinaryVTable>()
170 && binary.op() == Operator::Gt
171 {
172 return Ok(FoldDown::Skip(0));
173 }
174
175 Ok(FoldDown::Continue)
176 }
177
178 fn visit_up(
179 &mut self,
180 node: Self::NodeTy,
181 children: Vec<Self::Result>,
182 ) -> VortexResult<FoldUp<Self::Result>> {
183 if let Some(lit) = node.as_opt::<LiteralVTable>() {
184 let v = lit
185 .value()
186 .as_primitive()
187 .typed_value::<i32>()
188 .vortex_expect("i32");
189 Ok(FoldUp::Continue(v))
190 } else if let Some(binary) = node.as_opt::<BinaryVTable>() {
191 if binary.op() == Operator::Add {
192 Ok(FoldUp::Continue(children[0] + children[1]))
193 } else {
194 vortex_bail!("not a valid operator")
195 }
196 } else {
197 vortex_bail!("not a valid type")
198 }
199 }
200 }
201
202 #[test]
203 fn test_fold() {
204 let expr = checked_add(checked_add(lit(1), lit(2)), lit(3));
205
206 let mut folder = AddFold;
207 let result = expr.fold(&mut folder).unwrap().value();
208 assert_eq!(result, 6);
209 }
210
211 #[test]
212 fn test_stop_value() {
213 let expr = checked_add(checked_add(lit(1), lit(5)), lit(3));
214
215 let mut folder = AddFold;
216 let result = expr.fold(&mut folder).unwrap().value();
217 assert_eq!(result, 5);
218 }
219
220 #[test]
221 fn test_skip_value() {
222 let expr = checked_add(gt(lit(1), lit(2)), lit(3));
223
224 let mut folder = AddFold;
225 let result = expr.fold(&mut folder).unwrap().value();
226 assert_eq!(result, 3);
227 }
228
229 #[test]
230 fn test_control_flow_value() {
231 let expr = checked_add(gt(lit(1), lit(5)), lit(3));
232
233 let mut folder = AddFold;
234 let result = expr.fold(&mut folder).unwrap().value();
235 assert_eq!(result, 3);
236 }
237}