1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48 pub num_when_then_pairs: u32,
50 pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56 pub fn num_children(&self) -> usize {
58 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59 }
60}
61
62impl fmt::Display for CaseWhenOptions {
63 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64 write!(
65 f,
66 "case_when(pairs={}, else={})",
67 self.num_when_then_pairs, self.has_else
68 )
69 }
70}
71
72#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79 type Options = CaseWhenOptions;
80
81 fn id(&self) -> ScalarFnId {
82 ScalarFnId::from("vortex.case_when")
83 }
84
85 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86 vortex_bail!("cannot serialize")
90 }
91
92 fn deserialize(
93 &self,
94 metadata: &[u8],
95 _session: &VortexSession,
96 ) -> VortexResult<Self::Options> {
97 let opts = pb::CaseWhenOpts::decode(metadata)?;
98 if opts.num_children < 2 {
99 vortex_bail!(
100 "CaseWhen expects at least 2 children, got {}",
101 opts.num_children
102 );
103 }
104 Ok(CaseWhenOptions {
105 num_when_then_pairs: opts.num_children / 2,
106 has_else: opts.num_children % 2 == 1,
107 })
108 }
109
110 fn arity(&self, options: &Self::Options) -> Arity {
111 Arity::Exact(options.num_children())
112 }
113
114 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115 let num_pair_children = options.num_when_then_pairs as usize * 2;
116 if child_idx < num_pair_children {
117 let pair_idx = child_idx / 2;
118 if child_idx.is_multiple_of(2) {
119 ChildName::from(Arc::from(format!("when_{pair_idx}")))
120 } else {
121 ChildName::from(Arc::from(format!("then_{pair_idx}")))
122 }
123 } else if options.has_else && child_idx == num_pair_children {
124 ChildName::from("else")
125 } else {
126 unreachable!("Invalid child index {} for CaseWhen", child_idx)
127 }
128 }
129
130 fn fmt_sql(
131 &self,
132 options: &Self::Options,
133 expr: &Expression,
134 f: &mut Formatter<'_>,
135 ) -> fmt::Result {
136 write!(f, "CASE")?;
137 for i in 0..options.num_when_then_pairs as usize {
138 write!(
139 f,
140 " WHEN {} THEN {}",
141 expr.child(i * 2),
142 expr.child(i * 2 + 1)
143 )?;
144 }
145 if options.has_else {
146 let else_idx = options.num_when_then_pairs as usize * 2;
147 write!(f, " ELSE {}", expr.child(else_idx))?;
148 }
149 write!(f, " END")
150 }
151
152 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153 if options.num_when_then_pairs == 0 {
154 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155 }
156
157 let expected_len = options.num_children();
158 if arg_dtypes.len() != expected_len {
159 vortex_bail!(
160 "CaseWhen expects {expected_len} argument dtypes, got {}",
161 arg_dtypes.len()
162 );
163 }
164
165 let first_then = &arg_dtypes[1];
169 let mut result_dtype = first_then.clone();
170
171 for i in 1..options.num_when_then_pairs as usize {
172 let then_i = &arg_dtypes[i * 2 + 1];
173 if !first_then.eq_ignore_nullability(then_i) {
174 vortex_bail!(
175 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176 first_then,
177 then_i
178 );
179 }
180 result_dtype = result_dtype.union_nullability(then_i.nullability());
181 }
182
183 if options.has_else {
184 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185 if !result_dtype.eq_ignore_nullability(else_dtype) {
186 vortex_bail!(
187 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188 first_then,
189 else_dtype
190 );
191 }
192 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193 } else {
194 result_dtype = result_dtype.as_nullable();
196 }
197
198 Ok(result_dtype)
199 }
200
201 fn execute(
202 &self,
203 options: &Self::Options,
204 args: &dyn ExecutionArgs,
205 ctx: &mut ExecutionCtx,
206 ) -> VortexResult<ArrayRef> {
207 let row_count = args.row_count();
214 let num_pairs = options.num_when_then_pairs as usize;
215
216 let mut remaining = Mask::new_true(row_count);
217 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219 for i in 0..num_pairs {
220 if remaining.all_false() {
221 break;
222 }
223
224 let condition = args.get(i * 2)?;
225 let cond_bool = condition.execute::<BoolArray>(ctx)?;
226 let cond_mask = cond_bool.to_mask_fill_null_false();
227 let effective_mask = &remaining & &cond_mask;
228
229 if effective_mask.all_false() {
230 continue;
231 }
232
233 let then_value = args.get(i * 2 + 1)?;
234 remaining = remaining.bitand_not(&cond_mask);
235 branches.push((effective_mask, then_value));
236 }
237
238 let else_value: ArrayRef = if options.has_else {
239 args.get(num_pairs * 2)?
240 } else {
241 let then_dtype = args.get(1)?.dtype().as_nullable();
242 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243 };
244
245 if branches.is_empty() {
246 return Ok(else_value);
247 }
248
249 merge_case_branches(branches, else_value)
250 }
251
252 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253 true
254 }
255
256 fn is_fallible(&self, _options: &Self::Options) -> bool {
257 false
258 }
259}
260
261const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265fn merge_case_branches(
269 branches: Vec<(Mask, ArrayRef)>,
270 else_value: ArrayRef,
271) -> VortexResult<ArrayRef> {
272 if branches.len() == 1 {
273 let (mask, then_value) = &branches[0];
274 return zip_impl(then_value, &else_value, mask);
275 }
276
277 let output_nullability = branches
278 .iter()
279 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
280 acc | arr.dtype().nullability()
281 });
282 let output_dtype = else_value.dtype().with_nullability(output_nullability);
283 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
284
285 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
286 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
287 match mask.slices() {
288 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
289 AllOr::None => {}
290 AllOr::Some(slices) => {
291 for &(start, end) in slices {
292 spans.push((start, end, branch_idx));
293 }
294 }
295 }
296 }
297 spans.sort_unstable_by_key(|&(start, ..)| start);
298
299 if spans.is_empty() {
300 return else_value.cast(output_dtype);
301 }
302
303 let builder = builder_with_capacity(&output_dtype, else_value.len());
304
305 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
306 if fragmented {
307 merge_row_by_row(&branch_arrays, &else_value, &spans, &output_dtype, builder)
308 } else {
309 merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
310 }
311}
312
313fn merge_row_by_row(
316 branch_arrays: &[&ArrayRef],
317 else_value: &ArrayRef,
318 spans: &[(usize, usize, usize)],
319 output_dtype: &DType,
320 mut builder: Box<dyn ArrayBuilder>,
321) -> VortexResult<ArrayRef> {
322 let mut pos = 0;
323 for &(start, end, branch_idx) in spans {
324 for row in pos..start {
325 let scalar = else_value.scalar_at(row)?;
326 builder.append_scalar(&scalar.cast(output_dtype)?)?;
327 }
328 for row in start..end {
329 let scalar = branch_arrays[branch_idx].scalar_at(row)?;
330 builder.append_scalar(&scalar.cast(output_dtype)?)?;
331 }
332 pos = end;
333 }
334 for row in pos..else_value.len() {
335 let scalar = else_value.scalar_at(row)?;
336 builder.append_scalar(&scalar.cast(output_dtype)?)?;
337 }
338
339 Ok(builder.finish())
340}
341
342fn merge_run_by_run(
346 branch_arrays: &[&ArrayRef],
347 else_value: &ArrayRef,
348 spans: &[(usize, usize, usize)],
349 output_dtype: &DType,
350 mut builder: Box<dyn ArrayBuilder>,
351) -> VortexResult<ArrayRef> {
352 let else_value = else_value.cast(output_dtype.clone())?;
353 let len = else_value.len();
354 for (start, end, branch_idx) in spans {
355 if builder.len() < *start {
356 builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
357 }
358 builder.extend_from_array(
359 &branch_arrays[*branch_idx]
360 .cast(output_dtype.clone())?
361 .slice(*start..*end)?,
362 );
363 }
364 if builder.len() < len {
365 builder.extend_from_array(&else_value.slice(builder.len()..len)?);
366 }
367
368 Ok(builder.finish())
369}
370
371#[cfg(test)]
372mod tests {
373 use std::sync::LazyLock;
374
375 use vortex_buffer::buffer;
376 use vortex_error::VortexExpect as _;
377 use vortex_session::VortexSession;
378
379 use super::*;
380 use crate::Canonical;
381 use crate::IntoArray;
382 use crate::VortexSessionExecute as _;
383 use crate::arrays::BoolArray;
384 use crate::arrays::PrimitiveArray;
385 use crate::arrays::StructArray;
386 use crate::assert_arrays_eq;
387 use crate::dtype::DType;
388 use crate::dtype::Nullability;
389 use crate::dtype::PType;
390 use crate::expr::case_when;
391 use crate::expr::case_when_no_else;
392 use crate::expr::col;
393 use crate::expr::eq;
394 use crate::expr::get_item;
395 use crate::expr::gt;
396 use crate::expr::lit;
397 use crate::expr::nested_case_when;
398 use crate::expr::root;
399 use crate::expr::test_harness;
400 use crate::scalar::Scalar;
401 use crate::session::ArraySession;
402
403 static SESSION: LazyLock<VortexSession> =
404 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
405
406 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
408 let mut ctx = SESSION.create_execution_ctx();
409 array
410 .clone()
411 .apply(expr)
412 .unwrap()
413 .execute::<Canonical>(&mut ctx)
414 .unwrap()
415 .into_array()
416 }
417
418 #[test]
421 #[should_panic(expected = "cannot serialize")]
422 fn test_serialization_roundtrip() {
423 let options = CaseWhenOptions {
424 num_when_then_pairs: 1,
425 has_else: true,
426 };
427 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
428 let deserialized = CaseWhen
429 .deserialize(&serialized, &VortexSession::empty())
430 .unwrap();
431 assert_eq!(options, deserialized);
432 }
433
434 #[test]
435 #[should_panic(expected = "cannot serialize")]
436 fn test_serialization_no_else() {
437 let options = CaseWhenOptions {
438 num_when_then_pairs: 1,
439 has_else: false,
440 };
441 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
442 let deserialized = CaseWhen
443 .deserialize(&serialized, &VortexSession::empty())
444 .unwrap();
445 assert_eq!(options, deserialized);
446 }
447
448 #[test]
451 fn test_display_with_else() {
452 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
453 let display = format!("{}", expr);
454 assert!(display.contains("CASE"));
455 assert!(display.contains("WHEN"));
456 assert!(display.contains("THEN"));
457 assert!(display.contains("ELSE"));
458 assert!(display.contains("END"));
459 }
460
461 #[test]
462 fn test_display_no_else() {
463 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
464 let display = format!("{}", expr);
465 assert!(display.contains("CASE"));
466 assert!(display.contains("WHEN"));
467 assert!(display.contains("THEN"));
468 assert!(!display.contains("ELSE"));
469 assert!(display.contains("END"));
470 }
471
472 #[test]
473 fn test_display_nested_nary() {
474 let expr = nested_case_when(
476 vec![
477 (gt(col("x"), lit(10i32)), lit("high")),
478 (gt(col("x"), lit(5i32)), lit("medium")),
479 ],
480 Some(lit("low")),
481 );
482 let display = format!("{}", expr);
483 assert_eq!(display.matches("CASE").count(), 1);
484 assert_eq!(display.matches("WHEN").count(), 2);
485 assert_eq!(display.matches("THEN").count(), 2);
486 }
487
488 #[test]
491 fn test_return_dtype_with_else() {
492 let expr = case_when(lit(true), lit(100i32), lit(0i32));
493 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
494 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
495 assert_eq!(
496 result_dtype,
497 DType::Primitive(PType::I32, Nullability::NonNullable)
498 );
499 }
500
501 #[test]
502 fn test_return_dtype_with_nullable_else() {
503 let expr = case_when(
504 lit(true),
505 lit(100i32),
506 lit(Scalar::null(DType::Primitive(
507 PType::I32,
508 Nullability::Nullable,
509 ))),
510 );
511 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
512 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
513 assert_eq!(
514 result_dtype,
515 DType::Primitive(PType::I32, Nullability::Nullable)
516 );
517 }
518
519 #[test]
520 fn test_return_dtype_without_else_is_nullable() {
521 let expr = case_when_no_else(lit(true), lit(100i32));
522 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
523 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
524 assert_eq!(
525 result_dtype,
526 DType::Primitive(PType::I32, Nullability::Nullable)
527 );
528 }
529
530 #[test]
531 fn test_return_dtype_with_struct_input() {
532 let dtype = test_harness::struct_dtype();
533 let expr = case_when(
534 gt(get_item("col1", root()), lit(10u16)),
535 lit(100i32),
536 lit(0i32),
537 );
538 let result_dtype = expr.return_dtype(&dtype).unwrap();
539 assert_eq!(
540 result_dtype,
541 DType::Primitive(PType::I32, Nullability::NonNullable)
542 );
543 }
544
545 #[test]
546 fn test_return_dtype_mismatched_then_else_errors() {
547 let expr = case_when(lit(true), lit(100i32), lit("zero"));
548 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
549 let err = expr.return_dtype(&input_dtype).unwrap_err();
550 assert!(
551 err.to_string()
552 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
553 );
554 }
555
556 #[test]
559 fn test_arity_with_else() {
560 let options = CaseWhenOptions {
561 num_when_then_pairs: 1,
562 has_else: true,
563 };
564 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
565 }
566
567 #[test]
568 fn test_arity_without_else() {
569 let options = CaseWhenOptions {
570 num_when_then_pairs: 1,
571 has_else: false,
572 };
573 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
574 }
575
576 #[test]
579 fn test_child_names() {
580 let options = CaseWhenOptions {
581 num_when_then_pairs: 1,
582 has_else: true,
583 };
584 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
585 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
586 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
587 }
588
589 #[test]
592 #[should_panic(expected = "cannot serialize")]
593 fn test_serialization_roundtrip_nary() {
594 let options = CaseWhenOptions {
595 num_when_then_pairs: 3,
596 has_else: true,
597 };
598 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
599 let deserialized = CaseWhen
600 .deserialize(&serialized, &VortexSession::empty())
601 .unwrap();
602 assert_eq!(options, deserialized);
603 }
604
605 #[test]
606 #[should_panic(expected = "cannot serialize")]
607 fn test_serialization_roundtrip_nary_no_else() {
608 let options = CaseWhenOptions {
609 num_when_then_pairs: 4,
610 has_else: false,
611 };
612 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
613 let deserialized = CaseWhen
614 .deserialize(&serialized, &VortexSession::empty())
615 .unwrap();
616 assert_eq!(options, deserialized);
617 }
618
619 #[test]
622 fn test_arity_nary_with_else() {
623 let options = CaseWhenOptions {
624 num_when_then_pairs: 3,
625 has_else: true,
626 };
627 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
629 }
630
631 #[test]
632 fn test_arity_nary_without_else() {
633 let options = CaseWhenOptions {
634 num_when_then_pairs: 3,
635 has_else: false,
636 };
637 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
639 }
640
641 #[test]
644 fn test_child_names_nary() {
645 let options = CaseWhenOptions {
646 num_when_then_pairs: 3,
647 has_else: true,
648 };
649 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
650 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
651 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
652 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
653 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
654 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
655 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
656 }
657
658 #[test]
661 fn test_return_dtype_nary_mismatched_then_types_errors() {
662 let expr = nested_case_when(
663 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
664 Some(lit(0i32)),
665 );
666 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
667 let err = expr.return_dtype(&input_dtype).unwrap_err();
668 assert!(err.to_string().contains("THEN dtypes must match"));
669 }
670
671 #[test]
672 fn test_return_dtype_nary_mixed_nullability() {
673 let non_null_then = lit(100i32);
676 let nullable_then = lit(Scalar::null(DType::Primitive(
677 PType::I32,
678 Nullability::Nullable,
679 )));
680 let expr = nested_case_when(
681 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
682 Some(lit(0i32)),
683 );
684 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
685 let result = expr.return_dtype(&input_dtype).unwrap();
686 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
687 }
688
689 #[test]
690 fn test_return_dtype_nary_no_else_is_nullable() {
691 let expr = nested_case_when(
692 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
693 None,
694 );
695 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
696 let result = expr.return_dtype(&input_dtype).unwrap();
697 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
698 }
699
700 #[test]
703 fn test_replace_children() {
704 let expr = case_when(lit(true), lit(1i32), lit(0i32));
705 expr.with_children([lit(false), lit(2i32), lit(3i32)])
706 .vortex_expect("operation should succeed in test");
707 }
708
709 #[test]
712 fn test_evaluate_simple_condition() {
713 let test_array =
714 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
715 .unwrap()
716 .into_array();
717
718 let expr = case_when(
719 gt(get_item("value", root()), lit(2i32)),
720 lit(100i32),
721 lit(0i32),
722 );
723
724 let result = evaluate_expr(&expr, &test_array);
725 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
726 }
727
728 #[test]
729 fn test_evaluate_nary_multiple_conditions() {
730 let test_array =
732 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
733 .unwrap()
734 .into_array();
735
736 let expr = nested_case_when(
737 vec![
738 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
739 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
740 ],
741 Some(lit(0i32)),
742 );
743
744 let result = evaluate_expr(&expr, &test_array);
745 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
746 }
747
748 #[test]
749 fn test_evaluate_nary_first_match_wins() {
750 let test_array =
751 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
752 .unwrap()
753 .into_array();
754
755 let expr = nested_case_when(
757 vec![
758 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
759 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
760 ],
761 Some(lit(0i32)),
762 );
763
764 let result = evaluate_expr(&expr, &test_array);
765 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
766 }
767
768 #[test]
769 fn test_evaluate_no_else_returns_null() {
770 let test_array =
771 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
772 .unwrap()
773 .into_array();
774
775 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
776
777 let result = evaluate_expr(&expr, &test_array);
778 assert!(result.dtype().is_nullable());
779 assert_arrays_eq!(
780 result,
781 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
782 .into_array()
783 );
784 }
785
786 #[test]
787 fn test_evaluate_all_conditions_false() {
788 let test_array =
789 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
790 .unwrap()
791 .into_array();
792
793 let expr = case_when(
794 gt(get_item("value", root()), lit(100i32)),
795 lit(1i32),
796 lit(0i32),
797 );
798
799 let result = evaluate_expr(&expr, &test_array);
800 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
801 }
802
803 #[test]
804 fn test_evaluate_all_conditions_true() {
805 let test_array =
806 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
807 .unwrap()
808 .into_array();
809
810 let expr = case_when(
811 gt(get_item("value", root()), lit(0i32)),
812 lit(100i32),
813 lit(0i32),
814 );
815
816 let result = evaluate_expr(&expr, &test_array);
817 assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
818 }
819
820 #[test]
821 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
822 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
825 .unwrap()
826 .into_array();
827
828 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
829
830 let result = evaluate_expr(&expr, &test_array);
831 assert!(
832 result.dtype().is_nullable(),
833 "result dtype must be Nullable, got {:?}",
834 result.dtype()
835 );
836 assert_arrays_eq!(
837 result,
838 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
839 );
840 }
841
842 #[test]
843 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
844 let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
852 .unwrap()
853 .into_array();
854
855 let nullable_20 =
856 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
857
858 let expr = nested_case_when(
859 vec![
860 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
861 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
862 ],
863 Some(lit(0i32)),
864 );
865
866 let result = evaluate_expr(&expr, &test_array);
867 assert!(
868 result.dtype().is_nullable(),
869 "result dtype must be Nullable, got {:?}",
870 result.dtype()
871 );
872 assert_arrays_eq!(
873 result,
874 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
875 );
876 Ok(())
877 }
878
879 #[test]
880 fn test_evaluate_with_literal_condition() {
881 let test_array = buffer![1i32, 2, 3].into_array();
882 let expr = case_when(lit(true), lit(100i32), lit(0i32));
883 let result = evaluate_expr(&expr, &test_array);
884
885 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
886 }
887
888 #[test]
889 fn test_evaluate_with_bool_column_result() {
890 let test_array =
891 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
892 .unwrap()
893 .into_array();
894
895 let expr = case_when(
896 gt(get_item("value", root()), lit(2i32)),
897 lit(true),
898 lit(false),
899 );
900
901 let result = evaluate_expr(&expr, &test_array);
902 assert_arrays_eq!(
903 result,
904 BoolArray::from_iter([false, false, true, true, true]).into_array()
905 );
906 }
907
908 #[test]
909 fn test_evaluate_with_nullable_condition() {
910 let test_array = StructArray::from_fields(&[(
911 "cond",
912 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
913 )])
914 .unwrap()
915 .into_array();
916
917 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
918
919 let result = evaluate_expr(&expr, &test_array);
920 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
921 }
922
923 #[test]
924 fn test_evaluate_with_nullable_result_values() {
925 let test_array = StructArray::from_fields(&[
926 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
927 (
928 "result",
929 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
930 .into_array(),
931 ),
932 ])
933 .unwrap()
934 .into_array();
935
936 let expr = case_when(
937 gt(get_item("value", root()), lit(2i32)),
938 get_item("result", root()),
939 lit(0i32),
940 );
941
942 let result = evaluate_expr(&expr, &test_array);
943 assert_arrays_eq!(
944 result,
945 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
946 .into_array()
947 );
948 }
949
950 #[test]
951 fn test_evaluate_with_all_null_condition() {
952 let test_array = StructArray::from_fields(&[(
953 "cond",
954 BoolArray::from_iter([None, None, None]).into_array(),
955 )])
956 .unwrap()
957 .into_array();
958
959 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
960
961 let result = evaluate_expr(&expr, &test_array);
962 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
963 }
964
965 #[test]
968 fn test_evaluate_nary_no_else_returns_null() {
969 let test_array =
970 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
971 .unwrap()
972 .into_array();
973
974 let expr = nested_case_when(
976 vec![
977 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
978 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
979 ],
980 None,
981 );
982
983 let result = evaluate_expr(&expr, &test_array);
984 assert!(result.dtype().is_nullable());
985 assert_arrays_eq!(
986 result,
987 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
988 .into_array()
989 );
990 }
991
992 #[test]
993 fn test_evaluate_nary_many_conditions() {
994 let test_array =
995 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
996 .unwrap()
997 .into_array();
998
999 let expr = nested_case_when(
1001 vec![
1002 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1003 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1004 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1005 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1006 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1007 ],
1008 Some(lit(0i32)),
1009 );
1010
1011 let result = evaluate_expr(&expr, &test_array);
1012 assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1013 }
1014
1015 #[test]
1016 fn test_evaluate_nary_all_false_no_else() {
1017 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1018 .unwrap()
1019 .into_array();
1020
1021 let expr = nested_case_when(
1023 vec![
1024 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1025 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1026 ],
1027 None,
1028 );
1029
1030 let result = evaluate_expr(&expr, &test_array);
1031 assert!(result.dtype().is_nullable());
1032 assert_arrays_eq!(
1033 result,
1034 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1035 );
1036 }
1037
1038 #[test]
1039 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1040 let test_array =
1041 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1042 .unwrap()
1043 .into_array();
1044
1045 let expr = nested_case_when(
1049 vec![
1050 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1051 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1052 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1053 ],
1054 Some(lit(0i32)),
1055 );
1056
1057 let result = evaluate_expr(&expr, &test_array);
1058 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1060 }
1061
1062 #[test]
1063 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1064 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1067 .unwrap()
1068 .into_array();
1069
1070 let expr = nested_case_when(
1071 vec![
1072 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1073 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1075 ],
1076 Some(lit(0i32)),
1077 );
1078
1079 let result = evaluate_expr(&expr, &test_array);
1080 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1081 }
1082
1083 #[test]
1084 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1085 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1088 .unwrap()
1089 .into_array();
1090
1091 let expr = nested_case_when(
1092 vec![
1093 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1094 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1097 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1098 ],
1099 Some(lit(0i32)),
1100 );
1101
1102 let result = evaluate_expr(&expr, &test_array);
1103 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1104 }
1105
1106 #[test]
1107 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1108 let test_array =
1110 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
1111 .unwrap()
1112 .into_array();
1113
1114 let expr = nested_case_when(
1118 vec![
1119 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1120 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1121 ],
1122 Some(lit("none")),
1123 );
1124
1125 let result = evaluate_expr(&expr, &test_array);
1126 assert_eq!(
1127 result.scalar_at(0)?,
1128 Scalar::utf8("low", Nullability::NonNullable)
1129 );
1130 assert_eq!(
1131 result.scalar_at(1)?,
1132 Scalar::utf8("low", Nullability::NonNullable)
1133 );
1134 assert_eq!(
1135 result.scalar_at(2)?,
1136 Scalar::utf8("high", Nullability::NonNullable)
1137 );
1138 assert_eq!(
1139 result.scalar_at(3)?,
1140 Scalar::utf8("high", Nullability::NonNullable)
1141 );
1142 Ok(())
1143 }
1144
1145 #[test]
1146 fn test_evaluate_nary_with_nullable_conditions() {
1147 let test_array = StructArray::from_fields(&[
1148 (
1149 "cond1",
1150 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1151 ),
1152 (
1153 "cond2",
1154 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1155 ),
1156 ])
1157 .unwrap()
1158 .into_array();
1159
1160 let expr = nested_case_when(
1161 vec![
1162 (get_item("cond1", root()), lit(10i32)),
1163 (get_item("cond2", root()), lit(20i32)),
1164 ],
1165 Some(lit(0i32)),
1166 );
1167
1168 let result = evaluate_expr(&expr, &test_array);
1169 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1173 }
1174
1175 #[test]
1176 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1177 let n = 100usize;
1180
1181 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1183 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1184
1185 let result = merge_case_branches(
1186 vec![
1187 (
1188 branch0_mask,
1189 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1190 ),
1191 (
1192 branch1_mask,
1193 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1194 ),
1195 ],
1196 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1197 )?;
1198
1199 let expected: Vec<Option<i32>> = (0..n)
1201 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1202 .collect();
1203 assert_arrays_eq!(
1204 result,
1205 PrimitiveArray::from_option_iter(expected).into_array()
1206 );
1207 Ok(())
1208 }
1209}