1mod fold;
9mod references;
10mod visitor;
11
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15use itertools::Itertools;
16pub use references::ReferenceCollector;
17pub use visitor::{pre_order_visit_down, pre_order_visit_up};
18use vortex_error::VortexResult;
19
20use crate::ExprRef;
21use crate::traversal::fold::{
22 FoldDownContext, FoldUp, NodeFolder, NodeFolderContext, NodeFolderContextWrapper,
23};
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum TraversalOrder {
28 Skip,
32 Stop,
34 Continue,
36}
37
38impl TraversalOrder {
39 pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
41 self,
42 f: F,
43 ) -> VortexResult<TraversalOrder> {
44 match self {
45 Self::Skip => Ok(TraversalOrder::Continue),
46 Self::Stop => Ok(self),
47 Self::Continue => f(),
48 }
49 }
50
51 pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
53 self,
54 f: F,
55 ) -> VortexResult<TraversalOrder> {
56 match self {
57 Self::Continue => f(),
58 Self::Skip | Self::Stop => Ok(self),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct Transformed<T> {
65 pub value: T,
67 pub order: TraversalOrder,
69 pub changed: bool,
71}
72
73impl<T> Transformed<T> {
74 pub fn yes(value: T) -> Self {
75 Self {
76 value,
77 order: TraversalOrder::Continue,
78 changed: true,
79 }
80 }
81
82 pub fn no(value: T) -> Self {
83 Self {
84 value,
85 order: TraversalOrder::Continue,
86 changed: false,
87 }
88 }
89
90 pub fn into_inner(self) -> T {
91 self.value
92 }
93
94 pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
96 Transformed {
97 value: f(self.value),
98 order: self.order,
99 changed: self.changed,
100 }
101 }
102}
103
104pub trait NodeVisitor<'a> {
105 type NodeTy: Node;
106
107 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
108 _ = node;
109 Ok(TraversalOrder::Continue)
110 }
111
112 fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
113 _ = node;
114 Ok(TraversalOrder::Continue)
115 }
116}
117
118pub trait NodeRewriter: Sized {
119 type NodeTy: Node;
120
121 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
122 Ok(Transformed::no(node))
123 }
124
125 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
126 Ok(Transformed::no(node))
127 }
128}
129
130pub trait Node: Sized + Clone {
131 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
135 &'a self,
136 f: F,
137 ) -> VortexResult<TraversalOrder>;
138
139 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
143 self,
144 f: F,
145 ) -> VortexResult<Transformed<Self>>;
146
147 fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T;
149
150 fn children_count(&self) -> usize;
152}
153
154pub trait NodeExt: Node {
155 fn rewrite<R: NodeRewriter<NodeTy = Self>>(
157 self,
158 rewriter: &mut R,
159 ) -> VortexResult<Transformed<Self>> {
160 let mut transformed = rewriter.visit_down(self)?;
161
162 let transformed = match transformed.order {
163 TraversalOrder::Stop => Ok(transformed),
164 TraversalOrder::Skip => {
165 transformed.order = TraversalOrder::Continue;
166 Ok(transformed)
167 }
168 TraversalOrder::Continue => transformed
169 .value
170 .map_children(|c| c.rewrite(rewriter))
171 .map(|mut t| {
172 t.changed |= transformed.changed;
173 t
174 }),
175 }?;
176
177 match transformed.order {
178 TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
179 TraversalOrder::Continue => {
180 let mut up_rewrite = rewriter.visit_up(transformed.value)?;
181 up_rewrite.changed |= transformed.changed;
182 Ok(up_rewrite)
183 }
184 }
185 }
186
187 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
189 &'a self,
190 visitor: &mut V,
191 ) -> VortexResult<TraversalOrder> {
192 visitor
193 .visit_down(self)?
194 .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
195 .visit_parent(|| visitor.visit_up(self))
196 }
197
198 fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
200 self,
201 f: F,
202 ) -> VortexResult<Transformed<Self>> {
203 let mut rewriter = FnRewriter {
204 f_down: Some(f),
205 f_up: None,
206 _data: PhantomData,
207 };
208
209 self.rewrite(&mut rewriter)
210 }
211
212 fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
214 self,
215 f: F,
216 ) -> VortexResult<Transformed<Self>> {
217 let mut rewriter = FnRewriter {
218 f_down: None,
219 f_up: Some(f),
220 _data: PhantomData,
221 };
222
223 self.rewrite(&mut rewriter)
224 }
225
226 fn fold_context<R, F: NodeFolderContext<NodeTy = Self, Result = R>>(
228 self,
229 ctx: &F::Context,
230 folder: &mut F,
231 ) -> VortexResult<FoldUp<R>> {
232 let transformed = folder.visit_down(ctx, &self)?;
233 let ctx = match transformed {
234 FoldDownContext::Continue(ctx) => ctx,
235 FoldDownContext::Skip(r) => return Ok(FoldUp::Continue(r)),
236 FoldDownContext::Stop(r) => return Ok(FoldUp::Stop(r)),
237 };
238
239 let mut children = Vec::with_capacity(self.children_count());
240 let mut stop_result = None;
241 self.iter_children(|children_iter| -> VortexResult<()> {
242 for c in children_iter {
243 let t = c.clone().fold_context(&ctx, folder)?;
244 match t {
245 FoldUp::Stop(r) => {
246 stop_result = Some(r);
247 return Ok(());
248 }
249 FoldUp::Continue(r) => {
250 children.push(r);
251 }
252 }
253 }
254 Ok(())
255 })?;
256
257 if let Some(result) = stop_result {
258 return Ok(FoldUp::Stop(result));
259 }
260
261 folder.visit_up(self, &ctx, children)
262 }
263
264 fn fold<R, F: NodeFolder<NodeTy = Self, Result = R>>(
266 self,
267 folder: &mut F,
268 ) -> VortexResult<FoldUp<R>> {
269 let mut folder = NodeFolderContextWrapper { inner: folder };
270 self.fold_context(&(), &mut folder)
271 }
272}
273
274impl<T: Node> NodeExt for T {}
275
276struct FnRewriter<F, T> {
277 f_down: Option<F>,
278 f_up: Option<F>,
279 _data: PhantomData<T>,
280}
281
282impl<F, T> NodeRewriter for FnRewriter<F, T>
283where
284 T: Node,
285 F: FnMut(T) -> VortexResult<Transformed<T>>,
286{
287 type NodeTy = T;
288
289 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
290 if let Some(f) = self.f_down.as_mut() {
291 f(node)
292 } else {
293 Ok(Transformed::no(node))
294 }
295 }
296
297 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
298 if let Some(f) = self.f_up.as_mut() {
299 f(node)
300 } else {
301 Ok(Transformed::no(node))
302 }
303 }
304}
305
306pub trait NodeContainer<'a, T: 'a>: Sized {
311 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
313 &'a self,
314 f: F,
315 ) -> VortexResult<TraversalOrder>;
316
317 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
319 self,
320 f: F,
321 ) -> VortexResult<Transformed<Self>>;
322}
323
324pub trait NodeRefContainer<'a, T: 'a>: Sized {
325 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
326 &self,
327 f: F,
328 ) -> VortexResult<TraversalOrder>;
329}
330
331impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for Vec<&'a C> {
332 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
333 &self,
334 mut f: F,
335 ) -> VortexResult<TraversalOrder> {
336 let mut order = TraversalOrder::Continue;
337
338 for c in self {
339 order = c.apply_elements(&mut f)?;
340 match order {
341 TraversalOrder::Continue | TraversalOrder::Skip => {}
342 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
343 }
344 }
345
346 Ok(order)
347 }
348}
349
350impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
351 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
352 &'a self,
353 f: F,
354 ) -> VortexResult<TraversalOrder> {
355 self.as_ref().apply_elements(f)
356 }
357
358 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
359 self,
360 f: F,
361 ) -> VortexResult<Transformed<Box<C>>> {
362 Ok((*self).map_elements(f)?.map(Box::new))
363 }
364}
365
366impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
367where
368 T: 'a,
369 C: NodeContainer<'a, T> + Clone,
370{
371 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
372 &'a self,
373 f: F,
374 ) -> VortexResult<TraversalOrder> {
375 self.as_ref().apply_elements(f)
376 }
377
378 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
379 self,
380 f: F,
381 ) -> VortexResult<Transformed<Arc<C>>> {
382 Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
383 }
384}
385
386impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
387 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
388 &'a self,
389 mut f: F,
390 ) -> VortexResult<TraversalOrder> {
391 let [lhs, rhs] = self;
392 match lhs.apply_elements(&mut f)? {
393 TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
394 TraversalOrder::Stop => Ok(TraversalOrder::Stop),
395 }
396 }
397
398 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
399 self,
400 mut f: F,
401 ) -> VortexResult<Transformed<[C; 2]>> {
402 let [lhs, rhs] = self;
403 let transformed = lhs.map_elements(&mut f)?;
404 match transformed.order {
405 TraversalOrder::Skip | TraversalOrder::Continue => {
406 let mut t = rhs.map_elements(&mut f)?;
407 t.changed |= transformed.changed;
408 Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
409 }
410 TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
411 }
412 }
413}
414
415impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
416 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
417 &'a self,
418 mut f: F,
419 ) -> VortexResult<TraversalOrder> {
420 let mut order = TraversalOrder::Continue;
421
422 for c in self {
423 order = c.apply_elements(&mut f)?;
424 match order {
425 TraversalOrder::Continue | TraversalOrder::Skip => {}
426 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
427 }
428 }
429
430 Ok(order)
431 }
432
433 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
434 self,
435 mut f: F,
436 ) -> VortexResult<Transformed<Self>> {
437 let mut order = TraversalOrder::Continue;
438 let mut changed = false;
439
440 let value = self
441 .into_iter()
442 .map(|c| match order {
443 TraversalOrder::Continue | TraversalOrder::Skip => {
444 c.map_elements(&mut f).map(|result| {
445 order = result.order;
446 changed |= result.changed;
447 result.value
448 })
449 }
450 TraversalOrder::Stop => Ok(c),
451 })
452 .collect::<VortexResult<Vec<_>>>()?;
453
454 Ok(Transformed {
455 value,
456 order,
457 changed,
458 })
459 }
460}
461
462impl<'a> NodeContainer<'a, Self> for ExprRef {
463 fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
464 &'a self,
465 mut f: F,
466 ) -> VortexResult<TraversalOrder> {
467 f(self)
468 }
469
470 fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
471 self,
472 mut f: F,
473 ) -> VortexResult<Transformed<Self>> {
474 f(self)
475 }
476}
477
478impl Node for ExprRef {
479 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
480 &'a self,
481 mut f: F,
482 ) -> VortexResult<TraversalOrder> {
483 self.children().apply_ref_elements(&mut f)
484 }
485
486 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
487 self,
488 f: F,
489 ) -> VortexResult<Transformed<Self>> {
490 let transformed = self
491 .children()
492 .into_iter()
493 .cloned()
494 .collect_vec()
495 .map_elements(f)?;
496
497 if transformed.changed {
498 Ok(Transformed {
499 value: self.with_children(transformed.value)?,
500 order: transformed.order,
501 changed: true,
502 })
503 } else {
504 Ok(Transformed::no(self))
505 }
506 }
507
508 fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
509 f(&mut self.children().into_iter())
510 }
511
512 fn children_count(&self) -> usize {
513 self.children().len()
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use std::sync::Arc;
520
521 use vortex_error::VortexResult;
522 use vortex_utils::aliases::hash_set::HashSet;
523
524 use crate::traversal::visitor::pre_order_visit_down;
525 use crate::traversal::{NodeExt, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
526 use crate::{
527 BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralExpr, LiteralVTable,
528 Operator, VortexExpr, col, is_root, root,
529 };
530
531 #[derive(Default)]
532 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
533
534 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
535 type NodeTy = ExprRef;
536
537 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
538 if node.is::<LiteralVTable>() {
539 self.0.push(node)
540 }
541 Ok(TraversalOrder::Continue)
542 }
543
544 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
545 Ok(TraversalOrder::Continue)
546 }
547 }
548
549 fn expr_col_to_lit_transform(
550 node: ExprRef,
551 idx: &mut i32,
552 ) -> VortexResult<Transformed<ExprRef>> {
553 if node.is::<GetItemVTable>() {
554 let lit_id = *idx;
555 *idx += 1;
556 Ok(Transformed::yes(LiteralExpr::new_expr(lit_id)))
557 } else {
558 Ok(Transformed::no(node))
559 }
560 }
561
562 #[derive(Default)]
563 pub struct SkipDownRewriter;
564
565 impl NodeRewriter for SkipDownRewriter {
566 type NodeTy = ExprRef;
567
568 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
569 Ok(Transformed {
570 value: node,
571 order: TraversalOrder::Skip,
572 changed: false,
573 })
574 }
575
576 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
577 Ok(Transformed::yes(root()))
578 }
579 }
580
581 #[test]
582 fn expr_deep_visitor_test() {
583 let col1: Arc<dyn VortexExpr> = col("col1");
584 let lit1 = LiteralExpr::new(1).into_expr();
585 let expr = BinaryExpr::new(col1.clone(), Operator::Eq, lit1.clone()).into_expr();
586 let lit2 = LiteralExpr::new(2).into_expr();
587 let expr = BinaryExpr::new(expr, Operator::And, lit2).into_expr();
588 let mut printer = ExprLitCollector::default();
589 expr.accept(&mut printer).unwrap();
590 assert_eq!(printer.0.len(), 2);
591 }
592
593 #[test]
594 fn expr_deep_mut_visitor_test() {
595 let col1: Arc<dyn VortexExpr> = col("col1");
596 let col2: Arc<dyn VortexExpr> = col("col2");
597 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
598 let lit2 = LiteralExpr::new_expr(2);
599 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
600
601 let mut idx = 0_i32;
602 let new = expr
603 .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
604 .unwrap();
605 assert!(new.changed);
606
607 let expr = new.value;
608
609 let mut printer = ExprLitCollector::default();
610 expr.accept(&mut printer).unwrap();
611 assert_eq!(printer.0.len(), 3);
612 }
613
614 #[test]
615 fn expr_skip_test() {
616 let col1: Arc<dyn VortexExpr> = col("col1");
617 let col2: Arc<dyn VortexExpr> = col("col2");
618 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
619 let col3: Arc<dyn VortexExpr> = col("col3");
620 let col4: Arc<dyn VortexExpr> = col("col4");
621 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
622 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
623
624 let mut nodes = Vec::new();
625 pre_order_visit_down(&expr, |node: &ExprRef| {
626 if node.is::<GetItemVTable>() {
627 nodes.push(node)
628 }
629 if let Some(bin) = node.as_opt::<BinaryVTable>()
630 && bin.op() == Operator::Eq
631 {
632 return Ok(TraversalOrder::Skip);
633 }
634 Ok(TraversalOrder::Continue)
635 })
636 .unwrap();
637
638 let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
639 assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
640 }
641
642 #[test]
643 fn expr_stop_test() {
644 let col1: Arc<dyn VortexExpr> = col("col1");
645 let col2: Arc<dyn VortexExpr> = col("col2");
646 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
647 let col3: Arc<dyn VortexExpr> = col("col3");
648 let col4: Arc<dyn VortexExpr> = col("col4");
649 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
650 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
651
652 let mut nodes = Vec::new();
653 pre_order_visit_down(&expr, |node: &ExprRef| {
654 if node.is::<GetItemVTable>() {
655 nodes.push(node)
656 }
657 if let Some(bin) = node.as_opt::<BinaryVTable>()
658 && bin.op() == Operator::Eq
659 {
660 return Ok(TraversalOrder::Stop);
661 }
662 Ok(TraversalOrder::Continue)
663 })
664 .unwrap();
665
666 assert!(nodes.is_empty());
667 }
668
669 #[test]
670 fn expr_skip_down_visit_up() {
671 let col = col("col");
672
673 let mut visitor = SkipDownRewriter;
674 let result = col.rewrite(&mut visitor).unwrap();
675
676 assert!(result.changed);
677 assert!(is_root(&result.value));
678 }
679}