1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::collections::HashSet;
22use std::hash::Hash;
23use std::marker::PhantomData;
24#[cfg(feature = "simd")]
25use std::simd::{Mask, Simd, cmp::SimdPartialEq, cmp::SimdPartialOrd, num::SimdFloat};
26
27use minarrow::kernels::arithmetic::string::MAX_DICT_CHECK;
28use minarrow::traits::type_unions::Float;
29use minarrow::{
30 Array, Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, MaskedArray, Numeric,
31 NumericArray, StringAVT, TextArray, Vec64,
32};
33
34#[cfg(not(feature = "simd"))]
35use crate::kernels::bitmask::dispatch::{and_masks, or_masks, xor_masks};
36use crate::operators::LogicalOperator;
37use minarrow::enums::error::KernelError;
38#[cfg(feature = "simd")]
39use minarrow::kernels::bitmask::simd::{and_masks_simd, or_masks_simd, xor_masks_simd};
40use minarrow::utils::confirm_mask_capacity;
41
42#[cfg(feature = "simd")]
43use minarrow::utils::is_simd_aligned;
44use std::any::TypeId;
45
46#[inline(always)]
49fn new_bool_buffer(len: usize) -> Bitmask {
50 Bitmask::new_set_all(len, false)
51}
52
53macro_rules! impl_between_numeric {
56 ($name:ident, $name_to:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
57 #[inline(always)]
62 pub fn $name_to(
63 lhs: &[$ty],
64 rhs: &[$ty],
65 mask: Option<&Bitmask>,
66 has_nulls: bool,
67 output: &mut Bitmask,
68 ) -> Result<(), KernelError> {
69 let len = lhs.len();
70 if rhs.len() != 2 && rhs.len() != 2 * len {
71 return Err(KernelError::InvalidArguments(
72 format!("between: RHS must have len 2 or 2×LHS (got lhs: {}, rhs: {})", len, rhs.len())
73 ));
74 }
75
76 if let Some(m) = mask {
77 if m.capacity() < len {
78 return Err(KernelError::InvalidArguments(
79 format!("between: mask (Bitmask) capacity must be ≥ len (got capacity: {}, len: {})", m.capacity(), len)
80 ));
81 }
82 }
83 assert!(output.capacity() >= len, concat!(stringify!($name_to), ": output capacity too small"));
84
85 #[cfg(feature = "simd")]
87 {
88 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90 const N: usize = $lanes;
91 type V = Simd<$ty, N>;
92 type M = Mask<$mask_elem, N>;
93
94 if !has_nulls && rhs.len() == 2 {
95 let min_v = V::splat(rhs[0]);
96 let max_v = V::splat(rhs[1]);
97
98 let mut i = 0usize;
99 while i + N <= len {
100 let x = V::from_slice(&lhs[i..i + N]);
101 let m: M = x.simd_ge(min_v) & x.simd_le(max_v);
102 let bm = m.to_bitmask();
103
104 for l in 0..N {
105 if ((bm >> l) & 1) == 1 {
106 output.set(i + l, true);
107 }
108 }
109 i += N;
110 }
111 for j in i..len {
113 if lhs[j] >= rhs[0] && lhs[j] <= rhs[1] {
114 output.set(j, true);
115 }
116 }
117
118 return Ok(());
119 }
120 }
121 }
123
124 if rhs.len() == 2 {
126 let (min, max) = (rhs[0], rhs[1]);
127 for i in 0..len {
128 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
129 && lhs[i] >= min
130 && lhs[i] <= max
131 {
132 output.set(i, true);
133 }
134 }
135 } else {
136 for i in 0..len {
138 let min = rhs[i * 2];
139 let max = rhs[i * 2 + 1];
140 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
141 && lhs[i] >= min
142 && lhs[i] <= max
143 {
144 output.set(i, true);
145 }
146 }
147 }
148
149 Ok(())
150 }
151
152 #[inline(always)]
154 pub fn $name(
155 lhs: &[$ty],
156 rhs: &[$ty],
157 mask: Option<&Bitmask>,
158 has_nulls: bool
159 ) -> Result<BooleanArray<()>, KernelError> {
160 let len = lhs.len();
161 let mut out_data = new_bool_buffer(len);
162 $name_to(lhs, rhs, mask, has_nulls, &mut out_data)?;
163 Ok(BooleanArray {
164 data: out_data.into(),
165 null_mask: mask.cloned(),
166 len,
167 _phantom: PhantomData
168 })
169 }
170 };
171}
172
173#[inline(always)]
176fn between_generic<T: Numeric + Copy + std::cmp::PartialOrd>(
177 lhs: &[T],
178 rhs: &[T],
179 mask: Option<&Bitmask>,
180 has_nulls: bool,
181) -> Result<BooleanArray<()>, KernelError> {
182 let len = lhs.len();
183 let mut out = new_bool_buffer(len);
184 let _ = confirm_mask_capacity(len, mask)?;
185 if rhs.len() == 2 {
186 let (min, max) = (rhs[0], rhs[1]);
187 for i in 0..len {
188 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
189 && lhs[i] >= min
190 && lhs[i] <= max
191 {
192 out.set(i, true);
193 }
194 }
195 } else {
196 for i in 0..len {
197 let min = rhs[i * 2];
198 let max = rhs[i * 2 + 1];
199 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
200 && lhs[i] >= min
201 && lhs[i] <= max
202 {
203 out.set(i, true);
204 }
205 }
206 }
207
208 Ok(BooleanArray {
209 data: out.into(),
210 null_mask: mask.cloned(),
211 len,
212 _phantom: PhantomData,
213 })
214}
215
216macro_rules! impl_in_int {
219 ($name:ident, $name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
220 #[inline(always)]
225 pub fn $name_to(
226 lhs: &[$ty],
227 rhs: &[$ty],
228 mask: Option<&Bitmask>,
229 has_nulls: bool,
230 output: &mut Bitmask,
231 ) -> Result<(), KernelError> {
232 let len = lhs.len();
233 let _ = confirm_mask_capacity(len, mask)?;
234 assert!(
235 output.capacity() >= len,
236 concat!(stringify!($name_to), ": output capacity too small")
237 );
238
239 #[cfg(feature = "simd")]
240 {
241 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
243 use crate::utils::bitmask_to_simd_mask;
244 use core::simd::{Mask, Simd};
245
246 if rhs.len() <= 16 {
247 let mut i = 0;
248 let rhs_simd = rhs;
249 if !has_nulls {
250 while i + $lanes <= len {
251 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
252 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
253 for &v in rhs_simd {
254 m |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
255 }
256 let bm = m.to_bitmask();
257 for l in 0..$lanes {
258 if ((bm >> l) & 1) == 1 {
259 output.set(i + l, true);
260 }
261 }
262 i += $lanes;
263 }
264 for j in i..len {
265 if rhs_simd.contains(&lhs[j]) {
266 output.set(j, true);
267 }
268 }
269 return Ok(());
270 } else {
271 let mb = mask.expect("Bitmask must be Some if has_nulls is set");
273 let mask_bytes = mb.as_bytes();
274 while i + $lanes <= len {
275 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
276 let lane_mask =
278 bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
279 let mut in_mask = Mask::<$mask_elem, $lanes>::splat(false);
280 for &v in rhs_simd {
281 in_mask |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
282 }
283 let valid_in = lane_mask & in_mask;
285 let bm = valid_in.to_bitmask();
286 for l in 0..$lanes {
287 if ((bm >> l) & 1) == 1 {
288 output.set(i + l, true);
289 }
290 }
291 i += $lanes;
292 }
293 for j in i..len {
294 if unsafe { mb.get_unchecked(j) } && rhs_simd.contains(&lhs[j]) {
295 output.set(j, true);
296 }
297 }
298 return Ok(());
299 }
300 }
301 }
302 }
304
305 let set: std::collections::HashSet<$ty> = rhs.iter().copied().collect();
307 for i in 0..len {
308 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
309 && set.contains(&lhs[i])
310 {
311 output.set(i, true);
312 }
313 }
314 Ok(())
315 }
316
317 #[inline(always)]
319 pub fn $name(
320 lhs: &[$ty],
321 rhs: &[$ty],
322 mask: Option<&Bitmask>,
323 has_nulls: bool,
324 ) -> Result<BooleanArray<()>, KernelError> {
325 let len = lhs.len();
326 let mut out = new_bool_buffer(len);
327 $name_to(lhs, rhs, mask, has_nulls, &mut out)?;
328 Ok(BooleanArray {
329 data: out.into(),
330 null_mask: mask.cloned(),
331 len,
332 _phantom: PhantomData,
333 })
334 }
335 };
336}
337
338macro_rules! impl_in_float {
340 (
341 $fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty
342 ) => {
343 #[inline(always)]
348 pub fn $fn_name_to(
349 lhs: &[$ty],
350 rhs: &[$ty],
351 mask: Option<&Bitmask>,
352 has_nulls: bool,
353 output: &mut Bitmask,
354 ) -> Result<(), KernelError> {
355 let len = lhs.len();
356 let _ = confirm_mask_capacity(len, mask)?;
357 assert!(
358 output.capacity() >= len,
359 concat!(stringify!($fn_name_to), ": output capacity too small")
360 );
361
362 #[cfg(feature = "simd")]
363 {
364 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
366 use crate::utils::bitmask_to_simd_mask;
367 use core::simd::{Mask, Simd};
368 if rhs.len() <= 16 {
369 let mut i = 0;
370 if !has_nulls {
371 while i + $lanes <= len {
372 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
373 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
374 for &v in rhs {
375 let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
376 | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
377 m |= vmask;
378 }
379 let bm = m.to_bitmask();
380 for l in 0..$lanes {
381 if ((bm >> l) & 1) == 1 {
382 output.set(i + l, true);
383 }
384 }
385 i += $lanes;
386 }
387 for j in i..len {
388 let x = lhs[j];
389 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
390 output.set(j, true);
391 }
392 }
393 return Ok(());
394 } else {
395 let mb = mask.expect("Bitmask must be Some if nulls are present");
396 let mask_bytes = mb.as_bytes();
397 while i + $lanes <= len {
398 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
399 let lane_mask =
400 bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
401 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
402 for &v in rhs {
403 let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
404 | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
405 m |= vmask;
406 }
407 let m = m & lane_mask;
408 let bm = m.to_bitmask();
409 for l in 0..$lanes {
410 if ((bm >> l) & 1) == 1 {
411 output.set(i + l, true);
412 }
413 }
414 i += $lanes;
415 }
416 for j in i..len {
417 if mask.map_or(true, |m| unsafe { m.get_unchecked(j) }) {
418 let x = lhs[j];
419 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
420 output.set(j, true);
421 }
422 }
423 }
424 return Ok(());
425 }
426 }
427 }
428 }
430
431 for i in 0..len {
433 if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
434 continue;
435 }
436 let x = lhs[i];
437 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
438 output.set(i, true);
439 }
440 }
441 Ok(())
442 }
443
444 #[inline(always)]
446 pub fn $fn_name(
447 lhs: &[$ty],
448 rhs: &[$ty],
449 mask: Option<&Bitmask>,
450 has_nulls: bool,
451 ) -> Result<BooleanArray<()>, KernelError> {
452 let len = lhs.len();
453 let mut out = new_bool_buffer(len);
454 $fn_name_to(lhs, rhs, mask, has_nulls, &mut out)?;
455 Ok(BooleanArray {
456 data: out.into(),
457 null_mask: mask.cloned(),
458 len,
459 _phantom: PhantomData,
460 })
461 }
462 };
463}
464
465#[cfg(feature = "extended_numeric_types")]
467impl_in_int!(in_i8, in_i8_to, i8, W8, i8);
468#[cfg(feature = "extended_numeric_types")]
469impl_in_int!(in_u8, in_u8_to, u8, W8, i8);
470#[cfg(feature = "extended_numeric_types")]
471impl_in_int!(in_i16, in_i16_to, i16, W16, i16);
472#[cfg(feature = "extended_numeric_types")]
473impl_in_int!(in_u16, in_u16_to, u16, W16, i16);
474impl_in_int!(in_i32, in_i32_to, i32, W32, i32);
475impl_in_int!(in_u32, in_u32_to, u32, W32, i32);
476impl_in_int!(in_i64, in_i64_to, i64, W64, i64);
477impl_in_int!(in_u64, in_u64_to, u64, W64, i64);
478impl_in_float!(in_f32, in_f32_to, f32, W32, i32);
479impl_in_float!(in_f64, in_f64_to, f64, W64, i64);
480
481#[cfg(feature = "extended_numeric_types")]
482impl_between_numeric!(between_i8, between_i8_to, i8, i8, W8);
483#[cfg(feature = "extended_numeric_types")]
484impl_between_numeric!(between_u8, between_u8_to, u8, i8, W8);
485#[cfg(feature = "extended_numeric_types")]
486impl_between_numeric!(between_i16, between_i16_to, i16, i16, W16);
487#[cfg(feature = "extended_numeric_types")]
488impl_between_numeric!(between_u16, between_u16_to, u16, i16, W16);
489
490impl_between_numeric!(between_i32, between_i32_to, i32, i32, W32);
491impl_between_numeric!(between_u32, between_u32_to, u32, i32, W32);
492impl_between_numeric!(between_i64, between_i64_to, i64, i64, W64);
493impl_between_numeric!(between_u64, between_u64_to, u64, i64, W64);
494impl_between_numeric!(between_f32, between_f32_to, f32, i32, W32);
495impl_between_numeric!(between_f64, between_f64_to, f64, i64, W64);
496
497#[inline(always)]
501pub fn cmp_str_between<'a, T: Integer>(
502 lhs: StringAVT<'a, T>,
503 rhs: StringAVT<'a, T>,
504) -> Result<BooleanArray<()>, KernelError> {
505 let (larr, loff, llen) = lhs;
506 let (rarr, roff, rlen) = rhs;
507
508 if rlen < 2 {
509 return Err(KernelError::InvalidArguments(format!(
510 "str_between: RHS must contain at least two values (got {})",
511 rlen
512 )));
513 }
514 let min = rarr.get(roff).unwrap_or("");
515 let max = rarr.get(roff + 1).unwrap_or("");
516 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
517 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
518
519 let mut out = new_bool_buffer(llen);
520
521 for i in 0..llen {
522 if mask
523 .as_ref()
524 .map_or(true, |m| unsafe { m.get_unchecked(i) })
525 {
526 let s = unsafe { larr.get_str_unchecked(loff + i) };
527 if s >= min && s <= max {
528 unsafe { out.set_unchecked(i, true) };
529 }
530 }
531 }
532
533 Ok(BooleanArray {
534 data: out.into(),
535 null_mask: mask,
536 len: llen,
537 _phantom: PhantomData,
538 })
539}
540
541#[inline(always)]
542pub fn cmp_str_in<'a, T: Integer>(
544 lhs: StringAVT<'a, T>,
545 rhs: StringAVT<'a, T>,
546) -> Result<BooleanArray<()>, KernelError> {
547 let (larr, loff, llen) = lhs;
548 let (rarr, roff, rlen) = rhs;
549
550 let set: HashSet<&str> = (0..rlen)
551 .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
552 .collect();
553
554 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
555 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
556
557 let mut out = new_bool_buffer(llen);
558
559 for i in 0..llen {
560 if mask
561 .as_ref()
562 .map_or(true, |m| unsafe { m.get_unchecked(i) })
563 {
564 let s = unsafe { larr.get_str_unchecked(loff + i) };
565 if set.contains(s) {
566 unsafe { out.set_unchecked(i, true) };
567 }
568 }
569 }
570 Ok(BooleanArray {
571 data: out.into(),
572 null_mask: mask,
573 len: llen,
574 _phantom: PhantomData,
575 })
576}
577
578pub fn cmp_between<T: PartialOrd + Copy + Numeric>(
582 lhs: &[T],
583 rhs: &[T],
584) -> Result<BooleanArray<()>, KernelError> {
585 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
586 return between_i32(
587 unsafe { std::mem::transmute(lhs) },
588 unsafe { std::mem::transmute(rhs) },
589 None,
590 false,
591 );
592 }
593 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
594 return between_u32(
595 unsafe { std::mem::transmute(lhs) },
596 unsafe { std::mem::transmute(rhs) },
597 None,
598 false,
599 );
600 }
601 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
602 return between_i64(
603 unsafe { std::mem::transmute(lhs) },
604 unsafe { std::mem::transmute(rhs) },
605 None,
606 false,
607 );
608 }
609 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
610 return between_u64(
611 unsafe { std::mem::transmute(lhs) },
612 unsafe { std::mem::transmute(rhs) },
613 None,
614 false,
615 );
616 }
617 between_generic(lhs, rhs, None, false)
619}
620
621#[inline(always)]
623pub fn cmp_between_mask<T: PartialOrd + Copy + Numeric>(
624 lhs: &[T],
625 rhs: &[T],
626 mask: Option<&Bitmask>,
627) -> Result<BooleanArray<()>, KernelError> {
628 let has_nulls = mask.is_some();
629 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
630 return between_i32(
631 unsafe { std::mem::transmute(lhs) },
632 unsafe { std::mem::transmute(rhs) },
633 mask,
634 has_nulls,
635 );
636 }
637 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
638 return between_u32(
639 unsafe { std::mem::transmute(lhs) },
640 unsafe { std::mem::transmute(rhs) },
641 mask,
642 has_nulls,
643 );
644 }
645 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
646 return between_i64(
647 unsafe { std::mem::transmute(lhs) },
648 unsafe { std::mem::transmute(rhs) },
649 mask,
650 has_nulls,
651 );
652 }
653 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
654 return between_u64(
655 unsafe { std::mem::transmute(lhs) },
656 unsafe { std::mem::transmute(rhs) },
657 mask,
658 has_nulls,
659 );
660 }
661 between_generic(lhs, rhs, mask, has_nulls)
662}
663
664pub fn cmp_in<T: Eq + Hash + Copy + 'static>(
668 lhs: &[T],
669 rhs: &[T],
670) -> Result<BooleanArray<()>, KernelError> {
671 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
673 return in_i32(
674 unsafe { std::mem::transmute(lhs) },
675 unsafe { std::mem::transmute(rhs) },
676 None,
677 false,
678 );
679 }
680 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
682 return in_u32(
683 unsafe { std::mem::transmute(lhs) },
684 unsafe { std::mem::transmute(rhs) },
685 None,
686 false,
687 );
688 }
689 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
691 return in_i64(
692 unsafe { std::mem::transmute(lhs) },
693 unsafe { std::mem::transmute(rhs) },
694 None,
695 false,
696 );
697 }
698 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
700 return in_u64(
701 unsafe { std::mem::transmute(lhs) },
702 unsafe { std::mem::transmute(rhs) },
703 None,
704 false,
705 );
706 }
707 #[cfg(feature = "extended_numeric_types")]
709 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i16>() {
710 return in_i16(
711 unsafe { std::mem::transmute(lhs) },
712 unsafe { std::mem::transmute(rhs) },
713 None,
714 false,
715 );
716 }
717 #[cfg(feature = "extended_numeric_types")]
719 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u16>() {
720 return in_u16(
721 unsafe { std::mem::transmute(lhs) },
722 unsafe { std::mem::transmute(rhs) },
723 None,
724 false,
725 );
726 }
727 #[cfg(feature = "extended_numeric_types")]
729 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i8>() {
730 return in_i8(
731 unsafe { std::mem::transmute(lhs) },
732 unsafe { std::mem::transmute(rhs) },
733 None,
734 false,
735 );
736 }
737 #[cfg(feature = "extended_numeric_types")]
739 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u8>() {
740 return in_u8(
741 unsafe { std::mem::transmute(lhs) },
742 unsafe { std::mem::transmute(rhs) },
743 None,
744 false,
745 );
746 }
747 return Err(KernelError::UnsupportedType(
748 "cmp_in: unsupported type for SIMD in".into(),
749 ));
750}
751
752#[inline(always)]
754pub fn cmp_in_mask<T: Eq + Hash + Copy + 'static>(
755 lhs: &[T],
756 rhs: &[T],
757 mask: Option<&Bitmask>,
758) -> Result<BooleanArray<()>, KernelError> {
759 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
760 return in_i32(
761 unsafe { std::mem::transmute(lhs) },
762 unsafe { std::mem::transmute(rhs) },
763 mask,
764 mask.is_some(),
765 );
766 }
767 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
768 return in_u32(
769 unsafe { std::mem::transmute(lhs) },
770 unsafe { std::mem::transmute(rhs) },
771 mask,
772 mask.is_some(),
773 );
774 }
775 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
776 return in_i64(
777 unsafe { std::mem::transmute(lhs) },
778 unsafe { std::mem::transmute(rhs) },
779 mask,
780 mask.is_some(),
781 );
782 }
783 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
784 return in_u64(
785 unsafe { std::mem::transmute(lhs) },
786 unsafe { std::mem::transmute(rhs) },
787 mask,
788 mask.is_some(),
789 );
790 }
791 return Err(KernelError::UnsupportedType(
792 "cmp_in_mask: unsupported type (expected integer type)".into(),
793 ));
794}
795
796#[inline(always)]
798pub fn cmp_in_f_mask<T: Float + Copy>(
799 lhs: &[T],
800 rhs: &[T],
801 mask: Option<&Bitmask>,
802) -> Result<BooleanArray<()>, KernelError> {
803 if TypeId::of::<T>() == TypeId::of::<f32>() {
804 let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
805 let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
806 in_f32(lhs, rhs, mask, mask.is_some())
807 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
808 let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
809 let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
810 in_f64(lhs, rhs, mask, mask.is_some())
811 } else {
812 unreachable!("cmp_in_f_mask: Only f32/f64 supported for Float kernels")
813 }
814}
815
816#[inline(always)]
817pub fn cmp_in_f<T: Float + Copy>(lhs: &[T], rhs: &[T]) -> Result<BooleanArray<()>, KernelError> {
819 if TypeId::of::<T>() == TypeId::of::<f32>() {
820 let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
821 let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
822 in_f32(lhs, rhs, None, false)
823 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
824 let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
825 let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
826 in_f64(lhs, rhs, None, false)
827 } else {
828 unreachable!("cmp_in_f: Only f32/f64 supported for Float kernels")
829 }
830}
831
832pub fn cmp_between_f<T: PartialOrd + Copy + Float + Numeric>(
836 lhs: &[T],
837 rhs: &[T],
838) -> Result<BooleanArray<()>, KernelError> {
839 between_generic(lhs, rhs, None, false)
840}
841
842pub fn cmp_dict_between<'a, T: Integer>(
844 lhs: CategoricalAVT<'a, T>,
845 rhs: CategoricalAVT<'a, T>,
846) -> Result<BooleanArray<()>, KernelError> {
847 let (larr, loff, llen) = lhs;
848 let (rarr, roff, _rlen) = rhs;
849
850 let min = rarr.get(roff).unwrap_or("");
851 let max = rarr.get(roff + 1).unwrap_or("");
852 let mask = larr.null_mask.as_ref();
853 let _ = confirm_mask_capacity(larr.data.len(), mask)?;
854 let has_nulls = mask.is_some();
855
856 let mut out = new_bool_buffer(llen);
857 for i in 0..llen {
858 let li = loff + i;
859 if !has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(li) }) {
860 let s = unsafe { larr.get_str_unchecked(li) };
861 if s >= min && s <= max {
862 unsafe { out.set_unchecked(i, true) };
863 }
864 }
865 }
866 Ok(BooleanArray {
867 data: out.into(),
868 null_mask: mask.map(|m| m.slice_clone(loff, llen)),
869 len: llen,
870 _phantom: PhantomData,
871 })
872}
873
874pub fn cmp_dict_in<'a, T: Integer + Hash>(
879 lhs: CategoricalAVT<'a, T>,
880 rhs: CategoricalAVT<'a, T>,
881) -> Result<BooleanArray<()>, KernelError> {
882 let (larr, loff, llen) = lhs;
883 let (rarr, roff, rlen) = rhs;
884 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
885 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
886
887 let mut out = Bitmask::new_set_all(llen, false);
888
889 if (larr.unique_values.len() == rarr.unique_values.len())
890 && (larr.unique_values.len() <= MAX_DICT_CHECK)
891 {
892 let mut same_dict = true;
893 for (a, b) in larr.unique_values.iter().zip(rarr.unique_values.iter()) {
894 if a != b {
895 same_dict = false;
896 break;
897 }
898 }
899
900 if same_dict {
901 let rhs_codes: HashSet<T> = rarr.data[roff..roff + rlen].iter().copied().collect();
902 for i in 0..llen {
903 if mask
904 .as_ref()
905 .map_or(true, |m| unsafe { m.get_unchecked(i) })
906 {
907 let code = larr.data[loff + i];
908 if rhs_codes.contains(&code) {
909 unsafe { out.set_unchecked(i, true) };
910 }
911 }
912 }
913 return Ok(BooleanArray {
914 data: out.into(),
915 null_mask: mask,
916 len: llen,
917 _phantom: PhantomData,
918 });
919 }
920 }
921
922 let rhs_strings: HashSet<&str> = (0..rlen)
923 .filter(|&i| {
924 rarr.null_mask
925 .as_ref()
926 .map_or(true, |m| unsafe { m.get_unchecked(roff + i) })
927 })
928 .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
929 .collect();
930
931 for i in 0..llen {
932 if mask
933 .as_ref()
934 .map_or(true, |m| unsafe { m.get_unchecked(i) })
935 {
936 let s = unsafe { larr.get_str_unchecked(loff + i) };
937 if rhs_strings.contains(s) {
938 unsafe { out.set_unchecked(i, true) };
939 }
940 }
941 }
942
943 Ok(BooleanArray {
944 data: out.into(),
945 null_mask: mask,
946 len: llen,
947 _phantom: PhantomData,
948 })
949}
950
951pub fn is_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
955 let not_null = is_not_null_array(arr)?;
956 Ok(!not_null)
957}
958pub fn is_not_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
960 let len = arr.len();
961 let mut data = Bitmask::new_set_all(len, false);
962
963 if let Some(mask) = arr.null_mask() {
964 data = mask.clone();
965 } else {
966 data.fill(true);
967 }
968 Ok(BooleanArray {
969 data,
970 null_mask: None,
971 len,
972 _phantom: PhantomData,
973 })
974}
975
976pub fn in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
979 match (input, values) {
980 (
981 Array::NumericArray(NumericArray::Int32(a)),
982 Array::NumericArray(NumericArray::Int32(b)),
983 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
984 (
985 Array::NumericArray(NumericArray::Int64(a)),
986 Array::NumericArray(NumericArray::Int64(b)),
987 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
988 (
989 Array::NumericArray(NumericArray::UInt32(a)),
990 Array::NumericArray(NumericArray::UInt32(b)),
991 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
992 (
993 Array::NumericArray(NumericArray::UInt64(a)),
994 Array::NumericArray(NumericArray::UInt64(b)),
995 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
996 (
997 Array::NumericArray(NumericArray::Float32(a)),
998 Array::NumericArray(NumericArray::Float32(b)),
999 ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1000 (
1001 Array::NumericArray(NumericArray::Float64(a)),
1002 Array::NumericArray(NumericArray::Float64(b)),
1003 ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1004 (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String32(b))) => {
1005 cmp_str_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len()))
1006 }
1007 (Array::BooleanArray(a), Array::BooleanArray(b)) => {
1008 cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref())
1009 }
1010 (
1011 Array::TextArray(TextArray::Categorical32(a)),
1012 Array::TextArray(TextArray::Categorical32(b)),
1013 ) => cmp_dict_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len())),
1014 _ => unimplemented!(),
1015 }
1016}
1017
1018#[inline(always)]
1019pub fn not_in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
1021 let result = in_array(input, values)?;
1022 Ok(!result)
1023}
1024
1025pub fn between_array(input: &Array, min: &Array, max: &Array) -> Result<Array, KernelError> {
1027 macro_rules! between_case {
1028 ($variant:ident, $cmp:ident) => {{
1029 let arr = match input {
1030 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1031 _ => unreachable!(),
1032 };
1033 let mins = match min {
1034 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1035 _ => unreachable!(),
1036 };
1037 let maxs = match max {
1038 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1039 _ => unreachable!(),
1040 };
1041 let rhs: Vec64<_> = mins
1042 .data
1043 .iter()
1044 .zip(&maxs.data)
1045 .flat_map(|(&lo, &hi)| [lo, hi])
1046 .collect();
1047 Ok(Array::BooleanArray(
1048 $cmp(
1049 &arr.data,
1050 &rhs,
1051 arr.null_mask.as_ref(),
1052 arr.null_mask.is_some(),
1053 )?
1054 .into(),
1055 ))
1056 }};
1057 }
1058
1059 match (input, min, max) {
1060 (
1061 Array::NumericArray(NumericArray::Int32(..)),
1062 Array::NumericArray(NumericArray::Int32(..)),
1063 Array::NumericArray(NumericArray::Int32(..)),
1064 ) => between_case!(Int32, between_i32),
1065 (
1066 Array::NumericArray(NumericArray::Int64(..)),
1067 Array::NumericArray(NumericArray::Int64(..)),
1068 Array::NumericArray(NumericArray::Int64(..)),
1069 ) => between_case!(Int64, between_i64),
1070 (
1071 Array::NumericArray(NumericArray::UInt32(..)),
1072 Array::NumericArray(NumericArray::UInt32(..)),
1073 Array::NumericArray(NumericArray::UInt32(..)),
1074 ) => between_case!(UInt32, between_u32),
1075 (
1076 Array::NumericArray(NumericArray::UInt64(..)),
1077 Array::NumericArray(NumericArray::UInt64(..)),
1078 Array::NumericArray(NumericArray::UInt64(..)),
1079 ) => between_case!(UInt64, between_u64),
1080 (
1081 Array::NumericArray(NumericArray::Float32(..)),
1082 Array::NumericArray(NumericArray::Float32(..)),
1083 Array::NumericArray(NumericArray::Float32(..)),
1084 ) => between_case!(Float32, between_generic),
1085 (
1086 Array::NumericArray(NumericArray::Float64(..)),
1087 Array::NumericArray(NumericArray::Float64(..)),
1088 Array::NumericArray(NumericArray::Float64(..)),
1089 ) => between_case!(Float64, between_generic),
1090 _ => Err(KernelError::UnsupportedType(
1091 "Unsupported Type Error.".to_string(),
1092 )),
1093 }
1094}
1095
1096#[inline]
1100pub fn not_bool<const LANES: usize>(
1101 src: BooleanAVT<'_, ()>,
1102) -> Result<BooleanArray<()>, KernelError>
1103where
1104{
1105 let (arr, offset, len) = src;
1106
1107 if offset % 64 != 0 {
1108 return Err(KernelError::InvalidArguments(format!(
1109 "not_bool: offset must be 64-bit aligned (got offset={})",
1110 offset
1111 )));
1112 }
1113
1114 let null_mask = arr.null_mask.as_ref().map(|nm| nm.slice_clone(offset, len));
1115
1116 let data = {
1117 #[cfg(feature = "simd")]
1118 {
1119 minarrow::kernels::bitmask::simd::not_mask_simd::<LANES>((&arr.data, offset, len))
1120 }
1121 #[cfg(not(feature = "simd"))]
1122 {
1123 minarrow::kernels::bitmask::std::not_mask((&arr.data, offset, len))
1124 }
1125 };
1126
1127 Ok(BooleanArray {
1128 data,
1129 null_mask,
1130 len,
1131 _phantom: core::marker::PhantomData,
1132 })
1133}
1134
1135pub fn apply_logical_bool<const LANES: usize>(
1139 lhs: BooleanAVT<'_, ()>,
1140 rhs: BooleanAVT<'_, ()>,
1141 op: LogicalOperator,
1142) -> Result<BooleanArray<()>, KernelError>
1143where
1144{
1145 let (lhs_arr, lhs_off, len) = lhs;
1146 let (rhs_arr, rhs_off, rlen) = rhs;
1147
1148 if len != rlen {
1149 return Err(KernelError::LengthMismatch(format!(
1150 "logical_bool: window length mismatch (lhs: {}, rhs: {})",
1151 len, rlen
1152 )));
1153 }
1154 if lhs_off % 64 != 0 || rhs_off % 64 != 0 {
1155 return Err(KernelError::InvalidArguments(format!(
1156 "logical_bool: offsets must be 64-bit aligned (lhs: {}, rhs: {})",
1157 lhs_off, rhs_off
1158 )));
1159 }
1160
1161 #[cfg(feature = "simd")]
1164 let data = match op {
1165 LogicalOperator::And => {
1166 and_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1167 }
1168 LogicalOperator::Or => {
1169 or_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1170 }
1171 LogicalOperator::Xor => {
1172 xor_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1173 }
1174 };
1175
1176 #[cfg(feature = "simd")]
1178 let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1179 (None, None) => None,
1180 (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1181 (Some(a), Some(b)) => Some(and_masks_simd::<LANES>(
1182 (a, lhs_off, len),
1183 (b, rhs_off, len),
1184 )),
1185 };
1186
1187 #[cfg(not(feature = "simd"))]
1188 let data = match op {
1189 LogicalOperator::And => {
1190 and_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1191 }
1192 LogicalOperator::Or => {
1193 or_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1194 }
1195 LogicalOperator::Xor => {
1196 xor_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1197 }
1198 };
1199
1200 #[cfg(not(feature = "simd"))]
1201 let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1202 (None, None) => None,
1203 (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1204 (Some(a), Some(b)) => Some(and_masks((a, lhs_off, len), (b, rhs_off, len))),
1205 };
1206
1207 Ok(BooleanArray {
1208 data,
1209 null_mask,
1210 len,
1211 _phantom: PhantomData,
1212 })
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217 use minarrow::structs::variants::categorical::CategoricalArray;
1218 use minarrow::structs::variants::float::FloatArray;
1219 use minarrow::structs::variants::integer::IntegerArray;
1220 use minarrow::structs::variants::string::StringArray;
1221 use minarrow::{Array, Bitmask, BooleanArray, vec64};
1222
1223 use super::*;
1224
1225 fn bm(bits: &[bool]) -> Bitmask {
1228 let mut m = Bitmask::new_set_all(bits.len(), false);
1229 for (i, &b) in bits.iter().enumerate() {
1230 m.set(i, b);
1231 }
1232 m
1233 }
1234
1235 fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
1236 assert_eq!(arr.len, expect.len(), "length mismatch");
1237 for i in 0..expect.len() {
1238 assert_eq!(arr.data.get(i), expect[i], "val @ {i}");
1239 }
1240 match (expect_mask, &arr.null_mask) {
1241 (None, None) => {}
1242 (Some(exp), Some(mask)) => {
1243 for (i, &b) in exp.iter().enumerate() {
1244 assert_eq!(mask.get(i), b, "mask @ {i}");
1245 }
1246 }
1247 (None, Some(mask)) => {
1248 for i in 0..arr.len {
1250 assert!(mask.get(i), "unexpected false mask @ {i}");
1251 }
1252 }
1253 (Some(_), None) => panic!("expected null mask"),
1254 }
1255 }
1256
1257 fn i32_arr(data: &[i32]) -> IntegerArray<i32> {
1258 IntegerArray::from_slice(data)
1259 }
1260 fn f32_arr(data: &[f32]) -> FloatArray<f32> {
1261 FloatArray::from_slice(data)
1262 }
1263 fn str_arr<T: Integer>(vals: &[&str]) -> StringArray<T> {
1264 StringArray::<T>::from_slice(vals)
1265 }
1266 fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
1267 let owned: Vec<&str> = vals.to_vec();
1268 CategoricalArray::<T>::from_values(owned)
1269 }
1270 #[test]
1273 fn between_i32_scalar_rhs() {
1274 let lhs = vec64![1, 3, 5, 7];
1275 let rhs = vec64![2, 6];
1276 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1277 assert_bool(&out, &[false, true, true, false], None);
1278 }
1279
1280 #[test]
1281 fn between_i32_per_row_rhs() {
1282 let lhs = vec64![5, 9, 2, 8];
1283 let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9]; let out = between_i32(&lhs, &rhs, None, false).unwrap();
1285 assert_bool(&out, &[true, false, true, true], None);
1286 }
1287
1288 #[test]
1289 fn between_i32_nulls_propagate() {
1290 let lhs = vec64![5, 9, 2, 8];
1291 let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9];
1292 let mask = bm(&[true, false, true, true]);
1293 let out = between_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1294 assert_bool(
1295 &out,
1296 &[true, false, true, true],
1297 Some(&[true, false, true, true]),
1298 );
1299 }
1300
1301 #[cfg(feature = "extended_numeric_types")]
1302 #[test]
1303 fn between_i16_works() {
1304 let lhs = vec64![10i16, 12, 99];
1305 let rhs = vec64![10i16, 12];
1306 let out = in_i16(&lhs, &rhs, None, false).unwrap();
1307 assert_bool(&out, &[true, true, false], None);
1308 }
1309
1310 #[test]
1311 fn between_f64_scalar_and_nulls() {
1312 let lhs = vec64![1.0, 5.0, 8.0, 20.0];
1313 let rhs = vec64![4.0, 10.0];
1314 let mask = bm(&[true, false, true, true]);
1315 let out = between_f64(&lhs, &rhs, Some(&mask), true).unwrap();
1316 assert_bool(
1317 &out,
1318 &[false, false, true, false],
1319 Some(&[true, false, true, true]),
1320 );
1321 }
1322
1323 #[test]
1324 fn between_f32_generic_dispatch() {
1325 let lhs = vec64![0.1f32, 0.5, 1.2, -1.0];
1326 let rhs = vec64![0.0, 1.0];
1327 let out = cmp_between(&lhs, &rhs).unwrap();
1328 assert_bool(&out, &[true, true, false, false], None);
1329 }
1330
1331 #[test]
1332 fn between_masked_dispatch() {
1333 let lhs = vec64![1i32, 2, 3];
1334 let rhs = vec64![0, 2];
1335 let mask = bm(&[true, false, true]);
1336 let out = cmp_between_mask(&lhs, &rhs, Some(&mask)).unwrap();
1337 assert_bool(&out, &[true, false, false], Some(&[true, false, true]));
1338 }
1339
1340 #[test]
1343 fn in_i32_small_rhs() {
1344 let lhs = vec64![1, 2, 3, 4, 5];
1345 let rhs = vec64![2, 4];
1346 let out = in_i32(&lhs, &rhs, None, false).unwrap();
1347 assert_bool(&out, &[false, true, false, true, false], None);
1348 }
1349
1350 #[test]
1351 fn in_i32_with_nulls() {
1352 let lhs = vec64![7, 8, 9];
1353 let rhs = vec64![8];
1354 let mask = bm(&[true, false, true]);
1355 let out = in_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1356 assert_bool(&out, &[false, false, false], Some(&[true, false, true]));
1357 }
1358
1359 #[test]
1360 fn in_i64_large_rhs() {
1361 let lhs = vec64![1i64, 2, 3, 7, 8, 15];
1362 let rhs: Vec<i64> = (2..10).collect();
1363 let out = in_i64(&lhs, &rhs, None, false).unwrap();
1364 assert_bool(&out, &[false, true, true, true, true, false], None);
1365 }
1366
1367 #[cfg(feature = "extended_numeric_types")]
1368 #[test]
1369 fn in_u8_small_rhs() {
1370 let lhs = vec64![1u8, 2, 3, 4];
1371 let rhs = vec64![2u8, 3];
1372 let out = in_u8(&lhs, &rhs, None, false).unwrap();
1373 assert_bool(&out, &[false, true, true, false], None);
1374 }
1375
1376 #[test]
1377 fn in_float_nan_and_normal() {
1378 let lhs = vec64![1.0f32, f32::NAN, 7.0];
1379 let rhs = vec64![f32::NAN, 7.0];
1380 let out = in_f32(&lhs, &rhs, None, false).unwrap();
1381 assert_bool(&out, &[false, true, true], None);
1382 }
1383
1384 #[test]
1387 fn string_between() {
1388 let lhs = str_arr::<u32>(&["aa", "bb", "zz"]);
1389 let rhs = str_arr::<u32>(&["b", "y"]);
1390 let lhs_slice = (&lhs, 0, lhs.len());
1391 let rhs_slice = (&rhs, 0, rhs.len());
1392 let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1393 assert_bool(&out, &[false, true, false], None);
1394 }
1395
1396 #[test]
1397 fn string_between_chunk() {
1398 let lhs = str_arr::<u32>(&["0", "aa", "bb", "zz", "9"]);
1399 let rhs = str_arr::<u32>(&["a", "b", "y", "z"]);
1400 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1404 assert_bool(&out, &[false, true, false], None);
1405 }
1406
1407 #[test]
1408 fn string_in_basic() {
1409 let lhs = str_arr::<u32>(&["x", "y", "z"]);
1410 let rhs = str_arr::<u32>(&["y", "a"]);
1411 let lhs_slice = (&lhs, 0, lhs.len());
1412 let rhs_slice = (&rhs, 0, rhs.len());
1413 let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1414 assert_bool(&out, &[false, true, false], None);
1415 }
1416
1417 #[test]
1418 fn string_in_basic_chunk() {
1419 let lhs = str_arr::<u32>(&["0", "x", "y", "z", "9"]);
1420 let rhs = str_arr::<u32>(&["b", "y", "a", "c"]);
1421 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1424 assert_bool(&out, &[false, true, false], None);
1425 }
1426
1427 #[test]
1428 fn dict_between() {
1429 let lhs = dict_arr::<u32>(&["cat", "dog", "emu"]);
1430 let rhs = dict_arr::<u32>(&["cobra", "dove"]);
1431 let lhs_slice = (&lhs, 0, lhs.len());
1432 let rhs_slice = (&rhs, 0, rhs.len());
1433 let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1434 assert_bool(&out, &[false, true, false], None);
1435 }
1436
1437 #[test]
1438 fn dict_between_chunk() {
1439 let lhs = dict_arr::<u32>(&["a", "cat", "dog", "emu", "z"]);
1440 let rhs = dict_arr::<u32>(&["a", "cobra", "dove", "zz"]);
1441 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1444 assert_bool(&out, &[false, true, false], None);
1445 }
1446
1447 #[test]
1448 fn dict_in_membership() {
1449 let lhs = dict_arr::<u32>(&["aa", "bb", "cc"]);
1450 let rhs = dict_arr::<u32>(&["bb", "dd"]);
1451 let lhs_slice = (&lhs, 0, lhs.len());
1452 let rhs_slice = (&rhs, 0, rhs.len());
1453 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1454 assert_bool(&out, &[false, true, false], None);
1455 }
1456
1457 #[test]
1458 fn dict_in_membership_chunk() {
1459 let lhs = dict_arr::<u32>(&["0", "aa", "bb", "cc", "9"]);
1460 let rhs = dict_arr::<u32>(&["a", "bb", "dd", "zz"]);
1461 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1464 assert_bool(&out, &[false, true, false], None);
1465 }
1466
1467 #[test]
1468 fn string_between_nulls() {
1469 let mut lhs = str_arr::<u32>(&["foo", "bar", "baz"]);
1470 lhs.null_mask = Some(bm(&[true, false, true]));
1471 let rhs = str_arr::<u32>(&["a", "zzz"]);
1472 let lhs_slice = (&lhs, 0, lhs.len());
1473 let rhs_slice = (&rhs, 0, rhs.len());
1474 let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1475 assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1476 }
1477
1478 #[test]
1479 fn string_between_nulls_chunk() {
1480 let mut lhs = str_arr::<u32>(&["0", "foo", "bar", "baz", "z"]);
1481 lhs.null_mask = Some(bm(&[true, true, false, true, true]));
1482 let rhs = str_arr::<u32>(&["0", "a", "zzz", "9"]);
1483 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1486 assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1487 }
1488
1489 #[test]
1490 fn dict_in_nulls() {
1491 let mut lhs = dict_arr::<u32>(&["one", "two", "three"]);
1492 lhs.null_mask = Some(bm(&[false, true, true]));
1493 let rhs = dict_arr::<u32>(&["two", "four"]);
1494 let lhs_slice = (&lhs, 0, lhs.len());
1495 let rhs_slice = (&rhs, 0, rhs.len());
1496 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1497 assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1498 }
1499
1500 #[test]
1501 fn dict_in_nulls_chunk() {
1502 let mut lhs = dict_arr::<u32>(&["x", "one", "two", "three", "z"]);
1503 lhs.null_mask = Some(bm(&[true, false, true, true, true]));
1504 let rhs = dict_arr::<u32>(&["a", "two", "four", "b"]);
1505 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1508 assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1509 }
1510
1511 #[test]
1514 fn is_null_and_is_not_null() {
1515 let mut arr = i32_arr(&[1, 2, 0]);
1516 arr.null_mask = Some(bm(&[true, false, true]));
1517 let array = Array::from_int32(arr.clone());
1518
1519 let not_null = is_not_null_array(&array).unwrap();
1520 let is_null = is_null_array(&array).unwrap();
1521
1522 assert_bool(¬_null, &[true, false, true], None);
1523 assert_bool(&is_null, &[false, true, false], None);
1524 }
1525
1526 #[test]
1527 fn is_null_not_null_dense() {
1528 let arr = i32_arr(&[1, 2, 3]);
1529 let array = Array::from_int32(arr.clone());
1530 let is_null = is_null_array(&array).unwrap();
1531 assert_bool(&is_null, &[false, false, false], None);
1532 let not_null = is_not_null_array(&array).unwrap();
1533 assert_bool(¬_null, &[true, true, true], None);
1534 }
1535
1536 #[test]
1539 fn in_array_int32_dispatch() {
1540 let inp = Array::from_int32(i32_arr(&[10, 20, 30]));
1541 let vals = Array::from_int32(i32_arr(&[20, 40]));
1542 let out = in_array(&inp, &vals).unwrap();
1543 assert_bool(&out, &[false, true, false], None);
1544
1545 let out_not = not_in_array(&inp, &vals).unwrap();
1546 assert_bool(&out_not, &[true, false, true], None);
1547 }
1548
1549 #[test]
1550 fn in_array_f32_dispatch() {
1551 let inp = Array::from_float32(f32_arr(&[1.0, f32::NAN, 7.0]));
1552 let vals = Array::from_float32(f32_arr(&[f32::NAN, 7.0]));
1553 let out = in_array(&inp, &vals).unwrap();
1554 assert_bool(&out, &[false, true, true], None);
1555 }
1556
1557 #[test]
1558 fn in_array_string_dispatch() {
1559 let inp = Array::from_string32(str_arr::<u32>(&["a", "b", "c"]));
1560 let vals = Array::from_string32(str_arr::<u32>(&["b", "d"]));
1561 let out = in_array(&inp, &vals).unwrap();
1562 assert_bool(&out, &[false, true, false], None);
1563 }
1564
1565 #[test]
1566 fn in_array_dictionary_dispatch() {
1567 let inp = Array::from_categorical32(dict_arr::<u32>(&["aa", "bb", "cc"]));
1568 let vals = Array::from_categorical32(dict_arr::<u32>(&["bb", "cc"]));
1569 let out = in_array(&inp, &vals).unwrap();
1570 assert_bool(&out, &[false, true, true], None);
1571 }
1572
1573 #[test]
1574 fn between_array_int32_rows() {
1575 let inp = Array::from_int32(i32_arr(&[5, 15, 25]));
1576 let min = Array::from_int32(i32_arr(&[0, 10, 20]));
1577 let max = Array::from_int32(i32_arr(&[10, 20, 30]));
1578
1579 let out = between_array(&inp, &min, &max).unwrap();
1580 match out {
1581 Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1582 _ => panic!("expected Bool array"),
1583 }
1584 }
1585
1586 #[test]
1587 fn between_array_float_generic() {
1588 let inp = Array::from_float32(f32_arr(&[0.5, 1.5, 2.5]));
1589 let min = Array::from_float32(f32_arr(&[0.0, 1.0, 2.0]));
1590 let max = Array::from_float32(f32_arr(&[1.0, 2.0, 3.0]));
1591
1592 let out = between_array(&inp, &min, &max).unwrap();
1593 match out {
1594 Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1595 _ => panic!("expected Bool"),
1596 }
1597 }
1598
1599 #[test]
1600 fn between_array_type_mismatch() {
1601 let inp = Array::from_int32(i32_arr(&[1, 2, 3]));
1602 let min = Array::from_float32(f32_arr(&[0.0, 0.0, 0.0]));
1603 let max = Array::from_float32(f32_arr(&[5.0, 5.0, 5.0]));
1604 let err = between_array(&inp, &min, &max).unwrap_err();
1605 match err {
1606 KernelError::UnsupportedType(_) => {}
1607 _ => panic!("Expected UnsupportedType error"),
1608 }
1609 }
1610
1611 #[test]
1614 fn in_integers_various_types() {
1615 #[cfg(feature = "extended_numeric_types")]
1616 {
1617 let u8_lhs = vec64![1u8, 2, 3, 5];
1618 let u8_rhs = vec64![3u8, 5, 8];
1619 let out = in_u8(&u8_lhs, &u8_rhs, None, false).unwrap();
1620 assert_bool(&out, &[false, false, true, true], None);
1621
1622 let u16_lhs = vec64![100u16, 200, 300];
1623 let u16_rhs = vec64![200u16, 500];
1624 let out = in_u16(&u16_lhs, &u16_rhs, None, false).unwrap();
1625 assert_bool(&out, &[false, true, false], None);
1626
1627 let i16_lhs = vec64![10i16, 15, 42];
1628 let i16_rhs = vec64![15i16, 42, 77];
1629 let out = in_i16(&i16_lhs, &i16_rhs, None, false).unwrap();
1630 assert_bool(&out, &[false, true, true], None);
1631 }
1632
1633 let u32_lhs = vec64![0u32, 1, 2, 9];
1634 let u32_rhs = vec64![9u32, 1];
1635 let out = in_u32(&u32_lhs, &u32_rhs, None, false).unwrap();
1636 assert_bool(&out, &[false, true, false, true], None);
1637
1638 let i64_lhs = vec64![1i64, 9, 10];
1639 let i64_rhs = vec64![2i64, 10, 20];
1640 let out = in_i64(&i64_lhs, &i64_rhs, None, false).unwrap();
1641 assert_bool(&out, &[false, false, true], None);
1642
1643 let u64_lhs = vec64![1u64, 2, 3, 4];
1644 let u64_rhs = vec64![2u64, 4, 8];
1645 let out = in_u64(&u64_lhs, &u64_rhs, None, false).unwrap();
1646 assert_bool(&out, &[false, true, false, true], None);
1647 }
1648
1649 #[test]
1652 fn between_and_in_empty_inputs() {
1653 let lhs: [i32; 0] = [];
1655 let rhs = vec64![0, 1];
1656 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1657 assert_eq!(out.len, 0);
1658
1659 let lhs: [i32; 0] = [];
1661 let rhs = vec64![1, 2, 3];
1662 let out = in_i32(&lhs, &rhs, None, false).unwrap();
1663 assert_eq!(out.len, 0);
1664
1665 let lhs = str_arr::<u32>(&[]);
1667 let rhs = str_arr::<u32>(&["a", "b"]);
1668 let lhs_slice = (&lhs, 0, lhs.len());
1669 let rhs_slice = (&rhs, 0, rhs.len());
1670 let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1671 assert_eq!(out.len, 0);
1672 }
1673
1674 #[test]
1675 fn between_and_in_empty_inputs_chunk() {
1676 let lhs = str_arr::<u32>(&["x", "y"]);
1678 let rhs = str_arr::<u32>(&["a", "b", "c"]);
1679 let lhs_slice = (&lhs, 1, 0); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1682 assert_eq!(out.len, 0);
1683 }
1684
1685 #[test]
1686 fn between_per_row_bounds_on_last_row() {
1687 let lhs = vec64![0i32, 10, 20, 30];
1689 let rhs = vec64![0, 5, 5, 15, 15, 25, 25, 35];
1690 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1691 assert_bool(&out, &[true, true, true, true], None);
1692 }
1693
1694 #[test]
1695 fn test_cmp_dict_in_force_fallback() {
1696 let mut lhs = dict_arr::<u32>(&["a", "b", "c", "a"]);
1698 lhs.unique_values = vec64!["a".to_string(), "b".to_string(), "c".to_string()]; let mut rhs = dict_arr::<u32>(&["b", "x", "y", "z"]);
1700 rhs.unique_values = vec64![
1701 "b".to_string(),
1702 "x".to_string(),
1703 "y".to_string(),
1704 "z".to_string()
1705 ]; lhs.null_mask = Some(bm(&[true, true, true, true]));
1707 let lhs_slice = (&lhs, 0, lhs.len());
1708 let rhs_slice = (&rhs, 0, rhs.len());
1709 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1710 assert_bool(
1712 &out,
1713 &[false, true, false, false],
1714 Some(&[true, true, true, true]),
1715 );
1716 }
1717
1718 #[test]
1719 fn test_cmp_dict_in_force_fallback_chunk() {
1720 let mut lhs = dict_arr::<u32>(&["z", "a", "b", "c", "a", "q"]);
1721 lhs.unique_values = vec64![
1722 "z".to_string(),
1723 "a".to_string(),
1724 "b".to_string(),
1725 "c".to_string(),
1726 "q".to_string()
1727 ];
1728 let mut rhs = dict_arr::<u32>(&["x", "b", "x", "y", "z"]);
1729 rhs.unique_values = vec64![
1730 "x".to_string(),
1731 "b".to_string(),
1732 "y".to_string(),
1733 "z".to_string()
1734 ];
1735 lhs.null_mask = Some(bm(&[true, true, true, true, true, true]));
1736 let lhs_slice = (&lhs, 1, 4);
1738 let rhs_slice = (&rhs, 1, 4);
1739 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1740 assert_bool(
1742 &out,
1743 &[false, true, false, false],
1744 Some(&[true, true, true, true]),
1745 );
1746 }
1747
1748 #[test]
1749 fn test_in_array_empty_rhs() {
1750 let arr = Array::from_int32(i32_arr(&[1, 2, 3]));
1751 let empty = Array::from_int32(i32_arr(&[]));
1752 let out = in_array(&arr, &empty).unwrap();
1753 assert_bool(&out, &[false, false, false], None);
1755 }
1756}