vortex_array/expr/analysis/
labeling.rs1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_utils::aliases::hash_map::HashMap;
7
8use crate::expr::Expression;
9use crate::expr::traversal::NodeExt;
10use crate::expr::traversal::NodeVisitor;
11use crate::expr::traversal::TraversalOrder;
12
13pub fn label_tree<L: Clone>(
34 expr: &Expression,
35 self_label: impl Fn(&Expression) -> L,
36 mut merge_child: impl FnMut(L, &L) -> L,
37) -> HashMap<&Expression, L> {
38 let mut visitor = LabelingVisitor {
39 labels: Default::default(),
40 self_label,
41 merge_child: &mut merge_child,
42 };
43 expr.accept(&mut visitor)
44 .vortex_expect("LabelingVisitor is infallible");
45 visitor.labels
46}
47
48struct LabelingVisitor<'a, 'b, L, F, G>
49where
50 F: Fn(&Expression) -> L,
51 G: FnMut(L, &L) -> L,
52{
53 labels: HashMap<&'a Expression, L>,
54 self_label: F,
55 merge_child: &'b mut G,
56}
57
58impl<'a, 'b, L: Clone, F, G> NodeVisitor<'a> for LabelingVisitor<'a, 'b, L, F, G>
59where
60 F: Fn(&Expression) -> L,
61 G: FnMut(L, &L) -> L,
62{
63 type NodeTy = Expression;
64
65 fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66 Ok(TraversalOrder::Continue)
67 }
68
69 fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
70 let self_label = (self.self_label)(node);
71
72 let final_label = node.children().iter().fold(self_label, |acc, child| {
73 let child_label = self
74 .labels
75 .get(child)
76 .vortex_expect("child must have label");
77 (self.merge_child)(acc, child_label)
78 });
79
80 self.labels.insert(node, final_label);
81
82 Ok(TraversalOrder::Continue)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::expr::exprs::binary::eq;
90 use crate::expr::exprs::get_item::col;
91 use crate::expr::exprs::literal::lit;
92
93 #[test]
94 fn test_tree_depth() {
95 let expr = eq(col("col1"), lit(5));
99 let depths = label_tree(
100 &expr,
101 |_node| 1, |self_depth, child_depth| self_depth.max(*child_depth + 1),
103 );
104
105 assert_eq!(depths.get(&expr), Some(&3));
107 }
108
109 #[test]
110 fn test_node_count() {
111 let expr = eq(col("col1"), lit(5));
115 let counts = label_tree(
116 &expr,
117 |_node| 1, |self_count, child_count| self_count + *child_count,
119 );
120
121 assert_eq!(counts.get(&expr), Some(&4));
123 }
124}