1use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_proto::expr as pb;
22use vortex_session::VortexSession;
23
24use crate::ArrayRef;
25use crate::ExecutionCtx;
26use crate::IntoArray;
27use crate::arrays::BoolArray;
28use crate::arrays::ConstantArray;
29use crate::dtype::DType;
30use crate::expr::Expression;
31use crate::scalar::Scalar;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::fns::zip::zip_impl;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct CaseWhenOptions {
42 pub num_when_then_pairs: u32,
44 pub has_else: bool,
47}
48
49impl CaseWhenOptions {
50 pub fn num_children(&self) -> usize {
52 self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
53 }
54}
55
56impl fmt::Display for CaseWhenOptions {
57 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58 write!(
59 f,
60 "case_when(pairs={}, else={})",
61 self.num_when_then_pairs, self.has_else
62 )
63 }
64}
65
66#[derive(Clone)]
70pub struct CaseWhen;
71
72impl ScalarFnVTable for CaseWhen {
73 type Options = CaseWhenOptions;
74
75 fn id(&self) -> ScalarFnId {
76 ScalarFnId::from("vortex.case_when")
77 }
78
79 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
80 vortex_bail!("cannot serialize")
84 }
85
86 fn deserialize(
87 &self,
88 metadata: &[u8],
89 _session: &VortexSession,
90 ) -> VortexResult<Self::Options> {
91 let opts = pb::CaseWhenOpts::decode(metadata)?;
92 if opts.num_children < 2 {
93 vortex_bail!(
94 "CaseWhen expects at least 2 children, got {}",
95 opts.num_children
96 );
97 }
98 Ok(CaseWhenOptions {
99 num_when_then_pairs: opts.num_children / 2,
100 has_else: opts.num_children % 2 == 1,
101 })
102 }
103
104 fn arity(&self, options: &Self::Options) -> Arity {
105 Arity::Exact(options.num_children())
106 }
107
108 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
109 let num_pair_children = options.num_when_then_pairs as usize * 2;
110 if child_idx < num_pair_children {
111 let pair_idx = child_idx / 2;
112 if child_idx.is_multiple_of(2) {
113 ChildName::from(Arc::from(format!("when_{pair_idx}")))
114 } else {
115 ChildName::from(Arc::from(format!("then_{pair_idx}")))
116 }
117 } else if options.has_else && child_idx == num_pair_children {
118 ChildName::from("else")
119 } else {
120 unreachable!("Invalid child index {} for CaseWhen", child_idx)
121 }
122 }
123
124 fn fmt_sql(
125 &self,
126 options: &Self::Options,
127 expr: &Expression,
128 f: &mut Formatter<'_>,
129 ) -> fmt::Result {
130 write!(f, "CASE")?;
131 for i in 0..options.num_when_then_pairs as usize {
132 write!(
133 f,
134 " WHEN {} THEN {}",
135 expr.child(i * 2),
136 expr.child(i * 2 + 1)
137 )?;
138 }
139 if options.has_else {
140 let else_idx = options.num_when_then_pairs as usize * 2;
141 write!(f, " ELSE {}", expr.child(else_idx))?;
142 }
143 write!(f, " END")
144 }
145
146 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
147 if options.num_when_then_pairs == 0 {
148 vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
149 }
150
151 let expected_len = options.num_children();
152 if arg_dtypes.len() != expected_len {
153 vortex_bail!(
154 "CaseWhen expects {expected_len} argument dtypes, got {}",
155 arg_dtypes.len()
156 );
157 }
158
159 let first_then = &arg_dtypes[1];
163 let mut result_dtype = first_then.clone();
164
165 for i in 1..options.num_when_then_pairs as usize {
166 let then_i = &arg_dtypes[i * 2 + 1];
167 if !first_then.eq_ignore_nullability(then_i) {
168 vortex_bail!(
169 "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
170 first_then,
171 then_i
172 );
173 }
174 result_dtype = result_dtype.union_nullability(then_i.nullability());
175 }
176
177 if options.has_else {
178 let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
179 if !result_dtype.eq_ignore_nullability(else_dtype) {
180 vortex_bail!(
181 "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
182 first_then,
183 else_dtype
184 );
185 }
186 result_dtype = result_dtype.union_nullability(else_dtype.nullability());
187 } else {
188 result_dtype = result_dtype.as_nullable();
190 }
191
192 Ok(result_dtype)
193 }
194
195 fn execute(
196 &self,
197 options: &Self::Options,
198 args: &dyn ExecutionArgs,
199 ctx: &mut ExecutionCtx,
200 ) -> VortexResult<ArrayRef> {
201 let row_count = args.row_count();
202 let num_pairs = options.num_when_then_pairs as usize;
203
204 let mut result: ArrayRef = if options.has_else {
205 args.get(num_pairs * 2)?
206 } else {
207 let then_dtype = args.get(1)?.dtype().as_nullable();
208 ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
209 };
210
211 for i in (0..num_pairs).rev() {
215 let condition = args.get(i * 2)?;
216 let then_value = args.get(i * 2 + 1)?;
217
218 let cond_bool = condition.execute::<BoolArray>(ctx)?;
219 let mask = cond_bool.to_mask_fill_null_false();
220
221 if mask.all_true() {
222 result = then_value;
223 continue;
224 }
225
226 if mask.all_false() {
227 continue;
228 }
229
230 result = zip_impl(&then_value, &result, &mask)?;
231 }
232
233 Ok(result)
234 }
235
236 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
237 true
239 }
240
241 fn is_fallible(&self, _options: &Self::Options) -> bool {
242 false
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use std::sync::LazyLock;
249
250 use vortex_buffer::buffer;
251 use vortex_error::VortexExpect as _;
252 use vortex_session::VortexSession;
253
254 use super::*;
255 use crate::Canonical;
256 use crate::IntoArray;
257 use crate::ToCanonical;
258 use crate::VortexSessionExecute as _;
259 use crate::arrays::PrimitiveArray;
260 use crate::arrays::StructArray;
261 use crate::dtype::DType;
262 use crate::dtype::Nullability;
263 use crate::dtype::PType;
264 use crate::expr::case_when;
265 use crate::expr::case_when_no_else;
266 use crate::expr::col;
267 use crate::expr::eq;
268 use crate::expr::get_item;
269 use crate::expr::gt;
270 use crate::expr::lit;
271 use crate::expr::nested_case_when;
272 use crate::expr::root;
273 use crate::expr::test_harness;
274 use crate::scalar::Scalar;
275 use crate::scalar_fn::fns::case_when::BoolArray;
276 use crate::session::ArraySession;
277
278 static SESSION: LazyLock<VortexSession> =
279 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
280
281 fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
283 let mut ctx = SESSION.create_execution_ctx();
284 array
285 .apply(expr)
286 .unwrap()
287 .execute::<Canonical>(&mut ctx)
288 .unwrap()
289 .into_array()
290 }
291
292 #[test]
295 #[should_panic(expected = "cannot serialize")]
296 fn test_serialization_roundtrip() {
297 let options = CaseWhenOptions {
298 num_when_then_pairs: 1,
299 has_else: true,
300 };
301 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
302 let deserialized = CaseWhen
303 .deserialize(&serialized, &VortexSession::empty())
304 .unwrap();
305 assert_eq!(options, deserialized);
306 }
307
308 #[test]
309 #[should_panic(expected = "cannot serialize")]
310 fn test_serialization_no_else() {
311 let options = CaseWhenOptions {
312 num_when_then_pairs: 1,
313 has_else: false,
314 };
315 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
316 let deserialized = CaseWhen
317 .deserialize(&serialized, &VortexSession::empty())
318 .unwrap();
319 assert_eq!(options, deserialized);
320 }
321
322 #[test]
325 fn test_display_with_else() {
326 let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
327 let display = format!("{}", expr);
328 assert!(display.contains("CASE"));
329 assert!(display.contains("WHEN"));
330 assert!(display.contains("THEN"));
331 assert!(display.contains("ELSE"));
332 assert!(display.contains("END"));
333 }
334
335 #[test]
336 fn test_display_no_else() {
337 let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
338 let display = format!("{}", expr);
339 assert!(display.contains("CASE"));
340 assert!(display.contains("WHEN"));
341 assert!(display.contains("THEN"));
342 assert!(!display.contains("ELSE"));
343 assert!(display.contains("END"));
344 }
345
346 #[test]
347 fn test_display_nested_nary() {
348 let expr = nested_case_when(
350 vec![
351 (gt(col("x"), lit(10i32)), lit("high")),
352 (gt(col("x"), lit(5i32)), lit("medium")),
353 ],
354 Some(lit("low")),
355 );
356 let display = format!("{}", expr);
357 assert_eq!(display.matches("CASE").count(), 1);
358 assert_eq!(display.matches("WHEN").count(), 2);
359 assert_eq!(display.matches("THEN").count(), 2);
360 }
361
362 #[test]
365 fn test_return_dtype_with_else() {
366 let expr = case_when(lit(true), lit(100i32), lit(0i32));
367 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
368 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
369 assert_eq!(
370 result_dtype,
371 DType::Primitive(PType::I32, Nullability::NonNullable)
372 );
373 }
374
375 #[test]
376 fn test_return_dtype_with_nullable_else() {
377 let expr = case_when(
378 lit(true),
379 lit(100i32),
380 lit(Scalar::null(DType::Primitive(
381 PType::I32,
382 Nullability::Nullable,
383 ))),
384 );
385 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
386 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
387 assert_eq!(
388 result_dtype,
389 DType::Primitive(PType::I32, Nullability::Nullable)
390 );
391 }
392
393 #[test]
394 fn test_return_dtype_without_else_is_nullable() {
395 let expr = case_when_no_else(lit(true), lit(100i32));
396 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
397 let result_dtype = expr.return_dtype(&input_dtype).unwrap();
398 assert_eq!(
399 result_dtype,
400 DType::Primitive(PType::I32, Nullability::Nullable)
401 );
402 }
403
404 #[test]
405 fn test_return_dtype_with_struct_input() {
406 let dtype = test_harness::struct_dtype();
407 let expr = case_when(
408 gt(get_item("col1", root()), lit(10u16)),
409 lit(100i32),
410 lit(0i32),
411 );
412 let result_dtype = expr.return_dtype(&dtype).unwrap();
413 assert_eq!(
414 result_dtype,
415 DType::Primitive(PType::I32, Nullability::NonNullable)
416 );
417 }
418
419 #[test]
420 fn test_return_dtype_mismatched_then_else_errors() {
421 let expr = case_when(lit(true), lit(100i32), lit("zero"));
422 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
423 let err = expr.return_dtype(&input_dtype).unwrap_err();
424 assert!(
425 err.to_string()
426 .contains("THEN and ELSE dtypes must match (ignoring nullability)")
427 );
428 }
429
430 #[test]
433 fn test_arity_with_else() {
434 let options = CaseWhenOptions {
435 num_when_then_pairs: 1,
436 has_else: true,
437 };
438 assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
439 }
440
441 #[test]
442 fn test_arity_without_else() {
443 let options = CaseWhenOptions {
444 num_when_then_pairs: 1,
445 has_else: false,
446 };
447 assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
448 }
449
450 #[test]
453 fn test_child_names() {
454 let options = CaseWhenOptions {
455 num_when_then_pairs: 1,
456 has_else: true,
457 };
458 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
459 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
460 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
461 }
462
463 #[test]
466 #[should_panic(expected = "cannot serialize")]
467 fn test_serialization_roundtrip_nary() {
468 let options = CaseWhenOptions {
469 num_when_then_pairs: 3,
470 has_else: true,
471 };
472 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
473 let deserialized = CaseWhen
474 .deserialize(&serialized, &VortexSession::empty())
475 .unwrap();
476 assert_eq!(options, deserialized);
477 }
478
479 #[test]
480 #[should_panic(expected = "cannot serialize")]
481 fn test_serialization_roundtrip_nary_no_else() {
482 let options = CaseWhenOptions {
483 num_when_then_pairs: 4,
484 has_else: false,
485 };
486 let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
487 let deserialized = CaseWhen
488 .deserialize(&serialized, &VortexSession::empty())
489 .unwrap();
490 assert_eq!(options, deserialized);
491 }
492
493 #[test]
496 fn test_arity_nary_with_else() {
497 let options = CaseWhenOptions {
498 num_when_then_pairs: 3,
499 has_else: true,
500 };
501 assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
503 }
504
505 #[test]
506 fn test_arity_nary_without_else() {
507 let options = CaseWhenOptions {
508 num_when_then_pairs: 3,
509 has_else: false,
510 };
511 assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
513 }
514
515 #[test]
518 fn test_child_names_nary() {
519 let options = CaseWhenOptions {
520 num_when_then_pairs: 3,
521 has_else: true,
522 };
523 assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
524 assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
525 assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
526 assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
527 assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
528 assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
529 assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
530 }
531
532 #[test]
535 fn test_return_dtype_nary_mismatched_then_types_errors() {
536 let expr = nested_case_when(
537 vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
538 Some(lit(0i32)),
539 );
540 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
541 let err = expr.return_dtype(&input_dtype).unwrap_err();
542 assert!(err.to_string().contains("THEN dtypes must match"));
543 }
544
545 #[test]
546 fn test_return_dtype_nary_mixed_nullability() {
547 let non_null_then = lit(100i32);
550 let nullable_then = lit(Scalar::null(DType::Primitive(
551 PType::I32,
552 Nullability::Nullable,
553 )));
554 let expr = nested_case_when(
555 vec![(lit(true), non_null_then), (lit(false), nullable_then)],
556 Some(lit(0i32)),
557 );
558 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
559 let result = expr.return_dtype(&input_dtype).unwrap();
560 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
561 }
562
563 #[test]
564 fn test_return_dtype_nary_no_else_is_nullable() {
565 let expr = nested_case_when(
566 vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
567 None,
568 );
569 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
570 let result = expr.return_dtype(&input_dtype).unwrap();
571 assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
572 }
573
574 #[test]
577 fn test_replace_children() {
578 let expr = case_when(lit(true), lit(1i32), lit(0i32));
579 expr.with_children([lit(false), lit(2i32), lit(3i32)])
580 .vortex_expect("operation should succeed in test");
581 }
582
583 #[test]
586 fn test_evaluate_simple_condition() {
587 let test_array =
588 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
589 .unwrap()
590 .into_array();
591
592 let expr = case_when(
593 gt(get_item("value", root()), lit(2i32)),
594 lit(100i32),
595 lit(0i32),
596 );
597
598 let result = evaluate_expr(&expr, &test_array).to_primitive();
599 assert_eq!(result.as_slice::<i32>(), &[0, 0, 100, 100, 100]);
600 }
601
602 #[test]
603 fn test_evaluate_nary_multiple_conditions() {
604 let test_array =
606 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
607 .unwrap()
608 .into_array();
609
610 let expr = nested_case_when(
611 vec![
612 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
613 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
614 ],
615 Some(lit(0i32)),
616 );
617
618 let result = evaluate_expr(&expr, &test_array).to_primitive();
619 assert_eq!(result.as_slice::<i32>(), &[10, 0, 30, 0, 0]);
620 }
621
622 #[test]
623 fn test_evaluate_nary_first_match_wins() {
624 let test_array =
625 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
626 .unwrap()
627 .into_array();
628
629 let expr = nested_case_when(
631 vec![
632 (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
633 (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
634 ],
635 Some(lit(0i32)),
636 );
637
638 let result = evaluate_expr(&expr, &test_array).to_primitive();
639 assert_eq!(result.as_slice::<i32>(), &[0, 0, 100, 100, 100]);
640 }
641
642 #[test]
643 fn test_evaluate_no_else_returns_null() {
644 let test_array =
645 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
646 .unwrap()
647 .into_array();
648
649 let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
650
651 let result = evaluate_expr(&expr, &test_array);
652 assert!(result.dtype().is_nullable());
653
654 assert_eq!(
655 result.scalar_at(0).unwrap(),
656 Scalar::null(result.dtype().clone())
657 );
658 assert_eq!(
659 result.scalar_at(1).unwrap(),
660 Scalar::null(result.dtype().clone())
661 );
662 assert_eq!(
663 result.scalar_at(2).unwrap(),
664 Scalar::null(result.dtype().clone())
665 );
666 assert_eq!(
667 result.scalar_at(3).unwrap(),
668 Scalar::from(100i32).cast(result.dtype()).unwrap()
669 );
670 assert_eq!(
671 result.scalar_at(4).unwrap(),
672 Scalar::from(100i32).cast(result.dtype()).unwrap()
673 );
674 }
675
676 #[test]
677 fn test_evaluate_all_conditions_false() {
678 let test_array =
679 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
680 .unwrap()
681 .into_array();
682
683 let expr = case_when(
684 gt(get_item("value", root()), lit(100i32)),
685 lit(1i32),
686 lit(0i32),
687 );
688
689 let result = evaluate_expr(&expr, &test_array).to_primitive();
690 assert_eq!(result.as_slice::<i32>(), &[0, 0, 0, 0, 0]);
691 }
692
693 #[test]
694 fn test_evaluate_all_conditions_true() {
695 let test_array =
696 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
697 .unwrap()
698 .into_array();
699
700 let expr = case_when(
701 gt(get_item("value", root()), lit(0i32)),
702 lit(100i32),
703 lit(0i32),
704 );
705
706 let result = evaluate_expr(&expr, &test_array).to_primitive();
707 assert_eq!(result.as_slice::<i32>(), &[100, 100, 100, 100, 100]);
708 }
709
710 #[test]
711 fn test_evaluate_with_literal_condition() {
712 let test_array = buffer![1i32, 2, 3].into_array();
713 let expr = case_when(lit(true), lit(100i32), lit(0i32));
714 let result = evaluate_expr(&expr, &test_array);
715
716 if let Some(constant) = result.as_constant() {
717 assert_eq!(constant, Scalar::from(100i32));
718 } else {
719 let prim = result.to_primitive();
720 assert_eq!(prim.as_slice::<i32>(), &[100, 100, 100]);
721 }
722 }
723
724 #[test]
725 fn test_evaluate_with_bool_column_result() {
726 let test_array =
727 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
728 .unwrap()
729 .into_array();
730
731 let expr = case_when(
732 gt(get_item("value", root()), lit(2i32)),
733 lit(true),
734 lit(false),
735 );
736
737 let result = evaluate_expr(&expr, &test_array).to_bool();
738 assert_eq!(
739 result.to_bit_buffer().iter().collect::<Vec<_>>(),
740 vec![false, false, true, true, true]
741 );
742 }
743
744 #[test]
745 fn test_evaluate_with_nullable_condition() {
746 let test_array = StructArray::from_fields(&[(
747 "cond",
748 BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
749 )])
750 .unwrap()
751 .into_array();
752
753 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
754
755 let result = evaluate_expr(&expr, &test_array).to_primitive();
756 assert_eq!(result.as_slice::<i32>(), &[100, 0, 0, 0, 100]);
757 }
758
759 #[test]
760 fn test_evaluate_with_nullable_result_values() {
761 let test_array = StructArray::from_fields(&[
762 ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
763 (
764 "result",
765 PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
766 .into_array(),
767 ),
768 ])
769 .unwrap()
770 .into_array();
771
772 let expr = case_when(
773 gt(get_item("value", root()), lit(2i32)),
774 get_item("result", root()),
775 lit(0i32),
776 );
777
778 let result = evaluate_expr(&expr, &test_array);
779 let prim = result.to_primitive();
780 assert_eq!(prim.as_slice::<i32>(), &[0, 0, 30, 40, 50]);
781 }
782
783 #[test]
784 fn test_evaluate_with_all_null_condition() {
785 let test_array = StructArray::from_fields(&[(
786 "cond",
787 BoolArray::from_iter([None, None, None]).into_array(),
788 )])
789 .unwrap()
790 .into_array();
791
792 let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
793
794 let result = evaluate_expr(&expr, &test_array).to_primitive();
795 assert_eq!(result.as_slice::<i32>(), &[0, 0, 0]);
796 }
797
798 #[test]
801 fn test_evaluate_nary_no_else_returns_null() {
802 let test_array =
803 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
804 .unwrap()
805 .into_array();
806
807 let expr = nested_case_when(
809 vec![
810 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
811 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
812 ],
813 None,
814 );
815
816 let result = evaluate_expr(&expr, &test_array);
817 assert!(result.dtype().is_nullable());
818
819 assert_eq!(
820 result.scalar_at(0).unwrap(),
821 Scalar::from(10i32).cast(result.dtype()).unwrap()
822 );
823 assert_eq!(
824 result.scalar_at(1).unwrap(),
825 Scalar::null(result.dtype().clone())
826 );
827 assert_eq!(
828 result.scalar_at(2).unwrap(),
829 Scalar::from(30i32).cast(result.dtype()).unwrap()
830 );
831 assert_eq!(
832 result.scalar_at(3).unwrap(),
833 Scalar::null(result.dtype().clone())
834 );
835 assert_eq!(
836 result.scalar_at(4).unwrap(),
837 Scalar::null(result.dtype().clone())
838 );
839 }
840
841 #[test]
842 fn test_evaluate_nary_many_conditions() {
843 let test_array =
844 StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
845 .unwrap()
846 .into_array();
847
848 let expr = nested_case_when(
850 vec![
851 (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
852 (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
853 (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
854 (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
855 (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
856 ],
857 Some(lit(0i32)),
858 );
859
860 let result = evaluate_expr(&expr, &test_array).to_primitive();
861 assert_eq!(result.as_slice::<i32>(), &[10, 20, 30, 40, 50]);
862 }
863
864 #[test]
865 fn test_evaluate_nary_all_false_no_else() {
866 let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
867 .unwrap()
868 .into_array();
869
870 let expr = nested_case_when(
872 vec![
873 (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
874 (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
875 ],
876 None,
877 );
878
879 let result = evaluate_expr(&expr, &test_array);
880 assert!(result.dtype().is_nullable());
881 for i in 0..3 {
882 assert_eq!(
883 result.scalar_at(i).unwrap(),
884 Scalar::null(result.dtype().clone())
885 );
886 }
887 }
888
889 #[test]
890 fn test_evaluate_nary_overlapping_conditions_first_wins() {
891 let test_array =
892 StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
893 .unwrap()
894 .into_array();
895
896 let expr = nested_case_when(
900 vec![
901 (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
902 (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
903 (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
904 ],
905 Some(lit(0i32)),
906 );
907
908 let result = evaluate_expr(&expr, &test_array).to_primitive();
909 assert_eq!(result.as_slice::<i32>(), &[1, 1, 1]);
911 }
912
913 #[test]
914 fn test_evaluate_nary_with_nullable_conditions() {
915 let test_array = StructArray::from_fields(&[
916 (
917 "cond1",
918 BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
919 ),
920 (
921 "cond2",
922 BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
923 ),
924 ])
925 .unwrap()
926 .into_array();
927
928 let expr = nested_case_when(
929 vec![
930 (get_item("cond1", root()), lit(10i32)),
931 (get_item("cond2", root()), lit(20i32)),
932 ],
933 Some(lit(0i32)),
934 );
935
936 let result = evaluate_expr(&expr, &test_array).to_primitive();
937 assert_eq!(result.as_slice::<i32>(), &[10, 20, 0]);
941 }
942}