1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
19
20use core::simd::{LaneCount, SupportedLaneCount};
22use std::marker::PhantomData;
23#[cfg(feature = "simd")]
24use std::simd::{Mask, Simd};
25
26use minarrow::{Bitmask, BooleanArray, Integer, Numeric};
27
28#[cfg(not(feature = "simd"))]
29use crate::kernels::bitmask::std::{and_masks, in_mask, not_in_mask, not_mask};
30use crate::operators::ComparisonOperator;
31use minarrow::enums::error::KernelError;
32#[cfg(feature = "simd")]
33use minarrow::kernels::bitmask::simd::{
34 and_masks_simd, in_mask_simd, not_in_mask_simd, not_mask_simd,
35};
36use minarrow::utils::confirm_equal_len;
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39use minarrow::{BitmaskVT, BooleanAVT, CategoricalAVT, StringAVT};
40
41#[inline(always)]
43fn new_bool_bitmask(len: usize) -> Bitmask {
44 Bitmask::new_set_all(len, false)
45}
46
47fn merge_bitmasks_to_new(a: Option<&Bitmask>, b: Option<&Bitmask>, len: usize) -> Option<Bitmask> {
49 match (a, b) {
50 (None, None) => None,
51 (Some(x), None) | (None, Some(x)) => Some(x.slice_clone(0, len)),
52 (Some(x), Some(y)) => {
53 let mut out = Bitmask::new_set_all(len, true);
54 for i in 0..len {
55 unsafe { out.set_unchecked(i, x.get_unchecked(i) && y.get_unchecked(i)) };
56 }
57 Some(out)
58 }
59 }
60}
61
62macro_rules! impl_cmp_numeric {
65 ($fn_name:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
66 #[inline(always)]
86 pub fn $fn_name(
87 lhs: &[$ty],
88 rhs: &[$ty],
89 mask: Option<&Bitmask>,
90 op: ComparisonOperator,
91 ) -> Result<BooleanArray<()>, KernelError> {
92 let len = lhs.len();
93 confirm_equal_len("compare numeric length mismatch", len, rhs.len())?;
94 let has_nulls = mask.is_some();
95 let mut out = new_bool_bitmask(len);
96
97 #[cfg(feature = "simd")]
98 {
99 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
101 use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
102 const N: usize = $lanes;
103 if !has_nulls {
104 let mut i = 0;
105 while i + N <= len {
106 let a = Simd::<$ty, N>::from_slice(&lhs[i..i + N]);
107 let b = Simd::<$ty, N>::from_slice(&rhs[i..i + N]);
108 let m: Mask<$mask_elem, N> = match op {
109 ComparisonOperator::Equals => a.simd_eq(b),
110 ComparisonOperator::NotEquals => a.simd_ne(b),
111 ComparisonOperator::LessThan => a.simd_lt(b),
112 ComparisonOperator::LessThanOrEqualTo => a.simd_le(b),
113 ComparisonOperator::GreaterThan => a.simd_gt(b),
114 ComparisonOperator::GreaterThanOrEqualTo => a.simd_ge(b),
115 _ => Mask::splat(false),
116 };
117 let bits = m.to_bitmask();
118 for l in 0..N {
119 if ((bits >> l) & 1) == 1 {
120 unsafe { out.set_unchecked(i + l, true) };
121 }
122 }
123 i += N;
124 }
125 for j in i..len {
127 let res = match op {
128 ComparisonOperator::Equals => lhs[j] == rhs[j],
129 ComparisonOperator::NotEquals => lhs[j] != rhs[j],
130 ComparisonOperator::LessThan => lhs[j] < rhs[j],
131 ComparisonOperator::LessThanOrEqualTo => lhs[j] <= rhs[j],
132 ComparisonOperator::GreaterThan => lhs[j] > rhs[j],
133 ComparisonOperator::GreaterThanOrEqualTo => lhs[j] >= rhs[j],
134 _ => false,
135 };
136 if res {
137 unsafe { out.set_unchecked(j, true) };
138 }
139 }
140
141 return Ok(BooleanArray {
142 data: out.into(),
143 null_mask: None,
144 len,
145 _phantom: PhantomData,
146 });
147 }
148 }
149 }
151
152 for i in 0..len {
154 if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
155 continue;
156 }
157 let res = match op {
158 ComparisonOperator::Equals => lhs[i] == rhs[i],
159 ComparisonOperator::NotEquals => lhs[i] != rhs[i],
160 ComparisonOperator::LessThan => lhs[i] < rhs[i],
161 ComparisonOperator::LessThanOrEqualTo => lhs[i] <= rhs[i],
162 ComparisonOperator::GreaterThan => lhs[i] > rhs[i],
163 ComparisonOperator::GreaterThanOrEqualTo => lhs[i] >= rhs[i],
164 _ => false,
165 };
166 if res {
167 unsafe { out.set_unchecked(i, true) };
168 }
169 }
170 Ok(BooleanArray {
171 data: out.into(),
172 null_mask: mask.cloned(),
173 len,
174 _phantom: PhantomData,
175 })
176 }
177 };
178}
179
180#[inline(always)]
223pub fn cmp_numeric<T: Numeric + Copy + 'static>(
224 lhs: &[T],
225 rhs: &[T],
226 mask: Option<&Bitmask>,
227 op: ComparisonOperator,
228) -> Result<BooleanArray<()>, KernelError> {
229 macro_rules! dispatch {
230 ($t:ty, $f:ident) => {
231 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
232 return $f(
233 unsafe { std::mem::transmute(lhs) },
234 unsafe { std::mem::transmute(rhs) },
235 mask,
236 op,
237 );
238 }
239 };
240 }
241 dispatch!(i32, cmp_i32);
242 dispatch!(i64, cmp_i64);
243 dispatch!(u32, cmp_u32);
244 dispatch!(u64, cmp_u64);
245 dispatch!(f32, cmp_f32);
246 dispatch!(f64, cmp_f64);
247
248 unreachable!("Unsupported numeric type for compare_numeric");
249}
250
251#[cfg(feature = "simd")]
261pub fn cmp_bitmask_simd<const LANES: usize>(
262 lhs: BitmaskVT<'_>,
263 rhs: BitmaskVT<'_>,
264 mask: Option<BitmaskVT<'_>>,
265 op: ComparisonOperator,
266) -> Result<Bitmask, KernelError>
267where
268 LaneCount<LANES>: SupportedLaneCount,
269{
270 confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
276 let (lhs_mask, lhs_offset, len) = lhs;
277 let (rhs_mask, rhs_offset, _) = rhs;
278
279 if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
282 let mut out = match op {
283 ComparisonOperator::In => in_mask_simd::<LANES>(lhs, rhs),
284 ComparisonOperator::NotIn => not_in_mask_simd::<LANES>(lhs, rhs),
285 _ => unreachable!(),
286 };
287 if let Some(mask_slice) = mask {
288 out = and_masks_simd::<LANES>((&out, 0, out.len), mask_slice);
289 }
290 return Ok(out);
291 }
292
293 if lhs_offset % 64 != 0
295 || rhs_offset % 64 != 0
296 || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
297 {
298 return Err(KernelError::InvalidArguments(format!(
299 "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
300 lhs_offset,
301 rhs_offset,
302 mask.as_ref().map(|(_, mo, _)| mo)
303 )));
304 }
305
306 let lhs_word_start = lhs_offset / 64;
308 let rhs_word_start = rhs_offset / 64;
309 let n_words = (len + 63) / 64;
310
311 let mut out = Bitmask::new_set_all(len, false);
313
314 type Word = u64;
315 let lane_words = LANES;
316 let simd_chunks = n_words / lane_words;
317
318 let tail_words = n_words % lane_words;
319 let mut word_idx = 0;
320
321 for chunk in 0..simd_chunks {
323 let base_lhs = lhs_word_start + chunk * lane_words;
324 let base_rhs = rhs_word_start + chunk * lane_words;
325 let base_mask = mask
326 .as_ref()
327 .map(|(m, mask_word_start, _)| (m, mask_word_start + chunk * lane_words));
328
329 let mut lhs_arr = [0u64; LANES];
330 let mut rhs_arr = [0u64; LANES];
331 let mut mask_arr = [!0u64; LANES];
332
333 for lane in 0..LANES {
334 lhs_arr[lane] = unsafe { lhs_mask.word_unchecked(base_lhs + lane) };
335 rhs_arr[lane] = unsafe { rhs_mask.word_unchecked(base_rhs + lane) };
336 if let Some((m, mask_word_start)) = base_mask {
337 mask_arr[lane] = unsafe { m.word_unchecked(mask_word_start + lane) };
338 }
339 }
340 let lhs_v = Simd::<Word, LANES>::from_array(lhs_arr);
341 let rhs_v = Simd::<Word, LANES>::from_array(rhs_arr);
342 let mask_v = Simd::<Word, LANES>::from_array(mask_arr);
343
344 let cmp_v = match op {
345 ComparisonOperator::Equals => !(lhs_v ^ rhs_v),
346 ComparisonOperator::NotEquals => lhs_v ^ rhs_v,
347 ComparisonOperator::GreaterThan => lhs_v & (!rhs_v),
348 ComparisonOperator::LessThan => (!lhs_v) & rhs_v,
349 ComparisonOperator::GreaterThanOrEqualTo => lhs_v | (!rhs_v),
350 ComparisonOperator::LessThanOrEqualTo => (!lhs_v) | rhs_v,
351 _ => Simd::splat(0),
352 };
353 let result_v = cmp_v & mask_v;
354
355 for lane in 0..LANES {
356 unsafe {
357 out.set_word_unchecked(word_idx, result_v[lane]);
358 }
359 word_idx += 1;
360 }
361 }
362
363 let base_lhs = lhs_word_start + simd_chunks * lane_words;
365 let base_rhs = rhs_word_start + simd_chunks * lane_words;
366 let base_mask: Option<(&Bitmask, usize)> = mask
367 .as_ref()
368 .map(|(m, mo, _)| (*m, mo + simd_chunks * lane_words));
369
370 for tail in 0..tail_words {
371 let a = unsafe { lhs_mask.word_unchecked(base_lhs + tail) };
372 let b = unsafe { rhs_mask.word_unchecked(base_rhs + tail) };
373 let m = if let Some((m, mask_word_start)) = base_mask {
374 unsafe { m.word_unchecked(mask_word_start + tail) }
375 } else {
376 !0u64
377 };
378 let cmp = match op {
379 ComparisonOperator::Equals => !(a ^ b),
380 ComparisonOperator::NotEquals => a ^ b,
381 ComparisonOperator::GreaterThan => a & (!b),
382 ComparisonOperator::LessThan => (!a) & b,
383 ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
384 ComparisonOperator::LessThanOrEqualTo => (!a) | b,
385 _ => 0,
386 } & m;
387 unsafe {
388 out.set_word_unchecked(word_idx, cmp);
389 }
390 word_idx += 1;
391 }
392
393 out.mask_trailing_bits();
394 Ok(out)
395}
396
397pub fn cmp_bool<const LANES: usize>(
414 lhs: BooleanAVT<'_, ()>,
415 rhs: BooleanAVT<'_, ()>,
416 op: ComparisonOperator,
417) -> Result<BooleanArray<()>, KernelError>
418where
419 LaneCount<LANES>: SupportedLaneCount,
420{
421 let (lhs_arr, lhs_off, len) = lhs;
422 let (rhs_arr, rhs_off, rlen) = rhs;
423 debug_assert_eq!(len, rlen, "cmp_bool: window length mismatch");
424
425 #[cfg(feature = "simd")]
426 let merged_null_mask: Option<Bitmask> =
427 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
428 (None, None) => None,
429 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
430 (Some(a), Some(b)) => {
431 let am = (a, lhs_off, len);
432 let bm = (b, rhs_off, len);
433 Some(and_masks_simd::<LANES>(am, bm))
434 }
435 };
436
437 #[cfg(not(feature = "simd"))]
438 let merged_null_mask: Option<Bitmask> =
439 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
440 (None, None) => None,
441 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
442 (Some(a), Some(b)) => {
443 let am = (a, lhs_off, len);
444 let bm = (b, rhs_off, len);
445 Some(and_masks(am, bm))
446 }
447 };
448
449 let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
450
451 let data = match op {
452 ComparisonOperator::Equals
453 | ComparisonOperator::NotEquals
454 | ComparisonOperator::LessThan
455 | ComparisonOperator::LessThanOrEqualTo
456 | ComparisonOperator::GreaterThan
457 | ComparisonOperator::GreaterThanOrEqualTo
458 | ComparisonOperator::In
459 | ComparisonOperator::NotIn => {
460 #[cfg(feature = "simd")]
461 let res = cmp_bitmask_simd::<LANES>(
462 (&lhs_arr.data, lhs_off, len),
463 (&rhs_arr.data, rhs_off, len),
464 mask_slice,
465 op,
466 )?;
467 #[cfg(not(feature = "simd"))]
468 let res = cmp_bitmask_std(
469 (&lhs_arr.data, lhs_off, len),
470 (&rhs_arr.data, rhs_off, len),
471 mask_slice,
472 op,
473 )?;
474 res
475 }
476 ComparisonOperator::IsNull => {
477 #[cfg(feature = "simd")]
478 let data = match merged_null_mask.as_ref() {
479 Some(mask) => not_mask_simd::<LANES>((mask, 0, len)),
480 None => Bitmask::new_set_all(len, false),
481 };
482 #[cfg(not(feature = "simd"))]
483 let data = match merged_null_mask.as_ref() {
484 Some(mask) => not_mask((mask, 0, len)),
485 None => Bitmask::new_set_all(len, false),
486 };
487 return Ok(BooleanArray {
488 data,
489 null_mask: None,
490 len,
491 _phantom: PhantomData,
492 });
493 }
494 ComparisonOperator::IsNotNull => {
495 let data = match merged_null_mask.as_ref() {
496 Some(mask) => mask.slice_clone(0, len),
497 None => Bitmask::new_set_all(len, true),
498 };
499 return Ok(BooleanArray {
500 data,
501 null_mask: None,
502 len,
503 _phantom: PhantomData,
504 });
505 }
506 ComparisonOperator::Between => {
507 return Err(KernelError::InvalidArguments(
508 "Set operations are not defined for Bool arrays".to_owned(),
509 ));
510 }
511 };
512
513 Ok(BooleanArray {
514 data,
515 null_mask: merged_null_mask,
516 len,
517 _phantom: PhantomData,
518 })
519}
520
521#[cfg(not(feature = "simd"))]
529pub fn cmp_bitmask_std(
530 lhs: BitmaskVT<'_>,
531 rhs: BitmaskVT<'_>,
532 mask: Option<BitmaskVT<'_>>,
533 op: ComparisonOperator,
534) -> Result<Bitmask, KernelError> {
535 confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
541 let (lhs_mask, lhs_offset, len) = lhs;
542 let (rhs_mask, rhs_offset, _) = rhs;
543
544 if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
547 let mut out = match op {
548 ComparisonOperator::In => in_mask(lhs, rhs),
549 ComparisonOperator::NotIn => not_in_mask(lhs, rhs),
550 _ => unreachable!(),
551 };
552 if let Some(mask_slice) = mask {
553 out = and_masks((&out, 0, out.len), mask_slice);
554 }
555 return Ok(out);
556 }
557
558 if lhs_offset % 64 != 0
560 || rhs_offset % 64 != 0
561 || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
562 {
563 return Err(KernelError::InvalidArguments(format!(
564 "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
565 lhs_offset,
566 rhs_offset,
567 mask.as_ref().map(|(_, mo, _)| mo)
568 )));
569 }
570
571 let lhs_word_start = lhs_offset / 64;
573 let rhs_word_start = rhs_offset / 64;
574 let n_words = (len + 63) / 64;
575
576 let mut out = Bitmask::new_set_all(len, false);
578
579 let words = n_words;
580 let tail = len % 64;
581 let mask_mask_opt = mask;
582
583 for w in 0..words {
585 let a = unsafe { lhs_mask.word_unchecked(lhs_word_start + w) };
586 let b = unsafe { rhs_mask.word_unchecked(rhs_word_start + w) };
587 let valid_bits =
588 mask_mask_opt
589 .as_ref()
590 .map_or(!0u64, |(mask_mask, mask_word_start, _)| unsafe {
591 mask_mask.word_unchecked(mask_word_start + w)
592 });
593 let word_cmp = match op {
594 ComparisonOperator::Equals => !(a ^ b),
595 ComparisonOperator::NotEquals => a ^ b,
596 ComparisonOperator::GreaterThan => a & (!b),
597 ComparisonOperator::LessThan => (!a) & b,
598 ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
599 ComparisonOperator::LessThanOrEqualTo => (!a) | b,
600 _ => 0,
601 };
602 let final_bits = word_cmp & valid_bits;
603 unsafe {
604 out.set_word_unchecked(w, final_bits);
605 }
606 }
607
608 let base = words * 64;
611 for i in 0..tail {
612 let idx_lhs = lhs_offset + base + i;
613 let idx_rhs = rhs_offset + base + i;
614 let mask_valid =
615 mask_mask_opt
616 .as_ref()
617 .map_or(true, |(mask_mask, mask_word_start, mask_len)| unsafe {
618 let mask_idx = mask_word_start * 64 + base + i;
619 if mask_idx < *mask_len {
620 mask_mask.get_unchecked(mask_idx)
621 } else {
622 false
623 }
624 });
625 if !mask_valid {
626 continue;
627 }
628 if idx_lhs >= lhs_mask.len() || idx_rhs >= rhs_mask.len() {
629 continue;
630 }
631 let a = unsafe { lhs_mask.get_unchecked(idx_lhs) };
632 let b = unsafe { rhs_mask.get_unchecked(idx_rhs) };
633 let res = match op {
634 ComparisonOperator::Equals => a == b,
635 ComparisonOperator::NotEquals => a != b,
636 ComparisonOperator::GreaterThan => a & !b,
637 ComparisonOperator::LessThan => !a & b,
638 ComparisonOperator::GreaterThanOrEqualTo => a | !b,
639 ComparisonOperator::LessThanOrEqualTo => !a | b,
640 _ => false,
641 };
642 if res {
643 out.set(base + i, true)
644 }
645 }
646 out.mask_trailing_bits();
647 Ok(out)
648}
649
650macro_rules! impl_cmp_utf8_slice {
653 ($fn_name:ident, $lhs_slice:ty, $rhs_slice:ty, [$($gen:tt)+]) => {
654 #[inline(always)]
656 pub fn $fn_name<$($gen)+>(
657 lhs: $lhs_slice,
658 rhs: $rhs_slice,
659 op: ComparisonOperator,
660 ) -> Result<BooleanArray<()>, KernelError> {
661 let (larr, loff, llen) = lhs;
662 let (rarr, roff, rlen) = rhs;
663 confirm_equal_len("compare string/dict length mismatch (slice contract)", llen, rlen)?;
664
665 let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
666 let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
667
668 if let Some(m) = larr.null_mask.as_ref() {
669 if m.capacity() < loff + llen {
670 return Err(KernelError::InvalidArguments(
671 format!(
672 "lhs mask capacity too small (expected ≥ {}, got {})",
673 loff + llen,
674 m.capacity()
675 ),
676 ));
677 }
678 }
679 if let Some(m) = rarr.null_mask.as_ref() {
680 if m.capacity() < roff + rlen {
681 return Err(KernelError::InvalidArguments(
682 format!(
683 "rhs mask capacity too small (expected ≥ {}, got {})",
684 roff + rlen,
685 m.capacity()
686 ),
687 ));
688 }
689 }
690
691 let has_nulls = lhs_mask.is_some() || rhs_mask.is_some();
692 let mut out = new_bool_bitmask(llen);
693 for i in 0..llen {
694 if has_nulls
695 && !(lhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) })
696 && rhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }))
697 {
698 continue;
699 }
700 let l = unsafe { larr.get_str_unchecked(loff + i) };
701 let r = unsafe { rarr.get_str_unchecked(roff + i) };
702 let res = match op {
703 ComparisonOperator::Equals => l == r,
704 ComparisonOperator::NotEquals => l != r,
705 ComparisonOperator::GreaterThan => l > r,
706 ComparisonOperator::LessThan => l < r,
707 ComparisonOperator::GreaterThanOrEqualTo => l >= r,
708 ComparisonOperator::LessThanOrEqualTo => l <= r,
709 _ => false,
710 };
711 if res {
712 out.set(i, true);
713 }
714 }
715 let null_mask = merge_bitmasks_to_new(lhs_mask.as_ref(), rhs_mask.as_ref(), llen);
716 Ok(BooleanArray { data: out.into(), null_mask, len: llen, _phantom: PhantomData })
717 }
718 };
719}
720
721impl_cmp_numeric!(cmp_i32, i32, W32, i32);
722impl_cmp_numeric!(cmp_u32, u32, W32, i32);
723impl_cmp_numeric!(cmp_i64, i64, W64, i64);
724impl_cmp_numeric!(cmp_u64, u64, W64, i64);
725impl_cmp_numeric!(cmp_f32, f32, W32, i32);
726impl_cmp_numeric!(cmp_f64, f64, W64, i64);
727impl_cmp_utf8_slice!(cmp_str_str, StringAVT<'a, T>, StringAVT<'a, T>, [ 'a, T: Integer ]);
728impl_cmp_utf8_slice!(cmp_str_dict, StringAVT<'a, T>, CategoricalAVT<'a, U>, [ 'a, T: Integer, U: Integer ]);
729impl_cmp_utf8_slice!(cmp_dict_str, CategoricalAVT<'a, T>, StringAVT<'a, U>, [ 'a, T: Integer, U: Integer ]);
730impl_cmp_utf8_slice!(cmp_dict_dict, CategoricalAVT<'a, T>, CategoricalAVT<'a, T>, [ 'a, T: Integer ]);
731
732#[cfg(test)]
733mod tests {
734 use minarrow::{Bitmask, BooleanArray, CategoricalArray, Integer, StringArray, vec64};
735
736 use crate::kernels::comparison::{
737 cmp_dict_dict, cmp_dict_str, cmp_i32, cmp_numeric, cmp_str_dict,
738 };
739
740 #[cfg(feature = "simd")]
741 use crate::kernels::comparison::{W64, cmp_bitmask_simd};
742
743 use crate::operators::ComparisonOperator;
744
745 fn bm(bits: &[bool]) -> Bitmask {
748 let mut m = Bitmask::new_set_all(bits.len(), false);
749 for (i, &b) in bits.iter().enumerate() {
750 m.set(i, b);
751 }
752 m
753 }
754
755 fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
757 assert_eq!(arr.len, expect.len());
758 for i in 0..expect.len() {
759 assert_eq!(arr.data.get(i), expect[i], "value bit {i}");
760 }
761 match (arr.null_mask.as_ref(), expect_mask) {
762 (None, None) => {}
763 (Some(m), Some(exp)) => {
764 for (i, &b) in exp.iter().enumerate() {
765 assert_eq!(m.get(i), b, "null-bit {i}");
766 }
767 }
768 _ => panic!("mask mismatch"),
769 }
770 }
771
772 fn str_arr<T: Integer>(v: &[&str]) -> StringArray<T> {
774 StringArray::<T>::from_slice(v)
775 }
776
777 fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
778 let owned: Vec<&str> = vals.to_vec();
779 CategoricalArray::<T>::from_values(owned)
780 }
781
782 #[test]
785 fn numeric_compare_no_nulls() {
786 let a = vec64![1i32, 2, 3, 4];
787 let b = vec64![1i32, 1, 4, 4];
788
789 let eq = cmp_i32(&a, &b, None, ComparisonOperator::Equals).unwrap();
790 let neq = cmp_i32(&a, &b, None, ComparisonOperator::NotEquals).unwrap();
791 let lt = cmp_i32(&a, &b, None, ComparisonOperator::LessThan).unwrap();
792 let le = cmp_i32(&a, &b, None, ComparisonOperator::LessThanOrEqualTo).unwrap();
793 let gt = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThan).unwrap();
794 let ge = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThanOrEqualTo).unwrap();
795
796 assert_bool(&eq, &[true, false, false, true], None);
797 assert_bool(&neq, &[false, true, true, false], None);
798 assert_bool(<, &[false, false, true, false], None);
799 assert_bool(&le, &[true, false, true, true], None);
800 assert_bool(>, &[false, true, false, false], None);
801 assert_bool(&ge, &[true, true, false, true], None);
802 }
803
804 #[test]
805 fn numeric_compare_with_nulls_generic_dispatch() {
806 let a = vec64![1u64, 5, 9, 10];
808 let b = vec64![0u64, 5, 8, 11];
809 let mask = bm(&[true, true, true, false]);
810
811 let out = cmp_numeric(&a, &b, Some(&mask), ComparisonOperator::GreaterThan).unwrap();
812 assert_bool(
814 &out,
815 &[true, false, true, false],
816 Some(&[true, true, true, false]),
817 );
818 }
819
820 #[cfg(feature = "simd")]
823 #[test]
824 fn bool_compare_all_ops() {
825 let a = bm(&[true, false, true, false]);
826 let b = bm(&[true, true, false, false]);
827 let eq = cmp_bitmask_simd::<W64>(
828 (&a, 0, a.len()),
829 (&b, 0, b.len()),
830 None,
831 ComparisonOperator::Equals,
832 )
833 .unwrap();
834 let lt = cmp_bitmask_simd::<W64>(
835 (&a, 0, a.len()),
836 (&b, 0, b.len()),
837 None,
838 ComparisonOperator::LessThan,
839 )
840 .unwrap();
841 let gt = cmp_bitmask_simd::<W64>(
842 (&a, 0, a.len()),
843 (&b, 0, b.len()),
844 None,
845 ComparisonOperator::GreaterThan,
846 )
847 .unwrap();
848
849 assert_bool(
850 &BooleanArray::from_bitmask(eq, None),
851 &[true, false, false, true],
852 None,
853 );
854 assert_bool(
855 &BooleanArray::from_bitmask(lt, None),
856 &[false, true, false, false],
857 None,
858 );
859 assert_bool(
860 &BooleanArray::from_bitmask(gt, None),
861 &[false, false, true, false],
862 None,
863 );
864 }
865
866 #[test]
869 fn string_vs_dict_compare_with_nulls() {
870 let mut lhs = str_arr::<u32>(&["x", "y", "z"]);
871 lhs.null_mask = Some(bm(&[true, false, true]));
872 let rhs = dict_arr::<u32>(&["x", "w", "zz"]);
873 let lhs_slice = (&lhs, 0, lhs.len());
874 let rhs_slice = (&rhs, 0, rhs.data.len());
875 let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
876 assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
877 }
878
879 #[test]
880 fn string_vs_dict_compare_with_nulls_chunk() {
881 let mut lhs = str_arr::<u32>(&["pad", "x", "y", "z", "pad"]);
882 lhs.null_mask = Some(bm(&[true, true, false, true, true]));
883 let rhs = dict_arr::<u32>(&["pad", "x", "w", "zz", "pad"]);
884 let lhs_slice = (&lhs, 1, 3);
885 let rhs_slice = (&rhs, 1, 3);
886 let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
887 assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
888 }
889
890 #[test]
891 fn dict_vs_dict_compare_gt() {
892 let lhs = dict_arr::<u32>(&["apple", "pear", "banana"]);
893 let rhs = dict_arr::<u32>(&["ant", "pear", "apricot"]);
894 let lhs_slice = (&lhs, 0, lhs.data.len());
895 let rhs_slice = (&rhs, 0, rhs.data.len());
896 let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
897 assert_bool(&res, &[true, false, true], None);
898 }
899
900 #[test]
901 fn dict_vs_dict_compare_gt_chunk() {
902 let lhs = dict_arr::<u32>(&["pad", "apple", "pear", "banana", "pad"]);
903 let rhs = dict_arr::<u32>(&["pad", "ant", "pear", "apricot", "pad"]);
904 let lhs_slice = (&lhs, 1, 3);
905 let rhs_slice = (&rhs, 1, 3);
906 let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
907 assert_bool(&res, &[true, false, true], None);
908 }
909
910 #[test]
911 fn dict_vs_string_compare_le() {
912 let lhs = dict_arr::<u32>(&["a", "b", "c"]);
913 let rhs = str_arr::<u32>(&["b", "b", "d"]);
914 let lhs_slice = (&lhs, 0, lhs.data.len());
915 let rhs_slice = (&rhs, 0, rhs.len());
916 let res =
917 cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
918 assert_bool(&res, &[true, true, true], None);
919 }
920
921 #[test]
922 fn dict_vs_string_compare_le_chunk() {
923 let lhs = dict_arr::<u32>(&["pad", "a", "b", "c", "pad"]);
924 let rhs = str_arr::<u32>(&["pad", "b", "b", "d", "pad"]);
925 let lhs_slice = (&lhs, 1, 3);
926 let rhs_slice = (&rhs, 1, 3);
927 let res =
928 cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
929 assert_bool(&res, &[true, true, true], None);
930 }
931}