1pub mod dispatch;
23#[cfg(feature = "simd")]
24pub mod simd;
25pub mod std;
26pub mod string;
27
28#[cfg(test)]
31mod tests {
32 use minarrow::structs::variants::float::FloatArray;
33 use minarrow::structs::variants::integer::IntegerArray;
34 use minarrow::{Bitmask, MaskedArray, vec64};
35
36 use crate::kernels::arithmetic::dispatch::{
37 apply_float_f32, apply_float_f64, apply_fma_f32, apply_fma_f64, apply_int_i32,
38 apply_int_i64, apply_int_u32, apply_int_u64,
39 };
40 #[cfg(feature = "extended_numeric_types")]
41 use crate::kernels::arithmetic::dispatch::{
42 apply_int_i8, apply_int_i16, apply_int_u8, apply_int_u16,
43 };
44 #[cfg(feature = "simd")]
45 use crate::kernels::arithmetic::simd::int_dense_body_simd;
46 use crate::operators::ArithmeticOperator;
47
48 fn assert_int<T>(arr: &IntegerArray<T>, values: &[T], valid: Option<&[bool]>)
49 where
50 T: num_traits::PrimInt + std::fmt::Debug,
51 {
52 assert_eq!(arr.data.as_slice(), values);
53 match (valid, &arr.null_mask) {
54 (None, None) => {}
55 (Some(expected), Some(mask)) => {
56 for (i, bit) in expected.iter().enumerate() {
57 assert_eq!(
58 unsafe { mask.get_unchecked(i) },
59 *bit,
60 "mask mismatch at {i}"
61 );
62 }
63 }
64 (None, Some(mask)) => {
65 assert!(mask.all_true(), "mask unexpectedly present");
66 }
67 (Some(_), None) => panic!("expected mask missing"),
68 }
69 }
70
71 fn assert_float<T>(arr: &FloatArray<T>, values: &[T], valid: Option<&[bool]>)
72 where
73 T: num_traits::Float + std::fmt::Debug,
74 {
75 assert_eq!(arr.data.as_slice(), values);
76 match (valid, &arr.null_mask) {
77 (None, None) => {}
78 (Some(expected), Some(mask)) => {
79 for (i, bit) in expected.iter().enumerate() {
80 assert_eq!(
81 unsafe { mask.get_unchecked(i) },
82 *bit,
83 "mask mismatch at {i}"
84 );
85 }
86 }
87 (None, Some(mask)) => {
88 assert!(mask.all_true(), "mask unexpectedly present");
89 }
90 (Some(_), None) => panic!("expected mask missing"),
91 }
92 }
93
94 fn bitmask(bits: &[bool]) -> Bitmask {
95 let mut m = Bitmask::new_set_all(bits.len(), false);
96 for (i, b) in bits.iter().enumerate() {
97 unsafe { m.set_unchecked(i, *b) };
98 }
99 m
100 }
101
102 macro_rules! int_kernel_suite {
103 ($fn_dense:ident, $fn_masked:ident, $fn_empty:ident, $ty:ty, $apply_fn:ident) => {
104 #[test]
105 fn $fn_dense() {
106 let lhs = vec64![1, 4, 9, 16];
107 let rhs = vec64![1, 2, 3, 4];
108
109 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Add, None).unwrap();
110 assert_int(
111 &out,
112 &IntegerArray::<$ty>::from_slice(&[2, 6, 12, 20]),
113 None,
114 );
115
116 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Subtract, None).unwrap();
117 assert_int(&out, &IntegerArray::<$ty>::from_slice(&[0, 2, 6, 12]), None);
118
119 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Multiply, None).unwrap();
120 assert_int(
121 &out,
122 &IntegerArray::<$ty>::from_slice(&[1, 8, 27, 64]),
123 None,
124 );
125
126 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Divide, None).unwrap();
127 assert_int(&out, &IntegerArray::<$ty>::from_slice(&[1, 2, 3, 4]), None);
128
129 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Remainder, None).unwrap();
130 assert_int(&out, &IntegerArray::<$ty>::from_slice(&[0, 0, 0, 0]), None);
131
132 let expected: Vec<$ty> = lhs
133 .iter()
134 .zip(rhs.iter())
135 .map(|(&a, &b)| {
136 let mut acc = <$ty as num_traits::One>::one();
137 for _ in 0..(b as u32) {
138 acc = acc.wrapping_mul(a);
139 }
140 acc
141 })
142 .collect();
143 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Power, None).unwrap();
144 assert_int(&out, &IntegerArray::<$ty>::from_slice(&expected), None);
145
146 let rhs_divzero: &[$ty] = &[0, 0, 0, 0];
148 let result = std::panic::catch_unwind(|| {
149 $apply_fn(&lhs, rhs_divzero, ArithmeticOperator::Divide, None).unwrap()
150 });
151 assert!(
152 result.is_err(),
153 "Dense integer kernel division by zero must panic"
154 );
155
156 let result = std::panic::catch_unwind(|| {
157 $apply_fn(&lhs, rhs_divzero, ArithmeticOperator::Remainder, None).unwrap()
158 });
159 assert!(
160 result.is_err(),
161 "Dense integer kernel remainder by zero must panic"
162 );
163 }
164
165 #[test]
166 fn $fn_masked() {
167 let lhs = vec64![10, 20, 30, 40];
168 let rhs = vec64![2, 0, 3, 5];
169 let mask = bitmask(&[true, false, true, false]);
170
171 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Divide, Some(&mask)).unwrap();
173 assert_int(
174 &out,
175 &IntegerArray::<$ty>::from_slice(&[5, 0, 10, 0]),
176 Some(&[true, false, true, false]),
177 );
178
179 let out =
181 $apply_fn(&lhs, &rhs, ArithmeticOperator::Remainder, Some(&mask)).unwrap();
182 assert_int(
183 &out,
184 &IntegerArray::<$ty>::from_slice(&[0, 0, 0, 0]),
185 Some(&[true, false, true, false]),
186 );
187
188 let mask_divzero = bitmask(&[true, true, true, true]);
190 let rhs_divzero: &[$ty] = &[1, 0, 2, 0];
191 let lhs2: &[$ty] = &[100, 100, 100, 100];
192
193 let out = $apply_fn(
194 lhs2,
195 rhs_divzero,
196 ArithmeticOperator::Divide,
197 Some(&mask_divzero),
198 )
199 .unwrap();
200 assert_int(
201 &out,
202 &IntegerArray::<$ty>::from_slice(&[100, 0, 50, 0]),
203 Some(&[true, false, true, false]),
204 );
205 }
206
207 #[test]
208 fn $fn_empty() {
209 let lhs = vec64![];
210 let rhs = vec64![];
211 let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Add, None).unwrap();
212 assert!(out.is_empty());
213 }
214 };
215 }
216
217 #[cfg(feature = "extended_numeric_types")]
218 int_kernel_suite!(
219 apply_int_i8_dense,
220 apply_int_i8_masked,
221 apply_int_i8_empty,
222 i8,
223 apply_int_i8
224 );
225 #[cfg(feature = "extended_numeric_types")]
226 int_kernel_suite!(
227 apply_int_u8_dense,
228 apply_int_u8_masked,
229 apply_int_u8_empty,
230 u8,
231 apply_int_u8
232 );
233 #[cfg(feature = "extended_numeric_types")]
234 int_kernel_suite!(
235 apply_int_i16_dense,
236 apply_int_i16_masked,
237 apply_int_i16_empty,
238 i16,
239 apply_int_i16
240 );
241 #[cfg(feature = "extended_numeric_types")]
242 int_kernel_suite!(
243 apply_int_u16_dense,
244 apply_int_u16_masked,
245 apply_int_u16_empty,
246 u16,
247 apply_int_u16
248 );
249 int_kernel_suite!(
250 apply_int_i32_dense,
251 apply_int_i32_masked,
252 apply_int_i32_empty,
253 i32,
254 apply_int_i32
255 );
256 int_kernel_suite!(
257 apply_int_u32_dense,
258 apply_int_u32_masked,
259 apply_int_u32_empty,
260 u32,
261 apply_int_u32
262 );
263 int_kernel_suite!(
264 apply_int_i64_dense,
265 apply_int_i64_masked,
266 apply_int_i64_empty,
267 i64,
268 apply_int_i64
269 );
270 int_kernel_suite!(
271 apply_int_u64_dense,
272 apply_int_u64_masked,
273 apply_int_u64_empty,
274 u64,
275 apply_int_u64
276 );
277
278 macro_rules! float_kernel_suite {
279 ($test_fn:ident, $ty:ty, $apply_fn:ident, $eps:expr) => {
280 #[test]
281 fn $test_fn() {
282 let lhs = vec64![1.0, 4.0, 9.0, 16.0];
283 let rhs = vec64![0.5, 2.0, 3.0, 4.0];
284
285 let lhs: &[$ty] = lhs.as_slice();
286 let rhs: &[$ty] = rhs.as_slice();
287
288 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Add, None).unwrap();
289 assert_eq!(arr.data.as_slice(), &[1.5 as $ty, 6.0, 12.0, 20.0]);
290
291 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Subtract, None).unwrap();
292 assert_eq!(arr.data.as_slice(), &[0.5 as $ty, 2.0, 6.0, 12.0]);
293
294 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Multiply, None).unwrap();
295 assert_eq!(arr.data.as_slice(), &[0.5 as $ty, 8.0, 27.0, 64.0]);
296
297 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Divide, None).unwrap();
298 assert_eq!(arr.data.as_slice(), &[2.0 as $ty, 2.0, 3.0, 4.0]);
299
300 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Remainder, None).unwrap();
301 assert!(
302 arr.data
303 .as_slice()
304 .iter()
305 .zip(
306 [1.0 % 0.5, 4.0 % 2.0, 9.0 % 3.0, 16.0 % 4.0]
307 .iter()
308 .map(|&x| x as $ty)
309 )
310 .all(|(a, b)| (*a - b).abs() < $eps)
311 );
312
313 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Power, None).unwrap();
314 let expected: Vec<$ty> = lhs
315 .iter()
316 .zip(rhs.iter())
317 .map(|(&a, &b)| (b * a.ln()).exp())
318 .collect();
319 assert!(
320 arr.data
321 .as_slice()
322 .iter()
323 .zip(expected.iter())
324 .all(|(a, b)| (*a - *b).abs() < $eps)
325 );
326
327 let rhs_divzero: &[$ty] = &[0.0, 0.0, 0.0, 0.0];
329 let arr = $apply_fn(lhs, rhs_divzero, ArithmeticOperator::Divide, None).unwrap();
330 assert!(
331 arr.data.iter().all(|&x| x.is_infinite()),
332 "Float division by zero should yield Inf"
333 );
334
335 let arr = $apply_fn(lhs, rhs_divzero, ArithmeticOperator::Remainder, None).unwrap();
336 assert!(
337 arr.data.iter().all(|&x| x.is_nan()),
338 "Float remainder by zero should yield NaN"
339 );
340
341 let mask = bitmask(&[true, false, true, false]);
343 let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Multiply, Some(&mask)).unwrap();
344 assert_eq!(arr.data.as_slice(), &[0.5 as $ty, 0.0, 27.0, 0.0]);
345 assert_eq!(arr.null_mask.as_ref().unwrap().len(), 4);
346
347 let arr = $apply_fn(&[], &[], ArithmeticOperator::Add, None).unwrap();
349 assert!(arr.is_empty());
350 }
351 };
352 }
353
354 float_kernel_suite!(apply_float_f32_dense, f32, apply_float_f32, 1e-6f32);
355 float_kernel_suite!(apply_float_f64_dense, f64, apply_float_f64, 1e-12f64);
356
357 #[test]
358 fn fma_f32() {
359 let lhs = vec64![1.0f32, 2.0, 3.0];
360 let rhs = vec64![4.0f32, 5.0, 6.0];
361 let acc = vec64![0.5f32, 0.5, 0.5];
362 let out = apply_fma_f32(&lhs, &rhs, &acc, None).unwrap();
363 assert_float(&out, &[4.5, 10.5, 18.5], None);
364
365 let mask = bitmask(&[true, false, true]);
366 let out = apply_fma_f32(&lhs, &rhs, &acc, Some(&mask)).unwrap();
367 assert_float(&out, &[4.5, 0.0, 18.5], Some(&[true, false, true]));
368
369 let out = apply_fma_f32(&[], &[], &[], None).unwrap();
370 assert!(out.is_empty());
371 }
372
373 #[test]
374 fn fma_f64() {
375 let lhs = vec64![1.0f64, 2.0, 3.0];
376 let rhs = vec64![4.0f64, 5.0, 6.0];
377 let acc = vec64![0.5f64, 0.5, 0.5];
378 let out = apply_fma_f64(&lhs, &rhs, &acc, None).unwrap();
379 assert_float(&out, &[4.5, 10.5, 18.5], None);
380
381 let mask = bitmask(&[true, false, true]);
382 let out = apply_fma_f64(&lhs, &rhs, &acc, Some(&mask)).unwrap();
383 assert_float(&out, &[4.5, 0.0, 18.5], Some(&[true, false, true]));
384 }
385
386 #[test]
387 fn merge_masks_correctness() {
388 let a = bitmask(&[true, false, true, true]);
389 let b = bitmask(&[true, true, false, true]);
390 let merged = crate::utils::merge_bitmasks_to_new(Some(&a), Some(&b), 4).unwrap();
391 let expected = vec![true, false, false, true];
392 let merged_vec: Vec<bool> = (0..4).map(|i| merged.get(i)).collect();
393 assert_eq!(merged_vec, expected);
394 }
395
396 #[cfg(feature = "datetime")]
400 use minarrow::structs::variants::datetime::DatetimeArray;
401
402 #[cfg(feature = "datetime")]
403 use crate::kernels::arithmetic::dispatch::apply_datetime_i64;
404
405 #[cfg(feature = "datetime")]
406 #[test]
407 fn datetime_add() {
408 let lhs = DatetimeArray::<i64>::from_slice(&[1_000i64, 2_000, 3_000], None);
409 let rhs = DatetimeArray::<i64>::from_slice(&[10, 20, 30], None);
410 let lhs_slice = (&lhs, 0, lhs.len());
411 let rhs_slice = (&rhs, 0, rhs.len());
412 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
413 assert_eq!(out.data.as_slice(), &[1_010, 2_020, 3_030]);
414 assert!(out.null_mask.is_none());
415 }
416
417 #[cfg(feature = "datetime")]
418 #[test]
419 fn datetime_all_ops() {
420 let lhs = DatetimeArray::<i64>::from_slice(&[10, 20, 30, 40], None);
421 let rhs = DatetimeArray::<i64>::from_slice(&[1, 2, 3, 4], None);
422 let lhs_slice = (&lhs, 0, lhs.len());
423 let rhs_slice = (&rhs, 0, rhs.len());
424
425 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
426 assert_eq!(out.data.as_slice(), &[11, 22, 33, 44]);
427
428 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
429 assert_eq!(out.data.as_slice(), &[9, 18, 27, 36]);
430
431 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Multiply).unwrap();
432 assert_eq!(out.data.as_slice(), &[10, 40, 90, 160]);
433
434 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
435 assert_eq!(out.data.as_slice(), &[10, 10, 10, 10]);
436
437 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Remainder).unwrap();
438 assert_eq!(out.data.as_slice(), &[0, 0, 0, 0]);
439
440 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Power).unwrap();
441 assert_eq!(
442 out.data.as_slice(),
443 &[10_i64.pow(1), 20_i64.pow(2), 30_i64.pow(3), 40_i64.pow(4)]
444 );
445 }
446
447 #[cfg(feature = "datetime")]
448 #[test]
449 fn datetime_masked_and_empty() {
450 let lhs = DatetimeArray::<i64>::from_slice(&[10, 20, 30, 40], None);
451 let rhs = DatetimeArray::<i64>::from_slice(&[1, 2, 3, 4], None);
452 let mask = bitmask(&[true, false, true, true]);
453 let lhs_slice = (&lhs, 0, lhs.len());
454 let rhs_slice = (&rhs, 0, rhs.len());
455
456 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
457 assert_eq!(out.data.as_slice(), &[11, 22, 33, 44]);
458
459 let mut lhs_masked = lhs.clone();
461 lhs_masked.null_mask = Some(mask.clone());
462 let lhs_slice_masked = (&lhs_masked, 0, lhs_masked.len());
463 let out = apply_datetime_i64(lhs_slice_masked, rhs_slice, ArithmeticOperator::Add).unwrap();
464 let expected = vec![11, 0, 33, 44];
465 let mask_vec: Vec<bool> = (0..4).map(|i| mask.get(i)).collect();
466 assert_eq!(out.data.as_slice(), &expected);
467 assert_eq!(
468 out.null_mask
469 .as_ref()
470 .map(|m| (0..4).map(|i| m.get(i)).collect::<Vec<_>>()),
471 Some(mask_vec)
472 );
473
474 let lhs_empty = DatetimeArray::<i64>::from_slice(&[], None);
476 let rhs_empty = DatetimeArray::<i64>::from_slice(&[], None);
477 let lhs_slice = (&lhs_empty, 0, lhs_empty.len());
478 let rhs_slice = (&rhs_empty, 0, rhs_empty.len());
479 let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
480 assert!(out.is_empty());
481 }
482
483 #[cfg(feature = "datetime")]
484 #[test]
485 #[should_panic(expected = "apply_datetime: length mismatch")]
486 fn datetime_len_mismatch_panics() {
487 let lhs = DatetimeArray::<i64>::from_slice(&[1_000i64, 2_000], None);
488 let rhs = DatetimeArray::<i64>::from_slice(&[10], None);
489 let lhs_slice = (&lhs, 0, lhs.len());
490 let rhs_slice = (&rhs, 0, rhs.len());
491 let _ = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
492 }
493
494 #[cfg(feature = "simd")]
495 #[test]
496 fn test_int_dense_power_short_vs_long_input_simd() {
497 let lhs_short = vec64![2u32; 16];
498 let rhs_short = vec64![10u32; 16];
499 let mut out_short = vec64![0u32; 16];
500
501 let lhs_long = vec64![2u32; 128];
502 let rhs_long = vec64![10u32; 128];
503 let mut out_long = vec64![0u32; 128];
504
505 int_dense_body_simd::<u32, 4>(
506 ArithmeticOperator::Power,
507 &lhs_short,
508 &rhs_short,
509 &mut out_short,
510 );
511 int_dense_body_simd::<u32, 4>(
512 ArithmeticOperator::Power,
513 &lhs_long,
514 &rhs_long,
515 &mut out_long,
516 );
517
518 for &v in out_short.iter() {
519 assert_eq!(v, 1024);
520 }
521 for &v in out_long.iter() {
522 assert_eq!(v, 1024);
523 }
524 }
525}