1mod references;
9mod visitor;
10
11use std::marker::PhantomData;
12use std::sync::Arc;
13
14use itertools::Itertools;
15pub use references::ReferenceCollector;
16pub use visitor::{pre_order_visit_down, pre_order_visit_up};
17use vortex_error::VortexResult;
18
19use crate::ExprRef;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum TraversalOrder {
24 Skip,
28 Stop,
30 Continue,
32}
33
34impl TraversalOrder {
35 pub fn visit_children<F: FnOnce() -> VortexResult<TraversalOrder>>(
37 self,
38 f: F,
39 ) -> VortexResult<TraversalOrder> {
40 match self {
41 Self::Skip => Ok(TraversalOrder::Continue),
42 Self::Stop => Ok(self),
43 Self::Continue => f(),
44 }
45 }
46
47 pub fn visit_parent<F: FnOnce() -> VortexResult<TraversalOrder>>(
49 self,
50 f: F,
51 ) -> VortexResult<TraversalOrder> {
52 match self {
53 Self::Continue => f(),
54 Self::Skip | Self::Stop => Ok(self),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct Transformed<T> {
61 pub value: T,
63 pub order: TraversalOrder,
65 pub changed: bool,
67}
68
69impl<T> Transformed<T> {
70 pub fn yes(value: T) -> Self {
71 Self {
72 value,
73 order: TraversalOrder::Continue,
74 changed: true,
75 }
76 }
77
78 pub fn no(value: T) -> Self {
79 Self {
80 value,
81 order: TraversalOrder::Continue,
82 changed: false,
83 }
84 }
85
86 pub fn into_inner(self) -> T {
87 self.value
88 }
89
90 pub fn map<O, F: FnOnce(T) -> O>(self, f: F) -> Transformed<O> {
92 Transformed {
93 value: f(self.value),
94 order: self.order,
95 changed: self.changed,
96 }
97 }
98}
99
100pub trait NodeVisitor<'a> {
101 type NodeTy: Node;
102
103 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
104 _ = node;
105 Ok(TraversalOrder::Continue)
106 }
107
108 fn visit_up(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
109 _ = node;
110 Ok(TraversalOrder::Continue)
111 }
112}
113
114pub trait NodeRewriter: Sized {
115 type NodeTy: Node;
116
117 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
118 Ok(Transformed::no(node))
119 }
120
121 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
122 Ok(Transformed::no(node))
123 }
124}
125
126pub trait Node: Sized + Clone {
127 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
131 &'a self,
132 f: F,
133 ) -> VortexResult<TraversalOrder>;
134
135 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
139 self,
140 f: F,
141 ) -> VortexResult<Transformed<Self>>;
142
143 fn rewrite<R: NodeRewriter<NodeTy = Self>>(
145 self,
146 rewriter: &mut R,
147 ) -> VortexResult<Transformed<Self>> {
148 let mut transformed = rewriter.visit_down(self)?;
149
150 let transformed = match transformed.order {
151 TraversalOrder::Stop => Ok(transformed),
152 TraversalOrder::Skip => {
153 transformed.order = TraversalOrder::Continue;
154 Ok(transformed)
155 }
156 TraversalOrder::Continue => transformed
157 .value
158 .map_children(|c| c.rewrite(rewriter))
159 .map(|mut t| {
160 t.changed |= transformed.changed;
161 t
162 }),
163 }?;
164
165 match transformed.order {
166 TraversalOrder::Stop | TraversalOrder::Skip => Ok(transformed),
167 TraversalOrder::Continue => {
168 let mut up_rewrite = rewriter.visit_up(transformed.value)?;
169 up_rewrite.changed |= transformed.changed;
170 Ok(up_rewrite)
171 }
172 }
173 }
174
175 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
177 &'a self,
178 visitor: &mut V,
179 ) -> VortexResult<TraversalOrder> {
180 visitor
181 .visit_down(self)?
182 .visit_children(|| self.apply_children(|c| c.accept(visitor)))?
183 .visit_parent(|| visitor.visit_up(self))
184 }
185
186 fn transform_down<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
188 self,
189 f: F,
190 ) -> VortexResult<Transformed<Self>> {
191 let mut rewriter = FnRewriter {
192 f_down: Some(f),
193 f_up: None,
194 _data: PhantomData,
195 };
196
197 self.rewrite(&mut rewriter)
198 }
199
200 fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
202 self,
203 f: F,
204 ) -> VortexResult<Transformed<Self>> {
205 let mut rewriter = FnRewriter {
206 f_down: None,
207 f_up: Some(f),
208 _data: PhantomData,
209 };
210
211 self.rewrite(&mut rewriter)
212 }
213}
214
215struct FnRewriter<F, T> {
216 f_down: Option<F>,
217 f_up: Option<F>,
218 _data: PhantomData<T>,
219}
220
221impl<F, T> NodeRewriter for FnRewriter<F, T>
222where
223 T: Node,
224 F: FnMut(T) -> VortexResult<Transformed<T>>,
225{
226 type NodeTy = T;
227
228 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
229 if let Some(f) = self.f_down.as_mut() {
230 f(node)
231 } else {
232 Ok(Transformed::no(node))
233 }
234 }
235
236 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
237 if let Some(f) = self.f_up.as_mut() {
238 f(node)
239 } else {
240 Ok(Transformed::no(node))
241 }
242 }
243}
244
245pub trait NodeContainer<'a, T: 'a>: Sized {
250 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
252 &'a self,
253 f: F,
254 ) -> VortexResult<TraversalOrder>;
255
256 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
258 self,
259 f: F,
260 ) -> VortexResult<Transformed<Self>>;
261}
262
263pub trait NodeRefContainer<'a, T: 'a>: Sized {
264 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
265 &self,
266 f: F,
267 ) -> VortexResult<TraversalOrder>;
268}
269
270impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeRefContainer<'a, T> for Vec<&'a C> {
271 fn apply_ref_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
272 &self,
273 mut f: F,
274 ) -> VortexResult<TraversalOrder> {
275 let mut order = TraversalOrder::Continue;
276
277 for c in self {
278 order = c.apply_elements(&mut f)?;
279 match order {
280 TraversalOrder::Continue | TraversalOrder::Skip => {}
281 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
282 }
283 }
284
285 Ok(order)
286 }
287}
288
289impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Box<C> {
290 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
291 &'a self,
292 f: F,
293 ) -> VortexResult<TraversalOrder> {
294 self.as_ref().apply_elements(f)
295 }
296
297 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
298 self,
299 f: F,
300 ) -> VortexResult<Transformed<Box<C>>> {
301 Ok((*self).map_elements(f)?.map(Box::new))
302 }
303}
304
305impl<'a, T, C> NodeContainer<'a, T> for Arc<C>
306where
307 T: 'a,
308 C: NodeContainer<'a, T> + Clone,
309{
310 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
311 &'a self,
312 f: F,
313 ) -> VortexResult<TraversalOrder> {
314 self.as_ref().apply_elements(f)
315 }
316
317 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
318 self,
319 f: F,
320 ) -> VortexResult<Transformed<Arc<C>>> {
321 Ok(Arc::unwrap_or_clone(self).map_elements(f)?.map(Arc::new))
322 }
323}
324
325impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for [C; 2] {
326 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
327 &'a self,
328 mut f: F,
329 ) -> VortexResult<TraversalOrder> {
330 let [lhs, rhs] = self;
331 match lhs.apply_elements(&mut f)? {
332 TraversalOrder::Skip | TraversalOrder::Continue => rhs.apply_elements(&mut f),
333 TraversalOrder::Stop => Ok(TraversalOrder::Stop),
334 }
335 }
336
337 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
338 self,
339 mut f: F,
340 ) -> VortexResult<Transformed<[C; 2]>> {
341 let [lhs, rhs] = self;
342 let transformed = lhs.map_elements(&mut f)?;
343 match transformed.order {
344 TraversalOrder::Skip | TraversalOrder::Continue => {
345 let mut t = rhs.map_elements(&mut f)?;
346 t.changed |= transformed.changed;
347 Ok(t.map(|new_lhs| [new_lhs, transformed.value]))
348 }
349 TraversalOrder::Stop => Ok(transformed.map(|new_lhs| [new_lhs, rhs])),
350 }
351 }
352}
353
354impl<'a, T: 'a, C: NodeContainer<'a, T>> NodeContainer<'a, T> for Vec<C> {
355 fn apply_elements<F: FnMut(&'a T) -> VortexResult<TraversalOrder>>(
356 &'a self,
357 mut f: F,
358 ) -> VortexResult<TraversalOrder> {
359 let mut order = TraversalOrder::Continue;
360
361 for c in self {
362 order = c.apply_elements(&mut f)?;
363 match order {
364 TraversalOrder::Continue | TraversalOrder::Skip => {}
365 TraversalOrder::Stop => return Ok(TraversalOrder::Stop),
366 }
367 }
368
369 Ok(order)
370 }
371
372 fn map_elements<F: FnMut(T) -> VortexResult<Transformed<T>>>(
373 self,
374 mut f: F,
375 ) -> VortexResult<Transformed<Self>> {
376 let mut order = TraversalOrder::Continue;
377 let mut changed = false;
378
379 let value = self
380 .into_iter()
381 .map(|c| match order {
382 TraversalOrder::Continue | TraversalOrder::Skip => {
383 c.map_elements(&mut f).map(|result| {
384 order = result.order;
385 changed |= result.changed;
386 result.value
387 })
388 }
389 TraversalOrder::Stop => Ok(c),
390 })
391 .collect::<VortexResult<Vec<_>>>()?;
392
393 Ok(Transformed {
394 value,
395 order,
396 changed,
397 })
398 }
399}
400
401impl<'a> NodeContainer<'a, Self> for ExprRef {
402 fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
403 &'a self,
404 mut f: F,
405 ) -> VortexResult<TraversalOrder> {
406 f(self)
407 }
408
409 fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
410 self,
411 mut f: F,
412 ) -> VortexResult<Transformed<Self>> {
413 f(self)
414 }
415}
416
417impl Node for ExprRef {
418 fn apply_children<'a, F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
419 &'a self,
420 mut f: F,
421 ) -> VortexResult<TraversalOrder> {
422 self.children().apply_ref_elements(&mut f)
423 }
424
425 fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
426 self,
427 f: F,
428 ) -> VortexResult<Transformed<Self>> {
429 let transformed = self
430 .children()
431 .into_iter()
432 .cloned()
433 .collect_vec()
434 .map_elements(f)?;
435
436 if transformed.changed {
437 Ok(Transformed {
438 value: self.with_children(transformed.value)?,
439 order: transformed.order,
440 changed: true,
441 })
442 } else {
443 Ok(Transformed::no(self))
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use std::sync::Arc;
451
452 use vortex_error::VortexResult;
453 use vortex_utils::aliases::hash_set::HashSet;
454
455 use crate::traversal::visitor::pre_order_visit_down;
456 use crate::traversal::{Node, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
457 use crate::{
458 BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralExpr, LiteralVTable,
459 Operator, VortexExpr, col, is_root, root,
460 };
461
462 #[derive(Default)]
463 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
464
465 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
466 type NodeTy = ExprRef;
467
468 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
469 if node.is::<LiteralVTable>() {
470 self.0.push(node)
471 }
472 Ok(TraversalOrder::Continue)
473 }
474
475 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
476 Ok(TraversalOrder::Continue)
477 }
478 }
479
480 fn expr_col_to_lit_transform(
481 node: ExprRef,
482 idx: &mut i32,
483 ) -> VortexResult<Transformed<ExprRef>> {
484 if node.is::<GetItemVTable>() {
485 let lit_id = *idx;
486 *idx += 1;
487 Ok(Transformed::yes(LiteralExpr::new_expr(lit_id)))
488 } else {
489 Ok(Transformed::no(node))
490 }
491 }
492
493 #[derive(Default)]
494 pub struct SkipDownRewriter;
495
496 impl NodeRewriter for SkipDownRewriter {
497 type NodeTy = ExprRef;
498
499 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
500 Ok(Transformed {
501 value: node,
502 order: TraversalOrder::Skip,
503 changed: false,
504 })
505 }
506
507 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
508 Ok(Transformed::yes(root()))
509 }
510 }
511
512 #[test]
513 fn expr_deep_visitor_test() {
514 let col1: Arc<dyn VortexExpr> = col("col1");
515 let lit1 = LiteralExpr::new(1).into_expr();
516 let expr = BinaryExpr::new(col1.clone(), Operator::Eq, lit1.clone()).into_expr();
517 let lit2 = LiteralExpr::new(2).into_expr();
518 let expr = BinaryExpr::new(expr, Operator::And, lit2).into_expr();
519 let mut printer = ExprLitCollector::default();
520 expr.accept(&mut printer).unwrap();
521 assert_eq!(printer.0.len(), 2);
522 }
523
524 #[test]
525 fn expr_deep_mut_visitor_test() {
526 let col1: Arc<dyn VortexExpr> = col("col1");
527 let col2: Arc<dyn VortexExpr> = col("col2");
528 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
529 let lit2 = LiteralExpr::new_expr(2);
530 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
531
532 let mut idx = 0_i32;
533 let new = expr
534 .transform_up(|node| expr_col_to_lit_transform(node, &mut idx))
535 .unwrap();
536 assert!(new.changed);
537
538 let expr = new.value;
539
540 let mut printer = ExprLitCollector::default();
541 expr.accept(&mut printer).unwrap();
542 assert_eq!(printer.0.len(), 3);
543 }
544
545 #[test]
546 fn expr_skip_test() {
547 let col1: Arc<dyn VortexExpr> = col("col1");
548 let col2: Arc<dyn VortexExpr> = col("col2");
549 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
550 let col3: Arc<dyn VortexExpr> = col("col3");
551 let col4: Arc<dyn VortexExpr> = col("col4");
552 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
553 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
554
555 let mut nodes = Vec::new();
556 pre_order_visit_down(&expr, |node: &ExprRef| {
557 if node.is::<GetItemVTable>() {
558 nodes.push(node)
559 }
560 if let Some(bin) = node.as_opt::<BinaryVTable>() {
561 if bin.op() == Operator::Eq {
562 return Ok(TraversalOrder::Skip);
563 }
564 }
565 Ok(TraversalOrder::Continue)
566 })
567 .unwrap();
568
569 let nodes: HashSet<ExprRef> = HashSet::from_iter(nodes.into_iter().cloned());
570 assert_eq!(nodes, HashSet::from_iter([col("col3"), col("col4")]));
571 }
572
573 #[test]
574 fn expr_stop_test() {
575 let col1: Arc<dyn VortexExpr> = col("col1");
576 let col2: Arc<dyn VortexExpr> = col("col2");
577 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
578 let col3: Arc<dyn VortexExpr> = col("col3");
579 let col4: Arc<dyn VortexExpr> = col("col4");
580 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
581 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
582
583 let mut nodes = Vec::new();
584 pre_order_visit_down(&expr, |node: &ExprRef| {
585 if node.is::<GetItemVTable>() {
586 nodes.push(node)
587 }
588 if let Some(bin) = node.as_opt::<BinaryVTable>() {
589 if bin.op() == Operator::Eq {
590 return Ok(TraversalOrder::Stop);
591 }
592 }
593 Ok(TraversalOrder::Continue)
594 })
595 .unwrap();
596
597 assert!(nodes.is_empty());
598 }
599
600 #[test]
601 fn expr_skip_down_visit_up() {
602 let col = col("col");
603
604 let mut visitor = SkipDownRewriter;
605 let result = col.rewrite(&mut visitor).unwrap();
606
607 assert!(result.changed);
608 assert!(is_root(&result.value));
609 }
610}