1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48 pub num_when_then_pairs: u32,
50 pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56 pub fn num_children(&self) -> usize {
58 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59 }
60}
61
62impl fmt::Display for CaseWhenOptions {
63 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64 write!(
65 f,
66 "case_when(pairs={}, else={})",
67 self.num_when_then_pairs, self.has_else
68 )
69 }
70}
71
72#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79 type Options = CaseWhenOptions;
80
81 fn id(&self) -> ScalarFnId {
82 ScalarFnId::new("vortex.case_when")
83 }
84
85 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86 vortex_bail!("cannot serialize")
90 }
91
92 fn deserialize(
93 &self,
94 metadata: &[u8],
95 _session: &VortexSession,
96 ) -> VortexResult<Self::Options> {
97 let opts = pb::CaseWhenOpts::decode(metadata)?;
98 if opts.num_children < 2 {
99 vortex_bail!(
100 "CaseWhen expects at least 2 children, got {}",
101 opts.num_children
102 );
103 }
104 Ok(CaseWhenOptions {
105 num_when_then_pairs: opts.num_children / 2,
106 has_else: opts.num_children % 2 == 1,
107 })
108 }
109
110 fn arity(&self, options: &Self::Options) -> Arity {
111 Arity::Exact(options.num_children())
112 }
113
114 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115 let num_pair_children = options.num_when_then_pairs as usize * 2;
116 if child_idx < num_pair_children {
117 let pair_idx = child_idx / 2;
118 if child_idx.is_multiple_of(2) {
119 ChildName::from(Arc::from(format!("when_{pair_idx}")))
120 } else {
121 ChildName::from(Arc::from(format!("then_{pair_idx}")))
122 }
123 } else if options.has_else && child_idx == num_pair_children {
124 ChildName::from("else")
125 } else {
126 unreachable!("Invalid child index {} for CaseWhen", child_idx)
127 }
128 }
129
130 fn fmt_sql(
131 &self,
132 options: &Self::Options,
133 expr: &Expression,
134 f: &mut Formatter<'_>,
135 ) -> fmt::Result {
136 write!(f, "CASE")?;
137 for i in 0..options.num_when_then_pairs as usize {
138 write!(
139 f,
140 " WHEN {} THEN {}",
141 expr.child(i * 2),
142 expr.child(i * 2 + 1)
143 )?;
144 }
145 if options.has_else {
146 let else_idx = options.num_when_then_pairs as usize * 2;
147 write!(f, " ELSE {}", expr.child(else_idx))?;
148 }
149 write!(f, " END")
150 }
151
152 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153 if options.num_when_then_pairs == 0 {
154 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155 }
156
157 let expected_len = options.num_children();
158 if arg_dtypes.len() != expected_len {
159 vortex_bail!(
160 "CaseWhen expects {expected_len} argument dtypes, got {}",
161 arg_dtypes.len()
162 );
163 }
164
165 let first_then = &arg_dtypes[1];
169 let mut result_dtype = first_then.clone();
170
171 for i in 1..options.num_when_then_pairs as usize {
172 let then_i = &arg_dtypes[i * 2 + 1];
173 if !first_then.eq_ignore_nullability(then_i) {
174 vortex_bail!(
175 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176 first_then,
177 then_i
178 );
179 }
180 result_dtype = result_dtype.union_nullability(then_i.nullability());
181 }
182
183 if options.has_else {
184 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185 if !result_dtype.eq_ignore_nullability(else_dtype) {
186 vortex_bail!(
187 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188 first_then,
189 else_dtype
190 );
191 }
192 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193 } else {
194 result_dtype = result_dtype.as_nullable();
196 }
197
198 Ok(result_dtype)
199 }
200
201 fn execute(
202 &self,
203 options: &Self::Options,
204 args: &dyn ExecutionArgs,
205 ctx: &mut ExecutionCtx,
206 ) -> VortexResult<ArrayRef> {
207 let row_count = args.row_count();
214 let num_pairs = options.num_when_then_pairs as usize;
215
216 let mut remaining = Mask::new_true(row_count);
217 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219 for i in 0..num_pairs {
220 if remaining.all_false() {
221 break;
222 }
223
224 let condition = args.get(i * 2)?;
225 let cond_bool = condition.execute::<BoolArray>(ctx)?;
226 let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
227 let effective_mask = &remaining & &cond_mask;
228
229 if effective_mask.all_false() {
230 continue;
231 }
232
233 let then_value = args.get(i * 2 + 1)?;
234 remaining = remaining.bitand_not(&cond_mask);
235 branches.push((effective_mask, then_value));
236 }
237
238 let else_value: ArrayRef = if options.has_else {
239 args.get(num_pairs * 2)?
240 } else {
241 let then_dtype = args.get(1)?.dtype().as_nullable();
242 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243 };
244
245 if branches.is_empty() {
246 return Ok(else_value);
247 }
248
249 merge_case_branches(branches, else_value, ctx)
250 }
251
252 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253 true
254 }
255
256 fn is_fallible(&self, _options: &Self::Options) -> bool {
257 false
258 }
259}
260
261const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265fn merge_case_branches(
269 branches: Vec<(Mask, ArrayRef)>,
270 else_value: ArrayRef,
271 ctx: &mut ExecutionCtx,
272) -> VortexResult<ArrayRef> {
273 if branches.len() == 1 {
274 let (mask, then_value) = &branches[0];
275 return zip_impl(then_value, &else_value, mask);
276 }
277
278 let output_nullability = branches
279 .iter()
280 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
281 acc | arr.dtype().nullability()
282 });
283 let output_dtype = else_value.dtype().with_nullability(output_nullability);
284 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
285
286 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
287 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
288 match mask.slices() {
289 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
290 AllOr::None => {}
291 AllOr::Some(slices) => {
292 for &(start, end) in slices {
293 spans.push((start, end, branch_idx));
294 }
295 }
296 }
297 }
298 spans.sort_unstable_by_key(|&(start, ..)| start);
299
300 if spans.is_empty() {
301 return else_value.cast(output_dtype);
302 }
303
304 let builder = builder_with_capacity(&output_dtype, else_value.len());
305
306 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
307 if fragmented {
308 merge_row_by_row(
309 &branch_arrays,
310 &else_value,
311 &spans,
312 &output_dtype,
313 builder,
314 ctx,
315 )
316 } else {
317 merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
318 }
319}
320
321fn merge_row_by_row(
324 branch_arrays: &[&ArrayRef],
325 else_value: &ArrayRef,
326 spans: &[(usize, usize, usize)],
327 output_dtype: &DType,
328 mut builder: Box<dyn ArrayBuilder>,
329 ctx: &mut ExecutionCtx,
330) -> VortexResult<ArrayRef> {
331 let mut pos = 0;
332 for &(start, end, branch_idx) in spans {
333 for row in pos..start {
334 let scalar = else_value.execute_scalar(row, ctx)?;
335 builder.append_scalar(&scalar.cast(output_dtype)?)?;
336 }
337 for row in start..end {
338 let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
339 builder.append_scalar(&scalar.cast(output_dtype)?)?;
340 }
341 pos = end;
342 }
343 for row in pos..else_value.len() {
344 let scalar = else_value.execute_scalar(row, ctx)?;
345 builder.append_scalar(&scalar.cast(output_dtype)?)?;
346 }
347
348 Ok(builder.finish())
349}
350
351fn merge_run_by_run(
355 branch_arrays: &[&ArrayRef],
356 else_value: &ArrayRef,
357 spans: &[(usize, usize, usize)],
358 output_dtype: &DType,
359 mut builder: Box<dyn ArrayBuilder>,
360) -> VortexResult<ArrayRef> {
361 let else_value = else_value.cast(output_dtype.clone())?;
362 let len = else_value.len();
363 for (start, end, branch_idx) in spans {
364 if builder.len() < *start {
365 builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
366 }
367 builder.extend_from_array(
368 &branch_arrays[*branch_idx]
369 .cast(output_dtype.clone())?
370 .slice(*start..*end)?,
371 );
372 }
373 if builder.len() < len {
374 builder.extend_from_array(&else_value.slice(builder.len()..len)?);
375 }
376
377 Ok(builder.finish())
378}
379
380#[cfg(test)]
381mod tests {
382 use std::sync::LazyLock;
383
384 use vortex_buffer::buffer;
385 use vortex_error::VortexExpect as _;
386 use vortex_session::VortexSession;
387
388 use super::*;
389 use crate::Canonical;
390 use crate::IntoArray;
391 use crate::LEGACY_SESSION;
392 use crate::VortexSessionExecute;
393 use crate::arrays::BoolArray;
394 use crate::arrays::PrimitiveArray;
395 use crate::arrays::StructArray;
396 use crate::assert_arrays_eq;
397 use crate::dtype::DType;
398 use crate::dtype::Nullability;
399 use crate::dtype::PType;
400 use crate::expr::case_when;
401 use crate::expr::case_when_no_else;
402 use crate::expr::col;
403 use crate::expr::eq;
404 use crate::expr::get_item;
405 use crate::expr::gt;
406 use crate::expr::lit;
407 use crate::expr::nested_case_when;
408 use crate::expr::root;
409 use crate::expr::test_harness;
410 use crate::scalar::Scalar;
411 use crate::session::ArraySession;
412
413 static SESSION: LazyLock<VortexSession> =
414 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
415
416 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
418 let mut ctx = SESSION.create_execution_ctx();
419 array
420 .clone()
421 .apply(expr)
422 .unwrap()
423 .execute::<Canonical>(&mut ctx)
424 .unwrap()
425 .into_array()
426 }
427
428 #[test]
431 #[should_panic(expected = "cannot serialize")]
432 fn test_serialization_roundtrip() {
433 let options = CaseWhenOptions {
434 num_when_then_pairs: 1,
435 has_else: true,
436 };
437 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
438 let deserialized = CaseWhen
439 .deserialize(&serialized, &VortexSession::empty())
440 .unwrap();
441 assert_eq!(options, deserialized);
442 }
443
444 #[test]
445 #[should_panic(expected = "cannot serialize")]
446 fn test_serialization_no_else() {
447 let options = CaseWhenOptions {
448 num_when_then_pairs: 1,
449 has_else: false,
450 };
451 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
452 let deserialized = CaseWhen
453 .deserialize(&serialized, &VortexSession::empty())
454 .unwrap();
455 assert_eq!(options, deserialized);
456 }
457
458 #[test]
461 fn test_display_with_else() {
462 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
463 let display = format!("{}", expr);
464 assert!(display.contains("CASE"));
465 assert!(display.contains("WHEN"));
466 assert!(display.contains("THEN"));
467 assert!(display.contains("ELSE"));
468 assert!(display.contains("END"));
469 }
470
471 #[test]
472 fn test_display_no_else() {
473 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
474 let display = format!("{}", expr);
475 assert!(display.contains("CASE"));
476 assert!(display.contains("WHEN"));
477 assert!(display.contains("THEN"));
478 assert!(!display.contains("ELSE"));
479 assert!(display.contains("END"));
480 }
481
482 #[test]
483 fn test_display_nested_nary() {
484 let expr = nested_case_when(
486 vec![
487 (gt(col("x"), lit(10i32)), lit("high")),
488 (gt(col("x"), lit(5i32)), lit("medium")),
489 ],
490 Some(lit("low")),
491 );
492 let display = format!("{}", expr);
493 assert_eq!(display.matches("CASE").count(), 1);
494 assert_eq!(display.matches("WHEN").count(), 2);
495 assert_eq!(display.matches("THEN").count(), 2);
496 }
497
498 #[test]
501 fn test_return_dtype_with_else() {
502 let expr = case_when(lit(true), lit(100i32), lit(0i32));
503 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
504 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
505 assert_eq!(
506 result_dtype,
507 DType::Primitive(PType::I32, Nullability::NonNullable)
508 );
509 }
510
511 #[test]
512 fn test_return_dtype_with_nullable_else() {
513 let expr = case_when(
514 lit(true),
515 lit(100i32),
516 lit(Scalar::null(DType::Primitive(
517 PType::I32,
518 Nullability::Nullable,
519 ))),
520 );
521 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
522 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
523 assert_eq!(
524 result_dtype,
525 DType::Primitive(PType::I32, Nullability::Nullable)
526 );
527 }
528
529 #[test]
530 fn test_return_dtype_without_else_is_nullable() {
531 let expr = case_when_no_else(lit(true), lit(100i32));
532 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
533 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
534 assert_eq!(
535 result_dtype,
536 DType::Primitive(PType::I32, Nullability::Nullable)
537 );
538 }
539
540 #[test]
541 fn test_return_dtype_with_struct_input() {
542 let dtype = test_harness::struct_dtype();
543 let expr = case_when(
544 gt(get_item("col1", root()), lit(10u16)),
545 lit(100i32),
546 lit(0i32),
547 );
548 let result_dtype = expr.return_dtype(&dtype).unwrap();
549 assert_eq!(
550 result_dtype,
551 DType::Primitive(PType::I32, Nullability::NonNullable)
552 );
553 }
554
555 #[test]
556 fn test_return_dtype_mismatched_then_else_errors() {
557 let expr = case_when(lit(true), lit(100i32), lit("zero"));
558 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
559 let err = expr.return_dtype(&input_dtype).unwrap_err();
560 assert!(
561 err.to_string()
562 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
563 );
564 }
565
566 #[test]
569 fn test_arity_with_else() {
570 let options = CaseWhenOptions {
571 num_when_then_pairs: 1,
572 has_else: true,
573 };
574 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
575 }
576
577 #[test]
578 fn test_arity_without_else() {
579 let options = CaseWhenOptions {
580 num_when_then_pairs: 1,
581 has_else: false,
582 };
583 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
584 }
585
586 #[test]
589 fn test_child_names() {
590 let options = CaseWhenOptions {
591 num_when_then_pairs: 1,
592 has_else: true,
593 };
594 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
595 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
596 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
597 }
598
599 #[test]
602 #[should_panic(expected = "cannot serialize")]
603 fn test_serialization_roundtrip_nary() {
604 let options = CaseWhenOptions {
605 num_when_then_pairs: 3,
606 has_else: true,
607 };
608 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
609 let deserialized = CaseWhen
610 .deserialize(&serialized, &VortexSession::empty())
611 .unwrap();
612 assert_eq!(options, deserialized);
613 }
614
615 #[test]
616 #[should_panic(expected = "cannot serialize")]
617 fn test_serialization_roundtrip_nary_no_else() {
618 let options = CaseWhenOptions {
619 num_when_then_pairs: 4,
620 has_else: false,
621 };
622 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
623 let deserialized = CaseWhen
624 .deserialize(&serialized, &VortexSession::empty())
625 .unwrap();
626 assert_eq!(options, deserialized);
627 }
628
629 #[test]
632 fn test_arity_nary_with_else() {
633 let options = CaseWhenOptions {
634 num_when_then_pairs: 3,
635 has_else: true,
636 };
637 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
639 }
640
641 #[test]
642 fn test_arity_nary_without_else() {
643 let options = CaseWhenOptions {
644 num_when_then_pairs: 3,
645 has_else: false,
646 };
647 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
649 }
650
651 #[test]
654 fn test_child_names_nary() {
655 let options = CaseWhenOptions {
656 num_when_then_pairs: 3,
657 has_else: true,
658 };
659 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
660 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
661 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
662 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
663 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
664 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
665 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
666 }
667
668 #[test]
671 fn test_return_dtype_nary_mismatched_then_types_errors() {
672 let expr = nested_case_when(
673 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
674 Some(lit(0i32)),
675 );
676 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
677 let err = expr.return_dtype(&input_dtype).unwrap_err();
678 assert!(err.to_string().contains("THEN dtypes must match"));
679 }
680
681 #[test]
682 fn test_return_dtype_nary_mixed_nullability() {
683 let non_null_then = lit(100i32);
686 let nullable_then = lit(Scalar::null(DType::Primitive(
687 PType::I32,
688 Nullability::Nullable,
689 )));
690 let expr = nested_case_when(
691 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
692 Some(lit(0i32)),
693 );
694 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
695 let result = expr.return_dtype(&input_dtype).unwrap();
696 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
697 }
698
699 #[test]
700 fn test_return_dtype_nary_no_else_is_nullable() {
701 let expr = nested_case_when(
702 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
703 None,
704 );
705 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
706 let result = expr.return_dtype(&input_dtype).unwrap();
707 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
708 }
709
710 #[test]
713 fn test_replace_children() {
714 let expr = case_when(lit(true), lit(1i32), lit(0i32));
715 expr.with_children([lit(false), lit(2i32), lit(3i32)])
716 .vortex_expect("operation should succeed in test");
717 }
718
719 #[test]
722 fn test_evaluate_simple_condition() {
723 let test_array =
724 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
725 .unwrap()
726 .into_array();
727
728 let expr = case_when(
729 gt(get_item("value", root()), lit(2i32)),
730 lit(100i32),
731 lit(0i32),
732 );
733
734 let result = evaluate_expr(&expr, &test_array);
735 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
736 }
737
738 #[test]
739 fn test_evaluate_nary_multiple_conditions() {
740 let test_array =
742 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
743 .unwrap()
744 .into_array();
745
746 let expr = nested_case_when(
747 vec![
748 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
749 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
750 ],
751 Some(lit(0i32)),
752 );
753
754 let result = evaluate_expr(&expr, &test_array);
755 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
756 }
757
758 #[test]
759 fn test_evaluate_nary_first_match_wins() {
760 let test_array =
761 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
762 .unwrap()
763 .into_array();
764
765 let expr = nested_case_when(
767 vec![
768 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
769 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
770 ],
771 Some(lit(0i32)),
772 );
773
774 let result = evaluate_expr(&expr, &test_array);
775 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
776 }
777
778 #[test]
779 fn test_evaluate_no_else_returns_null() {
780 let test_array =
781 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
782 .unwrap()
783 .into_array();
784
785 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
786
787 let result = evaluate_expr(&expr, &test_array);
788 assert!(result.dtype().is_nullable());
789 assert_arrays_eq!(
790 result,
791 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
792 .into_array()
793 );
794 }
795
796 #[test]
797 fn test_evaluate_all_conditions_false() {
798 let test_array =
799 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
800 .unwrap()
801 .into_array();
802
803 let expr = case_when(
804 gt(get_item("value", root()), lit(100i32)),
805 lit(1i32),
806 lit(0i32),
807 );
808
809 let result = evaluate_expr(&expr, &test_array);
810 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
811 }
812
813 #[test]
814 fn test_evaluate_all_conditions_true() {
815 let test_array =
816 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
817 .unwrap()
818 .into_array();
819
820 let expr = case_when(
821 gt(get_item("value", root()), lit(0i32)),
822 lit(100i32),
823 lit(0i32),
824 );
825
826 let result = evaluate_expr(&expr, &test_array);
827 assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
828 }
829
830 #[test]
831 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
832 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
835 .unwrap()
836 .into_array();
837
838 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
839
840 let result = evaluate_expr(&expr, &test_array);
841 assert!(
842 result.dtype().is_nullable(),
843 "result dtype must be Nullable, got {:?}",
844 result.dtype()
845 );
846 assert_arrays_eq!(
847 result,
848 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
849 );
850 }
851
852 #[test]
853 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
854 let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
862 .unwrap()
863 .into_array();
864
865 let nullable_20 =
866 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
867
868 let expr = nested_case_when(
869 vec![
870 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
871 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
872 ],
873 Some(lit(0i32)),
874 );
875
876 let result = evaluate_expr(&expr, &test_array);
877 assert!(
878 result.dtype().is_nullable(),
879 "result dtype must be Nullable, got {:?}",
880 result.dtype()
881 );
882 assert_arrays_eq!(
883 result,
884 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
885 );
886 Ok(())
887 }
888
889 #[test]
890 fn test_evaluate_with_literal_condition() {
891 let test_array = buffer![1i32, 2, 3].into_array();
892 let expr = case_when(lit(true), lit(100i32), lit(0i32));
893 let result = evaluate_expr(&expr, &test_array);
894
895 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
896 }
897
898 #[test]
899 fn test_evaluate_with_bool_column_result() {
900 let test_array =
901 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
902 .unwrap()
903 .into_array();
904
905 let expr = case_when(
906 gt(get_item("value", root()), lit(2i32)),
907 lit(true),
908 lit(false),
909 );
910
911 let result = evaluate_expr(&expr, &test_array);
912 assert_arrays_eq!(
913 result,
914 BoolArray::from_iter([false, false, true, true, true]).into_array()
915 );
916 }
917
918 #[test]
919 fn test_evaluate_with_nullable_condition() {
920 let test_array = StructArray::from_fields(&[(
921 "cond",
922 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
923 )])
924 .unwrap()
925 .into_array();
926
927 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
928
929 let result = evaluate_expr(&expr, &test_array);
930 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
931 }
932
933 #[test]
934 fn test_evaluate_with_nullable_result_values() {
935 let test_array = StructArray::from_fields(&[
936 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
937 (
938 "result",
939 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
940 .into_array(),
941 ),
942 ])
943 .unwrap()
944 .into_array();
945
946 let expr = case_when(
947 gt(get_item("value", root()), lit(2i32)),
948 get_item("result", root()),
949 lit(0i32),
950 );
951
952 let result = evaluate_expr(&expr, &test_array);
953 assert_arrays_eq!(
954 result,
955 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
956 .into_array()
957 );
958 }
959
960 #[test]
961 fn test_evaluate_with_all_null_condition() {
962 let test_array = StructArray::from_fields(&[(
963 "cond",
964 BoolArray::from_iter([None, None, None]).into_array(),
965 )])
966 .unwrap()
967 .into_array();
968
969 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
970
971 let result = evaluate_expr(&expr, &test_array);
972 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
973 }
974
975 #[test]
978 fn test_evaluate_nary_no_else_returns_null() {
979 let test_array =
980 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
981 .unwrap()
982 .into_array();
983
984 let expr = nested_case_when(
986 vec![
987 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
988 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
989 ],
990 None,
991 );
992
993 let result = evaluate_expr(&expr, &test_array);
994 assert!(result.dtype().is_nullable());
995 assert_arrays_eq!(
996 result,
997 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
998 .into_array()
999 );
1000 }
1001
1002 #[test]
1003 fn test_evaluate_nary_many_conditions() {
1004 let test_array =
1005 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1006 .unwrap()
1007 .into_array();
1008
1009 let expr = nested_case_when(
1011 vec![
1012 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1013 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1014 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1015 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1016 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1017 ],
1018 Some(lit(0i32)),
1019 );
1020
1021 let result = evaluate_expr(&expr, &test_array);
1022 assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1023 }
1024
1025 #[test]
1026 fn test_evaluate_nary_all_false_no_else() {
1027 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1028 .unwrap()
1029 .into_array();
1030
1031 let expr = nested_case_when(
1033 vec![
1034 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1035 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1036 ],
1037 None,
1038 );
1039
1040 let result = evaluate_expr(&expr, &test_array);
1041 assert!(result.dtype().is_nullable());
1042 assert_arrays_eq!(
1043 result,
1044 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1045 );
1046 }
1047
1048 #[test]
1049 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1050 let test_array =
1051 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1052 .unwrap()
1053 .into_array();
1054
1055 let expr = nested_case_when(
1059 vec![
1060 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1061 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1062 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1063 ],
1064 Some(lit(0i32)),
1065 );
1066
1067 let result = evaluate_expr(&expr, &test_array);
1068 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1070 }
1071
1072 #[test]
1073 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1074 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1077 .unwrap()
1078 .into_array();
1079
1080 let expr = nested_case_when(
1081 vec![
1082 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1083 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1085 ],
1086 Some(lit(0i32)),
1087 );
1088
1089 let result = evaluate_expr(&expr, &test_array);
1090 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1091 }
1092
1093 #[test]
1094 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1095 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1098 .unwrap()
1099 .into_array();
1100
1101 let expr = nested_case_when(
1102 vec![
1103 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1104 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1107 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1108 ],
1109 Some(lit(0i32)),
1110 );
1111
1112 let result = evaluate_expr(&expr, &test_array);
1113 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1114 }
1115
1116 #[test]
1117 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1118 let test_array =
1120 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
1121 .unwrap()
1122 .into_array();
1123
1124 let expr = nested_case_when(
1128 vec![
1129 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1130 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1131 ],
1132 Some(lit("none")),
1133 );
1134
1135 let result = evaluate_expr(&expr, &test_array);
1136 assert_eq!(
1137 result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1138 Scalar::utf8("low", Nullability::NonNullable)
1139 );
1140 assert_eq!(
1141 result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1142 Scalar::utf8("low", Nullability::NonNullable)
1143 );
1144 assert_eq!(
1145 result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1146 Scalar::utf8("high", Nullability::NonNullable)
1147 );
1148 assert_eq!(
1149 result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1150 Scalar::utf8("high", Nullability::NonNullable)
1151 );
1152 Ok(())
1153 }
1154
1155 #[test]
1156 fn test_evaluate_nary_with_nullable_conditions() {
1157 let test_array = StructArray::from_fields(&[
1158 (
1159 "cond1",
1160 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1161 ),
1162 (
1163 "cond2",
1164 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1165 ),
1166 ])
1167 .unwrap()
1168 .into_array();
1169
1170 let expr = nested_case_when(
1171 vec![
1172 (get_item("cond1", root()), lit(10i32)),
1173 (get_item("cond2", root()), lit(20i32)),
1174 ],
1175 Some(lit(0i32)),
1176 );
1177
1178 let result = evaluate_expr(&expr, &test_array);
1179 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1183 }
1184
1185 #[test]
1186 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1187 let n = 100usize;
1190
1191 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1193 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1194
1195 let result = merge_case_branches(
1196 vec![
1197 (
1198 branch0_mask,
1199 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1200 ),
1201 (
1202 branch1_mask,
1203 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1204 ),
1205 ],
1206 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1207 &mut SESSION.create_execution_ctx(),
1208 )?;
1209
1210 let expected: Vec<Option<i32>> = (0..n)
1212 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1213 .collect();
1214 assert_arrays_eq!(
1215 result,
1216 PrimitiveArray::from_option_iter(expected).into_array()
1217 );
1218 Ok(())
1219 }
1220}