Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25use vortex_session::registry::CachedId;
26
27use crate::ArrayRef;
28use crate::ExecutionCtx;
29use crate::IntoArray;
30use crate::arrays::BoolArray;
31use crate::arrays::ConstantArray;
32use crate::arrays::bool::BoolArrayExt;
33use crate::builders::ArrayBuilder;
34use crate::builders::builder_with_capacity;
35use crate::builtins::ArrayBuiltins;
36use crate::dtype::DType;
37use crate::expr::Expression;
38use crate::scalar::Scalar;
39use crate::scalar_fn::Arity;
40use crate::scalar_fn::ChildName;
41use crate::scalar_fn::ExecutionArgs;
42use crate::scalar_fn::ScalarFnId;
43use crate::scalar_fn::ScalarFnVTable;
44use crate::scalar_fn::SimplifyCtx;
45use crate::scalar_fn::fns::is_not_null::IsNotNull;
46use crate::scalar_fn::fns::is_null::IsNull;
47use crate::scalar_fn::fns::literal::Literal;
48use crate::scalar_fn::fns::zip::zip_impl;
49
50/// Options for the n-ary CaseWhen expression.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct CaseWhenOptions {
53    /// Number of WHEN/THEN pairs.
54    pub num_when_then_pairs: u32,
55    /// Whether an ELSE clause is present.
56    /// If false, unmatched rows return NULL.
57    pub has_else: bool,
58}
59
60impl CaseWhenOptions {
61    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
62    pub fn num_children(&self) -> usize {
63        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
64    }
65}
66
67impl fmt::Display for CaseWhenOptions {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        write!(
70            f,
71            "case_when(pairs={}, else={})",
72            self.num_when_then_pairs, self.has_else
73        )
74    }
75}
76
77/// An n-ary CASE WHEN expression.
78///
79/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
80#[derive(Clone)]
81pub struct CaseWhen;
82
83impl ScalarFnVTable for CaseWhen {
84    type Options = CaseWhenOptions;
85
86    fn id(&self) -> ScalarFnId {
87        static ID: CachedId = CachedId::new("vortex.case_when");
88        *ID
89    }
90
91    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
92        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
93        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
94        // stabilize the expr
95        vortex_bail!("cannot serialize")
96    }
97
98    fn deserialize(
99        &self,
100        metadata: &[u8],
101        _session: &VortexSession,
102    ) -> VortexResult<Self::Options> {
103        let opts = pb::CaseWhenOpts::decode(metadata)?;
104        if opts.num_children < 2 {
105            vortex_bail!(
106                "CaseWhen expects at least 2 children, got {}",
107                opts.num_children
108            );
109        }
110        Ok(CaseWhenOptions {
111            num_when_then_pairs: opts.num_children / 2,
112            has_else: opts.num_children % 2 == 1,
113        })
114    }
115
116    fn arity(&self, options: &Self::Options) -> Arity {
117        Arity::Exact(options.num_children())
118    }
119
120    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
121        let num_pair_children = options.num_when_then_pairs as usize * 2;
122        if child_idx < num_pair_children {
123            let pair_idx = child_idx / 2;
124            if child_idx.is_multiple_of(2) {
125                ChildName::from(Arc::from(format!("when_{pair_idx}")))
126            } else {
127                ChildName::from(Arc::from(format!("then_{pair_idx}")))
128            }
129        } else if options.has_else && child_idx == num_pair_children {
130            ChildName::from("else")
131        } else {
132            unreachable!("Invalid child index {} for CaseWhen", child_idx)
133        }
134    }
135
136    fn fmt_sql(
137        &self,
138        options: &Self::Options,
139        expr: &Expression,
140        f: &mut Formatter<'_>,
141    ) -> fmt::Result {
142        write!(f, "CASE")?;
143        for i in 0..options.num_when_then_pairs as usize {
144            write!(
145                f,
146                " WHEN {} THEN {}",
147                expr.child(i * 2),
148                expr.child(i * 2 + 1)
149            )?;
150        }
151        if options.has_else {
152            let else_idx = options.num_when_then_pairs as usize * 2;
153            write!(f, " ELSE {}", expr.child(else_idx))?;
154        }
155        write!(f, " END")
156    }
157
158    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
159        if options.num_when_then_pairs == 0 {
160            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
161        }
162
163        let expected_len = options.num_children();
164        if arg_dtypes.len() != expected_len {
165            vortex_bail!(
166                "CaseWhen expects {expected_len} argument dtypes, got {}",
167                arg_dtypes.len()
168            );
169        }
170
171        // Unlike SQL which coerces all branches to a common supertype, we require
172        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
173        // The result nullability is the union of all branches.
174        let first_then = &arg_dtypes[1];
175        let mut result_dtype = first_then.clone();
176
177        for i in 1..options.num_when_then_pairs as usize {
178            let then_i = &arg_dtypes[i * 2 + 1];
179            if !first_then.eq_ignore_nullability(then_i) {
180                vortex_bail!(
181                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
182                    first_then,
183                    then_i
184                );
185            }
186            result_dtype = result_dtype.union_nullability(then_i.nullability());
187        }
188
189        if options.has_else {
190            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
191            if !result_dtype.eq_ignore_nullability(else_dtype) {
192                vortex_bail!(
193                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
194                    first_then,
195                    else_dtype
196                );
197            }
198            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
199        } else {
200            // No ELSE means unmatched rows are NULL
201            result_dtype = result_dtype.as_nullable();
202        }
203
204        Ok(result_dtype)
205    }
206
207    fn execute(
208        &self,
209        options: &Self::Options,
210        args: &dyn ExecutionArgs,
211        ctx: &mut ExecutionCtx,
212    ) -> VortexResult<ArrayRef> {
213        // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/
214        //
215        // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction).
216        // TODO: project to only referenced columns before batch reduction (column projection).
217        // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results.
218        // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup.
219        let row_count = args.row_count();
220        let num_pairs = options.num_when_then_pairs as usize;
221
222        let mut remaining = Mask::new_true(row_count);
223        let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
224
225        for i in 0..num_pairs {
226            if remaining.all_false() {
227                break;
228            }
229
230            let condition = args.get(i * 2)?;
231            let cond_bool = condition.execute::<BoolArray>(ctx)?;
232            let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
233            let effective_mask = &remaining & &cond_mask;
234
235            if effective_mask.all_false() {
236                continue;
237            }
238
239            let then_value = args.get(i * 2 + 1)?;
240            remaining = remaining.bitand_not(&cond_mask);
241            branches.push((effective_mask, then_value));
242        }
243
244        let else_value: ArrayRef = if options.has_else {
245            args.get(num_pairs * 2)?
246        } else {
247            let then_dtype = args.get(1)?.dtype().as_nullable();
248            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
249        };
250
251        if branches.is_empty() {
252            return Ok(else_value);
253        }
254
255        merge_case_branches(branches, else_value, ctx)
256    }
257
258    fn simplify(
259        &self,
260        options: &Self::Options,
261        expr: &Expression,
262        _ctx: &dyn SimplifyCtx,
263    ) -> VortexResult<Option<Expression>> {
264        // Rewrite the COALESCE-shaped CASE WHEN into `fill_null`, which references `x`
265        // once and lowers to a single fill kernel instead of a `zip`/merge that resolves
266        // `x` twice (once for the `is_null` predicate, once for the value branch).
267        //
268        //   CASE WHEN is_null(x)     THEN c ELSE x END  ==>  fill_null(x, c)
269        //   CASE WHEN is_not_null(x) THEN x ELSE c END  ==>  fill_null(x, c)
270        //
271        // The fill `c` must be a `Literal`: `fill_null`'s kernel reads the fill value via
272        // `as_constant()`, so a non-constant fill would produce an unexecutable expression.
273        if options.num_when_then_pairs != 1 || !options.has_else {
274            return Ok(None);
275        }
276
277        let when = expr.child(0);
278        let then = expr.child(1);
279        let els = expr.child(2);
280
281        // `is_null(x) ? c : x` — predicate operand and ELSE are the same `x`, fill is THEN.
282        let (x, fill) = if when.is::<IsNull>() && when.child(0) == els {
283            (els, then)
284        // `is_not_null(x) ? x : c` — predicate operand and THEN are the same `x`, fill is ELSE.
285        } else if when.is::<IsNotNull>() && when.child(0) == then {
286            (then, els)
287        } else {
288            return Ok(None);
289        };
290
291        let Some(scalar) = fill.as_opt::<Literal>() else {
292            return Ok(None);
293        };
294
295        if scalar.is_null() {
296            // Filling the nulls of `x` with NULL is a no-op
297            return Ok(Some(x.clone()));
298        }
299
300        Ok(Some(crate::expr::fill_null(x.clone(), fill.clone())))
301    }
302
303    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
304        true
305    }
306
307    fn is_fallible(&self, _options: &Self::Options) -> bool {
308        false
309    }
310}
311
312/// Average run length at which slicing + context-aware builder appends become cheaper than `scalar_at`.
313/// Measured empirically via benchmarks.
314const SLICE_CROSSOVER_RUN_LEN: usize = 4;
315
316/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array.
317///
318/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
319fn merge_case_branches(
320    branches: Vec<(Mask, ArrayRef)>,
321    else_value: ArrayRef,
322    ctx: &mut ExecutionCtx,
323) -> VortexResult<ArrayRef> {
324    if branches.len() == 1 {
325        let (mask, then_value) = &branches[0];
326        return zip_impl(then_value, &else_value, mask, ctx);
327    }
328
329    let output_nullability = branches
330        .iter()
331        .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
332            acc | arr.dtype().nullability()
333        });
334    let output_dtype = else_value.dtype().with_nullability(output_nullability);
335    let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
336
337    let mut spans: Vec<(usize, usize, usize)> = Vec::new();
338    for (branch_idx, (mask, _)) in branches.iter().enumerate() {
339        match mask.slices() {
340            AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
341            AllOr::None => {}
342            AllOr::Some(slices) => {
343                for &(start, end) in slices {
344                    spans.push((start, end, branch_idx));
345                }
346            }
347        }
348    }
349    spans.sort_unstable_by_key(|&(start, ..)| start);
350
351    if spans.is_empty() {
352        return else_value.cast(output_dtype);
353    }
354
355    let builder = builder_with_capacity(&output_dtype, else_value.len());
356
357    let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
358    if fragmented {
359        merge_row_by_row(
360            &branch_arrays,
361            &else_value,
362            &spans,
363            &output_dtype,
364            builder,
365            ctx,
366        )
367    } else {
368        merge_run_by_run(
369            &branch_arrays,
370            &else_value,
371            &spans,
372            &output_dtype,
373            builder,
374            ctx,
375        )
376    }
377}
378
379/// Iterates spans directly, emitting one `scalar_at` per row.
380/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]).
381fn merge_row_by_row(
382    branch_arrays: &[&ArrayRef],
383    else_value: &ArrayRef,
384    spans: &[(usize, usize, usize)],
385    output_dtype: &DType,
386    mut builder: Box<dyn ArrayBuilder>,
387    ctx: &mut ExecutionCtx,
388) -> VortexResult<ArrayRef> {
389    let mut pos = 0;
390    for &(start, end, branch_idx) in spans {
391        for row in pos..start {
392            let scalar = else_value.execute_scalar(row, ctx)?;
393            builder.append_scalar(&scalar.cast(output_dtype)?)?;
394        }
395        for row in start..end {
396            let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
397            builder.append_scalar(&scalar.cast(output_dtype)?)?;
398        }
399        pos = end;
400    }
401    for row in pos..else_value.len() {
402        let scalar = else_value.execute_scalar(row, ctx)?;
403        builder.append_scalar(&scalar.cast(output_dtype)?)?;
404    }
405
406    Ok(builder.finish())
407}
408
409/// Bulk-copies each span via `slice()` and context-aware builder appends.
410/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
411/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
412fn merge_run_by_run(
413    branch_arrays: &[&ArrayRef],
414    else_value: &ArrayRef,
415    spans: &[(usize, usize, usize)],
416    output_dtype: &DType,
417    mut builder: Box<dyn ArrayBuilder>,
418    ctx: &mut ExecutionCtx,
419) -> VortexResult<ArrayRef> {
420    let else_value = else_value.cast(output_dtype.clone())?;
421    let len = else_value.len();
422    for (start, end, branch_idx) in spans {
423        if builder.len() < *start {
424            else_value
425                .slice(builder.len()..*start)?
426                .append_to_builder(builder.as_mut(), ctx)?;
427        }
428        branch_arrays[*branch_idx]
429            .cast(output_dtype.clone())?
430            .slice(*start..*end)?
431            .append_to_builder(builder.as_mut(), ctx)?;
432    }
433    if builder.len() < len {
434        else_value
435            .slice(builder.len()..len)?
436            .append_to_builder(builder.as_mut(), ctx)?;
437    }
438
439    Ok(builder.finish())
440}
441
442#[cfg(test)]
443mod tests {
444    use std::sync::LazyLock;
445
446    use vortex_buffer::buffer;
447    use vortex_error::VortexExpect as _;
448    use vortex_session::VortexSession;
449
450    use super::*;
451    use crate::Canonical;
452    use crate::IntoArray;
453    use crate::LEGACY_SESSION;
454    use crate::VortexSessionExecute;
455    use crate::arrays::BoolArray;
456    use crate::arrays::PrimitiveArray;
457    use crate::arrays::StructArray;
458    use crate::assert_arrays_eq;
459    use crate::dtype::DType;
460    use crate::dtype::Nullability;
461    use crate::dtype::PType;
462    use crate::dtype::StructFields;
463    use crate::expr::case_when;
464    use crate::expr::case_when_no_else;
465    use crate::expr::col;
466    use crate::expr::eq;
467    use crate::expr::get_item;
468    use crate::expr::gt;
469    use crate::expr::is_not_null;
470    use crate::expr::is_null;
471    use crate::expr::lit;
472    use crate::expr::nested_case_when;
473    use crate::expr::root;
474    use crate::expr::test_harness;
475    use crate::scalar::Scalar;
476    use crate::session::ArraySession;
477
478    static SESSION: LazyLock<VortexSession> =
479        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
480
481    /// Helper to evaluate an expression using the apply+execute pattern
482    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
483        let mut ctx = SESSION.create_execution_ctx();
484        array
485            .clone()
486            .apply(expr)
487            .unwrap()
488            .execute::<Canonical>(&mut ctx)
489            .unwrap()
490            .into_array()
491    }
492
493    // ==================== Serialization Tests ====================
494
495    #[test]
496    #[should_panic(expected = "cannot serialize")]
497    fn test_serialization_roundtrip() {
498        let options = CaseWhenOptions {
499            num_when_then_pairs: 1,
500            has_else: true,
501        };
502        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
503        let deserialized = CaseWhen
504            .deserialize(&serialized, &VortexSession::empty())
505            .unwrap();
506        assert_eq!(options, deserialized);
507    }
508
509    #[test]
510    #[should_panic(expected = "cannot serialize")]
511    fn test_serialization_no_else() {
512        let options = CaseWhenOptions {
513            num_when_then_pairs: 1,
514            has_else: false,
515        };
516        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
517        let deserialized = CaseWhen
518            .deserialize(&serialized, &VortexSession::empty())
519            .unwrap();
520        assert_eq!(options, deserialized);
521    }
522
523    // ==================== Display Tests ====================
524
525    #[test]
526    fn test_display_with_else() {
527        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
528        let display = format!("{}", expr);
529        assert!(display.contains("CASE"));
530        assert!(display.contains("WHEN"));
531        assert!(display.contains("THEN"));
532        assert!(display.contains("ELSE"));
533        assert!(display.contains("END"));
534    }
535
536    #[test]
537    fn test_display_no_else() {
538        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
539        let display = format!("{}", expr);
540        assert!(display.contains("CASE"));
541        assert!(display.contains("WHEN"));
542        assert!(display.contains("THEN"));
543        assert!(!display.contains("ELSE"));
544        assert!(display.contains("END"));
545    }
546
547    #[test]
548    fn test_display_nested_nary() {
549        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
550        let expr = nested_case_when(
551            vec![
552                (gt(col("x"), lit(10i32)), lit("high")),
553                (gt(col("x"), lit(5i32)), lit("medium")),
554            ],
555            Some(lit("low")),
556        );
557        let display = format!("{}", expr);
558        assert_eq!(display.matches("CASE").count(), 1);
559        assert_eq!(display.matches("WHEN").count(), 2);
560        assert_eq!(display.matches("THEN").count(), 2);
561    }
562
563    // ==================== DType Tests ====================
564
565    #[test]
566    fn test_return_dtype_with_else() {
567        let expr = case_when(lit(true), lit(100i32), lit(0i32));
568        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
569        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
570        assert_eq!(
571            result_dtype,
572            DType::Primitive(PType::I32, Nullability::NonNullable)
573        );
574    }
575
576    #[test]
577    fn test_return_dtype_with_nullable_else() {
578        let expr = case_when(
579            lit(true),
580            lit(100i32),
581            lit(Scalar::null(DType::Primitive(
582                PType::I32,
583                Nullability::Nullable,
584            ))),
585        );
586        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
587        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
588        assert_eq!(
589            result_dtype,
590            DType::Primitive(PType::I32, Nullability::Nullable)
591        );
592    }
593
594    #[test]
595    fn test_return_dtype_without_else_is_nullable() {
596        let expr = case_when_no_else(lit(true), lit(100i32));
597        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
598        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
599        assert_eq!(
600            result_dtype,
601            DType::Primitive(PType::I32, Nullability::Nullable)
602        );
603    }
604
605    #[test]
606    fn test_return_dtype_with_struct_input() {
607        let dtype = test_harness::struct_dtype();
608        let expr = case_when(
609            gt(get_item("col1", root()), lit(10u16)),
610            lit(100i32),
611            lit(0i32),
612        );
613        let result_dtype = expr.return_dtype(&dtype).unwrap();
614        assert_eq!(
615            result_dtype,
616            DType::Primitive(PType::I32, Nullability::NonNullable)
617        );
618    }
619
620    #[test]
621    fn test_return_dtype_mismatched_then_else_errors() {
622        let expr = case_when(lit(true), lit(100i32), lit("zero"));
623        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
624        let err = expr.return_dtype(&input_dtype).unwrap_err();
625        assert!(
626            err.to_string()
627                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
628        );
629    }
630
631    // ==================== Arity Tests ====================
632
633    #[test]
634    fn test_arity_with_else() {
635        let options = CaseWhenOptions {
636            num_when_then_pairs: 1,
637            has_else: true,
638        };
639        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
640    }
641
642    #[test]
643    fn test_arity_without_else() {
644        let options = CaseWhenOptions {
645            num_when_then_pairs: 1,
646            has_else: false,
647        };
648        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
649    }
650
651    // ==================== Child Name Tests ====================
652
653    #[test]
654    fn test_child_names() {
655        let options = CaseWhenOptions {
656            num_when_then_pairs: 1,
657            has_else: true,
658        };
659        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
660        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
661        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
662    }
663
664    // ==================== N-ary Serialization Tests ====================
665
666    #[test]
667    #[should_panic(expected = "cannot serialize")]
668    fn test_serialization_roundtrip_nary() {
669        let options = CaseWhenOptions {
670            num_when_then_pairs: 3,
671            has_else: true,
672        };
673        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
674        let deserialized = CaseWhen
675            .deserialize(&serialized, &VortexSession::empty())
676            .unwrap();
677        assert_eq!(options, deserialized);
678    }
679
680    #[test]
681    #[should_panic(expected = "cannot serialize")]
682    fn test_serialization_roundtrip_nary_no_else() {
683        let options = CaseWhenOptions {
684            num_when_then_pairs: 4,
685            has_else: false,
686        };
687        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
688        let deserialized = CaseWhen
689            .deserialize(&serialized, &VortexSession::empty())
690            .unwrap();
691        assert_eq!(options, deserialized);
692    }
693
694    // ==================== N-ary Arity Tests ====================
695
696    #[test]
697    fn test_arity_nary_with_else() {
698        let options = CaseWhenOptions {
699            num_when_then_pairs: 3,
700            has_else: true,
701        };
702        // 3 pairs * 2 children + 1 else = 7
703        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
704    }
705
706    #[test]
707    fn test_arity_nary_without_else() {
708        let options = CaseWhenOptions {
709            num_when_then_pairs: 3,
710            has_else: false,
711        };
712        // 3 pairs * 2 children = 6
713        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
714    }
715
716    // ==================== N-ary Child Name Tests ====================
717
718    #[test]
719    fn test_child_names_nary() {
720        let options = CaseWhenOptions {
721            num_when_then_pairs: 3,
722            has_else: true,
723        };
724        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
725        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
726        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
727        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
728        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
729        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
730        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
731    }
732
733    // ==================== N-ary DType Tests ====================
734
735    #[test]
736    fn test_return_dtype_nary_mismatched_then_types_errors() {
737        let expr = nested_case_when(
738            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
739            Some(lit(0i32)),
740        );
741        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
742        let err = expr.return_dtype(&input_dtype).unwrap_err();
743        assert!(err.to_string().contains("THEN dtypes must match"));
744    }
745
746    #[test]
747    fn test_return_dtype_nary_mixed_nullability() {
748        // When some THEN branches are nullable and others are not,
749        // the result should be nullable (union of nullabilities).
750        let non_null_then = lit(100i32);
751        let nullable_then = lit(Scalar::null(DType::Primitive(
752            PType::I32,
753            Nullability::Nullable,
754        )));
755        let expr = nested_case_when(
756            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
757            Some(lit(0i32)),
758        );
759        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
760        let result = expr.return_dtype(&input_dtype).unwrap();
761        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
762    }
763
764    #[test]
765    fn test_return_dtype_nary_no_else_is_nullable() {
766        let expr = nested_case_when(
767            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
768            None,
769        );
770        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
771        let result = expr.return_dtype(&input_dtype).unwrap();
772        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
773    }
774
775    // ==================== Expression Manipulation Tests ====================
776
777    #[test]
778    fn test_replace_children() {
779        let expr = case_when(lit(true), lit(1i32), lit(0i32));
780        expr.with_children([lit(false), lit(2i32), lit(3i32)])
781            .vortex_expect("operation should succeed in test");
782    }
783
784    // ==================== Evaluate Tests ====================
785
786    #[test]
787    fn test_evaluate_simple_condition() {
788        let test_array =
789            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
790                .unwrap()
791                .into_array();
792
793        let expr = case_when(
794            gt(get_item("value", root()), lit(2i32)),
795            lit(100i32),
796            lit(0i32),
797        );
798
799        let result = evaluate_expr(&expr, &test_array);
800        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
801    }
802
803    #[test]
804    fn test_evaluate_nary_multiple_conditions() {
805        // Test n-ary via nested_case_when
806        let test_array =
807            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
808                .unwrap()
809                .into_array();
810
811        let expr = nested_case_when(
812            vec![
813                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
814                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
815            ],
816            Some(lit(0i32)),
817        );
818
819        let result = evaluate_expr(&expr, &test_array);
820        assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
821    }
822
823    #[test]
824    fn test_evaluate_nary_first_match_wins() {
825        let test_array =
826            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
827                .unwrap()
828                .into_array();
829
830        // Both conditions match for values > 3, but first one wins
831        let expr = nested_case_when(
832            vec![
833                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
834                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
835            ],
836            Some(lit(0i32)),
837        );
838
839        let result = evaluate_expr(&expr, &test_array);
840        assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
841    }
842
843    #[test]
844    fn test_evaluate_no_else_returns_null() {
845        let test_array =
846            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
847                .unwrap()
848                .into_array();
849
850        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
851
852        let result = evaluate_expr(&expr, &test_array);
853        assert!(result.dtype().is_nullable());
854        assert_arrays_eq!(
855            result,
856            PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
857                .into_array()
858        );
859    }
860
861    #[test]
862    fn test_evaluate_all_conditions_false() {
863        let test_array =
864            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
865                .unwrap()
866                .into_array();
867
868        let expr = case_when(
869            gt(get_item("value", root()), lit(100i32)),
870            lit(1i32),
871            lit(0i32),
872        );
873
874        let result = evaluate_expr(&expr, &test_array);
875        assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
876    }
877
878    #[test]
879    fn test_evaluate_all_conditions_true() {
880        let test_array =
881            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
882                .unwrap()
883                .into_array();
884
885        let expr = case_when(
886            gt(get_item("value", root()), lit(0i32)),
887            lit(100i32),
888            lit(0i32),
889        );
890
891        let result = evaluate_expr(&expr, &test_array);
892        assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
893    }
894
895    #[test]
896    fn test_evaluate_all_true_no_else_returns_correct_dtype() {
897        // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
898        // Result must be Nullable because the implicit ELSE is NULL.
899        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
900            .unwrap()
901            .into_array();
902
903        let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
904
905        let result = evaluate_expr(&expr, &test_array);
906        assert!(
907            result.dtype().is_nullable(),
908            "result dtype must be Nullable, got {:?}",
909            result.dtype()
910        );
911        assert_arrays_eq!(
912            result,
913            PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
914        );
915    }
916
917    #[test]
918    fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
919        // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
920        // the result dtype must still be Nullable.
921        //
922        // CASE WHEN value = 0 THEN 10          -- NonNullable
923        //      WHEN value = 1 THEN nullable(20) -- Nullable
924        //      ELSE 0                           -- NonNullable
925        // → result must be Nullable(i32)
926        let test_array =
927            StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
928
929        let nullable_20 =
930            Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
931
932        let expr = nested_case_when(
933            vec![
934                (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
935                (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
936            ],
937            Some(lit(0i32)),
938        );
939
940        let result = evaluate_expr(&expr, &test_array);
941        assert!(
942            result.dtype().is_nullable(),
943            "result dtype must be Nullable, got {:?}",
944            result.dtype()
945        );
946        assert_arrays_eq!(
947            result,
948            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
949        );
950        Ok(())
951    }
952
953    #[test]
954    fn test_evaluate_with_literal_condition() {
955        let test_array = buffer![1i32, 2, 3].into_array();
956        let expr = case_when(lit(true), lit(100i32), lit(0i32));
957        let result = evaluate_expr(&expr, &test_array);
958
959        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
960    }
961
962    #[test]
963    fn test_evaluate_with_bool_column_result() {
964        let test_array =
965            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
966                .unwrap()
967                .into_array();
968
969        let expr = case_when(
970            gt(get_item("value", root()), lit(2i32)),
971            lit(true),
972            lit(false),
973        );
974
975        let result = evaluate_expr(&expr, &test_array);
976        assert_arrays_eq!(
977            result,
978            BoolArray::from_iter([false, false, true, true, true]).into_array()
979        );
980    }
981
982    #[test]
983    fn test_evaluate_with_nullable_condition() {
984        let test_array = StructArray::from_fields(&[(
985            "cond",
986            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
987        )])
988        .unwrap()
989        .into_array();
990
991        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
992
993        let result = evaluate_expr(&expr, &test_array);
994        assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
995    }
996
997    #[test]
998    fn test_evaluate_with_nullable_result_values() {
999        let test_array = StructArray::from_fields(&[
1000            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
1001            (
1002                "result",
1003                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
1004                    .into_array(),
1005            ),
1006        ])
1007        .unwrap()
1008        .into_array();
1009
1010        let expr = case_when(
1011            gt(get_item("value", root()), lit(2i32)),
1012            get_item("result", root()),
1013            lit(0i32),
1014        );
1015
1016        let result = evaluate_expr(&expr, &test_array);
1017        assert_arrays_eq!(
1018            result,
1019            PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
1020                .into_array()
1021        );
1022    }
1023
1024    #[test]
1025    fn test_evaluate_with_all_null_condition() {
1026        let test_array = StructArray::from_fields(&[(
1027            "cond",
1028            BoolArray::from_iter([None, None, None]).into_array(),
1029        )])
1030        .unwrap()
1031        .into_array();
1032
1033        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
1034
1035        let result = evaluate_expr(&expr, &test_array);
1036        assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
1037    }
1038
1039    // ==================== N-ary Evaluate Tests ====================
1040
1041    #[test]
1042    fn test_evaluate_nary_no_else_returns_null() {
1043        let test_array =
1044            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1045                .unwrap()
1046                .into_array();
1047
1048        // Two conditions, no ELSE — unmatched rows should be NULL
1049        let expr = nested_case_when(
1050            vec![
1051                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1052                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1053            ],
1054            None,
1055        );
1056
1057        let result = evaluate_expr(&expr, &test_array);
1058        assert!(result.dtype().is_nullable());
1059        assert_arrays_eq!(
1060            result,
1061            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
1062                .into_array()
1063        );
1064    }
1065
1066    #[test]
1067    fn test_evaluate_nary_many_conditions() {
1068        let test_array =
1069            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1070                .unwrap()
1071                .into_array();
1072
1073        // 5 WHEN/THEN pairs: each value maps to its value * 10
1074        let expr = nested_case_when(
1075            vec![
1076                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1077                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1078                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1079                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1080                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1081            ],
1082            Some(lit(0i32)),
1083        );
1084
1085        let result = evaluate_expr(&expr, &test_array);
1086        assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
1087    }
1088
1089    #[test]
1090    fn test_evaluate_nary_all_false_no_else() {
1091        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1092            .unwrap()
1093            .into_array();
1094
1095        // All conditions are false, no ELSE — everything should be NULL
1096        let expr = nested_case_when(
1097            vec![
1098                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1099                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1100            ],
1101            None,
1102        );
1103
1104        let result = evaluate_expr(&expr, &test_array);
1105        assert!(result.dtype().is_nullable());
1106        assert_arrays_eq!(
1107            result,
1108            PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
1109        );
1110    }
1111
1112    #[test]
1113    fn test_evaluate_nary_overlapping_conditions_first_wins() {
1114        let test_array =
1115            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1116                .unwrap()
1117                .into_array();
1118
1119        // value=10: matches cond1 (>5) and cond2 (>0), first should win
1120        // value=20: matches all three, first should win
1121        // value=30: matches all three, first should win
1122        let expr = nested_case_when(
1123            vec![
1124                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1125                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1126                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1127            ],
1128            Some(lit(0i32)),
1129        );
1130
1131        let result = evaluate_expr(&expr, &test_array);
1132        // First matching condition always wins
1133        assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
1134    }
1135
1136    #[test]
1137    fn test_evaluate_nary_early_exit_when_remaining_empty() {
1138        // After branch 0 claims all rows, remaining becomes all_false.
1139        // The loop breaks before evaluating branch 1's condition.
1140        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1141            .unwrap()
1142            .into_array();
1143
1144        let expr = nested_case_when(
1145            vec![
1146                (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1147                // Never evaluated due to early exit; 999 must never appear in output.
1148                (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1149            ],
1150            Some(lit(0i32)),
1151        );
1152
1153        let result = evaluate_expr(&expr, &test_array);
1154        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
1155    }
1156
1157    #[test]
1158    fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1159        // Branch 0 claims value=1. Branch 1 targets the same rows but they are already
1160        // matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
1161        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1162            .unwrap()
1163            .into_array();
1164
1165        let expr = nested_case_when(
1166            vec![
1167                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1168                // Same condition as branch 0 — all matching rows already claimed → skipped.
1169                // 999 must never appear in output.
1170                (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1171                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1172            ],
1173            Some(lit(0i32)),
1174        );
1175
1176        let result = evaluate_expr(&expr, &test_array);
1177        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1178    }
1179
1180    #[test]
1181    fn test_evaluate_nary_string_output() -> VortexResult<()> {
1182        // Exercises merge_case_branches with a non-primitive (Utf8) builder.
1183        let test_array =
1184            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1185                .into_array();
1186
1187        // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
1188        // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
1189        // value=3,4 → 'high' (branch 0)
1190        let expr = nested_case_when(
1191            vec![
1192                (gt(get_item("value", root()), lit(2i32)), lit("high")),
1193                (gt(get_item("value", root()), lit(0i32)), lit("low")),
1194            ],
1195            Some(lit("none")),
1196        );
1197
1198        let result = evaluate_expr(&expr, &test_array);
1199        assert_eq!(
1200            result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
1201            Scalar::utf8("low", Nullability::NonNullable)
1202        );
1203        assert_eq!(
1204            result.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
1205            Scalar::utf8("low", Nullability::NonNullable)
1206        );
1207        assert_eq!(
1208            result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
1209            Scalar::utf8("high", Nullability::NonNullable)
1210        );
1211        assert_eq!(
1212            result.execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())?,
1213            Scalar::utf8("high", Nullability::NonNullable)
1214        );
1215        Ok(())
1216    }
1217
1218    #[test]
1219    fn test_evaluate_nary_with_nullable_conditions() {
1220        let test_array = StructArray::from_fields(&[
1221            (
1222                "cond1",
1223                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1224            ),
1225            (
1226                "cond2",
1227                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1228            ),
1229        ])
1230        .unwrap()
1231        .into_array();
1232
1233        let expr = nested_case_when(
1234            vec![
1235                (get_item("cond1", root()), lit(10i32)),
1236                (get_item("cond2", root()), lit(20i32)),
1237            ],
1238            Some(lit(0i32)),
1239        );
1240
1241        let result = evaluate_expr(&expr, &test_array);
1242        // row 0: cond1=true → 10
1243        // row 1: cond1=NULL(→false), cond2=true → 20
1244        // row 2: cond1=false, cond2=NULL(→false) → else=0
1245        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
1246    }
1247
1248    // ==================== Simplify: COALESCE -> fill_null ====================
1249
1250    /// Builds a non-nullable struct scope whose named fields are all `Nullable(I64)`.
1251    fn nullable_i64_scope(fields: &[&str]) -> DType {
1252        DType::Struct(
1253            StructFields::new(
1254                fields.to_vec().into(),
1255                vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()],
1256            ),
1257            Nullability::NonNullable,
1258        )
1259    }
1260
1261    #[test]
1262    fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> {
1263        // CASE WHEN is_null(x) THEN 0 ELSE x END  ==>  fill_null(x, 0)
1264        let expr = case_when(is_null(col("x")), lit(0i64), col("x"));
1265        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1266        assert!(
1267            optimized.to_string().starts_with("vortex.fill_null"),
1268            "expected fill_null, got {optimized}"
1269        );
1270        Ok(())
1271    }
1272
1273    #[test]
1274    fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> {
1275        // CASE WHEN is_not_null(x) THEN x ELSE 0 END  ==>  fill_null(x, 0)
1276        let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64));
1277        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1278        assert!(
1279            optimized.to_string().starts_with("vortex.fill_null"),
1280            "expected fill_null, got {optimized}"
1281        );
1282        Ok(())
1283    }
1284
1285    #[test]
1286    fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> {
1287        // The is_null operand (x) and the ELSE (y) are different columns: not a COALESCE.
1288        let expr = case_when(is_null(col("x")), lit(0i64), col("y"));
1289        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?;
1290        let s = optimized.to_string();
1291        assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1292        assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1293        Ok(())
1294    }
1295
1296    #[test]
1297    fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> {
1298        // COALESCE(x, c) with a *column* fill: fill_null cannot consume a non-constant
1299        // fill value, so the rewrite must not fire.
1300        let expr = case_when(is_null(col("x")), col("c"), col("x"));
1301        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?;
1302        let s = optimized.to_string();
1303        assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1304        assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn test_simplify_null_fill_collapses_to_input() -> VortexResult<()> {
1310        // Filling the nulls of x with NULL is a no-op, so both forms collapse to just `x`.
1311        //   CASE WHEN is_null(x)     THEN null ELSE x    END  ==>  x
1312        //   CASE WHEN is_not_null(x) THEN x    ELSE null END  ==>  x
1313        let null_fill = || {
1314            lit(Scalar::null(DType::Primitive(
1315                PType::I64,
1316                Nullability::Nullable,
1317            )))
1318        };
1319
1320        for expr in [
1321            case_when(is_null(col("x")), null_fill(), col("x")),
1322            case_when(is_not_null(col("x")), col("x"), null_fill()),
1323        ] {
1324            let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1325            assert_eq!(
1326                optimized.to_string(),
1327                "$.x",
1328                "expected collapse to input column, got {optimized}"
1329            );
1330        }
1331        Ok(())
1332    }
1333
1334    #[test]
1335    fn test_simplify_null_fill_semantic_equivalence() -> VortexResult<()> {
1336        // The collapse-to-input rewrite must preserve values (and `x`'s nullability).
1337        let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1338        let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1339        let null_fill = lit(Scalar::null(DType::Primitive(
1340            PType::I64,
1341            Nullability::Nullable,
1342        )));
1343
1344        let original = case_when(is_null(root()), null_fill, root());
1345        let optimized = original.optimize_recursive(&scope)?;
1346        assert_eq!(
1347            optimized.to_string(),
1348            "$",
1349            "expected collapse to root, got {optimized}"
1350        );
1351
1352        let expected = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1353        assert_arrays_eq!(evaluate_expr(&original, &array), expected);
1354        assert_arrays_eq!(evaluate_expr(&optimized, &array), expected);
1355        Ok(())
1356    }
1357
1358    #[test]
1359    fn test_simplify_does_not_fire_without_else() -> VortexResult<()> {
1360        let expr = case_when_no_else(is_null(col("x")), lit(0i64));
1361        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1362        assert!(
1363            !optimized.to_string().contains("fill_null"),
1364            "must not rewrite a no-ELSE case_when, got {optimized}"
1365        );
1366        Ok(())
1367    }
1368
1369    #[test]
1370    fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> {
1371        let expr = nested_case_when(
1372            vec![
1373                (is_null(col("x")), lit(0i64)),
1374                (gt(col("x"), lit(5i64)), lit(1i64)),
1375            ],
1376            Some(col("x")),
1377        );
1378        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1379        assert!(
1380            !optimized.to_string().contains("fill_null"),
1381            "must not rewrite a multi-pair case_when, got {optimized}"
1382        );
1383        Ok(())
1384    }
1385
1386    #[test]
1387    fn test_simplify_semantic_equivalence() -> VortexResult<()> {
1388        // The optimized expression must produce the same values as the original CASE WHEN.
1389        let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1390        let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1391
1392        let original = case_when(is_null(root()), lit(0i64), root());
1393        let optimized = original.optimize_recursive(&scope)?;
1394        assert!(
1395            optimized.to_string().starts_with("vortex.fill_null"),
1396            "expected fill_null, got {optimized}"
1397        );
1398
1399        // Original keeps CASE WHEN's nullable result dtype; the rewrite tightens it to
1400        // NonNullable because a non-null fill cannot leave any nulls behind. Values match.
1401        assert_arrays_eq!(
1402            evaluate_expr(&original, &array),
1403            PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array()
1404        );
1405        assert_arrays_eq!(
1406            evaluate_expr(&optimized, &array),
1407            buffer![1i64, 0, 3].into_array()
1408        );
1409        Ok(())
1410    }
1411
1412    #[test]
1413    fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1414        // Exercises the scalar path: alternating rows produce one slice per row (no runs),
1415        // triggering the per-row cursor path in merge_case_branches.
1416        let n = 100usize;
1417
1418        // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached.
1419        let branch0_mask = Mask::from_indices(n, (0..n).step_by(2));
1420        let branch1_mask = Mask::from_indices(n, (1..n).step_by(2));
1421
1422        let result = merge_case_branches(
1423            vec![
1424                (
1425                    branch0_mask,
1426                    PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1427                ),
1428                (
1429                    branch1_mask,
1430                    PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1431                ),
1432            ],
1433            PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1434            &mut SESSION.create_execution_ctx(),
1435        )?;
1436
1437        // Even rows → 0, odd rows → 1.
1438        let expected: Vec<Option<i32>> = (0..n)
1439            .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1440            .collect();
1441        assert_arrays_eq!(
1442            result,
1443            PrimitiveArray::from_option_iter(expected).into_array()
1444        );
1445        Ok(())
1446    }
1447}