vortex_array/expr/traversal/
fold.rs1use vortex_error::VortexResult;
5
6use crate::expr::traversal::Node;
7
8#[derive(Debug)]
13pub enum FoldDown<R> {
14 Continue,
15 Stop(R),
16 Skip(R),
17}
18
19#[derive(Debug)]
23pub enum FoldDownContext<C, R> {
24 Continue(C),
25 Stop(R),
26 Skip(R),
27}
28
29#[derive(Debug)]
32pub enum FoldUp<R> {
33 Continue(R),
34 Stop(R),
35}
36
37impl<R> FoldUp<R> {
38 pub fn value(self) -> R {
39 match self {
40 Self::Continue(r) => r,
41 Self::Stop(r) => r,
42 }
43 }
44}
45
46pub trait NodeFolderContext {
59 type NodeTy: Node;
60 type Result;
61 type Context;
62
63 fn visit_down(
68 &mut self,
69 _ctx: &Self::Context,
70 _node: &Self::NodeTy,
71 ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>>;
72
73 fn visit_up(
77 &mut self,
78 _node: Self::NodeTy,
79 _context: &Self::Context,
80 _children: Vec<Self::Result>,
81 ) -> VortexResult<FoldUp<Self::Result>>;
82}
83
84pub trait NodeFolder {
86 type NodeTy: Node;
87 type Result;
88
89 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
94 Ok(FoldDown::Continue)
95 }
96
97 fn visit_up(
101 &mut self,
102 _node: Self::NodeTy,
103 _children: Vec<Self::Result>,
104 ) -> VortexResult<FoldUp<Self::Result>>;
105}
106
107pub(crate) struct NodeFolderContextWrapper<'a, T>
108where
109 T: NodeFolder,
110{
111 pub inner: &'a mut T,
112}
113
114impl<T: NodeFolder> NodeFolderContext for NodeFolderContextWrapper<'_, T> {
115 type NodeTy = T::NodeTy;
116 type Result = T::Result;
117 type Context = ();
118
119 fn visit_down(
120 &mut self,
121 _ctx: &Self::Context,
122 _node: &Self::NodeTy,
123 ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>> {
124 match self.inner.visit_down(_node)? {
125 FoldDown::Continue => Ok(FoldDownContext::Continue(())),
126 FoldDown::Stop(r) => Ok(FoldDownContext::Stop(r)),
127 FoldDown::Skip(r) => Ok(FoldDownContext::Skip(r)),
128 }
129 }
130
131 fn visit_up(
132 &mut self,
133 _node: Self::NodeTy,
134 _context: &Self::Context,
135 _children: Vec<Self::Result>,
136 ) -> VortexResult<FoldUp<Self::Result>> {
137 self.inner.visit_up(_node, _children)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use vortex_error::VortexExpect;
144 use vortex_error::vortex_bail;
145
146 use super::*;
147 use crate::expr::Expression;
148 use crate::expr::exprs::binary::Binary;
149 use crate::expr::exprs::binary::checked_add;
150 use crate::expr::exprs::binary::gt;
151 use crate::expr::exprs::literal::Literal;
152 use crate::expr::exprs::literal::lit;
153 use crate::expr::exprs::operators::Operator;
154 use crate::expr::traversal::NodeExt;
155
156 struct AddFold;
157 impl NodeFolder for AddFold {
158 type NodeTy = Expression;
159 type Result = i32;
160
161 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
162 if let Some(lit) = node.as_opt::<Literal>() {
163 let v = lit
164 .data()
165 .as_primitive()
166 .typed_value::<i32>()
167 .vortex_expect("i32");
168
169 if v == 5 {
170 return Ok(FoldDown::Stop(5));
171 }
172 }
173
174 if let Some(binary) = node.as_opt::<Binary>()
175 && binary.operator() == Operator::Gt
176 {
177 return Ok(FoldDown::Skip(0));
178 }
179
180 Ok(FoldDown::Continue)
181 }
182
183 fn visit_up(
184 &mut self,
185 node: Self::NodeTy,
186 children: Vec<Self::Result>,
187 ) -> VortexResult<FoldUp<Self::Result>> {
188 if let Some(lit) = node.as_opt::<Literal>() {
189 let v = lit
190 .data()
191 .as_primitive()
192 .typed_value::<i32>()
193 .vortex_expect("i32");
194 Ok(FoldUp::Continue(v))
195 } else if let Some(binary) = node.as_opt::<Binary>() {
196 if binary.operator() == Operator::Add {
197 Ok(FoldUp::Continue(children[0] + children[1]))
198 } else {
199 vortex_bail!("not a valid operator")
200 }
201 } else {
202 vortex_bail!("not a valid type")
203 }
204 }
205 }
206
207 #[test]
208 fn test_fold() {
209 let expr = checked_add(checked_add(lit(1), lit(2)), lit(3));
210
211 let mut folder = AddFold;
212 let result = expr.fold(&mut folder).unwrap().value();
213 assert_eq!(result, 6);
214 }
215
216 #[test]
217 fn test_stop_value() {
218 let expr = checked_add(checked_add(lit(1), lit(5)), lit(3));
219
220 let mut folder = AddFold;
221 let result = expr.fold(&mut folder).unwrap().value();
222 assert_eq!(result, 5);
223 }
224
225 #[test]
226 fn test_skip_value() {
227 let expr = checked_add(gt(lit(1), lit(2)), lit(3));
228
229 let mut folder = AddFold;
230 let result = expr.fold(&mut folder).unwrap().value();
231 assert_eq!(result, 3);
232 }
233
234 #[test]
235 fn test_control_flow_value() {
236 let expr = checked_add(gt(lit(1), lit(5)), lit(3));
237
238 let mut folder = AddFold;
239 let result = expr.fold(&mut folder).unwrap().value();
240 assert_eq!(result, 3);
241 }
242}