vortex_expr/transform/
annotations.rs1use std::hash::Hash;
5
6use vortex_error::{VortexExpect, VortexResult};
7use vortex_utils::aliases::hash_map::HashMap;
8use vortex_utils::aliases::hash_set::HashSet;
9
10use crate::ExprRef;
11use crate::traversal::{Node, NodeVisitor, TraversalOrder};
12
13pub trait Annotation: Clone + Hash + Eq {}
14
15impl<A> Annotation for A where A: Clone + Hash + Eq {}
16
17pub trait AnnotationFn: Fn(&ExprRef) -> Vec<Self::Annotation> {
18 type Annotation: Annotation;
19}
20
21impl<A, F> AnnotationFn for F
22where
23 A: Annotation,
24 F: Fn(&ExprRef) -> Vec<A>,
25{
26 type Annotation = A;
27}
28
29pub type Annotations<'a, A> = HashMap<&'a ExprRef, HashSet<A>>;
30
31pub fn descendent_annotations<A: AnnotationFn>(
36 expr: &ExprRef,
37 annotate: A,
38) -> Annotations<'_, A::Annotation> {
39 let mut visitor = AnnotationVisitor {
40 annotations: Default::default(),
41 annotate,
42 };
43 expr.accept(&mut visitor).vortex_expect("Infallible");
44 visitor.annotations
45}
46
47struct AnnotationVisitor<'a, A: AnnotationFn> {
48 annotations: Annotations<'a, A::Annotation>,
49 annotate: A,
50}
51
52impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
53 type NodeTy = ExprRef;
54
55 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
56 let annotations = (self.annotate)(node);
57 if annotations.is_empty() {
58 Ok(TraversalOrder::Continue)
60 } else {
61 self.annotations
62 .entry(node)
63 .or_default()
64 .extend(annotations);
65 Ok(TraversalOrder::Skip)
66 }
67 }
68
69 fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
70 let child_annotations = node
71 .children()
72 .iter()
73 .filter_map(|c| self.annotations.get(c).cloned())
74 .collect::<Vec<_>>();
75
76 let annotations = self.annotations.entry(node).or_default();
77 child_annotations
78 .into_iter()
79 .for_each(|ps| annotations.extend(ps.iter().cloned()));
80
81 Ok(TraversalOrder::Continue)
82 }
83}