vortex_array/expr/analysis/
labeling.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_utils::aliases::hash_map::HashMap;
7
8use crate::expr::Expression;
9use crate::expr::traversal::NodeExt;
10use crate::expr::traversal::NodeVisitor;
11use crate::expr::traversal::TraversalOrder;
12
13/// Label each node in an expression tree using a bottom-up traversal.
14///
15/// This function separates tree labeling into two distinct steps:
16/// 1. **Label Self**: Compute a label for each node based only on the node itself
17/// 2. **Merge Child**: Fold/accumulate labels from children into the node's self-label
18///
19/// The labeling process:
20/// - First, `self_label` is called on the node to produce its self-label
21/// - Then, for each child, `merge_child` is called with `(self_label, child_label)`
22///   to fold the child label into the self_label
23/// - This produces the final label for the node
24///
25/// # Parameters
26///
27/// - `expr`: The root expression to label
28/// - `self_label`: Function that computes a label for a single node
29/// - `merge_child`: Mutable function that folds child labels into an accumulator.
30///   Takes `(self_label, child_label)` and returns the updated accumulator.
31///   Called once per child, with the initial accumulator being the node's self-label.
32///
33pub fn label_tree<L: Clone>(
34    expr: &Expression,
35    self_label: impl Fn(&Expression) -> L,
36    mut merge_child: impl FnMut(L, &L) -> L,
37) -> HashMap<&Expression, L> {
38    let mut visitor = LabelingVisitor {
39        labels: Default::default(),
40        self_label,
41        merge_child: &mut merge_child,
42    };
43    expr.accept(&mut visitor)
44        .vortex_expect("LabelingVisitor is infallible");
45    visitor.labels
46}
47
48struct LabelingVisitor<'a, 'b, L, F, G>
49where
50    F: Fn(&Expression) -> L,
51    G: FnMut(L, &L) -> L,
52{
53    labels: HashMap<&'a Expression, L>,
54    self_label: F,
55    merge_child: &'b mut G,
56}
57
58impl<'a, 'b, L: Clone, F, G> NodeVisitor<'a> for LabelingVisitor<'a, 'b, L, F, G>
59where
60    F: Fn(&Expression) -> L,
61    G: FnMut(L, &L) -> L,
62{
63    type NodeTy = Expression;
64
65    fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66        Ok(TraversalOrder::Continue)
67    }
68
69    fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
70        let self_label = (self.self_label)(node);
71
72        let final_label = node.children().iter().fold(self_label, |acc, child| {
73            let child_label = self
74                .labels
75                .get(child)
76                .vortex_expect("child must have label");
77            (self.merge_child)(acc, child_label)
78        });
79
80        self.labels.insert(node, final_label);
81
82        Ok(TraversalOrder::Continue)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::expr::exprs::binary::eq;
90    use crate::expr::exprs::get_item::col;
91    use crate::expr::exprs::literal::lit;
92
93    #[test]
94    fn test_tree_depth() {
95        // Expression: $.col1 = 5
96        // Tree: eq(get_item(root(), "col1"), lit(5))
97        // Depth: root = 1, get_item = 2, lit = 1, eq = 3
98        let expr = eq(col("col1"), lit(5));
99        let depths = label_tree(
100            &expr,
101            |_node| 1, // Each node has depth 1 by itself
102            |self_depth, child_depth| self_depth.max(*child_depth + 1),
103        );
104
105        // The root (eq) should have depth 3
106        assert_eq!(depths.get(&expr), Some(&3));
107    }
108
109    #[test]
110    fn test_node_count() {
111        // Count total nodes in subtree (including self)
112        // Tree: eq(get_item(root(), "col1"), lit(5))
113        // Nodes: eq, get_item, root, lit = 4
114        let expr = eq(col("col1"), lit(5));
115        let counts = label_tree(
116            &expr,
117            |_node| 1, // Each node counts as 1
118            |self_count, child_count| self_count + *child_count,
119        );
120
121        // Root should have count of 4 (eq, get_item, root, lit)
122        assert_eq!(counts.get(&expr), Some(&4));
123    }
124}