vortex_array/arrays/decimal/compute/
sum.rs

1use vortex_error::{VortexResult, vortex_bail};
2use vortex_mask::Mask;
3use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
4
5use crate::arrays::{DecimalArray, DecimalVTable};
6use crate::compute::{SumKernel, SumKernelAdapter};
7use crate::register_kernel;
8
9macro_rules! sum_decimal {
10    ($ty:ty, $values:expr) => {{
11        let mut sum: $ty = <$ty>::default();
12        for v in $values {
13            sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
14        }
15        sum
16    }};
17    ($ty:ty, $values:expr, $validity:expr) => {{
18        use itertools::Itertools;
19
20        let mut sum: $ty = <$ty>::default();
21        for (v, valid) in $values.iter().zip_eq($validity.iter()) {
22            if valid {
23                sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
24            }
25        }
26        sum
27    }};
28}
29
30impl SumKernel for DecimalVTable {
31    fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
32        let decimal_dtype = array.decimal_dtype();
33        let nullability = array.dtype.nullability();
34
35        match array.validity_mask()? {
36            Mask::AllFalse(_) => {
37                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
38            }
39            Mask::AllTrue(_) => {
40                match_each_decimal_value_type!(array.values_type(), |D| {
41                    Ok(Scalar::decimal(
42                        DecimalValue::from(sum_decimal!(D, array.buffer::<D>())),
43                        decimal_dtype,
44                        nullability,
45                    ))
46                })
47            }
48            Mask::Values(mask_values) => {
49                match_each_decimal_value_type!(array.values_type(), |D| {
50                    Ok(Scalar::decimal(
51                        DecimalValue::from(sum_decimal!(
52                            D,
53                            array.buffer::<D>(),
54                            mask_values.boolean_buffer()
55                        )),
56                        decimal_dtype,
57                        nullability,
58                    ))
59                })
60            }
61        }
62    }
63}
64
65register_kernel!(SumKernelAdapter(DecimalVTable).lift());