vortex_array/expr/traversal/
fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::expr::traversal::Node;
7
8/// Use to indicate the control flow of the fold on the downwards pass.
9/// `Stop` indicates that the fold should stop.
10/// `Skip` indicates that the fold should skip the children of the current node.
11/// `Continue` indicates that the fold should continue.
12#[derive(Debug)]
13pub enum FoldDown<R> {
14    Continue,
15    Stop(R),
16    Skip(R),
17}
18
19/// Use to indicate the control flow of the fold on the downwards pass.
20/// In the case of Continue, the context is passed on to the children nodes.
21/// Other cases are the same as `FoldDown`.
22#[derive(Debug)]
23pub enum FoldDownContext<C, R> {
24    Continue(C),
25    Stop(R),
26    Skip(R),
27}
28
29/// Use to indicate the control flow of the fold on the upwards pass.
30/// `Stop` indicates that the fold should stop at the current position and return the result.
31#[derive(Debug)]
32pub enum FoldUp<R> {
33    Continue(R),
34    Stop(R),
35}
36
37impl<R> FoldUp<R> {
38    pub fn value(self) -> R {
39        match self {
40            Self::Continue(r) => r,
41            Self::Stop(r) => r,
42        }
43    }
44}
45
46/// Use to implement the folding a tree like structure in a pre-order traversal.
47///
48/// At each point on the way down, the `visit_down` method is called. If it returns `Skip`,
49/// the children of the current node are skipped. If it returns `Stop`, the fold is stopped.
50/// If it returns `Continue`, the children of the current node are visited.
51///
52/// At each point on the way up, the `visit_up` method is called. If it returns `Stop`,
53/// the fold stops.
54///
55/// On the way up the folded children are passed to the `visit_up` method along with the current node.
56///
57/// Note: this trait is not safe to use for graphs with a cycle.
58pub trait NodeFolderContext {
59    type NodeTy: Node;
60    type Result;
61    type Context;
62
63    /// visit_down is called when a node is first encountered, in a pre-order traversal.
64    /// If the node's children are to be skipped, return Skip.
65    /// If the node should stop traversal, return Stop.
66    /// Otherwise, return Continue.
67    fn visit_down(
68        &mut self,
69        _ctx: &Self::Context,
70        _node: &Self::NodeTy,
71    ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>>;
72
73    /// visit_up is called when a node is last encountered, in a pre-order traversal.
74    /// If the node should stop traversal, return Stop.
75    /// Otherwise, return Continue.
76    fn visit_up(
77        &mut self,
78        _node: Self::NodeTy,
79        _context: &Self::Context,
80        _children: Vec<Self::Result>,
81    ) -> VortexResult<FoldUp<Self::Result>>;
82}
83
84/// This trait is used to implement a fold (see `NodeFolderContext`), but without a context.
85pub trait NodeFolder {
86    type NodeTy: Node;
87    type Result;
88
89    /// visit_down is called when a node is first encountered, in a pre-order traversal.
90    /// If the node's children are to be skipped, return Skip.
91    /// If the node should stop traversal, return Stop.
92    /// Otherwise, return Continue.
93    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
94        Ok(FoldDown::Continue)
95    }
96
97    /// visit_up is called when a node is last encountered, in a pre-order traversal.
98    /// If the node should stop traversal, return Stop.
99    /// Otherwise, return Continue.
100    fn visit_up(
101        &mut self,
102        _node: Self::NodeTy,
103        _children: Vec<Self::Result>,
104    ) -> VortexResult<FoldUp<Self::Result>>;
105}
106
107pub(crate) struct NodeFolderContextWrapper<'a, T>
108where
109    T: NodeFolder,
110{
111    pub inner: &'a mut T,
112}
113
114impl<T: NodeFolder> NodeFolderContext for NodeFolderContextWrapper<'_, T> {
115    type NodeTy = T::NodeTy;
116    type Result = T::Result;
117    type Context = ();
118
119    fn visit_down(
120        &mut self,
121        _ctx: &Self::Context,
122        _node: &Self::NodeTy,
123    ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>> {
124        match self.inner.visit_down(_node)? {
125            FoldDown::Continue => Ok(FoldDownContext::Continue(())),
126            FoldDown::Stop(r) => Ok(FoldDownContext::Stop(r)),
127            FoldDown::Skip(r) => Ok(FoldDownContext::Skip(r)),
128        }
129    }
130
131    fn visit_up(
132        &mut self,
133        _node: Self::NodeTy,
134        _context: &Self::Context,
135        _children: Vec<Self::Result>,
136    ) -> VortexResult<FoldUp<Self::Result>> {
137        self.inner.visit_up(_node, _children)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use vortex_error::{VortexExpect, vortex_bail};
144
145    use super::*;
146    use crate::expr::Expression;
147    use crate::expr::exprs::binary::{Binary, checked_add, gt};
148    use crate::expr::exprs::literal::{Literal, lit};
149    use crate::expr::exprs::operators::Operator;
150    use crate::expr::traversal::NodeExt;
151
152    struct AddFold;
153    impl NodeFolder for AddFold {
154        type NodeTy = Expression;
155        type Result = i32;
156
157        fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
158            if let Some(lit) = node.as_opt::<Literal>() {
159                let v = lit
160                    .data()
161                    .as_primitive()
162                    .typed_value::<i32>()
163                    .vortex_expect("i32");
164
165                if v == 5 {
166                    return Ok(FoldDown::Stop(5));
167                }
168            }
169
170            if let Some(binary) = node.as_opt::<Binary>()
171                && binary.operator() == Operator::Gt
172            {
173                return Ok(FoldDown::Skip(0));
174            }
175
176            Ok(FoldDown::Continue)
177        }
178
179        fn visit_up(
180            &mut self,
181            node: Self::NodeTy,
182            children: Vec<Self::Result>,
183        ) -> VortexResult<FoldUp<Self::Result>> {
184            if let Some(lit) = node.as_opt::<Literal>() {
185                let v = lit
186                    .data()
187                    .as_primitive()
188                    .typed_value::<i32>()
189                    .vortex_expect("i32");
190                Ok(FoldUp::Continue(v))
191            } else if let Some(binary) = node.as_opt::<Binary>() {
192                if binary.operator() == Operator::Add {
193                    Ok(FoldUp::Continue(children[0] + children[1]))
194                } else {
195                    vortex_bail!("not a valid operator")
196                }
197            } else {
198                vortex_bail!("not a valid type")
199            }
200        }
201    }
202
203    #[test]
204    fn test_fold() {
205        let expr = checked_add(checked_add(lit(1), lit(2)), lit(3));
206
207        let mut folder = AddFold;
208        let result = expr.fold(&mut folder).unwrap().value();
209        assert_eq!(result, 6);
210    }
211
212    #[test]
213    fn test_stop_value() {
214        let expr = checked_add(checked_add(lit(1), lit(5)), lit(3));
215
216        let mut folder = AddFold;
217        let result = expr.fold(&mut folder).unwrap().value();
218        assert_eq!(result, 5);
219    }
220
221    #[test]
222    fn test_skip_value() {
223        let expr = checked_add(gt(lit(1), lit(2)), lit(3));
224
225        let mut folder = AddFold;
226        let result = expr.fold(&mut folder).unwrap().value();
227        assert_eq!(result, 3);
228    }
229
230    #[test]
231    fn test_control_flow_value() {
232        let expr = checked_add(gt(lit(1), lit(5)), lit(3));
233
234        let mut folder = AddFold;
235        let result = expr.fold(&mut folder).unwrap().value();
236        assert_eq!(result, 3);
237    }
238}