1mod references;
2mod visitor;
3
4use itertools::Itertools;
5pub use references::ReferenceCollector;
6pub use visitor::{pre_order_visit_down, pre_order_visit_up};
7use vortex_error::VortexResult;
8
9use crate::ExprRef;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TraversalOrder {
22 Skip,
25
26 Stop,
28
29 Continue,
31}
32
33#[derive(Debug, Clone)]
34pub struct TransformResult<T> {
35 pub result: T,
36 order: TraversalOrder,
37 changed: bool,
38}
39
40impl<T> TransformResult<T> {
41 pub fn yes(result: T) -> Self {
42 Self {
43 result,
44 order: TraversalOrder::Continue,
45 changed: true,
46 }
47 }
48
49 pub fn no(result: T) -> Self {
50 Self {
51 result,
52 order: TraversalOrder::Continue,
53 changed: false,
54 }
55 }
56}
57
58pub trait NodeVisitor<'a> {
59 type NodeTy: Node;
60
61 fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
62 Ok(TraversalOrder::Continue)
63 }
64
65 fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66 Ok(TraversalOrder::Continue)
67 }
68}
69
70pub trait MutNodeVisitor {
71 type NodeTy: Node;
72
73 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
74 Ok(TraversalOrder::Continue)
75 }
76
77 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
78}
79
80pub enum FoldDown<Out, Context> {
81 Abort(Out),
83 SkipChildren(Out),
85 Continue(Context),
87}
88
89#[derive(Debug)]
90pub enum FoldUp<Out> {
91 Abort(Out),
93 Continue(Out),
95}
96
97impl<Out> FoldUp<Out> {
98 pub fn result(self) -> Out {
99 match self {
100 FoldUp::Abort(out) => out,
101 FoldUp::Continue(out) => out,
102 }
103 }
104}
105
106pub trait Folder<'a> {
107 type NodeTy: Node;
108 type Out;
109 type Context: Clone;
110
111 fn visit_down(
112 &mut self,
113 _node: &'a Self::NodeTy,
114 context: Self::Context,
115 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
116 Ok(FoldDown::Continue(context))
117 }
118
119 fn visit_up(
120 &mut self,
121 node: &'a Self::NodeTy,
122 context: Self::Context,
123 children: Vec<Self::Out>,
124 ) -> VortexResult<FoldUp<Self::Out>>;
125}
126
127pub trait FolderMut {
128 type NodeTy: Node;
129 type Out;
130 type Context: Clone;
131
132 fn visit_down(
133 &mut self,
134 _node: &Self::NodeTy,
135 context: Self::Context,
136 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
137 Ok(FoldDown::Continue(context))
138 }
139
140 fn visit_up(
141 &mut self,
142 node: Self::NodeTy,
143 context: Self::Context,
144 children: Vec<Self::Out>,
145 ) -> VortexResult<FoldUp<Self::Out>>;
146}
147
148pub trait Node: Sized {
149 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
150 &'a self,
151 _visitor: &mut V,
152 ) -> VortexResult<TraversalOrder>;
153
154 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
155 &'a self,
156 visitor: &mut V,
157 context: V::Context,
158 ) -> VortexResult<FoldUp<V::Out>>;
159
160 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
161 self,
162 _visitor: &mut V,
163 ) -> VortexResult<TransformResult<Self>>;
164
165 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
166 self,
167 _visitor: &mut V,
168 _context: V::Context,
169 ) -> VortexResult<FoldUp<V::Out>>;
170}
171
172impl Node for ExprRef {
173 fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
175 &'a self,
176 visitor: &mut V,
177 ) -> VortexResult<TraversalOrder> {
178 let mut ord = visitor.visit_down(self)?;
179 if ord == TraversalOrder::Stop {
180 return Ok(TraversalOrder::Stop);
181 }
182 if ord == TraversalOrder::Skip {
183 return Ok(TraversalOrder::Continue);
184 }
185 for child in self.children() {
186 if ord != TraversalOrder::Continue {
187 return Ok(ord);
188 }
189 ord = child.accept(visitor)?;
190 }
191 if ord == TraversalOrder::Stop {
192 return Ok(TraversalOrder::Stop);
193 }
194 visitor.visit_up(self)
195 }
196
197 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
198 &'a self,
199 visitor: &mut V,
200 context: V::Context,
201 ) -> VortexResult<FoldUp<V::Out>> {
202 let children = match visitor.visit_down(self, context.clone())? {
203 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
204 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
205 FoldDown::Continue(child_context) => {
206 let mut new_children = Vec::with_capacity(self.children().len());
207 for child in self.children() {
208 match child.accept_with_context(visitor, child_context.clone())? {
209 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
210 FoldUp::Continue(out) => new_children.push(out),
211 }
212 }
213 new_children
214 }
215 };
216
217 visitor.visit_up(self, context, children)
218 }
219
220 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
222 self,
223 visitor: &mut V,
224 ) -> VortexResult<TransformResult<Self>> {
225 let mut ord = visitor.visit_down(&self)?;
226 if ord == TraversalOrder::Stop {
227 return Ok(TransformResult {
228 result: self,
229 order: TraversalOrder::Stop,
230 changed: false,
231 });
232 }
233 let (children, ord, changed) = if ord == TraversalOrder::Continue {
234 let mut new_children = Vec::with_capacity(self.children().len());
235 let mut changed = false;
236 for child in self.children() {
237 match ord {
238 TraversalOrder::Continue | TraversalOrder::Skip => {
239 let TransformResult {
240 result: new_child,
241 order: child_order,
242 changed: child_changed,
243 } = child.clone().transform(visitor)?;
244 new_children.push(new_child);
245 ord = child_order;
246 changed |= child_changed;
247 }
248 TraversalOrder::Stop => new_children.push(child.clone()),
249 }
250 }
251 (new_children, ord, changed)
252 } else {
253 (
254 self.children().into_iter().cloned().collect_vec(),
255 ord,
256 false,
257 )
258 };
259
260 if ord == TraversalOrder::Continue {
261 let up = visitor.visit_up(self.replacing_children(children))?;
262 Ok(TransformResult::yes(up.result))
263 } else {
264 Ok(TransformResult {
265 result: self.replacing_children(children),
266 order: ord,
267 changed,
268 })
269 }
270 }
271
272 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
273 self,
274 visitor: &mut V,
275 context: V::Context,
276 ) -> VortexResult<FoldUp<V::Out>> {
277 let children = match visitor.visit_down(&self, context.clone())? {
278 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
279 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
280 FoldDown::Continue(child_context) => {
281 let mut new_children = Vec::with_capacity(self.children().len());
282 for child in self.children() {
283 match child
284 .clone()
285 .transform_with_context(visitor, child_context.clone())?
286 {
287 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
288 FoldUp::Continue(out) => new_children.push(out),
289 }
290 }
291 new_children
292 }
293 };
294
295 visitor.visit_up(self, context, children)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use std::sync::Arc;
302
303 use vortex_array::aliases::hash_set::HashSet;
304 use vortex_error::VortexResult;
305
306 use crate::traversal::visitor::pre_order_visit_down;
307 use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
308 use crate::{
309 BinaryExpr, ExprRef, FieldName, GetItem, Literal, Operator, VortexExpr, VortexExprExt, col,
310 };
311
312 #[derive(Default)]
313 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
314
315 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
316 type NodeTy = ExprRef;
317
318 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
319 if node.as_any().is::<Literal>() {
320 self.0.push(node)
321 }
322 Ok(TraversalOrder::Continue)
323 }
324
325 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
326 Ok(TraversalOrder::Continue)
327 }
328 }
329
330 #[derive(Default)]
331 pub struct ExprColToLit(i32);
332
333 impl MutNodeVisitor for ExprColToLit {
334 type NodeTy = ExprRef;
335
336 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
337 let col = node.as_any().downcast_ref::<GetItem>();
338 if col.is_some() {
339 let id = self.0;
340 self.0 += 1;
341 Ok(TransformResult::yes(Literal::new_expr(id)))
342 } else {
343 Ok(TransformResult::no(node))
344 }
345 }
346 }
347
348 #[test]
349 fn expr_deep_visitor_test() {
350 let col1: Arc<dyn VortexExpr> = col("col1");
351 let lit1 = Literal::new_expr(1);
352 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
353 let lit2 = Literal::new_expr(2);
354 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
355 let mut printer = ExprLitCollector::default();
356 expr.accept(&mut printer).unwrap();
357 assert_eq!(printer.0.len(), 2);
358 }
359
360 #[test]
361 fn expr_deep_mut_visitor_test() {
362 let col1: Arc<dyn VortexExpr> = col("col1");
363 let col2: Arc<dyn VortexExpr> = col("col2");
364 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
365 let lit2 = Literal::new_expr(2);
366 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
367 let mut printer = ExprColToLit::default();
368 let new = expr.transform(&mut printer).unwrap();
369 assert!(new.changed);
370
371 let expr = new.result;
372
373 let mut printer = ExprLitCollector::default();
374 expr.accept(&mut printer).unwrap();
375 assert_eq!(printer.0.len(), 3);
376 }
377
378 #[test]
379 fn expr_skip_test() {
380 let col1: Arc<dyn VortexExpr> = col("col1");
381 let col2: Arc<dyn VortexExpr> = col("col2");
382 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
383 let col3: Arc<dyn VortexExpr> = col("col3");
384 let col4: Arc<dyn VortexExpr> = col("col4");
385 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
386 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
387
388 let mut nodes = Vec::new();
389 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
390 if node.as_any().is::<GetItem>() {
391 nodes.push(node)
392 }
393 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
394 if bin.op() == Operator::Eq {
395 return Ok(TraversalOrder::Skip);
396 }
397 }
398 Ok(TraversalOrder::Continue)
399 }))
400 .unwrap();
401
402 assert_eq!(
403 nodes
404 .into_iter()
405 .map(|x| x.references())
406 .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
407 HashSet::from_iter(vec![FieldName::from("col3"), FieldName::from("col4")])
408 );
409 }
410
411 #[test]
412 fn expr_stop_test() {
413 let col1: Arc<dyn VortexExpr> = col("col1");
414 let col2: Arc<dyn VortexExpr> = col("col2");
415 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
416 let col3: Arc<dyn VortexExpr> = col("col3");
417 let col4: Arc<dyn VortexExpr> = col("col4");
418 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
419 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
420
421 let mut nodes = Vec::new();
422 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
423 if node.as_any().is::<GetItem>() {
424 nodes.push(node)
425 }
426 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
427 if bin.op() == Operator::Eq {
428 return Ok(TraversalOrder::Stop);
429 }
430 }
431 Ok(TraversalOrder::Continue)
432 }))
433 .unwrap();
434
435 assert_eq!(
436 nodes
437 .into_iter()
438 .map(|x| x.references())
439 .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
440 HashSet::from_iter(vec![])
441 );
442 }
443}