vortex_array/arrays/decimal/compute/
sum.rs

1use itertools::Itertools;
2use vortex_error::{VortexResult, vortex_bail};
3use vortex_mask::Mask;
4use vortex_scalar::Scalar;
5
6use crate::arrays::{DecimalArray, DecimalEncoding};
7use crate::compute::{SumKernel, SumKernelAdapter};
8use crate::{Array, match_each_decimal_value_type, register_kernel};
9
10macro_rules! sum_decimal {
11    ($ty:ty, $values:expr) => {{
12        let mut sum: $ty = <$ty>::default();
13        for v in $values {
14            sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
15        }
16        sum
17    }};
18    ($ty:ty, $values:expr, $validity:expr) => {{
19        let mut sum: $ty = <$ty>::default();
20        for (v, valid) in $values.iter().zip_eq($validity.iter()) {
21            if valid {
22                sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
23            }
24        }
25        sum
26    }};
27}
28
29impl SumKernel for DecimalEncoding {
30    fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
31        let decimal_dtype = array.decimal_dtype();
32        let nullability = array.dtype.nullability();
33
34        match array.validity_mask()? {
35            Mask::AllFalse(_) => {
36                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
37            }
38            Mask::AllTrue(_) => {
39                match_each_decimal_value_type!(array.values_type(), |($D, $CTor)| {
40                   Ok(Scalar::decimal(
41                    $CTor(sum_decimal!($D, array.buffer::<$D>())),
42                    decimal_dtype,
43                    nullability,
44                    ))
45                })
46            }
47            Mask::Values(mask_values) => {
48                match_each_decimal_value_type!(array.values_type(), |($D, $CTor)|{
49                    Ok(Scalar::decimal(
50                        $CTor(sum_decimal!(
51                            $D,
52                            array.buffer::<$D>(),
53                            mask_values.boolean_buffer()
54                        )),
55                        decimal_dtype,
56                        nullability,
57                    ))
58                })
59            }
60        }
61    }
62}
63
64register_kernel!(SumKernelAdapter(DecimalEncoding).lift());