simd_kernels/kernels/
conditional.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Conditional Logic Kernels Module** - *High-Performance Conditional Operations and Data Selection*
5//!
6//! Advanced conditional logic kernels providing efficient data selection, filtering, and transformation
7//! operations with comprehensive null handling and SIMD acceleration. Essential infrastructure
8//! for implementing complex analytical workflows and query execution.
9//!
10//! ## Core Operations
11//! - **Conditional selection**: IF-THEN-ELSE operations with three-valued logic support
12//! - **Array filtering**: Efficient boolean mask-based filtering with zero-copy optimisation
13//! - **Coalescing operations**: Null-aware value selection with fallback hierarchies
14//! - **Case-when logic**: Multi-condition branching with optimised evaluation strategies
15//! - **Null propagation**: Comprehensive null handling following Apache Arrow semantics
16//! - **Type preservation**: Maintains input data types through conditional transformations
17
18include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
19
20use std::marker::PhantomData;
21
22#[cfg(feature = "fast_hash")]
23use ahash::AHashMap;
24#[cfg(not(feature = "fast_hash"))]
25use std::collections::HashMap;
26
27use minarrow::{
28    Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, CategoricalArray, FloatArray, Integer,
29    IntegerArray, StringAVT, StringArray, Vec64, enums::error::KernelError,
30    utils::confirm_equal_len,
31};
32
33#[cfg(feature = "simd")]
34use core::simd::{Mask, Simd};
35
36#[cfg(feature = "simd")]
37use minarrow::utils::is_simd_aligned;
38
39use minarrow::utils::confirm_capacity;
40#[cfg(feature = "datetime")]
41use minarrow::{DatetimeArray, TimeUnit};
42
43#[inline(always)]
44fn prealloc_vec<T: Copy>(len: usize) -> Vec64<T> {
45    let mut v = Vec64::<T>::with_capacity(len);
46    // SAFETY: every slot is written before any read
47    unsafe { v.set_len(len) };
48    v
49}
50
51// Numeric Int float
52macro_rules! impl_conditional_copy_numeric {
53    ($fn_name:ident, $ty:ty, $mask_elem:ty, $lanes:expr, $array_ty:ident) => {
54        /// Conditional copy operation: select elements from `then_data` or `else_data` based on boolean mask.
55        #[inline(always)]
56        pub fn $fn_name(
57            mask: &BooleanArray<()>,
58            then_data: &[$ty],
59            else_data: &[$ty],
60        ) -> $array_ty<$ty> {
61            let len = mask.len;
62            let mut data = prealloc_vec::<$ty>(len);
63            let mask_data = &mask.data;
64
65            #[cfg(feature = "simd")]
66            {
67                // Check if both arrays are 64-byte aligned for SIMD
68                if is_simd_aligned(then_data) && is_simd_aligned(else_data) {
69                    const N: usize = $lanes;
70                    let mut i = 0;
71                    while i + N <= len {
72                        let mut bits = [false; N];
73                        for l in 0..N {
74                            bits[l] = unsafe { mask_data.get_unchecked(i + l) };
75                        }
76                        let cond = Mask::<$mask_elem, N>::from_array(bits);
77                        let t = Simd::<$ty, N>::from_slice(&then_data[i..i + N]);
78                        let e = Simd::<$ty, N>::from_slice(&else_data[i..i + N]);
79                        cond.select(t, e).copy_to_slice(&mut data[i..i + N]);
80                        i += N;
81                    }
82                    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
83                    for j in i..len {
84                        data[j] = if unsafe { mask_data.get_unchecked(j) } {
85                            then_data[j]
86                        } else {
87                            else_data[j]
88                        };
89                    }
90                    return $array_ty {
91                        data: data.into(),
92                        null_mask: mask.null_mask.clone(),
93                    };
94                }
95                // Fall through to scalar path if alignment check failed
96            }
97
98            // Scalar fallback - alignment check failed
99            for i in 0..len {
100                data[i] = if unsafe { mask_data.get_unchecked(i) } {
101                    then_data[i]
102                } else {
103                    else_data[i]
104                };
105            }
106
107            $array_ty {
108                data: data.into(),
109                null_mask: mask.null_mask.clone(),
110            }
111        }
112    };
113}
114
115// Conditional datetime
116#[cfg(feature = "datetime")]
117macro_rules! impl_conditional_copy_datetime {
118    ($fn_name:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
119        #[inline(always)]
120        pub fn $fn_name(
121            mask: &BooleanArray<()>,
122            then_data: &[$ty],
123            else_data: &[$ty],
124            time_unit: TimeUnit,
125        ) -> DatetimeArray<$ty> {
126            let len = mask.len;
127            let mut data = prealloc_vec::<$ty>(len);
128            let mask_data = &mask.data;
129
130            #[cfg(feature = "simd")]
131            {
132                // Check if both arrays are 64-byte aligned for SIMD
133                if is_simd_aligned(then_data) && is_simd_aligned(else_data) {
134                    use core::simd::{Mask, Simd};
135
136                    const N: usize = $lanes;
137                    let mut i = 0;
138                    while i + N <= len {
139                        let mut bits = [false; N];
140                        for l in 0..N {
141                            bits[l] = unsafe { mask_data.get_unchecked(i + l) };
142                        }
143                        let cond = Mask::<$mask_elem, N>::from_array(bits);
144                        let t = Simd::<$ty, N>::from_slice(&then_data[i..i + N]);
145                        let e = Simd::<$ty, N>::from_slice(&else_data[i..i + N]);
146                        cond.select(t, e).copy_to_slice(&mut data[i..i + N]);
147                        i += N;
148                    }
149                    // Scalar tail
150                    for j in i..len {
151                        data[j] = if unsafe { mask_data.get_unchecked(j) } {
152                            then_data[j]
153                        } else {
154                            else_data[j]
155                        };
156                    }
157                    return DatetimeArray {
158                        data: data.into(),
159                        null_mask: mask.null_mask.clone(),
160                        time_unit,
161                    };
162                }
163                // Fall through to scalar path if alignment check failed
164            }
165
166            // Scalar fallback - alignment check failed
167            for i in 0..len {
168                data[i] = unsafe {
169                    if mask_data.get_unchecked(i) {
170                        then_data[i]
171                    } else {
172                        else_data[i]
173                    }
174                };
175            }
176
177            DatetimeArray {
178                data: data.into(),
179                null_mask: mask.null_mask.clone(),
180                time_unit,
181            }
182        }
183    };
184}
185
186#[cfg(feature = "extended_numeric_types")]
187impl_conditional_copy_numeric!(conditional_copy_i8, i8, i8, W8, IntegerArray);
188#[cfg(feature = "extended_numeric_types")]
189impl_conditional_copy_numeric!(conditional_copy_u8, u8, i8, W8, IntegerArray);
190#[cfg(feature = "extended_numeric_types")]
191impl_conditional_copy_numeric!(conditional_copy_i16, i16, i16, W16, IntegerArray);
192#[cfg(feature = "extended_numeric_types")]
193impl_conditional_copy_numeric!(conditional_copy_u16, u16, i16, W16, IntegerArray);
194impl_conditional_copy_numeric!(conditional_copy_i32, i32, i32, W32, IntegerArray);
195impl_conditional_copy_numeric!(conditional_copy_u32, u32, i32, W32, IntegerArray);
196impl_conditional_copy_numeric!(conditional_copy_i64, i64, i64, W64, IntegerArray);
197impl_conditional_copy_numeric!(conditional_copy_u64, u64, i64, W64, IntegerArray);
198impl_conditional_copy_numeric!(conditional_copy_f32, f32, i32, W32, FloatArray);
199impl_conditional_copy_numeric!(conditional_copy_f64, f64, i64, W64, FloatArray);
200
201#[cfg(feature = "datetime")]
202impl_conditional_copy_datetime!(conditional_copy_datetime32, i32, i32, W32);
203#[cfg(feature = "datetime")]
204impl_conditional_copy_datetime!(conditional_copy_datetime64, i64, i64, W64);
205
206/// Conditional copy for floating-point arrays with runtime type dispatch.
207#[inline(always)]
208pub fn conditional_copy_float<T: Copy + 'static>(
209    mask: &BooleanArray<()>,
210    then_data: &[T],
211    else_data: &[T],
212) -> FloatArray<T> {
213    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
214        return unsafe {
215            std::mem::transmute(conditional_copy_f32(
216                mask,
217                std::mem::transmute(then_data),
218                std::mem::transmute(else_data),
219            ))
220        };
221    }
222    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
223        return unsafe {
224            std::mem::transmute(conditional_copy_f64(
225                mask,
226                std::mem::transmute(then_data),
227                std::mem::transmute(else_data),
228            ))
229        };
230    }
231    unreachable!("unsupported float type")
232}
233
234// Bit-packed Boolean
235/// Conditional copy operation for boolean bitmask arrays.
236pub fn conditional_copy_bool(
237    mask: &BooleanArray<()>,
238    then_data: &Bitmask,
239    else_data: &Bitmask,
240) -> Result<BooleanArray<()>, KernelError> {
241    let len_bits = mask.len;
242    confirm_capacity("if_then_else: then_data", then_data.capacity(), len_bits)?;
243    confirm_capacity("if_then_else: else_data", else_data.capacity(), len_bits)?;
244
245    // Dense fast-path
246    if mask.null_mask.is_none() {
247        let mut out = Bitmask::with_capacity(len_bits);
248        for i in 0..len_bits {
249            let m = unsafe { mask.data.get_unchecked(i) };
250            let v = if m {
251                unsafe { then_data.get_unchecked(i) }
252            } else {
253                unsafe { else_data.get_unchecked(i) }
254            };
255            out.set(i, v);
256        }
257        return Ok(BooleanArray {
258            data: out.into(),
259            null_mask: None,
260            len: len_bits,
261            _phantom: PhantomData,
262        });
263    }
264
265    // Null-aware path
266    let mut out = Bitmask::with_capacity(len_bits);
267    let mut out_mask = Bitmask::new_set_all(len_bits, false);
268    for i in 0..len_bits {
269        if !mask
270            .null_mask
271            .as_ref()
272            .map_or(true, |m| unsafe { m.get_unchecked(i) })
273        {
274            continue;
275        }
276        let choose_then = unsafe { mask.data.get_unchecked(i) };
277        let v = if choose_then {
278            unsafe { then_data.get_unchecked(i) }
279        } else {
280            unsafe { else_data.get_unchecked(i) }
281        };
282        out.set(i, v);
283        out_mask.set(i, true);
284    }
285
286    Ok(BooleanArray {
287        data: out.into(),
288        null_mask: Some(out_mask),
289        len: len_bits,
290        _phantom: PhantomData,
291    })
292}
293
294// Strings
295#[inline(always)]
296/// Conditional copy operation for UTF-8 string arrays.
297pub fn conditional_copy_str<'a, T: Integer>(
298    mask: BooleanAVT<'a, ()>,
299    then_arr: StringAVT<'a, T>,
300    else_arr: StringAVT<'a, T>,
301) -> Result<StringArray<T>, KernelError> {
302    let (mask_arr, mask_off, mask_len) = mask;
303    let (then_arr, then_off, then_len) = then_arr;
304    let (else_arr, else_off, else_len) = else_arr;
305
306    confirm_equal_len(
307        "if_then_else: then_arr.len() != mask_len",
308        then_len,
309        mask_len,
310    )?;
311    confirm_equal_len(
312        "if_then_else: else_arr.len() != mask_len",
313        else_len,
314        mask_len,
315    )?;
316
317    // First pass: compute total bytes required
318    let mut total_bytes = 0;
319    for i in 0..mask_len {
320        let idx = mask_off + i;
321        let valid = mask_arr
322            .null_mask
323            .as_ref()
324            .map_or(true, |m| unsafe { m.get_unchecked(idx) });
325        if valid {
326            let use_then = unsafe { mask_arr.data.get_unchecked(idx) };
327            let s = unsafe {
328                if use_then {
329                    then_arr.get_str_unchecked(then_off + i)
330                } else {
331                    else_arr.get_str_unchecked(else_off + i)
332                }
333            };
334            total_bytes += s.len();
335        }
336    }
337
338    // Allocate output
339    let mut offsets = Vec64::<T>::with_capacity(mask_len + 1);
340    let mut values = Vec64::<u8>::with_capacity(total_bytes);
341    let mut out_mask = Bitmask::new_set_all(mask_len, false);
342    unsafe {
343        offsets.set_len(mask_len + 1);
344    }
345
346    // Fill
347    offsets[0] = T::zero();
348    let mut cur = 0;
349
350    for i in 0..mask_len {
351        let idx = mask_off + i;
352        let mask_valid = mask_arr
353            .null_mask
354            .as_ref()
355            .map_or(true, |m| unsafe { m.get_unchecked(idx) });
356        if !mask_valid {
357            offsets[i + 1] = T::from_usize(cur);
358            continue;
359        }
360
361        let use_then = unsafe { mask_arr.data.get_unchecked(idx) };
362        let s = unsafe {
363            if use_then {
364                then_arr.get_str_unchecked(then_off + i)
365            } else {
366                else_arr.get_str_unchecked(else_off + i)
367            }
368        }
369        .as_bytes();
370
371        values.extend_from_slice(s);
372        cur += s.len();
373        offsets[i + 1] = T::from_usize(cur);
374        unsafe {
375            out_mask.set_unchecked(i, true);
376        }
377    }
378
379    Ok(StringArray {
380        offsets: offsets.into(),
381        data: values.into(),
382        null_mask: Some(out_mask),
383    })
384}
385
386// Dictionary
387
388/// Conditional copy operation for dictionary/categorical arrays.
389pub fn conditional_copy_dict32<'a, T: Integer>(
390    mask: BooleanAVT<'a, ()>,
391    then_arr: CategoricalAVT<'a, T>,
392    else_arr: CategoricalAVT<'a, T>,
393) -> Result<CategoricalArray<T>, KernelError> {
394    let (mask_arr, mask_off, mask_len) = mask;
395    let (then_arr, then_off, then_len) = then_arr;
396    let (else_arr, else_off, else_len) = else_arr;
397
398    confirm_equal_len(
399        "if_then_else: then_arr.len() != mask_len",
400        then_len,
401        mask_len,
402    )?;
403    confirm_equal_len(
404        "if_then_else: else_arr.len() != mask_len",
405        else_len,
406        mask_len,
407    )?;
408
409    if mask_len == 0 {
410        return Ok(CategoricalArray {
411            data: Vec64::new().into(),
412            unique_values: Vec64::new(),
413            null_mask: Some(Bitmask::new_set_all(0, false)),
414        });
415    }
416
417    // Merge unique values
418    let mut uniques = then_arr.unique_values.clone();
419    for v in &else_arr.unique_values {
420        if !uniques.contains(v) {
421            uniques.push(v.clone());
422        }
423    }
424
425    #[cfg(feature = "fast_hash")]
426    let lookup: AHashMap<&str, T> = uniques
427        .iter()
428        .enumerate()
429        .map(|(i, v)| (v.as_str(), T::from_usize(i)))
430        .collect();
431    #[cfg(not(feature = "fast_hash"))]
432    let lookup: HashMap<&str, T> = uniques
433        .iter()
434        .enumerate()
435        .map(|(i, v)| (v.as_str(), T::from_usize(i)))
436        .collect();
437
438    let mut data = Vec64::<T>::with_capacity(mask_len);
439    unsafe {
440        data.set_len(mask_len);
441    }
442
443    let mut out_mask = Bitmask::new_set_all(mask_len, false);
444
445    for i in 0..mask_len {
446        let mask_idx = mask_off + i;
447        let then_idx = then_off + i;
448        let else_idx = else_off + i;
449
450        let mask_valid = mask_arr
451            .null_mask
452            .as_ref()
453            .map_or(true, |m| unsafe { m.get_unchecked(mask_idx) });
454        if !mask_valid {
455            data[i] = T::zero();
456            continue;
457        }
458
459        let choose_then = unsafe { mask_arr.data.get_unchecked(mask_idx) };
460        let (src_arr, src_idx, valid_mask) = if choose_then {
461            (then_arr, then_idx, then_arr.null_mask.as_ref())
462        } else {
463            (else_arr, else_idx, else_arr.null_mask.as_ref())
464        };
465
466        if valid_mask.map_or(true, |m| unsafe { m.get_unchecked(src_idx) }) {
467            let idx = unsafe { *src_arr.data.get_unchecked(src_idx) }.to_usize();
468            let val = unsafe { src_arr.unique_values.get_unchecked(idx) };
469            data[i] = *lookup.get(val.as_str()).unwrap();
470            unsafe {
471                out_mask.set_unchecked(i, true);
472            }
473        } else {
474            data[i] = T::zero();
475        }
476    }
477
478    Ok(CategoricalArray {
479        data: data.into(),
480        unique_values: uniques,
481        null_mask: Some(out_mask),
482    })
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use minarrow::{
489        Bitmask, BooleanArray, MaskedArray,
490        structs::variants::{categorical::CategoricalArray, string::StringArray},
491    };
492
493    fn bm(bools: &[bool]) -> Bitmask {
494        Bitmask::from_bools(bools)
495    }
496
497    fn bool_arr(bools: &[bool]) -> BooleanArray<()> {
498        BooleanArray::from_slice(bools)
499    }
500
501    #[test]
502    fn test_conditional_copy_numeric_no_null() {
503        // i32
504        let mask = bool_arr(&[true, false, true, false, false, true]);
505        let then = vec![10, 20, 30, 40, 50, 60];
506        let els = vec![1, 2, 3, 4, 5, 6];
507        let arr = conditional_copy_i32(&mask, &then, &els);
508        assert_eq!(&arr.data[..], &[10, 2, 30, 4, 5, 60]);
509        assert!(arr.null_mask.is_none());
510
511        // f64
512        let mask = bool_arr(&[true, false, false]);
513        let then = vec![2.0, 4.0, 6.0];
514        let els = vec![1.0, 3.0, 5.0];
515        let arr = conditional_copy_f64(&mask, &then, &els);
516        assert_eq!(&arr.data[..], &[2.0, 3.0, 5.0]);
517        assert!(arr.null_mask.is_none());
518    }
519
520    #[test]
521    fn test_conditional_copy_numeric_with_null() {
522        let mut mask = bool_arr(&[true, false, true]);
523        mask.null_mask = Some(bm(&[true, false, true]));
524        let then = vec![10i64, 20, 30];
525        let els = vec![1i64, 2, 3];
526        let arr = conditional_copy_i64(&mask, &then, &els);
527        assert_eq!(&arr.data[..], &[10, 2, 30]);
528        let null_mask = arr.null_mask.as_ref().unwrap();
529        assert_eq!(null_mask.get(0), true);
530        assert_eq!(null_mask.get(1), false);
531        assert_eq!(null_mask.get(2), true);
532    }
533
534    #[cfg(feature = "extended_numeric_types")]
535    #[test]
536    fn test_conditional_copy_numeric_edge_cases() {
537        // Empty
538        let mask = bool_arr(&[]);
539        let then: Vec<u32> = vec![];
540        let els: Vec<u32> = vec![];
541        let arr = conditional_copy_u32(&mask, &then, &els);
542        assert_eq!(arr.data.len(), 0);
543        assert!(arr.null_mask.is_none());
544        // 1-element, mask true/false/null
545        let mask = bool_arr(&[true]);
546        let then = vec![42];
547        let els = vec![1];
548        let arr = conditional_copy_u8(&mask, &then, &els);
549        assert_eq!(&arr.data[..], &[42]);
550        let mask = bool_arr(&[false]);
551        let arr = conditional_copy_u8(&mask, &then, &els);
552        assert_eq!(&arr.data[..], &[1]);
553        let mut mask = bool_arr(&[true]);
554        mask.null_mask = Some(bm(&[false]));
555        let arr = conditional_copy_u8(&mask, &then, &els);
556        assert_eq!(&arr.data[..], &[42]);
557        assert_eq!(arr.null_mask.as_ref().unwrap().get(0), false);
558    }
559
560    #[test]
561    fn test_conditional_copy_bool_no_null() {
562        let mask = bool_arr(&[true, false, true, false]);
563        let then = bm(&[true, true, false, false]);
564        let els = bm(&[false, true, true, false]);
565        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
566        assert_eq!(out.data.get(0), true);
567        assert_eq!(out.data.get(1), true);
568        assert_eq!(out.data.get(2), false);
569        assert_eq!(out.data.get(3), false);
570        assert!(out.null_mask.is_none());
571        assert_eq!(out.len, 4);
572    }
573
574    #[test]
575    fn test_conditional_copy_bool_with_null() {
576        let mut mask = bool_arr(&[true, false, true, false, true]);
577        mask.null_mask = Some(bm(&[true, false, true, true, false]));
578        let then = bm(&[false, true, false, true, true]);
579        let els = bm(&[true, false, true, false, false]);
580        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
581        // Only positions 0,2,3 should be valid (others null)
582        let null_mask = out.null_mask.as_ref().unwrap();
583        assert_eq!(null_mask.get(0), true);
584        assert_eq!(null_mask.get(1), false);
585        assert_eq!(null_mask.get(2), true);
586        assert_eq!(null_mask.get(3), true);
587        assert_eq!(null_mask.get(4), false);
588        assert_eq!(out.data.get(0), false);
589        assert_eq!(out.data.get(2), false);
590        assert_eq!(out.data.get(3), false);
591    }
592
593    #[test]
594    fn test_conditional_copy_bool_edge_cases() {
595        // Empty
596        let mask = bool_arr(&[]);
597        let then = bm(&[]);
598        let els = bm(&[]);
599        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
600        assert_eq!(out.len, 0);
601        assert!(out.data.is_empty());
602        // All nulls in mask
603        let mut mask = bool_arr(&[true, false, true]);
604        mask.null_mask = Some(bm(&[false, false, false]));
605        let then = bm(&[true, true, true]);
606        let els = bm(&[false, false, false]);
607        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
608        assert_eq!(out.len, 3);
609        let null_mask = out.null_mask.as_ref().unwrap();
610        assert!(!null_mask.get(0));
611        assert!(!null_mask.get(1));
612        assert!(!null_mask.get(2));
613    }
614
615    #[test]
616    fn test_conditional_copy_float_type_dispatch() {
617        // f32
618        let mask = bool_arr(&[true, false]);
619        let then: Vec<f32> = vec![1.0, 2.0];
620        let els: Vec<f32> = vec![3.0, 4.0];
621        let arr = conditional_copy_float(&mask, &then, &els);
622        assert_eq!(&arr.data[..], &[1.0, 4.0]);
623        // f64
624        let mask = bool_arr(&[false, true]);
625        let then: Vec<f64> = vec![7.0, 8.0];
626        let els: Vec<f64> = vec![9.0, 10.0];
627        let arr = conditional_copy_float(&mask, &then, &els);
628        assert_eq!(&arr.data[..], &[9.0, 8.0]);
629    }
630
631    #[test]
632    fn test_conditional_copy_str_basic() {
633        // mask of length 4
634        let mask = bool_arr(&[true, false, false, true]);
635
636        // then_arr and else_arr must also be length 4
637        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "qux"]);
638        let b = StringArray::<u32>::from_slice(&["AAA", "Y", "Z", "BBB"]);
639
640        // Wrap as slices
641        let mask_slice = (&mask, 0, mask.len());
642        let a_slice = (&a, 0, a.len());
643        let b_slice = (&b, 0, b.len());
644
645        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
646
647        // mask picks a[0], b[1], b[2], a[3]
648        assert_eq!(arr.get(0), Some("foo"));
649        assert_eq!(arr.get(1), Some("Y"));
650        assert_eq!(arr.get(2), Some("Z"));
651        assert_eq!(arr.get(3), Some("qux"));
652
653        assert_eq!(arr.len(), 4);
654        let nm = arr.null_mask.as_ref().unwrap();
655        assert!(nm.all_set());
656    }
657
658    #[test]
659    fn test_conditional_copy_str_with_null() {
660        let mut mask = bool_arr(&[true, false, true]);
661        mask.null_mask = Some(bm(&[true, false, true]));
662        let mut a = StringArray::<u32>::from_slice(&["one", "two", "three"]);
663        let mut b = StringArray::<u32>::from_slice(&["uno", "dos", "tres"]);
664        a.set_null(2);
665        b.set_null(0);
666        let mask_slice = (&mask, 0, mask.len());
667        let a_slice = (&a, 0, a.len());
668        let b_slice = (&b, 0, b.len());
669
670        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
671
672        assert_eq!(arr.get(0), Some("one"));
673        assert!(arr.get(1).is_none());
674        assert_eq!(arr.get(2), Some(""));
675        let nm = arr.null_mask.as_ref().unwrap();
676        assert!(nm.get(0));
677        assert!(!nm.get(1));
678        assert!(nm.get(2));
679    }
680
681    #[test]
682    fn test_conditional_copy_str_with_null_chunk() {
683        // pad to allow offset
684        let mut mask = bool_arr(&[false, true, false, true, false]);
685        mask.null_mask = Some(bm(&[false, true, false, true, false]));
686        let mut a = StringArray::<u32>::from_slice(&["", "one", "two", "three", ""]);
687        let mut b = StringArray::<u32>::from_slice(&["", "uno", "dos", "tres", ""]);
688        a.set_null(3);
689        b.set_null(1);
690        let mask_slice = (&mask, 1, 3);
691        let a_slice = (&a, 1, 3);
692        let b_slice = (&b, 1, 3);
693
694        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
695
696        assert_eq!(arr.get(0), Some("one"));
697        assert!(arr.get(1).is_none());
698        assert_eq!(arr.get(2), Some(""));
699        let nm = arr.null_mask.as_ref().unwrap();
700        assert!(nm.get(0));
701        assert!(!nm.get(1));
702        assert!(nm.get(2));
703    }
704
705    #[test]
706    fn test_conditional_copy_str_edge_cases() {
707        let mut mask = bool_arr(&[true, false]);
708        mask.null_mask = Some(bm(&[false, false]));
709        let a = StringArray::<u32>::from_slice(&["foo", "bar"]);
710        let b = StringArray::<u32>::from_slice(&["baz", "qux"]);
711        let mask_slice = (&mask, 0, mask.len());
712        let a_slice = (&a, 0, a.len());
713        let b_slice = (&b, 0, b.len());
714        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
715        assert_eq!(arr.len(), 2);
716        assert!(!arr.null_mask.as_ref().unwrap().get(0));
717        assert!(!arr.null_mask.as_ref().unwrap().get(1));
718        // Empty arrays
719        let mask = bool_arr(&[]);
720        let a = StringArray::<u32>::from_slice(&[]);
721        let b = StringArray::<u32>::from_slice(&[]);
722        let mask_slice = (&mask, 0, 0);
723        let a_slice = (&a, 0, 0);
724        let b_slice = (&b, 0, 0);
725        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
726        assert_eq!(arr.len(), 0);
727    }
728
729    #[test]
730    fn test_conditional_copy_str_edge_cases_chunk() {
731        // chunked window: use 0-length window
732        let mask = bool_arr(&[false, true, false]);
733        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz"]);
734        let b = StringArray::<u32>::from_slice(&["qux", "quux", "quuz"]);
735        let mask_slice = (&mask, 1, 0);
736        let a_slice = (&a, 1, 0);
737        let b_slice = (&b, 1, 0);
738        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
739        assert_eq!(arr.len(), 0);
740    }
741
742    #[test]
743    fn test_conditional_copy_dict32_basic() {
744        let mask = bool_arr(&[true, false, false, true]);
745        let a = CategoricalArray::<u32>::from_slices(
746            &[0, 1, 1, 0],
747            &["dog".to_string(), "cat".to_string()],
748        );
749        let b = CategoricalArray::<u32>::from_slices(
750            &[0, 0, 1, 1],
751            &["fish".to_string(), "cat".to_string()],
752        );
753        let mask_slice = (&mask, 0, mask.len());
754        let a_slice = (&a, 0, a.data.len());
755        let b_slice = (&b, 0, b.data.len());
756
757        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
758
759        let vals: Vec<Option<&str>> = (0..4).map(|i| arr.get(i)).collect();
760        assert_eq!(vals[0], Some("dog"));
761        assert_eq!(vals[1], Some("fish"));
762        assert_eq!(vals[2], Some("cat"));
763        assert_eq!(vals[3], Some("dog"));
764
765        let mut all = arr
766            .unique_values
767            .iter()
768            .map(|s| s.as_str())
769            .collect::<Vec<_>>();
770        all.sort();
771        let mut ref_all = vec!["cat", "dog", "fish"];
772        ref_all.sort();
773        assert_eq!(all, ref_all);
774
775        assert!(arr.null_mask.as_ref().unwrap().all_set());
776    }
777
778    #[test]
779    fn test_conditional_copy_dict32_basic_chunk() {
780        let mask = bool_arr(&[false, true, false, true, false]);
781        let a = CategoricalArray::<u32>::from_slices(
782            &[0, 1, 1, 0, 0],
783            &["dog".to_string(), "cat".to_string()],
784        );
785        let b = CategoricalArray::<u32>::from_slices(
786            &[0, 0, 1, 1, 0],
787            &["fish".to_string(), "cat".to_string()],
788        );
789        let mask_slice = (&mask, 1, 3);
790        let a_slice = (&a, 1, 3);
791        let b_slice = (&b, 1, 3);
792
793        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
794
795        let vals: Vec<Option<&str>> = (0..3).map(|i| arr.get(i)).collect();
796        assert_eq!(vals[0], Some("cat"));
797        assert_eq!(vals[1], Some("cat"));
798        assert_eq!(vals[2], Some("dog"));
799
800        let mut all = arr
801            .unique_values
802            .iter()
803            .map(|s| s.as_str())
804            .collect::<Vec<_>>();
805        all.sort();
806        let mut ref_all = vec!["cat", "dog", "fish"];
807        ref_all.sort();
808        assert_eq!(all, ref_all);
809
810        assert!(arr.null_mask.as_ref().unwrap().all_set());
811    }
812
813    #[test]
814    fn test_conditional_copy_dict32_with_null() {
815        let mut mask = bool_arr(&[true, false, true]);
816        mask.null_mask = Some(bm(&[true, false, true]));
817        let mut a = CategoricalArray::<u32>::from_slices(
818            &[0, 1, 0],
819            &["dog".to_string(), "cat".to_string()],
820        );
821        let mut b = CategoricalArray::<u32>::from_slices(
822            &[1, 0, 1],
823            &["cat".to_string(), "fish".to_string()],
824        );
825        a.set_null(2);
826        b.set_null(0);
827        let mask_slice = (&mask, 0, mask.len());
828        let a_slice = (&a, 0, a.data.len());
829        let b_slice = (&b, 0, b.data.len());
830        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
831        assert_eq!(arr.get(0), Some("dog"));
832        assert!(arr.get(1).is_none());
833        assert!(arr.get(2).is_none());
834        assert!(arr.null_mask.as_ref().unwrap().get(0));
835        assert!(!arr.null_mask.as_ref().unwrap().get(1));
836        assert!(!arr.null_mask.as_ref().unwrap().get(2));
837    }
838
839    #[test]
840    fn test_conditional_copy_dict32_with_null_chunk() {
841        let mut mask = bool_arr(&[false, true, false, true]);
842        mask.null_mask = Some(bm(&[false, true, false, true]));
843        let mut a = CategoricalArray::<u32>::from_slices(
844            &[0, 1, 0, 1],
845            &["dog".to_string(), "cat".to_string()],
846        );
847        let mut b = CategoricalArray::<u32>::from_slices(
848            &[1, 0, 1, 0],
849            &["cat".to_string(), "fish".to_string()],
850        );
851        a.set_null(3);
852        b.set_null(1);
853        let mask_slice = (&mask, 1, 2);
854        let a_slice = (&a, 1, 2);
855        let b_slice = (&b, 1, 2);
856        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
857        assert_eq!(arr.get(0), Some("cat"));
858        assert!(arr.get(1).is_none());
859        assert!(arr.null_mask.as_ref().unwrap().get(0));
860        assert!(!arr.null_mask.as_ref().unwrap().get(1));
861    }
862
863    #[test]
864    fn test_conditional_copy_dict32_edge_cases() {
865        let mask = bool_arr(&[]);
866        let a = CategoricalArray::<u32>::from_slices(&[], &[]);
867        let b = CategoricalArray::<u32>::from_slices(&[], &[]);
868        let mask_slice = (&mask, 0, 0);
869        let a_slice = (&a, 0, 0);
870        let b_slice = (&b, 0, 0);
871        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
872        assert_eq!(arr.data.len(), 0);
873        assert_eq!(arr.unique_values.len(), 0);
874
875        let mut mask = bool_arr(&[true, false]);
876        mask.null_mask = Some(bm(&[false, false]));
877        let a = CategoricalArray::<u32>::from_slices(&[0, 0], &["foo".to_string()]);
878        let b = CategoricalArray::<u32>::from_slices(&[0, 0], &["bar".to_string()]);
879        let mask_slice = (&mask, 0, mask.len());
880        let a_slice = (&a, 0, a.data.len());
881        let b_slice = (&b, 0, b.data.len());
882        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
883        assert!(!arr.null_mask.as_ref().unwrap().get(0));
884        assert!(!arr.null_mask.as_ref().unwrap().get(1));
885    }
886
887    #[test]
888    fn test_conditional_copy_dict32_edge_cases_chunk() {
889        let mask = bool_arr(&[false, false, false]);
890        let a = CategoricalArray::<u32>::from_slices(&[0, 0, 0], &["foo".to_string()]);
891        let b = CategoricalArray::<u32>::from_slices(&[0, 0, 0], &["bar".to_string()]);
892        let mask_slice = (&mask, 1, 0);
893        let a_slice = (&a, 1, 0);
894        let b_slice = (&b, 1, 0);
895        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
896        assert_eq!(arr.data.len(), 0);
897        assert_eq!(arr.unique_values.len(), 0);
898    }
899}