vortex_array/expr/analysis/
annotation.rs1use std::hash::Hash;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_utils::aliases::hash_map::HashMap;
9use vortex_utils::aliases::hash_set::HashSet;
10
11use crate::expr::Expression;
12use crate::expr::traversal::NodeExt;
13use crate::expr::traversal::NodeVisitor;
14use crate::expr::traversal::TraversalOrder;
15
16pub trait Annotation: Clone + Hash + Eq {}
17
18impl<A> Annotation for A where A: Clone + Hash + Eq {}
19
20pub trait AnnotationFn: Fn(&Expression) -> Vec<Self::Annotation> {
21 type Annotation: Annotation;
22}
23
24impl<A, F> AnnotationFn for F
25where
26 A: Annotation,
27 F: Fn(&Expression) -> Vec<A>,
28{
29 type Annotation = A;
30}
31
32pub type Annotations<'a, A> = HashMap<&'a Expression, HashSet<A>>;
33
34pub fn descendent_annotations<A: AnnotationFn>(
39 expr: &Expression,
40 annotate: A,
41) -> Annotations<'_, A::Annotation> {
42 let mut visitor = AnnotationVisitor {
43 annotations: Default::default(),
44 annotate,
45 };
46 expr.accept(&mut visitor).vortex_expect("Infallible");
47 visitor.annotations
48}
49
50struct AnnotationVisitor<'a, A: AnnotationFn> {
51 annotations: Annotations<'a, A::Annotation>,
52 annotate: A,
53}
54
55impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
56 type NodeTy = Expression;
57
58 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
59 let annotations = (self.annotate)(node);
60 if annotations.is_empty() {
61 Ok(TraversalOrder::Continue)
63 } else {
64 self.annotations
65 .entry(node)
66 .or_default()
67 .extend(annotations);
68 Ok(TraversalOrder::Skip)
69 }
70 }
71
72 fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
73 let child_annotations = node
74 .children()
75 .iter()
76 .filter_map(|c| self.annotations.get(c).cloned())
77 .collect::<Vec<_>>();
78
79 let annotations = self.annotations.entry(node).or_default();
80 child_annotations
81 .into_iter()
82 .for_each(|ps| annotations.extend(ps.iter().cloned()));
83
84 Ok(TraversalOrder::Continue)
85 }
86}