vortex_array/expr/analysis/
annotation.rs1use std::hash::Hash;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_utils::aliases::hash_map::HashMap;
9use vortex_utils::aliases::hash_set::HashSet;
10
11use crate::expr::Expression;
12use crate::expr::traversal::NodeExt;
13use crate::expr::traversal::NodeVisitor;
14use crate::expr::traversal::TraversalOrder;
15
16pub trait Annotation: Clone + Hash + Eq {}
17
18impl<A> Annotation for A where A: Clone + Hash + Eq {}
19
20pub trait AnnotationFn: Fn(&Expression) -> Vec<Self::Annotation> {
21 type Annotation: Annotation;
22}
23
24impl<A, F> AnnotationFn for F
25where
26 A: Annotation,
27 F: Fn(&Expression) -> Vec<A>,
28{
29 type Annotation = A;
30}
31
32pub type Annotations<'a, A> = HashMap<&'a Expression, HashSet<A>>;
33
34pub fn descendent_annotations<A: AnnotationFn>(
39 expr: &Expression,
40 annotate: A,
41) -> Annotations<'_, A::Annotation> {
42 let mut visitor = AnnotationVisitor {
43 annotations: Default::default(),
44 annotate,
45 propagate_up: true,
46 };
47 expr.accept(&mut visitor).vortex_expect("Infallible");
48 visitor.annotations
49}
50
51pub fn direct_annotations<A: AnnotationFn>(
57 expr: &Expression,
58 annotate: A,
59) -> Annotations<'_, A::Annotation> {
60 let mut visitor = AnnotationVisitor {
61 annotations: Default::default(),
62 annotate,
63 propagate_up: false,
64 };
65 expr.accept(&mut visitor).vortex_expect("Infallible");
66 visitor.annotations
67}
68
69struct AnnotationVisitor<'a, A: AnnotationFn> {
70 annotations: Annotations<'a, A::Annotation>,
71 annotate: A,
72 propagate_up: bool,
73}
74
75impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
76 type NodeTy = Expression;
77
78 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
79 let annotations = (self.annotate)(node);
80 if annotations.is_empty() {
81 Ok(TraversalOrder::Continue)
83 } else {
84 self.annotations
85 .entry(node)
86 .or_default()
87 .extend(annotations);
88 Ok(TraversalOrder::Skip)
89 }
90 }
91
92 fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
93 if !self.propagate_up {
94 return Ok(TraversalOrder::Continue);
95 }
96 let child_annotations = node
97 .children()
98 .iter()
99 .filter_map(|c| self.annotations.get(c).cloned())
100 .collect::<Vec<_>>();
101
102 let annotations = self.annotations.entry(node).or_default();
103 child_annotations
104 .into_iter()
105 .for_each(|ps| annotations.extend(ps.iter().cloned()));
106
107 Ok(TraversalOrder::Continue)
108 }
109}