Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_proto::expr as pb;
22use vortex_session::VortexSession;
23
24use crate::ArrayRef;
25use crate::ExecutionCtx;
26use crate::IntoArray;
27use crate::arrays::BoolArray;
28use crate::arrays::ConstantArray;
29use crate::dtype::DType;
30use crate::expr::Expression;
31use crate::scalar::Scalar;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::fns::zip::zip_impl;
38
39/// Options for the n-ary CaseWhen expression.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct CaseWhenOptions {
42    /// Number of WHEN/THEN pairs.
43    pub num_when_then_pairs: u32,
44    /// Whether an ELSE clause is present.
45    /// If false, unmatched rows return NULL.
46    pub has_else: bool,
47}
48
49impl CaseWhenOptions {
50    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
51    pub fn num_children(&self) -> usize {
52        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
53    }
54}
55
56impl fmt::Display for CaseWhenOptions {
57    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58        write!(
59            f,
60            "case_when(pairs={}, else={})",
61            self.num_when_then_pairs, self.has_else
62        )
63    }
64}
65
66/// An n-ary CASE WHEN expression.
67///
68/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
69#[derive(Clone)]
70pub struct CaseWhen;
71
72impl ScalarFnVTable for CaseWhen {
73    type Options = CaseWhenOptions;
74
75    fn id(&self) -> ScalarFnId {
76        ScalarFnId::from("vortex.case_when")
77    }
78
79    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
80        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
81        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
82        // stabilize the expr
83        vortex_bail!("cannot serialize")
84    }
85
86    fn deserialize(
87        &self,
88        metadata: &[u8],
89        _session: &VortexSession,
90    ) -> VortexResult<Self::Options> {
91        let opts = pb::CaseWhenOpts::decode(metadata)?;
92        if opts.num_children < 2 {
93            vortex_bail!(
94                "CaseWhen expects at least 2 children, got {}",
95                opts.num_children
96            );
97        }
98        Ok(CaseWhenOptions {
99            num_when_then_pairs: opts.num_children / 2,
100            has_else: opts.num_children % 2 == 1,
101        })
102    }
103
104    fn arity(&self, options: &Self::Options) -> Arity {
105        Arity::Exact(options.num_children())
106    }
107
108    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
109        let num_pair_children = options.num_when_then_pairs as usize * 2;
110        if child_idx < num_pair_children {
111            let pair_idx = child_idx / 2;
112            if child_idx.is_multiple_of(2) {
113                ChildName::from(Arc::from(format!("when_{pair_idx}")))
114            } else {
115                ChildName::from(Arc::from(format!("then_{pair_idx}")))
116            }
117        } else if options.has_else && child_idx == num_pair_children {
118            ChildName::from("else")
119        } else {
120            unreachable!("Invalid child index {} for CaseWhen", child_idx)
121        }
122    }
123
124    fn fmt_sql(
125        &self,
126        options: &Self::Options,
127        expr: &Expression,
128        f: &mut Formatter<'_>,
129    ) -> fmt::Result {
130        write!(f, "CASE")?;
131        for i in 0..options.num_when_then_pairs as usize {
132            write!(
133                f,
134                " WHEN {} THEN {}",
135                expr.child(i * 2),
136                expr.child(i * 2 + 1)
137            )?;
138        }
139        if options.has_else {
140            let else_idx = options.num_when_then_pairs as usize * 2;
141            write!(f, " ELSE {}", expr.child(else_idx))?;
142        }
143        write!(f, " END")
144    }
145
146    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
147        if options.num_when_then_pairs == 0 {
148            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
149        }
150
151        let expected_len = options.num_children();
152        if arg_dtypes.len() != expected_len {
153            vortex_bail!(
154                "CaseWhen expects {expected_len} argument dtypes, got {}",
155                arg_dtypes.len()
156            );
157        }
158
159        // Unlike SQL which coerces all branches to a common supertype, we require
160        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
161        // The result nullability is the union of all branches.
162        let first_then = &arg_dtypes[1];
163        let mut result_dtype = first_then.clone();
164
165        for i in 1..options.num_when_then_pairs as usize {
166            let then_i = &arg_dtypes[i * 2 + 1];
167            if !first_then.eq_ignore_nullability(then_i) {
168                vortex_bail!(
169                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
170                    first_then,
171                    then_i
172                );
173            }
174            result_dtype = result_dtype.union_nullability(then_i.nullability());
175        }
176
177        if options.has_else {
178            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
179            if !result_dtype.eq_ignore_nullability(else_dtype) {
180                vortex_bail!(
181                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
182                    first_then,
183                    else_dtype
184                );
185            }
186            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
187        } else {
188            // No ELSE means unmatched rows are NULL
189            result_dtype = result_dtype.as_nullable();
190        }
191
192        Ok(result_dtype)
193    }
194
195    fn execute(
196        &self,
197        options: &Self::Options,
198        args: &dyn ExecutionArgs,
199        ctx: &mut ExecutionCtx,
200    ) -> VortexResult<ArrayRef> {
201        let row_count = args.row_count();
202        let num_pairs = options.num_when_then_pairs as usize;
203
204        let mut result: ArrayRef = if options.has_else {
205            args.get(num_pairs * 2)?
206        } else {
207            let then_dtype = args.get(1)?.dtype().as_nullable();
208            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
209        };
210
211        // TODO(perf): this reverse-zip approach touches every row for every condition.
212        // A left-to-right filter approach could maintain an "unmatched" mask, narrow it
213        // as conditions match, and exit early once all rows are resolved.
214        for i in (0..num_pairs).rev() {
215            let condition = args.get(i * 2)?;
216            let then_value = args.get(i * 2 + 1)?;
217
218            let cond_bool = condition.execute::<BoolArray>(ctx)?;
219            let mask = cond_bool.to_mask_fill_null_false();
220
221            if mask.all_true() {
222                result = then_value;
223                continue;
224            }
225
226            if mask.all_false() {
227                continue;
228            }
229
230            result = zip_impl(&then_value, &result, &mask)?;
231        }
232
233        Ok(result)
234    }
235
236    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
237        // CaseWhen is null-sensitive because NULL conditions are treated as false
238        true
239    }
240
241    fn is_fallible(&self, _options: &Self::Options) -> bool {
242        false
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use std::sync::LazyLock;
249
250    use vortex_buffer::buffer;
251    use vortex_error::VortexExpect as _;
252    use vortex_session::VortexSession;
253
254    use super::*;
255    use crate::Canonical;
256    use crate::IntoArray;
257    use crate::ToCanonical;
258    use crate::VortexSessionExecute as _;
259    use crate::arrays::PrimitiveArray;
260    use crate::arrays::StructArray;
261    use crate::dtype::DType;
262    use crate::dtype::Nullability;
263    use crate::dtype::PType;
264    use crate::expr::case_when;
265    use crate::expr::case_when_no_else;
266    use crate::expr::col;
267    use crate::expr::eq;
268    use crate::expr::get_item;
269    use crate::expr::gt;
270    use crate::expr::lit;
271    use crate::expr::nested_case_when;
272    use crate::expr::root;
273    use crate::expr::test_harness;
274    use crate::scalar::Scalar;
275    use crate::scalar_fn::fns::case_when::BoolArray;
276    use crate::session::ArraySession;
277
278    static SESSION: LazyLock<VortexSession> =
279        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
280
281    /// Helper to evaluate an expression using the apply+execute pattern
282    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
283        let mut ctx = SESSION.create_execution_ctx();
284        array
285            .apply(expr)
286            .unwrap()
287            .execute::<Canonical>(&mut ctx)
288            .unwrap()
289            .into_array()
290    }
291
292    // ==================== Serialization Tests ====================
293
294    #[test]
295    #[should_panic(expected = "cannot serialize")]
296    fn test_serialization_roundtrip() {
297        let options = CaseWhenOptions {
298            num_when_then_pairs: 1,
299            has_else: true,
300        };
301        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
302        let deserialized = CaseWhen
303            .deserialize(&serialized, &VortexSession::empty())
304            .unwrap();
305        assert_eq!(options, deserialized);
306    }
307
308    #[test]
309    #[should_panic(expected = "cannot serialize")]
310    fn test_serialization_no_else() {
311        let options = CaseWhenOptions {
312            num_when_then_pairs: 1,
313            has_else: false,
314        };
315        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
316        let deserialized = CaseWhen
317            .deserialize(&serialized, &VortexSession::empty())
318            .unwrap();
319        assert_eq!(options, deserialized);
320    }
321
322    // ==================== Display Tests ====================
323
324    #[test]
325    fn test_display_with_else() {
326        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
327        let display = format!("{}", expr);
328        assert!(display.contains("CASE"));
329        assert!(display.contains("WHEN"));
330        assert!(display.contains("THEN"));
331        assert!(display.contains("ELSE"));
332        assert!(display.contains("END"));
333    }
334
335    #[test]
336    fn test_display_no_else() {
337        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
338        let display = format!("{}", expr);
339        assert!(display.contains("CASE"));
340        assert!(display.contains("WHEN"));
341        assert!(display.contains("THEN"));
342        assert!(!display.contains("ELSE"));
343        assert!(display.contains("END"));
344    }
345
346    #[test]
347    fn test_display_nested_nary() {
348        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
349        let expr = nested_case_when(
350            vec![
351                (gt(col("x"), lit(10i32)), lit("high")),
352                (gt(col("x"), lit(5i32)), lit("medium")),
353            ],
354            Some(lit("low")),
355        );
356        let display = format!("{}", expr);
357        assert_eq!(display.matches("CASE").count(), 1);
358        assert_eq!(display.matches("WHEN").count(), 2);
359        assert_eq!(display.matches("THEN").count(), 2);
360    }
361
362    // ==================== DType Tests ====================
363
364    #[test]
365    fn test_return_dtype_with_else() {
366        let expr = case_when(lit(true), lit(100i32), lit(0i32));
367        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
368        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
369        assert_eq!(
370            result_dtype,
371            DType::Primitive(PType::I32, Nullability::NonNullable)
372        );
373    }
374
375    #[test]
376    fn test_return_dtype_with_nullable_else() {
377        let expr = case_when(
378            lit(true),
379            lit(100i32),
380            lit(Scalar::null(DType::Primitive(
381                PType::I32,
382                Nullability::Nullable,
383            ))),
384        );
385        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
386        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
387        assert_eq!(
388            result_dtype,
389            DType::Primitive(PType::I32, Nullability::Nullable)
390        );
391    }
392
393    #[test]
394    fn test_return_dtype_without_else_is_nullable() {
395        let expr = case_when_no_else(lit(true), lit(100i32));
396        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
397        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
398        assert_eq!(
399            result_dtype,
400            DType::Primitive(PType::I32, Nullability::Nullable)
401        );
402    }
403
404    #[test]
405    fn test_return_dtype_with_struct_input() {
406        let dtype = test_harness::struct_dtype();
407        let expr = case_when(
408            gt(get_item("col1", root()), lit(10u16)),
409            lit(100i32),
410            lit(0i32),
411        );
412        let result_dtype = expr.return_dtype(&dtype).unwrap();
413        assert_eq!(
414            result_dtype,
415            DType::Primitive(PType::I32, Nullability::NonNullable)
416        );
417    }
418
419    #[test]
420    fn test_return_dtype_mismatched_then_else_errors() {
421        let expr = case_when(lit(true), lit(100i32), lit("zero"));
422        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
423        let err = expr.return_dtype(&input_dtype).unwrap_err();
424        assert!(
425            err.to_string()
426                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
427        );
428    }
429
430    // ==================== Arity Tests ====================
431
432    #[test]
433    fn test_arity_with_else() {
434        let options = CaseWhenOptions {
435            num_when_then_pairs: 1,
436            has_else: true,
437        };
438        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
439    }
440
441    #[test]
442    fn test_arity_without_else() {
443        let options = CaseWhenOptions {
444            num_when_then_pairs: 1,
445            has_else: false,
446        };
447        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
448    }
449
450    // ==================== Child Name Tests ====================
451
452    #[test]
453    fn test_child_names() {
454        let options = CaseWhenOptions {
455            num_when_then_pairs: 1,
456            has_else: true,
457        };
458        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
459        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
460        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
461    }
462
463    // ==================== N-ary Serialization Tests ====================
464
465    #[test]
466    #[should_panic(expected = "cannot serialize")]
467    fn test_serialization_roundtrip_nary() {
468        let options = CaseWhenOptions {
469            num_when_then_pairs: 3,
470            has_else: true,
471        };
472        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
473        let deserialized = CaseWhen
474            .deserialize(&serialized, &VortexSession::empty())
475            .unwrap();
476        assert_eq!(options, deserialized);
477    }
478
479    #[test]
480    #[should_panic(expected = "cannot serialize")]
481    fn test_serialization_roundtrip_nary_no_else() {
482        let options = CaseWhenOptions {
483            num_when_then_pairs: 4,
484            has_else: false,
485        };
486        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
487        let deserialized = CaseWhen
488            .deserialize(&serialized, &VortexSession::empty())
489            .unwrap();
490        assert_eq!(options, deserialized);
491    }
492
493    // ==================== N-ary Arity Tests ====================
494
495    #[test]
496    fn test_arity_nary_with_else() {
497        let options = CaseWhenOptions {
498            num_when_then_pairs: 3,
499            has_else: true,
500        };
501        // 3 pairs * 2 children + 1 else = 7
502        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
503    }
504
505    #[test]
506    fn test_arity_nary_without_else() {
507        let options = CaseWhenOptions {
508            num_when_then_pairs: 3,
509            has_else: false,
510        };
511        // 3 pairs * 2 children = 6
512        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
513    }
514
515    // ==================== N-ary Child Name Tests ====================
516
517    #[test]
518    fn test_child_names_nary() {
519        let options = CaseWhenOptions {
520            num_when_then_pairs: 3,
521            has_else: true,
522        };
523        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
524        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
525        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
526        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
527        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
528        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
529        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
530    }
531
532    // ==================== N-ary DType Tests ====================
533
534    #[test]
535    fn test_return_dtype_nary_mismatched_then_types_errors() {
536        let expr = nested_case_when(
537            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
538            Some(lit(0i32)),
539        );
540        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
541        let err = expr.return_dtype(&input_dtype).unwrap_err();
542        assert!(err.to_string().contains("THEN dtypes must match"));
543    }
544
545    #[test]
546    fn test_return_dtype_nary_mixed_nullability() {
547        // When some THEN branches are nullable and others are not,
548        // the result should be nullable (union of nullabilities).
549        let non_null_then = lit(100i32);
550        let nullable_then = lit(Scalar::null(DType::Primitive(
551            PType::I32,
552            Nullability::Nullable,
553        )));
554        let expr = nested_case_when(
555            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
556            Some(lit(0i32)),
557        );
558        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
559        let result = expr.return_dtype(&input_dtype).unwrap();
560        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
561    }
562
563    #[test]
564    fn test_return_dtype_nary_no_else_is_nullable() {
565        let expr = nested_case_when(
566            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
567            None,
568        );
569        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
570        let result = expr.return_dtype(&input_dtype).unwrap();
571        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
572    }
573
574    // ==================== Expression Manipulation Tests ====================
575
576    #[test]
577    fn test_replace_children() {
578        let expr = case_when(lit(true), lit(1i32), lit(0i32));
579        expr.with_children([lit(false), lit(2i32), lit(3i32)])
580            .vortex_expect("operation should succeed in test");
581    }
582
583    // ==================== Evaluate Tests ====================
584
585    #[test]
586    fn test_evaluate_simple_condition() {
587        let test_array =
588            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
589                .unwrap()
590                .into_array();
591
592        let expr = case_when(
593            gt(get_item("value", root()), lit(2i32)),
594            lit(100i32),
595            lit(0i32),
596        );
597
598        let result = evaluate_expr(&expr, &test_array).to_primitive();
599        assert_eq!(result.as_slice::<i32>(), &[0, 0, 100, 100, 100]);
600    }
601
602    #[test]
603    fn test_evaluate_nary_multiple_conditions() {
604        // Test n-ary via nested_case_when
605        let test_array =
606            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
607                .unwrap()
608                .into_array();
609
610        let expr = nested_case_when(
611            vec![
612                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
613                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
614            ],
615            Some(lit(0i32)),
616        );
617
618        let result = evaluate_expr(&expr, &test_array).to_primitive();
619        assert_eq!(result.as_slice::<i32>(), &[10, 0, 30, 0, 0]);
620    }
621
622    #[test]
623    fn test_evaluate_nary_first_match_wins() {
624        let test_array =
625            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
626                .unwrap()
627                .into_array();
628
629        // Both conditions match for values > 3, but first one wins
630        let expr = nested_case_when(
631            vec![
632                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
633                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
634            ],
635            Some(lit(0i32)),
636        );
637
638        let result = evaluate_expr(&expr, &test_array).to_primitive();
639        assert_eq!(result.as_slice::<i32>(), &[0, 0, 100, 100, 100]);
640    }
641
642    #[test]
643    fn test_evaluate_no_else_returns_null() {
644        let test_array =
645            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
646                .unwrap()
647                .into_array();
648
649        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
650
651        let result = evaluate_expr(&expr, &test_array);
652        assert!(result.dtype().is_nullable());
653
654        assert_eq!(
655            result.scalar_at(0).unwrap(),
656            Scalar::null(result.dtype().clone())
657        );
658        assert_eq!(
659            result.scalar_at(1).unwrap(),
660            Scalar::null(result.dtype().clone())
661        );
662        assert_eq!(
663            result.scalar_at(2).unwrap(),
664            Scalar::null(result.dtype().clone())
665        );
666        assert_eq!(
667            result.scalar_at(3).unwrap(),
668            Scalar::from(100i32).cast(result.dtype()).unwrap()
669        );
670        assert_eq!(
671            result.scalar_at(4).unwrap(),
672            Scalar::from(100i32).cast(result.dtype()).unwrap()
673        );
674    }
675
676    #[test]
677    fn test_evaluate_all_conditions_false() {
678        let test_array =
679            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
680                .unwrap()
681                .into_array();
682
683        let expr = case_when(
684            gt(get_item("value", root()), lit(100i32)),
685            lit(1i32),
686            lit(0i32),
687        );
688
689        let result = evaluate_expr(&expr, &test_array).to_primitive();
690        assert_eq!(result.as_slice::<i32>(), &[0, 0, 0, 0, 0]);
691    }
692
693    #[test]
694    fn test_evaluate_all_conditions_true() {
695        let test_array =
696            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
697                .unwrap()
698                .into_array();
699
700        let expr = case_when(
701            gt(get_item("value", root()), lit(0i32)),
702            lit(100i32),
703            lit(0i32),
704        );
705
706        let result = evaluate_expr(&expr, &test_array).to_primitive();
707        assert_eq!(result.as_slice::<i32>(), &[100, 100, 100, 100, 100]);
708    }
709
710    #[test]
711    fn test_evaluate_with_literal_condition() {
712        let test_array = buffer![1i32, 2, 3].into_array();
713        let expr = case_when(lit(true), lit(100i32), lit(0i32));
714        let result = evaluate_expr(&expr, &test_array);
715
716        if let Some(constant) = result.as_constant() {
717            assert_eq!(constant, Scalar::from(100i32));
718        } else {
719            let prim = result.to_primitive();
720            assert_eq!(prim.as_slice::<i32>(), &[100, 100, 100]);
721        }
722    }
723
724    #[test]
725    fn test_evaluate_with_bool_column_result() {
726        let test_array =
727            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
728                .unwrap()
729                .into_array();
730
731        let expr = case_when(
732            gt(get_item("value", root()), lit(2i32)),
733            lit(true),
734            lit(false),
735        );
736
737        let result = evaluate_expr(&expr, &test_array).to_bool();
738        assert_eq!(
739            result.to_bit_buffer().iter().collect::<Vec<_>>(),
740            vec![false, false, true, true, true]
741        );
742    }
743
744    #[test]
745    fn test_evaluate_with_nullable_condition() {
746        let test_array = StructArray::from_fields(&[(
747            "cond",
748            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
749        )])
750        .unwrap()
751        .into_array();
752
753        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
754
755        let result = evaluate_expr(&expr, &test_array).to_primitive();
756        assert_eq!(result.as_slice::<i32>(), &[100, 0, 0, 0, 100]);
757    }
758
759    #[test]
760    fn test_evaluate_with_nullable_result_values() {
761        let test_array = StructArray::from_fields(&[
762            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
763            (
764                "result",
765                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
766                    .into_array(),
767            ),
768        ])
769        .unwrap()
770        .into_array();
771
772        let expr = case_when(
773            gt(get_item("value", root()), lit(2i32)),
774            get_item("result", root()),
775            lit(0i32),
776        );
777
778        let result = evaluate_expr(&expr, &test_array);
779        let prim = result.to_primitive();
780        assert_eq!(prim.as_slice::<i32>(), &[0, 0, 30, 40, 50]);
781    }
782
783    #[test]
784    fn test_evaluate_with_all_null_condition() {
785        let test_array = StructArray::from_fields(&[(
786            "cond",
787            BoolArray::from_iter([None, None, None]).into_array(),
788        )])
789        .unwrap()
790        .into_array();
791
792        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
793
794        let result = evaluate_expr(&expr, &test_array).to_primitive();
795        assert_eq!(result.as_slice::<i32>(), &[0, 0, 0]);
796    }
797
798    // ==================== N-ary Evaluate Tests ====================
799
800    #[test]
801    fn test_evaluate_nary_no_else_returns_null() {
802        let test_array =
803            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
804                .unwrap()
805                .into_array();
806
807        // Two conditions, no ELSE — unmatched rows should be NULL
808        let expr = nested_case_when(
809            vec![
810                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
811                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
812            ],
813            None,
814        );
815
816        let result = evaluate_expr(&expr, &test_array);
817        assert!(result.dtype().is_nullable());
818
819        assert_eq!(
820            result.scalar_at(0).unwrap(),
821            Scalar::from(10i32).cast(result.dtype()).unwrap()
822        );
823        assert_eq!(
824            result.scalar_at(1).unwrap(),
825            Scalar::null(result.dtype().clone())
826        );
827        assert_eq!(
828            result.scalar_at(2).unwrap(),
829            Scalar::from(30i32).cast(result.dtype()).unwrap()
830        );
831        assert_eq!(
832            result.scalar_at(3).unwrap(),
833            Scalar::null(result.dtype().clone())
834        );
835        assert_eq!(
836            result.scalar_at(4).unwrap(),
837            Scalar::null(result.dtype().clone())
838        );
839    }
840
841    #[test]
842    fn test_evaluate_nary_many_conditions() {
843        let test_array =
844            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
845                .unwrap()
846                .into_array();
847
848        // 5 WHEN/THEN pairs: each value maps to its value * 10
849        let expr = nested_case_when(
850            vec![
851                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
852                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
853                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
854                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
855                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
856            ],
857            Some(lit(0i32)),
858        );
859
860        let result = evaluate_expr(&expr, &test_array).to_primitive();
861        assert_eq!(result.as_slice::<i32>(), &[10, 20, 30, 40, 50]);
862    }
863
864    #[test]
865    fn test_evaluate_nary_all_false_no_else() {
866        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
867            .unwrap()
868            .into_array();
869
870        // All conditions are false, no ELSE — everything should be NULL
871        let expr = nested_case_when(
872            vec![
873                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
874                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
875            ],
876            None,
877        );
878
879        let result = evaluate_expr(&expr, &test_array);
880        assert!(result.dtype().is_nullable());
881        for i in 0..3 {
882            assert_eq!(
883                result.scalar_at(i).unwrap(),
884                Scalar::null(result.dtype().clone())
885            );
886        }
887    }
888
889    #[test]
890    fn test_evaluate_nary_overlapping_conditions_first_wins() {
891        let test_array =
892            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
893                .unwrap()
894                .into_array();
895
896        // value=10: matches cond1 (>5) and cond2 (>0), first should win
897        // value=20: matches all three, first should win
898        // value=30: matches all three, first should win
899        let expr = nested_case_when(
900            vec![
901                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
902                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
903                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
904            ],
905            Some(lit(0i32)),
906        );
907
908        let result = evaluate_expr(&expr, &test_array).to_primitive();
909        // First matching condition always wins
910        assert_eq!(result.as_slice::<i32>(), &[1, 1, 1]);
911    }
912
913    #[test]
914    fn test_evaluate_nary_with_nullable_conditions() {
915        let test_array = StructArray::from_fields(&[
916            (
917                "cond1",
918                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
919            ),
920            (
921                "cond2",
922                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
923            ),
924        ])
925        .unwrap()
926        .into_array();
927
928        let expr = nested_case_when(
929            vec![
930                (get_item("cond1", root()), lit(10i32)),
931                (get_item("cond2", root()), lit(20i32)),
932            ],
933            Some(lit(0i32)),
934        );
935
936        let result = evaluate_expr(&expr, &test_array).to_primitive();
937        // row 0: cond1=true → 10
938        // row 1: cond1=NULL(→false), cond2=true → 20
939        // row 2: cond1=false, cond2=NULL(→false) → else=0
940        assert_eq!(result.as_slice::<i32>(), &[10, 20, 0]);
941    }
942}