vortex_expr/transform/
field_mask.rs

1use vortex_dtype::{DType, Field, FieldPath};
2use vortex_error::{VortexResult, vortex_bail};
3use vortex_utils::aliases::hash_set::HashSet;
4
5use crate::traversal::{FoldUp, Folder, Node};
6use crate::{ExprRef, GetItem, Select, is_root};
7
8/// Returns the field mask for the given expression.
9///
10/// This defines a mask over the scope of the fields that are accessed by the expression.
11pub fn field_mask(expr: &ExprRef, scope_dtype: &DType) -> VortexResult<HashSet<FieldPath>> {
12    // I know it's unused now, but we will for sure need the scope DType for future expressions.
13    let DType::Struct(_scope_dtype, _) = scope_dtype else {
14        vortex_bail!("Mismatched dtype {} for struct layout", scope_dtype);
15    };
16
17    Ok(match expr.accept_with_context(&mut FieldMaskFolder, ())? {
18        FoldUp::Abort(out) => out,
19        FoldUp::Continue(out) => out,
20    })
21}
22
23struct FieldMaskFolder;
24
25impl<'a> Folder<'a> for FieldMaskFolder {
26    type NodeTy = ExprRef;
27    type Out = HashSet<FieldPath>;
28    type Context = ();
29
30    fn visit_up(
31        &mut self,
32        node: &'a Self::NodeTy,
33        _context: Self::Context,
34        children: Vec<Self::Out>,
35    ) -> VortexResult<FoldUp<Self::Out>> {
36        // The identity returns a field path covering the root.
37        if is_root(node) {
38            return Ok(FoldUp::Continue([FieldPath::root()].into()));
39        }
40
41        // GetItem pushes an element to each field path
42        if let Some(getitem) = node.as_any().downcast_ref::<GetItem>() {
43            let fields = children
44                .into_iter()
45                .flat_map(|field_mask| field_mask.into_iter())
46                .map(|field_path| field_path.push(Field::Name(getitem.field().clone())))
47                .collect();
48            return Ok(FoldUp::Continue(fields));
49        }
50
51        if node.as_any().is::<Select>() {
52            vortex_bail!("Expression must be simplified")
53        }
54
55        // Otherwise, return the field paths from the children
56        Ok(FoldUp::Continue(children.into_iter().flatten().collect()))
57    }
58}
59
60#[cfg(test)]
61mod test {
62    use std::iter;
63
64    use itertools::Itertools;
65    use vortex_dtype::Nullability::NonNullable;
66    use vortex_dtype::{DType, FieldPath, PType, StructFields};
67
68    use crate::transform::field_mask::field_mask;
69    use crate::{get_item, root};
70
71    fn dtype() -> DType {
72        DType::Struct(
73            StructFields::new(
74                ["A", "B", "C"].into(),
75                iter::repeat_n(DType::Primitive(PType::I32, NonNullable), 3).collect(),
76            ),
77            NonNullable,
78        )
79    }
80
81    #[test]
82    fn field_mask_ident() {
83        let mask = field_mask(&root(), &dtype())
84            .unwrap()
85            .into_iter()
86            .collect_vec();
87        assert_eq!(mask.as_slice(), &[FieldPath::root()]);
88    }
89
90    #[test]
91    fn field_mask_get_item() {
92        let mask = field_mask(&get_item("A", root()), &dtype())
93            .unwrap()
94            .into_iter()
95            .collect_vec();
96        assert_eq!(mask.as_slice(), &[FieldPath::from_name("A")]);
97    }
98
99    #[test]
100    fn field_mask_get_item_nested() {
101        let mask = field_mask(&get_item("B", get_item("A", root())), &dtype())
102            .unwrap()
103            .into_iter()
104            .collect_vec();
105        assert_eq!(mask.as_slice(), &[FieldPath::from_name("A").push("B")]);
106    }
107}