1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::marker::PhantomData;
23#[cfg(feature = "simd")]
24use std::simd::{Mask, Simd};
25
26use minarrow::{Bitmask, BooleanArray, Integer, Numeric};
27
28#[cfg(not(feature = "simd"))]
29use crate::kernels::bitmask::std::{and_masks, in_mask, not_in_mask, not_mask};
30use crate::operators::ComparisonOperator;
31use minarrow::enums::error::KernelError;
32#[cfg(feature = "simd")]
33use minarrow::kernels::bitmask::simd::{
34 and_masks_simd, in_mask_simd, not_in_mask_simd, not_mask_simd,
35};
36use minarrow::utils::confirm_equal_len;
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39use minarrow::{BitmaskVT, BooleanAVT, CategoricalAVT, StringAVT};
40
41#[inline(always)]
43fn new_bool_bitmask(len: usize) -> Bitmask {
44 Bitmask::new_set_all(len, false)
45}
46
47fn merge_bitmasks_to_new(a: Option<&Bitmask>, b: Option<&Bitmask>, len: usize) -> Option<Bitmask> {
49 match (a, b) {
50 (None, None) => None,
51 (Some(x), None) | (None, Some(x)) => Some(x.slice_clone(0, len)),
52 (Some(x), Some(y)) => {
53 let mut out = Bitmask::new_set_all(len, true);
54 for i in 0..len {
55 unsafe { out.set_unchecked(i, x.get_unchecked(i) && y.get_unchecked(i)) };
56 }
57 Some(out)
58 }
59 }
60}
61
62macro_rules! impl_cmp_numeric {
65 ($fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
66 #[inline(always)]
71 pub fn $fn_name_to(
72 lhs: &[$ty],
73 rhs: &[$ty],
74 mask: Option<&Bitmask>,
75 op: ComparisonOperator,
76 output: &mut Bitmask,
77 ) -> Result<(), KernelError> {
78 let len = lhs.len();
79 confirm_equal_len("compare numeric length mismatch", len, rhs.len())?;
80 assert!(
81 output.capacity() >= len,
82 concat!(stringify!($fn_name_to), ": output capacity too small")
83 );
84 let has_nulls = mask.is_some();
85
86 #[cfg(feature = "simd")]
87 {
88 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90 use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
91 const N: usize = $lanes;
92 if !has_nulls {
93 let mut i = 0;
94 while i + N <= len {
95 let a = Simd::<$ty, N>::from_slice(&lhs[i..i + N]);
96 let b = Simd::<$ty, N>::from_slice(&rhs[i..i + N]);
97 let m: Mask<$mask_elem, N> = match op {
98 ComparisonOperator::Equals => a.simd_eq(b),
99 ComparisonOperator::NotEquals => a.simd_ne(b),
100 ComparisonOperator::LessThan => a.simd_lt(b),
101 ComparisonOperator::LessThanOrEqualTo => a.simd_le(b),
102 ComparisonOperator::GreaterThan => a.simd_gt(b),
103 ComparisonOperator::GreaterThanOrEqualTo => a.simd_ge(b),
104 _ => Mask::splat(false),
105 };
106 let bits = m.to_bitmask();
107 for l in 0..N {
108 if ((bits >> l) & 1) == 1 {
109 unsafe { output.set_unchecked(i + l, true) };
110 }
111 }
112 i += N;
113 }
114 for j in i..len {
116 let res = match op {
117 ComparisonOperator::Equals => lhs[j] == rhs[j],
118 ComparisonOperator::NotEquals => lhs[j] != rhs[j],
119 ComparisonOperator::LessThan => lhs[j] < rhs[j],
120 ComparisonOperator::LessThanOrEqualTo => lhs[j] <= rhs[j],
121 ComparisonOperator::GreaterThan => lhs[j] > rhs[j],
122 ComparisonOperator::GreaterThanOrEqualTo => lhs[j] >= rhs[j],
123 _ => false,
124 };
125 if res {
126 unsafe { output.set_unchecked(j, true) };
127 }
128 }
129
130 return Ok(());
131 }
132 }
133 }
135
136 for i in 0..len {
138 if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
139 continue;
140 }
141 let res = match op {
142 ComparisonOperator::Equals => lhs[i] == rhs[i],
143 ComparisonOperator::NotEquals => lhs[i] != rhs[i],
144 ComparisonOperator::LessThan => lhs[i] < rhs[i],
145 ComparisonOperator::LessThanOrEqualTo => lhs[i] <= rhs[i],
146 ComparisonOperator::GreaterThan => lhs[i] > rhs[i],
147 ComparisonOperator::GreaterThanOrEqualTo => lhs[i] >= rhs[i],
148 _ => false,
149 };
150 if res {
151 unsafe { output.set_unchecked(i, true) };
152 }
153 }
154 Ok(())
155 }
156
157 #[inline(always)]
177 pub fn $fn_name(
178 lhs: &[$ty],
179 rhs: &[$ty],
180 mask: Option<&Bitmask>,
181 op: ComparisonOperator,
182 ) -> Result<BooleanArray<()>, KernelError> {
183 let len = lhs.len();
184 let mut out = new_bool_bitmask(len);
185 $fn_name_to(lhs, rhs, mask, op, &mut out)?;
186 Ok(BooleanArray {
187 data: out.into(),
188 null_mask: mask.cloned(),
189 len,
190 _phantom: PhantomData,
191 })
192 }
193 };
194}
195
196#[inline(always)]
201pub fn cmp_numeric_to<T: Numeric + Copy + 'static>(
202 lhs: &[T],
203 rhs: &[T],
204 mask: Option<&Bitmask>,
205 op: ComparisonOperator,
206 output: &mut Bitmask,
207) -> Result<(), KernelError> {
208 macro_rules! dispatch {
209 ($t:ty, $f:ident) => {
210 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
211 return $f(
212 unsafe { std::mem::transmute(lhs) },
213 unsafe { std::mem::transmute(rhs) },
214 mask,
215 op,
216 output,
217 );
218 }
219 };
220 }
221 dispatch!(i32, cmp_i32_to);
222 dispatch!(i64, cmp_i64_to);
223 dispatch!(u32, cmp_u32_to);
224 dispatch!(u64, cmp_u64_to);
225 dispatch!(f32, cmp_f32_to);
226 dispatch!(f64, cmp_f64_to);
227
228 unreachable!("Unsupported numeric type for compare_numeric");
229}
230
231#[inline(always)]
274pub fn cmp_numeric<T: Numeric + Copy + 'static>(
275 lhs: &[T],
276 rhs: &[T],
277 mask: Option<&Bitmask>,
278 op: ComparisonOperator,
279) -> Result<BooleanArray<()>, KernelError> {
280 let len = lhs.len();
281 let mut out = new_bool_bitmask(len);
282 cmp_numeric_to(lhs, rhs, mask, op, &mut out)?;
283 Ok(BooleanArray {
284 data: out.into(),
285 null_mask: mask.cloned(),
286 len,
287 _phantom: PhantomData,
288 })
289}
290
291#[cfg(feature = "simd")]
301pub fn cmp_bitmask_simd<const LANES: usize>(
302 lhs: BitmaskVT<'_>,
303 rhs: BitmaskVT<'_>,
304 mask: Option<BitmaskVT<'_>>,
305 op: ComparisonOperator,
306) -> Result<Bitmask, KernelError>
307where
308{
309 confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
315 let (lhs_mask, lhs_offset, len) = lhs;
316 let (rhs_mask, rhs_offset, _) = rhs;
317
318 if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
321 let mut out = match op {
322 ComparisonOperator::In => in_mask_simd::<LANES>(lhs, rhs),
323 ComparisonOperator::NotIn => not_in_mask_simd::<LANES>(lhs, rhs),
324 _ => unreachable!(),
325 };
326 if let Some(mask_slice) = mask {
327 out = and_masks_simd::<LANES>((&out, 0, out.len), mask_slice);
328 }
329 return Ok(out);
330 }
331
332 if lhs_offset % 64 != 0
334 || rhs_offset % 64 != 0
335 || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
336 {
337 return Err(KernelError::InvalidArguments(format!(
338 "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
339 lhs_offset,
340 rhs_offset,
341 mask.as_ref().map(|(_, mo, _)| mo)
342 )));
343 }
344
345 let lhs_word_start = lhs_offset / 64;
347 let rhs_word_start = rhs_offset / 64;
348 let n_words = (len + 63) / 64;
349
350 let mut out = Bitmask::new_set_all(len, false);
352
353 type Word = u64;
354 let lane_words = LANES;
355 let simd_chunks = n_words / lane_words;
356
357 let tail_words = n_words % lane_words;
358 let mut word_idx = 0;
359
360 for chunk in 0..simd_chunks {
362 let base_lhs = lhs_word_start + chunk * lane_words;
363 let base_rhs = rhs_word_start + chunk * lane_words;
364 let base_mask = mask
365 .as_ref()
366 .map(|(m, mask_word_start, _)| (m, mask_word_start + chunk * lane_words));
367
368 let mut lhs_arr = [0u64; LANES];
369 let mut rhs_arr = [0u64; LANES];
370 let mut mask_arr = [!0u64; LANES];
371
372 for lane in 0..LANES {
373 lhs_arr[lane] = unsafe { lhs_mask.word_unchecked(base_lhs + lane) };
374 rhs_arr[lane] = unsafe { rhs_mask.word_unchecked(base_rhs + lane) };
375 if let Some((m, mask_word_start)) = base_mask {
376 mask_arr[lane] = unsafe { m.word_unchecked(mask_word_start + lane) };
377 }
378 }
379 let lhs_v = Simd::<Word, LANES>::from_array(lhs_arr);
380 let rhs_v = Simd::<Word, LANES>::from_array(rhs_arr);
381 let mask_v = Simd::<Word, LANES>::from_array(mask_arr);
382
383 let cmp_v = match op {
384 ComparisonOperator::Equals => !(lhs_v ^ rhs_v),
385 ComparisonOperator::NotEquals => lhs_v ^ rhs_v,
386 ComparisonOperator::GreaterThan => lhs_v & (!rhs_v),
387 ComparisonOperator::LessThan => (!lhs_v) & rhs_v,
388 ComparisonOperator::GreaterThanOrEqualTo => lhs_v | (!rhs_v),
389 ComparisonOperator::LessThanOrEqualTo => (!lhs_v) | rhs_v,
390 _ => Simd::splat(0),
391 };
392 let result_v = cmp_v & mask_v;
393
394 for lane in 0..LANES {
395 unsafe {
396 out.set_word_unchecked(word_idx, result_v[lane]);
397 }
398 word_idx += 1;
399 }
400 }
401
402 let base_lhs = lhs_word_start + simd_chunks * lane_words;
404 let base_rhs = rhs_word_start + simd_chunks * lane_words;
405 let base_mask: Option<(&Bitmask, usize)> = mask
406 .as_ref()
407 .map(|(m, mo, _)| (*m, mo + simd_chunks * lane_words));
408
409 for tail in 0..tail_words {
410 let a = unsafe { lhs_mask.word_unchecked(base_lhs + tail) };
411 let b = unsafe { rhs_mask.word_unchecked(base_rhs + tail) };
412 let m = if let Some((m, mask_word_start)) = base_mask {
413 unsafe { m.word_unchecked(mask_word_start + tail) }
414 } else {
415 !0u64
416 };
417 let cmp = match op {
418 ComparisonOperator::Equals => !(a ^ b),
419 ComparisonOperator::NotEquals => a ^ b,
420 ComparisonOperator::GreaterThan => a & (!b),
421 ComparisonOperator::LessThan => (!a) & b,
422 ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
423 ComparisonOperator::LessThanOrEqualTo => (!a) | b,
424 _ => 0,
425 } & m;
426 unsafe {
427 out.set_word_unchecked(word_idx, cmp);
428 }
429 word_idx += 1;
430 }
431
432 out.mask_trailing_bits();
433 Ok(out)
434}
435
436pub fn cmp_bool<const LANES: usize>(
453 lhs: BooleanAVT<'_, ()>,
454 rhs: BooleanAVT<'_, ()>,
455 op: ComparisonOperator,
456) -> Result<BooleanArray<()>, KernelError>
457where
458{
459 let (lhs_arr, lhs_off, len) = lhs;
460 let (rhs_arr, rhs_off, rlen) = rhs;
461 debug_assert_eq!(len, rlen, "cmp_bool: window length mismatch");
462
463 #[cfg(feature = "simd")]
464 let merged_null_mask: Option<Bitmask> =
465 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
466 (None, None) => None,
467 (Some(m), None) => Some(m.slice_clone(lhs_off, len)),
468 (None, Some(m)) => Some(m.slice_clone(rhs_off, len)),
469 (Some(a), Some(b)) => {
470 let am = (a, lhs_off, len);
471 let bm = (b, rhs_off, len);
472 Some(and_masks_simd::<LANES>(am, bm))
473 }
474 };
475
476 #[cfg(not(feature = "simd"))]
477 let merged_null_mask: Option<Bitmask> =
478 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
479 (None, None) => None,
480 (Some(m), None) => Some(m.slice_clone(lhs_off, len)),
481 (None, Some(m)) => Some(m.slice_clone(rhs_off, len)),
482 (Some(a), Some(b)) => {
483 let am = (a, lhs_off, len);
484 let bm = (b, rhs_off, len);
485 Some(and_masks(am, bm))
486 }
487 };
488
489 let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
490
491 let data = match op {
492 ComparisonOperator::Equals
493 | ComparisonOperator::NotEquals
494 | ComparisonOperator::LessThan
495 | ComparisonOperator::LessThanOrEqualTo
496 | ComparisonOperator::GreaterThan
497 | ComparisonOperator::GreaterThanOrEqualTo
498 | ComparisonOperator::In
499 | ComparisonOperator::NotIn => {
500 #[cfg(feature = "simd")]
501 let res = cmp_bitmask_simd::<LANES>(
502 (&lhs_arr.data, lhs_off, len),
503 (&rhs_arr.data, rhs_off, len),
504 mask_slice,
505 op,
506 )?;
507 #[cfg(not(feature = "simd"))]
508 let res = cmp_bitmask_std(
509 (&lhs_arr.data, lhs_off, len),
510 (&rhs_arr.data, rhs_off, len),
511 mask_slice,
512 op,
513 )?;
514 res
515 }
516 ComparisonOperator::IsNull => {
517 #[cfg(feature = "simd")]
518 let data = match merged_null_mask.as_ref() {
519 Some(mask) => not_mask_simd::<LANES>((mask, 0, len)),
520 None => Bitmask::new_set_all(len, false),
521 };
522 #[cfg(not(feature = "simd"))]
523 let data = match merged_null_mask.as_ref() {
524 Some(mask) => not_mask((mask, 0, len)),
525 None => Bitmask::new_set_all(len, false),
526 };
527 return Ok(BooleanArray {
528 data,
529 null_mask: None,
530 len,
531 _phantom: PhantomData,
532 });
533 }
534 ComparisonOperator::IsNotNull => {
535 let data = match merged_null_mask.as_ref() {
536 Some(mask) => mask.slice_clone(0, len),
537 None => Bitmask::new_set_all(len, true),
538 };
539 return Ok(BooleanArray {
540 data,
541 null_mask: None,
542 len,
543 _phantom: PhantomData,
544 });
545 }
546 ComparisonOperator::Between => {
547 return Err(KernelError::InvalidArguments(
548 "Set operations are not defined for Bool arrays".to_owned(),
549 ));
550 }
551 };
552
553 Ok(BooleanArray {
554 data,
555 null_mask: merged_null_mask,
556 len,
557 _phantom: PhantomData,
558 })
559}
560
561#[cfg(not(feature = "simd"))]
569pub fn cmp_bitmask_std(
570 lhs: BitmaskVT<'_>,
571 rhs: BitmaskVT<'_>,
572 mask: Option<BitmaskVT<'_>>,
573 op: ComparisonOperator,
574) -> Result<Bitmask, KernelError> {
575 confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
581 let (lhs_mask, lhs_offset, len) = lhs;
582 let (rhs_mask, rhs_offset, _) = rhs;
583
584 if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
587 let mut out = match op {
588 ComparisonOperator::In => in_mask(lhs, rhs),
589 ComparisonOperator::NotIn => not_in_mask(lhs, rhs),
590 _ => unreachable!(),
591 };
592 if let Some(mask_slice) = mask {
593 out = and_masks((&out, 0, out.len), mask_slice);
594 }
595 return Ok(out);
596 }
597
598 if lhs_offset % 64 != 0
600 || rhs_offset % 64 != 0
601 || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
602 {
603 return Err(KernelError::InvalidArguments(format!(
604 "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
605 lhs_offset,
606 rhs_offset,
607 mask.as_ref().map(|(_, mo, _)| mo)
608 )));
609 }
610
611 let lhs_word_start = lhs_offset / 64;
613 let rhs_word_start = rhs_offset / 64;
614 let n_words = (len + 63) / 64;
615
616 let mut out = Bitmask::new_set_all(len, false);
618
619 let words = n_words;
620 let tail = len % 64;
621 let mask_mask_opt = mask;
622
623 for w in 0..words {
625 let a = unsafe { lhs_mask.word_unchecked(lhs_word_start + w) };
626 let b = unsafe { rhs_mask.word_unchecked(rhs_word_start + w) };
627 let valid_bits =
628 mask_mask_opt
629 .as_ref()
630 .map_or(!0u64, |(mask_mask, mask_word_start, _)| unsafe {
631 mask_mask.word_unchecked(mask_word_start + w)
632 });
633 let word_cmp = match op {
634 ComparisonOperator::Equals => !(a ^ b),
635 ComparisonOperator::NotEquals => a ^ b,
636 ComparisonOperator::GreaterThan => a & (!b),
637 ComparisonOperator::LessThan => (!a) & b,
638 ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
639 ComparisonOperator::LessThanOrEqualTo => (!a) | b,
640 _ => 0,
641 };
642 let final_bits = word_cmp & valid_bits;
643 unsafe {
644 out.set_word_unchecked(w, final_bits);
645 }
646 }
647
648 let base = words * 64;
651 for i in 0..tail {
652 let idx_lhs = lhs_offset + base + i;
653 let idx_rhs = rhs_offset + base + i;
654 let mask_valid =
655 mask_mask_opt
656 .as_ref()
657 .map_or(true, |(mask_mask, mask_word_start, mask_len)| unsafe {
658 let mask_idx = mask_word_start * 64 + base + i;
659 if mask_idx < *mask_len {
660 mask_mask.get_unchecked(mask_idx)
661 } else {
662 false
663 }
664 });
665 if !mask_valid {
666 continue;
667 }
668 if idx_lhs >= lhs_mask.len() || idx_rhs >= rhs_mask.len() {
669 continue;
670 }
671 let a = unsafe { lhs_mask.get_unchecked(idx_lhs) };
672 let b = unsafe { rhs_mask.get_unchecked(idx_rhs) };
673 let res = match op {
674 ComparisonOperator::Equals => a == b,
675 ComparisonOperator::NotEquals => a != b,
676 ComparisonOperator::GreaterThan => a & !b,
677 ComparisonOperator::LessThan => !a & b,
678 ComparisonOperator::GreaterThanOrEqualTo => a | !b,
679 ComparisonOperator::LessThanOrEqualTo => !a | b,
680 _ => false,
681 };
682 if res {
683 out.set(base + i, true)
684 }
685 }
686 out.mask_trailing_bits();
687 Ok(out)
688}
689
690macro_rules! impl_cmp_utf8_slice {
693 ($fn_name:ident, $fn_name_to:ident, $lhs_slice:ty, $rhs_slice:ty, [$($gen:tt)+]) => {
694 #[inline(always)]
699 pub fn $fn_name_to<$($gen)+>(
700 lhs: $lhs_slice,
701 rhs: $rhs_slice,
702 op: ComparisonOperator,
703 output: &mut Bitmask,
704 ) -> Result<(), KernelError> {
705 let (larr, loff, llen) = lhs;
706 let (rarr, roff, rlen) = rhs;
707 confirm_equal_len("compare string/dict length mismatch (slice contract)", llen, rlen)?;
708 assert!(output.capacity() >= llen, concat!(stringify!($fn_name_to), ": output capacity too small"));
709
710 let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
711 let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
712
713 if let Some(m) = larr.null_mask.as_ref() {
714 if m.capacity() < loff + llen {
715 return Err(KernelError::InvalidArguments(
716 format!(
717 "lhs mask capacity too small (expected ≥ {}, got {})",
718 loff + llen,
719 m.capacity()
720 ),
721 ));
722 }
723 }
724 if let Some(m) = rarr.null_mask.as_ref() {
725 if m.capacity() < roff + rlen {
726 return Err(KernelError::InvalidArguments(
727 format!(
728 "rhs mask capacity too small (expected ≥ {}, got {})",
729 roff + rlen,
730 m.capacity()
731 ),
732 ));
733 }
734 }
735
736 let has_nulls = lhs_mask.is_some() || rhs_mask.is_some();
737 for i in 0..llen {
738 if has_nulls
739 && !(lhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) })
740 && rhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }))
741 {
742 continue;
743 }
744 let l = unsafe { larr.get_str_unchecked(loff + i) };
745 let r = unsafe { rarr.get_str_unchecked(roff + i) };
746 let res = match op {
747 ComparisonOperator::Equals => l == r,
748 ComparisonOperator::NotEquals => l != r,
749 ComparisonOperator::GreaterThan => l > r,
750 ComparisonOperator::LessThan => l < r,
751 ComparisonOperator::GreaterThanOrEqualTo => l >= r,
752 ComparisonOperator::LessThanOrEqualTo => l <= r,
753 _ => false,
754 };
755 if res {
756 output.set(i, true);
757 }
758 }
759 Ok(())
760 }
761
762 #[inline(always)]
764 pub fn $fn_name<$($gen)+>(
765 lhs: $lhs_slice,
766 rhs: $rhs_slice,
767 op: ComparisonOperator,
768 ) -> Result<BooleanArray<()>, KernelError> {
769 let (larr, loff, llen) = lhs;
770 let (rarr, roff, _) = rhs;
771 let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
772 let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, llen));
773 let mut out = new_bool_bitmask(llen);
774 $fn_name_to((larr, loff, llen), (rarr, roff, llen), op, &mut out)?;
775 let null_mask = merge_bitmasks_to_new(lhs_mask.as_ref(), rhs_mask.as_ref(), llen);
776 Ok(BooleanArray { data: out.into(), null_mask, len: llen, _phantom: PhantomData })
777 }
778 };
779}
780
781impl_cmp_numeric!(cmp_i32, cmp_i32_to, i32, W32, i32);
782impl_cmp_numeric!(cmp_u32, cmp_u32_to, u32, W32, i32);
783impl_cmp_numeric!(cmp_i64, cmp_i64_to, i64, W64, i64);
784impl_cmp_numeric!(cmp_u64, cmp_u64_to, u64, W64, i64);
785impl_cmp_numeric!(cmp_f32, cmp_f32_to, f32, W32, i32);
786impl_cmp_numeric!(cmp_f64, cmp_f64_to, f64, W64, i64);
787impl_cmp_utf8_slice!(cmp_str_str, cmp_str_str_to, StringAVT<'a, T>, StringAVT<'a, T>, [ 'a, T: Integer ]);
788impl_cmp_utf8_slice!(cmp_str_dict, cmp_str_dict_to, StringAVT<'a, T>, CategoricalAVT<'a, U>, [ 'a, T: Integer, U: Integer ]);
789impl_cmp_utf8_slice!(cmp_dict_str, cmp_dict_str_to, CategoricalAVT<'a, T>, StringAVT<'a, U>, [ 'a, T: Integer, U: Integer ]);
790impl_cmp_utf8_slice!(cmp_dict_dict, cmp_dict_dict_to, CategoricalAVT<'a, T>, CategoricalAVT<'a, T>, [ 'a, T: Integer ]);
791
792#[cfg(test)]
793mod tests {
794 use minarrow::{Bitmask, BooleanArray, CategoricalArray, Integer, StringArray, vec64};
795
796 use crate::kernels::comparison::{
797 cmp_dict_dict, cmp_dict_str, cmp_i32, cmp_numeric, cmp_str_dict,
798 };
799
800 #[cfg(feature = "simd")]
801 use crate::kernels::comparison::{W64, cmp_bitmask_simd};
802
803 use crate::operators::ComparisonOperator;
804
805 fn bm(bits: &[bool]) -> Bitmask {
808 let mut m = Bitmask::new_set_all(bits.len(), false);
809 for (i, &b) in bits.iter().enumerate() {
810 m.set(i, b);
811 }
812 m
813 }
814
815 fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
817 assert_eq!(arr.len, expect.len());
818 for i in 0..expect.len() {
819 assert_eq!(arr.data.get(i), expect[i], "value bit {i}");
820 }
821 match (arr.null_mask.as_ref(), expect_mask) {
822 (None, None) => {}
823 (Some(m), Some(exp)) => {
824 for (i, &b) in exp.iter().enumerate() {
825 assert_eq!(m.get(i), b, "null-bit {i}");
826 }
827 }
828 _ => panic!("mask mismatch"),
829 }
830 }
831
832 fn str_arr<T: Integer>(v: &[&str]) -> StringArray<T> {
834 StringArray::<T>::from_slice(v)
835 }
836
837 fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
838 let owned: Vec<&str> = vals.to_vec();
839 CategoricalArray::<T>::from_values(owned)
840 }
841
842 #[test]
845 fn numeric_compare_no_nulls() {
846 let a = vec64![1i32, 2, 3, 4];
847 let b = vec64![1i32, 1, 4, 4];
848
849 let eq = cmp_i32(&a, &b, None, ComparisonOperator::Equals).unwrap();
850 let neq = cmp_i32(&a, &b, None, ComparisonOperator::NotEquals).unwrap();
851 let lt = cmp_i32(&a, &b, None, ComparisonOperator::LessThan).unwrap();
852 let le = cmp_i32(&a, &b, None, ComparisonOperator::LessThanOrEqualTo).unwrap();
853 let gt = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThan).unwrap();
854 let ge = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThanOrEqualTo).unwrap();
855
856 assert_bool(&eq, &[true, false, false, true], None);
857 assert_bool(&neq, &[false, true, true, false], None);
858 assert_bool(<, &[false, false, true, false], None);
859 assert_bool(&le, &[true, false, true, true], None);
860 assert_bool(>, &[false, true, false, false], None);
861 assert_bool(&ge, &[true, true, false, true], None);
862 }
863
864 #[test]
865 fn numeric_compare_with_nulls_generic_dispatch() {
866 let a = vec64![1u64, 5, 9, 10];
868 let b = vec64![0u64, 5, 8, 11];
869 let mask = bm(&[true, true, true, false]);
870
871 let out = cmp_numeric(&a, &b, Some(&mask), ComparisonOperator::GreaterThan).unwrap();
872 assert_bool(
874 &out,
875 &[true, false, true, false],
876 Some(&[true, true, true, false]),
877 );
878 }
879
880 #[cfg(feature = "simd")]
883 #[test]
884 fn bool_compare_all_ops() {
885 let a = bm(&[true, false, true, false]);
886 let b = bm(&[true, true, false, false]);
887 let eq = cmp_bitmask_simd::<W64>(
888 (&a, 0, a.len()),
889 (&b, 0, b.len()),
890 None,
891 ComparisonOperator::Equals,
892 )
893 .unwrap();
894 let lt = cmp_bitmask_simd::<W64>(
895 (&a, 0, a.len()),
896 (&b, 0, b.len()),
897 None,
898 ComparisonOperator::LessThan,
899 )
900 .unwrap();
901 let gt = cmp_bitmask_simd::<W64>(
902 (&a, 0, a.len()),
903 (&b, 0, b.len()),
904 None,
905 ComparisonOperator::GreaterThan,
906 )
907 .unwrap();
908
909 assert_bool(
910 &BooleanArray::from_bitmask(eq, None),
911 &[true, false, false, true],
912 None,
913 );
914 assert_bool(
915 &BooleanArray::from_bitmask(lt, None),
916 &[false, true, false, false],
917 None,
918 );
919 assert_bool(
920 &BooleanArray::from_bitmask(gt, None),
921 &[false, false, true, false],
922 None,
923 );
924 }
925
926 #[test]
929 fn string_vs_dict_compare_with_nulls() {
930 let mut lhs = str_arr::<u32>(&["x", "y", "z"]);
931 lhs.null_mask = Some(bm(&[true, false, true]));
932 let rhs = dict_arr::<u32>(&["x", "w", "zz"]);
933 let lhs_slice = (&lhs, 0, lhs.len());
934 let rhs_slice = (&rhs, 0, rhs.data.len());
935 let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
936 assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
937 }
938
939 #[test]
940 fn string_vs_dict_compare_with_nulls_chunk() {
941 let mut lhs = str_arr::<u32>(&["pad", "x", "y", "z", "pad"]);
942 lhs.null_mask = Some(bm(&[true, true, false, true, true]));
943 let rhs = dict_arr::<u32>(&["pad", "x", "w", "zz", "pad"]);
944 let lhs_slice = (&lhs, 1, 3);
945 let rhs_slice = (&rhs, 1, 3);
946 let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
947 assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
948 }
949
950 #[test]
951 fn dict_vs_dict_compare_gt() {
952 let lhs = dict_arr::<u32>(&["apple", "pear", "banana"]);
953 let rhs = dict_arr::<u32>(&["ant", "pear", "apricot"]);
954 let lhs_slice = (&lhs, 0, lhs.data.len());
955 let rhs_slice = (&rhs, 0, rhs.data.len());
956 let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
957 assert_bool(&res, &[true, false, true], None);
958 }
959
960 #[test]
961 fn dict_vs_dict_compare_gt_chunk() {
962 let lhs = dict_arr::<u32>(&["pad", "apple", "pear", "banana", "pad"]);
963 let rhs = dict_arr::<u32>(&["pad", "ant", "pear", "apricot", "pad"]);
964 let lhs_slice = (&lhs, 1, 3);
965 let rhs_slice = (&rhs, 1, 3);
966 let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
967 assert_bool(&res, &[true, false, true], None);
968 }
969
970 #[test]
971 fn dict_vs_string_compare_le() {
972 let lhs = dict_arr::<u32>(&["a", "b", "c"]);
973 let rhs = str_arr::<u32>(&["b", "b", "d"]);
974 let lhs_slice = (&lhs, 0, lhs.data.len());
975 let rhs_slice = (&rhs, 0, rhs.len());
976 let res =
977 cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
978 assert_bool(&res, &[true, true, true], None);
979 }
980
981 #[test]
982 fn dict_vs_string_compare_le_chunk() {
983 let lhs = dict_arr::<u32>(&["pad", "a", "b", "c", "pad"]);
984 let rhs = str_arr::<u32>(&["pad", "b", "b", "d", "pad"]);
985 let lhs_slice = (&lhs, 1, 3);
986 let rhs_slice = (&rhs, 1, 3);
987 let res =
988 cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
989 assert_bool(&res, &[true, true, true], None);
990 }
991}