vortex_array/arrays/decimal/compute/
sum.rs1use vortex_error::{VortexResult, vortex_bail};
5use vortex_mask::Mask;
6use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
7
8use crate::arrays::{DecimalArray, DecimalVTable};
9use crate::compute::{SumKernel, SumKernelAdapter};
10use crate::register_kernel;
11
12macro_rules! sum_decimal {
13 ($ty:ty, $values:expr) => {{
14 let mut sum: $ty = <$ty>::default();
15 for v in $values {
16 sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
17 }
18 sum
19 }};
20 ($ty:ty, $values:expr, $validity:expr) => {{
21 use itertools::Itertools;
22
23 let mut sum: $ty = <$ty>::default();
24 for (v, valid) in $values.iter().zip_eq($validity.iter()) {
25 if valid {
26 sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
27 }
28 }
29 sum
30 }};
31}
32
33impl SumKernel for DecimalVTable {
34 fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
35 let decimal_dtype = array.decimal_dtype();
36 let nullability = array.dtype.nullability();
37
38 match array.validity_mask()? {
39 Mask::AllFalse(_) => {
40 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
41 }
42 Mask::AllTrue(_) => {
43 match_each_decimal_value_type!(array.values_type(), |D| {
44 Ok(Scalar::decimal(
45 DecimalValue::from(sum_decimal!(D, array.buffer::<D>())),
46 decimal_dtype,
47 nullability,
48 ))
49 })
50 }
51 Mask::Values(mask_values) => {
52 match_each_decimal_value_type!(array.values_type(), |D| {
53 Ok(Scalar::decimal(
54 DecimalValue::from(sum_decimal!(
55 D,
56 array.buffer::<D>(),
57 mask_values.boolean_buffer()
58 )),
59 decimal_dtype,
60 nullability,
61 ))
62 })
63 }
64 }
65 }
66}
67
68register_kernel!(SumKernelAdapter(DecimalVTable).lift());