vortex_expr/traversal/
mod.rs

1mod references;
2mod vars;
3mod visitor;
4
5pub use references::ReferenceCollector;
6pub use vars::VarsCollector;
7pub use visitor::{pre_order_visit_down, pre_order_visit_up};
8use vortex_error::VortexResult;
9
10use crate::ExprRef;
11
12/// Define a data fusion inspired traversal pattern for visiting nodes in a `Node`,
13/// for now only VortexExpr.
14///
15/// This traversal is a pre-order traversal.
16/// There are control traversal controls `TraversalOrder`:
17/// - `Skip`: Skip visiting the children of the current node.
18/// - `Stop`: Stop visiting any more nodes in the traversal.
19/// - `Continue`: Continue with the traversal as expected.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TraversalOrder {
22    /// In a top-down traversal, skip visiting the children of the current node.
23    /// In the bottom-up phase of the traversal this does nothing (for now).
24    Skip,
25
26    /// Stop visiting any more nodes in the traversal.
27    Stop,
28
29    /// Continue with the traversal as expected.
30    Continue,
31}
32
33#[derive(Debug, Clone)]
34pub struct TransformResult<T> {
35    result: T,
36    order: TraversalOrder,
37    changed: bool,
38}
39
40impl<T> TransformResult<T> {
41    pub fn yes(result: T) -> Self {
42        Self {
43            result,
44            order: TraversalOrder::Continue,
45            changed: true,
46        }
47    }
48
49    pub fn no(result: T) -> Self {
50        Self {
51            result,
52            order: TraversalOrder::Continue,
53            changed: false,
54        }
55    }
56
57    pub fn into_inner(self) -> T {
58        self.result
59    }
60}
61
62pub trait NodeVisitor<'a> {
63    type NodeTy: Node;
64
65    fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66        Ok(TraversalOrder::Continue)
67    }
68
69    fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
70        Ok(TraversalOrder::Continue)
71    }
72}
73
74pub trait MutNodeVisitor {
75    type NodeTy: Node;
76
77    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
78        Ok(TraversalOrder::Continue)
79    }
80
81    fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
82}
83
84pub enum FoldDown<Out, Context> {
85    /// Abort the entire traversal and immediately return the result.
86    Abort(Out),
87    /// Skip visiting children of the current node and return the result to the parent's `fold_up`.
88    SkipChildren(Out),
89    /// Continue visiting the `fold_down` of the children of the current node.
90    Continue(Context),
91}
92
93#[derive(Debug)]
94pub enum FoldUp<Out> {
95    /// Abort the entire traversal and immediately return the result.
96    Abort(Out),
97    /// Continue visiting the `fold_up` of the parent node.
98    Continue(Out),
99}
100
101impl<Out> FoldUp<Out> {
102    pub fn result(self) -> Out {
103        match self {
104            FoldUp::Abort(out) => out,
105            FoldUp::Continue(out) => out,
106        }
107    }
108}
109
110pub trait Folder<'a> {
111    type NodeTy: Node;
112    type Out;
113    type Context: Clone;
114
115    fn visit_down(
116        &mut self,
117        _node: &'a Self::NodeTy,
118        context: Self::Context,
119    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
120        Ok(FoldDown::Continue(context))
121    }
122
123    fn visit_up(
124        &mut self,
125        node: &'a Self::NodeTy,
126        context: Self::Context,
127        children: Vec<Self::Out>,
128    ) -> VortexResult<FoldUp<Self::Out>>;
129}
130
131pub trait FolderMut {
132    type NodeTy: Node;
133    type Out;
134    type Context: Clone;
135
136    fn visit_down(
137        &mut self,
138        _node: &Self::NodeTy,
139        context: Self::Context,
140    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
141        Ok(FoldDown::Continue(context))
142    }
143
144    fn visit_up(
145        &mut self,
146        node: Self::NodeTy,
147        context: Self::Context,
148        children: Vec<Self::Out>,
149    ) -> VortexResult<FoldUp<Self::Out>>;
150}
151
152pub trait Node: Sized {
153    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
154        &'a self,
155        _visitor: &mut V,
156    ) -> VortexResult<TraversalOrder>;
157
158    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
159        &'a self,
160        visitor: &mut V,
161        context: V::Context,
162    ) -> VortexResult<FoldUp<V::Out>>;
163
164    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
165        self,
166        _visitor: &mut V,
167    ) -> VortexResult<TransformResult<Self>>;
168
169    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
170        self,
171        _visitor: &mut V,
172        _context: V::Context,
173    ) -> VortexResult<FoldUp<V::Out>>;
174}
175
176impl Node for ExprRef {
177    /// A pre-order traversal.
178    fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
179        &'a self,
180        visitor: &mut V,
181    ) -> VortexResult<TraversalOrder> {
182        let mut ord = visitor.visit_down(self)?;
183        if ord == TraversalOrder::Stop {
184            return Ok(TraversalOrder::Stop);
185        }
186        if ord == TraversalOrder::Skip {
187            return Ok(TraversalOrder::Continue);
188        }
189        for child in self.children() {
190            if ord != TraversalOrder::Continue {
191                return Ok(ord);
192            }
193            ord = child.accept(visitor)?;
194        }
195        if ord == TraversalOrder::Stop {
196            return Ok(TraversalOrder::Stop);
197        }
198        visitor.visit_up(self)
199    }
200
201    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
202        &'a self,
203        visitor: &mut V,
204        context: V::Context,
205    ) -> VortexResult<FoldUp<V::Out>> {
206        let children = match visitor.visit_down(self, context.clone())? {
207            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
208            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
209            FoldDown::Continue(child_context) => {
210                let mut new_children = Vec::with_capacity(self.children().len());
211                for child in self.children() {
212                    match child.accept_with_context(visitor, child_context.clone())? {
213                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
214                        FoldUp::Continue(out) => new_children.push(out),
215                    }
216                }
217                new_children
218            }
219        };
220
221        visitor.visit_up(self, context, children)
222    }
223
224    /// A post-order transform, with an option to ignore sub-tress (using visit_down).
225    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
226        self,
227        visitor: &mut V,
228    ) -> VortexResult<TransformResult<Self>> {
229        match visitor.visit_down(&self)? {
230            TraversalOrder::Stop => Ok(TransformResult {
231                result: self,
232                order: TraversalOrder::Stop,
233                changed: false,
234            }),
235            TraversalOrder::Skip => visitor.visit_up(self),
236            TraversalOrder::Continue => {
237                let mut new_children = Vec::with_capacity(self.children().len());
238                let mut changed = false;
239                let mut stopped = false;
240                for child in self.children() {
241                    if stopped {
242                        new_children.push(child.clone());
243                        continue;
244                    }
245                    let TransformResult {
246                        result: new_child,
247                        order: child_order,
248                        changed: child_changed,
249                    } = child.clone().transform(visitor)?;
250                    new_children.push(new_child);
251                    changed |= child_changed;
252                    stopped |= child_order == TraversalOrder::Stop;
253                }
254
255                if changed {
256                    let up = visitor.visit_up(self.replacing_children(new_children))?;
257                    Ok(TransformResult::yes(up.result))
258                } else {
259                    visitor.visit_up(self)
260                }
261            }
262        }
263    }
264
265    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
266        self,
267        visitor: &mut V,
268        context: V::Context,
269    ) -> VortexResult<FoldUp<V::Out>> {
270        let children = match visitor.visit_down(&self, context.clone())? {
271            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
272            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
273            FoldDown::Continue(child_context) => {
274                let mut new_children = Vec::with_capacity(self.children().len());
275                for child in self.children() {
276                    match child
277                        .clone()
278                        .transform_with_context(visitor, child_context.clone())?
279                    {
280                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
281                        FoldUp::Continue(out) => new_children.push(out),
282                    }
283                }
284                new_children
285            }
286        };
287
288        visitor.visit_up(self, context, children)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use std::sync::Arc;
295
296    use vortex_error::VortexResult;
297    use vortex_utils::aliases::hash_set::HashSet;
298
299    use crate::traversal::visitor::pre_order_visit_down;
300    use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
301    use crate::{
302        BinaryExpr, ExprRef, GetItem, Literal, Operator, VortexExpr, col, get_item_scope, is_root,
303        root,
304    };
305
306    #[derive(Default)]
307    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
308
309    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
310        type NodeTy = ExprRef;
311
312        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
313            if node.as_any().is::<Literal>() {
314                self.0.push(node)
315            }
316            Ok(TraversalOrder::Continue)
317        }
318
319        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
320            Ok(TraversalOrder::Continue)
321        }
322    }
323
324    #[derive(Default)]
325    pub struct ExprColToLit(i32);
326
327    impl MutNodeVisitor for ExprColToLit {
328        type NodeTy = ExprRef;
329
330        fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
331            let col = node.as_any().downcast_ref::<GetItem>();
332            if col.is_some() {
333                let id = self.0;
334                self.0 += 1;
335                Ok(TransformResult::yes(Literal::new_expr(id)))
336            } else {
337                Ok(TransformResult::no(node))
338            }
339        }
340    }
341
342    #[derive(Default)]
343    pub struct SkipDownVisitor;
344
345    impl MutNodeVisitor for SkipDownVisitor {
346        type NodeTy = ExprRef;
347
348        fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
349            Ok(TraversalOrder::Skip)
350        }
351
352        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
353            Ok(TransformResult::yes(root()))
354        }
355    }
356
357    #[test]
358    fn expr_deep_visitor_test() {
359        let col1: Arc<dyn VortexExpr> = col("col1");
360        let lit1 = Literal::new_expr(1);
361        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
362        let lit2 = Literal::new_expr(2);
363        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
364        let mut printer = ExprLitCollector::default();
365        expr.accept(&mut printer).unwrap();
366        assert_eq!(printer.0.len(), 2);
367    }
368
369    #[test]
370    fn expr_deep_mut_visitor_test() {
371        let col1: Arc<dyn VortexExpr> = col("col1");
372        let col2: Arc<dyn VortexExpr> = col("col2");
373        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
374        let lit2 = Literal::new_expr(2);
375        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
376        let mut printer = ExprColToLit::default();
377        let new = expr.transform(&mut printer).unwrap();
378        assert!(new.changed);
379
380        let expr = new.result;
381
382        let mut printer = ExprLitCollector::default();
383        expr.accept(&mut printer).unwrap();
384        assert_eq!(printer.0.len(), 3);
385    }
386
387    #[test]
388    fn expr_skip_test() {
389        let col1: Arc<dyn VortexExpr> = col("col1");
390        let col2: Arc<dyn VortexExpr> = col("col2");
391        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
392        let col3: Arc<dyn VortexExpr> = col("col3");
393        let col4: Arc<dyn VortexExpr> = col("col4");
394        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
395        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
396
397        let mut nodes = Vec::new();
398        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
399            if node.as_any().is::<GetItem>() {
400                nodes.push(node)
401            }
402            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
403                if bin.op() == Operator::Eq {
404                    return Ok(TraversalOrder::Skip);
405                }
406            }
407            Ok(TraversalOrder::Continue)
408        }))
409        .unwrap();
410
411        let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
412        assert_eq!(
413            nodes,
414            HashSet::from_iter([get_item_scope("col3"), get_item_scope("col4")])
415        );
416    }
417
418    #[test]
419    fn expr_stop_test() {
420        let col1: Arc<dyn VortexExpr> = col("col1");
421        let col2: Arc<dyn VortexExpr> = col("col2");
422        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
423        let col3: Arc<dyn VortexExpr> = col("col3");
424        let col4: Arc<dyn VortexExpr> = col("col4");
425        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
426        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
427
428        let mut nodes = Vec::new();
429        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
430            if node.as_any().is::<GetItem>() {
431                nodes.push(node)
432            }
433            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
434                if bin.op() == Operator::Eq {
435                    return Ok(TraversalOrder::Stop);
436                }
437            }
438            Ok(TraversalOrder::Continue)
439        }))
440        .unwrap();
441
442        assert!(nodes.is_empty());
443    }
444
445    #[test]
446    fn expr_skip_down_visit_up() {
447        let col = col("col");
448
449        let mut visitor = SkipDownVisitor;
450        let result = col.transform(&mut visitor).unwrap();
451
452        assert!(result.changed);
453        assert!(is_root(&result.result));
454    }
455}