vortex_array/expr/analysis/
annotation.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_utils::aliases::hash_map::HashMap;
9use vortex_utils::aliases::hash_set::HashSet;
10
11use crate::expr::Expression;
12use crate::expr::traversal::NodeExt;
13use crate::expr::traversal::NodeVisitor;
14use crate::expr::traversal::TraversalOrder;
15
16pub trait Annotation: Clone + Hash + Eq {}
17
18impl<A> Annotation for A where A: Clone + Hash + Eq {}
19
20pub trait AnnotationFn: Fn(&Expression) -> Vec<Self::Annotation> {
21    type Annotation: Annotation;
22}
23
24impl<A, F> AnnotationFn for F
25where
26    A: Annotation,
27    F: Fn(&Expression) -> Vec<A>,
28{
29    type Annotation = A;
30}
31
32pub type Annotations<'a, A> = HashMap<&'a Expression, HashSet<A>>;
33
34/// Walk the expression tree and annotate each expression with zero or more annotations.
35///
36/// Returns a map of each expression to all annotations that any of its descendent (child)
37/// expressions are annotated with.
38pub fn descendent_annotations<A: AnnotationFn>(
39    expr: &Expression,
40    annotate: A,
41) -> Annotations<'_, A::Annotation> {
42    let mut visitor = AnnotationVisitor {
43        annotations: Default::default(),
44        annotate,
45    };
46    expr.accept(&mut visitor).vortex_expect("Infallible");
47    visitor.annotations
48}
49
50struct AnnotationVisitor<'a, A: AnnotationFn> {
51    annotations: Annotations<'a, A::Annotation>,
52    annotate: A,
53}
54
55impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
56    type NodeTy = Expression;
57
58    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
59        let annotations = (self.annotate)(node);
60        if annotations.is_empty() {
61            // If the annotate fn returns empty, we do not annotate this node.
62            Ok(TraversalOrder::Continue)
63        } else {
64            self.annotations
65                .entry(node)
66                .or_default()
67                .extend(annotations);
68            Ok(TraversalOrder::Skip)
69        }
70    }
71
72    fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
73        let child_annotations = node
74            .children()
75            .iter()
76            .filter_map(|c| self.annotations.get(c).cloned())
77            .collect::<Vec<_>>();
78
79        let annotations = self.annotations.entry(node).or_default();
80        child_annotations
81            .into_iter()
82            .for_each(|ps| annotations.extend(ps.iter().cloned()));
83
84        Ok(TraversalOrder::Continue)
85    }
86}