1mod references;
2mod vars;
3mod visitor;
4
5pub use references::ReferenceCollector;
6pub use vars::VarsCollector;
7pub use visitor::{pre_order_visit_down, pre_order_visit_up};
8use vortex_error::VortexResult;
9
10use crate::ExprRef;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TraversalOrder {
22 Skip,
25
26 Stop,
28
29 Continue,
31}
32
33#[derive(Debug, Clone)]
34pub struct TransformResult<T> {
35 result: T,
36 order: TraversalOrder,
37 changed: bool,
38}
39
40impl<T> TransformResult<T> {
41 pub fn yes(result: T) -> Self {
42 Self {
43 result,
44 order: TraversalOrder::Continue,
45 changed: true,
46 }
47 }
48
49 pub fn no(result: T) -> Self {
50 Self {
51 result,
52 order: TraversalOrder::Continue,
53 changed: false,
54 }
55 }
56
57 pub fn into_inner(self) -> T {
58 self.result
59 }
60}
61
62pub trait NodeVisitor<'a> {
63 type NodeTy: Node;
64
65 fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66 Ok(TraversalOrder::Continue)
67 }
68
69 fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
70 Ok(TraversalOrder::Continue)
71 }
72}
73
74pub trait MutNodeVisitor {
75 type NodeTy: Node;
76
77 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
78 Ok(TraversalOrder::Continue)
79 }
80
81 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
82}
83
84pub enum FoldDown<Out, Context> {
85 Abort(Out),
87 SkipChildren(Out),
89 Continue(Context),
91}
92
93#[derive(Debug)]
94pub enum FoldUp<Out> {
95 Abort(Out),
97 Continue(Out),
99}
100
101impl<Out> FoldUp<Out> {
102 pub fn result(self) -> Out {
103 match self {
104 FoldUp::Abort(out) => out,
105 FoldUp::Continue(out) => out,
106 }
107 }
108}
109
110pub trait Folder<'a> {
111 type NodeTy: Node;
112 type Out;
113 type Context: Clone;
114
115 fn visit_down(
116 &mut self,
117 _node: &'a Self::NodeTy,
118 context: Self::Context,
119 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
120 Ok(FoldDown::Continue(context))
121 }
122
123 fn visit_up(
124 &mut self,
125 node: &'a Self::NodeTy,
126 context: Self::Context,
127 children: Vec<Self::Out>,
128 ) -> VortexResult<FoldUp<Self::Out>>;
129}
130
131pub trait FolderMut {
132 type NodeTy: Node;
133 type Out;
134 type Context: Clone;
135
136 fn visit_down(
137 &mut self,
138 _node: &Self::NodeTy,
139 context: Self::Context,
140 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
141 Ok(FoldDown::Continue(context))
142 }
143
144 fn visit_up(
145 &mut self,
146 node: Self::NodeTy,
147 context: Self::Context,
148 children: Vec<Self::Out>,
149 ) -> VortexResult<FoldUp<Self::Out>>;
150}
151
152pub trait Node: Sized {
153 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
154 &'a self,
155 _visitor: &mut V,
156 ) -> VortexResult<TraversalOrder>;
157
158 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
159 &'a self,
160 visitor: &mut V,
161 context: V::Context,
162 ) -> VortexResult<FoldUp<V::Out>>;
163
164 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
165 self,
166 _visitor: &mut V,
167 ) -> VortexResult<TransformResult<Self>>;
168
169 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
170 self,
171 _visitor: &mut V,
172 _context: V::Context,
173 ) -> VortexResult<FoldUp<V::Out>>;
174}
175
176impl Node for ExprRef {
177 fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
179 &'a self,
180 visitor: &mut V,
181 ) -> VortexResult<TraversalOrder> {
182 let mut ord = visitor.visit_down(self)?;
183 if ord == TraversalOrder::Stop {
184 return Ok(TraversalOrder::Stop);
185 }
186 if ord == TraversalOrder::Skip {
187 return Ok(TraversalOrder::Continue);
188 }
189 for child in self.children() {
190 if ord != TraversalOrder::Continue {
191 return Ok(ord);
192 }
193 ord = child.accept(visitor)?;
194 }
195 if ord == TraversalOrder::Stop {
196 return Ok(TraversalOrder::Stop);
197 }
198 visitor.visit_up(self)
199 }
200
201 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
202 &'a self,
203 visitor: &mut V,
204 context: V::Context,
205 ) -> VortexResult<FoldUp<V::Out>> {
206 let children = match visitor.visit_down(self, context.clone())? {
207 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
208 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
209 FoldDown::Continue(child_context) => {
210 let mut new_children = Vec::with_capacity(self.children().len());
211 for child in self.children() {
212 match child.accept_with_context(visitor, child_context.clone())? {
213 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
214 FoldUp::Continue(out) => new_children.push(out),
215 }
216 }
217 new_children
218 }
219 };
220
221 visitor.visit_up(self, context, children)
222 }
223
224 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
226 self,
227 visitor: &mut V,
228 ) -> VortexResult<TransformResult<Self>> {
229 match visitor.visit_down(&self)? {
230 TraversalOrder::Stop => Ok(TransformResult {
231 result: self,
232 order: TraversalOrder::Stop,
233 changed: false,
234 }),
235 TraversalOrder::Skip => visitor.visit_up(self),
236 TraversalOrder::Continue => {
237 let mut new_children = Vec::with_capacity(self.children().len());
238 let mut changed = false;
239 let mut stopped = false;
240 for child in self.children() {
241 if stopped {
242 new_children.push(child.clone());
243 continue;
244 }
245 let TransformResult {
246 result: new_child,
247 order: child_order,
248 changed: child_changed,
249 } = child.clone().transform(visitor)?;
250 new_children.push(new_child);
251 changed |= child_changed;
252 stopped |= child_order == TraversalOrder::Stop;
253 }
254
255 if changed {
256 let up = visitor.visit_up(self.replacing_children(new_children))?;
257 Ok(TransformResult::yes(up.result))
258 } else {
259 visitor.visit_up(self)
260 }
261 }
262 }
263 }
264
265 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
266 self,
267 visitor: &mut V,
268 context: V::Context,
269 ) -> VortexResult<FoldUp<V::Out>> {
270 let children = match visitor.visit_down(&self, context.clone())? {
271 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
272 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
273 FoldDown::Continue(child_context) => {
274 let mut new_children = Vec::with_capacity(self.children().len());
275 for child in self.children() {
276 match child
277 .clone()
278 .transform_with_context(visitor, child_context.clone())?
279 {
280 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
281 FoldUp::Continue(out) => new_children.push(out),
282 }
283 }
284 new_children
285 }
286 };
287
288 visitor.visit_up(self, context, children)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use std::sync::Arc;
295
296 use vortex_error::VortexResult;
297 use vortex_utils::aliases::hash_set::HashSet;
298
299 use crate::traversal::visitor::pre_order_visit_down;
300 use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
301 use crate::{
302 BinaryExpr, ExprRef, GetItem, Literal, Operator, VortexExpr, col, get_item_scope, is_root,
303 root,
304 };
305
306 #[derive(Default)]
307 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
308
309 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
310 type NodeTy = ExprRef;
311
312 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
313 if node.as_any().is::<Literal>() {
314 self.0.push(node)
315 }
316 Ok(TraversalOrder::Continue)
317 }
318
319 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
320 Ok(TraversalOrder::Continue)
321 }
322 }
323
324 #[derive(Default)]
325 pub struct ExprColToLit(i32);
326
327 impl MutNodeVisitor for ExprColToLit {
328 type NodeTy = ExprRef;
329
330 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
331 let col = node.as_any().downcast_ref::<GetItem>();
332 if col.is_some() {
333 let id = self.0;
334 self.0 += 1;
335 Ok(TransformResult::yes(Literal::new_expr(id)))
336 } else {
337 Ok(TransformResult::no(node))
338 }
339 }
340 }
341
342 #[derive(Default)]
343 pub struct SkipDownVisitor;
344
345 impl MutNodeVisitor for SkipDownVisitor {
346 type NodeTy = ExprRef;
347
348 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
349 Ok(TraversalOrder::Skip)
350 }
351
352 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
353 Ok(TransformResult::yes(root()))
354 }
355 }
356
357 #[test]
358 fn expr_deep_visitor_test() {
359 let col1: Arc<dyn VortexExpr> = col("col1");
360 let lit1 = Literal::new_expr(1);
361 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
362 let lit2 = Literal::new_expr(2);
363 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
364 let mut printer = ExprLitCollector::default();
365 expr.accept(&mut printer).unwrap();
366 assert_eq!(printer.0.len(), 2);
367 }
368
369 #[test]
370 fn expr_deep_mut_visitor_test() {
371 let col1: Arc<dyn VortexExpr> = col("col1");
372 let col2: Arc<dyn VortexExpr> = col("col2");
373 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
374 let lit2 = Literal::new_expr(2);
375 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
376 let mut printer = ExprColToLit::default();
377 let new = expr.transform(&mut printer).unwrap();
378 assert!(new.changed);
379
380 let expr = new.result;
381
382 let mut printer = ExprLitCollector::default();
383 expr.accept(&mut printer).unwrap();
384 assert_eq!(printer.0.len(), 3);
385 }
386
387 #[test]
388 fn expr_skip_test() {
389 let col1: Arc<dyn VortexExpr> = col("col1");
390 let col2: Arc<dyn VortexExpr> = col("col2");
391 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
392 let col3: Arc<dyn VortexExpr> = col("col3");
393 let col4: Arc<dyn VortexExpr> = col("col4");
394 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
395 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
396
397 let mut nodes = Vec::new();
398 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
399 if node.as_any().is::<GetItem>() {
400 nodes.push(node)
401 }
402 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
403 if bin.op() == Operator::Eq {
404 return Ok(TraversalOrder::Skip);
405 }
406 }
407 Ok(TraversalOrder::Continue)
408 }))
409 .unwrap();
410
411 let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
412 assert_eq!(
413 nodes,
414 HashSet::from_iter([get_item_scope("col3"), get_item_scope("col4")])
415 );
416 }
417
418 #[test]
419 fn expr_stop_test() {
420 let col1: Arc<dyn VortexExpr> = col("col1");
421 let col2: Arc<dyn VortexExpr> = col("col2");
422 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
423 let col3: Arc<dyn VortexExpr> = col("col3");
424 let col4: Arc<dyn VortexExpr> = col("col4");
425 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
426 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
427
428 let mut nodes = Vec::new();
429 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
430 if node.as_any().is::<GetItem>() {
431 nodes.push(node)
432 }
433 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
434 if bin.op() == Operator::Eq {
435 return Ok(TraversalOrder::Stop);
436 }
437 }
438 Ok(TraversalOrder::Continue)
439 }))
440 .unwrap();
441
442 assert!(nodes.is_empty());
443 }
444
445 #[test]
446 fn expr_skip_down_visit_up() {
447 let col = col("col");
448
449 let mut visitor = SkipDownVisitor;
450 let result = col.transform(&mut visitor).unwrap();
451
452 assert!(result.changed);
453 assert!(is_root(&result.result));
454 }
455}