vortex_expr/transform/
field_mask.rs1use vortex_dtype::{DType, Field, FieldPath};
2use vortex_error::{VortexResult, vortex_bail};
3use vortex_utils::aliases::hash_set::HashSet;
4
5use crate::traversal::{FoldUp, Folder, Node};
6use crate::{ExprRef, GetItem, Select, is_root};
7
8pub fn field_mask(expr: &ExprRef, scope_dtype: &DType) -> VortexResult<HashSet<FieldPath>> {
12 let DType::Struct(_scope_dtype, _) = scope_dtype else {
14 vortex_bail!("Mismatched dtype {} for struct layout", scope_dtype);
15 };
16
17 Ok(match expr.accept_with_context(&mut FieldMaskFolder, ())? {
18 FoldUp::Abort(out) => out,
19 FoldUp::Continue(out) => out,
20 })
21}
22
23struct FieldMaskFolder;
24
25impl<'a> Folder<'a> for FieldMaskFolder {
26 type NodeTy = ExprRef;
27 type Out = HashSet<FieldPath>;
28 type Context = ();
29
30 fn visit_up(
31 &mut self,
32 node: &'a Self::NodeTy,
33 _context: Self::Context,
34 children: Vec<Self::Out>,
35 ) -> VortexResult<FoldUp<Self::Out>> {
36 if is_root(node) {
38 return Ok(FoldUp::Continue([FieldPath::root()].into()));
39 }
40
41 if let Some(getitem) = node.as_any().downcast_ref::<GetItem>() {
43 let fields = children
44 .into_iter()
45 .flat_map(|field_mask| field_mask.into_iter())
46 .map(|field_path| field_path.push(Field::Name(getitem.field().clone())))
47 .collect();
48 return Ok(FoldUp::Continue(fields));
49 }
50
51 if node.as_any().is::<Select>() {
52 vortex_bail!("Expression must be simplified")
53 }
54
55 Ok(FoldUp::Continue(children.into_iter().flatten().collect()))
57 }
58}
59
60#[cfg(test)]
61mod test {
62 use std::iter;
63
64 use itertools::Itertools;
65 use vortex_dtype::Nullability::NonNullable;
66 use vortex_dtype::{DType, FieldPath, PType, StructFields};
67
68 use crate::transform::field_mask::field_mask;
69 use crate::{get_item, root};
70
71 fn dtype() -> DType {
72 DType::Struct(
73 StructFields::new(
74 ["A", "B", "C"].into(),
75 iter::repeat_n(DType::Primitive(PType::I32, NonNullable), 3).collect(),
76 ),
77 NonNullable,
78 )
79 }
80
81 #[test]
82 fn field_mask_ident() {
83 let mask = field_mask(&root(), &dtype())
84 .unwrap()
85 .into_iter()
86 .collect_vec();
87 assert_eq!(mask.as_slice(), &[FieldPath::root()]);
88 }
89
90 #[test]
91 fn field_mask_get_item() {
92 let mask = field_mask(&get_item("A", root()), &dtype())
93 .unwrap()
94 .into_iter()
95 .collect_vec();
96 assert_eq!(mask.as_slice(), &[FieldPath::from_name("A")]);
97 }
98
99 #[test]
100 fn field_mask_get_item_nested() {
101 let mask = field_mask(&get_item("B", get_item("A", root())), &dtype())
102 .unwrap()
103 .into_iter()
104 .collect_vec();
105 assert_eq!(mask.as_slice(), &[FieldPath::from_name("A").push("B")]);
106 }
107}