vortex_expr/traversal/
mod.rs

1mod references;
2mod visitor;
3
4use itertools::Itertools;
5pub use references::ReferenceCollector;
6pub use visitor::{pre_order_visit_down, pre_order_visit_up};
7use vortex_error::VortexResult;
8
9use crate::ExprRef;
10
11/// Define a data fusion inspired traversal pattern for visiting nodes in a `Node`,
12/// for now only VortexExpr.
13///
14/// This traversal is a pre-order traversal.
15/// There are control traversal controls `TraversalOrder`:
16/// - `Skip`: Skip visiting the children of the current node.
17/// - `Stop`: Stop visiting any more nodes in the traversal.
18/// - `Continue`: Continue with the traversal as expected.
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TraversalOrder {
22    // In a top-down traversal, skip visiting the children of the current node.
23    // In the bottom-up phase of the traversal this does nothing (for now).
24    Skip,
25
26    // Stop visiting any more nodes in the traversal.
27    Stop,
28
29    // Continue with the traversal as expected.
30    Continue,
31}
32
33#[derive(Debug, Clone)]
34pub struct TransformResult<T> {
35    pub result: T,
36    order: TraversalOrder,
37    changed: bool,
38}
39
40impl<T> TransformResult<T> {
41    pub fn yes(result: T) -> Self {
42        Self {
43            result,
44            order: TraversalOrder::Continue,
45            changed: true,
46        }
47    }
48
49    pub fn no(result: T) -> Self {
50        Self {
51            result,
52            order: TraversalOrder::Continue,
53            changed: false,
54        }
55    }
56}
57
58pub trait NodeVisitor<'a> {
59    type NodeTy: Node;
60
61    fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
62        Ok(TraversalOrder::Continue)
63    }
64
65    fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66        Ok(TraversalOrder::Continue)
67    }
68}
69
70pub trait MutNodeVisitor {
71    type NodeTy: Node;
72
73    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
74        Ok(TraversalOrder::Continue)
75    }
76
77    fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
78}
79
80pub enum FoldDown<Out, Context> {
81    /// Abort the entire traversal and immediately return the result.
82    Abort(Out),
83    /// Skip visiting children of the current node and return the result to the parent's `fold_up`.
84    SkipChildren(Out),
85    /// Continue visiting the `fold_down` of the children of the current node.
86    Continue(Context),
87}
88
89#[derive(Debug)]
90pub enum FoldUp<Out> {
91    /// Abort the entire traversal and immediately return the result.
92    Abort(Out),
93    /// Continue visiting the `fold_up` of the parent node.
94    Continue(Out),
95}
96
97impl<Out> FoldUp<Out> {
98    pub fn result(self) -> Out {
99        match self {
100            FoldUp::Abort(out) => out,
101            FoldUp::Continue(out) => out,
102        }
103    }
104}
105
106pub trait Folder<'a> {
107    type NodeTy: Node;
108    type Out;
109    type Context: Clone;
110
111    fn visit_down(
112        &mut self,
113        _node: &'a Self::NodeTy,
114        context: Self::Context,
115    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
116        Ok(FoldDown::Continue(context))
117    }
118
119    fn visit_up(
120        &mut self,
121        node: &'a Self::NodeTy,
122        context: Self::Context,
123        children: Vec<Self::Out>,
124    ) -> VortexResult<FoldUp<Self::Out>>;
125}
126
127pub trait FolderMut {
128    type NodeTy: Node;
129    type Out;
130    type Context: Clone;
131
132    fn visit_down(
133        &mut self,
134        _node: &Self::NodeTy,
135        context: Self::Context,
136    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
137        Ok(FoldDown::Continue(context))
138    }
139
140    fn visit_up(
141        &mut self,
142        node: Self::NodeTy,
143        context: Self::Context,
144        children: Vec<Self::Out>,
145    ) -> VortexResult<FoldUp<Self::Out>>;
146}
147
148pub trait Node: Sized {
149    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
150        &'a self,
151        _visitor: &mut V,
152    ) -> VortexResult<TraversalOrder>;
153
154    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
155        &'a self,
156        visitor: &mut V,
157        context: V::Context,
158    ) -> VortexResult<FoldUp<V::Out>>;
159
160    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
161        self,
162        _visitor: &mut V,
163    ) -> VortexResult<TransformResult<Self>>;
164
165    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
166        self,
167        _visitor: &mut V,
168        _context: V::Context,
169    ) -> VortexResult<FoldUp<V::Out>>;
170}
171
172impl Node for ExprRef {
173    // A pre-order traversal.
174    fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
175        &'a self,
176        visitor: &mut V,
177    ) -> VortexResult<TraversalOrder> {
178        let mut ord = visitor.visit_down(self)?;
179        if ord == TraversalOrder::Stop {
180            return Ok(TraversalOrder::Stop);
181        }
182        if ord == TraversalOrder::Skip {
183            return Ok(TraversalOrder::Continue);
184        }
185        for child in self.children() {
186            if ord != TraversalOrder::Continue {
187                return Ok(ord);
188            }
189            ord = child.accept(visitor)?;
190        }
191        if ord == TraversalOrder::Stop {
192            return Ok(TraversalOrder::Stop);
193        }
194        visitor.visit_up(self)
195    }
196
197    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
198        &'a self,
199        visitor: &mut V,
200        context: V::Context,
201    ) -> VortexResult<FoldUp<V::Out>> {
202        let children = match visitor.visit_down(self, context.clone())? {
203            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
204            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
205            FoldDown::Continue(child_context) => {
206                let mut new_children = Vec::with_capacity(self.children().len());
207                for child in self.children() {
208                    match child.accept_with_context(visitor, child_context.clone())? {
209                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
210                        FoldUp::Continue(out) => new_children.push(out),
211                    }
212                }
213                new_children
214            }
215        };
216
217        visitor.visit_up(self, context, children)
218    }
219
220    // A post-order transform, with an option to ignore sub-tress (using visit_down).
221    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
222        self,
223        visitor: &mut V,
224    ) -> VortexResult<TransformResult<Self>> {
225        let mut ord = visitor.visit_down(&self)?;
226        if ord == TraversalOrder::Stop {
227            return Ok(TransformResult {
228                result: self,
229                order: TraversalOrder::Stop,
230                changed: false,
231            });
232        }
233        let (children, ord, changed) = if ord == TraversalOrder::Continue {
234            let mut new_children = Vec::with_capacity(self.children().len());
235            let mut changed = false;
236            for child in self.children() {
237                match ord {
238                    TraversalOrder::Continue | TraversalOrder::Skip => {
239                        let TransformResult {
240                            result: new_child,
241                            order: child_order,
242                            changed: child_changed,
243                        } = child.clone().transform(visitor)?;
244                        new_children.push(new_child);
245                        ord = child_order;
246                        changed |= child_changed;
247                    }
248                    TraversalOrder::Stop => new_children.push(child.clone()),
249                }
250            }
251            (new_children, ord, changed)
252        } else {
253            (
254                self.children().into_iter().cloned().collect_vec(),
255                ord,
256                false,
257            )
258        };
259
260        if ord == TraversalOrder::Continue {
261            let up = visitor.visit_up(self.replacing_children(children))?;
262            Ok(TransformResult::yes(up.result))
263        } else {
264            Ok(TransformResult {
265                result: self.replacing_children(children),
266                order: ord,
267                changed,
268            })
269        }
270    }
271
272    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
273        self,
274        visitor: &mut V,
275        context: V::Context,
276    ) -> VortexResult<FoldUp<V::Out>> {
277        let children = match visitor.visit_down(&self, context.clone())? {
278            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
279            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
280            FoldDown::Continue(child_context) => {
281                let mut new_children = Vec::with_capacity(self.children().len());
282                for child in self.children() {
283                    match child
284                        .clone()
285                        .transform_with_context(visitor, child_context.clone())?
286                    {
287                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
288                        FoldUp::Continue(out) => new_children.push(out),
289                    }
290                }
291                new_children
292            }
293        };
294
295        visitor.visit_up(self, context, children)
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use std::sync::Arc;
302
303    use vortex_array::aliases::hash_set::HashSet;
304    use vortex_error::VortexResult;
305
306    use crate::traversal::visitor::pre_order_visit_down;
307    use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
308    use crate::{
309        BinaryExpr, ExprRef, FieldName, GetItem, Literal, Operator, VortexExpr, VortexExprExt, col,
310    };
311
312    #[derive(Default)]
313    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
314
315    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
316        type NodeTy = ExprRef;
317
318        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
319            if node.as_any().is::<Literal>() {
320                self.0.push(node)
321            }
322            Ok(TraversalOrder::Continue)
323        }
324
325        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
326            Ok(TraversalOrder::Continue)
327        }
328    }
329
330    #[derive(Default)]
331    pub struct ExprColToLit(i32);
332
333    impl MutNodeVisitor for ExprColToLit {
334        type NodeTy = ExprRef;
335
336        fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
337            let col = node.as_any().downcast_ref::<GetItem>();
338            if col.is_some() {
339                let id = self.0;
340                self.0 += 1;
341                Ok(TransformResult::yes(Literal::new_expr(id)))
342            } else {
343                Ok(TransformResult::no(node))
344            }
345        }
346    }
347
348    #[test]
349    fn expr_deep_visitor_test() {
350        let col1: Arc<dyn VortexExpr> = col("col1");
351        let lit1 = Literal::new_expr(1);
352        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
353        let lit2 = Literal::new_expr(2);
354        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
355        let mut printer = ExprLitCollector::default();
356        expr.accept(&mut printer).unwrap();
357        assert_eq!(printer.0.len(), 2);
358    }
359
360    #[test]
361    fn expr_deep_mut_visitor_test() {
362        let col1: Arc<dyn VortexExpr> = col("col1");
363        let col2: Arc<dyn VortexExpr> = col("col2");
364        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
365        let lit2 = Literal::new_expr(2);
366        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
367        let mut printer = ExprColToLit::default();
368        let new = expr.transform(&mut printer).unwrap();
369        assert!(new.changed);
370
371        let expr = new.result;
372
373        let mut printer = ExprLitCollector::default();
374        expr.accept(&mut printer).unwrap();
375        assert_eq!(printer.0.len(), 3);
376    }
377
378    #[test]
379    fn expr_skip_test() {
380        let col1: Arc<dyn VortexExpr> = col("col1");
381        let col2: Arc<dyn VortexExpr> = col("col2");
382        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
383        let col3: Arc<dyn VortexExpr> = col("col3");
384        let col4: Arc<dyn VortexExpr> = col("col4");
385        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
386        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
387
388        let mut nodes = Vec::new();
389        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
390            if node.as_any().is::<GetItem>() {
391                nodes.push(node)
392            }
393            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
394                if bin.op() == Operator::Eq {
395                    return Ok(TraversalOrder::Skip);
396                }
397            }
398            Ok(TraversalOrder::Continue)
399        }))
400        .unwrap();
401
402        assert_eq!(
403            nodes
404                .into_iter()
405                .map(|x| x.references())
406                .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
407            HashSet::from_iter(vec![FieldName::from("col3"), FieldName::from("col4")])
408        );
409    }
410
411    #[test]
412    fn expr_stop_test() {
413        let col1: Arc<dyn VortexExpr> = col("col1");
414        let col2: Arc<dyn VortexExpr> = col("col2");
415        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
416        let col3: Arc<dyn VortexExpr> = col("col3");
417        let col4: Arc<dyn VortexExpr> = col("col4");
418        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
419        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
420
421        let mut nodes = Vec::new();
422        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
423            if node.as_any().is::<GetItem>() {
424                nodes.push(node)
425            }
426            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
427                if bin.op() == Operator::Eq {
428                    return Ok(TraversalOrder::Stop);
429                }
430            }
431            Ok(TraversalOrder::Continue)
432        }))
433        .unwrap();
434
435        assert_eq!(
436            nodes
437                .into_iter()
438                .map(|x| x.references())
439                .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
440            HashSet::from_iter(vec![])
441        );
442    }
443}