1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexExpect as _;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_mask::Mask;
13use vortex_mask::MaskValues;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::BoolArray;
21use crate::arrays::bool::BoolArrayExt;
22use crate::builders::ArrayBuilder;
23use crate::builders::builder_with_capacity;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::expr::Expression;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::EmptyOptions;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::literal::Literal;
35use crate::validity::Validity;
36
37#[derive(Clone)]
45pub struct Zip;
46
47impl ScalarFnVTable for Zip {
48 type Options = EmptyOptions;
49
50 fn id(&self) -> ScalarFnId {
51 static ID: CachedId = CachedId::new("vortex.zip");
52 *ID
53 }
54
55 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56 Ok(Some(vec![]))
57 }
58
59 fn deserialize(
60 &self,
61 _metadata: &[u8],
62 _session: &VortexSession,
63 ) -> VortexResult<Self::Options> {
64 Ok(EmptyOptions)
65 }
66
67 fn arity(&self, _options: &Self::Options) -> Arity {
68 Arity::Exact(3)
69 }
70
71 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
72 match child_idx {
73 0 => ChildName::from("if_true"),
74 1 => ChildName::from("if_false"),
75 2 => ChildName::from("mask"),
76 _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
77 }
78 }
79
80 fn fmt_sql(
81 &self,
82 _options: &Self::Options,
83 expr: &Expression,
84 f: &mut Formatter<'_>,
85 ) -> std::fmt::Result {
86 write!(f, "zip(")?;
87 expr.child(0).fmt_sql(f)?;
88 write!(f, ", ")?;
89 expr.child(1).fmt_sql(f)?;
90 write!(f, ", ")?;
91 expr.child(2).fmt_sql(f)?;
92 write!(f, ")")
93 }
94
95 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
96 vortex_ensure!(
97 matches!(arg_dtypes[2], DType::Bool(_)),
98 "zip requires mask to be a boolean type, got {}",
99 arg_dtypes[2]
100 );
101 zip_return_dtype(&arg_dtypes[0], &arg_dtypes[1])
102 }
103
104 fn execute(
105 &self,
106 _options: &Self::Options,
107 args: &dyn ExecutionArgs,
108 ctx: &mut ExecutionCtx,
109 ) -> VortexResult<ArrayRef> {
110 let if_true = args.get(0)?;
111 let if_false = args.get(1)?;
112 let mask_array = args.get(2)?;
113
114 let mask = mask_array
115 .execute::<BoolArray>(ctx)?
116 .to_mask_fill_null_false(ctx);
117
118 let return_dtype = zip_return_dtype(if_true.dtype(), if_false.dtype())?;
119
120 if mask.all_true() {
121 return if_true.cast(return_dtype)?.execute(ctx);
122 }
123
124 if mask.all_false() {
125 return if_false.cast(return_dtype)?.execute(ctx);
126 }
127
128 if !if_true.is_canonical() || !if_false.is_canonical() {
129 let if_true = if_true.execute::<ArrayRef>(ctx)?;
130 let if_false = if_false.execute::<ArrayRef>(ctx)?;
131 return mask.into_array().zip(if_true, if_false);
132 }
133
134 zip_impl(&if_true, &if_false, &mask, ctx)
135 }
136
137 fn simplify(
138 &self,
139 _options: &Self::Options,
140 expr: &Expression,
141 _ctx: &dyn SimplifyCtx,
142 ) -> VortexResult<Option<Expression>> {
143 let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
144 return Ok(None);
145 };
146
147 if let Some(mask_val) = mask_lit.as_bool().value() {
148 if mask_val {
149 return Ok(Some(expr.child(0).clone()));
150 } else {
151 return Ok(Some(expr.child(1).clone()));
152 }
153 }
154
155 Ok(None)
156 }
157
158 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
159 true
160 }
161
162 fn is_fallible(&self, _options: &Self::Options) -> bool {
163 false
164 }
165}
166
167pub(crate) fn zip_impl(
168 if_true: &ArrayRef,
169 if_false: &ArrayRef,
170 mask: &Mask,
171 ctx: &mut ExecutionCtx,
172) -> VortexResult<ArrayRef> {
173 assert_eq!(
174 if_true.len(),
175 if_false.len(),
176 "zip requires arrays to have the same size"
177 );
178
179 let return_type = zip_return_dtype(if_true.dtype(), if_false.dtype())?;
180
181 if mask.all_true() {
182 return if_true.cast(return_type);
183 }
184 if mask.all_false() {
185 return if_false.cast(return_type);
186 }
187
188 let if_true = if_true.cast(return_type.clone())?;
191 let if_false = if_false.cast(return_type.clone())?;
192
193 zip_impl_with_builder(
194 &if_true,
195 &if_false,
196 mask.values()
197 .vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"),
198 builder_with_capacity(&return_type, if_true.len()),
199 ctx,
200 )
201}
202
203fn zip_return_dtype(if_true: &DType, if_false: &DType) -> VortexResult<DType> {
204 vortex_ensure!(
205 if_true.eq_ignore_nullability(if_false),
206 "zip requires if_true and if_false to have the same base type, got {} and {}",
207 if_true,
208 if_false
209 );
210 Ok(if_true
211 .least_supertype(if_false)
212 .vortex_expect("zip inputs with the same base type must have a common dtype"))
213}
214
215fn zip_impl_with_builder(
216 if_true: &ArrayRef,
217 if_false: &ArrayRef,
218 mask: &MaskValues,
219 mut builder: Box<dyn ArrayBuilder>,
220 ctx: &mut ExecutionCtx,
221) -> VortexResult<ArrayRef> {
222 for (start, end) in mask.slices() {
223 if builder.len() < *start {
224 if_false
225 .slice(builder.len()..*start)?
226 .append_to_builder(builder.as_mut(), ctx)?;
227 }
228 if_true
229 .slice(*start..*end)?
230 .append_to_builder(builder.as_mut(), ctx)?;
231 }
232 if builder.len() < if_false.len() {
233 if_false
234 .slice(builder.len()..if_false.len())?
235 .append_to_builder(builder.as_mut(), ctx)?;
236 }
237 Ok(builder.finish())
238}
239
240pub(crate) fn zip_validity(
249 if_true: Validity,
250 if_false: Validity,
251 mask: &Mask,
252) -> VortexResult<Validity> {
253 match (&if_true, &if_false) {
254 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
255 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
256 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
257 _ => {}
258 }
259
260 let len = mask.len();
261 let validity = mask
262 .clone()
263 .into_array()
264 .zip(if_true.to_array(len), if_false.to_array(len))?;
265 Ok(Validity::Array(validity))
266}
267
268#[cfg(test)]
269mod tests {
270 use arrow_array::cast::AsArray;
271 use arrow_select::zip::zip as arrow_zip;
272 use vortex_buffer::buffer;
273 use vortex_error::VortexResult;
274 use vortex_mask::Mask;
275
276 use super::zip_impl;
277 use crate::ArrayRef;
278 use crate::IntoArray;
279 use crate::LEGACY_SESSION;
280 use crate::VortexSessionExecute;
281 use crate::arrays::ConstantArray;
282 use crate::arrays::PrimitiveArray;
283 use crate::arrays::Struct;
284 use crate::arrays::StructArray;
285 use crate::arrays::VarBinView;
286 use crate::arrow::ArrowSessionExt;
287 use crate::assert_arrays_eq;
288 use crate::builders::ArrayBuilder;
289 use crate::builders::BufferGrowthStrategy;
290 use crate::builders::VarBinViewBuilder;
291 use crate::builtins::ArrayBuiltins;
292 use crate::columnar::Columnar;
293 use crate::dtype::DType;
294 use crate::dtype::Nullability;
295 use crate::dtype::PType;
296 use crate::expr::lit;
297 use crate::expr::root;
298 use crate::expr::zip_expr;
299 use crate::scalar::Scalar;
300
301 #[test]
302 fn dtype() {
303 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
304 let expr = zip_expr(lit(true), root(), lit(0i32));
305 let result_dtype = expr.return_dtype(&dtype).unwrap();
306 assert_eq!(
307 result_dtype,
308 DType::Primitive(PType::I32, Nullability::NonNullable)
309 );
310 }
311
312 #[test]
313 fn test_display() {
314 let expr = zip_expr(lit(true), root(), lit(0i32));
315 assert_eq!(expr.to_string(), "zip($, 0i32, true)");
316 }
317
318 #[test]
319 fn test_zip_basic() {
320 let mask = Mask::from_iter([true, false, false, true, false]);
321 let if_true = buffer![10, 20, 30, 40, 50].into_array();
322 let if_false = buffer![1, 2, 3, 4, 5].into_array();
323
324 let result = mask.into_array().zip(if_true, if_false).unwrap();
325 let expected = buffer![10, 2, 3, 40, 5].into_array();
326
327 assert_arrays_eq!(result, expected);
328 }
329
330 #[test]
331 fn test_zip_all_true() {
332 let mask = Mask::new_true(4);
333 let if_true = buffer![10, 20, 30, 40].into_array();
334 let if_false =
335 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
336
337 let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
338 let expected =
339 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
340
341 assert_arrays_eq!(result, expected);
342 assert_eq!(result.dtype(), if_false.dtype())
343 }
344
345 #[test]
346 fn test_zip_all_false_widens_nullability() {
347 let mask = Mask::new_false(4);
348 let if_true =
349 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
350 let if_false = buffer![1i32, 2, 3, 4].into_array();
351
352 let result = mask.into_array().zip(if_true.clone(), if_false).unwrap();
353 let expected =
354 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array();
355
356 assert_arrays_eq!(result, expected);
357 assert_eq!(result.dtype(), if_true.dtype());
358 }
359
360 #[test]
361 fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> {
362 let mask = Mask::new_true(4);
363 let if_true = buffer![10i32, 20, 30, 40].into_array();
364 let if_false =
365 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
366
367 let mut ctx = LEGACY_SESSION.create_execution_ctx();
368 let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
369 assert_arrays_eq!(
370 result,
371 PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)])
372 .into_array()
373 );
374 assert_eq!(result.dtype(), if_false.dtype());
375 Ok(())
376 }
377
378 #[test]
379 fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> {
380 let mask = Mask::new_false(4);
381 let if_true =
382 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
383 let if_false = buffer![1i32, 2, 3, 4].into_array();
384
385 let mut ctx = LEGACY_SESSION.create_execution_ctx();
386 let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
387 assert_arrays_eq!(
388 result,
389 PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array()
390 );
391 assert_eq!(result.dtype(), if_true.dtype());
392 Ok(())
393 }
394
395 #[test]
396 #[should_panic]
397 fn test_invalid_lengths() {
398 let mask = Mask::new_false(4);
399 let if_true = buffer![10, 20, 30].into_array();
400 let if_false = buffer![1, 2, 3, 4].into_array();
401
402 let _result = mask.into_array().zip(if_true, if_false).unwrap();
403 }
404
405 #[test]
406 fn test_fragmentation() -> VortexResult<()> {
407 let len = 100;
408
409 let const1 = ConstantArray::new(
410 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
411 len,
412 )
413 .into_array();
414
415 let const2 = ConstantArray::new(
416 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
417 len,
418 )
419 .into_array();
420
421 let indices: Vec<usize> = (0..len).step_by(2).collect();
422 let mask = Mask::from_indices(len, indices);
423 let mask_array = mask.into_array();
424
425 let mut ctx = LEGACY_SESSION.create_execution_ctx();
426 let result = mask_array
427 .zip(const1.clone(), const2.clone())?
428 .execute::<Columnar>(&mut ctx)?
429 .into_array();
430
431 insta::assert_snapshot!(result.display_tree(), @r"
432 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
433 metadata:
434 buffer: buffer_0 host 29 B (align=1) (1.75%)
435 buffer: buffer_1 host 28 B (align=1) (1.69%)
436 buffer: views host 1.60 kB (align=16) (96.56%)
437 ");
438
439 let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
440 let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
441
442 let wrapped_result = mask_array
443 .zip(wrapped1, wrapped2)?
444 .execute::<ArrayRef>(&mut ctx)?;
445 assert!(wrapped_result.is::<Struct>());
446
447 Ok(())
448 }
449
450 #[test]
451 fn test_varbinview_zip() {
452 let if_true = {
453 let mut builder = VarBinViewBuilder::new(
454 DType::Utf8(Nullability::NonNullable),
455 10,
456 Default::default(),
457 BufferGrowthStrategy::fixed(64 * 1024),
458 0.0,
459 );
460 for _ in 0..100 {
461 builder.append_value("Hello");
462 builder.append_value("Hello this is a long string that won't be inlined.");
463 }
464 builder.finish()
465 };
466
467 let if_false = {
468 let mut builder = VarBinViewBuilder::new(
469 DType::Utf8(Nullability::NonNullable),
470 10,
471 Default::default(),
472 BufferGrowthStrategy::fixed(64 * 1024),
473 0.0,
474 );
475 for _ in 0..100 {
476 builder.append_value("Hello2");
477 builder.append_value("Hello2 this is a long string that won't be inlined.");
478 }
479 builder.finish()
480 };
481
482 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0));
483 let mask_array = mask.clone().into_array();
484
485 let mut ctx = LEGACY_SESSION.create_execution_ctx();
486 let zipped = mask_array
487 .zip(if_true.clone(), if_false.clone())
488 .unwrap()
489 .execute::<ArrayRef>(&mut ctx)
490 .unwrap();
491 let zipped = zipped.as_opt::<VarBinView>().unwrap();
492 assert_eq!(zipped.data_buffers().len(), 2);
493
494 let mut arrow_ctx = LEGACY_SESSION.create_execution_ctx();
495 let expected = arrow_zip(
496 LEGACY_SESSION
497 .arrow()
498 .execute_arrow(mask.into_array(), None, &mut arrow_ctx)
499 .unwrap()
500 .as_boolean(),
501 &LEGACY_SESSION
502 .arrow()
503 .execute_arrow(if_true, None, &mut arrow_ctx)
504 .unwrap(),
505 &LEGACY_SESSION
506 .arrow()
507 .execute_arrow(if_false, None, &mut arrow_ctx)
508 .unwrap(),
509 )
510 .unwrap();
511
512 let actual = LEGACY_SESSION
513 .arrow()
514 .execute_arrow(zipped.array().clone(), None, &mut arrow_ctx)
515 .unwrap();
516 assert_eq!(actual.as_ref(), expected.as_ref());
517 }
518}