1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::collections::HashSet;
22use std::hash::Hash;
23use std::marker::PhantomData;
24#[cfg(feature = "simd")]
25use std::simd::{Mask, Simd, cmp::SimdPartialEq, cmp::SimdPartialOrd, num::SimdFloat};
26
27use minarrow::kernels::arithmetic::string::MAX_DICT_CHECK;
28use minarrow::traits::type_unions::Float;
29use minarrow::{
30 Array, Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, MaskedArray, Numeric,
31 NumericArray, StringAVT, TextArray, Vec64,
32};
33
34#[cfg(not(feature = "simd"))]
35use crate::kernels::bitmask::dispatch::{and_masks, or_masks, xor_masks};
36use crate::operators::LogicalOperator;
37use minarrow::enums::error::KernelError;
38#[cfg(feature = "simd")]
39use minarrow::kernels::bitmask::simd::{and_masks_simd, or_masks_simd, xor_masks_simd};
40use minarrow::utils::confirm_mask_capacity;
41
42#[cfg(feature = "simd")]
43use minarrow::utils::is_simd_aligned;
44use std::any::TypeId;
45
46#[inline(always)]
49fn new_bool_buffer(len: usize) -> Bitmask {
50 Bitmask::new_set_all(len, false)
51}
52
53macro_rules! impl_between_numeric {
56 ($name:ident, $name_to:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
57 #[inline(always)]
62 pub fn $name_to(
63 lhs: &[$ty],
64 rhs: &[$ty],
65 mask: Option<&Bitmask>,
66 has_nulls: bool,
67 output: &mut Bitmask,
68 ) -> Result<(), KernelError> {
69 let len = lhs.len();
70 if rhs.len() != 2 && rhs.len() != 2 * len {
71 return Err(KernelError::InvalidArguments(
72 format!("between: RHS must have len 2 or 2×LHS (got lhs: {}, rhs: {})", len, rhs.len())
73 ));
74 }
75
76 if let Some(m) = mask {
77 if m.capacity() < len {
78 return Err(KernelError::InvalidArguments(
79 format!("between: mask (Bitmask) capacity must be ≥ len (got capacity: {}, len: {})", m.capacity(), len)
80 ));
81 }
82 }
83 assert!(output.capacity() >= len, concat!(stringify!($name_to), ": output capacity too small"));
84
85 #[cfg(feature = "simd")]
87 {
88 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90 const N: usize = $lanes;
91 type V = Simd<$ty, N>;
92 type M = Mask<$mask_elem, N>;
93
94 if !has_nulls && rhs.len() == 2 {
95 let min_v = V::splat(rhs[0]);
96 let max_v = V::splat(rhs[1]);
97
98 let mut i = 0usize;
99 while i + N <= len {
100 let x = V::from_slice(&lhs[i..i + N]);
101 let m: M = x.simd_ge(min_v) & x.simd_le(max_v);
102 let bm = m.to_bitmask();
103
104 for l in 0..N {
105 if ((bm >> l) & 1) == 1 {
106 output.set(i + l, true);
107 }
108 }
109 i += N;
110 }
111 for j in i..len {
113 if lhs[j] >= rhs[0] && lhs[j] <= rhs[1] {
114 output.set(j, true);
115 }
116 }
117
118 return Ok(());
119 }
120 }
121 }
123
124 if rhs.len() == 2 {
126 let (min, max) = (rhs[0], rhs[1]);
127 for i in 0..len {
128 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
129 && lhs[i] >= min
130 && lhs[i] <= max
131 {
132 output.set(i, true);
133 }
134 }
135 } else {
136 for i in 0..len {
138 let min = rhs[i * 2];
139 let max = rhs[i * 2 + 1];
140 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
141 && lhs[i] >= min
142 && lhs[i] <= max
143 {
144 output.set(i, true);
145 }
146 }
147 }
148
149 Ok(())
150 }
151
152 #[inline(always)]
154 pub fn $name(
155 lhs: &[$ty],
156 rhs: &[$ty],
157 mask: Option<&Bitmask>,
158 has_nulls: bool
159 ) -> Result<BooleanArray<()>, KernelError> {
160 let len = lhs.len();
161 let mut out_data = new_bool_buffer(len);
162 $name_to(lhs, rhs, mask, has_nulls, &mut out_data)?;
163 Ok(BooleanArray {
164 data: out_data.into(),
165 null_mask: mask.cloned(),
166 len,
167 _phantom: PhantomData
168 })
169 }
170 };
171}
172
173#[inline(always)]
176fn between_generic<T: Numeric + Copy + std::cmp::PartialOrd>(
177 lhs: &[T],
178 rhs: &[T],
179 mask: Option<&Bitmask>,
180 has_nulls: bool,
181) -> Result<BooleanArray<()>, KernelError> {
182 let len = lhs.len();
183 let mut out = new_bool_buffer(len);
184 let _ = confirm_mask_capacity(len, mask)?;
185 if rhs.len() == 2 {
186 let (min, max) = (rhs[0], rhs[1]);
187 for i in 0..len {
188 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
189 && lhs[i] >= min
190 && lhs[i] <= max
191 {
192 out.set(i, true);
193 }
194 }
195 } else {
196 for i in 0..len {
197 let min = rhs[i * 2];
198 let max = rhs[i * 2 + 1];
199 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
200 && lhs[i] >= min
201 && lhs[i] <= max
202 {
203 out.set(i, true);
204 }
205 }
206 }
207
208 Ok(BooleanArray {
209 data: out.into(),
210 null_mask: mask.cloned(),
211 len,
212 _phantom: PhantomData,
213 })
214}
215
216macro_rules! impl_in_int {
219 ($name:ident, $name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
220 #[inline(always)]
225 pub fn $name_to(
226 lhs: &[$ty],
227 rhs: &[$ty],
228 mask: Option<&Bitmask>,
229 has_nulls: bool,
230 output: &mut Bitmask,
231 ) -> Result<(), KernelError> {
232 let len = lhs.len();
233 let _ = confirm_mask_capacity(len, mask)?;
234 assert!(
235 output.capacity() >= len,
236 concat!(stringify!($name_to), ": output capacity too small")
237 );
238
239 #[cfg(feature = "simd")]
240 {
241 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
243 use crate::utils::bitmask_to_simd_mask;
244 use core::simd::{Mask, Simd};
245
246 if rhs.len() <= 16 {
247 let mut i = 0;
248 let rhs_simd = rhs;
249 if !has_nulls {
250 while i + $lanes <= len {
251 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
252 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
253 for &v in rhs_simd {
254 m |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
255 }
256 let bm = m.to_bitmask();
257 for l in 0..$lanes {
258 if ((bm >> l) & 1) == 1 {
259 output.set(i + l, true);
260 }
261 }
262 i += $lanes;
263 }
264 for j in i..len {
265 if rhs_simd.contains(&lhs[j]) {
266 output.set(j, true);
267 }
268 }
269 return Ok(());
270 } else {
271 let mb = mask.expect("Bitmask must be Some if has_nulls is set");
273 let mask_bytes = mb.as_bytes();
274 while i + $lanes <= len {
275 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
276 let lane_mask =
278 bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
279 let mut in_mask = Mask::<$mask_elem, $lanes>::splat(false);
280 for &v in rhs_simd {
281 in_mask |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
282 }
283 let valid_in = lane_mask & in_mask;
285 let bm = valid_in.to_bitmask();
286 for l in 0..$lanes {
287 if ((bm >> l) & 1) == 1 {
288 output.set(i + l, true);
289 }
290 }
291 i += $lanes;
292 }
293 for j in i..len {
294 if unsafe { mb.get_unchecked(j) } && rhs_simd.contains(&lhs[j]) {
295 output.set(j, true);
296 }
297 }
298 return Ok(());
299 }
300 }
301 }
302 }
304
305 let set: std::collections::HashSet<$ty> = rhs.iter().copied().collect();
307 for i in 0..len {
308 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
309 && set.contains(&lhs[i])
310 {
311 output.set(i, true);
312 }
313 }
314 Ok(())
315 }
316
317 #[inline(always)]
319 pub fn $name(
320 lhs: &[$ty],
321 rhs: &[$ty],
322 mask: Option<&Bitmask>,
323 has_nulls: bool,
324 ) -> Result<BooleanArray<()>, KernelError> {
325 let len = lhs.len();
326 let mut out = new_bool_buffer(len);
327 $name_to(lhs, rhs, mask, has_nulls, &mut out)?;
328 Ok(BooleanArray {
329 data: out.into(),
330 null_mask: mask.cloned(),
331 len,
332 _phantom: PhantomData,
333 })
334 }
335 };
336}
337
338macro_rules! impl_in_float {
340 (
341 $fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty
342 ) => {
343 #[inline(always)]
348 pub fn $fn_name_to(
349 lhs: &[$ty],
350 rhs: &[$ty],
351 mask: Option<&Bitmask>,
352 has_nulls: bool,
353 output: &mut Bitmask,
354 ) -> Result<(), KernelError> {
355 let len = lhs.len();
356 let _ = confirm_mask_capacity(len, mask)?;
357 assert!(
358 output.capacity() >= len,
359 concat!(stringify!($fn_name_to), ": output capacity too small")
360 );
361
362 #[cfg(feature = "simd")]
363 {
364 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
366 use crate::utils::bitmask_to_simd_mask;
367 use core::simd::{Mask, Simd};
368 if rhs.len() <= 16 {
369 let mut i = 0;
370 if !has_nulls {
371 while i + $lanes <= len {
372 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
373 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
374 for &v in rhs {
375 let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
376 | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
377 m |= vmask;
378 }
379 let bm = m.to_bitmask();
380 for l in 0..$lanes {
381 if ((bm >> l) & 1) == 1 {
382 output.set(i + l, true);
383 }
384 }
385 i += $lanes;
386 }
387 for j in i..len {
388 let x = lhs[j];
389 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
390 output.set(j, true);
391 }
392 }
393 return Ok(());
394 } else {
395 let mb = mask.expect("Bitmask must be Some if nulls are present");
396 let mask_bytes = mb.as_bytes();
397 while i + $lanes <= len {
398 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
399 let lane_mask =
400 bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
401 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
402 for &v in rhs {
403 let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
404 | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
405 m |= vmask;
406 }
407 let m = m & lane_mask;
408 let bm = m.to_bitmask();
409 for l in 0..$lanes {
410 if ((bm >> l) & 1) == 1 {
411 output.set(i + l, true);
412 }
413 }
414 i += $lanes;
415 }
416 for j in i..len {
417 if mask.map_or(true, |m| unsafe { m.get_unchecked(j) }) {
418 let x = lhs[j];
419 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
420 output.set(j, true);
421 }
422 }
423 }
424 return Ok(());
425 }
426 }
427 }
428 }
430
431 for i in 0..len {
433 if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
434 continue;
435 }
436 let x = lhs[i];
437 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
438 output.set(i, true);
439 }
440 }
441 Ok(())
442 }
443
444 #[inline(always)]
446 pub fn $fn_name(
447 lhs: &[$ty],
448 rhs: &[$ty],
449 mask: Option<&Bitmask>,
450 has_nulls: bool,
451 ) -> Result<BooleanArray<()>, KernelError> {
452 let len = lhs.len();
453 let mut out = new_bool_buffer(len);
454 $fn_name_to(lhs, rhs, mask, has_nulls, &mut out)?;
455 Ok(BooleanArray {
456 data: out.into(),
457 null_mask: mask.cloned(),
458 len,
459 _phantom: PhantomData,
460 })
461 }
462 };
463}
464
465#[cfg(feature = "extended_numeric_types")]
467impl_in_int!(in_i8, in_i8_to, i8, W8, i8);
468#[cfg(feature = "extended_numeric_types")]
469impl_in_int!(in_u8, in_u8_to, u8, W8, i8);
470#[cfg(feature = "extended_numeric_types")]
471impl_in_int!(in_i16, in_i16_to, i16, W16, i16);
472#[cfg(feature = "extended_numeric_types")]
473impl_in_int!(in_u16, in_u16_to, u16, W16, i16);
474impl_in_int!(in_i32, in_i32_to, i32, W32, i32);
475impl_in_int!(in_u32, in_u32_to, u32, W32, i32);
476impl_in_int!(in_i64, in_i64_to, i64, W64, i64);
477impl_in_int!(in_u64, in_u64_to, u64, W64, i64);
478impl_in_float!(in_f32, in_f32_to, f32, W32, i32);
479impl_in_float!(in_f64, in_f64_to, f64, W64, i64);
480
481#[cfg(feature = "extended_numeric_types")]
482impl_between_numeric!(between_i8, between_i8_to, i8, i8, W8);
483#[cfg(feature = "extended_numeric_types")]
484impl_between_numeric!(between_u8, between_u8_to, u8, i8, W8);
485#[cfg(feature = "extended_numeric_types")]
486impl_between_numeric!(between_i16, between_i16_to, i16, i16, W16);
487#[cfg(feature = "extended_numeric_types")]
488impl_between_numeric!(between_u16, between_u16_to, u16, i16, W16);
489
490impl_between_numeric!(between_i32, between_i32_to, i32, i32, W32);
491impl_between_numeric!(between_u32, between_u32_to, u32, i32, W32);
492impl_between_numeric!(between_i64, between_i64_to, i64, i64, W64);
493impl_between_numeric!(between_u64, between_u64_to, u64, i64, W64);
494impl_between_numeric!(between_f32, between_f32_to, f32, i32, W32);
495impl_between_numeric!(between_f64, between_f64_to, f64, i64, W64);
496
497#[inline(always)]
501pub fn cmp_str_between<'a, T: Integer>(
502 lhs: StringAVT<'a, T>,
503 rhs: StringAVT<'a, T>,
504) -> Result<BooleanArray<()>, KernelError> {
505 let (larr, loff, llen) = lhs;
506 let (rarr, roff, rlen) = rhs;
507
508 if rlen < 2 {
509 return Err(KernelError::InvalidArguments(format!(
510 "str_between: RHS must contain at least two values (got {})",
511 rlen
512 )));
513 }
514 let min = rarr.get(roff).unwrap_or("");
515 let max = rarr.get(roff + 1).unwrap_or("");
516 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
517 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
518
519 let mut out = new_bool_buffer(llen);
520
521 for i in 0..llen {
522 if mask
523 .as_ref()
524 .map_or(true, |m| unsafe { m.get_unchecked(i) })
525 {
526 let s = unsafe { larr.get_str_unchecked(loff + i) };
527 if s >= min && s <= max {
528 unsafe { out.set_unchecked(i, true) };
529 }
530 }
531 }
532
533 Ok(BooleanArray {
534 data: out.into(),
535 null_mask: mask,
536 len: llen,
537 _phantom: PhantomData,
538 })
539}
540
541#[inline(always)]
542pub fn cmp_str_in<'a, T: Integer>(
544 lhs: StringAVT<'a, T>,
545 rhs: StringAVT<'a, T>,
546) -> Result<BooleanArray<()>, KernelError> {
547 let (larr, loff, llen) = lhs;
548 let (rarr, roff, rlen) = rhs;
549
550 let set: HashSet<&str> = (0..rlen)
551 .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
552 .collect();
553
554 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
555 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
556
557 let mut out = new_bool_buffer(llen);
558
559 for i in 0..llen {
560 if mask
561 .as_ref()
562 .map_or(true, |m| unsafe { m.get_unchecked(i) })
563 {
564 let s = unsafe { larr.get_str_unchecked(loff + i) };
565 if set.contains(s) {
566 unsafe { out.set_unchecked(i, true) };
567 }
568 }
569 }
570 Ok(BooleanArray {
571 data: out.into(),
572 null_mask: mask,
573 len: llen,
574 _phantom: PhantomData,
575 })
576}
577
578pub fn cmp_between<T: PartialOrd + Copy + Numeric>(
582 lhs: &[T],
583 rhs: &[T],
584) -> Result<BooleanArray<()>, KernelError> {
585 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
586 return between_i32(
587 unsafe { std::mem::transmute(lhs) },
588 unsafe { std::mem::transmute(rhs) },
589 None,
590 false,
591 );
592 }
593 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
594 return between_u32(
595 unsafe { std::mem::transmute(lhs) },
596 unsafe { std::mem::transmute(rhs) },
597 None,
598 false,
599 );
600 }
601 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
602 return between_i64(
603 unsafe { std::mem::transmute(lhs) },
604 unsafe { std::mem::transmute(rhs) },
605 None,
606 false,
607 );
608 }
609 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
610 return between_u64(
611 unsafe { std::mem::transmute(lhs) },
612 unsafe { std::mem::transmute(rhs) },
613 None,
614 false,
615 );
616 }
617 between_generic(lhs, rhs, None, false)
619}
620
621#[inline(always)]
623pub fn cmp_between_mask<T: PartialOrd + Copy + Numeric>(
624 lhs: &[T],
625 rhs: &[T],
626 mask: Option<&Bitmask>,
627) -> Result<BooleanArray<()>, KernelError> {
628 let has_nulls = mask.is_some();
629 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
630 return between_i32(
631 unsafe { std::mem::transmute(lhs) },
632 unsafe { std::mem::transmute(rhs) },
633 mask,
634 has_nulls,
635 );
636 }
637 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
638 return between_u32(
639 unsafe { std::mem::transmute(lhs) },
640 unsafe { std::mem::transmute(rhs) },
641 mask,
642 has_nulls,
643 );
644 }
645 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
646 return between_i64(
647 unsafe { std::mem::transmute(lhs) },
648 unsafe { std::mem::transmute(rhs) },
649 mask,
650 has_nulls,
651 );
652 }
653 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
654 return between_u64(
655 unsafe { std::mem::transmute(lhs) },
656 unsafe { std::mem::transmute(rhs) },
657 mask,
658 has_nulls,
659 );
660 }
661 between_generic(lhs, rhs, mask, has_nulls)
662}
663
664pub fn cmp_in<T: Eq + Hash + Copy + 'static>(
668 lhs: &[T],
669 rhs: &[T],
670) -> Result<BooleanArray<()>, KernelError> {
671 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
673 return in_i32(
674 unsafe { std::mem::transmute(lhs) },
675 unsafe { std::mem::transmute(rhs) },
676 None,
677 false,
678 );
679 }
680 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
682 return in_u32(
683 unsafe { std::mem::transmute(lhs) },
684 unsafe { std::mem::transmute(rhs) },
685 None,
686 false,
687 );
688 }
689 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
691 return in_i64(
692 unsafe { std::mem::transmute(lhs) },
693 unsafe { std::mem::transmute(rhs) },
694 None,
695 false,
696 );
697 }
698 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
700 return in_u64(
701 unsafe { std::mem::transmute(lhs) },
702 unsafe { std::mem::transmute(rhs) },
703 None,
704 false,
705 );
706 }
707 #[cfg(feature = "extended_numeric_types")]
709 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i16>() {
710 return in_i16(
711 unsafe { std::mem::transmute(lhs) },
712 unsafe { std::mem::transmute(rhs) },
713 None,
714 false,
715 );
716 }
717 #[cfg(feature = "extended_numeric_types")]
719 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u16>() {
720 return in_u16(
721 unsafe { std::mem::transmute(lhs) },
722 unsafe { std::mem::transmute(rhs) },
723 None,
724 false,
725 );
726 }
727 #[cfg(feature = "extended_numeric_types")]
729 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i8>() {
730 return in_i8(
731 unsafe { std::mem::transmute(lhs) },
732 unsafe { std::mem::transmute(rhs) },
733 None,
734 false,
735 );
736 }
737 #[cfg(feature = "extended_numeric_types")]
739 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u8>() {
740 return in_u8(
741 unsafe { std::mem::transmute(lhs) },
742 unsafe { std::mem::transmute(rhs) },
743 None,
744 false,
745 );
746 }
747 return Err(KernelError::UnsupportedType(
748 "cmp_in: unsupported type for SIMD in".into(),
749 ));
750}
751
752#[inline(always)]
754pub fn cmp_in_mask<T: Eq + Hash + Copy + 'static>(
755 lhs: &[T],
756 rhs: &[T],
757 mask: Option<&Bitmask>,
758) -> Result<BooleanArray<()>, KernelError> {
759 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
760 return in_i32(
761 unsafe { std::mem::transmute(lhs) },
762 unsafe { std::mem::transmute(rhs) },
763 mask,
764 mask.is_some(),
765 );
766 }
767 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
768 return in_u32(
769 unsafe { std::mem::transmute(lhs) },
770 unsafe { std::mem::transmute(rhs) },
771 mask,
772 mask.is_some(),
773 );
774 }
775 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
776 return in_i64(
777 unsafe { std::mem::transmute(lhs) },
778 unsafe { std::mem::transmute(rhs) },
779 mask,
780 mask.is_some(),
781 );
782 }
783 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
784 return in_u64(
785 unsafe { std::mem::transmute(lhs) },
786 unsafe { std::mem::transmute(rhs) },
787 mask,
788 mask.is_some(),
789 );
790 }
791 return Err(KernelError::UnsupportedType(
792 "cmp_in_mask: unsupported type (expected integer type)".into(),
793 ));
794}
795
796#[inline(always)]
798pub fn cmp_in_f_mask<T: Float + Copy>(
799 lhs: &[T],
800 rhs: &[T],
801 mask: Option<&Bitmask>,
802) -> Result<BooleanArray<()>, KernelError> {
803 if TypeId::of::<T>() == TypeId::of::<f32>() {
804 let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
805 let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
806 in_f32(lhs, rhs, mask, mask.is_some())
807 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
808 let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
809 let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
810 in_f64(lhs, rhs, mask, mask.is_some())
811 } else {
812 unreachable!("cmp_in_f_mask: Only f32/f64 supported for Float kernels")
813 }
814}
815
816#[inline(always)]
817pub fn cmp_in_f<T: Float + Copy>(lhs: &[T], rhs: &[T]) -> Result<BooleanArray<()>, KernelError> {
819 if TypeId::of::<T>() == TypeId::of::<f32>() {
820 let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
821 let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
822 in_f32(lhs, rhs, None, false)
823 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
824 let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
825 let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
826 in_f64(lhs, rhs, None, false)
827 } else {
828 unreachable!("cmp_in_f: Only f32/f64 supported for Float kernels")
829 }
830}
831
832pub fn cmp_between_f<T: PartialOrd + Copy + Float + Numeric>(
836 lhs: &[T],
837 rhs: &[T],
838) -> Result<BooleanArray<()>, KernelError> {
839 between_generic(lhs, rhs, None, false)
840}
841
842pub fn cmp_dict_between<'a, T: Integer>(
844 lhs: CategoricalAVT<'a, T>,
845 rhs: CategoricalAVT<'a, T>,
846) -> Result<BooleanArray<()>, KernelError> {
847 let (larr, loff, llen) = lhs;
848 let (rarr, roff, _rlen) = rhs;
849
850 let min = rarr.get(roff).unwrap_or("");
851 let max = rarr.get(roff + 1).unwrap_or("");
852 let mask = larr.null_mask.as_ref();
853 let _ = confirm_mask_capacity(larr.data.len(), mask)?;
854 let has_nulls = mask.is_some();
855
856 let mut out = new_bool_buffer(llen);
857 for i in 0..llen {
858 let li = loff + i;
859 if !has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(li) }) {
860 let s = unsafe { larr.get_str_unchecked(li) };
861 if s > min && s <= max {
862 unsafe { out.set_unchecked(i, true) };
863 }
864 }
865 }
866 Ok(BooleanArray {
867 data: out.into(),
868 null_mask: mask.cloned(),
869 len: llen,
870 _phantom: PhantomData,
871 })
872}
873
874pub fn cmp_dict_in<'a, T: Integer + Hash>(
879 lhs: CategoricalAVT<'a, T>,
880 rhs: CategoricalAVT<'a, T>,
881) -> Result<BooleanArray<()>, KernelError> {
882 let (larr, loff, llen) = lhs;
883 let (rarr, roff, rlen) = rhs;
884 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
885 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
886
887 let mut out = Bitmask::new_set_all(llen, false);
888
889 if (larr.unique_values.len() == rarr.unique_values.len())
890 && (larr.unique_values.len() <= MAX_DICT_CHECK)
891 {
892 let mut same_dict = true;
893 for (a, b) in larr.unique_values.iter().zip(rarr.unique_values.iter()) {
894 if a != b {
895 same_dict = false;
896 break;
897 }
898 }
899
900 if same_dict {
901 let rhs_codes: HashSet<T> = rarr.data[roff..roff + rlen].iter().copied().collect();
902 for i in 0..llen {
903 if mask
904 .as_ref()
905 .map_or(true, |m| unsafe { m.get_unchecked(i) })
906 {
907 let code = larr.data[loff + i];
908 if rhs_codes.contains(&code) {
909 unsafe { out.set_unchecked(i, true) };
910 }
911 }
912 }
913 return Ok(BooleanArray {
914 data: out.into(),
915 null_mask: mask,
916 len: llen,
917 _phantom: PhantomData,
918 });
919 }
920 }
921
922 let rhs_strings: HashSet<&str> = (0..rlen)
923 .filter(|&i| {
924 rarr.null_mask
925 .as_ref()
926 .map_or(true, |m| unsafe { m.get_unchecked(roff + i) })
927 })
928 .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
929 .collect();
930
931 for i in 0..llen {
932 if mask
933 .as_ref()
934 .map_or(true, |m| unsafe { m.get_unchecked(i) })
935 {
936 let s = unsafe { larr.get_str_unchecked(loff + i) };
937 if rhs_strings.contains(s) {
938 unsafe { out.set_unchecked(i, true) };
939 }
940 }
941 }
942
943 Ok(BooleanArray {
944 data: out.into(),
945 null_mask: mask,
946 len: llen,
947 _phantom: PhantomData,
948 })
949}
950
951pub fn is_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
955 let not_null = is_not_null_array(arr)?;
956 Ok(!not_null)
957}
958pub fn is_not_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
960 let len = arr.len();
961 let mut data = Bitmask::new_set_all(len, false);
962
963 if let Some(mask) = arr.null_mask() {
964 data = mask.clone();
965 } else {
966 data.fill(true);
967 }
968 Ok(BooleanArray {
969 data,
970 null_mask: None,
971 len,
972 _phantom: PhantomData,
973 })
974}
975
976pub fn in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
979 match (input, values) {
980 (
981 Array::NumericArray(NumericArray::Int32(a)),
982 Array::NumericArray(NumericArray::Int32(b)),
983 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
984 (
985 Array::NumericArray(NumericArray::Int64(a)),
986 Array::NumericArray(NumericArray::Int64(b)),
987 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
988 (
989 Array::NumericArray(NumericArray::UInt32(a)),
990 Array::NumericArray(NumericArray::UInt32(b)),
991 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
992 (
993 Array::NumericArray(NumericArray::UInt64(a)),
994 Array::NumericArray(NumericArray::UInt64(b)),
995 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
996 (
997 Array::NumericArray(NumericArray::Float32(a)),
998 Array::NumericArray(NumericArray::Float32(b)),
999 ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1000 (
1001 Array::NumericArray(NumericArray::Float64(a)),
1002 Array::NumericArray(NumericArray::Float64(b)),
1003 ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1004 (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String32(b))) => {
1005 cmp_str_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len()))
1006 }
1007 (Array::BooleanArray(a), Array::BooleanArray(b)) => {
1008 cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref())
1009 }
1010 (
1011 Array::TextArray(TextArray::Categorical32(a)),
1012 Array::TextArray(TextArray::Categorical32(b)),
1013 ) => cmp_dict_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len())),
1014 _ => unimplemented!(),
1015 }
1016}
1017
1018#[inline(always)]
1019pub fn not_in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
1021 let result = in_array(input, values)?;
1022 Ok(!result)
1023}
1024
1025pub fn between_array(input: &Array, min: &Array, max: &Array) -> Result<Array, KernelError> {
1027 macro_rules! between_case {
1028 ($variant:ident, $cmp:ident) => {{
1029 let arr = match input {
1030 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1031 _ => unreachable!(),
1032 };
1033 let mins = match min {
1034 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1035 _ => unreachable!(),
1036 };
1037 let maxs = match max {
1038 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1039 _ => unreachable!(),
1040 };
1041 let rhs: Vec64<_> = mins
1042 .data
1043 .iter()
1044 .zip(&maxs.data)
1045 .flat_map(|(&lo, &hi)| [lo, hi])
1046 .collect();
1047 Ok(Array::BooleanArray(
1048 $cmp(
1049 &arr.data,
1050 &rhs,
1051 arr.null_mask.as_ref(),
1052 arr.null_mask.is_some(),
1053 )?
1054 .into(),
1055 ))
1056 }};
1057 }
1058
1059 match (input, min, max) {
1060 (
1061 Array::NumericArray(NumericArray::Int32(..)),
1062 Array::NumericArray(NumericArray::Int32(..)),
1063 Array::NumericArray(NumericArray::Int32(..)),
1064 ) => between_case!(Int32, between_i32),
1065 (
1066 Array::NumericArray(NumericArray::Int64(..)),
1067 Array::NumericArray(NumericArray::Int64(..)),
1068 Array::NumericArray(NumericArray::Int64(..)),
1069 ) => between_case!(Int64, between_i64),
1070 (
1071 Array::NumericArray(NumericArray::UInt32(..)),
1072 Array::NumericArray(NumericArray::UInt32(..)),
1073 Array::NumericArray(NumericArray::UInt32(..)),
1074 ) => between_case!(UInt32, between_u32),
1075 (
1076 Array::NumericArray(NumericArray::UInt64(..)),
1077 Array::NumericArray(NumericArray::UInt64(..)),
1078 Array::NumericArray(NumericArray::UInt64(..)),
1079 ) => between_case!(UInt64, between_u64),
1080 (
1081 Array::NumericArray(NumericArray::Float32(..)),
1082 Array::NumericArray(NumericArray::Float32(..)),
1083 Array::NumericArray(NumericArray::Float32(..)),
1084 ) => between_case!(Float32, between_generic),
1085 (
1086 Array::NumericArray(NumericArray::Float64(..)),
1087 Array::NumericArray(NumericArray::Float64(..)),
1088 Array::NumericArray(NumericArray::Float64(..)),
1089 ) => between_case!(Float64, between_generic),
1090 _ => Err(KernelError::UnsupportedType(
1091 "Unsupported Type Error.".to_string(),
1092 )),
1093 }
1094}
1095
1096#[inline]
1100pub fn not_bool<const LANES: usize>(
1101 src: BooleanAVT<'_, ()>,
1102) -> Result<BooleanArray<()>, KernelError>
1103where
1104{
1105 let (arr, offset, len) = src;
1106
1107 if offset % 64 != 0 {
1108 return Err(KernelError::InvalidArguments(format!(
1109 "not_bool: offset must be 64-bit aligned (got offset={})",
1110 offset
1111 )));
1112 }
1113
1114 let null_mask = arr.null_mask.as_ref().map(|nm| nm.slice_clone(offset, len));
1115
1116 let data = if null_mask.is_none() {
1117 #[cfg(feature = "simd")]
1118 {
1119 minarrow::kernels::bitmask::simd::not_mask_simd::<LANES>((&arr.data, offset, len))
1120 }
1121 #[cfg(not(feature = "simd"))]
1122 {
1123 minarrow::kernels::bitmask::std::not_mask((&arr.data, offset, len))
1124 }
1125 } else {
1126 arr.data.slice_clone(offset, len)
1128 };
1129
1130 Ok(BooleanArray {
1131 data,
1132 null_mask,
1133 len,
1134 _phantom: core::marker::PhantomData,
1135 })
1136}
1137
1138pub fn apply_logical_bool<const LANES: usize>(
1142 lhs: BooleanAVT<'_, ()>,
1143 rhs: BooleanAVT<'_, ()>,
1144 op: LogicalOperator,
1145) -> Result<BooleanArray<()>, KernelError>
1146where
1147{
1148 let (lhs_arr, lhs_off, len) = lhs;
1149 let (rhs_arr, rhs_off, rlen) = rhs;
1150
1151 if len != rlen {
1152 return Err(KernelError::LengthMismatch(format!(
1153 "logical_bool: window length mismatch (lhs: {}, rhs: {})",
1154 len, rlen
1155 )));
1156 }
1157 if lhs_off % 64 != 0 || rhs_off % 64 != 0 {
1158 return Err(KernelError::InvalidArguments(format!(
1159 "logical_bool: offsets must be 64-bit aligned (lhs: {}, rhs: {})",
1160 lhs_off, rhs_off
1161 )));
1162 }
1163
1164 #[cfg(feature = "simd")]
1167 let data = match op {
1168 LogicalOperator::And => {
1169 and_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1170 }
1171 LogicalOperator::Or => {
1172 or_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1173 }
1174 LogicalOperator::Xor => {
1175 xor_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1176 }
1177 };
1178
1179 #[cfg(feature = "simd")]
1181 let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1182 (None, None) => None,
1183 (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1184 (Some(a), Some(b)) => Some(and_masks_simd::<LANES>(
1185 (a, lhs_off, len),
1186 (b, rhs_off, len),
1187 )),
1188 };
1189
1190 #[cfg(not(feature = "simd"))]
1191 let data = match op {
1192 LogicalOperator::And => {
1193 and_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1194 }
1195 LogicalOperator::Or => {
1196 or_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1197 }
1198 LogicalOperator::Xor => {
1199 xor_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1200 }
1201 };
1202
1203 #[cfg(not(feature = "simd"))]
1204 let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1205 (None, None) => None,
1206 (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1207 (Some(a), Some(b)) => Some(and_masks((a, lhs_off, len), (b, rhs_off, len))),
1208 };
1209
1210 Ok(BooleanArray {
1211 data,
1212 null_mask,
1213 len,
1214 _phantom: PhantomData,
1215 })
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220 use minarrow::structs::variants::categorical::CategoricalArray;
1221 use minarrow::structs::variants::float::FloatArray;
1222 use minarrow::structs::variants::integer::IntegerArray;
1223 use minarrow::structs::variants::string::StringArray;
1224 use minarrow::{Array, Bitmask, BooleanArray, vec64};
1225
1226 use super::*;
1227
1228 fn bm(bits: &[bool]) -> Bitmask {
1231 let mut m = Bitmask::new_set_all(bits.len(), false);
1232 for (i, &b) in bits.iter().enumerate() {
1233 m.set(i, b);
1234 }
1235 m
1236 }
1237
1238 fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
1239 assert_eq!(arr.len, expect.len(), "length mismatch");
1240 for i in 0..expect.len() {
1241 assert_eq!(arr.data.get(i), expect[i], "val @ {i}");
1242 }
1243 match (expect_mask, &arr.null_mask) {
1244 (None, None) => {}
1245 (Some(exp), Some(mask)) => {
1246 for (i, &b) in exp.iter().enumerate() {
1247 assert_eq!(mask.get(i), b, "mask @ {i}");
1248 }
1249 }
1250 (None, Some(mask)) => {
1251 for i in 0..arr.len {
1253 assert!(mask.get(i), "unexpected false mask @ {i}");
1254 }
1255 }
1256 (Some(_), None) => panic!("expected null mask"),
1257 }
1258 }
1259
1260 fn i32_arr(data: &[i32]) -> IntegerArray<i32> {
1261 IntegerArray::from_slice(data)
1262 }
1263 fn f32_arr(data: &[f32]) -> FloatArray<f32> {
1264 FloatArray::from_slice(data)
1265 }
1266 fn str_arr<T: Integer>(vals: &[&str]) -> StringArray<T> {
1267 StringArray::<T>::from_slice(vals)
1268 }
1269 fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
1270 let owned: Vec<&str> = vals.to_vec();
1271 CategoricalArray::<T>::from_values(owned)
1272 }
1273 #[test]
1276 fn between_i32_scalar_rhs() {
1277 let lhs = vec64![1, 3, 5, 7];
1278 let rhs = vec64![2, 6];
1279 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1280 assert_bool(&out, &[false, true, true, false], None);
1281 }
1282
1283 #[test]
1284 fn between_i32_per_row_rhs() {
1285 let lhs = vec64![5, 9, 2, 8];
1286 let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9]; let out = between_i32(&lhs, &rhs, None, false).unwrap();
1288 assert_bool(&out, &[true, false, true, true], None);
1289 }
1290
1291 #[test]
1292 fn between_i32_nulls_propagate() {
1293 let lhs = vec64![5, 9, 2, 8];
1294 let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9];
1295 let mask = bm(&[true, false, true, true]);
1296 let out = between_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1297 assert_bool(
1298 &out,
1299 &[true, false, true, true],
1300 Some(&[true, false, true, true]),
1301 );
1302 }
1303
1304 #[cfg(feature = "extended_numeric_types")]
1305 #[test]
1306 fn between_i16_works() {
1307 let lhs = vec64![10i16, 12, 99];
1308 let rhs = vec64![10i16, 12];
1309 let out = in_i16(&lhs, &rhs, None, false).unwrap();
1310 assert_bool(&out, &[true, true, false], None);
1311 }
1312
1313 #[test]
1314 fn between_f64_scalar_and_nulls() {
1315 let lhs = vec64![1.0, 5.0, 8.0, 20.0];
1316 let rhs = vec64![4.0, 10.0];
1317 let mask = bm(&[true, false, true, true]);
1318 let out = between_f64(&lhs, &rhs, Some(&mask), true).unwrap();
1319 assert_bool(
1320 &out,
1321 &[false, false, true, false],
1322 Some(&[true, false, true, true]),
1323 );
1324 }
1325
1326 #[test]
1327 fn between_f32_generic_dispatch() {
1328 let lhs = vec64![0.1f32, 0.5, 1.2, -1.0];
1329 let rhs = vec64![0.0, 1.0];
1330 let out = cmp_between(&lhs, &rhs).unwrap();
1331 assert_bool(&out, &[true, true, false, false], None);
1332 }
1333
1334 #[test]
1335 fn between_masked_dispatch() {
1336 let lhs = vec64![1i32, 2, 3];
1337 let rhs = vec64![0, 2];
1338 let mask = bm(&[true, false, true]);
1339 let out = cmp_between_mask(&lhs, &rhs, Some(&mask)).unwrap();
1340 assert_bool(&out, &[true, false, false], Some(&[true, false, true]));
1341 }
1342
1343 #[test]
1346 fn in_i32_small_rhs() {
1347 let lhs = vec64![1, 2, 3, 4, 5];
1348 let rhs = vec64![2, 4];
1349 let out = in_i32(&lhs, &rhs, None, false).unwrap();
1350 assert_bool(&out, &[false, true, false, true, false], None);
1351 }
1352
1353 #[test]
1354 fn in_i32_with_nulls() {
1355 let lhs = vec64![7, 8, 9];
1356 let rhs = vec64![8];
1357 let mask = bm(&[true, false, true]);
1358 let out = in_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1359 assert_bool(&out, &[false, false, false], Some(&[true, false, true]));
1360 }
1361
1362 #[test]
1363 fn in_i64_large_rhs() {
1364 let lhs = vec64![1i64, 2, 3, 7, 8, 15];
1365 let rhs: Vec<i64> = (2..10).collect();
1366 let out = in_i64(&lhs, &rhs, None, false).unwrap();
1367 assert_bool(&out, &[false, true, true, true, true, false], None);
1368 }
1369
1370 #[cfg(feature = "extended_numeric_types")]
1371 #[test]
1372 fn in_u8_small_rhs() {
1373 let lhs = vec64![1u8, 2, 3, 4];
1374 let rhs = vec64![2u8, 3];
1375 let out = in_u8(&lhs, &rhs, None, false).unwrap();
1376 assert_bool(&out, &[false, true, true, false], None);
1377 }
1378
1379 #[test]
1380 fn in_float_nan_and_normal() {
1381 let lhs = vec64![1.0f32, f32::NAN, 7.0];
1382 let rhs = vec64![f32::NAN, 7.0];
1383 let out = in_f32(&lhs, &rhs, None, false).unwrap();
1384 assert_bool(&out, &[false, true, true], None);
1385 }
1386
1387 #[test]
1390 fn string_between() {
1391 let lhs = str_arr::<u32>(&["aa", "bb", "zz"]);
1392 let rhs = str_arr::<u32>(&["b", "y"]);
1393 let lhs_slice = (&lhs, 0, lhs.len());
1394 let rhs_slice = (&rhs, 0, rhs.len());
1395 let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1396 assert_bool(&out, &[false, true, false], None);
1397 }
1398
1399 #[test]
1400 fn string_between_chunk() {
1401 let lhs = str_arr::<u32>(&["0", "aa", "bb", "zz", "9"]);
1402 let rhs = str_arr::<u32>(&["a", "b", "y", "z"]);
1403 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1407 assert_bool(&out, &[false, true, false], None);
1408 }
1409
1410 #[test]
1411 fn string_in_basic() {
1412 let lhs = str_arr::<u32>(&["x", "y", "z"]);
1413 let rhs = str_arr::<u32>(&["y", "a"]);
1414 let lhs_slice = (&lhs, 0, lhs.len());
1415 let rhs_slice = (&rhs, 0, rhs.len());
1416 let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1417 assert_bool(&out, &[false, true, false], None);
1418 }
1419
1420 #[test]
1421 fn string_in_basic_chunk() {
1422 let lhs = str_arr::<u32>(&["0", "x", "y", "z", "9"]);
1423 let rhs = str_arr::<u32>(&["b", "y", "a", "c"]);
1424 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1427 assert_bool(&out, &[false, true, false], None);
1428 }
1429
1430 #[test]
1431 fn dict_between() {
1432 let lhs = dict_arr::<u32>(&["cat", "dog", "emu"]);
1433 let rhs = dict_arr::<u32>(&["cobra", "dove"]);
1434 let lhs_slice = (&lhs, 0, lhs.len());
1435 let rhs_slice = (&rhs, 0, rhs.len());
1436 let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1437 assert_bool(&out, &[false, true, false], None);
1438 }
1439
1440 #[test]
1441 fn dict_between_chunk() {
1442 let lhs = dict_arr::<u32>(&["a", "cat", "dog", "emu", "z"]);
1443 let rhs = dict_arr::<u32>(&["a", "cobra", "dove", "zz"]);
1444 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1447 assert_bool(&out, &[false, true, false], None);
1448 }
1449
1450 #[test]
1451 fn dict_in_membership() {
1452 let lhs = dict_arr::<u32>(&["aa", "bb", "cc"]);
1453 let rhs = dict_arr::<u32>(&["bb", "dd"]);
1454 let lhs_slice = (&lhs, 0, lhs.len());
1455 let rhs_slice = (&rhs, 0, rhs.len());
1456 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1457 assert_bool(&out, &[false, true, false], None);
1458 }
1459
1460 #[test]
1461 fn dict_in_membership_chunk() {
1462 let lhs = dict_arr::<u32>(&["0", "aa", "bb", "cc", "9"]);
1463 let rhs = dict_arr::<u32>(&["a", "bb", "dd", "zz"]);
1464 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1467 assert_bool(&out, &[false, true, false], None);
1468 }
1469
1470 #[test]
1471 fn string_between_nulls() {
1472 let mut lhs = str_arr::<u32>(&["foo", "bar", "baz"]);
1473 lhs.null_mask = Some(bm(&[true, false, true]));
1474 let rhs = str_arr::<u32>(&["a", "zzz"]);
1475 let lhs_slice = (&lhs, 0, lhs.len());
1476 let rhs_slice = (&rhs, 0, rhs.len());
1477 let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1478 assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1479 }
1480
1481 #[test]
1482 fn string_between_nulls_chunk() {
1483 let mut lhs = str_arr::<u32>(&["0", "foo", "bar", "baz", "z"]);
1484 lhs.null_mask = Some(bm(&[true, true, false, true, true]));
1485 let rhs = str_arr::<u32>(&["0", "a", "zzz", "9"]);
1486 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1489 assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1490 }
1491
1492 #[test]
1493 fn dict_in_nulls() {
1494 let mut lhs = dict_arr::<u32>(&["one", "two", "three"]);
1495 lhs.null_mask = Some(bm(&[false, true, true]));
1496 let rhs = dict_arr::<u32>(&["two", "four"]);
1497 let lhs_slice = (&lhs, 0, lhs.len());
1498 let rhs_slice = (&rhs, 0, rhs.len());
1499 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1500 assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1501 }
1502
1503 #[test]
1504 fn dict_in_nulls_chunk() {
1505 let mut lhs = dict_arr::<u32>(&["x", "one", "two", "three", "z"]);
1506 lhs.null_mask = Some(bm(&[true, false, true, true, true]));
1507 let rhs = dict_arr::<u32>(&["a", "two", "four", "b"]);
1508 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1511 assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1512 }
1513
1514 #[test]
1517 fn is_null_and_is_not_null() {
1518 let mut arr = i32_arr(&[1, 2, 0]);
1519 arr.null_mask = Some(bm(&[true, false, true]));
1520 let array = Array::from_int32(arr.clone());
1521
1522 let not_null = is_not_null_array(&array).unwrap();
1523 let is_null = is_null_array(&array).unwrap();
1524
1525 assert_bool(¬_null, &[true, false, true], None);
1526 assert_bool(&is_null, &[false, true, false], None);
1527 }
1528
1529 #[test]
1530 fn is_null_not_null_dense() {
1531 let arr = i32_arr(&[1, 2, 3]);
1532 let array = Array::from_int32(arr.clone());
1533 let is_null = is_null_array(&array).unwrap();
1534 assert_bool(&is_null, &[false, false, false], None);
1535 let not_null = is_not_null_array(&array).unwrap();
1536 assert_bool(¬_null, &[true, true, true], None);
1537 }
1538
1539 #[test]
1542 fn in_array_int32_dispatch() {
1543 let inp = Array::from_int32(i32_arr(&[10, 20, 30]));
1544 let vals = Array::from_int32(i32_arr(&[20, 40]));
1545 let out = in_array(&inp, &vals).unwrap();
1546 assert_bool(&out, &[false, true, false], None);
1547
1548 let out_not = not_in_array(&inp, &vals).unwrap();
1549 assert_bool(&out_not, &[true, false, true], None);
1550 }
1551
1552 #[test]
1553 fn in_array_f32_dispatch() {
1554 let inp = Array::from_float32(f32_arr(&[1.0, f32::NAN, 7.0]));
1555 let vals = Array::from_float32(f32_arr(&[f32::NAN, 7.0]));
1556 let out = in_array(&inp, &vals).unwrap();
1557 assert_bool(&out, &[false, true, true], None);
1558 }
1559
1560 #[test]
1561 fn in_array_string_dispatch() {
1562 let inp = Array::from_string32(str_arr::<u32>(&["a", "b", "c"]));
1563 let vals = Array::from_string32(str_arr::<u32>(&["b", "d"]));
1564 let out = in_array(&inp, &vals).unwrap();
1565 assert_bool(&out, &[false, true, false], None);
1566 }
1567
1568 #[test]
1569 fn in_array_dictionary_dispatch() {
1570 let inp = Array::from_categorical32(dict_arr::<u32>(&["aa", "bb", "cc"]));
1571 let vals = Array::from_categorical32(dict_arr::<u32>(&["bb", "cc"]));
1572 let out = in_array(&inp, &vals).unwrap();
1573 assert_bool(&out, &[false, true, true], None);
1574 }
1575
1576 #[test]
1577 fn between_array_int32_rows() {
1578 let inp = Array::from_int32(i32_arr(&[5, 15, 25]));
1579 let min = Array::from_int32(i32_arr(&[0, 10, 20]));
1580 let max = Array::from_int32(i32_arr(&[10, 20, 30]));
1581
1582 let out = between_array(&inp, &min, &max).unwrap();
1583 match out {
1584 Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1585 _ => panic!("expected Bool array"),
1586 }
1587 }
1588
1589 #[test]
1590 fn between_array_float_generic() {
1591 let inp = Array::from_float32(f32_arr(&[0.5, 1.5, 2.5]));
1592 let min = Array::from_float32(f32_arr(&[0.0, 1.0, 2.0]));
1593 let max = Array::from_float32(f32_arr(&[1.0, 2.0, 3.0]));
1594
1595 let out = between_array(&inp, &min, &max).unwrap();
1596 match out {
1597 Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1598 _ => panic!("expected Bool"),
1599 }
1600 }
1601
1602 #[test]
1603 fn between_array_type_mismatch() {
1604 let inp = Array::from_int32(i32_arr(&[1, 2, 3]));
1605 let min = Array::from_float32(f32_arr(&[0.0, 0.0, 0.0]));
1606 let max = Array::from_float32(f32_arr(&[5.0, 5.0, 5.0]));
1607 let err = between_array(&inp, &min, &max).unwrap_err();
1608 match err {
1609 KernelError::UnsupportedType(_) => {}
1610 _ => panic!("Expected UnsupportedType error"),
1611 }
1612 }
1613
1614 #[test]
1617 fn in_integers_various_types() {
1618 #[cfg(feature = "extended_numeric_types")]
1619 {
1620 let u8_lhs = vec64![1u8, 2, 3, 5];
1621 let u8_rhs = vec64![3u8, 5, 8];
1622 let out = in_u8(&u8_lhs, &u8_rhs, None, false).unwrap();
1623 assert_bool(&out, &[false, false, true, true], None);
1624
1625 let u16_lhs = vec64![100u16, 200, 300];
1626 let u16_rhs = vec64![200u16, 500];
1627 let out = in_u16(&u16_lhs, &u16_rhs, None, false).unwrap();
1628 assert_bool(&out, &[false, true, false], None);
1629
1630 let i16_lhs = vec64![10i16, 15, 42];
1631 let i16_rhs = vec64![15i16, 42, 77];
1632 let out = in_i16(&i16_lhs, &i16_rhs, None, false).unwrap();
1633 assert_bool(&out, &[false, true, true], None);
1634 }
1635
1636 let u32_lhs = vec64![0u32, 1, 2, 9];
1637 let u32_rhs = vec64![9u32, 1];
1638 let out = in_u32(&u32_lhs, &u32_rhs, None, false).unwrap();
1639 assert_bool(&out, &[false, true, false, true], None);
1640
1641 let i64_lhs = vec64![1i64, 9, 10];
1642 let i64_rhs = vec64![2i64, 10, 20];
1643 let out = in_i64(&i64_lhs, &i64_rhs, None, false).unwrap();
1644 assert_bool(&out, &[false, false, true], None);
1645
1646 let u64_lhs = vec64![1u64, 2, 3, 4];
1647 let u64_rhs = vec64![2u64, 4, 8];
1648 let out = in_u64(&u64_lhs, &u64_rhs, None, false).unwrap();
1649 assert_bool(&out, &[false, true, false, true], None);
1650 }
1651
1652 #[test]
1655 fn between_and_in_empty_inputs() {
1656 let lhs: [i32; 0] = [];
1658 let rhs = vec64![0, 1];
1659 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1660 assert_eq!(out.len, 0);
1661
1662 let lhs: [i32; 0] = [];
1664 let rhs = vec64![1, 2, 3];
1665 let out = in_i32(&lhs, &rhs, None, false).unwrap();
1666 assert_eq!(out.len, 0);
1667
1668 let lhs = str_arr::<u32>(&[]);
1670 let rhs = str_arr::<u32>(&["a", "b"]);
1671 let lhs_slice = (&lhs, 0, lhs.len());
1672 let rhs_slice = (&rhs, 0, rhs.len());
1673 let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1674 assert_eq!(out.len, 0);
1675 }
1676
1677 #[test]
1678 fn between_and_in_empty_inputs_chunk() {
1679 let lhs = str_arr::<u32>(&["x", "y"]);
1681 let rhs = str_arr::<u32>(&["a", "b", "c"]);
1682 let lhs_slice = (&lhs, 1, 0); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1685 assert_eq!(out.len, 0);
1686 }
1687
1688 #[test]
1689 fn between_per_row_bounds_on_last_row() {
1690 let lhs = vec64![0i32, 10, 20, 30];
1692 let rhs = vec64![0, 5, 5, 15, 15, 25, 25, 35];
1693 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1694 assert_bool(&out, &[true, true, true, true], None);
1695 }
1696
1697 #[test]
1698 fn test_cmp_dict_in_force_fallback() {
1699 let mut lhs = dict_arr::<u32>(&["a", "b", "c", "a"]);
1701 lhs.unique_values = vec64!["a".to_string(), "b".to_string(), "c".to_string()]; let mut rhs = dict_arr::<u32>(&["b", "x", "y", "z"]);
1703 rhs.unique_values = vec64![
1704 "b".to_string(),
1705 "x".to_string(),
1706 "y".to_string(),
1707 "z".to_string()
1708 ]; lhs.null_mask = Some(bm(&[true, true, true, true]));
1710 let lhs_slice = (&lhs, 0, lhs.len());
1711 let rhs_slice = (&rhs, 0, rhs.len());
1712 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1713 assert_bool(
1715 &out,
1716 &[false, true, false, false],
1717 Some(&[true, true, true, true]),
1718 );
1719 }
1720
1721 #[test]
1722 fn test_cmp_dict_in_force_fallback_chunk() {
1723 let mut lhs = dict_arr::<u32>(&["z", "a", "b", "c", "a", "q"]);
1724 lhs.unique_values = vec64![
1725 "z".to_string(),
1726 "a".to_string(),
1727 "b".to_string(),
1728 "c".to_string(),
1729 "q".to_string()
1730 ];
1731 let mut rhs = dict_arr::<u32>(&["x", "b", "x", "y", "z"]);
1732 rhs.unique_values = vec64![
1733 "x".to_string(),
1734 "b".to_string(),
1735 "y".to_string(),
1736 "z".to_string()
1737 ];
1738 lhs.null_mask = Some(bm(&[true, true, true, true, true, true]));
1739 let lhs_slice = (&lhs, 1, 4);
1741 let rhs_slice = (&rhs, 1, 4);
1742 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1743 assert_bool(
1745 &out,
1746 &[false, true, false, false],
1747 Some(&[true, true, true, true]),
1748 );
1749 }
1750
1751 #[test]
1752 fn test_in_array_empty_rhs() {
1753 let arr = Array::from_int32(i32_arr(&[1, 2, 3]));
1754 let empty = Array::from_int32(i32_arr(&[]));
1755 let out = in_array(&arr, &empty).unwrap();
1756 assert_bool(&out, &[false, false, false], None);
1758 }
1759}