vortex_array/expr/traversal/
fold.rs1use vortex_error::VortexResult;
5
6use crate::expr::traversal::Node;
7
8#[derive(Debug)]
13pub enum FoldDown<R> {
14 Continue,
15 Stop(R),
16 Skip(R),
17}
18
19#[derive(Debug)]
23pub enum FoldDownContext<C, R> {
24 Continue(C),
25 Stop(R),
26 Skip(R),
27}
28
29#[derive(Debug)]
32pub enum FoldUp<R> {
33 Continue(R),
34 Stop(R),
35}
36
37impl<R> FoldUp<R> {
38 pub fn value(self) -> R {
39 match self {
40 Self::Continue(r) => r,
41 Self::Stop(r) => r,
42 }
43 }
44}
45
46pub trait NodeFolderContext {
59 type NodeTy: Node;
60 type Result;
61 type Context;
62
63 fn visit_down(
68 &mut self,
69 _ctx: &Self::Context,
70 _node: &Self::NodeTy,
71 ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>>;
72
73 fn visit_up(
77 &mut self,
78 _node: Self::NodeTy,
79 _context: &Self::Context,
80 _children: Vec<Self::Result>,
81 ) -> VortexResult<FoldUp<Self::Result>>;
82}
83
84pub trait NodeFolder {
86 type NodeTy: Node;
87 type Result;
88
89 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
94 Ok(FoldDown::Continue)
95 }
96
97 fn visit_up(
101 &mut self,
102 _node: Self::NodeTy,
103 _children: Vec<Self::Result>,
104 ) -> VortexResult<FoldUp<Self::Result>>;
105}
106
107pub(crate) struct NodeFolderContextWrapper<'a, T>
108where
109 T: NodeFolder,
110{
111 pub inner: &'a mut T,
112}
113
114impl<T: NodeFolder> NodeFolderContext for NodeFolderContextWrapper<'_, T> {
115 type NodeTy = T::NodeTy;
116 type Result = T::Result;
117 type Context = ();
118
119 fn visit_down(
120 &mut self,
121 _ctx: &Self::Context,
122 _node: &Self::NodeTy,
123 ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>> {
124 match self.inner.visit_down(_node)? {
125 FoldDown::Continue => Ok(FoldDownContext::Continue(())),
126 FoldDown::Stop(r) => Ok(FoldDownContext::Stop(r)),
127 FoldDown::Skip(r) => Ok(FoldDownContext::Skip(r)),
128 }
129 }
130
131 fn visit_up(
132 &mut self,
133 _node: Self::NodeTy,
134 _context: &Self::Context,
135 _children: Vec<Self::Result>,
136 ) -> VortexResult<FoldUp<Self::Result>> {
137 self.inner.visit_up(_node, _children)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use vortex_error::{VortexExpect, vortex_bail};
144
145 use super::*;
146 use crate::expr::Expression;
147 use crate::expr::exprs::binary::{Binary, checked_add, gt};
148 use crate::expr::exprs::literal::{Literal, lit};
149 use crate::expr::exprs::operators::Operator;
150 use crate::expr::traversal::NodeExt;
151
152 struct AddFold;
153 impl NodeFolder for AddFold {
154 type NodeTy = Expression;
155 type Result = i32;
156
157 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
158 if let Some(lit) = node.as_opt::<Literal>() {
159 let v = lit
160 .data()
161 .as_primitive()
162 .typed_value::<i32>()
163 .vortex_expect("i32");
164
165 if v == 5 {
166 return Ok(FoldDown::Stop(5));
167 }
168 }
169
170 if let Some(binary) = node.as_opt::<Binary>()
171 && binary.operator() == Operator::Gt
172 {
173 return Ok(FoldDown::Skip(0));
174 }
175
176 Ok(FoldDown::Continue)
177 }
178
179 fn visit_up(
180 &mut self,
181 node: Self::NodeTy,
182 children: Vec<Self::Result>,
183 ) -> VortexResult<FoldUp<Self::Result>> {
184 if let Some(lit) = node.as_opt::<Literal>() {
185 let v = lit
186 .data()
187 .as_primitive()
188 .typed_value::<i32>()
189 .vortex_expect("i32");
190 Ok(FoldUp::Continue(v))
191 } else if let Some(binary) = node.as_opt::<Binary>() {
192 if binary.operator() == Operator::Add {
193 Ok(FoldUp::Continue(children[0] + children[1]))
194 } else {
195 vortex_bail!("not a valid operator")
196 }
197 } else {
198 vortex_bail!("not a valid type")
199 }
200 }
201 }
202
203 #[test]
204 fn test_fold() {
205 let expr = checked_add(checked_add(lit(1), lit(2)), lit(3));
206
207 let mut folder = AddFold;
208 let result = expr.fold(&mut folder).unwrap().value();
209 assert_eq!(result, 6);
210 }
211
212 #[test]
213 fn test_stop_value() {
214 let expr = checked_add(checked_add(lit(1), lit(5)), lit(3));
215
216 let mut folder = AddFold;
217 let result = expr.fold(&mut folder).unwrap().value();
218 assert_eq!(result, 5);
219 }
220
221 #[test]
222 fn test_skip_value() {
223 let expr = checked_add(gt(lit(1), lit(2)), lit(3));
224
225 let mut folder = AddFold;
226 let result = expr.fold(&mut folder).unwrap().value();
227 assert_eq!(result, 3);
228 }
229
230 #[test]
231 fn test_control_flow_value() {
232 let expr = checked_add(gt(lit(1), lit(5)), lit(3));
233
234 let mut folder = AddFold;
235 let result = expr.fold(&mut folder).unwrap().value();
236 assert_eq!(result, 3);
237 }
238}