Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45/// Options for the n-ary CaseWhen expression.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48    /// Number of WHEN/THEN pairs.
49    pub num_when_then_pairs: u32,
50    /// Whether an ELSE clause is present.
51    /// If false, unmatched rows return NULL.
52    pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
57    pub fn num_children(&self) -> usize {
58        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59    }
60}
61
62impl fmt::Display for CaseWhenOptions {
63    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64        write!(
65            f,
66            "case_when(pairs={}, else={})",
67            self.num_when_then_pairs, self.has_else
68        )
69    }
70}
71
72/// An n-ary CASE WHEN expression.
73///
74/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
75#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79    type Options = CaseWhenOptions;
80
81    fn id(&self) -> ScalarFnId {
82        ScalarFnId::new("vortex.case_when")
83    }
84
85    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
87        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
88        // stabilize the expr
89        vortex_bail!("cannot serialize")
90    }
91
92    fn deserialize(
93        &self,
94        metadata: &[u8],
95        _session: &VortexSession,
96    ) -> VortexResult<Self::Options> {
97        let opts = pb::CaseWhenOpts::decode(metadata)?;
98        if opts.num_children < 2 {
99            vortex_bail!(
100                "CaseWhen expects at least 2 children, got {}",
101                opts.num_children
102            );
103        }
104        Ok(CaseWhenOptions {
105            num_when_then_pairs: opts.num_children / 2,
106            has_else: opts.num_children % 2 == 1,
107        })
108    }
109
110    fn arity(&self, options: &Self::Options) -> Arity {
111        Arity::Exact(options.num_children())
112    }
113
114    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115        let num_pair_children = options.num_when_then_pairs as usize * 2;
116        if child_idx < num_pair_children {
117            let pair_idx = child_idx / 2;
118            if child_idx.is_multiple_of(2) {
119                ChildName::from(Arc::from(format!("when_{pair_idx}")))
120            } else {
121                ChildName::from(Arc::from(format!("then_{pair_idx}")))
122            }
123        } else if options.has_else && child_idx == num_pair_children {
124            ChildName::from("else")
125        } else {
126            unreachable!("Invalid child index {} for CaseWhen", child_idx)
127        }
128    }
129
130    fn fmt_sql(
131        &self,
132        options: &Self::Options,
133        expr: &Expression,
134        f: &mut Formatter<'_>,
135    ) -> fmt::Result {
136        write!(f, "CASE")?;
137        for i in 0..options.num_when_then_pairs as usize {
138            write!(
139                f,
140                " WHEN {} THEN {}",
141                expr.child(i * 2),
142                expr.child(i * 2 + 1)
143            )?;
144        }
145        if options.has_else {
146            let else_idx = options.num_when_then_pairs as usize * 2;
147            write!(f, " ELSE {}", expr.child(else_idx))?;
148        }
149        write!(f, " END")
150    }
151
152    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153        if options.num_when_then_pairs == 0 {
154            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155        }
156
157        let expected_len = options.num_children();
158        if arg_dtypes.len() != expected_len {
159            vortex_bail!(
160                "CaseWhen expects {expected_len} argument dtypes, got {}",
161                arg_dtypes.len()
162            );
163        }
164
165        // Unlike SQL which coerces all branches to a common supertype, we require
166        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
167        // The result nullability is the union of all branches.
168        let first_then = &arg_dtypes[1];
169        let mut result_dtype = first_then.clone();
170
171        for i in 1..options.num_when_then_pairs as usize {
172            let then_i = &arg_dtypes[i * 2 + 1];
173            if !first_then.eq_ignore_nullability(then_i) {
174                vortex_bail!(
175                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176                    first_then,
177                    then_i
178                );
179            }
180            result_dtype = result_dtype.union_nullability(then_i.nullability());
181        }
182
183        if options.has_else {
184            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185            if !result_dtype.eq_ignore_nullability(else_dtype) {
186                vortex_bail!(
187                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188                    first_then,
189                    else_dtype
190                );
191            }
192            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193        } else {
194            // No ELSE means unmatched rows are NULL
195            result_dtype = result_dtype.as_nullable();
196        }
197
198        Ok(result_dtype)
199    }
200
201    fn execute(
202        &self,
203        options: &Self::Options,
204        args: &dyn ExecutionArgs,
205        ctx: &mut ExecutionCtx,
206    ) -> VortexResult<ArrayRef> {
207        // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/
208        //
209        // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction).
210        // TODO: project to only referenced columns before batch reduction (column projection).
211        // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results.
212        // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup.
213        let row_count = args.row_count();
214        let num_pairs = options.num_when_then_pairs as usize;
215
216        let mut remaining = Mask::new_true(row_count);
217        let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219        for i in 0..num_pairs {
220            if remaining.all_false() {
221                break;
222            }
223
224            let condition = args.get(i * 2)?;
225            let cond_bool = condition.execute::<BoolArray>(ctx)?;
226            let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
227            let effective_mask = &remaining & &cond_mask;
228
229            if effective_mask.all_false() {
230                continue;
231            }
232
233            let then_value = args.get(i * 2 + 1)?;
234            remaining = remaining.bitand_not(&cond_mask);
235            branches.push((effective_mask, then_value));
236        }
237
238        let else_value: ArrayRef = if options.has_else {
239            args.get(num_pairs * 2)?
240        } else {
241            let then_dtype = args.get(1)?.dtype().as_nullable();
242            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243        };
244
245        if branches.is_empty() {
246            return Ok(else_value);
247        }
248
249        merge_case_branches(branches, else_value, ctx)
250    }
251
252    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253        true
254    }
255
256    fn is_fallible(&self, _options: &Self::Options) -> bool {
257        false
258    }
259}
260
261/// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`.
262/// Measured empirically via benchmarks.
263const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array.
266///
267/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
268fn merge_case_branches(
269    branches: Vec<(Mask, ArrayRef)>,
270    else_value: ArrayRef,
271    ctx: &mut ExecutionCtx,
272) -> VortexResult<ArrayRef> {
273    if branches.len() == 1 {
274        let (mask, then_value) = &branches[0];
275        return zip_impl(then_value, &else_value, mask);
276    }
277
278    let output_nullability = branches
279        .iter()
280        .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
281            acc | arr.dtype().nullability()
282        });
283    let output_dtype = else_value.dtype().with_nullability(output_nullability);
284    let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
285
286    let mut spans: Vec<(usize, usize, usize)> = Vec::new();
287    for (branch_idx, (mask, _)) in branches.iter().enumerate() {
288        match mask.slices() {
289            AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
290            AllOr::None => {}
291            AllOr::Some(slices) => {
292                for &(start, end) in slices {
293                    spans.push((start, end, branch_idx));
294                }
295            }
296        }
297    }
298    spans.sort_unstable_by_key(|&(start, ..)| start);
299
300    if spans.is_empty() {
301        return else_value.cast(output_dtype);
302    }
303
304    let builder = builder_with_capacity(&output_dtype, else_value.len());
305
306    let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
307    if fragmented {
308        merge_row_by_row(
309            &branch_arrays,
310            &else_value,
311            &spans,
312            &output_dtype,
313            builder,
314            ctx,
315        )
316    } else {
317        merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
318    }
319}
320
321/// Iterates spans directly, emitting one `scalar_at` per row.
322/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]).
323fn merge_row_by_row(
324    branch_arrays: &[&ArrayRef],
325    else_value: &ArrayRef,
326    spans: &[(usize, usize, usize)],
327    output_dtype: &DType,
328    mut builder: Box<dyn ArrayBuilder>,
329    ctx: &mut ExecutionCtx,
330) -> VortexResult<ArrayRef> {
331    let mut pos = 0;
332    for &(start, end, branch_idx) in spans {
333        for row in pos..start {
334            let scalar = else_value.execute_scalar(row, ctx)?;
335            builder.append_scalar(&scalar.cast(output_dtype)?)?;
336        }
337        for row in start..end {
338            let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
339            builder.append_scalar(&scalar.cast(output_dtype)?)?;
340        }
341        pos = end;
342    }
343    for row in pos..else_value.len() {
344        let scalar = else_value.execute_scalar(row, ctx)?;
345        builder.append_scalar(&scalar.cast(output_dtype)?)?;
346    }
347
348    Ok(builder.finish())
349}
350
351/// Bulk-copies each span via `slice()` + `extend_from_array`.
352/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
353/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
354fn merge_run_by_run(
355    branch_arrays: &[&ArrayRef],
356    else_value: &ArrayRef,
357    spans: &[(usize, usize, usize)],
358    output_dtype: &DType,
359    mut builder: Box<dyn ArrayBuilder>,
360) -> VortexResult<ArrayRef> {
361    let else_value = else_value.cast(output_dtype.clone())?;
362    let len = else_value.len();
363    for (start, end, branch_idx) in spans {
364        if builder.len() < *start {
365            builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
366        }
367        builder.extend_from_array(
368            &branch_arrays[*branch_idx]
369                .cast(output_dtype.clone())?
370                .slice(*start..*end)?,
371        );
372    }
373    if builder.len() < len {
374        builder.extend_from_array(&else_value.slice(builder.len()..len)?);
375    }
376
377    Ok(builder.finish())
378}
379
380#[cfg(test)]
381mod tests {
382    use std::sync::LazyLock;
383
384    use vortex_buffer::buffer;
385    use vortex_error::VortexExpect as _;
386    use vortex_session::VortexSession;
387
388    use super::*;
389    use crate::Canonical;
390    use crate::IntoArray;
391    use crate::LEGACY_SESSION;
392    use crate::VortexSessionExecute;
393    use crate::arrays::BoolArray;
394    use crate::arrays::PrimitiveArray;
395    use crate::arrays::StructArray;
396    use crate::assert_arrays_eq;
397    use crate::dtype::DType;
398    use crate::dtype::Nullability;
399    use crate::dtype::PType;
400    use crate::expr::case_when;
401    use crate::expr::case_when_no_else;
402    use crate::expr::col;
403    use crate::expr::eq;
404    use crate::expr::get_item;
405    use crate::expr::gt;
406    use crate::expr::lit;
407    use crate::expr::nested_case_when;
408    use crate::expr::root;
409    use crate::expr::test_harness;
410    use crate::scalar::Scalar;
411    use crate::session::ArraySession;
412
413    static SESSION: LazyLock<VortexSession> =
414        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
415
416    /// Helper to evaluate an expression using the apply+execute pattern
417    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
418        let mut ctx = SESSION.create_execution_ctx();
419        array
420            .clone()
421            .apply(expr)
422            .unwrap()
423            .execute::<Canonical>(&mut ctx)
424            .unwrap()
425            .into_array()
426    }
427
428    // ==================== Serialization Tests ====================
429
430    #[test]
431    #[should_panic(expected = "cannot serialize")]
432    fn test_serialization_roundtrip() {
433        let options = CaseWhenOptions {
434            num_when_then_pairs: 1,
435            has_else: true,
436        };
437        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
438        let deserialized = CaseWhen
439            .deserialize(&serialized, &VortexSession::empty())
440            .unwrap();
441        assert_eq!(options, deserialized);
442    }
443
444    #[test]
445    #[should_panic(expected = "cannot serialize")]
446    fn test_serialization_no_else() {
447        let options = CaseWhenOptions {
448            num_when_then_pairs: 1,
449            has_else: false,
450        };
451        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
452        let deserialized = CaseWhen
453            .deserialize(&serialized, &VortexSession::empty())
454            .unwrap();
455        assert_eq!(options, deserialized);
456    }
457
458    // ==================== Display Tests ====================
459
460    #[test]
461    fn test_display_with_else() {
462        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
463        let display = format!("{}", expr);
464        assert!(display.contains("CASE"));
465        assert!(display.contains("WHEN"));
466        assert!(display.contains("THEN"));
467        assert!(display.contains("ELSE"));
468        assert!(display.contains("END"));
469    }
470
471    #[test]
472    fn test_display_no_else() {
473        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
474        let display = format!("{}", expr);
475        assert!(display.contains("CASE"));
476        assert!(display.contains("WHEN"));
477        assert!(display.contains("THEN"));
478        assert!(!display.contains("ELSE"));
479        assert!(display.contains("END"));
480    }
481
482    #[test]
483    fn test_display_nested_nary() {
484        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
485        let expr = nested_case_when(
486            vec![
487                (gt(col("x"), lit(10i32)), lit("high")),
488                (gt(col("x"), lit(5i32)), lit("medium")),
489            ],
490            Some(lit("low")),
491        );
492        let display = format!("{}", expr);
493        assert_eq!(display.matches("CASE").count(), 1);
494        assert_eq!(display.matches("WHEN").count(), 2);
495        assert_eq!(display.matches("THEN").count(), 2);
496    }
497
498    // ==================== DType Tests ====================
499
500    #[test]
501    fn test_return_dtype_with_else() {
502        let expr = case_when(lit(true), lit(100i32), lit(0i32));
503        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
504        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
505        assert_eq!(
506            result_dtype,
507            DType::Primitive(PType::I32, Nullability::NonNullable)
508        );
509    }
510
511    #[test]
512    fn test_return_dtype_with_nullable_else() {
513        let expr = case_when(
514            lit(true),
515            lit(100i32),
516            lit(Scalar::null(DType::Primitive(
517                PType::I32,
518                Nullability::Nullable,
519            ))),
520        );
521        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
522        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
523        assert_eq!(
524            result_dtype,
525            DType::Primitive(PType::I32, Nullability::Nullable)
526        );
527    }
528
529    #[test]
530    fn test_return_dtype_without_else_is_nullable() {
531        let expr = case_when_no_else(lit(true), lit(100i32));
532        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
533        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
534        assert_eq!(
535            result_dtype,
536            DType::Primitive(PType::I32, Nullability::Nullable)
537        );
538    }
539
540    #[test]
541    fn test_return_dtype_with_struct_input() {
542        let dtype = test_harness::struct_dtype();
543        let expr = case_when(
544            gt(get_item("col1", root()), lit(10u16)),
545            lit(100i32),
546            lit(0i32),
547        );
548        let result_dtype = expr.return_dtype(&dtype).unwrap();
549        assert_eq!(
550            result_dtype,
551            DType::Primitive(PType::I32, Nullability::NonNullable)
552        );
553    }
554
555    #[test]
556    fn test_return_dtype_mismatched_then_else_errors() {
557        let expr = case_when(lit(true), lit(100i32), lit("zero"));
558        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
559        let err = expr.return_dtype(&input_dtype).unwrap_err();
560        assert!(
561            err.to_string()
562                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
563        );
564    }
565
566    // ==================== Arity Tests ====================
567
568    #[test]
569    fn test_arity_with_else() {
570        let options = CaseWhenOptions {
571            num_when_then_pairs: 1,
572            has_else: true,
573        };
574        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
575    }
576
577    #[test]
578    fn test_arity_without_else() {
579        let options = CaseWhenOptions {
580            num_when_then_pairs: 1,
581            has_else: false,
582        };
583        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
584    }
585
586    // ==================== Child Name Tests ====================
587
588    #[test]
589    fn test_child_names() {
590        let options = CaseWhenOptions {
591            num_when_then_pairs: 1,
592            has_else: true,
593        };
594        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
595        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
596        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
597    }
598
599    // ==================== N-ary Serialization Tests ====================
600
601    #[test]
602    #[should_panic(expected = "cannot serialize")]
603    fn test_serialization_roundtrip_nary() {
604        let options = CaseWhenOptions {
605            num_when_then_pairs: 3,
606            has_else: true,
607        };
608        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
609        let deserialized = CaseWhen
610            .deserialize(&serialized, &VortexSession::empty())
611            .unwrap();
612        assert_eq!(options, deserialized);
613    }
614
615    #[test]
616    #[should_panic(expected = "cannot serialize")]
617    fn test_serialization_roundtrip_nary_no_else() {
618        let options = CaseWhenOptions {
619            num_when_then_pairs: 4,
620            has_else: false,
621        };
622        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
623        let deserialized = CaseWhen
624            .deserialize(&serialized, &VortexSession::empty())
625            .unwrap();
626        assert_eq!(options, deserialized);
627    }
628
629    // ==================== N-ary Arity Tests ====================
630
631    #[test]
632    fn test_arity_nary_with_else() {
633        let options = CaseWhenOptions {
634            num_when_then_pairs: 3,
635            has_else: true,
636        };
637        // 3 pairs * 2 children + 1 else = 7
638        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
639    }
640
641    #[test]
642    fn test_arity_nary_without_else() {
643        let options = CaseWhenOptions {
644            num_when_then_pairs: 3,
645            has_else: false,
646        };
647        // 3 pairs * 2 children = 6
648        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
649    }
650
651    // ==================== N-ary Child Name Tests ====================
652
653    #[test]
654    fn test_child_names_nary() {
655        let options = CaseWhenOptions {
656            num_when_then_pairs: 3,
657            has_else: true,
658        };
659        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
660        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
661        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
662        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
663        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
664        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
665        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
666    }
667
668    // ==================== N-ary DType Tests ====================
669
670    #[test]
671    fn test_return_dtype_nary_mismatched_then_types_errors() {
672        let expr = nested_case_when(
673            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
674            Some(lit(0i32)),
675        );
676        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
677        let err = expr.return_dtype(&input_dtype).unwrap_err();
678        assert!(err.to_string().contains("THEN dtypes must match"));
679    }
680
681    #[test]
682    fn test_return_dtype_nary_mixed_nullability() {
683        // When some THEN branches are nullable and others are not,
684        // the result should be nullable (union of nullabilities).
685        let non_null_then = lit(100i32);
686        let nullable_then = lit(Scalar::null(DType::Primitive(
687            PType::I32,
688            Nullability::Nullable,
689        )));
690        let expr = nested_case_when(
691            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
692            Some(lit(0i32)),
693        );
694        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
695        let result = expr.return_dtype(&input_dtype).unwrap();
696        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
697    }
698
699    #[test]
700    fn test_return_dtype_nary_no_else_is_nullable() {
701        let expr = nested_case_when(
702            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
703            None,
704        );
705        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
706        let result = expr.return_dtype(&input_dtype).unwrap();
707        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
708    }
709
710    // ==================== Expression Manipulation Tests ====================
711
712    #[test]
713    fn test_replace_children() {
714        let expr = case_when(lit(true), lit(1i32), lit(0i32));
715        expr.with_children([lit(false), lit(2i32), lit(3i32)])
716            .vortex_expect("operation should succeed in test");
717    }
718
719    // ==================== Evaluate Tests ====================
720
721    #[test]
722    fn test_evaluate_simple_condition() {
723        let test_array =
724            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
725                .unwrap()
726                .into_array();
727
728        let expr = case_when(
729            gt(get_item("value", root()), lit(2i32)),
730            lit(100i32),
731            lit(0i32),
732        );
733
734        let result = evaluate_expr(&expr, &test_array);
735        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
736    }
737
738    #[test]
739    fn test_evaluate_nary_multiple_conditions() {
740        // Test n-ary via nested_case_when
741        let test_array =
742            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
743                .unwrap()
744                .into_array();
745
746        let expr = nested_case_when(
747            vec![
748                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
749                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
750            ],
751            Some(lit(0i32)),
752        );
753
754        let result = evaluate_expr(&expr, &test_array);
755        assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
756    }
757
758    #[test]
759    fn test_evaluate_nary_first_match_wins() {
760        let test_array =
761            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
762                .unwrap()
763                .into_array();
764
765        // Both conditions match for values > 3, but first one wins
766        let expr = nested_case_when(
767            vec![
768                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
769                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
770            ],
771            Some(lit(0i32)),
772        );
773
774        let result = evaluate_expr(&expr, &test_array);
775        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
776    }
777
778    #[test]
779    fn test_evaluate_no_else_returns_null() {
780        let test_array =
781            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
782                .unwrap()
783                .into_array();
784
785        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
786
787        let result = evaluate_expr(&expr, &test_array);
788        assert!(result.dtype().is_nullable());
789        assert_arrays_eq!(
790            result,
791            PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
792                .into_array()
793        );
794    }
795
796    #[test]
797    fn test_evaluate_all_conditions_false() {
798        let test_array =
799            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
800                .unwrap()
801                .into_array();
802
803        let expr = case_when(
804            gt(get_item("value", root()), lit(100i32)),
805            lit(1i32),
806            lit(0i32),
807        );
808
809        let result = evaluate_expr(&expr, &test_array);
810        assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
811    }
812
813    #[test]
814    fn test_evaluate_all_conditions_true() {
815        let test_array =
816            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
817                .unwrap()
818                .into_array();
819
820        let expr = case_when(
821            gt(get_item("value", root()), lit(0i32)),
822            lit(100i32),
823            lit(0i32),
824        );
825
826        let result = evaluate_expr(&expr, &test_array);
827        assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
828    }
829
830    #[test]
831    fn test_evaluate_all_true_no_else_returns_correct_dtype() {
832        // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
833        // Result must be Nullable because the implicit ELSE is NULL.
834        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
835            .unwrap()
836            .into_array();
837
838        let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
839
840        let result = evaluate_expr(&expr, &test_array);
841        assert!(
842            result.dtype().is_nullable(),
843            "result dtype must be Nullable, got {:?}",
844            result.dtype()
845        );
846        assert_arrays_eq!(
847            result,
848            PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
849        );
850    }
851
852    #[test]
853    fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
854        // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
855        // the result dtype must still be Nullable.
856        //
857        // CASE WHEN value = 0 THEN 10          -- NonNullable
858        //      WHEN value = 1 THEN nullable(20) -- Nullable
859        //      ELSE 0                           -- NonNullable
860        // → result must be Nullable(i32)
861        let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
862            .unwrap()
863            .into_array();
864
865        let nullable_20 =
866            Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
867
868        let expr = nested_case_when(
869            vec![
870                (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
871                (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
872            ],
873            Some(lit(0i32)),
874        );
875
876        let result = evaluate_expr(&expr, &test_array);
877        assert!(
878            result.dtype().is_nullable(),
879            "result dtype must be Nullable, got {:?}",
880            result.dtype()
881        );
882        assert_arrays_eq!(
883            result,
884            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
885        );
886        Ok(())
887    }
888
889    #[test]
890    fn test_evaluate_with_literal_condition() {
891        let test_array = buffer![1i32, 2, 3].into_array();
892        let expr = case_when(lit(true), lit(100i32), lit(0i32));
893        let result = evaluate_expr(&expr, &test_array);
894
895        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
896    }
897
898    #[test]
899    fn test_evaluate_with_bool_column_result() {
900        let test_array =
901            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
902                .unwrap()
903                .into_array();
904
905        let expr = case_when(
906            gt(get_item("value", root()), lit(2i32)),
907            lit(true),
908            lit(false),
909        );
910
911        let result = evaluate_expr(&expr, &test_array);
912        assert_arrays_eq!(
913            result,
914            BoolArray::from_iter([false, false, true, true, true]).into_array()
915        );
916    }
917
918    #[test]
919    fn test_evaluate_with_nullable_condition() {
920        let test_array = StructArray::from_fields(&[(
921            "cond",
922            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
923        )])
924        .unwrap()
925        .into_array();
926
927        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
928
929        let result = evaluate_expr(&expr, &test_array);
930        assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
931    }
932
933    #[test]
934    fn test_evaluate_with_nullable_result_values() {
935        let test_array = StructArray::from_fields(&[
936            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
937            (
938                "result",
939                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
940                    .into_array(),
941            ),
942        ])
943        .unwrap()
944        .into_array();
945
946        let expr = case_when(
947            gt(get_item("value", root()), lit(2i32)),
948            get_item("result", root()),
949            lit(0i32),
950        );
951
952        let result = evaluate_expr(&expr, &test_array);
953        assert_arrays_eq!(
954            result,
955            PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
956                .into_array()
957        );
958    }
959
960    #[test]
961    fn test_evaluate_with_all_null_condition() {
962        let test_array = StructArray::from_fields(&[(
963            "cond",
964            BoolArray::from_iter([None, None, None]).into_array(),
965        )])
966        .unwrap()
967        .into_array();
968
969        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
970
971        let result = evaluate_expr(&expr, &test_array);
972        assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
973    }
974
975    // ==================== N-ary Evaluate Tests ====================
976
977    #[test]
978    fn test_evaluate_nary_no_else_returns_null() {
979        let test_array =
980            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
981                .unwrap()
982                .into_array();
983
984        // Two conditions, no ELSE — unmatched rows should be NULL
985        let expr = nested_case_when(
986            vec![
987                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
988                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
989            ],
990            None,
991        );
992
993        let result = evaluate_expr(&expr, &test_array);
994        assert!(result.dtype().is_nullable());
995        assert_arrays_eq!(
996            result,
997            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
998                .into_array()
999        );
1000    }
1001
1002    #[test]
1003    fn test_evaluate_nary_many_conditions() {
1004        let test_array =
1005            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1006                .unwrap()
1007                .into_array();
1008
1009        // 5 WHEN/THEN pairs: each value maps to its value * 10
1010        let expr = nested_case_when(
1011            vec![
1012                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1013                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1014                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1015                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1016                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1017            ],
1018            Some(lit(0i32)),
1019        );
1020
1021        let result = evaluate_expr(&expr, &test_array);
1022        assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1023    }
1024
1025    #[test]
1026    fn test_evaluate_nary_all_false_no_else() {
1027        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1028            .unwrap()
1029            .into_array();
1030
1031        // All conditions are false, no ELSE — everything should be NULL
1032        let expr = nested_case_when(
1033            vec![
1034                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1035                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1036            ],
1037            None,
1038        );
1039
1040        let result = evaluate_expr(&expr, &test_array);
1041        assert!(result.dtype().is_nullable());
1042        assert_arrays_eq!(
1043            result,
1044            PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1045        );
1046    }
1047
1048    #[test]
1049    fn test_evaluate_nary_overlapping_conditions_first_wins() {
1050        let test_array =
1051            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1052                .unwrap()
1053                .into_array();
1054
1055        // value=10: matches cond1 (>5) and cond2 (>0), first should win
1056        // value=20: matches all three, first should win
1057        // value=30: matches all three, first should win
1058        let expr = nested_case_when(
1059            vec![
1060                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1061                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1062                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1063            ],
1064            Some(lit(0i32)),
1065        );
1066
1067        let result = evaluate_expr(&expr, &test_array);
1068        // First matching condition always wins
1069        assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1070    }
1071
1072    #[test]
1073    fn test_evaluate_nary_early_exit_when_remaining_empty() {
1074        // After branch 0 claims all rows, remaining becomes all_false.
1075        // The loop breaks before evaluating branch 1's condition.
1076        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1077            .unwrap()
1078            .into_array();
1079
1080        let expr = nested_case_when(
1081            vec![
1082                (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1083                // Never evaluated due to early exit; 999 must never appear in output.
1084                (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1085            ],
1086            Some(lit(0i32)),
1087        );
1088
1089        let result = evaluate_expr(&expr, &test_array);
1090        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1091    }
1092
1093    #[test]
1094    fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1095        // Branch 0 claims value=1. Branch 1 targets the same rows but they are already
1096        // matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
1097        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1098            .unwrap()
1099            .into_array();
1100
1101        let expr = nested_case_when(
1102            vec![
1103                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1104                // Same condition as branch 0 — all matching rows already claimed → skipped.
1105                // 999 must never appear in output.
1106                (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1107                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1108            ],
1109            Some(lit(0i32)),
1110        );
1111
1112        let result = evaluate_expr(&expr, &test_array);
1113        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1114    }
1115
1116    #[test]
1117    fn test_evaluate_nary_string_output() -> VortexResult<()> {
1118        // Exercises merge_case_branches with a non-primitive (Utf8) builder.
1119        let test_array =
1120            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
1121                .unwrap()
1122                .into_array();
1123
1124        // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
1125        // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
1126        // value=3,4 → 'high' (branch 0)
1127        let expr = nested_case_when(
1128            vec![
1129                (gt(get_item("value", root()), lit(2i32)), lit("high")),
1130                (gt(get_item("value", root()), lit(0i32)), lit("low")),
1131            ],
1132            Some(lit("none")),
1133        );
1134
1135        let result = evaluate_expr(&expr, &test_array);
1136        assert_eq!(
1137            result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1138            Scalar::utf8("low", Nullability::NonNullable)
1139        );
1140        assert_eq!(
1141            result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1142            Scalar::utf8("low", Nullability::NonNullable)
1143        );
1144        assert_eq!(
1145            result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1146            Scalar::utf8("high", Nullability::NonNullable)
1147        );
1148        assert_eq!(
1149            result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1150            Scalar::utf8("high", Nullability::NonNullable)
1151        );
1152        Ok(())
1153    }
1154
1155    #[test]
1156    fn test_evaluate_nary_with_nullable_conditions() {
1157        let test_array = StructArray::from_fields(&[
1158            (
1159                "cond1",
1160                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1161            ),
1162            (
1163                "cond2",
1164                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1165            ),
1166        ])
1167        .unwrap()
1168        .into_array();
1169
1170        let expr = nested_case_when(
1171            vec![
1172                (get_item("cond1", root()), lit(10i32)),
1173                (get_item("cond2", root()), lit(20i32)),
1174            ],
1175            Some(lit(0i32)),
1176        );
1177
1178        let result = evaluate_expr(&expr, &test_array);
1179        // row 0: cond1=true → 10
1180        // row 1: cond1=NULL(→false), cond2=true → 20
1181        // row 2: cond1=false, cond2=NULL(→false) → else=0
1182        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1183    }
1184
1185    #[test]
1186    fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1187        // Exercises the scalar path: alternating rows produce one slice per row (no runs),
1188        // triggering the per-row cursor path in merge_case_branches.
1189        let n = 100usize;
1190
1191        // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached.
1192        let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1193        let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1194
1195        let result = merge_case_branches(
1196            vec![
1197                (
1198                    branch0_mask,
1199                    PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1200                ),
1201                (
1202                    branch1_mask,
1203                    PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1204                ),
1205            ],
1206            PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1207            &mut SESSION.create_execution_ctx(),
1208        )?;
1209
1210        // Even rows → 0, odd rows → 1.
1211        let expected: Vec<Option<i32>> = (0..n)
1212            .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1213            .collect();
1214        assert_arrays_eq!(
1215            result,
1216            PrimitiveArray::from_option_iter(expected).into_array()
1217        );
1218        Ok(())
1219    }
1220}