vortex_array/arrays/struct_/compute/
zip.rs1use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::arrays::StructArray;
14use crate::arrays::StructVTable;
15use crate::compute::ZipKernel;
16use crate::compute::ZipKernelAdapter;
17use crate::compute::zip;
18use crate::register_kernel;
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21
22impl ZipKernel for StructVTable {
23 fn zip(
24 &self,
25 if_true: &StructArray,
26 if_false: &dyn Array,
27 mask: &Mask,
28 ) -> VortexResult<Option<ArrayRef>> {
29 let Some(if_false) = if_false.as_opt::<StructVTable>() else {
30 return Ok(None);
31 };
32 assert_eq!(
33 if_true.names(),
34 if_false.names(),
35 "input arrays to zip must have the same field names",
36 );
37
38 let fields = if_true
39 .fields()
40 .iter()
41 .zip(if_false.fields().iter())
42 .map(|(t, f)| zip(t, f, mask))
43 .collect::<VortexResult<Vec<_>>>()?;
44
45 let validity = match (if_true.validity(), if_false.validity()) {
46 (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
47 (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
48 (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
49
50 (v1, v2) => {
51 let v1m = v1.to_mask(if_true.len());
52 let v2m = v2.to_mask(if_false.len());
53
54 let combined = (v1m.bitand(mask)).bitor(&v2m.bitand(&mask.not()));
55 Validity::from_mask(
56 combined,
57 if_true.dtype.nullability() | if_false.dtype.nullability(),
58 )
59 }
60 };
61
62 Ok(Some(
63 StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
64 .to_array(),
65 ))
66 }
67}
68
69register_kernel!(ZipKernelAdapter(StructVTable).lift());
70
71#[cfg(test)]
72mod tests {
73 use vortex_dtype::FieldNames;
74 use vortex_mask::Mask;
75
76 use crate::IntoArray;
77 use crate::arrays::PrimitiveArray;
78 use crate::arrays::StructArray;
79 use crate::compute::zip;
80 use crate::validity::Validity;
81
82 #[test]
83 fn test_validity_zip_both_validity_array() {
84 let if_true = StructArray::new(
86 FieldNames::from_iter(["field"]),
87 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
88 4,
89 Validity::from_iter([true, false, true, false]),
90 )
91 .into_array();
92
93 let if_false = StructArray::new(
94 FieldNames::from_iter(["field"]),
95 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
96 4,
97 Validity::from_iter([false, true, false, true]),
98 )
99 .into_array();
100
101 let mask = Mask::from_iter([false, false, true, false]);
102
103 let result = zip(&if_true, &if_false, &mask).unwrap();
104
105 insta::assert_snapshot!(result.display_table(), @r"
106 ┌───────┐
107 │ field │
108 ├───────┤
109 │ null │
110 ├───────┤
111 │ 20i32 │
112 ├───────┤
113 │ 3i32 │
114 ├───────┤
115 │ 40i32 │
116 └───────┘
117 ");
118 }
119
120 #[test]
121 fn test_validity_zip_allvalid_and_array() {
122 let if_true = StructArray::new(
123 FieldNames::from_iter(["a"]),
124 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
125 4,
126 Validity::AllValid,
127 )
128 .into_array();
129
130 let if_false = StructArray::new(
131 FieldNames::from_iter(["a"]),
132 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
133 4,
134 Validity::from_iter([false, false, true, true]),
135 )
136 .into_array();
137
138 let mask = Mask::from_iter([true, false, false, false]);
139
140 let result = zip(&if_true, &if_false, &mask).unwrap();
141
142 insta::assert_snapshot!(result.display_table(), @r"
143 ┌───────┐
144 │ a │
145 ├───────┤
146 │ 1i32 │
147 ├───────┤
148 │ null │
149 ├───────┤
150 │ 30i32 │
151 ├───────┤
152 │ 40i32 │
153 └───────┘
154 ");
155 }
156}