vortex_expr/transform/
field_mask.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use vortex_array::aliases::hash_set::HashSet;
use vortex_dtype::{DType, Field, FieldPath};
use vortex_error::{vortex_bail, VortexResult};

use crate::traversal::{FoldUp, Folder, Node};
use crate::{ExprRef, GetItem, Identity, Select};

/// Returns the field mask for the given expression.
///
/// This defines a mask over the scope of the fields that are accessed by the expression.
pub fn field_mask(expr: &ExprRef, scope_dtype: &DType) -> VortexResult<HashSet<FieldPath>> {
    // I know it's unused now, but we will for sure need the scope DType for future expressions.
    let DType::Struct(_scope_dtype, _) = scope_dtype else {
        vortex_bail!("Mismatched dtype {} for struct layout", scope_dtype);
    };

    Ok(match expr.accept_with_context(&mut FieldMaskFolder, ())? {
        FoldUp::Abort(out) => out,
        FoldUp::Continue(out) => out,
    })
}

struct FieldMaskFolder;

impl<'a> Folder<'a> for FieldMaskFolder {
    type NodeTy = ExprRef;
    type Out = HashSet<FieldPath>;
    type Context = ();

    fn visit_up(
        &mut self,
        node: &'a Self::NodeTy,
        _context: Self::Context,
        children: Vec<Self::Out>,
    ) -> VortexResult<FoldUp<Self::Out>> {
        // The identity returns a field path covering the root.
        if node.as_any().is::<Identity>() {
            return Ok(FoldUp::Continue([FieldPath::root()].into()));
        }

        // GetItem pushes an element to each field path
        if let Some(getitem) = node.as_any().downcast_ref::<GetItem>() {
            let fields = children
                .into_iter()
                .flat_map(|field_mask| field_mask.into_iter())
                .map(|field_path| field_path.push(Field::Name(getitem.field().clone())))
                .collect();
            return Ok(FoldUp::Continue(fields));
        }

        if node.as_any().is::<Select>() {
            vortex_bail!("Expression must be simplified")
        }

        // Otherwise, return the field paths from the children
        Ok(FoldUp::Continue(children.into_iter().flatten().collect()))
    }
}

#[cfg(test)]
mod test {
    use std::iter;

    use itertools::Itertools;
    use vortex_dtype::Nullability::NonNullable;
    use vortex_dtype::{DType, FieldPath, PType, StructDType};

    use crate::transform::field_mask::field_mask;
    use crate::{get_item, ident};

    fn dtype() -> DType {
        DType::Struct(
            StructDType::new(
                ["A".into(), "B".into(), "C".into()].into(),
                iter::repeat(DType::Primitive(PType::I32, NonNullable))
                    .take(3)
                    .collect(),
            ),
            NonNullable,
        )
    }

    #[test]
    fn field_mask_ident() {
        let mask = field_mask(&ident(), &dtype())
            .unwrap()
            .into_iter()
            .collect_vec();
        assert_eq!(mask.as_slice(), &[FieldPath::root()]);
    }

    #[test]
    fn field_mask_get_item() {
        let mask = field_mask(&get_item("A", ident()), &dtype())
            .unwrap()
            .into_iter()
            .collect_vec();
        assert_eq!(mask.as_slice(), &[FieldPath::from_name("A")]);
    }

    #[test]
    fn field_mask_get_item_nested() {
        let mask = field_mask(&get_item("B", get_item("A", ident())), &dtype())
            .unwrap()
            .into_iter()
            .collect_vec();
        assert_eq!(mask.as_slice(), &[FieldPath::from_name("A").push("B")]);
    }
}