Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::bool::BoolArrayExt;
32use crate::builders::ArrayBuilder;
33use crate::builders::builder_with_capacity;
34use crate::builtins::ArrayBuiltins;
35use crate::dtype::DType;
36use crate::expr::Expression;
37use crate::scalar::Scalar;
38use crate::scalar_fn::Arity;
39use crate::scalar_fn::ChildName;
40use crate::scalar_fn::ExecutionArgs;
41use crate::scalar_fn::ScalarFnId;
42use crate::scalar_fn::ScalarFnVTable;
43use crate::scalar_fn::fns::zip::zip_impl;
44
45/// Options for the n-ary CaseWhen expression.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct CaseWhenOptions {
48    /// Number of WHEN/THEN pairs.
49    pub num_when_then_pairs: u32,
50    /// Whether an ELSE clause is present.
51    /// If false, unmatched rows return NULL.
52    pub has_else: bool,
53}
54
55impl CaseWhenOptions {
56    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
57    pub fn num_children(&self) -> usize {
58        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
59    }
60}
61
62impl fmt::Display for CaseWhenOptions {
63    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
64        write!(
65            f,
66            "case_when(pairs={}, else={})",
67            self.num_when_then_pairs, self.has_else
68        )
69    }
70}
71
72/// An n-ary CASE WHEN expression.
73///
74/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
75#[derive(Clone)]
76pub struct CaseWhen;
77
78impl ScalarFnVTable for CaseWhen {
79    type Options = CaseWhenOptions;
80
81    fn id(&self) -> ScalarFnId {
82        ScalarFnId::new("vortex.case_when")
83    }
84
85    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
86        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
87        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
88        // stabilize the expr
89        vortex_bail!("cannot serialize")
90    }
91
92    fn deserialize(
93        &self,
94        metadata: &[u8],
95        _session: &VortexSession,
96    ) -> VortexResult<Self::Options> {
97        let opts = pb::CaseWhenOpts::decode(metadata)?;
98        if opts.num_children < 2 {
99            vortex_bail!(
100                "CaseWhen expects at least 2 children, got {}",
101                opts.num_children
102            );
103        }
104        Ok(CaseWhenOptions {
105            num_when_then_pairs: opts.num_children / 2,
106            has_else: opts.num_children % 2 == 1,
107        })
108    }
109
110    fn arity(&self, options: &Self::Options) -> Arity {
111        Arity::Exact(options.num_children())
112    }
113
114    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
115        let num_pair_children = options.num_when_then_pairs as usize * 2;
116        if child_idx < num_pair_children {
117            let pair_idx = child_idx / 2;
118            if child_idx.is_multiple_of(2) {
119                ChildName::from(Arc::from(format!("when_{pair_idx}")))
120            } else {
121                ChildName::from(Arc::from(format!("then_{pair_idx}")))
122            }
123        } else if options.has_else && child_idx == num_pair_children {
124            ChildName::from("else")
125        } else {
126            unreachable!("Invalid child index {} for CaseWhen", child_idx)
127        }
128    }
129
130    fn fmt_sql(
131        &self,
132        options: &Self::Options,
133        expr: &Expression,
134        f: &mut Formatter<'_>,
135    ) -> fmt::Result {
136        write!(f, "CASE")?;
137        for i in 0..options.num_when_then_pairs as usize {
138            write!(
139                f,
140                " WHEN {} THEN {}",
141                expr.child(i * 2),
142                expr.child(i * 2 + 1)
143            )?;
144        }
145        if options.has_else {
146            let else_idx = options.num_when_then_pairs as usize * 2;
147            write!(f, " ELSE {}", expr.child(else_idx))?;
148        }
149        write!(f, " END")
150    }
151
152    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
153        if options.num_when_then_pairs == 0 {
154            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
155        }
156
157        let expected_len = options.num_children();
158        if arg_dtypes.len() != expected_len {
159            vortex_bail!(
160                "CaseWhen expects {expected_len} argument dtypes, got {}",
161                arg_dtypes.len()
162            );
163        }
164
165        // Unlike SQL which coerces all branches to a common supertype, we require
166        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
167        // The result nullability is the union of all branches.
168        let first_then = &arg_dtypes[1];
169        let mut result_dtype = first_then.clone();
170
171        for i in 1..options.num_when_then_pairs as usize {
172            let then_i = &arg_dtypes[i * 2 + 1];
173            if !first_then.eq_ignore_nullability(then_i) {
174                vortex_bail!(
175                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
176                    first_then,
177                    then_i
178                );
179            }
180            result_dtype = result_dtype.union_nullability(then_i.nullability());
181        }
182
183        if options.has_else {
184            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
185            if !result_dtype.eq_ignore_nullability(else_dtype) {
186                vortex_bail!(
187                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
188                    first_then,
189                    else_dtype
190                );
191            }
192            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
193        } else {
194            // No ELSE means unmatched rows are NULL
195            result_dtype = result_dtype.as_nullable();
196        }
197
198        Ok(result_dtype)
199    }
200
201    fn execute(
202        &self,
203        options: &Self::Options,
204        args: &dyn ExecutionArgs,
205        ctx: &mut ExecutionCtx,
206    ) -> VortexResult<ArrayRef> {
207        // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/
208        //
209        // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction).
210        // TODO: project to only referenced columns before batch reduction (column projection).
211        // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results.
212        // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup.
213        let row_count = args.row_count();
214        let num_pairs = options.num_when_then_pairs as usize;
215
216        let mut remaining = Mask::new_true(row_count);
217        let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
218
219        for i in 0..num_pairs {
220            if remaining.all_false() {
221                break;
222            }
223
224            let condition = args.get(i * 2)?;
225            let cond_bool = condition.execute::<BoolArray>(ctx)?;
226            let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
227            let effective_mask = &remaining & &cond_mask;
228
229            if effective_mask.all_false() {
230                continue;
231            }
232
233            let then_value = args.get(i * 2 + 1)?;
234            remaining = remaining.bitand_not(&cond_mask);
235            branches.push((effective_mask, then_value));
236        }
237
238        let else_value: ArrayRef = if options.has_else {
239            args.get(num_pairs * 2)?
240        } else {
241            let then_dtype = args.get(1)?.dtype().as_nullable();
242            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
243        };
244
245        if branches.is_empty() {
246            return Ok(else_value);
247        }
248
249        merge_case_branches(branches, else_value, ctx)
250    }
251
252    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
253        true
254    }
255
256    fn is_fallible(&self, _options: &Self::Options) -> bool {
257        false
258    }
259}
260
261/// Average run length at which slicing + context-aware builder appends become cheaper than `scalar_at`.
262/// Measured empirically via benchmarks.
263const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264
265/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array.
266///
267/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
268fn merge_case_branches(
269    branches: Vec<(Mask, ArrayRef)>,
270    else_value: ArrayRef,
271    ctx: &mut ExecutionCtx,
272) -> VortexResult<ArrayRef> {
273    if branches.len() == 1 {
274        let (mask, then_value) = &branches[0];
275        return zip_impl(then_value, &else_value, mask, ctx);
276    }
277
278    let output_nullability = branches
279        .iter()
280        .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
281            acc | arr.dtype().nullability()
282        });
283    let output_dtype = else_value.dtype().with_nullability(output_nullability);
284    let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
285
286    let mut spans: Vec<(usize, usize, usize)> = Vec::new();
287    for (branch_idx, (mask, _)) in branches.iter().enumerate() {
288        match mask.slices() {
289            AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
290            AllOr::None => {}
291            AllOr::Some(slices) => {
292                for &(start, end) in slices {
293                    spans.push((start, end, branch_idx));
294                }
295            }
296        }
297    }
298    spans.sort_unstable_by_key(|&(start, ..)| start);
299
300    if spans.is_empty() {
301        return else_value.cast(output_dtype);
302    }
303
304    let builder = builder_with_capacity(&output_dtype, else_value.len());
305
306    let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
307    if fragmented {
308        merge_row_by_row(
309            &branch_arrays,
310            &else_value,
311            &spans,
312            &output_dtype,
313            builder,
314            ctx,
315        )
316    } else {
317        merge_run_by_run(
318            &branch_arrays,
319            &else_value,
320            &spans,
321            &output_dtype,
322            builder,
323            ctx,
324        )
325    }
326}
327
328/// Iterates spans directly, emitting one `scalar_at` per row.
329/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]).
330fn merge_row_by_row(
331    branch_arrays: &[&ArrayRef],
332    else_value: &ArrayRef,
333    spans: &[(usize, usize, usize)],
334    output_dtype: &DType,
335    mut builder: Box<dyn ArrayBuilder>,
336    ctx: &mut ExecutionCtx,
337) -> VortexResult<ArrayRef> {
338    let mut pos = 0;
339    for &(start, end, branch_idx) in spans {
340        for row in pos..start {
341            let scalar = else_value.execute_scalar(row, ctx)?;
342            builder.append_scalar(&scalar.cast(output_dtype)?)?;
343        }
344        for row in start..end {
345            let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
346            builder.append_scalar(&scalar.cast(output_dtype)?)?;
347        }
348        pos = end;
349    }
350    for row in pos..else_value.len() {
351        let scalar = else_value.execute_scalar(row, ctx)?;
352        builder.append_scalar(&scalar.cast(output_dtype)?)?;
353    }
354
355    Ok(builder.finish())
356}
357
358/// Bulk-copies each span via `slice()` and context-aware builder appends.
359/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
360/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
361fn merge_run_by_run(
362    branch_arrays: &[&ArrayRef],
363    else_value: &ArrayRef,
364    spans: &[(usize, usize, usize)],
365    output_dtype: &DType,
366    mut builder: Box<dyn ArrayBuilder>,
367    ctx: &mut ExecutionCtx,
368) -> VortexResult<ArrayRef> {
369    let else_value = else_value.cast(output_dtype.clone())?;
370    let len = else_value.len();
371    for (start, end, branch_idx) in spans {
372        if builder.len() < *start {
373            else_value
374                .slice(builder.len()..*start)?
375                .append_to_builder(builder.as_mut(), ctx)?;
376        }
377        branch_arrays[*branch_idx]
378            .cast(output_dtype.clone())?
379            .slice(*start..*end)?
380            .append_to_builder(builder.as_mut(), ctx)?;
381    }
382    if builder.len() < len {
383        else_value
384            .slice(builder.len()..len)?
385            .append_to_builder(builder.as_mut(), ctx)?;
386    }
387
388    Ok(builder.finish())
389}
390
391#[cfg(test)]
392mod tests {
393    use std::sync::LazyLock;
394
395    use vortex_buffer::buffer;
396    use vortex_error::VortexExpect as _;
397    use vortex_session::VortexSession;
398
399    use super::*;
400    use crate::Canonical;
401    use crate::IntoArray;
402    use crate::LEGACY_SESSION;
403    use crate::VortexSessionExecute;
404    use crate::arrays::BoolArray;
405    use crate::arrays::PrimitiveArray;
406    use crate::arrays::StructArray;
407    use crate::assert_arrays_eq;
408    use crate::dtype::DType;
409    use crate::dtype::Nullability;
410    use crate::dtype::PType;
411    use crate::expr::case_when;
412    use crate::expr::case_when_no_else;
413    use crate::expr::col;
414    use crate::expr::eq;
415    use crate::expr::get_item;
416    use crate::expr::gt;
417    use crate::expr::lit;
418    use crate::expr::nested_case_when;
419    use crate::expr::root;
420    use crate::expr::test_harness;
421    use crate::scalar::Scalar;
422    use crate::session::ArraySession;
423
424    static SESSION: LazyLock<VortexSession> =
425        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
426
427    /// Helper to evaluate an expression using the apply+execute pattern
428    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
429        let mut ctx = SESSION.create_execution_ctx();
430        array
431            .clone()
432            .apply(expr)
433            .unwrap()
434            .execute::<Canonical>(&mut ctx)
435            .unwrap()
436            .into_array()
437    }
438
439    // ==================== Serialization Tests ====================
440
441    #[test]
442    #[should_panic(expected = "cannot serialize")]
443    fn test_serialization_roundtrip() {
444        let options = CaseWhenOptions {
445            num_when_then_pairs: 1,
446            has_else: true,
447        };
448        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
449        let deserialized = CaseWhen
450            .deserialize(&serialized, &VortexSession::empty())
451            .unwrap();
452        assert_eq!(options, deserialized);
453    }
454
455    #[test]
456    #[should_panic(expected = "cannot serialize")]
457    fn test_serialization_no_else() {
458        let options = CaseWhenOptions {
459            num_when_then_pairs: 1,
460            has_else: false,
461        };
462        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
463        let deserialized = CaseWhen
464            .deserialize(&serialized, &VortexSession::empty())
465            .unwrap();
466        assert_eq!(options, deserialized);
467    }
468
469    // ==================== Display Tests ====================
470
471    #[test]
472    fn test_display_with_else() {
473        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
474        let display = format!("{}", expr);
475        assert!(display.contains("CASE"));
476        assert!(display.contains("WHEN"));
477        assert!(display.contains("THEN"));
478        assert!(display.contains("ELSE"));
479        assert!(display.contains("END"));
480    }
481
482    #[test]
483    fn test_display_no_else() {
484        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
485        let display = format!("{}", expr);
486        assert!(display.contains("CASE"));
487        assert!(display.contains("WHEN"));
488        assert!(display.contains("THEN"));
489        assert!(!display.contains("ELSE"));
490        assert!(display.contains("END"));
491    }
492
493    #[test]
494    fn test_display_nested_nary() {
495        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
496        let expr = nested_case_when(
497            vec![
498                (gt(col("x"), lit(10i32)), lit("high")),
499                (gt(col("x"), lit(5i32)), lit("medium")),
500            ],
501            Some(lit("low")),
502        );
503        let display = format!("{}", expr);
504        assert_eq!(display.matches("CASE").count(), 1);
505        assert_eq!(display.matches("WHEN").count(), 2);
506        assert_eq!(display.matches("THEN").count(), 2);
507    }
508
509    // ==================== DType Tests ====================
510
511    #[test]
512    fn test_return_dtype_with_else() {
513        let expr = case_when(lit(true), lit(100i32), lit(0i32));
514        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
515        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
516        assert_eq!(
517            result_dtype,
518            DType::Primitive(PType::I32, Nullability::NonNullable)
519        );
520    }
521
522    #[test]
523    fn test_return_dtype_with_nullable_else() {
524        let expr = case_when(
525            lit(true),
526            lit(100i32),
527            lit(Scalar::null(DType::Primitive(
528                PType::I32,
529                Nullability::Nullable,
530            ))),
531        );
532        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
533        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
534        assert_eq!(
535            result_dtype,
536            DType::Primitive(PType::I32, Nullability::Nullable)
537        );
538    }
539
540    #[test]
541    fn test_return_dtype_without_else_is_nullable() {
542        let expr = case_when_no_else(lit(true), lit(100i32));
543        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
544        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
545        assert_eq!(
546            result_dtype,
547            DType::Primitive(PType::I32, Nullability::Nullable)
548        );
549    }
550
551    #[test]
552    fn test_return_dtype_with_struct_input() {
553        let dtype = test_harness::struct_dtype();
554        let expr = case_when(
555            gt(get_item("col1", root()), lit(10u16)),
556            lit(100i32),
557            lit(0i32),
558        );
559        let result_dtype = expr.return_dtype(&dtype).unwrap();
560        assert_eq!(
561            result_dtype,
562            DType::Primitive(PType::I32, Nullability::NonNullable)
563        );
564    }
565
566    #[test]
567    fn test_return_dtype_mismatched_then_else_errors() {
568        let expr = case_when(lit(true), lit(100i32), lit("zero"));
569        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
570        let err = expr.return_dtype(&input_dtype).unwrap_err();
571        assert!(
572            err.to_string()
573                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
574        );
575    }
576
577    // ==================== Arity Tests ====================
578
579    #[test]
580    fn test_arity_with_else() {
581        let options = CaseWhenOptions {
582            num_when_then_pairs: 1,
583            has_else: true,
584        };
585        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
586    }
587
588    #[test]
589    fn test_arity_without_else() {
590        let options = CaseWhenOptions {
591            num_when_then_pairs: 1,
592            has_else: false,
593        };
594        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
595    }
596
597    // ==================== Child Name Tests ====================
598
599    #[test]
600    fn test_child_names() {
601        let options = CaseWhenOptions {
602            num_when_then_pairs: 1,
603            has_else: true,
604        };
605        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
606        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
607        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
608    }
609
610    // ==================== N-ary Serialization Tests ====================
611
612    #[test]
613    #[should_panic(expected = "cannot serialize")]
614    fn test_serialization_roundtrip_nary() {
615        let options = CaseWhenOptions {
616            num_when_then_pairs: 3,
617            has_else: true,
618        };
619        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
620        let deserialized = CaseWhen
621            .deserialize(&serialized, &VortexSession::empty())
622            .unwrap();
623        assert_eq!(options, deserialized);
624    }
625
626    #[test]
627    #[should_panic(expected = "cannot serialize")]
628    fn test_serialization_roundtrip_nary_no_else() {
629        let options = CaseWhenOptions {
630            num_when_then_pairs: 4,
631            has_else: false,
632        };
633        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
634        let deserialized = CaseWhen
635            .deserialize(&serialized, &VortexSession::empty())
636            .unwrap();
637        assert_eq!(options, deserialized);
638    }
639
640    // ==================== N-ary Arity Tests ====================
641
642    #[test]
643    fn test_arity_nary_with_else() {
644        let options = CaseWhenOptions {
645            num_when_then_pairs: 3,
646            has_else: true,
647        };
648        // 3 pairs * 2 children + 1 else = 7
649        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
650    }
651
652    #[test]
653    fn test_arity_nary_without_else() {
654        let options = CaseWhenOptions {
655            num_when_then_pairs: 3,
656            has_else: false,
657        };
658        // 3 pairs * 2 children = 6
659        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
660    }
661
662    // ==================== N-ary Child Name Tests ====================
663
664    #[test]
665    fn test_child_names_nary() {
666        let options = CaseWhenOptions {
667            num_when_then_pairs: 3,
668            has_else: true,
669        };
670        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
671        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
672        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
673        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
674        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
675        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
676        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
677    }
678
679    // ==================== N-ary DType Tests ====================
680
681    #[test]
682    fn test_return_dtype_nary_mismatched_then_types_errors() {
683        let expr = nested_case_when(
684            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
685            Some(lit(0i32)),
686        );
687        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
688        let err = expr.return_dtype(&input_dtype).unwrap_err();
689        assert!(err.to_string().contains("THEN dtypes must match"));
690    }
691
692    #[test]
693    fn test_return_dtype_nary_mixed_nullability() {
694        // When some THEN branches are nullable and others are not,
695        // the result should be nullable (union of nullabilities).
696        let non_null_then = lit(100i32);
697        let nullable_then = lit(Scalar::null(DType::Primitive(
698            PType::I32,
699            Nullability::Nullable,
700        )));
701        let expr = nested_case_when(
702            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
703            Some(lit(0i32)),
704        );
705        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
706        let result = expr.return_dtype(&input_dtype).unwrap();
707        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
708    }
709
710    #[test]
711    fn test_return_dtype_nary_no_else_is_nullable() {
712        let expr = nested_case_when(
713            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
714            None,
715        );
716        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
717        let result = expr.return_dtype(&input_dtype).unwrap();
718        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
719    }
720
721    // ==================== Expression Manipulation Tests ====================
722
723    #[test]
724    fn test_replace_children() {
725        let expr = case_when(lit(true), lit(1i32), lit(0i32));
726        expr.with_children([lit(false), lit(2i32), lit(3i32)])
727            .vortex_expect("operation should succeed in test");
728    }
729
730    // ==================== Evaluate Tests ====================
731
732    #[test]
733    fn test_evaluate_simple_condition() {
734        let test_array =
735            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
736                .unwrap()
737                .into_array();
738
739        let expr = case_when(
740            gt(get_item("value", root()), lit(2i32)),
741            lit(100i32),
742            lit(0i32),
743        );
744
745        let result = evaluate_expr(&expr, &test_array);
746        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
747    }
748
749    #[test]
750    fn test_evaluate_nary_multiple_conditions() {
751        // Test n-ary via nested_case_when
752        let test_array =
753            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
754                .unwrap()
755                .into_array();
756
757        let expr = nested_case_when(
758            vec![
759                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
760                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
761            ],
762            Some(lit(0i32)),
763        );
764
765        let result = evaluate_expr(&expr, &test_array);
766        assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
767    }
768
769    #[test]
770    fn test_evaluate_nary_first_match_wins() {
771        let test_array =
772            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
773                .unwrap()
774                .into_array();
775
776        // Both conditions match for values > 3, but first one wins
777        let expr = nested_case_when(
778            vec![
779                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
780                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
781            ],
782            Some(lit(0i32)),
783        );
784
785        let result = evaluate_expr(&expr, &test_array);
786        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
787    }
788
789    #[test]
790    fn test_evaluate_no_else_returns_null() {
791        let test_array =
792            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
793                .unwrap()
794                .into_array();
795
796        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
797
798        let result = evaluate_expr(&expr, &test_array);
799        assert!(result.dtype().is_nullable());
800        assert_arrays_eq!(
801            result,
802            PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
803                .into_array()
804        );
805    }
806
807    #[test]
808    fn test_evaluate_all_conditions_false() {
809        let test_array =
810            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
811                .unwrap()
812                .into_array();
813
814        let expr = case_when(
815            gt(get_item("value", root()), lit(100i32)),
816            lit(1i32),
817            lit(0i32),
818        );
819
820        let result = evaluate_expr(&expr, &test_array);
821        assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
822    }
823
824    #[test]
825    fn test_evaluate_all_conditions_true() {
826        let test_array =
827            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
828                .unwrap()
829                .into_array();
830
831        let expr = case_when(
832            gt(get_item("value", root()), lit(0i32)),
833            lit(100i32),
834            lit(0i32),
835        );
836
837        let result = evaluate_expr(&expr, &test_array);
838        assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
839    }
840
841    #[test]
842    fn test_evaluate_all_true_no_else_returns_correct_dtype() {
843        // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
844        // Result must be Nullable because the implicit ELSE is NULL.
845        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
846            .unwrap()
847            .into_array();
848
849        let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
850
851        let result = evaluate_expr(&expr, &test_array);
852        assert!(
853            result.dtype().is_nullable(),
854            "result dtype must be Nullable, got {:?}",
855            result.dtype()
856        );
857        assert_arrays_eq!(
858            result,
859            PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
860        );
861    }
862
863    #[test]
864    fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
865        // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
866        // the result dtype must still be Nullable.
867        //
868        // CASE WHEN value = 0 THEN 10          -- NonNullable
869        //      WHEN value = 1 THEN nullable(20) -- Nullable
870        //      ELSE 0                           -- NonNullable
871        // → result must be Nullable(i32)
872        let test_array =
873            StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
874
875        let nullable_20 =
876            Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
877
878        let expr = nested_case_when(
879            vec![
880                (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
881                (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
882            ],
883            Some(lit(0i32)),
884        );
885
886        let result = evaluate_expr(&expr, &test_array);
887        assert!(
888            result.dtype().is_nullable(),
889            "result dtype must be Nullable, got {:?}",
890            result.dtype()
891        );
892        assert_arrays_eq!(
893            result,
894            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
895        );
896        Ok(())
897    }
898
899    #[test]
900    fn test_evaluate_with_literal_condition() {
901        let test_array = buffer![1i32, 2, 3].into_array();
902        let expr = case_when(lit(true), lit(100i32), lit(0i32));
903        let result = evaluate_expr(&expr, &test_array);
904
905        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
906    }
907
908    #[test]
909    fn test_evaluate_with_bool_column_result() {
910        let test_array =
911            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
912                .unwrap()
913                .into_array();
914
915        let expr = case_when(
916            gt(get_item("value", root()), lit(2i32)),
917            lit(true),
918            lit(false),
919        );
920
921        let result = evaluate_expr(&expr, &test_array);
922        assert_arrays_eq!(
923            result,
924            BoolArray::from_iter([false, false, true, true, true]).into_array()
925        );
926    }
927
928    #[test]
929    fn test_evaluate_with_nullable_condition() {
930        let test_array = StructArray::from_fields(&[(
931            "cond",
932            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
933        )])
934        .unwrap()
935        .into_array();
936
937        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
938
939        let result = evaluate_expr(&expr, &test_array);
940        assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
941    }
942
943    #[test]
944    fn test_evaluate_with_nullable_result_values() {
945        let test_array = StructArray::from_fields(&[
946            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
947            (
948                "result",
949                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
950                    .into_array(),
951            ),
952        ])
953        .unwrap()
954        .into_array();
955
956        let expr = case_when(
957            gt(get_item("value", root()), lit(2i32)),
958            get_item("result", root()),
959            lit(0i32),
960        );
961
962        let result = evaluate_expr(&expr, &test_array);
963        assert_arrays_eq!(
964            result,
965            PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
966                .into_array()
967        );
968    }
969
970    #[test]
971    fn test_evaluate_with_all_null_condition() {
972        let test_array = StructArray::from_fields(&[(
973            "cond",
974            BoolArray::from_iter([None, None, None]).into_array(),
975        )])
976        .unwrap()
977        .into_array();
978
979        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
980
981        let result = evaluate_expr(&expr, &test_array);
982        assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
983    }
984
985    // ==================== N-ary Evaluate Tests ====================
986
987    #[test]
988    fn test_evaluate_nary_no_else_returns_null() {
989        let test_array =
990            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
991                .unwrap()
992                .into_array();
993
994        // Two conditions, no ELSE — unmatched rows should be NULL
995        let expr = nested_case_when(
996            vec![
997                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
998                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
999            ],
1000            None,
1001        );
1002
1003        let result = evaluate_expr(&expr, &test_array);
1004        assert!(result.dtype().is_nullable());
1005        assert_arrays_eq!(
1006            result,
1007            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
1008                .into_array()
1009        );
1010    }
1011
1012    #[test]
1013    fn test_evaluate_nary_many_conditions() {
1014        let test_array =
1015            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1016                .unwrap()
1017                .into_array();
1018
1019        // 5 WHEN/THEN pairs: each value maps to its value * 10
1020        let expr = nested_case_when(
1021            vec![
1022                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1023                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1024                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1025                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1026                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1027            ],
1028            Some(lit(0i32)),
1029        );
1030
1031        let result = evaluate_expr(&expr, &test_array);
1032        assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1033    }
1034
1035    #[test]
1036    fn test_evaluate_nary_all_false_no_else() {
1037        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1038            .unwrap()
1039            .into_array();
1040
1041        // All conditions are false, no ELSE — everything should be NULL
1042        let expr = nested_case_when(
1043            vec![
1044                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1045                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1046            ],
1047            None,
1048        );
1049
1050        let result = evaluate_expr(&expr, &test_array);
1051        assert!(result.dtype().is_nullable());
1052        assert_arrays_eq!(
1053            result,
1054            PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1055        );
1056    }
1057
1058    #[test]
1059    fn test_evaluate_nary_overlapping_conditions_first_wins() {
1060        let test_array =
1061            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1062                .unwrap()
1063                .into_array();
1064
1065        // value=10: matches cond1 (>5) and cond2 (>0), first should win
1066        // value=20: matches all three, first should win
1067        // value=30: matches all three, first should win
1068        let expr = nested_case_when(
1069            vec![
1070                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1071                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1072                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1073            ],
1074            Some(lit(0i32)),
1075        );
1076
1077        let result = evaluate_expr(&expr, &test_array);
1078        // First matching condition always wins
1079        assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1080    }
1081
1082    #[test]
1083    fn test_evaluate_nary_early_exit_when_remaining_empty() {
1084        // After branch 0 claims all rows, remaining becomes all_false.
1085        // The loop breaks before evaluating branch 1's condition.
1086        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1087            .unwrap()
1088            .into_array();
1089
1090        let expr = nested_case_when(
1091            vec![
1092                (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1093                // Never evaluated due to early exit; 999 must never appear in output.
1094                (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1095            ],
1096            Some(lit(0i32)),
1097        );
1098
1099        let result = evaluate_expr(&expr, &test_array);
1100        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1101    }
1102
1103    #[test]
1104    fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1105        // Branch 0 claims value=1. Branch 1 targets the same rows but they are already
1106        // matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
1107        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1108            .unwrap()
1109            .into_array();
1110
1111        let expr = nested_case_when(
1112            vec![
1113                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1114                // Same condition as branch 0 — all matching rows already claimed → skipped.
1115                // 999 must never appear in output.
1116                (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1117                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1118            ],
1119            Some(lit(0i32)),
1120        );
1121
1122        let result = evaluate_expr(&expr, &test_array);
1123        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1124    }
1125
1126    #[test]
1127    fn test_evaluate_nary_string_output() -> VortexResult<()> {
1128        // Exercises merge_case_branches with a non-primitive (Utf8) builder.
1129        let test_array =
1130            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1131                .into_array();
1132
1133        // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
1134        // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
1135        // value=3,4 → 'high' (branch 0)
1136        let expr = nested_case_when(
1137            vec![
1138                (gt(get_item("value", root()), lit(2i32)), lit("high")),
1139                (gt(get_item("value", root()), lit(0i32)), lit("low")),
1140            ],
1141            Some(lit("none")),
1142        );
1143
1144        let result = evaluate_expr(&expr, &test_array);
1145        assert_eq!(
1146            result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1147            Scalar::utf8("low", Nullability::NonNullable)
1148        );
1149        assert_eq!(
1150            result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1151            Scalar::utf8("low", Nullability::NonNullable)
1152        );
1153        assert_eq!(
1154            result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1155            Scalar::utf8("high", Nullability::NonNullable)
1156        );
1157        assert_eq!(
1158            result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1159            Scalar::utf8("high", Nullability::NonNullable)
1160        );
1161        Ok(())
1162    }
1163
1164    #[test]
1165    fn test_evaluate_nary_with_nullable_conditions() {
1166        let test_array = StructArray::from_fields(&[
1167            (
1168                "cond1",
1169                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1170            ),
1171            (
1172                "cond2",
1173                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1174            ),
1175        ])
1176        .unwrap()
1177        .into_array();
1178
1179        let expr = nested_case_when(
1180            vec![
1181                (get_item("cond1", root()), lit(10i32)),
1182                (get_item("cond2", root()), lit(20i32)),
1183            ],
1184            Some(lit(0i32)),
1185        );
1186
1187        let result = evaluate_expr(&expr, &test_array);
1188        // row 0: cond1=true → 10
1189        // row 1: cond1=NULL(→false), cond2=true → 20
1190        // row 2: cond1=false, cond2=NULL(→false) → else=0
1191        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1192    }
1193
1194    #[test]
1195    fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1196        // Exercises the scalar path: alternating rows produce one slice per row (no runs),
1197        // triggering the per-row cursor path in merge_case_branches.
1198        let n = 100usize;
1199
1200        // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached.
1201        let branch0_mask = Mask::from_indices(n, (0..n).step_by(2));
1202        let branch1_mask = Mask::from_indices(n, (1..n).step_by(2));
1203
1204        let result = merge_case_branches(
1205            vec![
1206                (
1207                    branch0_mask,
1208                    PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1209                ),
1210                (
1211                    branch1_mask,
1212                    PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1213                ),
1214            ],
1215            PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1216            &mut SESSION.create_execution_ctx(),
1217        )?;
1218
1219        // Even rows → 0, odd rows → 1.
1220        let expected: Vec<Option<i32>> = (0..n)
1221            .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1222            .collect();
1223        assert_arrays_eq!(
1224            result,
1225            PrimitiveArray::from_option_iter(expected).into_array()
1226        );
1227        Ok(())
1228    }
1229}