Skip to main content

vortex_array/arrays/struct_/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::arrays::StructArray;
13use crate::arrays::StructVTable;
14use crate::builtins::ArrayBuiltins;
15use crate::scalar_fn::fns::zip::ZipKernel;
16use crate::validity::Validity;
17use crate::vtable::ValidityHelper;
18
19impl ZipKernel for StructVTable {
20    fn zip(
21        if_true: &StructArray,
22        if_false: &ArrayRef,
23        mask: &ArrayRef,
24        ctx: &mut ExecutionCtx,
25    ) -> VortexResult<Option<ArrayRef>> {
26        let Some(if_false) = if_false.as_opt::<StructVTable>() else {
27            return Ok(None);
28        };
29        assert_eq!(
30            if_true.names(),
31            if_false.names(),
32            "input arrays to zip must have the same field names",
33        );
34
35        let fields = if_true
36            .unmasked_fields()
37            .iter()
38            .zip(if_false.unmasked_fields().iter())
39            .map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
40            .collect::<VortexResult<Vec<_>>>()?;
41
42        let validity = match (if_true.validity(), if_false.validity()) {
43            (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
44            (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
45            (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
46
47            (v1, v2) => {
48                let mask_mask = mask.try_to_mask_fill_null_false(ctx)?;
49                let v1m = v1.to_mask(if_true.len());
50                let v2m = v2.to_mask(if_false.len());
51
52                let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not()));
53                Validity::from_mask(
54                    combined,
55                    if_true.dtype.nullability() | if_false.dtype.nullability(),
56                )
57            }
58        };
59
60        Ok(Some(
61            StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
62                .to_array(),
63        ))
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use vortex_mask::Mask;
70
71    use crate::IntoArray;
72    use crate::arrays::PrimitiveArray;
73    use crate::arrays::StructArray;
74    use crate::builtins::ArrayBuiltins;
75    use crate::dtype::FieldNames;
76    use crate::validity::Validity;
77
78    #[test]
79    fn test_validity_zip_both_validity_array() {
80        // Both structs have Validity::Array
81        let if_true = StructArray::new(
82            FieldNames::from_iter(["field"]),
83            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
84            4,
85            Validity::from_iter([true, false, true, false]),
86        )
87        .into_array();
88
89        let if_false = StructArray::new(
90            FieldNames::from_iter(["field"]),
91            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
92            4,
93            Validity::from_iter([false, true, false, true]),
94        )
95        .into_array();
96
97        let mask = Mask::from_iter([false, false, true, false]);
98
99        let result = mask
100            .into_array()
101            .zip(if_true.clone(), if_false.clone())
102            .unwrap();
103
104        insta::assert_snapshot!(result.display_table(), @r"
105        ┌───────┐
106        │ field │
107        ├───────┤
108        │ null  │
109        ├───────┤
110        │ 20i32 │
111        ├───────┤
112        │ 3i32  │
113        ├───────┤
114        │ 40i32 │
115        └───────┘
116        ");
117    }
118
119    #[test]
120    fn test_validity_zip_allvalid_and_array() {
121        let if_true = StructArray::new(
122            FieldNames::from_iter(["a"]),
123            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
124            4,
125            Validity::AllValid,
126        )
127        .into_array();
128
129        let if_false = StructArray::new(
130            FieldNames::from_iter(["a"]),
131            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
132            4,
133            Validity::from_iter([false, false, true, true]),
134        )
135        .into_array();
136
137        let mask = Mask::from_iter([true, false, false, false]);
138
139        let result = mask
140            .into_array()
141            .zip(if_true.clone(), if_false.clone())
142            .unwrap();
143
144        insta::assert_snapshot!(result.display_table(), @r"
145        ┌───────┐
146        │   a   │
147        ├───────┤
148        │ 1i32  │
149        ├───────┤
150        │ null  │
151        ├───────┤
152        │ 30i32 │
153        ├───────┤
154        │ 40i32 │
155        └───────┘
156        ");
157    }
158}