vortex_expr/transform/
field_mask.rs1use vortex_dtype::{DType, Field, FieldPath};
2use vortex_error::{VortexResult, vortex_bail};
3use vortex_utils::aliases::hash_set::HashSet;
4
5use crate::traversal::{FoldUp, Folder, Node};
6use crate::{ExprRef, GetItem, Select, is_root};
7
8pub fn field_mask(expr: &ExprRef, scope_dtype: &DType) -> VortexResult<HashSet<FieldPath>> {
12 let DType::Struct(_scope_dtype, _) = scope_dtype else {
14 vortex_bail!("Mismatched dtype {} for struct layout", scope_dtype);
15 };
16
17 Ok(match expr.accept_with_context(&mut FieldMaskFolder, ())? {
18 FoldUp::Abort(out) => out,
19 FoldUp::Continue(out) => out,
20 })
21}
22
23struct FieldMaskFolder;
24
25impl<'a> Folder<'a> for FieldMaskFolder {
26 type NodeTy = ExprRef;
27 type Out = HashSet<FieldPath>;
28 type Context = ();
29
30 fn visit_up(
31 &mut self,
32 node: &'a Self::NodeTy,
33 _context: Self::Context,
34 children: Vec<Self::Out>,
35 ) -> VortexResult<FoldUp<Self::Out>> {
36 if is_root(node) {
38 return Ok(FoldUp::Continue([FieldPath::root()].into()));
39 }
40
41 if let Some(getitem) = node.as_any().downcast_ref::<GetItem>() {
43 let fields = children
44 .into_iter()
45 .flat_map(|field_mask| field_mask.into_iter())
46 .map(|field_path| field_path.push(Field::Name(getitem.field().clone())))
47 .collect();
48 return Ok(FoldUp::Continue(fields));
49 }
50
51 if node.as_any().is::<Select>() {
52 vortex_bail!("Expression must be simplified")
53 }
54
55 Ok(FoldUp::Continue(children.into_iter().flatten().collect()))
57 }
58}
59
60#[cfg(test)]
61mod test {
62 use std::iter;
63 use std::sync::Arc;
64
65 use itertools::Itertools;
66 use vortex_dtype::Nullability::NonNullable;
67 use vortex_dtype::{DType, FieldPath, PType, StructFields};
68
69 use crate::transform::field_mask::field_mask;
70 use crate::{get_item, root};
71
72 fn dtype() -> DType {
73 DType::Struct(
74 Arc::new(StructFields::new(
75 ["A".into(), "B".into(), "C".into()].into(),
76 iter::repeat_n(DType::Primitive(PType::I32, NonNullable), 3).collect(),
77 )),
78 NonNullable,
79 )
80 }
81
82 #[test]
83 fn field_mask_ident() {
84 let mask = field_mask(&root(), &dtype())
85 .unwrap()
86 .into_iter()
87 .collect_vec();
88 assert_eq!(mask.as_slice(), &[FieldPath::root()]);
89 }
90
91 #[test]
92 fn field_mask_get_item() {
93 let mask = field_mask(&get_item("A", root()), &dtype())
94 .unwrap()
95 .into_iter()
96 .collect_vec();
97 assert_eq!(mask.as_slice(), &[FieldPath::from_name("A")]);
98 }
99
100 #[test]
101 fn field_mask_get_item_nested() {
102 let mask = field_mask(&get_item("B", get_item("A", root())), &dtype())
103 .unwrap()
104 .into_iter()
105 .collect_vec();
106 assert_eq!(mask.as_slice(), &[FieldPath::from_name("A").push("B")]);
107 }
108}