vortex_array/expr/analysis/
labeling.rs1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_utils::aliases::hash_map::HashMap;
7
8use crate::expr::Expression;
9use crate::expr::traversal::NodeExt;
10use crate::expr::traversal::NodeVisitor;
11use crate::expr::traversal::TraversalOrder;
12
13pub fn label_tree<L: Clone>(
33 expr: &Expression,
34 self_label: impl Fn(&Expression) -> L,
35 mut merge_child: impl FnMut(L, &L) -> L,
36) -> HashMap<&Expression, L> {
37 let mut visitor = LabelingVisitor {
38 labels: Default::default(),
39 self_label,
40 merge_child: &mut merge_child,
41 };
42 expr.accept(&mut visitor)
43 .vortex_expect("LabelingVisitor is infallible");
44 visitor.labels
45}
46
47struct LabelingVisitor<'a, 'b, L, F, G>
48where
49 F: Fn(&Expression) -> L,
50 G: FnMut(L, &L) -> L,
51{
52 labels: HashMap<&'a Expression, L>,
53 self_label: F,
54 merge_child: &'b mut G,
55}
56
57impl<'a, 'b, L: Clone, F, G> NodeVisitor<'a> for LabelingVisitor<'a, 'b, L, F, G>
58where
59 F: Fn(&Expression) -> L,
60 G: FnMut(L, &L) -> L,
61{
62 type NodeTy = Expression;
63
64 fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
65 Ok(TraversalOrder::Continue)
66 }
67
68 fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
69 let self_label = (self.self_label)(node);
70
71 let final_label = node.children().iter().fold(self_label, |acc, child| {
72 let child_label = self
73 .labels
74 .get(child)
75 .vortex_expect("child must have label");
76 (self.merge_child)(acc, child_label)
77 });
78
79 self.labels.insert(node, final_label);
80
81 Ok(TraversalOrder::Continue)
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use crate::expr::col;
89 use crate::expr::eq;
90 use crate::expr::lit;
91
92 #[test]
93 fn test_tree_depth() {
94 let expr = eq(col("col1"), lit(5));
98 let depths = label_tree(
99 &expr,
100 |_node| 1, |self_depth, child_depth| self_depth.max(*child_depth + 1),
102 );
103
104 assert_eq!(depths.get(&expr), Some(&3));
106 }
107
108 #[test]
109 fn test_node_count() {
110 let expr = eq(col("col1"), lit(5));
114 let counts = label_tree(
115 &expr,
116 |_node| 1, |self_count, child_count| self_count + *child_count,
118 );
119
120 assert_eq!(counts.get(&expr), Some(&4));
122 }
123}