1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
22
23#[cfg(feature = "fast_hash")]
24use ahash::AHashMap;
25#[cfg(feature = "str_arithmetic")]
26use core::ptr::copy_nonoverlapping;
27#[cfg(not(feature = "fast_hash"))]
28use std::collections::HashMap;
29
30#[cfg(feature = "str_arithmetic")]
31use memchr::memmem::Finder;
32use minarrow::structs::variants::categorical::CategoricalArray;
33
34use minarrow::structs::variants::string::StringArray;
35use minarrow::traits::type_unions::Integer;
36use minarrow::{Bitmask, Vec64};
37#[cfg(feature = "str_arithmetic")]
38use num_traits::ToPrimitive;
39
40use crate::config::STRING_MULTIPLICATION_LIMIT;
41use crate::errors::{KernelError, log_length_mismatch};
42#[cfg(feature = "str_arithmetic")]
43use crate::kernels::string::string_predicate_masks;
44use crate::operators::ArithmeticOperator::{self};
45#[cfg(feature = "str_arithmetic")]
46use crate::utils::format_finite;
47use crate::utils::merge_bitmasks_to_new;
48#[cfg(feature = "str_arithmetic")]
49use crate::utils::{
50 confirm_mask_capacity, estimate_categorical_cardinality, estimate_string_cardinality,
51};
52use minarrow::{CategoricalAVT, StringAVTExt};
53#[cfg(feature = "str_arithmetic")]
54use minarrow::{MaskedArray, StringAVT};
55
56pub fn apply_str_num<T, N, O>(
65 lhs: StringAVTExt<T>,
66 rhs: &[N],
67 op: ArithmeticOperator,
68) -> Result<StringArray<O>, KernelError>
69where
70 T: Integer,
71 N: num_traits::ToPrimitive + Copy,
72 O: Integer + num_traits::NumCast,
73{
74 let (array, offset, logical_len, physical_bytes_len) = lhs;
75
76 if logical_len != rhs.len() {
77 return Err(KernelError::LengthMismatch(log_length_mismatch(
78 "apply_str_num".to_string(),
79 logical_len,
80 rhs.len(),
81 )));
82 }
83
84 let lhs_mask = array.null_mask.as_ref();
86 let mut out_mask = lhs_mask.map(|_| minarrow::Bitmask::new_set_all(logical_len, true));
87
88 let mut offsets = Vec64::<O>::with_capacity(logical_len + 1);
89 offsets.push(O::zero()); let estimated_bytes = physical_bytes_len.min(STRING_MULTIPLICATION_LIMIT * logical_len);
92 let mut data = Vec64::with_capacity(estimated_bytes);
93
94 for (out_idx, i) in (offset..offset + logical_len).enumerate() {
95 let valid = lhs_mask.map_or(true, |mask| unsafe { mask.get_unchecked(i) });
96
97 if let Some(mask) = &mut out_mask {
98 unsafe { mask.set_unchecked(out_idx, valid) };
99 }
100
101 if valid {
102 let s = unsafe { array.get_str_unchecked(i) };
103 let n = rhs[out_idx].to_usize().unwrap_or(0);
104
105 match op {
106 ArithmeticOperator::Multiply => {
107 let count = n.min(STRING_MULTIPLICATION_LIMIT);
108 for _ in 0..count {
109 data.extend_from_slice(s.as_bytes());
110 }
111 }
112 _ => {
113 data.extend_from_slice(s.as_bytes());
114 }
115 }
116 }
117
118 let new_offset = O::from(data.len()).expect("offset conversion overflow");
121 offsets.push(new_offset);
122 }
123
124 Ok(StringArray {
125 offsets: offsets.into(),
126 data: data.into(),
127 null_mask: out_mask,
128 })
129}
130
131#[cfg(feature = "str_arithmetic")]
167pub fn apply_str_float<T, F>(
168 lhs: StringAVT<T>,
169 rhs: &[F],
170 op: ArithmeticOperator,
171) -> Result<StringArray<T>, KernelError>
172where
173 T: Integer,
174 F: Into<f64> + Copy + ryu::Float,
175{
176 let (array, offset, logical_len) = lhs;
178
179 use std::mem::MaybeUninit;
182 if rhs.len() != logical_len {
183 return Err(KernelError::LengthMismatch(log_length_mismatch(
184 "apply_str_float".into(),
185 logical_len,
186 rhs.len(),
187 )));
188 }
189 let lhs_mask = &array.null_mask;
190 let _ = confirm_mask_capacity(array.len(), lhs_mask.as_ref())?;
191
192 let mut total_bytes = 0usize;
194 let mut fmt_buf: [MaybeUninit<u8>; 24] = unsafe { MaybeUninit::uninit().assume_init() };
195
196 for (out_idx, i) in (offset..offset + logical_len).enumerate() {
197 if !lhs_mask
198 .as_ref()
199 .map_or(true, |m| unsafe { m.get_unchecked(i) })
200 {
201 continue;
202 }
203
204 let src_len = {
206 let a = array.offsets[i].to_usize();
207 let b = array.offsets[i + 1].to_usize();
208 b - a
209 };
210 let n_s = format_finite(&mut fmt_buf, rhs[out_idx]);
211 total_bytes += match op {
212 ArithmeticOperator::Add => src_len + n_s.len(),
213 ArithmeticOperator::Subtract => src_len,
214 ArithmeticOperator::Multiply => {
215 let times =
216 rhs[out_idx].into().round().abs() as usize % (STRING_MULTIPLICATION_LIMIT + 1);
217 src_len * times
218 }
219 ArithmeticOperator::Divide => {
220 let pat_len = n_s.len();
221 let splits = (src_len + pat_len).saturating_sub(1) / pat_len;
222 src_len + splits
223 }
224 _ => {
225 return Err(KernelError::UnsupportedType(format!(
226 "Unsupported {:?}",
227 op
228 )));
229 }
230 };
231 }
232
233 let mut offsets = Vec64::<T>::with_capacity(logical_len + 1);
235
236 let mut data = Vec64::<u8>::with_capacity(total_bytes);
238 unsafe {
239 offsets.set_len(logical_len + 1);
240 data.set_len(total_bytes);
241 }
242
243 let mut out_mask = lhs_mask
244 .as_ref()
245 .map(|_| Bitmask::new_set_all(logical_len, false));
246
247 let mut cursor = 0usize;
248 offsets[0] = T::zero();
249
250 for (out_idx, i) in (offset..offset + logical_len).enumerate() {
251 let valid = lhs_mask
252 .as_ref()
253 .map_or(true, |m| unsafe { m.get_unchecked(i) });
254 if let Some(mask) = &mut out_mask {
255 unsafe { mask.set_unchecked(out_idx, valid) };
256 }
257
258 if !valid {
259 offsets[out_idx + 1] = T::from(cursor).unwrap();
260 continue;
261 }
262
263 let start = array.offsets[i].to_usize();
264 let end = array.offsets[i + 1].to_usize();
265 let src = &array.data[start..end];
266 let n_s = format_finite(&mut fmt_buf, rhs[out_idx]);
267 let pat = n_s.as_bytes();
268
269 let mut write = |bytes: &[u8]| unsafe {
270 copy_nonoverlapping(bytes.as_ptr(), data.as_mut_ptr().add(cursor), bytes.len());
271 cursor += bytes.len();
272 };
273
274 match op {
275 ArithmeticOperator::Add => {
276 write(src);
277 write(pat);
278 }
279 ArithmeticOperator::Subtract => {
280 if let Some(idx) = Finder::new(pat).find(src) {
281 write(&src[..idx]);
282 write(&src[(idx + pat.len())..]);
283 } else {
284 write(src);
285 }
286 }
287 ArithmeticOperator::Multiply => {
288 let times =
289 rhs[out_idx].into().round().abs() as usize % (STRING_MULTIPLICATION_LIMIT + 1);
290 for _ in 0..times {
291 write(src);
292 }
293 }
294 ArithmeticOperator::Divide => {
295 let finder = Finder::new(pat);
296 let mut start_pos = 0;
297 let mut first = true;
298 while let Some(idx) = finder.find(&src[start_pos..]) {
299 if !first {
300 data[cursor] = b'|';
301 cursor += 1;
302 }
303 let rel_idx = idx;
304 let segment = &src[start_pos..start_pos + rel_idx];
305 unsafe {
306 copy_nonoverlapping(
307 segment.as_ptr(),
308 data.as_mut_ptr().add(cursor),
309 segment.len(),
310 );
311 cursor += segment.len();
312 }
313 start_pos += rel_idx + pat.len();
314 first = false;
315 }
316 if !first {
317 data[cursor] = b'|';
318 cursor += 1;
319 }
320 let tail = &src[start_pos..];
321 unsafe {
322 copy_nonoverlapping(tail.as_ptr(), data.as_mut_ptr().add(cursor), tail.len());
323 cursor += tail.len();
324 }
325 }
326 _ => unreachable!(),
327 }
328 offsets[out_idx + 1] = T::from(cursor).unwrap();
329 }
330
331 Ok(StringArray {
334 offsets: offsets.into(),
335 data: data.into(),
336 null_mask: out_mask,
337 })
338}
339
340#[cfg(feature = "fast_hash")]
343#[inline(always)]
344fn intern(s: &str, dict: &mut AHashMap<String, u32>, uniq: &mut Vec64<String>) -> u32 {
345 if let Some(&code) = dict.get(s) {
346 code
347 } else {
348 let idx = uniq.len() as u32;
349 uniq.push(s.to_owned());
350 dict.insert(s.to_owned(), idx);
351 idx
352 }
353}
354
355#[cfg(not(feature = "fast_hash"))]
358#[inline(always)]
359fn intern(s: &str, dict: &mut HashMap<String, u32>, uniq: &mut Vec64<String>) -> u32 {
360 if let Some(&code) = dict.get(s) {
361 code
362 } else {
363 let idx = uniq.len() as u32;
364 uniq.push(s.to_owned());
365 dict.insert(s.to_owned(), idx);
366 idx
367 }
368}
369
370pub fn apply_dict32_dict32(
400 lhs: CategoricalAVT<u32>,
401 rhs: CategoricalAVT<u32>,
402 op: ArithmeticOperator,
403) -> Result<CategoricalArray<u32>, KernelError> {
404 let (lhs_array, lhs_offset, lhs_logical_len) = lhs;
406 let (rhs_array, rhs_offset, rhs_logical_len) = rhs;
407
408 if lhs_logical_len != rhs_logical_len {
409 return Err(KernelError::LengthMismatch(log_length_mismatch(
410 "apply_dict32_dict32".into(),
411 lhs_logical_len,
412 rhs_logical_len,
413 )));
414 }
415
416 let in_mask = merge_bitmasks_to_new(
418 lhs_array.null_mask.as_ref(),
419 rhs_array.null_mask.as_ref(),
420 lhs_logical_len,
421 );
422
423 let mut uniq: Vec64<String> = Vec64::with_capacity(
425 lhs_array.unique_values.len() + rhs_array.unique_values.len() + lhs_logical_len,
426 );
427
428 #[cfg(feature = "fast_hash")]
429 let mut dict: AHashMap<String, u32> = AHashMap::with_capacity(uniq.capacity());
430
431 #[cfg(not(feature = "fast_hash"))]
432 let mut dict: HashMap<String, u32> = HashMap::with_capacity(uniq.capacity());
433
434 for v in lhs_array
435 .unique_values
436 .iter()
437 .chain(rhs_array.unique_values.iter())
438 {
439 if !dict.contains_key(v) {
440 let idx = uniq.len() as u32;
441 uniq.push(v.clone());
442 dict.insert(uniq.last().unwrap().clone(), idx);
443 }
444 }
445
446 let empty_code = *dict.entry("".to_owned()).or_insert_with(|| {
448 let idx = uniq.len() as u32;
449 uniq.push("".to_owned());
450 idx
451 });
452
453 let mut total_out = 0usize;
455 for local_idx in 0..lhs_logical_len {
456 let i = lhs_offset + local_idx;
457 let j = rhs_offset + local_idx;
458 let valid = in_mask
459 .as_ref()
460 .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
461 if !valid {
462 total_out += 1;
463 } else if let ArithmeticOperator::Divide = op {
464 let l = unsafe { lhs_array.get_str_unchecked(i) };
465 let r = unsafe { rhs_array.get_str_unchecked(j) };
466 if r.is_empty() {
467 total_out += 1;
468 } else {
469 let mut parts = 0;
470 let mut start = 0;
471 while let Some(pos) = l[start..].find(r) {
472 parts += 1;
473 start += pos + r.len();
474 }
475 total_out += parts + 1;
476 }
477 } else {
478 total_out += 1;
479 }
480 }
481
482 let mut out_data = Vec64::with_capacity(total_out);
484 unsafe {
485 out_data.set_len(total_out);
486 }
487 let mut out_mask = Bitmask::new_set_all(total_out, false);
488
489 let mut write_ptr = 0;
491 for local_idx in 0..lhs_logical_len {
492 let i = lhs_offset + local_idx;
493 let j = rhs_offset + local_idx;
494 let valid = in_mask
495 .as_ref()
496 .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
497
498 if !valid {
499 out_data.push(empty_code);
500 unsafe { out_mask.set_unchecked(write_ptr, false) };
501 write_ptr += 1;
502 continue;
503 }
504
505 let l = unsafe { lhs_array.get_str_unchecked(i) };
506 let r = unsafe { rhs_array.get_str_unchecked(j) };
507
508 match op {
509 ArithmeticOperator::Add => {
510 let mut tmp = String::with_capacity(l.len() + r.len());
511 tmp.push_str(l);
512 tmp.push_str(r);
513 let code = intern(&tmp, &mut dict, &mut uniq);
514 unsafe {
515 *out_data.get_unchecked_mut(write_ptr) = code;
516 }
517 out_mask.set(write_ptr, true);
518 write_ptr += 1;
519 }
520 ArithmeticOperator::Subtract => {
521 let result = if r.is_empty() {
522 l.to_owned()
523 } else if let Some(pos) = l.find(r) {
524 let mut tmp = String::with_capacity(l.len() - r.len());
525 tmp.push_str(&l[..pos]);
526 tmp.push_str(&l[pos + r.len()..]);
527 tmp
528 } else {
529 l.to_owned()
530 };
531 let code = intern(&result, &mut dict, &mut uniq);
532 unsafe {
533 *out_data.get_unchecked_mut(write_ptr) = code;
534 }
535 out_mask.set(write_ptr, true);
536 write_ptr += 1;
537 }
538 ArithmeticOperator::Multiply => {
539 let code = intern(l, &mut dict, &mut uniq);
540 unsafe {
541 *out_data.get_unchecked_mut(write_ptr) = code;
542 }
543 out_mask.set(write_ptr, true);
544 write_ptr += 1;
545 }
546 ArithmeticOperator::Divide => {
547 if r.is_empty() {
548 let code = intern(l, &mut dict, &mut uniq);
549 unsafe {
550 *out_data.get_unchecked_mut(write_ptr) = code;
551 }
552 out_mask.set(write_ptr, true);
553 write_ptr += 1;
554 } else {
555 let mut start = 0;
556 while let Some(pos) = l[start..].find(r) {
557 let part = &l[start..start + pos];
558 let code = intern(part, &mut dict, &mut uniq);
559 unsafe {
560 *out_data.get_unchecked_mut(write_ptr) = code;
561 }
562 out_mask.set(write_ptr, true);
563 write_ptr += 1;
564 start += pos + r.len();
565 }
566 let tail = &l[start..];
567 let code = intern(tail, &mut dict, &mut uniq);
568 unsafe {
569 *out_data.get_unchecked_mut(write_ptr) = code;
570 }
571 out_mask.set(write_ptr, true);
572 write_ptr += 1;
573 }
574 }
575 _ => {
576 return Err(KernelError::UnsupportedType(format!(
577 "Unsupported apply_dict32_dict32 op={:?}",
578 op
579 )));
580 }
581 }
582 }
583
584 debug_assert_eq!(write_ptr, total_out);
585
586 Ok(CategoricalArray {
587 data: out_data.into(),
588 unique_values: uniq,
589 null_mask: Some(out_mask),
590 })
591}
592
593#[cfg(feature = "str_arithmetic")]
618pub fn apply_str_str<T, U>(
619 lhs: StringAVT<T>,
620 rhs: StringAVT<U>,
621 op: ArithmeticOperator,
622) -> Result<StringArray<T>, KernelError>
623where
624 T: Integer,
625 U: Integer,
626{
627 let (larr, loff, llen) = lhs;
628 let (rarr, roff, rlen) = rhs;
629
630 if llen != rlen {
631 return Err(KernelError::LengthMismatch(log_length_mismatch(
632 "apply_str_str".to_string(),
633 llen,
634 rlen,
635 )));
636 }
637
638 let lmask_slice = larr.null_mask.as_ref().map(|m| {
640 let mut m2 = Bitmask::new_set_all(llen, true);
641 for i in 0..llen {
642 unsafe {
643 m2.set_unchecked(i, m.get_unchecked(loff + i));
644 }
645 }
646 m2
647 });
648 let rmask_slice = rarr.null_mask.as_ref().map(|m| {
649 let mut m2 = Bitmask::new_set_all(llen, true);
650 for i in 0..llen {
651 unsafe {
652 m2.set_unchecked(i, m.get_unchecked(roff + i));
653 }
654 }
655 m2
656 });
657 let lmask_ref = lmask_slice.as_ref();
658 let rmask_ref = rmask_slice.as_ref();
659
660 let (lmask, rmask, mut out_mask) = string_predicate_masks(lmask_ref, rmask_ref, llen);
662 let _ = confirm_mask_capacity(llen, lmask)?;
663 let _ = confirm_mask_capacity(llen, rmask)?;
664
665 let mut total_bytes = 0;
667 for idx in 0..llen {
668 let valid = lmask.map_or(true, |m| unsafe { m.get_unchecked(idx) })
669 && rmask.map_or(true, |m| unsafe { m.get_unchecked(idx) });
670 if !valid {
671 continue;
672 }
673 let a = unsafe { larr.get_str_unchecked(loff + idx) };
674 let b = unsafe { rarr.get_str_unchecked(roff + idx) };
675 total_bytes += match op {
676 ArithmeticOperator::Add => a.len() + b.len(),
677 ArithmeticOperator::Subtract => a.len(),
678 ArithmeticOperator::Multiply => a.len() * b.len().min(STRING_MULTIPLICATION_LIMIT),
679 ArithmeticOperator::Divide => {
680 if b.is_empty() {
681 a.len()
682 } else {
683 a.len() + a.matches(b).count().saturating_sub(1)
684 }
685 }
686 _ => {
687 return Err(KernelError::UnsupportedType(format!(
688 "Unsupported {:?}",
689 op
690 )));
691 }
692 };
693 }
694
695 let mut offsets = Vec64::<T>::with_capacity(llen + 1);
697 let mut data = Vec64::<u8>::with_capacity(total_bytes);
698 offsets.push(T::zero());
699
700 for idx in 0..llen {
702 let valid = lmask.map_or(true, |m| unsafe { m.get_unchecked(idx) })
703 && rmask.map_or(true, |m| unsafe { m.get_unchecked(idx) });
704 if valid {
705 let a = unsafe { larr.get_str_unchecked(loff + idx) };
706 let b = unsafe { rarr.get_str_unchecked(roff + idx) };
707 match op {
708 ArithmeticOperator::Add => {
709 data.extend_from_slice(a.as_bytes());
710 data.extend_from_slice(b.as_bytes());
711 }
712 ArithmeticOperator::Subtract => {
713 if b.is_empty() {
714 data.extend_from_slice(a.as_bytes());
715 } else if let Some(p) =
716 memchr::memmem::Finder::new(b.as_bytes()).find(a.as_bytes())
717 {
718 data.extend_from_slice(&a.as_bytes()[..p]);
719 data.extend_from_slice(&a.as_bytes()[p + b.len()..]);
720 } else {
721 data.extend_from_slice(a.as_bytes());
722 }
723 }
724 ArithmeticOperator::Multiply => {
725 let times = b.len().min(STRING_MULTIPLICATION_LIMIT);
726 for _ in 0..times {
727 data.extend_from_slice(a.as_bytes());
728 }
729 }
730 ArithmeticOperator::Divide => {
731 if b.is_empty() {
732 data.extend_from_slice(a.as_bytes());
733 } else {
734 let finder = memchr::memmem::Finder::new(b.as_bytes());
735 let mut start = 0;
736 let mut first = true;
737 while let Some(p) = finder.find(&a.as_bytes()[start..]) {
738 if !first {
739 data.push(b'|');
740 }
741 let abs = start + p;
742 data.extend_from_slice(&a.as_bytes()[start..abs]);
743 start = abs + b.len();
744 first = false;
745 }
746 if !first {
747 data.push(b'|');
748 }
749 data.extend_from_slice(&a.as_bytes()[start..]);
750 }
751 }
752 _ => unreachable!(),
753 }
754 unsafe { out_mask.set_unchecked(idx, true) };
755 }
756 offsets.push(T::from_usize(data.len()));
757 }
758
759 Ok(StringArray {
760 offsets: offsets.into(),
761 data: data.into(),
762 null_mask: Some(out_mask),
763 })
764}
765
766#[cfg(feature = "str_arithmetic")]
768pub fn apply_dict32_str<T>(
769 lhs: CategoricalAVT<u32>,
770 rhs: StringAVT<T>,
771 op: ArithmeticOperator,
772) -> Result<CategoricalArray<u32>, KernelError>
773where
774 T: Integer,
775{
776 const SAMPLE_SIZE: usize = 256;
777 const CARDINALITY_THRESHOLD: f64 = 0.75;
778
779 let (larr, loff, llen) = lhs;
781 let (rarr, roff, rlen) = rhs;
782
783 if llen != rlen {
784 return Err(KernelError::LengthMismatch(log_length_mismatch(
785 "apply_dict32_str".to_string(),
786 llen,
787 rlen,
788 )));
789 }
790
791 let cat_ratio = estimate_categorical_cardinality(larr, SAMPLE_SIZE);
793 let str_ratio = estimate_string_cardinality(rarr, SAMPLE_SIZE);
794 let max_ratio = cat_ratio.max(str_ratio);
795
796 if max_ratio > CARDINALITY_THRESHOLD {
797 let lhs_str = larr.to_string_array();
799 let str_result = apply_str_str((&lhs_str, loff, llen), (rarr, roff, rlen), op)?;
800 return Ok(str_result.to_categorical_array());
801 }
802
803 let out_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
805
806 let mut total_out = 0usize;
808 for local_idx in 0..llen {
809 let valid = out_mask
810 .as_ref()
811 .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
812 if !valid {
813 total_out += 1;
814 } else if let ArithmeticOperator::Divide = op {
815 let i = loff + local_idx;
816 let j = roff + local_idx;
817 let l_val = unsafe { larr.get_str_unchecked(i) };
818 let r_val = unsafe { rarr.get_str_unchecked(j) };
819 if r_val.is_empty() {
820 total_out += 1;
821 } else {
822 let mut start = 0;
823 while let Some(pos) = l_val[start..].find(r_val) {
824 total_out += 1;
825 start += pos + r_val.len();
826 }
827 total_out += 1; }
829 } else {
830 total_out += 1;
831 }
832 }
833
834 let mut out_data = Vec64::<u32>::with_capacity(total_out);
836 unsafe {
837 out_data.set_len(total_out);
838 }
839 let mut out_null = Bitmask::new_set_all(total_out, false);
840
841 let mut uniq: Vec64<String> = Vec64::with_capacity(larr.unique_values.len() + llen);
843 uniq.extend(larr.unique_values.iter().cloned());
844
845 #[cfg(feature = "fast_hash")]
846 let mut dict: AHashMap<String, u32> = AHashMap::with_capacity(uniq.len());
847
848 #[cfg(not(feature = "fast_hash"))]
849 let mut dict: HashMap<String, u32> = HashMap::with_capacity(uniq.len());
850
851 for (i, s) in uniq.iter().enumerate() {
852 dict.insert(s.clone(), i as u32);
853 }
854 let empty_code = *dict.entry("".to_string()).or_insert_with(|| {
856 let idx = uniq.len() as u32;
857 uniq.push(String::new());
858 idx
859 });
860
861 let mut write_ptr = 0usize;
863 for local_idx in 0..llen {
864 let valid = out_mask
865 .as_ref()
866 .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
867 if !valid {
868 out_data.push(empty_code);
869 out_null.set(write_ptr, false);
870 write_ptr += 1;
871 continue;
872 }
873 let i = loff + local_idx;
874 let j = roff + local_idx;
875 let l_val = unsafe { larr.get_str_unchecked(i) };
876 let r_val = unsafe { rarr.get_str_unchecked(j) };
877 match op {
878 ArithmeticOperator::Add => {
879 let mut s = String::with_capacity(l_val.len() + r_val.len());
880 s.push_str(l_val);
881 s.push_str(r_val);
882 let code = intern(&s, &mut dict, &mut uniq);
883 *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
884 out_null.set(write_ptr, true);
885 write_ptr += 1;
886 }
887 ArithmeticOperator::Subtract => {
888 let result = if r_val.is_empty() {
889 l_val.to_string()
890 } else if let Some(pos) = l_val.find(r_val) {
891 let mut s = l_val[..pos].to_owned();
892 s.push_str(&l_val[pos + r_val.len()..]);
893 s
894 } else {
895 l_val.to_string()
896 };
897 let code = intern(&result, &mut dict, &mut uniq);
898 *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
899 out_null.set(write_ptr, true);
900 write_ptr += 1;
901 }
902 ArithmeticOperator::Multiply => {
903 let code = intern(l_val, &mut dict, &mut uniq);
904 *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
905 out_null.set(write_ptr, true);
906 write_ptr += 1;
907 }
908 ArithmeticOperator::Divide => {
909 if r_val.is_empty() {
910 let code = intern(l_val, &mut dict, &mut uniq);
911 *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
912 out_null.set(write_ptr, true);
913 write_ptr += 1;
914 } else {
915 let mut start = 0;
916 loop {
917 match l_val[start..].find(r_val) {
918 Some(pos) => {
919 let part = &l_val[start..start + pos];
920 let code = intern(part, &mut dict, &mut uniq);
921 *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
922 out_null.set(write_ptr, true);
923 write_ptr += 1;
924 start += pos + r_val.len();
925 }
926 None => {
927 let tail = &l_val[start..];
928 let code = intern(tail, &mut dict, &mut uniq);
929 *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
930 out_null.set(write_ptr, true);
931 write_ptr += 1;
932 break;
933 }
934 }
935 }
936 }
937 }
938 _ => {
939 return Err(KernelError::UnsupportedType(
940 "Unsupported Type Error.".to_string(),
941 ));
942 }
943 }
944 }
945
946 debug_assert_eq!(write_ptr, total_out);
947
948 Ok(CategoricalArray {
949 data: out_data.into(),
950 unique_values: uniq,
951 null_mask: Some(out_null),
952 })
953}
954
955#[cfg(feature = "str_arithmetic")]
957pub fn apply_str_dict32<T>(
958 lhs: StringAVT<T>,
959 rhs: CategoricalAVT<u32>,
960 op: ArithmeticOperator,
961) -> Result<StringArray<T>, KernelError>
962where
963 T: Integer,
964{
965 let (larr, loff, llen) = lhs;
967 let (rarr, roff, rlen) = rhs;
968
969 if llen != rlen {
970 return Err(KernelError::LengthMismatch(log_length_mismatch(
971 "apply_str_dict32".to_string(),
972 llen,
973 rlen,
974 )));
975 }
976
977 let out_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
978
979 let mut total_rows = 0usize;
981 let mut total_bytes = 0usize;
982
983 for local_idx in 0..llen {
984 let valid = out_mask
985 .as_ref()
986 .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
987 if !valid {
988 total_rows += 1;
989 continue;
990 }
991
992 let i = loff + local_idx;
993 let j = roff + local_idx;
994
995 let l = unsafe { larr.get_str_unchecked(i) };
996 let r = unsafe { rarr.get_str_unchecked(j) };
997
998 match op {
999 ArithmeticOperator::Divide => {
1000 total_rows += l.split(r).count();
1001 total_bytes += l.len(); }
1003 ArithmeticOperator::Add => {
1004 total_rows += 1;
1005 total_bytes += l.len() + r.len();
1006 }
1007 ArithmeticOperator::Subtract => {
1008 total_rows += 1;
1009 total_bytes += l.len();
1010 }
1011 ArithmeticOperator::Multiply => {
1012 total_rows += 1;
1013 total_bytes += l.len();
1014 }
1015 _ => {
1016 return Err(KernelError::UnsupportedType(
1017 "Unsupported Type Error.".to_string(),
1018 ));
1019 }
1020 }
1021 }
1022
1023 let mut offsets = Vec64::<T>::with_capacity(total_rows + 1);
1025 let mut data = Vec64::<u8>::with_capacity(total_bytes);
1026
1027 unsafe {
1028 offsets.set_len(total_rows + 1);
1029 }
1030 offsets[0] = T::zero();
1031
1032 let mut cursor = 0;
1033 let mut offset_idx = 1;
1034
1035 for local_idx in 0..llen {
1036 let valid = out_mask
1037 .as_ref()
1038 .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
1039 if !valid {
1040 offsets[offset_idx] = T::from_usize(cursor);
1041 offset_idx += 1;
1042 continue;
1043 }
1044
1045 let i = loff + local_idx;
1046 let j = roff + local_idx;
1047
1048 let l = unsafe { larr.get_str_unchecked(i) };
1049 let r = unsafe { rarr.get_str_unchecked(j) };
1050
1051 match op {
1052 ArithmeticOperator::Divide => {
1053 for part in l.split(r) {
1054 data.extend_from_slice(part.as_bytes());
1055 cursor += part.len();
1056 offsets[offset_idx] = T::from_usize(cursor);
1057 offset_idx += 1;
1058 }
1059 }
1060 ArithmeticOperator::Add => {
1061 data.extend_from_slice(l.as_bytes());
1062 data.extend_from_slice(r.as_bytes());
1063 cursor += l.len() + r.len();
1064 offsets[offset_idx] = T::from_usize(cursor);
1065 offset_idx += 1;
1066 }
1067 ArithmeticOperator::Subtract => {
1068 if r.is_empty() {
1069 data.extend_from_slice(l.as_bytes());
1070 cursor += l.len();
1071 } else if let Some(pos) = l.find(r) {
1072 data.extend_from_slice(&l.as_bytes()[..pos]);
1073 data.extend_from_slice(&l.as_bytes()[pos + r.len()..]);
1074 cursor += l.len() - r.len();
1075 } else {
1076 data.extend_from_slice(l.as_bytes());
1077 cursor += l.len();
1078 }
1079 offsets[offset_idx] = T::from_usize(cursor);
1080 offset_idx += 1;
1081 }
1082 ArithmeticOperator::Multiply => {
1083 data.extend_from_slice(l.as_bytes());
1084 cursor += l.len();
1085 offsets[offset_idx] = T::from_usize(cursor);
1086 offset_idx += 1;
1087 }
1088 _ => unreachable!(),
1089 }
1090 }
1091
1092 debug_assert_eq!(offset_idx, total_rows + 1);
1093
1094 Ok(StringArray {
1095 offsets: offsets.into(),
1096 data: data.into(),
1097 null_mask: out_mask,
1098 })
1099}
1100
1101#[cfg(feature = "str_arithmetic")]
1104pub fn apply_dict32_num<T>(
1105 lhs: CategoricalAVT<u32>,
1106 rhs: &[T],
1107 op: ArithmeticOperator,
1108) -> Result<CategoricalArray<u32>, KernelError>
1109where
1110 T: ToPrimitive + Copy,
1111{
1112 #[cfg(feature = "fast_hash")]
1113 use ahash::{HashMap, HashMapExt};
1114
1115 #[cfg(not(feature = "fast_hash"))]
1116 use std::collections::HashMap;
1117
1118 let (larr, loff, llen) = lhs;
1119
1120 if llen != rhs.len() {
1121 return Err(KernelError::LengthMismatch(log_length_mismatch(
1122 "apply_dict32_num".to_string(),
1123 llen,
1124 rhs.len(),
1125 )));
1126 }
1127
1128 let has_mask = larr.null_mask.is_some();
1129 let mut out_mask = if has_mask {
1130 Some(Bitmask::new_set_all(llen, true))
1131 } else {
1132 None
1133 };
1134
1135 let mut data = Vec64::<u32>::with_capacity(llen);
1136 unsafe {
1137 data.set_len(llen);
1138 }
1139
1140 let mut unique_values = Vec64::<String>::with_capacity(llen);
1141 let mut seen: HashMap<String, u32> = HashMap::with_capacity(llen);
1142 let mut unique_idx = 0;
1143
1144 for local_idx in 0..llen {
1145 let valid = !has_mask
1146 || unsafe {
1147 larr.null_mask
1148 .as_ref()
1149 .unwrap()
1150 .get_unchecked(loff + local_idx)
1151 };
1152
1153 if valid {
1154 let i = loff + local_idx;
1155 let l_val = unsafe { larr.get_str_unchecked(i) };
1156 let n = rhs[local_idx].to_usize().unwrap_or(0);
1157
1158 let cat = match op {
1159 ArithmeticOperator::Multiply => {
1160 let count = n.min(1_000_000);
1161 l_val.repeat(count)
1162 }
1163 _ => l_val.to_owned(),
1164 };
1165
1166 let idx = if let Some(&ix) = seen.get(&cat) {
1167 ix
1168 } else {
1169 let ix = unique_idx as u32;
1170 seen.insert(cat.clone(), ix);
1171 unique_values.push(cat);
1172 unique_idx += 1;
1173 ix
1174 };
1175
1176 unsafe {
1177 *data.get_unchecked_mut(local_idx) = idx;
1178 if let Some(mask) = &mut out_mask {
1179 mask.set_unchecked(local_idx, true);
1180 }
1181 }
1182 } else {
1183 unsafe {
1184 *data.get_unchecked_mut(local_idx) = 0;
1185 if let Some(mask) = &mut out_mask {
1186 mask.set_unchecked(local_idx, false);
1187 }
1188 }
1189 }
1190 }
1191
1192 Ok(CategoricalArray {
1193 data: data.into(),
1194 unique_values,
1195 null_mask: out_mask,
1196 })
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201 use minarrow::MaskedArray;
1202 use minarrow::structs::variants::string::StringArray;
1203 #[cfg(feature = "str_arithmetic")]
1204 use minarrow::{Bitmask, CategoricalArray};
1205
1206 use super::*;
1207 use crate::operators::ArithmeticOperator;
1208 use minarrow::vec64;
1209
1210 fn assert_str<T>(arr: &StringArray<T>, expect: &[&str], valid: Option<&[bool]>)
1214 where
1215 T: minarrow::traits::type_unions::Integer + std::fmt::Debug,
1216 {
1217 assert_eq!(arr.len(), expect.len());
1218 for (i, exp) in expect.iter().enumerate() {
1219 assert_eq!(unsafe { arr.get_str_unchecked(i) }, *exp);
1220 }
1221 match (valid, &arr.null_mask) {
1222 (None, None) => {}
1223 (Some(expected), Some(mask)) => {
1224 for (i, bit) in expected.iter().enumerate() {
1225 assert_eq!(unsafe { mask.get_unchecked(i) }, *bit);
1226 }
1227 }
1228 (None, Some(mask)) => {
1229 assert!(mask.all_true());
1230 }
1231 (Some(_), None) => panic!("expected mask missing"),
1232 }
1233 }
1234
1235
1236 #[test]
1239 fn str_num_multiply() {
1240 let input = StringArray::<u32>::from_slice(&["hi", "bye", "x"]);
1241 let nums: &[i32] = &[3, 2, 0];
1242 let input_slice = (&input, 0, input.len(), input.data.len());
1243 let out =
1244 super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Multiply)
1245 .unwrap();
1246 assert_str(&out, &["hihihi", "byebye", ""], None);
1247 }
1248
1249 #[test]
1250 fn str_num_multiply_chunk() {
1251 let base = StringArray::<u32>::from_slice(&["pad", "hi", "bye", "x", "pad2"]);
1252 let nums: &[i32] = &[3, 2, 0];
1253 let input_slice = (&base, 1, 3, base.data.len());
1255 let out =
1256 super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Multiply)
1257 .unwrap();
1258 assert_str(&out, &["hihihi", "byebye", ""], None);
1259 }
1260
1261 #[test]
1262 fn str_num_len_mismatch() {
1263 let input = StringArray::<u32>::from_slice(&["a"]);
1264 let nums: &[i32] = &[1, 2];
1265 let input_slice = (&input, 0, input.len(), input.data.len());
1266 let err = super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Add)
1267 .unwrap_err();
1268 match err {
1269 KernelError::LengthMismatch(_) => {}
1270 _ => panic!("wrong error variant"),
1271 }
1272 }
1273
1274 #[test]
1275 fn str_num_len_mismatch_chunk() {
1276 let base = StringArray::<u32>::from_slice(&["pad", "a", "pad2"]);
1277 let nums: &[i32] = &[1, 2];
1278 let input_slice = (&base, 1, 1, base.data.len());
1280 let err = super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Add)
1281 .unwrap_err();
1282 match err {
1283 KernelError::LengthMismatch(_) => {}
1284 _ => panic!("wrong error variant"),
1285 }
1286 }
1287
1288 #[cfg(feature = "str_arithmetic")]
1289 #[test]
1290 fn str_float_all_ops() {
1291 let input = StringArray::<u32>::from_slice(&["foo", "bar1", "baz"]);
1292 let nums: &[f64] = &[1.0, 1.0, 2.0];
1293 let input_slice = (&input, 0, input.len());
1294 let add = super::apply_str_float(input_slice, nums, ArithmeticOperator::Add).unwrap();
1296 assert_str(&add, &["foo1", "bar11", "baz2"], None);
1297 let sub = super::apply_str_float(input_slice, nums, ArithmeticOperator::Subtract).unwrap();
1299 assert_str(&sub, &["foo", "bar", "baz"], None);
1300 let mul = super::apply_str_float(input_slice, nums, ArithmeticOperator::Multiply).unwrap();
1302 assert_str(&mul, &["foo", "bar1", "bazbaz"], None);
1303 let div = super::apply_str_float(input_slice, nums, ArithmeticOperator::Divide).unwrap();
1305 assert_str(&div, &["foo", "bar|", "baz"], None);
1306 }
1307
1308 #[cfg(feature = "str_arithmetic")]
1309 #[test]
1310 fn str_float_all_ops_chunk() {
1311 let base = StringArray::<u32>::from_slice(&["pad", "foo", "bar1", "baz", "pad2"]);
1312 let nums: &[f64] = &[1.0, 1.0, 2.0];
1313 let input_slice = (&base, 1, 3);
1315 let add = super::apply_str_float(input_slice, nums, ArithmeticOperator::Add).unwrap();
1317 assert_str(&add, &["foo1", "bar11", "baz2"], None);
1318 let sub = super::apply_str_float(input_slice, nums, ArithmeticOperator::Subtract).unwrap();
1320 assert_str(&sub, &["foo", "bar", "baz"], None);
1321 let mul = super::apply_str_float(input_slice, nums, ArithmeticOperator::Multiply).unwrap();
1323 assert_str(&mul, &["foo", "bar1", "bazbaz"], None);
1324 let div = super::apply_str_float(input_slice, nums, ArithmeticOperator::Divide).unwrap();
1326 assert_str(&div, &["foo", "bar|", "baz"], None);
1327 }
1328
1329
1330 #[cfg(feature = "str_arithmetic")]
1333 fn cat(values: &[&str]) -> CategoricalArray<u32> {
1334 CategoricalArray::<u32>::from_values(values.iter().copied())
1335 }
1336
1337 #[cfg(feature = "str_arithmetic")]
1338 #[test]
1339 fn dict32_dict32_add() {
1340
1341 let lhs = cat(&["A", "B", ""]);
1342 let rhs = cat(&["1", "2", "3"]);
1343 let lhs_slice = (&lhs, 0, lhs.data.len());
1344 let rhs_slice = (&rhs, 0, rhs.data.len());
1345 let out =
1346 super::apply_dict32_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1347 let expected = vec64!["A1", "B2", "3"];
1348 for (i, exp) in expected.iter().enumerate() {
1349 assert_eq!(out.get(i).unwrap_or(""), *exp);
1350 }
1351 }
1352
1353 #[cfg(feature = "str_arithmetic")]
1354 #[test]
1355 fn dict32_dict32_add_chunk() {
1356 let lhs = cat(&["pad", "A", "B", "", "pad2"]);
1357 let rhs = cat(&["padx", "1", "2", "3", "pady"]);
1358 let lhs_slice = (&lhs, 1, 3); let rhs_slice = (&rhs, 1, 3); let out =
1361 super::apply_dict32_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1362 let expected = vec64!["A1", "B2", "3"];
1363 for (i, exp) in expected.iter().enumerate() {
1364 assert_eq!(out.get(i).unwrap_or(""), *exp);
1365 }
1366 }
1367
1368 #[cfg(feature = "str_arithmetic")]
1369 #[test]
1370 fn dict32_str_subtract() {
1371 let lhs = cat(&["hello", "yellow"]);
1372 let rhs = StringArray::<u32>::from_slice(&["l", "el"]);
1373 let lhs_slice = (&lhs, 0, lhs.data.len());
1374 let rhs_slice = (&rhs, 0, rhs.len());
1375 let out =
1376 super::apply_dict32_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1377 assert_eq!(out.get(0).unwrap(), "helo");
1378 assert_eq!(out.get(1).unwrap(), "ylow");
1379 }
1380
1381 #[cfg(feature = "str_arithmetic")]
1382 #[test]
1383 fn dict32_str_subtract_chunk() {
1384 let lhs = cat(&["pad", "hello", "yellow", "pad2"]);
1385 let rhs = StringArray::<u32>::from_slice(&["pad", "l", "el", "pad2"]);
1386 let lhs_slice = (&lhs, 1, 2); let rhs_slice = (&rhs, 1, 2); let out =
1389 super::apply_dict32_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1390 assert_eq!(out.get(0).unwrap(), "helo");
1391 assert_eq!(out.get(1).unwrap(), "ylow");
1392 }
1393
1394 #[cfg(feature = "str_arithmetic")]
1395 #[test]
1396 fn str_dict32_divide() {
1397 let lhs = StringArray::<u32>::from_slice(&["a:b:c"]);
1398 let rhs = cat(&[":"]);
1399 let lhs_slice = (&lhs, 0, lhs.len());
1400 let rhs_slice = (&rhs, 0, rhs.data.len());
1401 let out =
1402 super::apply_str_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1403 assert_str(&out, &["a", "b", "c"], None);
1404 }
1405
1406 #[cfg(feature = "str_arithmetic")]
1407 #[test]
1408 fn str_dict32_divide_chunk() {
1409 let lhs = StringArray::<u32>::from_slice(&["pad", "a:b:c", "pad2"]);
1411 let rhs = cat(&["pad", ":", "pad2"]);
1412 let lhs_slice = (&lhs, 1, 1); let rhs_slice = (&rhs, 1, 1); let out =
1415 super::apply_str_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1416 assert_str(&out, &["a", "b", "c"], None);
1417 }
1418
1419 #[cfg(feature = "str_arithmetic")]
1420 #[test]
1421 fn dict32_num_multiply() {
1422 let lhs = cat(&["x", "y"]);
1423 let nums: &[u32] = &[3, 1];
1424 let lhs_slice = (&lhs, 0, lhs.data.len());
1425 let nums_window = &nums[0..lhs.data.len()];
1426 let out =
1427 super::apply_dict32_num(lhs_slice, nums_window, ArithmeticOperator::Multiply).unwrap();
1428 assert_eq!(out.get(0).unwrap(), "xxx");
1429 assert_eq!(out.get(1).unwrap(), "y");
1430 }
1431
1432 #[cfg(feature = "str_arithmetic")]
1433 #[test]
1434 fn dict32_num_multiply_chunk() {
1435 let lhs = cat(&["pad", "x", "y", "pad2"]);
1436 let nums: &[u32] = &[0, 3, 1, 0];
1437 let lhs_slice = (&lhs, 1, 2); let nums_window = &nums[1..3];
1439 let out =
1440 super::apply_dict32_num(lhs_slice, nums_window, ArithmeticOperator::Multiply).unwrap();
1441 assert_eq!(out.get(0).unwrap(), "xxx");
1442 assert_eq!(out.get(1).unwrap(), "y");
1443 }
1444
1445 #[cfg(feature = "str_arithmetic")]
1446 fn cat32_str_arr(strings: &[&str]) -> (CategoricalArray<u32>, StringArray<u32>) {
1447 let str_arr = StringArray::from_vec(strings.to_vec(), None);
1448 let cat_arr = str_arr.to_categorical_array();
1449 (cat_arr, str_arr)
1450 }
1451
1452 #[cfg(feature = "str_arithmetic")]
1453 #[test]
1454 fn test_apply_dict32_str_add_and_divide() {
1455 let (lhs_cat, rhs_str) = cat32_str_arr(&["foo", "bar|baz", ""]);
1456 let lhs_cat_slice = (&lhs_cat, 0, lhs_cat.data.len());
1458 let rhs_str_slice = (&rhs_str, 0, rhs_str.len());
1459 let added =
1460 apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Add).unwrap();
1461 let expected_cat = apply_str_str(
1462 (&lhs_cat.to_string_array(), 0, lhs_cat.len()),
1463 rhs_str_slice,
1464 ArithmeticOperator::Add,
1465 )
1466 .unwrap()
1467 .to_categorical_array();
1468 assert_eq!(added.unique_values, expected_cat.unique_values);
1469 assert_eq!(added.data, expected_cat.data);
1470
1471 let divided =
1473 apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Divide).unwrap();
1474 let expected_div = apply_str_str(
1475 (&lhs_cat.to_string_array(), 0, lhs_cat.len()),
1476 rhs_str_slice,
1477 ArithmeticOperator::Divide,
1478 )
1479 .unwrap()
1480 .to_categorical_array();
1481 assert_eq!(divided.unique_values, expected_div.unique_values);
1482 assert_eq!(divided.data, expected_div.data);
1483 }
1484
1485 #[cfg(feature = "str_arithmetic")]
1486 #[test]
1487 fn test_apply_dict32_str_add_and_divide_chunk() {
1488 let (lhs_cat, rhs_str) = cat32_str_arr(&["pad", "foo", "bar|baz", "", "pad2"]);
1489 let lhs_cat_slice = (&lhs_cat, 1, 3); let rhs_str_slice = (&rhs_str, 1, 3);
1491
1492 let added =
1494 apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Add).unwrap();
1495 let expected_cat = apply_str_str(
1496 (&lhs_cat.to_string_array(), 1, 3),
1497 rhs_str_slice,
1498 ArithmeticOperator::Add,
1499 )
1500 .unwrap()
1501 .to_categorical_array();
1502 assert_eq!(added.unique_values, expected_cat.unique_values);
1503 assert_eq!(added.data, expected_cat.data);
1504
1505 let divided =
1507 apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Divide).unwrap();
1508 let expected_div = apply_str_str(
1509 (&lhs_cat.to_string_array(), 1, 3),
1510 rhs_str_slice,
1511 ArithmeticOperator::Divide,
1512 )
1513 .unwrap()
1514 .to_categorical_array();
1515 assert_eq!(divided.unique_values, expected_div.unique_values);
1516 assert_eq!(divided.data, expected_div.data);
1517 }
1518
1519
1520 #[cfg(feature = "str_arithmetic")]
1523 fn string_array<T: Integer>(data: &[&str], nulls: Option<&[bool]>) -> StringArray<T> {
1524 let array = StringArray::from_vec(data.to_vec(), nulls.map(Bitmask::from_bools));
1525 assert_eq!(array.len(), data.len());
1526 array
1527 }
1528
1529 #[cfg(feature = "str_arithmetic")]
1530 #[test]
1531 fn test_add_str() {
1532 let lhs = string_array::<u32>(&["a", "b", "c"], None);
1533 let rhs = string_array::<u32>(&["x", "y", "z"], None);
1534 let lhs_slice = (&lhs, 0, lhs.len());
1535 let rhs_slice = (&rhs, 0, rhs.len());
1536 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1537
1538 assert_eq!(result.get(0), Some("ax"));
1539 assert_eq!(result.get(1), Some("by"));
1540 assert_eq!(result.get(2), Some("cz"));
1541 }
1542
1543 #[cfg(feature = "str_arithmetic")]
1544 #[test]
1545 fn test_add_str_chunk() {
1546 let lhs = string_array::<u32>(&["pad", "a", "b", "c", "pad2"], None);
1547 let rhs = string_array::<u32>(&["pad", "x", "y", "z", "pad2"], None);
1548 let lhs_slice = (&lhs, 1, 3);
1550 let rhs_slice = (&rhs, 1, 3);
1551 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1552
1553 assert_eq!(result.get(0), Some("ax"));
1554 assert_eq!(result.get(1), Some("by"));
1555 assert_eq!(result.get(2), Some("cz"));
1556 }
1557
1558 #[cfg(feature = "str_arithmetic")]
1559 #[test]
1560 fn test_subtract_str() {
1561 let lhs = string_array::<u32>(&["hello", "goodbye", "test"], None);
1562 let rhs = string_array::<u32>(&["l", "bye", "xyz"], None);
1563 let lhs_slice = (&lhs, 0, lhs.len());
1564 let rhs_slice = (&rhs, 0, rhs.len());
1565 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1566
1567 assert_eq!(result.get(0), Some("helo"));
1568 assert_eq!(result.get(1), Some("good"));
1569 assert_eq!(result.get(2), Some("test")); }
1571
1572 #[cfg(feature = "str_arithmetic")]
1573 #[test]
1574 fn test_subtract_str_chunk() {
1575 let lhs = string_array::<u32>(&["pad", "hello", "goodbye", "test", "pad2"], None);
1576 let rhs = string_array::<u32>(&["pad", "l", "bye", "xyz", "pad2"], None);
1577 let lhs_slice = (&lhs, 1, 3);
1579 let rhs_slice = (&rhs, 1, 3);
1580 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1581
1582 assert_eq!(result.get(0), Some("helo"));
1583 assert_eq!(result.get(1), Some("good"));
1584 assert_eq!(result.get(2), Some("test")); }
1586
1587 #[cfg(feature = "str_arithmetic")]
1588 #[test]
1589 fn test_multiply_str() {
1590 let lhs = string_array::<u32>(&["x", "ab", "c"], None);
1591 let rhs = string_array::<u32>(&["123", "12", "long_string"], None);
1592 let lhs_slice = (&lhs, 0, lhs.len());
1593 let rhs_slice = (&rhs, 0, rhs.len());
1594 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Multiply).unwrap();
1595
1596 assert_eq!(result.get(0), Some("xxx"));
1597 assert_eq!(result.get(1), Some("abab"));
1598 assert_eq!(
1599 result.get(2),
1600 Some("c".repeat("long_string".len()).as_str())
1601 );
1602 }
1603
1604 #[cfg(feature = "str_arithmetic")]
1605 #[test]
1606 fn test_multiply_str_chunk() {
1607 let lhs = string_array::<u32>(&["pad", "x", "ab", "c", "pad2"], None);
1608 let rhs = string_array::<u32>(&["pad", "123", "12", "long_string", "pad2"], None);
1609 let lhs_slice = (&lhs, 1, 3);
1611 let rhs_slice = (&rhs, 1, 3);
1612 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Multiply).unwrap();
1613
1614 assert_eq!(result.get(0), Some("xxx"));
1615 assert_eq!(result.get(1), Some("abab"));
1616 assert_eq!(
1617 result.get(2),
1618 Some("c".repeat("long_string".len()).as_str())
1619 );
1620 }
1621
1622 #[cfg(feature = "str_arithmetic")]
1623 #[test]
1624 fn test_divide_str() {
1625 let lhs = string_array::<u32>(&["a,b,c", "a--b--c", "abc"], None);
1626 let rhs = string_array::<u32>(&[",", "--", ""], None);
1627 let lhs_slice = (&lhs, 0, lhs.len());
1628 let rhs_slice = (&rhs, 0, rhs.len());
1629 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1630
1631 assert_eq!(result.get(0), Some("a|b|c"));
1632 assert_eq!(result.get(1), Some("a|b|c"));
1633 assert_eq!(result.get(2), Some("abc"));
1634 }
1635
1636 #[cfg(feature = "str_arithmetic")]
1637 #[test]
1638 fn test_divide_str_chunk() {
1639 let lhs = string_array::<u32>(&["xxx", "a,b,c", "a--b--c", "abc", "yyy"], None);
1640 let rhs = string_array::<u32>(&["", ",", "--", "", ""], None);
1641 let lhs_slice = (&lhs, 1, 3);
1643 let rhs_slice = (&rhs, 1, 3);
1644 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1645
1646 assert_eq!(result.get(0), Some("a|b|c"));
1647 assert_eq!(result.get(1), Some("a|b|c"));
1648 assert_eq!(result.get(2), Some("abc"));
1649 }
1650
1651 #[cfg(feature = "str_arithmetic")]
1652 #[test]
1653 fn test_nulls_str() {
1654 let lhs = string_array::<u32>(&["a", "b", "c"], Some(&[true, false, true]));
1655 let rhs = string_array::<u32>(&["x", "y", "z"], Some(&[true, true, false]));
1656 let lhs_slice = (&lhs, 0, lhs.len());
1657 let rhs_slice = (&rhs, 0, rhs.len());
1658 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1659
1660 assert_eq!(result.get(0), Some("ax"));
1661 assert_eq!(result.get(1), None);
1662 assert_eq!(result.get(2), None);
1663 }
1664
1665 #[cfg(feature = "str_arithmetic")]
1666 #[test]
1667 fn test_nulls_str_chunk() {
1668 let lhs = string_array::<u32>(
1669 &["0", "a", "b", "c", "9"],
1670 Some(&[false, true, false, true, false]),
1671 );
1672 let rhs = string_array::<u32>(
1673 &["y", "x", "y", "z", "w"],
1674 Some(&[true, true, true, false, false]),
1675 );
1676 let lhs_slice = (&lhs, 1, 3);
1678 let rhs_slice = (&rhs, 1, 3);
1679 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1680
1681 assert_eq!(result.get(0), Some("ax"));
1682 assert_eq!(result.get(1), None);
1683 assert_eq!(result.get(2), None);
1684 }
1685
1686 #[cfg(feature = "str_arithmetic")]
1687 #[test]
1688 fn test_mismatched_length_str() {
1689 let lhs = string_array::<u32>(&["a", "b"], None);
1690 let rhs = string_array::<u32>(&["x"], None);
1691 let lhs_slice = (&lhs, 0, lhs.len());
1692 let rhs_slice = (&rhs, 0, rhs.len());
1693 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add);
1694 assert!(matches!(result, Err(KernelError::LengthMismatch(_))));
1695 }
1696
1697 #[cfg(feature = "str_arithmetic")]
1698 #[test]
1699 fn test_mismatched_length_str_chunk() {
1700 let lhs = string_array::<u32>(&["a", "b", "c"], None);
1701 let rhs = string_array::<u32>(&["x"], None);
1702 let lhs_slice = (&lhs, 1, 2);
1704 let rhs_slice = (&rhs, 0, 1);
1705 let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add);
1706 assert!(matches!(result, Err(KernelError::LengthMismatch(_))));
1707 }
1708}