1mod references;
2mod visitor;
3
4pub use references::ReferenceCollector;
5pub use visitor::{pre_order_visit_down, pre_order_visit_up};
6use vortex_error::VortexResult;
7
8use crate::ExprRef;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum TraversalOrder {
21 Skip,
24
25 Stop,
27
28 Continue,
30}
31
32#[derive(Debug, Clone)]
33pub struct TransformResult<T> {
34 pub result: T,
35 order: TraversalOrder,
36 changed: bool,
37}
38
39impl<T> TransformResult<T> {
40 pub fn yes(result: T) -> Self {
41 Self {
42 result,
43 order: TraversalOrder::Continue,
44 changed: true,
45 }
46 }
47
48 pub fn no(result: T) -> Self {
49 Self {
50 result,
51 order: TraversalOrder::Continue,
52 changed: false,
53 }
54 }
55}
56
57pub trait NodeVisitor<'a> {
58 type NodeTy: Node;
59
60 fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
61 Ok(TraversalOrder::Continue)
62 }
63
64 fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
65 Ok(TraversalOrder::Continue)
66 }
67}
68
69pub trait MutNodeVisitor {
70 type NodeTy: Node;
71
72 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
73 Ok(TraversalOrder::Continue)
74 }
75
76 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>>;
77}
78
79pub enum FoldDown<Out, Context> {
80 Abort(Out),
82 SkipChildren(Out),
84 Continue(Context),
86}
87
88#[derive(Debug)]
89pub enum FoldUp<Out> {
90 Abort(Out),
92 Continue(Out),
94}
95
96impl<Out> FoldUp<Out> {
97 pub fn result(self) -> Out {
98 match self {
99 FoldUp::Abort(out) => out,
100 FoldUp::Continue(out) => out,
101 }
102 }
103}
104
105pub trait Folder<'a> {
106 type NodeTy: Node;
107 type Out;
108 type Context: Clone;
109
110 fn visit_down(
111 &mut self,
112 _node: &'a Self::NodeTy,
113 context: Self::Context,
114 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
115 Ok(FoldDown::Continue(context))
116 }
117
118 fn visit_up(
119 &mut self,
120 node: &'a Self::NodeTy,
121 context: Self::Context,
122 children: Vec<Self::Out>,
123 ) -> VortexResult<FoldUp<Self::Out>>;
124}
125
126pub trait FolderMut {
127 type NodeTy: Node;
128 type Out;
129 type Context: Clone;
130
131 fn visit_down(
132 &mut self,
133 _node: &Self::NodeTy,
134 context: Self::Context,
135 ) -> VortexResult<FoldDown<Self::Out, Self::Context>> {
136 Ok(FoldDown::Continue(context))
137 }
138
139 fn visit_up(
140 &mut self,
141 node: Self::NodeTy,
142 context: Self::Context,
143 children: Vec<Self::Out>,
144 ) -> VortexResult<FoldUp<Self::Out>>;
145}
146
147pub trait Node: Sized {
148 fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>(
149 &'a self,
150 _visitor: &mut V,
151 ) -> VortexResult<TraversalOrder>;
152
153 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
154 &'a self,
155 visitor: &mut V,
156 context: V::Context,
157 ) -> VortexResult<FoldUp<V::Out>>;
158
159 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
160 self,
161 _visitor: &mut V,
162 ) -> VortexResult<TransformResult<Self>>;
163
164 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
165 self,
166 _visitor: &mut V,
167 _context: V::Context,
168 ) -> VortexResult<FoldUp<V::Out>>;
169}
170
171impl Node for ExprRef {
172 fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
174 &'a self,
175 visitor: &mut V,
176 ) -> VortexResult<TraversalOrder> {
177 let mut ord = visitor.visit_down(self)?;
178 if ord == TraversalOrder::Stop {
179 return Ok(TraversalOrder::Stop);
180 }
181 if ord == TraversalOrder::Skip {
182 return Ok(TraversalOrder::Continue);
183 }
184 for child in self.children() {
185 if ord != TraversalOrder::Continue {
186 return Ok(ord);
187 }
188 ord = child.accept(visitor)?;
189 }
190 if ord == TraversalOrder::Stop {
191 return Ok(TraversalOrder::Stop);
192 }
193 visitor.visit_up(self)
194 }
195
196 fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
197 &'a self,
198 visitor: &mut V,
199 context: V::Context,
200 ) -> VortexResult<FoldUp<V::Out>> {
201 let children = match visitor.visit_down(self, context.clone())? {
202 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
203 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
204 FoldDown::Continue(child_context) => {
205 let mut new_children = Vec::with_capacity(self.children().len());
206 for child in self.children() {
207 match child.accept_with_context(visitor, child_context.clone())? {
208 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
209 FoldUp::Continue(out) => new_children.push(out),
210 }
211 }
212 new_children
213 }
214 };
215
216 visitor.visit_up(self, context, children)
217 }
218
219 fn transform<V: MutNodeVisitor<NodeTy = Self>>(
221 self,
222 visitor: &mut V,
223 ) -> VortexResult<TransformResult<Self>> {
224 match visitor.visit_down(&self)? {
225 TraversalOrder::Stop => Ok(TransformResult {
226 result: self,
227 order: TraversalOrder::Stop,
228 changed: false,
229 }),
230 TraversalOrder::Skip => visitor.visit_up(self),
231 TraversalOrder::Continue => {
232 let mut new_children = Vec::with_capacity(self.children().len());
233 let mut changed = false;
234 let mut stopped = false;
235 for child in self.children() {
236 if stopped {
237 new_children.push(child.clone());
238 continue;
239 }
240 let TransformResult {
241 result: new_child,
242 order: child_order,
243 changed: child_changed,
244 } = child.clone().transform(visitor)?;
245 new_children.push(new_child);
246 changed |= child_changed;
247 stopped |= child_order == TraversalOrder::Stop;
248 }
249
250 if changed {
251 let up = visitor.visit_up(self.replacing_children(new_children))?;
252 Ok(TransformResult::yes(up.result))
253 } else {
254 visitor.visit_up(self)
255 }
256 }
257 }
258 }
259
260 fn transform_with_context<V: FolderMut<NodeTy = Self>>(
261 self,
262 visitor: &mut V,
263 context: V::Context,
264 ) -> VortexResult<FoldUp<V::Out>> {
265 let children = match visitor.visit_down(&self, context.clone())? {
266 FoldDown::Abort(out) => return Ok(FoldUp::Abort(out)),
267 FoldDown::SkipChildren(out) => return Ok(FoldUp::Continue(out)),
268 FoldDown::Continue(child_context) => {
269 let mut new_children = Vec::with_capacity(self.children().len());
270 for child in self.children() {
271 match child
272 .clone()
273 .transform_with_context(visitor, child_context.clone())?
274 {
275 FoldUp::Abort(out) => return Ok(FoldUp::Abort(out)),
276 FoldUp::Continue(out) => new_children.push(out),
277 }
278 }
279 new_children
280 }
281 };
282
283 visitor.visit_up(self, context, children)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use std::sync::Arc;
290
291 use vortex_array::aliases::hash_set::HashSet;
292 use vortex_error::VortexResult;
293
294 use crate::traversal::visitor::pre_order_visit_down;
295 use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
296 use crate::{
297 BinaryExpr, ExprRef, FieldName, GetItem, Identity, Literal, Operator, VortexExpr,
298 VortexExprExt, col,
299 };
300
301 #[derive(Default)]
302 pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>);
303
304 impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> {
305 type NodeTy = ExprRef;
306
307 fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
308 if node.as_any().is::<Literal>() {
309 self.0.push(node)
310 }
311 Ok(TraversalOrder::Continue)
312 }
313
314 fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult<TraversalOrder> {
315 Ok(TraversalOrder::Continue)
316 }
317 }
318
319 #[derive(Default)]
320 pub struct ExprColToLit(i32);
321
322 impl MutNodeVisitor for ExprColToLit {
323 type NodeTy = ExprRef;
324
325 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
326 let col = node.as_any().downcast_ref::<GetItem>();
327 if col.is_some() {
328 let id = self.0;
329 self.0 += 1;
330 Ok(TransformResult::yes(Literal::new_expr(id)))
331 } else {
332 Ok(TransformResult::no(node))
333 }
334 }
335 }
336
337 #[derive(Default)]
338 pub struct SkipDownVisitor;
339
340 impl MutNodeVisitor for SkipDownVisitor {
341 type NodeTy = ExprRef;
342
343 fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
344 Ok(TraversalOrder::Skip)
345 }
346
347 fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
348 Ok(TransformResult::yes(Identity::new_expr()))
349 }
350 }
351
352 #[test]
353 fn expr_deep_visitor_test() {
354 let col1: Arc<dyn VortexExpr> = col("col1");
355 let lit1 = Literal::new_expr(1);
356 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone());
357 let lit2 = Literal::new_expr(2);
358 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
359 let mut printer = ExprLitCollector::default();
360 expr.accept(&mut printer).unwrap();
361 assert_eq!(printer.0.len(), 2);
362 }
363
364 #[test]
365 fn expr_deep_mut_visitor_test() {
366 let col1: Arc<dyn VortexExpr> = col("col1");
367 let col2: Arc<dyn VortexExpr> = col("col2");
368 let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
369 let lit2 = Literal::new_expr(2);
370 let expr = BinaryExpr::new_expr(expr, Operator::And, lit2);
371 let mut printer = ExprColToLit::default();
372 let new = expr.transform(&mut printer).unwrap();
373 assert!(new.changed);
374
375 let expr = new.result;
376
377 let mut printer = ExprLitCollector::default();
378 expr.accept(&mut printer).unwrap();
379 assert_eq!(printer.0.len(), 3);
380 }
381
382 #[test]
383 fn expr_skip_test() {
384 let col1: Arc<dyn VortexExpr> = col("col1");
385 let col2: Arc<dyn VortexExpr> = col("col2");
386 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
387 let col3: Arc<dyn VortexExpr> = col("col3");
388 let col4: Arc<dyn VortexExpr> = col("col4");
389 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
390 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
391
392 let mut nodes = Vec::new();
393 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
394 if node.as_any().is::<GetItem>() {
395 nodes.push(node)
396 }
397 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
398 if bin.op() == Operator::Eq {
399 return Ok(TraversalOrder::Skip);
400 }
401 }
402 Ok(TraversalOrder::Continue)
403 }))
404 .unwrap();
405
406 assert_eq!(
407 nodes
408 .into_iter()
409 .map(|x| x.references())
410 .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
411 HashSet::from_iter(vec![FieldName::from("col3"), FieldName::from("col4")])
412 );
413 }
414
415 #[test]
416 fn expr_stop_test() {
417 let col1: Arc<dyn VortexExpr> = col("col1");
418 let col2: Arc<dyn VortexExpr> = col("col2");
419 let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone());
420 let col3: Arc<dyn VortexExpr> = col("col3");
421 let col4: Arc<dyn VortexExpr> = col("col4");
422 let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone());
423 let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2);
424
425 let mut nodes = Vec::new();
426 expr.accept(&mut pre_order_visit_down(|node: &ExprRef| {
427 if node.as_any().is::<GetItem>() {
428 nodes.push(node)
429 }
430 if let Some(bin) = node.as_any().downcast_ref::<BinaryExpr>() {
431 if bin.op() == Operator::Eq {
432 return Ok(TraversalOrder::Stop);
433 }
434 }
435 Ok(TraversalOrder::Continue)
436 }))
437 .unwrap();
438
439 assert_eq!(
440 nodes
441 .into_iter()
442 .map(|x| x.references())
443 .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()),
444 HashSet::from_iter(vec![])
445 );
446 }
447
448 #[test]
449 fn expr_skip_down_visit_up() {
450 let col = col("col");
451
452 let mut visitor = SkipDownVisitor;
453 let result = col.transform(&mut visitor).unwrap();
454
455 assert!(result.changed);
456 assert!(result.result.as_any().downcast_ref::<Identity>().is_some());
457 }
458}