vortex_expr/traversal/
fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::traversal::Node;
7
8/// Use to indicate the control flow of the fold on the downwards pass.
9/// `Stop` indicates that the fold should stop.
10/// `Skip` indicates that the fold should skip the children of the current node.
11/// `Continue` indicates that the fold should continue.
12#[derive(Debug)]
13pub enum FoldDown<R> {
14    Continue,
15    Stop(R),
16    Skip(R),
17}
18
19/// Use to indicate the control flow of the fold on the downwards pass.
20/// In the case of Continue, the context is passed on to the children nodes.
21/// Other cases are the same as `FoldDown`.
22#[derive(Debug)]
23pub enum FoldDownContext<C, R> {
24    Continue(C),
25    Stop(R),
26    Skip(R),
27}
28
29/// Use to indicate the control flow of the fold on the upwards pass.
30/// `Stop` indicates that the fold should stop at the current position and return the result.
31#[derive(Debug)]
32pub enum FoldUp<R> {
33    Continue(R),
34    Stop(R),
35}
36
37impl<R> FoldUp<R> {
38    pub fn value(self) -> R {
39        match self {
40            Self::Continue(r) => r,
41            Self::Stop(r) => r,
42        }
43    }
44}
45
46/// Use to implement the folding a tree like structure in a pre-order traversal.
47///
48/// At each point on the way down, the `visit_down` method is called. If it returns `Skip`,
49/// the children of the current node are skipped. If it returns `Stop`, the fold is stopped.
50/// If it returns `Continue`, the children of the current node are visited.
51///
52/// At each point on the way up, the `visit_up` method is called. If it returns `Stop`,
53/// the fold stops.
54///
55/// On the way up the folded children are passed to the `visit_up` method along with the current node.
56///
57/// Note: this trait is not safe to use for graphs with a cycle.
58pub trait NodeFolderContext {
59    type NodeTy: Node;
60    type Result;
61    type Context;
62
63    /// visit_down is called when a node is first encountered, in a pre-order traversal.
64    /// If the node's children are to be skipped, return Skip.
65    /// If the node should stop traversal, return Stop.
66    /// Otherwise, return Continue.
67    fn visit_down(
68        &mut self,
69        _ctx: &Self::Context,
70        _node: &Self::NodeTy,
71    ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>>;
72
73    /// visit_up is called when a node is last encountered, in a pre-order traversal.
74    /// If the node should stop traversal, return Stop.
75    /// Otherwise, return Continue.
76    fn visit_up(
77        &mut self,
78        _node: Self::NodeTy,
79        _context: &Self::Context,
80        _children: Vec<Self::Result>,
81    ) -> VortexResult<FoldUp<Self::Result>>;
82}
83
84/// This trait is used to implemet a fold (see `NodeFolderContext`), but without a context.
85pub trait NodeFolder {
86    type NodeTy: Node;
87    type Result;
88
89    /// visit_down is called when a node is first encountered, in a pre-order traversal.
90    /// If the node's children are to be skipped, return Skip.
91    /// If the node should stop traversal, return Stop.
92    /// Otherwise, return Continue.
93    fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
94        Ok(FoldDown::Continue)
95    }
96
97    /// visit_up is called when a node is last encountered, in a pre-order traversal.
98    /// If the node should stop traversal, return Stop.
99    /// Otherwise, return Continue.
100    fn visit_up(
101        &mut self,
102        _node: Self::NodeTy,
103        _children: Vec<Self::Result>,
104    ) -> VortexResult<FoldUp<Self::Result>>;
105}
106
107pub(crate) struct NodeFolderContextWrapper<'a, T>
108where
109    T: NodeFolder,
110{
111    pub inner: &'a mut T,
112}
113
114impl<T: NodeFolder> NodeFolderContext for NodeFolderContextWrapper<'_, T> {
115    type NodeTy = T::NodeTy;
116    type Result = T::Result;
117    type Context = ();
118
119    fn visit_down(
120        &mut self,
121        _ctx: &Self::Context,
122        _node: &Self::NodeTy,
123    ) -> VortexResult<FoldDownContext<Self::Context, Self::Result>> {
124        match self.inner.visit_down(_node)? {
125            FoldDown::Continue => Ok(FoldDownContext::Continue(())),
126            FoldDown::Stop(r) => Ok(FoldDownContext::Stop(r)),
127            FoldDown::Skip(r) => Ok(FoldDownContext::Skip(r)),
128        }
129    }
130
131    fn visit_up(
132        &mut self,
133        _node: Self::NodeTy,
134        _context: &Self::Context,
135        _children: Vec<Self::Result>,
136    ) -> VortexResult<FoldUp<Self::Result>> {
137        self.inner.visit_up(_node, _children)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use vortex_error::VortexExpect;
144
145    use super::*;
146    use crate::traversal::NodeExt;
147    use crate::{
148        BinaryVTable, ExprRef, LiteralVTable, Operator, checked_add, gt, lit, vortex_bail,
149    };
150
151    struct AddFold;
152    impl NodeFolder for AddFold {
153        type NodeTy = ExprRef;
154        type Result = i32;
155
156        fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<FoldDown<Self::Result>> {
157            if let Some(lit) = node.as_opt::<LiteralVTable>() {
158                let v = lit
159                    .value()
160                    .as_primitive()
161                    .typed_value::<i32>()
162                    .vortex_expect("i32");
163
164                if v == 5 {
165                    return Ok(FoldDown::Stop(5));
166                }
167            }
168
169            if let Some(binary) = node.as_opt::<BinaryVTable>()
170                && binary.op() == Operator::Gt
171            {
172                return Ok(FoldDown::Skip(0));
173            }
174
175            Ok(FoldDown::Continue)
176        }
177
178        fn visit_up(
179            &mut self,
180            node: Self::NodeTy,
181            children: Vec<Self::Result>,
182        ) -> VortexResult<FoldUp<Self::Result>> {
183            if let Some(lit) = node.as_opt::<LiteralVTable>() {
184                let v = lit
185                    .value()
186                    .as_primitive()
187                    .typed_value::<i32>()
188                    .vortex_expect("i32");
189                Ok(FoldUp::Continue(v))
190            } else if let Some(binary) = node.as_opt::<BinaryVTable>() {
191                if binary.op() == Operator::Add {
192                    Ok(FoldUp::Continue(children[0] + children[1]))
193                } else {
194                    vortex_bail!("not a valid operator")
195                }
196            } else {
197                vortex_bail!("not a valid type")
198            }
199        }
200    }
201
202    #[test]
203    fn test_fold() {
204        let expr = checked_add(checked_add(lit(1), lit(2)), lit(3));
205
206        let mut folder = AddFold;
207        let result = expr.fold(&mut folder).unwrap().value();
208        assert_eq!(result, 6);
209    }
210
211    #[test]
212    fn test_stop_value() {
213        let expr = checked_add(checked_add(lit(1), lit(5)), lit(3));
214
215        let mut folder = AddFold;
216        let result = expr.fold(&mut folder).unwrap().value();
217        assert_eq!(result, 5);
218    }
219
220    #[test]
221    fn test_skip_value() {
222        let expr = checked_add(gt(lit(1), lit(2)), lit(3));
223
224        let mut folder = AddFold;
225        let result = expr.fold(&mut folder).unwrap().value();
226        assert_eq!(result, 3);
227    }
228
229    #[test]
230    fn test_control_flow_value() {
231        let expr = checked_add(gt(lit(1), lit(5)), lit(3));
232
233        let mut folder = AddFold;
234        let result = expr.fold(&mut folder).unwrap().value();
235        assert_eq!(result, 3);
236    }
237}