Skip to main content

vortex_array/expr/analysis/
annotation.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_utils::aliases::hash_map::HashMap;
9use vortex_utils::aliases::hash_set::HashSet;
10
11use crate::expr::Expression;
12use crate::expr::traversal::NodeExt;
13use crate::expr::traversal::NodeVisitor;
14use crate::expr::traversal::TraversalOrder;
15
16pub trait Annotation: Clone + Hash + Eq {}
17
18impl<A> Annotation for A where A: Clone + Hash + Eq {}
19
20pub trait AnnotationFn: Fn(&Expression) -> Vec<Self::Annotation> {
21    type Annotation: Annotation;
22}
23
24impl<A, F> AnnotationFn for F
25where
26    A: Annotation,
27    F: Fn(&Expression) -> Vec<A>,
28{
29    type Annotation = A;
30}
31
32pub type Annotations<'a, A> = HashMap<&'a Expression, HashSet<A>>;
33
34/// Walk the expression tree and annotate each expression with zero or more annotations.
35///
36/// Returns a map of each expression to all annotations that any of its descendent (child)
37/// expressions are annotated with.
38pub fn descendent_annotations<A: AnnotationFn>(
39    expr: &Expression,
40    annotate: A,
41) -> Annotations<'_, A::Annotation> {
42    let mut visitor = AnnotationVisitor {
43        annotations: Default::default(),
44        annotate,
45        propagate_up: true,
46    };
47    expr.accept(&mut visitor).vortex_expect("Infallible");
48    visitor.annotations
49}
50
51/// Walk the expression tree and annotate each expression with zero or more
52/// annotations.
53///
54/// Returns a map of each expression to all annotations. Annotations of
55/// children are not propagated to parents.
56pub fn direct_annotations<A: AnnotationFn>(
57    expr: &Expression,
58    annotate: A,
59) -> Annotations<'_, A::Annotation> {
60    let mut visitor = AnnotationVisitor {
61        annotations: Default::default(),
62        annotate,
63        propagate_up: false,
64    };
65    expr.accept(&mut visitor).vortex_expect("Infallible");
66    visitor.annotations
67}
68
69struct AnnotationVisitor<'a, A: AnnotationFn> {
70    annotations: Annotations<'a, A::Annotation>,
71    annotate: A,
72    propagate_up: bool,
73}
74
75impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
76    type NodeTy = Expression;
77
78    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
79        let annotations = (self.annotate)(node);
80        if annotations.is_empty() {
81            // If the annotate fn returns empty, we do not annotate this node.
82            Ok(TraversalOrder::Continue)
83        } else {
84            self.annotations
85                .entry(node)
86                .or_default()
87                .extend(annotations);
88            Ok(TraversalOrder::Skip)
89        }
90    }
91
92    fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
93        if !self.propagate_up {
94            return Ok(TraversalOrder::Continue);
95        }
96        let child_annotations = node
97            .children()
98            .iter()
99            .filter_map(|c| self.annotations.get(c).cloned())
100            .collect::<Vec<_>>();
101
102        let annotations = self.annotations.entry(node).or_default();
103        child_annotations
104            .into_iter()
105            .for_each(|ps| annotations.extend(ps.iter().cloned()));
106
107        Ok(TraversalOrder::Continue)
108    }
109}