vortex_expr/traversal/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Datafusion inspired tree traversal logic.
5//!
6//! Users should want to implement [`Node`] and potentially [`NodeContainer`].
7
8mod fold;
9mod references;
10mod visitor;
11
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15pub use fold::{FoldDown, FoldDownContext, FoldUp, NodeFolder, NodeFolderContext};
16use itertools::Itertools;
17pub use references::ReferenceCollector;
18pub use visitor::{pre_order_visit_down, pre_order_visit_up};
19use vortex_error::VortexResult;
20
21use crate::ExprRef;
22use crate::traversal::fold::NodeFolderContextWrapper;
23
24/// Signal to control a traversal's flow
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum TraversalOrder {
27    /// In a top-down traversal, skip visiting the children of the current node.
28    /// In the bottom-up phase of the traversal, skip the next step. Either skipping the children of the node,
29    /// moving to its next sibling, or skipping its parent once the children are traversed.
30    Skip,
31    /// Stop visiting any more nodes in the traversal.
32    Stop,
33    /// Continue with the traversal as expected.
34    Continue,
35}
36
37impl TraversalOrder {
38    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's children.
39    pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
40        self,
41        f: F,
42    ) -> VortexResult<TraversalOrder> {
43        match self {
44            Self::Skip => Ok(TraversalOrder::Continue),
45            Self::Stop => Ok(self),
46            Self::Continue => f(),
47        }
48    }
49
50    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's parent.
51    pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
52        self,
53        f: F,
54    ) -> VortexResult<TraversalOrder> {
55        match self {
56            Self::Continue => f(),
57            Self::Skip | Self::Stop => Ok(self),
58        }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct Transformed<T> {
64    /// Value that was being rewritten.
65    pub value: T,
66    /// Controls the flow of rewriting, see [`TraversalOrder`] for more details.
67    pub order: TraversalOrder,
68    /// Was the value changed during rewriting.
69    pub changed: bool,
70}
71
72impl<T> Transformed<T> {
73    pub fn yes(value: T) -> Self {
74        Self {
75            value,
76            order: TraversalOrder::Continue,
77            changed: true,
78        }
79    }
80
81    pub fn no(value: T) -> Self {
82        Self {
83            value,
84            order: TraversalOrder::Continue,
85            changed: false,
86        }
87    }
88
89    pub fn into_inner(self) -> T {
90        self.value
91    }
92
93    /// Apply a function to `value`, changing it without changing the `changed` field.
94    pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
95        Transformed {
96            value: f(self.value),
97            order: self.order,
98            changed: self.changed,
99        }
100    }
101}
102
103pub trait NodeVisitor<'a> {
104    type NodeTy: Node;
105
106    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
107        _ = node;
108        Ok(TraversalOrder::Continue)
109    }
110
111    fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
112        _ = node;
113        Ok(TraversalOrder::Continue)
114    }
115}
116
117pub trait NodeRewriter: Sized {
118    type NodeTy: Node;
119
120    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
121        Ok(Transformed::no(node))
122    }
123
124    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
125        Ok(Transformed::no(node))
126    }
127}
128
129pub trait Node: Sized + Clone {
130    /// Walk the node's children by applying `f` to them.
131    ///
132    /// This is a lower level API that other functions rely on for their implementation.
133    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
134        &'a self,
135        f: F,
136    ) -> VortexResult<TraversalOrder>;
137
138    /// Rewrite the node's children by applying `f` to them.
139    ///
140    /// This is a lower level API that other functions rely on for their implementation.
141    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
142        self,
143        f: F,
144    ) -> VortexResult<Transformed<Self>>;
145
146    /// This is a lower level API that other functions rely on for their implementation.
147    fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T;
148
149    /// This is a lower level API that other functions rely on for their implementation.
150    fn children_count(&self) -> usize;
151}
152
153pub trait NodeExt: Node {
154    /// Walk the tree in pre-order (top-down) way, rewriting it as it goes.
155    fn rewrite<R: NodeRewriter<NodeTy = Self>>(
156        self,
157        rewriter: &mut R,
158    ) -> VortexResult<Transformed<Self>> {
159        let mut transformed = rewriter.visit_down(self)?;
160
161        let transformed = match transformed.order {
162            TraversalOrder::Stop => Ok(transformed),
163            TraversalOrder::Skip => {
164                transformed.order = TraversalOrder::Continue;
165                Ok(transformed)
166            }
167            TraversalOrder::Continue => transformed
168                .value
169                .map_children(|c| c.rewrite(rewriter))
170                .map(|mut t| {
171                    t.changed |= transformed.changed;
172                    t
173                }),
174        }?;
175
176        match transformed.order {
177            TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
178            TraversalOrder::Continue => {
179                let mut up_rewrite = rewriter.visit_up(transformed.value)?;
180                up_rewrite.changed |= transformed.changed;
181                Ok(up_rewrite)
182            }
183        }
184    }
185
186    /// A pre-order (top-down) traversal.
187    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
188        &'a self,
189        visitor: &mut V,
190    ) -> VortexResult<TraversalOrder> {
191        visitor
192            .visit_down(self)?
193            .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
194            .visit_parent(|| visitor.visit_up(self))
195    }
196
197    /// A pre-order transformation
198    fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
199        self,
200        f: F,
201    ) -> VortexResult<Transformed<Self>> {
202        let mut rewriter = FnRewriter {
203            f_down: Some(f),
204            f_up: None,
205            _data: PhantomData,
206        };
207
208        self.rewrite(&mut rewriter)
209    }
210
211    /// A post-order transform
212    fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
213        self,
214        f: F,
215    ) -> VortexResult<Transformed<Self>> {
216        let mut rewriter = FnRewriter {
217            f_down: None,
218            f_up: Some(f),
219            _data: PhantomData,
220        };
221
222        self.rewrite(&mut rewriter)
223    }
224
225    /// applies the `NodeFolderContext` to the Node tree, with an initial `Context`.
226    fn fold_context<R, F: NodeFolderContext<NodeTy = Self, Result = R>>(
227        self,
228        ctx: &F::Context,
229        folder: &mut F,
230    ) -> VortexResult<FoldUp<R>> {
231        let transformed = folder.visit_down(ctx, &self)?;
232        let ctx = match transformed {
233            FoldDownContext::Continue(ctx) => ctx,
234            FoldDownContext::Skip(r) => return Ok(FoldUp::Continue(r)),
235            FoldDownContext::Stop(r) => return Ok(FoldUp::Stop(r)),
236        };
237
238        let mut children = Vec::with_capacity(self.children_count());
239        let mut stop_result = None;
240        self.iter_children(|children_iter| -> VortexResult<()> {
241            for c in children_iter {
242                let t = c.clone().fold_context(&ctx, folder)?;
243                match t {
244                    FoldUp::Stop(r) => {
245                        stop_result = Some(r);
246                        return Ok(());
247                    }
248                    FoldUp::Continue(r) => {
249                        children.push(r);
250                    }
251                }
252            }
253            Ok(())
254        })?;
255
256        if let Some(result) = stop_result {
257            return Ok(FoldUp::Stop(result));
258        }
259
260        folder.visit_up(self, &ctx, children)
261    }
262
263    /// applies the `NodeFolder` to the Node tree
264    fn fold<R, F: NodeFolder<NodeTy = Self, Result = R>>(
265        self,
266        folder: &mut F,
267    ) -> VortexResult<FoldUp<R>> {
268        let mut folder = NodeFolderContextWrapper { inner: folder };
269        self.fold_context(&(), &mut folder)
270    }
271}
272
273impl<T: Node> NodeExt for T {}
274
275struct FnRewriter<F, T> {
276    f_down: Option<F>,
277    f_up: Option<F>,
278    _data: PhantomData<T>,
279}
280
281impl<F, T> NodeRewriter for FnRewriter<F, T>
282where
283    T: Node,
284    F: FnMut(T) -> VortexResult<Transformed<T>>,
285{
286    type NodeTy = T;
287
288    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
289        if let Some(f) = self.f_down.as_mut() {
290            f(node)
291        } else {
292            Ok(Transformed::no(node))
293        }
294    }
295
296    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
297        if let Some(f) = self.f_up.as_mut() {
298            f(node)
299        } else {
300            Ok(Transformed::no(node))
301        }
302    }
303}
304
305/// A container holding a [`Node`]'s children, which a function can be applied (or mapped) to.
306///
307/// The trait is also implemented to container types in order to make implementing [`Node::map_children`]
308/// and [`Node::apply_children`] easier.
309pub trait NodeContainer<'a, T: 'a>: Sized {
310    /// Applies `f` to all elements of the container, accepting them by reference
311    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
312        &'a self,
313        f: F,
314    ) -> VortexResult<TraversalOrder>;
315
316    /// Consumes all the children of the node, replacing them with the result of `f`.
317    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
318        self,
319        f: F,
320    ) -> VortexResult<Transformed<Self>>;
321}
322
323pub trait NodeRefContainer<'a, T: 'a>: Sized {
324    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
325        &self,
326        f: F,
327    ) -> VortexResult<TraversalOrder>;
328}
329
330impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for Vec<&'a C> {
331    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
332        &self,
333        mut f: F,
334    ) -> VortexResult<TraversalOrder> {
335        let mut order = TraversalOrder::Continue;
336
337        for c in self {
338            order = c.apply_elements(&mut f)?;
339            match order {
340                TraversalOrder::Continue | TraversalOrder::Skip => {}
341                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
342            }
343        }
344
345        Ok(order)
346    }
347}
348
349impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
350    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
351        &'a self,
352        f: F,
353    ) -> VortexResult<TraversalOrder> {
354        self.as_ref().apply_elements(f)
355    }
356
357    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
358        self,
359        f: F,
360    ) -> VortexResult<Transformed<Box<C>>> {
361        Ok((*self).map_elements(f)?.map(Box::new))
362    }
363}
364
365impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
366where
367    T: 'a,
368    C: NodeContainer<'a, T> + Clone,
369{
370    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
371        &'a self,
372        f: F,
373    ) -> VortexResult<TraversalOrder> {
374        self.as_ref().apply_elements(f)
375    }
376
377    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
378        self,
379        f: F,
380    ) -> VortexResult<Transformed<Arc<C>>> {
381        Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
382    }
383}
384
385impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
386    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
387        &'a self,
388        mut f: F,
389    ) -> VortexResult<TraversalOrder> {
390        let [lhs, rhs] = self;
391        match lhs.apply_elements(&mut f)? {
392            TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
393            TraversalOrder::Stop => Ok(TraversalOrder::Stop),
394        }
395    }
396
397    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
398        self,
399        mut f: F,
400    ) -> VortexResult<Transformed<[C; 2]>> {
401        let [lhs, rhs] = self;
402        let transformed = lhs.map_elements(&mut f)?;
403        match transformed.order {
404            TraversalOrder::Skip | TraversalOrder::Continue => {
405                let mut t = rhs.map_elements(&mut f)?;
406                t.changed |= transformed.changed;
407                Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
408            }
409            TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
410        }
411    }
412}
413
414impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
415    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
416        &'a self,
417        mut f: F,
418    ) -> VortexResult<TraversalOrder> {
419        let mut order = TraversalOrder::Continue;
420
421        for c in self {
422            order = c.apply_elements(&mut f)?;
423            match order {
424                TraversalOrder::Continue | TraversalOrder::Skip => {}
425                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
426            }
427        }
428
429        Ok(order)
430    }
431
432    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
433        self,
434        mut f: F,
435    ) -> VortexResult<Transformed<Self>> {
436        let mut order = TraversalOrder::Continue;
437        let mut changed = false;
438
439        let value = self
440            .into_iter()
441            .map(|c| match order {
442                TraversalOrder::Continue | TraversalOrder::Skip => {
443                    c.map_elements(&mut f).map(|result| {
444                        order = result.order;
445                        changed |= result.changed;
446                        result.value
447                    })
448                }
449                TraversalOrder::Stop => Ok(c),
450            })
451            .collect::<VortexResult<Vec<_>>>()?;
452
453        Ok(Transformed {
454            value,
455            order,
456            changed,
457        })
458    }
459}
460
461impl<'a> NodeContainer<'a, Self> for ExprRef {
462    fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
463        &'a self,
464        mut f: F,
465    ) -> VortexResult<TraversalOrder> {
466        f(self)
467    }
468
469    fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
470        self,
471        mut f: F,
472    ) -> VortexResult<Transformed<Self>> {
473        f(self)
474    }
475}
476
477impl Node for ExprRef {
478    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
479        &'a self,
480        mut f: F,
481    ) -> VortexResult<TraversalOrder> {
482        self.children().apply_ref_elements(&mut f)
483    }
484
485    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
486        self,
487        f: F,
488    ) -> VortexResult<Transformed<Self>> {
489        let transformed = self
490            .children()
491            .into_iter()
492            .cloned()
493            .collect_vec()
494            .map_elements(f)?;
495
496        if transformed.changed {
497            Ok(Transformed {
498                value: self.with_children(transformed.value)?,
499                order: transformed.order,
500                changed: true,
501            })
502        } else {
503            Ok(Transformed::no(self))
504        }
505    }
506
507    fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
508        f(&mut self.children().into_iter())
509    }
510
511    fn children_count(&self) -> usize {
512        self.children().len()
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use std::sync::Arc;
519
520    use vortex_error::VortexResult;
521    use vortex_utils::aliases::hash_set::HashSet;
522
523    use crate::traversal::visitor::pre_order_visit_down;
524    use crate::traversal::{NodeExt, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
525    use crate::{
526        BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralExpr, LiteralVTable,
527        Operator, VortexExpr, col, is_root, root,
528    };
529
530    #[derive(Default)]
531    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
532
533    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
534        type NodeTy = ExprRef;
535
536        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
537            if node.is::<LiteralVTable>() {
538                self.0.push(node)
539            }
540            Ok(TraversalOrder::Continue)
541        }
542
543        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
544            Ok(TraversalOrder::Continue)
545        }
546    }
547
548    fn expr_col_to_lit_transform(
549        node: ExprRef,
550        idx: &mut i32,
551    ) -> VortexResult<Transformed<ExprRef>> {
552        if node.is::<GetItemVTable>() {
553            let lit_id = *idx;
554            *idx += 1;
555            Ok(Transformed::yes(LiteralExpr::new_expr(lit_id)))
556        } else {
557            Ok(Transformed::no(node))
558        }
559    }
560
561    #[derive(Default)]
562    pub struct SkipDownRewriter;
563
564    impl NodeRewriter for SkipDownRewriter {
565        type NodeTy = ExprRef;
566
567        fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
568            Ok(Transformed {
569                value: node,
570                order: TraversalOrder::Skip,
571                changed: false,
572            })
573        }
574
575        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
576            Ok(Transformed::yes(root()))
577        }
578    }
579
580    #[test]
581    fn expr_deep_visitor_test() {
582        let col1: Arc<dyn VortexExpr> = col("col1");
583        let lit1 = LiteralExpr::new(1).into_expr();
584        let expr = BinaryExpr::new(col1.clone(), Operator::Eq, lit1.clone()).into_expr();
585        let lit2 = LiteralExpr::new(2).into_expr();
586        let expr = BinaryExpr::new(expr, Operator::And, lit2).into_expr();
587        let mut printer = ExprLitCollector::default();
588        expr.accept(&mut printer).unwrap();
589        assert_eq!(printer.0.len(), 2);
590    }
591
592    #[test]
593    fn expr_deep_mut_visitor_test() {
594        let col1: Arc<dyn VortexExpr> = col("col1");
595        let col2: Arc<dyn VortexExpr> = col("col2");
596        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
597        let lit2 = LiteralExpr::new_expr(2);
598        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
599
600        let mut idx = 0_i32;
601        let new = expr
602            .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
603            .unwrap();
604        assert!(new.changed);
605
606        let expr = new.value;
607
608        let mut printer = ExprLitCollector::default();
609        expr.accept(&mut printer).unwrap();
610        assert_eq!(printer.0.len(), 3);
611    }
612
613    #[test]
614    fn expr_skip_test() {
615        let col1: Arc<dyn VortexExpr> = col("col1");
616        let col2: Arc<dyn VortexExpr> = col("col2");
617        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
618        let col3: Arc<dyn VortexExpr> = col("col3");
619        let col4: Arc<dyn VortexExpr> = col("col4");
620        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
621        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
622
623        let mut nodes = Vec::new();
624        pre_order_visit_down(&expr, |node: &ExprRef| {
625            if node.is::<GetItemVTable>() {
626                nodes.push(node)
627            }
628            if let Some(bin) = node.as_opt::<BinaryVTable>()
629                && bin.op() == Operator::Eq
630            {
631                return Ok(TraversalOrder::Skip);
632            }
633            Ok(TraversalOrder::Continue)
634        })
635        .unwrap();
636
637        let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
638        assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
639    }
640
641    #[test]
642    fn expr_stop_test() {
643        let col1: Arc<dyn VortexExpr> = col("col1");
644        let col2: Arc<dyn VortexExpr> = col("col2");
645        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
646        let col3: Arc<dyn VortexExpr> = col("col3");
647        let col4: Arc<dyn VortexExpr> = col("col4");
648        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
649        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
650
651        let mut nodes = Vec::new();
652        pre_order_visit_down(&expr, |node: &ExprRef| {
653            if node.is::<GetItemVTable>() {
654                nodes.push(node)
655            }
656            if let Some(bin) = node.as_opt::<BinaryVTable>()
657                && bin.op() == Operator::Eq
658            {
659                return Ok(TraversalOrder::Stop);
660            }
661            Ok(TraversalOrder::Continue)
662        })
663        .unwrap();
664
665        assert!(nodes.is_empty());
666    }
667
668    #[test]
669    fn expr_skip_down_visit_up() {
670        let col = col("col");
671
672        let mut visitor = SkipDownRewriter;
673        let result = col.rewrite(&mut visitor).unwrap();
674
675        assert!(result.changed);
676        assert!(is_root(&result.value));
677    }
678}