1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6use vortex_mask::AllOr;
7use vortex_mask::Mask;
8
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::builders::ArrayBuilder;
13use crate::builders::builder_with_capacity;
14use crate::builtins::ArrayBuiltins;
15
16pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
25 if_true
26 .to_array()
27 .zip(if_false.to_array(), mask.clone().into_array())
28}
29
30pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
31 if_true
32 .dtype()
33 .union_nullability(if_false.dtype().nullability())
34}
35
36pub(crate) fn zip_impl(
37 if_true: &dyn Array,
38 if_false: &dyn Array,
39 mask: &Mask,
40) -> VortexResult<ArrayRef> {
41 assert_eq!(
42 if_true.len(),
43 if_false.len(),
44 "zip requires arrays to have the same size"
45 );
46
47 let return_type = zip_return_dtype(if_true, if_false);
48 zip_impl_with_builder(
49 if_true,
50 if_false,
51 mask,
52 builder_with_capacity(&return_type, if_true.len()),
53 )
54}
55
56fn zip_impl_with_builder(
57 if_true: &dyn Array,
58 if_false: &dyn Array,
59 mask: &Mask,
60 mut builder: Box<dyn ArrayBuilder>,
61) -> VortexResult<ArrayRef> {
62 match mask.slices() {
63 AllOr::All => Ok(if_true.to_array()),
64 AllOr::None => Ok(if_false.to_array()),
65 AllOr::Some(slices) => {
66 for (start, end) in slices {
67 builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
68 builder.extend_from_array(&if_true.slice(*start..*end)?);
69 }
70 if builder.len() < if_false.len() {
71 builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
72 }
73 Ok(builder.finish())
74 }
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use arrow_array::cast::AsArray;
81 use arrow_select::zip::zip as arrow_zip;
82 use vortex_buffer::buffer;
83 use vortex_dtype::DType;
84 use vortex_dtype::Nullability;
85 use vortex_mask::Mask;
86
87 use crate::Array;
88 use crate::IntoArray;
89 use crate::arrays::ConstantArray;
90 use crate::arrays::PrimitiveArray;
91 use crate::arrays::StructArray;
92 use crate::arrays::VarBinViewVTable;
93 use crate::arrow::IntoArrowArray;
94 use crate::assert_arrays_eq;
95 use crate::builders::ArrayBuilder;
96 use crate::builders::BufferGrowthStrategy;
97 use crate::builders::VarBinViewBuilder;
98 use crate::compute::zip;
99 use crate::scalar::Scalar;
100
101 #[test]
102 fn test_zip_basic() {
103 let mask = Mask::from_iter([true, false, false, true, false]);
104 let if_true = buffer![10, 20, 30, 40, 50].into_array();
105 let if_false = buffer![1, 2, 3, 4, 5].into_array();
106
107 let result = zip(&if_true, &if_false, &mask).unwrap();
108 let expected = buffer![10, 2, 3, 40, 5].into_array();
109
110 assert_arrays_eq!(result, expected);
111 }
112
113 #[test]
114 fn test_zip_all_true() {
115 let mask = Mask::new_true(4);
116 let if_true = buffer![10, 20, 30, 40].into_array();
117 let if_false =
118 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
119
120 let result = zip(&if_true, &if_false, &mask).unwrap();
121 let expected =
122 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
123
124 assert_arrays_eq!(result, expected);
125
126 assert_eq!(result.dtype(), if_false.dtype())
128 }
129
130 #[test]
131 #[should_panic]
132 fn test_invalid_lengths() {
133 let mask = Mask::new_false(4);
134 let if_true = buffer![10, 20, 30].into_array();
135 let if_false = buffer![1, 2, 3, 4].into_array();
136
137 zip(&if_true, &if_false, &mask).unwrap();
138 }
139
140 #[test]
141 fn test_fragmentation() {
142 let len = 100;
143
144 let const1 = ConstantArray::new(
145 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
146 len,
147 )
148 .to_array();
149
150 let const2 = ConstantArray::new(
151 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
152 len,
153 )
154 .to_array();
155
156 let indices: Vec<usize> = (0..len).step_by(2).collect();
159 let mask = Mask::from_indices(len, indices);
160
161 let result = zip(&const1, &const2, &mask).unwrap();
162
163 insta::assert_snapshot!(result.display_tree(), @r"
164 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
165 metadata: EmptyMetadata
166 buffer: buffer_0 host 29 B (align=1) (1.75%)
167 buffer: buffer_1 host 28 B (align=1) (1.69%)
168 buffer: views host 1.60 kB (align=16) (96.56%)
169 ");
170
171 let wrapped1 = StructArray::try_from_iter([("nested", const1)])
173 .unwrap()
174 .to_array();
175 let wrapped2 = StructArray::try_from_iter([("nested", const2)])
176 .unwrap()
177 .to_array();
178
179 let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap();
180 insta::assert_snapshot!(wrapped_result.display_tree(), @r"
181 root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
182 metadata: EmptyMetadata
183 nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
184 metadata: EmptyMetadata
185 buffer: buffer_0 host 29 B (align=1) (1.75%)
186 buffer: buffer_1 host 28 B (align=1) (1.69%)
187 buffer: views host 1.60 kB (align=16) (96.56%)
188 ");
189 }
190
191 #[test]
192 fn test_varbinview_zip() {
193 let if_true = {
194 let mut builder = VarBinViewBuilder::new(
195 DType::Utf8(Nullability::NonNullable),
196 10,
197 Default::default(),
198 BufferGrowthStrategy::fixed(64 * 1024),
199 0.0,
200 );
201 for _ in 0..100 {
202 builder.append_value("Hello");
203 builder.append_value("Hello this is a long string that won't be inlined.");
204 }
205 builder.finish()
206 };
207
208 let if_false = {
209 let mut builder = VarBinViewBuilder::new(
210 DType::Utf8(Nullability::NonNullable),
211 10,
212 Default::default(),
213 BufferGrowthStrategy::fixed(64 * 1024),
214 0.0,
215 );
216 for _ in 0..100 {
217 builder.append_value("Hello2");
218 builder.append_value("Hello2 this is a long string that won't be inlined.");
219 }
220 builder.finish()
221 };
222
223 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
225
226 let zipped = zip(&if_true, &if_false, &mask).unwrap();
227 let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
228 assert_eq!(zipped.nbuffers(), 2);
229
230 let expected = arrow_zip(
232 mask.into_array()
233 .into_arrow_preferred()
234 .unwrap()
235 .as_boolean(),
236 &if_true.into_arrow_preferred().unwrap(),
237 &if_false.into_arrow_preferred().unwrap(),
238 )
239 .unwrap();
240
241 let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
242 assert_eq!(actual.as_ref(), expected.as_ref());
243 }
244}