1mod references;
2mod visitor;
3
4pub use references::ReferenceCollector;
5pub use visitor::{pre_order_visit_down, pre_order_visit_up};
6use vortex_error::VortexResult;
7
8use crate::ExprRef;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum TraversalOrder {
20 Skip,
23
24 Stop,
26
27 Continue,
29}
30
31#[derive(Debug, Clone)]
32pub struct TransformResult<T> {
33 result: T,
34 order: TraversalOrder,
35 changed: bool,
36}
37
38impl<T> TransformResult<T> {
39 pub fn yes(result: T) -> Self {
40 Self {
41 result,
42 order: TraversalOrder::Continue,
43 changed: true,
44 }
45 }
46
47 pub fn no(result: T) -> Self {
48 Self {
49 result,
50 order: TraversalOrder::Continue,
51 changed: false,
52 }
53 }
54
55 pub fn into_inner(self) -> T {
56 self.result
57 }
58}
59
60pub trait NodeVisitor<'a> {
61 type NodeTy: Node;
62
63 fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
64 Ok(TraversalOrder::Continue)
65 }
66
67 fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
68 Ok(TraversalOrder::Continue)
69 }
70}
71
72pub trait MutNodeVisitor {
73 type NodeTy: Node;
74
75 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
76 Ok(TraversalOrder::Continue)
77 }
78
79 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
80}
81
82pub enum FoldDown<Out, Context> {
83 Abort(Out),
85 SkipChildren(Out),
87 Continue(Context),
89}
90
91#[derive(Debug)]
92pub enum FoldUp<Out> {
93 Abort(Out),
95 Continue(Out),
97}
98
99impl<Out> FoldUp<Out> {
100 pub fn result(self) -> Out {
101 match self {
102 FoldUp::Abort(out) => out,
103 FoldUp::Continue(out) => out,
104 }
105 }
106}
107
108pub trait Folder<'a> {
109 type NodeTy: Node;
110 type Out;
111 type Context: Clone;
112
113 fn visit_down(
114 &mut self,
115 _node: &'a Self::NodeTy,
116 context: Self::Context,
117 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
118 Ok(FoldDown::Continue(context))
119 }
120
121 fn visit_up(
122 &mut self,
123 node: &'a Self::NodeTy,
124 context: Self::Context,
125 children: Vec<Self::Out>,
126 ) -> VortexResult<FoldUp<Self::Out>>;
127}
128
129pub trait FolderMut {
130 type NodeTy: Node;
131 type Out;
132 type Context: Clone;
133
134 fn visit_down(
135 &mut self,
136 _node: &Self::NodeTy,
137 context: Self::Context,
138 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
139 Ok(FoldDown::Continue(context))
140 }
141
142 fn visit_up(
143 &mut self,
144 node: Self::NodeTy,
145 context: Self::Context,
146 children: Vec<Self::Out>,
147 ) -> VortexResult<FoldUp<Self::Out>>;
148}
149
150pub trait Node: Sized {
151 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
152 &'a self,
153 _visitor: &mut V,
154 ) -> VortexResult<TraversalOrder>;
155
156 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
157 &'a self,
158 visitor: &mut V,
159 context: V::Context,
160 ) -> VortexResult<FoldUp<V::Out>>;
161
162 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
163 self,
164 _visitor: &mut V,
165 ) -> VortexResult<TransformResult<Self>>;
166
167 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
168 self,
169 _visitor: &mut V,
170 _context: V::Context,
171 ) -> VortexResult<FoldUp<V::Out>>;
172}
173
174impl Node for ExprRef {
175 fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
177 &'a self,
178 visitor: &mut V,
179 ) -> VortexResult<TraversalOrder> {
180 let mut ord = visitor.visit_down(self)?;
181 if ord == TraversalOrder::Stop {
182 return Ok(TraversalOrder::Stop);
183 }
184 if ord == TraversalOrder::Skip {
185 return Ok(TraversalOrder::Continue);
186 }
187 for child in self.children() {
188 if ord != TraversalOrder::Continue {
189 return Ok(ord);
190 }
191 ord = child.accept(visitor)?;
192 }
193 if ord == TraversalOrder::Stop {
194 return Ok(TraversalOrder::Stop);
195 }
196 visitor.visit_up(self)
197 }
198
199 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
200 &'a self,
201 visitor: &mut V,
202 context: V::Context,
203 ) -> VortexResult<FoldUp<V::Out>> {
204 let children = match visitor.visit_down(self, context.clone())? {
205 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
206 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
207 FoldDown::Continue(child_context) => {
208 let mut new_children = Vec::with_capacity(self.children().len());
209 for child in self.children() {
210 match child.accept_with_context(visitor, child_context.clone())? {
211 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
212 FoldUp::Continue(out) => new_children.push(out),
213 }
214 }
215 new_children
216 }
217 };
218
219 visitor.visit_up(self, context, children)
220 }
221
222 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
224 self,
225 visitor: &mut V,
226 ) -> VortexResult<TransformResult<Self>> {
227 match visitor.visit_down(&self)? {
228 TraversalOrder::Stop => Ok(TransformResult {
229 result: self,
230 order: TraversalOrder::Stop,
231 changed: false,
232 }),
233 TraversalOrder::Skip => visitor.visit_up(self),
234 TraversalOrder::Continue => {
235 let mut new_children = Vec::with_capacity(self.children().len());
236 let mut changed = false;
237 let mut stopped = false;
238 for child in self.children() {
239 if stopped {
240 new_children.push(child.clone());
241 continue;
242 }
243 let TransformResult {
244 result: new_child,
245 order: child_order,
246 changed: child_changed,
247 } = child.clone().transform(visitor)?;
248 new_children.push(new_child);
249 changed |= child_changed;
250 stopped |= child_order == TraversalOrder::Stop;
251 }
252
253 if changed {
254 let up = visitor.visit_up(self.replacing_children(new_children))?;
255 Ok(TransformResult::yes(up.result))
256 } else {
257 visitor.visit_up(self)
258 }
259 }
260 }
261 }
262
263 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
264 self,
265 visitor: &mut V,
266 context: V::Context,
267 ) -> VortexResult<FoldUp<V::Out>> {
268 let children = match visitor.visit_down(&self, context.clone())? {
269 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
270 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
271 FoldDown::Continue(child_context) => {
272 let mut new_children = Vec::with_capacity(self.children().len());
273 for child in self.children() {
274 match child
275 .clone()
276 .transform_with_context(visitor, child_context.clone())?
277 {
278 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
279 FoldUp::Continue(out) => new_children.push(out),
280 }
281 }
282 new_children
283 }
284 };
285
286 visitor.visit_up(self, context, children)
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use std::sync::Arc;
293
294 use vortex_error::VortexResult;
295 use vortex_utils::aliases::hash_set::HashSet;
296
297 use crate::traversal::visitor::pre_order_visit_down;
298 use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
299 use crate::{
300 BinaryExpr, ExprRef, FieldName, GetItem, Literal, Operator, VortexExpr, VortexExprExt, col,
301 is_root, root,
302 };
303
304 #[derive(Default)]
305 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
306
307 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
308 type NodeTy = ExprRef;
309
310 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
311 if node.as_any().is::<Literal>() {
312 self.0.push(node)
313 }
314 Ok(TraversalOrder::Continue)
315 }
316
317 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
318 Ok(TraversalOrder::Continue)
319 }
320 }
321
322 #[derive(Default)]
323 pub struct ExprColToLit(i32);
324
325 impl MutNodeVisitor for ExprColToLit {
326 type NodeTy = ExprRef;
327
328 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
329 let col = node.as_any().downcast_ref::<GetItem>();
330 if col.is_some() {
331 let id = self.0;
332 self.0 += 1;
333 Ok(TransformResult::yes(Literal::new_expr(id)))
334 } else {
335 Ok(TransformResult::no(node))
336 }
337 }
338 }
339
340 #[derive(Default)]
341 pub struct SkipDownVisitor;
342
343 impl MutNodeVisitor for SkipDownVisitor {
344 type NodeTy = ExprRef;
345
346 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
347 Ok(TraversalOrder::Skip)
348 }
349
350 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
351 Ok(TransformResult::yes(root()))
352 }
353 }
354
355 #[test]
356 fn expr_deep_visitor_test() {
357 let col1: Arc<dyn VortexExpr> = col("col1");
358 let lit1 = Literal::new_expr(1);
359 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
360 let lit2 = Literal::new_expr(2);
361 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
362 let mut printer = ExprLitCollector::default();
363 expr.accept(&mut printer).unwrap();
364 assert_eq!(printer.0.len(), 2);
365 }
366
367 #[test]
368 fn expr_deep_mut_visitor_test() {
369 let col1: Arc<dyn VortexExpr> = col("col1");
370 let col2: Arc<dyn VortexExpr> = col("col2");
371 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
372 let lit2 = Literal::new_expr(2);
373 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
374 let mut printer = ExprColToLit::default();
375 let new = expr.transform(&mut printer).unwrap();
376 assert!(new.changed);
377
378 let expr = new.result;
379
380 let mut printer = ExprLitCollector::default();
381 expr.accept(&mut printer).unwrap();
382 assert_eq!(printer.0.len(), 3);
383 }
384
385 #[test]
386 fn expr_skip_test() {
387 let col1: Arc<dyn VortexExpr> = col("col1");
388 let col2: Arc<dyn VortexExpr> = col("col2");
389 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
390 let col3: Arc<dyn VortexExpr> = col("col3");
391 let col4: Arc<dyn VortexExpr> = col("col4");
392 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
393 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
394
395 let mut nodes = Vec::new();
396 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
397 if node.as_any().is::<GetItem>() {
398 nodes.push(node)
399 }
400 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
401 if bin.op() == Operator::Eq {
402 return Ok(TraversalOrder::Skip);
403 }
404 }
405 Ok(TraversalOrder::Continue)
406 }))
407 .unwrap();
408
409 assert_eq!(
410 nodes
411 .into_iter()
412 .map(|x| x.references())
413 .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
414 HashSet::from_iter(vec![FieldName::from("col3"), FieldName::from("col4")])
415 );
416 }
417
418 #[test]
419 fn expr_stop_test() {
420 let col1: Arc<dyn VortexExpr> = col("col1");
421 let col2: Arc<dyn VortexExpr> = col("col2");
422 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
423 let col3: Arc<dyn VortexExpr> = col("col3");
424 let col4: Arc<dyn VortexExpr> = col("col4");
425 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
426 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
427
428 let mut nodes = Vec::new();
429 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
430 if node.as_any().is::<GetItem>() {
431 nodes.push(node)
432 }
433 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
434 if bin.op() == Operator::Eq {
435 return Ok(TraversalOrder::Stop);
436 }
437 }
438 Ok(TraversalOrder::Continue)
439 }))
440 .unwrap();
441
442 assert_eq!(
443 nodes
444 .into_iter()
445 .map(|x| x.references())
446 .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
447 HashSet::from_iter(vec![])
448 );
449 }
450
451 #[test]
452 fn expr_skip_down_visit_up() {
453 let col = col("col");
454
455 let mut visitor = SkipDownVisitor;
456 let result = col.transform(&mut visitor).unwrap();
457
458 assert!(result.changed);
459 assert!(is_root(&result.result));
460 }
461}