Skip to main content

vortex_array/arrays/struct_/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Struct;
15use crate::arrays::StructArray;
16use crate::arrays::struct_::StructArrayExt;
17use crate::builtins::ArrayBuiltins;
18use crate::scalar_fn::fns::zip::ZipKernel;
19use crate::validity::Validity;
20
21impl ZipKernel for Struct {
22    fn zip(
23        if_true: ArrayView<'_, Struct>,
24        if_false: &ArrayRef,
25        mask: &ArrayRef,
26        ctx: &mut ExecutionCtx,
27    ) -> VortexResult<Option<ArrayRef>> {
28        let Some(if_false) = if_false.as_opt::<Struct>() else {
29            return Ok(None);
30        };
31        assert_eq!(
32            if_true.names(),
33            if_false.names(),
34            "input arrays to zip must have the same field names",
35        );
36
37        let fields = if_true
38            .iter_unmasked_fields()
39            .zip(if_false.iter_unmasked_fields())
40            .map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
41            .collect::<VortexResult<Vec<_>>>()?;
42
43        let v1 = if_true.validity()?;
44        let v2 = if_false.validity()?;
45        let validity = match (&v1, &v2) {
46            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
47            (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
48            (Validity::AllInvalid, Validity::AllInvalid) => Validity::AllInvalid,
49
50            (v1, v2) => {
51                let mask_mask = mask.try_to_mask_fill_null_false(ctx)?;
52                let v1m = v1.execute_mask(if_true.len(), ctx)?;
53                let v2m = v2.execute_mask(if_false.len(), ctx)?;
54
55                let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not()));
56                Validity::from_mask(
57                    combined,
58                    if_true.dtype().nullability() | if_false.dtype().nullability(),
59                )
60            }
61        };
62
63        Ok(Some(
64            StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
65                .into_array(),
66        ))
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use vortex_mask::Mask;
73
74    use crate::IntoArray;
75    use crate::arrays::PrimitiveArray;
76    use crate::arrays::StructArray;
77    use crate::builtins::ArrayBuiltins;
78    use crate::dtype::FieldNames;
79    use crate::validity::Validity;
80
81    #[test]
82    fn test_validity_zip_both_validity_array() {
83        // Both structs have Validity::Array
84        let if_true = StructArray::new(
85            FieldNames::from_iter(["field"]),
86            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
87            4,
88            Validity::from_iter([true, false, true, false]),
89        )
90        .into_array();
91
92        let if_false = StructArray::new(
93            FieldNames::from_iter(["field"]),
94            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
95            4,
96            Validity::from_iter([false, true, false, true]),
97        )
98        .into_array();
99
100        let mask = Mask::from_iter([false, false, true, false]);
101
102        let result = mask.into_array().zip(if_true, if_false).unwrap();
103
104        insta::assert_snapshot!(result.display_table(), @r"
105        ┌───────┐
106        │ field │
107        ├───────┤
108        │ null  │
109        ├───────┤
110        │ 20i32 │
111        ├───────┤
112        │ 3i32  │
113        ├───────┤
114        │ 40i32 │
115        └───────┘
116        ");
117    }
118
119    #[test]
120    fn test_validity_zip_allvalid_and_array() {
121        let if_true = StructArray::new(
122            FieldNames::from_iter(["a"]),
123            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
124            4,
125            Validity::AllValid,
126        )
127        .into_array();
128
129        let if_false = StructArray::new(
130            FieldNames::from_iter(["a"]),
131            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
132            4,
133            Validity::from_iter([false, false, true, true]),
134        )
135        .into_array();
136
137        let mask = Mask::from_iter([true, false, false, false]);
138
139        let result = mask.into_array().zip(if_true, if_false).unwrap();
140
141        insta::assert_snapshot!(result.display_table(), @r"
142        ┌───────┐
143        │   a   │
144        ├───────┤
145        │ 1i32  │
146        ├───────┤
147        │ null  │
148        ├───────┤
149        │ 30i32 │
150        ├───────┤
151        │ 40i32 │
152        └───────┘
153        ");
154    }
155}