Skip to main content

vortex_array/arrays/struct_/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::StructArray;
16use crate::arrays::StructVTable;
17use crate::builtins::ArrayBuiltins;
18use crate::scalar_fn::fns::zip::ZipKernel;
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21
22impl ZipKernel for StructVTable {
23    fn zip(
24        if_true: &StructArray,
25        if_false: &dyn Array,
26        mask: &Mask,
27        _ctx: &mut ExecutionCtx,
28    ) -> VortexResult<Option<ArrayRef>> {
29        let Some(if_false) = if_false.as_opt::<StructVTable>() else {
30            return Ok(None);
31        };
32        assert_eq!(
33            if_true.names(),
34            if_false.names(),
35            "input arrays to zip must have the same field names",
36        );
37
38        let fields = if_true
39            .unmasked_fields()
40            .iter()
41            .zip(if_false.unmasked_fields().iter())
42            .map(|(t, f)| ArrayBuiltins::zip(t, f.clone(), mask.clone().into_array()))
43            .collect::<VortexResult<Vec<_>>>()?;
44
45        let validity = match (if_true.validity(), if_false.validity()) {
46            (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
47            (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
48            (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
49
50            (v1, v2) => {
51                let v1m = v1.to_mask(if_true.len());
52                let v2m = v2.to_mask(if_false.len());
53
54                let combined = (v1m.bitand(mask)).bitor(&v2m.bitand(&mask.not()));
55                Validity::from_mask(
56                    combined,
57                    if_true.dtype.nullability() | if_false.dtype.nullability(),
58                )
59            }
60        };
61
62        Ok(Some(
63            StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
64                .to_array(),
65        ))
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use vortex_mask::Mask;
72
73    use crate::IntoArray;
74    use crate::arrays::PrimitiveArray;
75    use crate::arrays::StructArray;
76    #[expect(deprecated)]
77    use crate::compute::zip;
78    use crate::dtype::FieldNames;
79    use crate::validity::Validity;
80
81    #[test]
82    fn test_validity_zip_both_validity_array() {
83        // Both structs have Validity::Array
84        let if_true = StructArray::new(
85            FieldNames::from_iter(["field"]),
86            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
87            4,
88            Validity::from_iter([true, false, true, false]),
89        )
90        .into_array();
91
92        let if_false = StructArray::new(
93            FieldNames::from_iter(["field"]),
94            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
95            4,
96            Validity::from_iter([false, true, false, true]),
97        )
98        .into_array();
99
100        let mask = Mask::from_iter([false, false, true, false]);
101
102        #[expect(deprecated)]
103        let result = zip(&if_true, &if_false, &mask).unwrap();
104
105        insta::assert_snapshot!(result.display_table(), @r"
106        ┌───────┐
107        │ field │
108        ├───────┤
109        │ null  │
110        ├───────┤
111        │ 20i32 │
112        ├───────┤
113        │ 3i32  │
114        ├───────┤
115        │ 40i32 │
116        └───────┘
117        ");
118    }
119
120    #[test]
121    fn test_validity_zip_allvalid_and_array() {
122        let if_true = StructArray::new(
123            FieldNames::from_iter(["a"]),
124            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
125            4,
126            Validity::AllValid,
127        )
128        .into_array();
129
130        let if_false = StructArray::new(
131            FieldNames::from_iter(["a"]),
132            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
133            4,
134            Validity::from_iter([false, false, true, true]),
135        )
136        .into_array();
137
138        let mask = Mask::from_iter([true, false, false, false]);
139
140        #[expect(deprecated)]
141        let result = zip(&if_true, &if_false, &mask).unwrap();
142
143        insta::assert_snapshot!(result.display_table(), @r"
144        ┌───────┐
145        │   a   │
146        ├───────┤
147        │ 1i32  │
148        ├───────┤
149        │ null  │
150        ├───────┤
151        │ 30i32 │
152        ├───────┤
153        │ 40i32 │
154        └───────┘
155        ");
156    }
157}