vortex_array/scalar_fn/fns/zip/
mod.rs1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_session::VortexSession;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::arrays::BoolArray;
20use crate::builders::ArrayBuilder;
21use crate::builders::builder_with_capacity;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::expr::Expression;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::EmptyOptions;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ScalarFnId;
30use crate::scalar_fn::ScalarFnVTable;
31use crate::scalar_fn::SimplifyCtx;
32use crate::scalar_fn::fns::literal::Literal;
33
34#[derive(Clone)]
42pub struct Zip;
43
44impl ScalarFnVTable for Zip {
45 type Options = EmptyOptions;
46
47 fn id(&self) -> ScalarFnId {
48 ScalarFnId::from("vortex.zip")
49 }
50
51 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52 Ok(Some(vec![]))
53 }
54
55 fn deserialize(
56 &self,
57 _metadata: &[u8],
58 _session: &VortexSession,
59 ) -> VortexResult<Self::Options> {
60 Ok(EmptyOptions)
61 }
62
63 fn arity(&self, _options: &Self::Options) -> Arity {
64 Arity::Exact(3)
65 }
66
67 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
68 match child_idx {
69 0 => ChildName::from("if_true"),
70 1 => ChildName::from("if_false"),
71 2 => ChildName::from("mask"),
72 _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
73 }
74 }
75
76 fn fmt_sql(
77 &self,
78 _options: &Self::Options,
79 expr: &Expression,
80 f: &mut Formatter<'_>,
81 ) -> std::fmt::Result {
82 write!(f, "zip(")?;
83 expr.child(0).fmt_sql(f)?;
84 write!(f, ", ")?;
85 expr.child(1).fmt_sql(f)?;
86 write!(f, ", ")?;
87 expr.child(2).fmt_sql(f)?;
88 write!(f, ")")
89 }
90
91 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
92 vortex_ensure!(
93 arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
94 "zip requires if_true and if_false to have the same base type, got {} and {}",
95 arg_dtypes[0],
96 arg_dtypes[1]
97 );
98 vortex_ensure!(
99 matches!(arg_dtypes[2], DType::Bool(_)),
100 "zip requires mask to be a boolean type, got {}",
101 arg_dtypes[2]
102 );
103 Ok(arg_dtypes[0]
104 .clone()
105 .union_nullability(arg_dtypes[1].nullability()))
106 }
107
108 fn execute(
109 &self,
110 _options: &Self::Options,
111 args: &dyn ExecutionArgs,
112 ctx: &mut ExecutionCtx,
113 ) -> VortexResult<ArrayRef> {
114 let if_true = args.get(0)?;
115 let if_false = args.get(1)?;
116 let mask_array = args.get(2)?;
117
118 let mask = mask_array.execute::<BoolArray>(ctx)?.to_mask();
119
120 let return_dtype = if_true
121 .dtype()
122 .clone()
123 .union_nullability(if_false.dtype().nullability());
124
125 if mask.all_true() {
126 return if_true.cast(return_dtype)?.execute(ctx);
127 }
128
129 let return_dtype = if_true
130 .dtype()
131 .clone()
132 .union_nullability(if_false.dtype().nullability());
133
134 if mask.all_false() {
135 return if_false.cast(return_dtype)?.execute(ctx);
136 }
137
138 if !if_true.is_canonical() || !if_false.is_canonical() {
139 let if_true = if_true.execute::<ArrayRef>(ctx)?;
140 let if_false = if_false.execute::<ArrayRef>(ctx)?;
141 return mask.into_array().zip(if_true, if_false);
142 }
143
144 zip_impl(&if_true, &if_false, &mask)
145 }
146
147 fn simplify(
148 &self,
149 _options: &Self::Options,
150 expr: &Expression,
151 _ctx: &dyn SimplifyCtx,
152 ) -> VortexResult<Option<Expression>> {
153 let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
154 return Ok(None);
155 };
156
157 if let Some(mask_val) = mask_lit.as_bool().value() {
158 if mask_val {
159 return Ok(Some(expr.child(0).clone()));
160 } else {
161 return Ok(Some(expr.child(1).clone()));
162 }
163 }
164
165 Ok(None)
166 }
167
168 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
169 true
170 }
171
172 fn is_fallible(&self, _options: &Self::Options) -> bool {
173 false
174 }
175}
176
177pub(crate) fn zip_impl(
178 if_true: &ArrayRef,
179 if_false: &ArrayRef,
180 mask: &Mask,
181) -> VortexResult<ArrayRef> {
182 assert_eq!(
183 if_true.len(),
184 if_false.len(),
185 "zip requires arrays to have the same size"
186 );
187
188 let return_type = if_true
189 .dtype()
190 .clone()
191 .union_nullability(if_false.dtype().nullability());
192 zip_impl_with_builder(
193 if_true,
194 if_false,
195 mask,
196 builder_with_capacity(&return_type, if_true.len()),
197 )
198}
199
200fn zip_impl_with_builder(
201 if_true: &ArrayRef,
202 if_false: &ArrayRef,
203 mask: &Mask,
204 mut builder: Box<dyn ArrayBuilder>,
205) -> VortexResult<ArrayRef> {
206 match mask.slices() {
207 AllOr::All => Ok(if_true.to_array()),
208 AllOr::None => Ok(if_false.to_array()),
209 AllOr::Some(slices) => {
210 for (start, end) in slices {
211 builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
212 builder.extend_from_array(&if_true.slice(*start..*end)?);
213 }
214 if builder.len() < if_false.len() {
215 builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
216 }
217 Ok(builder.finish())
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use arrow_array::cast::AsArray;
225 use arrow_select::zip::zip as arrow_zip;
226 use vortex_buffer::buffer;
227 use vortex_error::VortexResult;
228 use vortex_mask::Mask;
229
230 use crate::Array;
231 use crate::ArrayRef;
232 use crate::IntoArray;
233 use crate::LEGACY_SESSION;
234 use crate::VortexSessionExecute;
235 use crate::arrays::ConstantArray;
236 use crate::arrays::PrimitiveArray;
237 use crate::arrays::StructArray;
238 use crate::arrays::StructVTable;
239 use crate::arrays::VarBinViewArray;
240 use crate::arrow::IntoArrowArray;
241 use crate::assert_arrays_eq;
242 use crate::builders::ArrayBuilder;
243 use crate::builders::BufferGrowthStrategy;
244 use crate::builders::VarBinViewBuilder;
245 use crate::builtins::ArrayBuiltins;
246 use crate::dtype::DType;
247 use crate::dtype::Nullability;
248 use crate::dtype::PType;
249 use crate::expr::lit;
250 use crate::expr::root;
251 use crate::expr::zip_expr;
252 use crate::scalar::Scalar;
253
254 #[test]
255 fn dtype() {
256 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
257 let expr = zip_expr(lit(true), root(), lit(0i32));
258 let result_dtype = expr.return_dtype(&dtype).unwrap();
259 assert_eq!(
260 result_dtype,
261 DType::Primitive(PType::I32, Nullability::NonNullable)
262 );
263 }
264
265 #[test]
266 fn test_display() {
267 let expr = zip_expr(lit(true), root(), lit(0i32));
268 assert_eq!(expr.to_string(), "zip($, 0i32, true)");
269 }
270
271 #[test]
272 fn test_zip_basic() {
273 let mask = Mask::from_iter([true, false, false, true, false]);
274 let if_true = buffer![10, 20, 30, 40, 50].into_array();
275 let if_false = buffer![1, 2, 3, 4, 5].into_array();
276
277 let result = mask.into_array().zip(if_true, if_false).unwrap();
278 let expected = buffer![10, 2, 3, 40, 5].into_array();
279
280 assert_arrays_eq!(result, expected);
281 }
282
283 #[test]
284 fn test_zip_all_true() {
285 let mask = Mask::new_true(4);
286 let if_true = buffer![10, 20, 30, 40].into_array();
287 let if_false =
288 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
289
290 let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
291 let expected =
292 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
293
294 assert_arrays_eq!(result, expected);
295
296 assert_eq!(result.dtype(), if_false.dtype())
298 }
299
300 #[test]
301 #[should_panic]
302 fn test_invalid_lengths() {
303 let mask = Mask::new_false(4);
304 let if_true = buffer![10, 20, 30].into_array();
305 let if_false = buffer![1, 2, 3, 4].into_array();
306
307 let _result = mask.into_array().zip(if_true, if_false).unwrap();
308 }
309
310 #[test]
311 fn test_fragmentation() -> VortexResult<()> {
312 let len = 100;
313
314 let const1 = ConstantArray::new(
315 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
316 len,
317 )
318 .into_array();
319
320 let const2 = ConstantArray::new(
321 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
322 len,
323 )
324 .into_array();
325
326 let indices: Vec<usize> = (0..len).step_by(2).collect();
327 let mask = Mask::from_indices(len, indices);
328 let mask_array = mask.into_array();
329
330 let result = mask_array
331 .clone()
332 .zip(const1.clone(), const2.clone())?
333 .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
334
335 insta::assert_snapshot!(result.display_tree(), @r"
336 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
337 metadata: EmptyMetadata
338 buffer: buffer_0 host 29 B (align=1) (1.75%)
339 buffer: buffer_1 host 28 B (align=1) (1.69%)
340 buffer: views host 1.60 kB (align=16) (96.56%)
341 ");
342
343 let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
345 let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
346
347 let wrapped_result = mask_array
348 .zip(wrapped1, wrapped2)?
349 .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())?;
350 assert!(wrapped_result.is::<StructVTable>());
351
352 Ok(())
353 }
354
355 #[test]
356 fn test_varbinview_zip() {
357 let if_true = {
358 let mut builder = VarBinViewBuilder::new(
359 DType::Utf8(Nullability::NonNullable),
360 10,
361 Default::default(),
362 BufferGrowthStrategy::fixed(64 * 1024),
363 0.0,
364 );
365 for _ in 0..100 {
366 builder.append_value("Hello");
367 builder.append_value("Hello this is a long string that won't be inlined.");
368 }
369 builder.finish()
370 };
371
372 let if_false = {
373 let mut builder = VarBinViewBuilder::new(
374 DType::Utf8(Nullability::NonNullable),
375 10,
376 Default::default(),
377 BufferGrowthStrategy::fixed(64 * 1024),
378 0.0,
379 );
380 for _ in 0..100 {
381 builder.append_value("Hello2");
382 builder.append_value("Hello2 this is a long string that won't be inlined.");
383 }
384 builder.finish()
385 };
386
387 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
389 let mask_array = mask.clone().into_array();
390
391 let zipped = mask_array
392 .zip(if_true.clone(), if_false.clone())
393 .unwrap()
394 .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
395 .unwrap();
396 assert_eq!(zipped.nbuffers(), 2);
397
398 let expected = arrow_zip(
400 mask.into_array()
401 .into_arrow_preferred()
402 .unwrap()
403 .as_boolean(),
404 &if_true.into_arrow_preferred().unwrap(),
405 &if_false.into_arrow_preferred().unwrap(),
406 )
407 .unwrap();
408
409 let actual = zipped.into_array().into_arrow_preferred().unwrap();
410 assert_eq!(actual.as_ref(), expected.as_ref());
411 }
412}