simd_kernels/kernels/
conditional.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Conditional Logic Kernels Module** - *High-Performance Conditional Operations and Data Selection*
5//!
6//! Advanced conditional logic kernels providing efficient data selection, filtering, and transformation
7//! operations with comprehensive null handling and SIMD acceleration. Essential infrastructure
8//! for implementing complex analytical workflows and query execution.
9//!
10//! ## Core Operations
11//! - **Conditional selection**: IF-THEN-ELSE operations with three-valued logic support
12//! - **Array filtering**: Efficient boolean mask-based filtering with zero-copy optimisation
13//! - **Coalescing operations**: Null-aware value selection with fallback hierarchies
14//! - **Case-when logic**: Multi-condition branching with optimised evaluation strategies
15//! - **Null propagation**: Comprehensive null handling following Apache Arrow semantics
16//! - **Type preservation**: Maintains input data types through conditional transformations
17
18include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
19
20use std::marker::PhantomData;
21
22#[cfg(feature = "fast_hash")]
23use ahash::AHashMap;
24#[cfg(not(feature = "fast_hash"))]
25use std::collections::HashMap;
26
27use minarrow::{
28    Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, CategoricalArray, FloatArray, Integer,
29    IntegerArray, StringAVT, StringArray, Vec64,
30};
31
32#[cfg(feature = "simd")]
33use core::simd::{Mask, Simd};
34
35#[cfg(feature = "simd")]
36use crate::utils::is_simd_aligned;
37
38use crate::{
39    errors::KernelError,
40    utils::{confirm_capacity, confirm_equal_len},
41};
42#[cfg(feature = "datetime")]
43use minarrow::{DatetimeArray, TimeUnit};
44
45#[inline(always)]
46fn prealloc_vec<T: Copy>(len: usize) -> Vec64<T> {
47    let mut v = Vec64::<T>::with_capacity(len);
48    // SAFETY: every slot is written before any read
49    unsafe { v.set_len(len) };
50    v
51}
52
53// Numeric Int float
54macro_rules! impl_conditional_copy_numeric {
55    ($fn_name:ident, $ty:ty, $mask_elem:ty, $lanes:expr, $array_ty:ident) => {
56        /// Conditional copy operation: select elements from `then_data` or `else_data` based on boolean mask.
57        #[inline(always)]
58        pub fn $fn_name(
59            mask: &BooleanArray<()>,
60            then_data: &[$ty],
61            else_data: &[$ty],
62        ) -> $array_ty<$ty> {
63            let len = mask.len;
64            let mut data = prealloc_vec::<$ty>(len);
65            let mask_data = &mask.data;
66
67            #[cfg(feature = "simd")]
68            {
69                // Check if both arrays are 64-byte aligned for SIMD
70                if is_simd_aligned(then_data) && is_simd_aligned(else_data) {
71                    const N: usize = $lanes;
72                    let mut i = 0;
73                    while i + N <= len {
74                        let mut bits = [false; N];
75                        for l in 0..N {
76                            bits[l] = unsafe { mask_data.get_unchecked(i + l) };
77                        }
78                        let cond = Mask::<$mask_elem, N>::from_array(bits);
79                        let t = Simd::<$ty, N>::from_slice(&then_data[i..i + N]);
80                        let e = Simd::<$ty, N>::from_slice(&else_data[i..i + N]);
81                        cond.select(t, e).copy_to_slice(&mut data[i..i + N]);
82                        i += N;
83                    }
84                    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
85                    for j in i..len {
86                        data[j] = if unsafe { mask_data.get_unchecked(j) } {
87                            then_data[j]
88                        } else {
89                            else_data[j]
90                        };
91                    }
92                    return $array_ty {
93                        data: data.into(),
94                        null_mask: mask.null_mask.clone(),
95                    };
96                }
97                // Fall through to scalar path if alignment check failed
98            }
99
100            // Scalar fallback - alignment check failed
101            for i in 0..len {
102                data[i] = if unsafe { mask_data.get_unchecked(i) } {
103                    then_data[i]
104                } else {
105                    else_data[i]
106                };
107            }
108
109            $array_ty {
110                data: data.into(),
111                null_mask: mask.null_mask.clone(),
112            }
113        }
114    };
115}
116
117// Conditional datetime
118#[cfg(feature = "datetime")]
119macro_rules! impl_conditional_copy_datetime {
120    ($fn_name:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
121        #[inline(always)]
122        pub fn $fn_name(
123            mask: &BooleanArray<()>,
124            then_data: &[$ty],
125            else_data: &[$ty],
126            time_unit: TimeUnit,
127        ) -> DatetimeArray<$ty> {
128            let len = mask.len;
129            let mut data = prealloc_vec::<$ty>(len);
130            let mask_data = &mask.data;
131
132            #[cfg(feature = "simd")]
133            {
134                // Check if both arrays are 64-byte aligned for SIMD
135                if is_simd_aligned(then_data) && is_simd_aligned(else_data) {
136                    use core::simd::{Mask, Simd};
137
138                    const N: usize = $lanes;
139                    let mut i = 0;
140                    while i + N <= len {
141                        let mut bits = [false; N];
142                        for l in 0..N {
143                            bits[l] = unsafe { mask_data.get_unchecked(i + l) };
144                        }
145                        let cond = Mask::<$mask_elem, N>::from_array(bits);
146                        let t = Simd::<$ty, N>::from_slice(&then_data[i..i + N]);
147                        let e = Simd::<$ty, N>::from_slice(&else_data[i..i + N]);
148                        cond.select(t, e).copy_to_slice(&mut data[i..i + N]);
149                        i += N;
150                    }
151                    // Scalar tail
152                    for j in i..len {
153                        data[j] = if unsafe { mask_data.get_unchecked(j) } {
154                            then_data[j]
155                        } else {
156                            else_data[j]
157                        };
158                    }
159                    return DatetimeArray {
160                        data: data.into(),
161                        null_mask: mask.null_mask.clone(),
162                        time_unit,
163                    };
164                }
165                // Fall through to scalar path if alignment check failed
166            }
167
168            // Scalar fallback - alignment check failed
169            for i in 0..len {
170                data[i] = unsafe {
171                    if mask_data.get_unchecked(i) {
172                        then_data[i]
173                    } else {
174                        else_data[i]
175                    }
176                };
177            }
178
179            DatetimeArray {
180                data: data.into(),
181                null_mask: mask.null_mask.clone(),
182                time_unit,
183            }
184        }
185    };
186}
187
188#[cfg(feature = "extended_numeric_types")]
189impl_conditional_copy_numeric!(conditional_copy_i8, i8, i8, W8, IntegerArray);
190#[cfg(feature = "extended_numeric_types")]
191impl_conditional_copy_numeric!(conditional_copy_u8, u8, i8, W8, IntegerArray);
192#[cfg(feature = "extended_numeric_types")]
193impl_conditional_copy_numeric!(conditional_copy_i16, i16, i16, W16, IntegerArray);
194#[cfg(feature = "extended_numeric_types")]
195impl_conditional_copy_numeric!(conditional_copy_u16, u16, i16, W16, IntegerArray);
196impl_conditional_copy_numeric!(conditional_copy_i32, i32, i32, W32, IntegerArray);
197impl_conditional_copy_numeric!(conditional_copy_u32, u32, i32, W32, IntegerArray);
198impl_conditional_copy_numeric!(conditional_copy_i64, i64, i64, W64, IntegerArray);
199impl_conditional_copy_numeric!(conditional_copy_u64, u64, i64, W64, IntegerArray);
200impl_conditional_copy_numeric!(conditional_copy_f32, f32, i32, W32, FloatArray);
201impl_conditional_copy_numeric!(conditional_copy_f64, f64, i64, W64, FloatArray);
202
203#[cfg(feature = "datetime")]
204impl_conditional_copy_datetime!(conditional_copy_datetime32, i32, i32, W32);
205#[cfg(feature = "datetime")]
206impl_conditional_copy_datetime!(conditional_copy_datetime64, i64, i64, W64);
207
208/// Conditional copy for floating-point arrays with runtime type dispatch.
209#[inline(always)]
210pub fn conditional_copy_float<T: Copy + 'static>(
211    mask: &BooleanArray<()>,
212    then_data: &[T],
213    else_data: &[T],
214) -> FloatArray<T> {
215    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
216        return unsafe {
217            std::mem::transmute(conditional_copy_f32(
218                mask,
219                std::mem::transmute(then_data),
220                std::mem::transmute(else_data),
221            ))
222        };
223    }
224    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
225        return unsafe {
226            std::mem::transmute(conditional_copy_f64(
227                mask,
228                std::mem::transmute(then_data),
229                std::mem::transmute(else_data),
230            ))
231        };
232    }
233    unreachable!("unsupported float type")
234}
235
236// Bit-packed Boolean
237/// Conditional copy operation for boolean bitmask arrays.
238pub fn conditional_copy_bool(
239    mask: &BooleanArray<()>,
240    then_data: &Bitmask,
241    else_data: &Bitmask,
242) -> Result<BooleanArray<()>, KernelError> {
243    let len_bits = mask.len;
244    confirm_capacity("if_then_else: then_data", then_data.capacity(), len_bits)?;
245    confirm_capacity("if_then_else: else_data", else_data.capacity(), len_bits)?;
246
247    // Dense fast-path
248    if mask.null_mask.is_none() {
249        let mut out = Bitmask::with_capacity(len_bits);
250        for i in 0..len_bits {
251            let m = unsafe { mask.data.get_unchecked(i) };
252            let v = if m {
253                unsafe { then_data.get_unchecked(i) }
254            } else {
255                unsafe { else_data.get_unchecked(i) }
256            };
257            out.set(i, v);
258        }
259        return Ok(BooleanArray {
260            data: out.into(),
261            null_mask: None,
262            len: len_bits,
263            _phantom: PhantomData,
264        });
265    }
266
267    // Null-aware path
268    let mut out = Bitmask::with_capacity(len_bits);
269    let mut out_mask = Bitmask::new_set_all(len_bits, false);
270    for i in 0..len_bits {
271        if !mask
272            .null_mask
273            .as_ref()
274            .map_or(true, |m| unsafe { m.get_unchecked(i) })
275        {
276            continue;
277        }
278        let choose_then = unsafe { mask.data.get_unchecked(i) };
279        let v = if choose_then {
280            unsafe { then_data.get_unchecked(i) }
281        } else {
282            unsafe { else_data.get_unchecked(i) }
283        };
284        out.set(i, v);
285        out_mask.set(i, true);
286    }
287
288    Ok(BooleanArray {
289        data: out.into(),
290        null_mask: Some(out_mask),
291        len: len_bits,
292        _phantom: PhantomData,
293    })
294}
295
296// Strings
297#[inline(always)]
298/// Conditional copy operation for UTF-8 string arrays.
299pub fn conditional_copy_str<'a, T: Integer>(
300    mask: BooleanAVT<'a, ()>,
301    then_arr: StringAVT<'a, T>,
302    else_arr: StringAVT<'a, T>,
303) -> Result<StringArray<T>, KernelError> {
304    let (mask_arr, mask_off, mask_len) = mask;
305    let (then_arr, then_off, then_len) = then_arr;
306    let (else_arr, else_off, else_len) = else_arr;
307
308    confirm_equal_len(
309        "if_then_else: then_arr.len() != mask_len",
310        then_len,
311        mask_len,
312    )?;
313    confirm_equal_len(
314        "if_then_else: else_arr.len() != mask_len",
315        else_len,
316        mask_len,
317    )?;
318
319    // First pass: compute total bytes required
320    let mut total_bytes = 0;
321    for i in 0..mask_len {
322        let idx = mask_off + i;
323        let valid = mask_arr
324            .null_mask
325            .as_ref()
326            .map_or(true, |m| unsafe { m.get_unchecked(idx) });
327        if valid {
328            let use_then = unsafe { mask_arr.data.get_unchecked(idx) };
329            let s = unsafe {
330                if use_then {
331                    then_arr.get_str_unchecked(then_off + i)
332                } else {
333                    else_arr.get_str_unchecked(else_off + i)
334                }
335            };
336            total_bytes += s.len();
337        }
338    }
339
340    // Allocate output
341    let mut offsets = Vec64::<T>::with_capacity(mask_len + 1);
342    let mut values = Vec64::<u8>::with_capacity(total_bytes);
343    let mut out_mask = Bitmask::new_set_all(mask_len, false);
344    unsafe {
345        offsets.set_len(mask_len + 1);
346    }
347
348    // Fill
349    offsets[0] = T::zero();
350    let mut cur = 0;
351
352    for i in 0..mask_len {
353        let idx = mask_off + i;
354        let mask_valid = mask_arr
355            .null_mask
356            .as_ref()
357            .map_or(true, |m| unsafe { m.get_unchecked(idx) });
358        if !mask_valid {
359            offsets[i + 1] = T::from_usize(cur);
360            continue;
361        }
362
363        let use_then = unsafe { mask_arr.data.get_unchecked(idx) };
364        let s = unsafe {
365            if use_then {
366                then_arr.get_str_unchecked(then_off + i)
367            } else {
368                else_arr.get_str_unchecked(else_off + i)
369            }
370        }
371        .as_bytes();
372
373        values.extend_from_slice(s);
374        cur += s.len();
375        offsets[i + 1] = T::from_usize(cur);
376        unsafe {
377            out_mask.set_unchecked(i, true);
378        }
379    }
380
381    Ok(StringArray {
382        offsets: offsets.into(),
383        data: values.into(),
384        null_mask: Some(out_mask),
385    })
386}
387
388// Dictionary
389
390/// Conditional copy operation for dictionary/categorical arrays.
391pub fn conditional_copy_dict32<'a, T: Integer>(
392    mask: BooleanAVT<'a, ()>,
393    then_arr: CategoricalAVT<'a, T>,
394    else_arr: CategoricalAVT<'a, T>,
395) -> Result<CategoricalArray<T>, KernelError> {
396    let (mask_arr, mask_off, mask_len) = mask;
397    let (then_arr, then_off, then_len) = then_arr;
398    let (else_arr, else_off, else_len) = else_arr;
399
400    confirm_equal_len(
401        "if_then_else: then_arr.len() != mask_len",
402        then_len,
403        mask_len,
404    )?;
405    confirm_equal_len(
406        "if_then_else: else_arr.len() != mask_len",
407        else_len,
408        mask_len,
409    )?;
410
411    if mask_len == 0 {
412        return Ok(CategoricalArray {
413            data: Vec64::new().into(),
414            unique_values: Vec64::new(),
415            null_mask: Some(Bitmask::new_set_all(0, false)),
416        });
417    }
418
419    // Merge unique values
420    let mut uniques = then_arr.unique_values.clone();
421    for v in &else_arr.unique_values {
422        if !uniques.contains(v) {
423            uniques.push(v.clone());
424        }
425    }
426
427    #[cfg(feature = "fast_hash")]
428    let lookup: AHashMap<&str, T> = uniques
429        .iter()
430        .enumerate()
431        .map(|(i, v)| (v.as_str(), T::from_usize(i)))
432        .collect();
433    #[cfg(not(feature = "fast_hash"))]
434    let lookup: HashMap<&str, T> = uniques
435        .iter()
436        .enumerate()
437        .map(|(i, v)| (v.as_str(), T::from_usize(i)))
438        .collect();
439
440    let mut data = Vec64::<T>::with_capacity(mask_len);
441    unsafe {
442        data.set_len(mask_len);
443    }
444
445    let mut out_mask = Bitmask::new_set_all(mask_len, false);
446
447    for i in 0..mask_len {
448        let mask_idx = mask_off + i;
449        let then_idx = then_off + i;
450        let else_idx = else_off + i;
451
452        let mask_valid = mask_arr
453            .null_mask
454            .as_ref()
455            .map_or(true, |m| unsafe { m.get_unchecked(mask_idx) });
456        if !mask_valid {
457            data[i] = T::zero();
458            continue;
459        }
460
461        let choose_then = unsafe { mask_arr.data.get_unchecked(mask_idx) };
462        let (src_arr, src_idx, valid_mask) = if choose_then {
463            (then_arr, then_idx, then_arr.null_mask.as_ref())
464        } else {
465            (else_arr, else_idx, else_arr.null_mask.as_ref())
466        };
467
468        if valid_mask.map_or(true, |m| unsafe { m.get_unchecked(src_idx) }) {
469            let idx = unsafe { *src_arr.data.get_unchecked(src_idx) }.to_usize();
470            let val = unsafe { src_arr.unique_values.get_unchecked(idx) };
471            data[i] = *lookup.get(val.as_str()).unwrap();
472            unsafe {
473                out_mask.set_unchecked(i, true);
474            }
475        } else {
476            data[i] = T::zero();
477        }
478    }
479
480    Ok(CategoricalArray {
481        data: data.into(),
482        unique_values: uniques,
483        null_mask: Some(out_mask),
484    })
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use minarrow::{
491        Bitmask, BooleanArray, MaskedArray,
492        structs::variants::{categorical::CategoricalArray, string::StringArray},
493    };
494
495    fn bm(bools: &[bool]) -> Bitmask {
496        Bitmask::from_bools(bools)
497    }
498
499    fn bool_arr(bools: &[bool]) -> BooleanArray<()> {
500        BooleanArray::from_slice(bools)
501    }
502
503    #[test]
504    fn test_conditional_copy_numeric_no_null() {
505        // i32
506        let mask = bool_arr(&[true, false, true, false, false, true]);
507        let then = vec![10, 20, 30, 40, 50, 60];
508        let els = vec![1, 2, 3, 4, 5, 6];
509        let arr = conditional_copy_i32(&mask, &then, &els);
510        assert_eq!(&arr.data[..], &[10, 2, 30, 4, 5, 60]);
511        assert!(arr.null_mask.is_none());
512
513        // f64
514        let mask = bool_arr(&[true, false, false]);
515        let then = vec![2.0, 4.0, 6.0];
516        let els = vec![1.0, 3.0, 5.0];
517        let arr = conditional_copy_f64(&mask, &then, &els);
518        assert_eq!(&arr.data[..], &[2.0, 3.0, 5.0]);
519        assert!(arr.null_mask.is_none());
520    }
521
522    #[test]
523    fn test_conditional_copy_numeric_with_null() {
524        let mut mask = bool_arr(&[true, false, true]);
525        mask.null_mask = Some(bm(&[true, false, true]));
526        let then = vec![10i64, 20, 30];
527        let els = vec![1i64, 2, 3];
528        let arr = conditional_copy_i64(&mask, &then, &els);
529        assert_eq!(&arr.data[..], &[10, 2, 30]);
530        let null_mask = arr.null_mask.as_ref().unwrap();
531        assert_eq!(null_mask.get(0), true);
532        assert_eq!(null_mask.get(1), false);
533        assert_eq!(null_mask.get(2), true);
534    }
535
536    #[cfg(feature = "extended_numeric_types")]
537    #[test]
538    fn test_conditional_copy_numeric_edge_cases() {
539        // Empty
540        let mask = bool_arr(&[]);
541        let then: Vec<u32> = vec![];
542        let els: Vec<u32> = vec![];
543        let arr = conditional_copy_u32(&mask, &then, &els);
544        assert_eq!(arr.data.len(), 0);
545        assert!(arr.null_mask.is_none());
546        // 1-element, mask true/false/null
547        let mask = bool_arr(&[true]);
548        let then = vec![42];
549        let els = vec![1];
550        let arr = conditional_copy_u8(&mask, &then, &els);
551        assert_eq!(&arr.data[..], &[42]);
552        let mask = bool_arr(&[false]);
553        let arr = conditional_copy_u8(&mask, &then, &els);
554        assert_eq!(&arr.data[..], &[1]);
555        let mut mask = bool_arr(&[true]);
556        mask.null_mask = Some(bm(&[false]));
557        let arr = conditional_copy_u8(&mask, &then, &els);
558        assert_eq!(&arr.data[..], &[42]);
559        assert_eq!(arr.null_mask.as_ref().unwrap().get(0), false);
560    }
561
562    #[test]
563    fn test_conditional_copy_bool_no_null() {
564        let mask = bool_arr(&[true, false, true, false]);
565        let then = bm(&[true, true, false, false]);
566        let els = bm(&[false, true, true, false]);
567        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
568        assert_eq!(out.data.get(0), true);
569        assert_eq!(out.data.get(1), true);
570        assert_eq!(out.data.get(2), false);
571        assert_eq!(out.data.get(3), false);
572        assert!(out.null_mask.is_none());
573        assert_eq!(out.len, 4);
574    }
575
576    #[test]
577    fn test_conditional_copy_bool_with_null() {
578        let mut mask = bool_arr(&[true, false, true, false, true]);
579        mask.null_mask = Some(bm(&[true, false, true, true, false]));
580        let then = bm(&[false, true, false, true, true]);
581        let els = bm(&[true, false, true, false, false]);
582        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
583        // Only positions 0,2,3 should be valid (others null)
584        let null_mask = out.null_mask.as_ref().unwrap();
585        assert_eq!(null_mask.get(0), true);
586        assert_eq!(null_mask.get(1), false);
587        assert_eq!(null_mask.get(2), true);
588        assert_eq!(null_mask.get(3), true);
589        assert_eq!(null_mask.get(4), false);
590        assert_eq!(out.data.get(0), false);
591        assert_eq!(out.data.get(2), false);
592        assert_eq!(out.data.get(3), false);
593    }
594
595    #[test]
596    fn test_conditional_copy_bool_edge_cases() {
597        // Empty
598        let mask = bool_arr(&[]);
599        let then = bm(&[]);
600        let els = bm(&[]);
601        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
602        assert_eq!(out.len, 0);
603        assert!(out.data.is_empty());
604        // All nulls in mask
605        let mut mask = bool_arr(&[true, false, true]);
606        mask.null_mask = Some(bm(&[false, false, false]));
607        let then = bm(&[true, true, true]);
608        let els = bm(&[false, false, false]);
609        let out = conditional_copy_bool(&mask, &then, &els).unwrap();
610        assert_eq!(out.len, 3);
611        let null_mask = out.null_mask.as_ref().unwrap();
612        assert!(!null_mask.get(0));
613        assert!(!null_mask.get(1));
614        assert!(!null_mask.get(2));
615    }
616
617    #[test]
618    fn test_conditional_copy_float_type_dispatch() {
619        // f32
620        let mask = bool_arr(&[true, false]);
621        let then: Vec<f32> = vec![1.0, 2.0];
622        let els: Vec<f32> = vec![3.0, 4.0];
623        let arr = conditional_copy_float(&mask, &then, &els);
624        assert_eq!(&arr.data[..], &[1.0, 4.0]);
625        // f64
626        let mask = bool_arr(&[false, true]);
627        let then: Vec<f64> = vec![7.0, 8.0];
628        let els: Vec<f64> = vec![9.0, 10.0];
629        let arr = conditional_copy_float(&mask, &then, &els);
630        assert_eq!(&arr.data[..], &[9.0, 8.0]);
631    }
632
633    #[test]
634    fn test_conditional_copy_str_basic() {
635        // mask of length 4
636        let mask = bool_arr(&[true, false, false, true]);
637
638        // then_arr and else_arr must also be length 4
639        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "qux"]);
640        let b = StringArray::<u32>::from_slice(&["AAA", "Y", "Z", "BBB"]);
641
642        // Wrap as slices
643        let mask_slice = (&mask, 0, mask.len());
644        let a_slice = (&a, 0, a.len());
645        let b_slice = (&b, 0, b.len());
646
647        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
648
649        // mask picks a[0], b[1], b[2], a[3]
650        assert_eq!(arr.get(0), Some("foo"));
651        assert_eq!(arr.get(1), Some("Y"));
652        assert_eq!(arr.get(2), Some("Z"));
653        assert_eq!(arr.get(3), Some("qux"));
654
655        assert_eq!(arr.len(), 4);
656        let nm = arr.null_mask.as_ref().unwrap();
657        assert!(nm.all_set());
658    }
659
660    #[test]
661    fn test_conditional_copy_str_with_null() {
662        let mut mask = bool_arr(&[true, false, true]);
663        mask.null_mask = Some(bm(&[true, false, true]));
664        let mut a = StringArray::<u32>::from_slice(&["one", "two", "three"]);
665        let mut b = StringArray::<u32>::from_slice(&["uno", "dos", "tres"]);
666        a.set_null(2);
667        b.set_null(0);
668        let mask_slice = (&mask, 0, mask.len());
669        let a_slice = (&a, 0, a.len());
670        let b_slice = (&b, 0, b.len());
671
672        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
673
674        assert_eq!(arr.get(0), Some("one"));
675        assert!(arr.get(1).is_none());
676        assert_eq!(arr.get(2), Some(""));
677        let nm = arr.null_mask.as_ref().unwrap();
678        assert!(nm.get(0));
679        assert!(!nm.get(1));
680        assert!(nm.get(2));
681    }
682
683    #[test]
684    fn test_conditional_copy_str_with_null_chunk() {
685        // pad to allow offset
686        let mut mask = bool_arr(&[false, true, false, true, false]);
687        mask.null_mask = Some(bm(&[false, true, false, true, false]));
688        let mut a = StringArray::<u32>::from_slice(&["", "one", "two", "three", ""]);
689        let mut b = StringArray::<u32>::from_slice(&["", "uno", "dos", "tres", ""]);
690        a.set_null(3);
691        b.set_null(1);
692        let mask_slice = (&mask, 1, 3);
693        let a_slice = (&a, 1, 3);
694        let b_slice = (&b, 1, 3);
695
696        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
697
698        assert_eq!(arr.get(0), Some("one"));
699        assert!(arr.get(1).is_none());
700        assert_eq!(arr.get(2), Some(""));
701        let nm = arr.null_mask.as_ref().unwrap();
702        assert!(nm.get(0));
703        assert!(!nm.get(1));
704        assert!(nm.get(2));
705    }
706
707    #[test]
708    fn test_conditional_copy_str_edge_cases() {
709        let mut mask = bool_arr(&[true, false]);
710        mask.null_mask = Some(bm(&[false, false]));
711        let a = StringArray::<u32>::from_slice(&["foo", "bar"]);
712        let b = StringArray::<u32>::from_slice(&["baz", "qux"]);
713        let mask_slice = (&mask, 0, mask.len());
714        let a_slice = (&a, 0, a.len());
715        let b_slice = (&b, 0, b.len());
716        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
717        assert_eq!(arr.len(), 2);
718        assert!(!arr.null_mask.as_ref().unwrap().get(0));
719        assert!(!arr.null_mask.as_ref().unwrap().get(1));
720        // Empty arrays
721        let mask = bool_arr(&[]);
722        let a = StringArray::<u32>::from_slice(&[]);
723        let b = StringArray::<u32>::from_slice(&[]);
724        let mask_slice = (&mask, 0, 0);
725        let a_slice = (&a, 0, 0);
726        let b_slice = (&b, 0, 0);
727        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
728        assert_eq!(arr.len(), 0);
729    }
730
731    #[test]
732    fn test_conditional_copy_str_edge_cases_chunk() {
733        // chunked window: use 0-length window
734        let mask = bool_arr(&[false, true, false]);
735        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz"]);
736        let b = StringArray::<u32>::from_slice(&["qux", "quux", "quuz"]);
737        let mask_slice = (&mask, 1, 0);
738        let a_slice = (&a, 1, 0);
739        let b_slice = (&b, 1, 0);
740        let arr = conditional_copy_str(mask_slice, a_slice, b_slice).unwrap();
741        assert_eq!(arr.len(), 0);
742    }
743
744    #[test]
745    fn test_conditional_copy_dict32_basic() {
746        let mask = bool_arr(&[true, false, false, true]);
747        let a = CategoricalArray::<u32>::from_slices(
748            &[0, 1, 1, 0],
749            &["dog".to_string(), "cat".to_string()],
750        );
751        let b = CategoricalArray::<u32>::from_slices(
752            &[0, 0, 1, 1],
753            &["fish".to_string(), "cat".to_string()],
754        );
755        let mask_slice = (&mask, 0, mask.len());
756        let a_slice = (&a, 0, a.data.len());
757        let b_slice = (&b, 0, b.data.len());
758
759        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
760
761        let vals: Vec<Option<&str>> = (0..4).map(|i| arr.get(i)).collect();
762        assert_eq!(vals[0], Some("dog"));
763        assert_eq!(vals[1], Some("fish"));
764        assert_eq!(vals[2], Some("cat"));
765        assert_eq!(vals[3], Some("dog"));
766
767        let mut all = arr
768            .unique_values
769            .iter()
770            .map(|s| s.as_str())
771            .collect::<Vec<_>>();
772        all.sort();
773        let mut ref_all = vec!["cat", "dog", "fish"];
774        ref_all.sort();
775        assert_eq!(all, ref_all);
776
777        assert!(arr.null_mask.as_ref().unwrap().all_set());
778    }
779
780    #[test]
781    fn test_conditional_copy_dict32_basic_chunk() {
782        let mask = bool_arr(&[false, true, false, true, false]);
783        let a = CategoricalArray::<u32>::from_slices(
784            &[0, 1, 1, 0, 0],
785            &["dog".to_string(), "cat".to_string()],
786        );
787        let b = CategoricalArray::<u32>::from_slices(
788            &[0, 0, 1, 1, 0],
789            &["fish".to_string(), "cat".to_string()],
790        );
791        let mask_slice = (&mask, 1, 3);
792        let a_slice = (&a, 1, 3);
793        let b_slice = (&b, 1, 3);
794
795        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
796
797        let vals: Vec<Option<&str>> = (0..3).map(|i| arr.get(i)).collect();
798        assert_eq!(vals[0], Some("cat"));
799        assert_eq!(vals[1], Some("cat"));
800        assert_eq!(vals[2], Some("dog"));
801
802        let mut all = arr
803            .unique_values
804            .iter()
805            .map(|s| s.as_str())
806            .collect::<Vec<_>>();
807        all.sort();
808        let mut ref_all = vec!["cat", "dog", "fish"];
809        ref_all.sort();
810        assert_eq!(all, ref_all);
811
812        assert!(arr.null_mask.as_ref().unwrap().all_set());
813    }
814
815    #[test]
816    fn test_conditional_copy_dict32_with_null() {
817        let mut mask = bool_arr(&[true, false, true]);
818        mask.null_mask = Some(bm(&[true, false, true]));
819        let mut a = CategoricalArray::<u32>::from_slices(
820            &[0, 1, 0],
821            &["dog".to_string(), "cat".to_string()],
822        );
823        let mut b = CategoricalArray::<u32>::from_slices(
824            &[1, 0, 1],
825            &["cat".to_string(), "fish".to_string()],
826        );
827        a.set_null(2);
828        b.set_null(0);
829        let mask_slice = (&mask, 0, mask.len());
830        let a_slice = (&a, 0, a.data.len());
831        let b_slice = (&b, 0, b.data.len());
832        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
833        assert_eq!(arr.get(0), Some("dog"));
834        assert!(arr.get(1).is_none());
835        assert!(arr.get(2).is_none());
836        assert!(arr.null_mask.as_ref().unwrap().get(0));
837        assert!(!arr.null_mask.as_ref().unwrap().get(1));
838        assert!(!arr.null_mask.as_ref().unwrap().get(2));
839    }
840
841    #[test]
842    fn test_conditional_copy_dict32_with_null_chunk() {
843        let mut mask = bool_arr(&[false, true, false, true]);
844        mask.null_mask = Some(bm(&[false, true, false, true]));
845        let mut a = CategoricalArray::<u32>::from_slices(
846            &[0, 1, 0, 1],
847            &["dog".to_string(), "cat".to_string()],
848        );
849        let mut b = CategoricalArray::<u32>::from_slices(
850            &[1, 0, 1, 0],
851            &["cat".to_string(), "fish".to_string()],
852        );
853        a.set_null(3);
854        b.set_null(1);
855        let mask_slice = (&mask, 1, 2);
856        let a_slice = (&a, 1, 2);
857        let b_slice = (&b, 1, 2);
858        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
859        assert_eq!(arr.get(0), Some("cat"));
860        assert!(arr.get(1).is_none());
861        assert!(arr.null_mask.as_ref().unwrap().get(0));
862        assert!(!arr.null_mask.as_ref().unwrap().get(1));
863    }
864
865    #[test]
866    fn test_conditional_copy_dict32_edge_cases() {
867        let mask = bool_arr(&[]);
868        let a = CategoricalArray::<u32>::from_slices(&[], &[]);
869        let b = CategoricalArray::<u32>::from_slices(&[], &[]);
870        let mask_slice = (&mask, 0, 0);
871        let a_slice = (&a, 0, 0);
872        let b_slice = (&b, 0, 0);
873        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
874        assert_eq!(arr.data.len(), 0);
875        assert_eq!(arr.unique_values.len(), 0);
876
877        let mut mask = bool_arr(&[true, false]);
878        mask.null_mask = Some(bm(&[false, false]));
879        let a = CategoricalArray::<u32>::from_slices(&[0, 0], &["foo".to_string()]);
880        let b = CategoricalArray::<u32>::from_slices(&[0, 0], &["bar".to_string()]);
881        let mask_slice = (&mask, 0, mask.len());
882        let a_slice = (&a, 0, a.data.len());
883        let b_slice = (&b, 0, b.data.len());
884        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
885        assert!(!arr.null_mask.as_ref().unwrap().get(0));
886        assert!(!arr.null_mask.as_ref().unwrap().get(1));
887    }
888
889    #[test]
890    fn test_conditional_copy_dict32_edge_cases_chunk() {
891        let mask = bool_arr(&[false, false, false]);
892        let a = CategoricalArray::<u32>::from_slices(&[0, 0, 0], &["foo".to_string()]);
893        let b = CategoricalArray::<u32>::from_slices(&[0, 0, 0], &["bar".to_string()]);
894        let mask_slice = (&mask, 1, 0);
895        let a_slice = (&a, 1, 0);
896        let b_slice = (&b, 1, 0);
897        let arr = conditional_copy_dict32(mask_slice, a_slice, b_slice).unwrap();
898        assert_eq!(arr.data.len(), 0);
899        assert_eq!(arr.unique_values.len(), 0);
900    }
901}