vortex_array/scalar_fn/fns/zip/
mod.rs1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_mask::AllOr;
13use vortex_mask::Mask;
14use vortex_session::VortexSession;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::builders::ArrayBuilder;
20use crate::builders::builder_with_capacity;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::expr::Expression;
24use crate::scalar_fn::Arity;
25use crate::scalar_fn::ChildName;
26use crate::scalar_fn::EmptyOptions;
27use crate::scalar_fn::ExecutionArgs;
28use crate::scalar_fn::ScalarFnId;
29use crate::scalar_fn::ScalarFnVTable;
30use crate::scalar_fn::SimplifyCtx;
31use crate::scalar_fn::fns::literal::Literal;
32
33#[derive(Clone)]
41pub struct Zip;
42
43impl ScalarFnVTable for Zip {
44 type Options = EmptyOptions;
45
46 fn id(&self) -> ScalarFnId {
47 ScalarFnId::from("vortex.zip")
48 }
49
50 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 Ok(Some(vec![]))
52 }
53
54 fn deserialize(
55 &self,
56 _metadata: &[u8],
57 _session: &VortexSession,
58 ) -> VortexResult<Self::Options> {
59 Ok(EmptyOptions)
60 }
61
62 fn arity(&self, _options: &Self::Options) -> Arity {
63 Arity::Exact(3)
64 }
65
66 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
67 match child_idx {
68 0 => ChildName::from("if_true"),
69 1 => ChildName::from("if_false"),
70 2 => ChildName::from("mask"),
71 _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
72 }
73 }
74
75 fn fmt_sql(
76 &self,
77 _options: &Self::Options,
78 expr: &Expression,
79 f: &mut Formatter<'_>,
80 ) -> std::fmt::Result {
81 write!(f, "zip(")?;
82 expr.child(0).fmt_sql(f)?;
83 write!(f, ", ")?;
84 expr.child(1).fmt_sql(f)?;
85 write!(f, ", ")?;
86 expr.child(2).fmt_sql(f)?;
87 write!(f, ")")
88 }
89
90 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
91 vortex_ensure!(
92 arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
93 "zip requires if_true and if_false to have the same base type, got {} and {}",
94 arg_dtypes[0],
95 arg_dtypes[1]
96 );
97 vortex_ensure!(
98 matches!(arg_dtypes[2], DType::Bool(_)),
99 "zip requires mask to be a boolean type, got {}",
100 arg_dtypes[2]
101 );
102 Ok(arg_dtypes[0]
103 .clone()
104 .union_nullability(arg_dtypes[1].nullability()))
105 }
106
107 fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
108 let [if_true, if_false, mask_array]: [ArrayRef; _] = args
109 .inputs
110 .try_into()
111 .map_err(|_| vortex_err!("Wrong arg count"))?;
112
113 let mask = mask_array.try_to_mask_fill_null_false()?;
114
115 let return_dtype = if_true
116 .dtype()
117 .clone()
118 .union_nullability(if_false.dtype().nullability());
119
120 if mask.all_true() {
121 return if_true.cast(return_dtype)?.execute(args.ctx);
122 }
123
124 let return_dtype = if_true
125 .dtype()
126 .clone()
127 .union_nullability(if_false.dtype().nullability());
128
129 if mask.all_false() {
130 return if_false.cast(return_dtype)?.execute(args.ctx);
131 }
132
133 if !if_true.is_canonical() || !if_false.is_canonical() {
134 let if_true = if_true.execute::<ArrayRef>(args.ctx)?;
135 let if_false = if_false.execute::<ArrayRef>(args.ctx)?;
136 return if_true.zip(if_false, mask.into_array());
137 }
138
139 zip_impl(&if_true, &if_false, &mask)
140 }
141
142 fn simplify(
143 &self,
144 _options: &Self::Options,
145 expr: &Expression,
146 _ctx: &dyn SimplifyCtx,
147 ) -> VortexResult<Option<Expression>> {
148 let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
149 return Ok(None);
150 };
151
152 if let Some(mask_val) = mask_lit.as_bool().value() {
153 if mask_val {
154 return Ok(Some(expr.child(0).clone()));
155 } else {
156 return Ok(Some(expr.child(1).clone()));
157 }
158 }
159
160 Ok(None)
161 }
162
163 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
164 true
165 }
166
167 fn is_fallible(&self, _options: &Self::Options) -> bool {
168 false
169 }
170}
171
172pub(crate) fn zip_impl(
173 if_true: &dyn Array,
174 if_false: &dyn Array,
175 mask: &Mask,
176) -> VortexResult<ArrayRef> {
177 assert_eq!(
178 if_true.len(),
179 if_false.len(),
180 "zip requires arrays to have the same size"
181 );
182
183 let return_type = if_true
184 .dtype()
185 .clone()
186 .union_nullability(if_false.dtype().nullability());
187 zip_impl_with_builder(
188 if_true,
189 if_false,
190 mask,
191 builder_with_capacity(&return_type, if_true.len()),
192 )
193}
194
195fn zip_impl_with_builder(
196 if_true: &dyn Array,
197 if_false: &dyn Array,
198 mask: &Mask,
199 mut builder: Box<dyn ArrayBuilder>,
200) -> VortexResult<ArrayRef> {
201 match mask.slices() {
202 AllOr::All => Ok(if_true.to_array()),
203 AllOr::None => Ok(if_false.to_array()),
204 AllOr::Some(slices) => {
205 for (start, end) in slices {
206 builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
207 builder.extend_from_array(&if_true.slice(*start..*end)?);
208 }
209 if builder.len() < if_false.len() {
210 builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
211 }
212 Ok(builder.finish())
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use arrow_array::cast::AsArray;
220 use arrow_select::zip::zip as arrow_zip;
221 use vortex_buffer::buffer;
222 use vortex_mask::Mask;
223
224 use crate::Array;
225 use crate::IntoArray;
226 use crate::arrays::ConstantArray;
227 use crate::arrays::PrimitiveArray;
228 use crate::arrays::StructArray;
229 use crate::arrays::VarBinViewVTable;
230 use crate::arrow::IntoArrowArray;
231 use crate::assert_arrays_eq;
232 use crate::builders::ArrayBuilder;
233 use crate::builders::BufferGrowthStrategy;
234 use crate::builders::VarBinViewBuilder;
235 use crate::builtins::ArrayBuiltins;
236 use crate::dtype::DType;
237 use crate::dtype::Nullability;
238 use crate::dtype::PType;
239 use crate::expr::lit;
240 use crate::expr::root;
241 use crate::expr::zip_expr;
242 use crate::scalar::Scalar;
243
244 #[test]
245 fn dtype() {
246 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
247 let expr = zip_expr(root(), lit(0i32), lit(true));
248 let result_dtype = expr.return_dtype(&dtype).unwrap();
249 assert_eq!(
250 result_dtype,
251 DType::Primitive(PType::I32, Nullability::NonNullable)
252 );
253 }
254
255 #[test]
256 fn test_display() {
257 let expr = zip_expr(root(), lit(0i32), lit(true));
258 assert_eq!(expr.to_string(), "zip($, 0i32, true)");
259 }
260
261 #[test]
262 fn test_zip_basic() {
263 let mask = Mask::from_iter([true, false, false, true, false]);
264 let if_true = buffer![10, 20, 30, 40, 50].into_array();
265 let if_false = buffer![1, 2, 3, 4, 5].into_array();
266
267 let result = if_true.zip(if_false, mask.into_array()).unwrap();
268 let expected = buffer![10, 2, 3, 40, 5].into_array();
269
270 assert_arrays_eq!(result, expected);
271 }
272
273 #[test]
274 fn test_zip_all_true() {
275 let mask = Mask::new_true(4);
276 let if_true = buffer![10, 20, 30, 40].into_array();
277 let if_false =
278 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
279
280 let result = if_true.zip(if_false.clone(), mask.into_array()).unwrap();
281 let expected =
282 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
283
284 assert_arrays_eq!(result, expected);
285
286 assert_eq!(result.dtype(), if_false.dtype())
288 }
289
290 #[test]
291 #[should_panic]
292 fn test_invalid_lengths() {
293 let mask = Mask::new_false(4);
294 let if_true = buffer![10, 20, 30].into_array();
295 let if_false = buffer![1, 2, 3, 4].into_array();
296
297 let _result = if_true.zip(if_false, mask.into_array()).unwrap();
298 }
299
300 #[test]
301 fn test_fragmentation() {
302 let len = 100;
303
304 let const1 = ConstantArray::new(
305 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
306 len,
307 )
308 .to_array();
309
310 let const2 = ConstantArray::new(
311 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
312 len,
313 )
314 .to_array();
315
316 let indices: Vec<usize> = (0..len).step_by(2).collect();
317 let mask = Mask::from_indices(len, indices);
318 let mask_array = mask.into_array();
319
320 let result = const1.zip(const2.clone(), mask_array.clone()).unwrap();
321
322 insta::assert_snapshot!(result.display_tree(), @r"
323 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
324 metadata: EmptyMetadata
325 buffer: buffer_0 host 29 B (align=1) (1.75%)
326 buffer: buffer_1 host 28 B (align=1) (1.69%)
327 buffer: views host 1.60 kB (align=16) (96.56%)
328 ");
329
330 let wrapped1 = StructArray::try_from_iter([("nested", const1)])
332 .unwrap()
333 .to_array();
334 let wrapped2 = StructArray::try_from_iter([("nested", const2)])
335 .unwrap()
336 .to_array();
337
338 let wrapped_result = wrapped1.zip(wrapped2, mask_array).unwrap();
339 insta::assert_snapshot!(wrapped_result.display_tree(), @r"
340 root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
341 metadata: EmptyMetadata
342 nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
343 metadata: EmptyMetadata
344 buffer: buffer_0 host 29 B (align=1) (1.75%)
345 buffer: buffer_1 host 28 B (align=1) (1.69%)
346 buffer: views host 1.60 kB (align=16) (96.56%)
347 ");
348 }
349
350 #[test]
351 fn test_varbinview_zip() {
352 let if_true = {
353 let mut builder = VarBinViewBuilder::new(
354 DType::Utf8(Nullability::NonNullable),
355 10,
356 Default::default(),
357 BufferGrowthStrategy::fixed(64 * 1024),
358 0.0,
359 );
360 for _ in 0..100 {
361 builder.append_value("Hello");
362 builder.append_value("Hello this is a long string that won't be inlined.");
363 }
364 builder.finish()
365 };
366
367 let if_false = {
368 let mut builder = VarBinViewBuilder::new(
369 DType::Utf8(Nullability::NonNullable),
370 10,
371 Default::default(),
372 BufferGrowthStrategy::fixed(64 * 1024),
373 0.0,
374 );
375 for _ in 0..100 {
376 builder.append_value("Hello2");
377 builder.append_value("Hello2 this is a long string that won't be inlined.");
378 }
379 builder.finish()
380 };
381
382 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
384 let mask_array = mask.clone().into_array();
385
386 let zipped = if_true.zip(if_false.clone(), mask_array).unwrap();
387 let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
388 assert_eq!(zipped.nbuffers(), 2);
389
390 let expected = arrow_zip(
392 mask.into_array()
393 .into_arrow_preferred()
394 .unwrap()
395 .as_boolean(),
396 &if_true.into_arrow_preferred().unwrap(),
397 &if_false.into_arrow_preferred().unwrap(),
398 )
399 .unwrap();
400
401 let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
402 assert_eq!(actual.as_ref(), expected.as_ref());
403 }
404}