vortex_expr/traversal/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Datafusion inspired tree traversal logic.
5//!
6//! Users should want to implement [`Node`] and potentially [`NodeContainer`].
7
8mod references;
9mod visitor;
10
11use std::marker::PhantomData;
12use std::sync::Arc;
13
14use itertools::Itertools;
15pub use references::ReferenceCollector;
16pub use visitor::{pre_order_visit_down, pre_order_visit_up};
17use vortex_error::VortexResult;
18
19use crate::ExprRef;
20
21/// Signal to control a traversal's flow
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum TraversalOrder {
24    /// In a top-down traversal, skip visiting the children of the current node.
25    /// In the bottom-up phase of the traversal, skip the next step. Either skipping the children of the node,
26    /// moving to its next sibling, or skipping its parent once the children are traversed.
27    Skip,
28    /// Stop visiting any more nodes in the traversal.
29    Stop,
30    /// Continue with the traversal as expected.
31    Continue,
32}
33
34impl TraversalOrder {
35    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's children.
36    pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
37        self,
38        f: F,
39    ) -> VortexResult<TraversalOrder> {
40        match self {
41            Self::Skip => Ok(TraversalOrder::Continue),
42            Self::Stop => Ok(self),
43            Self::Continue => f(),
44        }
45    }
46
47    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's parent.
48    pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
49        self,
50        f: F,
51    ) -> VortexResult<TraversalOrder> {
52        match self {
53            Self::Continue => f(),
54            Self::Skip | Self::Stop => Ok(self),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct Transformed<T> {
61    /// Value that was being rewritten.
62    pub value: T,
63    /// Controls the flow of rewriting, see [`TraversalOrder`] for more details.
64    pub order: TraversalOrder,
65    /// Was the value changed during rewriting.
66    pub changed: bool,
67}
68
69impl<T> Transformed<T> {
70    pub fn yes(value: T) -> Self {
71        Self {
72            value,
73            order: TraversalOrder::Continue,
74            changed: true,
75        }
76    }
77
78    pub fn no(value: T) -> Self {
79        Self {
80            value,
81            order: TraversalOrder::Continue,
82            changed: false,
83        }
84    }
85
86    pub fn into_inner(self) -> T {
87        self.value
88    }
89
90    /// Apply a function to `value`, changing it without changing the `changed` field.
91    pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
92        Transformed {
93            value: f(self.value),
94            order: self.order,
95            changed: self.changed,
96        }
97    }
98}
99
100pub trait NodeVisitor<'a> {
101    type NodeTy: Node;
102
103    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
104        _ = node;
105        Ok(TraversalOrder::Continue)
106    }
107
108    fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
109        _ = node;
110        Ok(TraversalOrder::Continue)
111    }
112}
113
114pub trait NodeRewriter: Sized {
115    type NodeTy: Node;
116
117    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
118        Ok(Transformed::no(node))
119    }
120
121    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
122        Ok(Transformed::no(node))
123    }
124}
125
126pub trait Node: Sized + Clone {
127    /// Walk the node's children by applying `f` to them.
128    ///
129    /// This is a lower level API that other functions rely on for correctness.
130    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
131        &'a self,
132        f: F,
133    ) -> VortexResult<TraversalOrder>;
134
135    /// Rewrite the node's children by applying `f` to them.
136    ///
137    /// This is a lower level API that other functions rely on for correctness.
138    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
139        self,
140        f: F,
141    ) -> VortexResult<Transformed<Self>>;
142
143    /// Walk the tree in pre-order (top-down) way, rewriting it as it goes.
144    fn rewrite<R: NodeRewriter<NodeTy = Self>>(
145        self,
146        rewriter: &mut R,
147    ) -> VortexResult<Transformed<Self>> {
148        let mut transformed = rewriter.visit_down(self)?;
149
150        let transformed = match transformed.order {
151            TraversalOrder::Stop => Ok(transformed),
152            TraversalOrder::Skip => {
153                transformed.order = TraversalOrder::Continue;
154                Ok(transformed)
155            }
156            TraversalOrder::Continue => transformed
157                .value
158                .map_children(|c| c.rewrite(rewriter))
159                .map(|mut t| {
160                    t.changed |= transformed.changed;
161                    t
162                }),
163        }?;
164
165        match transformed.order {
166            TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
167            TraversalOrder::Continue => {
168                let mut up_rewrite = rewriter.visit_up(transformed.value)?;
169                up_rewrite.changed |= transformed.changed;
170                Ok(up_rewrite)
171            }
172        }
173    }
174
175    /// A pre-order (top-down) traversal.
176    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
177        &'a self,
178        visitor: &mut V,
179    ) -> VortexResult<TraversalOrder> {
180        visitor
181            .visit_down(self)?
182            .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
183            .visit_parent(|| visitor.visit_up(self))
184    }
185
186    /// A pre-order transformation
187    fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
188        self,
189        f: F,
190    ) -> VortexResult<Transformed<Self>> {
191        let mut rewriter = FnRewriter {
192            f_down: Some(f),
193            f_up: None,
194            _data: PhantomData,
195        };
196
197        self.rewrite(&mut rewriter)
198    }
199
200    /// A post-order transform
201    fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
202        self,
203        f: F,
204    ) -> VortexResult<Transformed<Self>> {
205        let mut rewriter = FnRewriter {
206            f_down: None,
207            f_up: Some(f),
208            _data: PhantomData,
209        };
210
211        self.rewrite(&mut rewriter)
212    }
213}
214
215struct FnRewriter<F, T> {
216    f_down: Option<F>,
217    f_up: Option<F>,
218    _data: PhantomData<T>,
219}
220
221impl<F, T> NodeRewriter for FnRewriter<F, T>
222where
223    T: Node,
224    F: FnMut(T) -> VortexResult<Transformed<T>>,
225{
226    type NodeTy = T;
227
228    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
229        if let Some(f) = self.f_down.as_mut() {
230            f(node)
231        } else {
232            Ok(Transformed::no(node))
233        }
234    }
235
236    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
237        if let Some(f) = self.f_up.as_mut() {
238            f(node)
239        } else {
240            Ok(Transformed::no(node))
241        }
242    }
243}
244
245/// A container holding a [`Node`]'s children, which a function can be applied (or mapped) to.
246///
247/// The trait is also implemented to container types in order to make implementing [`Node::map_children`]
248/// and [`Node::apply_children`] easier.
249pub trait NodeContainer<'a, T: 'a>: Sized {
250    /// Applies `f` to all elements of the container, accepting them by reference
251    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
252        &'a self,
253        f: F,
254    ) -> VortexResult<TraversalOrder>;
255
256    /// Consumes all the children of the node, replacing them with the result of `f`.
257    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
258        self,
259        f: F,
260    ) -> VortexResult<Transformed<Self>>;
261}
262
263pub trait NodeRefContainer<'a, T: 'a>: Sized {
264    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
265        &self,
266        f: F,
267    ) -> VortexResult<TraversalOrder>;
268}
269
270impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for Vec<&'a C> {
271    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
272        &self,
273        mut f: F,
274    ) -> VortexResult<TraversalOrder> {
275        let mut order = TraversalOrder::Continue;
276
277        for c in self {
278            order = c.apply_elements(&mut f)?;
279            match order {
280                TraversalOrder::Continue | TraversalOrder::Skip => {}
281                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
282            }
283        }
284
285        Ok(order)
286    }
287}
288
289impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
290    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
291        &'a self,
292        f: F,
293    ) -> VortexResult<TraversalOrder> {
294        self.as_ref().apply_elements(f)
295    }
296
297    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
298        self,
299        f: F,
300    ) -> VortexResult<Transformed<Box<C>>> {
301        Ok((*self).map_elements(f)?.map(Box::new))
302    }
303}
304
305impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
306where
307    T: 'a,
308    C: NodeContainer<'a, T> + Clone,
309{
310    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
311        &'a self,
312        f: F,
313    ) -> VortexResult<TraversalOrder> {
314        self.as_ref().apply_elements(f)
315    }
316
317    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
318        self,
319        f: F,
320    ) -> VortexResult<Transformed<Arc<C>>> {
321        Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
322    }
323}
324
325impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
326    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
327        &'a self,
328        mut f: F,
329    ) -> VortexResult<TraversalOrder> {
330        let [lhs, rhs] = self;
331        match lhs.apply_elements(&mut f)? {
332            TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
333            TraversalOrder::Stop => Ok(TraversalOrder::Stop),
334        }
335    }
336
337    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
338        self,
339        mut f: F,
340    ) -> VortexResult<Transformed<[C; 2]>> {
341        let [lhs, rhs] = self;
342        let transformed = lhs.map_elements(&mut f)?;
343        match transformed.order {
344            TraversalOrder::Skip | TraversalOrder::Continue => {
345                let mut t = rhs.map_elements(&mut f)?;
346                t.changed |= transformed.changed;
347                Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
348            }
349            TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
350        }
351    }
352}
353
354impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
355    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
356        &'a self,
357        mut f: F,
358    ) -> VortexResult<TraversalOrder> {
359        let mut order = TraversalOrder::Continue;
360
361        for c in self {
362            order = c.apply_elements(&mut f)?;
363            match order {
364                TraversalOrder::Continue | TraversalOrder::Skip => {}
365                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
366            }
367        }
368
369        Ok(order)
370    }
371
372    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
373        self,
374        mut f: F,
375    ) -> VortexResult<Transformed<Self>> {
376        let mut order = TraversalOrder::Continue;
377        let mut changed = false;
378
379        let value = self
380            .into_iter()
381            .map(|c| match order {
382                TraversalOrder::Continue | TraversalOrder::Skip => {
383                    c.map_elements(&mut f).map(|result| {
384                        order = result.order;
385                        changed |= result.changed;
386                        result.value
387                    })
388                }
389                TraversalOrder::Stop => Ok(c),
390            })
391            .collect::<VortexResult<Vec<_>>>()?;
392
393        Ok(Transformed {
394            value,
395            order,
396            changed,
397        })
398    }
399}
400
401impl<'a> NodeContainer<'a, Self> for ExprRef {
402    fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
403        &'a self,
404        mut f: F,
405    ) -> VortexResult<TraversalOrder> {
406        f(self)
407    }
408
409    fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
410        self,
411        mut f: F,
412    ) -> VortexResult<Transformed<Self>> {
413        f(self)
414    }
415}
416
417impl Node for ExprRef {
418    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
419        &'a self,
420        mut f: F,
421    ) -> VortexResult<TraversalOrder> {
422        self.children().apply_ref_elements(&mut f)
423    }
424
425    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
426        self,
427        f: F,
428    ) -> VortexResult<Transformed<Self>> {
429        let transformed = self
430            .children()
431            .into_iter()
432            .cloned()
433            .collect_vec()
434            .map_elements(f)?;
435
436        if transformed.changed {
437            Ok(Transformed {
438                value: self.with_children(transformed.value)?,
439                order: transformed.order,
440                changed: true,
441            })
442        } else {
443            Ok(Transformed::no(self))
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use std::sync::Arc;
451
452    use vortex_error::VortexResult;
453    use vortex_utils::aliases::hash_set::HashSet;
454
455    use crate::traversal::visitor::pre_order_visit_down;
456    use crate::traversal::{Node, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
457    use crate::{
458        BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralExpr, LiteralVTable,
459        Operator, VortexExpr, col, is_root, root,
460    };
461
462    #[derive(Default)]
463    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
464
465    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
466        type NodeTy = ExprRef;
467
468        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
469            if node.is::<LiteralVTable>() {
470                self.0.push(node)
471            }
472            Ok(TraversalOrder::Continue)
473        }
474
475        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
476            Ok(TraversalOrder::Continue)
477        }
478    }
479
480    fn expr_col_to_lit_transform(
481        node: ExprRef,
482        idx: &mut i32,
483    ) -> VortexResult<Transformed<ExprRef>> {
484        if node.is::<GetItemVTable>() {
485            let lit_id = *idx;
486            *idx += 1;
487            Ok(Transformed::yes(LiteralExpr::new_expr(lit_id)))
488        } else {
489            Ok(Transformed::no(node))
490        }
491    }
492
493    #[derive(Default)]
494    pub struct SkipDownRewriter;
495
496    impl NodeRewriter for SkipDownRewriter {
497        type NodeTy = ExprRef;
498
499        fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
500            Ok(Transformed {
501                value: node,
502                order: TraversalOrder::Skip,
503                changed: false,
504            })
505        }
506
507        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
508            Ok(Transformed::yes(root()))
509        }
510    }
511
512    #[test]
513    fn expr_deep_visitor_test() {
514        let col1: Arc<dyn VortexExpr> = col("col1");
515        let lit1 = LiteralExpr::new(1).into_expr();
516        let expr = BinaryExpr::new(col1.clone(), Operator::Eq, lit1.clone()).into_expr();
517        let lit2 = LiteralExpr::new(2).into_expr();
518        let expr = BinaryExpr::new(expr, Operator::And, lit2).into_expr();
519        let mut printer = ExprLitCollector::default();
520        expr.accept(&mut printer).unwrap();
521        assert_eq!(printer.0.len(), 2);
522    }
523
524    #[test]
525    fn expr_deep_mut_visitor_test() {
526        let col1: Arc<dyn VortexExpr> = col("col1");
527        let col2: Arc<dyn VortexExpr> = col("col2");
528        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
529        let lit2 = LiteralExpr::new_expr(2);
530        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
531
532        let mut idx = 0_i32;
533        let new = expr
534            .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
535            .unwrap();
536        assert!(new.changed);
537
538        let expr = new.value;
539
540        let mut printer = ExprLitCollector::default();
541        expr.accept(&mut printer).unwrap();
542        assert_eq!(printer.0.len(), 3);
543    }
544
545    #[test]
546    fn expr_skip_test() {
547        let col1: Arc<dyn VortexExpr> = col("col1");
548        let col2: Arc<dyn VortexExpr> = col("col2");
549        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
550        let col3: Arc<dyn VortexExpr> = col("col3");
551        let col4: Arc<dyn VortexExpr> = col("col4");
552        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
553        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
554
555        let mut nodes = Vec::new();
556        pre_order_visit_down(&expr, |node: &ExprRef| {
557            if node.is::<GetItemVTable>() {
558                nodes.push(node)
559            }
560            if let Some(bin) = node.as_opt::<BinaryVTable>() {
561                if bin.op() == Operator::Eq {
562                    return Ok(TraversalOrder::Skip);
563                }
564            }
565            Ok(TraversalOrder::Continue)
566        })
567        .unwrap();
568
569        let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
570        assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
571    }
572
573    #[test]
574    fn expr_stop_test() {
575        let col1: Arc<dyn VortexExpr> = col("col1");
576        let col2: Arc<dyn VortexExpr> = col("col2");
577        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
578        let col3: Arc<dyn VortexExpr> = col("col3");
579        let col4: Arc<dyn VortexExpr> = col("col4");
580        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
581        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
582
583        let mut nodes = Vec::new();
584        pre_order_visit_down(&expr, |node: &ExprRef| {
585            if node.is::<GetItemVTable>() {
586                nodes.push(node)
587            }
588            if let Some(bin) = node.as_opt::<BinaryVTable>() {
589                if bin.op() == Operator::Eq {
590                    return Ok(TraversalOrder::Stop);
591                }
592            }
593            Ok(TraversalOrder::Continue)
594        })
595        .unwrap();
596
597        assert!(nodes.is_empty());
598    }
599
600    #[test]
601    fn expr_skip_down_visit_up() {
602        let col = col("col");
603
604        let mut visitor = SkipDownRewriter;
605        let result = col.rewrite(&mut visitor).unwrap();
606
607        assert!(result.changed);
608        assert!(is_root(&result.value));
609    }
610}