1mod fold;
9mod references;
10mod visitor;
11
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15pub use fold::{FoldDown, FoldDownContext, FoldUp, NodeFolder, NodeFolderContext};
16use itertools::Itertools;
17pub use references::ReferenceCollector;
18pub use visitor::{pre_order_visit_down, pre_order_visit_up};
19use vortex_error::VortexResult;
20
21use crate::ExprRef;
22use crate::traversal::fold::NodeFolderContextWrapper;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum TraversalOrder {
27 Skip,
31 Stop,
33 Continue,
35}
36
37impl TraversalOrder {
38 pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
40 self,
41 f: F,
42 ) -> VortexResult<TraversalOrder> {
43 match self {
44 Self::Skip => Ok(TraversalOrder::Continue),
45 Self::Stop => Ok(self),
46 Self::Continue => f(),
47 }
48 }
49
50 pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
52 self,
53 f: F,
54 ) -> VortexResult<TraversalOrder> {
55 match self {
56 Self::Continue => f(),
57 Self::Skip | Self::Stop => Ok(self),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct Transformed<T> {
64 pub value: T,
66 pub order: TraversalOrder,
68 pub changed: bool,
70}
71
72impl<T> Transformed<T> {
73 pub fn yes(value: T) -> Self {
74 Self {
75 value,
76 order: TraversalOrder::Continue,
77 changed: true,
78 }
79 }
80
81 pub fn no(value: T) -> Self {
82 Self {
83 value,
84 order: TraversalOrder::Continue,
85 changed: false,
86 }
87 }
88
89 pub fn into_inner(self) -> T {
90 self.value
91 }
92
93 pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
95 Transformed {
96 value: f(self.value),
97 order: self.order,
98 changed: self.changed,
99 }
100 }
101}
102
103pub trait NodeVisitor<'a> {
104 type NodeTy: Node;
105
106 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
107 _ = node;
108 Ok(TraversalOrder::Continue)
109 }
110
111 fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
112 _ = node;
113 Ok(TraversalOrder::Continue)
114 }
115}
116
117pub trait NodeRewriter: Sized {
118 type NodeTy: Node;
119
120 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
121 Ok(Transformed::no(node))
122 }
123
124 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
125 Ok(Transformed::no(node))
126 }
127}
128
129pub trait Node: Sized + Clone {
130 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
134 &'a self,
135 f: F,
136 ) -> VortexResult<TraversalOrder>;
137
138 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
142 self,
143 f: F,
144 ) -> VortexResult<Transformed<Self>>;
145
146 fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T;
148
149 fn children_count(&self) -> usize;
151}
152
153pub trait NodeExt: Node {
154 fn rewrite<R: NodeRewriter<NodeTy = Self>>(
156 self,
157 rewriter: &mut R,
158 ) -> VortexResult<Transformed<Self>> {
159 let mut transformed = rewriter.visit_down(self)?;
160
161 let transformed = match transformed.order {
162 TraversalOrder::Stop => Ok(transformed),
163 TraversalOrder::Skip => {
164 transformed.order = TraversalOrder::Continue;
165 Ok(transformed)
166 }
167 TraversalOrder::Continue => transformed
168 .value
169 .map_children(|c| c.rewrite(rewriter))
170 .map(|mut t| {
171 t.changed |= transformed.changed;
172 t
173 }),
174 }?;
175
176 match transformed.order {
177 TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
178 TraversalOrder::Continue => {
179 let mut up_rewrite = rewriter.visit_up(transformed.value)?;
180 up_rewrite.changed |= transformed.changed;
181 Ok(up_rewrite)
182 }
183 }
184 }
185
186 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
188 &'a self,
189 visitor: &mut V,
190 ) -> VortexResult<TraversalOrder> {
191 visitor
192 .visit_down(self)?
193 .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
194 .visit_parent(|| visitor.visit_up(self))
195 }
196
197 fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
199 self,
200 f: F,
201 ) -> VortexResult<Transformed<Self>> {
202 let mut rewriter = FnRewriter {
203 f_down: Some(f),
204 f_up: None,
205 _data: PhantomData,
206 };
207
208 self.rewrite(&mut rewriter)
209 }
210
211 fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
213 self,
214 f: F,
215 ) -> VortexResult<Transformed<Self>> {
216 let mut rewriter = FnRewriter {
217 f_down: None,
218 f_up: Some(f),
219 _data: PhantomData,
220 };
221
222 self.rewrite(&mut rewriter)
223 }
224
225 fn fold_context<R, F: NodeFolderContext<NodeTy = Self, Result = R>>(
227 self,
228 ctx: &F::Context,
229 folder: &mut F,
230 ) -> VortexResult<FoldUp<R>> {
231 let transformed = folder.visit_down(ctx, &self)?;
232 let ctx = match transformed {
233 FoldDownContext::Continue(ctx) => ctx,
234 FoldDownContext::Skip(r) => return Ok(FoldUp::Continue(r)),
235 FoldDownContext::Stop(r) => return Ok(FoldUp::Stop(r)),
236 };
237
238 let mut children = Vec::with_capacity(self.children_count());
239 let mut stop_result = None;
240 self.iter_children(|children_iter| -> VortexResult<()> {
241 for c in children_iter {
242 let t = c.clone().fold_context(&ctx, folder)?;
243 match t {
244 FoldUp::Stop(r) => {
245 stop_result = Some(r);
246 return Ok(());
247 }
248 FoldUp::Continue(r) => {
249 children.push(r);
250 }
251 }
252 }
253 Ok(())
254 })?;
255
256 if let Some(result) = stop_result {
257 return Ok(FoldUp::Stop(result));
258 }
259
260 folder.visit_up(self, &ctx, children)
261 }
262
263 fn fold<R, F: NodeFolder<NodeTy = Self, Result = R>>(
265 self,
266 folder: &mut F,
267 ) -> VortexResult<FoldUp<R>> {
268 let mut folder = NodeFolderContextWrapper { inner: folder };
269 self.fold_context(&(), &mut folder)
270 }
271}
272
273impl<T: Node> NodeExt for T {}
274
275struct FnRewriter<F, T> {
276 f_down: Option<F>,
277 f_up: Option<F>,
278 _data: PhantomData<T>,
279}
280
281impl<F, T> NodeRewriter for FnRewriter<F, T>
282where
283 T: Node,
284 F: FnMut(T) -> VortexResult<Transformed<T>>,
285{
286 type NodeTy = T;
287
288 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
289 if let Some(f) = self.f_down.as_mut() {
290 f(node)
291 } else {
292 Ok(Transformed::no(node))
293 }
294 }
295
296 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
297 if let Some(f) = self.f_up.as_mut() {
298 f(node)
299 } else {
300 Ok(Transformed::no(node))
301 }
302 }
303}
304
305pub trait NodeContainer<'a, T: 'a>: Sized {
310 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
312 &'a self,
313 f: F,
314 ) -> VortexResult<TraversalOrder>;
315
316 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
318 self,
319 f: F,
320 ) -> VortexResult<Transformed<Self>>;
321}
322
323pub trait NodeRefContainer<'a, T: 'a>: Sized {
324 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
325 &self,
326 f: F,
327 ) -> VortexResult<TraversalOrder>;
328}
329
330impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for Vec<&'a C> {
331 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
332 &self,
333 mut f: F,
334 ) -> VortexResult<TraversalOrder> {
335 let mut order = TraversalOrder::Continue;
336
337 for c in self {
338 order = c.apply_elements(&mut f)?;
339 match order {
340 TraversalOrder::Continue | TraversalOrder::Skip => {}
341 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
342 }
343 }
344
345 Ok(order)
346 }
347}
348
349impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
350 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
351 &'a self,
352 f: F,
353 ) -> VortexResult<TraversalOrder> {
354 self.as_ref().apply_elements(f)
355 }
356
357 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
358 self,
359 f: F,
360 ) -> VortexResult<Transformed<Box<C>>> {
361 Ok((*self).map_elements(f)?.map(Box::new))
362 }
363}
364
365impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
366where
367 T: 'a,
368 C: NodeContainer<'a, T> + Clone,
369{
370 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
371 &'a self,
372 f: F,
373 ) -> VortexResult<TraversalOrder> {
374 self.as_ref().apply_elements(f)
375 }
376
377 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
378 self,
379 f: F,
380 ) -> VortexResult<Transformed<Arc<C>>> {
381 Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
382 }
383}
384
385impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
386 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
387 &'a self,
388 mut f: F,
389 ) -> VortexResult<TraversalOrder> {
390 let [lhs, rhs] = self;
391 match lhs.apply_elements(&mut f)? {
392 TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
393 TraversalOrder::Stop => Ok(TraversalOrder::Stop),
394 }
395 }
396
397 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
398 self,
399 mut f: F,
400 ) -> VortexResult<Transformed<[C; 2]>> {
401 let [lhs, rhs] = self;
402 let transformed = lhs.map_elements(&mut f)?;
403 match transformed.order {
404 TraversalOrder::Skip | TraversalOrder::Continue => {
405 let mut t = rhs.map_elements(&mut f)?;
406 t.changed |= transformed.changed;
407 Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
408 }
409 TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
410 }
411 }
412}
413
414impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
415 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
416 &'a self,
417 mut f: F,
418 ) -> VortexResult<TraversalOrder> {
419 let mut order = TraversalOrder::Continue;
420
421 for c in self {
422 order = c.apply_elements(&mut f)?;
423 match order {
424 TraversalOrder::Continue | TraversalOrder::Skip => {}
425 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
426 }
427 }
428
429 Ok(order)
430 }
431
432 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
433 self,
434 mut f: F,
435 ) -> VortexResult<Transformed<Self>> {
436 let mut order = TraversalOrder::Continue;
437 let mut changed = false;
438
439 let value = self
440 .into_iter()
441 .map(|c| match order {
442 TraversalOrder::Continue | TraversalOrder::Skip => {
443 c.map_elements(&mut f).map(|result| {
444 order = result.order;
445 changed |= result.changed;
446 result.value
447 })
448 }
449 TraversalOrder::Stop => Ok(c),
450 })
451 .collect::<VortexResult<Vec<_>>>()?;
452
453 Ok(Transformed {
454 value,
455 order,
456 changed,
457 })
458 }
459}
460
461impl<'a> NodeContainer<'a, Self> for ExprRef {
462 fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
463 &'a self,
464 mut f: F,
465 ) -> VortexResult<TraversalOrder> {
466 f(self)
467 }
468
469 fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
470 self,
471 mut f: F,
472 ) -> VortexResult<Transformed<Self>> {
473 f(self)
474 }
475}
476
477impl Node for ExprRef {
478 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
479 &'a self,
480 mut f: F,
481 ) -> VortexResult<TraversalOrder> {
482 self.children().apply_ref_elements(&mut f)
483 }
484
485 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
486 self,
487 f: F,
488 ) -> VortexResult<Transformed<Self>> {
489 let transformed = self
490 .children()
491 .into_iter()
492 .cloned()
493 .collect_vec()
494 .map_elements(f)?;
495
496 if transformed.changed {
497 Ok(Transformed {
498 value: self.with_children(transformed.value)?,
499 order: transformed.order,
500 changed: true,
501 })
502 } else {
503 Ok(Transformed::no(self))
504 }
505 }
506
507 fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
508 f(&mut self.children().into_iter())
509 }
510
511 fn children_count(&self) -> usize {
512 self.children().len()
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use std::sync::Arc;
519
520 use vortex_error::VortexResult;
521 use vortex_utils::aliases::hash_set::HashSet;
522
523 use crate::traversal::visitor::pre_order_visit_down;
524 use crate::traversal::{NodeExt, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
525 use crate::{
526 BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralExpr, LiteralVTable,
527 Operator, VortexExpr, col, is_root, root,
528 };
529
530 #[derive(Default)]
531 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
532
533 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
534 type NodeTy = ExprRef;
535
536 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
537 if node.is::<LiteralVTable>() {
538 self.0.push(node)
539 }
540 Ok(TraversalOrder::Continue)
541 }
542
543 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
544 Ok(TraversalOrder::Continue)
545 }
546 }
547
548 fn expr_col_to_lit_transform(
549 node: ExprRef,
550 idx: &mut i32,
551 ) -> VortexResult<Transformed<ExprRef>> {
552 if node.is::<GetItemVTable>() {
553 let lit_id = *idx;
554 *idx += 1;
555 Ok(Transformed::yes(LiteralExpr::new_expr(lit_id)))
556 } else {
557 Ok(Transformed::no(node))
558 }
559 }
560
561 #[derive(Default)]
562 pub struct SkipDownRewriter;
563
564 impl NodeRewriter for SkipDownRewriter {
565 type NodeTy = ExprRef;
566
567 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
568 Ok(Transformed {
569 value: node,
570 order: TraversalOrder::Skip,
571 changed: false,
572 })
573 }
574
575 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
576 Ok(Transformed::yes(root()))
577 }
578 }
579
580 #[test]
581 fn expr_deep_visitor_test() {
582 let col1: Arc<dyn VortexExpr> = col("col1");
583 let lit1 = LiteralExpr::new(1).into_expr();
584 let expr = BinaryExpr::new(col1.clone(), Operator::Eq, lit1.clone()).into_expr();
585 let lit2 = LiteralExpr::new(2).into_expr();
586 let expr = BinaryExpr::new(expr, Operator::And, lit2).into_expr();
587 let mut printer = ExprLitCollector::default();
588 expr.accept(&mut printer).unwrap();
589 assert_eq!(printer.0.len(), 2);
590 }
591
592 #[test]
593 fn expr_deep_mut_visitor_test() {
594 let col1: Arc<dyn VortexExpr> = col("col1");
595 let col2: Arc<dyn VortexExpr> = col("col2");
596 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
597 let lit2 = LiteralExpr::new_expr(2);
598 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
599
600 let mut idx = 0_i32;
601 let new = expr
602 .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
603 .unwrap();
604 assert!(new.changed);
605
606 let expr = new.value;
607
608 let mut printer = ExprLitCollector::default();
609 expr.accept(&mut printer).unwrap();
610 assert_eq!(printer.0.len(), 3);
611 }
612
613 #[test]
614 fn expr_skip_test() {
615 let col1: Arc<dyn VortexExpr> = col("col1");
616 let col2: Arc<dyn VortexExpr> = col("col2");
617 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
618 let col3: Arc<dyn VortexExpr> = col("col3");
619 let col4: Arc<dyn VortexExpr> = col("col4");
620 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
621 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
622
623 let mut nodes = Vec::new();
624 pre_order_visit_down(&expr, |node: &ExprRef| {
625 if node.is::<GetItemVTable>() {
626 nodes.push(node)
627 }
628 if let Some(bin) = node.as_opt::<BinaryVTable>()
629 && bin.op() == Operator::Eq
630 {
631 return Ok(TraversalOrder::Skip);
632 }
633 Ok(TraversalOrder::Continue)
634 })
635 .unwrap();
636
637 let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
638 assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
639 }
640
641 #[test]
642 fn expr_stop_test() {
643 let col1: Arc<dyn VortexExpr> = col("col1");
644 let col2: Arc<dyn VortexExpr> = col("col2");
645 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
646 let col3: Arc<dyn VortexExpr> = col("col3");
647 let col4: Arc<dyn VortexExpr> = col("col4");
648 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
649 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
650
651 let mut nodes = Vec::new();
652 pre_order_visit_down(&expr, |node: &ExprRef| {
653 if node.is::<GetItemVTable>() {
654 nodes.push(node)
655 }
656 if let Some(bin) = node.as_opt::<BinaryVTable>()
657 && bin.op() == Operator::Eq
658 {
659 return Ok(TraversalOrder::Stop);
660 }
661 Ok(TraversalOrder::Continue)
662 })
663 .unwrap();
664
665 assert!(nodes.is_empty());
666 }
667
668 #[test]
669 fn expr_skip_down_visit_up() {
670 let col = col("col");
671
672 let mut visitor = SkipDownRewriter;
673 let result = col.rewrite(&mut visitor).unwrap();
674
675 assert!(result.changed);
676 assert!(is_root(&result.value));
677 }
678}