vortex_expr/transform/
immediate_access.rs

1use itertools::Itertools;
2use vortex_array::aliases::hash_map::HashMap;
3use vortex_array::aliases::hash_set::HashSet;
4use vortex_dtype::{FieldName, StructDType};
5use vortex_error::{VortexResult, vortex_err};
6
7use crate::traversal::{Node, NodeVisitor, TraversalOrder};
8use crate::{ExprRef, GetItem, Identity, Select};
9
10pub type FieldAccesses<'a> = HashMap<&'a ExprRef, HashSet<FieldName>>;
11
12/// For all subexpressions in an expression, find the fields that are accessed directly from the
13/// scope, but not any fields in those fields
14/// e.g. scope = {a: {b: .., c: ..}, d: ..}, expr = ident().a.b + ident().d accesses {a,d} (not b).
15///
16/// Note: This is a very naive, but simple analysis to find the fields that are accessed directly on an
17/// identity node. This is combined to provide an over-approximation of the fields that are accessed
18/// by an expression.
19pub fn immediate_scope_accesses<'a>(
20    expr: &'a ExprRef,
21    scope_dtype: &'a StructDType,
22) -> VortexResult<FieldAccesses<'a>> {
23    ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)
24}
25
26/// This returns the immediate scope_access (as explained `immediate_scope_accesses`) for `expr`.
27pub fn immediate_scope_access<'a>(
28    expr: &'a ExprRef,
29    scope_dtype: &'a StructDType,
30) -> VortexResult<HashSet<FieldName>> {
31    ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)?
32        .get(expr)
33        .ok_or_else(|| {
34            vortex_err!("Expression missing from scope accesses, this is a internal bug")
35        })
36        .cloned()
37}
38
39struct ImmediateScopeAccessesAnalysis<'a> {
40    sub_expressions: FieldAccesses<'a>,
41    scope_dtype: &'a StructDType,
42}
43
44impl<'a> ImmediateScopeAccessesAnalysis<'a> {
45    fn new(scope_dtype: &'a StructDType) -> Self {
46        Self {
47            sub_expressions: HashMap::new(),
48            scope_dtype,
49        }
50    }
51
52    fn analyze(expr: &'a ExprRef, scope_dtype: &'a StructDType) -> VortexResult<FieldAccesses<'a>> {
53        let mut analysis = Self::new(scope_dtype);
54        expr.accept(&mut analysis)?;
55        Ok(analysis.sub_expressions)
56    }
57}
58
59impl<'a> NodeVisitor<'a> for ImmediateScopeAccessesAnalysis<'a> {
60    type NodeTy = ExprRef;
61
62    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
63        assert!(
64            !node.as_any().is::<Select>(),
65            "cannot analyse select, simply the expression"
66        );
67        if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
68            if get_item
69                .child()
70                .as_any()
71                .downcast_ref::<Identity>()
72                .is_some()
73            {
74                self.sub_expressions
75                    .insert(node, HashSet::from_iter(vec![get_item.field().clone()]));
76
77                return Ok(TraversalOrder::Skip);
78            }
79        } else if node.as_any().downcast_ref::<Identity>().is_some() {
80            let st_dtype = &self.scope_dtype;
81            self.sub_expressions
82                .insert(node, st_dtype.names().iter().cloned().collect());
83        }
84
85        Ok(TraversalOrder::Continue)
86    }
87
88    fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
89        let accesses = node
90            .children()
91            .iter()
92            .filter_map(|c| self.sub_expressions.get(c).cloned())
93            .collect_vec();
94
95        let node_accesses = self.sub_expressions.entry(node).or_default();
96        accesses
97            .into_iter()
98            .for_each(|fields| node_accesses.extend(fields.iter().cloned()));
99
100        Ok(TraversalOrder::Continue)
101    }
102}