vortex_expr/transform/
immediate_access.rs

1use itertools::Itertools;
2use vortex_dtype::{FieldName, StructFields};
3use vortex_error::{VortexResult, vortex_err};
4use vortex_utils::aliases::hash_map::HashMap;
5use vortex_utils::aliases::hash_set::HashSet;
6
7use crate::traversal::{Node, NodeVisitor, TraversalOrder};
8use crate::{ExprRef, GetItem, Select, is_root};
9
10pub type FieldAccesses<'a> = HashMap<&'a ExprRef, HashSet<FieldName>>;
11
12/// For all subexpressions in an expression, find the fields that are accessed directly from the
13/// scope, but not any fields in those fields
14/// e.g. scope = {a: {b: .., c: ..}, d: ..}, expr = ident().a.b + ident().d accesses {a,d} (not b).
15///
16/// Note: This is a very naive, but simple analysis to find the fields that are accessed directly on an
17/// identity node. This is combined to provide an over-approximation of the fields that are accessed
18/// by an expression.
19pub fn immediate_scope_accesses<'a>(
20    expr: &'a ExprRef,
21    scope_dtype: &'a StructFields,
22) -> VortexResult<FieldAccesses<'a>> {
23    ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)
24}
25
26/// This returns the immediate scope_access (as explained `immediate_scope_accesses`) for `expr`.
27pub fn immediate_scope_access<'a>(
28    expr: &'a ExprRef,
29    scope_dtype: &'a StructFields,
30) -> VortexResult<HashSet<FieldName>> {
31    ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype)?
32        .get(expr)
33        .ok_or_else(|| {
34            vortex_err!("Expression missing from scope accesses, this is a internal bug")
35        })
36        .cloned()
37}
38
39struct ImmediateScopeAccessesAnalysis<'a> {
40    sub_expressions: FieldAccesses<'a>,
41    scope_dtype: &'a StructFields,
42}
43
44impl<'a> ImmediateScopeAccessesAnalysis<'a> {
45    fn new(scope_dtype: &'a StructFields) -> Self {
46        Self {
47            sub_expressions: HashMap::new(),
48            scope_dtype,
49        }
50    }
51
52    fn analyze(
53        expr: &'a ExprRef,
54        scope_dtype: &'a StructFields,
55    ) -> VortexResult<FieldAccesses<'a>> {
56        let mut analysis = Self::new(scope_dtype);
57        expr.accept(&mut analysis)?;
58        Ok(analysis.sub_expressions)
59    }
60}
61
62impl<'a> NodeVisitor<'a> for ImmediateScopeAccessesAnalysis<'a> {
63    type NodeTy = ExprRef;
64
65    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
66        assert!(
67            !node.as_any().is::<Select>(),
68            "cannot analyze select, simplify the expression"
69        );
70        if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
71            if is_root(get_item.child()) {
72                self.sub_expressions
73                    .insert(node, HashSet::from_iter(vec![get_item.field().clone()]));
74
75                return Ok(TraversalOrder::Skip);
76            }
77        } else if is_root(node) {
78            let st_dtype = &self.scope_dtype;
79            self.sub_expressions
80                .insert(node, st_dtype.names().iter().cloned().collect());
81        }
82
83        Ok(TraversalOrder::Continue)
84    }
85
86    fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult<TraversalOrder> {
87        let accesses = node
88            .children()
89            .iter()
90            .filter_map(|c| self.sub_expressions.get(c).cloned())
91            .collect_vec();
92
93        let node_accesses = self.sub_expressions.entry(node).or_default();
94        accesses
95            .into_iter()
96            .for_each(|fields| node_accesses.extend(fields.iter().cloned()));
97
98        Ok(TraversalOrder::Continue)
99    }
100}