vortex_expr/transform/
annotations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use vortex_error::{VortexExpect, VortexResult};
7use vortex_utils::aliases::hash_map::HashMap;
8use vortex_utils::aliases::hash_set::HashSet;
9
10use crate::ExprRef;
11use crate::traversal::{Node, NodeVisitor, TraversalOrder};
12
13pub trait Annotation: Clone + Hash + Eq {}
14
15impl<A> Annotation for A where A: Clone + Hash + Eq {}
16
17pub trait AnnotationFn: Fn(&ExprRef) -> Vec<Self::Annotation> {
18    type Annotation: Annotation;
19}
20
21impl<A, F> AnnotationFn for F
22where
23    A: Annotation,
24    F: Fn(&ExprRef) -> Vec<A>,
25{
26    type Annotation = A;
27}
28
29pub type Annotations<'a, A> = HashMap<&'a ExprRef, HashSet<A>>;
30
31/// Walk the expression tree and annotate each expression with zero or more annotations.
32///
33/// Returns a map of each expression to all annotations that any of its descendent (child)
34/// expressions are annotated with.
35pub fn descendent_annotations<A: AnnotationFn>(
36    expr: &ExprRef,
37    annotate: A,
38) -> Annotations<'_, A::Annotation> {
39    let mut visitor = AnnotationVisitor {
40        annotations: Default::default(),
41        annotate,
42    };
43    expr.accept(&mut visitor).vortex_expect("Infallible");
44    visitor.annotations
45}
46
47struct AnnotationVisitor<'a, A: AnnotationFn> {
48    annotations: Annotations<'a, A::Annotation>,
49    annotate: A,
50}
51
52impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
53    type NodeTy = ExprRef;
54
55    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
56        let annotations = (self.annotate)(node);
57        if annotations.is_empty() {
58            // If the annotate fn returns empty, we do not annotate this node.
59            Ok(TraversalOrder::Continue)
60        } else {
61            self.annotations
62                .entry(node)
63                .or_default()
64                .extend(annotations);
65            Ok(TraversalOrder::Skip)
66        }
67    }
68
69    fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
70        let child_annotations = node
71            .children()
72            .iter()
73            .filter_map(|c| self.annotations.get(c).cloned())
74            .collect::<Vec<_>>();
75
76        let annotations = self.annotations.entry(node).or_default();
77        child_annotations
78            .into_iter()
79            .for_each(|ps| annotations.extend(ps.iter().cloned()));
80
81        Ok(TraversalOrder::Continue)
82    }
83}