1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48 pub num_when_then_pairs: u32,
50 pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56 pub fn num_children(&self) -> usize {
58 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59 }
60}
61
62impl fmt::Display for CaseWhenOptions {
63 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64 write!(
65 f,
66 "case_when(pairs={}, else={})",
67 self.num_when_then_pairs, self.has_else
68 )
69 }
70}
71
72#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79 type Options = CaseWhenOptions;
80
81 fn id(&self) -> ScalarFnId {
82 ScalarFnId::new("vortex.case_when")
83 }
84
85 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86 vortex_bail!("cannot serialize")
90 }
91
92 fn deserialize(
93 &self,
94 metadata: &[u8],
95 _session: &VortexSession,
96 ) -> VortexResult<Self::Options> {
97 let opts = pb::CaseWhenOpts::decode(metadata)?;
98 if opts.num_children < 2 {
99 vortex_bail!(
100 "CaseWhen expects at least 2 children, got {}",
101 opts.num_children
102 );
103 }
104 Ok(CaseWhenOptions {
105 num_when_then_pairs: opts.num_children / 2,
106 has_else: opts.num_children % 2 == 1,
107 })
108 }
109
110 fn arity(&self, options: &Self::Options) -> Arity {
111 Arity::Exact(options.num_children())
112 }
113
114 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115 let num_pair_children = options.num_when_then_pairs as usize * 2;
116 if child_idx < num_pair_children {
117 let pair_idx = child_idx / 2;
118 if child_idx.is_multiple_of(2) {
119 ChildName::from(Arc::from(format!("when_{pair_idx}")))
120 } else {
121 ChildName::from(Arc::from(format!("then_{pair_idx}")))
122 }
123 } else if options.has_else && child_idx == num_pair_children {
124 ChildName::from("else")
125 } else {
126 unreachable!("Invalid child index {} for CaseWhen", child_idx)
127 }
128 }
129
130 fn fmt_sql(
131 &self,
132 options: &Self::Options,
133 expr: &Expression,
134 f: &mut Formatter<'_>,
135 ) -> fmt::Result {
136 write!(f, "CASE")?;
137 for i in 0..options.num_when_then_pairs as usize {
138 write!(
139 f,
140 " WHEN {} THEN {}",
141 expr.child(i * 2),
142 expr.child(i * 2 + 1)
143 )?;
144 }
145 if options.has_else {
146 let else_idx = options.num_when_then_pairs as usize * 2;
147 write!(f, " ELSE {}", expr.child(else_idx))?;
148 }
149 write!(f, " END")
150 }
151
152 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153 if options.num_when_then_pairs == 0 {
154 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155 }
156
157 let expected_len = options.num_children();
158 if arg_dtypes.len() != expected_len {
159 vortex_bail!(
160 "CaseWhen expects {expected_len} argument dtypes, got {}",
161 arg_dtypes.len()
162 );
163 }
164
165 let first_then = &arg_dtypes[1];
169 let mut result_dtype = first_then.clone();
170
171 for i in 1..options.num_when_then_pairs as usize {
172 let then_i = &arg_dtypes[i * 2 + 1];
173 if !first_then.eq_ignore_nullability(then_i) {
174 vortex_bail!(
175 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176 first_then,
177 then_i
178 );
179 }
180 result_dtype = result_dtype.union_nullability(then_i.nullability());
181 }
182
183 if options.has_else {
184 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185 if !result_dtype.eq_ignore_nullability(else_dtype) {
186 vortex_bail!(
187 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188 first_then,
189 else_dtype
190 );
191 }
192 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193 } else {
194 result_dtype = result_dtype.as_nullable();
196 }
197
198 Ok(result_dtype)
199 }
200
201 fn execute(
202 &self,
203 options: &Self::Options,
204 args: &dyn ExecutionArgs,
205 ctx: &mut ExecutionCtx,
206 ) -> VortexResult<ArrayRef> {
207 let row_count = args.row_count();
214 let num_pairs = options.num_when_then_pairs as usize;
215
216 let mut remaining = Mask::new_true(row_count);
217 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219 for i in 0..num_pairs {
220 if remaining.all_false() {
221 break;
222 }
223
224 let condition = args.get(i * 2)?;
225 let cond_bool = condition.execute::<BoolArray>(ctx)?;
226 let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
227 let effective_mask = &remaining & &cond_mask;
228
229 if effective_mask.all_false() {
230 continue;
231 }
232
233 let then_value = args.get(i * 2 + 1)?;
234 remaining = remaining.bitand_not(&cond_mask);
235 branches.push((effective_mask, then_value));
236 }
237
238 let else_value: ArrayRef = if options.has_else {
239 args.get(num_pairs * 2)?
240 } else {
241 let then_dtype = args.get(1)?.dtype().as_nullable();
242 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243 };
244
245 if branches.is_empty() {
246 return Ok(else_value);
247 }
248
249 merge_case_branches(branches, else_value, ctx)
250 }
251
252 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253 true
254 }
255
256 fn is_fallible(&self, _options: &Self::Options) -> bool {
257 false
258 }
259}
260
261const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265fn merge_case_branches(
269 branches: Vec<(Mask, ArrayRef)>,
270 else_value: ArrayRef,
271 ctx: &mut ExecutionCtx,
272) -> VortexResult<ArrayRef> {
273 if branches.len() == 1 {
274 let (mask, then_value) = &branches[0];
275 return zip_impl(then_value, &else_value, mask, ctx);
276 }
277
278 let output_nullability = branches
279 .iter()
280 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
281 acc | arr.dtype().nullability()
282 });
283 let output_dtype = else_value.dtype().with_nullability(output_nullability);
284 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
285
286 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
287 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
288 match mask.slices() {
289 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
290 AllOr::None => {}
291 AllOr::Some(slices) => {
292 for &(start, end) in slices {
293 spans.push((start, end, branch_idx));
294 }
295 }
296 }
297 }
298 spans.sort_unstable_by_key(|&(start, ..)| start);
299
300 if spans.is_empty() {
301 return else_value.cast(output_dtype);
302 }
303
304 let builder = builder_with_capacity(&output_dtype, else_value.len());
305
306 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
307 if fragmented {
308 merge_row_by_row(
309 &branch_arrays,
310 &else_value,
311 &spans,
312 &output_dtype,
313 builder,
314 ctx,
315 )
316 } else {
317 merge_run_by_run(
318 &branch_arrays,
319 &else_value,
320 &spans,
321 &output_dtype,
322 builder,
323 ctx,
324 )
325 }
326}
327
328fn merge_row_by_row(
331 branch_arrays: &[&ArrayRef],
332 else_value: &ArrayRef,
333 spans: &[(usize, usize, usize)],
334 output_dtype: &DType,
335 mut builder: Box<dyn ArrayBuilder>,
336 ctx: &mut ExecutionCtx,
337) -> VortexResult<ArrayRef> {
338 let mut pos = 0;
339 for &(start, end, branch_idx) in spans {
340 for row in pos..start {
341 let scalar = else_value.execute_scalar(row, ctx)?;
342 builder.append_scalar(&scalar.cast(output_dtype)?)?;
343 }
344 for row in start..end {
345 let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
346 builder.append_scalar(&scalar.cast(output_dtype)?)?;
347 }
348 pos = end;
349 }
350 for row in pos..else_value.len() {
351 let scalar = else_value.execute_scalar(row, ctx)?;
352 builder.append_scalar(&scalar.cast(output_dtype)?)?;
353 }
354
355 Ok(builder.finish())
356}
357
358fn merge_run_by_run(
362 branch_arrays: &[&ArrayRef],
363 else_value: &ArrayRef,
364 spans: &[(usize, usize, usize)],
365 output_dtype: &DType,
366 mut builder: Box<dyn ArrayBuilder>,
367 ctx: &mut ExecutionCtx,
368) -> VortexResult<ArrayRef> {
369 let else_value = else_value.cast(output_dtype.clone())?;
370 let len = else_value.len();
371 for (start, end, branch_idx) in spans {
372 if builder.len() < *start {
373 else_value
374 .slice(builder.len()..*start)?
375 .append_to_builder(builder.as_mut(), ctx)?;
376 }
377 branch_arrays[*branch_idx]
378 .cast(output_dtype.clone())?
379 .slice(*start..*end)?
380 .append_to_builder(builder.as_mut(), ctx)?;
381 }
382 if builder.len() < len {
383 else_value
384 .slice(builder.len()..len)?
385 .append_to_builder(builder.as_mut(), ctx)?;
386 }
387
388 Ok(builder.finish())
389}
390
391#[cfg(test)]
392mod tests {
393 use std::sync::LazyLock;
394
395 use vortex_buffer::buffer;
396 use vortex_error::VortexExpect as _;
397 use vortex_session::VortexSession;
398
399 use super::*;
400 use crate::Canonical;
401 use crate::IntoArray;
402 use crate::LEGACY_SESSION;
403 use crate::VortexSessionExecute;
404 use crate::arrays::BoolArray;
405 use crate::arrays::PrimitiveArray;
406 use crate::arrays::StructArray;
407 use crate::assert_arrays_eq;
408 use crate::dtype::DType;
409 use crate::dtype::Nullability;
410 use crate::dtype::PType;
411 use crate::expr::case_when;
412 use crate::expr::case_when_no_else;
413 use crate::expr::col;
414 use crate::expr::eq;
415 use crate::expr::get_item;
416 use crate::expr::gt;
417 use crate::expr::lit;
418 use crate::expr::nested_case_when;
419 use crate::expr::root;
420 use crate::expr::test_harness;
421 use crate::scalar::Scalar;
422 use crate::session::ArraySession;
423
424 static SESSION: LazyLock<VortexSession> =
425 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
426
427 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
429 let mut ctx = SESSION.create_execution_ctx();
430 array
431 .clone()
432 .apply(expr)
433 .unwrap()
434 .execute::<Canonical>(&mut ctx)
435 .unwrap()
436 .into_array()
437 }
438
439 #[test]
442 #[should_panic(expected = "cannot serialize")]
443 fn test_serialization_roundtrip() {
444 let options = CaseWhenOptions {
445 num_when_then_pairs: 1,
446 has_else: true,
447 };
448 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
449 let deserialized = CaseWhen
450 .deserialize(&serialized, &VortexSession::empty())
451 .unwrap();
452 assert_eq!(options, deserialized);
453 }
454
455 #[test]
456 #[should_panic(expected = "cannot serialize")]
457 fn test_serialization_no_else() {
458 let options = CaseWhenOptions {
459 num_when_then_pairs: 1,
460 has_else: false,
461 };
462 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
463 let deserialized = CaseWhen
464 .deserialize(&serialized, &VortexSession::empty())
465 .unwrap();
466 assert_eq!(options, deserialized);
467 }
468
469 #[test]
472 fn test_display_with_else() {
473 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
474 let display = format!("{}", expr);
475 assert!(display.contains("CASE"));
476 assert!(display.contains("WHEN"));
477 assert!(display.contains("THEN"));
478 assert!(display.contains("ELSE"));
479 assert!(display.contains("END"));
480 }
481
482 #[test]
483 fn test_display_no_else() {
484 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
485 let display = format!("{}", expr);
486 assert!(display.contains("CASE"));
487 assert!(display.contains("WHEN"));
488 assert!(display.contains("THEN"));
489 assert!(!display.contains("ELSE"));
490 assert!(display.contains("END"));
491 }
492
493 #[test]
494 fn test_display_nested_nary() {
495 let expr = nested_case_when(
497 vec![
498 (gt(col("x"), lit(10i32)), lit("high")),
499 (gt(col("x"), lit(5i32)), lit("medium")),
500 ],
501 Some(lit("low")),
502 );
503 let display = format!("{}", expr);
504 assert_eq!(display.matches("CASE").count(), 1);
505 assert_eq!(display.matches("WHEN").count(), 2);
506 assert_eq!(display.matches("THEN").count(), 2);
507 }
508
509 #[test]
512 fn test_return_dtype_with_else() {
513 let expr = case_when(lit(true), lit(100i32), lit(0i32));
514 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
515 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
516 assert_eq!(
517 result_dtype,
518 DType::Primitive(PType::I32, Nullability::NonNullable)
519 );
520 }
521
522 #[test]
523 fn test_return_dtype_with_nullable_else() {
524 let expr = case_when(
525 lit(true),
526 lit(100i32),
527 lit(Scalar::null(DType::Primitive(
528 PType::I32,
529 Nullability::Nullable,
530 ))),
531 );
532 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
533 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
534 assert_eq!(
535 result_dtype,
536 DType::Primitive(PType::I32, Nullability::Nullable)
537 );
538 }
539
540 #[test]
541 fn test_return_dtype_without_else_is_nullable() {
542 let expr = case_when_no_else(lit(true), lit(100i32));
543 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
544 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
545 assert_eq!(
546 result_dtype,
547 DType::Primitive(PType::I32, Nullability::Nullable)
548 );
549 }
550
551 #[test]
552 fn test_return_dtype_with_struct_input() {
553 let dtype = test_harness::struct_dtype();
554 let expr = case_when(
555 gt(get_item("col1", root()), lit(10u16)),
556 lit(100i32),
557 lit(0i32),
558 );
559 let result_dtype = expr.return_dtype(&dtype).unwrap();
560 assert_eq!(
561 result_dtype,
562 DType::Primitive(PType::I32, Nullability::NonNullable)
563 );
564 }
565
566 #[test]
567 fn test_return_dtype_mismatched_then_else_errors() {
568 let expr = case_when(lit(true), lit(100i32), lit("zero"));
569 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
570 let err = expr.return_dtype(&input_dtype).unwrap_err();
571 assert!(
572 err.to_string()
573 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
574 );
575 }
576
577 #[test]
580 fn test_arity_with_else() {
581 let options = CaseWhenOptions {
582 num_when_then_pairs: 1,
583 has_else: true,
584 };
585 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
586 }
587
588 #[test]
589 fn test_arity_without_else() {
590 let options = CaseWhenOptions {
591 num_when_then_pairs: 1,
592 has_else: false,
593 };
594 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
595 }
596
597 #[test]
600 fn test_child_names() {
601 let options = CaseWhenOptions {
602 num_when_then_pairs: 1,
603 has_else: true,
604 };
605 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
606 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
607 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
608 }
609
610 #[test]
613 #[should_panic(expected = "cannot serialize")]
614 fn test_serialization_roundtrip_nary() {
615 let options = CaseWhenOptions {
616 num_when_then_pairs: 3,
617 has_else: true,
618 };
619 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
620 let deserialized = CaseWhen
621 .deserialize(&serialized, &VortexSession::empty())
622 .unwrap();
623 assert_eq!(options, deserialized);
624 }
625
626 #[test]
627 #[should_panic(expected = "cannot serialize")]
628 fn test_serialization_roundtrip_nary_no_else() {
629 let options = CaseWhenOptions {
630 num_when_then_pairs: 4,
631 has_else: false,
632 };
633 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
634 let deserialized = CaseWhen
635 .deserialize(&serialized, &VortexSession::empty())
636 .unwrap();
637 assert_eq!(options, deserialized);
638 }
639
640 #[test]
643 fn test_arity_nary_with_else() {
644 let options = CaseWhenOptions {
645 num_when_then_pairs: 3,
646 has_else: true,
647 };
648 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
650 }
651
652 #[test]
653 fn test_arity_nary_without_else() {
654 let options = CaseWhenOptions {
655 num_when_then_pairs: 3,
656 has_else: false,
657 };
658 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
660 }
661
662 #[test]
665 fn test_child_names_nary() {
666 let options = CaseWhenOptions {
667 num_when_then_pairs: 3,
668 has_else: true,
669 };
670 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
671 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
672 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
673 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
674 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
675 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
676 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
677 }
678
679 #[test]
682 fn test_return_dtype_nary_mismatched_then_types_errors() {
683 let expr = nested_case_when(
684 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
685 Some(lit(0i32)),
686 );
687 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
688 let err = expr.return_dtype(&input_dtype).unwrap_err();
689 assert!(err.to_string().contains("THEN dtypes must match"));
690 }
691
692 #[test]
693 fn test_return_dtype_nary_mixed_nullability() {
694 let non_null_then = lit(100i32);
697 let nullable_then = lit(Scalar::null(DType::Primitive(
698 PType::I32,
699 Nullability::Nullable,
700 )));
701 let expr = nested_case_when(
702 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
703 Some(lit(0i32)),
704 );
705 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
706 let result = expr.return_dtype(&input_dtype).unwrap();
707 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
708 }
709
710 #[test]
711 fn test_return_dtype_nary_no_else_is_nullable() {
712 let expr = nested_case_when(
713 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
714 None,
715 );
716 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
717 let result = expr.return_dtype(&input_dtype).unwrap();
718 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
719 }
720
721 #[test]
724 fn test_replace_children() {
725 let expr = case_when(lit(true), lit(1i32), lit(0i32));
726 expr.with_children([lit(false), lit(2i32), lit(3i32)])
727 .vortex_expect("operation should succeed in test");
728 }
729
730 #[test]
733 fn test_evaluate_simple_condition() {
734 let test_array =
735 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
736 .unwrap()
737 .into_array();
738
739 let expr = case_when(
740 gt(get_item("value", root()), lit(2i32)),
741 lit(100i32),
742 lit(0i32),
743 );
744
745 let result = evaluate_expr(&expr, &test_array);
746 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
747 }
748
749 #[test]
750 fn test_evaluate_nary_multiple_conditions() {
751 let test_array =
753 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
754 .unwrap()
755 .into_array();
756
757 let expr = nested_case_when(
758 vec![
759 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
760 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
761 ],
762 Some(lit(0i32)),
763 );
764
765 let result = evaluate_expr(&expr, &test_array);
766 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
767 }
768
769 #[test]
770 fn test_evaluate_nary_first_match_wins() {
771 let test_array =
772 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
773 .unwrap()
774 .into_array();
775
776 let expr = nested_case_when(
778 vec![
779 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
780 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
781 ],
782 Some(lit(0i32)),
783 );
784
785 let result = evaluate_expr(&expr, &test_array);
786 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
787 }
788
789 #[test]
790 fn test_evaluate_no_else_returns_null() {
791 let test_array =
792 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
793 .unwrap()
794 .into_array();
795
796 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
797
798 let result = evaluate_expr(&expr, &test_array);
799 assert!(result.dtype().is_nullable());
800 assert_arrays_eq!(
801 result,
802 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
803 .into_array()
804 );
805 }
806
807 #[test]
808 fn test_evaluate_all_conditions_false() {
809 let test_array =
810 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
811 .unwrap()
812 .into_array();
813
814 let expr = case_when(
815 gt(get_item("value", root()), lit(100i32)),
816 lit(1i32),
817 lit(0i32),
818 );
819
820 let result = evaluate_expr(&expr, &test_array);
821 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
822 }
823
824 #[test]
825 fn test_evaluate_all_conditions_true() {
826 let test_array =
827 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
828 .unwrap()
829 .into_array();
830
831 let expr = case_when(
832 gt(get_item("value", root()), lit(0i32)),
833 lit(100i32),
834 lit(0i32),
835 );
836
837 let result = evaluate_expr(&expr, &test_array);
838 assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
839 }
840
841 #[test]
842 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
843 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
846 .unwrap()
847 .into_array();
848
849 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
850
851 let result = evaluate_expr(&expr, &test_array);
852 assert!(
853 result.dtype().is_nullable(),
854 "result dtype must be Nullable, got {:?}",
855 result.dtype()
856 );
857 assert_arrays_eq!(
858 result,
859 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
860 );
861 }
862
863 #[test]
864 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
865 let test_array =
873 StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
874
875 let nullable_20 =
876 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
877
878 let expr = nested_case_when(
879 vec![
880 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
881 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
882 ],
883 Some(lit(0i32)),
884 );
885
886 let result = evaluate_expr(&expr, &test_array);
887 assert!(
888 result.dtype().is_nullable(),
889 "result dtype must be Nullable, got {:?}",
890 result.dtype()
891 );
892 assert_arrays_eq!(
893 result,
894 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
895 );
896 Ok(())
897 }
898
899 #[test]
900 fn test_evaluate_with_literal_condition() {
901 let test_array = buffer![1i32, 2, 3].into_array();
902 let expr = case_when(lit(true), lit(100i32), lit(0i32));
903 let result = evaluate_expr(&expr, &test_array);
904
905 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
906 }
907
908 #[test]
909 fn test_evaluate_with_bool_column_result() {
910 let test_array =
911 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
912 .unwrap()
913 .into_array();
914
915 let expr = case_when(
916 gt(get_item("value", root()), lit(2i32)),
917 lit(true),
918 lit(false),
919 );
920
921 let result = evaluate_expr(&expr, &test_array);
922 assert_arrays_eq!(
923 result,
924 BoolArray::from_iter([false, false, true, true, true]).into_array()
925 );
926 }
927
928 #[test]
929 fn test_evaluate_with_nullable_condition() {
930 let test_array = StructArray::from_fields(&[(
931 "cond",
932 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
933 )])
934 .unwrap()
935 .into_array();
936
937 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
938
939 let result = evaluate_expr(&expr, &test_array);
940 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
941 }
942
943 #[test]
944 fn test_evaluate_with_nullable_result_values() {
945 let test_array = StructArray::from_fields(&[
946 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
947 (
948 "result",
949 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
950 .into_array(),
951 ),
952 ])
953 .unwrap()
954 .into_array();
955
956 let expr = case_when(
957 gt(get_item("value", root()), lit(2i32)),
958 get_item("result", root()),
959 lit(0i32),
960 );
961
962 let result = evaluate_expr(&expr, &test_array);
963 assert_arrays_eq!(
964 result,
965 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
966 .into_array()
967 );
968 }
969
970 #[test]
971 fn test_evaluate_with_all_null_condition() {
972 let test_array = StructArray::from_fields(&[(
973 "cond",
974 BoolArray::from_iter([None, None, None]).into_array(),
975 )])
976 .unwrap()
977 .into_array();
978
979 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
980
981 let result = evaluate_expr(&expr, &test_array);
982 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
983 }
984
985 #[test]
988 fn test_evaluate_nary_no_else_returns_null() {
989 let test_array =
990 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
991 .unwrap()
992 .into_array();
993
994 let expr = nested_case_when(
996 vec![
997 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
998 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
999 ],
1000 None,
1001 );
1002
1003 let result = evaluate_expr(&expr, &test_array);
1004 assert!(result.dtype().is_nullable());
1005 assert_arrays_eq!(
1006 result,
1007 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
1008 .into_array()
1009 );
1010 }
1011
1012 #[test]
1013 fn test_evaluate_nary_many_conditions() {
1014 let test_array =
1015 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1016 .unwrap()
1017 .into_array();
1018
1019 let expr = nested_case_when(
1021 vec![
1022 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1023 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1024 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1025 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1026 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1027 ],
1028 Some(lit(0i32)),
1029 );
1030
1031 let result = evaluate_expr(&expr, &test_array);
1032 assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1033 }
1034
1035 #[test]
1036 fn test_evaluate_nary_all_false_no_else() {
1037 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1038 .unwrap()
1039 .into_array();
1040
1041 let expr = nested_case_when(
1043 vec![
1044 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1045 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1046 ],
1047 None,
1048 );
1049
1050 let result = evaluate_expr(&expr, &test_array);
1051 assert!(result.dtype().is_nullable());
1052 assert_arrays_eq!(
1053 result,
1054 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1055 );
1056 }
1057
1058 #[test]
1059 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1060 let test_array =
1061 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1062 .unwrap()
1063 .into_array();
1064
1065 let expr = nested_case_when(
1069 vec![
1070 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1071 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1072 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1073 ],
1074 Some(lit(0i32)),
1075 );
1076
1077 let result = evaluate_expr(&expr, &test_array);
1078 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1080 }
1081
1082 #[test]
1083 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1084 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1087 .unwrap()
1088 .into_array();
1089
1090 let expr = nested_case_when(
1091 vec![
1092 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1093 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1095 ],
1096 Some(lit(0i32)),
1097 );
1098
1099 let result = evaluate_expr(&expr, &test_array);
1100 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1101 }
1102
1103 #[test]
1104 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1105 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1108 .unwrap()
1109 .into_array();
1110
1111 let expr = nested_case_when(
1112 vec![
1113 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1114 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1117 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1118 ],
1119 Some(lit(0i32)),
1120 );
1121
1122 let result = evaluate_expr(&expr, &test_array);
1123 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1124 }
1125
1126 #[test]
1127 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1128 let test_array =
1130 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1131 .into_array();
1132
1133 let expr = nested_case_when(
1137 vec![
1138 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1139 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1140 ],
1141 Some(lit("none")),
1142 );
1143
1144 let result = evaluate_expr(&expr, &test_array);
1145 assert_eq!(
1146 result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1147 Scalar::utf8("low", Nullability::NonNullable)
1148 );
1149 assert_eq!(
1150 result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1151 Scalar::utf8("low", Nullability::NonNullable)
1152 );
1153 assert_eq!(
1154 result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1155 Scalar::utf8("high", Nullability::NonNullable)
1156 );
1157 assert_eq!(
1158 result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1159 Scalar::utf8("high", Nullability::NonNullable)
1160 );
1161 Ok(())
1162 }
1163
1164 #[test]
1165 fn test_evaluate_nary_with_nullable_conditions() {
1166 let test_array = StructArray::from_fields(&[
1167 (
1168 "cond1",
1169 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1170 ),
1171 (
1172 "cond2",
1173 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1174 ),
1175 ])
1176 .unwrap()
1177 .into_array();
1178
1179 let expr = nested_case_when(
1180 vec![
1181 (get_item("cond1", root()), lit(10i32)),
1182 (get_item("cond2", root()), lit(20i32)),
1183 ],
1184 Some(lit(0i32)),
1185 );
1186
1187 let result = evaluate_expr(&expr, &test_array);
1188 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1192 }
1193
1194 #[test]
1195 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1196 let n = 100usize;
1199
1200 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2));
1202 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2));
1203
1204 let result = merge_case_branches(
1205 vec![
1206 (
1207 branch0_mask,
1208 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1209 ),
1210 (
1211 branch1_mask,
1212 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1213 ),
1214 ],
1215 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1216 &mut SESSION.create_execution_ctx(),
1217 )?;
1218
1219 let expected: Vec<Option<i32>> = (0..n)
1221 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1222 .collect();
1223 assert_arrays_eq!(
1224 result,
1225 PrimitiveArray::from_option_iter(expected).into_array()
1226 );
1227 Ok(())
1228 }
1229}