vortex_expr/transform/
field_mask.rs

1use vortex_array::aliases::hash_set::HashSet;
2use vortex_dtype::{DType, Field, FieldPath};
3use vortex_error::{VortexResult, vortex_bail};
4
5use crate::traversal::{FoldUp, Folder, Node};
6use crate::{ExprRef, GetItem, Identity, Select};
7
8/// Returns the field mask for the given expression.
9///
10/// This defines a mask over the scope of the fields that are accessed by the expression.
11pub fn field_mask(expr: &ExprRef, scope_dtype: &DType) -> VortexResult<HashSet<FieldPath>> {
12    // I know it's unused now, but we will for sure need the scope DType for future expressions.
13    let DType::Struct(_scope_dtype, _) = scope_dtype else {
14        vortex_bail!("Mismatched dtype {} for struct layout", scope_dtype);
15    };
16
17    Ok(match expr.accept_with_context(&mut FieldMaskFolder, ())? {
18        FoldUp::Abort(out) => out,
19        FoldUp::Continue(out) => out,
20    })
21}
22
23struct FieldMaskFolder;
24
25impl<'a> Folder<'a> for FieldMaskFolder {
26    type NodeTy = ExprRef;
27    type Out = HashSet<FieldPath>;
28    type Context = ();
29
30    fn visit_up(
31        &mut self,
32        node: &'a Self::NodeTy,
33        _context: Self::Context,
34        children: Vec<Self::Out>,
35    ) -> VortexResult<FoldUp<Self::Out>> {
36        // The identity returns a field path covering the root.
37        if node.as_any().is::<Identity>() {
38            return Ok(FoldUp::Continue([FieldPath::root()].into()));
39        }
40
41        // GetItem pushes an element to each field path
42        if let Some(getitem) = node.as_any().downcast_ref::<GetItem>() {
43            let fields = children
44                .into_iter()
45                .flat_map(|field_mask| field_mask.into_iter())
46                .map(|field_path| field_path.push(Field::Name(getitem.field().clone())))
47                .collect();
48            return Ok(FoldUp::Continue(fields));
49        }
50
51        if node.as_any().is::<Select>() {
52            vortex_bail!("Expression must be simplified")
53        }
54
55        // Otherwise, return the field paths from the children
56        Ok(FoldUp::Continue(children.into_iter().flatten().collect()))
57    }
58}
59
60#[cfg(test)]
61mod test {
62    use std::iter;
63    use std::sync::Arc;
64
65    use itertools::Itertools;
66    use vortex_dtype::Nullability::NonNullable;
67    use vortex_dtype::{DType, FieldPath, PType, StructDType};
68
69    use crate::transform::field_mask::field_mask;
70    use crate::{get_item, ident};
71
72    fn dtype() -> DType {
73        DType::Struct(
74            Arc::new(StructDType::new(
75                ["A".into(), "B".into(), "C".into()].into(),
76                iter::repeat_n(DType::Primitive(PType::I32, NonNullable), 3).collect(),
77            )),
78            NonNullable,
79        )
80    }
81
82    #[test]
83    fn field_mask_ident() {
84        let mask = field_mask(&ident(), &dtype())
85            .unwrap()
86            .into_iter()
87            .collect_vec();
88        assert_eq!(mask.as_slice(), &[FieldPath::root()]);
89    }
90
91    #[test]
92    fn field_mask_get_item() {
93        let mask = field_mask(&get_item("A", ident()), &dtype())
94            .unwrap()
95            .into_iter()
96            .collect_vec();
97        assert_eq!(mask.as_slice(), &[FieldPath::from_name("A")]);
98    }
99
100    #[test]
101    fn field_mask_get_item_nested() {
102        let mask = field_mask(&get_item("B", get_item("A", ident())), &dtype())
103            .unwrap()
104            .into_iter()
105            .collect_vec();
106        assert_eq!(mask.as_slice(), &[FieldPath::from_name("A").push("B")]);
107    }
108}