vortex_array/arrays/decimal/compute/
sum.rs1use itertools::Itertools;
2use vortex_error::{VortexResult, vortex_bail};
3use vortex_mask::Mask;
4use vortex_scalar::Scalar;
5
6use crate::arrays::{DecimalArray, DecimalEncoding};
7use crate::compute::{SumKernel, SumKernelAdapter};
8use crate::{Array, match_each_decimal_value_type, register_kernel};
9
10macro_rules! sum_decimal {
11 ($ty:ty, $values:expr) => {{
12 let mut sum: $ty = <$ty>::default();
13 for v in $values {
14 sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
15 }
16 sum
17 }};
18 ($ty:ty, $values:expr, $validity:expr) => {{
19 let mut sum: $ty = <$ty>::default();
20 for (v, valid) in $values.iter().zip_eq($validity.iter()) {
21 if valid {
22 sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
23 }
24 }
25 sum
26 }};
27}
28
29impl SumKernel for DecimalEncoding {
30 fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
31 let decimal_dtype = array.decimal_dtype();
32 let nullability = array.dtype.nullability();
33
34 match array.validity_mask()? {
35 Mask::AllFalse(_) => {
36 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
37 }
38 Mask::AllTrue(_) => {
39 match_each_decimal_value_type!(array.values_type(), |($D, $CTor)| {
40 Ok(Scalar::decimal(
41 $CTor(sum_decimal!($D, array.buffer::<$D>())),
42 decimal_dtype,
43 nullability,
44 ))
45 })
46 }
47 Mask::Values(mask_values) => {
48 match_each_decimal_value_type!(array.values_type(), |($D, $CTor)|{
49 Ok(Scalar::decimal(
50 $CTor(sum_decimal!(
51 $D,
52 array.buffer::<$D>(),
53 mask_values.boolean_buffer()
54 )),
55 decimal_dtype,
56 nullability,
57 ))
58 })
59 }
60 }
61 }
62}
63
64register_kernel!(SumKernelAdapter(DecimalEncoding).lift());