1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexExpect as _;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_mask::Mask;
13use vortex_mask::MaskValues;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::arrays::BoolArray;
20use crate::arrays::bool::BoolArrayExt;
21use crate::builders::ArrayBuilder;
22use crate::builders::builder_with_capacity;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::expr::Expression;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::EmptyOptions;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::scalar_fn::SimplifyCtx;
33use crate::scalar_fn::fns::literal::Literal;
34
35#[derive(Clone)]
43pub struct Zip;
44
45impl ScalarFnVTable for Zip {
46 type Options = EmptyOptions;
47
48 fn id(&self) -> ScalarFnId {
49 ScalarFnId::new("vortex.zip")
50 }
51
52 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53 Ok(Some(vec![]))
54 }
55
56 fn deserialize(
57 &self,
58 _metadata: &[u8],
59 _session: &VortexSession,
60 ) -> VortexResult<Self::Options> {
61 Ok(EmptyOptions)
62 }
63
64 fn arity(&self, _options: &Self::Options) -> Arity {
65 Arity::Exact(3)
66 }
67
68 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
69 match child_idx {
70 0 => ChildName::from("if_true"),
71 1 => ChildName::from("if_false"),
72 2 => ChildName::from("mask"),
73 _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
74 }
75 }
76
77 fn fmt_sql(
78 &self,
79 _options: &Self::Options,
80 expr: &Expression,
81 f: &mut Formatter<'_>,
82 ) -> std::fmt::Result {
83 write!(f, "zip(")?;
84 expr.child(0).fmt_sql(f)?;
85 write!(f, ", ")?;
86 expr.child(1).fmt_sql(f)?;
87 write!(f, ", ")?;
88 expr.child(2).fmt_sql(f)?;
89 write!(f, ")")
90 }
91
92 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
93 vortex_ensure!(
94 arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
95 "zip requires if_true and if_false to have the same base type, got {} and {}",
96 arg_dtypes[0],
97 arg_dtypes[1]
98 );
99 vortex_ensure!(
100 matches!(arg_dtypes[2], DType::Bool(_)),
101 "zip requires mask to be a boolean type, got {}",
102 arg_dtypes[2]
103 );
104 Ok(arg_dtypes[0]
105 .clone()
106 .union_nullability(arg_dtypes[1].nullability()))
107 }
108
109 fn execute(
110 &self,
111 _options: &Self::Options,
112 args: &dyn ExecutionArgs,
113 ctx: &mut ExecutionCtx,
114 ) -> VortexResult<ArrayRef> {
115 let if_true = args.get(0)?;
116 let if_false = args.get(1)?;
117 let mask_array = args.get(2)?;
118
119 let mask = mask_array
120 .execute::<BoolArray>(ctx)?
121 .to_mask_fill_null_false(ctx);
122
123 let return_dtype = if_true
124 .dtype()
125 .clone()
126 .union_nullability(if_false.dtype().nullability());
127
128 if mask.all_true() {
129 return if_true.cast(return_dtype)?.execute(ctx);
130 }
131
132 if mask.all_false() {
133 return if_false.cast(return_dtype)?.execute(ctx);
134 }
135
136 if !if_true.is_canonical() || !if_false.is_canonical() {
137 let if_true = if_true.execute::<ArrayRef>(ctx)?;
138 let if_false = if_false.execute::<ArrayRef>(ctx)?;
139 return mask.into_array().zip(if_true, if_false);
140 }
141
142 zip_impl(&if_true, &if_false, &mask, ctx)
143 }
144
145 fn simplify(
146 &self,
147 _options: &Self::Options,
148 expr: &Expression,
149 _ctx: &dyn SimplifyCtx,
150 ) -> VortexResult<Option<Expression>> {
151 let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
152 return Ok(None);
153 };
154
155 if let Some(mask_val) = mask_lit.as_bool().value() {
156 if mask_val {
157 return Ok(Some(expr.child(0).clone()));
158 } else {
159 return Ok(Some(expr.child(1).clone()));
160 }
161 }
162
163 Ok(None)
164 }
165
166 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
167 true
168 }
169
170 fn is_fallible(&self, _options: &Self::Options) -> bool {
171 false
172 }
173}
174
175pub(crate) fn zip_impl(
176 if_true: &ArrayRef,
177 if_false: &ArrayRef,
178 mask: &Mask,
179 ctx: &mut ExecutionCtx,
180) -> VortexResult<ArrayRef> {
181 assert_eq!(
182 if_true.len(),
183 if_false.len(),
184 "zip requires arrays to have the same size"
185 );
186
187 let return_type = if_true
188 .dtype()
189 .clone()
190 .union_nullability(if_false.dtype().nullability());
191
192 if mask.all_true() {
193 return if_true.cast(return_type);
194 }
195 if mask.all_false() {
196 return if_false.cast(return_type);
197 }
198
199 let if_true = if_true.cast(return_type.clone())?;
202 let if_false = if_false.cast(return_type.clone())?;
203
204 zip_impl_with_builder(
205 &if_true,
206 &if_false,
207 mask.values()
208 .vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"),
209 builder_with_capacity(&return_type, if_true.len()),
210 ctx,
211 )
212}
213
214fn zip_impl_with_builder(
215 if_true: &ArrayRef,
216 if_false: &ArrayRef,
217 mask: &MaskValues,
218 mut builder: Box<dyn ArrayBuilder>,
219 ctx: &mut ExecutionCtx,
220) -> VortexResult<ArrayRef> {
221 for (start, end) in mask.slices() {
222 if builder.len() < *start {
223 if_false
224 .slice(builder.len()..*start)?
225 .append_to_builder(builder.as_mut(), ctx)?;
226 }
227 if_true
228 .slice(*start..*end)?
229 .append_to_builder(builder.as_mut(), ctx)?;
230 }
231 if builder.len() < if_false.len() {
232 if_false
233 .slice(builder.len()..if_false.len())?
234 .append_to_builder(builder.as_mut(), ctx)?;
235 }
236 Ok(builder.finish())
237}
238
239#[cfg(test)]
240mod tests {
241 use arrow_array::cast::AsArray;
242 use arrow_select::zip::zip as arrow_zip;
243 use vortex_buffer::buffer;
244 use vortex_error::VortexResult;
245 use vortex_mask::Mask;
246
247 use super::zip_impl;
248 use crate::ArrayRef;
249 use crate::IntoArray;
250 use crate::LEGACY_SESSION;
251 use crate::VortexSessionExecute;
252 use crate::arrays::ConstantArray;
253 use crate::arrays::PrimitiveArray;
254 use crate::arrays::Struct;
255 use crate::arrays::StructArray;
256 use crate::arrays::VarBinView;
257 use crate::arrow::ArrowSessionExt;
258 use crate::assert_arrays_eq;
259 use crate::builders::ArrayBuilder;
260 use crate::builders::BufferGrowthStrategy;
261 use crate::builders::VarBinViewBuilder;
262 use crate::builtins::ArrayBuiltins;
263 use crate::columnar::Columnar;
264 use crate::dtype::DType;
265 use crate::dtype::Nullability;
266 use crate::dtype::PType;
267 use crate::expr::lit;
268 use crate::expr::root;
269 use crate::expr::zip_expr;
270 use crate::scalar::Scalar;
271
272 #[test]
273 fn dtype() {
274 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
275 let expr = zip_expr(lit(true), root(), lit(0i32));
276 let result_dtype = expr.return_dtype(&dtype).unwrap();
277 assert_eq!(
278 result_dtype,
279 DType::Primitive(PType::I32, Nullability::NonNullable)
280 );
281 }
282
283 #[test]
284 fn test_display() {
285 let expr = zip_expr(lit(true), root(), lit(0i32));
286 assert_eq!(expr.to_string(), "zip($, 0i32, true)");
287 }
288
289 #[test]
290 fn test_zip_basic() {
291 let mask = Mask::from_iter([true, false, false, true, false]);
292 let if_true = buffer![10, 20, 30, 40, 50].into_array();
293 let if_false = buffer![1, 2, 3, 4, 5].into_array();
294
295 let result = mask.into_array().zip(if_true, if_false).unwrap();
296 let expected = buffer![10, 2, 3, 40, 5].into_array();
297
298 assert_arrays_eq!(result, expected);
299 }
300
301 #[test]
302 fn test_zip_all_true() {
303 let mask = Mask::new_true(4);
304 let if_true = buffer![10, 20, 30, 40].into_array();
305 let if_false =
306 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
307
308 let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
309 let expected =
310 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
311
312 assert_arrays_eq!(result, expected);
313 assert_eq!(result.dtype(), if_false.dtype())
314 }
315
316 #[test]
317 fn test_zip_all_false_widens_nullability() {
318 let mask = Mask::new_false(4);
319 let if_true =
320 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
321 let if_false = buffer![1i32, 2, 3, 4].into_array();
322
323 let result = mask.into_array().zip(if_true.clone(), if_false).unwrap();
324 let expected =
325 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array();
326
327 assert_arrays_eq!(result, expected);
328 assert_eq!(result.dtype(), if_true.dtype());
329 }
330
331 #[test]
332 fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> {
333 let mask = Mask::new_true(4);
334 let if_true = buffer![10i32, 20, 30, 40].into_array();
335 let if_false =
336 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
337
338 let mut ctx = LEGACY_SESSION.create_execution_ctx();
339 let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
340 assert_arrays_eq!(
341 result,
342 PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)])
343 .into_array()
344 );
345 assert_eq!(result.dtype(), if_false.dtype());
346 Ok(())
347 }
348
349 #[test]
350 fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> {
351 let mask = Mask::new_false(4);
352 let if_true =
353 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
354 let if_false = buffer![1i32, 2, 3, 4].into_array();
355
356 let mut ctx = LEGACY_SESSION.create_execution_ctx();
357 let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
358 assert_arrays_eq!(
359 result,
360 PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array()
361 );
362 assert_eq!(result.dtype(), if_true.dtype());
363 Ok(())
364 }
365
366 #[test]
367 #[should_panic]
368 fn test_invalid_lengths() {
369 let mask = Mask::new_false(4);
370 let if_true = buffer![10, 20, 30].into_array();
371 let if_false = buffer![1, 2, 3, 4].into_array();
372
373 let _result = mask.into_array().zip(if_true, if_false).unwrap();
374 }
375
376 #[test]
377 fn test_fragmentation() -> VortexResult<()> {
378 let len = 100;
379
380 let const1 = ConstantArray::new(
381 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
382 len,
383 )
384 .into_array();
385
386 let const2 = ConstantArray::new(
387 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
388 len,
389 )
390 .into_array();
391
392 let indices: Vec<usize> = (0..len).step_by(2).collect();
393 let mask = Mask::from_indices(len, indices);
394 let mask_array = mask.into_array();
395
396 let mut ctx = LEGACY_SESSION.create_execution_ctx();
397 let result = mask_array
398 .zip(const1.clone(), const2.clone())?
399 .execute::<Columnar>(&mut ctx)?
400 .into_array();
401
402 insta::assert_snapshot!(result.display_tree(), @r"
403 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
404 metadata:
405 buffer: buffer_0 host 29 B (align=1) (1.75%)
406 buffer: buffer_1 host 28 B (align=1) (1.69%)
407 buffer: views host 1.60 kB (align=16) (96.56%)
408 ");
409
410 let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
411 let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
412
413 let wrapped_result = mask_array
414 .zip(wrapped1, wrapped2)?
415 .execute::<ArrayRef>(&mut ctx)?;
416 assert!(wrapped_result.is::<Struct>());
417
418 Ok(())
419 }
420
421 #[test]
422 fn test_varbinview_zip() {
423 let if_true = {
424 let mut builder = VarBinViewBuilder::new(
425 DType::Utf8(Nullability::NonNullable),
426 10,
427 Default::default(),
428 BufferGrowthStrategy::fixed(64 * 1024),
429 0.0,
430 );
431 for _ in 0..100 {
432 builder.append_value("Hello");
433 builder.append_value("Hello this is a long string that won't be inlined.");
434 }
435 builder.finish()
436 };
437
438 let if_false = {
439 let mut builder = VarBinViewBuilder::new(
440 DType::Utf8(Nullability::NonNullable),
441 10,
442 Default::default(),
443 BufferGrowthStrategy::fixed(64 * 1024),
444 0.0,
445 );
446 for _ in 0..100 {
447 builder.append_value("Hello2");
448 builder.append_value("Hello2 this is a long string that won't be inlined.");
449 }
450 builder.finish()
451 };
452
453 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0));
454 let mask_array = mask.clone().into_array();
455
456 let mut ctx = LEGACY_SESSION.create_execution_ctx();
457 let zipped = mask_array
458 .zip(if_true.clone(), if_false.clone())
459 .unwrap()
460 .execute::<ArrayRef>(&mut ctx)
461 .unwrap();
462 let zipped = zipped.as_opt::<VarBinView>().unwrap();
463 assert_eq!(zipped.data_buffers().len(), 2);
464
465 let mut arrow_ctx = LEGACY_SESSION.create_execution_ctx();
466 let expected = arrow_zip(
467 LEGACY_SESSION
468 .arrow()
469 .execute_arrow(mask.into_array(), None, &mut arrow_ctx)
470 .unwrap()
471 .as_boolean(),
472 &LEGACY_SESSION
473 .arrow()
474 .execute_arrow(if_true, None, &mut arrow_ctx)
475 .unwrap(),
476 &LEGACY_SESSION
477 .arrow()
478 .execute_arrow(if_false, None, &mut arrow_ctx)
479 .unwrap(),
480 )
481 .unwrap();
482
483 let actual = LEGACY_SESSION
484 .arrow()
485 .execute_arrow(zipped.array().clone(), None, &mut arrow_ctx)
486 .unwrap();
487 assert_eq!(actual.as_ref(), expected.as_ref());
488 }
489}