vortex_array/arrays/struct_/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::arrays::StructArray;
14use crate::arrays::StructVTable;
15use crate::compute::ZipKernel;
16use crate::compute::ZipKernelAdapter;
17use crate::compute::zip;
18use crate::register_kernel;
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21
22impl ZipKernel for StructVTable {
23    fn zip(
24        &self,
25        if_true: &StructArray,
26        if_false: &dyn Array,
27        mask: &Mask,
28    ) -> VortexResult<Option<ArrayRef>> {
29        let Some(if_false) = if_false.as_opt::<StructVTable>() else {
30            return Ok(None);
31        };
32        assert_eq!(
33            if_true.names(),
34            if_false.names(),
35            "input arrays to zip must have the same field names",
36        );
37
38        let fields = if_true
39            .fields()
40            .iter()
41            .zip(if_false.fields().iter())
42            .map(|(t, f)| zip(t, f, mask))
43            .collect::<VortexResult<Vec<_>>>()?;
44
45        let validity = match (if_true.validity(), if_false.validity()) {
46            (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
47            (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
48            (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
49
50            (v1, v2) => {
51                let v1m = v1.to_mask(if_true.len());
52                let v2m = v2.to_mask(if_false.len());
53
54                let combined = (v1m.bitand(mask)).bitor(&v2m.bitand(&mask.not()));
55                Validity::from_mask(
56                    combined,
57                    if_true.dtype.nullability() | if_false.dtype.nullability(),
58                )
59            }
60        };
61
62        Ok(Some(
63            StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
64                .to_array(),
65        ))
66    }
67}
68
69register_kernel!(ZipKernelAdapter(StructVTable).lift());
70
71#[cfg(test)]
72mod tests {
73    use vortex_dtype::FieldNames;
74    use vortex_mask::Mask;
75
76    use crate::IntoArray;
77    use crate::arrays::PrimitiveArray;
78    use crate::arrays::StructArray;
79    use crate::compute::zip;
80    use crate::validity::Validity;
81
82    #[test]
83    fn test_validity_zip_both_validity_array() {
84        // Both structs have Validity::Array
85        let if_true = StructArray::new(
86            FieldNames::from_iter(["field"]),
87            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
88            4,
89            Validity::from_iter([true, false, true, false]),
90        )
91        .into_array();
92
93        let if_false = StructArray::new(
94            FieldNames::from_iter(["field"]),
95            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
96            4,
97            Validity::from_iter([false, true, false, true]),
98        )
99        .into_array();
100
101        let mask = Mask::from_iter([false, false, true, false]);
102
103        let result = zip(&if_true, &if_false, &mask).unwrap();
104
105        insta::assert_snapshot!(result.display_table(), @r"
106        ┌───────┐
107        │ field │
108        ├───────┤
109        │ null  │
110        ├───────┤
111        │ 20i32 │
112        ├───────┤
113        │ 3i32  │
114        ├───────┤
115        │ 40i32 │
116        └───────┘
117        ");
118    }
119
120    #[test]
121    fn test_validity_zip_allvalid_and_array() {
122        let if_true = StructArray::new(
123            FieldNames::from_iter(["a"]),
124            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
125            4,
126            Validity::AllValid,
127        )
128        .into_array();
129
130        let if_false = StructArray::new(
131            FieldNames::from_iter(["a"]),
132            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
133            4,
134            Validity::from_iter([false, false, true, true]),
135        )
136        .into_array();
137
138        let mask = Mask::from_iter([true, false, false, false]);
139
140        let result = zip(&if_true, &if_false, &mask).unwrap();
141
142        insta::assert_snapshot!(result.display_table(), @r"
143        ┌───────┐
144        │   a   │
145        ├───────┤
146        │ 1i32  │
147        ├───────┤
148        │ null  │
149        ├───────┤
150        │ 30i32 │
151        ├───────┤
152        │ 40i32 │
153        └───────┘
154        ");
155    }
156}