Skip to main content

vortex_array/expr/analysis/
labeling.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_utils::aliases::hash_map::HashMap;
7
8use crate::expr::Expression;
9use crate::expr::traversal::NodeExt;
10use crate::expr::traversal::NodeVisitor;
11use crate::expr::traversal::TraversalOrder;
12
13/// Label each node in an expression tree using a bottom-up traversal.
14///
15/// This function separates tree labeling into two distinct steps:
16/// 1. **Label Self**: Compute a label for each node based only on the node itself
17/// 2. **Merge Child**: Fold/accumulate labels from children into the node's self-label
18///
19/// The labeling process:
20/// - First, `self_label` is called on the node to produce its self-label
21/// - Then, for each child, `merge_child` is called with `(self_label, child_label)`
22///   to fold the child label into the self_label
23/// - This produces the final label for the node
24///
25/// # Parameters
26///
27/// - `expr`: The root expression to label
28/// - `self_label`: Function that computes a label for a single node
29/// - `merge_child`: Mutable function that folds child labels into an accumulator.
30///   Takes `(self_label, child_label)` and returns the updated accumulator.
31///   Called once per child, with the initial accumulator being the node's self-label.
32pub fn label_tree<L: Clone>(
33    expr: &Expression,
34    self_label: impl Fn(&Expression) -> L,
35    mut merge_child: impl FnMut(L, &L) -> L,
36) -> HashMap<&Expression, L> {
37    let mut visitor = LabelingVisitor {
38        labels: Default::default(),
39        self_label,
40        merge_child: &mut merge_child,
41    };
42    expr.accept(&mut visitor)
43        .vortex_expect("LabelingVisitor is infallible");
44    visitor.labels
45}
46
47struct LabelingVisitor<'a, 'b, L, F, G>
48where
49    F: Fn(&Expression) -> L,
50    G: FnMut(L, &L) -> L,
51{
52    labels: HashMap<&'a Expression, L>,
53    self_label: F,
54    merge_child: &'b mut G,
55}
56
57impl<'a, 'b, L: Clone, F, G> NodeVisitor<'a> for LabelingVisitor<'a, 'b, L, F, G>
58where
59    F: Fn(&Expression) -> L,
60    G: FnMut(L, &L) -> L,
61{
62    type NodeTy = Expression;
63
64    fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
65        Ok(TraversalOrder::Continue)
66    }
67
68    fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
69        let self_label = (self.self_label)(node);
70
71        let final_label = node.children().iter().fold(self_label, |acc, child| {
72            let child_label = self
73                .labels
74                .get(child)
75                .vortex_expect("child must have label");
76            (self.merge_child)(acc, child_label)
77        });
78
79        self.labels.insert(node, final_label);
80
81        Ok(TraversalOrder::Continue)
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use crate::expr::col;
89    use crate::expr::eq;
90    use crate::expr::lit;
91
92    #[test]
93    fn test_tree_depth() {
94        // Expression: $.col1 = 5
95        // Tree: eq(get_item(root(), "col1"), lit(5))
96        // Depth: root = 1, get_item = 2, lit = 1, eq = 3
97        let expr = eq(col("col1"), lit(5));
98        let depths = label_tree(
99            &expr,
100            |_node| 1, // Each node has depth 1 by itself
101            |self_depth, child_depth| self_depth.max(*child_depth + 1),
102        );
103
104        // The root (eq) should have depth 3
105        assert_eq!(depths.get(&expr), Some(&3));
106    }
107
108    #[test]
109    fn test_node_count() {
110        // Count total nodes in subtree (including self)
111        // Tree: eq(get_item(root(), "col1"), lit(5))
112        // Nodes: eq, get_item, root, lit = 4
113        let expr = eq(col("col1"), lit(5));
114        let counts = label_tree(
115            &expr,
116            |_node| 1, // Each node counts as 1
117            |self_count, child_count| self_count + *child_count,
118        );
119
120        // Root should have count of 4 (eq, get_item, root, lit)
121        assert_eq!(counts.get(&expr), Some(&4));
122    }
123}