1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::builders::ArrayBuilder;
32use crate::builders::builder_with_capacity;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::Expression;
36use crate::scalar::Scalar;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnVTable;
42use crate::scalar_fn::fns::zip::zip_impl;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct CaseWhenOptions {
47 pub num_when_then_pairs: u32,
49 pub has_else: bool,
52}
53
54impl CaseWhenOptions {
55 pub fn num_children(&self) -> usize {
57 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
58 }
59}
60
61impl fmt::Display for CaseWhenOptions {
62 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63 write!(
64 f,
65 "case_when(pairs={}, else={})",
66 self.num_when_then_pairs, self.has_else
67 )
68 }
69}
70
71#[derive(Clone)]
75pub struct CaseWhen;
76
77impl ScalarFnVTable for CaseWhen {
78 type Options = CaseWhenOptions;
79
80 fn id(&self) -> ScalarFnId {
81 ScalarFnId::from("vortex.case_when")
82 }
83
84 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
85 vortex_bail!("cannot serialize")
89 }
90
91 fn deserialize(
92 &self,
93 metadata: &[u8],
94 _session: &VortexSession,
95 ) -> VortexResult<Self::Options> {
96 let opts = pb::CaseWhenOpts::decode(metadata)?;
97 if opts.num_children < 2 {
98 vortex_bail!(
99 "CaseWhen expects at least 2 children, got {}",
100 opts.num_children
101 );
102 }
103 Ok(CaseWhenOptions {
104 num_when_then_pairs: opts.num_children / 2,
105 has_else: opts.num_children % 2 == 1,
106 })
107 }
108
109 fn arity(&self, options: &Self::Options) -> Arity {
110 Arity::Exact(options.num_children())
111 }
112
113 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
114 let num_pair_children = options.num_when_then_pairs as usize * 2;
115 if child_idx < num_pair_children {
116 let pair_idx = child_idx / 2;
117 if child_idx.is_multiple_of(2) {
118 ChildName::from(Arc::from(format!("when_{pair_idx}")))
119 } else {
120 ChildName::from(Arc::from(format!("then_{pair_idx}")))
121 }
122 } else if options.has_else && child_idx == num_pair_children {
123 ChildName::from("else")
124 } else {
125 unreachable!("Invalid child index {} for CaseWhen", child_idx)
126 }
127 }
128
129 fn fmt_sql(
130 &self,
131 options: &Self::Options,
132 expr: &Expression,
133 f: &mut Formatter<'_>,
134 ) -> fmt::Result {
135 write!(f, "CASE")?;
136 for i in 0..options.num_when_then_pairs as usize {
137 write!(
138 f,
139 " WHEN {} THEN {}",
140 expr.child(i * 2),
141 expr.child(i * 2 + 1)
142 )?;
143 }
144 if options.has_else {
145 let else_idx = options.num_when_then_pairs as usize * 2;
146 write!(f, " ELSE {}", expr.child(else_idx))?;
147 }
148 write!(f, " END")
149 }
150
151 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
152 if options.num_when_then_pairs == 0 {
153 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
154 }
155
156 let expected_len = options.num_children();
157 if arg_dtypes.len() != expected_len {
158 vortex_bail!(
159 "CaseWhen expects {expected_len} argument dtypes, got {}",
160 arg_dtypes.len()
161 );
162 }
163
164 let first_then = &arg_dtypes[1];
168 let mut result_dtype = first_then.clone();
169
170 for i in 1..options.num_when_then_pairs as usize {
171 let then_i = &arg_dtypes[i * 2 + 1];
172 if !first_then.eq_ignore_nullability(then_i) {
173 vortex_bail!(
174 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
175 first_then,
176 then_i
177 );
178 }
179 result_dtype = result_dtype.union_nullability(then_i.nullability());
180 }
181
182 if options.has_else {
183 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
184 if !result_dtype.eq_ignore_nullability(else_dtype) {
185 vortex_bail!(
186 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
187 first_then,
188 else_dtype
189 );
190 }
191 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
192 } else {
193 result_dtype = result_dtype.as_nullable();
195 }
196
197 Ok(result_dtype)
198 }
199
200 fn execute(
201 &self,
202 options: &Self::Options,
203 args: &dyn ExecutionArgs,
204 ctx: &mut ExecutionCtx,
205 ) -> VortexResult<ArrayRef> {
206 let row_count = args.row_count();
213 let num_pairs = options.num_when_then_pairs as usize;
214
215 let mut remaining = Mask::new_true(row_count);
216 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
217
218 for i in 0..num_pairs {
219 if remaining.all_false() {
220 break;
221 }
222
223 let condition = args.get(i * 2)?;
224 let cond_bool = condition.execute::<BoolArray>(ctx)?;
225 let cond_mask = cond_bool.to_mask_fill_null_false();
226 let effective_mask = &remaining & &cond_mask;
227
228 if effective_mask.all_false() {
229 continue;
230 }
231
232 let then_value = args.get(i * 2 + 1)?;
233 remaining = remaining.bitand_not(&cond_mask);
234 branches.push((effective_mask, then_value));
235 }
236
237 let else_value: ArrayRef = if options.has_else {
238 args.get(num_pairs * 2)?
239 } else {
240 let then_dtype = args.get(1)?.dtype().as_nullable();
241 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
242 };
243
244 if branches.is_empty() {
245 return Ok(else_value);
246 }
247
248 merge_case_branches(branches, else_value)
249 }
250
251 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
252 true
253 }
254
255 fn is_fallible(&self, _options: &Self::Options) -> bool {
256 false
257 }
258}
259
260const SLICE_CROSSOVER_RUN_LEN: usize = 4;
263
264fn merge_case_branches(
268 branches: Vec<(Mask, ArrayRef)>,
269 else_value: ArrayRef,
270) -> VortexResult<ArrayRef> {
271 if branches.len() == 1 {
272 let (mask, then_value) = &branches[0];
273 return zip_impl(then_value, &else_value, mask);
274 }
275
276 let output_nullability = branches
277 .iter()
278 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
279 acc | arr.dtype().nullability()
280 });
281 let output_dtype = else_value.dtype().with_nullability(output_nullability);
282 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
283
284 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
285 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
286 match mask.slices() {
287 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
288 AllOr::None => {}
289 AllOr::Some(slices) => {
290 for &(start, end) in slices {
291 spans.push((start, end, branch_idx));
292 }
293 }
294 }
295 }
296 spans.sort_unstable_by_key(|&(start, ..)| start);
297
298 if spans.is_empty() {
299 return else_value.cast(output_dtype);
300 }
301
302 let builder = builder_with_capacity(&output_dtype, else_value.len());
303
304 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
305 if fragmented {
306 merge_row_by_row(&branch_arrays, &else_value, &spans, &output_dtype, builder)
307 } else {
308 merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
309 }
310}
311
312fn merge_row_by_row(
315 branch_arrays: &[&ArrayRef],
316 else_value: &ArrayRef,
317 spans: &[(usize, usize, usize)],
318 output_dtype: &DType,
319 mut builder: Box<dyn ArrayBuilder>,
320) -> VortexResult<ArrayRef> {
321 let mut pos = 0;
322 for &(start, end, branch_idx) in spans {
323 for row in pos..start {
324 let scalar = else_value.scalar_at(row)?;
325 builder.append_scalar(&scalar.cast(output_dtype)?)?;
326 }
327 for row in start..end {
328 let scalar = branch_arrays[branch_idx].scalar_at(row)?;
329 builder.append_scalar(&scalar.cast(output_dtype)?)?;
330 }
331 pos = end;
332 }
333 for row in pos..else_value.len() {
334 let scalar = else_value.scalar_at(row)?;
335 builder.append_scalar(&scalar.cast(output_dtype)?)?;
336 }
337
338 Ok(builder.finish())
339}
340
341fn merge_run_by_run(
345 branch_arrays: &[&ArrayRef],
346 else_value: &ArrayRef,
347 spans: &[(usize, usize, usize)],
348 output_dtype: &DType,
349 mut builder: Box<dyn ArrayBuilder>,
350) -> VortexResult<ArrayRef> {
351 let else_value = else_value.cast(output_dtype.clone())?;
352 let len = else_value.len();
353 for (start, end, branch_idx) in spans {
354 if builder.len() < *start {
355 builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
356 }
357 builder.extend_from_array(
358 &branch_arrays[*branch_idx]
359 .cast(output_dtype.clone())?
360 .slice(*start..*end)?,
361 );
362 }
363 if builder.len() < len {
364 builder.extend_from_array(&else_value.slice(builder.len()..len)?);
365 }
366
367 Ok(builder.finish())
368}
369
370#[cfg(test)]
371mod tests {
372 use std::sync::LazyLock;
373
374 use vortex_buffer::buffer;
375 use vortex_error::VortexExpect as _;
376 use vortex_session::VortexSession;
377
378 use super::*;
379 use crate::Canonical;
380 use crate::IntoArray;
381 use crate::VortexSessionExecute as _;
382 use crate::arrays::BoolArray;
383 use crate::arrays::PrimitiveArray;
384 use crate::arrays::StructArray;
385 use crate::assert_arrays_eq;
386 use crate::dtype::DType;
387 use crate::dtype::Nullability;
388 use crate::dtype::PType;
389 use crate::expr::case_when;
390 use crate::expr::case_when_no_else;
391 use crate::expr::col;
392 use crate::expr::eq;
393 use crate::expr::get_item;
394 use crate::expr::gt;
395 use crate::expr::lit;
396 use crate::expr::nested_case_when;
397 use crate::expr::root;
398 use crate::expr::test_harness;
399 use crate::scalar::Scalar;
400 use crate::session::ArraySession;
401
402 static SESSION: LazyLock<VortexSession> =
403 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
404
405 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
407 let mut ctx = SESSION.create_execution_ctx();
408 array
409 .apply(expr)
410 .unwrap()
411 .execute::<Canonical>(&mut ctx)
412 .unwrap()
413 .into_array()
414 }
415
416 #[test]
419 #[should_panic(expected = "cannot serialize")]
420 fn test_serialization_roundtrip() {
421 let options = CaseWhenOptions {
422 num_when_then_pairs: 1,
423 has_else: true,
424 };
425 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
426 let deserialized = CaseWhen
427 .deserialize(&serialized, &VortexSession::empty())
428 .unwrap();
429 assert_eq!(options, deserialized);
430 }
431
432 #[test]
433 #[should_panic(expected = "cannot serialize")]
434 fn test_serialization_no_else() {
435 let options = CaseWhenOptions {
436 num_when_then_pairs: 1,
437 has_else: false,
438 };
439 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
440 let deserialized = CaseWhen
441 .deserialize(&serialized, &VortexSession::empty())
442 .unwrap();
443 assert_eq!(options, deserialized);
444 }
445
446 #[test]
449 fn test_display_with_else() {
450 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
451 let display = format!("{}", expr);
452 assert!(display.contains("CASE"));
453 assert!(display.contains("WHEN"));
454 assert!(display.contains("THEN"));
455 assert!(display.contains("ELSE"));
456 assert!(display.contains("END"));
457 }
458
459 #[test]
460 fn test_display_no_else() {
461 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
462 let display = format!("{}", expr);
463 assert!(display.contains("CASE"));
464 assert!(display.contains("WHEN"));
465 assert!(display.contains("THEN"));
466 assert!(!display.contains("ELSE"));
467 assert!(display.contains("END"));
468 }
469
470 #[test]
471 fn test_display_nested_nary() {
472 let expr = nested_case_when(
474 vec![
475 (gt(col("x"), lit(10i32)), lit("high")),
476 (gt(col("x"), lit(5i32)), lit("medium")),
477 ],
478 Some(lit("low")),
479 );
480 let display = format!("{}", expr);
481 assert_eq!(display.matches("CASE").count(), 1);
482 assert_eq!(display.matches("WHEN").count(), 2);
483 assert_eq!(display.matches("THEN").count(), 2);
484 }
485
486 #[test]
489 fn test_return_dtype_with_else() {
490 let expr = case_when(lit(true), lit(100i32), lit(0i32));
491 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
492 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
493 assert_eq!(
494 result_dtype,
495 DType::Primitive(PType::I32, Nullability::NonNullable)
496 );
497 }
498
499 #[test]
500 fn test_return_dtype_with_nullable_else() {
501 let expr = case_when(
502 lit(true),
503 lit(100i32),
504 lit(Scalar::null(DType::Primitive(
505 PType::I32,
506 Nullability::Nullable,
507 ))),
508 );
509 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
510 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
511 assert_eq!(
512 result_dtype,
513 DType::Primitive(PType::I32, Nullability::Nullable)
514 );
515 }
516
517 #[test]
518 fn test_return_dtype_without_else_is_nullable() {
519 let expr = case_when_no_else(lit(true), lit(100i32));
520 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
521 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
522 assert_eq!(
523 result_dtype,
524 DType::Primitive(PType::I32, Nullability::Nullable)
525 );
526 }
527
528 #[test]
529 fn test_return_dtype_with_struct_input() {
530 let dtype = test_harness::struct_dtype();
531 let expr = case_when(
532 gt(get_item("col1", root()), lit(10u16)),
533 lit(100i32),
534 lit(0i32),
535 );
536 let result_dtype = expr.return_dtype(&dtype).unwrap();
537 assert_eq!(
538 result_dtype,
539 DType::Primitive(PType::I32, Nullability::NonNullable)
540 );
541 }
542
543 #[test]
544 fn test_return_dtype_mismatched_then_else_errors() {
545 let expr = case_when(lit(true), lit(100i32), lit("zero"));
546 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
547 let err = expr.return_dtype(&input_dtype).unwrap_err();
548 assert!(
549 err.to_string()
550 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
551 );
552 }
553
554 #[test]
557 fn test_arity_with_else() {
558 let options = CaseWhenOptions {
559 num_when_then_pairs: 1,
560 has_else: true,
561 };
562 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
563 }
564
565 #[test]
566 fn test_arity_without_else() {
567 let options = CaseWhenOptions {
568 num_when_then_pairs: 1,
569 has_else: false,
570 };
571 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
572 }
573
574 #[test]
577 fn test_child_names() {
578 let options = CaseWhenOptions {
579 num_when_then_pairs: 1,
580 has_else: true,
581 };
582 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
583 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
584 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
585 }
586
587 #[test]
590 #[should_panic(expected = "cannot serialize")]
591 fn test_serialization_roundtrip_nary() {
592 let options = CaseWhenOptions {
593 num_when_then_pairs: 3,
594 has_else: true,
595 };
596 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
597 let deserialized = CaseWhen
598 .deserialize(&serialized, &VortexSession::empty())
599 .unwrap();
600 assert_eq!(options, deserialized);
601 }
602
603 #[test]
604 #[should_panic(expected = "cannot serialize")]
605 fn test_serialization_roundtrip_nary_no_else() {
606 let options = CaseWhenOptions {
607 num_when_then_pairs: 4,
608 has_else: false,
609 };
610 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
611 let deserialized = CaseWhen
612 .deserialize(&serialized, &VortexSession::empty())
613 .unwrap();
614 assert_eq!(options, deserialized);
615 }
616
617 #[test]
620 fn test_arity_nary_with_else() {
621 let options = CaseWhenOptions {
622 num_when_then_pairs: 3,
623 has_else: true,
624 };
625 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
627 }
628
629 #[test]
630 fn test_arity_nary_without_else() {
631 let options = CaseWhenOptions {
632 num_when_then_pairs: 3,
633 has_else: false,
634 };
635 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
637 }
638
639 #[test]
642 fn test_child_names_nary() {
643 let options = CaseWhenOptions {
644 num_when_then_pairs: 3,
645 has_else: true,
646 };
647 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
648 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
649 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
650 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
651 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
652 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
653 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
654 }
655
656 #[test]
659 fn test_return_dtype_nary_mismatched_then_types_errors() {
660 let expr = nested_case_when(
661 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
662 Some(lit(0i32)),
663 );
664 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
665 let err = expr.return_dtype(&input_dtype).unwrap_err();
666 assert!(err.to_string().contains("THEN dtypes must match"));
667 }
668
669 #[test]
670 fn test_return_dtype_nary_mixed_nullability() {
671 let non_null_then = lit(100i32);
674 let nullable_then = lit(Scalar::null(DType::Primitive(
675 PType::I32,
676 Nullability::Nullable,
677 )));
678 let expr = nested_case_when(
679 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
680 Some(lit(0i32)),
681 );
682 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
683 let result = expr.return_dtype(&input_dtype).unwrap();
684 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
685 }
686
687 #[test]
688 fn test_return_dtype_nary_no_else_is_nullable() {
689 let expr = nested_case_when(
690 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
691 None,
692 );
693 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
694 let result = expr.return_dtype(&input_dtype).unwrap();
695 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
696 }
697
698 #[test]
701 fn test_replace_children() {
702 let expr = case_when(lit(true), lit(1i32), lit(0i32));
703 expr.with_children([lit(false), lit(2i32), lit(3i32)])
704 .vortex_expect("operation should succeed in test");
705 }
706
707 #[test]
710 fn test_evaluate_simple_condition() {
711 let test_array =
712 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
713 .unwrap()
714 .into_array();
715
716 let expr = case_when(
717 gt(get_item("value", root()), lit(2i32)),
718 lit(100i32),
719 lit(0i32),
720 );
721
722 let result = evaluate_expr(&expr, &test_array);
723 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
724 }
725
726 #[test]
727 fn test_evaluate_nary_multiple_conditions() {
728 let test_array =
730 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
731 .unwrap()
732 .into_array();
733
734 let expr = nested_case_when(
735 vec![
736 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
737 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
738 ],
739 Some(lit(0i32)),
740 );
741
742 let result = evaluate_expr(&expr, &test_array);
743 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
744 }
745
746 #[test]
747 fn test_evaluate_nary_first_match_wins() {
748 let test_array =
749 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
750 .unwrap()
751 .into_array();
752
753 let expr = nested_case_when(
755 vec![
756 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
757 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
758 ],
759 Some(lit(0i32)),
760 );
761
762 let result = evaluate_expr(&expr, &test_array);
763 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
764 }
765
766 #[test]
767 fn test_evaluate_no_else_returns_null() {
768 let test_array =
769 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
770 .unwrap()
771 .into_array();
772
773 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
774
775 let result = evaluate_expr(&expr, &test_array);
776 assert!(result.dtype().is_nullable());
777 assert_arrays_eq!(
778 result,
779 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
780 .into_array()
781 );
782 }
783
784 #[test]
785 fn test_evaluate_all_conditions_false() {
786 let test_array =
787 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
788 .unwrap()
789 .into_array();
790
791 let expr = case_when(
792 gt(get_item("value", root()), lit(100i32)),
793 lit(1i32),
794 lit(0i32),
795 );
796
797 let result = evaluate_expr(&expr, &test_array);
798 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
799 }
800
801 #[test]
802 fn test_evaluate_all_conditions_true() {
803 let test_array =
804 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
805 .unwrap()
806 .into_array();
807
808 let expr = case_when(
809 gt(get_item("value", root()), lit(0i32)),
810 lit(100i32),
811 lit(0i32),
812 );
813
814 let result = evaluate_expr(&expr, &test_array);
815 assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
816 }
817
818 #[test]
819 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
820 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
823 .unwrap()
824 .into_array();
825
826 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
827
828 let result = evaluate_expr(&expr, &test_array);
829 assert!(
830 result.dtype().is_nullable(),
831 "result dtype must be Nullable, got {:?}",
832 result.dtype()
833 );
834 assert_arrays_eq!(
835 result,
836 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
837 );
838 }
839
840 #[test]
841 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
842 let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
850 .unwrap()
851 .into_array();
852
853 let nullable_20 =
854 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
855
856 let expr = nested_case_when(
857 vec![
858 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
859 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
860 ],
861 Some(lit(0i32)),
862 );
863
864 let result = evaluate_expr(&expr, &test_array);
865 assert!(
866 result.dtype().is_nullable(),
867 "result dtype must be Nullable, got {:?}",
868 result.dtype()
869 );
870 assert_arrays_eq!(
871 result,
872 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
873 );
874 Ok(())
875 }
876
877 #[test]
878 fn test_evaluate_with_literal_condition() {
879 let test_array = buffer![1i32, 2, 3].into_array();
880 let expr = case_when(lit(true), lit(100i32), lit(0i32));
881 let result = evaluate_expr(&expr, &test_array);
882
883 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
884 }
885
886 #[test]
887 fn test_evaluate_with_bool_column_result() {
888 let test_array =
889 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
890 .unwrap()
891 .into_array();
892
893 let expr = case_when(
894 gt(get_item("value", root()), lit(2i32)),
895 lit(true),
896 lit(false),
897 );
898
899 let result = evaluate_expr(&expr, &test_array);
900 assert_arrays_eq!(
901 result,
902 BoolArray::from_iter([false, false, true, true, true]).into_array()
903 );
904 }
905
906 #[test]
907 fn test_evaluate_with_nullable_condition() {
908 let test_array = StructArray::from_fields(&[(
909 "cond",
910 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
911 )])
912 .unwrap()
913 .into_array();
914
915 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
916
917 let result = evaluate_expr(&expr, &test_array);
918 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
919 }
920
921 #[test]
922 fn test_evaluate_with_nullable_result_values() {
923 let test_array = StructArray::from_fields(&[
924 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
925 (
926 "result",
927 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
928 .into_array(),
929 ),
930 ])
931 .unwrap()
932 .into_array();
933
934 let expr = case_when(
935 gt(get_item("value", root()), lit(2i32)),
936 get_item("result", root()),
937 lit(0i32),
938 );
939
940 let result = evaluate_expr(&expr, &test_array);
941 assert_arrays_eq!(
942 result,
943 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
944 .into_array()
945 );
946 }
947
948 #[test]
949 fn test_evaluate_with_all_null_condition() {
950 let test_array = StructArray::from_fields(&[(
951 "cond",
952 BoolArray::from_iter([None, None, None]).into_array(),
953 )])
954 .unwrap()
955 .into_array();
956
957 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
958
959 let result = evaluate_expr(&expr, &test_array);
960 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
961 }
962
963 #[test]
966 fn test_evaluate_nary_no_else_returns_null() {
967 let test_array =
968 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
969 .unwrap()
970 .into_array();
971
972 let expr = nested_case_when(
974 vec![
975 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
976 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
977 ],
978 None,
979 );
980
981 let result = evaluate_expr(&expr, &test_array);
982 assert!(result.dtype().is_nullable());
983 assert_arrays_eq!(
984 result,
985 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
986 .into_array()
987 );
988 }
989
990 #[test]
991 fn test_evaluate_nary_many_conditions() {
992 let test_array =
993 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
994 .unwrap()
995 .into_array();
996
997 let expr = nested_case_when(
999 vec![
1000 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1001 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1002 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1003 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1004 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1005 ],
1006 Some(lit(0i32)),
1007 );
1008
1009 let result = evaluate_expr(&expr, &test_array);
1010 assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1011 }
1012
1013 #[test]
1014 fn test_evaluate_nary_all_false_no_else() {
1015 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1016 .unwrap()
1017 .into_array();
1018
1019 let expr = nested_case_when(
1021 vec![
1022 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1023 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1024 ],
1025 None,
1026 );
1027
1028 let result = evaluate_expr(&expr, &test_array);
1029 assert!(result.dtype().is_nullable());
1030 assert_arrays_eq!(
1031 result,
1032 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1033 );
1034 }
1035
1036 #[test]
1037 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1038 let test_array =
1039 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1040 .unwrap()
1041 .into_array();
1042
1043 let expr = nested_case_when(
1047 vec![
1048 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1049 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1050 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1051 ],
1052 Some(lit(0i32)),
1053 );
1054
1055 let result = evaluate_expr(&expr, &test_array);
1056 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1058 }
1059
1060 #[test]
1061 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1062 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1065 .unwrap()
1066 .into_array();
1067
1068 let expr = nested_case_when(
1069 vec![
1070 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1071 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1073 ],
1074 Some(lit(0i32)),
1075 );
1076
1077 let result = evaluate_expr(&expr, &test_array);
1078 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1079 }
1080
1081 #[test]
1082 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1083 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1086 .unwrap()
1087 .into_array();
1088
1089 let expr = nested_case_when(
1090 vec![
1091 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1092 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1095 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1096 ],
1097 Some(lit(0i32)),
1098 );
1099
1100 let result = evaluate_expr(&expr, &test_array);
1101 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1102 }
1103
1104 #[test]
1105 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1106 let test_array =
1108 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
1109 .unwrap()
1110 .into_array();
1111
1112 let expr = nested_case_when(
1116 vec![
1117 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1118 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1119 ],
1120 Some(lit("none")),
1121 );
1122
1123 let result = evaluate_expr(&expr, &test_array);
1124 assert_eq!(
1125 result.scalar_at(0)?,
1126 Scalar::utf8("low", Nullability::NonNullable)
1127 );
1128 assert_eq!(
1129 result.scalar_at(1)?,
1130 Scalar::utf8("low", Nullability::NonNullable)
1131 );
1132 assert_eq!(
1133 result.scalar_at(2)?,
1134 Scalar::utf8("high", Nullability::NonNullable)
1135 );
1136 assert_eq!(
1137 result.scalar_at(3)?,
1138 Scalar::utf8("high", Nullability::NonNullable)
1139 );
1140 Ok(())
1141 }
1142
1143 #[test]
1144 fn test_evaluate_nary_with_nullable_conditions() {
1145 let test_array = StructArray::from_fields(&[
1146 (
1147 "cond1",
1148 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1149 ),
1150 (
1151 "cond2",
1152 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1153 ),
1154 ])
1155 .unwrap()
1156 .into_array();
1157
1158 let expr = nested_case_when(
1159 vec![
1160 (get_item("cond1", root()), lit(10i32)),
1161 (get_item("cond2", root()), lit(20i32)),
1162 ],
1163 Some(lit(0i32)),
1164 );
1165
1166 let result = evaluate_expr(&expr, &test_array);
1167 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1171 }
1172
1173 #[test]
1174 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1175 let n = 100usize;
1178
1179 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1181 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1182
1183 let result = merge_case_branches(
1184 vec![
1185 (
1186 branch0_mask,
1187 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1188 ),
1189 (
1190 branch1_mask,
1191 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1192 ),
1193 ],
1194 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1195 )?;
1196
1197 let expected: Vec<Option<i32>> = (0..n)
1199 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1200 .collect();
1201 assert_arrays_eq!(
1202 result,
1203 PrimitiveArray::from_option_iter(expected).into_array()
1204 );
1205 Ok(())
1206 }
1207}