vortex_expr/transform/
immediate_access.rs1use itertools::Itertools;
2use vortex_dtype::{FieldName, StructFields};
3use vortex_error::{VortexResult, vortex_err};
4use vortex_utils::aliases::hash_map::HashMap;
5use vortex_utils::aliases::hash_set::HashSet;
6
7use crate::traversal::{Node, NodeVisitor, TraversalOrder};
8use crate::{ExprRef, GetItem, Select, is_root};
9
10pub type FieldAccesses<'a> = HashMap<&'a ExprRef, HashSet<FieldName>>;
11
12pub fn immediate_scope_accesses<'a>(
20 expr: &'a ExprRef,
21 scope_dtype: &'a StructFields,
22) -> VortexResult<FieldAccesses<'a>> {
23 ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)
24}
25
26pub fn immediate_scope_access<'a>(
28 expr: &'a ExprRef,
29 scope_dtype: &'a StructFields,
30) -> VortexResult<HashSet<FieldName>> {
31 ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)?
32 .get(expr)
33 .ok_or_else(|| {
34 vortex_err!("Expression missing from scope accesses, this is a internal bug")
35 })
36 .cloned()
37}
38
39struct ImmediateScopeAccessesAnalysis<'a> {
40 sub_expressions: FieldAccesses<'a>,
41 scope_dtype: &'a StructFields,
42}
43
44impl<'a> ImmediateScopeAccessesAnalysis<'a> {
45 fn new(scope_dtype: &'a StructFields) -> Self {
46 Self {
47 sub_expressions: HashMap::new(),
48 scope_dtype,
49 }
50 }
51
52 fn analyze(
53 expr: &'a ExprRef,
54 scope_dtype: &'a StructFields,
55 ) -> VortexResult<FieldAccesses<'a>> {
56 let mut analysis = Self::new(scope_dtype);
57 expr.accept(&mut analysis)?;
58 Ok(analysis.sub_expressions)
59 }
60}
61
62impl<'a> NodeVisitor<'a> for ImmediateScopeAccessesAnalysis<'a> {
63 type NodeTy = ExprRef;
64
65 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66 assert!(
67 !node.as_any().is::<Select>(),
68 "cannot analyze select, simplify the expression"
69 );
70 if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
71 if is_root(get_item.child()) {
72 self.sub_expressions
73 .insert(node, HashSet::from_iter(vec![get_item.field().clone()]));
74
75 return Ok(TraversalOrder::Skip);
76 }
77 } else if is_root(node) {
78 let st_dtype = &self.scope_dtype;
79 self.sub_expressions
80 .insert(node, st_dtype.names().iter().cloned().collect());
81 }
82
83 Ok(TraversalOrder::Continue)
84 }
85
86 fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
87 let accesses = node
88 .children()
89 .iter()
90 .filter_map(|c| self.sub_expressions.get(c).cloned())
91 .collect_vec();
92
93 let node_accesses = self.sub_expressions.entry(node).or_default();
94 accesses
95 .into_iter()
96 .for_each(|fields| node_accesses.extend(fields.iter().cloned()));
97
98 Ok(TraversalOrder::Continue)
99 }
100}