vortex_expr/transform/
immediate_access.rs1use itertools::Itertools;
2use vortex_array::aliases::hash_map::HashMap;
3use vortex_array::aliases::hash_set::HashSet;
4use vortex_dtype::{FieldName, StructDType};
5use vortex_error::{VortexResult, vortex_err};
6
7use crate::traversal::{Node, NodeVisitor, TraversalOrder};
8use crate::{ExprRef, GetItem, Identity, Select};
9
10pub type FieldAccesses<'a> = HashMap<&'a ExprRef, HashSet<FieldName>>;
11
12pub fn immediate_scope_accesses<'a>(
20 expr: &'a ExprRef,
21 scope_dtype: &'a StructDType,
22) -> VortexResult<FieldAccesses<'a>> {
23 ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)
24}
25
26pub fn immediate_scope_access<'a>(
28 expr: &'a ExprRef,
29 scope_dtype: &'a StructDType,
30) -> VortexResult<HashSet<FieldName>> {
31 ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)?
32 .get(expr)
33 .ok_or_else(|| {
34 vortex_err!("Expression missing from scope accesses, this is a internal bug")
35 })
36 .cloned()
37}
38
39struct ImmediateScopeAccessesAnalysis<'a> {
40 sub_expressions: FieldAccesses<'a>,
41 scope_dtype: &'a StructDType,
42}
43
44impl<'a> ImmediateScopeAccessesAnalysis<'a> {
45 fn new(scope_dtype: &'a StructDType) -> Self {
46 Self {
47 sub_expressions: HashMap::new(),
48 scope_dtype,
49 }
50 }
51
52 fn analyze(expr: &'a ExprRef, scope_dtype: &'a StructDType) -> VortexResult<FieldAccesses<'a>> {
53 let mut analysis = Self::new(scope_dtype);
54 expr.accept(&mut analysis)?;
55 Ok(analysis.sub_expressions)
56 }
57}
58
59impl<'a> NodeVisitor<'a> for ImmediateScopeAccessesAnalysis<'a> {
60 type NodeTy = ExprRef;
61
62 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
63 assert!(
64 !node.as_any().is::<Select>(),
65 "cannot analyse select, simply the expression"
66 );
67 if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
68 if get_item
69 .child()
70 .as_any()
71 .downcast_ref::<Identity>()
72 .is_some()
73 {
74 self.sub_expressions
75 .insert(node, HashSet::from_iter(vec![get_item.field().clone()]));
76
77 return Ok(TraversalOrder::Skip);
78 }
79 } else if node.as_any().downcast_ref::<Identity>().is_some() {
80 let st_dtype = &self.scope_dtype;
81 self.sub_expressions
82 .insert(node, st_dtype.names().iter().cloned().collect());
83 }
84
85 Ok(TraversalOrder::Continue)
86 }
87
88 fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
89 let accesses = node
90 .children()
91 .iter()
92 .filter_map(|c| self.sub_expressions.get(c).cloned())
93 .collect_vec();
94
95 let node_accesses = self.sub_expressions.entry(node).or_default();
96 accesses
97 .into_iter()
98 .for_each(|fields| node_accesses.extend(fields.iter().cloned()));
99
100 Ok(TraversalOrder::Continue)
101 }
102}