vortex_array/arrays/struct_/compute/
zip.rs1use std::ops::{BitAnd, BitOr, Not};
5
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use crate::arrays::{StructArray, StructVTable};
10use crate::compute::{ZipKernel, ZipKernelAdapter, zip};
11use crate::validity::Validity;
12use crate::vtable::ValidityHelper;
13use crate::{Array, ArrayRef, register_kernel};
14
15impl ZipKernel for StructVTable {
16 fn zip(
17 &self,
18 if_true: &StructArray,
19 if_false: &dyn Array,
20 mask: &Mask,
21 ) -> VortexResult<Option<ArrayRef>> {
22 let Some(if_false) = if_false.as_opt::<StructVTable>() else {
23 return Ok(None);
24 };
25 assert_eq!(
26 if_true.len(),
27 if_false.len(),
28 "ComputeFn::invoke checks that arrays have the same size"
29 );
30 assert_eq!(
31 if_true.names(),
32 if_false.names(),
33 "Zip checks that arrays type"
34 );
35
36 let fields = if_true
37 .fields()
38 .iter()
39 .zip(if_false.fields().iter())
40 .map(|(t, f)| zip(t, f, mask))
41 .collect::<VortexResult<Vec<_>>>()?;
42
43 let validity = match (if_true.validity(), if_false.validity()) {
44 (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
45 (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
46 (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
47
48 (v1, v2) => {
49 let v1m = v1.to_mask(if_true.len());
50 let v2m = v2.to_mask(if_false.len());
51
52 let combined = (v1m.bitand(mask)).bitor(&v2m.bitand(&mask.not()));
53 Validity::from_mask(
54 combined,
55 if_true.dtype.nullability() | if_false.dtype.nullability(),
56 )
57 }
58 };
59
60 Ok(Some(
61 StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
62 .to_array(),
63 ))
64 }
65}
66
67register_kernel!(ZipKernelAdapter(StructVTable).lift());
68
69#[cfg(test)]
70mod tests {
71 use vortex_dtype::FieldNames;
72 use vortex_mask::Mask;
73
74 use crate::IntoArray;
75 use crate::arrays::{PrimitiveArray, StructArray};
76 use crate::compute::zip;
77 use crate::validity::Validity;
78
79 #[test]
80 fn test_validity_zip_both_validity_array() {
81 let if_true = StructArray::new(
83 FieldNames::from_iter(["field"]),
84 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
85 4,
86 Validity::from_iter([true, false, true, false]),
87 )
88 .into_array();
89
90 let if_false = StructArray::new(
91 FieldNames::from_iter(["field"]),
92 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
93 4,
94 Validity::from_iter([false, true, false, true]),
95 )
96 .into_array();
97
98 let mask = Mask::from_iter([false, false, true, false]);
99
100 let result = zip(&if_true, &if_false, &mask).unwrap();
101
102 insta::assert_snapshot!(result.display_table(), @r"
103 ┌───────┐
104 │ field │
105 ├───────┤
106 │ null │
107 ├───────┤
108 │ 20i32 │
109 ├───────┤
110 │ 3i32 │
111 ├───────┤
112 │ 40i32 │
113 └───────┘
114 ");
115 }
116
117 #[test]
118 fn test_validity_zip_allvalid_and_array() {
119 let if_true = StructArray::new(
120 FieldNames::from_iter(["a"]),
121 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
122 4,
123 Validity::AllValid,
124 )
125 .into_array();
126
127 let if_false = StructArray::new(
128 FieldNames::from_iter(["a"]),
129 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
130 4,
131 Validity::from_iter([false, false, true, true]),
132 )
133 .into_array();
134
135 let mask = Mask::from_iter([true, false, false, false]);
136
137 let result = zip(&if_true, &if_false, &mask).unwrap();
138
139 insta::assert_snapshot!(result.display_table(), @r"
140 ┌───────┐
141 │ a │
142 ├───────┤
143 │ 1i32 │
144 ├───────┤
145 │ null │
146 ├───────┤
147 │ 30i32 │
148 ├───────┤
149 │ 40i32 │
150 └───────┘
151 ");
152 }
153}