1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
27
28use std::hash::Hash;
29use std::marker::PhantomData;
30
31use minarrow::traits::type_unions::Float;
32use minarrow::{Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, Numeric, StringAVT};
33
34use crate::errors::KernelError;
35#[cfg(not(feature = "simd"))]
36use crate::kernels::bitmask::std::{and_masks, not_mask};
37#[cfg(not(feature = "simd"))]
38use crate::kernels::comparison::cmp_bitmask_std;
39use crate::kernels::comparison::{
40 cmp_dict_dict, cmp_dict_str, cmp_numeric, cmp_str_dict, cmp_str_str,
41};
42use crate::kernels::logical::{
43 cmp_between, cmp_dict_between, cmp_dict_in, cmp_in, cmp_in_f, cmp_str_between, cmp_str_in,
44};
45use crate::operators::ComparisonOperator;
46use crate::utils::confirm_equal_len;
47
48#[inline(always)]
50fn new_bool_bitmask(len_bits: usize) -> Bitmask {
51 Bitmask::new_set_all(len_bits, false)
52}
53
54#[inline(always)]
56fn full_bool_bitmask(len_bits: usize) -> Bitmask {
57 Bitmask::new_set_all(len_bits, true)
58}
59
60
61#[inline]
64fn merge_bitmasks_to_new(
65 lhs: Option<&Bitmask>,
66 rhs: Option<&Bitmask>,
67 len: usize,
68) -> Option<Bitmask> {
69 if let Some(m) = lhs {
70 debug_assert!(
71 m.capacity() >= len,
72 "lhs null mask too small: capacity {} < len {}",
73 m.capacity(),
74 len
75 );
76 }
77 if let Some(m) = rhs {
78 debug_assert!(
79 m.capacity() >= len,
80 "rhs null mask too small: capacity {} < len {}",
81 m.capacity(),
82 len
83 );
84 }
85
86 match (lhs, rhs) {
87 (None, None) => None,
88
89 (Some(l), None) | (None, Some(l)) => {
90 let mut out = Bitmask::new_set_all(len, true);
91 for i in 0..len {
92 unsafe {
93 out.set_unchecked(i, l.get_unchecked(i));
94 }
95 }
96 Some(out)
97 }
98
99 (Some(l), Some(r)) => {
100 let mut out = Bitmask::new_set_all(len, true);
101 for i in 0..len {
102 unsafe {
103 out.set_unchecked(i, l.get_unchecked(i) && r.get_unchecked(i));
104 }
105 }
106 Some(out)
107 }
108 }
109}
110
111pub fn apply_cmp<T>(
147 lhs: &[T],
148 rhs: &[T],
149 mask: Option<&Bitmask>,
150 op: ComparisonOperator,
151) -> Result<BooleanArray<()>, KernelError>
152where
153 T: Numeric + Copy + Hash + Eq + PartialOrd + 'static,
154{
155 let len = lhs.len();
156 match op {
157 ComparisonOperator::Between => {
158 let mut out = cmp_between(lhs, rhs)?;
159 out.null_mask = mask.cloned();
160 Ok(out)
161 }
162 ComparisonOperator::In => {
163 let mut out = cmp_in(lhs, rhs)?;
164 out.null_mask = mask.cloned();
165 Ok(out)
166 }
167 ComparisonOperator::NotIn => {
168 let mut out = cmp_in(lhs, rhs)?;
169 for i in 0..len {
170 unsafe { out.data.set_unchecked(i, !out.data.get_unchecked(i)) };
171 }
172 out.null_mask = mask.cloned();
173 Ok(out)
174 }
175 ComparisonOperator::IsNull => Ok(BooleanArray {
176 data: new_bool_bitmask(len),
177 null_mask: mask.cloned(),
178 len,
179 _phantom: std::marker::PhantomData,
180 }),
181 ComparisonOperator::IsNotNull => Ok(BooleanArray {
182 data: full_bool_bitmask(len),
183 null_mask: mask.cloned(),
184 len,
185 _phantom: std::marker::PhantomData,
186 }),
187 _ => {
188 let mut out = cmp_numeric(lhs, rhs, mask, op)?;
189 out.null_mask = mask.cloned();
190 Ok(out)
191 }
192 }
193}
194
195pub fn apply_cmp_f<T>(
213 lhs: &[T],
214 rhs: &[T],
215 mask: Option<&Bitmask>,
216 op: ComparisonOperator,
217) -> Result<BooleanArray<()>, KernelError>
218where
219 T: Float + Numeric + Copy + 'static,
220{
221 let len = lhs.len();
222 match op {
223 ComparisonOperator::Between => {
224 let mut out = cmp_between(lhs, rhs)?;
225 out.null_mask = mask.cloned();
226 Ok(out)
227 }
228 ComparisonOperator::In => {
229 let mut out = cmp_in_f(lhs, rhs)?;
230 out.null_mask = mask.cloned();
231 Ok(out)
232 }
233 ComparisonOperator::NotIn => {
234 let mut out = cmp_in_f(lhs, rhs)?;
235 for i in 0..len {
236 unsafe { out.data.set_unchecked(i, !out.data.get_unchecked(i)) };
237 }
238 out.null_mask = mask.cloned();
239 Ok(out)
240 }
241 ComparisonOperator::IsNull => Ok(BooleanArray {
242 data: new_bool_bitmask(len),
243 null_mask: mask.cloned(),
244 len,
245 _phantom: std::marker::PhantomData,
246 }),
247 ComparisonOperator::IsNotNull => Ok(BooleanArray {
248 data: full_bool_bitmask(len),
249 null_mask: mask.cloned(),
250 len,
251 _phantom: std::marker::PhantomData,
252 }),
253 ComparisonOperator::Equals | ComparisonOperator::NotEquals => {
254 let mut out = cmp_numeric(lhs, rhs, mask, op)?;
255 for i in 0..len {
257 let is_valid = mask.map_or(true, |m| unsafe { m.get_unchecked(i) });
258 if is_valid && lhs[i].is_nan() && rhs[i].is_nan() {
259 match op {
260 ComparisonOperator::Equals => unsafe { out.data.set_unchecked(i, true) },
261 ComparisonOperator::NotEquals => unsafe {
262 out.data.set_unchecked(i, false)
263 },
264 _ => {}
265 }
266 }
267 }
268 out.null_mask = mask.cloned();
269 Ok(out)
270 }
271 _ => {
272 let mut out = cmp_numeric(lhs, rhs, mask, op)?;
273 out.null_mask = mask.cloned();
274 Ok(out)
275 }
276 }
277}
278
279#[inline(always)]
286pub fn apply_cmp_bool(
287 lhs: BooleanAVT<'_, ()>,
288 rhs: BooleanAVT<'_, ()>,
289 op: ComparisonOperator,
290) -> Result<BooleanArray<()>, KernelError> {
291 let (lhs_arr, lhs_off, len) = lhs;
292 let (rhs_arr, rhs_off, rlen) = rhs;
293 confirm_equal_len("apply_cmp_bool_windowed: window length mismatch", len, rlen)?;
294
295 #[cfg(feature = "simd")]
297 let merged_null_mask: Option<Bitmask> =
298 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
299 (None, None) => None,
300 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
301 (Some(a), Some(b)) => {
302 use crate::kernels::bitmask::simd::and_masks_simd;
303 let am = (a, lhs_off, len);
304 let bm = (b, rhs_off, len);
305 Some(and_masks_simd::<W8>(am, bm))
306 }
307 };
308
309 #[cfg(not(feature = "simd"))]
310 let merged_null_mask: Option<Bitmask> =
311 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
312 (None, None) => None,
313 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
314 (Some(a), Some(b)) => {
315 let am = (a, lhs_off, len);
316 let bm = (b, rhs_off, len);
317 Some(and_masks(am, bm))
318 }
319 };
320
321 let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
322
323 #[cfg(feature = "simd")]
324 let data = match op {
325 ComparisonOperator::Equals
326 | ComparisonOperator::NotEquals
327 | ComparisonOperator::LessThan
328 | ComparisonOperator::LessThanOrEqualTo
329 | ComparisonOperator::GreaterThan
330 | ComparisonOperator::GreaterThanOrEqualTo
331 | ComparisonOperator::In
332 | ComparisonOperator::NotIn => crate::kernels::comparison::cmp_bitmask_simd::<W8>(
333 (&lhs_arr.data, lhs_off, len),
334 (&rhs_arr.data, rhs_off, len),
335 mask_slice,
336 op,
337 )?,
338 ComparisonOperator::IsNull => {
339 let data = match merged_null_mask.as_ref() {
340 Some(mask) => crate::kernels::bitmask::simd::not_mask_simd::<W8>((mask, 0, len)),
341 None => Bitmask::new_set_all(len, false),
342 };
343 return Ok(BooleanArray {
344 data,
345 null_mask: None,
346 len,
347 _phantom: PhantomData,
348 });
349 }
350 ComparisonOperator::IsNotNull => {
351 let data = match merged_null_mask.as_ref() {
352 Some(mask) => mask.slice_clone(0, len),
353 None => Bitmask::new_set_all(len, true),
354 };
355 return Ok(BooleanArray {
356 data,
357 null_mask: None,
358 len,
359 _phantom: PhantomData,
360 });
361 }
362 ComparisonOperator::Between => {
363 return Err(KernelError::InvalidArguments(
364 "Set operations are not defined for Bool arrays".to_owned(),
365 ));
366 }
367 };
368
369 #[cfg(not(feature = "simd"))]
370 let data = match op {
371 ComparisonOperator::Equals
372 | ComparisonOperator::NotEquals
373 | ComparisonOperator::LessThan
374 | ComparisonOperator::LessThanOrEqualTo
375 | ComparisonOperator::GreaterThan
376 | ComparisonOperator::GreaterThanOrEqualTo
377 | ComparisonOperator::In
378 | ComparisonOperator::NotIn => cmp_bitmask_std(
379 (&lhs_arr.data, lhs_off, len),
380 (&rhs_arr.data, rhs_off, len),
381 mask_slice,
382 op,
383 )?,
384 ComparisonOperator::IsNull => {
385 let data = match merged_null_mask.as_ref() {
386 Some(mask) => not_mask((mask, 0, len)),
387 None => Bitmask::new_set_all(len, false),
388 };
389 return Ok(BooleanArray {
390 data,
391 null_mask: None,
392 len,
393 _phantom: PhantomData,
394 });
395 }
396 ComparisonOperator::IsNotNull => {
397 let data = match merged_null_mask.as_ref() {
398 Some(mask) => mask.slice_clone(0, len),
399 None => Bitmask::new_set_all(len, true),
400 };
401 return Ok(BooleanArray {
402 data,
403 null_mask: None,
404 len,
405 _phantom: PhantomData,
406 });
407 }
408 ComparisonOperator::Between => {
409 return Err(KernelError::InvalidArguments(
410 "Set operations are not defined for Bool arrays".to_owned(),
411 ));
412 }
413 };
414
415 Ok(BooleanArray {
416 data,
417 null_mask: merged_null_mask,
418 len,
419 _phantom: PhantomData,
420 })
421}
422
423pub fn apply_cmp_str<T: Integer>(
448 lhs: StringAVT<T>,
449 rhs: StringAVT<T>,
450 op: ComparisonOperator,
451) -> Result<BooleanArray<()>, KernelError> {
452 let (larr, loff, llen) = lhs;
454 let (rarr, roff, rlen) = rhs;
455
456 assert_eq!(llen, rlen, "apply_cmp_str: slice lengths must match");
457
458 let null_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
459
460 let mut out = match op {
461 ComparisonOperator::Between => cmp_str_between((larr, loff, llen), (rarr, roff, rlen)),
462 ComparisonOperator::In => cmp_str_in((larr, loff, llen), (rarr, roff, rlen)),
463 ComparisonOperator::NotIn => {
464 let mut b = cmp_str_in((larr, loff, llen), (rarr, roff, rlen))?;
465 debug_assert!(
466 b.data.capacity() >= llen,
467 "bitmask capacity {} < needed len {}",
468 b.data.capacity(),
469 llen
470 );
471 for i in 0..llen {
472 unsafe { b.data.set_unchecked(i, !b.data.get_unchecked(i)) };
473 }
474 Ok(b)
475 }
476 ComparisonOperator::IsNull => Ok(BooleanArray {
477 data: new_bool_bitmask(llen),
478 null_mask: null_mask.clone(),
479 len: llen,
480 _phantom: std::marker::PhantomData,
481 }),
482 ComparisonOperator::IsNotNull => Ok(BooleanArray {
483 data: full_bool_bitmask(llen),
484 null_mask: null_mask.clone(),
485 len: llen,
486 _phantom: std::marker::PhantomData,
487 }),
488 _ => cmp_str_str((larr, loff, llen), (rarr, roff, rlen), op),
489 }?;
490 out.null_mask = null_mask;
491 out.len = llen;
492 Ok(out)
493}
494
495pub fn apply_cmp_str_dict<T: Integer, U: Integer>(
515 lhs: StringAVT<T>,
516 rhs: CategoricalAVT<U>,
517 op: ComparisonOperator,
518) -> Result<BooleanArray<()>, KernelError> {
519 let (larr, loff, llen) = lhs;
520 let (rarr, roff, rlen) = rhs;
521 assert_eq!(llen, rlen, "apply_cmp_str_dict: slice lengths must match");
522
523 let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
525 let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
526 let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
527
528 let mut out = cmp_str_dict((larr, loff, llen), (rarr, roff, rlen), op)?;
529 out.null_mask = null_mask;
530 out.len = llen;
531 Ok(out)
532}
533
534pub fn apply_cmp_dict_str<T: Integer, U: Integer>(
554 lhs: CategoricalAVT<T>,
555 rhs: StringAVT<U>,
556 op: ComparisonOperator,
557) -> Result<BooleanArray<()>, KernelError> {
558 let (larr, loff, llen) = lhs;
559 let (rarr, roff, rlen) = rhs;
560 assert_eq!(llen, rlen, "apply_cmp_dict_str: slice lengths must match");
561
562 let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
564 let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
565 let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
566
567 let mut out = cmp_dict_str((larr, loff, llen), (rarr, roff, rlen), op)?;
568 out.null_mask = null_mask;
569 out.len = llen;
570 Ok(out)
571}
572
573pub fn apply_cmp_dict<T: Integer + Hash>(
594 lhs: CategoricalAVT<T>,
595 rhs: CategoricalAVT<T>,
596 op: ComparisonOperator,
597) -> Result<BooleanArray<()>, KernelError> {
598 let (larr, loff, llen) = lhs;
599 let (rarr, roff, rlen) = rhs;
600 assert_eq!(llen, rlen, "apply_cmp_dict: slice lengths must match");
601 let null_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
602 let mut out = match op {
603 ComparisonOperator::Between => cmp_dict_between((larr, loff, llen), (rarr, roff, rlen)),
604 ComparisonOperator::In => cmp_dict_in((larr, loff, llen), (rarr, roff, rlen)),
605 ComparisonOperator::NotIn => {
606 let mut b = cmp_dict_in((larr, loff, llen), (rarr, roff, rlen))?;
607 for i in 0..llen {
608 unsafe {
609 b.data.set_unchecked(i, !b.data.get_unchecked(i));
610 }
611 }
612 Ok(b)
613 }
614 ComparisonOperator::IsNull => Ok(BooleanArray {
615 data: new_bool_bitmask(llen),
616 null_mask: null_mask.clone(),
617 len: llen,
618 _phantom: std::marker::PhantomData,
619 }),
620 ComparisonOperator::IsNotNull => Ok(BooleanArray {
621 data: full_bool_bitmask(llen),
622 null_mask: null_mask.clone(),
623 len: llen,
624 _phantom: std::marker::PhantomData,
625 }),
626 _ => cmp_dict_dict((larr, loff, llen), (rarr, roff, rlen), op),
627 }?;
628 out.null_mask = null_mask;
629 out.len = llen;
630 Ok(out)
631}
632
633#[cfg(test)]
634mod tests {
635 use minarrow::structs::variants::categorical::CategoricalArray;
636 use minarrow::structs::variants::string::StringArray;
637 use minarrow::{Bitmask, BooleanArray, MaskedArray, vec64};
638
639 use super::*;
640
641 fn bm(bools: &[bool]) -> Bitmask {
643 Bitmask::from_bools(bools)
644 }
645 fn bool_arr(bools: &[bool]) -> BooleanArray<()> {
646 BooleanArray::from_slice(bools)
647 }
648
649 #[test]
651 fn test_apply_cmp_numeric_all_ops() {
652 let a = vec64![1, 2, 3, 4, 5, 6];
653 let b = vec64![3, 2, 1, 4, 5, 0];
654 let mask = bm(&[true, false, true, true, true, true]);
655
656 for &op in &[
658 ComparisonOperator::Equals,
659 ComparisonOperator::NotEquals,
660 ComparisonOperator::LessThan,
661 ComparisonOperator::LessThanOrEqualTo,
662 ComparisonOperator::GreaterThan,
663 ComparisonOperator::GreaterThanOrEqualTo,
664 ] {
665 let arr = apply_cmp(&a, &b, Some(&mask), op).unwrap();
666 for i in 0..a.len() {
667 let expect = match op {
668 ComparisonOperator::Equals => a[i] == b[i],
669 ComparisonOperator::NotEquals => a[i] != b[i],
670 ComparisonOperator::LessThan => a[i] < b[i],
671 ComparisonOperator::LessThanOrEqualTo => a[i] <= b[i],
672 ComparisonOperator::GreaterThan => a[i] > b[i],
673 ComparisonOperator::GreaterThanOrEqualTo => a[i] >= b[i],
674 _ => unreachable!(),
675 };
676 if mask.get(i) {
677 assert_eq!(arr.data.get(i), expect);
678 } else {
679 assert_eq!(arr.get(i), None);
680 }
681 }
682 assert_eq!(arr.null_mask, Some(mask.clone()));
683 }
684 }
685
686 #[test]
687 fn test_apply_cmp_numeric_between_in_notin() {
688 let a = vec64![4, 2, 3, 5];
689 let mask = bm(&[true, true, false, true]);
690 let rhs = vec64![2, 4];
692 let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::Between).unwrap();
693 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), true); assert_eq!(arr.get(2), None);
696 assert_eq!(arr.data.get(3), false); let rhs = vec64![2, 3, 4];
699 let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::In).unwrap();
700 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), true); assert_eq!(arr.get(2), None);
703 assert_eq!(arr.data.get(3), false); let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::NotIn).unwrap();
706 assert_eq!(arr.data.get(0), false);
707 assert_eq!(arr.data.get(1), false);
708 assert_eq!(arr.get(2), None);
709 assert_eq!(arr.data.get(3), true);
710 }
711
712 #[test]
713 fn test_apply_cmp_numeric_isnull_isnotnull() {
714 let a = vec64![1, 2, 3];
715 let mask = bm(&[true, false, true]);
716 let arr = apply_cmp(&a, &a, Some(&mask), ComparisonOperator::IsNull).unwrap();
717 assert_eq!(arr.data.get(0), false);
718 assert_eq!(arr.data.get(1), false);
719 assert_eq!(arr.data.get(2), false);
720 assert_eq!(arr.null_mask, Some(mask.clone()));
721 let arr = apply_cmp(&a, &a, Some(&mask), ComparisonOperator::IsNotNull).unwrap();
722 assert_eq!(arr.data.get(0), true);
723 assert_eq!(arr.data.get(1), true);
724 assert_eq!(arr.data.get(2), true);
725 assert_eq!(arr.null_mask, Some(mask.clone()));
726 }
727
728 #[test]
729 fn test_apply_cmp_numeric_edge_cases() {
730 let a: [i32; 0] = [];
732 let arr = apply_cmp(&a, &a, None, ComparisonOperator::Equals).unwrap();
733 assert_eq!(arr.len, 0);
734 assert!(arr.null_mask.is_none());
735 let a = vec64![7];
737 let arr = apply_cmp(&a, &a, None, ComparisonOperator::Equals).unwrap();
738 assert_eq!(arr.data.get(0), true);
739 assert!(arr.null_mask.is_none());
740 }
741
742 #[test]
743 fn test_apply_cmp_f_all_ops_nan_patch() {
744 let a = vec64![1.0, 2.0, f32::NAN, f32::NAN];
745 let b = vec64![1.0, 3.0, f32::NAN, 0.0];
746 let mask = bm(&[true, true, true, false]);
747 for &op in &[ComparisonOperator::Equals, ComparisonOperator::NotEquals] {
749 let arr = apply_cmp_f(&a, &b, Some(&mask), op).unwrap();
750 assert_eq!(arr.data.get(2), matches!(op, ComparisonOperator::Equals)) }
752 let arr = apply_cmp_f(&a, &b, Some(&mask), ComparisonOperator::In).unwrap();
754 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), false);
756 }
757
758 #[test]
759 fn test_cmp_bool_w8() {
760 let a = bool_arr(&[true, false, true]);
761 let b = bool_arr(&[false, false, true]);
762 let op = ComparisonOperator::Equals;
763 let arr = apply_cmp_bool((&a, 0, a.len()), (&b, 0, b.len()), op).unwrap();
764 assert!(!arr.data.get(0));
765 assert!(arr.data.get(1));
766 assert!(arr.data.get(2));
767 println!("mask bytes: {:02x?}", arr.data.bits);
768 println!("get(0): {}", arr.data.get(0));
769 println!("get(1): {}", arr.data.get(1));
770 println!("get(2): {}", arr.data.get(2));
771 println!("lhs: {:?}", a);
772 println!("rhs: {:?}", b);
773 println!(
774 "{}: mask bytes: {:?} get(0): {} get(1): {} get(2): {}",
775 stringify!($test_name),
776 arr.data.as_slice(),
777 arr.data.get(0),
778 arr.data.get(1),
779 arr.data.get(2)
780 );
781
782 let arr = apply_cmp_bool(
784 (&a, 0, a.len()),
785 (&b, 0, b.len()),
786 ComparisonOperator::NotEquals,
787 )
788 .unwrap();
789 assert!(arr.data.get(0));
790 assert!(!arr.data.get(1));
791 assert!(!arr.data.get(2));
792
793 let arr = apply_cmp_bool(
795 (&a, 0, a.len()),
796 (&b, 0, b.len()),
797 ComparisonOperator::LessThan,
798 )
799 .unwrap();
800 assert!(!arr.data.get(0));
801 assert!(!arr.data.get(1));
802 assert!(!arr.data.get(2));
803
804 let mut a = bool_arr(&[true, false]);
806 a.null_mask = Some(bm(&[true, false]));
807 let mut b = bool_arr(&[true, false]);
808 b.null_mask = Some(bm(&[true, true]));
809 let arr = apply_cmp_bool(
810 (&a, 0, a.len()),
811 (&b, 0, b.len()),
812 ComparisonOperator::Equals,
813 )
814 .unwrap();
815 assert!(arr.null_mask.as_ref().unwrap().get(0));
816 assert!(!arr.null_mask.as_ref().unwrap().get(1));
817 }
818
819 #[test]
820 fn test_bool_is_null() {
821 let a = bool_arr(&[true, false]);
822 let b = bool_arr(&[false, true]);
823 let arr = apply_cmp_bool(
824 (&a, 0, a.len()),
825 (&b, 0, b.len()),
826 ComparisonOperator::IsNull,
827 )
828 .unwrap();
829 assert!(!arr.data.get(0));
830 assert!(!arr.data.get(1));
831 let arr = apply_cmp_bool(
832 (&a, 0, a.len()),
833 (&b, 0, b.len()),
834 ComparisonOperator::IsNotNull,
835 )
836 .unwrap();
837 assert!(arr.data.get(0));
838 assert!(arr.data.get(1));
839 }
840
841 #[test]
844 fn test_apply_cmp_str_all_ops() {
845 let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "qux"]);
846 let b = StringArray::<u32>::from_slice(&["foo", "baz", "baz", "quux"]);
847 let mut a2 = a.clone();
848 a2.set_null(2);
849 let a_slice = (&a, 0, a.len());
850 let b_slice = (&b, 0, b.len());
851 let a2_slice = (&a2, 0, a2.len());
852
853 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Equals).unwrap();
854 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), false); assert_eq!(arr.data.get(2), true); assert_eq!(arr.data.get(3), false);
858
859 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotEquals).unwrap();
861 assert_eq!(arr.data.get(0), false);
862 assert_eq!(arr.data.get(1), true);
863 assert_eq!(arr.data.get(2), false);
864 assert_eq!(arr.data.get(3), true);
865
866 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::LessThan).unwrap();
868 assert_eq!(arr.data.get(0), false); assert_eq!(arr.data.get(1), true); assert_eq!(arr.data.get(2), false);
871 assert_eq!(arr.data.get(3), false);
872
873 let mut b2 = b.clone();
875 b2.set_null(1);
876 let b2_slice = (&b2, 0, b2.len());
877 let arr = apply_cmp_str(a2_slice, b2_slice, ComparisonOperator::Equals).unwrap();
878 assert!(!arr.null_mask.as_ref().unwrap().get(2));
879 assert!(!arr.null_mask.as_ref().unwrap().get(1));
880 assert!(arr.null_mask.as_ref().unwrap().get(0));
881 assert!(arr.null_mask.as_ref().unwrap().get(3));
882 }
883
884 #[test]
885 fn test_apply_cmp_str_all_ops_chunk() {
886 let a = StringArray::<u32>::from_slice(&["x", "foo", "bar", "baz", "qux", "y"]);
887 let b = StringArray::<u32>::from_slice(&["q", "foo", "baz", "baz", "quux", "z"]);
888 let a_slice = (&a, 1, 4);
890 let b_slice = (&b, 1, 4);
891 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Equals).unwrap();
892 assert_eq!(arr.data.get(0), true); assert_eq!(arr.data.get(1), false); assert_eq!(arr.data.get(2), true); assert_eq!(arr.data.get(3), false);
896 }
897
898 #[test]
899 fn test_apply_cmp_str_set_ops() {
900 let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz"]);
901 let b = StringArray::<u32>::from_slice(&["foo", "qux", "baz"]);
902 let a_slice = (&a, 0, a.len());
903 let b_slice = (&b, 0, b.len());
904 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Between).unwrap();
906 assert_eq!(arr.len, 3);
907 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::In).unwrap();
909 let arr2 = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
910 for i in 0..a.len() {
911 assert_eq!(arr.data.get(i), !arr2.data.get(i));
912 }
913 }
914
915 #[test]
916 fn test_apply_cmp_str_set_ops_chunk() {
917 let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "w"]);
918 let b = StringArray::<u32>::from_slice(&["foo", "qux", "baz", "w"]);
919 let a_slice = (&a, 1, 2);
921 let b_slice = (&b, 1, 2);
922 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Between).unwrap();
924 assert_eq!(arr.len, 2);
925 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::In).unwrap();
927 let arr2 = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
928 for i in 0..2 {
929 assert_eq!(arr.data.get(i), !arr2.data.get(i));
930 }
931 }
932
933 #[test]
934 fn test_apply_cmp_str_isnull_isnotnull() {
935 let a = StringArray::<u32>::from_slice(&["foo"]);
936 let b = StringArray::<u32>::from_slice(&["bar"]);
937 let a_slice = (&a, 0, a.len());
938 let b_slice = (&b, 0, b.len());
939 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
940 assert_eq!(arr.data.get(0), false);
941 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
942 assert_eq!(arr.data.get(0), true);
943 }
944
945 #[test]
946 fn test_apply_cmp_str_isnull_isnotnull_chunk() {
947 let a = StringArray::<u32>::from_slice(&["pad", "foo"]);
948 let b = StringArray::<u32>::from_slice(&["pad", "bar"]);
949 let a_slice = (&a, 1, 1);
950 let b_slice = (&b, 1, 1);
951 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
952 assert_eq!(arr.data.get(0), false);
953 let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
954 assert_eq!(arr.data.get(0), true);
955 }
956
957 #[test]
960 fn test_apply_cmp_str_dict() {
961 let s = StringArray::<u32>::from_slice(&["a", "b", "c"]);
962 let dict = CategoricalArray::<u32>::from_slices(&[0, 1, 0], &["a".into(), "b".into()]);
963
964 let s_slice = (&s, 0, s.len());
965 let dict_slice = (&dict, 0, dict.data.len());
966 let arr = apply_cmp_str_dict(s_slice, dict_slice, ComparisonOperator::Equals).unwrap();
967 assert_eq!(arr.len, 3);
968
969 let mut s2 = s.clone();
971 s2.set_null(0);
972 let mut d2 = dict.clone();
973 d2.set_null(1);
974
975 let s2_slice = (&s2, 0, s2.len());
976 let d2_slice = (&d2, 0, d2.data.len());
977 let arr = apply_cmp_str_dict(s2_slice, d2_slice, ComparisonOperator::Equals).unwrap();
978
979 let mask = arr.null_mask.as_ref().unwrap();
980 assert!(!mask.get(0));
981 assert!(!mask.get(1));
982 assert!(mask.get(2));
983 }
984
985 #[test]
986 fn test_apply_cmp_str_dict_chunk() {
987 let s = StringArray::<u32>::from_slice(&["pad", "a", "b", "c", "pad2"]);
988 let dict = CategoricalArray::<u32>::from_slices(
989 &[2, 0, 1, 0, 2], &["z".into(), "a".into(), "b".into()],
991 );
992 let s_slice = (&s, 1, 3);
994 let dict_slice = (&dict, 1, 3);
995 let arr = apply_cmp_str_dict(s_slice, dict_slice, ComparisonOperator::Equals).unwrap();
996 assert_eq!(arr.len, 3);
997 }
998
999 #[test]
1002 fn test_apply_cmp_dict_str() {
1003 let dict = CategoricalArray::<u32>::from_slices(&[0, 1, 0], &["a".into(), "b".into()]);
1004 let s = StringArray::<u32>::from_slice(&["a", "b", "c"]);
1005 let dict_slice = (&dict, 0, dict.data.len());
1006 let s_slice = (&s, 0, s.len());
1007 let arr = apply_cmp_dict_str(dict_slice, s_slice, ComparisonOperator::Equals).unwrap();
1008 assert_eq!(arr.len, 3);
1009 }
1010
1011 #[test]
1012 fn test_apply_cmp_dict_str_chunk() {
1013 let dict = CategoricalArray::<u32>::from_slices(
1014 &[2, 0, 1, 0, 2], &["z".into(), "a".into(), "b".into()],
1016 );
1017 let s = StringArray::<u32>::from_slice(&["pad", "a", "b", "c", "pad2"]);
1018 let dict_slice = (&dict, 1, 3);
1019 let s_slice = (&s, 1, 3);
1020 let arr = apply_cmp_dict_str(dict_slice, s_slice, ComparisonOperator::Equals).unwrap();
1021 assert_eq!(arr.len, 3);
1022 }
1023
1024 #[test]
1027 fn test_apply_cmp_dict_all_ops() {
1028 let a = CategoricalArray::<u32>::from_slices(
1029 &[0, 1, 2],
1030 &["dog".into(), "cat".into(), "fish".into()],
1031 );
1032 let b = CategoricalArray::<u32>::from_slices(
1033 &[2, 1, 0],
1034 &["fish".into(), "cat".into(), "dog".into()],
1035 );
1036
1037 let a_slice = (&a, 0, a.data.len());
1038 let b_slice = (&b, 0, b.data.len());
1039
1040 for &op in &[
1042 ComparisonOperator::Equals,
1043 ComparisonOperator::NotEquals,
1044 ComparisonOperator::LessThan,
1045 ComparisonOperator::LessThanOrEqualTo,
1046 ComparisonOperator::GreaterThan,
1047 ComparisonOperator::GreaterThanOrEqualTo,
1048 ] {
1049 let arr = apply_cmp_dict(a_slice, b_slice, op).unwrap();
1050 assert_eq!(arr.len, 3);
1051 }
1052 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1054 assert_eq!(arr.len, 3);
1055 let arr2 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::In).unwrap();
1056 let arr3 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1057 for i in 0..3 {
1058 assert_eq!(arr2.data.get(i), !arr3.data.get(i));
1059 }
1060 }
1061
1062 #[test]
1063 fn test_apply_cmp_dict_all_ops_chunk() {
1064 let a = CategoricalArray::<u32>::from_slices(
1065 &[0, 1, 2, 3, 1], &["pad".into(), "dog".into(), "cat".into(), "fish".into()],
1067 );
1068 let b = CategoricalArray::<u32>::from_slices(
1069 &[3, 2, 1, 0, 2], &["foo".into(), "fish".into(), "cat".into(), "dog".into()],
1071 );
1072 let a_slice = (&a, 1, 3);
1074 let b_slice = (&b, 1, 3);
1075
1076 for &op in &[
1077 ComparisonOperator::Equals,
1078 ComparisonOperator::NotEquals,
1079 ComparisonOperator::LessThan,
1080 ComparisonOperator::LessThanOrEqualTo,
1081 ComparisonOperator::GreaterThan,
1082 ComparisonOperator::GreaterThanOrEqualTo,
1083 ] {
1084 let arr = apply_cmp_dict(a_slice, b_slice, op).unwrap();
1085 assert_eq!(arr.len, 3);
1086 }
1087 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1088 assert_eq!(arr.len, 3);
1089 let arr2 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::In).unwrap();
1090 let arr3 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1091 for i in 0..3 {
1092 assert_eq!(arr2.data.get(i), !arr3.data.get(i));
1093 }
1094 }
1095
1096 #[test]
1097 fn test_apply_cmp_dict_isnull_isnotnull() {
1098 let a = CategoricalArray::<u32>::from_slices(&[0, 1], &["x".into(), "y".into()]);
1099 let b = CategoricalArray::<u32>::from_slices(&[1, 0], &["y".into(), "x".into()]);
1100 let a_slice = (&a, 0, a.data.len());
1101 let b_slice = (&b, 0, b.data.len());
1102 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1103 assert_eq!(arr.data.get(0), false);
1104 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1105 assert_eq!(arr.data.get(0), true);
1106 }
1107
1108 #[test]
1109 fn test_apply_cmp_dict_isnull_isnotnull_chunk() {
1110 let a = CategoricalArray::<u32>::from_slices(
1111 &[2, 0, 1, 2],
1112 &["z".into(), "x".into(), "y".into()],
1113 );
1114 let b = CategoricalArray::<u32>::from_slices(
1115 &[2, 1, 0, 1],
1116 &["w".into(), "y".into(), "x".into(), "z".into()],
1117 );
1118 let a_slice = (&a, 1, 2);
1119 let b_slice = (&b, 1, 2);
1120 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1121 assert_eq!(arr.data.get(0), false);
1122 let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1123 assert_eq!(arr.data.get(0), true);
1124 }
1125
1126 #[test]
1127 #[should_panic(expected = "All indices must be valid for unique_values")]
1128 fn test_apply_cmp_dict_isnull_isnotnull_chunk_invalid_indices() {
1129 let a = CategoricalArray::<u32>::from_slices(
1130 &[9, 0, 1, 9], &["z".into(), "x".into(), "y".into()],
1132 );
1133 let b = CategoricalArray::<u32>::from_slices(
1134 &[2, 1, 0, 3], &["w".into(), "y".into(), "x".into(), "z".into()],
1136 );
1137 let a_slice = (&a, 1, 2);
1138 let b_slice = (&b, 1, 2);
1139 let _ = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1140 let _ = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1141 }
1142
1143 #[test]
1145 fn test_merge_bitmasks_to_new_none_none() {
1146 assert!(merge_bitmasks_to_new(None, None, 5).is_none());
1147 }
1148 #[test]
1149 fn test_merge_bitmasks_to_new_some_none() {
1150 let m = bm(&[true, false, true]);
1151 let out = merge_bitmasks_to_new(Some(&m), None, 3).unwrap();
1152 for i in 0..3 {
1153 assert_eq!(out.get(i), m.get(i));
1154 }
1155 let out2 = merge_bitmasks_to_new(None, Some(&m), 3).unwrap();
1156 for i in 0..3 {
1157 assert_eq!(out2.get(i), m.get(i));
1158 }
1159 }
1160 #[test]
1161 fn test_merge_bitmasks_to_new_both_some_and() {
1162 let a = bm(&[true, false, true, true]);
1163 let b = bm(&[true, true, false, true]);
1164 let out = merge_bitmasks_to_new(Some(&a), Some(&b), 4).unwrap();
1165 assert_eq!(out.get(0), true);
1166 assert_eq!(out.get(1), false);
1167 assert_eq!(out.get(2), false);
1168 assert_eq!(out.get(3), true);
1169 }
1170}