vortex_array/arrays/decimal/compute/
sum.rs1use vortex_error::{VortexResult, vortex_bail};
2use vortex_mask::Mask;
3use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
4
5use crate::arrays::{DecimalArray, DecimalVTable};
6use crate::compute::{SumKernel, SumKernelAdapter};
7use crate::register_kernel;
8
9macro_rules! sum_decimal {
10 ($ty:ty, $values:expr) => {{
11 let mut sum: $ty = <$ty>::default();
12 for v in $values {
13 sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
14 }
15 sum
16 }};
17 ($ty:ty, $values:expr, $validity:expr) => {{
18 use itertools::Itertools;
19
20 let mut sum: $ty = <$ty>::default();
21 for (v, valid) in $values.iter().zip_eq($validity.iter()) {
22 if valid {
23 sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
24 }
25 }
26 sum
27 }};
28}
29
30impl SumKernel for DecimalVTable {
31 fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
32 let decimal_dtype = array.decimal_dtype();
33 let nullability = array.dtype.nullability();
34
35 match array.validity_mask()? {
36 Mask::AllFalse(_) => {
37 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
38 }
39 Mask::AllTrue(_) => {
40 match_each_decimal_value_type!(array.values_type(), |D| {
41 Ok(Scalar::decimal(
42 DecimalValue::from(sum_decimal!(D, array.buffer::<D>())),
43 decimal_dtype,
44 nullability,
45 ))
46 })
47 }
48 Mask::Values(mask_values) => {
49 match_each_decimal_value_type!(array.values_type(), |D| {
50 Ok(Scalar::decimal(
51 DecimalValue::from(sum_decimal!(
52 D,
53 array.buffer::<D>(),
54 mask_values.boolean_buffer()
55 )),
56 decimal_dtype,
57 nullability,
58 ))
59 })
60 }
61 }
62 }
63}
64
65register_kernel!(SumKernelAdapter(DecimalVTable).lift());