vortex_array/expr/traversal/
fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::expr::traversal::Node;
7
8/// Use to indicate the control flow of the fold on the downwards pass.
9/// `Stop` indicates that the fold should stop.
10/// `Skip` indicates that the fold should skip the children of the current node.
11/// `Continue` indicates that the fold should continue.
12#[derive(Debug)]
13pub enum FoldDown<R> {
14    Continue,
15    Stop(R),
16    Skip(R),
17}
18
19/// Use to indicate the control flow of the fold on the downwards pass.
20/// In the case of Continue, the context is passed on to the children nodes.
21/// Other cases are the same as `FoldDown`.
22#[derive(Debug)]
23pub enum FoldDownContext<C, R> {
24    Continue(C),
25    Stop(R),
26    Skip(R),
27}
28
29/// Use to indicate the control flow of the fold on the upwards pass.
30/// `Stop` indicates that the fold should stop at the current position and return the result.
31#[derive(Debug)]
32pub enum FoldUp<R> {
33    Continue(R),
34    Stop(R),
35}
36
37impl<R> FoldUp<R> {
38    pub fn value(self) -> R {
39        match self {
40            Self::Continue(r) => r,
41            Self::Stop(r) => r,
42        }
43    }
44}
45
46/// Use to implement the folding a tree like structure in a pre-order traversal.
47///
48/// At each point on the way down, the `visit_down` method is called. If it returns `Skip`,
49/// the children of the current node are skipped. If it returns `Stop`, the fold is stopped.
50/// If it returns `Continue`, the children of the current node are visited.
51///
52/// At each point on the way up, the `visit_up` method is called. If it returns `Stop`,
53/// the fold stops.
54///
55/// On the way up the folded children are passed to the `visit_up` method along with the current node.
56///
57/// Note: this trait is not safe to use for graphs with a cycle.
58pub trait NodeFolderContext {
59    type NodeTy: Node;
60    type Result;
61    type Context;
62
63    /// visit_down is called when a node is first encountered, in a pre-order traversal.
64    /// If the node's children are to be skipped, return Skip.
65    /// If the node should stop traversal, return Stop.
66    /// Otherwise, return Continue.
67    fn visit_down(
68        &mut self,
69        _ctx: &Self::Context,
70        _node: &Self::NodeTy,
71    ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>>;
72
73    /// visit_up is called when a node is last encountered, in a pre-order traversal.
74    /// If the node should stop traversal, return Stop.
75    /// Otherwise, return Continue.
76    fn visit_up(
77        &mut self,
78        _node: Self::NodeTy,
79        _context: &Self::Context,
80        _children: Vec<Self::Result>,
81    ) -> VortexResult<FoldUp<Self::Result>>;
82}
83
84/// This trait is used to implement a fold (see `NodeFolderContext`), but without a context.
85pub trait NodeFolder {
86    type NodeTy: Node;
87    type Result;
88
89    /// visit_down is called when a node is first encountered, in a pre-order traversal.
90    /// If the node's children are to be skipped, return Skip.
91    /// If the node should stop traversal, return Stop.
92    /// Otherwise, return Continue.
93    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
94        Ok(FoldDown::Continue)
95    }
96
97    /// visit_up is called when a node is last encountered, in a pre-order traversal.
98    /// If the node should stop traversal, return Stop.
99    /// Otherwise, return Continue.
100    fn visit_up(
101        &mut self,
102        _node: Self::NodeTy,
103        _children: Vec<Self::Result>,
104    ) -> VortexResult<FoldUp<Self::Result>>;
105}
106
107pub(crate) struct NodeFolderContextWrapper<'a, T>
108where
109    T: NodeFolder,
110{
111    pub inner: &'a mut T,
112}
113
114impl<T: NodeFolder> NodeFolderContext for NodeFolderContextWrapper<'_, T> {
115    type NodeTy = T::NodeTy;
116    type Result = T::Result;
117    type Context = ();
118
119    fn visit_down(
120        &mut self,
121        _ctx: &Self::Context,
122        _node: &Self::NodeTy,
123    ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>> {
124        match self.inner.visit_down(_node)? {
125            FoldDown::Continue => Ok(FoldDownContext::Continue(())),
126            FoldDown::Stop(r) => Ok(FoldDownContext::Stop(r)),
127            FoldDown::Skip(r) => Ok(FoldDownContext::Skip(r)),
128        }
129    }
130
131    fn visit_up(
132        &mut self,
133        _node: Self::NodeTy,
134        _context: &Self::Context,
135        _children: Vec<Self::Result>,
136    ) -> VortexResult<FoldUp<Self::Result>> {
137        self.inner.visit_up(_node, _children)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use vortex_error::VortexExpect;
144    use vortex_error::vortex_bail;
145
146    use super::*;
147    use crate::expr::Expression;
148    use crate::expr::exprs::binary::Binary;
149    use crate::expr::exprs::binary::checked_add;
150    use crate::expr::exprs::binary::gt;
151    use crate::expr::exprs::literal::Literal;
152    use crate::expr::exprs::literal::lit;
153    use crate::expr::exprs::operators::Operator;
154    use crate::expr::traversal::NodeExt;
155
156    struct AddFold;
157    impl NodeFolder for AddFold {
158        type NodeTy = Expression;
159        type Result = i32;
160
161        fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
162            if let Some(lit) = node.as_opt::<Literal>() {
163                let v = lit
164                    .data()
165                    .as_primitive()
166                    .typed_value::<i32>()
167                    .vortex_expect("i32");
168
169                if v == 5 {
170                    return Ok(FoldDown::Stop(5));
171                }
172            }
173
174            if let Some(binary) = node.as_opt::<Binary>()
175                && binary.operator() == Operator::Gt
176            {
177                return Ok(FoldDown::Skip(0));
178            }
179
180            Ok(FoldDown::Continue)
181        }
182
183        fn visit_up(
184            &mut self,
185            node: Self::NodeTy,
186            children: Vec<Self::Result>,
187        ) -> VortexResult<FoldUp<Self::Result>> {
188            if let Some(lit) = node.as_opt::<Literal>() {
189                let v = lit
190                    .data()
191                    .as_primitive()
192                    .typed_value::<i32>()
193                    .vortex_expect("i32");
194                Ok(FoldUp::Continue(v))
195            } else if let Some(binary) = node.as_opt::<Binary>() {
196                if binary.operator() == Operator::Add {
197                    Ok(FoldUp::Continue(children[0] + children[1]))
198                } else {
199                    vortex_bail!("not a valid operator")
200                }
201            } else {
202                vortex_bail!("not a valid type")
203            }
204        }
205    }
206
207    #[test]
208    fn test_fold() {
209        let expr = checked_add(checked_add(lit(1), lit(2)), lit(3));
210
211        let mut folder = AddFold;
212        let result = expr.fold(&mut folder).unwrap().value();
213        assert_eq!(result, 6);
214    }
215
216    #[test]
217    fn test_stop_value() {
218        let expr = checked_add(checked_add(lit(1), lit(5)), lit(3));
219
220        let mut folder = AddFold;
221        let result = expr.fold(&mut folder).unwrap().value();
222        assert_eq!(result, 5);
223    }
224
225    #[test]
226    fn test_skip_value() {
227        let expr = checked_add(gt(lit(1), lit(2)), lit(3));
228
229        let mut folder = AddFold;
230        let result = expr.fold(&mut folder).unwrap().value();
231        assert_eq!(result, 3);
232    }
233
234    #[test]
235    fn test_control_flow_value() {
236        let expr = checked_add(gt(lit(1), lit(5)), lit(3));
237
238        let mut folder = AddFold;
239        let result = expr.fold(&mut folder).unwrap().value();
240        assert_eq!(result, 3);
241    }
242}