1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
27
28use std::hash::Hash;
29use std::marker::PhantomData;
30
31use minarrow::traits::type_unions::Float;
32use minarrow::utils::confirm_equal_len;
33use minarrow::{Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, Numeric, StringAVT};
34
35#[cfg(not(feature = "simd"))]
36use crate::kernels::bitmask::std::{and_masks, not_mask};
37#[cfg(not(feature = "simd"))]
38use crate::kernels::comparison::cmp_bitmask_std;
39use crate::kernels::comparison::{
40 cmp_dict_dict, cmp_dict_str, cmp_numeric, cmp_str_dict, cmp_str_str,
41};
42use crate::kernels::logical::{
43 cmp_between, cmp_dict_between, cmp_dict_in, cmp_in, cmp_in_f, cmp_str_between, cmp_str_in,
44};
45use crate::operators::ComparisonOperator;
46use minarrow::enums::error::KernelError;
47
48#[inline(always)]
50fn new_bool_bitmask(len_bits: usize) -> Bitmask {
51 Bitmask::new_set_all(len_bits, false)
52}
53
54#[inline(always)]
56fn full_bool_bitmask(len_bits: usize) -> Bitmask {
57 Bitmask::new_set_all(len_bits, true)
58}
59
60#[inline]
63fn merge_bitmasks_to_new(
64 lhs: Option<&Bitmask>,
65 rhs: Option<&Bitmask>,
66 len: usize,
67) -> Option<Bitmask> {
68 if let Some(m) = lhs {
69 debug_assert!(
70 m.capacity() >= len,
71 "lhs null mask too small: capacity {} < len {}",
72 m.capacity(),
73 len
74 );
75 }
76 if let Some(m) = rhs {
77 debug_assert!(
78 m.capacity() >= len,
79 "rhs null mask too small: capacity {} < len {}",
80 m.capacity(),
81 len
82 );
83 }
84
85 match (lhs, rhs) {
86 (None, None) => None,
87
88 (Some(l), None) | (None, Some(l)) => {
89 let mut out = Bitmask::new_set_all(len, true);
90 for i in 0..len {
91 unsafe {
92 out.set_unchecked(i, l.get_unchecked(i));
93 }
94 }
95 Some(out)
96 }
97
98 (Some(l), Some(r)) => {
99 let mut out = Bitmask::new_set_all(len, true);
100 for i in 0..len {
101 unsafe {
102 out.set_unchecked(i, l.get_unchecked(i) && r.get_unchecked(i));
103 }
104 }
105 Some(out)
106 }
107 }
108}
109
110pub fn apply_cmp<T>(
146 lhs: &[T],
147 rhs: &[T],
148 mask: Option<&Bitmask>,
149 op: ComparisonOperator,
150) -> Result<BooleanArray<()>, KernelError>
151where
152 T: Numeric + Copy + Hash + Eq + PartialOrd + 'static,
153{
154 let len = lhs.len();
155 match op {
156 ComparisonOperator::Between => {
157 let mut out = cmp_between(lhs, rhs)?;
158 out.null_mask = mask.cloned();
159 Ok(out)
160 }
161 ComparisonOperator::In => {
162 let mut out = cmp_in(lhs, rhs)?;
163 out.null_mask = mask.cloned();
164 Ok(out)
165 }
166 ComparisonOperator::NotIn => {
167 let mut out = cmp_in(lhs, rhs)?;
168 for i in 0..len {
169 unsafe { out.data.set_unchecked(i, !out.data.get_unchecked(i)) };
170 }
171 out.null_mask = mask.cloned();
172 Ok(out)
173 }
174 ComparisonOperator::IsNull => Ok(BooleanArray {
175 data: new_bool_bitmask(len),
176 null_mask: mask.cloned(),
177 len,
178 _phantom: std::marker::PhantomData,
179 }),
180 ComparisonOperator::IsNotNull => Ok(BooleanArray {
181 data: full_bool_bitmask(len),
182 null_mask: mask.cloned(),
183 len,
184 _phantom: std::marker::PhantomData,
185 }),
186 _ => {
187 let mut out = cmp_numeric(lhs, rhs, mask, op)?;
188 out.null_mask = mask.cloned();
189 Ok(out)
190 }
191 }
192}
193
194pub fn apply_cmp_f<T>(
212 lhs: &[T],
213 rhs: &[T],
214 mask: Option<&Bitmask>,
215 op: ComparisonOperator,
216) -> Result<BooleanArray<()>, KernelError>
217where
218 T: Float + Numeric + Copy + 'static,
219{
220 let len = lhs.len();
221 match op {
222 ComparisonOperator::Between => {
223 let mut out = cmp_between(lhs, rhs)?;
224 out.null_mask = mask.cloned();
225 Ok(out)
226 }
227 ComparisonOperator::In => {
228 let mut out = cmp_in_f(lhs, rhs)?;
229 out.null_mask = mask.cloned();
230 Ok(out)
231 }
232 ComparisonOperator::NotIn => {
233 let mut out = cmp_in_f(lhs, rhs)?;
234 for i in 0..len {
235 unsafe { out.data.set_unchecked(i, !out.data.get_unchecked(i)) };
236 }
237 out.null_mask = mask.cloned();
238 Ok(out)
239 }
240 ComparisonOperator::IsNull => Ok(BooleanArray {
241 data: new_bool_bitmask(len),
242 null_mask: mask.cloned(),
243 len,
244 _phantom: std::marker::PhantomData,
245 }),
246 ComparisonOperator::IsNotNull => Ok(BooleanArray {
247 data: full_bool_bitmask(len),
248 null_mask: mask.cloned(),
249 len,
250 _phantom: std::marker::PhantomData,
251 }),
252 ComparisonOperator::Equals | ComparisonOperator::NotEquals => {
253 let mut out = cmp_numeric(lhs, rhs, mask, op)?;
254 for i in 0..len {
256 let is_valid = mask.map_or(true, |m| unsafe { m.get_unchecked(i) });
257 if is_valid && lhs[i].is_nan() && rhs[i].is_nan() {
258 match op {
259 ComparisonOperator::Equals => unsafe { out.data.set_unchecked(i, true) },
260 ComparisonOperator::NotEquals => unsafe {
261 out.data.set_unchecked(i, false)
262 },
263 _ => {}
264 }
265 }
266 }
267 out.null_mask = mask.cloned();
268 Ok(out)
269 }
270 _ => {
271 let mut out = cmp_numeric(lhs, rhs, mask, op)?;
272 out.null_mask = mask.cloned();
273 Ok(out)
274 }
275 }
276}
277
278#[inline(always)]
285pub fn apply_cmp_bool(
286 lhs: BooleanAVT<'_, ()>,
287 rhs: BooleanAVT<'_, ()>,
288 op: ComparisonOperator,
289) -> Result<BooleanArray<()>, KernelError> {
290 let (lhs_arr, lhs_off, len) = lhs;
291 let (rhs_arr, rhs_off, rlen) = rhs;
292 confirm_equal_len("apply_cmp_bool_windowed: window length mismatch", len, rlen)?;
293
294 #[cfg(feature = "simd")]
296 let merged_null_mask: Option<Bitmask> =
297 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
298 (None, None) => None,
299 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
300 (Some(a), Some(b)) => {
301 use minarrow::kernels::bitmask::simd::and_masks_simd;
302 let am = (a, lhs_off, len);
303 let bm = (b, rhs_off, len);
304 Some(and_masks_simd::<W8>(am, bm))
305 }
306 };
307
308 #[cfg(not(feature = "simd"))]
309 let merged_null_mask: Option<Bitmask> =
310 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
311 (None, None) => None,
312 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
313 (Some(a), Some(b)) => {
314 let am = (a, lhs_off, len);
315 let bm = (b, rhs_off, len);
316 Some(and_masks(am, bm))
317 }
318 };
319
320 let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
321
322 #[cfg(feature = "simd")]
323 let data = match op {
324 ComparisonOperator::Equals
325 | ComparisonOperator::NotEquals
326 | ComparisonOperator::LessThan
327 | ComparisonOperator::LessThanOrEqualTo
328 | ComparisonOperator::GreaterThan
329 | ComparisonOperator::GreaterThanOrEqualTo
330 | ComparisonOperator::In
331 | ComparisonOperator::NotIn => crate::kernels::comparison::cmp_bitmask_simd::<W8>(
332 (&lhs_arr.data, lhs_off, len),
333 (&rhs_arr.data, rhs_off, len),
334 mask_slice,
335 op,
336 )?,
337 ComparisonOperator::IsNull => {
338 let data = match merged_null_mask.as_ref() {
339 Some(mask) => minarrow::kernels::bitmask::simd::not_mask_simd::<W8>((mask, 0, len)),
340 None => Bitmask::new_set_all(len, false),
341 };
342 return Ok(BooleanArray {
343 data,
344 null_mask: None,
345 len,
346 _phantom: PhantomData,
347 });
348 }
349 ComparisonOperator::IsNotNull => {
350 let data = match merged_null_mask.as_ref() {
351 Some(mask) => mask.slice_clone(0, len),
352 None => Bitmask::new_set_all(len, true),
353 };
354 return Ok(BooleanArray {
355 data,
356 null_mask: None,
357 len,
358 _phantom: PhantomData,
359 });
360 }
361 ComparisonOperator::Between => {
362 return Err(KernelError::InvalidArguments(
363 "Set operations are not defined for Bool arrays".to_owned(),
364 ));
365 }
366 };
367
368 #[cfg(not(feature = "simd"))]
369 let data = match op {
370 ComparisonOperator::Equals
371 | ComparisonOperator::NotEquals
372 | ComparisonOperator::LessThan
373 | ComparisonOperator::LessThanOrEqualTo
374 | ComparisonOperator::GreaterThan
375 | ComparisonOperator::GreaterThanOrEqualTo
376 | ComparisonOperator::In
377 | ComparisonOperator::NotIn => cmp_bitmask_std(
378 (&lhs_arr.data, lhs_off, len),
379 (&rhs_arr.data, rhs_off, len),
380 mask_slice,
381 op,
382 )?,
383 ComparisonOperator::IsNull => {
384 let data = match merged_null_mask.as_ref() {
385 Some(mask) => not_mask((mask, 0, len)),
386 None => Bitmask::new_set_all(len, false),
387 };
388 return Ok(BooleanArray {
389 data,
390 null_mask: None,
391 len,
392 _phantom: PhantomData,
393 });
394 }
395 ComparisonOperator::IsNotNull => {
396 let data = match merged_null_mask.as_ref() {
397 Some(mask) => mask.slice_clone(0, len),
398 None => Bitmask::new_set_all(len, true),
399 };
400 return Ok(BooleanArray {
401 data,
402 null_mask: None,
403 len,
404 _phantom: PhantomData,
405 });
406 }
407 ComparisonOperator::Between => {
408 return Err(KernelError::InvalidArguments(
409 "Set operations are not defined for Bool arrays".to_owned(),
410 ));
411 }
412 };
413
414 Ok(BooleanArray {
415 data,
416 null_mask: merged_null_mask,
417 len,
418 _phantom: PhantomData,
419 })
420}
421
422pub fn apply_cmp_str<T: Integer>(
447 lhs: StringAVT<T>,
448 rhs: StringAVT<T>,
449 op: ComparisonOperator,
450) -> Result<BooleanArray<()>, KernelError> {
451 let (larr, loff, llen) = lhs;
453 let (rarr, roff, rlen) = rhs;
454
455 assert_eq!(llen, rlen, "apply_cmp_str: slice lengths must match");
456
457 let null_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
458
459 let mut out = match op {
460 ComparisonOperator::Between => cmp_str_between((larr, loff, llen), (rarr, roff, rlen)),
461 ComparisonOperator::In => cmp_str_in((larr, loff, llen), (rarr, roff, rlen)),
462 ComparisonOperator::NotIn => {
463 let mut b = cmp_str_in((larr, loff, llen), (rarr, roff, rlen))?;
464 debug_assert!(
465 b.data.capacity() >= llen,
466 "bitmask capacity {} < needed len {}",
467 b.data.capacity(),
468 llen
469 );
470 for i in 0..llen {
471 unsafe { b.data.set_unchecked(i, !b.data.get_unchecked(i)) };
472 }
473 Ok(b)
474 }
475 ComparisonOperator::IsNull => Ok(BooleanArray {
476 data: new_bool_bitmask(llen),
477 null_mask: null_mask.clone(),
478 len: llen,
479 _phantom: std::marker::PhantomData,
480 }),
481 ComparisonOperator::IsNotNull => Ok(BooleanArray {
482 data: full_bool_bitmask(llen),
483 null_mask: null_mask.clone(),
484 len: llen,
485 _phantom: std::marker::PhantomData,
486 }),
487 _ => cmp_str_str((larr, loff, llen), (rarr, roff, rlen), op),
488 }?;
489 out.null_mask = null_mask;
490 out.len = llen;
491 Ok(out)
492}
493
494pub fn apply_cmp_str_dict<T: Integer, U: Integer>(
514 lhs: StringAVT<T>,
515 rhs: CategoricalAVT<U>,
516 op: ComparisonOperator,
517) -> Result<BooleanArray<()>, KernelError> {
518 let (larr, loff, llen) = lhs;
519 let (rarr, roff, rlen) = rhs;
520 assert_eq!(llen, rlen, "apply_cmp_str_dict: slice lengths must match");
521
522 let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
524 let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
525 let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
526
527 let mut out = cmp_str_dict((larr, loff, llen), (rarr, roff, rlen), op)?;
528 out.null_mask = null_mask;
529 out.len = llen;
530 Ok(out)
531}
532
533pub fn apply_cmp_dict_str<T: Integer, U: Integer>(
553 lhs: CategoricalAVT<T>,
554 rhs: StringAVT<U>,
555 op: ComparisonOperator,
556) -> Result<BooleanArray<()>, KernelError> {
557 let (larr, loff, llen) = lhs;
558 let (rarr, roff, rlen) = rhs;
559 assert_eq!(llen, rlen, "apply_cmp_dict_str: slice lengths must match");
560
561 let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
563 let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
564 let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
565
566 let mut out = cmp_dict_str((larr, loff, llen), (rarr, roff, rlen), op)?;
567 out.null_mask = null_mask;
568 out.len = llen;
569 Ok(out)
570}
571
572pub fn apply_cmp_dict<T: Integer + Hash>(
593 lhs: CategoricalAVT<T>,
594 rhs: CategoricalAVT<T>,
595 op: ComparisonOperator,
596) -> Result<BooleanArray<()>, KernelError> {
597 let (larr, loff, llen) = lhs;
598 let (rarr, roff, rlen) = rhs;
599 assert_eq!(llen, rlen, "apply_cmp_dict: slice lengths must match");
600 let null_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
601 let mut out = match op {
602 ComparisonOperator::Between => cmp_dict_between((larr, loff, llen), (rarr, roff, rlen)),
603 ComparisonOperator::In => cmp_dict_in((larr, loff, llen), (rarr, roff, rlen)),
604 ComparisonOperator::NotIn => {
605 let mut b = cmp_dict_in((larr, loff, llen), (rarr, roff, rlen))?;
606 for i in 0..llen {
607 unsafe {
608 b.data.set_unchecked(i, !b.data.get_unchecked(i));
609 }
610 }
611 Ok(b)
612 }
613 ComparisonOperator::IsNull => Ok(BooleanArray {
614 data: new_bool_bitmask(llen),
615 null_mask: null_mask.clone(),
616 len: llen,
617 _phantom: std::marker::PhantomData,
618 }),
619 ComparisonOperator::IsNotNull => Ok(BooleanArray {
620 data: full_bool_bitmask(llen),
621 null_mask: null_mask.clone(),
622 len: llen,
623 _phantom: std::marker::PhantomData,
624 }),
625 _ => cmp_dict_dict((larr, loff, llen), (rarr, roff, rlen), op),
626 }?;
627 out.null_mask = null_mask;
628 out.len = llen;
629 Ok(out)
630}
631
632#[cfg(test)]
633mod tests {
634 use minarrow::structs::variants::categorical::CategoricalArray;
635 use minarrow::structs::variants::string::StringArray;
636 use minarrow::{Bitmask, BooleanArray, MaskedArray, vec64};
637
638 use super::*;
639
640 fn bm(bools: &[bool]) -> Bitmask {
642 Bitmask::from_bools(bools)
643 }
644 fn bool_arr(bools: &[bool]) -> BooleanArray<()> {
645 BooleanArray::from_slice(bools)
646 }
647
648 #[test]
650 fn test_apply_cmp_numeric_all_ops() {
651 let a = vec64![1, 2, 3, 4, 5, 6];
652 let b = vec64![3, 2, 1, 4, 5, 0];
653 let mask = bm(&[true, false, true, true, true, true]);
654
655 for &op in &[
657 ComparisonOperator::Equals,
658 ComparisonOperator::NotEquals,
659 ComparisonOperator::LessThan,
660 ComparisonOperator::LessThanOrEqualTo,
661 ComparisonOperator::GreaterThan,
662 ComparisonOperator::GreaterThanOrEqualTo,
663 ] {
664 let arr = apply_cmp(&a, &b, Some(&mask), op).unwrap();
665 for i in 0..a.len() {
666 let expect = match op {
667 ComparisonOperator::Equals => a[i] == b[i],
668 ComparisonOperator::NotEquals => a[i] != b[i],
669 ComparisonOperator::LessThan => a[i] < b[i],
670 ComparisonOperator::LessThanOrEqualTo => a[i] <= b[i],
671 ComparisonOperator::GreaterThan => a[i] > b[i],
672 ComparisonOperator::GreaterThanOrEqualTo => a[i] >= b[i],
673 _ => unreachable!(),
674 };
675 if mask.get(i) {
676 assert_eq!(arr.data.get(i), expect);
677 } else {
678 assert_eq!(arr.get(i), None);
679 }
680 }
681 assert_eq!(arr.null_mask, Some(mask.clone()));
682 }
683 }
684
685 #[test]
686 fn test_apply_cmp_numeric_between_in_notin() {
687 let a = vec64![4, 2, 3, 5];
688 let mask = bm(&[true, true, false, true]);
689 let rhs = vec64![2, 4];
691 let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::Between).unwrap();
692 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), true); assert_eq!(arr.get(2), None);
695 assert_eq!(arr.data.get(3), false); let rhs = vec64![2, 3, 4];
698 let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::In).unwrap();
699 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), true); assert_eq!(arr.get(2), None);
702 assert_eq!(arr.data.get(3), false); let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::NotIn).unwrap();
705 assert_eq!(arr.data.get(0), false);
706 assert_eq!(arr.data.get(1), false);
707 assert_eq!(arr.get(2), None);
708 assert_eq!(arr.data.get(3), true);
709 }
710
711 #[test]
712 fn test_apply_cmp_numeric_isnull_isnotnull() {
713 let a = vec64![1, 2, 3];
714 let mask = bm(&[true, false, true]);
715 let arr = apply_cmp(&a, &a, Some(&mask), ComparisonOperator::IsNull).unwrap();
716 assert_eq!(arr.data.get(0), false);
717 assert_eq!(arr.data.get(1), false);
718 assert_eq!(arr.data.get(2), false);
719 assert_eq!(arr.null_mask, Some(mask.clone()));
720 let arr = apply_cmp(&a, &a, Some(&mask), ComparisonOperator::IsNotNull).unwrap();
721 assert_eq!(arr.data.get(0), true);
722 assert_eq!(arr.data.get(1), true);
723 assert_eq!(arr.data.get(2), true);
724 assert_eq!(arr.null_mask, Some(mask.clone()));
725 }
726
727 #[test]
728 fn test_apply_cmp_numeric_edge_cases() {
729 let a: [i32; 0] = [];
731 let arr = apply_cmp(&a, &a, None, ComparisonOperator::Equals).unwrap();
732 assert_eq!(arr.len, 0);
733 assert!(arr.null_mask.is_none());
734 let a = vec64![7];
736 let arr = apply_cmp(&a, &a, None, ComparisonOperator::Equals).unwrap();
737 assert_eq!(arr.data.get(0), true);
738 assert!(arr.null_mask.is_none());
739 }
740
741 #[test]
742 fn test_apply_cmp_f_all_ops_nan_patch() {
743 let a = vec64![1.0, 2.0, f32::NAN, f32::NAN];
744 let b = vec64![1.0, 3.0, f32::NAN, 0.0];
745 let mask = bm(&[true, true, true, false]);
746 for &op in &[ComparisonOperator::Equals, ComparisonOperator::NotEquals] {
748 let arr = apply_cmp_f(&a, &b, Some(&mask), op).unwrap();
749 assert_eq!(arr.data.get(2), matches!(op, ComparisonOperator::Equals)) }
751 let arr = apply_cmp_f(&a, &b, Some(&mask), ComparisonOperator::In).unwrap();
753 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), false);
755 }
756
757 #[test]
758 fn test_cmp_bool_w8() {
759 let a = bool_arr(&[true, false, true]);
760 let b = bool_arr(&[false, false, true]);
761 let op = ComparisonOperator::Equals;
762 let arr = apply_cmp_bool((&a, 0, a.len()), (&b, 0, b.len()), op).unwrap();
763 assert!(!arr.data.get(0));
764 assert!(arr.data.get(1));
765 assert!(arr.data.get(2));
766 println!("mask bytes: {:02x?}", arr.data.bits);
767 println!("get(0): {}", arr.data.get(0));
768 println!("get(1): {}", arr.data.get(1));
769 println!("get(2): {}", arr.data.get(2));
770 println!("lhs: {:?}", a);
771 println!("rhs: {:?}", b);
772 println!(
773 "{}: mask bytes: {:?} get(0): {} get(1): {} get(2): {}",
774 stringify!($test_name),
775 arr.data.as_slice(),
776 arr.data.get(0),
777 arr.data.get(1),
778 arr.data.get(2)
779 );
780
781 let arr = apply_cmp_bool(
783 (&a, 0, a.len()),
784 (&b, 0, b.len()),
785 ComparisonOperator::NotEquals,
786 )
787 .unwrap();
788 assert!(arr.data.get(0));
789 assert!(!arr.data.get(1));
790 assert!(!arr.data.get(2));
791
792 let arr = apply_cmp_bool(
794 (&a, 0, a.len()),
795 (&b, 0, b.len()),
796 ComparisonOperator::LessThan,
797 )
798 .unwrap();
799 assert!(!arr.data.get(0));
800 assert!(!arr.data.get(1));
801 assert!(!arr.data.get(2));
802
803 let mut a = bool_arr(&[true, false]);
805 a.null_mask = Some(bm(&[true, false]));
806 let mut b = bool_arr(&[true, false]);
807 b.null_mask = Some(bm(&[true, true]));
808 let arr = apply_cmp_bool(
809 (&a, 0, a.len()),
810 (&b, 0, b.len()),
811 ComparisonOperator::Equals,
812 )
813 .unwrap();
814 assert!(arr.null_mask.as_ref().unwrap().get(0));
815 assert!(!arr.null_mask.as_ref().unwrap().get(1));
816 }
817
818 #[test]
819 fn test_bool_is_null() {
820 let a = bool_arr(&[true, false]);
821 let b = bool_arr(&[false, true]);
822 let arr = apply_cmp_bool(
823 (&a, 0, a.len()),
824 (&b, 0, b.len()),
825 ComparisonOperator::IsNull,
826 )
827 .unwrap();
828 assert!(!arr.data.get(0));
829 assert!(!arr.data.get(1));
830 let arr = apply_cmp_bool(
831 (&a, 0, a.len()),
832 (&b, 0, b.len()),
833 ComparisonOperator::IsNotNull,
834 )
835 .unwrap();
836 assert!(arr.data.get(0));
837 assert!(arr.data.get(1));
838 }
839
840 #[test]
843 fn test_apply_cmp_str_all_ops() {
844 let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "qux"]);
845 let b = StringArray::<u32>::from_slice(&["foo", "baz", "baz", "quux"]);
846 let mut a2 = a.clone();
847 a2.set_null(2);
848 let a_slice = (&a, 0, a.len());
849 let b_slice = (&b, 0, b.len());
850 let a2_slice = (&a2, 0, a2.len());
851
852 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Equals).unwrap();
853 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), false); assert_eq!(arr.data.get(2), true); assert_eq!(arr.data.get(3), false);
857
858 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotEquals).unwrap();
860 assert_eq!(arr.data.get(0), false);
861 assert_eq!(arr.data.get(1), true);
862 assert_eq!(arr.data.get(2), false);
863 assert_eq!(arr.data.get(3), true);
864
865 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::LessThan).unwrap();
867 assert_eq!(arr.data.get(0), false); assert_eq!(arr.data.get(1), true); assert_eq!(arr.data.get(2), false);
870 assert_eq!(arr.data.get(3), false);
871
872 let mut b2 = b.clone();
874 b2.set_null(1);
875 let b2_slice = (&b2, 0, b2.len());
876 let arr = apply_cmp_str(a2_slice, b2_slice, ComparisonOperator::Equals).unwrap();
877 assert!(!arr.null_mask.as_ref().unwrap().get(2));
878 assert!(!arr.null_mask.as_ref().unwrap().get(1));
879 assert!(arr.null_mask.as_ref().unwrap().get(0));
880 assert!(arr.null_mask.as_ref().unwrap().get(3));
881 }
882
883 #[test]
884 fn test_apply_cmp_str_all_ops_chunk() {
885 let a = StringArray::<u32>::from_slice(&["x", "foo", "bar", "baz", "qux", "y"]);
886 let b = StringArray::<u32>::from_slice(&["q", "foo", "baz", "baz", "quux", "z"]);
887 let a_slice = (&a, 1, 4);
889 let b_slice = (&b, 1, 4);
890 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Equals).unwrap();
891 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), false); assert_eq!(arr.data.get(2), true); assert_eq!(arr.data.get(3), false);
895 }
896
897 #[test]
898 fn test_apply_cmp_str_set_ops() {
899 let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz"]);
900 let b = StringArray::<u32>::from_slice(&["foo", "qux", "baz"]);
901 let a_slice = (&a, 0, a.len());
902 let b_slice = (&b, 0, b.len());
903 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Between).unwrap();
905 assert_eq!(arr.len, 3);
906 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::In).unwrap();
908 let arr2 = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
909 for i in 0..a.len() {
910 assert_eq!(arr.data.get(i), !arr2.data.get(i));
911 }
912 }
913
914 #[test]
915 fn test_apply_cmp_str_set_ops_chunk() {
916 let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "w"]);
917 let b = StringArray::<u32>::from_slice(&["foo", "qux", "baz", "w"]);
918 let a_slice = (&a, 1, 2);
920 let b_slice = (&b, 1, 2);
921 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Between).unwrap();
923 assert_eq!(arr.len, 2);
924 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::In).unwrap();
926 let arr2 = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
927 for i in 0..2 {
928 assert_eq!(arr.data.get(i), !arr2.data.get(i));
929 }
930 }
931
932 #[test]
933 fn test_apply_cmp_str_isnull_isnotnull() {
934 let a = StringArray::<u32>::from_slice(&["foo"]);
935 let b = StringArray::<u32>::from_slice(&["bar"]);
936 let a_slice = (&a, 0, a.len());
937 let b_slice = (&b, 0, b.len());
938 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
939 assert_eq!(arr.data.get(0), false);
940 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
941 assert_eq!(arr.data.get(0), true);
942 }
943
944 #[test]
945 fn test_apply_cmp_str_isnull_isnotnull_chunk() {
946 let a = StringArray::<u32>::from_slice(&["pad", "foo"]);
947 let b = StringArray::<u32>::from_slice(&["pad", "bar"]);
948 let a_slice = (&a, 1, 1);
949 let b_slice = (&b, 1, 1);
950 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
951 assert_eq!(arr.data.get(0), false);
952 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
953 assert_eq!(arr.data.get(0), true);
954 }
955
956 #[test]
959 fn test_apply_cmp_str_dict() {
960 let s = StringArray::<u32>::from_slice(&["a", "b", "c"]);
961 let dict = CategoricalArray::<u32>::from_slices(&[0, 1, 0], &["a".into(), "b".into()]);
962
963 let s_slice = (&s, 0, s.len());
964 let dict_slice = (&dict, 0, dict.data.len());
965 let arr = apply_cmp_str_dict(s_slice, dict_slice, ComparisonOperator::Equals).unwrap();
966 assert_eq!(arr.len, 3);
967
968 let mut s2 = s.clone();
970 s2.set_null(0);
971 let mut d2 = dict.clone();
972 d2.set_null(1);
973
974 let s2_slice = (&s2, 0, s2.len());
975 let d2_slice = (&d2, 0, d2.data.len());
976 let arr = apply_cmp_str_dict(s2_slice, d2_slice, ComparisonOperator::Equals).unwrap();
977
978 let mask = arr.null_mask.as_ref().unwrap();
979 assert!(!mask.get(0));
980 assert!(!mask.get(1));
981 assert!(mask.get(2));
982 }
983
984 #[test]
985 fn test_apply_cmp_str_dict_chunk() {
986 let s = StringArray::<u32>::from_slice(&["pad", "a", "b", "c", "pad2"]);
987 let dict = CategoricalArray::<u32>::from_slices(
988 &[2, 0, 1, 0, 2], &["z".into(), "a".into(), "b".into()],
990 );
991 let s_slice = (&s, 1, 3);
993 let dict_slice = (&dict, 1, 3);
994 let arr = apply_cmp_str_dict(s_slice, dict_slice, ComparisonOperator::Equals).unwrap();
995 assert_eq!(arr.len, 3);
996 }
997
998 #[test]
1001 fn test_apply_cmp_dict_str() {
1002 let dict = CategoricalArray::<u32>::from_slices(&[0, 1, 0], &["a".into(), "b".into()]);
1003 let s = StringArray::<u32>::from_slice(&["a", "b", "c"]);
1004 let dict_slice = (&dict, 0, dict.data.len());
1005 let s_slice = (&s, 0, s.len());
1006 let arr = apply_cmp_dict_str(dict_slice, s_slice, ComparisonOperator::Equals).unwrap();
1007 assert_eq!(arr.len, 3);
1008 }
1009
1010 #[test]
1011 fn test_apply_cmp_dict_str_chunk() {
1012 let dict = CategoricalArray::<u32>::from_slices(
1013 &[2, 0, 1, 0, 2], &["z".into(), "a".into(), "b".into()],
1015 );
1016 let s = StringArray::<u32>::from_slice(&["pad", "a", "b", "c", "pad2"]);
1017 let dict_slice = (&dict, 1, 3);
1018 let s_slice = (&s, 1, 3);
1019 let arr = apply_cmp_dict_str(dict_slice, s_slice, ComparisonOperator::Equals).unwrap();
1020 assert_eq!(arr.len, 3);
1021 }
1022
1023 #[test]
1026 fn test_apply_cmp_dict_all_ops() {
1027 let a = CategoricalArray::<u32>::from_slices(
1028 &[0, 1, 2],
1029 &["dog".into(), "cat".into(), "fish".into()],
1030 );
1031 let b = CategoricalArray::<u32>::from_slices(
1032 &[2, 1, 0],
1033 &["fish".into(), "cat".into(), "dog".into()],
1034 );
1035
1036 let a_slice = (&a, 0, a.data.len());
1037 let b_slice = (&b, 0, b.data.len());
1038
1039 for &op in &[
1041 ComparisonOperator::Equals,
1042 ComparisonOperator::NotEquals,
1043 ComparisonOperator::LessThan,
1044 ComparisonOperator::LessThanOrEqualTo,
1045 ComparisonOperator::GreaterThan,
1046 ComparisonOperator::GreaterThanOrEqualTo,
1047 ] {
1048 let arr = apply_cmp_dict(a_slice, b_slice, op).unwrap();
1049 assert_eq!(arr.len, 3);
1050 }
1051 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1053 assert_eq!(arr.len, 3);
1054 let arr2 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::In).unwrap();
1055 let arr3 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1056 for i in 0..3 {
1057 assert_eq!(arr2.data.get(i), !arr3.data.get(i));
1058 }
1059 }
1060
1061 #[test]
1062 fn test_apply_cmp_dict_all_ops_chunk() {
1063 let a = CategoricalArray::<u32>::from_slices(
1064 &[0, 1, 2, 3, 1], &["pad".into(), "dog".into(), "cat".into(), "fish".into()],
1066 );
1067 let b = CategoricalArray::<u32>::from_slices(
1068 &[3, 2, 1, 0, 2], &["foo".into(), "fish".into(), "cat".into(), "dog".into()],
1070 );
1071 let a_slice = (&a, 1, 3);
1073 let b_slice = (&b, 1, 3);
1074
1075 for &op in &[
1076 ComparisonOperator::Equals,
1077 ComparisonOperator::NotEquals,
1078 ComparisonOperator::LessThan,
1079 ComparisonOperator::LessThanOrEqualTo,
1080 ComparisonOperator::GreaterThan,
1081 ComparisonOperator::GreaterThanOrEqualTo,
1082 ] {
1083 let arr = apply_cmp_dict(a_slice, b_slice, op).unwrap();
1084 assert_eq!(arr.len, 3);
1085 }
1086 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1087 assert_eq!(arr.len, 3);
1088 let arr2 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::In).unwrap();
1089 let arr3 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1090 for i in 0..3 {
1091 assert_eq!(arr2.data.get(i), !arr3.data.get(i));
1092 }
1093 }
1094
1095 #[test]
1096 fn test_apply_cmp_dict_isnull_isnotnull() {
1097 let a = CategoricalArray::<u32>::from_slices(&[0, 1], &["x".into(), "y".into()]);
1098 let b = CategoricalArray::<u32>::from_slices(&[1, 0], &["y".into(), "x".into()]);
1099 let a_slice = (&a, 0, a.data.len());
1100 let b_slice = (&b, 0, b.data.len());
1101 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1102 assert_eq!(arr.data.get(0), false);
1103 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1104 assert_eq!(arr.data.get(0), true);
1105 }
1106
1107 #[test]
1108 fn test_apply_cmp_dict_isnull_isnotnull_chunk() {
1109 let a = CategoricalArray::<u32>::from_slices(
1110 &[2, 0, 1, 2],
1111 &["z".into(), "x".into(), "y".into()],
1112 );
1113 let b = CategoricalArray::<u32>::from_slices(
1114 &[2, 1, 0, 1],
1115 &["w".into(), "y".into(), "x".into(), "z".into()],
1116 );
1117 let a_slice = (&a, 1, 2);
1118 let b_slice = (&b, 1, 2);
1119 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1120 assert_eq!(arr.data.get(0), false);
1121 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1122 assert_eq!(arr.data.get(0), true);
1123 }
1124
1125 #[test]
1126 #[should_panic(expected = "All indices must be valid for unique_values")]
1127 fn test_apply_cmp_dict_isnull_isnotnull_chunk_invalid_indices() {
1128 let a = CategoricalArray::<u32>::from_slices(
1129 &[9, 0, 1, 9], &["z".into(), "x".into(), "y".into()],
1131 );
1132 let b = CategoricalArray::<u32>::from_slices(
1133 &[2, 1, 0, 3], &["w".into(), "y".into(), "x".into(), "z".into()],
1135 );
1136 let a_slice = (&a, 1, 2);
1137 let b_slice = (&b, 1, 2);
1138 let _ = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1139 let _ = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1140 }
1141
1142 #[test]
1144 fn test_merge_bitmasks_to_new_none_none() {
1145 assert!(merge_bitmasks_to_new(None, None, 5).is_none());
1146 }
1147 #[test]
1148 fn test_merge_bitmasks_to_new_some_none() {
1149 let m = bm(&[true, false, true]);
1150 let out = merge_bitmasks_to_new(Some(&m), None, 3).unwrap();
1151 for i in 0..3 {
1152 assert_eq!(out.get(i), m.get(i));
1153 }
1154 let out2 = merge_bitmasks_to_new(None, Some(&m), 3).unwrap();
1155 for i in 0..3 {
1156 assert_eq!(out2.get(i), m.get(i));
1157 }
1158 }
1159 #[test]
1160 fn test_merge_bitmasks_to_new_both_some_and() {
1161 let a = bm(&[true, false, true, true]);
1162 let b = bm(&[true, true, false, true]);
1163 let out = merge_bitmasks_to_new(Some(&a), Some(&b), 4).unwrap();
1164 assert_eq!(out.get(0), true);
1165 assert_eq!(out.get(1), false);
1166 assert_eq!(out.get(2), false);
1167 assert_eq!(out.get(3), true);
1168 }
1169}