1mod fold;
9mod references;
10mod visitor;
11
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15pub use fold::FoldDown;
16pub use fold::FoldDownContext;
17pub use fold::FoldUp;
18pub use fold::NodeFolder;
19pub use fold::NodeFolderContext;
20use itertools::Itertools;
21pub use references::ReferenceCollector;
22pub use visitor::pre_order_visit_down;
23pub use visitor::pre_order_visit_up;
24use vortex_error::VortexResult;
25
26use crate::expr::Expression;
27use crate::expr::traversal::fold::NodeFolderContextWrapper;
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum TraversalOrder {
32 Skip,
36 Stop,
38 Continue,
40}
41
42impl TraversalOrder {
43 pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
45 self,
46 f: F,
47 ) -> VortexResult<TraversalOrder> {
48 match self {
49 Self::Skip => Ok(TraversalOrder::Continue),
50 Self::Stop => Ok(self),
51 Self::Continue => f(),
52 }
53 }
54
55 pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
57 self,
58 f: F,
59 ) -> VortexResult<TraversalOrder> {
60 match self {
61 Self::Continue => f(),
62 Self::Skip | Self::Stop => Ok(self),
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
68pub struct Transformed<T> {
69 pub value: T,
71 pub order: TraversalOrder,
73 pub changed: bool,
75}
76
77impl<T> Transformed<T> {
78 pub fn yes(value: T) -> Self {
79 Self {
80 value,
81 order: TraversalOrder::Continue,
82 changed: true,
83 }
84 }
85
86 pub fn no(value: T) -> Self {
87 Self {
88 value,
89 order: TraversalOrder::Continue,
90 changed: false,
91 }
92 }
93
94 pub fn into_inner(self) -> T {
95 self.value
96 }
97
98 pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
100 Transformed {
101 value: f(self.value),
102 order: self.order,
103 changed: self.changed,
104 }
105 }
106}
107
108pub trait NodeVisitor<'a> {
109 type NodeTy: Node;
110
111 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
112 _ = node;
113 Ok(TraversalOrder::Continue)
114 }
115
116 fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
117 _ = node;
118 Ok(TraversalOrder::Continue)
119 }
120}
121
122pub trait NodeRewriter: Sized {
123 type NodeTy: Node;
124
125 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
126 Ok(Transformed::no(node))
127 }
128
129 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
130 Ok(Transformed::no(node))
131 }
132}
133
134pub trait Node: Sized + Clone {
135 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
139 &'a self,
140 f: F,
141 ) -> VortexResult<TraversalOrder>;
142
143 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
147 self,
148 f: F,
149 ) -> VortexResult<Transformed<Self>>;
150
151 fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T;
153
154 fn children_count(&self) -> usize;
156}
157
158pub trait NodeExt: Node {
159 fn rewrite<R: NodeRewriter<NodeTy = Self>>(
161 self,
162 rewriter: &mut R,
163 ) -> VortexResult<Transformed<Self>> {
164 let mut transformed = rewriter.visit_down(self)?;
165
166 let transformed = match transformed.order {
167 TraversalOrder::Stop => Ok(transformed),
168 TraversalOrder::Skip => {
169 transformed.order = TraversalOrder::Continue;
170 Ok(transformed)
171 }
172 TraversalOrder::Continue => transformed
173 .value
174 .map_children(|c| c.rewrite(rewriter))
175 .map(|mut t| {
176 t.changed |= transformed.changed;
177 t
178 }),
179 }?;
180
181 match transformed.order {
182 TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
183 TraversalOrder::Continue => {
184 let mut up_rewrite = rewriter.visit_up(transformed.value)?;
185 up_rewrite.changed |= transformed.changed;
186 Ok(up_rewrite)
187 }
188 }
189 }
190
191 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
193 &'a self,
194 visitor: &mut V,
195 ) -> VortexResult<TraversalOrder> {
196 visitor
197 .visit_down(self)?
198 .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
199 .visit_parent(|| visitor.visit_up(self))
200 }
201
202 fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
204 self,
205 f: F,
206 ) -> VortexResult<Transformed<Self>> {
207 let mut rewriter = FnRewriter::<F, F, _> {
208 f_down: Some(f),
209 f_up: None,
210 _data: PhantomData,
211 };
212
213 self.rewrite(&mut rewriter)
214 }
215
216 fn transform<F, G>(self, down: F, up: G) -> VortexResult<Transformed<Self>>
217 where
218 F: FnMut(Self) -> VortexResult<Transformed<Self>>,
219 G: FnMut(Self) -> VortexResult<Transformed<Self>>,
220 {
221 let mut rewriter = FnRewriter {
222 f_down: Some(down),
223 f_up: Some(up),
224 _data: PhantomData,
225 };
226
227 self.rewrite(&mut rewriter)
228 }
229
230 fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
232 self,
233 f: F,
234 ) -> VortexResult<Transformed<Self>> {
235 let mut rewriter = FnRewriter::<F, F, _> {
236 f_down: None,
237 f_up: Some(f),
238 _data: PhantomData,
239 };
240
241 self.rewrite(&mut rewriter)
242 }
243
244 fn fold_context<R, F: NodeFolderContext<NodeTy = Self, Result = R>>(
246 self,
247 ctx: &F::Context,
248 folder: &mut F,
249 ) -> VortexResult<FoldUp<R>> {
250 let transformed = folder.visit_down(ctx, &self)?;
251 let ctx = match transformed {
252 FoldDownContext::Continue(ctx) => ctx,
253 FoldDownContext::Skip(r) => return Ok(FoldUp::Continue(r)),
254 FoldDownContext::Stop(r) => return Ok(FoldUp::Stop(r)),
255 };
256
257 let mut children = Vec::with_capacity(self.children_count());
258 let mut stop_result = None;
259 self.iter_children(|children_iter| -> VortexResult<()> {
260 for c in children_iter {
261 let t = c.clone().fold_context(&ctx, folder)?;
262 match t {
263 FoldUp::Stop(r) => {
264 stop_result = Some(r);
265 return Ok(());
266 }
267 FoldUp::Continue(r) => {
268 children.push(r);
269 }
270 }
271 }
272 Ok(())
273 })?;
274
275 if let Some(result) = stop_result {
276 return Ok(FoldUp::Stop(result));
277 }
278
279 folder.visit_up(self, &ctx, children)
280 }
281
282 fn fold<R, F: NodeFolder<NodeTy = Self, Result = R>>(
284 self,
285 folder: &mut F,
286 ) -> VortexResult<FoldUp<R>> {
287 let mut folder = NodeFolderContextWrapper { inner: folder };
288 self.fold_context(&(), &mut folder)
289 }
290}
291
292impl<T: Node> NodeExt for T {}
293
294struct FnRewriter<F, G, T> {
295 f_down: Option<F>,
296 f_up: Option<G>,
297 _data: PhantomData<T>,
298}
299
300impl<F, G, T> NodeRewriter for FnRewriter<F, G, T>
301where
302 T: Node,
303 F: FnMut(T) -> VortexResult<Transformed<T>>,
304 G: FnMut(T) -> VortexResult<Transformed<T>>,
305{
306 type NodeTy = T;
307
308 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
309 if let Some(f) = self.f_down.as_mut() {
310 f(node)
311 } else {
312 Ok(Transformed::no(node))
313 }
314 }
315
316 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
317 if let Some(f) = self.f_up.as_mut() {
318 f(node)
319 } else {
320 Ok(Transformed::no(node))
321 }
322 }
323}
324
325pub trait NodeContainer<'a, T: 'a>: Sized {
330 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
332 &'a self,
333 f: F,
334 ) -> VortexResult<TraversalOrder>;
335
336 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
338 self,
339 f: F,
340 ) -> VortexResult<Transformed<Self>>;
341}
342
343pub trait NodeRefContainer<'a, T: 'a>: Sized {
344 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
345 &self,
346 f: F,
347 ) -> VortexResult<TraversalOrder>;
348}
349
350impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for &'a [C] {
351 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
352 &self,
353 mut f: F,
354 ) -> VortexResult<TraversalOrder> {
355 let mut order = TraversalOrder::Continue;
356
357 for c in *self {
358 order = c.apply_elements(&mut f)?;
359 match order {
360 TraversalOrder::Continue | TraversalOrder::Skip => {}
361 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
362 }
363 }
364
365 Ok(order)
366 }
367}
368
369impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
370 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
371 &'a self,
372 f: F,
373 ) -> VortexResult<TraversalOrder> {
374 self.as_ref().apply_elements(f)
375 }
376
377 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
378 self,
379 f: F,
380 ) -> VortexResult<Transformed<Box<C>>> {
381 Ok((*self).map_elements(f)?.map(Box::new))
382 }
383}
384
385impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
386where
387 T: 'a,
388 C: NodeContainer<'a, T> + Clone,
389{
390 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
391 &'a self,
392 f: F,
393 ) -> VortexResult<TraversalOrder> {
394 self.as_ref().apply_elements(f)
395 }
396
397 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
398 self,
399 f: F,
400 ) -> VortexResult<Transformed<Arc<C>>> {
401 Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
402 }
403}
404
405impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
406 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
407 &'a self,
408 mut f: F,
409 ) -> VortexResult<TraversalOrder> {
410 let [lhs, rhs] = self;
411 match lhs.apply_elements(&mut f)? {
412 TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
413 TraversalOrder::Stop => Ok(TraversalOrder::Stop),
414 }
415 }
416
417 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
418 self,
419 mut f: F,
420 ) -> VortexResult<Transformed<[C; 2]>> {
421 let [lhs, rhs] = self;
422 let transformed = lhs.map_elements(&mut f)?;
423 match transformed.order {
424 TraversalOrder::Skip | TraversalOrder::Continue => {
425 let mut t = rhs.map_elements(&mut f)?;
426 t.changed |= transformed.changed;
427 Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
428 }
429 TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
430 }
431 }
432}
433
434impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
435 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
436 &'a self,
437 mut f: F,
438 ) -> VortexResult<TraversalOrder> {
439 let mut order = TraversalOrder::Continue;
440
441 for c in self {
442 order = c.apply_elements(&mut f)?;
443 match order {
444 TraversalOrder::Continue | TraversalOrder::Skip => {}
445 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
446 }
447 }
448
449 Ok(order)
450 }
451
452 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
453 self,
454 mut f: F,
455 ) -> VortexResult<Transformed<Self>> {
456 let mut order = TraversalOrder::Continue;
457 let mut changed = false;
458
459 let value = self
460 .into_iter()
461 .map(|c| match order {
462 TraversalOrder::Continue | TraversalOrder::Skip => {
463 c.map_elements(&mut f).map(|result| {
464 order = result.order;
465 changed |= result.changed;
466 result.value
467 })
468 }
469 TraversalOrder::Stop => Ok(c),
470 })
471 .collect::<VortexResult<Vec<_>>>()?;
472
473 Ok(Transformed {
474 value,
475 order,
476 changed,
477 })
478 }
479}
480
481impl<'a> NodeContainer<'a, Self> for Expression {
482 fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
483 &'a self,
484 mut f: F,
485 ) -> VortexResult<TraversalOrder> {
486 f(self)
487 }
488
489 fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
490 self,
491 mut f: F,
492 ) -> VortexResult<Transformed<Self>> {
493 f(self)
494 }
495}
496
497impl Node for Expression {
498 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
499 &'a self,
500 mut f: F,
501 ) -> VortexResult<TraversalOrder> {
502 self.children().as_ref().apply_ref_elements(&mut f)
503 }
504
505 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
506 self,
507 f: F,
508 ) -> VortexResult<Transformed<Self>> {
509 let transformed = self
510 .children()
511 .iter()
512 .cloned()
513 .collect_vec()
514 .map_elements(f)?;
515
516 if transformed.changed {
517 Ok(Transformed {
518 value: self.with_children(transformed.value)?,
519 order: transformed.order,
520 changed: true,
521 })
522 } else {
523 Ok(Transformed::no(self))
524 }
525 }
526
527 fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
528 f(&mut self.children().iter())
529 }
530
531 fn children_count(&self) -> usize {
532 self.children().len()
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use vortex_error::VortexResult;
539 use vortex_utils::aliases::hash_set::HashSet;
540
541 use super::NodeExt;
542 use super::NodeRewriter;
543 use super::NodeVisitor;
544 use super::Transformed;
545 use super::TraversalOrder;
546 use super::visitor::pre_order_visit_down;
547 use crate::expr::Expression;
548 use crate::expr::exprs::binary::Binary;
549 use crate::expr::exprs::binary::and;
550 use crate::expr::exprs::binary::eq;
551 use crate::expr::exprs::binary::not_eq;
552 use crate::expr::exprs::get_item::GetItem;
553 use crate::expr::exprs::get_item::col;
554 use crate::expr::exprs::literal::Literal;
555 use crate::expr::exprs::literal::lit;
556 use crate::expr::exprs::operators::Operator;
557 use crate::expr::exprs::root::is_root;
558 use crate::expr::exprs::root::root;
559
560 #[derive(Default)]
561 pub struct ExprLitCollector<'a>(pub Vec<&'a Expression>);
562
563 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
564 type NodeTy = Expression;
565
566 fn visit_down(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
567 if node.is::<Literal>() {
568 self.0.push(node)
569 }
570 Ok(TraversalOrder::Continue)
571 }
572
573 fn visit_up(&mut self, _node: &'a Expression) -> VortexResult<TraversalOrder> {
574 Ok(TraversalOrder::Continue)
575 }
576 }
577
578 fn expr_col_to_lit_transform(
579 node: Expression,
580 idx: &mut i32,
581 ) -> VortexResult<Transformed<Expression>> {
582 if node.is::<GetItem>() {
583 let lit_id = *idx;
584 *idx += 1;
585 Ok(Transformed::yes(lit(lit_id)))
586 } else {
587 Ok(Transformed::no(node))
588 }
589 }
590
591 #[derive(Default)]
592 pub struct SkipDownRewriter;
593
594 impl NodeRewriter for SkipDownRewriter {
595 type NodeTy = Expression;
596
597 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
598 Ok(Transformed {
599 value: node,
600 order: TraversalOrder::Skip,
601 changed: false,
602 })
603 }
604
605 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
606 Ok(Transformed::yes(root()))
607 }
608 }
609
610 #[test]
611 fn expr_deep_visitor_test() {
612 let col1: Expression = col("col1");
613 let lit1 = lit(1);
614 let expr = eq(col1, lit1);
615 let lit2 = lit(2);
616 let expr = and(expr, lit2);
617 let mut printer = ExprLitCollector::default();
618 expr.accept(&mut printer).unwrap();
619 assert_eq!(printer.0.len(), 2);
620 }
621
622 #[test]
623 fn expr_deep_mut_visitor_test() {
624 let col1: Expression = col("col1");
625 let col2: Expression = col("col2");
626 let expr = eq(col1, col2);
627 let lit2 = lit(2);
628 let expr = and(expr, lit2);
629
630 let mut idx = 0_i32;
631 let new = expr
632 .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
633 .unwrap();
634 assert!(new.changed);
635
636 let expr = new.value;
637
638 let mut printer = ExprLitCollector::default();
639 expr.accept(&mut printer).unwrap();
640 assert_eq!(printer.0.len(), 3);
641 }
642
643 #[test]
644 fn expr_skip_test() {
645 let col1: Expression = col("col1");
646 let col2: Expression = col("col2");
647 let expr1 = eq(col1, col2);
648 let col3: Expression = col("col3");
649 let col4: Expression = col("col4");
650 let expr2 = not_eq(col3, col4);
651 let expr = and(expr1, expr2);
652
653 let mut nodes = Vec::new();
654 pre_order_visit_down(&expr, |node: &Expression| {
655 if node.is::<GetItem>() {
656 nodes.push(node)
657 }
658 if let Some(bin) = node.as_opt::<Binary>()
659 && bin.operator() == Operator::Eq
660 {
661 return Ok(TraversalOrder::Skip);
662 }
663 Ok(TraversalOrder::Continue)
664 })
665 .unwrap();
666
667 let nodes: HashSet<Expression> = HashSet::from_iter(nodes.into_iter().cloned());
668 assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
669 }
670
671 #[test]
672 fn expr_stop_test() {
673 let col1: Expression = col("col1");
674 let col2: Expression = col("col2");
675 let expr1 = eq(col1, col2);
676 let col3: Expression = col("col3");
677 let col4: Expression = col("col4");
678 let expr2 = not_eq(col3, col4);
679 let expr = and(expr1, expr2);
680
681 let mut nodes = Vec::new();
682 pre_order_visit_down(&expr, |node: &Expression| {
683 if node.is::<GetItem>() {
684 nodes.push(node)
685 }
686 if let Some(bin) = node.as_opt::<Binary>()
687 && bin.operator() == Operator::Eq
688 {
689 return Ok(TraversalOrder::Stop);
690 }
691 Ok(TraversalOrder::Continue)
692 })
693 .unwrap();
694
695 assert!(nodes.is_empty());
696 }
697
698 #[test]
699 fn expr_skip_down_visit_up() {
700 let col = col("col");
701
702 let mut visitor = SkipDownRewriter;
703 let result = col.rewrite(&mut visitor).unwrap();
704
705 assert!(result.changed);
706 assert!(is_root(&result.value));
707 }
708}