1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
19
20use std::collections::HashSet;
21use std::hash::Hash;
22use std::marker::PhantomData;
23use std::simd::{LaneCount, SupportedLaneCount};
24#[cfg(feature = "simd")]
25use std::simd::{Mask, Simd, cmp::SimdPartialEq, cmp::SimdPartialOrd, num::SimdFloat};
26
27use minarrow::kernels::arithmetic::string::MAX_DICT_CHECK;
28use minarrow::traits::type_unions::Float;
29use minarrow::{
30 Array, Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, MaskedArray, Numeric,
31 NumericArray, StringAVT, TextArray, Vec64,
32};
33
34#[cfg(not(feature = "simd"))]
35use crate::kernels::bitmask::dispatch::{and_masks, or_masks, xor_masks};
36use crate::operators::LogicalOperator;
37use minarrow::enums::error::KernelError;
38#[cfg(feature = "simd")]
39use minarrow::kernels::bitmask::simd::{and_masks_simd, or_masks_simd, xor_masks_simd};
40use minarrow::utils::confirm_mask_capacity;
41
42#[cfg(feature = "simd")]
43use minarrow::utils::is_simd_aligned;
44use std::any::TypeId;
45
46#[inline(always)]
49fn new_bool_buffer(len: usize) -> Bitmask {
50 Bitmask::new_set_all(len, false)
51}
52
53macro_rules! impl_between_numeric {
56 ($name:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
57 #[inline(always)]
59 pub fn $name(
60 lhs: &[$ty],
61 rhs: &[$ty],
62 mask: Option<&Bitmask>, has_nulls: bool
64 ) -> Result<BooleanArray<()>, KernelError> {
65
66 let len = lhs.len();
67 if rhs.len() != 2 && rhs.len() != 2 * len {
68 return Err(KernelError::InvalidArguments(
69 format!("between: RHS must have len 2 or 2×LHS (got lhs: {}, rhs: {})", len, rhs.len())
70 ));
71 }
72
73 if let Some(m) = mask {
74 if m.capacity() < len {
75 return Err(KernelError::InvalidArguments(
76 format!("between: mask (Bitmask) capacity must be ≥ len (got capacity: {}, len: {})", m.capacity(), len)
77 ));
78 }
79 }
80 let mut out_data = new_bool_buffer(len);
81
82 #[cfg(feature = "simd")]
84 {
85 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
87 const N: usize = $lanes;
88 type V = Simd<$ty, N>;
89 type M = Mask<$mask_elem, N>;
90
91 if !has_nulls && rhs.len() == 2 {
92 let min_v = V::splat(rhs[0]);
93 let max_v = V::splat(rhs[1]);
94
95 let mut i = 0usize;
96 while i + N <= len {
97 let x = V::from_slice(&lhs[i..i + N]);
98 let m: M = x.simd_ge(min_v) & x.simd_le(max_v);
99 let bm = m.to_bitmask();
100
101 for l in 0..N {
102 if ((bm >> l) & 1) == 1 {
103 out_data.set(i + l, true);
104 }
105 }
106 i += N;
107 }
108 for j in i..len {
110 if lhs[j] >= rhs[0] && lhs[j] <= rhs[1] {
111 out_data.set(j, true);
112 }
113 }
114
115 return Ok(BooleanArray {
116 data: out_data.into(),
117 null_mask: mask.cloned(),
118 len,
119 _phantom: PhantomData
120 });
121 }
122 }
123 }
125
126 if rhs.len() == 2 {
128 let (min, max) = (rhs[0], rhs[1]);
129 for i in 0..len {
130 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
131 && lhs[i] >= min
132 && lhs[i] <= max
133 {
134 out_data.set(i, true);
135 }
136 }
137 } else {
138 for i in 0..len {
140 let min = rhs[i * 2];
141 let max = rhs[i * 2 + 1];
142 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
143 && lhs[i] >= min
144 && lhs[i] <= max
145 {
146 out_data.set(i, true);
147 }
148 }
149 }
150
151 Ok(BooleanArray {
152 data: out_data.into(),
153 null_mask: mask.cloned(),
154 len,
155 _phantom: PhantomData
156 })
157 }
158 };
159}
160
161#[inline(always)]
164fn between_generic<T: Numeric + Copy + std::cmp::PartialOrd>(
165 lhs: &[T],
166 rhs: &[T],
167 mask: Option<&Bitmask>,
168 has_nulls: bool,
169) -> Result<BooleanArray<()>, KernelError> {
170 let len = lhs.len();
171 let mut out = new_bool_buffer(len);
172 let _ = confirm_mask_capacity(len, mask)?;
173 if rhs.len() == 2 {
174 let (min, max) = (rhs[0], rhs[1]);
175 for i in 0..len {
176 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
177 && lhs[i] >= min
178 && lhs[i] <= max
179 {
180 out.set(i, true);
181 }
182 }
183 } else {
184 for i in 0..len {
185 let min = rhs[i * 2];
186 let max = rhs[i * 2 + 1];
187 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
188 && lhs[i] >= min
189 && lhs[i] <= max
190 {
191 out.set(i, true);
192 }
193 }
194 }
195
196 Ok(BooleanArray {
197 data: out.into(),
198 null_mask: mask.cloned(),
199 len,
200 _phantom: PhantomData,
201 })
202}
203
204macro_rules! impl_in_int {
207 ($name:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
208 #[inline(always)]
210 pub fn $name(
211 lhs: &[$ty],
212 rhs: &[$ty],
213 mask: Option<&Bitmask>,
214 has_nulls: bool,
215 ) -> Result<BooleanArray<()>, KernelError> {
216 let len = lhs.len();
217 let mut out = new_bool_buffer(len);
218 let _ = confirm_mask_capacity(len, mask)?;
219
220 #[cfg(feature = "simd")]
221 {
222 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
224 use crate::utils::bitmask_to_simd_mask;
225 use core::simd::{Mask, Simd};
226
227 if rhs.len() <= 16 {
228 let mut i = 0;
229 let rhs_simd = rhs;
230 if !has_nulls {
231 while i + $lanes <= len {
232 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
233 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
234 for &v in rhs_simd {
235 m |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
236 }
237 let bm = m.to_bitmask();
238 for l in 0..$lanes {
239 if ((bm >> l) & 1) == 1 {
240 out.set(i + l, true);
241 }
242 }
243 i += $lanes;
244 }
245 for j in i..len {
246 if rhs_simd.contains(&lhs[j]) {
247 out.set(j, true);
248 }
249 }
250 return Ok(BooleanArray {
251 data: out.into(),
252 null_mask: mask.cloned(),
253 len,
254 _phantom: PhantomData,
255 });
256 } else {
257 let mb = mask.expect("Bitmask must be Some if has_nulls is set");
259 let mask_bytes = mb.as_bytes();
260 while i + $lanes <= len {
261 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
262 let lane_mask =
264 bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
265 let mut in_mask = Mask::<$mask_elem, $lanes>::splat(false);
266 for &v in rhs_simd {
267 in_mask |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
268 }
269 let valid_in = lane_mask & in_mask;
271 let bm = valid_in.to_bitmask();
272 for l in 0..$lanes {
273 if ((bm >> l) & 1) == 1 {
274 out.set(i + l, true);
275 }
276 }
277 i += $lanes;
278 }
279 for j in i..len {
280 if unsafe { mb.get_unchecked(j) } && rhs_simd.contains(&lhs[j]) {
281 out.set(j, true);
282 }
283 }
284 return Ok(BooleanArray {
285 data: out.into(),
286 null_mask: mask.cloned(),
287 len,
288 _phantom: PhantomData,
289 });
290 }
291 }
292 }
293 }
295
296 let set: std::collections::HashSet<$ty> = rhs.iter().copied().collect();
298 for i in 0..len {
299 if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
300 && set.contains(&lhs[i])
301 {
302 out.set(i, true);
303 }
304 }
305 Ok(BooleanArray {
306 data: out.into(),
307 null_mask: mask.cloned(),
308 len,
309 _phantom: PhantomData,
310 })
311 }
312 };
313}
314
315macro_rules! impl_in_float {
317 (
318 $fn_name:ident, $ty:ty, $lanes:expr, $mask_elem:ty
319 ) => {
320 #[inline(always)]
322 pub fn $fn_name(
323 lhs: &[$ty],
324 rhs: &[$ty],
325 mask: Option<&Bitmask>,
326 has_nulls: bool,
327 ) -> Result<BooleanArray<()>, KernelError> {
328 let len = lhs.len();
329 let mut out = new_bool_buffer(len);
330 let _ = confirm_mask_capacity(len, mask)?;
331
332 #[cfg(feature = "simd")]
333 {
334 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
336 use crate::utils::bitmask_to_simd_mask;
337 use core::simd::{Mask, Simd};
338 if rhs.len() <= 16 {
339 let mut i = 0;
340 if !has_nulls {
341 while i + $lanes <= len {
342 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
343 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
344 for &v in rhs {
345 let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
346 | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
347 m |= vmask;
348 }
349 let bm = m.to_bitmask();
350 for l in 0..$lanes {
351 if ((bm >> l) & 1) == 1 {
352 out.set(i + l, true);
353 }
354 }
355 i += $lanes;
356 }
357 for j in i..len {
358 let x = lhs[j];
359 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
360 out.set(j, true);
361 }
362 }
363 return Ok(BooleanArray {
364 data: out.into(),
365 null_mask: mask.cloned(),
366 len,
367 _phantom: PhantomData,
368 });
369 } else {
370 let mb = mask.expect("Bitmask must be Some if nulls are present");
371 let mask_bytes = mb.as_bytes();
372 while i + $lanes <= len {
373 let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
374 let lane_mask =
375 bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
376 let mut m = Mask::<$mask_elem, $lanes>::splat(false);
377 for &v in rhs {
378 let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
379 | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
380 m |= vmask;
381 }
382 let m = m & lane_mask;
383 let bm = m.to_bitmask();
384 for l in 0..$lanes {
385 if ((bm >> l) & 1) == 1 {
386 out.set(i + l, true);
387 }
388 }
389 i += $lanes;
390 }
391 for j in i..len {
392 if mask.map_or(true, |m| unsafe { m.get_unchecked(j) }) {
393 let x = lhs[j];
394 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
395 out.set(j, true);
396 }
397 }
398 }
399 return Ok(BooleanArray {
400 data: out.into(),
401 null_mask: mask.cloned(),
402 len,
403 _phantom: PhantomData,
404 });
405 }
406 }
407 }
408 }
410
411 for i in 0..len {
413 if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
414 continue;
415 }
416 let x = lhs[i];
417 if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
418 out.set(i, true);
419 }
420 }
421 Ok(BooleanArray {
422 data: out.into(),
423 null_mask: mask.cloned(),
424 len,
425 _phantom: PhantomData,
426 })
427 }
428 };
429}
430
431#[cfg(feature = "extended_numeric_types")]
433impl_in_int!(in_i8, i8, W8, i8);
434#[cfg(feature = "extended_numeric_types")]
435impl_in_int!(in_u8, u8, W8, i8);
436#[cfg(feature = "extended_numeric_types")]
437impl_in_int!(in_i16, i16, W16, i16);
438#[cfg(feature = "extended_numeric_types")]
439impl_in_int!(in_u16, u16, W16, i16);
440impl_in_int!(in_i32, i32, W32, i32);
441impl_in_int!(in_u32, u32, W32, i32);
442impl_in_int!(in_i64, i64, W64, i64);
443impl_in_int!(in_u64, u64, W64, i64);
444impl_in_float!(in_f32, f32, W32, i32);
445impl_in_float!(in_f64, f64, W64, i64);
446
447#[cfg(feature = "extended_numeric_types")]
448impl_between_numeric!(between_i8, i8, i8, W8);
449#[cfg(feature = "extended_numeric_types")]
450impl_between_numeric!(between_u8, u8, i8, W8);
451#[cfg(feature = "extended_numeric_types")]
452impl_between_numeric!(between_i16, i16, i16, W16);
453#[cfg(feature = "extended_numeric_types")]
454impl_between_numeric!(between_u16, u16, i16, W16);
455
456impl_between_numeric!(between_i32, i32, i32, W32);
457impl_between_numeric!(between_u32, u32, i32, W32);
458impl_between_numeric!(between_i64, i64, i64, W64);
459impl_between_numeric!(between_u64, u64, i64, W64);
460impl_between_numeric!(between_f32, f32, i32, W32);
461impl_between_numeric!(between_f64, f64, i64, W64);
462
463#[inline(always)]
467pub fn cmp_str_between<'a, T: Integer>(
468 lhs: StringAVT<'a, T>,
469 rhs: StringAVT<'a, T>,
470) -> Result<BooleanArray<()>, KernelError> {
471 let (larr, loff, llen) = lhs;
472 let (rarr, roff, rlen) = rhs;
473
474 if rlen < 2 {
475 return Err(KernelError::InvalidArguments(format!(
476 "str_between: RHS must contain at least two values (got {})",
477 rlen
478 )));
479 }
480 let min = rarr.get(roff).unwrap_or("");
481 let max = rarr.get(roff + 1).unwrap_or("");
482 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
483 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
484
485 let mut out = new_bool_buffer(llen);
486
487 for i in 0..llen {
488 if mask
489 .as_ref()
490 .map_or(true, |m| unsafe { m.get_unchecked(i) })
491 {
492 let s = unsafe { larr.get_str_unchecked(loff + i) };
493 if s >= min && s <= max {
494 unsafe { out.set_unchecked(i, true) };
495 }
496 }
497 }
498
499 Ok(BooleanArray {
500 data: out.into(),
501 null_mask: mask,
502 len: llen,
503 _phantom: PhantomData,
504 })
505}
506
507#[inline(always)]
508pub fn cmp_str_in<'a, T: Integer>(
510 lhs: StringAVT<'a, T>,
511 rhs: StringAVT<'a, T>,
512) -> Result<BooleanArray<()>, KernelError> {
513 let (larr, loff, llen) = lhs;
514 let (rarr, roff, rlen) = rhs;
515
516 let set: HashSet<&str> = (0..rlen)
517 .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
518 .collect();
519
520 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
521 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
522
523 let mut out = new_bool_buffer(llen);
524
525 for i in 0..llen {
526 if mask
527 .as_ref()
528 .map_or(true, |m| unsafe { m.get_unchecked(i) })
529 {
530 let s = unsafe { larr.get_str_unchecked(loff + i) };
531 if set.contains(s) {
532 unsafe { out.set_unchecked(i, true) };
533 }
534 }
535 }
536 Ok(BooleanArray {
537 data: out.into(),
538 null_mask: mask,
539 len: llen,
540 _phantom: PhantomData,
541 })
542}
543
544pub fn cmp_between<T: PartialOrd + Copy + Numeric>(
548 lhs: &[T],
549 rhs: &[T],
550) -> Result<BooleanArray<()>, KernelError> {
551 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
552 return between_i32(
553 unsafe { std::mem::transmute(lhs) },
554 unsafe { std::mem::transmute(rhs) },
555 None,
556 false,
557 );
558 }
559 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
560 return between_u32(
561 unsafe { std::mem::transmute(lhs) },
562 unsafe { std::mem::transmute(rhs) },
563 None,
564 false,
565 );
566 }
567 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
568 return between_i64(
569 unsafe { std::mem::transmute(lhs) },
570 unsafe { std::mem::transmute(rhs) },
571 None,
572 false,
573 );
574 }
575 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
576 return between_u64(
577 unsafe { std::mem::transmute(lhs) },
578 unsafe { std::mem::transmute(rhs) },
579 None,
580 false,
581 );
582 }
583 between_generic(lhs, rhs, None, false)
585}
586
587#[inline(always)]
589pub fn cmp_between_mask<T: PartialOrd + Copy + Numeric>(
590 lhs: &[T],
591 rhs: &[T],
592 mask: Option<&Bitmask>,
593) -> Result<BooleanArray<()>, KernelError> {
594 let has_nulls = mask.is_some();
595 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
596 return between_i32(
597 unsafe { std::mem::transmute(lhs) },
598 unsafe { std::mem::transmute(rhs) },
599 mask,
600 has_nulls,
601 );
602 }
603 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
604 return between_u32(
605 unsafe { std::mem::transmute(lhs) },
606 unsafe { std::mem::transmute(rhs) },
607 mask,
608 has_nulls,
609 );
610 }
611 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
612 return between_i64(
613 unsafe { std::mem::transmute(lhs) },
614 unsafe { std::mem::transmute(rhs) },
615 mask,
616 has_nulls,
617 );
618 }
619 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
620 return between_u64(
621 unsafe { std::mem::transmute(lhs) },
622 unsafe { std::mem::transmute(rhs) },
623 mask,
624 has_nulls,
625 );
626 }
627 between_generic(lhs, rhs, mask, has_nulls)
628}
629
630pub fn cmp_in<T: Eq + Hash + Copy + 'static>(
634 lhs: &[T],
635 rhs: &[T],
636) -> Result<BooleanArray<()>, KernelError> {
637 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
639 return in_i32(
640 unsafe { std::mem::transmute(lhs) },
641 unsafe { std::mem::transmute(rhs) },
642 None,
643 false,
644 );
645 }
646 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
648 return in_u32(
649 unsafe { std::mem::transmute(lhs) },
650 unsafe { std::mem::transmute(rhs) },
651 None,
652 false,
653 );
654 }
655 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
657 return in_i64(
658 unsafe { std::mem::transmute(lhs) },
659 unsafe { std::mem::transmute(rhs) },
660 None,
661 false,
662 );
663 }
664 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
666 return in_u64(
667 unsafe { std::mem::transmute(lhs) },
668 unsafe { std::mem::transmute(rhs) },
669 None,
670 false,
671 );
672 }
673 #[cfg(feature = "extended_numeric_types")]
675 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i16>() {
676 return in_i16(
677 unsafe { std::mem::transmute(lhs) },
678 unsafe { std::mem::transmute(rhs) },
679 None,
680 false,
681 );
682 }
683 #[cfg(feature = "extended_numeric_types")]
685 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u16>() {
686 return in_u16(
687 unsafe { std::mem::transmute(lhs) },
688 unsafe { std::mem::transmute(rhs) },
689 None,
690 false,
691 );
692 }
693 #[cfg(feature = "extended_numeric_types")]
695 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i8>() {
696 return in_i8(
697 unsafe { std::mem::transmute(lhs) },
698 unsafe { std::mem::transmute(rhs) },
699 None,
700 false,
701 );
702 }
703 #[cfg(feature = "extended_numeric_types")]
705 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u8>() {
706 return in_u8(
707 unsafe { std::mem::transmute(lhs) },
708 unsafe { std::mem::transmute(rhs) },
709 None,
710 false,
711 );
712 }
713 return Err(KernelError::UnsupportedType(
714 "cmp_in: unsupported type for SIMD in".into(),
715 ));
716}
717
718#[inline(always)]
720pub fn cmp_in_mask<T: Eq + Hash + Copy + 'static>(
721 lhs: &[T],
722 rhs: &[T],
723 mask: Option<&Bitmask>,
724) -> Result<BooleanArray<()>, KernelError> {
725 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
726 return in_i32(
727 unsafe { std::mem::transmute(lhs) },
728 unsafe { std::mem::transmute(rhs) },
729 mask,
730 mask.is_some(),
731 );
732 }
733 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
734 return in_u32(
735 unsafe { std::mem::transmute(lhs) },
736 unsafe { std::mem::transmute(rhs) },
737 mask,
738 mask.is_some(),
739 );
740 }
741 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
742 return in_i64(
743 unsafe { std::mem::transmute(lhs) },
744 unsafe { std::mem::transmute(rhs) },
745 mask,
746 mask.is_some(),
747 );
748 }
749 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
750 return in_u64(
751 unsafe { std::mem::transmute(lhs) },
752 unsafe { std::mem::transmute(rhs) },
753 mask,
754 mask.is_some(),
755 );
756 }
757 return Err(KernelError::UnsupportedType(
758 "cmp_in_mask: unsupported type (expected integer type)".into(),
759 ));
760}
761
762#[inline(always)]
764pub fn cmp_in_f_mask<T: Float + Copy>(
765 lhs: &[T],
766 rhs: &[T],
767 mask: Option<&Bitmask>,
768) -> Result<BooleanArray<()>, KernelError> {
769 if TypeId::of::<T>() == TypeId::of::<f32>() {
770 let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
771 let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
772 in_f32(lhs, rhs, mask, mask.is_some())
773 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
774 let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
775 let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
776 in_f64(lhs, rhs, mask, mask.is_some())
777 } else {
778 unreachable!("cmp_in_f_mask: Only f32/f64 supported for Float kernels")
779 }
780}
781
782#[inline(always)]
783pub fn cmp_in_f<T: Float + Copy>(lhs: &[T], rhs: &[T]) -> Result<BooleanArray<()>, KernelError> {
785 if TypeId::of::<T>() == TypeId::of::<f32>() {
786 let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
787 let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
788 in_f32(lhs, rhs, None, false)
789 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
790 let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
791 let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
792 in_f64(lhs, rhs, None, false)
793 } else {
794 unreachable!("cmp_in_f: Only f32/f64 supported for Float kernels")
795 }
796}
797
798pub fn cmp_between_f<T: PartialOrd + Copy + Float + Numeric>(
802 lhs: &[T],
803 rhs: &[T],
804) -> Result<BooleanArray<()>, KernelError> {
805 between_generic(lhs, rhs, None, false)
806}
807
808pub fn cmp_dict_between<'a, T: Integer>(
810 lhs: CategoricalAVT<'a, T>,
811 rhs: CategoricalAVT<'a, T>,
812) -> Result<BooleanArray<()>, KernelError> {
813 let (larr, loff, llen) = lhs;
814 let (rarr, roff, _rlen) = rhs;
815
816 let min = rarr.get(roff).unwrap_or("");
817 let max = rarr.get(roff + 1).unwrap_or("");
818 let mask = larr.null_mask.as_ref();
819 let _ = confirm_mask_capacity(larr.data.len(), mask)?;
820 let has_nulls = mask.is_some();
821
822 let mut out = new_bool_buffer(llen);
823 for i in 0..llen {
824 let li = loff + i;
825 if !has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(li) }) {
826 let s = unsafe { larr.get_str_unchecked(li) };
827 if s > min && s <= max {
828 unsafe { out.set_unchecked(i, true) };
829 }
830 }
831 }
832 Ok(BooleanArray {
833 data: out.into(),
834 null_mask: mask.cloned(),
835 len: llen,
836 _phantom: PhantomData,
837 })
838}
839
840pub fn cmp_dict_in<'a, T: Integer + Hash>(
845 lhs: CategoricalAVT<'a, T>,
846 rhs: CategoricalAVT<'a, T>,
847) -> Result<BooleanArray<()>, KernelError> {
848 let (larr, loff, llen) = lhs;
849 let (rarr, roff, rlen) = rhs;
850 let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
851 let _ = confirm_mask_capacity(llen, mask.as_ref())?;
852
853 let mut out = Bitmask::new_set_all(llen, false);
854
855 if (larr.unique_values.len() == rarr.unique_values.len())
856 && (larr.unique_values.len() <= MAX_DICT_CHECK)
857 {
858 let mut same_dict = true;
859 for (a, b) in larr.unique_values.iter().zip(rarr.unique_values.iter()) {
860 if a != b {
861 same_dict = false;
862 break;
863 }
864 }
865
866 if same_dict {
867 let rhs_codes: HashSet<T> = rarr.data[roff..roff + rlen].iter().copied().collect();
868 for i in 0..llen {
869 if mask
870 .as_ref()
871 .map_or(true, |m| unsafe { m.get_unchecked(i) })
872 {
873 let code = larr.data[loff + i];
874 if rhs_codes.contains(&code) {
875 unsafe { out.set_unchecked(i, true) };
876 }
877 }
878 }
879 return Ok(BooleanArray {
880 data: out.into(),
881 null_mask: mask,
882 len: llen,
883 _phantom: PhantomData,
884 });
885 }
886 }
887
888 let rhs_strings: HashSet<&str> = (0..rlen)
889 .filter(|&i| {
890 rarr.null_mask
891 .as_ref()
892 .map_or(true, |m| unsafe { m.get_unchecked(roff + i) })
893 })
894 .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
895 .collect();
896
897 for i in 0..llen {
898 if mask
899 .as_ref()
900 .map_or(true, |m| unsafe { m.get_unchecked(i) })
901 {
902 let s = unsafe { larr.get_str_unchecked(loff + i) };
903 if rhs_strings.contains(s) {
904 unsafe { out.set_unchecked(i, true) };
905 }
906 }
907 }
908
909 Ok(BooleanArray {
910 data: out.into(),
911 null_mask: mask,
912 len: llen,
913 _phantom: PhantomData,
914 })
915}
916
917pub fn is_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
921 let not_null = is_not_null_array(arr)?;
922 Ok(!not_null)
923}
924pub fn is_not_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
926 let len = arr.len();
927 let mut data = Bitmask::new_set_all(len, false);
928
929 if let Some(mask) = arr.null_mask() {
930 data = mask.clone();
931 } else {
932 data.fill(true);
933 }
934 Ok(BooleanArray {
935 data,
936 null_mask: None,
937 len,
938 _phantom: PhantomData,
939 })
940}
941
942pub fn in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
945 match (input, values) {
946 (
947 Array::NumericArray(NumericArray::Int32(a)),
948 Array::NumericArray(NumericArray::Int32(b)),
949 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
950 (
951 Array::NumericArray(NumericArray::Int64(a)),
952 Array::NumericArray(NumericArray::Int64(b)),
953 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
954 (
955 Array::NumericArray(NumericArray::UInt32(a)),
956 Array::NumericArray(NumericArray::UInt32(b)),
957 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
958 (
959 Array::NumericArray(NumericArray::UInt64(a)),
960 Array::NumericArray(NumericArray::UInt64(b)),
961 ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
962 (
963 Array::NumericArray(NumericArray::Float32(a)),
964 Array::NumericArray(NumericArray::Float32(b)),
965 ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
966 (
967 Array::NumericArray(NumericArray::Float64(a)),
968 Array::NumericArray(NumericArray::Float64(b)),
969 ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
970 (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String32(b))) => {
971 cmp_str_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len()))
972 }
973 (Array::BooleanArray(a), Array::BooleanArray(b)) => {
974 cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref())
975 }
976 (
977 Array::TextArray(TextArray::Categorical32(a)),
978 Array::TextArray(TextArray::Categorical32(b)),
979 ) => cmp_dict_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len())),
980 _ => unimplemented!(),
981 }
982}
983
984#[inline(always)]
985pub fn not_in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
987 let result = in_array(input, values)?;
988 Ok(!result)
989}
990
991pub fn between_array(input: &Array, min: &Array, max: &Array) -> Result<Array, KernelError> {
993 macro_rules! between_case {
994 ($variant:ident, $cmp:ident) => {{
995 let arr = match input {
996 Array::NumericArray(NumericArray::$variant(arr)) => arr,
997 _ => unreachable!(),
998 };
999 let mins = match min {
1000 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1001 _ => unreachable!(),
1002 };
1003 let maxs = match max {
1004 Array::NumericArray(NumericArray::$variant(arr)) => arr,
1005 _ => unreachable!(),
1006 };
1007 let rhs: Vec64<_> = mins
1008 .data
1009 .iter()
1010 .zip(&maxs.data)
1011 .flat_map(|(&lo, &hi)| [lo, hi])
1012 .collect();
1013 Ok(Array::BooleanArray(
1014 $cmp(
1015 &arr.data,
1016 &rhs,
1017 arr.null_mask.as_ref(),
1018 arr.null_mask.is_some(),
1019 )?
1020 .into(),
1021 ))
1022 }};
1023 }
1024
1025 match (input, min, max) {
1026 (
1027 Array::NumericArray(NumericArray::Int32(..)),
1028 Array::NumericArray(NumericArray::Int32(..)),
1029 Array::NumericArray(NumericArray::Int32(..)),
1030 ) => between_case!(Int32, between_i32),
1031 (
1032 Array::NumericArray(NumericArray::Int64(..)),
1033 Array::NumericArray(NumericArray::Int64(..)),
1034 Array::NumericArray(NumericArray::Int64(..)),
1035 ) => between_case!(Int64, between_i64),
1036 (
1037 Array::NumericArray(NumericArray::UInt32(..)),
1038 Array::NumericArray(NumericArray::UInt32(..)),
1039 Array::NumericArray(NumericArray::UInt32(..)),
1040 ) => between_case!(UInt32, between_u32),
1041 (
1042 Array::NumericArray(NumericArray::UInt64(..)),
1043 Array::NumericArray(NumericArray::UInt64(..)),
1044 Array::NumericArray(NumericArray::UInt64(..)),
1045 ) => between_case!(UInt64, between_u64),
1046 (
1047 Array::NumericArray(NumericArray::Float32(..)),
1048 Array::NumericArray(NumericArray::Float32(..)),
1049 Array::NumericArray(NumericArray::Float32(..)),
1050 ) => between_case!(Float32, between_generic),
1051 (
1052 Array::NumericArray(NumericArray::Float64(..)),
1053 Array::NumericArray(NumericArray::Float64(..)),
1054 Array::NumericArray(NumericArray::Float64(..)),
1055 ) => between_case!(Float64, between_generic),
1056 _ => Err(KernelError::UnsupportedType(
1057 "Unsupported Type Error.".to_string(),
1058 )),
1059 }
1060}
1061
1062#[inline]
1066pub fn not_bool<const LANES: usize>(
1067 src: BooleanAVT<'_, ()>,
1068) -> Result<BooleanArray<()>, KernelError>
1069where
1070 LaneCount<LANES>: SupportedLaneCount,
1071{
1072 let (arr, offset, len) = src;
1073
1074 if offset % 64 != 0 {
1075 return Err(KernelError::InvalidArguments(format!(
1076 "not_bool: offset must be 64-bit aligned (got offset={})",
1077 offset
1078 )));
1079 }
1080
1081 let null_mask = arr.null_mask.as_ref().map(|nm| nm.slice_clone(offset, len));
1082
1083 let data = if null_mask.is_none() {
1084 #[cfg(feature = "simd")]
1085 {
1086 minarrow::kernels::bitmask::simd::not_mask_simd::<LANES>((&arr.data, offset, len))
1087 }
1088 #[cfg(not(feature = "simd"))]
1089 {
1090 minarrow::kernels::bitmask::std::not_mask((&arr.data, offset, len))
1091 }
1092 } else {
1093 arr.data.slice_clone(offset, len)
1095 };
1096
1097 Ok(BooleanArray {
1098 data,
1099 null_mask,
1100 len,
1101 _phantom: core::marker::PhantomData,
1102 })
1103}
1104
1105pub fn apply_logical_bool<const LANES: usize>(
1109 lhs: BooleanAVT<'_, ()>,
1110 rhs: BooleanAVT<'_, ()>,
1111 op: LogicalOperator,
1112) -> Result<BooleanArray<()>, KernelError>
1113where
1114 LaneCount<LANES>: SupportedLaneCount,
1115{
1116 let (lhs_arr, lhs_off, len) = lhs;
1117 let (rhs_arr, rhs_off, rlen) = rhs;
1118
1119 if len != rlen {
1120 return Err(KernelError::LengthMismatch(format!(
1121 "logical_bool: window length mismatch (lhs: {}, rhs: {})",
1122 len, rlen
1123 )));
1124 }
1125 if lhs_off % 64 != 0 || rhs_off % 64 != 0 {
1126 return Err(KernelError::InvalidArguments(format!(
1127 "logical_bool: offsets must be 64-bit aligned (lhs: {}, rhs: {})",
1128 lhs_off, rhs_off
1129 )));
1130 }
1131
1132 #[cfg(feature = "simd")]
1135 let data = match op {
1136 LogicalOperator::And => {
1137 and_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1138 }
1139 LogicalOperator::Or => {
1140 or_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1141 }
1142 LogicalOperator::Xor => {
1143 xor_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1144 }
1145 };
1146
1147 #[cfg(feature = "simd")]
1149 let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1150 (None, None) => None,
1151 (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1152 (Some(a), Some(b)) => Some(and_masks_simd::<LANES>(
1153 (a, lhs_off, len),
1154 (b, rhs_off, len),
1155 )),
1156 };
1157
1158 #[cfg(not(feature = "simd"))]
1159 let data = match op {
1160 LogicalOperator::And => {
1161 and_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1162 }
1163 LogicalOperator::Or => {
1164 or_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1165 }
1166 LogicalOperator::Xor => {
1167 xor_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1168 }
1169 };
1170
1171 #[cfg(not(feature = "simd"))]
1172 let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1173 (None, None) => None,
1174 (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1175 (Some(a), Some(b)) => Some(and_masks((a, lhs_off, len), (b, rhs_off, len))),
1176 };
1177
1178 Ok(BooleanArray {
1179 data,
1180 null_mask,
1181 len,
1182 _phantom: PhantomData,
1183 })
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188 use minarrow::structs::variants::categorical::CategoricalArray;
1189 use minarrow::structs::variants::float::FloatArray;
1190 use minarrow::structs::variants::integer::IntegerArray;
1191 use minarrow::structs::variants::string::StringArray;
1192 use minarrow::{Array, Bitmask, BooleanArray, vec64};
1193
1194 use super::*;
1195
1196 fn bm(bits: &[bool]) -> Bitmask {
1199 let mut m = Bitmask::new_set_all(bits.len(), false);
1200 for (i, &b) in bits.iter().enumerate() {
1201 m.set(i, b);
1202 }
1203 m
1204 }
1205
1206 fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
1207 assert_eq!(arr.len, expect.len(), "length mismatch");
1208 for i in 0..expect.len() {
1209 assert_eq!(arr.data.get(i), expect[i], "val @ {i}");
1210 }
1211 match (expect_mask, &arr.null_mask) {
1212 (None, None) => {}
1213 (Some(exp), Some(mask)) => {
1214 for (i, &b) in exp.iter().enumerate() {
1215 assert_eq!(mask.get(i), b, "mask @ {i}");
1216 }
1217 }
1218 (None, Some(mask)) => {
1219 for i in 0..arr.len {
1221 assert!(mask.get(i), "unexpected false mask @ {i}");
1222 }
1223 }
1224 (Some(_), None) => panic!("expected null mask"),
1225 }
1226 }
1227
1228 fn i32_arr(data: &[i32]) -> IntegerArray<i32> {
1229 IntegerArray::from_slice(data)
1230 }
1231 fn f32_arr(data: &[f32]) -> FloatArray<f32> {
1232 FloatArray::from_slice(data)
1233 }
1234 fn str_arr<T: Integer>(vals: &[&str]) -> StringArray<T> {
1235 StringArray::<T>::from_slice(vals)
1236 }
1237 fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
1238 let owned: Vec<&str> = vals.to_vec();
1239 CategoricalArray::<T>::from_values(owned)
1240 }
1241 #[test]
1244 fn between_i32_scalar_rhs() {
1245 let lhs = vec64![1, 3, 5, 7];
1246 let rhs = vec64![2, 6];
1247 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1248 assert_bool(&out, &[false, true, true, false], None);
1249 }
1250
1251 #[test]
1252 fn between_i32_per_row_rhs() {
1253 let lhs = vec64![5, 9, 2, 8];
1254 let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9]; let out = between_i32(&lhs, &rhs, None, false).unwrap();
1256 assert_bool(&out, &[true, false, true, true], None);
1257 }
1258
1259 #[test]
1260 fn between_i32_nulls_propagate() {
1261 let lhs = vec64![5, 9, 2, 8];
1262 let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9];
1263 let mask = bm(&[true, false, true, true]);
1264 let out = between_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1265 assert_bool(
1266 &out,
1267 &[true, false, true, true],
1268 Some(&[true, false, true, true]),
1269 );
1270 }
1271
1272 #[cfg(feature = "extended_numeric_types")]
1273 #[test]
1274 fn between_i16_works() {
1275 let lhs = vec64![10i16, 12, 99];
1276 let rhs = vec64![10i16, 12];
1277 let out = in_i16(&lhs, &rhs, None, false).unwrap();
1278 assert_bool(&out, &[true, true, false], None);
1279 }
1280
1281 #[test]
1282 fn between_f64_scalar_and_nulls() {
1283 let lhs = vec64![1.0, 5.0, 8.0, 20.0];
1284 let rhs = vec64![4.0, 10.0];
1285 let mask = bm(&[true, false, true, true]);
1286 let out = between_f64(&lhs, &rhs, Some(&mask), true).unwrap();
1287 assert_bool(
1288 &out,
1289 &[false, false, true, false],
1290 Some(&[true, false, true, true]),
1291 );
1292 }
1293
1294 #[test]
1295 fn between_f32_generic_dispatch() {
1296 let lhs = vec64![0.1f32, 0.5, 1.2, -1.0];
1297 let rhs = vec64![0.0, 1.0];
1298 let out = cmp_between(&lhs, &rhs).unwrap();
1299 assert_bool(&out, &[true, true, false, false], None);
1300 }
1301
1302 #[test]
1303 fn between_masked_dispatch() {
1304 let lhs = vec64![1i32, 2, 3];
1305 let rhs = vec64![0, 2];
1306 let mask = bm(&[true, false, true]);
1307 let out = cmp_between_mask(&lhs, &rhs, Some(&mask)).unwrap();
1308 assert_bool(&out, &[true, false, false], Some(&[true, false, true]));
1309 }
1310
1311 #[test]
1314 fn in_i32_small_rhs() {
1315 let lhs = vec64![1, 2, 3, 4, 5];
1316 let rhs = vec64![2, 4];
1317 let out = in_i32(&lhs, &rhs, None, false).unwrap();
1318 assert_bool(&out, &[false, true, false, true, false], None);
1319 }
1320
1321 #[test]
1322 fn in_i32_with_nulls() {
1323 let lhs = vec64![7, 8, 9];
1324 let rhs = vec64![8];
1325 let mask = bm(&[true, false, true]);
1326 let out = in_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1327 assert_bool(&out, &[false, false, false], Some(&[true, false, true]));
1328 }
1329
1330 #[test]
1331 fn in_i64_large_rhs() {
1332 let lhs = vec64![1i64, 2, 3, 7, 8, 15];
1333 let rhs: Vec<i64> = (2..10).collect();
1334 let out = in_i64(&lhs, &rhs, None, false).unwrap();
1335 assert_bool(&out, &[false, true, true, true, true, false], None);
1336 }
1337
1338 #[cfg(feature = "extended_numeric_types")]
1339 #[test]
1340 fn in_u8_small_rhs() {
1341 let lhs = vec64![1u8, 2, 3, 4];
1342 let rhs = vec64![2u8, 3];
1343 let out = in_u8(&lhs, &rhs, None, false).unwrap();
1344 assert_bool(&out, &[false, true, true, false], None);
1345 }
1346
1347 #[test]
1348 fn in_float_nan_and_normal() {
1349 let lhs = vec64![1.0f32, f32::NAN, 7.0];
1350 let rhs = vec64![f32::NAN, 7.0];
1351 let out = in_f32(&lhs, &rhs, None, false).unwrap();
1352 assert_bool(&out, &[false, true, true], None);
1353 }
1354
1355 #[test]
1358 fn string_between() {
1359 let lhs = str_arr::<u32>(&["aa", "bb", "zz"]);
1360 let rhs = str_arr::<u32>(&["b", "y"]);
1361 let lhs_slice = (&lhs, 0, lhs.len());
1362 let rhs_slice = (&rhs, 0, rhs.len());
1363 let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1364 assert_bool(&out, &[false, true, false], None);
1365 }
1366
1367 #[test]
1368 fn string_between_chunk() {
1369 let lhs = str_arr::<u32>(&["0", "aa", "bb", "zz", "9"]);
1370 let rhs = str_arr::<u32>(&["a", "b", "y", "z"]);
1371 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1375 assert_bool(&out, &[false, true, false], None);
1376 }
1377
1378 #[test]
1379 fn string_in_basic() {
1380 let lhs = str_arr::<u32>(&["x", "y", "z"]);
1381 let rhs = str_arr::<u32>(&["y", "a"]);
1382 let lhs_slice = (&lhs, 0, lhs.len());
1383 let rhs_slice = (&rhs, 0, rhs.len());
1384 let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1385 assert_bool(&out, &[false, true, false], None);
1386 }
1387
1388 #[test]
1389 fn string_in_basic_chunk() {
1390 let lhs = str_arr::<u32>(&["0", "x", "y", "z", "9"]);
1391 let rhs = str_arr::<u32>(&["b", "y", "a", "c"]);
1392 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1395 assert_bool(&out, &[false, true, false], None);
1396 }
1397
1398 #[test]
1399 fn dict_between() {
1400 let lhs = dict_arr::<u32>(&["cat", "dog", "emu"]);
1401 let rhs = dict_arr::<u32>(&["cobra", "dove"]);
1402 let lhs_slice = (&lhs, 0, lhs.len());
1403 let rhs_slice = (&rhs, 0, rhs.len());
1404 let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1405 assert_bool(&out, &[false, true, false], None);
1406 }
1407
1408 #[test]
1409 fn dict_between_chunk() {
1410 let lhs = dict_arr::<u32>(&["a", "cat", "dog", "emu", "z"]);
1411 let rhs = dict_arr::<u32>(&["a", "cobra", "dove", "zz"]);
1412 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1415 assert_bool(&out, &[false, true, false], None);
1416 }
1417
1418 #[test]
1419 fn dict_in_membership() {
1420 let lhs = dict_arr::<u32>(&["aa", "bb", "cc"]);
1421 let rhs = dict_arr::<u32>(&["bb", "dd"]);
1422 let lhs_slice = (&lhs, 0, lhs.len());
1423 let rhs_slice = (&rhs, 0, rhs.len());
1424 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1425 assert_bool(&out, &[false, true, false], None);
1426 }
1427
1428 #[test]
1429 fn dict_in_membership_chunk() {
1430 let lhs = dict_arr::<u32>(&["0", "aa", "bb", "cc", "9"]);
1431 let rhs = dict_arr::<u32>(&["a", "bb", "dd", "zz"]);
1432 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1435 assert_bool(&out, &[false, true, false], None);
1436 }
1437
1438 #[test]
1439 fn string_between_nulls() {
1440 let mut lhs = str_arr::<u32>(&["foo", "bar", "baz"]);
1441 lhs.null_mask = Some(bm(&[true, false, true]));
1442 let rhs = str_arr::<u32>(&["a", "zzz"]);
1443 let lhs_slice = (&lhs, 0, lhs.len());
1444 let rhs_slice = (&rhs, 0, rhs.len());
1445 let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1446 assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1447 }
1448
1449 #[test]
1450 fn string_between_nulls_chunk() {
1451 let mut lhs = str_arr::<u32>(&["0", "foo", "bar", "baz", "z"]);
1452 lhs.null_mask = Some(bm(&[true, true, false, true, true]));
1453 let rhs = str_arr::<u32>(&["0", "a", "zzz", "9"]);
1454 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1457 assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1458 }
1459
1460 #[test]
1461 fn dict_in_nulls() {
1462 let mut lhs = dict_arr::<u32>(&["one", "two", "three"]);
1463 lhs.null_mask = Some(bm(&[false, true, true]));
1464 let rhs = dict_arr::<u32>(&["two", "four"]);
1465 let lhs_slice = (&lhs, 0, lhs.len());
1466 let rhs_slice = (&rhs, 0, rhs.len());
1467 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1468 assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1469 }
1470
1471 #[test]
1472 fn dict_in_nulls_chunk() {
1473 let mut lhs = dict_arr::<u32>(&["x", "one", "two", "three", "z"]);
1474 lhs.null_mask = Some(bm(&[true, false, true, true, true]));
1475 let rhs = dict_arr::<u32>(&["a", "two", "four", "b"]);
1476 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 2); let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1479 assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1480 }
1481
1482 #[test]
1485 fn is_null_and_is_not_null() {
1486 let mut arr = i32_arr(&[1, 2, 0]);
1487 arr.null_mask = Some(bm(&[true, false, true]));
1488 let array = Array::from_int32(arr.clone());
1489
1490 let not_null = is_not_null_array(&array).unwrap();
1491 let is_null = is_null_array(&array).unwrap();
1492
1493 assert_bool(¬_null, &[true, false, true], None);
1494 assert_bool(&is_null, &[false, true, false], None);
1495 }
1496
1497 #[test]
1498 fn is_null_not_null_dense() {
1499 let arr = i32_arr(&[1, 2, 3]);
1500 let array = Array::from_int32(arr.clone());
1501 let is_null = is_null_array(&array).unwrap();
1502 assert_bool(&is_null, &[false, false, false], None);
1503 let not_null = is_not_null_array(&array).unwrap();
1504 assert_bool(¬_null, &[true, true, true], None);
1505 }
1506
1507 #[test]
1510 fn in_array_int32_dispatch() {
1511 let inp = Array::from_int32(i32_arr(&[10, 20, 30]));
1512 let vals = Array::from_int32(i32_arr(&[20, 40]));
1513 let out = in_array(&inp, &vals).unwrap();
1514 assert_bool(&out, &[false, true, false], None);
1515
1516 let out_not = not_in_array(&inp, &vals).unwrap();
1517 assert_bool(&out_not, &[true, false, true], None);
1518 }
1519
1520 #[test]
1521 fn in_array_f32_dispatch() {
1522 let inp = Array::from_float32(f32_arr(&[1.0, f32::NAN, 7.0]));
1523 let vals = Array::from_float32(f32_arr(&[f32::NAN, 7.0]));
1524 let out = in_array(&inp, &vals).unwrap();
1525 assert_bool(&out, &[false, true, true], None);
1526 }
1527
1528 #[test]
1529 fn in_array_string_dispatch() {
1530 let inp = Array::from_string32(str_arr::<u32>(&["a", "b", "c"]));
1531 let vals = Array::from_string32(str_arr::<u32>(&["b", "d"]));
1532 let out = in_array(&inp, &vals).unwrap();
1533 assert_bool(&out, &[false, true, false], None);
1534 }
1535
1536 #[test]
1537 fn in_array_dictionary_dispatch() {
1538 let inp = Array::from_categorical32(dict_arr::<u32>(&["aa", "bb", "cc"]));
1539 let vals = Array::from_categorical32(dict_arr::<u32>(&["bb", "cc"]));
1540 let out = in_array(&inp, &vals).unwrap();
1541 assert_bool(&out, &[false, true, true], None);
1542 }
1543
1544 #[test]
1545 fn between_array_int32_rows() {
1546 let inp = Array::from_int32(i32_arr(&[5, 15, 25]));
1547 let min = Array::from_int32(i32_arr(&[0, 10, 20]));
1548 let max = Array::from_int32(i32_arr(&[10, 20, 30]));
1549
1550 let out = between_array(&inp, &min, &max).unwrap();
1551 match out {
1552 Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1553 _ => panic!("expected Bool array"),
1554 }
1555 }
1556
1557 #[test]
1558 fn between_array_float_generic() {
1559 let inp = Array::from_float32(f32_arr(&[0.5, 1.5, 2.5]));
1560 let min = Array::from_float32(f32_arr(&[0.0, 1.0, 2.0]));
1561 let max = Array::from_float32(f32_arr(&[1.0, 2.0, 3.0]));
1562
1563 let out = between_array(&inp, &min, &max).unwrap();
1564 match out {
1565 Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1566 _ => panic!("expected Bool"),
1567 }
1568 }
1569
1570 #[test]
1571 fn between_array_type_mismatch() {
1572 let inp = Array::from_int32(i32_arr(&[1, 2, 3]));
1573 let min = Array::from_float32(f32_arr(&[0.0, 0.0, 0.0]));
1574 let max = Array::from_float32(f32_arr(&[5.0, 5.0, 5.0]));
1575 let err = between_array(&inp, &min, &max).unwrap_err();
1576 match err {
1577 KernelError::UnsupportedType(_) => {}
1578 _ => panic!("Expected UnsupportedType error"),
1579 }
1580 }
1581
1582 #[test]
1585 fn in_integers_various_types() {
1586 #[cfg(feature = "extended_numeric_types")]
1587 {
1588 let u8_lhs = vec64![1u8, 2, 3, 5];
1589 let u8_rhs = vec64![3u8, 5, 8];
1590 let out = in_u8(&u8_lhs, &u8_rhs, None, false).unwrap();
1591 assert_bool(&out, &[false, false, true, true], None);
1592
1593 let u16_lhs = vec64![100u16, 200, 300];
1594 let u16_rhs = vec64![200u16, 500];
1595 let out = in_u16(&u16_lhs, &u16_rhs, None, false).unwrap();
1596 assert_bool(&out, &[false, true, false], None);
1597
1598 let i16_lhs = vec64![10i16, 15, 42];
1599 let i16_rhs = vec64![15i16, 42, 77];
1600 let out = in_i16(&i16_lhs, &i16_rhs, None, false).unwrap();
1601 assert_bool(&out, &[false, true, true], None);
1602 }
1603
1604 let u32_lhs = vec64![0u32, 1, 2, 9];
1605 let u32_rhs = vec64![9u32, 1];
1606 let out = in_u32(&u32_lhs, &u32_rhs, None, false).unwrap();
1607 assert_bool(&out, &[false, true, false, true], None);
1608
1609 let i64_lhs = vec64![1i64, 9, 10];
1610 let i64_rhs = vec64![2i64, 10, 20];
1611 let out = in_i64(&i64_lhs, &i64_rhs, None, false).unwrap();
1612 assert_bool(&out, &[false, false, true], None);
1613
1614 let u64_lhs = vec64![1u64, 2, 3, 4];
1615 let u64_rhs = vec64![2u64, 4, 8];
1616 let out = in_u64(&u64_lhs, &u64_rhs, None, false).unwrap();
1617 assert_bool(&out, &[false, true, false, true], None);
1618 }
1619
1620 #[test]
1623 fn between_and_in_empty_inputs() {
1624 let lhs: [i32; 0] = [];
1626 let rhs = vec64![0, 1];
1627 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1628 assert_eq!(out.len, 0);
1629
1630 let lhs: [i32; 0] = [];
1632 let rhs = vec64![1, 2, 3];
1633 let out = in_i32(&lhs, &rhs, None, false).unwrap();
1634 assert_eq!(out.len, 0);
1635
1636 let lhs = str_arr::<u32>(&[]);
1638 let rhs = str_arr::<u32>(&["a", "b"]);
1639 let lhs_slice = (&lhs, 0, lhs.len());
1640 let rhs_slice = (&rhs, 0, rhs.len());
1641 let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1642 assert_eq!(out.len, 0);
1643 }
1644
1645 #[test]
1646 fn between_and_in_empty_inputs_chunk() {
1647 let lhs = str_arr::<u32>(&["x", "y"]);
1649 let rhs = str_arr::<u32>(&["a", "b", "c"]);
1650 let lhs_slice = (&lhs, 1, 0); let rhs_slice = (&rhs, 1, 2); let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1653 assert_eq!(out.len, 0);
1654 }
1655
1656 #[test]
1657 fn between_per_row_bounds_on_last_row() {
1658 let lhs = vec64![0i32, 10, 20, 30];
1660 let rhs = vec64![0, 5, 5, 15, 15, 25, 25, 35];
1661 let out = between_i32(&lhs, &rhs, None, false).unwrap();
1662 assert_bool(&out, &[true, true, true, true], None);
1663 }
1664
1665 #[test]
1666 fn test_cmp_dict_in_force_fallback() {
1667 let mut lhs = dict_arr::<u32>(&["a", "b", "c", "a"]);
1669 lhs.unique_values = vec64!["a".to_string(), "b".to_string(), "c".to_string()]; let mut rhs = dict_arr::<u32>(&["b", "x", "y", "z"]);
1671 rhs.unique_values = vec64![
1672 "b".to_string(),
1673 "x".to_string(),
1674 "y".to_string(),
1675 "z".to_string()
1676 ]; lhs.null_mask = Some(bm(&[true, true, true, true]));
1678 let lhs_slice = (&lhs, 0, lhs.len());
1679 let rhs_slice = (&rhs, 0, rhs.len());
1680 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1681 assert_bool(
1683 &out,
1684 &[false, true, false, false],
1685 Some(&[true, true, true, true]),
1686 );
1687 }
1688
1689 #[test]
1690 fn test_cmp_dict_in_force_fallback_chunk() {
1691 let mut lhs = dict_arr::<u32>(&["z", "a", "b", "c", "a", "q"]);
1692 lhs.unique_values = vec64![
1693 "z".to_string(),
1694 "a".to_string(),
1695 "b".to_string(),
1696 "c".to_string(),
1697 "q".to_string()
1698 ];
1699 let mut rhs = dict_arr::<u32>(&["x", "b", "x", "y", "z"]);
1700 rhs.unique_values = vec64![
1701 "x".to_string(),
1702 "b".to_string(),
1703 "y".to_string(),
1704 "z".to_string()
1705 ];
1706 lhs.null_mask = Some(bm(&[true, true, true, true, true, true]));
1707 let lhs_slice = (&lhs, 1, 4);
1709 let rhs_slice = (&rhs, 1, 4);
1710 let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1711 assert_bool(
1713 &out,
1714 &[false, true, false, false],
1715 Some(&[true, true, true, true]),
1716 );
1717 }
1718
1719 #[test]
1720 fn test_in_array_empty_rhs() {
1721 let arr = Array::from_int32(i32_arr(&[1, 2, 3]));
1722 let empty = Array::from_int32(i32_arr(&[]));
1723 let out = in_array(&arr, &empty).unwrap();
1724 assert_bool(&out, &[false, false, false], None);
1726 }
1727}