vortex_array/arrays/decimal/compute/
sum.rs1use vortex_error::{VortexResult, vortex_bail, vortex_err};
5use vortex_mask::Mask;
6use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
7
8use crate::arrays::{DecimalArray, DecimalVTable};
9use crate::compute::{SumKernel, SumKernelAdapter};
10use crate::register_kernel;
11
12macro_rules! sum_decimal {
13 ($ty:ty, $values:expr) => {{
14 let mut sum: $ty = <$ty>::default();
15 for v in $values.iter() {
16 sum = num_traits::CheckedAdd::checked_add(&sum, v)
17 .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?;
18 }
19 sum
20 }};
21 ($ty:ty, $values:expr, $validity:expr) => {{
22 use itertools::Itertools;
23
24 let mut sum: $ty = <$ty>::default();
25 for (v, valid) in $values.iter().zip_eq($validity.iter()) {
26 if valid {
27 sum = num_traits::CheckedAdd::checked_add(&sum, v)
28 .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
29 }
30 }
31 sum
32 }};
33}
34
35impl SumKernel for DecimalVTable {
36 fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
37 let decimal_dtype = array.decimal_dtype();
38 let nullability = array.dtype().nullability();
39
40 match array.validity_mask() {
41 Mask::AllFalse(_) => {
42 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
43 }
44 Mask::AllTrue(_) => {
45 match_each_decimal_value_type!(array.values_type(), |D| {
46 Ok(Scalar::decimal(
47 DecimalValue::from(sum_decimal!(D, array.buffer::<D>())),
48 decimal_dtype,
49 nullability,
50 ))
51 })
52 }
53 Mask::Values(mask_values) => {
54 match_each_decimal_value_type!(array.values_type(), |D| {
55 Ok(Scalar::decimal(
56 DecimalValue::from(sum_decimal!(
57 D,
58 array.buffer::<D>(),
59 mask_values.boolean_buffer()
60 )),
61 decimal_dtype,
62 nullability,
63 ))
64 })
65 }
66 }
67 }
68}
69
70register_kernel!(SumKernelAdapter(DecimalVTable).lift());