Skip to main content

simd_kernels/kernels/
conditional.rs

1// Copyright (c) 2025 Peter Bower
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Conditional Logic Kernels Module** - *High-Performance Conditional Operations and Data Selection*
6//!
7//! Advanced conditional logic kernels providing efficient data selection, filtering, and transformation
8//! operations with comprehensive null handling and SIMD acceleration. Essential infrastructure
9//! for implementing complex analytical workflows and query execution.
10//!
11//! ## Core Operations
12//! - **Conditional selection**: IF-THEN-ELSE operations with three-valued logic support
13//! - **Array filtering**: Efficient boolean mask-based filtering with zero-copy optimisation
14//! - **Coalescing operations**: Null-aware value selection with fallback hierarchies
15//! - **Case-when logic**: Multi-condition branching with optimised evaluation strategies
16//! - **Null propagation**: Comprehensive null handling following Apache Arrow semantics
17//! - **Type preservation**: Maintains input data types through conditional transformations
18
19include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::marker::PhantomData;
22
23#[cfg(feature = "fast_hash")]
24use ahash::AHashMap;
25#[cfg(not(feature = "fast_hash"))]
26use std::collections::HashMap;
27
28use minarrow::{
29    Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, CategoricalArray, FloatArray, Integer,
30    IntegerArray, StringAVT, StringArray, Vec64, enums::error::KernelError,
31    utils::confirm_equal_len,
32};
33
34#[cfg(feature = "simd")]
35use core::simd::{Mask, Select, Simd};
36
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39
40use minarrow::utils::confirm_capacity;
41#[cfg(feature = "datetime")]
42use minarrow::{DatetimeArray, TimeUnit};
43
44#[inline(always)]
45fn prealloc_vec<T: Copy>(len: usize) -> Vec64<T> {
46    let mut v = Vec64::<T>::with_capacity(len);
47    // SAFETY: every slot is written before any read
48    unsafe { v.set_len(len) };
49    v
50}
51
52// Numeric Int float
53macro_rules! impl_conditional_copy_numeric {
54    ($fn_name:ident, $fn_name_to:ident, $ty:ty, $mask_elem:ty, $lanes:expr, $array_ty:ident) => {
55        /// Zero-allocation variant: writes directly to caller's output buffer.
56        ///
57        /// Conditional copy operation: select elements from `then_data` or `else_data` based on boolean mask.
58        /// Panics if output.len() != mask.len.
59        #[inline(always)]
60        pub fn $fn_name_to(
61            mask: &BooleanArray<()>,
62            then_data: &[$ty],
63            else_data: &[$ty],
64            output: &mut [$ty],
65        ) {
66            let len = mask.len;
67            assert_eq!(
68                len,
69                output.len(),
70                concat!(stringify!($fn_name_to), ": output length mismatch")
71            );
72            let mask_data = &mask.data;
73
74            #[cfg(feature = "simd")]
75            {
76                // Check if both arrays are 64-byte aligned for SIMD
77                if is_simd_aligned(then_data) && is_simd_aligned(else_data) {
78                    const N: usize = $lanes;
79                    let mut i = 0;
80                    while i + N <= len {
81                        let mut bits = [false; N];
82                        for l in 0..N {
83                            bits[l] = unsafe { mask_data.get_unchecked(i + l) };
84                        }
85                        let cond = Mask::<$mask_elem, N>::from_array(bits);
86                        let t = Simd::<$ty, N>::from_slice(&then_data[i..i + N]);
87                        let e = Simd::<$ty, N>::from_slice(&else_data[i..i + N]);
88                        cond.select(t, e).copy_to_slice(&mut output[i..i + N]);
89                        i += N;
90                    }
91                    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
92                    for j in i..len {
93                        output[j] = if unsafe { mask_data.get_unchecked(j) } {
94                            then_data[j]
95                        } else {
96                            else_data[j]
97                        };
98                    }
99                    return;
100                }
101                // Fall through to scalar path if alignment check failed
102            }
103
104            // Scalar fallback - alignment check failed
105            for i in 0..len {
106                output[i] = if unsafe { mask_data.get_unchecked(i) } {
107                    then_data[i]
108                } else {
109                    else_data[i]
110                };
111            }
112        }
113
114        /// Conditional copy operation: select elements from `then_data` or `else_data` based on boolean mask.
115        #[inline(always)]
116        pub fn $fn_name(
117            mask: &BooleanArray<()>,
118            then_data: &[$ty],
119            else_data: &[$ty],
120        ) -> $array_ty<$ty> {
121            let len = mask.len;
122            let mut data = prealloc_vec::<$ty>(len);
123            $fn_name_to(mask, then_data, else_data, &mut data);
124            $array_ty {
125                data: data.into(),
126                null_mask: mask.null_mask.clone(),
127            }
128        }
129    };
130}
131
132// Conditional datetime
133#[cfg(feature = "datetime")]
134macro_rules! impl_conditional_copy_datetime {
135    ($fn_name:ident, $fn_name_to:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
136        /// Zero-allocation variant: writes directly to caller's output buffer.
137        ///
138        /// Panics if output.len() != mask.len.
139        #[inline(always)]
140        pub fn $fn_name_to(
141            mask: &BooleanArray<()>,
142            then_data: &[$ty],
143            else_data: &[$ty],
144            output: &mut [$ty],
145        ) {
146            let len = mask.len;
147            assert_eq!(
148                len,
149                output.len(),
150                concat!(stringify!($fn_name_to), ": output length mismatch")
151            );
152            let mask_data = &mask.data;
153
154            #[cfg(feature = "simd")]
155            {
156                // Check if both arrays are 64-byte aligned for SIMD
157                if is_simd_aligned(then_data) && is_simd_aligned(else_data) {
158                    use core::simd::{Mask, Select, Simd};
159
160                    const N: usize = $lanes;
161                    let mut i = 0;
162                    while i + N <= len {
163                        let mut bits = [false; N];
164                        for l in 0..N {
165                            bits[l] = unsafe { mask_data.get_unchecked(i + l) };
166                        }
167                        let cond = Mask::<$mask_elem, N>::from_array(bits);
168                        let t = Simd::<$ty, N>::from_slice(&then_data[i..i + N]);
169                        let e = Simd::<$ty, N>::from_slice(&else_data[i..i + N]);
170                        cond.select(t, e).copy_to_slice(&mut output[i..i + N]);
171                        i += N;
172                    }
173                    // Scalar tail
174                    for j in i..len {
175                        output[j] = if unsafe { mask_data.get_unchecked(j) } {
176                            then_data[j]
177                        } else {
178                            else_data[j]
179                        };
180                    }
181                    return;
182                }
183                // Fall through to scalar path if alignment check failed
184            }
185
186            // Scalar fallback - alignment check failed
187            for i in 0..len {
188                output[i] = unsafe {
189                    if mask_data.get_unchecked(i) {
190                        then_data[i]
191                    } else {
192                        else_data[i]
193                    }
194                };
195            }
196        }
197
198        #[inline(always)]
199        pub fn $fn_name(
200            mask: &BooleanArray<()>,
201            then_data: &[$ty],
202            else_data: &[$ty],
203            time_unit: TimeUnit,
204        ) -> DatetimeArray<$ty> {
205            let len = mask.len;
206            let mut data = prealloc_vec::<$ty>(len);
207            $fn_name_to(mask, then_data, else_data, &mut data);
208            DatetimeArray {
209                data: data.into(),
210                null_mask: mask.null_mask.clone(),
211                time_unit,
212            }
213        }
214    };
215}
216
217#[cfg(feature = "extended_numeric_types")]
218impl_conditional_copy_numeric!(
219    conditional_copy_i8,
220    conditional_copy_i8_to,
221    i8,
222    i8,
223    W8,
224    IntegerArray
225);
226#[cfg(feature = "extended_numeric_types")]
227impl_conditional_copy_numeric!(
228    conditional_copy_u8,
229    conditional_copy_u8_to,
230    u8,
231    i8,
232    W8,
233    IntegerArray
234);
235#[cfg(feature = "extended_numeric_types")]
236impl_conditional_copy_numeric!(
237    conditional_copy_i16,
238    conditional_copy_i16_to,
239    i16,
240    i16,
241    W16,
242    IntegerArray
243);
244#[cfg(feature = "extended_numeric_types")]
245impl_conditional_copy_numeric!(
246    conditional_copy_u16,
247    conditional_copy_u16_to,
248    u16,
249    i16,
250    W16,
251    IntegerArray
252);
253impl_conditional_copy_numeric!(
254    conditional_copy_i32,
255    conditional_copy_i32_to,
256    i32,
257    i32,
258    W32,
259    IntegerArray
260);
261impl_conditional_copy_numeric!(
262    conditional_copy_u32,
263    conditional_copy_u32_to,
264    u32,
265    i32,
266    W32,
267    IntegerArray
268);
269impl_conditional_copy_numeric!(
270    conditional_copy_i64,
271    conditional_copy_i64_to,
272    i64,
273    i64,
274    W64,
275    IntegerArray
276);
277impl_conditional_copy_numeric!(
278    conditional_copy_u64,
279    conditional_copy_u64_to,
280    u64,
281    i64,
282    W64,
283    IntegerArray
284);
285impl_conditional_copy_numeric!(
286    conditional_copy_f32,
287    conditional_copy_f32_to,
288    f32,
289    i32,
290    W32,
291    FloatArray
292);
293impl_conditional_copy_numeric!(
294    conditional_copy_f64,
295    conditional_copy_f64_to,
296    f64,
297    i64,
298    W64,
299    FloatArray
300);
301
302#[cfg(feature = "datetime")]
303impl_conditional_copy_datetime!(
304    conditional_copy_datetime32,
305    conditional_copy_datetime32_to,
306    i32,
307    i32,
308    W32
309);
310#[cfg(feature = "datetime")]
311impl_conditional_copy_datetime!(
312    conditional_copy_datetime64,
313    conditional_copy_datetime64_to,
314    i64,
315    i64,
316    W64
317);
318
319/// Zero-allocation variant: writes directly to caller's output buffer.
320///
321/// Conditional copy for floating-point arrays with runtime type dispatch.
322/// Panics if output.len() != mask.len.
323#[inline(always)]
324pub fn conditional_copy_float_to<T: Copy + 'static>(
325    mask: &BooleanArray<()>,
326    then_data: &[T],
327    else_data: &[T],
328    output: &mut [T],
329) {
330    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
331        unsafe {
332            conditional_copy_f32_to(
333                mask,
334                std::mem::transmute(then_data),
335                std::mem::transmute(else_data),
336                std::mem::transmute(output),
337            )
338        };
339        return;
340    }
341    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
342        unsafe {
343            conditional_copy_f64_to(
344                mask,
345                std::mem::transmute(then_data),
346                std::mem::transmute(else_data),
347                std::mem::transmute(output),
348            )
349        };
350        return;
351    }
352    unreachable!("unsupported float type")
353}
354
355/// Conditional copy for floating-point arrays with runtime type dispatch.
356#[inline(always)]
357pub fn conditional_copy_float<T: Copy + 'static>(
358    mask: &BooleanArray<()>,
359    then_data: &[T],
360    else_data: &[T],
361) -> FloatArray<T> {
362    let len = mask.len;
363    let mut data = prealloc_vec::<T>(len);
364    conditional_copy_float_to(mask, then_data, else_data, &mut data);
365    FloatArray {
366        data: data.into(),
367        null_mask: mask.null_mask.clone(),
368    }
369}
370
371// Bit-packed Boolean
372/// Zero-allocation variant: writes directly to caller's output buffer.
373///
374/// Conditional copy operation for boolean bitmask arrays.
375/// The output Bitmask must have capacity >= mask.len.
376pub fn conditional_copy_bool_to(
377    mask: &BooleanArray<()>,
378    then_data: &Bitmask,
379    else_data: &Bitmask,
380    output: &mut Bitmask,
381) -> Result<(), KernelError> {
382    let len_bits = mask.len;
383    confirm_capacity("if_then_else: then_data", then_data.capacity(), len_bits)?;
384    confirm_capacity("if_then_else: else_data", else_data.capacity(), len_bits)?;
385    confirm_capacity("if_then_else: output", output.capacity(), len_bits)?;
386
387    // Dense fast-path
388    if mask.null_mask.is_none() {
389        for i in 0..len_bits {
390            let m = unsafe { mask.data.get_unchecked(i) };
391            let v = if m {
392                unsafe { then_data.get_unchecked(i) }
393            } else {
394                unsafe { else_data.get_unchecked(i) }
395            };
396            output.set(i, v);
397        }
398        return Ok(());
399    }
400
401    // Null-aware path
402    for i in 0..len_bits {
403        if !mask
404            .null_mask
405            .as_ref()
406            .map_or(true, |m| unsafe { m.get_unchecked(i) })
407        {
408            continue;
409        }
410        let choose_then = unsafe { mask.data.get_unchecked(i) };
411        let v = if choose_then {
412            unsafe { then_data.get_unchecked(i) }
413        } else {
414            unsafe { else_data.get_unchecked(i) }
415        };
416        output.set(i, v);
417    }
418
419    Ok(())
420}
421
422/// Conditional copy operation for boolean bitmask arrays.
423pub fn conditional_copy_bool(
424    mask: &BooleanArray<()>,
425    then_data: &Bitmask,
426    else_data: &Bitmask,
427) -> Result<BooleanArray<()>, KernelError> {
428    let len_bits = mask.len;
429    let mut out = Bitmask::with_capacity(len_bits);
430    conditional_copy_bool_to(mask, then_data, else_data, &mut out)?;
431
432    // Dense fast-path
433    if mask.null_mask.is_none() {
434        return Ok(BooleanArray {
435            data: out.into(),
436            null_mask: None,
437            len: len_bits,
438            _phantom: PhantomData,
439        });
440    }
441
442    // Null-aware path: build output mask
443    let mut out_mask = Bitmask::new_set_all(len_bits, false);
444    for i in 0..len_bits {
445        if mask
446            .null_mask
447            .as_ref()
448            .map_or(true, |m| unsafe { m.get_unchecked(i) })
449        {
450            out_mask.set(i, true);
451        }
452    }
453
454    Ok(BooleanArray {
455        data: out.into(),
456        null_mask: Some(out_mask),
457        len: len_bits,
458        _phantom: PhantomData,
459    })
460}
461
462// Strings
463#[inline(always)]
464/// Conditional copy operation for UTF-8 string arrays.
465pub fn conditional_copy_str<'a, T: Integer>(
466    mask: BooleanAVT<'a, ()>,
467    then_arr: StringAVT<'a, T>,
468    else_arr: StringAVT<'a, T>,
469) -> Result<StringArray<T>, KernelError> {
470    let (mask_arr, mask_off, mask_len) = mask;
471    let (then_arr, then_off, then_len) = then_arr;
472    let (else_arr, else_off, else_len) = else_arr;
473
474    confirm_equal_len(
475        "if_then_else: then_arr.len() != mask_len",
476        then_len,
477        mask_len,
478    )?;
479    confirm_equal_len(
480        "if_then_else: else_arr.len() != mask_len",
481        else_len,
482        mask_len,
483    )?;
484
485    // First pass: compute total bytes required
486    let mut total_bytes = 0;
487    for i in 0..mask_len {
488        let idx = mask_off + i;
489        let valid = mask_arr
490            .null_mask
491            .as_ref()
492            .map_or(true, |m| unsafe { m.get_unchecked(idx) });
493        if valid {
494            let use_then = unsafe { mask_arr.data.get_unchecked(idx) };
495            let s = unsafe {
496                if use_then {
497                    then_arr.get_str_unchecked(then_off + i)
498                } else {
499                    else_arr.get_str_unchecked(else_off + i)
500                }
501            };
502            total_bytes += s.len();
503        }
504    }
505
506    // Allocate output
507    let mut offsets = Vec64::<T>::with_capacity(mask_len + 1);
508    let mut values = Vec64::<u8>::with_capacity(total_bytes);
509    let mut out_mask = Bitmask::new_set_all(mask_len, false);
510    unsafe {
511        offsets.set_len(mask_len + 1);
512    }
513
514    // Fill
515    offsets[0] = T::zero();
516    let mut cur = 0;
517
518    for i in 0..mask_len {
519        let idx = mask_off + i;
520        let mask_valid = mask_arr
521            .null_mask
522            .as_ref()
523            .map_or(true, |m| unsafe { m.get_unchecked(idx) });
524        if !mask_valid {
525            offsets[i + 1] = T::from_usize(cur);
526            continue;
527        }
528
529        let use_then = unsafe { mask_arr.data.get_unchecked(idx) };
530        let s = unsafe {
531            if use_then {
532                then_arr.get_str_unchecked(then_off + i)
533            } else {
534                else_arr.get_str_unchecked(else_off + i)
535            }
536        }
537        .as_bytes();
538
539        values.extend_from_slice(s);
540        cur += s.len();
541        offsets[i + 1] = T::from_usize(cur);
542        unsafe {
543            out_mask.set_unchecked(i, true);
544        }
545    }
546
547    Ok(StringArray {
548        offsets: offsets.into(),
549        data: values.into(),
550        null_mask: Some(out_mask),
551    })
552}
553
554// Dictionary
555
556/// Conditional copy operation for dictionary/categorical arrays.
557pub fn conditional_copy_dict32<'a, T: Integer>(
558    mask: BooleanAVT<'a, ()>,
559    then_arr: CategoricalAVT<'a, T>,
560    else_arr: CategoricalAVT<'a, T>,
561) -> Result<CategoricalArray<T>, KernelError> {
562    let (mask_arr, mask_off, mask_len) = mask;
563    let (then_arr, then_off, then_len) = then_arr;
564    let (else_arr, else_off, else_len) = else_arr;
565
566    confirm_equal_len(
567        "if_then_else: then_arr.len() != mask_len",
568        then_len,
569        mask_len,
570    )?;
571    confirm_equal_len(
572        "if_then_else: else_arr.len() != mask_len",
573        else_len,
574        mask_len,
575    )?;
576
577    if mask_len == 0 {
578        return Ok(CategoricalArray {
579            data: Vec64::new().into(),
580            unique_values: Vec64::new(),
581            null_mask: Some(Bitmask::new_set_all(0, false)),
582        });
583    }
584
585    // Merge unique values
586    let mut uniques = then_arr.unique_values.clone();
587    for v in &else_arr.unique_values {
588        if !uniques.contains(v) {
589            uniques.push(v.clone());
590        }
591    }
592
593    #[cfg(feature = "fast_hash")]
594    let lookup: AHashMap<&str, T> = uniques
595        .iter()
596        .enumerate()
597        .map(|(i, v)| (v.as_str(), T::from_usize(i)))
598        .collect();
599    #[cfg(not(feature = "fast_hash"))]
600    let lookup: HashMap<&str, T> = uniques
601        .iter()
602        .enumerate()
603        .map(|(i, v)| (v.as_str(), T::from_usize(i)))
604        .collect();
605
606    let mut data = Vec64::<T>::with_capacity(mask_len);
607    unsafe {
608        data.set_len(mask_len);
609    }
610
611    let mut out_mask = Bitmask::new_set_all(mask_len, false);
612
613    for i in 0..mask_len {
614        let mask_idx = mask_off + i;
615        let then_idx = then_off + i;
616        let else_idx = else_off + i;
617
618        let mask_valid = mask_arr
619            .null_mask
620            .as_ref()
621            .map_or(true, |m| unsafe { m.get_unchecked(mask_idx) });
622        if !mask_valid {
623            data[i] = T::zero();
624            continue;
625        }
626
627        let choose_then = unsafe { mask_arr.data.get_unchecked(mask_idx) };
628        let (src_arr, src_idx, valid_mask) = if choose_then {
629            (then_arr, then_idx, then_arr.null_mask.as_ref())
630        } else {
631            (else_arr, else_idx, else_arr.null_mask.as_ref())
632        };
633
634        if valid_mask.map_or(true, |m| unsafe { m.get_unchecked(src_idx) }) {
635            let idx = unsafe { *src_arr.data.get_unchecked(src_idx) }.to_usize();
636            let val = unsafe { src_arr.unique_values.get_unchecked(idx) };
637            data[i] = *lookup.get(val.as_str()).unwrap();
638            unsafe {
639                out_mask.set_unchecked(i, true);
640            }
641        } else {
642            data[i] = T::zero();
643        }
644    }
645
646    Ok(CategoricalArray {
647        data: data.into(),
648        unique_values: uniques,
649        null_mask: Some(out_mask),
650    })
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use minarrow::{
657        Bitmask, BooleanArray, MaskedArray,
658        structs::variants::{categorical::CategoricalArray, string::StringArray},
659    };
660
661    fn bm(bools: &[bool]) -> Bitmask {
662        Bitmask::from_bools(bools)
663    }
664
665    fn bool_arr(bools: &[bool]) -> BooleanArray<()> {
666        BooleanArray::from_slice(bools)
667    }
668
669    #[test]
670    fn test_conditional_copy_numeric_no_null() {
671        // i32
672        let mask = bool_arr(&[true, false, true, false, false, true]);
673        let then = vec![10, 20, 30, 40, 50, 60];
674        let els = vec![1, 2, 3, 4, 5, 6];
675        let arr = conditional_copy_i32(&mask, &then, &els);
676        assert_eq!(&arr.data[..], &[10, 2, 30, 4, 5, 60]);
677        assert!(arr.null_mask.is_none());
678
679        // f64
680        let mask = bool_arr(&[true, false, false]);
681        let then = vec![2.0, 4.0, 6.0];
682        let els = vec![1.0, 3.0, 5.0];
683        let arr = conditional_copy_f64(&mask, &then, &els);
684        assert_eq!(&arr.data[..], &[2.0, 3.0, 5.0]);
685        assert!(arr.null_mask.is_none());
686    }
687
688    #[test]
689    fn test_conditional_copy_numeric_with_null() {
690        let mut mask = bool_arr(&[true, false, true]);
691        mask.null_mask = Some(bm(&[true, false, true]));
692        let then = vec![10i64, 20, 30];
693        let els = vec![1i64, 2, 3];
694        let arr = conditional_copy_i64(&mask, &then, &els);
695        assert_eq!(&arr.data[..], &[10, 2, 30]);
696        let null_mask = arr.null_mask.as_ref().unwrap();
697        assert_eq!(null_mask.get(0), true);
698        assert_eq!(null_mask.get(1), false);
699        assert_eq!(null_mask.get(2), true);
700    }
701
702    #[cfg(feature = "extended_numeric_types")]
703    #[test]
704    fn test_conditional_copy_numeric_edge_cases() {
705        // Empty
706        let mask = bool_arr(&[]);
707        let then: Vec<u32> = vec![];
708        let els: Vec<u32> = vec![];
709        let arr = conditional_copy_u32(&mask, &then, &els);
710        assert_eq!(arr.data.len(), 0);
711        assert!(arr.null_mask.is_none());
712        // 1-element, mask true/false/null
713        let mask = bool_arr(&[true]);
714        let then = vec![42];
715        let els = vec![1];
716        let arr = conditional_copy_u8(&mask, &then, &els);
717        assert_eq!(&arr.data[..], &[42]);
718        let mask = bool_arr(&[false]);
719        let arr = conditional_copy_u8(&mask, &then, &els);
720        assert_eq!(&arr.data[..], &[1]);
721        let mut mask = bool_arr(&[true]);
722        mask.null_mask = Some(bm(&[false]));
723        let arr = conditional_copy_u8(&mask, &then, &els);
724        assert_eq!(&arr.data[..], &[42]);
725        assert_eq!(arr.null_mask.as_ref().unwrap().get(0), false);
726    }
727
728    #[test]
729    fn test_conditional_copy_bool_no_null() {
730        let mask = bool_arr(&[true, false, true, false]);
731        let then = bm(&[true, true, false, false]);
732        let els = bm(&[false, true, true, false]);
733        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
734        assert_eq!(out.data.get(0), true);
735        assert_eq!(out.data.get(1), true);
736        assert_eq!(out.data.get(2), false);
737        assert_eq!(out.data.get(3), false);
738        assert!(out.null_mask.is_none());
739        assert_eq!(out.len, 4);
740    }
741
742    #[test]
743    fn test_conditional_copy_bool_with_null() {
744        let mut mask = bool_arr(&[true, false, true, false, true]);
745        mask.null_mask = Some(bm(&[true, false, true, true, false]));
746        let then = bm(&[false, true, false, true, true]);
747        let els = bm(&[true, false, true, false, false]);
748        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
749        // Only positions 0,2,3 should be valid (others null)
750        let null_mask = out.null_mask.as_ref().unwrap();
751        assert_eq!(null_mask.get(0), true);
752        assert_eq!(null_mask.get(1), false);
753        assert_eq!(null_mask.get(2), true);
754        assert_eq!(null_mask.get(3), true);
755        assert_eq!(null_mask.get(4), false);
756        assert_eq!(out.data.get(0), false);
757        assert_eq!(out.data.get(2), false);
758        assert_eq!(out.data.get(3), false);
759    }
760
761    #[test]
762    fn test_conditional_copy_bool_edge_cases() {
763        // Empty
764        let mask = bool_arr(&[]);
765        let then = bm(&[]);
766        let els = bm(&[]);
767        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
768        assert_eq!(out.len, 0);
769        assert!(out.data.is_empty());
770        // All nulls in mask
771        let mut mask = bool_arr(&[true, false, true]);
772        mask.null_mask = Some(bm(&[false, false, false]));
773        let then = bm(&[true, true, true]);
774        let els = bm(&[false, false, false]);
775        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
776        assert_eq!(out.len, 3);
777        let null_mask = out.null_mask.as_ref().unwrap();
778        assert!(!null_mask.get(0));
779        assert!(!null_mask.get(1));
780        assert!(!null_mask.get(2));
781    }
782
783    #[test]
784    fn test_conditional_copy_float_type_dispatch() {
785        // f32
786        let mask = bool_arr(&[true, false]);
787        let then: Vec<f32> = vec![1.0, 2.0];
788        let els: Vec<f32> = vec![3.0, 4.0];
789        let arr = conditional_copy_float(&mask, &then, &els);
790        assert_eq!(&arr.data[..], &[1.0, 4.0]);
791        // f64
792        let mask = bool_arr(&[false, true]);
793        let then: Vec<f64> = vec![7.0, 8.0];
794        let els: Vec<f64> = vec![9.0, 10.0];
795        let arr = conditional_copy_float(&mask, &then, &els);
796        assert_eq!(&arr.data[..], &[9.0, 8.0]);
797    }
798
799    #[test]
800    fn test_conditional_copy_str_basic() {
801        // mask of length 4
802        let mask = bool_arr(&[true, false, false, true]);
803
804        // then_arr and else_arr must also be length 4
805        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "qux"]);
806        let b = StringArray::<u32>::from_slice(&["AAA", "Y", "Z", "BBB"]);
807
808        // Wrap as slices
809        let mask_slice = (&mask, 0, mask.len());
810        let a_slice = (&a, 0, a.len());
811        let b_slice = (&b, 0, b.len());
812
813        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
814
815        // mask picks a[0], b[1], b[2], a[3]
816        assert_eq!(arr.get(0), Some("foo"));
817        assert_eq!(arr.get(1), Some("Y"));
818        assert_eq!(arr.get(2), Some("Z"));
819        assert_eq!(arr.get(3), Some("qux"));
820
821        assert_eq!(arr.len(), 4);
822        let nm = arr.null_mask.as_ref().unwrap();
823        assert!(nm.all_set());
824    }
825
826    #[test]
827    fn test_conditional_copy_str_with_null() {
828        let mut mask = bool_arr(&[true, false, true]);
829        mask.null_mask = Some(bm(&[true, false, true]));
830        let mut a = StringArray::<u32>::from_slice(&["one", "two", "three"]);
831        let mut b = StringArray::<u32>::from_slice(&["uno", "dos", "tres"]);
832        a.set_null(2);
833        b.set_null(0);
834        let mask_slice = (&mask, 0, mask.len());
835        let a_slice = (&a, 0, a.len());
836        let b_slice = (&b, 0, b.len());
837
838        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
839
840        assert_eq!(arr.get(0), Some("one"));
841        assert!(arr.get(1).is_none());
842        assert_eq!(arr.get(2), Some(""));
843        let nm = arr.null_mask.as_ref().unwrap();
844        assert!(nm.get(0));
845        assert!(!nm.get(1));
846        assert!(nm.get(2));
847    }
848
849    #[test]
850    fn test_conditional_copy_str_with_null_chunk() {
851        // pad to allow offset
852        let mut mask = bool_arr(&[false, true, false, true, false]);
853        mask.null_mask = Some(bm(&[false, true, false, true, false]));
854        let mut a = StringArray::<u32>::from_slice(&["", "one", "two", "three", ""]);
855        let mut b = StringArray::<u32>::from_slice(&["", "uno", "dos", "tres", ""]);
856        a.set_null(3);
857        b.set_null(1);
858        let mask_slice = (&mask, 1, 3);
859        let a_slice = (&a, 1, 3);
860        let b_slice = (&b, 1, 3);
861
862        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
863
864        assert_eq!(arr.get(0), Some("one"));
865        assert!(arr.get(1).is_none());
866        assert_eq!(arr.get(2), Some(""));
867        let nm = arr.null_mask.as_ref().unwrap();
868        assert!(nm.get(0));
869        assert!(!nm.get(1));
870        assert!(nm.get(2));
871    }
872
873    #[test]
874    fn test_conditional_copy_str_edge_cases() {
875        let mut mask = bool_arr(&[true, false]);
876        mask.null_mask = Some(bm(&[false, false]));
877        let a = StringArray::<u32>::from_slice(&["foo", "bar"]);
878        let b = StringArray::<u32>::from_slice(&["baz", "qux"]);
879        let mask_slice = (&mask, 0, mask.len());
880        let a_slice = (&a, 0, a.len());
881        let b_slice = (&b, 0, b.len());
882        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
883        assert_eq!(arr.len(), 2);
884        assert!(!arr.null_mask.as_ref().unwrap().get(0));
885        assert!(!arr.null_mask.as_ref().unwrap().get(1));
886        // Empty arrays
887        let mask = bool_arr(&[]);
888        let a = StringArray::<u32>::from_slice(&[]);
889        let b = StringArray::<u32>::from_slice(&[]);
890        let mask_slice = (&mask, 0, 0);
891        let a_slice = (&a, 0, 0);
892        let b_slice = (&b, 0, 0);
893        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
894        assert_eq!(arr.len(), 0);
895    }
896
897    #[test]
898    fn test_conditional_copy_str_edge_cases_chunk() {
899        // chunked window: use 0-length window
900        let mask = bool_arr(&[false, true, false]);
901        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz"]);
902        let b = StringArray::<u32>::from_slice(&["qux", "quux", "quuz"]);
903        let mask_slice = (&mask, 1, 0);
904        let a_slice = (&a, 1, 0);
905        let b_slice = (&b, 1, 0);
906        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
907        assert_eq!(arr.len(), 0);
908    }
909
910    #[test]
911    fn test_conditional_copy_dict32_basic() {
912        let mask = bool_arr(&[true, false, false, true]);
913        let a = CategoricalArray::<u32>::from_slices(
914            &[0, 1, 1, 0],
915            &["dog".to_string(), "cat".to_string()],
916        );
917        let b = CategoricalArray::<u32>::from_slices(
918            &[0, 0, 1, 1],
919            &["fish".to_string(), "cat".to_string()],
920        );
921        let mask_slice = (&mask, 0, mask.len());
922        let a_slice = (&a, 0, a.data.len());
923        let b_slice = (&b, 0, b.data.len());
924
925        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
926
927        let vals: Vec<Option<&str>> = (0..4).map(|i| arr.get(i)).collect();
928        assert_eq!(vals[0], Some("dog"));
929        assert_eq!(vals[1], Some("fish"));
930        assert_eq!(vals[2], Some("cat"));
931        assert_eq!(vals[3], Some("dog"));
932
933        let mut all = arr
934            .unique_values
935            .iter()
936            .map(|s| s.as_str())
937            .collect::<Vec<_>>();
938        all.sort();
939        let mut ref_all = vec!["cat", "dog", "fish"];
940        ref_all.sort();
941        assert_eq!(all, ref_all);
942
943        assert!(arr.null_mask.as_ref().unwrap().all_set());
944    }
945
946    #[test]
947    fn test_conditional_copy_dict32_basic_chunk() {
948        let mask = bool_arr(&[false, true, false, true, false]);
949        let a = CategoricalArray::<u32>::from_slices(
950            &[0, 1, 1, 0, 0],
951            &["dog".to_string(), "cat".to_string()],
952        );
953        let b = CategoricalArray::<u32>::from_slices(
954            &[0, 0, 1, 1, 0],
955            &["fish".to_string(), "cat".to_string()],
956        );
957        let mask_slice = (&mask, 1, 3);
958        let a_slice = (&a, 1, 3);
959        let b_slice = (&b, 1, 3);
960
961        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
962
963        let vals: Vec<Option<&str>> = (0..3).map(|i| arr.get(i)).collect();
964        assert_eq!(vals[0], Some("cat"));
965        assert_eq!(vals[1], Some("cat"));
966        assert_eq!(vals[2], Some("dog"));
967
968        let mut all = arr
969            .unique_values
970            .iter()
971            .map(|s| s.as_str())
972            .collect::<Vec<_>>();
973        all.sort();
974        let mut ref_all = vec!["cat", "dog", "fish"];
975        ref_all.sort();
976        assert_eq!(all, ref_all);
977
978        assert!(arr.null_mask.as_ref().unwrap().all_set());
979    }
980
981    #[test]
982    fn test_conditional_copy_dict32_with_null() {
983        let mut mask = bool_arr(&[true, false, true]);
984        mask.null_mask = Some(bm(&[true, false, true]));
985        let mut a = CategoricalArray::<u32>::from_slices(
986            &[0, 1, 0],
987            &["dog".to_string(), "cat".to_string()],
988        );
989        let mut b = CategoricalArray::<u32>::from_slices(
990            &[1, 0, 1],
991            &["cat".to_string(), "fish".to_string()],
992        );
993        a.set_null(2);
994        b.set_null(0);
995        let mask_slice = (&mask, 0, mask.len());
996        let a_slice = (&a, 0, a.data.len());
997        let b_slice = (&b, 0, b.data.len());
998        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
999        assert_eq!(arr.get(0), Some("dog"));
1000        assert!(arr.get(1).is_none());
1001        assert!(arr.get(2).is_none());
1002        assert!(arr.null_mask.as_ref().unwrap().get(0));
1003        assert!(!arr.null_mask.as_ref().unwrap().get(1));
1004        assert!(!arr.null_mask.as_ref().unwrap().get(2));
1005    }
1006
1007    #[test]
1008    fn test_conditional_copy_dict32_with_null_chunk() {
1009        let mut mask = bool_arr(&[false, true, false, true]);
1010        mask.null_mask = Some(bm(&[false, true, false, true]));
1011        let mut a = CategoricalArray::<u32>::from_slices(
1012            &[0, 1, 0, 1],
1013            &["dog".to_string(), "cat".to_string()],
1014        );
1015        let mut b = CategoricalArray::<u32>::from_slices(
1016            &[1, 0, 1, 0],
1017            &["cat".to_string(), "fish".to_string()],
1018        );
1019        a.set_null(3);
1020        b.set_null(1);
1021        let mask_slice = (&mask, 1, 2);
1022        let a_slice = (&a, 1, 2);
1023        let b_slice = (&b, 1, 2);
1024        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
1025        assert_eq!(arr.get(0), Some("cat"));
1026        assert!(arr.get(1).is_none());
1027        assert!(arr.null_mask.as_ref().unwrap().get(0));
1028        assert!(!arr.null_mask.as_ref().unwrap().get(1));
1029    }
1030
1031    #[test]
1032    fn test_conditional_copy_dict32_edge_cases() {
1033        let mask = bool_arr(&[]);
1034        let a = CategoricalArray::<u32>::from_slices(&[], &[]);
1035        let b = CategoricalArray::<u32>::from_slices(&[], &[]);
1036        let mask_slice = (&mask, 0, 0);
1037        let a_slice = (&a, 0, 0);
1038        let b_slice = (&b, 0, 0);
1039        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
1040        assert_eq!(arr.data.len(), 0);
1041        assert_eq!(arr.unique_values.len(), 0);
1042
1043        let mut mask = bool_arr(&[true, false]);
1044        mask.null_mask = Some(bm(&[false, false]));
1045        let a = CategoricalArray::<u32>::from_slices(&[0, 0], &["foo".to_string()]);
1046        let b = CategoricalArray::<u32>::from_slices(&[0, 0], &["bar".to_string()]);
1047        let mask_slice = (&mask, 0, mask.len());
1048        let a_slice = (&a, 0, a.data.len());
1049        let b_slice = (&b, 0, b.data.len());
1050        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
1051        assert!(!arr.null_mask.as_ref().unwrap().get(0));
1052        assert!(!arr.null_mask.as_ref().unwrap().get(1));
1053    }
1054
1055    #[test]
1056    fn test_conditional_copy_dict32_edge_cases_chunk() {
1057        let mask = bool_arr(&[false, false, false]);
1058        let a = CategoricalArray::<u32>::from_slices(&[0, 0, 0], &["foo".to_string()]);
1059        let b = CategoricalArray::<u32>::from_slices(&[0, 0, 0], &["bar".to_string()]);
1060        let mask_slice = (&mask, 1, 0);
1061        let a_slice = (&a, 1, 0);
1062        let b_slice = (&b, 1, 0);
1063        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
1064        assert_eq!(arr.data.len(), 0);
1065        assert_eq!(arr.unique_values.len(), 0);
1066    }
1067}