vortex_array/arrays/struct_/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::{BitAnd, BitOr, Not};
5
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use crate::arrays::{StructArray, StructVTable};
10use crate::compute::{ZipKernel, ZipKernelAdapter, zip};
11use crate::validity::Validity;
12use crate::vtable::ValidityHelper;
13use crate::{Array, ArrayRef, register_kernel};
14
15impl ZipKernel for StructVTable {
16    fn zip(
17        &self,
18        if_true: &StructArray,
19        if_false: &dyn Array,
20        mask: &Mask,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let Some(if_false) = if_false.as_opt::<StructVTable>() else {
23            return Ok(None);
24        };
25        assert_eq!(
26            if_true.len(),
27            if_false.len(),
28            "ComputeFn::invoke checks that arrays have the same size"
29        );
30        assert_eq!(
31            if_true.names(),
32            if_false.names(),
33            "Zip checks that arrays type"
34        );
35
36        let fields = if_true
37            .fields()
38            .iter()
39            .zip(if_false.fields().iter())
40            .map(|(t, f)| zip(t, f, mask))
41            .collect::<VortexResult<Vec<_>>>()?;
42
43        let validity = match (if_true.validity(), if_false.validity()) {
44            (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
45            (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
46            (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
47
48            (v1, v2) => {
49                let v1m = v1.to_mask(if_true.len());
50                let v2m = v2.to_mask(if_false.len());
51
52                let combined = (v1m.bitand(mask)).bitor(&v2m.bitand(&mask.not()));
53                Validity::from_mask(
54                    combined,
55                    if_true.dtype.nullability() | if_false.dtype.nullability(),
56                )
57            }
58        };
59
60        Ok(Some(
61            StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
62                .to_array(),
63        ))
64    }
65}
66
67register_kernel!(ZipKernelAdapter(StructVTable).lift());
68
69#[cfg(test)]
70mod tests {
71    use vortex_dtype::FieldNames;
72    use vortex_mask::Mask;
73
74    use crate::IntoArray;
75    use crate::arrays::{PrimitiveArray, StructArray};
76    use crate::compute::zip;
77    use crate::validity::Validity;
78
79    #[test]
80    fn test_validity_zip_both_validity_array() {
81        // Both structs have Validity::Array
82        let if_true = StructArray::new(
83            FieldNames::from_iter(["field"]),
84            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
85            4,
86            Validity::from_iter([true, false, true, false]),
87        )
88        .into_array();
89
90        let if_false = StructArray::new(
91            FieldNames::from_iter(["field"]),
92            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
93            4,
94            Validity::from_iter([false, true, false, true]),
95        )
96        .into_array();
97
98        let mask = Mask::from_iter([false, false, true, false]);
99
100        let result = zip(&if_true, &if_false, &mask).unwrap();
101
102        insta::assert_snapshot!(result.display_table(), @r"
103        ┌───────┐
104        │ field │
105        ├───────┤
106        │ null  │
107        ├───────┤
108        │ 20i32 │
109        ├───────┤
110        │ 3i32  │
111        ├───────┤
112        │ 40i32 │
113        └───────┘
114        ");
115    }
116
117    #[test]
118    fn test_validity_zip_allvalid_and_array() {
119        let if_true = StructArray::new(
120            FieldNames::from_iter(["a"]),
121            vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
122            4,
123            Validity::AllValid,
124        )
125        .into_array();
126
127        let if_false = StructArray::new(
128            FieldNames::from_iter(["a"]),
129            vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
130            4,
131            Validity::from_iter([false, false, true, true]),
132        )
133        .into_array();
134
135        let mask = Mask::from_iter([true, false, false, false]);
136
137        let result = zip(&if_true, &if_false, &mask).unwrap();
138
139        insta::assert_snapshot!(result.display_table(), @r"
140        ┌───────┐
141        │   a   │
142        ├───────┤
143        │ 1i32  │
144        ├───────┤
145        │ null  │
146        ├───────┤
147        │ 30i32 │
148        ├───────┤
149        │ 40i32 │
150        └───────┘
151        ");
152    }
153}