vortex_array/expr/traversal/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Datafusion inspired tree traversal logic.
5//!
6//! Users should want to implement [`Node`] and potentially [`NodeContainer`].
7
8mod fold;
9mod references;
10mod visitor;
11
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15pub use fold::FoldDown;
16pub use fold::FoldDownContext;
17pub use fold::FoldUp;
18pub use fold::NodeFolder;
19pub use fold::NodeFolderContext;
20use itertools::Itertools;
21pub use references::ReferenceCollector;
22pub use visitor::pre_order_visit_down;
23pub use visitor::pre_order_visit_up;
24use vortex_error::VortexResult;
25
26use crate::expr::Expression;
27use crate::expr::traversal::fold::NodeFolderContextWrapper;
28
29/// Signal to control a traversal's flow
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum TraversalOrder {
32    /// In a top-down traversal, skip visiting the children of the current node.
33    /// In the bottom-up phase of the traversal, skip the next step. Either skipping the children of the node,
34    /// moving to its next sibling, or skipping its parent once the children are traversed.
35    Skip,
36    /// Stop visiting any more nodes in the traversal.
37    Stop,
38    /// Continue with the traversal as expected.
39    Continue,
40}
41
42impl TraversalOrder {
43    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's children.
44    pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
45        self,
46        f: F,
47    ) -> VortexResult<TraversalOrder> {
48        match self {
49            Self::Skip => Ok(TraversalOrder::Continue),
50            Self::Stop => Ok(self),
51            Self::Continue => f(),
52        }
53    }
54
55    /// If directed to, continue to visit nodes by running `f`, which should apply on the node's parent.
56    pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
57        self,
58        f: F,
59    ) -> VortexResult<TraversalOrder> {
60        match self {
61            Self::Continue => f(),
62            Self::Skip | Self::Stop => Ok(self),
63        }
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct Transformed<T> {
69    /// Value that was being rewritten.
70    pub value: T,
71    /// Controls the flow of rewriting, see [`TraversalOrder`] for more details.
72    pub order: TraversalOrder,
73    /// Was the value changed during rewriting.
74    pub changed: bool,
75}
76
77impl<T> Transformed<T> {
78    pub fn yes(value: T) -> Self {
79        Self {
80            value,
81            order: TraversalOrder::Continue,
82            changed: true,
83        }
84    }
85
86    pub fn no(value: T) -> Self {
87        Self {
88            value,
89            order: TraversalOrder::Continue,
90            changed: false,
91        }
92    }
93
94    pub fn into_inner(self) -> T {
95        self.value
96    }
97
98    /// Apply a function to `value`, changing it without changing the `changed` field.
99    pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
100        Transformed {
101            value: f(self.value),
102            order: self.order,
103            changed: self.changed,
104        }
105    }
106}
107
108pub trait NodeVisitor<'a> {
109    type NodeTy: Node;
110
111    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
112        _ = node;
113        Ok(TraversalOrder::Continue)
114    }
115
116    fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
117        _ = node;
118        Ok(TraversalOrder::Continue)
119    }
120}
121
122pub trait NodeRewriter: Sized {
123    type NodeTy: Node;
124
125    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
126        Ok(Transformed::no(node))
127    }
128
129    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
130        Ok(Transformed::no(node))
131    }
132}
133
134pub trait Node: Sized + Clone {
135    /// Walk the node's children by applying `f` to them.
136    ///
137    /// This is a lower level API that other functions rely on for their implementation.
138    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
139        &'a self,
140        f: F,
141    ) -> VortexResult<TraversalOrder>;
142
143    /// Rewrite the node's children by applying `f` to them.
144    ///
145    /// This is a lower level API that other functions rely on for their implementation.
146    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
147        self,
148        f: F,
149    ) -> VortexResult<Transformed<Self>>;
150
151    /// This is a lower level API that other functions rely on for their implementation.
152    fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T;
153
154    /// This is a lower level API that other functions rely on for their implementation.
155    fn children_count(&self) -> usize;
156}
157
158pub trait NodeExt: Node {
159    /// Walk the tree in pre-order (top-down) way, rewriting it as it goes.
160    fn rewrite<R: NodeRewriter<NodeTy = Self>>(
161        self,
162        rewriter: &mut R,
163    ) -> VortexResult<Transformed<Self>> {
164        let mut transformed = rewriter.visit_down(self)?;
165
166        let transformed = match transformed.order {
167            TraversalOrder::Stop => Ok(transformed),
168            TraversalOrder::Skip => {
169                transformed.order = TraversalOrder::Continue;
170                Ok(transformed)
171            }
172            TraversalOrder::Continue => transformed
173                .value
174                .map_children(|c| c.rewrite(rewriter))
175                .map(|mut t| {
176                    t.changed |= transformed.changed;
177                    t
178                }),
179        }?;
180
181        match transformed.order {
182            TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
183            TraversalOrder::Continue => {
184                let mut up_rewrite = rewriter.visit_up(transformed.value)?;
185                up_rewrite.changed |= transformed.changed;
186                Ok(up_rewrite)
187            }
188        }
189    }
190
191    /// A pre-order (top-down) traversal.
192    fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
193        &'a self,
194        visitor: &mut V,
195    ) -> VortexResult<TraversalOrder> {
196        visitor
197            .visit_down(self)?
198            .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
199            .visit_parent(|| visitor.visit_up(self))
200    }
201
202    /// A pre-order transformation
203    fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
204        self,
205        f: F,
206    ) -> VortexResult<Transformed<Self>> {
207        let mut rewriter = FnRewriter::<F, F, _> {
208            f_down: Some(f),
209            f_up: None,
210            _data: PhantomData,
211        };
212
213        self.rewrite(&mut rewriter)
214    }
215
216    fn transform<F, G>(self, down: F, up: G) -> VortexResult<Transformed<Self>>
217    where
218        F: FnMut(Self) -> VortexResult<Transformed<Self>>,
219        G: FnMut(Self) -> VortexResult<Transformed<Self>>,
220    {
221        let mut rewriter = FnRewriter {
222            f_down: Some(down),
223            f_up: Some(up),
224            _data: PhantomData,
225        };
226
227        self.rewrite(&mut rewriter)
228    }
229
230    /// A post-order transform
231    fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
232        self,
233        f: F,
234    ) -> VortexResult<Transformed<Self>> {
235        let mut rewriter = FnRewriter::<F, F, _> {
236            f_down: None,
237            f_up: Some(f),
238            _data: PhantomData,
239        };
240
241        self.rewrite(&mut rewriter)
242    }
243
244    /// applies the `NodeFolderContext` to the Node tree, with an initial `Context`.
245    fn fold_context<R, F: NodeFolderContext<NodeTy = Self, Result = R>>(
246        self,
247        ctx: &F::Context,
248        folder: &mut F,
249    ) -> VortexResult<FoldUp<R>> {
250        let transformed = folder.visit_down(ctx, &self)?;
251        let ctx = match transformed {
252            FoldDownContext::Continue(ctx) => ctx,
253            FoldDownContext::Skip(r) => return Ok(FoldUp::Continue(r)),
254            FoldDownContext::Stop(r) => return Ok(FoldUp::Stop(r)),
255        };
256
257        let mut children = Vec::with_capacity(self.children_count());
258        let mut stop_result = None;
259        self.iter_children(|children_iter| -> VortexResult<()> {
260            for c in children_iter {
261                let t = c.clone().fold_context(&ctx, folder)?;
262                match t {
263                    FoldUp::Stop(r) => {
264                        stop_result = Some(r);
265                        return Ok(());
266                    }
267                    FoldUp::Continue(r) => {
268                        children.push(r);
269                    }
270                }
271            }
272            Ok(())
273        })?;
274
275        if let Some(result) = stop_result {
276            return Ok(FoldUp::Stop(result));
277        }
278
279        folder.visit_up(self, &ctx, children)
280    }
281
282    /// applies the `NodeFolder` to the Node tree
283    fn fold<R, F: NodeFolder<NodeTy = Self, Result = R>>(
284        self,
285        folder: &mut F,
286    ) -> VortexResult<FoldUp<R>> {
287        let mut folder = NodeFolderContextWrapper { inner: folder };
288        self.fold_context(&(), &mut folder)
289    }
290}
291
292impl<T: Node> NodeExt for T {}
293
294struct FnRewriter<F, G, T> {
295    f_down: Option<F>,
296    f_up: Option<G>,
297    _data: PhantomData<T>,
298}
299
300impl<F, G, T> NodeRewriter for FnRewriter<F, G, T>
301where
302    T: Node,
303    F: FnMut(T) -> VortexResult<Transformed<T>>,
304    G: FnMut(T) -> VortexResult<Transformed<T>>,
305{
306    type NodeTy = T;
307
308    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
309        if let Some(f) = self.f_down.as_mut() {
310            f(node)
311        } else {
312            Ok(Transformed::no(node))
313        }
314    }
315
316    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
317        if let Some(f) = self.f_up.as_mut() {
318            f(node)
319        } else {
320            Ok(Transformed::no(node))
321        }
322    }
323}
324
325/// A container holding a [`Node`]'s children, which a function can be applied (or mapped) to.
326///
327/// The trait is also implemented to container types in order to make implementing [`Node::map_children`]
328/// and [`Node::apply_children`] easier.
329pub trait NodeContainer<'a, T: 'a>: Sized {
330    /// Applies `f` to all elements of the container, accepting them by reference
331    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
332        &'a self,
333        f: F,
334    ) -> VortexResult<TraversalOrder>;
335
336    /// Consumes all the children of the node, replacing them with the result of `f`.
337    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
338        self,
339        f: F,
340    ) -> VortexResult<Transformed<Self>>;
341}
342
343pub trait NodeRefContainer<'a, T: 'a>: Sized {
344    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
345        &self,
346        f: F,
347    ) -> VortexResult<TraversalOrder>;
348}
349
350impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for &'a [C] {
351    fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
352        &self,
353        mut f: F,
354    ) -> VortexResult<TraversalOrder> {
355        let mut order = TraversalOrder::Continue;
356
357        for c in *self {
358            order = c.apply_elements(&mut f)?;
359            match order {
360                TraversalOrder::Continue | TraversalOrder::Skip => {}
361                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
362            }
363        }
364
365        Ok(order)
366    }
367}
368
369impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
370    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
371        &'a self,
372        f: F,
373    ) -> VortexResult<TraversalOrder> {
374        self.as_ref().apply_elements(f)
375    }
376
377    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
378        self,
379        f: F,
380    ) -> VortexResult<Transformed<Box<C>>> {
381        Ok((*self).map_elements(f)?.map(Box::new))
382    }
383}
384
385impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
386where
387    T: 'a,
388    C: NodeContainer<'a, T> + Clone,
389{
390    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
391        &'a self,
392        f: F,
393    ) -> VortexResult<TraversalOrder> {
394        self.as_ref().apply_elements(f)
395    }
396
397    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
398        self,
399        f: F,
400    ) -> VortexResult<Transformed<Arc<C>>> {
401        Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
402    }
403}
404
405impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
406    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
407        &'a self,
408        mut f: F,
409    ) -> VortexResult<TraversalOrder> {
410        let [lhs, rhs] = self;
411        match lhs.apply_elements(&mut f)? {
412            TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
413            TraversalOrder::Stop => Ok(TraversalOrder::Stop),
414        }
415    }
416
417    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
418        self,
419        mut f: F,
420    ) -> VortexResult<Transformed<[C; 2]>> {
421        let [lhs, rhs] = self;
422        let transformed = lhs.map_elements(&mut f)?;
423        match transformed.order {
424            TraversalOrder::Skip | TraversalOrder::Continue => {
425                let mut t = rhs.map_elements(&mut f)?;
426                t.changed |= transformed.changed;
427                Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
428            }
429            TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
430        }
431    }
432}
433
434impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
435    fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
436        &'a self,
437        mut f: F,
438    ) -> VortexResult<TraversalOrder> {
439        let mut order = TraversalOrder::Continue;
440
441        for c in self {
442            order = c.apply_elements(&mut f)?;
443            match order {
444                TraversalOrder::Continue | TraversalOrder::Skip => {}
445                TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
446            }
447        }
448
449        Ok(order)
450    }
451
452    fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
453        self,
454        mut f: F,
455    ) -> VortexResult<Transformed<Self>> {
456        let mut order = TraversalOrder::Continue;
457        let mut changed = false;
458
459        let value = self
460            .into_iter()
461            .map(|c| match order {
462                TraversalOrder::Continue | TraversalOrder::Skip => {
463                    c.map_elements(&mut f).map(|result| {
464                        order = result.order;
465                        changed |= result.changed;
466                        result.value
467                    })
468                }
469                TraversalOrder::Stop => Ok(c),
470            })
471            .collect::<VortexResult<Vec<_>>>()?;
472
473        Ok(Transformed {
474            value,
475            order,
476            changed,
477        })
478    }
479}
480
481impl<'a> NodeContainer<'a, Self> for Expression {
482    fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
483        &'a self,
484        mut f: F,
485    ) -> VortexResult<TraversalOrder> {
486        f(self)
487    }
488
489    fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
490        self,
491        mut f: F,
492    ) -> VortexResult<Transformed<Self>> {
493        f(self)
494    }
495}
496
497impl Node for Expression {
498    fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
499        &'a self,
500        mut f: F,
501    ) -> VortexResult<TraversalOrder> {
502        self.children().as_ref().apply_ref_elements(&mut f)
503    }
504
505    fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
506        self,
507        f: F,
508    ) -> VortexResult<Transformed<Self>> {
509        let transformed = self
510            .children()
511            .iter()
512            .cloned()
513            .collect_vec()
514            .map_elements(f)?;
515
516        if transformed.changed {
517            Ok(Transformed {
518                value: self.with_children(transformed.value)?,
519                order: transformed.order,
520                changed: true,
521            })
522        } else {
523            Ok(Transformed::no(self))
524        }
525    }
526
527    fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
528        f(&mut self.children().iter())
529    }
530
531    fn children_count(&self) -> usize {
532        self.children().len()
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use vortex_error::VortexResult;
539    use vortex_utils::aliases::hash_set::HashSet;
540
541    use super::NodeExt;
542    use super::NodeRewriter;
543    use super::NodeVisitor;
544    use super::Transformed;
545    use super::TraversalOrder;
546    use super::visitor::pre_order_visit_down;
547    use crate::expr::Expression;
548    use crate::expr::exprs::binary::Binary;
549    use crate::expr::exprs::binary::and;
550    use crate::expr::exprs::binary::eq;
551    use crate::expr::exprs::binary::not_eq;
552    use crate::expr::exprs::get_item::GetItem;
553    use crate::expr::exprs::get_item::col;
554    use crate::expr::exprs::literal::Literal;
555    use crate::expr::exprs::literal::lit;
556    use crate::expr::exprs::operators::Operator;
557    use crate::expr::exprs::root::is_root;
558    use crate::expr::exprs::root::root;
559
560    #[derive(Default)]
561    pub struct ExprLitCollector<'a>(pub Vec<&'a Expression>);
562
563    impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
564        type NodeTy = Expression;
565
566        fn visit_down(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
567            if node.is::<Literal>() {
568                self.0.push(node)
569            }
570            Ok(TraversalOrder::Continue)
571        }
572
573        fn visit_up(&mut self, _node: &'a Expression) -> VortexResult<TraversalOrder> {
574            Ok(TraversalOrder::Continue)
575        }
576    }
577
578    fn expr_col_to_lit_transform(
579        node: Expression,
580        idx: &mut i32,
581    ) -> VortexResult<Transformed<Expression>> {
582        if node.is::<GetItem>() {
583            let lit_id = *idx;
584            *idx += 1;
585            Ok(Transformed::yes(lit(lit_id)))
586        } else {
587            Ok(Transformed::no(node))
588        }
589    }
590
591    #[derive(Default)]
592    pub struct SkipDownRewriter;
593
594    impl NodeRewriter for SkipDownRewriter {
595        type NodeTy = Expression;
596
597        fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
598            Ok(Transformed {
599                value: node,
600                order: TraversalOrder::Skip,
601                changed: false,
602            })
603        }
604
605        fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
606            Ok(Transformed::yes(root()))
607        }
608    }
609
610    #[test]
611    fn expr_deep_visitor_test() {
612        let col1: Expression = col("col1");
613        let lit1 = lit(1);
614        let expr = eq(col1, lit1);
615        let lit2 = lit(2);
616        let expr = and(expr, lit2);
617        let mut printer = ExprLitCollector::default();
618        expr.accept(&mut printer).unwrap();
619        assert_eq!(printer.0.len(), 2);
620    }
621
622    #[test]
623    fn expr_deep_mut_visitor_test() {
624        let col1: Expression = col("col1");
625        let col2: Expression = col("col2");
626        let expr = eq(col1, col2);
627        let lit2 = lit(2);
628        let expr = and(expr, lit2);
629
630        let mut idx = 0_i32;
631        let new = expr
632            .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
633            .unwrap();
634        assert!(new.changed);
635
636        let expr = new.value;
637
638        let mut printer = ExprLitCollector::default();
639        expr.accept(&mut printer).unwrap();
640        assert_eq!(printer.0.len(), 3);
641    }
642
643    #[test]
644    fn expr_skip_test() {
645        let col1: Expression = col("col1");
646        let col2: Expression = col("col2");
647        let expr1 = eq(col1, col2);
648        let col3: Expression = col("col3");
649        let col4: Expression = col("col4");
650        let expr2 = not_eq(col3, col4);
651        let expr = and(expr1, expr2);
652
653        let mut nodes = Vec::new();
654        pre_order_visit_down(&expr, |node: &Expression| {
655            if node.is::<GetItem>() {
656                nodes.push(node)
657            }
658            if let Some(bin) = node.as_opt::<Binary>()
659                && bin.operator() == Operator::Eq
660            {
661                return Ok(TraversalOrder::Skip);
662            }
663            Ok(TraversalOrder::Continue)
664        })
665        .unwrap();
666
667        let nodes: HashSet<Expression> = HashSet::from_iter(nodes.into_iter().cloned());
668        assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
669    }
670
671    #[test]
672    fn expr_stop_test() {
673        let col1: Expression = col("col1");
674        let col2: Expression = col("col2");
675        let expr1 = eq(col1, col2);
676        let col3: Expression = col("col3");
677        let col4: Expression = col("col4");
678        let expr2 = not_eq(col3, col4);
679        let expr = and(expr1, expr2);
680
681        let mut nodes = Vec::new();
682        pre_order_visit_down(&expr, |node: &Expression| {
683            if node.is::<GetItem>() {
684                nodes.push(node)
685            }
686            if let Some(bin) = node.as_opt::<Binary>()
687                && bin.operator() == Operator::Eq
688            {
689                return Ok(TraversalOrder::Stop);
690            }
691            Ok(TraversalOrder::Continue)
692        })
693        .unwrap();
694
695        assert!(nodes.is_empty());
696    }
697
698    #[test]
699    fn expr_skip_down_visit_up() {
700        let col = col("col");
701
702        let mut visitor = SkipDownRewriter;
703        let result = col.rewrite(&mut visitor).unwrap();
704
705        assert!(result.changed);
706        assert!(is_root(&result.value));
707    }
708}