1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25use vortex_session::registry::CachedId;
26
27use crate::ArrayRef;
28use crate::ExecutionCtx;
29use crate::IntoArray;
30use crate::arrays::BoolArray;
31use crate::arrays::ConstantArray;
32use crate::arrays::bool::BoolArrayExt;
33use crate::builders::ArrayBuilder;
34use crate::builders::builder_with_capacity;
35use crate::builtins::ArrayBuiltins;
36use crate::dtype::DType;
37use crate::expr::Expression;
38use crate::scalar::Scalar;
39use crate::scalar_fn::Arity;
40use crate::scalar_fn::ChildName;
41use crate::scalar_fn::ExecutionArgs;
42use crate::scalar_fn::ScalarFnId;
43use crate::scalar_fn::ScalarFnVTable;
44use crate::scalar_fn::SimplifyCtx;
45use crate::scalar_fn::fns::is_not_null::IsNotNull;
46use crate::scalar_fn::fns::is_null::IsNull;
47use crate::scalar_fn::fns::literal::Literal;
48use crate::scalar_fn::fns::zip::zip_impl;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct CaseWhenOptions {
53 pub num_when_then_pairs: u32,
55 pub has_else: bool,
58}
59
60impl CaseWhenOptions {
61 pub fn num_children(&self) -> usize {
63 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
64 }
65}
66
67impl fmt::Display for CaseWhenOptions {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 write!(
70 f,
71 "case_when(pairs={}, else={})",
72 self.num_when_then_pairs, self.has_else
73 )
74 }
75}
76
77#[derive(Clone)]
81pub struct CaseWhen;
82
83impl ScalarFnVTable for CaseWhen {
84 type Options = CaseWhenOptions;
85
86 fn id(&self) -> ScalarFnId {
87 static ID: CachedId = CachedId::new("vortex.case_when");
88 *ID
89 }
90
91 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
92 vortex_bail!("cannot serialize")
96 }
97
98 fn deserialize(
99 &self,
100 metadata: &[u8],
101 _session: &VortexSession,
102 ) -> VortexResult<Self::Options> {
103 let opts = pb::CaseWhenOpts::decode(metadata)?;
104 if opts.num_children < 2 {
105 vortex_bail!(
106 "CaseWhen expects at least 2 children, got {}",
107 opts.num_children
108 );
109 }
110 Ok(CaseWhenOptions {
111 num_when_then_pairs: opts.num_children / 2,
112 has_else: opts.num_children % 2 == 1,
113 })
114 }
115
116 fn arity(&self, options: &Self::Options) -> Arity {
117 Arity::Exact(options.num_children())
118 }
119
120 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
121 let num_pair_children = options.num_when_then_pairs as usize * 2;
122 if child_idx < num_pair_children {
123 let pair_idx = child_idx / 2;
124 if child_idx.is_multiple_of(2) {
125 ChildName::from(Arc::from(format!("when_{pair_idx}")))
126 } else {
127 ChildName::from(Arc::from(format!("then_{pair_idx}")))
128 }
129 } else if options.has_else && child_idx == num_pair_children {
130 ChildName::from("else")
131 } else {
132 unreachable!("Invalid child index {} for CaseWhen", child_idx)
133 }
134 }
135
136 fn fmt_sql(
137 &self,
138 options: &Self::Options,
139 expr: &Expression,
140 f: &mut Formatter<'_>,
141 ) -> fmt::Result {
142 write!(f, "CASE")?;
143 for i in 0..options.num_when_then_pairs as usize {
144 write!(
145 f,
146 " WHEN {} THEN {}",
147 expr.child(i * 2),
148 expr.child(i * 2 + 1)
149 )?;
150 }
151 if options.has_else {
152 let else_idx = options.num_when_then_pairs as usize * 2;
153 write!(f, " ELSE {}", expr.child(else_idx))?;
154 }
155 write!(f, " END")
156 }
157
158 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
159 if options.num_when_then_pairs == 0 {
160 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
161 }
162
163 let expected_len = options.num_children();
164 if arg_dtypes.len() != expected_len {
165 vortex_bail!(
166 "CaseWhen expects {expected_len} argument dtypes, got {}",
167 arg_dtypes.len()
168 );
169 }
170
171 let first_then = &arg_dtypes[1];
175 let mut result_dtype = first_then.clone();
176
177 for i in 1..options.num_when_then_pairs as usize {
178 let then_i = &arg_dtypes[i * 2 + 1];
179 if !first_then.eq_ignore_nullability(then_i) {
180 vortex_bail!(
181 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
182 first_then,
183 then_i
184 );
185 }
186 result_dtype = result_dtype.union_nullability(then_i.nullability());
187 }
188
189 if options.has_else {
190 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
191 if !result_dtype.eq_ignore_nullability(else_dtype) {
192 vortex_bail!(
193 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
194 first_then,
195 else_dtype
196 );
197 }
198 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
199 } else {
200 result_dtype = result_dtype.as_nullable();
202 }
203
204 Ok(result_dtype)
205 }
206
207 fn execute(
208 &self,
209 options: &Self::Options,
210 args: &dyn ExecutionArgs,
211 ctx: &mut ExecutionCtx,
212 ) -> VortexResult<ArrayRef> {
213 let row_count = args.row_count();
220 let num_pairs = options.num_when_then_pairs as usize;
221
222 let mut remaining = Mask::new_true(row_count);
223 let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
224
225 for i in 0..num_pairs {
226 if remaining.all_false() {
227 break;
228 }
229
230 let condition = args.get(i * 2)?;
231 let cond_bool = condition.execute::<BoolArray>(ctx)?;
232 let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
233 let effective_mask = &remaining & &cond_mask;
234
235 if effective_mask.all_false() {
236 continue;
237 }
238
239 let then_value = args.get(i * 2 + 1)?;
240 remaining = remaining.bitand_not(&cond_mask);
241 branches.push((effective_mask, then_value));
242 }
243
244 let else_value: ArrayRef = if options.has_else {
245 args.get(num_pairs * 2)?
246 } else {
247 let then_dtype = args.get(1)?.dtype().as_nullable();
248 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
249 };
250
251 if branches.is_empty() {
252 return Ok(else_value);
253 }
254
255 merge_case_branches(branches, else_value, ctx)
256 }
257
258 fn simplify(
259 &self,
260 options: &Self::Options,
261 expr: &Expression,
262 _ctx: &dyn SimplifyCtx,
263 ) -> VortexResult<Option<Expression>> {
264 if options.num_when_then_pairs != 1 || !options.has_else {
274 return Ok(None);
275 }
276
277 let when = expr.child(0);
278 let then = expr.child(1);
279 let els = expr.child(2);
280
281 let (x, fill) = if when.is::<IsNull>() && when.child(0) == els {
283 (els, then)
284 } else if when.is::<IsNotNull>() && when.child(0) == then {
286 (then, els)
287 } else {
288 return Ok(None);
289 };
290
291 let Some(scalar) = fill.as_opt::<Literal>() else {
292 return Ok(None);
293 };
294
295 if scalar.is_null() {
296 return Ok(Some(x.clone()));
298 }
299
300 Ok(Some(crate::expr::fill_null(x.clone(), fill.clone())))
301 }
302
303 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
304 true
305 }
306
307 fn is_fallible(&self, _options: &Self::Options) -> bool {
308 false
309 }
310}
311
312const SLICE_CROSSOVER_RUN_LEN: usize = 4;
315
316fn merge_case_branches(
320 branches: Vec<(Mask, ArrayRef)>,
321 else_value: ArrayRef,
322 ctx: &mut ExecutionCtx,
323) -> VortexResult<ArrayRef> {
324 if branches.len() == 1 {
325 let (mask, then_value) = &branches[0];
326 return zip_impl(then_value, &else_value, mask, ctx);
327 }
328
329 let output_nullability = branches
330 .iter()
331 .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
332 acc | arr.dtype().nullability()
333 });
334 let output_dtype = else_value.dtype().with_nullability(output_nullability);
335 let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
336
337 let mut spans: Vec<(usize, usize, usize)> = Vec::new();
338 for (branch_idx, (mask, _)) in branches.iter().enumerate() {
339 match mask.slices() {
340 AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
341 AllOr::None => {}
342 AllOr::Some(slices) => {
343 for &(start, end) in slices {
344 spans.push((start, end, branch_idx));
345 }
346 }
347 }
348 }
349 spans.sort_unstable_by_key(|&(start, ..)| start);
350
351 if spans.is_empty() {
352 return else_value.cast(output_dtype);
353 }
354
355 let builder = builder_with_capacity(&output_dtype, else_value.len());
356
357 let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
358 if fragmented {
359 merge_row_by_row(
360 &branch_arrays,
361 &else_value,
362 &spans,
363 &output_dtype,
364 builder,
365 ctx,
366 )
367 } else {
368 merge_run_by_run(
369 &branch_arrays,
370 &else_value,
371 &spans,
372 &output_dtype,
373 builder,
374 ctx,
375 )
376 }
377}
378
379fn merge_row_by_row(
382 branch_arrays: &[&ArrayRef],
383 else_value: &ArrayRef,
384 spans: &[(usize, usize, usize)],
385 output_dtype: &DType,
386 mut builder: Box<dyn ArrayBuilder>,
387 ctx: &mut ExecutionCtx,
388) -> VortexResult<ArrayRef> {
389 let mut pos = 0;
390 for &(start, end, branch_idx) in spans {
391 for row in pos..start {
392 let scalar = else_value.execute_scalar(row, ctx)?;
393 builder.append_scalar(&scalar.cast(output_dtype)?)?;
394 }
395 for row in start..end {
396 let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
397 builder.append_scalar(&scalar.cast(output_dtype)?)?;
398 }
399 pos = end;
400 }
401 for row in pos..else_value.len() {
402 let scalar = else_value.execute_scalar(row, ctx)?;
403 builder.append_scalar(&scalar.cast(output_dtype)?)?;
404 }
405
406 Ok(builder.finish())
407}
408
409fn merge_run_by_run(
413 branch_arrays: &[&ArrayRef],
414 else_value: &ArrayRef,
415 spans: &[(usize, usize, usize)],
416 output_dtype: &DType,
417 mut builder: Box<dyn ArrayBuilder>,
418 ctx: &mut ExecutionCtx,
419) -> VortexResult<ArrayRef> {
420 let else_value = else_value.cast(output_dtype.clone())?;
421 let len = else_value.len();
422 for (start, end, branch_idx) in spans {
423 if builder.len() < *start {
424 else_value
425 .slice(builder.len()..*start)?
426 .append_to_builder(builder.as_mut(), ctx)?;
427 }
428 branch_arrays[*branch_idx]
429 .cast(output_dtype.clone())?
430 .slice(*start..*end)?
431 .append_to_builder(builder.as_mut(), ctx)?;
432 }
433 if builder.len() < len {
434 else_value
435 .slice(builder.len()..len)?
436 .append_to_builder(builder.as_mut(), ctx)?;
437 }
438
439 Ok(builder.finish())
440}
441
442#[cfg(test)]
443mod tests {
444 use std::sync::LazyLock;
445
446 use vortex_buffer::buffer;
447 use vortex_error::VortexExpect as _;
448 use vortex_session::VortexSession;
449
450 use super::*;
451 use crate::Canonical;
452 use crate::IntoArray;
453 use crate::LEGACY_SESSION;
454 use crate::VortexSessionExecute;
455 use crate::arrays::BoolArray;
456 use crate::arrays::PrimitiveArray;
457 use crate::arrays::StructArray;
458 use crate::assert_arrays_eq;
459 use crate::dtype::DType;
460 use crate::dtype::Nullability;
461 use crate::dtype::PType;
462 use crate::dtype::StructFields;
463 use crate::expr::case_when;
464 use crate::expr::case_when_no_else;
465 use crate::expr::col;
466 use crate::expr::eq;
467 use crate::expr::get_item;
468 use crate::expr::gt;
469 use crate::expr::is_not_null;
470 use crate::expr::is_null;
471 use crate::expr::lit;
472 use crate::expr::nested_case_when;
473 use crate::expr::root;
474 use crate::expr::test_harness;
475 use crate::scalar::Scalar;
476 use crate::session::ArraySession;
477
478 static SESSION: LazyLock<VortexSession> =
479 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
480
481 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
483 let mut ctx = SESSION.create_execution_ctx();
484 array
485 .clone()
486 .apply(expr)
487 .unwrap()
488 .execute::<Canonical>(&mut ctx)
489 .unwrap()
490 .into_array()
491 }
492
493 #[test]
496 #[should_panic(expected = "cannot serialize")]
497 fn test_serialization_roundtrip() {
498 let options = CaseWhenOptions {
499 num_when_then_pairs: 1,
500 has_else: true,
501 };
502 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
503 let deserialized = CaseWhen
504 .deserialize(&serialized, &VortexSession::empty())
505 .unwrap();
506 assert_eq!(options, deserialized);
507 }
508
509 #[test]
510 #[should_panic(expected = "cannot serialize")]
511 fn test_serialization_no_else() {
512 let options = CaseWhenOptions {
513 num_when_then_pairs: 1,
514 has_else: false,
515 };
516 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
517 let deserialized = CaseWhen
518 .deserialize(&serialized, &VortexSession::empty())
519 .unwrap();
520 assert_eq!(options, deserialized);
521 }
522
523 #[test]
526 fn test_display_with_else() {
527 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
528 let display = format!("{}", expr);
529 assert!(display.contains("CASE"));
530 assert!(display.contains("WHEN"));
531 assert!(display.contains("THEN"));
532 assert!(display.contains("ELSE"));
533 assert!(display.contains("END"));
534 }
535
536 #[test]
537 fn test_display_no_else() {
538 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
539 let display = format!("{}", expr);
540 assert!(display.contains("CASE"));
541 assert!(display.contains("WHEN"));
542 assert!(display.contains("THEN"));
543 assert!(!display.contains("ELSE"));
544 assert!(display.contains("END"));
545 }
546
547 #[test]
548 fn test_display_nested_nary() {
549 let expr = nested_case_when(
551 vec![
552 (gt(col("x"), lit(10i32)), lit("high")),
553 (gt(col("x"), lit(5i32)), lit("medium")),
554 ],
555 Some(lit("low")),
556 );
557 let display = format!("{}", expr);
558 assert_eq!(display.matches("CASE").count(), 1);
559 assert_eq!(display.matches("WHEN").count(), 2);
560 assert_eq!(display.matches("THEN").count(), 2);
561 }
562
563 #[test]
566 fn test_return_dtype_with_else() {
567 let expr = case_when(lit(true), lit(100i32), lit(0i32));
568 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
569 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
570 assert_eq!(
571 result_dtype,
572 DType::Primitive(PType::I32, Nullability::NonNullable)
573 );
574 }
575
576 #[test]
577 fn test_return_dtype_with_nullable_else() {
578 let expr = case_when(
579 lit(true),
580 lit(100i32),
581 lit(Scalar::null(DType::Primitive(
582 PType::I32,
583 Nullability::Nullable,
584 ))),
585 );
586 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
587 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
588 assert_eq!(
589 result_dtype,
590 DType::Primitive(PType::I32, Nullability::Nullable)
591 );
592 }
593
594 #[test]
595 fn test_return_dtype_without_else_is_nullable() {
596 let expr = case_when_no_else(lit(true), lit(100i32));
597 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
598 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
599 assert_eq!(
600 result_dtype,
601 DType::Primitive(PType::I32, Nullability::Nullable)
602 );
603 }
604
605 #[test]
606 fn test_return_dtype_with_struct_input() {
607 let dtype = test_harness::struct_dtype();
608 let expr = case_when(
609 gt(get_item("col1", root()), lit(10u16)),
610 lit(100i32),
611 lit(0i32),
612 );
613 let result_dtype = expr.return_dtype(&dtype).unwrap();
614 assert_eq!(
615 result_dtype,
616 DType::Primitive(PType::I32, Nullability::NonNullable)
617 );
618 }
619
620 #[test]
621 fn test_return_dtype_mismatched_then_else_errors() {
622 let expr = case_when(lit(true), lit(100i32), lit("zero"));
623 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
624 let err = expr.return_dtype(&input_dtype).unwrap_err();
625 assert!(
626 err.to_string()
627 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
628 );
629 }
630
631 #[test]
634 fn test_arity_with_else() {
635 let options = CaseWhenOptions {
636 num_when_then_pairs: 1,
637 has_else: true,
638 };
639 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
640 }
641
642 #[test]
643 fn test_arity_without_else() {
644 let options = CaseWhenOptions {
645 num_when_then_pairs: 1,
646 has_else: false,
647 };
648 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
649 }
650
651 #[test]
654 fn test_child_names() {
655 let options = CaseWhenOptions {
656 num_when_then_pairs: 1,
657 has_else: true,
658 };
659 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
660 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
661 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
662 }
663
664 #[test]
667 #[should_panic(expected = "cannot serialize")]
668 fn test_serialization_roundtrip_nary() {
669 let options = CaseWhenOptions {
670 num_when_then_pairs: 3,
671 has_else: true,
672 };
673 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
674 let deserialized = CaseWhen
675 .deserialize(&serialized, &VortexSession::empty())
676 .unwrap();
677 assert_eq!(options, deserialized);
678 }
679
680 #[test]
681 #[should_panic(expected = "cannot serialize")]
682 fn test_serialization_roundtrip_nary_no_else() {
683 let options = CaseWhenOptions {
684 num_when_then_pairs: 4,
685 has_else: false,
686 };
687 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
688 let deserialized = CaseWhen
689 .deserialize(&serialized, &VortexSession::empty())
690 .unwrap();
691 assert_eq!(options, deserialized);
692 }
693
694 #[test]
697 fn test_arity_nary_with_else() {
698 let options = CaseWhenOptions {
699 num_when_then_pairs: 3,
700 has_else: true,
701 };
702 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
704 }
705
706 #[test]
707 fn test_arity_nary_without_else() {
708 let options = CaseWhenOptions {
709 num_when_then_pairs: 3,
710 has_else: false,
711 };
712 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
714 }
715
716 #[test]
719 fn test_child_names_nary() {
720 let options = CaseWhenOptions {
721 num_when_then_pairs: 3,
722 has_else: true,
723 };
724 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
725 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
726 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
727 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
728 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
729 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
730 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
731 }
732
733 #[test]
736 fn test_return_dtype_nary_mismatched_then_types_errors() {
737 let expr = nested_case_when(
738 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
739 Some(lit(0i32)),
740 );
741 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
742 let err = expr.return_dtype(&input_dtype).unwrap_err();
743 assert!(err.to_string().contains("THEN dtypes must match"));
744 }
745
746 #[test]
747 fn test_return_dtype_nary_mixed_nullability() {
748 let non_null_then = lit(100i32);
751 let nullable_then = lit(Scalar::null(DType::Primitive(
752 PType::I32,
753 Nullability::Nullable,
754 )));
755 let expr = nested_case_when(
756 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
757 Some(lit(0i32)),
758 );
759 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
760 let result = expr.return_dtype(&input_dtype).unwrap();
761 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
762 }
763
764 #[test]
765 fn test_return_dtype_nary_no_else_is_nullable() {
766 let expr = nested_case_when(
767 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
768 None,
769 );
770 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
771 let result = expr.return_dtype(&input_dtype).unwrap();
772 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
773 }
774
775 #[test]
778 fn test_replace_children() {
779 let expr = case_when(lit(true), lit(1i32), lit(0i32));
780 expr.with_children([lit(false), lit(2i32), lit(3i32)])
781 .vortex_expect("operation should succeed in test");
782 }
783
784 #[test]
787 fn test_evaluate_simple_condition() {
788 let test_array =
789 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
790 .unwrap()
791 .into_array();
792
793 let expr = case_when(
794 gt(get_item("value", root()), lit(2i32)),
795 lit(100i32),
796 lit(0i32),
797 );
798
799 let result = evaluate_expr(&expr, &test_array);
800 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
801 }
802
803 #[test]
804 fn test_evaluate_nary_multiple_conditions() {
805 let test_array =
807 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
808 .unwrap()
809 .into_array();
810
811 let expr = nested_case_when(
812 vec![
813 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
814 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
815 ],
816 Some(lit(0i32)),
817 );
818
819 let result = evaluate_expr(&expr, &test_array);
820 assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
821 }
822
823 #[test]
824 fn test_evaluate_nary_first_match_wins() {
825 let test_array =
826 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
827 .unwrap()
828 .into_array();
829
830 let expr = nested_case_when(
832 vec![
833 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
834 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
835 ],
836 Some(lit(0i32)),
837 );
838
839 let result = evaluate_expr(&expr, &test_array);
840 assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
841 }
842
843 #[test]
844 fn test_evaluate_no_else_returns_null() {
845 let test_array =
846 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
847 .unwrap()
848 .into_array();
849
850 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
851
852 let result = evaluate_expr(&expr, &test_array);
853 assert!(result.dtype().is_nullable());
854 assert_arrays_eq!(
855 result,
856 PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
857 .into_array()
858 );
859 }
860
861 #[test]
862 fn test_evaluate_all_conditions_false() {
863 let test_array =
864 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
865 .unwrap()
866 .into_array();
867
868 let expr = case_when(
869 gt(get_item("value", root()), lit(100i32)),
870 lit(1i32),
871 lit(0i32),
872 );
873
874 let result = evaluate_expr(&expr, &test_array);
875 assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
876 }
877
878 #[test]
879 fn test_evaluate_all_conditions_true() {
880 let test_array =
881 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
882 .unwrap()
883 .into_array();
884
885 let expr = case_when(
886 gt(get_item("value", root()), lit(0i32)),
887 lit(100i32),
888 lit(0i32),
889 );
890
891 let result = evaluate_expr(&expr, &test_array);
892 assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
893 }
894
895 #[test]
896 fn test_evaluate_all_true_no_else_returns_correct_dtype() {
897 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
900 .unwrap()
901 .into_array();
902
903 let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
904
905 let result = evaluate_expr(&expr, &test_array);
906 assert!(
907 result.dtype().is_nullable(),
908 "result dtype must be Nullable, got {:?}",
909 result.dtype()
910 );
911 assert_arrays_eq!(
912 result,
913 PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
914 );
915 }
916
917 #[test]
918 fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
919 let test_array =
927 StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
928
929 let nullable_20 =
930 Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
931
932 let expr = nested_case_when(
933 vec![
934 (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
935 (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
936 ],
937 Some(lit(0i32)),
938 );
939
940 let result = evaluate_expr(&expr, &test_array);
941 assert!(
942 result.dtype().is_nullable(),
943 "result dtype must be Nullable, got {:?}",
944 result.dtype()
945 );
946 assert_arrays_eq!(
947 result,
948 PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
949 );
950 Ok(())
951 }
952
953 #[test]
954 fn test_evaluate_with_literal_condition() {
955 let test_array = buffer![1i32, 2, 3].into_array();
956 let expr = case_when(lit(true), lit(100i32), lit(0i32));
957 let result = evaluate_expr(&expr, &test_array);
958
959 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
960 }
961
962 #[test]
963 fn test_evaluate_with_bool_column_result() {
964 let test_array =
965 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
966 .unwrap()
967 .into_array();
968
969 let expr = case_when(
970 gt(get_item("value", root()), lit(2i32)),
971 lit(true),
972 lit(false),
973 );
974
975 let result = evaluate_expr(&expr, &test_array);
976 assert_arrays_eq!(
977 result,
978 BoolArray::from_iter([false, false, true, true, true]).into_array()
979 );
980 }
981
982 #[test]
983 fn test_evaluate_with_nullable_condition() {
984 let test_array = StructArray::from_fields(&[(
985 "cond",
986 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
987 )])
988 .unwrap()
989 .into_array();
990
991 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
992
993 let result = evaluate_expr(&expr, &test_array);
994 assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
995 }
996
997 #[test]
998 fn test_evaluate_with_nullable_result_values() {
999 let test_array = StructArray::from_fields(&[
1000 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
1001 (
1002 "result",
1003 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
1004 .into_array(),
1005 ),
1006 ])
1007 .unwrap()
1008 .into_array();
1009
1010 let expr = case_when(
1011 gt(get_item("value", root()), lit(2i32)),
1012 get_item("result", root()),
1013 lit(0i32),
1014 );
1015
1016 let result = evaluate_expr(&expr, &test_array);
1017 assert_arrays_eq!(
1018 result,
1019 PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
1020 .into_array()
1021 );
1022 }
1023
1024 #[test]
1025 fn test_evaluate_with_all_null_condition() {
1026 let test_array = StructArray::from_fields(&[(
1027 "cond",
1028 BoolArray::from_iter([None, None, None]).into_array(),
1029 )])
1030 .unwrap()
1031 .into_array();
1032
1033 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
1034
1035 let result = evaluate_expr(&expr, &test_array);
1036 assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
1037 }
1038
1039 #[test]
1042 fn test_evaluate_nary_no_else_returns_null() {
1043 let test_array =
1044 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1045 .unwrap()
1046 .into_array();
1047
1048 let expr = nested_case_when(
1050 vec![
1051 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1052 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1053 ],
1054 None,
1055 );
1056
1057 let result = evaluate_expr(&expr, &test_array);
1058 assert!(result.dtype().is_nullable());
1059 assert_arrays_eq!(
1060 result,
1061 PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
1062 .into_array()
1063 );
1064 }
1065
1066 #[test]
1067 fn test_evaluate_nary_many_conditions() {
1068 let test_array =
1069 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1070 .unwrap()
1071 .into_array();
1072
1073 let expr = nested_case_when(
1075 vec![
1076 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1077 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1078 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1079 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1080 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1081 ],
1082 Some(lit(0i32)),
1083 );
1084
1085 let result = evaluate_expr(&expr, &test_array);
1086 assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1087 }
1088
1089 #[test]
1090 fn test_evaluate_nary_all_false_no_else() {
1091 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1092 .unwrap()
1093 .into_array();
1094
1095 let expr = nested_case_when(
1097 vec![
1098 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1099 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1100 ],
1101 None,
1102 );
1103
1104 let result = evaluate_expr(&expr, &test_array);
1105 assert!(result.dtype().is_nullable());
1106 assert_arrays_eq!(
1107 result,
1108 PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1109 );
1110 }
1111
1112 #[test]
1113 fn test_evaluate_nary_overlapping_conditions_first_wins() {
1114 let test_array =
1115 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1116 .unwrap()
1117 .into_array();
1118
1119 let expr = nested_case_when(
1123 vec![
1124 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1125 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1126 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1127 ],
1128 Some(lit(0i32)),
1129 );
1130
1131 let result = evaluate_expr(&expr, &test_array);
1132 assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1134 }
1135
1136 #[test]
1137 fn test_evaluate_nary_early_exit_when_remaining_empty() {
1138 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1141 .unwrap()
1142 .into_array();
1143
1144 let expr = nested_case_when(
1145 vec![
1146 (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1147 (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1149 ],
1150 Some(lit(0i32)),
1151 );
1152
1153 let result = evaluate_expr(&expr, &test_array);
1154 assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1155 }
1156
1157 #[test]
1158 fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1159 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1162 .unwrap()
1163 .into_array();
1164
1165 let expr = nested_case_when(
1166 vec![
1167 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1168 (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1171 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1172 ],
1173 Some(lit(0i32)),
1174 );
1175
1176 let result = evaluate_expr(&expr, &test_array);
1177 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1178 }
1179
1180 #[test]
1181 fn test_evaluate_nary_string_output() -> VortexResult<()> {
1182 let test_array =
1184 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1185 .into_array();
1186
1187 let expr = nested_case_when(
1191 vec![
1192 (gt(get_item("value", root()), lit(2i32)), lit("high")),
1193 (gt(get_item("value", root()), lit(0i32)), lit("low")),
1194 ],
1195 Some(lit("none")),
1196 );
1197
1198 let result = evaluate_expr(&expr, &test_array);
1199 assert_eq!(
1200 result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1201 Scalar::utf8("low", Nullability::NonNullable)
1202 );
1203 assert_eq!(
1204 result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1205 Scalar::utf8("low", Nullability::NonNullable)
1206 );
1207 assert_eq!(
1208 result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1209 Scalar::utf8("high", Nullability::NonNullable)
1210 );
1211 assert_eq!(
1212 result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1213 Scalar::utf8("high", Nullability::NonNullable)
1214 );
1215 Ok(())
1216 }
1217
1218 #[test]
1219 fn test_evaluate_nary_with_nullable_conditions() {
1220 let test_array = StructArray::from_fields(&[
1221 (
1222 "cond1",
1223 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1224 ),
1225 (
1226 "cond2",
1227 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1228 ),
1229 ])
1230 .unwrap()
1231 .into_array();
1232
1233 let expr = nested_case_when(
1234 vec![
1235 (get_item("cond1", root()), lit(10i32)),
1236 (get_item("cond2", root()), lit(20i32)),
1237 ],
1238 Some(lit(0i32)),
1239 );
1240
1241 let result = evaluate_expr(&expr, &test_array);
1242 assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1246 }
1247
1248 fn nullable_i64_scope(fields: &[&str]) -> DType {
1252 DType::Struct(
1253 StructFields::new(
1254 fields.to_vec().into(),
1255 vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()],
1256 ),
1257 Nullability::NonNullable,
1258 )
1259 }
1260
1261 #[test]
1262 fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> {
1263 let expr = case_when(is_null(col("x")), lit(0i64), col("x"));
1265 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1266 assert!(
1267 optimized.to_string().starts_with("vortex.fill_null"),
1268 "expected fill_null, got {optimized}"
1269 );
1270 Ok(())
1271 }
1272
1273 #[test]
1274 fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> {
1275 let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64));
1277 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1278 assert!(
1279 optimized.to_string().starts_with("vortex.fill_null"),
1280 "expected fill_null, got {optimized}"
1281 );
1282 Ok(())
1283 }
1284
1285 #[test]
1286 fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> {
1287 let expr = case_when(is_null(col("x")), lit(0i64), col("y"));
1289 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?;
1290 let s = optimized.to_string();
1291 assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1292 assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1293 Ok(())
1294 }
1295
1296 #[test]
1297 fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> {
1298 let expr = case_when(is_null(col("x")), col("c"), col("x"));
1301 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?;
1302 let s = optimized.to_string();
1303 assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1304 assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn test_simplify_null_fill_collapses_to_input() -> VortexResult<()> {
1310 let null_fill = || {
1314 lit(Scalar::null(DType::Primitive(
1315 PType::I64,
1316 Nullability::Nullable,
1317 )))
1318 };
1319
1320 for expr in [
1321 case_when(is_null(col("x")), null_fill(), col("x")),
1322 case_when(is_not_null(col("x")), col("x"), null_fill()),
1323 ] {
1324 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1325 assert_eq!(
1326 optimized.to_string(),
1327 "$.x",
1328 "expected collapse to input column, got {optimized}"
1329 );
1330 }
1331 Ok(())
1332 }
1333
1334 #[test]
1335 fn test_simplify_null_fill_semantic_equivalence() -> VortexResult<()> {
1336 let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1338 let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1339 let null_fill = lit(Scalar::null(DType::Primitive(
1340 PType::I64,
1341 Nullability::Nullable,
1342 )));
1343
1344 let original = case_when(is_null(root()), null_fill, root());
1345 let optimized = original.optimize_recursive(&scope)?;
1346 assert_eq!(
1347 optimized.to_string(),
1348 "$",
1349 "expected collapse to root, got {optimized}"
1350 );
1351
1352 let expected = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1353 assert_arrays_eq!(evaluate_expr(&original, &array), expected);
1354 assert_arrays_eq!(evaluate_expr(&optimized, &array), expected);
1355 Ok(())
1356 }
1357
1358 #[test]
1359 fn test_simplify_does_not_fire_without_else() -> VortexResult<()> {
1360 let expr = case_when_no_else(is_null(col("x")), lit(0i64));
1361 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1362 assert!(
1363 !optimized.to_string().contains("fill_null"),
1364 "must not rewrite a no-ELSE case_when, got {optimized}"
1365 );
1366 Ok(())
1367 }
1368
1369 #[test]
1370 fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> {
1371 let expr = nested_case_when(
1372 vec![
1373 (is_null(col("x")), lit(0i64)),
1374 (gt(col("x"), lit(5i64)), lit(1i64)),
1375 ],
1376 Some(col("x")),
1377 );
1378 let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1379 assert!(
1380 !optimized.to_string().contains("fill_null"),
1381 "must not rewrite a multi-pair case_when, got {optimized}"
1382 );
1383 Ok(())
1384 }
1385
1386 #[test]
1387 fn test_simplify_semantic_equivalence() -> VortexResult<()> {
1388 let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1390 let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1391
1392 let original = case_when(is_null(root()), lit(0i64), root());
1393 let optimized = original.optimize_recursive(&scope)?;
1394 assert!(
1395 optimized.to_string().starts_with("vortex.fill_null"),
1396 "expected fill_null, got {optimized}"
1397 );
1398
1399 assert_arrays_eq!(
1402 evaluate_expr(&original, &array),
1403 PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array()
1404 );
1405 assert_arrays_eq!(
1406 evaluate_expr(&optimized, &array),
1407 buffer![1i64, 0, 3].into_array()
1408 );
1409 Ok(())
1410 }
1411
1412 #[test]
1413 fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1414 let n = 100usize;
1417
1418 let branch0_mask = Mask::from_indices(n, (0..n).step_by(2));
1420 let branch1_mask = Mask::from_indices(n, (1..n).step_by(2));
1421
1422 let result = merge_case_branches(
1423 vec![
1424 (
1425 branch0_mask,
1426 PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1427 ),
1428 (
1429 branch1_mask,
1430 PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1431 ),
1432 ],
1433 PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1434 &mut SESSION.create_execution_ctx(),
1435 )?;
1436
1437 let expected: Vec<Option<i32>> = (0..n)
1439 .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1440 .collect();
1441 assert_arrays_eq!(
1442 result,
1443 PrimitiveArray::from_option_iter(expected).into_array()
1444 );
1445 Ok(())
1446 }
1447}