Skip to main content

vortex_array/arrays/struct_/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::StructArray;
14use crate::arrays::StructVTable;
15use crate::builtins::ArrayBuiltins;
16use crate::scalar_fn::fns::zip::ZipKernel;
17use crate::validity::Validity;
18use crate::vtable::ValidityHelper;
19
20impl ZipKernel for StructVTable {
21    fn zip(
22        if_true: &StructArray,
23        if_false: &ArrayRef,
24        mask: &ArrayRef,
25        ctx: &mut ExecutionCtx,
26    ) -> VortexResult<Option<ArrayRef>> {
27        let Some(if_false) = if_false.as_opt::<StructVTable>() else {
28            return Ok(None);
29        };
30        assert_eq!(
31            if_true.names(),
32            if_false.names(),
33            "input arrays to zip must have the same field names",
34        );
35
36        let fields = if_true
37            .unmasked_fields()
38            .iter()
39            .zip(if_false.unmasked_fields().iter())
40            .map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
41            .collect::<VortexResult<Vec<_>>>()?;
42
43        let validity = match (if_true.validity(), if_false.validity()) {
44            (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
45            (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
46            (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
47
48            (v1, v2) => {
49                let mask_mask = mask.try_to_mask_fill_null_false(ctx)?;
50                let v1m = v1.to_mask(if_true.len());
51                let v2m = v2.to_mask(if_false.len());
52
53                let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not()));
54                Validity::from_mask(
55                    combined,
56                    if_true.dtype.nullability() | if_false.dtype.nullability(),
57                )
58            }
59        };
60
61        Ok(Some(
62            StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
63                .into_array(),
64        ))
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use vortex_mask::Mask;
71
72    use crate::IntoArray;
73    use crate::arrays::PrimitiveArray;
74    use crate::arrays::StructArray;
75    use crate::builtins::ArrayBuiltins;
76    use crate::dtype::FieldNames;
77    use crate::validity::Validity;
78
79    #[test]
80    fn test_validity_zip_both_validity_array() {
81        // Both structs have Validity::Array
82        let if_true = StructArray::new(
83            FieldNames::from_iter(["field"]),
84            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
85            4,
86            Validity::from_iter([true, false, true, false]),
87        )
88        .into_array();
89
90        let if_false = StructArray::new(
91            FieldNames::from_iter(["field"]),
92            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
93            4,
94            Validity::from_iter([false, true, false, true]),
95        )
96        .into_array();
97
98        let mask = Mask::from_iter([false, false, true, false]);
99
100        let result = mask
101            .into_array()
102            .zip(if_true.clone(), if_false.clone())
103            .unwrap();
104
105        insta::assert_snapshot!(result.display_table(), @r"
106        ┌───────┐
107        │ field │
108        ├───────┤
109        │ null  │
110        ├───────┤
111        │ 20i32 │
112        ├───────┤
113        │ 3i32  │
114        ├───────┤
115        │ 40i32 │
116        └───────┘
117        ");
118    }
119
120    #[test]
121    fn test_validity_zip_allvalid_and_array() {
122        let if_true = StructArray::new(
123            FieldNames::from_iter(["a"]),
124            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
125            4,
126            Validity::AllValid,
127        )
128        .into_array();
129
130        let if_false = StructArray::new(
131            FieldNames::from_iter(["a"]),
132            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
133            4,
134            Validity::from_iter([false, false, true, true]),
135        )
136        .into_array();
137
138        let mask = Mask::from_iter([true, false, false, false]);
139
140        let result = mask
141            .into_array()
142            .zip(if_true.clone(), if_false.clone())
143            .unwrap();
144
145        insta::assert_snapshot!(result.display_table(), @r"
146        ┌───────┐
147        │   a   │
148        ├───────┤
149        │ 1i32  │
150        ├───────┤
151        │ null  │
152        ├───────┤
153        │ 30i32 │
154        ├───────┤
155        │ 40i32 │
156        └───────┘
157        ");
158    }
159}