1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48 pub num_when_then_pairs: u32,
50 pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56 pub fn num_children(&self) -> usize {
58 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59 }
60}
61
62impl fmt::Display for CaseWhenOptions {
63 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64 write!(
65 f,
66 "case_when(pairs={}, else={})",
67 self.num_when_then_pairs, self.has_else
68 )
69 }
70}
71
72#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79 type Options = CaseWhenOptions;
80
81 fn id(&self) -> ScalarFnId {
82 ScalarFnId::new("vortex.case_when")
83 }
84
85 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86 vortex_bail!("cannot serialize")
90 }
91
92 fn deserialize(
93 &self,
94 metadata: &[u8],
95 _session: &VortexSession,
96 ) -> VortexResult<Self::Options> {
97 let opts = pb::CaseWhenOpts::decode(metadata)?;
98 if opts.num_children < 2 {
99 vortex_bail!(
100 "CaseWhen expects at least 2 children, got {}",
101 opts.num_children
102 );
103 }
104 Ok(CaseWhenOptions {
105 num_when_then_pairs: opts.num_children / 2,
106 has_else: opts.num_children % 2 == 1,
107 })
108 }
109
110 fn arity(&self, options: &Self::Options) -> Arity {
111 Arity::Exact(options.num_children())
112 }
113
114 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115 let num_pair_children = options.num_when_then_pairs as usize * 2;
116 if child_idx < num_pair_children {
117 let pair_idx = child_idx / 2;
118 if child_idx.is_multiple_of(2) {
119 ChildName::from(Arc::from(format!("when_{pair_idx}")))
120 } else {
121 ChildName::from(Arc::from(format!("then_{pair_idx}")))
122 }
123 } else if options.has_else && child_idx == num_pair_children {
124 ChildName::from("else")
125 } else {
126 unreachable!("Invalid child index {} for CaseWhen", child_idx)
127 }
128 }
129
130 fn fmt_sql(
131 &self,
132 options: &Self::Options,
133 expr: &Expression,
134 f: &mut Formatter<'_>,
135 ) -> fmt::Result {
136 write!(f, "CASE")?;
137 for i in 0..options.num_when_then_pairs as usize {
138 write!(
139 f,
140 " WHEN {} THEN {}",
141 expr.child(i * 2),
142 expr.child(i * 2 + 1)
143 )?;
144 }
145 if options.has_else {
146 let else_idx = options.num_when_then_pairs as usize * 2;
147 write!(f, " ELSE {}", expr.child(else_idx))?;
148 }
149 write!(f, " END")
150 }
151
152 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153 if options.num_when_then_pairs == 0 {
154 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155 }
156
157 let expected_len = options.num_children();
158 if arg_dtypes.len() != expected_len {
159 vortex_bail!(
160 "CaseWhen expects {expected_len} argument dtypes, got {}",
161 arg_dtypes.len()
162 );
163 }
164
165 let first_then = &arg_dtypes[1];
169 let mut result_dtype = first_then.clone();
170
171 for i in 1..options.num_when_then_pairs as usize {
172 let then_i = &arg_dtypes[i * 2 + 1];
173 if !first_then.eq_ignore_nullability(then_i) {
174 vortex_bail!(
175 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176 first_then,
177 then_i
178 );
179 }
180 result_dtype = result_dtype.union_nullability(then_i.nullability());
181 }
182
183 if options.has_else {
184 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185 if !result_dtype.eq_ignore_nullability(else_dtype) {
186 vortex_bail!(
187 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188 first_then,
189 else_dtype
190 );
191 }
192 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193 } else {
194 result_dtype = result_dtype.as_nullable();
196 }
197
198 Ok(result_dtype)
199 }
200
201 fn execute(
202 &self,
203 options: &Self::Options,
204 args: &dyn ExecutionArgs,
205 ctx: &mut ExecutionCtx,
206 ) -> VortexResult<ArrayRef> {
207 let row_count = args.row_count();
214 let num_pairs = options.num_when_then_pairs as usize;
215
216 let mut remaining = Mask::new_true(row_count);
217 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219 for i in 0..num_pairs {
220 if remaining.all_false() {
221 break;
222 }
223
224 let condition = args.get(i * 2)?;
225 let cond_bool = condition.execute::<BoolArray>(ctx)?;
226 let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
227 let effective_mask = &remaining & &cond_mask;
228
229 if effective_mask.all_false() {
230 continue;
231 }
232
233 let then_value = args.get(i * 2 + 1)?;
234 remaining = remaining.bitand_not(&cond_mask);
235 branches.push((effective_mask, then_value));
236 }
237
238 let else_value: ArrayRef = if options.has_else {
239 args.get(num_pairs * 2)?
240 } else {
241 let then_dtype = args.get(1)?.dtype().as_nullable();
242 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243 };
244
245 if branches.is_empty() {
246 return Ok(else_value);
247 }
248
249 merge_case_branches(branches, else_value, ctx)
250 }
251
252 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253 true
254 }
255
256 fn is_fallible(&self, _options: &Self::Options) -> bool {
257 false
258 }
259}
260
261const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265fn merge_case_branches(
269 branches: Vec<(Mask, ArrayRef)>,
270 else_value: ArrayRef,
271 ctx: &mut ExecutionCtx,
272) -> VortexResult<ArrayRef> {
273 if branches.len() == 1 {
274 let (mask, then_value) = &branches[0];
275 return zip_impl(then_value, &else_value, mask);
276 }
277
278 let output_nullability = branches
279 .iter()
280 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
281 acc | arr.dtype().nullability()
282 });
283 let output_dtype = else_value.dtype().with_nullability(output_nullability);
284 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
285
286 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
287 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
288 match mask.slices() {
289 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
290 AllOr::None => {}
291 AllOr::Some(slices) => {
292 for &(start, end) in slices {
293 spans.push((start, end, branch_idx));
294 }
295 }
296 }
297 }
298 spans.sort_unstable_by_key(|&(start, ..)| start);
299
300 if spans.is_empty() {
301 return else_value.cast(output_dtype);
302 }
303
304 let builder = builder_with_capacity(&output_dtype, else_value.len());
305
306 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
307 if fragmented {
308 merge_row_by_row(
309 &branch_arrays,
310 &else_value,
311 &spans,
312 &output_dtype,
313 builder,
314 ctx,
315 )
316 } else {
317 merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
318 }
319}
320
321fn merge_row_by_row(
324 branch_arrays: &[&ArrayRef],
325 else_value: &ArrayRef,
326 spans: &[(usize, usize, usize)],
327 output_dtype: &DType,
328 mut builder: Box<dyn ArrayBuilder>,
329 ctx: &mut ExecutionCtx,
330) -> VortexResult<ArrayRef> {
331 let mut pos = 0;
332 for &(start, end, branch_idx) in spans {
333 for row in pos..start {
334 let scalar = else_value.execute_scalar(row, ctx)?;
335 builder.append_scalar(&scalar.cast(output_dtype)?)?;
336 }
337 for row in start..end {
338 let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
339 builder.append_scalar(&scalar.cast(output_dtype)?)?;
340 }
341 pos = end;
342 }
343 for row in pos..else_value.len() {
344 let scalar = else_value.execute_scalar(row, ctx)?;
345 builder.append_scalar(&scalar.cast(output_dtype)?)?;
346 }
347
348 Ok(builder.finish())
349}
350
351fn merge_run_by_run(
355 branch_arrays: &[&ArrayRef],
356 else_value: &ArrayRef,
357 spans: &[(usize, usize, usize)],
358 output_dtype: &DType,
359 mut builder: Box<dyn ArrayBuilder>,
360) -> VortexResult<ArrayRef> {
361 let else_value = else_value.cast(output_dtype.clone())?;
362 let len = else_value.len();
363 for (start, end, branch_idx) in spans {
364 if builder.len() < *start {
365 builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
366 }
367 builder.extend_from_array(
368 &branch_arrays[*branch_idx]
369 .cast(output_dtype.clone())?
370 .slice(*start..*end)?,
371 );
372 }
373 if builder.len() < len {
374 builder.extend_from_array(&else_value.slice(builder.len()..len)?);
375 }
376
377 Ok(builder.finish())
378}
379
380#[cfg(test)]
381mod tests {
382 use std::sync::LazyLock;
383
384 use vortex_buffer::buffer;
385 use vortex_error::VortexExpect as _;
386 use vortex_session::VortexSession;
387
388 use super::*;
389 use crate::Canonical;
390 use crate::IntoArray;
391 use crate::LEGACY_SESSION;
392 use crate::VortexSessionExecute;
393 use crate::arrays::BoolArray;
394 use crate::arrays::PrimitiveArray;
395 use crate::arrays::StructArray;
396 use crate::assert_arrays_eq;
397 use crate::dtype::DType;
398 use crate::dtype::Nullability;
399 use crate::dtype::PType;
400 use crate::expr::case_when;
401 use crate::expr::case_when_no_else;
402 use crate::expr::col;
403 use crate::expr::eq;
404 use crate::expr::get_item;
405 use crate::expr::gt;
406 use crate::expr::lit;
407 use crate::expr::nested_case_when;
408 use crate::expr::root;
409 use crate::expr::test_harness;
410 use crate::scalar::Scalar;
411 use crate::session::ArraySession;
412
413 static SESSION: LazyLock<VortexSession> =
414 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
415
416 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
418 let mut ctx = SESSION.create_execution_ctx();
419 array
420 .clone()
421 .apply(expr)
422 .unwrap()
423 .execute::<Canonical>(&mut ctx)
424 .unwrap()
425 .into_array()
426 }
427
428 #[test]
431 #[should_panic(expected = "cannot serialize")]
432 fn test_serialization_roundtrip() {
433 let options = CaseWhenOptions {
434 num_when_then_pairs: 1,
435 has_else: true,
436 };
437 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
438 let deserialized = CaseWhen
439 .deserialize(&serialized, &VortexSession::empty())
440 .unwrap();
441 assert_eq!(options, deserialized);
442 }
443
444 #[test]
445 #[should_panic(expected = "cannot serialize")]
446 fn test_serialization_no_else() {
447 let options = CaseWhenOptions {
448 num_when_then_pairs: 1,
449 has_else: false,
450 };
451 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
452 let deserialized = CaseWhen
453 .deserialize(&serialized, &VortexSession::empty())
454 .unwrap();
455 assert_eq!(options, deserialized);
456 }
457
458 #[test]
461 fn test_display_with_else() {
462 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
463 let display = format!("{}", expr);
464 assert!(display.contains("CASE"));
465 assert!(display.contains("WHEN"));
466 assert!(display.contains("THEN"));
467 assert!(display.contains("ELSE"));
468 assert!(display.contains("END"));
469 }
470
471 #[test]
472 fn test_display_no_else() {
473 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
474 let display = format!("{}", expr);
475 assert!(display.contains("CASE"));
476 assert!(display.contains("WHEN"));
477 assert!(display.contains("THEN"));
478 assert!(!display.contains("ELSE"));
479 assert!(display.contains("END"));
480 }
481
482 #[test]
483 fn test_display_nested_nary() {
484 let expr = nested_case_when(
486 vec![
487 (gt(col("x"), lit(10i32)), lit("high")),
488 (gt(col("x"), lit(5i32)), lit("medium")),
489 ],
490 Some(lit("low")),
491 );
492 let display = format!("{}", expr);
493 assert_eq!(display.matches("CASE").count(), 1);
494 assert_eq!(display.matches("WHEN").count(), 2);
495 assert_eq!(display.matches("THEN").count(), 2);
496 }
497
498 #[test]
501 fn test_return_dtype_with_else() {
502 let expr = case_when(lit(true), lit(100i32), lit(0i32));
503 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
504 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
505 assert_eq!(
506 result_dtype,
507 DType::Primitive(PType::I32, Nullability::NonNullable)
508 );
509 }
510
511 #[test]
512 fn test_return_dtype_with_nullable_else() {
513 let expr = case_when(
514 lit(true),
515 lit(100i32),
516 lit(Scalar::null(DType::Primitive(
517 PType::I32,
518 Nullability::Nullable,
519 ))),
520 );
521 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
522 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
523 assert_eq!(
524 result_dtype,
525 DType::Primitive(PType::I32, Nullability::Nullable)
526 );
527 }
528
529 #[test]
530 fn test_return_dtype_without_else_is_nullable() {
531 let expr = case_when_no_else(lit(true), lit(100i32));
532 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
533 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
534 assert_eq!(
535 result_dtype,
536 DType::Primitive(PType::I32, Nullability::Nullable)
537 );
538 }
539
540 #[test]
541 fn test_return_dtype_with_struct_input() {
542 let dtype = test_harness::struct_dtype();
543 let expr = case_when(
544 gt(get_item("col1", root()), lit(10u16)),
545 lit(100i32),
546 lit(0i32),
547 );
548 let result_dtype = expr.return_dtype(&dtype).unwrap();
549 assert_eq!(
550 result_dtype,
551 DType::Primitive(PType::I32, Nullability::NonNullable)
552 );
553 }
554
555 #[test]
556 fn test_return_dtype_mismatched_then_else_errors() {
557 let expr = case_when(lit(true), lit(100i32), lit("zero"));
558 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
559 let err = expr.return_dtype(&input_dtype).unwrap_err();
560 assert!(
561 err.to_string()
562 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
563 );
564 }
565
566 #[test]
569 fn test_arity_with_else() {
570 let options = CaseWhenOptions {
571 num_when_then_pairs: 1,
572 has_else: true,
573 };
574 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
575 }
576
577 #[test]
578 fn test_arity_without_else() {
579 let options = CaseWhenOptions {
580 num_when_then_pairs: 1,
581 has_else: false,
582 };
583 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
584 }
585
586 #[test]
589 fn test_child_names() {
590 let options = CaseWhenOptions {
591 num_when_then_pairs: 1,
592 has_else: true,
593 };
594 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
595 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
596 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
597 }
598
599 #[test]
602 #[should_panic(expected = "cannot serialize")]
603 fn test_serialization_roundtrip_nary() {
604 let options = CaseWhenOptions {
605 num_when_then_pairs: 3,
606 has_else: true,
607 };
608 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
609 let deserialized = CaseWhen
610 .deserialize(&serialized, &VortexSession::empty())
611 .unwrap();
612 assert_eq!(options, deserialized);
613 }
614
615 #[test]
616 #[should_panic(expected = "cannot serialize")]
617 fn test_serialization_roundtrip_nary_no_else() {
618 let options = CaseWhenOptions {
619 num_when_then_pairs: 4,
620 has_else: false,
621 };
622 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
623 let deserialized = CaseWhen
624 .deserialize(&serialized, &VortexSession::empty())
625 .unwrap();
626 assert_eq!(options, deserialized);
627 }
628
629 #[test]
632 fn test_arity_nary_with_else() {
633 let options = CaseWhenOptions {
634 num_when_then_pairs: 3,
635 has_else: true,
636 };
637 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
639 }
640
641 #[test]
642 fn test_arity_nary_without_else() {
643 let options = CaseWhenOptions {
644 num_when_then_pairs: 3,
645 has_else: false,
646 };
647 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
649 }
650
651 #[test]
654 fn test_child_names_nary() {
655 let options = CaseWhenOptions {
656 num_when_then_pairs: 3,
657 has_else: true,
658 };
659 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
660 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
661 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
662 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
663 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
664 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
665 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
666 }
667
668 #[test]
671 fn test_return_dtype_nary_mismatched_then_types_errors() {
672 let expr = nested_case_when(
673 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
674 Some(lit(0i32)),
675 );
676 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
677 let err = expr.return_dtype(&input_dtype).unwrap_err();
678 assert!(err.to_string().contains("THEN dtypes must match"));
679 }
680
681 #[test]
682 fn test_return_dtype_nary_mixed_nullability() {
683 let non_null_then = lit(100i32);
686 let nullable_then = lit(Scalar::null(DType::Primitive(
687 PType::I32,
688 Nullability::Nullable,
689 )));
690 let expr = nested_case_when(
691 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
692 Some(lit(0i32)),
693 );
694 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
695 let result = expr.return_dtype(&input_dtype).unwrap();
696 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
697 }
698
699 #[test]
700 fn test_return_dtype_nary_no_else_is_nullable() {
701 let expr = nested_case_when(
702 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
703 None,
704 );
705 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
706 let result = expr.return_dtype(&input_dtype).unwrap();
707 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
708 }
709
710 #[test]
713 fn test_replace_children() {
714 let expr = case_when(lit(true), lit(1i32), lit(0i32));
715 expr.with_children([lit(false), lit(2i32), lit(3i32)])
716 .vortex_expect("operation should succeed in test");
717 }
718
719 #[test]
722 fn test_evaluate_simple_condition() {
723 let test_array =
724 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
725 .unwrap()
726 .into_array();
727
728 let expr = case_when(
729 gt(get_item("value", root()), lit(2i32)),
730 lit(100i32),
731 lit(0i32),
732 );
733
734 let result = evaluate_expr(&expr, &test_array);
735 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
736 }
737
738 #[test]
739 fn test_evaluate_nary_multiple_conditions() {
740 let test_array =
742 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
743 .unwrap()
744 .into_array();
745
746 let expr = nested_case_when(
747 vec![
748 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
749 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
750 ],
751 Some(lit(0i32)),
752 );
753
754 let result = evaluate_expr(&expr, &test_array);
755 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
756 }
757
758 #[test]
759 fn test_evaluate_nary_first_match_wins() {
760 let test_array =
761 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
762 .unwrap()
763 .into_array();
764
765 let expr = nested_case_when(
767 vec![
768 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
769 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
770 ],
771 Some(lit(0i32)),
772 );
773
774 let result = evaluate_expr(&expr, &test_array);
775 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
776 }
777
778 #[test]
779 fn test_evaluate_no_else_returns_null() {
780 let test_array =
781 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
782 .unwrap()
783 .into_array();
784
785 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
786
787 let result = evaluate_expr(&expr, &test_array);
788 assert!(result.dtype().is_nullable());
789 assert_arrays_eq!(
790 result,
791 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
792 .into_array()
793 );
794 }
795
796 #[test]
797 fn test_evaluate_all_conditions_false() {
798 let test_array =
799 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
800 .unwrap()
801 .into_array();
802
803 let expr = case_when(
804 gt(get_item("value", root()), lit(100i32)),
805 lit(1i32),
806 lit(0i32),
807 );
808
809 let result = evaluate_expr(&expr, &test_array);
810 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
811 }
812
813 #[test]
814 fn test_evaluate_all_conditions_true() {
815 let test_array =
816 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
817 .unwrap()
818 .into_array();
819
820 let expr = case_when(
821 gt(get_item("value", root()), lit(0i32)),
822 lit(100i32),
823 lit(0i32),
824 );
825
826 let result = evaluate_expr(&expr, &test_array);
827 assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
828 }
829
830 #[test]
831 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
832 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
835 .unwrap()
836 .into_array();
837
838 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
839
840 let result = evaluate_expr(&expr, &test_array);
841 assert!(
842 result.dtype().is_nullable(),
843 "result dtype must be Nullable, got {:?}",
844 result.dtype()
845 );
846 assert_arrays_eq!(
847 result,
848 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
849 );
850 }
851
852 #[test]
853 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
854 let test_array =
862 StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
863
864 let nullable_20 =
865 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
866
867 let expr = nested_case_when(
868 vec![
869 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
870 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
871 ],
872 Some(lit(0i32)),
873 );
874
875 let result = evaluate_expr(&expr, &test_array);
876 assert!(
877 result.dtype().is_nullable(),
878 "result dtype must be Nullable, got {:?}",
879 result.dtype()
880 );
881 assert_arrays_eq!(
882 result,
883 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
884 );
885 Ok(())
886 }
887
888 #[test]
889 fn test_evaluate_with_literal_condition() {
890 let test_array = buffer![1i32, 2, 3].into_array();
891 let expr = case_when(lit(true), lit(100i32), lit(0i32));
892 let result = evaluate_expr(&expr, &test_array);
893
894 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
895 }
896
897 #[test]
898 fn test_evaluate_with_bool_column_result() {
899 let test_array =
900 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
901 .unwrap()
902 .into_array();
903
904 let expr = case_when(
905 gt(get_item("value", root()), lit(2i32)),
906 lit(true),
907 lit(false),
908 );
909
910 let result = evaluate_expr(&expr, &test_array);
911 assert_arrays_eq!(
912 result,
913 BoolArray::from_iter([false, false, true, true, true]).into_array()
914 );
915 }
916
917 #[test]
918 fn test_evaluate_with_nullable_condition() {
919 let test_array = StructArray::from_fields(&[(
920 "cond",
921 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
922 )])
923 .unwrap()
924 .into_array();
925
926 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
927
928 let result = evaluate_expr(&expr, &test_array);
929 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
930 }
931
932 #[test]
933 fn test_evaluate_with_nullable_result_values() {
934 let test_array = StructArray::from_fields(&[
935 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
936 (
937 "result",
938 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
939 .into_array(),
940 ),
941 ])
942 .unwrap()
943 .into_array();
944
945 let expr = case_when(
946 gt(get_item("value", root()), lit(2i32)),
947 get_item("result", root()),
948 lit(0i32),
949 );
950
951 let result = evaluate_expr(&expr, &test_array);
952 assert_arrays_eq!(
953 result,
954 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
955 .into_array()
956 );
957 }
958
959 #[test]
960 fn test_evaluate_with_all_null_condition() {
961 let test_array = StructArray::from_fields(&[(
962 "cond",
963 BoolArray::from_iter([None, None, None]).into_array(),
964 )])
965 .unwrap()
966 .into_array();
967
968 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
969
970 let result = evaluate_expr(&expr, &test_array);
971 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
972 }
973
974 #[test]
977 fn test_evaluate_nary_no_else_returns_null() {
978 let test_array =
979 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
980 .unwrap()
981 .into_array();
982
983 let expr = nested_case_when(
985 vec![
986 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
987 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
988 ],
989 None,
990 );
991
992 let result = evaluate_expr(&expr, &test_array);
993 assert!(result.dtype().is_nullable());
994 assert_arrays_eq!(
995 result,
996 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
997 .into_array()
998 );
999 }
1000
1001 #[test]
1002 fn test_evaluate_nary_many_conditions() {
1003 let test_array =
1004 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1005 .unwrap()
1006 .into_array();
1007
1008 let expr = nested_case_when(
1010 vec![
1011 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1012 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1013 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1014 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1015 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1016 ],
1017 Some(lit(0i32)),
1018 );
1019
1020 let result = evaluate_expr(&expr, &test_array);
1021 assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1022 }
1023
1024 #[test]
1025 fn test_evaluate_nary_all_false_no_else() {
1026 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1027 .unwrap()
1028 .into_array();
1029
1030 let expr = nested_case_when(
1032 vec![
1033 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1034 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1035 ],
1036 None,
1037 );
1038
1039 let result = evaluate_expr(&expr, &test_array);
1040 assert!(result.dtype().is_nullable());
1041 assert_arrays_eq!(
1042 result,
1043 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1044 );
1045 }
1046
1047 #[test]
1048 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1049 let test_array =
1050 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1051 .unwrap()
1052 .into_array();
1053
1054 let expr = nested_case_when(
1058 vec![
1059 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1060 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1061 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1062 ],
1063 Some(lit(0i32)),
1064 );
1065
1066 let result = evaluate_expr(&expr, &test_array);
1067 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1069 }
1070
1071 #[test]
1072 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1073 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1076 .unwrap()
1077 .into_array();
1078
1079 let expr = nested_case_when(
1080 vec![
1081 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1082 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1084 ],
1085 Some(lit(0i32)),
1086 );
1087
1088 let result = evaluate_expr(&expr, &test_array);
1089 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1090 }
1091
1092 #[test]
1093 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1094 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1097 .unwrap()
1098 .into_array();
1099
1100 let expr = nested_case_when(
1101 vec![
1102 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1103 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1106 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1107 ],
1108 Some(lit(0i32)),
1109 );
1110
1111 let result = evaluate_expr(&expr, &test_array);
1112 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1113 }
1114
1115 #[test]
1116 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1117 let test_array =
1119 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1120 .into_array();
1121
1122 let expr = nested_case_when(
1126 vec![
1127 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1128 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1129 ],
1130 Some(lit("none")),
1131 );
1132
1133 let result = evaluate_expr(&expr, &test_array);
1134 assert_eq!(
1135 result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1136 Scalar::utf8("low", Nullability::NonNullable)
1137 );
1138 assert_eq!(
1139 result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1140 Scalar::utf8("low", Nullability::NonNullable)
1141 );
1142 assert_eq!(
1143 result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1144 Scalar::utf8("high", Nullability::NonNullable)
1145 );
1146 assert_eq!(
1147 result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1148 Scalar::utf8("high", Nullability::NonNullable)
1149 );
1150 Ok(())
1151 }
1152
1153 #[test]
1154 fn test_evaluate_nary_with_nullable_conditions() {
1155 let test_array = StructArray::from_fields(&[
1156 (
1157 "cond1",
1158 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1159 ),
1160 (
1161 "cond2",
1162 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1163 ),
1164 ])
1165 .unwrap()
1166 .into_array();
1167
1168 let expr = nested_case_when(
1169 vec![
1170 (get_item("cond1", root()), lit(10i32)),
1171 (get_item("cond2", root()), lit(20i32)),
1172 ],
1173 Some(lit(0i32)),
1174 );
1175
1176 let result = evaluate_expr(&expr, &test_array);
1177 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1181 }
1182
1183 #[test]
1184 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1185 let n = 100usize;
1188
1189 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1191 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1192
1193 let result = merge_case_branches(
1194 vec![
1195 (
1196 branch0_mask,
1197 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1198 ),
1199 (
1200 branch1_mask,
1201 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1202 ),
1203 ],
1204 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1205 &mut SESSION.create_execution_ctx(),
1206 )?;
1207
1208 let expected: Vec<Option<i32>> = (0..n)
1210 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1211 .collect();
1212 assert_arrays_eq!(
1213 result,
1214 PrimitiveArray::from_option_iter(expected).into_array()
1215 );
1216 Ok(())
1217 }
1218}