Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25
26use crate::ArrayRef;
27use crate::ExecutionCtx;
28use crate::IntoArray;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::builders::ArrayBuilder;
32use crate::builders::builder_with_capacity;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::Expression;
36use crate::scalar::Scalar;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnVTable;
42use crate::scalar_fn::fns::zip::zip_impl;
43
44/// Options for the n-ary CaseWhen expression.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct CaseWhenOptions {
47    /// Number of WHEN/THEN pairs.
48    pub num_when_then_pairs: u32,
49    /// Whether an ELSE clause is present.
50    /// If false, unmatched rows return NULL.
51    pub has_else: bool,
52}
53
54impl CaseWhenOptions {
55    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
56    pub fn num_children(&self) -> usize {
57        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
58    }
59}
60
61impl fmt::Display for CaseWhenOptions {
62    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63        write!(
64            f,
65            "case_when(pairs={}, else={})",
66            self.num_when_then_pairs, self.has_else
67        )
68    }
69}
70
71/// An n-ary CASE WHEN expression.
72///
73/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
74#[derive(Clone)]
75pub struct CaseWhen;
76
77impl ScalarFnVTable for CaseWhen {
78    type Options = CaseWhenOptions;
79
80    fn id(&self) -> ScalarFnId {
81        ScalarFnId::from("vortex.case_when")
82    }
83
84    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
85        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
86        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
87        // stabilize the expr
88        vortex_bail!("cannot serialize")
89    }
90
91    fn deserialize(
92        &self,
93        metadata: &[u8],
94        _session: &VortexSession,
95    ) -> VortexResult<Self::Options> {
96        let opts = pb::CaseWhenOpts::decode(metadata)?;
97        if opts.num_children < 2 {
98            vortex_bail!(
99                "CaseWhen expects at least 2 children, got {}",
100                opts.num_children
101            );
102        }
103        Ok(CaseWhenOptions {
104            num_when_then_pairs: opts.num_children / 2,
105            has_else: opts.num_children % 2 == 1,
106        })
107    }
108
109    fn arity(&self, options: &Self::Options) -> Arity {
110        Arity::Exact(options.num_children())
111    }
112
113    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
114        let num_pair_children = options.num_when_then_pairs as usize * 2;
115        if child_idx < num_pair_children {
116            let pair_idx = child_idx / 2;
117            if child_idx.is_multiple_of(2) {
118                ChildName::from(Arc::from(format!("when_{pair_idx}")))
119            } else {
120                ChildName::from(Arc::from(format!("then_{pair_idx}")))
121            }
122        } else if options.has_else && child_idx == num_pair_children {
123            ChildName::from("else")
124        } else {
125            unreachable!("Invalid child index {} for CaseWhen", child_idx)
126        }
127    }
128
129    fn fmt_sql(
130        &self,
131        options: &Self::Options,
132        expr: &Expression,
133        f: &mut Formatter<'_>,
134    ) -> fmt::Result {
135        write!(f, "CASE")?;
136        for i in 0..options.num_when_then_pairs as usize {
137            write!(
138                f,
139                " WHEN {} THEN {}",
140                expr.child(i * 2),
141                expr.child(i * 2 + 1)
142            )?;
143        }
144        if options.has_else {
145            let else_idx = options.num_when_then_pairs as usize * 2;
146            write!(f, " ELSE {}", expr.child(else_idx))?;
147        }
148        write!(f, " END")
149    }
150
151    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
152        if options.num_when_then_pairs == 0 {
153            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
154        }
155
156        let expected_len = options.num_children();
157        if arg_dtypes.len() != expected_len {
158            vortex_bail!(
159                "CaseWhen expects {expected_len} argument dtypes, got {}",
160                arg_dtypes.len()
161            );
162        }
163
164        // Unlike SQL which coerces all branches to a common supertype, we require
165        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
166        // The result nullability is the union of all branches.
167        let first_then = &arg_dtypes[1];
168        let mut result_dtype = first_then.clone();
169
170        for i in 1..options.num_when_then_pairs as usize {
171            let then_i = &arg_dtypes[i * 2 + 1];
172            if !first_then.eq_ignore_nullability(then_i) {
173                vortex_bail!(
174                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
175                    first_then,
176                    then_i
177                );
178            }
179            result_dtype = result_dtype.union_nullability(then_i.nullability());
180        }
181
182        if options.has_else {
183            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
184            if !result_dtype.eq_ignore_nullability(else_dtype) {
185                vortex_bail!(
186                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
187                    first_then,
188                    else_dtype
189                );
190            }
191            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
192        } else {
193            // No ELSE means unmatched rows are NULL
194            result_dtype = result_dtype.as_nullable();
195        }
196
197        Ok(result_dtype)
198    }
199
200    fn execute(
201        &self,
202        options: &Self::Options,
203        args: &dyn ExecutionArgs,
204        ctx: &mut ExecutionCtx,
205    ) -> VortexResult<ArrayRef> {
206        // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/
207        //
208        // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction).
209        // TODO: project to only referenced columns before batch reduction (column projection).
210        // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results.
211        // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup.
212        let row_count = args.row_count();
213        let num_pairs = options.num_when_then_pairs as usize;
214
215        let mut remaining = Mask::new_true(row_count);
216        let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
217
218        for i in 0..num_pairs {
219            if remaining.all_false() {
220                break;
221            }
222
223            let condition = args.get(i * 2)?;
224            let cond_bool = condition.execute::<BoolArray>(ctx)?;
225            let cond_mask = cond_bool.to_mask_fill_null_false();
226            let effective_mask = &remaining & &cond_mask;
227
228            if effective_mask.all_false() {
229                continue;
230            }
231
232            let then_value = args.get(i * 2 + 1)?;
233            remaining = remaining.bitand_not(&cond_mask);
234            branches.push((effective_mask, then_value));
235        }
236
237        let else_value: ArrayRef = if options.has_else {
238            args.get(num_pairs * 2)?
239        } else {
240            let then_dtype = args.get(1)?.dtype().as_nullable();
241            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
242        };
243
244        if branches.is_empty() {
245            return Ok(else_value);
246        }
247
248        merge_case_branches(branches, else_value)
249    }
250
251    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
252        true
253    }
254
255    fn is_fallible(&self, _options: &Self::Options) -> bool {
256        false
257    }
258}
259
260/// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`.
261/// Measured empirically via benchmarks.
262const SLICE_CROSSOVER_RUN_LEN: usize = 4;
263
264/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array.
265///
266/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
267fn merge_case_branches(
268    branches: Vec<(Mask, ArrayRef)>,
269    else_value: ArrayRef,
270) -> VortexResult<ArrayRef> {
271    if branches.len() == 1 {
272        let (mask, then_value) = &branches[0];
273        return zip_impl(then_value, &else_value, mask);
274    }
275
276    let output_nullability = branches
277        .iter()
278        .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
279            acc | arr.dtype().nullability()
280        });
281    let output_dtype = else_value.dtype().with_nullability(output_nullability);
282    let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
283
284    let mut spans: Vec<(usize, usize, usize)> = Vec::new();
285    for (branch_idx, (mask, _)) in branches.iter().enumerate() {
286        match mask.slices() {
287            AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
288            AllOr::None => {}
289            AllOr::Some(slices) => {
290                for &(start, end) in slices {
291                    spans.push((start, end, branch_idx));
292                }
293            }
294        }
295    }
296    spans.sort_unstable_by_key(|&(start, ..)| start);
297
298    if spans.is_empty() {
299        return else_value.cast(output_dtype);
300    }
301
302    let builder = builder_with_capacity(&output_dtype, else_value.len());
303
304    let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
305    if fragmented {
306        merge_row_by_row(&branch_arrays, &else_value, &spans, &output_dtype, builder)
307    } else {
308        merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
309    }
310}
311
312/// Iterates spans directly, emitting one `scalar_at` per row.
313/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]).
314fn merge_row_by_row(
315    branch_arrays: &[&ArrayRef],
316    else_value: &ArrayRef,
317    spans: &[(usize, usize, usize)],
318    output_dtype: &DType,
319    mut builder: Box<dyn ArrayBuilder>,
320) -> VortexResult<ArrayRef> {
321    let mut pos = 0;
322    for &(start, end, branch_idx) in spans {
323        for row in pos..start {
324            let scalar = else_value.scalar_at(row)?;
325            builder.append_scalar(&scalar.cast(output_dtype)?)?;
326        }
327        for row in start..end {
328            let scalar = branch_arrays[branch_idx].scalar_at(row)?;
329            builder.append_scalar(&scalar.cast(output_dtype)?)?;
330        }
331        pos = end;
332    }
333    for row in pos..else_value.len() {
334        let scalar = else_value.scalar_at(row)?;
335        builder.append_scalar(&scalar.cast(output_dtype)?)?;
336    }
337
338    Ok(builder.finish())
339}
340
341/// Bulk-copies each span via `slice()` + `extend_from_array`.
342/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
343/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
344fn merge_run_by_run(
345    branch_arrays: &[&ArrayRef],
346    else_value: &ArrayRef,
347    spans: &[(usize, usize, usize)],
348    output_dtype: &DType,
349    mut builder: Box<dyn ArrayBuilder>,
350) -> VortexResult<ArrayRef> {
351    let else_value = else_value.cast(output_dtype.clone())?;
352    let len = else_value.len();
353    for (start, end, branch_idx) in spans {
354        if builder.len() < *start {
355            builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
356        }
357        builder.extend_from_array(
358            &branch_arrays[*branch_idx]
359                .cast(output_dtype.clone())?
360                .slice(*start..*end)?,
361        );
362    }
363    if builder.len() < len {
364        builder.extend_from_array(&else_value.slice(builder.len()..len)?);
365    }
366
367    Ok(builder.finish())
368}
369
370#[cfg(test)]
371mod tests {
372    use std::sync::LazyLock;
373
374    use vortex_buffer::buffer;
375    use vortex_error::VortexExpect as _;
376    use vortex_session::VortexSession;
377
378    use super::*;
379    use crate::Canonical;
380    use crate::IntoArray;
381    use crate::VortexSessionExecute as _;
382    use crate::arrays::BoolArray;
383    use crate::arrays::PrimitiveArray;
384    use crate::arrays::StructArray;
385    use crate::assert_arrays_eq;
386    use crate::dtype::DType;
387    use crate::dtype::Nullability;
388    use crate::dtype::PType;
389    use crate::expr::case_when;
390    use crate::expr::case_when_no_else;
391    use crate::expr::col;
392    use crate::expr::eq;
393    use crate::expr::get_item;
394    use crate::expr::gt;
395    use crate::expr::lit;
396    use crate::expr::nested_case_when;
397    use crate::expr::root;
398    use crate::expr::test_harness;
399    use crate::scalar::Scalar;
400    use crate::session::ArraySession;
401
402    static SESSION: LazyLock<VortexSession> =
403        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
404
405    /// Helper to evaluate an expression using the apply+execute pattern
406    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
407        let mut ctx = SESSION.create_execution_ctx();
408        array
409            .apply(expr)
410            .unwrap()
411            .execute::<Canonical>(&mut ctx)
412            .unwrap()
413            .into_array()
414    }
415
416    // ==================== Serialization Tests ====================
417
418    #[test]
419    #[should_panic(expected = "cannot serialize")]
420    fn test_serialization_roundtrip() {
421        let options = CaseWhenOptions {
422            num_when_then_pairs: 1,
423            has_else: true,
424        };
425        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
426        let deserialized = CaseWhen
427            .deserialize(&serialized, &VortexSession::empty())
428            .unwrap();
429        assert_eq!(options, deserialized);
430    }
431
432    #[test]
433    #[should_panic(expected = "cannot serialize")]
434    fn test_serialization_no_else() {
435        let options = CaseWhenOptions {
436            num_when_then_pairs: 1,
437            has_else: false,
438        };
439        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
440        let deserialized = CaseWhen
441            .deserialize(&serialized, &VortexSession::empty())
442            .unwrap();
443        assert_eq!(options, deserialized);
444    }
445
446    // ==================== Display Tests ====================
447
448    #[test]
449    fn test_display_with_else() {
450        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
451        let display = format!("{}", expr);
452        assert!(display.contains("CASE"));
453        assert!(display.contains("WHEN"));
454        assert!(display.contains("THEN"));
455        assert!(display.contains("ELSE"));
456        assert!(display.contains("END"));
457    }
458
459    #[test]
460    fn test_display_no_else() {
461        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
462        let display = format!("{}", expr);
463        assert!(display.contains("CASE"));
464        assert!(display.contains("WHEN"));
465        assert!(display.contains("THEN"));
466        assert!(!display.contains("ELSE"));
467        assert!(display.contains("END"));
468    }
469
470    #[test]
471    fn test_display_nested_nary() {
472        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
473        let expr = nested_case_when(
474            vec![
475                (gt(col("x"), lit(10i32)), lit("high")),
476                (gt(col("x"), lit(5i32)), lit("medium")),
477            ],
478            Some(lit("low")),
479        );
480        let display = format!("{}", expr);
481        assert_eq!(display.matches("CASE").count(), 1);
482        assert_eq!(display.matches("WHEN").count(), 2);
483        assert_eq!(display.matches("THEN").count(), 2);
484    }
485
486    // ==================== DType Tests ====================
487
488    #[test]
489    fn test_return_dtype_with_else() {
490        let expr = case_when(lit(true), lit(100i32), lit(0i32));
491        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
492        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
493        assert_eq!(
494            result_dtype,
495            DType::Primitive(PType::I32, Nullability::NonNullable)
496        );
497    }
498
499    #[test]
500    fn test_return_dtype_with_nullable_else() {
501        let expr = case_when(
502            lit(true),
503            lit(100i32),
504            lit(Scalar::null(DType::Primitive(
505                PType::I32,
506                Nullability::Nullable,
507            ))),
508        );
509        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
510        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
511        assert_eq!(
512            result_dtype,
513            DType::Primitive(PType::I32, Nullability::Nullable)
514        );
515    }
516
517    #[test]
518    fn test_return_dtype_without_else_is_nullable() {
519        let expr = case_when_no_else(lit(true), lit(100i32));
520        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
521        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
522        assert_eq!(
523            result_dtype,
524            DType::Primitive(PType::I32, Nullability::Nullable)
525        );
526    }
527
528    #[test]
529    fn test_return_dtype_with_struct_input() {
530        let dtype = test_harness::struct_dtype();
531        let expr = case_when(
532            gt(get_item("col1", root()), lit(10u16)),
533            lit(100i32),
534            lit(0i32),
535        );
536        let result_dtype = expr.return_dtype(&dtype).unwrap();
537        assert_eq!(
538            result_dtype,
539            DType::Primitive(PType::I32, Nullability::NonNullable)
540        );
541    }
542
543    #[test]
544    fn test_return_dtype_mismatched_then_else_errors() {
545        let expr = case_when(lit(true), lit(100i32), lit("zero"));
546        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
547        let err = expr.return_dtype(&input_dtype).unwrap_err();
548        assert!(
549            err.to_string()
550                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
551        );
552    }
553
554    // ==================== Arity Tests ====================
555
556    #[test]
557    fn test_arity_with_else() {
558        let options = CaseWhenOptions {
559            num_when_then_pairs: 1,
560            has_else: true,
561        };
562        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
563    }
564
565    #[test]
566    fn test_arity_without_else() {
567        let options = CaseWhenOptions {
568            num_when_then_pairs: 1,
569            has_else: false,
570        };
571        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
572    }
573
574    // ==================== Child Name Tests ====================
575
576    #[test]
577    fn test_child_names() {
578        let options = CaseWhenOptions {
579            num_when_then_pairs: 1,
580            has_else: true,
581        };
582        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
583        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
584        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
585    }
586
587    // ==================== N-ary Serialization Tests ====================
588
589    #[test]
590    #[should_panic(expected = "cannot serialize")]
591    fn test_serialization_roundtrip_nary() {
592        let options = CaseWhenOptions {
593            num_when_then_pairs: 3,
594            has_else: true,
595        };
596        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
597        let deserialized = CaseWhen
598            .deserialize(&serialized, &VortexSession::empty())
599            .unwrap();
600        assert_eq!(options, deserialized);
601    }
602
603    #[test]
604    #[should_panic(expected = "cannot serialize")]
605    fn test_serialization_roundtrip_nary_no_else() {
606        let options = CaseWhenOptions {
607            num_when_then_pairs: 4,
608            has_else: false,
609        };
610        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
611        let deserialized = CaseWhen
612            .deserialize(&serialized, &VortexSession::empty())
613            .unwrap();
614        assert_eq!(options, deserialized);
615    }
616
617    // ==================== N-ary Arity Tests ====================
618
619    #[test]
620    fn test_arity_nary_with_else() {
621        let options = CaseWhenOptions {
622            num_when_then_pairs: 3,
623            has_else: true,
624        };
625        // 3 pairs * 2 children + 1 else = 7
626        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
627    }
628
629    #[test]
630    fn test_arity_nary_without_else() {
631        let options = CaseWhenOptions {
632            num_when_then_pairs: 3,
633            has_else: false,
634        };
635        // 3 pairs * 2 children = 6
636        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
637    }
638
639    // ==================== N-ary Child Name Tests ====================
640
641    #[test]
642    fn test_child_names_nary() {
643        let options = CaseWhenOptions {
644            num_when_then_pairs: 3,
645            has_else: true,
646        };
647        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
648        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
649        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
650        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
651        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
652        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
653        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
654    }
655
656    // ==================== N-ary DType Tests ====================
657
658    #[test]
659    fn test_return_dtype_nary_mismatched_then_types_errors() {
660        let expr = nested_case_when(
661            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
662            Some(lit(0i32)),
663        );
664        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
665        let err = expr.return_dtype(&input_dtype).unwrap_err();
666        assert!(err.to_string().contains("THEN dtypes must match"));
667    }
668
669    #[test]
670    fn test_return_dtype_nary_mixed_nullability() {
671        // When some THEN branches are nullable and others are not,
672        // the result should be nullable (union of nullabilities).
673        let non_null_then = lit(100i32);
674        let nullable_then = lit(Scalar::null(DType::Primitive(
675            PType::I32,
676            Nullability::Nullable,
677        )));
678        let expr = nested_case_when(
679            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
680            Some(lit(0i32)),
681        );
682        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
683        let result = expr.return_dtype(&input_dtype).unwrap();
684        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
685    }
686
687    #[test]
688    fn test_return_dtype_nary_no_else_is_nullable() {
689        let expr = nested_case_when(
690            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
691            None,
692        );
693        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
694        let result = expr.return_dtype(&input_dtype).unwrap();
695        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
696    }
697
698    // ==================== Expression Manipulation Tests ====================
699
700    #[test]
701    fn test_replace_children() {
702        let expr = case_when(lit(true), lit(1i32), lit(0i32));
703        expr.with_children([lit(false), lit(2i32), lit(3i32)])
704            .vortex_expect("operation should succeed in test");
705    }
706
707    // ==================== Evaluate Tests ====================
708
709    #[test]
710    fn test_evaluate_simple_condition() {
711        let test_array =
712            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
713                .unwrap()
714                .into_array();
715
716        let expr = case_when(
717            gt(get_item("value", root()), lit(2i32)),
718            lit(100i32),
719            lit(0i32),
720        );
721
722        let result = evaluate_expr(&expr, &test_array);
723        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
724    }
725
726    #[test]
727    fn test_evaluate_nary_multiple_conditions() {
728        // Test n-ary via nested_case_when
729        let test_array =
730            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
731                .unwrap()
732                .into_array();
733
734        let expr = nested_case_when(
735            vec![
736                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
737                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
738            ],
739            Some(lit(0i32)),
740        );
741
742        let result = evaluate_expr(&expr, &test_array);
743        assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
744    }
745
746    #[test]
747    fn test_evaluate_nary_first_match_wins() {
748        let test_array =
749            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
750                .unwrap()
751                .into_array();
752
753        // Both conditions match for values > 3, but first one wins
754        let expr = nested_case_when(
755            vec![
756                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
757                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
758            ],
759            Some(lit(0i32)),
760        );
761
762        let result = evaluate_expr(&expr, &test_array);
763        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
764    }
765
766    #[test]
767    fn test_evaluate_no_else_returns_null() {
768        let test_array =
769            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
770                .unwrap()
771                .into_array();
772
773        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
774
775        let result = evaluate_expr(&expr, &test_array);
776        assert!(result.dtype().is_nullable());
777        assert_arrays_eq!(
778            result,
779            PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
780                .into_array()
781        );
782    }
783
784    #[test]
785    fn test_evaluate_all_conditions_false() {
786        let test_array =
787            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
788                .unwrap()
789                .into_array();
790
791        let expr = case_when(
792            gt(get_item("value", root()), lit(100i32)),
793            lit(1i32),
794            lit(0i32),
795        );
796
797        let result = evaluate_expr(&expr, &test_array);
798        assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
799    }
800
801    #[test]
802    fn test_evaluate_all_conditions_true() {
803        let test_array =
804            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
805                .unwrap()
806                .into_array();
807
808        let expr = case_when(
809            gt(get_item("value", root()), lit(0i32)),
810            lit(100i32),
811            lit(0i32),
812        );
813
814        let result = evaluate_expr(&expr, &test_array);
815        assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
816    }
817
818    #[test]
819    fn test_evaluate_all_true_no_else_returns_correct_dtype() {
820        // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
821        // Result must be Nullable because the implicit ELSE is NULL.
822        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
823            .unwrap()
824            .into_array();
825
826        let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
827
828        let result = evaluate_expr(&expr, &test_array);
829        assert!(
830            result.dtype().is_nullable(),
831            "result dtype must be Nullable, got {:?}",
832            result.dtype()
833        );
834        assert_arrays_eq!(
835            result,
836            PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
837        );
838    }
839
840    #[test]
841    fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
842        // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
843        // the result dtype must still be Nullable.
844        //
845        // CASE WHEN value = 0 THEN 10          -- NonNullable
846        //      WHEN value = 1 THEN nullable(20) -- Nullable
847        //      ELSE 0                           -- NonNullable
848        // → result must be Nullable(i32)
849        let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
850            .unwrap()
851            .into_array();
852
853        let nullable_20 =
854            Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
855
856        let expr = nested_case_when(
857            vec![
858                (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
859                (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
860            ],
861            Some(lit(0i32)),
862        );
863
864        let result = evaluate_expr(&expr, &test_array);
865        assert!(
866            result.dtype().is_nullable(),
867            "result dtype must be Nullable, got {:?}",
868            result.dtype()
869        );
870        assert_arrays_eq!(
871            result,
872            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
873        );
874        Ok(())
875    }
876
877    #[test]
878    fn test_evaluate_with_literal_condition() {
879        let test_array = buffer![1i32, 2, 3].into_array();
880        let expr = case_when(lit(true), lit(100i32), lit(0i32));
881        let result = evaluate_expr(&expr, &test_array);
882
883        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
884    }
885
886    #[test]
887    fn test_evaluate_with_bool_column_result() {
888        let test_array =
889            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
890                .unwrap()
891                .into_array();
892
893        let expr = case_when(
894            gt(get_item("value", root()), lit(2i32)),
895            lit(true),
896            lit(false),
897        );
898
899        let result = evaluate_expr(&expr, &test_array);
900        assert_arrays_eq!(
901            result,
902            BoolArray::from_iter([false, false, true, true, true]).into_array()
903        );
904    }
905
906    #[test]
907    fn test_evaluate_with_nullable_condition() {
908        let test_array = StructArray::from_fields(&[(
909            "cond",
910            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
911        )])
912        .unwrap()
913        .into_array();
914
915        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
916
917        let result = evaluate_expr(&expr, &test_array);
918        assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
919    }
920
921    #[test]
922    fn test_evaluate_with_nullable_result_values() {
923        let test_array = StructArray::from_fields(&[
924            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
925            (
926                "result",
927                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
928                    .into_array(),
929            ),
930        ])
931        .unwrap()
932        .into_array();
933
934        let expr = case_when(
935            gt(get_item("value", root()), lit(2i32)),
936            get_item("result", root()),
937            lit(0i32),
938        );
939
940        let result = evaluate_expr(&expr, &test_array);
941        assert_arrays_eq!(
942            result,
943            PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
944                .into_array()
945        );
946    }
947
948    #[test]
949    fn test_evaluate_with_all_null_condition() {
950        let test_array = StructArray::from_fields(&[(
951            "cond",
952            BoolArray::from_iter([None, None, None]).into_array(),
953        )])
954        .unwrap()
955        .into_array();
956
957        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
958
959        let result = evaluate_expr(&expr, &test_array);
960        assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
961    }
962
963    // ==================== N-ary Evaluate Tests ====================
964
965    #[test]
966    fn test_evaluate_nary_no_else_returns_null() {
967        let test_array =
968            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
969                .unwrap()
970                .into_array();
971
972        // Two conditions, no ELSE — unmatched rows should be NULL
973        let expr = nested_case_when(
974            vec![
975                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
976                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
977            ],
978            None,
979        );
980
981        let result = evaluate_expr(&expr, &test_array);
982        assert!(result.dtype().is_nullable());
983        assert_arrays_eq!(
984            result,
985            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
986                .into_array()
987        );
988    }
989
990    #[test]
991    fn test_evaluate_nary_many_conditions() {
992        let test_array =
993            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
994                .unwrap()
995                .into_array();
996
997        // 5 WHEN/THEN pairs: each value maps to its value * 10
998        let expr = nested_case_when(
999            vec![
1000                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1001                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1002                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1003                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1004                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1005            ],
1006            Some(lit(0i32)),
1007        );
1008
1009        let result = evaluate_expr(&expr, &test_array);
1010        assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1011    }
1012
1013    #[test]
1014    fn test_evaluate_nary_all_false_no_else() {
1015        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1016            .unwrap()
1017            .into_array();
1018
1019        // All conditions are false, no ELSE — everything should be NULL
1020        let expr = nested_case_when(
1021            vec![
1022                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1023                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1024            ],
1025            None,
1026        );
1027
1028        let result = evaluate_expr(&expr, &test_array);
1029        assert!(result.dtype().is_nullable());
1030        assert_arrays_eq!(
1031            result,
1032            PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1033        );
1034    }
1035
1036    #[test]
1037    fn test_evaluate_nary_overlapping_conditions_first_wins() {
1038        let test_array =
1039            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1040                .unwrap()
1041                .into_array();
1042
1043        // value=10: matches cond1 (>5) and cond2 (>0), first should win
1044        // value=20: matches all three, first should win
1045        // value=30: matches all three, first should win
1046        let expr = nested_case_when(
1047            vec![
1048                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1049                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1050                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1051            ],
1052            Some(lit(0i32)),
1053        );
1054
1055        let result = evaluate_expr(&expr, &test_array);
1056        // First matching condition always wins
1057        assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1058    }
1059
1060    #[test]
1061    fn test_evaluate_nary_early_exit_when_remaining_empty() {
1062        // After branch 0 claims all rows, remaining becomes all_false.
1063        // The loop breaks before evaluating branch 1's condition.
1064        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1065            .unwrap()
1066            .into_array();
1067
1068        let expr = nested_case_when(
1069            vec![
1070                (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1071                // Never evaluated due to early exit; 999 must never appear in output.
1072                (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1073            ],
1074            Some(lit(0i32)),
1075        );
1076
1077        let result = evaluate_expr(&expr, &test_array);
1078        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1079    }
1080
1081    #[test]
1082    fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1083        // Branch 0 claims value=1. Branch 1 targets the same rows but they are already
1084        // matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
1085        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1086            .unwrap()
1087            .into_array();
1088
1089        let expr = nested_case_when(
1090            vec![
1091                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1092                // Same condition as branch 0 — all matching rows already claimed → skipped.
1093                // 999 must never appear in output.
1094                (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1095                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1096            ],
1097            Some(lit(0i32)),
1098        );
1099
1100        let result = evaluate_expr(&expr, &test_array);
1101        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1102    }
1103
1104    #[test]
1105    fn test_evaluate_nary_string_output() -> VortexResult<()> {
1106        // Exercises merge_case_branches with a non-primitive (Utf8) builder.
1107        let test_array =
1108            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
1109                .unwrap()
1110                .into_array();
1111
1112        // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
1113        // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
1114        // value=3,4 → 'high' (branch 0)
1115        let expr = nested_case_when(
1116            vec![
1117                (gt(get_item("value", root()), lit(2i32)), lit("high")),
1118                (gt(get_item("value", root()), lit(0i32)), lit("low")),
1119            ],
1120            Some(lit("none")),
1121        );
1122
1123        let result = evaluate_expr(&expr, &test_array);
1124        assert_eq!(
1125            result.scalar_at(0)?,
1126            Scalar::utf8("low", Nullability::NonNullable)
1127        );
1128        assert_eq!(
1129            result.scalar_at(1)?,
1130            Scalar::utf8("low", Nullability::NonNullable)
1131        );
1132        assert_eq!(
1133            result.scalar_at(2)?,
1134            Scalar::utf8("high", Nullability::NonNullable)
1135        );
1136        assert_eq!(
1137            result.scalar_at(3)?,
1138            Scalar::utf8("high", Nullability::NonNullable)
1139        );
1140        Ok(())
1141    }
1142
1143    #[test]
1144    fn test_evaluate_nary_with_nullable_conditions() {
1145        let test_array = StructArray::from_fields(&[
1146            (
1147                "cond1",
1148                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1149            ),
1150            (
1151                "cond2",
1152                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1153            ),
1154        ])
1155        .unwrap()
1156        .into_array();
1157
1158        let expr = nested_case_when(
1159            vec![
1160                (get_item("cond1", root()), lit(10i32)),
1161                (get_item("cond2", root()), lit(20i32)),
1162            ],
1163            Some(lit(0i32)),
1164        );
1165
1166        let result = evaluate_expr(&expr, &test_array);
1167        // row 0: cond1=true → 10
1168        // row 1: cond1=NULL(→false), cond2=true → 20
1169        // row 2: cond1=false, cond2=NULL(→false) → else=0
1170        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1171    }
1172
1173    #[test]
1174    fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1175        // Exercises the scalar path: alternating rows produce one slice per row (no runs),
1176        // triggering the per-row cursor path in merge_case_branches.
1177        let n = 100usize;
1178
1179        // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached.
1180        let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
1181        let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
1182
1183        let result = merge_case_branches(
1184            vec![
1185                (
1186                    branch0_mask,
1187                    PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1188                ),
1189                (
1190                    branch1_mask,
1191                    PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1192                ),
1193            ],
1194            PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1195        )?;
1196
1197        // Even rows → 0, odd rows → 1.
1198        let expected: Vec<Option<i32>> = (0..n)
1199            .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1200            .collect();
1201        assert_arrays_eq!(
1202            result,
1203            PrimitiveArray::from_option_iter(expected).into_array()
1204        );
1205        Ok(())
1206    }
1207}