Skip to main content

vortex_array/scalar_fn/fns/
case_when.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
5//! the value from the first matching condition (first-match-wins). NULL conditions
6//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
7//! otherwise they get the ELSE value.
8//!
9//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
10//! branches must share the same base dtype (ignoring nullability). The result
11//! nullability is the union of all branches (forced nullable if no ELSE).
12
13use std::fmt;
14use std::fmt::Formatter;
15use std::hash::Hash;
16use std::sync::Arc;
17
18use prost::Message;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_mask::AllOr;
22use vortex_mask::Mask;
23use vortex_proto::expr as pb;
24use vortex_session::VortexSession;
25use vortex_session::registry::CachedId;
26
27use crate::ArrayRef;
28use crate::ExecutionCtx;
29use crate::IntoArray;
30use crate::arrays::BoolArray;
31use crate::arrays::ConstantArray;
32use crate::arrays::bool::BoolArrayExt;
33use crate::builders::ArrayBuilder;
34use crate::builders::builder_with_capacity;
35use crate::builtins::ArrayBuiltins;
36use crate::dtype::DType;
37use crate::expr::Expression;
38use crate::scalar::Scalar;
39use crate::scalar_fn::Arity;
40use crate::scalar_fn::ChildName;
41use crate::scalar_fn::ExecutionArgs;
42use crate::scalar_fn::ScalarFnId;
43use crate::scalar_fn::ScalarFnVTable;
44use crate::scalar_fn::SimplifyCtx;
45use crate::scalar_fn::fns::is_not_null::IsNotNull;
46use crate::scalar_fn::fns::is_null::IsNull;
47use crate::scalar_fn::fns::literal::Literal;
48use crate::scalar_fn::fns::zip::zip_impl;
49
50/// Options for the n-ary CaseWhen expression.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct CaseWhenOptions {
53    /// Number of WHEN/THEN pairs.
54    pub num_when_then_pairs: u32,
55    /// Whether an ELSE clause is present.
56    /// If false, unmatched rows return NULL.
57    pub has_else: bool,
58}
59
60impl CaseWhenOptions {
61    /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present.
62    pub fn num_children(&self) -> usize {
63        self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
64    }
65}
66
67impl fmt::Display for CaseWhenOptions {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        write!(
70            f,
71            "case_when(pairs={}, else={})",
72            self.num_when_then_pairs, self.has_else
73        )
74    }
75}
76
77/// An n-ary CASE WHEN expression.
78///
79/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`.
80#[derive(Clone)]
81pub struct CaseWhen;
82
83impl ScalarFnVTable for CaseWhen {
84    type Options = CaseWhenOptions;
85
86    fn id(&self) -> ScalarFnId {
87        static ID: CachedId = CachedId::new("vortex.case_when");
88        *ID
89    }
90
91    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
92        // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
93        // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
94        // stabilize the expr
95        vortex_bail!("cannot serialize")
96    }
97
98    fn deserialize(
99        &self,
100        metadata: &[u8],
101        _session: &VortexSession,
102    ) -> VortexResult<Self::Options> {
103        let opts = pb::CaseWhenOpts::decode(metadata)?;
104        if opts.num_children < 2 {
105            vortex_bail!(
106                "CaseWhen expects at least 2 children, got {}",
107                opts.num_children
108            );
109        }
110        Ok(CaseWhenOptions {
111            num_when_then_pairs: opts.num_children / 2,
112            has_else: opts.num_children % 2 == 1,
113        })
114    }
115
116    fn arity(&self, options: &Self::Options) -> Arity {
117        Arity::Exact(options.num_children())
118    }
119
120    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
121        let num_pair_children = options.num_when_then_pairs as usize * 2;
122        if child_idx < num_pair_children {
123            let pair_idx = child_idx / 2;
124            if child_idx.is_multiple_of(2) {
125                ChildName::from(Arc::from(format!("when_{pair_idx}")))
126            } else {
127                ChildName::from(Arc::from(format!("then_{pair_idx}")))
128            }
129        } else if options.has_else && child_idx == num_pair_children {
130            ChildName::from("else")
131        } else {
132            unreachable!("Invalid child index {} for CaseWhen", child_idx)
133        }
134    }
135
136    fn fmt_sql(
137        &self,
138        options: &Self::Options,
139        expr: &Expression,
140        f: &mut Formatter<'_>,
141    ) -> fmt::Result {
142        write!(f, "CASE")?;
143        for i in 0..options.num_when_then_pairs as usize {
144            write!(
145                f,
146                " WHEN {} THEN {}",
147                expr.child(i * 2),
148                expr.child(i * 2 + 1)
149            )?;
150        }
151        if options.has_else {
152            let else_idx = options.num_when_then_pairs as usize * 2;
153            write!(f, " ELSE {}", expr.child(else_idx))?;
154        }
155        write!(f, " END")
156    }
157
158    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
159        if options.num_when_then_pairs == 0 {
160            vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
161        }
162
163        let expected_len = options.num_children();
164        if arg_dtypes.len() != expected_len {
165            vortex_bail!(
166                "CaseWhen expects {expected_len} argument dtypes, got {}",
167                arg_dtypes.len()
168            );
169        }
170
171        // Unlike SQL which coerces all branches to a common supertype, we require
172        // all THEN/ELSE branches to have the same base dtype (ignoring nullability).
173        // The result nullability is the union of all branches.
174        let first_then = &arg_dtypes[1];
175        let mut result_dtype = first_then.clone();
176
177        for i in 1..options.num_when_then_pairs as usize {
178            let then_i = &arg_dtypes[i * 2 + 1];
179            if !first_then.eq_ignore_nullability(then_i) {
180                vortex_bail!(
181                    "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
182                    first_then,
183                    then_i
184                );
185            }
186            result_dtype = result_dtype.union_nullability(then_i.nullability());
187        }
188
189        if options.has_else {
190            let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
191            if !result_dtype.eq_ignore_nullability(else_dtype) {
192                vortex_bail!(
193                    "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
194                    first_then,
195                    else_dtype
196                );
197            }
198            result_dtype = result_dtype.union_nullability(else_dtype.nullability());
199        } else {
200            // No ELSE means unmatched rows are NULL
201            result_dtype = result_dtype.as_nullable();
202        }
203
204        Ok(result_dtype)
205    }
206
207    fn execute(
208        &self,
209        options: &Self::Options,
210        args: &dyn ExecutionArgs,
211        ctx: &mut ExecutionCtx,
212    ) -> VortexResult<ArrayRef> {
213        // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/
214        //
215        // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction).
216        // TODO: project to only referenced columns before batch reduction (column projection).
217        // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and scatter-merge the results.
218        // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup.
219        let row_count = args.row_count();
220        let num_pairs = options.num_when_then_pairs as usize;
221
222        let mut remaining = Mask::new_true(row_count);
223        let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
224
225        for i in 0..num_pairs {
226            if remaining.all_false() {
227                break;
228            }
229
230            let condition = args.get(i * 2)?;
231            let cond_bool = condition.execute::<BoolArray>(ctx)?;
232            let cond_mask = cond_bool.to_mask_fill_null_false(ctx);
233            let effective_mask = &remaining & &cond_mask;
234
235            if effective_mask.all_false() {
236                continue;
237            }
238
239            let then_value = args.get(i * 2 + 1)?;
240            remaining = remaining.bitand_not(&cond_mask);
241            branches.push((effective_mask, then_value));
242        }
243
244        let else_value: ArrayRef = if options.has_else {
245            args.get(num_pairs * 2)?
246        } else {
247            let then_dtype = args.get(1)?.dtype().as_nullable();
248            ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
249        };
250
251        if branches.is_empty() {
252            return Ok(else_value);
253        }
254
255        merge_case_branches(branches, else_value, ctx)
256    }
257
258    fn simplify(
259        &self,
260        options: &Self::Options,
261        expr: &Expression,
262        _ctx: &dyn SimplifyCtx,
263    ) -> VortexResult<Option<Expression>> {
264        // Rewrite the COALESCE-shaped CASE WHEN into `fill_null`, which references `x`
265        // once and lowers to a single fill kernel instead of a `zip`/merge that resolves
266        // `x` twice (once for the `is_null` predicate, once for the value branch).
267        //
268        //   CASE WHEN is_null(x)     THEN c ELSE x END  ==>  fill_null(x, c)
269        //   CASE WHEN is_not_null(x) THEN x ELSE c END  ==>  fill_null(x, c)
270        //
271        // The fill `c` must be a `Literal`: `fill_null`'s kernel reads the fill value via
272        // `as_constant()`, so a non-constant fill would produce an unexecutable expression.
273        if options.num_when_then_pairs != 1 || !options.has_else {
274            return Ok(None);
275        }
276
277        let when = expr.child(0);
278        let then = expr.child(1);
279        let els = expr.child(2);
280
281        // `is_null(x) ? c : x` — predicate operand and ELSE are the same `x`, fill is THEN.
282        let (x, fill) = if when.is::<IsNull>() && when.child(0) == els {
283            (els, then)
284        // `is_not_null(x) ? x : c` — predicate operand and THEN are the same `x`, fill is ELSE.
285        } else if when.is::<IsNotNull>() && when.child(0) == then {
286            (then, els)
287        } else {
288            return Ok(None);
289        };
290
291        let Some(scalar) = fill.as_opt::<Literal>() else {
292            return Ok(None);
293        };
294
295        if scalar.is_null() {
296            // Filling the nulls of `x` with NULL is a no-op
297            return Ok(Some(x.clone()));
298        }
299
300        Ok(Some(crate::expr::fill_null(x.clone(), fill.clone())))
301    }
302
303    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
304        true
305    }
306
307    fn is_fallible(&self, _options: &Self::Options) -> bool {
308        false
309    }
310}
311
312/// Average run length at which slicing + context-aware builder appends become cheaper than `scalar_at`.
313/// Measured empirically via benchmarks.
314const SLICE_CROSSOVER_RUN_LEN: usize = 4;
315
316/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` into a single array.
317///
318/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`].
319fn merge_case_branches(
320    branches: Vec<(Mask, ArrayRef)>,
321    else_value: ArrayRef,
322    ctx: &mut ExecutionCtx,
323) -> VortexResult<ArrayRef> {
324    if branches.len() == 1 {
325        let (mask, then_value) = &branches[0];
326        return zip_impl(then_value, &else_value, mask, ctx);
327    }
328
329    let output_nullability = branches
330        .iter()
331        .fold(else_value.dtype().nullability(), |acc, (_, arr)| {
332            acc | arr.dtype().nullability()
333        });
334    let output_dtype = else_value.dtype().with_nullability(output_nullability);
335    let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
336
337    let mut spans: Vec<(usize, usize, usize)> = Vec::new();
338    for (branch_idx, (mask, _)) in branches.iter().enumerate() {
339        match mask.slices() {
340            AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
341            AllOr::None => {}
342            AllOr::Some(slices) => {
343                for &(start, end) in slices {
344                    spans.push((start, end, branch_idx));
345                }
346            }
347        }
348    }
349    spans.sort_unstable_by_key(|&(start, ..)| start);
350
351    if spans.is_empty() {
352        return else_value.cast(output_dtype);
353    }
354
355    let builder = builder_with_capacity(&output_dtype, else_value.len());
356
357    let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
358    if fragmented {
359        merge_row_by_row(
360            &branch_arrays,
361            &else_value,
362            &spans,
363            &output_dtype,
364            builder,
365            ctx,
366        )
367    } else {
368        merge_run_by_run(
369            &branch_arrays,
370            &else_value,
371            &spans,
372            &output_dtype,
373            builder,
374            ctx,
375        )
376    }
377}
378
379/// Iterates spans directly, emitting one `scalar_at` per row.
380/// Zero per-run allocations; preferred for fragmented masks (avg run < [`SLICE_CROSSOVER_RUN_LEN`]).
381fn merge_row_by_row(
382    branch_arrays: &[&ArrayRef],
383    else_value: &ArrayRef,
384    spans: &[(usize, usize, usize)],
385    output_dtype: &DType,
386    mut builder: Box<dyn ArrayBuilder>,
387    ctx: &mut ExecutionCtx,
388) -> VortexResult<ArrayRef> {
389    let mut pos = 0;
390    for &(start, end, branch_idx) in spans {
391        for row in pos..start {
392            let scalar = else_value.execute_scalar(row, ctx)?;
393            builder.append_scalar(&scalar.cast(output_dtype)?)?;
394        }
395        for row in start..end {
396            let scalar = branch_arrays[branch_idx].execute_scalar(row, ctx)?;
397            builder.append_scalar(&scalar.cast(output_dtype)?)?;
398        }
399        pos = end;
400    }
401    for row in pos..else_value.len() {
402        let scalar = else_value.execute_scalar(row, ctx)?;
403        builder.append_scalar(&scalar.cast(output_dtype)?)?;
404    }
405
406    Ok(builder.finish())
407}
408
409/// Bulk-copies each span via `slice()` and context-aware builder appends.
410/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
411/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
412fn merge_run_by_run(
413    branch_arrays: &[&ArrayRef],
414    else_value: &ArrayRef,
415    spans: &[(usize, usize, usize)],
416    output_dtype: &DType,
417    mut builder: Box<dyn ArrayBuilder>,
418    ctx: &mut ExecutionCtx,
419) -> VortexResult<ArrayRef> {
420    let else_value = else_value.cast(output_dtype.clone())?;
421    let len = else_value.len();
422    for (start, end, branch_idx) in spans {
423        if builder.len() < *start {
424            else_value
425                .slice(builder.len()..*start)?
426                .append_to_builder(builder.as_mut(), ctx)?;
427        }
428        branch_arrays[*branch_idx]
429            .cast(output_dtype.clone())?
430            .slice(*start..*end)?
431            .append_to_builder(builder.as_mut(), ctx)?;
432    }
433    if builder.len() < len {
434        else_value
435            .slice(builder.len()..len)?
436            .append_to_builder(builder.as_mut(), ctx)?;
437    }
438
439    Ok(builder.finish())
440}
441
442#[cfg(test)]
443mod tests {
444    use std::sync::LazyLock;
445
446    use vortex_buffer::buffer;
447    use vortex_error::VortexExpect as _;
448    use vortex_session::VortexSession;
449
450    use super::*;
451    use crate::Canonical;
452    use crate::IntoArray;
453    use crate::VortexSessionExecute;
454    use crate::arrays::BoolArray;
455    use crate::arrays::PrimitiveArray;
456    use crate::arrays::StructArray;
457    use crate::assert_arrays_eq;
458    use crate::dtype::DType;
459    use crate::dtype::Nullability;
460    use crate::dtype::PType;
461    use crate::dtype::StructFields;
462    use crate::expr::case_when;
463    use crate::expr::case_when_no_else;
464    use crate::expr::col;
465    use crate::expr::eq;
466    use crate::expr::get_item;
467    use crate::expr::gt;
468    use crate::expr::is_not_null;
469    use crate::expr::is_null;
470    use crate::expr::lit;
471    use crate::expr::nested_case_when;
472    use crate::expr::root;
473    use crate::expr::test_harness;
474    use crate::scalar::Scalar;
475
476    static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
477
478    /// Helper to evaluate an expression using the apply+execute pattern
479    fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
480        let mut ctx = SESSION.create_execution_ctx();
481        array
482            .clone()
483            .apply(expr)
484            .unwrap()
485            .execute::<Canonical>(&mut ctx)
486            .unwrap()
487            .into_array()
488    }
489
490    // ==================== Serialization Tests ====================
491
492    #[test]
493    #[should_panic(expected = "cannot serialize")]
494    fn test_serialization_roundtrip() {
495        let options = CaseWhenOptions {
496            num_when_then_pairs: 1,
497            has_else: true,
498        };
499        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
500        let deserialized = CaseWhen
501            .deserialize(&serialized, &VortexSession::empty())
502            .unwrap();
503        assert_eq!(options, deserialized);
504    }
505
506    #[test]
507    #[should_panic(expected = "cannot serialize")]
508    fn test_serialization_no_else() {
509        let options = CaseWhenOptions {
510            num_when_then_pairs: 1,
511            has_else: false,
512        };
513        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
514        let deserialized = CaseWhen
515            .deserialize(&serialized, &VortexSession::empty())
516            .unwrap();
517        assert_eq!(options, deserialized);
518    }
519
520    // ==================== Display Tests ====================
521
522    #[test]
523    fn test_display_with_else() {
524        let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
525        let display = format!("{}", expr);
526        assert!(display.contains("CASE"));
527        assert!(display.contains("WHEN"));
528        assert!(display.contains("THEN"));
529        assert!(display.contains("ELSE"));
530        assert!(display.contains("END"));
531    }
532
533    #[test]
534    fn test_display_no_else() {
535        let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
536        let display = format!("{}", expr);
537        assert!(display.contains("CASE"));
538        assert!(display.contains("WHEN"));
539        assert!(display.contains("THEN"));
540        assert!(!display.contains("ELSE"));
541        assert!(display.contains("END"));
542    }
543
544    #[test]
545    fn test_display_nested_nary() {
546        // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END
547        let expr = nested_case_when(
548            vec![
549                (gt(col("x"), lit(10i32)), lit("high")),
550                (gt(col("x"), lit(5i32)), lit("medium")),
551            ],
552            Some(lit("low")),
553        );
554        let display = format!("{}", expr);
555        assert_eq!(display.matches("CASE").count(), 1);
556        assert_eq!(display.matches("WHEN").count(), 2);
557        assert_eq!(display.matches("THEN").count(), 2);
558    }
559
560    // ==================== DType Tests ====================
561
562    #[test]
563    fn test_return_dtype_with_else() {
564        let expr = case_when(lit(true), lit(100i32), lit(0i32));
565        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
566        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
567        assert_eq!(
568            result_dtype,
569            DType::Primitive(PType::I32, Nullability::NonNullable)
570        );
571    }
572
573    #[test]
574    fn test_return_dtype_with_nullable_else() {
575        let expr = case_when(
576            lit(true),
577            lit(100i32),
578            lit(Scalar::null(DType::Primitive(
579                PType::I32,
580                Nullability::Nullable,
581            ))),
582        );
583        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
584        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
585        assert_eq!(
586            result_dtype,
587            DType::Primitive(PType::I32, Nullability::Nullable)
588        );
589    }
590
591    #[test]
592    fn test_return_dtype_without_else_is_nullable() {
593        let expr = case_when_no_else(lit(true), lit(100i32));
594        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
595        let result_dtype = expr.return_dtype(&input_dtype).unwrap();
596        assert_eq!(
597            result_dtype,
598            DType::Primitive(PType::I32, Nullability::Nullable)
599        );
600    }
601
602    #[test]
603    fn test_return_dtype_with_struct_input() {
604        let dtype = test_harness::struct_dtype();
605        let expr = case_when(
606            gt(get_item("col1", root()), lit(10u16)),
607            lit(100i32),
608            lit(0i32),
609        );
610        let result_dtype = expr.return_dtype(&dtype).unwrap();
611        assert_eq!(
612            result_dtype,
613            DType::Primitive(PType::I32, Nullability::NonNullable)
614        );
615    }
616
617    #[test]
618    fn test_return_dtype_mismatched_then_else_errors() {
619        let expr = case_when(lit(true), lit(100i32), lit("zero"));
620        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
621        let err = expr.return_dtype(&input_dtype).unwrap_err();
622        assert!(
623            err.to_string()
624                .contains("THEN and ELSE dtypes must match (ignoring nullability)")
625        );
626    }
627
628    // ==================== Arity Tests ====================
629
630    #[test]
631    fn test_arity_with_else() {
632        let options = CaseWhenOptions {
633            num_when_then_pairs: 1,
634            has_else: true,
635        };
636        assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
637    }
638
639    #[test]
640    fn test_arity_without_else() {
641        let options = CaseWhenOptions {
642            num_when_then_pairs: 1,
643            has_else: false,
644        };
645        assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
646    }
647
648    // ==================== Child Name Tests ====================
649
650    #[test]
651    fn test_child_names() {
652        let options = CaseWhenOptions {
653            num_when_then_pairs: 1,
654            has_else: true,
655        };
656        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
657        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
658        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
659    }
660
661    // ==================== N-ary Serialization Tests ====================
662
663    #[test]
664    #[should_panic(expected = "cannot serialize")]
665    fn test_serialization_roundtrip_nary() {
666        let options = CaseWhenOptions {
667            num_when_then_pairs: 3,
668            has_else: true,
669        };
670        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
671        let deserialized = CaseWhen
672            .deserialize(&serialized, &VortexSession::empty())
673            .unwrap();
674        assert_eq!(options, deserialized);
675    }
676
677    #[test]
678    #[should_panic(expected = "cannot serialize")]
679    fn test_serialization_roundtrip_nary_no_else() {
680        let options = CaseWhenOptions {
681            num_when_then_pairs: 4,
682            has_else: false,
683        };
684        let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
685        let deserialized = CaseWhen
686            .deserialize(&serialized, &VortexSession::empty())
687            .unwrap();
688        assert_eq!(options, deserialized);
689    }
690
691    // ==================== N-ary Arity Tests ====================
692
693    #[test]
694    fn test_arity_nary_with_else() {
695        let options = CaseWhenOptions {
696            num_when_then_pairs: 3,
697            has_else: true,
698        };
699        // 3 pairs * 2 children + 1 else = 7
700        assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
701    }
702
703    #[test]
704    fn test_arity_nary_without_else() {
705        let options = CaseWhenOptions {
706            num_when_then_pairs: 3,
707            has_else: false,
708        };
709        // 3 pairs * 2 children = 6
710        assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
711    }
712
713    // ==================== N-ary Child Name Tests ====================
714
715    #[test]
716    fn test_child_names_nary() {
717        let options = CaseWhenOptions {
718            num_when_then_pairs: 3,
719            has_else: true,
720        };
721        assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
722        assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
723        assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
724        assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
725        assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
726        assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
727        assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
728    }
729
730    // ==================== N-ary DType Tests ====================
731
732    #[test]
733    fn test_return_dtype_nary_mismatched_then_types_errors() {
734        let expr = nested_case_when(
735            vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
736            Some(lit(0i32)),
737        );
738        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
739        let err = expr.return_dtype(&input_dtype).unwrap_err();
740        assert!(err.to_string().contains("THEN dtypes must match"));
741    }
742
743    #[test]
744    fn test_return_dtype_nary_mixed_nullability() {
745        // When some THEN branches are nullable and others are not,
746        // the result should be nullable (union of nullabilities).
747        let non_null_then = lit(100i32);
748        let nullable_then = lit(Scalar::null(DType::Primitive(
749            PType::I32,
750            Nullability::Nullable,
751        )));
752        let expr = nested_case_when(
753            vec![(lit(true), non_null_then), (lit(false), nullable_then)],
754            Some(lit(0i32)),
755        );
756        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
757        let result = expr.return_dtype(&input_dtype).unwrap();
758        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
759    }
760
761    #[test]
762    fn test_return_dtype_nary_no_else_is_nullable() {
763        let expr = nested_case_when(
764            vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
765            None,
766        );
767        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
768        let result = expr.return_dtype(&input_dtype).unwrap();
769        assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
770    }
771
772    // ==================== Expression Manipulation Tests ====================
773
774    #[test]
775    fn test_replace_children() {
776        let expr = case_when(lit(true), lit(1i32), lit(0i32));
777        expr.with_children([lit(false), lit(2i32), lit(3i32)])
778            .vortex_expect("operation should succeed in test");
779    }
780
781    // ==================== Evaluate Tests ====================
782
783    #[test]
784    fn test_evaluate_simple_condition() {
785        let mut ctx = SESSION.create_execution_ctx();
786        let test_array =
787            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
788                .unwrap()
789                .into_array();
790
791        let expr = case_when(
792            gt(get_item("value", root()), lit(2i32)),
793            lit(100i32),
794            lit(0i32),
795        );
796
797        let result = evaluate_expr(&expr, &test_array);
798        assert_arrays_eq!(
799            result,
800            buffer![0i32, 0, 100, 100, 100].into_array(),
801            &mut ctx
802        );
803    }
804
805    #[test]
806    fn test_evaluate_nary_multiple_conditions() {
807        let mut ctx = SESSION.create_execution_ctx();
808        // Test n-ary via nested_case_when
809        let test_array =
810            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
811                .unwrap()
812                .into_array();
813
814        let expr = nested_case_when(
815            vec![
816                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
817                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
818            ],
819            Some(lit(0i32)),
820        );
821
822        let result = evaluate_expr(&expr, &test_array);
823        assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array(), &mut ctx);
824    }
825
826    #[test]
827    fn test_evaluate_nary_first_match_wins() {
828        let mut ctx = SESSION.create_execution_ctx();
829        let test_array =
830            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
831                .unwrap()
832                .into_array();
833
834        // Both conditions match for values > 3, but first one wins
835        let expr = nested_case_when(
836            vec![
837                (gt(get_item("value", root()), lit(2i32)), lit(100i32)),
838                (gt(get_item("value", root()), lit(3i32)), lit(200i32)),
839            ],
840            Some(lit(0i32)),
841        );
842
843        let result = evaluate_expr(&expr, &test_array);
844        assert_arrays_eq!(
845            result,
846            buffer![0i32, 0, 100, 100, 100].into_array(),
847            &mut ctx
848        );
849    }
850
851    #[test]
852    fn test_evaluate_no_else_returns_null() {
853        let mut ctx = SESSION.create_execution_ctx();
854        let test_array =
855            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
856                .unwrap()
857                .into_array();
858
859        let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
860
861        let result = evaluate_expr(&expr, &test_array);
862        assert!(result.dtype().is_nullable());
863        assert_arrays_eq!(
864            result,
865            PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
866                .into_array(),
867            &mut ctx
868        );
869    }
870
871    #[test]
872    fn test_evaluate_all_conditions_false() {
873        let mut ctx = SESSION.create_execution_ctx();
874        let test_array =
875            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
876                .unwrap()
877                .into_array();
878
879        let expr = case_when(
880            gt(get_item("value", root()), lit(100i32)),
881            lit(1i32),
882            lit(0i32),
883        );
884
885        let result = evaluate_expr(&expr, &test_array);
886        assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array(), &mut ctx);
887    }
888
889    #[test]
890    fn test_evaluate_all_conditions_true() {
891        let mut ctx = SESSION.create_execution_ctx();
892        let test_array =
893            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
894                .unwrap()
895                .into_array();
896
897        let expr = case_when(
898            gt(get_item("value", root()), lit(0i32)),
899            lit(100i32),
900            lit(0i32),
901        );
902
903        let result = evaluate_expr(&expr, &test_array);
904        assert_arrays_eq!(
905            result,
906            buffer![100i32, 100, 100, 100, 100].into_array(),
907            &mut ctx
908        );
909    }
910
911    #[test]
912    fn test_evaluate_all_true_no_else_returns_correct_dtype() {
913        let mut ctx = SESSION.create_execution_ctx();
914        // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE.
915        // Result must be Nullable because the implicit ELSE is NULL.
916        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
917            .unwrap()
918            .into_array();
919
920        let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
921
922        let result = evaluate_expr(&expr, &test_array);
923        assert!(
924            result.dtype().is_nullable(),
925            "result dtype must be Nullable, got {:?}",
926            result.dtype()
927        );
928        assert_arrays_eq!(
929            result,
930            PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array(),
931            &mut ctx
932        );
933    }
934
935    #[test]
936    fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
937        let mut ctx = SESSION.create_execution_ctx();
938        // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable,
939        // the result dtype must still be Nullable.
940        //
941        // CASE WHEN value = 0 THEN 10          -- NonNullable
942        //      WHEN value = 1 THEN nullable(20) -- Nullable
943        //      ELSE 0                           -- NonNullable
944        // → result must be Nullable(i32)
945        let test_array =
946            StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])?.into_array();
947
948        let nullable_20 =
949            Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
950
951        let expr = nested_case_when(
952            vec![
953                (eq(get_item("value", root()), lit(0i32)), lit(10i32)),
954                (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
955            ],
956            Some(lit(0i32)),
957        );
958
959        let result = evaluate_expr(&expr, &test_array);
960        assert!(
961            result.dtype().is_nullable(),
962            "result dtype must be Nullable, got {:?}",
963            result.dtype()
964        );
965        assert_arrays_eq!(
966            result,
967            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array(),
968            &mut ctx
969        );
970        Ok(())
971    }
972
973    #[test]
974    fn test_evaluate_with_literal_condition() {
975        let mut ctx = SESSION.create_execution_ctx();
976        let test_array = buffer![1i32, 2, 3].into_array();
977        let expr = case_when(lit(true), lit(100i32), lit(0i32));
978        let result = evaluate_expr(&expr, &test_array);
979
980        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array(), &mut ctx);
981    }
982
983    #[test]
984    fn test_evaluate_with_bool_column_result() {
985        let mut ctx = SESSION.create_execution_ctx();
986        let test_array =
987            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
988                .unwrap()
989                .into_array();
990
991        let expr = case_when(
992            gt(get_item("value", root()), lit(2i32)),
993            lit(true),
994            lit(false),
995        );
996
997        let result = evaluate_expr(&expr, &test_array);
998        assert_arrays_eq!(
999            result,
1000            BoolArray::from_iter([false, false, true, true, true]).into_array(),
1001            &mut ctx
1002        );
1003    }
1004
1005    #[test]
1006    fn test_evaluate_with_nullable_condition() {
1007        let mut ctx = SESSION.create_execution_ctx();
1008        let test_array = StructArray::from_fields(&[(
1009            "cond",
1010            BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
1011        )])
1012        .unwrap()
1013        .into_array();
1014
1015        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
1016
1017        let result = evaluate_expr(&expr, &test_array);
1018        assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array(), &mut ctx);
1019    }
1020
1021    #[test]
1022    fn test_evaluate_with_nullable_result_values() {
1023        let mut ctx = SESSION.create_execution_ctx();
1024        let test_array = StructArray::from_fields(&[
1025            ("value", buffer![1i32, 2, 3, 4, 5].into_array()),
1026            (
1027                "result",
1028                PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
1029                    .into_array(),
1030            ),
1031        ])
1032        .unwrap()
1033        .into_array();
1034
1035        let expr = case_when(
1036            gt(get_item("value", root()), lit(2i32)),
1037            get_item("result", root()),
1038            lit(0i32),
1039        );
1040
1041        let result = evaluate_expr(&expr, &test_array);
1042        assert_arrays_eq!(
1043            result,
1044            PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
1045                .into_array(),
1046            &mut ctx
1047        );
1048    }
1049
1050    #[test]
1051    fn test_evaluate_with_all_null_condition() {
1052        let mut ctx = SESSION.create_execution_ctx();
1053        let test_array = StructArray::from_fields(&[(
1054            "cond",
1055            BoolArray::from_iter([None, None, None]).into_array(),
1056        )])
1057        .unwrap()
1058        .into_array();
1059
1060        let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
1061
1062        let result = evaluate_expr(&expr, &test_array);
1063        assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array(), &mut ctx);
1064    }
1065
1066    // ==================== N-ary Evaluate Tests ====================
1067
1068    #[test]
1069    fn test_evaluate_nary_no_else_returns_null() {
1070        let mut ctx = SESSION.create_execution_ctx();
1071        let test_array =
1072            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1073                .unwrap()
1074                .into_array();
1075
1076        // Two conditions, no ELSE — unmatched rows should be NULL
1077        let expr = nested_case_when(
1078            vec![
1079                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1080                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1081            ],
1082            None,
1083        );
1084
1085        let result = evaluate_expr(&expr, &test_array);
1086        assert!(result.dtype().is_nullable());
1087        assert_arrays_eq!(
1088            result,
1089            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
1090                .into_array(),
1091            &mut ctx
1092        );
1093    }
1094
1095    #[test]
1096    fn test_evaluate_nary_many_conditions() {
1097        let mut ctx = SESSION.create_execution_ctx();
1098        let test_array =
1099            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
1100                .unwrap()
1101                .into_array();
1102
1103        // 5 WHEN/THEN pairs: each value maps to its value * 10
1104        let expr = nested_case_when(
1105            vec![
1106                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1107                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1108                (eq(get_item("value", root()), lit(3i32)), lit(30i32)),
1109                (eq(get_item("value", root()), lit(4i32)), lit(40i32)),
1110                (eq(get_item("value", root()), lit(5i32)), lit(50i32)),
1111            ],
1112            Some(lit(0i32)),
1113        );
1114
1115        let result = evaluate_expr(&expr, &test_array);
1116        assert_arrays_eq!(
1117            result,
1118            buffer![10i32, 20, 30, 40, 50].into_array(),
1119            &mut ctx
1120        );
1121    }
1122
1123    #[test]
1124    fn test_evaluate_nary_all_false_no_else() {
1125        let mut ctx = SESSION.create_execution_ctx();
1126        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1127            .unwrap()
1128            .into_array();
1129
1130        // All conditions are false, no ELSE — everything should be NULL
1131        let expr = nested_case_when(
1132            vec![
1133                (gt(get_item("value", root()), lit(100i32)), lit(10i32)),
1134                (gt(get_item("value", root()), lit(200i32)), lit(20i32)),
1135            ],
1136            None,
1137        );
1138
1139        let result = evaluate_expr(&expr, &test_array);
1140        assert!(result.dtype().is_nullable());
1141        assert_arrays_eq!(
1142            result,
1143            PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array(),
1144            &mut ctx
1145        );
1146    }
1147
1148    #[test]
1149    fn test_evaluate_nary_overlapping_conditions_first_wins() {
1150        let mut ctx = SESSION.create_execution_ctx();
1151        let test_array =
1152            StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
1153                .unwrap()
1154                .into_array();
1155
1156        // value=10: matches cond1 (>5) and cond2 (>0), first should win
1157        // value=20: matches all three, first should win
1158        // value=30: matches all three, first should win
1159        let expr = nested_case_when(
1160            vec![
1161                (gt(get_item("value", root()), lit(5i32)), lit(1i32)),
1162                (gt(get_item("value", root()), lit(0i32)), lit(2i32)),
1163                (gt(get_item("value", root()), lit(15i32)), lit(3i32)),
1164            ],
1165            Some(lit(0i32)),
1166        );
1167
1168        let result = evaluate_expr(&expr, &test_array);
1169        // First matching condition always wins
1170        assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array(), &mut ctx);
1171    }
1172
1173    #[test]
1174    fn test_evaluate_nary_early_exit_when_remaining_empty() {
1175        let mut ctx = SESSION.create_execution_ctx();
1176        // After branch 0 claims all rows, remaining becomes all_false.
1177        // The loop breaks before evaluating branch 1's condition.
1178        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1179            .unwrap()
1180            .into_array();
1181
1182        let expr = nested_case_when(
1183            vec![
1184                (gt(get_item("value", root()), lit(0i32)), lit(100i32)),
1185                // Never evaluated due to early exit; 999 must never appear in output.
1186                (gt(get_item("value", root()), lit(0i32)), lit(999i32)),
1187            ],
1188            Some(lit(0i32)),
1189        );
1190
1191        let result = evaluate_expr(&expr, &test_array);
1192        assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array(), &mut ctx);
1193    }
1194
1195    #[test]
1196    fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
1197        let mut ctx = SESSION.create_execution_ctx();
1198        // Branch 0 claims value=1. Branch 1 targets the same rows but they are already
1199        // matched → effective_mask is all_false → branch 1 is skipped (THEN not used).
1200        let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
1201            .unwrap()
1202            .into_array();
1203
1204        let expr = nested_case_when(
1205            vec![
1206                (eq(get_item("value", root()), lit(1i32)), lit(10i32)),
1207                // Same condition as branch 0 — all matching rows already claimed → skipped.
1208                // 999 must never appear in output.
1209                (eq(get_item("value", root()), lit(1i32)), lit(999i32)),
1210                (eq(get_item("value", root()), lit(2i32)), lit(20i32)),
1211            ],
1212            Some(lit(0i32)),
1213        );
1214
1215        let result = evaluate_expr(&expr, &test_array);
1216        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array(), &mut ctx);
1217    }
1218
1219    #[test]
1220    fn test_evaluate_nary_string_output() -> VortexResult<()> {
1221        // Exercises merge_case_branches with a non-primitive (Utf8) builder.
1222        let test_array =
1223            StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])?
1224                .into_array();
1225
1226        // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END
1227        // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4)
1228        // value=3,4 → 'high' (branch 0)
1229        let expr = nested_case_when(
1230            vec![
1231                (gt(get_item("value", root()), lit(2i32)), lit("high")),
1232                (gt(get_item("value", root()), lit(0i32)), lit("low")),
1233            ],
1234            Some(lit("none")),
1235        );
1236
1237        let result = evaluate_expr(&expr, &test_array);
1238        assert_eq!(
1239            result.execute_scalar(0, &mut SESSION.create_execution_ctx())?,
1240            Scalar::utf8("low", Nullability::NonNullable)
1241        );
1242        assert_eq!(
1243            result.execute_scalar(1, &mut SESSION.create_execution_ctx())?,
1244            Scalar::utf8("low", Nullability::NonNullable)
1245        );
1246        assert_eq!(
1247            result.execute_scalar(2, &mut SESSION.create_execution_ctx())?,
1248            Scalar::utf8("high", Nullability::NonNullable)
1249        );
1250        assert_eq!(
1251            result.execute_scalar(3, &mut SESSION.create_execution_ctx())?,
1252            Scalar::utf8("high", Nullability::NonNullable)
1253        );
1254        Ok(())
1255    }
1256
1257    #[test]
1258    fn test_evaluate_nary_with_nullable_conditions() {
1259        let mut ctx = SESSION.create_execution_ctx();
1260        let test_array = StructArray::from_fields(&[
1261            (
1262                "cond1",
1263                BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
1264            ),
1265            (
1266                "cond2",
1267                BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
1268            ),
1269        ])
1270        .unwrap()
1271        .into_array();
1272
1273        let expr = nested_case_when(
1274            vec![
1275                (get_item("cond1", root()), lit(10i32)),
1276                (get_item("cond2", root()), lit(20i32)),
1277            ],
1278            Some(lit(0i32)),
1279        );
1280
1281        let result = evaluate_expr(&expr, &test_array);
1282        // row 0: cond1=true → 10
1283        // row 1: cond1=NULL(→false), cond2=true → 20
1284        // row 2: cond1=false, cond2=NULL(→false) → else=0
1285        assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array(), &mut ctx);
1286    }
1287
1288    // ==================== Simplify: COALESCE -> fill_null ====================
1289
1290    /// Builds a non-nullable struct scope whose named fields are all `Nullable(I64)`.
1291    fn nullable_i64_scope(fields: &[&str]) -> DType {
1292        DType::Struct(
1293            StructFields::new(
1294                fields.to_vec().into(),
1295                vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()],
1296            ),
1297            Nullability::NonNullable,
1298        )
1299    }
1300
1301    #[test]
1302    fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> {
1303        // CASE WHEN is_null(x) THEN 0 ELSE x END  ==>  fill_null(x, 0)
1304        let expr = case_when(is_null(col("x")), lit(0i64), col("x"));
1305        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1306        assert!(
1307            optimized.to_string().starts_with("vortex.fill_null"),
1308            "expected fill_null, got {optimized}"
1309        );
1310        Ok(())
1311    }
1312
1313    #[test]
1314    fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> {
1315        // CASE WHEN is_not_null(x) THEN x ELSE 0 END  ==>  fill_null(x, 0)
1316        let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64));
1317        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1318        assert!(
1319            optimized.to_string().starts_with("vortex.fill_null"),
1320            "expected fill_null, got {optimized}"
1321        );
1322        Ok(())
1323    }
1324
1325    #[test]
1326    fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> {
1327        // The is_null operand (x) and the ELSE (y) are different columns: not a COALESCE.
1328        let expr = case_when(is_null(col("x")), lit(0i64), col("y"));
1329        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?;
1330        let s = optimized.to_string();
1331        assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1332        assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1333        Ok(())
1334    }
1335
1336    #[test]
1337    fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> {
1338        // COALESCE(x, c) with a *column* fill: fill_null cannot consume a non-constant
1339        // fill value, so the rewrite must not fire.
1340        let expr = case_when(is_null(col("x")), col("c"), col("x"));
1341        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?;
1342        let s = optimized.to_string();
1343        assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1344        assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1345        Ok(())
1346    }
1347
1348    #[test]
1349    fn test_simplify_null_fill_collapses_to_input() -> VortexResult<()> {
1350        // Filling the nulls of x with NULL is a no-op, so both forms collapse to just `x`.
1351        //   CASE WHEN is_null(x)     THEN null ELSE x    END  ==>  x
1352        //   CASE WHEN is_not_null(x) THEN x    ELSE null END  ==>  x
1353        let null_fill = || {
1354            lit(Scalar::null(DType::Primitive(
1355                PType::I64,
1356                Nullability::Nullable,
1357            )))
1358        };
1359
1360        for expr in [
1361            case_when(is_null(col("x")), null_fill(), col("x")),
1362            case_when(is_not_null(col("x")), col("x"), null_fill()),
1363        ] {
1364            let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1365            assert_eq!(
1366                optimized.to_string(),
1367                "$.x",
1368                "expected collapse to input column, got {optimized}"
1369            );
1370        }
1371        Ok(())
1372    }
1373
1374    #[test]
1375    fn test_simplify_null_fill_semantic_equivalence() -> VortexResult<()> {
1376        let mut ctx = SESSION.create_execution_ctx();
1377        // The collapse-to-input rewrite must preserve values (and `x`'s nullability).
1378        let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1379        let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1380        let null_fill = lit(Scalar::null(DType::Primitive(
1381            PType::I64,
1382            Nullability::Nullable,
1383        )));
1384
1385        let original = case_when(is_null(root()), null_fill, root());
1386        let optimized = original.optimize_recursive(&scope)?;
1387        assert_eq!(
1388            optimized.to_string(),
1389            "$",
1390            "expected collapse to root, got {optimized}"
1391        );
1392
1393        let expected = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1394        assert_arrays_eq!(evaluate_expr(&original, &array), expected, &mut ctx);
1395        assert_arrays_eq!(evaluate_expr(&optimized, &array), expected, &mut ctx);
1396        Ok(())
1397    }
1398
1399    #[test]
1400    fn test_simplify_does_not_fire_without_else() -> VortexResult<()> {
1401        let expr = case_when_no_else(is_null(col("x")), lit(0i64));
1402        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1403        assert!(
1404            !optimized.to_string().contains("fill_null"),
1405            "must not rewrite a no-ELSE case_when, got {optimized}"
1406        );
1407        Ok(())
1408    }
1409
1410    #[test]
1411    fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> {
1412        let expr = nested_case_when(
1413            vec![
1414                (is_null(col("x")), lit(0i64)),
1415                (gt(col("x"), lit(5i64)), lit(1i64)),
1416            ],
1417            Some(col("x")),
1418        );
1419        let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1420        assert!(
1421            !optimized.to_string().contains("fill_null"),
1422            "must not rewrite a multi-pair case_when, got {optimized}"
1423        );
1424        Ok(())
1425    }
1426
1427    #[test]
1428    fn test_simplify_semantic_equivalence() -> VortexResult<()> {
1429        let mut ctx = SESSION.create_execution_ctx();
1430        // The optimized expression must produce the same values as the original CASE WHEN.
1431        let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1432        let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1433
1434        let original = case_when(is_null(root()), lit(0i64), root());
1435        let optimized = original.optimize_recursive(&scope)?;
1436        assert!(
1437            optimized.to_string().starts_with("vortex.fill_null"),
1438            "expected fill_null, got {optimized}"
1439        );
1440
1441        // Original keeps CASE WHEN's nullable result dtype; the rewrite tightens it to
1442        // NonNullable because a non-null fill cannot leave any nulls behind. Values match.
1443        assert_arrays_eq!(
1444            evaluate_expr(&original, &array),
1445            PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array(),
1446            &mut ctx
1447        );
1448        assert_arrays_eq!(
1449            evaluate_expr(&optimized, &array),
1450            buffer![1i64, 0, 3].into_array(),
1451            &mut ctx
1452        );
1453        Ok(())
1454    }
1455
1456    #[test]
1457    fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
1458        let mut ctx = SESSION.create_execution_ctx();
1459        // Exercises the scalar path: alternating rows produce one slice per row (no runs),
1460        // triggering the per-row cursor path in merge_case_branches.
1461        let n = 100usize;
1462
1463        // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached.
1464        let branch0_mask = Mask::from_indices(n, (0..n).step_by(2));
1465        let branch1_mask = Mask::from_indices(n, (1..n).step_by(2));
1466
1467        let result = merge_case_branches(
1468            vec![
1469                (
1470                    branch0_mask,
1471                    PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
1472                ),
1473                (
1474                    branch1_mask,
1475                    PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
1476                ),
1477            ],
1478            PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
1479            &mut SESSION.create_execution_ctx(),
1480        )?;
1481
1482        // Even rows → 0, odd rows → 1.
1483        let expected: Vec<Option<i32>> = (0..n)
1484            .map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
1485            .collect();
1486        assert_arrays_eq!(
1487            result,
1488            PrimitiveArray::from_option_iter(expected).into_array(),
1489            &mut ctx
1490        );
1491        Ok(())
1492    }
1493}