Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45/// Options for the n-ary CaseWhen expression.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48    /// Number of WHEN/THEN pairs.
49    pub num_when_then_pairs: u32,
50    /// Whether an ELSE clause is present.
51    /// If false, unmatched rows return NULL.
52    pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
57    pub fn num_children(&self) -> usize {
58        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59    }
60}
61
62impl fmt::Display for CaseWhenOptions {
63    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64        write!(
65            f,
66            "case_when(pairs={}, else={})",
67            self.num_when_then_pairs, self.has_else
68        )
69    }
70}
71
72/// An n-ary CASE WHEN expression.
73///
74/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
75#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79    type Options = CaseWhenOptions;
80
81    fn id(&self) -> ScalarFnId {
82        ScalarFnId::from("vortex.case_when")
83    }
84
85    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
87        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
88        // stabilize the expr
89        vortex_bail!("cannot serialize")
90    }
91
92    fn deserialize(
93        &self,
94        metadata: &[u8],
95        _session: &VortexSession,
96    ) -> VortexResult<Self::Options> {
97        let opts = pb::CaseWhenOpts::decode(metadata)?;
98        if opts.num_children < 2 {
99            vortex_bail!(
100                "CaseWhen expects at least 2 children, got {}",
101                opts.num_children
102            );
103        }
104        Ok(CaseWhenOptions {
105            num_when_then_pairs: opts.num_children / 2,
106            has_else: opts.num_children % 2 == 1,
107        })
108    }
109
110    fn arity(&self, options: &Self::Options) -> Arity {
111        Arity::Exact(options.num_children())
112    }
113
114    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115        let num_pair_children = options.num_when_then_pairs as usize * 2;
116        if child_idx < num_pair_children {
117            let pair_idx = child_idx / 2;
118            if child_idx.is_multiple_of(2) {
119                ChildName::from(Arc::from(format!("when_{pair_idx}")))
120            } else {
121                ChildName::from(Arc::from(format!("then_{pair_idx}")))
122            }
123        } else if options.has_else && child_idx == num_pair_children {
124            ChildName::from("else")
125        } else {
126            unreachable!("Invalid child index {} for CaseWhen", child_idx)
127        }
128    }
129
130    fn fmt_sql(
131        &self,
132        options: &Self::Options,
133        expr: &Expression,
134        f: &mut Formatter<'_>,
135    ) -> fmt::Result {
136        write!(f, "CASE")?;
137        for i in 0..options.num_when_then_pairs as usize {
138            write!(
139                f,
140                " WHEN {} THEN {}",
141                expr.child(i * 2),
142                expr.child(i * 2 + 1)
143            )?;
144        }
145        if options.has_else {
146            let else_idx = options.num_when_then_pairs as usize * 2;
147            write!(f, " ELSE {}", expr.child(else_idx))?;
148        }
149        write!(f, " END")
150    }
151
152    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153        if options.num_when_then_pairs == 0 {
154            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155        }
156
157        let expected_len = options.num_children();
158        if arg_dtypes.len() != expected_len {
159            vortex_bail!(
160                "CaseWhen expects {expected_len} argument dtypes, got {}",
161                arg_dtypes.len()
162            );
163        }
164
165        // Unlike SQL which coerces all branches to a common supertype, we require
166        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
167        // The result nullability is the union of all branches.
168        let first_then = &arg_dtypes[1];
169        let mut result_dtype = first_then.clone();
170
171        for i in 1..options.num_when_then_pairs as usize {
172            let then_i = &arg_dtypes[i * 2 + 1];
173            if !first_then.eq_ignore_nullability(then_i) {
174                vortex_bail!(
175                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176                    first_then,
177                    then_i
178                );
179            }
180            result_dtype = result_dtype.union_nullability(then_i.nullability());
181        }
182
183        if options.has_else {
184            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185            if !result_dtype.eq_ignore_nullability(else_dtype) {
186                vortex_bail!(
187                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188                    first_then,
189                    else_dtype
190                );
191            }
192            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193        } else {
194            // No ELSE means unmatched rows are NULL
195            result_dtype = result_dtype.as_nullable();
196        }
197
198        Ok(result_dtype)
199    }
200
201    fn execute(
202        &self,
203        options: &Self::Options,
204        args: &dyn ExecutionArgs,
205        ctx: &mut ExecutionCtx,
206    ) -> VortexResult<ArrayRef> {
207        // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/
208        //
209        // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction).
210        // TODO: project to only referenced columns before batch reduction (column projection).
211        // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results.
212        // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup.
213        let row_count = args.row_count();
214        let num_pairs = options.num_when_then_pairs as usize;
215
216        let mut remaining = Mask::new_true(row_count);
217        let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219        for i in 0..num_pairs {
220            if remaining.all_false() {
221                break;
222            }
223
224            let condition = args.get(i * 2)?;
225            let cond_bool = condition.execute::<BoolArray>(ctx)?;
226            let cond_mask = cond_bool.to_mask_fill_null_false();
227            let effective_mask = &remaining & &cond_mask;
228
229            if effective_mask.all_false() {
230                continue;
231            }
232
233            let then_value = args.get(i * 2 + 1)?;
234            remaining = remaining.bitand_not(&cond_mask);
235            branches.push((effective_mask, then_value));
236        }
237
238        let else_value: ArrayRef = if options.has_else {
239            args.get(num_pairs * 2)?
240        } else {
241            let then_dtype = args.get(1)?.dtype().as_nullable();
242            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243        };
244
245        if branches.is_empty() {
246            return Ok(else_value);
247        }
248
249        merge_case_branches(branches, else_value)
250    }
251
252    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253        true
254    }
255
256    fn is_fallible(&self, _options: &Self::Options) -> bool {
257        false
258    }
259}
260
261/// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`.
262/// Measured empirically via benchmarks.
263const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array.
266///
267/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
268fn merge_case_branches(
269    branches: Vec<(Mask, ArrayRef)>,
270    else_value: ArrayRef,
271) -> VortexResult<ArrayRef> {
272    if branches.len() == 1 {
273        let (mask, then_value) = &branches[0];
274        return zip_impl(then_value, &else_value, mask);
275    }
276
277    let output_nullability = branches
278        .iter()
279        .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
280            acc | arr.dtype().nullability()
281        });
282    let output_dtype = else_value.dtype().with_nullability(output_nullability);
283    let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
284
285    let mut spans: Vec<(usize, usize, usize)> = Vec::new();
286    for (branch_idx, (mask, _)) in branches.iter().enumerate() {
287        match mask.slices() {
288            AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
289            AllOr::None => {}
290            AllOr::Some(slices) => {
291                for &(start, end) in slices {
292                    spans.push((start, end, branch_idx));
293                }
294            }
295        }
296    }
297    spans.sort_unstable_by_key(|&(start, ..)| start);
298
299    if spans.is_empty() {
300        return else_value.cast(output_dtype);
301    }
302
303    let builder = builder_with_capacity(&output_dtype, else_value.len());
304
305    let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
306    if fragmented {
307        merge_row_by_row(&branch_arrays, &else_value, &spans, &output_dtype, builder)
308    } else {
309        merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
310    }
311}
312
313/// Iterates spans directly, emitting one `scalar_at` per row.
314/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]).
315fn merge_row_by_row(
316    branch_arrays: &[&ArrayRef],
317    else_value: &ArrayRef,
318    spans: &[(usize, usize, usize)],
319    output_dtype: &DType,
320    mut builder: Box<dyn ArrayBuilder>,
321) -> VortexResult<ArrayRef> {
322    let mut pos = 0;
323    for &(start, end, branch_idx) in spans {
324        for row in pos..start {
325            let scalar = else_value.scalar_at(row)?;
326            builder.append_scalar(&scalar.cast(output_dtype)?)?;
327        }
328        for row in start..end {
329            let scalar = branch_arrays[branch_idx].scalar_at(row)?;
330            builder.append_scalar(&scalar.cast(output_dtype)?)?;
331        }
332        pos = end;
333    }
334    for row in pos..else_value.len() {
335        let scalar = else_value.scalar_at(row)?;
336        builder.append_scalar(&scalar.cast(output_dtype)?)?;
337    }
338
339    Ok(builder.finish())
340}
341
342/// Bulk-copies each span via `slice()` + `extend_from_array`.
343/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
344/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
345fn merge_run_by_run(
346    branch_arrays: &[&ArrayRef],
347    else_value: &ArrayRef,
348    spans: &[(usize, usize, usize)],
349    output_dtype: &DType,
350    mut builder: Box<dyn ArrayBuilder>,
351) -> VortexResult<ArrayRef> {
352    let else_value = else_value.cast(output_dtype.clone())?;
353    let len = else_value.len();
354    for (start, end, branch_idx) in spans {
355        if builder.len() < *start {
356            builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
357        }
358        builder.extend_from_array(
359            &branch_arrays[*branch_idx]
360                .cast(output_dtype.clone())?
361                .slice(*start..*end)?,
362        );
363    }
364    if builder.len() < len {
365        builder.extend_from_array(&else_value.slice(builder.len()..len)?);
366    }
367
368    Ok(builder.finish())
369}
370
371#[cfg(test)]
372mod tests {
373    use std::sync::LazyLock;
374
375    use vortex_buffer::buffer;
376    use vortex_error::VortexExpect as _;
377    use vortex_session::VortexSession;
378
379    use super::*;
380    use crate::Canonical;
381    use crate::IntoArray;
382    use crate::VortexSessionExecute as _;
383    use crate::arrays::BoolArray;
384    use crate::arrays::PrimitiveArray;
385    use crate::arrays::StructArray;
386    use crate::assert_arrays_eq;
387    use crate::dtype::DType;
388    use crate::dtype::Nullability;
389    use crate::dtype::PType;
390    use crate::expr::case_when;
391    use crate::expr::case_when_no_else;
392    use crate::expr::col;
393    use crate::expr::eq;
394    use crate::expr::get_item;
395    use crate::expr::gt;
396    use crate::expr::lit;
397    use crate::expr::nested_case_when;
398    use crate::expr::root;
399    use crate::expr::test_harness;
400    use crate::scalar::Scalar;
401    use crate::session::ArraySession;
402
403    static SESSION: LazyLock<VortexSession> =
404        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
405
406    /// Helper to evaluate an expression using the apply+execute pattern
407    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
408        let mut ctx = SESSION.create_execution_ctx();
409        array
410            .clone()
411            .apply(expr)
412            .unwrap()
413            .execute::<Canonical>(&mut ctx)
414            .unwrap()
415            .into_array()
416    }
417
418    // ==================== Serialization Tests ====================
419
420    #[test]
421    #[should_panic(expected = "cannot serialize")]
422    fn test_serialization_roundtrip() {
423        let options = CaseWhenOptions {
424            num_when_then_pairs: 1,
425            has_else: true,
426        };
427        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
428        let deserialized = CaseWhen
429            .deserialize(&serialized, &VortexSession::empty())
430            .unwrap();
431        assert_eq!(options, deserialized);
432    }
433
434    #[test]
435    #[should_panic(expected = "cannot serialize")]
436    fn test_serialization_no_else() {
437        let options = CaseWhenOptions {
438            num_when_then_pairs: 1,
439            has_else: false,
440        };
441        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
442        let deserialized = CaseWhen
443            .deserialize(&serialized, &VortexSession::empty())
444            .unwrap();
445        assert_eq!(options, deserialized);
446    }
447
448    // ==================== Display Tests ====================
449
450    #[test]
451    fn test_display_with_else() {
452        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
453        let display = format!("{}", expr);
454        assert!(display.contains("CASE"));
455        assert!(display.contains("WHEN"));
456        assert!(display.contains("THEN"));
457        assert!(display.contains("ELSE"));
458        assert!(display.contains("END"));
459    }
460
461    #[test]
462    fn test_display_no_else() {
463        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
464        let display = format!("{}", expr);
465        assert!(display.contains("CASE"));
466        assert!(display.contains("WHEN"));
467        assert!(display.contains("THEN"));
468        assert!(!display.contains("ELSE"));
469        assert!(display.contains("END"));
470    }
471
472    #[test]
473    fn test_display_nested_nary() {
474        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
475        let expr = nested_case_when(
476            vec![
477                (gt(col("x"), lit(10i32)), lit("high")),
478                (gt(col("x"), lit(5i32)), lit("medium")),
479            ],
480            Some(lit("low")),
481        );
482        let display = format!("{}", expr);
483        assert_eq!(display.matches("CASE").count(), 1);
484        assert_eq!(display.matches("WHEN").count(), 2);
485        assert_eq!(display.matches("THEN").count(), 2);
486    }
487
488    // ==================== DType Tests ====================
489
490    #[test]
491    fn test_return_dtype_with_else() {
492        let expr = case_when(lit(true), lit(100i32), lit(0i32));
493        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
494        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
495        assert_eq!(
496            result_dtype,
497            DType::Primitive(PType::I32, Nullability::NonNullable)
498        );
499    }
500
501    #[test]
502    fn test_return_dtype_with_nullable_else() {
503        let expr = case_when(
504            lit(true),
505            lit(100i32),
506            lit(Scalar::null(DType::Primitive(
507                PType::I32,
508                Nullability::Nullable,
509            ))),
510        );
511        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
512        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
513        assert_eq!(
514            result_dtype,
515            DType::Primitive(PType::I32, Nullability::Nullable)
516        );
517    }
518
519    #[test]
520    fn test_return_dtype_without_else_is_nullable() {
521        let expr = case_when_no_else(lit(true), lit(100i32));
522        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
523        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
524        assert_eq!(
525            result_dtype,
526            DType::Primitive(PType::I32, Nullability::Nullable)
527        );
528    }
529
530    #[test]
531    fn test_return_dtype_with_struct_input() {
532        let dtype = test_harness::struct_dtype();
533        let expr = case_when(
534            gt(get_item("col1", root()), lit(10u16)),
535            lit(100i32),
536            lit(0i32),
537        );
538        let result_dtype = expr.return_dtype(&dtype).unwrap();
539        assert_eq!(
540            result_dtype,
541            DType::Primitive(PType::I32, Nullability::NonNullable)
542        );
543    }
544
545    #[test]
546    fn test_return_dtype_mismatched_then_else_errors() {
547        let expr = case_when(lit(true), lit(100i32), lit("zero"));
548        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
549        let err = expr.return_dtype(&input_dtype).unwrap_err();
550        assert!(
551            err.to_string()
552                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
553        );
554    }
555
556    // ==================== Arity Tests ====================
557
558    #[test]
559    fn test_arity_with_else() {
560        let options = CaseWhenOptions {
561            num_when_then_pairs: 1,
562            has_else: true,
563        };
564        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
565    }
566
567    #[test]
568    fn test_arity_without_else() {
569        let options = CaseWhenOptions {
570            num_when_then_pairs: 1,
571            has_else: false,
572        };
573        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
574    }
575
576    // ==================== Child Name Tests ====================
577
578    #[test]
579    fn test_child_names() {
580        let options = CaseWhenOptions {
581            num_when_then_pairs: 1,
582            has_else: true,
583        };
584        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
585        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
586        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
587    }
588
589    // ==================== N-ary Serialization Tests ====================
590
591    #[test]
592    #[should_panic(expected = "cannot serialize")]
593    fn test_serialization_roundtrip_nary() {
594        let options = CaseWhenOptions {
595            num_when_then_pairs: 3,
596            has_else: true,
597        };
598        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
599        let deserialized = CaseWhen
600            .deserialize(&serialized, &VortexSession::empty())
601            .unwrap();
602        assert_eq!(options, deserialized);
603    }
604
605    #[test]
606    #[should_panic(expected = "cannot serialize")]
607    fn test_serialization_roundtrip_nary_no_else() {
608        let options = CaseWhenOptions {
609            num_when_then_pairs: 4,
610            has_else: false,
611        };
612        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
613        let deserialized = CaseWhen
614            .deserialize(&serialized, &VortexSession::empty())
615            .unwrap();
616        assert_eq!(options, deserialized);
617    }
618
619    // ==================== N-ary Arity Tests ====================
620
621    #[test]
622    fn test_arity_nary_with_else() {
623        let options = CaseWhenOptions {
624            num_when_then_pairs: 3,
625            has_else: true,
626        };
627        // 3 pairs * 2 children + 1 else = 7
628        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
629    }
630
631    #[test]
632    fn test_arity_nary_without_else() {
633        let options = CaseWhenOptions {
634            num_when_then_pairs: 3,
635            has_else: false,
636        };
637        // 3 pairs * 2 children = 6
638        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
639    }
640
641    // ==================== N-ary Child Name Tests ====================
642
643    #[test]
644    fn test_child_names_nary() {
645        let options = CaseWhenOptions {
646            num_when_then_pairs: 3,
647            has_else: true,
648        };
649        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
650        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
651        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
652        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
653        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
654        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
655        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
656    }
657
658    // ==================== N-ary DType Tests ====================
659
660    #[test]
661    fn test_return_dtype_nary_mismatched_then_types_errors() {
662        let expr = nested_case_when(
663            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
664            Some(lit(0i32)),
665        );
666        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
667        let err = expr.return_dtype(&input_dtype).unwrap_err();
668        assert!(err.to_string().contains("THEN dtypes must match"));
669    }
670
671    #[test]
672    fn test_return_dtype_nary_mixed_nullability() {
673        // When some THEN branches are nullable and others are not,
674        // the result should be nullable (union of nullabilities).
675        let non_null_then = lit(100i32);
676        let nullable_then = lit(Scalar::null(DType::Primitive(
677            PType::I32,
678            Nullability::Nullable,
679        )));
680        let expr = nested_case_when(
681            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
682            Some(lit(0i32)),
683        );
684        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
685        let result = expr.return_dtype(&input_dtype).unwrap();
686        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
687    }
688
689    #[test]
690    fn test_return_dtype_nary_no_else_is_nullable() {
691        let expr = nested_case_when(
692            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
693            None,
694        );
695        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
696        let result = expr.return_dtype(&input_dtype).unwrap();
697        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
698    }
699
700    // ==================== Expression Manipulation Tests ====================
701
702    #[test]
703    fn test_replace_children() {
704        let expr = case_when(lit(true), lit(1i32), lit(0i32));
705        expr.with_children([lit(false), lit(2i32), lit(3i32)])
706            .vortex_expect("operation should succeed in test");
707    }
708
709    // ==================== Evaluate Tests ====================
710
711    #[test]
712    fn test_evaluate_simple_condition() {
713        let test_array =
714            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
715                .unwrap()
716                .into_array();
717
718        let expr = case_when(
719            gt(get_item("value", root()), lit(2i32)),
720            lit(100i32),
721            lit(0i32),
722        );
723
724        let result = evaluate_expr(&expr, &test_array);
725        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
726    }
727
728    #[test]
729    fn test_evaluate_nary_multiple_conditions() {
730        // Test n-ary via nested_case_when
731        let test_array =
732            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
733                .unwrap()
734                .into_array();
735
736        let expr = nested_case_when(
737            vec![
738                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
739                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
740            ],
741            Some(lit(0i32)),
742        );
743
744        let result = evaluate_expr(&expr, &test_array);
745        assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
746    }
747
748    #[test]
749    fn test_evaluate_nary_first_match_wins() {
750        let test_array =
751            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
752                .unwrap()
753                .into_array();
754
755        // Both conditions match for values > 3, but first one wins
756        let expr = nested_case_when(
757            vec![
758                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
759                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
760            ],
761            Some(lit(0i32)),
762        );
763
764        let result = evaluate_expr(&expr, &test_array);
765        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
766    }
767
768    #[test]
769    fn test_evaluate_no_else_returns_null() {
770        let test_array =
771            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
772                .unwrap()
773                .into_array();
774
775        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
776
777        let result = evaluate_expr(&expr, &test_array);
778        assert!(result.dtype().is_nullable());
779        assert_arrays_eq!(
780            result,
781            PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
782                .into_array()
783        );
784    }
785
786    #[test]
787    fn test_evaluate_all_conditions_false() {
788        let test_array =
789            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
790                .unwrap()
791                .into_array();
792
793        let expr = case_when(
794            gt(get_item("value", root()), lit(100i32)),
795            lit(1i32),
796            lit(0i32),
797        );
798
799        let result = evaluate_expr(&expr, &test_array);
800        assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
801    }
802
803    #[test]
804    fn test_evaluate_all_conditions_true() {
805        let test_array =
806            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
807                .unwrap()
808                .into_array();
809
810        let expr = case_when(
811            gt(get_item("value", root()), lit(0i32)),
812            lit(100i32),
813            lit(0i32),
814        );
815
816        let result = evaluate_expr(&expr, &test_array);
817        assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
818    }
819
820    #[test]
821    fn test_evaluate_all_true_no_else_returns_correct_dtype() {
822        // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
823        // Result must be Nullable because the implicit ELSE is NULL.
824        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
825            .unwrap()
826            .into_array();
827
828        let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
829
830        let result = evaluate_expr(&expr, &test_array);
831        assert!(
832            result.dtype().is_nullable(),
833            "result dtype must be Nullable, got {:?}",
834            result.dtype()
835        );
836        assert_arrays_eq!(
837            result,
838            PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
839        );
840    }
841
842    #[test]
843    fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
844        // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
845        // the result dtype must still be Nullable.
846        //
847        // CASE WHEN value = 0 THEN 10          -- NonNullable
848        //      WHEN value = 1 THEN nullable(20) -- Nullable
849        //      ELSE 0                           -- NonNullable
850        // → result must be Nullable(i32)
851        let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
852            .unwrap()
853            .into_array();
854
855        let nullable_20 =
856            Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
857
858        let expr = nested_case_when(
859            vec![
860                (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
861                (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
862            ],
863            Some(lit(0i32)),
864        );
865
866        let result = evaluate_expr(&expr, &test_array);
867        assert!(
868            result.dtype().is_nullable(),
869            "result dtype must be Nullable, got {:?}",
870            result.dtype()
871        );
872        assert_arrays_eq!(
873            result,
874            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
875        );
876        Ok(())
877    }
878
879    #[test]
880    fn test_evaluate_with_literal_condition() {
881        let test_array = buffer![1i32, 2, 3].into_array();
882        let expr = case_when(lit(true), lit(100i32), lit(0i32));
883        let result = evaluate_expr(&expr, &test_array);
884
885        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
886    }
887
888    #[test]
889    fn test_evaluate_with_bool_column_result() {
890        let test_array =
891            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
892                .unwrap()
893                .into_array();
894
895        let expr = case_when(
896            gt(get_item("value", root()), lit(2i32)),
897            lit(true),
898            lit(false),
899        );
900
901        let result = evaluate_expr(&expr, &test_array);
902        assert_arrays_eq!(
903            result,
904            BoolArray::from_iter([false, false, true, true, true]).into_array()
905        );
906    }
907
908    #[test]
909    fn test_evaluate_with_nullable_condition() {
910        let test_array = StructArray::from_fields(&[(
911            "cond",
912            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
913        )])
914        .unwrap()
915        .into_array();
916
917        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
918
919        let result = evaluate_expr(&expr, &test_array);
920        assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
921    }
922
923    #[test]
924    fn test_evaluate_with_nullable_result_values() {
925        let test_array = StructArray::from_fields(&[
926            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
927            (
928                "result",
929                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
930                    .into_array(),
931            ),
932        ])
933        .unwrap()
934        .into_array();
935
936        let expr = case_when(
937            gt(get_item("value", root()), lit(2i32)),
938            get_item("result", root()),
939            lit(0i32),
940        );
941
942        let result = evaluate_expr(&expr, &test_array);
943        assert_arrays_eq!(
944            result,
945            PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
946                .into_array()
947        );
948    }
949
950    #[test]
951    fn test_evaluate_with_all_null_condition() {
952        let test_array = StructArray::from_fields(&[(
953            "cond",
954            BoolArray::from_iter([None, None, None]).into_array(),
955        )])
956        .unwrap()
957        .into_array();
958
959        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
960
961        let result = evaluate_expr(&expr, &test_array);
962        assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
963    }
964
965    // ==================== N-ary Evaluate Tests ====================
966
967    #[test]
968    fn test_evaluate_nary_no_else_returns_null() {
969        let test_array =
970            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
971                .unwrap()
972                .into_array();
973
974        // Two conditions, no ELSE — unmatched rows should be NULL
975        let expr = nested_case_when(
976            vec![
977                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
978                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
979            ],
980            None,
981        );
982
983        let result = evaluate_expr(&expr, &test_array);
984        assert!(result.dtype().is_nullable());
985        assert_arrays_eq!(
986            result,
987            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
988                .into_array()
989        );
990    }
991
992    #[test]
993    fn test_evaluate_nary_many_conditions() {
994        let test_array =
995            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
996                .unwrap()
997                .into_array();
998
999        // 5 WHEN/THEN pairs: each value maps to its value * 10
1000        let expr = nested_case_when(
1001            vec![
1002                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1003                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1004                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1005                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1006                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1007            ],
1008            Some(lit(0i32)),
1009        );
1010
1011        let result = evaluate_expr(&expr, &test_array);
1012        assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1013    }
1014
1015    #[test]
1016    fn test_evaluate_nary_all_false_no_else() {
1017        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1018            .unwrap()
1019            .into_array();
1020
1021        // All conditions are false, no ELSE — everything should be NULL
1022        let expr = nested_case_when(
1023            vec![
1024                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1025                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1026            ],
1027            None,
1028        );
1029
1030        let result = evaluate_expr(&expr, &test_array);
1031        assert!(result.dtype().is_nullable());
1032        assert_arrays_eq!(
1033            result,
1034            PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1035        );
1036    }
1037
1038    #[test]
1039    fn test_evaluate_nary_overlapping_conditions_first_wins() {
1040        let test_array =
1041            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1042                .unwrap()
1043                .into_array();
1044
1045        // value=10: matches cond1 (>5) and cond2 (>0), first should win
1046        // value=20: matches all three, first should win
1047        // value=30: matches all three, first should win
1048        let expr = nested_case_when(
1049            vec![
1050                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1051                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1052                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1053            ],
1054            Some(lit(0i32)),
1055        );
1056
1057        let result = evaluate_expr(&expr, &test_array);
1058        // First matching condition always wins
1059        assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1060    }
1061
1062    #[test]
1063    fn test_evaluate_nary_early_exit_when_remaining_empty() {
1064        // After branch 0 claims all rows, remaining becomes all_false.
1065        // The loop breaks before evaluating branch 1's condition.
1066        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1067            .unwrap()
1068            .into_array();
1069
1070        let expr = nested_case_when(
1071            vec![
1072                (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1073                // Never evaluated due to early exit; 999 must never appear in output.
1074                (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1075            ],
1076            Some(lit(0i32)),
1077        );
1078
1079        let result = evaluate_expr(&expr, &test_array);
1080        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1081    }
1082
1083    #[test]
1084    fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1085        // Branch 0 claims value=1. Branch 1 targets the same rows but they are already
1086        // matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
1087        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1088            .unwrap()
1089            .into_array();
1090
1091        let expr = nested_case_when(
1092            vec![
1093                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1094                // Same condition as branch 0 — all matching rows already claimed → skipped.
1095                // 999 must never appear in output.
1096                (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1097                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1098            ],
1099            Some(lit(0i32)),
1100        );
1101
1102        let result = evaluate_expr(&expr, &test_array);
1103        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1104    }
1105
1106    #[test]
1107    fn test_evaluate_nary_string_output() -> VortexResult<()> {
1108        // Exercises merge_case_branches with a non-primitive (Utf8) builder.
1109        let test_array =
1110            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
1111                .unwrap()
1112                .into_array();
1113
1114        // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
1115        // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
1116        // value=3,4 → 'high' (branch 0)
1117        let expr = nested_case_when(
1118            vec![
1119                (gt(get_item("value", root()), lit(2i32)), lit("high")),
1120                (gt(get_item("value", root()), lit(0i32)), lit("low")),
1121            ],
1122            Some(lit("none")),
1123        );
1124
1125        let result = evaluate_expr(&expr, &test_array);
1126        assert_eq!(
1127            result.scalar_at(0)?,
1128            Scalar::utf8("low", Nullability::NonNullable)
1129        );
1130        assert_eq!(
1131            result.scalar_at(1)?,
1132            Scalar::utf8("low", Nullability::NonNullable)
1133        );
1134        assert_eq!(
1135            result.scalar_at(2)?,
1136            Scalar::utf8("high", Nullability::NonNullable)
1137        );
1138        assert_eq!(
1139            result.scalar_at(3)?,
1140            Scalar::utf8("high", Nullability::NonNullable)
1141        );
1142        Ok(())
1143    }
1144
1145    #[test]
1146    fn test_evaluate_nary_with_nullable_conditions() {
1147        let test_array = StructArray::from_fields(&[
1148            (
1149                "cond1",
1150                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1151            ),
1152            (
1153                "cond2",
1154                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1155            ),
1156        ])
1157        .unwrap()
1158        .into_array();
1159
1160        let expr = nested_case_when(
1161            vec![
1162                (get_item("cond1", root()), lit(10i32)),
1163                (get_item("cond2", root()), lit(20i32)),
1164            ],
1165            Some(lit(0i32)),
1166        );
1167
1168        let result = evaluate_expr(&expr, &test_array);
1169        // row 0: cond1=true → 10
1170        // row 1: cond1=NULL(→false), cond2=true → 20
1171        // row 2: cond1=false, cond2=NULL(→false) → else=0
1172        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1173    }
1174
1175    #[test]
1176    fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1177        // Exercises the scalar path: alternating rows produce one slice per row (no runs),
1178        // triggering the per-row cursor path in merge_case_branches.
1179        let n = 100usize;
1180
1181        // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached.
1182        let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1183        let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1184
1185        let result = merge_case_branches(
1186            vec![
1187                (
1188                    branch0_mask,
1189                    PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1190                ),
1191                (
1192                    branch1_mask,
1193                    PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1194                ),
1195            ],
1196            PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1197        )?;
1198
1199        // Even rows → 0, odd rows → 1.
1200        let expected: Vec<Option<i32>> = (0..n)
1201            .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1202            .collect();
1203        assert_arrays_eq!(
1204            result,
1205            PrimitiveArray::from_option_iter(expected).into_array()
1206        );
1207        Ok(())
1208    }
1209}