vortex_expr/traversal/
mod.rs

1mod references;
2mod visitor;
3
4pub use references::ReferenceCollector;
5pub use visitor::{pre_order_visit_down, pre_order_visit_up};
6use vortex_error::VortexResult;
7
8use crate::ExprRef;
9
10/// Define a data fusion inspired traversal pattern for visiting nodes in a `Node`,
11/// for now only VortexExpr.
12///
13/// This traversal is a pre-order traversal.
14/// There are control traversal controls `TraversalOrder`:
15/// - `Skip`: Skip visiting the children of the current node.
16/// - `Stop`: Stop visiting any more nodes in the traversal.
17/// - `Continue`: Continue with the traversal as expected.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum TraversalOrder {
20    /// In a top-down traversal, skip visiting the children of the current node.
21    /// In the bottom-up phase of the traversal this does nothing (for now).
22    Skip,
23
24    /// Stop visiting any more nodes in the traversal.
25    Stop,
26
27    /// Continue with the traversal as expected.
28    Continue,
29}
30
31#[derive(Debug, Clone)]
32pub struct TransformResult<T> {
33    result: T,
34    order: TraversalOrder,
35    changed: bool,
36}
37
38impl<T> TransformResult<T> {
39    pub fn yes(result: T) -> Self {
40        Self {
41            result,
42            order: TraversalOrder::Continue,
43            changed: true,
44        }
45    }
46
47    pub fn no(result: T) -> Self {
48        Self {
49            result,
50            order: TraversalOrder::Continue,
51            changed: false,
52        }
53    }
54
55    pub fn into_inner(self) -> T {
56        self.result
57    }
58}
59
60pub trait NodeVisitor<'a> {
61    type NodeTy: Node;
62
63    fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
64        Ok(TraversalOrder::Continue)
65    }
66
67    fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
68        Ok(TraversalOrder::Continue)
69    }
70}
71
72pub trait MutNodeVisitor {
73    type NodeTy: Node;
74
75    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
76        Ok(TraversalOrder::Continue)
77    }
78
79    fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
80}
81
82pub enum FoldDown<Out, Context> {
83    /// Abort the entire traversal and immediately return the result.
84    Abort(Out),
85    /// Skip visiting children of the current node and return the result to the parent's `fold_up`.
86    SkipChildren(Out),
87    /// Continue visiting the `fold_down` of the children of the current node.
88    Continue(Context),
89}
90
91#[derive(Debug)]
92pub enum FoldUp<Out> {
93    /// Abort the entire traversal and immediately return the result.
94    Abort(Out),
95    /// Continue visiting the `fold_up` of the parent node.
96    Continue(Out),
97}
98
99impl<Out> FoldUp<Out> {
100    pub fn result(self) -> Out {
101        match self {
102            FoldUp::Abort(out) => out,
103            FoldUp::Continue(out) => out,
104        }
105    }
106}
107
108pub trait Folder<'a> {
109    type NodeTy: Node;
110    type Out;
111    type Context: Clone;
112
113    fn visit_down(
114        &mut self,
115        _node: &'a Self::NodeTy,
116        context: Self::Context,
117    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
118        Ok(FoldDown::Continue(context))
119    }
120
121    fn visit_up(
122        &mut self,
123        node: &'a Self::NodeTy,
124        context: Self::Context,
125        children: Vec<Self::Out>,
126    ) -> VortexResult<FoldUp<Self::Out>>;
127}
128
129pub trait FolderMut {
130    type NodeTy: Node;
131    type Out;
132    type Context: Clone;
133
134    fn visit_down(
135        &mut self,
136        _node: &Self::NodeTy,
137        context: Self::Context,
138    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
139        Ok(FoldDown::Continue(context))
140    }
141
142    fn visit_up(
143        &mut self,
144        node: Self::NodeTy,
145        context: Self::Context,
146        children: Vec<Self::Out>,
147    ) -> VortexResult<FoldUp<Self::Out>>;
148}
149
150pub trait Node: Sized {
151    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
152        &'a self,
153        _visitor: &mut V,
154    ) -> VortexResult<TraversalOrder>;
155
156    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
157        &'a self,
158        visitor: &mut V,
159        context: V::Context,
160    ) -> VortexResult<FoldUp<V::Out>>;
161
162    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
163        self,
164        _visitor: &mut V,
165    ) -> VortexResult<TransformResult<Self>>;
166
167    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
168        self,
169        _visitor: &mut V,
170        _context: V::Context,
171    ) -> VortexResult<FoldUp<V::Out>>;
172}
173
174impl Node for ExprRef {
175    /// A pre-order traversal.
176    fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
177        &'a self,
178        visitor: &mut V,
179    ) -> VortexResult<TraversalOrder> {
180        let mut ord = visitor.visit_down(self)?;
181        if ord == TraversalOrder::Stop {
182            return Ok(TraversalOrder::Stop);
183        }
184        if ord == TraversalOrder::Skip {
185            return Ok(TraversalOrder::Continue);
186        }
187        for child in self.children() {
188            if ord != TraversalOrder::Continue {
189                return Ok(ord);
190            }
191            ord = child.accept(visitor)?;
192        }
193        if ord == TraversalOrder::Stop {
194            return Ok(TraversalOrder::Stop);
195        }
196        visitor.visit_up(self)
197    }
198
199    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
200        &'a self,
201        visitor: &mut V,
202        context: V::Context,
203    ) -> VortexResult<FoldUp<V::Out>> {
204        let children = match visitor.visit_down(self, context.clone())? {
205            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
206            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
207            FoldDown::Continue(child_context) => {
208                let mut new_children = Vec::with_capacity(self.children().len());
209                for child in self.children() {
210                    match child.accept_with_context(visitor, child_context.clone())? {
211                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
212                        FoldUp::Continue(out) => new_children.push(out),
213                    }
214                }
215                new_children
216            }
217        };
218
219        visitor.visit_up(self, context, children)
220    }
221
222    /// A post-order transform, with an option to ignore sub-tress (using visit_down).
223    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
224        self,
225        visitor: &mut V,
226    ) -> VortexResult<TransformResult<Self>> {
227        match visitor.visit_down(&self)? {
228            TraversalOrder::Stop => Ok(TransformResult {
229                result: self,
230                order: TraversalOrder::Stop,
231                changed: false,
232            }),
233            TraversalOrder::Skip => visitor.visit_up(self),
234            TraversalOrder::Continue => {
235                let mut new_children = Vec::with_capacity(self.children().len());
236                let mut changed = false;
237                let mut stopped = false;
238                for child in self.children() {
239                    if stopped {
240                        new_children.push(child.clone());
241                        continue;
242                    }
243                    let TransformResult {
244                        result: new_child,
245                        order: child_order,
246                        changed: child_changed,
247                    } = child.clone().transform(visitor)?;
248                    new_children.push(new_child);
249                    changed |= child_changed;
250                    stopped |= child_order == TraversalOrder::Stop;
251                }
252
253                if changed {
254                    let up = visitor.visit_up(self.replacing_children(new_children))?;
255                    Ok(TransformResult::yes(up.result))
256                } else {
257                    visitor.visit_up(self)
258                }
259            }
260        }
261    }
262
263    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
264        self,
265        visitor: &mut V,
266        context: V::Context,
267    ) -> VortexResult<FoldUp<V::Out>> {
268        let children = match visitor.visit_down(&self, context.clone())? {
269            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
270            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
271            FoldDown::Continue(child_context) => {
272                let mut new_children = Vec::with_capacity(self.children().len());
273                for child in self.children() {
274                    match child
275                        .clone()
276                        .transform_with_context(visitor, child_context.clone())?
277                    {
278                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
279                        FoldUp::Continue(out) => new_children.push(out),
280                    }
281                }
282                new_children
283            }
284        };
285
286        visitor.visit_up(self, context, children)
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use std::sync::Arc;
293
294    use vortex_error::VortexResult;
295    use vortex_utils::aliases::hash_set::HashSet;
296
297    use crate::traversal::visitor::pre_order_visit_down;
298    use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
299    use crate::{
300        BinaryExpr, ExprRef, FieldName, GetItem, Literal, Operator, VortexExpr, VortexExprExt, col,
301        is_root, root,
302    };
303
304    #[derive(Default)]
305    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
306
307    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
308        type NodeTy = ExprRef;
309
310        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
311            if node.as_any().is::<Literal>() {
312                self.0.push(node)
313            }
314            Ok(TraversalOrder::Continue)
315        }
316
317        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
318            Ok(TraversalOrder::Continue)
319        }
320    }
321
322    #[derive(Default)]
323    pub struct ExprColToLit(i32);
324
325    impl MutNodeVisitor for ExprColToLit {
326        type NodeTy = ExprRef;
327
328        fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
329            let col = node.as_any().downcast_ref::<GetItem>();
330            if col.is_some() {
331                let id = self.0;
332                self.0 += 1;
333                Ok(TransformResult::yes(Literal::new_expr(id)))
334            } else {
335                Ok(TransformResult::no(node))
336            }
337        }
338    }
339
340    #[derive(Default)]
341    pub struct SkipDownVisitor;
342
343    impl MutNodeVisitor for SkipDownVisitor {
344        type NodeTy = ExprRef;
345
346        fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
347            Ok(TraversalOrder::Skip)
348        }
349
350        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
351            Ok(TransformResult::yes(root()))
352        }
353    }
354
355    #[test]
356    fn expr_deep_visitor_test() {
357        let col1: Arc<dyn VortexExpr> = col("col1");
358        let lit1 = Literal::new_expr(1);
359        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
360        let lit2 = Literal::new_expr(2);
361        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
362        let mut printer = ExprLitCollector::default();
363        expr.accept(&mut printer).unwrap();
364        assert_eq!(printer.0.len(), 2);
365    }
366
367    #[test]
368    fn expr_deep_mut_visitor_test() {
369        let col1: Arc<dyn VortexExpr> = col("col1");
370        let col2: Arc<dyn VortexExpr> = col("col2");
371        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
372        let lit2 = Literal::new_expr(2);
373        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
374        let mut printer = ExprColToLit::default();
375        let new = expr.transform(&mut printer).unwrap();
376        assert!(new.changed);
377
378        let expr = new.result;
379
380        let mut printer = ExprLitCollector::default();
381        expr.accept(&mut printer).unwrap();
382        assert_eq!(printer.0.len(), 3);
383    }
384
385    #[test]
386    fn expr_skip_test() {
387        let col1: Arc<dyn VortexExpr> = col("col1");
388        let col2: Arc<dyn VortexExpr> = col("col2");
389        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
390        let col3: Arc<dyn VortexExpr> = col("col3");
391        let col4: Arc<dyn VortexExpr> = col("col4");
392        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
393        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
394
395        let mut nodes = Vec::new();
396        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
397            if node.as_any().is::<GetItem>() {
398                nodes.push(node)
399            }
400            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
401                if bin.op() == Operator::Eq {
402                    return Ok(TraversalOrder::Skip);
403                }
404            }
405            Ok(TraversalOrder::Continue)
406        }))
407        .unwrap();
408
409        assert_eq!(
410            nodes
411                .into_iter()
412                .map(|x| x.references())
413                .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
414            HashSet::from_iter(vec![FieldName::from("col3"), FieldName::from("col4")])
415        );
416    }
417
418    #[test]
419    fn expr_stop_test() {
420        let col1: Arc<dyn VortexExpr> = col("col1");
421        let col2: Arc<dyn VortexExpr> = col("col2");
422        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
423        let col3: Arc<dyn VortexExpr> = col("col3");
424        let col4: Arc<dyn VortexExpr> = col("col4");
425        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
426        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
427
428        let mut nodes = Vec::new();
429        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
430            if node.as_any().is::<GetItem>() {
431                nodes.push(node)
432            }
433            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
434                if bin.op() == Operator::Eq {
435                    return Ok(TraversalOrder::Stop);
436                }
437            }
438            Ok(TraversalOrder::Continue)
439        }))
440        .unwrap();
441
442        assert_eq!(
443            nodes
444                .into_iter()
445                .map(|x| x.references())
446                .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
447            HashSet::from_iter(vec![])
448        );
449    }
450
451    #[test]
452    fn expr_skip_down_visit_up() {
453        let col = col("col");
454
455        let mut visitor = SkipDownVisitor;
456        let result = col.transform(&mut visitor).unwrap();
457
458        assert!(result.changed);
459        assert!(is_root(&result.result));
460    }
461}