simd_kernels/kernels/arithmetic/
string.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **String Arithmetic Module** - *String Operations with Numeric Interactions*
5//!
6//! String-specific arithmetic operations including string multiplication, concatenation, and manipulation.
7//! This unifies strings into a typical numeric-compatible workloads. E.g., "hello" + "there" = "hellothere". 
8//! These are opt-in via the "str_arithmetic" feature.
9//!
10//! ## Overview
11//! - **String multiplication**: Repeat strings by numeric factors with configurable limits
12//! - **String-numeric conversions**: Format numbers into string representations  
13//! - **Categorical operations**: Efficient string deduplication and categorical array generation
14//! - **Null-aware processing**: Full Arrow-compatible null propagation
15//!
16//! ## Features
17//! - **Memory efficiency**: Uses string interning and categorical encoding to reduce allocation overhead
18//! - **Safety limits**: Configurable multiplication limits prevent excessive memory usage
19//! - **Optional dependencies**: String-numeric arithmetic gated behind `str_arithmetic` feature
20
21include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
22
23#[cfg(feature = "fast_hash")]
24use ahash::AHashMap;
25#[cfg(feature = "str_arithmetic")]
26use core::ptr::copy_nonoverlapping;
27#[cfg(not(feature = "fast_hash"))]
28use std::collections::HashMap;
29
30#[cfg(feature = "str_arithmetic")]
31use memchr::memmem::Finder;
32use minarrow::structs::variants::categorical::CategoricalArray;
33
34use minarrow::structs::variants::string::StringArray;
35use minarrow::traits::type_unions::Integer;
36use minarrow::{Bitmask, Vec64};
37#[cfg(feature = "str_arithmetic")]
38use num_traits::ToPrimitive;
39
40use crate::config::STRING_MULTIPLICATION_LIMIT;
41use crate::errors::{KernelError, log_length_mismatch};
42#[cfg(feature = "str_arithmetic")]
43use crate::kernels::string::string_predicate_masks;
44use crate::operators::ArithmeticOperator::{self};
45#[cfg(feature = "str_arithmetic")]
46use crate::utils::format_finite;
47use crate::utils::merge_bitmasks_to_new;
48#[cfg(feature = "str_arithmetic")]
49use crate::utils::{
50    confirm_mask_capacity, estimate_categorical_cardinality, estimate_string_cardinality,
51};
52use minarrow::{CategoricalAVT, StringAVTExt};
53#[cfg(feature = "str_arithmetic")]
54use minarrow::{MaskedArray, StringAVT};
55
56/// String-numeric arithmetic operation dispatcher.
57/// Supports string multiplication (repeat string N times) with safety limits.
58/// Other operations pass through unchanged. Handles null propagation correctly.
59///
60/// # Type Parameters
61/// - `T`: Offset type for the input array (e.g., `u32`, `u64`)
62/// - `N`: Numeric type convertible to `usize`
63/// - `O`: Offset type for the output array
64pub fn apply_str_num<T, N, O>(
65    lhs: StringAVTExt<T>,
66    rhs: &[N],
67    op: ArithmeticOperator,
68) -> Result<StringArray<O>, KernelError>
69where
70    T: Integer,
71    N: num_traits::ToPrimitive + Copy,
72    O: Integer + num_traits::NumCast,
73{
74    let (array, offset, logical_len, physical_bytes_len) = lhs;
75
76    if logical_len != rhs.len() {
77        return Err(KernelError::LengthMismatch(log_length_mismatch(
78            "apply_str_num".to_string(),
79            logical_len,
80            rhs.len(),
81        )));
82    }
83
84    // Preallocate offsets and estimate capacity
85    let lhs_mask = array.null_mask.as_ref();
86    let mut out_mask = lhs_mask.map(|_| minarrow::Bitmask::new_set_all(logical_len, true));
87
88    let mut offsets = Vec64::<O>::with_capacity(logical_len + 1);
89    offsets.push(O::zero()); // initial offset 0
90
91    let estimated_bytes = physical_bytes_len.min(STRING_MULTIPLICATION_LIMIT * logical_len);
92    let mut data = Vec64::with_capacity(estimated_bytes);
93
94    for (out_idx, i) in (offset..offset + logical_len).enumerate() {
95        let valid = lhs_mask.map_or(true, |mask| unsafe { mask.get_unchecked(i) });
96
97        if let Some(mask) = &mut out_mask {
98            unsafe { mask.set_unchecked(out_idx, valid) };
99        }
100
101        if valid {
102            let s = unsafe { array.get_str_unchecked(i) };
103            let n = rhs[out_idx].to_usize().unwrap_or(0);
104
105            match op {
106                ArithmeticOperator::Multiply => {
107                    let count = n.min(STRING_MULTIPLICATION_LIMIT);
108                    for _ in 0..count {
109                        data.extend_from_slice(s.as_bytes());
110                    }
111                }
112                _ => {
113                    data.extend_from_slice(s.as_bytes());
114                }
115            }
116        }
117
118        // Push offset regardless of validity to keep offsets aligned
119        // This ensures we can still slice [a..b] intuitively.
120        let new_offset = O::from(data.len()).expect("offset conversion overflow");
121        offsets.push(new_offset);
122    }
123
124    Ok(StringArray {
125        offsets: offsets.into(),
126        data: data.into(),
127        null_mask: out_mask,
128    })
129}
130
131/// Applies an element-wise binary operation between a `StringArray<T>` and a slice of floating-point values,
132/// producing a new `StringArray<T>`. Each operation is performed by interpreting the float as a finite
133/// decimal string representation (`f64` formatted with `ryu`), and applying string transformations accordingly.
134///
135/// Supported operations:
136///
137/// - `Add`: Appends the stringified float to the string from `lhs`.
138/// - `Subtract`: Removes the first occurrence of the stringified float from `lhs`, if present.
139///               If not found, `lhs` is returned unchanged.
140/// - `Multiply`: Repeats the `lhs` string `N` times, where `N = abs(round(rhs)) % (STRING_MULTIPLICATION_LIMIT + 1)`.
141/// - `Divide`: Splits the `lhs` string by the stringified float, and joins the segments using `'|'`.
142///             If the float pattern is not found, the original string is returned unchanged.
143///
144/// Null handling:
145/// - If the `lhs` value is null at an index, the result is null at that index.
146/// - The float operand cannot be null; the caller must guarantee its presence.
147///
148/// Output:
149/// - A new `StringArray<T>` with the same length as `lhs`.
150/// - The underlying byte storage is preallocated based on a prepass analysis of required capacity.
151///
152/// Errors:
153/// - Returns `KernelError::LengthMismatch` if `lhs` and `rhs` lengths differ.
154/// - Returns `KernelError::UnsupportedType` if the operator is not one of Add, Subtract, Multiply, Divide.
155///
156/// # Features
157/// This function is only available when the `str_arithmetic` feature is enabled.
158///
159/// This kernel is optional as it pulls in external dependencies, and fits more niche use cases. It's a
160/// good fit for flex-typing scenarios where users are concatenating strings and numbers, or working
161/// with semi-structured web content, string formatting pipelines etc.
162///
163/// # Safety
164/// - Uses unchecked access and raw pointer copies for performance. Invariants around memory safety must hold.
165/// - Assumes `rhs` contains only finite floating-point values.
166#[cfg(feature = "str_arithmetic")]
167pub fn apply_str_float<T, F>(
168    lhs: StringAVT<T>,
169    rhs: &[F],
170    op: ArithmeticOperator,
171) -> Result<StringArray<T>, KernelError>
172where
173    T: Integer,
174    F: Into<f64> + Copy + ryu::Float,
175{
176    // Destructure the string slice: array, offset, and logical length
177    let (array, offset, logical_len) = lhs;
178
179    // Validate inputs
180
181    use std::mem::MaybeUninit;
182    if rhs.len() != logical_len {
183        return Err(KernelError::LengthMismatch(log_length_mismatch(
184            "apply_str_float".into(),
185            logical_len,
186            rhs.len(),
187        )));
188    }
189    let lhs_mask = &array.null_mask;
190    let _ = confirm_mask_capacity(array.len(), lhs_mask.as_ref())?;
191
192    // 1st pass: size accounting
193    let mut total_bytes = 0usize;
194    let mut fmt_buf: [MaybeUninit<u8>; 24] = unsafe { MaybeUninit::uninit().assume_init() };
195
196    for (out_idx, i) in (offset..offset + logical_len).enumerate() {
197        if !lhs_mask
198            .as_ref()
199            .map_or(true, |m| unsafe { m.get_unchecked(i) })
200        {
201            continue;
202        }
203
204        // src_len = string length at physical index
205        let src_len = {
206            let a = array.offsets[i].to_usize();
207            let b = array.offsets[i + 1].to_usize();
208            b - a
209        };
210        let n_s = format_finite(&mut fmt_buf, rhs[out_idx]);
211        total_bytes += match op {
212            ArithmeticOperator::Add => src_len + n_s.len(),
213            ArithmeticOperator::Subtract => src_len,
214            ArithmeticOperator::Multiply => {
215                let times =
216                    rhs[out_idx].into().round().abs() as usize % (STRING_MULTIPLICATION_LIMIT + 1);
217                src_len * times
218            }
219            ArithmeticOperator::Divide => {
220                let pat_len = n_s.len();
221                let splits = (src_len + pat_len).saturating_sub(1) / pat_len;
222                src_len + splits
223            }
224            _ => {
225                return Err(KernelError::UnsupportedType(format!(
226                    "Unsupported {:?}",
227                    op
228                )));
229            }
230        };
231    }
232
233    // allocate outputs once
234    let mut offsets = Vec64::<T>::with_capacity(logical_len + 1);
235
236    // 2nd pass: copy / build strings
237    let mut data = Vec64::<u8>::with_capacity(total_bytes);
238    unsafe {
239        offsets.set_len(logical_len + 1);
240        data.set_len(total_bytes);
241    }
242
243    let mut out_mask = lhs_mask
244        .as_ref()
245        .map(|_| Bitmask::new_set_all(logical_len, false));
246
247    let mut cursor = 0usize;
248    offsets[0] = T::zero();
249
250    for (out_idx, i) in (offset..offset + logical_len).enumerate() {
251        let valid = lhs_mask
252            .as_ref()
253            .map_or(true, |m| unsafe { m.get_unchecked(i) });
254        if let Some(mask) = &mut out_mask {
255            unsafe { mask.set_unchecked(out_idx, valid) };
256        }
257
258        if !valid {
259            offsets[out_idx + 1] = T::from(cursor).unwrap();
260            continue;
261        }
262
263        let start = array.offsets[i].to_usize();
264        let end = array.offsets[i + 1].to_usize();
265        let src = &array.data[start..end];
266        let n_s = format_finite(&mut fmt_buf, rhs[out_idx]);
267        let pat = n_s.as_bytes();
268
269        let mut write = |bytes: &[u8]| unsafe {
270            copy_nonoverlapping(bytes.as_ptr(), data.as_mut_ptr().add(cursor), bytes.len());
271            cursor += bytes.len();
272        };
273
274        match op {
275            ArithmeticOperator::Add => {
276                write(src);
277                write(pat);
278            }
279            ArithmeticOperator::Subtract => {
280                if let Some(idx) = Finder::new(pat).find(src) {
281                    write(&src[..idx]);
282                    write(&src[(idx + pat.len())..]);
283                } else {
284                    write(src);
285                }
286            }
287            ArithmeticOperator::Multiply => {
288                let times =
289                    rhs[out_idx].into().round().abs() as usize % (STRING_MULTIPLICATION_LIMIT + 1);
290                for _ in 0..times {
291                    write(src);
292                }
293            }
294            ArithmeticOperator::Divide => {
295                let finder = Finder::new(pat);
296                let mut start_pos = 0;
297                let mut first = true;
298                while let Some(idx) = finder.find(&src[start_pos..]) {
299                    if !first {
300                        data[cursor] = b'|';
301                        cursor += 1;
302                    }
303                    let rel_idx = idx;
304                    let segment = &src[start_pos..start_pos + rel_idx];
305                    unsafe {
306                        copy_nonoverlapping(
307                            segment.as_ptr(),
308                            data.as_mut_ptr().add(cursor),
309                            segment.len(),
310                        );
311                        cursor += segment.len();
312                    }
313                    start_pos += rel_idx + pat.len();
314                    first = false;
315                }
316                if !first {
317                    data[cursor] = b'|';
318                    cursor += 1;
319                }
320                let tail = &src[start_pos..];
321                unsafe {
322                    copy_nonoverlapping(tail.as_ptr(), data.as_mut_ptr().add(cursor), tail.len());
323                    cursor += tail.len();
324                }
325            }
326            _ => unreachable!(),
327        }
328        offsets[out_idx + 1] = T::from(cursor).unwrap();
329    }
330
331    // build & return
332
333    Ok(StringArray {
334        offsets: offsets.into(),
335        data: data.into(),
336        null_mask: out_mask,
337    })
338}
339
340/// String interning helper for categorical array generation with fast hashing.
341/// Deduplicates strings and assigns numeric codes for memory efficiency.
342#[cfg(feature = "fast_hash")]
343#[inline(always)]
344fn intern(s: &str, dict: &mut AHashMap<String, u32>, uniq: &mut Vec64<String>) -> u32 {
345    if let Some(&code) = dict.get(s) {
346        code
347    } else {
348        let idx = uniq.len() as u32;
349        uniq.push(s.to_owned());
350        dict.insert(s.to_owned(), idx);
351        idx
352    }
353}
354
355/// String interning helper for categorical array generation with standard hashing.
356/// Deduplicates strings and assigns numeric codes for memory efficiency.
357#[cfg(not(feature = "fast_hash"))]
358#[inline(always)]
359fn intern(s: &str, dict: &mut HashMap<String, u32>, uniq: &mut Vec64<String>) -> u32 {
360    if let Some(&code) = dict.get(s) {
361        code
362    } else {
363        let idx = uniq.len() as u32;
364        uniq.push(s.to_owned());
365        dict.insert(s.to_owned(), idx);
366        idx
367    }
368}
369
370/// Applies an element-wise binary operation between two `CategoricalArray<u32>` arrays,
371/// producing a new `CategoricalArray<u32>`. The result reuses or extends the unified dictionary
372/// from both input arrays and ensures deterministic interned value codes.
373///
374/// Supported operations:
375///
376/// - `Add`: Concatenates strings from `lhs` and `rhs`. Result is interned into the output dictionary.
377/// - `Subtract`: Removes the first occurrence of `rhs` from `lhs`. If `rhs` is empty or not found,
378///               returns `lhs` unchanged.
379/// - `Multiply`: Returns `lhs` unchanged. No actual repetition occurs—identity operation.
380/// - `Divide`: Splits `lhs` by occurrences of `rhs`, and each resulting segment is interned separately.
381///             If `rhs` is empty, `lhs` is returned unchanged.
382///
383/// Null handling:
384/// - If either side is null at a given index, the result is marked null and the empty string code is emitted.
385/// - Null mask is propagated accordingly.
386///
387/// Output:
388/// - The resulting `CategoricalArray<u32>` may have a different length than the input
389///   if `Divide` produces multiple segments per row.
390/// - The dictionary (`unique_values`) is the union of all unique values observed in inputs and results,
391///   with stable interned codes.
392///
393/// Errors:
394/// - Returns `KernelError::LengthMismatch` if `lhs` and `rhs` differ in length.
395/// - Returns `KernelError::UnsupportedType` for any operator other than Add, Subtract, Multiply, Divide.
396///
397/// # Panics
398/// - Panics if internal memory allocation fails or if invariants are violated in unsafe regions.
399pub fn apply_dict32_dict32(
400    lhs: CategoricalAVT<u32>,
401    rhs: CategoricalAVT<u32>,
402    op: ArithmeticOperator,
403) -> Result<CategoricalArray<u32>, KernelError> {
404    // Destructure slice tuples for offset/length-local processing
405    let (lhs_array, lhs_offset, lhs_logical_len) = lhs;
406    let (rhs_array, rhs_offset, rhs_logical_len) = rhs;
407
408    if lhs_logical_len != rhs_logical_len {
409        return Err(KernelError::LengthMismatch(log_length_mismatch(
410            "apply_dict32_dict32".into(),
411            lhs_logical_len,
412            rhs_logical_len,
413        )));
414    }
415
416    // Input mask: merge only over local window
417    let in_mask = merge_bitmasks_to_new(
418        lhs_array.null_mask.as_ref(),
419        rhs_array.null_mask.as_ref(),
420        lhs_logical_len,
421    );
422
423    // Build unique dictionary for the output, initially union of both inputs
424    let mut uniq: Vec64<String> = Vec64::with_capacity(
425        lhs_array.unique_values.len() + rhs_array.unique_values.len() + lhs_logical_len,
426    );
427
428    #[cfg(feature = "fast_hash")]
429    let mut dict: AHashMap<String, u32> = AHashMap::with_capacity(uniq.capacity());
430
431    #[cfg(not(feature = "fast_hash"))]
432    let mut dict: HashMap<String, u32> = HashMap::with_capacity(uniq.capacity());
433
434    for v in lhs_array
435        .unique_values
436        .iter()
437        .chain(rhs_array.unique_values.iter())
438    {
439        if !dict.contains_key(v) {
440            let idx = uniq.len() as u32;
441            uniq.push(v.clone());
442            dict.insert(uniq.last().unwrap().clone(), idx);
443        }
444    }
445
446    // Ensure "" is present and get its code
447    let empty_code = *dict.entry("".to_owned()).or_insert_with(|| {
448        let idx = uniq.len() as u32;
449        uniq.push("".to_owned());
450        idx
451    });
452
453    // 1st pass: Count output rows for precise allocation
454    let mut total_out = 0usize;
455    for local_idx in 0..lhs_logical_len {
456        let i = lhs_offset + local_idx;
457        let j = rhs_offset + local_idx;
458        let valid = in_mask
459            .as_ref()
460            .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
461        if !valid {
462            total_out += 1;
463        } else if let ArithmeticOperator::Divide = op {
464            let l = unsafe { lhs_array.get_str_unchecked(i) };
465            let r = unsafe { rhs_array.get_str_unchecked(j) };
466            if r.is_empty() {
467                total_out += 1;
468            } else {
469                let mut parts = 0;
470                let mut start = 0;
471                while let Some(pos) = l[start..].find(r) {
472                    parts += 1;
473                    start += pos + r.len();
474                }
475                total_out += parts + 1;
476            }
477        } else {
478            total_out += 1;
479        }
480    }
481
482    // Preallocate output buffer for window only
483    let mut out_data = Vec64::with_capacity(total_out);
484    unsafe {
485        out_data.set_len(total_out);
486    }
487    let mut out_mask = Bitmask::new_set_all(total_out, false);
488
489    // 2nd pass: Populate output for this slice only
490    let mut write_ptr = 0;
491    for local_idx in 0..lhs_logical_len {
492        let i = lhs_offset + local_idx;
493        let j = rhs_offset + local_idx;
494        let valid = in_mask
495            .as_ref()
496            .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
497
498        if !valid {
499            out_data.push(empty_code);
500            unsafe { out_mask.set_unchecked(write_ptr, false) };
501            write_ptr += 1;
502            continue;
503        }
504
505        let l = unsafe { lhs_array.get_str_unchecked(i) };
506        let r = unsafe { rhs_array.get_str_unchecked(j) };
507
508        match op {
509            ArithmeticOperator::Add => {
510                let mut tmp = String::with_capacity(l.len() + r.len());
511                tmp.push_str(l);
512                tmp.push_str(r);
513                let code = intern(&tmp, &mut dict, &mut uniq);
514                unsafe {
515                    *out_data.get_unchecked_mut(write_ptr) = code;
516                }
517                out_mask.set(write_ptr, true);
518                write_ptr += 1;
519            }
520            ArithmeticOperator::Subtract => {
521                let result = if r.is_empty() {
522                    l.to_owned()
523                } else if let Some(pos) = l.find(r) {
524                    let mut tmp = String::with_capacity(l.len() - r.len());
525                    tmp.push_str(&l[..pos]);
526                    tmp.push_str(&l[pos + r.len()..]);
527                    tmp
528                } else {
529                    l.to_owned()
530                };
531                let code = intern(&result, &mut dict, &mut uniq);
532                unsafe {
533                    *out_data.get_unchecked_mut(write_ptr) = code;
534                }
535                out_mask.set(write_ptr, true);
536                write_ptr += 1;
537            }
538            ArithmeticOperator::Multiply => {
539                let code = intern(l, &mut dict, &mut uniq);
540                unsafe {
541                    *out_data.get_unchecked_mut(write_ptr) = code;
542                }
543                out_mask.set(write_ptr, true);
544                write_ptr += 1;
545            }
546            ArithmeticOperator::Divide => {
547                if r.is_empty() {
548                    let code = intern(l, &mut dict, &mut uniq);
549                    unsafe {
550                        *out_data.get_unchecked_mut(write_ptr) = code;
551                    }
552                    out_mask.set(write_ptr, true);
553                    write_ptr += 1;
554                } else {
555                    let mut start = 0;
556                    while let Some(pos) = l[start..].find(r) {
557                        let part = &l[start..start + pos];
558                        let code = intern(part, &mut dict, &mut uniq);
559                        unsafe {
560                            *out_data.get_unchecked_mut(write_ptr) = code;
561                        }
562                        out_mask.set(write_ptr, true);
563                        write_ptr += 1;
564                        start += pos + r.len();
565                    }
566                    let tail = &l[start..];
567                    let code = intern(tail, &mut dict, &mut uniq);
568                    unsafe {
569                        *out_data.get_unchecked_mut(write_ptr) = code;
570                    }
571                    out_mask.set(write_ptr, true);
572                    write_ptr += 1;
573                }
574            }
575            _ => {
576                return Err(KernelError::UnsupportedType(format!(
577                    "Unsupported apply_dict32_dict32 op={:?}",
578                    op
579                )));
580            }
581        }
582    }
583
584    debug_assert_eq!(write_ptr, total_out);
585
586    Ok(CategoricalArray {
587        data: out_data.into(),
588        unique_values: uniq,
589        null_mask: Some(out_mask),
590    })
591}
592
593/// Applies an element-wise binary operation between two `StringArray`s,
594/// producing a new `StringArray`. Requires both arrays to have the same length.
595///
596/// Supported operations:
597///
598/// - `Add`: Concatenates each pair of strings (`a + b`).
599/// - `Subtract`: Removes the first occurrence of `b` from `a`, if present.
600///               If `b` is empty or not found, `a` is returned unchanged.
601/// - `Multiply`: Repeats string `a` N times, where `N = min(b.len(), STRING_MULTIPLICATION_LIMIT)`.
602/// - `Divide`: Splits string `a` by occurrences of `b` and rejoins the segments using a `'|'` separator.
603///             If `b` is empty, returns `a` unchanged.
604///
605/// Null handling:
606/// - If either side is null at an index, the output will be null at that index.
607///
608/// Returns:
609/// - A new `StringArray<T>` containing the result of applying the binary operation to each pair.
610///
611/// Errors:
612/// - Returns `KernelError::LengthMismatch` if `lhs` and `rhs` lengths differ.
613/// - Returns `KernelError::UnsupportedType` if an unsupported binary operator is passed.
614///
615/// # Features
616/// This function is available only when the `str_arithmetic` feature is enabled.
617#[cfg(feature = "str_arithmetic")]
618pub fn apply_str_str<T, U>(
619    lhs: StringAVT<T>,
620    rhs: StringAVT<U>,
621    op: ArithmeticOperator,
622) -> Result<StringArray<T>, KernelError>
623where
624    T: Integer,
625    U: Integer,
626{
627    let (larr, loff, llen) = lhs;
628    let (rarr, roff, rlen) = rhs;
629
630    if llen != rlen {
631        return Err(KernelError::LengthMismatch(log_length_mismatch(
632            "apply_str_str".to_string(),
633            llen,
634            rlen,
635        )));
636    }
637
638    // slice the incoming masks down to [offset .. offset+llen)
639    let lmask_slice = larr.null_mask.as_ref().map(|m| {
640        let mut m2 = Bitmask::new_set_all(llen, true);
641        for i in 0..llen {
642            unsafe {
643                m2.set_unchecked(i, m.get_unchecked(loff + i));
644            }
645        }
646        m2
647    });
648    let rmask_slice = rarr.null_mask.as_ref().map(|m| {
649        let mut m2 = Bitmask::new_set_all(llen, true);
650        for i in 0..llen {
651            unsafe {
652                m2.set_unchecked(i, m.get_unchecked(roff + i));
653            }
654        }
655        m2
656    });
657    let lmask_ref = lmask_slice.as_ref();
658    let rmask_ref = rmask_slice.as_ref();
659
660    // build per‐position validity
661    let (lmask, rmask, mut out_mask) = string_predicate_masks(lmask_ref, rmask_ref, llen);
662    let _ = confirm_mask_capacity(llen, lmask)?;
663    let _ = confirm_mask_capacity(llen, rmask)?;
664
665    // 1) size pass
666    let mut total_bytes = 0;
667    for idx in 0..llen {
668        let valid = lmask.map_or(true, |m| unsafe { m.get_unchecked(idx) })
669            && rmask.map_or(true, |m| unsafe { m.get_unchecked(idx) });
670        if !valid {
671            continue;
672        }
673        let a = unsafe { larr.get_str_unchecked(loff + idx) };
674        let b = unsafe { rarr.get_str_unchecked(roff + idx) };
675        total_bytes += match op {
676            ArithmeticOperator::Add => a.len() + b.len(),
677            ArithmeticOperator::Subtract => a.len(),
678            ArithmeticOperator::Multiply => a.len() * b.len().min(STRING_MULTIPLICATION_LIMIT),
679            ArithmeticOperator::Divide => {
680                if b.is_empty() {
681                    a.len()
682                } else {
683                    a.len() + a.matches(b).count().saturating_sub(1)
684                }
685            }
686            _ => {
687                return Err(KernelError::UnsupportedType(format!(
688                    "Unsupported {:?}",
689                    op
690                )));
691            }
692        };
693    }
694
695    // 2) allocate
696    let mut offsets = Vec64::<T>::with_capacity(llen + 1);
697    let mut data = Vec64::<u8>::with_capacity(total_bytes);
698    offsets.push(T::zero());
699
700    // 3) build pass
701    for idx in 0..llen {
702        let valid = lmask.map_or(true, |m| unsafe { m.get_unchecked(idx) })
703            && rmask.map_or(true, |m| unsafe { m.get_unchecked(idx) });
704        if valid {
705            let a = unsafe { larr.get_str_unchecked(loff + idx) };
706            let b = unsafe { rarr.get_str_unchecked(roff + idx) };
707            match op {
708                ArithmeticOperator::Add => {
709                    data.extend_from_slice(a.as_bytes());
710                    data.extend_from_slice(b.as_bytes());
711                }
712                ArithmeticOperator::Subtract => {
713                    if b.is_empty() {
714                        data.extend_from_slice(a.as_bytes());
715                    } else if let Some(p) =
716                        memchr::memmem::Finder::new(b.as_bytes()).find(a.as_bytes())
717                    {
718                        data.extend_from_slice(&a.as_bytes()[..p]);
719                        data.extend_from_slice(&a.as_bytes()[p + b.len()..]);
720                    } else {
721                        data.extend_from_slice(a.as_bytes());
722                    }
723                }
724                ArithmeticOperator::Multiply => {
725                    let times = b.len().min(STRING_MULTIPLICATION_LIMIT);
726                    for _ in 0..times {
727                        data.extend_from_slice(a.as_bytes());
728                    }
729                }
730                ArithmeticOperator::Divide => {
731                    if b.is_empty() {
732                        data.extend_from_slice(a.as_bytes());
733                    } else {
734                        let finder = memchr::memmem::Finder::new(b.as_bytes());
735                        let mut start = 0;
736                        let mut first = true;
737                        while let Some(p) = finder.find(&a.as_bytes()[start..]) {
738                            if !first {
739                                data.push(b'|');
740                            }
741                            let abs = start + p;
742                            data.extend_from_slice(&a.as_bytes()[start..abs]);
743                            start = abs + b.len();
744                            first = false;
745                        }
746                        if !first {
747                            data.push(b'|');
748                        }
749                        data.extend_from_slice(&a.as_bytes()[start..]);
750                    }
751                }
752                _ => unreachable!(),
753            }
754            unsafe { out_mask.set_unchecked(idx, true) };
755        }
756        offsets.push(T::from_usize(data.len()));
757    }
758
759    Ok(StringArray {
760        offsets: offsets.into(),
761        data: data.into(),
762        null_mask: Some(out_mask),
763    })
764}
765
766/// Applies element-wise binary arithmetic ops between a `CategoricalArray<u32>` and a `StringArray<T>`.
767#[cfg(feature = "str_arithmetic")]
768pub fn apply_dict32_str<T>(
769    lhs: CategoricalAVT<u32>,
770    rhs: StringAVT<T>,
771    op: ArithmeticOperator,
772) -> Result<CategoricalArray<u32>, KernelError>
773where
774    T: Integer,
775{
776    const SAMPLE_SIZE: usize = 256;
777    const CARDINALITY_THRESHOLD: f64 = 0.75;
778
779    // Destructure slice for local scope
780    let (larr, loff, llen) = lhs;
781    let (rarr, roff, rlen) = rhs;
782
783    if llen != rlen {
784        return Err(KernelError::LengthMismatch(log_length_mismatch(
785            "apply_dict32_str".to_string(),
786            llen,
787            rlen,
788        )));
789    }
790
791    // --- Estimate string diversity, pick path ---
792    let cat_ratio = estimate_categorical_cardinality(larr, SAMPLE_SIZE);
793    let str_ratio = estimate_string_cardinality(rarr, SAMPLE_SIZE);
794    let max_ratio = cat_ratio.max(str_ratio);
795
796    if max_ratio > CARDINALITY_THRESHOLD {
797        // High cardinality: materialise, do flat string ops, then re-categorise.
798        let lhs_str = larr.to_string_array();
799        let str_result = apply_str_str((&lhs_str, loff, llen), (rarr, roff, rlen), op)?;
800        return Ok(str_result.to_categorical_array());
801    }
802
803    // --- Low-cardinality: interned path ---
804    let out_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
805
806    // First pass: Pre-count total number of output rows
807    let mut total_out = 0usize;
808    for local_idx in 0..llen {
809        let valid = out_mask
810            .as_ref()
811            .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
812        if !valid {
813            total_out += 1;
814        } else if let ArithmeticOperator::Divide = op {
815            let i = loff + local_idx;
816            let j = roff + local_idx;
817            let l_val = unsafe { larr.get_str_unchecked(i) };
818            let r_val = unsafe { rarr.get_str_unchecked(j) };
819            if r_val.is_empty() {
820                total_out += 1;
821            } else {
822                let mut start = 0;
823                while let Some(pos) = l_val[start..].find(r_val) {
824                    total_out += 1;
825                    start += pos + r_val.len();
826                }
827                total_out += 1; // final segment
828            }
829        } else {
830            total_out += 1;
831        }
832    }
833
834    // Pre-allocate output buffers
835    let mut out_data = Vec64::<u32>::with_capacity(total_out);
836    unsafe {
837        out_data.set_len(total_out);
838    }
839    let mut out_null = Bitmask::new_set_all(total_out, false);
840
841    // Prepare dictionary and unique values (for this slice)
842    let mut uniq: Vec64<String> = Vec64::with_capacity(larr.unique_values.len() + llen);
843    uniq.extend(larr.unique_values.iter().cloned());
844
845    #[cfg(feature = "fast_hash")]
846    let mut dict: AHashMap<String, u32> = AHashMap::with_capacity(uniq.len());
847
848    #[cfg(not(feature = "fast_hash"))]
849    let mut dict: HashMap<String, u32> = HashMap::with_capacity(uniq.len());
850
851    for (i, s) in uniq.iter().enumerate() {
852        dict.insert(s.clone(), i as u32);
853    }
854    // Ensure "" is interned
855    let empty_code = *dict.entry("".to_string()).or_insert_with(|| {
856        let idx = uniq.len() as u32;
857        uniq.push(String::new());
858        idx
859    });
860
861    // Second pass: Fill output buffers for the slice
862    let mut write_ptr = 0usize;
863    for local_idx in 0..llen {
864        let valid = out_mask
865            .as_ref()
866            .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
867        if !valid {
868            out_data.push(empty_code);
869            out_null.set(write_ptr, false);
870            write_ptr += 1;
871            continue;
872        }
873        let i = loff + local_idx;
874        let j = roff + local_idx;
875        let l_val = unsafe { larr.get_str_unchecked(i) };
876        let r_val = unsafe { rarr.get_str_unchecked(j) };
877        match op {
878            ArithmeticOperator::Add => {
879                let mut s = String::with_capacity(l_val.len() + r_val.len());
880                s.push_str(l_val);
881                s.push_str(r_val);
882                let code = intern(&s, &mut dict, &mut uniq);
883                *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
884                out_null.set(write_ptr, true);
885                write_ptr += 1;
886            }
887            ArithmeticOperator::Subtract => {
888                let result = if r_val.is_empty() {
889                    l_val.to_string()
890                } else if let Some(pos) = l_val.find(r_val) {
891                    let mut s = l_val[..pos].to_owned();
892                    s.push_str(&l_val[pos + r_val.len()..]);
893                    s
894                } else {
895                    l_val.to_string()
896                };
897                let code = intern(&result, &mut dict, &mut uniq);
898                *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
899                out_null.set(write_ptr, true);
900                write_ptr += 1;
901            }
902            ArithmeticOperator::Multiply => {
903                let code = intern(l_val, &mut dict, &mut uniq);
904                *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
905                out_null.set(write_ptr, true);
906                write_ptr += 1;
907            }
908            ArithmeticOperator::Divide => {
909                if r_val.is_empty() {
910                    let code = intern(l_val, &mut dict, &mut uniq);
911                    *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
912                    out_null.set(write_ptr, true);
913                    write_ptr += 1;
914                } else {
915                    let mut start = 0;
916                    loop {
917                        match l_val[start..].find(r_val) {
918                            Some(pos) => {
919                                let part = &l_val[start..start + pos];
920                                let code = intern(part, &mut dict, &mut uniq);
921                                *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
922                                out_null.set(write_ptr, true);
923                                write_ptr += 1;
924                                start += pos + r_val.len();
925                            }
926                            None => {
927                                let tail = &l_val[start..];
928                                let code = intern(tail, &mut dict, &mut uniq);
929                                *unsafe { out_data.get_unchecked_mut(write_ptr) } = code;
930                                out_null.set(write_ptr, true);
931                                write_ptr += 1;
932                                break;
933                            }
934                        }
935                    }
936                }
937            }
938            _ => {
939                return Err(KernelError::UnsupportedType(
940                    "Unsupported Type Error.".to_string(),
941                ));
942            }
943        }
944    }
945
946    debug_assert_eq!(write_ptr, total_out);
947
948    Ok(CategoricalArray {
949        data: out_data.into(),
950        unique_values: uniq,
951        null_mask: Some(out_null),
952    })
953}
954
955/// Applies element-wise binary arithmetic ops between `StringArray<T>` and `CategoricalArray<u32>`
956#[cfg(feature = "str_arithmetic")]
957pub fn apply_str_dict32<T>(
958    lhs: StringAVT<T>,
959    rhs: CategoricalAVT<u32>,
960    op: ArithmeticOperator,
961) -> Result<StringArray<T>, KernelError>
962where
963    T: Integer,
964{
965    // Destructure input slices
966    let (larr, loff, llen) = lhs;
967    let (rarr, roff, rlen) = rhs;
968
969    if llen != rlen {
970        return Err(KernelError::LengthMismatch(log_length_mismatch(
971            "apply_str_dict32".to_string(),
972            llen,
973            rlen,
974        )));
975    }
976
977    let out_mask = merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
978
979    // --- Pre-count output rows and byte size ---
980    let mut total_rows = 0usize;
981    let mut total_bytes = 0usize;
982
983    for local_idx in 0..llen {
984        let valid = out_mask
985            .as_ref()
986            .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
987        if !valid {
988            total_rows += 1;
989            continue;
990        }
991
992        let i = loff + local_idx;
993        let j = roff + local_idx;
994
995        let l = unsafe { larr.get_str_unchecked(i) };
996        let r = unsafe { rarr.get_str_unchecked(j) };
997
998        match op {
999            ArithmeticOperator::Divide => {
1000                total_rows += l.split(r).count();
1001                total_bytes += l.len(); // splitting doesn't remove data
1002            }
1003            ArithmeticOperator::Add => {
1004                total_rows += 1;
1005                total_bytes += l.len() + r.len();
1006            }
1007            ArithmeticOperator::Subtract => {
1008                total_rows += 1;
1009                total_bytes += l.len();
1010            }
1011            ArithmeticOperator::Multiply => {
1012                total_rows += 1;
1013                total_bytes += l.len();
1014            }
1015            _ => {
1016                return Err(KernelError::UnsupportedType(
1017                    "Unsupported Type Error.".to_string(),
1018                ));
1019            }
1020        }
1021    }
1022
1023    // Allocate output buffers for local window
1024    let mut offsets = Vec64::<T>::with_capacity(total_rows + 1);
1025    let mut data = Vec64::<u8>::with_capacity(total_bytes);
1026
1027    unsafe {
1028        offsets.set_len(total_rows + 1);
1029    }
1030    offsets[0] = T::zero();
1031
1032    let mut cursor = 0;
1033    let mut offset_idx = 1;
1034
1035    for local_idx in 0..llen {
1036        let valid = out_mask
1037            .as_ref()
1038            .map_or(true, |m| unsafe { m.get_unchecked(local_idx) });
1039        if !valid {
1040            offsets[offset_idx] = T::from_usize(cursor);
1041            offset_idx += 1;
1042            continue;
1043        }
1044
1045        let i = loff + local_idx;
1046        let j = roff + local_idx;
1047
1048        let l = unsafe { larr.get_str_unchecked(i) };
1049        let r = unsafe { rarr.get_str_unchecked(j) };
1050
1051        match op {
1052            ArithmeticOperator::Divide => {
1053                for part in l.split(r) {
1054                    data.extend_from_slice(part.as_bytes());
1055                    cursor += part.len();
1056                    offsets[offset_idx] = T::from_usize(cursor);
1057                    offset_idx += 1;
1058                }
1059            }
1060            ArithmeticOperator::Add => {
1061                data.extend_from_slice(l.as_bytes());
1062                data.extend_from_slice(r.as_bytes());
1063                cursor += l.len() + r.len();
1064                offsets[offset_idx] = T::from_usize(cursor);
1065                offset_idx += 1;
1066            }
1067            ArithmeticOperator::Subtract => {
1068                if r.is_empty() {
1069                    data.extend_from_slice(l.as_bytes());
1070                    cursor += l.len();
1071                } else if let Some(pos) = l.find(r) {
1072                    data.extend_from_slice(&l.as_bytes()[..pos]);
1073                    data.extend_from_slice(&l.as_bytes()[pos + r.len()..]);
1074                    cursor += l.len() - r.len();
1075                } else {
1076                    data.extend_from_slice(l.as_bytes());
1077                    cursor += l.len();
1078                }
1079                offsets[offset_idx] = T::from_usize(cursor);
1080                offset_idx += 1;
1081            }
1082            ArithmeticOperator::Multiply => {
1083                data.extend_from_slice(l.as_bytes());
1084                cursor += l.len();
1085                offsets[offset_idx] = T::from_usize(cursor);
1086                offset_idx += 1;
1087            }
1088            _ => unreachable!(),
1089        }
1090    }
1091
1092    debug_assert_eq!(offset_idx, total_rows + 1);
1093
1094    Ok(StringArray {
1095        offsets: offsets.into(),
1096        data: data.into(),
1097        null_mask: out_mask,
1098    })
1099}
1100
1101/// Applies element-wise binary arithmetic op between `CategoricalArray<u32>`s
1102/// a numeric slice.
1103#[cfg(feature = "str_arithmetic")]
1104pub fn apply_dict32_num<T>(
1105    lhs: CategoricalAVT<u32>,
1106    rhs: &[T],
1107    op: ArithmeticOperator,
1108) -> Result<CategoricalArray<u32>, KernelError>
1109where
1110    T: ToPrimitive + Copy,
1111{
1112    #[cfg(feature = "fast_hash")]
1113    use ahash::{HashMap, HashMapExt};
1114
1115    #[cfg(not(feature = "fast_hash"))]
1116    use ahash::HashMap;
1117
1118    let (larr, loff, llen) = lhs;
1119
1120    if llen != rhs.len() {
1121        return Err(KernelError::LengthMismatch(log_length_mismatch(
1122            "apply_dict32_num".to_string(),
1123            llen,
1124            rhs.len(),
1125        )));
1126    }
1127
1128    let has_mask = larr.null_mask.is_some();
1129    let mut out_mask = if has_mask {
1130        Some(Bitmask::new_set_all(llen, true))
1131    } else {
1132        None
1133    };
1134
1135    let mut data = Vec64::<u32>::with_capacity(llen);
1136    unsafe {
1137        data.set_len(llen);
1138    }
1139
1140    let mut unique_values = Vec64::<String>::with_capacity(llen);
1141    let mut seen: HashMap<String, u32> = HashMap::with_capacity(llen);
1142    let mut unique_idx = 0;
1143
1144    for local_idx in 0..llen {
1145        let valid = !has_mask
1146            || unsafe {
1147                larr.null_mask
1148                    .as_ref()
1149                    .unwrap()
1150                    .get_unchecked(loff + local_idx)
1151            };
1152
1153        if valid {
1154            let i = loff + local_idx;
1155            let l_val = unsafe { larr.get_str_unchecked(i) };
1156            let n = rhs[local_idx].to_usize().unwrap_or(0);
1157
1158            let cat = match op {
1159                ArithmeticOperator::Multiply => {
1160                    let count = n.min(1_000_000);
1161                    l_val.repeat(count)
1162                }
1163                _ => l_val.to_owned(),
1164            };
1165
1166            let idx = if let Some(&ix) = seen.get(&cat) {
1167                ix
1168            } else {
1169                let ix = unique_idx as u32;
1170                seen.insert(cat.clone(), ix);
1171                unique_values.push(cat);
1172                unique_idx += 1;
1173                ix
1174            };
1175
1176            unsafe {
1177                *data.get_unchecked_mut(local_idx) = idx;
1178                if let Some(mask) = &mut out_mask {
1179                    mask.set_unchecked(local_idx, true);
1180                }
1181            }
1182        } else {
1183            unsafe {
1184                *data.get_unchecked_mut(local_idx) = 0;
1185                if let Some(mask) = &mut out_mask {
1186                    mask.set_unchecked(local_idx, false);
1187                }
1188            }
1189        }
1190    }
1191
1192    Ok(CategoricalArray {
1193        data: data.into(),
1194        unique_values,
1195        null_mask: out_mask,
1196    })
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201    use minarrow::MaskedArray;
1202    use minarrow::structs::variants::string::StringArray;
1203    #[cfg(feature = "str_arithmetic")]
1204    use minarrow::{Bitmask, CategoricalArray};
1205
1206    use super::*;
1207    use crate::operators::ArithmeticOperator;
1208    use minarrow::vec64;
1209
1210    // Helpers
1211
1212    /// Assert that a `StringArray<T>` matches the supplied `Vec<&str>` and nullity.
1213    fn assert_str<T>(arr: &StringArray<T>, expect: &[&str], valid: Option<&[bool]>)
1214    where
1215        T: minarrow::traits::type_unions::Integer + std::fmt::Debug,
1216    {
1217        assert_eq!(arr.len(), expect.len());
1218        for (i, exp) in expect.iter().enumerate() {
1219            assert_eq!(unsafe { arr.get_str_unchecked(i) }, *exp);
1220        }
1221        match (valid, &arr.null_mask) {
1222            (None, None) => {}
1223            (Some(expected), Some(mask)) => {
1224                for (i, bit) in expected.iter().enumerate() {
1225                    assert_eq!(unsafe { mask.get_unchecked(i) }, *bit);
1226                }
1227            }
1228            (None, Some(mask)) => {
1229                assert!(mask.all_true());
1230            }
1231            (Some(_), None) => panic!("expected mask missing"),
1232        }
1233    }
1234
1235
1236    // String - Numeric Kernels
1237
1238    #[test]
1239    fn str_num_multiply() {
1240        let input = StringArray::<u32>::from_slice(&["hi", "bye", "x"]);
1241        let nums: &[i32] = &[3, 2, 0];
1242        let input_slice = (&input, 0, input.len(), input.data.len());
1243        let out =
1244            super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Multiply)
1245                .unwrap();
1246        assert_str(&out, &["hihihi", "byebye", ""], None);
1247    }
1248
1249    #[test]
1250    fn str_num_multiply_chunk() {
1251        let base = StringArray::<u32>::from_slice(&["pad", "hi", "bye", "x", "pad2"]);
1252        let nums: &[i32] = &[3, 2, 0];
1253        // Window: ["hi", "bye", "x"]
1254        let input_slice = (&base, 1, 3, base.data.len());
1255        let out =
1256            super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Multiply)
1257                .unwrap();
1258        assert_str(&out, &["hihihi", "byebye", ""], None);
1259    }
1260
1261    #[test]
1262    fn str_num_len_mismatch() {
1263        let input = StringArray::<u32>::from_slice(&["a"]);
1264        let nums: &[i32] = &[1, 2];
1265        let input_slice = (&input, 0, input.len(), input.data.len());
1266        let err = super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Add)
1267            .unwrap_err();
1268        match err {
1269            KernelError::LengthMismatch(_) => {}
1270            _ => panic!("wrong error variant"),
1271        }
1272    }
1273
1274    #[test]
1275    fn str_num_len_mismatch_chunk() {
1276        let base = StringArray::<u32>::from_slice(&["pad", "a", "pad2"]);
1277        let nums: &[i32] = &[1, 2];
1278        // Window: ["a"]
1279        let input_slice = (&base, 1, 1, base.data.len());
1280        let err = super::apply_str_num::<u32, i32, u32>(input_slice, nums, ArithmeticOperator::Add)
1281            .unwrap_err();
1282        match err {
1283            KernelError::LengthMismatch(_) => {}
1284            _ => panic!("wrong error variant"),
1285        }
1286    }
1287
1288    #[cfg(feature = "str_arithmetic")]
1289    #[test]
1290    fn str_float_all_ops() {
1291        let input = StringArray::<u32>::from_slice(&["foo", "bar1", "baz"]);
1292        let nums: &[f64] = &[1.0, 1.0, 2.0];
1293        let input_slice = (&input, 0, input.len());
1294        // Add
1295        let add = super::apply_str_float(input_slice, nums, ArithmeticOperator::Add).unwrap();
1296        assert_str(&add, &["foo1", "bar11", "baz2"], None);
1297        // Subtract
1298        let sub = super::apply_str_float(input_slice, nums, ArithmeticOperator::Subtract).unwrap();
1299        assert_str(&sub, &["foo", "bar", "baz"], None);
1300        // Multiply
1301        let mul = super::apply_str_float(input_slice, nums, ArithmeticOperator::Multiply).unwrap();
1302        assert_str(&mul, &["foo", "bar1", "bazbaz"], None);
1303        // Divide
1304        let div = super::apply_str_float(input_slice, nums, ArithmeticOperator::Divide).unwrap();
1305        assert_str(&div, &["foo", "bar|", "baz"], None);
1306    }
1307
1308    #[cfg(feature = "str_arithmetic")]
1309    #[test]
1310    fn str_float_all_ops_chunk() {
1311        let base = StringArray::<u32>::from_slice(&["pad", "foo", "bar1", "baz", "pad2"]);
1312        let nums: &[f64] = &[1.0, 1.0, 2.0];
1313        // Window: ["foo", "bar1", "baz"]
1314        let input_slice = (&base, 1, 3);
1315        // Add
1316        let add = super::apply_str_float(input_slice, nums, ArithmeticOperator::Add).unwrap();
1317        assert_str(&add, &["foo1", "bar11", "baz2"], None);
1318        // Subtract
1319        let sub = super::apply_str_float(input_slice, nums, ArithmeticOperator::Subtract).unwrap();
1320        assert_str(&sub, &["foo", "bar", "baz"], None);
1321        // Multiply
1322        let mul = super::apply_str_float(input_slice, nums, ArithmeticOperator::Multiply).unwrap();
1323        assert_str(&mul, &["foo", "bar1", "bazbaz"], None);
1324        // Divide
1325        let div = super::apply_str_float(input_slice, nums, ArithmeticOperator::Divide).unwrap();
1326        assert_str(&div, &["foo", "bar|", "baz"], None);
1327    }
1328
1329
1330    // Dictionary Kernels
1331
1332    #[cfg(feature = "str_arithmetic")]
1333    fn cat(values: &[&str]) -> CategoricalArray<u32> {
1334        CategoricalArray::<u32>::from_values(values.iter().copied())
1335    }
1336
1337    #[cfg(feature = "str_arithmetic")]
1338    #[test]
1339    fn dict32_dict32_add() {
1340
1341        let lhs = cat(&["A", "B", ""]);
1342        let rhs = cat(&["1", "2", "3"]);
1343        let lhs_slice = (&lhs, 0, lhs.data.len());
1344        let rhs_slice = (&rhs, 0, rhs.data.len());
1345        let out =
1346            super::apply_dict32_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1347        let expected = vec64!["A1", "B2", "3"];
1348        for (i, exp) in expected.iter().enumerate() {
1349            assert_eq!(out.get(i).unwrap_or(""), *exp);
1350        }
1351    }
1352
1353    #[cfg(feature = "str_arithmetic")]
1354    #[test]
1355    fn dict32_dict32_add_chunk() {
1356        let lhs = cat(&["pad", "A", "B", "", "pad2"]);
1357        let rhs = cat(&["padx", "1", "2", "3", "pady"]);
1358        let lhs_slice = (&lhs, 1, 3); // "A", "B", ""
1359        let rhs_slice = (&rhs, 1, 3); // "1", "2", "3"
1360        let out =
1361            super::apply_dict32_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1362        let expected = vec64!["A1", "B2", "3"];
1363        for (i, exp) in expected.iter().enumerate() {
1364            assert_eq!(out.get(i).unwrap_or(""), *exp);
1365        }
1366    }
1367
1368    #[cfg(feature = "str_arithmetic")]
1369    #[test]
1370    fn dict32_str_subtract() {
1371        let lhs = cat(&["hello", "yellow"]);
1372        let rhs = StringArray::<u32>::from_slice(&["l", "el"]);
1373        let lhs_slice = (&lhs, 0, lhs.data.len());
1374        let rhs_slice = (&rhs, 0, rhs.len());
1375        let out =
1376            super::apply_dict32_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1377        assert_eq!(out.get(0).unwrap(), "helo");
1378        assert_eq!(out.get(1).unwrap(), "ylow");
1379    }
1380
1381    #[cfg(feature = "str_arithmetic")]
1382    #[test]
1383    fn dict32_str_subtract_chunk() {
1384        let lhs = cat(&["pad", "hello", "yellow", "pad2"]);
1385        let rhs = StringArray::<u32>::from_slice(&["pad", "l", "el", "pad2"]);
1386        let lhs_slice = (&lhs, 1, 2); // "hello", "yellow"
1387        let rhs_slice = (&rhs, 1, 2); // "l", "el"
1388        let out =
1389            super::apply_dict32_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1390        assert_eq!(out.get(0).unwrap(), "helo");
1391        assert_eq!(out.get(1).unwrap(), "ylow");
1392    }
1393
1394    #[cfg(feature = "str_arithmetic")]
1395    #[test]
1396    fn str_dict32_divide() {
1397        let lhs = StringArray::<u32>::from_slice(&["a:b:c"]);
1398        let rhs = cat(&[":"]);
1399        let lhs_slice = (&lhs, 0, lhs.len());
1400        let rhs_slice = (&rhs, 0, rhs.data.len());
1401        let out =
1402            super::apply_str_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1403        assert_str(&out, &["a", "b", "c"], None);
1404    }
1405
1406    #[cfg(feature = "str_arithmetic")]
1407    #[test]
1408    fn str_dict32_divide_chunk() {
1409        // Extended arrays for windowing
1410        let lhs = StringArray::<u32>::from_slice(&["pad", "a:b:c", "pad2"]);
1411        let rhs = cat(&["pad", ":", "pad2"]);
1412        let lhs_slice = (&lhs, 1, 1); // only "a:b:c"
1413        let rhs_slice = (&rhs, 1, 1); // only ":"
1414        let out =
1415            super::apply_str_dict32(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1416        assert_str(&out, &["a", "b", "c"], None);
1417    }
1418
1419    #[cfg(feature = "str_arithmetic")]
1420    #[test]
1421    fn dict32_num_multiply() {
1422        let lhs = cat(&["x", "y"]);
1423        let nums: &[u32] = &[3, 1];
1424        let lhs_slice = (&lhs, 0, lhs.data.len());
1425        let nums_window = &nums[0..lhs.data.len()];
1426        let out =
1427            super::apply_dict32_num(lhs_slice, nums_window, ArithmeticOperator::Multiply).unwrap();
1428        assert_eq!(out.get(0).unwrap(), "xxx");
1429        assert_eq!(out.get(1).unwrap(), "y");
1430    }
1431
1432    #[cfg(feature = "str_arithmetic")]
1433    #[test]
1434    fn dict32_num_multiply_chunk() {
1435        let lhs = cat(&["pad", "x", "y", "pad2"]);
1436        let nums: &[u32] = &[0, 3, 1, 0];
1437        let lhs_slice = (&lhs, 1, 2); // only "x", "y"
1438        let nums_window = &nums[1..3];
1439        let out =
1440            super::apply_dict32_num(lhs_slice, nums_window, ArithmeticOperator::Multiply).unwrap();
1441        assert_eq!(out.get(0).unwrap(), "xxx");
1442        assert_eq!(out.get(1).unwrap(), "y");
1443    }
1444
1445    #[cfg(feature = "str_arithmetic")]
1446    fn cat32_str_arr(strings: &[&str]) -> (CategoricalArray<u32>, StringArray<u32>) {
1447        let str_arr = StringArray::from_vec(strings.to_vec(), None);
1448        let cat_arr = str_arr.to_categorical_array();
1449        (cat_arr, str_arr)
1450    }
1451
1452    #[cfg(feature = "str_arithmetic")]
1453    #[test]
1454    fn test_apply_dict32_str_add_and_divide() {
1455        let (lhs_cat, rhs_str) = cat32_str_arr(&["foo", "bar|baz", ""]);
1456        // Add: Use slices
1457        let lhs_cat_slice = (&lhs_cat, 0, lhs_cat.data.len());
1458        let rhs_str_slice = (&rhs_str, 0, rhs_str.len());
1459        let added =
1460            apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Add).unwrap();
1461        let expected_cat = apply_str_str(
1462            (&lhs_cat.to_string_array(), 0, lhs_cat.len()),
1463            rhs_str_slice,
1464            ArithmeticOperator::Add,
1465        )
1466        .unwrap()
1467        .to_categorical_array();
1468        assert_eq!(added.unique_values, expected_cat.unique_values);
1469        assert_eq!(added.data, expected_cat.data);
1470
1471        // Divide: Use slices
1472        let divided =
1473            apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Divide).unwrap();
1474        let expected_div = apply_str_str(
1475            (&lhs_cat.to_string_array(), 0, lhs_cat.len()),
1476            rhs_str_slice,
1477            ArithmeticOperator::Divide,
1478        )
1479        .unwrap()
1480        .to_categorical_array();
1481        assert_eq!(divided.unique_values, expected_div.unique_values);
1482        assert_eq!(divided.data, expected_div.data);
1483    }
1484
1485    #[cfg(feature = "str_arithmetic")]
1486    #[test]
1487    fn test_apply_dict32_str_add_and_divide_chunk() {
1488        let (lhs_cat, rhs_str) = cat32_str_arr(&["pad", "foo", "bar|baz", "", "pad2"]);
1489        let lhs_cat_slice = (&lhs_cat, 1, 3); // only "foo", "bar|baz", ""
1490        let rhs_str_slice = (&rhs_str, 1, 3);
1491
1492        // Add
1493        let added =
1494            apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Add).unwrap();
1495        let expected_cat = apply_str_str(
1496            (&lhs_cat.to_string_array(), 1, 3),
1497            rhs_str_slice,
1498            ArithmeticOperator::Add,
1499        )
1500        .unwrap()
1501        .to_categorical_array();
1502        assert_eq!(added.unique_values, expected_cat.unique_values);
1503        assert_eq!(added.data, expected_cat.data);
1504
1505        // Divide
1506        let divided =
1507            apply_dict32_str(lhs_cat_slice, rhs_str_slice, ArithmeticOperator::Divide).unwrap();
1508        let expected_div = apply_str_str(
1509            (&lhs_cat.to_string_array(), 1, 3),
1510            rhs_str_slice,
1511            ArithmeticOperator::Divide,
1512        )
1513        .unwrap()
1514        .to_categorical_array();
1515        assert_eq!(divided.unique_values, expected_div.unique_values);
1516        assert_eq!(divided.data, expected_div.data);
1517    }
1518
1519
1520    // String arithmetic
1521
1522    #[cfg(feature = "str_arithmetic")]
1523    fn string_array<T: Integer>(data: &[&str], nulls: Option<&[bool]>) -> StringArray<T> {
1524        let array = StringArray::from_vec(data.to_vec(), nulls.map(Bitmask::from_bools));
1525        assert_eq!(array.len(), data.len());
1526        array
1527    }
1528
1529    #[cfg(feature = "str_arithmetic")]
1530    #[test]
1531    fn test_add_str() {
1532        let lhs = string_array::<u32>(&["a", "b", "c"], None);
1533        let rhs = string_array::<u32>(&["x", "y", "z"], None);
1534        let lhs_slice = (&lhs, 0, lhs.len());
1535        let rhs_slice = (&rhs, 0, rhs.len());
1536        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1537
1538        assert_eq!(result.get(0), Some("ax"));
1539        assert_eq!(result.get(1), Some("by"));
1540        assert_eq!(result.get(2), Some("cz"));
1541    }
1542
1543    #[cfg(feature = "str_arithmetic")]
1544    #[test]
1545    fn test_add_str_chunk() {
1546        let lhs = string_array::<u32>(&["pad", "a", "b", "c", "pad2"], None);
1547        let rhs = string_array::<u32>(&["pad", "x", "y", "z", "pad2"], None);
1548        // window: indices 1..4
1549        let lhs_slice = (&lhs, 1, 3);
1550        let rhs_slice = (&rhs, 1, 3);
1551        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1552
1553        assert_eq!(result.get(0), Some("ax"));
1554        assert_eq!(result.get(1), Some("by"));
1555        assert_eq!(result.get(2), Some("cz"));
1556    }
1557
1558    #[cfg(feature = "str_arithmetic")]
1559    #[test]
1560    fn test_subtract_str() {
1561        let lhs = string_array::<u32>(&["hello", "goodbye", "test"], None);
1562        let rhs = string_array::<u32>(&["l", "bye", "xyz"], None);
1563        let lhs_slice = (&lhs, 0, lhs.len());
1564        let rhs_slice = (&rhs, 0, rhs.len());
1565        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1566
1567        assert_eq!(result.get(0), Some("helo"));
1568        assert_eq!(result.get(1), Some("good"));
1569        assert_eq!(result.get(2), Some("test")); // no match
1570    }
1571
1572    #[cfg(feature = "str_arithmetic")]
1573    #[test]
1574    fn test_subtract_str_chunk() {
1575        let lhs = string_array::<u32>(&["pad", "hello", "goodbye", "test", "pad2"], None);
1576        let rhs = string_array::<u32>(&["pad", "l", "bye", "xyz", "pad2"], None);
1577        // window: indices 1..4
1578        let lhs_slice = (&lhs, 1, 3);
1579        let rhs_slice = (&rhs, 1, 3);
1580        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
1581
1582        assert_eq!(result.get(0), Some("helo"));
1583        assert_eq!(result.get(1), Some("good"));
1584        assert_eq!(result.get(2), Some("test")); // no match
1585    }
1586
1587    #[cfg(feature = "str_arithmetic")]
1588    #[test]
1589    fn test_multiply_str() {
1590        let lhs = string_array::<u32>(&["x", "ab", "c"], None);
1591        let rhs = string_array::<u32>(&["123", "12", "long_string"], None);
1592        let lhs_slice = (&lhs, 0, lhs.len());
1593        let rhs_slice = (&rhs, 0, rhs.len());
1594        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Multiply).unwrap();
1595
1596        assert_eq!(result.get(0), Some("xxx"));
1597        assert_eq!(result.get(1), Some("abab"));
1598        assert_eq!(
1599            result.get(2),
1600            Some("c".repeat("long_string".len()).as_str())
1601        );
1602    }
1603
1604    #[cfg(feature = "str_arithmetic")]
1605    #[test]
1606    fn test_multiply_str_chunk() {
1607        let lhs = string_array::<u32>(&["pad", "x", "ab", "c", "pad2"], None);
1608        let rhs = string_array::<u32>(&["pad", "123", "12", "long_string", "pad2"], None);
1609        // window: indices 1..4
1610        let lhs_slice = (&lhs, 1, 3);
1611        let rhs_slice = (&rhs, 1, 3);
1612        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Multiply).unwrap();
1613
1614        assert_eq!(result.get(0), Some("xxx"));
1615        assert_eq!(result.get(1), Some("abab"));
1616        assert_eq!(
1617            result.get(2),
1618            Some("c".repeat("long_string".len()).as_str())
1619        );
1620    }
1621
1622    #[cfg(feature = "str_arithmetic")]
1623    #[test]
1624    fn test_divide_str() {
1625        let lhs = string_array::<u32>(&["a,b,c", "a--b--c", "abc"], None);
1626        let rhs = string_array::<u32>(&[",", "--", ""], None);
1627        let lhs_slice = (&lhs, 0, lhs.len());
1628        let rhs_slice = (&rhs, 0, rhs.len());
1629        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1630
1631        assert_eq!(result.get(0), Some("a|b|c"));
1632        assert_eq!(result.get(1), Some("a|b|c"));
1633        assert_eq!(result.get(2), Some("abc"));
1634    }
1635
1636    #[cfg(feature = "str_arithmetic")]
1637    #[test]
1638    fn test_divide_str_chunk() {
1639        let lhs = string_array::<u32>(&["xxx", "a,b,c", "a--b--c", "abc", "yyy"], None);
1640        let rhs = string_array::<u32>(&["", ",", "--", "", ""], None);
1641        // operate only on the window: indices 1,2,3
1642        let lhs_slice = (&lhs, 1, 3);
1643        let rhs_slice = (&rhs, 1, 3);
1644        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
1645
1646        assert_eq!(result.get(0), Some("a|b|c"));
1647        assert_eq!(result.get(1), Some("a|b|c"));
1648        assert_eq!(result.get(2), Some("abc"));
1649    }
1650
1651    #[cfg(feature = "str_arithmetic")]
1652    #[test]
1653    fn test_nulls_str() {
1654        let lhs = string_array::<u32>(&["a", "b", "c"], Some(&[true, false, true]));
1655        let rhs = string_array::<u32>(&["x", "y", "z"], Some(&[true, true, false]));
1656        let lhs_slice = (&lhs, 0, lhs.len());
1657        let rhs_slice = (&rhs, 0, rhs.len());
1658        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1659
1660        assert_eq!(result.get(0), Some("ax"));
1661        assert_eq!(result.get(1), None);
1662        assert_eq!(result.get(2), None);
1663    }
1664
1665    #[cfg(feature = "str_arithmetic")]
1666    #[test]
1667    fn test_nulls_str_chunk() {
1668        let lhs = string_array::<u32>(
1669            &["0", "a", "b", "c", "9"],
1670            Some(&[false, true, false, true, false]),
1671        );
1672        let rhs = string_array::<u32>(
1673            &["y", "x", "y", "z", "w"],
1674            Some(&[true, true, true, false, false]),
1675        );
1676        // window covering indices 1..4
1677        let lhs_slice = (&lhs, 1, 3);
1678        let rhs_slice = (&rhs, 1, 3);
1679        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
1680
1681        assert_eq!(result.get(0), Some("ax"));
1682        assert_eq!(result.get(1), None);
1683        assert_eq!(result.get(2), None);
1684    }
1685
1686    #[cfg(feature = "str_arithmetic")]
1687    #[test]
1688    fn test_mismatched_length_str() {
1689        let lhs = string_array::<u32>(&["a", "b"], None);
1690        let rhs = string_array::<u32>(&["x"], None);
1691        let lhs_slice = (&lhs, 0, lhs.len());
1692        let rhs_slice = (&rhs, 0, rhs.len());
1693        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add);
1694        assert!(matches!(result, Err(KernelError::LengthMismatch(_))));
1695    }
1696
1697    #[cfg(feature = "str_arithmetic")]
1698    #[test]
1699    fn test_mismatched_length_str_chunk() {
1700        let lhs = string_array::<u32>(&["a", "b", "c"], None);
1701        let rhs = string_array::<u32>(&["x"], None);
1702        // windowed call: indices 1..3, lhs has 2 elements, rhs has 1, so mismatch
1703        let lhs_slice = (&lhs, 1, 2);
1704        let rhs_slice = (&rhs, 0, 1);
1705        let result = apply_str_str(lhs_slice, rhs_slice, ArithmeticOperator::Add);
1706        assert!(matches!(result, Err(KernelError::LengthMismatch(_))));
1707    }
1708}