1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25use vortex_session::registry::CachedId;
26
27use crate::ArrayRef;
28use crate::ExecutionCtx;
29use crate::IntoArray;
30use crate::arrays::BoolArray;
31use crate::arrays::ConstantArray;
32use crate::arrays::bool::BoolArrayExt;
33use crate::builders::ArrayBuilder;
34use crate::builders::builder_with_capacity;
35use crate::builtins::ArrayBuiltins;
36use crate::dtype::DType;
37use crate::expr::Expression;
38use crate::scalar::Scalar;
39use crate::scalar_fn::Arity;
40use crate::scalar_fn::ChildName;
41use crate::scalar_fn::ExecutionArgs;
42use crate::scalar_fn::ScalarFnId;
43use crate::scalar_fn::ScalarFnVTable;
44use crate::scalar_fn::SimplifyCtx;
45use crate::scalar_fn::fns::is_not_null::IsNotNull;
46use crate::scalar_fn::fns::is_null::IsNull;
47use crate::scalar_fn::fns::literal::Literal;
48use crate::scalar_fn::fns::zip::zip_impl;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct CaseWhenOptions {
53 pub num_when_then_pairs: u32,
55 pub has_else: bool,
58}
59
60impl CaseWhenOptions {
61 pub fn num_children(&self) -> usize {
63 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
64 }
65}
66
67impl fmt::Display for CaseWhenOptions {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 write!(
70 f,
71 "case_when(pairs={}, else={})",
72 self.num_when_then_pairs, self.has_else
73 )
74 }
75}
76
77#[derive(Clone)]
81pub struct CaseWhen;
82
83impl ScalarFnVTable for CaseWhen {
84 type Options = CaseWhenOptions;
85
86 fn id(&self) -> ScalarFnId {
87 static ID: CachedId = CachedId::new("vortex.case_when");
88 *ID
89 }
90
91 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
92 vortex_bail!("cannot serialize")
96 }
97
98 fn deserialize(
99 &self,
100 metadata: &[u8],
101 _session: &VortexSession,
102 ) -> VortexResult<Self::Options> {
103 let opts = pb::CaseWhenOpts::decode(metadata)?;
104 if opts.num_children < 2 {
105 vortex_bail!(
106 "CaseWhen expects at least 2 children, got {}",
107 opts.num_children
108 );
109 }
110 Ok(CaseWhenOptions {
111 num_when_then_pairs: opts.num_children / 2,
112 has_else: opts.num_children % 2 == 1,
113 })
114 }
115
116 fn arity(&self, options: &Self::Options) -> Arity {
117 Arity::Exact(options.num_children())
118 }
119
120 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
121 let num_pair_children = options.num_when_then_pairs as usize * 2;
122 if child_idx < num_pair_children {
123 let pair_idx = child_idx / 2;
124 if child_idx.is_multiple_of(2) {
125 ChildName::from(Arc::from(format!("when_{pair_idx}")))
126 } else {
127 ChildName::from(Arc::from(format!("then_{pair_idx}")))
128 }
129 } else if options.has_else && child_idx == num_pair_children {
130 ChildName::from("else")
131 } else {
132 unreachable!("Invalid child index {} for CaseWhen", child_idx)
133 }
134 }
135
136 fn fmt_sql(
137 &self,
138 options: &Self::Options,
139 expr: &Expression,
140 f: &mut Formatter<'_>,
141 ) -> fmt::Result {
142 write!(f, "CASE")?;
143 for i in 0..options.num_when_then_pairs as usize {
144 write!(
145 f,
146 " WHEN {} THEN {}",
147 expr.child(i * 2),
148 expr.child(i * 2 + 1)
149 )?;
150 }
151 if options.has_else {
152 let else_idx = options.num_when_then_pairs as usize * 2;
153 write!(f, " ELSE {}", expr.child(else_idx))?;
154 }
155 write!(f, " END")
156 }
157
158 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
159 if options.num_when_then_pairs == 0 {
160 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
161 }
162
163 let expected_len = options.num_children();
164 if arg_dtypes.len() != expected_len {
165 vortex_bail!(
166 "CaseWhen expects {expected_len} argument dtypes, got {}",
167 arg_dtypes.len()
168 );
169 }
170
171 let first_then = &arg_dtypes[1];
175 let mut result_dtype = first_then.clone();
176
177 for i in 1..options.num_when_then_pairs as usize {
178 let then_i = &arg_dtypes[i * 2 + 1];
179 if !first_then.eq_ignore_nullability(then_i) {
180 vortex_bail!(
181 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
182 first_then,
183 then_i
184 );
185 }
186 result_dtype = result_dtype.union_nullability(then_i.nullability());
187 }
188
189 if options.has_else {
190 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
191 if !result_dtype.eq_ignore_nullability(else_dtype) {
192 vortex_bail!(
193 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
194 first_then,
195 else_dtype
196 );
197 }
198 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
199 } else {
200 result_dtype = result_dtype.as_nullable();
202 }
203
204 Ok(result_dtype)
205 }
206
207 fn execute(
208 &self,
209 options: &Self::Options,
210 args: &dyn ExecutionArgs,
211 ctx: &mut ExecutionCtx,
212 ) -> VortexResult<ArrayRef> {
213 let row_count = args.row_count();
220 let num_pairs = options.num_when_then_pairs as usize;
221
222 let mut remaining = Mask::new_true(row_count);
223 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
224
225 for i in 0..num_pairs {
226 if remaining.all_false() {
227 break;
228 }
229
230 let condition = args.get(i * 2)?;
231 let cond_bool = condition.execute::<BoolArray>(ctx)?;
232 let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
233 let effective_mask = &remaining & &cond_mask;
234
235 if effective_mask.all_false() {
236 continue;
237 }
238
239 let then_value = args.get(i * 2 + 1)?;
240 remaining = remaining.bitand_not(&cond_mask);
241 branches.push((effective_mask, then_value));
242 }
243
244 let else_value: ArrayRef = if options.has_else {
245 args.get(num_pairs * 2)?
246 } else {
247 let then_dtype = args.get(1)?.dtype().as_nullable();
248 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
249 };
250
251 if branches.is_empty() {
252 return Ok(else_value);
253 }
254
255 merge_case_branches(branches, else_value, ctx)
256 }
257
258 fn simplify(
259 &self,
260 options: &Self::Options,
261 expr: &Expression,
262 _ctx: &dyn SimplifyCtx,
263 ) -> VortexResult<Option<Expression>> {
264 if options.num_when_then_pairs != 1 || !options.has_else {
274 return Ok(None);
275 }
276
277 let when = expr.child(0);
278 let then = expr.child(1);
279 let els = expr.child(2);
280
281 let (x, fill) = if when.is::<IsNull>() && when.child(0) == els {
283 (els, then)
284 } else if when.is::<IsNotNull>() && when.child(0) == then {
286 (then, els)
287 } else {
288 return Ok(None);
289 };
290
291 let Some(scalar) = fill.as_opt::<Literal>() else {
292 return Ok(None);
293 };
294
295 if scalar.is_null() {
296 return Ok(Some(x.clone()));
298 }
299
300 Ok(Some(crate::expr::fill_null(x.clone(), fill.clone())))
301 }
302
303 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
304 true
305 }
306
307 fn is_fallible(&self, _options: &Self::Options) -> bool {
308 false
309 }
310}
311
312const SLICE_CROSSOVER_RUN_LEN: usize = 4;
315
316fn merge_case_branches(
320 branches: Vec<(Mask, ArrayRef)>,
321 else_value: ArrayRef,
322 ctx: &mut ExecutionCtx,
323) -> VortexResult<ArrayRef> {
324 if branches.len() == 1 {
325 let (mask, then_value) = &branches[0];
326 return zip_impl(then_value, &else_value, mask, ctx);
327 }
328
329 let output_nullability = branches
330 .iter()
331 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
332 acc | arr.dtype().nullability()
333 });
334 let output_dtype = else_value.dtype().with_nullability(output_nullability);
335 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
336
337 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
338 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
339 match mask.slices() {
340 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
341 AllOr::None => {}
342 AllOr::Some(slices) => {
343 for &(start, end) in slices {
344 spans.push((start, end, branch_idx));
345 }
346 }
347 }
348 }
349 spans.sort_unstable_by_key(|&(start, ..)| start);
350
351 if spans.is_empty() {
352 return else_value.cast(output_dtype);
353 }
354
355 let builder = builder_with_capacity(&output_dtype, else_value.len());
356
357 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
358 if fragmented {
359 merge_row_by_row(
360 &branch_arrays,
361 &else_value,
362 &spans,
363 &output_dtype,
364 builder,
365 ctx,
366 )
367 } else {
368 merge_run_by_run(
369 &branch_arrays,
370 &else_value,
371 &spans,
372 &output_dtype,
373 builder,
374 ctx,
375 )
376 }
377}
378
379fn merge_row_by_row(
382 branch_arrays: &[&ArrayRef],
383 else_value: &ArrayRef,
384 spans: &[(usize, usize, usize)],
385 output_dtype: &DType,
386 mut builder: Box<dyn ArrayBuilder>,
387 ctx: &mut ExecutionCtx,
388) -> VortexResult<ArrayRef> {
389 let mut pos = 0;
390 for &(start, end, branch_idx) in spans {
391 for row in pos..start {
392 let scalar = else_value.execute_scalar(row, ctx)?;
393 builder.append_scalar(&scalar.cast(output_dtype)?)?;
394 }
395 for row in start..end {
396 let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
397 builder.append_scalar(&scalar.cast(output_dtype)?)?;
398 }
399 pos = end;
400 }
401 for row in pos..else_value.len() {
402 let scalar = else_value.execute_scalar(row, ctx)?;
403 builder.append_scalar(&scalar.cast(output_dtype)?)?;
404 }
405
406 Ok(builder.finish())
407}
408
409fn merge_run_by_run(
413 branch_arrays: &[&ArrayRef],
414 else_value: &ArrayRef,
415 spans: &[(usize, usize, usize)],
416 output_dtype: &DType,
417 mut builder: Box<dyn ArrayBuilder>,
418 ctx: &mut ExecutionCtx,
419) -> VortexResult<ArrayRef> {
420 let else_value = else_value.cast(output_dtype.clone())?;
421 let len = else_value.len();
422 for (start, end, branch_idx) in spans {
423 if builder.len() < *start {
424 else_value
425 .slice(builder.len()..*start)?
426 .append_to_builder(builder.as_mut(), ctx)?;
427 }
428 branch_arrays[*branch_idx]
429 .cast(output_dtype.clone())?
430 .slice(*start..*end)?
431 .append_to_builder(builder.as_mut(), ctx)?;
432 }
433 if builder.len() < len {
434 else_value
435 .slice(builder.len()..len)?
436 .append_to_builder(builder.as_mut(), ctx)?;
437 }
438
439 Ok(builder.finish())
440}
441
442#[cfg(test)]
443mod tests {
444 use std::sync::LazyLock;
445
446 use vortex_buffer::buffer;
447 use vortex_error::VortexExpect as _;
448 use vortex_session::VortexSession;
449
450 use super::*;
451 use crate::Canonical;
452 use crate::IntoArray;
453 use crate::VortexSessionExecute;
454 use crate::arrays::BoolArray;
455 use crate::arrays::PrimitiveArray;
456 use crate::arrays::StructArray;
457 use crate::assert_arrays_eq;
458 use crate::dtype::DType;
459 use crate::dtype::Nullability;
460 use crate::dtype::PType;
461 use crate::dtype::StructFields;
462 use crate::expr::case_when;
463 use crate::expr::case_when_no_else;
464 use crate::expr::col;
465 use crate::expr::eq;
466 use crate::expr::get_item;
467 use crate::expr::gt;
468 use crate::expr::is_not_null;
469 use crate::expr::is_null;
470 use crate::expr::lit;
471 use crate::expr::nested_case_when;
472 use crate::expr::root;
473 use crate::expr::test_harness;
474 use crate::scalar::Scalar;
475
476 static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
477
478 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
480 let mut ctx = SESSION.create_execution_ctx();
481 array
482 .clone()
483 .apply(expr)
484 .unwrap()
485 .execute::<Canonical>(&mut ctx)
486 .unwrap()
487 .into_array()
488 }
489
490 #[test]
493 #[should_panic(expected = "cannot serialize")]
494 fn test_serialization_roundtrip() {
495 let options = CaseWhenOptions {
496 num_when_then_pairs: 1,
497 has_else: true,
498 };
499 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
500 let deserialized = CaseWhen
501 .deserialize(&serialized, &VortexSession::empty())
502 .unwrap();
503 assert_eq!(options, deserialized);
504 }
505
506 #[test]
507 #[should_panic(expected = "cannot serialize")]
508 fn test_serialization_no_else() {
509 let options = CaseWhenOptions {
510 num_when_then_pairs: 1,
511 has_else: false,
512 };
513 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
514 let deserialized = CaseWhen
515 .deserialize(&serialized, &VortexSession::empty())
516 .unwrap();
517 assert_eq!(options, deserialized);
518 }
519
520 #[test]
523 fn test_display_with_else() {
524 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
525 let display = format!("{}", expr);
526 assert!(display.contains("CASE"));
527 assert!(display.contains("WHEN"));
528 assert!(display.contains("THEN"));
529 assert!(display.contains("ELSE"));
530 assert!(display.contains("END"));
531 }
532
533 #[test]
534 fn test_display_no_else() {
535 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
536 let display = format!("{}", expr);
537 assert!(display.contains("CASE"));
538 assert!(display.contains("WHEN"));
539 assert!(display.contains("THEN"));
540 assert!(!display.contains("ELSE"));
541 assert!(display.contains("END"));
542 }
543
544 #[test]
545 fn test_display_nested_nary() {
546 let expr = nested_case_when(
548 vec![
549 (gt(col("x"), lit(10i32)), lit("high")),
550 (gt(col("x"), lit(5i32)), lit("medium")),
551 ],
552 Some(lit("low")),
553 );
554 let display = format!("{}", expr);
555 assert_eq!(display.matches("CASE").count(), 1);
556 assert_eq!(display.matches("WHEN").count(), 2);
557 assert_eq!(display.matches("THEN").count(), 2);
558 }
559
560 #[test]
563 fn test_return_dtype_with_else() {
564 let expr = case_when(lit(true), lit(100i32), lit(0i32));
565 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
566 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
567 assert_eq!(
568 result_dtype,
569 DType::Primitive(PType::I32, Nullability::NonNullable)
570 );
571 }
572
573 #[test]
574 fn test_return_dtype_with_nullable_else() {
575 let expr = case_when(
576 lit(true),
577 lit(100i32),
578 lit(Scalar::null(DType::Primitive(
579 PType::I32,
580 Nullability::Nullable,
581 ))),
582 );
583 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
584 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
585 assert_eq!(
586 result_dtype,
587 DType::Primitive(PType::I32, Nullability::Nullable)
588 );
589 }
590
591 #[test]
592 fn test_return_dtype_without_else_is_nullable() {
593 let expr = case_when_no_else(lit(true), lit(100i32));
594 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
595 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
596 assert_eq!(
597 result_dtype,
598 DType::Primitive(PType::I32, Nullability::Nullable)
599 );
600 }
601
602 #[test]
603 fn test_return_dtype_with_struct_input() {
604 let dtype = test_harness::struct_dtype();
605 let expr = case_when(
606 gt(get_item("col1", root()), lit(10u16)),
607 lit(100i32),
608 lit(0i32),
609 );
610 let result_dtype = expr.return_dtype(&dtype).unwrap();
611 assert_eq!(
612 result_dtype,
613 DType::Primitive(PType::I32, Nullability::NonNullable)
614 );
615 }
616
617 #[test]
618 fn test_return_dtype_mismatched_then_else_errors() {
619 let expr = case_when(lit(true), lit(100i32), lit("zero"));
620 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
621 let err = expr.return_dtype(&input_dtype).unwrap_err();
622 assert!(
623 err.to_string()
624 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
625 );
626 }
627
628 #[test]
631 fn test_arity_with_else() {
632 let options = CaseWhenOptions {
633 num_when_then_pairs: 1,
634 has_else: true,
635 };
636 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
637 }
638
639 #[test]
640 fn test_arity_without_else() {
641 let options = CaseWhenOptions {
642 num_when_then_pairs: 1,
643 has_else: false,
644 };
645 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
646 }
647
648 #[test]
651 fn test_child_names() {
652 let options = CaseWhenOptions {
653 num_when_then_pairs: 1,
654 has_else: true,
655 };
656 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
657 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
658 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
659 }
660
661 #[test]
664 #[should_panic(expected = "cannot serialize")]
665 fn test_serialization_roundtrip_nary() {
666 let options = CaseWhenOptions {
667 num_when_then_pairs: 3,
668 has_else: true,
669 };
670 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
671 let deserialized = CaseWhen
672 .deserialize(&serialized, &VortexSession::empty())
673 .unwrap();
674 assert_eq!(options, deserialized);
675 }
676
677 #[test]
678 #[should_panic(expected = "cannot serialize")]
679 fn test_serialization_roundtrip_nary_no_else() {
680 let options = CaseWhenOptions {
681 num_when_then_pairs: 4,
682 has_else: false,
683 };
684 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
685 let deserialized = CaseWhen
686 .deserialize(&serialized, &VortexSession::empty())
687 .unwrap();
688 assert_eq!(options, deserialized);
689 }
690
691 #[test]
694 fn test_arity_nary_with_else() {
695 let options = CaseWhenOptions {
696 num_when_then_pairs: 3,
697 has_else: true,
698 };
699 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
701 }
702
703 #[test]
704 fn test_arity_nary_without_else() {
705 let options = CaseWhenOptions {
706 num_when_then_pairs: 3,
707 has_else: false,
708 };
709 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
711 }
712
713 #[test]
716 fn test_child_names_nary() {
717 let options = CaseWhenOptions {
718 num_when_then_pairs: 3,
719 has_else: true,
720 };
721 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
722 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
723 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
724 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
725 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
726 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
727 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
728 }
729
730 #[test]
733 fn test_return_dtype_nary_mismatched_then_types_errors() {
734 let expr = nested_case_when(
735 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
736 Some(lit(0i32)),
737 );
738 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
739 let err = expr.return_dtype(&input_dtype).unwrap_err();
740 assert!(err.to_string().contains("THEN dtypes must match"));
741 }
742
743 #[test]
744 fn test_return_dtype_nary_mixed_nullability() {
745 let non_null_then = lit(100i32);
748 let nullable_then = lit(Scalar::null(DType::Primitive(
749 PType::I32,
750 Nullability::Nullable,
751 )));
752 let expr = nested_case_when(
753 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
754 Some(lit(0i32)),
755 );
756 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
757 let result = expr.return_dtype(&input_dtype).unwrap();
758 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
759 }
760
761 #[test]
762 fn test_return_dtype_nary_no_else_is_nullable() {
763 let expr = nested_case_when(
764 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
765 None,
766 );
767 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
768 let result = expr.return_dtype(&input_dtype).unwrap();
769 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
770 }
771
772 #[test]
775 fn test_replace_children() {
776 let expr = case_when(lit(true), lit(1i32), lit(0i32));
777 expr.with_children([lit(false), lit(2i32), lit(3i32)])
778 .vortex_expect("operation should succeed in test");
779 }
780
781 #[test]
784 fn test_evaluate_simple_condition() {
785 let mut ctx = SESSION.create_execution_ctx();
786 let test_array =
787 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
788 .unwrap()
789 .into_array();
790
791 let expr = case_when(
792 gt(get_item("value", root()), lit(2i32)),
793 lit(100i32),
794 lit(0i32),
795 );
796
797 let result = evaluate_expr(&expr, &test_array);
798 assert_arrays_eq!(
799 result,
800 buffer![0i32, 0, 100, 100, 100].into_array(),
801 &mut ctx
802 );
803 }
804
805 #[test]
806 fn test_evaluate_nary_multiple_conditions() {
807 let mut ctx = SESSION.create_execution_ctx();
808 let test_array =
810 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
811 .unwrap()
812 .into_array();
813
814 let expr = nested_case_when(
815 vec![
816 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
817 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
818 ],
819 Some(lit(0i32)),
820 );
821
822 let result = evaluate_expr(&expr, &test_array);
823 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array(), &mut ctx);
824 }
825
826 #[test]
827 fn test_evaluate_nary_first_match_wins() {
828 let mut ctx = SESSION.create_execution_ctx();
829 let test_array =
830 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
831 .unwrap()
832 .into_array();
833
834 let expr = nested_case_when(
836 vec![
837 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
838 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
839 ],
840 Some(lit(0i32)),
841 );
842
843 let result = evaluate_expr(&expr, &test_array);
844 assert_arrays_eq!(
845 result,
846 buffer![0i32, 0, 100, 100, 100].into_array(),
847 &mut ctx
848 );
849 }
850
851 #[test]
852 fn test_evaluate_no_else_returns_null() {
853 let mut ctx = SESSION.create_execution_ctx();
854 let test_array =
855 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
856 .unwrap()
857 .into_array();
858
859 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
860
861 let result = evaluate_expr(&expr, &test_array);
862 assert!(result.dtype().is_nullable());
863 assert_arrays_eq!(
864 result,
865 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
866 .into_array(),
867 &mut ctx
868 );
869 }
870
871 #[test]
872 fn test_evaluate_all_conditions_false() {
873 let mut ctx = SESSION.create_execution_ctx();
874 let test_array =
875 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
876 .unwrap()
877 .into_array();
878
879 let expr = case_when(
880 gt(get_item("value", root()), lit(100i32)),
881 lit(1i32),
882 lit(0i32),
883 );
884
885 let result = evaluate_expr(&expr, &test_array);
886 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array(), &mut ctx);
887 }
888
889 #[test]
890 fn test_evaluate_all_conditions_true() {
891 let mut ctx = SESSION.create_execution_ctx();
892 let test_array =
893 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
894 .unwrap()
895 .into_array();
896
897 let expr = case_when(
898 gt(get_item("value", root()), lit(0i32)),
899 lit(100i32),
900 lit(0i32),
901 );
902
903 let result = evaluate_expr(&expr, &test_array);
904 assert_arrays_eq!(
905 result,
906 buffer![100i32, 100, 100, 100, 100].into_array(),
907 &mut ctx
908 );
909 }
910
911 #[test]
912 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
913 let mut ctx = SESSION.create_execution_ctx();
914 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
917 .unwrap()
918 .into_array();
919
920 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
921
922 let result = evaluate_expr(&expr, &test_array);
923 assert!(
924 result.dtype().is_nullable(),
925 "result dtype must be Nullable, got {:?}",
926 result.dtype()
927 );
928 assert_arrays_eq!(
929 result,
930 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array(),
931 &mut ctx
932 );
933 }
934
935 #[test]
936 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
937 let mut ctx = SESSION.create_execution_ctx();
938 let test_array =
946 StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
947
948 let nullable_20 =
949 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
950
951 let expr = nested_case_when(
952 vec![
953 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
954 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
955 ],
956 Some(lit(0i32)),
957 );
958
959 let result = evaluate_expr(&expr, &test_array);
960 assert!(
961 result.dtype().is_nullable(),
962 "result dtype must be Nullable, got {:?}",
963 result.dtype()
964 );
965 assert_arrays_eq!(
966 result,
967 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array(),
968 &mut ctx
969 );
970 Ok(())
971 }
972
973 #[test]
974 fn test_evaluate_with_literal_condition() {
975 let mut ctx = SESSION.create_execution_ctx();
976 let test_array = buffer![1i32, 2, 3].into_array();
977 let expr = case_when(lit(true), lit(100i32), lit(0i32));
978 let result = evaluate_expr(&expr, &test_array);
979
980 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array(), &mut ctx);
981 }
982
983 #[test]
984 fn test_evaluate_with_bool_column_result() {
985 let mut ctx = SESSION.create_execution_ctx();
986 let test_array =
987 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
988 .unwrap()
989 .into_array();
990
991 let expr = case_when(
992 gt(get_item("value", root()), lit(2i32)),
993 lit(true),
994 lit(false),
995 );
996
997 let result = evaluate_expr(&expr, &test_array);
998 assert_arrays_eq!(
999 result,
1000 BoolArray::from_iter([false, false, true, true, true]).into_array(),
1001 &mut ctx
1002 );
1003 }
1004
1005 #[test]
1006 fn test_evaluate_with_nullable_condition() {
1007 let mut ctx = SESSION.create_execution_ctx();
1008 let test_array = StructArray::from_fields(&[(
1009 "cond",
1010 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
1011 )])
1012 .unwrap()
1013 .into_array();
1014
1015 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
1016
1017 let result = evaluate_expr(&expr, &test_array);
1018 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array(), &mut ctx);
1019 }
1020
1021 #[test]
1022 fn test_evaluate_with_nullable_result_values() {
1023 let mut ctx = SESSION.create_execution_ctx();
1024 let test_array = StructArray::from_fields(&[
1025 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
1026 (
1027 "result",
1028 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
1029 .into_array(),
1030 ),
1031 ])
1032 .unwrap()
1033 .into_array();
1034
1035 let expr = case_when(
1036 gt(get_item("value", root()), lit(2i32)),
1037 get_item("result", root()),
1038 lit(0i32),
1039 );
1040
1041 let result = evaluate_expr(&expr, &test_array);
1042 assert_arrays_eq!(
1043 result,
1044 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
1045 .into_array(),
1046 &mut ctx
1047 );
1048 }
1049
1050 #[test]
1051 fn test_evaluate_with_all_null_condition() {
1052 let mut ctx = SESSION.create_execution_ctx();
1053 let test_array = StructArray::from_fields(&[(
1054 "cond",
1055 BoolArray::from_iter([None, None, None]).into_array(),
1056 )])
1057 .unwrap()
1058 .into_array();
1059
1060 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
1061
1062 let result = evaluate_expr(&expr, &test_array);
1063 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array(), &mut ctx);
1064 }
1065
1066 #[test]
1069 fn test_evaluate_nary_no_else_returns_null() {
1070 let mut ctx = SESSION.create_execution_ctx();
1071 let test_array =
1072 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1073 .unwrap()
1074 .into_array();
1075
1076 let expr = nested_case_when(
1078 vec![
1079 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1080 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1081 ],
1082 None,
1083 );
1084
1085 let result = evaluate_expr(&expr, &test_array);
1086 assert!(result.dtype().is_nullable());
1087 assert_arrays_eq!(
1088 result,
1089 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
1090 .into_array(),
1091 &mut ctx
1092 );
1093 }
1094
1095 #[test]
1096 fn test_evaluate_nary_many_conditions() {
1097 let mut ctx = SESSION.create_execution_ctx();
1098 let test_array =
1099 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1100 .unwrap()
1101 .into_array();
1102
1103 let expr = nested_case_when(
1105 vec![
1106 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1107 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1108 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1109 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1110 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1111 ],
1112 Some(lit(0i32)),
1113 );
1114
1115 let result = evaluate_expr(&expr, &test_array);
1116 assert_arrays_eq!(
1117 result,
1118 buffer![10i32, 20, 30, 40, 50].into_array(),
1119 &mut ctx
1120 );
1121 }
1122
1123 #[test]
1124 fn test_evaluate_nary_all_false_no_else() {
1125 let mut ctx = SESSION.create_execution_ctx();
1126 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1127 .unwrap()
1128 .into_array();
1129
1130 let expr = nested_case_when(
1132 vec![
1133 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1134 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1135 ],
1136 None,
1137 );
1138
1139 let result = evaluate_expr(&expr, &test_array);
1140 assert!(result.dtype().is_nullable());
1141 assert_arrays_eq!(
1142 result,
1143 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array(),
1144 &mut ctx
1145 );
1146 }
1147
1148 #[test]
1149 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1150 let mut ctx = SESSION.create_execution_ctx();
1151 let test_array =
1152 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1153 .unwrap()
1154 .into_array();
1155
1156 let expr = nested_case_when(
1160 vec![
1161 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1162 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1163 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1164 ],
1165 Some(lit(0i32)),
1166 );
1167
1168 let result = evaluate_expr(&expr, &test_array);
1169 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array(), &mut ctx);
1171 }
1172
1173 #[test]
1174 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1175 let mut ctx = SESSION.create_execution_ctx();
1176 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1179 .unwrap()
1180 .into_array();
1181
1182 let expr = nested_case_when(
1183 vec![
1184 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1185 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1187 ],
1188 Some(lit(0i32)),
1189 );
1190
1191 let result = evaluate_expr(&expr, &test_array);
1192 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array(), &mut ctx);
1193 }
1194
1195 #[test]
1196 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1197 let mut ctx = SESSION.create_execution_ctx();
1198 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1201 .unwrap()
1202 .into_array();
1203
1204 let expr = nested_case_when(
1205 vec![
1206 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1207 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1210 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1211 ],
1212 Some(lit(0i32)),
1213 );
1214
1215 let result = evaluate_expr(&expr, &test_array);
1216 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array(), &mut ctx);
1217 }
1218
1219 #[test]
1220 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1221 let test_array =
1223 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1224 .into_array();
1225
1226 let expr = nested_case_when(
1230 vec![
1231 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1232 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1233 ],
1234 Some(lit("none")),
1235 );
1236
1237 let result = evaluate_expr(&expr, &test_array);
1238 assert_eq!(
1239 result.execute_scalar(0, &mut SESSION.create_execution_ctx())?,
1240 Scalar::utf8("low", Nullability::NonNullable)
1241 );
1242 assert_eq!(
1243 result.execute_scalar(1, &mut SESSION.create_execution_ctx())?,
1244 Scalar::utf8("low", Nullability::NonNullable)
1245 );
1246 assert_eq!(
1247 result.execute_scalar(2, &mut SESSION.create_execution_ctx())?,
1248 Scalar::utf8("high", Nullability::NonNullable)
1249 );
1250 assert_eq!(
1251 result.execute_scalar(3, &mut SESSION.create_execution_ctx())?,
1252 Scalar::utf8("high", Nullability::NonNullable)
1253 );
1254 Ok(())
1255 }
1256
1257 #[test]
1258 fn test_evaluate_nary_with_nullable_conditions() {
1259 let mut ctx = SESSION.create_execution_ctx();
1260 let test_array = StructArray::from_fields(&[
1261 (
1262 "cond1",
1263 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1264 ),
1265 (
1266 "cond2",
1267 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1268 ),
1269 ])
1270 .unwrap()
1271 .into_array();
1272
1273 let expr = nested_case_when(
1274 vec![
1275 (get_item("cond1", root()), lit(10i32)),
1276 (get_item("cond2", root()), lit(20i32)),
1277 ],
1278 Some(lit(0i32)),
1279 );
1280
1281 let result = evaluate_expr(&expr, &test_array);
1282 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array(), &mut ctx);
1286 }
1287
1288 fn nullable_i64_scope(fields: &[&str]) -> DType {
1292 DType::Struct(
1293 StructFields::new(
1294 fields.to_vec().into(),
1295 vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()],
1296 ),
1297 Nullability::NonNullable,
1298 )
1299 }
1300
1301 #[test]
1302 fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> {
1303 let expr = case_when(is_null(col("x")), lit(0i64), col("x"));
1305 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1306 assert!(
1307 optimized.to_string().starts_with("vortex.fill_null"),
1308 "expected fill_null, got {optimized}"
1309 );
1310 Ok(())
1311 }
1312
1313 #[test]
1314 fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> {
1315 let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64));
1317 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1318 assert!(
1319 optimized.to_string().starts_with("vortex.fill_null"),
1320 "expected fill_null, got {optimized}"
1321 );
1322 Ok(())
1323 }
1324
1325 #[test]
1326 fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> {
1327 let expr = case_when(is_null(col("x")), lit(0i64), col("y"));
1329 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?;
1330 let s = optimized.to_string();
1331 assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1332 assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1333 Ok(())
1334 }
1335
1336 #[test]
1337 fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> {
1338 let expr = case_when(is_null(col("x")), col("c"), col("x"));
1341 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?;
1342 let s = optimized.to_string();
1343 assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1344 assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1345 Ok(())
1346 }
1347
1348 #[test]
1349 fn test_simplify_null_fill_collapses_to_input() -> VortexResult<()> {
1350 let null_fill = || {
1354 lit(Scalar::null(DType::Primitive(
1355 PType::I64,
1356 Nullability::Nullable,
1357 )))
1358 };
1359
1360 for expr in [
1361 case_when(is_null(col("x")), null_fill(), col("x")),
1362 case_when(is_not_null(col("x")), col("x"), null_fill()),
1363 ] {
1364 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1365 assert_eq!(
1366 optimized.to_string(),
1367 "$.x",
1368 "expected collapse to input column, got {optimized}"
1369 );
1370 }
1371 Ok(())
1372 }
1373
1374 #[test]
1375 fn test_simplify_null_fill_semantic_equivalence() -> VortexResult<()> {
1376 let mut ctx = SESSION.create_execution_ctx();
1377 let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1379 let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1380 let null_fill = lit(Scalar::null(DType::Primitive(
1381 PType::I64,
1382 Nullability::Nullable,
1383 )));
1384
1385 let original = case_when(is_null(root()), null_fill, root());
1386 let optimized = original.optimize_recursive(&scope)?;
1387 assert_eq!(
1388 optimized.to_string(),
1389 "$",
1390 "expected collapse to root, got {optimized}"
1391 );
1392
1393 let expected = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1394 assert_arrays_eq!(evaluate_expr(&original, &array), expected, &mut ctx);
1395 assert_arrays_eq!(evaluate_expr(&optimized, &array), expected, &mut ctx);
1396 Ok(())
1397 }
1398
1399 #[test]
1400 fn test_simplify_does_not_fire_without_else() -> VortexResult<()> {
1401 let expr = case_when_no_else(is_null(col("x")), lit(0i64));
1402 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1403 assert!(
1404 !optimized.to_string().contains("fill_null"),
1405 "must not rewrite a no-ELSE case_when, got {optimized}"
1406 );
1407 Ok(())
1408 }
1409
1410 #[test]
1411 fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> {
1412 let expr = nested_case_when(
1413 vec![
1414 (is_null(col("x")), lit(0i64)),
1415 (gt(col("x"), lit(5i64)), lit(1i64)),
1416 ],
1417 Some(col("x")),
1418 );
1419 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1420 assert!(
1421 !optimized.to_string().contains("fill_null"),
1422 "must not rewrite a multi-pair case_when, got {optimized}"
1423 );
1424 Ok(())
1425 }
1426
1427 #[test]
1428 fn test_simplify_semantic_equivalence() -> VortexResult<()> {
1429 let mut ctx = SESSION.create_execution_ctx();
1430 let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1432 let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1433
1434 let original = case_when(is_null(root()), lit(0i64), root());
1435 let optimized = original.optimize_recursive(&scope)?;
1436 assert!(
1437 optimized.to_string().starts_with("vortex.fill_null"),
1438 "expected fill_null, got {optimized}"
1439 );
1440
1441 assert_arrays_eq!(
1444 evaluate_expr(&original, &array),
1445 PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array(),
1446 &mut ctx
1447 );
1448 assert_arrays_eq!(
1449 evaluate_expr(&optimized, &array),
1450 buffer![1i64, 0, 3].into_array(),
1451 &mut ctx
1452 );
1453 Ok(())
1454 }
1455
1456 #[test]
1457 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1458 let mut ctx = SESSION.create_execution_ctx();
1459 let n = 100usize;
1462
1463 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2));
1465 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2));
1466
1467 let result = merge_case_branches(
1468 vec![
1469 (
1470 branch0_mask,
1471 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1472 ),
1473 (
1474 branch1_mask,
1475 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1476 ),
1477 ],
1478 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1479 &mut SESSION.create_execution_ctx(),
1480 )?;
1481
1482 let expected: Vec<Option<i32>> = (0..n)
1484 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1485 .collect();
1486 assert_arrays_eq!(
1487 result,
1488 PrimitiveArray::from_option_iter(expected).into_array(),
1489 &mut ctx
1490 );
1491 Ok(())
1492 }
1493}