vortex_expr/traversal/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Datafusion inspired tree traversal logic.
5//!
6//! Users should want to implement [`Node`] and potentially [`NodeContainer`].
7
8mod fold;
9mod references;
10mod visitor;
11
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15use itertools::Itertools;
16pub use references::ReferenceCollector;
17pub use visitor::{pre_order_visit_down, pre_order_visit_up};
18use vortex_error::VortexResult;
19
20use crate::ExprRef;
21use crate::traversal::fold::{
22    FoldDownContext, FoldUp, NodeFolder, NodeFolderContext, NodeFolderContextWrapper,
23};
24
25/// Signal to control a traversal's flow
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum TraversalOrder {
28    /// In a top-down traversal, skip visiting the children of the current node.
29    /// In the bottom-up phase of the traversal, skip the next step. Either skipping the children of the node,
30    /// moving to its next sibling, or skipping its parent once the children are traversed.
31    Skip,
32    /// Stop visiting any more nodes in the traversal.
33    Stop,
34    /// Continue with the traversal as expected.
35    Continue,
36}
37
38impl TraversalOrder {
39    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's children.
40    pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
41        self,
42        f: F,
43    ) -> VortexResult<TraversalOrder> {
44        match self {
45            Self::Skip => Ok(TraversalOrder::Continue),
46            Self::Stop => Ok(self),
47            Self::Continue => f(),
48        }
49    }
50
51    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's parent.
52    pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
53        self,
54        f: F,
55    ) -> VortexResult<TraversalOrder> {
56        match self {
57            Self::Continue => f(),
58            Self::Skip | Self::Stop => Ok(self),
59        }
60    }
61}
62
63#[derive(Debug, Clone)]
64pub struct Transformed<T> {
65    /// Value that was being rewritten.
66    pub value: T,
67    /// Controls the flow of rewriting, see [`TraversalOrder`] for more details.
68    pub order: TraversalOrder,
69    /// Was the value changed during rewriting.
70    pub changed: bool,
71}
72
73impl<T> Transformed<T> {
74    pub fn yes(value: T) -> Self {
75        Self {
76            value,
77            order: TraversalOrder::Continue,
78            changed: true,
79        }
80    }
81
82    pub fn no(value: T) -> Self {
83        Self {
84            value,
85            order: TraversalOrder::Continue,
86            changed: false,
87        }
88    }
89
90    pub fn into_inner(self) -> T {
91        self.value
92    }
93
94    /// Apply a function to `value`, changing it without changing the `changed` field.
95    pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
96        Transformed {
97            value: f(self.value),
98            order: self.order,
99            changed: self.changed,
100        }
101    }
102}
103
104pub trait NodeVisitor<'a> {
105    type NodeTy: Node;
106
107    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
108        _ = node;
109        Ok(TraversalOrder::Continue)
110    }
111
112    fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
113        _ = node;
114        Ok(TraversalOrder::Continue)
115    }
116}
117
118pub trait NodeRewriter: Sized {
119    type NodeTy: Node;
120
121    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
122        Ok(Transformed::no(node))
123    }
124
125    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
126        Ok(Transformed::no(node))
127    }
128}
129
130pub trait Node: Sized + Clone {
131    /// Walk the node's children by applying `f` to them.
132    ///
133    /// This is a lower level API that other functions rely on for their implementation.
134    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
135        &'a self,
136        f: F,
137    ) -> VortexResult<TraversalOrder>;
138
139    /// Rewrite the node's children by applying `f` to them.
140    ///
141    /// This is a lower level API that other functions rely on for their implementation.
142    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
143        self,
144        f: F,
145    ) -> VortexResult<Transformed<Self>>;
146
147    /// This is a lower level API that other functions rely on for their implementation.
148    fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T;
149
150    /// This is a lower level API that other functions rely on for their implementation.
151    fn children_count(&self) -> usize;
152}
153
154pub trait NodeExt: Node {
155    /// Walk the tree in pre-order (top-down) way, rewriting it as it goes.
156    fn rewrite<R: NodeRewriter<NodeTy = Self>>(
157        self,
158        rewriter: &mut R,
159    ) -> VortexResult<Transformed<Self>> {
160        let mut transformed = rewriter.visit_down(self)?;
161
162        let transformed = match transformed.order {
163            TraversalOrder::Stop => Ok(transformed),
164            TraversalOrder::Skip => {
165                transformed.order = TraversalOrder::Continue;
166                Ok(transformed)
167            }
168            TraversalOrder::Continue => transformed
169                .value
170                .map_children(|c| c.rewrite(rewriter))
171                .map(|mut t| {
172                    t.changed |= transformed.changed;
173                    t
174                }),
175        }?;
176
177        match transformed.order {
178            TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
179            TraversalOrder::Continue => {
180                let mut up_rewrite = rewriter.visit_up(transformed.value)?;
181                up_rewrite.changed |= transformed.changed;
182                Ok(up_rewrite)
183            }
184        }
185    }
186
187    /// A pre-order (top-down) traversal.
188    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
189        &'a self,
190        visitor: &mut V,
191    ) -> VortexResult<TraversalOrder> {
192        visitor
193            .visit_down(self)?
194            .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
195            .visit_parent(|| visitor.visit_up(self))
196    }
197
198    /// A pre-order transformation
199    fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
200        self,
201        f: F,
202    ) -> VortexResult<Transformed<Self>> {
203        let mut rewriter = FnRewriter {
204            f_down: Some(f),
205            f_up: None,
206            _data: PhantomData,
207        };
208
209        self.rewrite(&mut rewriter)
210    }
211
212    /// A post-order transform
213    fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
214        self,
215        f: F,
216    ) -> VortexResult<Transformed<Self>> {
217        let mut rewriter = FnRewriter {
218            f_down: None,
219            f_up: Some(f),
220            _data: PhantomData,
221        };
222
223        self.rewrite(&mut rewriter)
224    }
225
226    /// applies the `NodeFolderContext` to the Node tree, with an initial `Context`.
227    fn fold_context<R, F: NodeFolderContext<NodeTy = Self, Result = R>>(
228        self,
229        ctx: &F::Context,
230        folder: &mut F,
231    ) -> VortexResult<FoldUp<R>> {
232        let transformed = folder.visit_down(ctx, &self)?;
233        let ctx = match transformed {
234            FoldDownContext::Continue(ctx) => ctx,
235            FoldDownContext::Skip(r) => return Ok(FoldUp::Continue(r)),
236            FoldDownContext::Stop(r) => return Ok(FoldUp::Stop(r)),
237        };
238
239        let mut children = Vec::with_capacity(self.children_count());
240        let mut stop_result = None;
241        self.iter_children(|children_iter| -> VortexResult<()> {
242            for c in children_iter {
243                let t = c.clone().fold_context(&ctx, folder)?;
244                match t {
245                    FoldUp::Stop(r) => {
246                        stop_result = Some(r);
247                        return Ok(());
248                    }
249                    FoldUp::Continue(r) => {
250                        children.push(r);
251                    }
252                }
253            }
254            Ok(())
255        })?;
256
257        if let Some(result) = stop_result {
258            return Ok(FoldUp::Stop(result));
259        }
260
261        folder.visit_up(self, &ctx, children)
262    }
263
264    /// applies the `NodeFolder` to the Node tree
265    fn fold<R, F: NodeFolder<NodeTy = Self, Result = R>>(
266        self,
267        folder: &mut F,
268    ) -> VortexResult<FoldUp<R>> {
269        let mut folder = NodeFolderContextWrapper { inner: folder };
270        self.fold_context(&(), &mut folder)
271    }
272}
273
274impl<T: Node> NodeExt for T {}
275
276struct FnRewriter<F, T> {
277    f_down: Option<F>,
278    f_up: Option<F>,
279    _data: PhantomData<T>,
280}
281
282impl<F, T> NodeRewriter for FnRewriter<F, T>
283where
284    T: Node,
285    F: FnMut(T) -> VortexResult<Transformed<T>>,
286{
287    type NodeTy = T;
288
289    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
290        if let Some(f) = self.f_down.as_mut() {
291            f(node)
292        } else {
293            Ok(Transformed::no(node))
294        }
295    }
296
297    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
298        if let Some(f) = self.f_up.as_mut() {
299            f(node)
300        } else {
301            Ok(Transformed::no(node))
302        }
303    }
304}
305
306/// A container holding a [`Node`]'s children, which a function can be applied (or mapped) to.
307///
308/// The trait is also implemented to container types in order to make implementing [`Node::map_children`]
309/// and [`Node::apply_children`] easier.
310pub trait NodeContainer<'a, T: 'a>: Sized {
311    /// Applies `f` to all elements of the container, accepting them by reference
312    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
313        &'a self,
314        f: F,
315    ) -> VortexResult<TraversalOrder>;
316
317    /// Consumes all the children of the node, replacing them with the result of `f`.
318    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
319        self,
320        f: F,
321    ) -> VortexResult<Transformed<Self>>;
322}
323
324pub trait NodeRefContainer<'a, T: 'a>: Sized {
325    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
326        &self,
327        f: F,
328    ) -> VortexResult<TraversalOrder>;
329}
330
331impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for Vec<&'a C> {
332    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
333        &self,
334        mut f: F,
335    ) -> VortexResult<TraversalOrder> {
336        let mut order = TraversalOrder::Continue;
337
338        for c in self {
339            order = c.apply_elements(&mut f)?;
340            match order {
341                TraversalOrder::Continue | TraversalOrder::Skip => {}
342                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
343            }
344        }
345
346        Ok(order)
347    }
348}
349
350impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
351    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
352        &'a self,
353        f: F,
354    ) -> VortexResult<TraversalOrder> {
355        self.as_ref().apply_elements(f)
356    }
357
358    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
359        self,
360        f: F,
361    ) -> VortexResult<Transformed<Box<C>>> {
362        Ok((*self).map_elements(f)?.map(Box::new))
363    }
364}
365
366impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
367where
368    T: 'a,
369    C: NodeContainer<'a, T> + Clone,
370{
371    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
372        &'a self,
373        f: F,
374    ) -> VortexResult<TraversalOrder> {
375        self.as_ref().apply_elements(f)
376    }
377
378    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
379        self,
380        f: F,
381    ) -> VortexResult<Transformed<Arc<C>>> {
382        Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
383    }
384}
385
386impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
387    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
388        &'a self,
389        mut f: F,
390    ) -> VortexResult<TraversalOrder> {
391        let [lhs, rhs] = self;
392        match lhs.apply_elements(&mut f)? {
393            TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
394            TraversalOrder::Stop => Ok(TraversalOrder::Stop),
395        }
396    }
397
398    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
399        self,
400        mut f: F,
401    ) -> VortexResult<Transformed<[C; 2]>> {
402        let [lhs, rhs] = self;
403        let transformed = lhs.map_elements(&mut f)?;
404        match transformed.order {
405            TraversalOrder::Skip | TraversalOrder::Continue => {
406                let mut t = rhs.map_elements(&mut f)?;
407                t.changed |= transformed.changed;
408                Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
409            }
410            TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
411        }
412    }
413}
414
415impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
416    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
417        &'a self,
418        mut f: F,
419    ) -> VortexResult<TraversalOrder> {
420        let mut order = TraversalOrder::Continue;
421
422        for c in self {
423            order = c.apply_elements(&mut f)?;
424            match order {
425                TraversalOrder::Continue | TraversalOrder::Skip => {}
426                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
427            }
428        }
429
430        Ok(order)
431    }
432
433    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
434        self,
435        mut f: F,
436    ) -> VortexResult<Transformed<Self>> {
437        let mut order = TraversalOrder::Continue;
438        let mut changed = false;
439
440        let value = self
441            .into_iter()
442            .map(|c| match order {
443                TraversalOrder::Continue | TraversalOrder::Skip => {
444                    c.map_elements(&mut f).map(|result| {
445                        order = result.order;
446                        changed |= result.changed;
447                        result.value
448                    })
449                }
450                TraversalOrder::Stop => Ok(c),
451            })
452            .collect::<VortexResult<Vec<_>>>()?;
453
454        Ok(Transformed {
455            value,
456            order,
457            changed,
458        })
459    }
460}
461
462impl<'a> NodeContainer<'a, Self> for ExprRef {
463    fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
464        &'a self,
465        mut f: F,
466    ) -> VortexResult<TraversalOrder> {
467        f(self)
468    }
469
470    fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
471        self,
472        mut f: F,
473    ) -> VortexResult<Transformed<Self>> {
474        f(self)
475    }
476}
477
478impl Node for ExprRef {
479    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
480        &'a self,
481        mut f: F,
482    ) -> VortexResult<TraversalOrder> {
483        self.children().apply_ref_elements(&mut f)
484    }
485
486    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
487        self,
488        f: F,
489    ) -> VortexResult<Transformed<Self>> {
490        let transformed = self
491            .children()
492            .into_iter()
493            .cloned()
494            .collect_vec()
495            .map_elements(f)?;
496
497        if transformed.changed {
498            Ok(Transformed {
499                value: self.with_children(transformed.value)?,
500                order: transformed.order,
501                changed: true,
502            })
503        } else {
504            Ok(Transformed::no(self))
505        }
506    }
507
508    fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
509        f(&mut self.children().into_iter())
510    }
511
512    fn children_count(&self) -> usize {
513        self.children().len()
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use std::sync::Arc;
520
521    use vortex_error::VortexResult;
522    use vortex_utils::aliases::hash_set::HashSet;
523
524    use crate::traversal::visitor::pre_order_visit_down;
525    use crate::traversal::{NodeExt, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
526    use crate::{
527        BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralExpr, LiteralVTable,
528        Operator, VortexExpr, col, is_root, root,
529    };
530
531    #[derive(Default)]
532    pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
533
534    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
535        type NodeTy = ExprRef;
536
537        fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
538            if node.is::<LiteralVTable>() {
539                self.0.push(node)
540            }
541            Ok(TraversalOrder::Continue)
542        }
543
544        fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
545            Ok(TraversalOrder::Continue)
546        }
547    }
548
549    fn expr_col_to_lit_transform(
550        node: ExprRef,
551        idx: &mut i32,
552    ) -> VortexResult<Transformed<ExprRef>> {
553        if node.is::<GetItemVTable>() {
554            let lit_id = *idx;
555            *idx += 1;
556            Ok(Transformed::yes(LiteralExpr::new_expr(lit_id)))
557        } else {
558            Ok(Transformed::no(node))
559        }
560    }
561
562    #[derive(Default)]
563    pub struct SkipDownRewriter;
564
565    impl NodeRewriter for SkipDownRewriter {
566        type NodeTy = ExprRef;
567
568        fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
569            Ok(Transformed {
570                value: node,
571                order: TraversalOrder::Skip,
572                changed: false,
573            })
574        }
575
576        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
577            Ok(Transformed::yes(root()))
578        }
579    }
580
581    #[test]
582    fn expr_deep_visitor_test() {
583        let col1: Arc<dyn VortexExpr> = col("col1");
584        let lit1 = LiteralExpr::new(1).into_expr();
585        let expr = BinaryExpr::new(col1.clone(), Operator::Eq, lit1.clone()).into_expr();
586        let lit2 = LiteralExpr::new(2).into_expr();
587        let expr = BinaryExpr::new(expr, Operator::And, lit2).into_expr();
588        let mut printer = ExprLitCollector::default();
589        expr.accept(&mut printer).unwrap();
590        assert_eq!(printer.0.len(), 2);
591    }
592
593    #[test]
594    fn expr_deep_mut_visitor_test() {
595        let col1: Arc<dyn VortexExpr> = col("col1");
596        let col2: Arc<dyn VortexExpr> = col("col2");
597        let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
598        let lit2 = LiteralExpr::new_expr(2);
599        let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
600
601        let mut idx = 0_i32;
602        let new = expr
603            .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
604            .unwrap();
605        assert!(new.changed);
606
607        let expr = new.value;
608
609        let mut printer = ExprLitCollector::default();
610        expr.accept(&mut printer).unwrap();
611        assert_eq!(printer.0.len(), 3);
612    }
613
614    #[test]
615    fn expr_skip_test() {
616        let col1: Arc<dyn VortexExpr> = col("col1");
617        let col2: Arc<dyn VortexExpr> = col("col2");
618        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
619        let col3: Arc<dyn VortexExpr> = col("col3");
620        let col4: Arc<dyn VortexExpr> = col("col4");
621        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
622        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
623
624        let mut nodes = Vec::new();
625        pre_order_visit_down(&expr, |node: &ExprRef| {
626            if node.is::<GetItemVTable>() {
627                nodes.push(node)
628            }
629            if let Some(bin) = node.as_opt::<BinaryVTable>()
630                && bin.op() == Operator::Eq
631            {
632                return Ok(TraversalOrder::Skip);
633            }
634            Ok(TraversalOrder::Continue)
635        })
636        .unwrap();
637
638        let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
639        assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
640    }
641
642    #[test]
643    fn expr_stop_test() {
644        let col1: Arc<dyn VortexExpr> = col("col1");
645        let col2: Arc<dyn VortexExpr> = col("col2");
646        let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
647        let col3: Arc<dyn VortexExpr> = col("col3");
648        let col4: Arc<dyn VortexExpr> = col("col4");
649        let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
650        let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
651
652        let mut nodes = Vec::new();
653        pre_order_visit_down(&expr, |node: &ExprRef| {
654            if node.is::<GetItemVTable>() {
655                nodes.push(node)
656            }
657            if let Some(bin) = node.as_opt::<BinaryVTable>()
658                && bin.op() == Operator::Eq
659            {
660                return Ok(TraversalOrder::Stop);
661            }
662            Ok(TraversalOrder::Continue)
663        })
664        .unwrap();
665
666        assert!(nodes.is_empty());
667    }
668
669    #[test]
670    fn expr_skip_down_visit_up() {
671        let col = col("col");
672
673        let mut visitor = SkipDownRewriter;
674        let result = col.rewrite(&mut visitor).unwrap();
675
676        assert!(result.changed);
677        assert!(is_root(&result.value));
678    }
679}