vortex_expr/traversal/
mod.rs

1mod references;
2mod visitor;
3
4pub use references::ReferenceCollector;
5pub use visitor::{pre_order_visit_down, pre_order_visit_up};
6use vortex_error::VortexResult;
7
8use crate::ExprRef;
9
10/// Define a data fusion inspired traversal pattern for visiting nodes in a `Node`,
11/// for now only VortexExpr.
12///
13/// This traversal is a pre-order traversal.
14/// There are control traversal controls `TraversalOrder`:
15/// - `Skip`: Skip visiting the children of the current node.
16/// - `Stop`: Stop visiting any more nodes in the traversal.
17/// - `Continue`: Continue with the traversal as expected.
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum TraversalOrder {
21    // In a top-down traversal, skip visiting the children of the current node.
22    // In the bottom-up phase of the traversal this does nothing (for now).
23    Skip,
24
25    // Stop visiting any more nodes in the traversal.
26    Stop,
27
28    // Continue with the traversal as expected.
29    Continue,
30}
31
32#[derive(Debug, Clone)]
33pub struct TransformResult<T> {
34    pub result: T,
35    order: TraversalOrder,
36    changed: bool,
37}
38
39impl<T> TransformResult<T> {
40    pub fn yes(result: T) -> Self {
41        Self {
42            result,
43            order: TraversalOrder::Continue,
44            changed: true,
45        }
46    }
47
48    pub fn no(result: T) -> Self {
49        Self {
50            result,
51            order: TraversalOrder::Continue,
52            changed: false,
53        }
54    }
55}
56
57pub trait NodeVisitor<'a> {
58    type NodeTy: Node;
59
60    fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
61        Ok(TraversalOrder::Continue)
62    }
63
64    fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
65        Ok(TraversalOrder::Continue)
66    }
67}
68
69pub trait MutNodeVisitor {
70    type NodeTy: Node;
71
72    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
73        Ok(TraversalOrder::Continue)
74    }
75
76    fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
77}
78
79pub enum FoldDown<Out, Context> {
80    /// Abort the entire traversal and immediately return the result.
81    Abort(Out),
82    /// Skip visiting children of the current node and return the result to the parent's `fold_up`.
83    SkipChildren(Out),
84    /// Continue visiting the `fold_down` of the children of the current node.
85    Continue(Context),
86}
87
88#[derive(Debug)]
89pub enum FoldUp<Out> {
90    /// Abort the entire traversal and immediately return the result.
91    Abort(Out),
92    /// Continue visiting the `fold_up` of the parent node.
93    Continue(Out),
94}
95
96impl<Out> FoldUp<Out> {
97    pub fn result(self) -> Out {
98        match self {
99            FoldUp::Abort(out) => out,
100            FoldUp::Continue(out) => out,
101        }
102    }
103}
104
105pub trait Folder<'a> {
106    type NodeTy: Node;
107    type Out;
108    type Context: Clone;
109
110    fn visit_down(
111        &mut self,
112        _node: &'a Self::NodeTy,
113        context: Self::Context,
114    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
115        Ok(FoldDown::Continue(context))
116    }
117
118    fn visit_up(
119        &mut self,
120        node: &'a Self::NodeTy,
121        context: Self::Context,
122        children: Vec<Self::Out>,
123    ) -> VortexResult<FoldUp<Self::Out>>;
124}
125
126pub trait FolderMut {
127    type NodeTy: Node;
128    type Out;
129    type Context: Clone;
130
131    fn visit_down(
132        &mut self,
133        _node: &Self::NodeTy,
134        context: Self::Context,
135    ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
136        Ok(FoldDown::Continue(context))
137    }
138
139    fn visit_up(
140        &mut self,
141        node: Self::NodeTy,
142        context: Self::Context,
143        children: Vec<Self::Out>,
144    ) -> VortexResult<FoldUp<Self::Out>>;
145}
146
147pub trait Node: Sized {
148    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
149        &'a self,
150        _visitor: &mut V,
151    ) -> VortexResult<TraversalOrder>;
152
153    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
154        &'a self,
155        visitor: &mut V,
156        context: V::Context,
157    ) -> VortexResult<FoldUp<V::Out>>;
158
159    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
160        self,
161        _visitor: &mut V,
162    ) -> VortexResult<TransformResult<Self>>;
163
164    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
165        self,
166        _visitor: &mut V,
167        _context: V::Context,
168    ) -> VortexResult<FoldUp<V::Out>>;
169}
170
171impl Node for ExprRef {
172    // A pre-order traversal.
173    fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
174        &'a self,
175        visitor: &mut V,
176    ) -> VortexResult<TraversalOrder> {
177        let mut ord = visitor.visit_down(self)?;
178        if ord == TraversalOrder::Stop {
179            return Ok(TraversalOrder::Stop);
180        }
181        if ord == TraversalOrder::Skip {
182            return Ok(TraversalOrder::Continue);
183        }
184        for child in self.children() {
185            if ord != TraversalOrder::Continue {
186                return Ok(ord);
187            }
188            ord = child.accept(visitor)?;
189        }
190        if ord == TraversalOrder::Stop {
191            return Ok(TraversalOrder::Stop);
192        }
193        visitor.visit_up(self)
194    }
195
196    fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
197        &'a self,
198        visitor: &mut V,
199        context: V::Context,
200    ) -> VortexResult<FoldUp<V::Out>> {
201        let children = match visitor.visit_down(self, context.clone())? {
202            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
203            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
204            FoldDown::Continue(child_context) => {
205                let mut new_children = Vec::with_capacity(self.children().len());
206                for child in self.children() {
207                    match child.accept_with_context(visitor, child_context.clone())? {
208                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
209                        FoldUp::Continue(out) => new_children.push(out),
210                    }
211                }
212                new_children
213            }
214        };
215
216        visitor.visit_up(self, context, children)
217    }
218
219    // A post-order transform, with an option to ignore sub-tress (using visit_down).
220    fn transform<V: MutNodeVisitor<NodeTy = Self>>(
221        self,
222        visitor: &mut V,
223    ) -> VortexResult<TransformResult<Self>> {
224        match visitor.visit_down(&self)? {
225            TraversalOrder::Stop => Ok(TransformResult {
226                result: self,
227                order: TraversalOrder::Stop,
228                changed: false,
229            }),
230            TraversalOrder::Skip => visitor.visit_up(self),
231            TraversalOrder::Continue => {
232                let mut new_children = Vec::with_capacity(self.children().len());
233                let mut changed = false;
234                let mut stopped = false;
235                for child in self.children() {
236                    if stopped {
237                        new_children.push(child.clone());
238                        continue;
239                    }
240                    let TransformResult {
241                        result: new_child,
242                        order: child_order,
243                        changed: child_changed,
244                    } = child.clone().transform(visitor)?;
245                    new_children.push(new_child);
246                    changed |= child_changed;
247                    stopped |= child_order == TraversalOrder::Stop;
248                }
249
250                if changed {
251                    let up = visitor.visit_up(self.replacing_children(new_children))?;
252                    Ok(TransformResult::yes(up.result))
253                } else {
254                    visitor.visit_up(self)
255                }
256            }
257        }
258    }
259
260    fn transform_with_context<V: FolderMut<NodeTy = Self>>(
261        self,
262        visitor: &mut V,
263        context: V::Context,
264    ) -> VortexResult<FoldUp<V::Out>> {
265        let children = match visitor.visit_down(&self, context.clone())? {
266            FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
267            FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
268            FoldDown::Continue(child_context) => {
269                let mut new_children = Vec::with_capacity(self.children().len());
270                for child in self.children() {
271                    match child
272                        .clone()
273                        .transform_with_context(visitor, child_context.clone())?
274                    {
275                        FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
276                        FoldUp::Continue(out) => new_children.push(out),
277                    }
278                }
279                new_children
280            }
281        };
282
283        visitor.visit_up(self, context, children)
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use std::sync::Arc;
290
291    use vortex_array::aliases::hash_set::HashSet;
292    use vortex_error::VortexResult;
293
294    use crate::traversal::visitor::pre_order_visit_down;
295    use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
296    use crate::{
297        BinaryExpr, ExprRef, FieldName, GetItem, Identity, Literal, Operator, VortexExpr,
298        VortexExprExt, col,
299    };
300
301    #[derive(Default)]
302    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
303
304    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
305        type NodeTy = ExprRef;
306
307        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
308            if node.as_any().is::<Literal>() {
309                self.0.push(node)
310            }
311            Ok(TraversalOrder::Continue)
312        }
313
314        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
315            Ok(TraversalOrder::Continue)
316        }
317    }
318
319    #[derive(Default)]
320    pub struct ExprColToLit(i32);
321
322    impl MutNodeVisitor for ExprColToLit {
323        type NodeTy = ExprRef;
324
325        fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
326            let col = node.as_any().downcast_ref::<GetItem>();
327            if col.is_some() {
328                let id = self.0;
329                self.0 += 1;
330                Ok(TransformResult::yes(Literal::new_expr(id)))
331            } else {
332                Ok(TransformResult::no(node))
333            }
334        }
335    }
336
337    #[derive(Default)]
338    pub struct SkipDownVisitor;
339
340    impl MutNodeVisitor for SkipDownVisitor {
341        type NodeTy = ExprRef;
342
343        fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
344            Ok(TraversalOrder::Skip)
345        }
346
347        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
348            Ok(TransformResult::yes(Identity::new_expr()))
349        }
350    }
351
352    #[test]
353    fn expr_deep_visitor_test() {
354        let col1: Arc<dyn VortexExpr> = col("col1");
355        let lit1 = Literal::new_expr(1);
356        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
357        let lit2 = Literal::new_expr(2);
358        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
359        let mut printer = ExprLitCollector::default();
360        expr.accept(&mut printer).unwrap();
361        assert_eq!(printer.0.len(), 2);
362    }
363
364    #[test]
365    fn expr_deep_mut_visitor_test() {
366        let col1: Arc<dyn VortexExpr> = col("col1");
367        let col2: Arc<dyn VortexExpr> = col("col2");
368        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
369        let lit2 = Literal::new_expr(2);
370        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
371        let mut printer = ExprColToLit::default();
372        let new = expr.transform(&mut printer).unwrap();
373        assert!(new.changed);
374
375        let expr = new.result;
376
377        let mut printer = ExprLitCollector::default();
378        expr.accept(&mut printer).unwrap();
379        assert_eq!(printer.0.len(), 3);
380    }
381
382    #[test]
383    fn expr_skip_test() {
384        let col1: Arc<dyn VortexExpr> = col("col1");
385        let col2: Arc<dyn VortexExpr> = col("col2");
386        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
387        let col3: Arc<dyn VortexExpr> = col("col3");
388        let col4: Arc<dyn VortexExpr> = col("col4");
389        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
390        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
391
392        let mut nodes = Vec::new();
393        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
394            if node.as_any().is::<GetItem>() {
395                nodes.push(node)
396            }
397            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
398                if bin.op() == Operator::Eq {
399                    return Ok(TraversalOrder::Skip);
400                }
401            }
402            Ok(TraversalOrder::Continue)
403        }))
404        .unwrap();
405
406        assert_eq!(
407            nodes
408                .into_iter()
409                .map(|x| x.references())
410                .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
411            HashSet::from_iter(vec![FieldName::from("col3"), FieldName::from("col4")])
412        );
413    }
414
415    #[test]
416    fn expr_stop_test() {
417        let col1: Arc<dyn VortexExpr> = col("col1");
418        let col2: Arc<dyn VortexExpr> = col("col2");
419        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
420        let col3: Arc<dyn VortexExpr> = col("col3");
421        let col4: Arc<dyn VortexExpr> = col("col4");
422        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
423        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
424
425        let mut nodes = Vec::new();
426        expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
427            if node.as_any().is::<GetItem>() {
428                nodes.push(node)
429            }
430            if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
431                if bin.op() == Operator::Eq {
432                    return Ok(TraversalOrder::Stop);
433                }
434            }
435            Ok(TraversalOrder::Continue)
436        }))
437        .unwrap();
438
439        assert_eq!(
440            nodes
441                .into_iter()
442                .map(|x| x.references())
443                .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
444            HashSet::from_iter(vec![])
445        );
446    }
447
448    #[test]
449    fn expr_skip_down_visit_up() {
450        let col = col("col");
451
452        let mut visitor = SkipDownVisitor;
453        let result = col.transform(&mut visitor).unwrap();
454
455        assert!(result.changed);
456        assert!(result.result.as_any().downcast_ref::<Identity>().is_some());
457    }
458}