vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::{VortexResult, vortex_bail, vortex_err};
5use vortex_mask::Mask;
6use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
7
8use crate::arrays::{DecimalArray, DecimalVTable};
9use crate::compute::{SumKernel, SumKernelAdapter};
10use crate::register_kernel;
11
12macro_rules! sum_decimal {
13    ($ty:ty, $values:expr) => {{
14        let mut sum: $ty = <$ty>::default();
15        for v in $values.iter() {
16            sum = num_traits::CheckedAdd::checked_add(&sum, v)
17                .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?;
18        }
19        sum
20    }};
21    ($ty:ty, $values:expr, $validity:expr) => {{
22        use itertools::Itertools;
23
24        let mut sum: $ty = <$ty>::default();
25        for (v, valid) in $values.iter().zip_eq($validity.iter()) {
26            if valid {
27                sum = num_traits::CheckedAdd::checked_add(&sum, v)
28                    .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
29            }
30        }
31        sum
32    }};
33}
34
35impl SumKernel for DecimalVTable {
36    fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
37        let decimal_dtype = array.decimal_dtype();
38        let nullability = array.dtype().nullability();
39
40        match array.validity_mask() {
41            Mask::AllFalse(_) => {
42                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
43            }
44            Mask::AllTrue(_) => {
45                match_each_decimal_value_type!(array.values_type(), |D| {
46                    Ok(Scalar::decimal(
47                        DecimalValue::from(sum_decimal!(D, array.buffer::<D>())),
48                        decimal_dtype,
49                        nullability,
50                    ))
51                })
52            }
53            Mask::Values(mask_values) => {
54                match_each_decimal_value_type!(array.values_type(), |D| {
55                    Ok(Scalar::decimal(
56                        DecimalValue::from(sum_decimal!(
57                            D,
58                            array.buffer::<D>(),
59                            mask_values.boolean_buffer()
60                        )),
61                        decimal_dtype,
62                        nullability,
63                    ))
64                })
65            }
66        }
67    }
68}
69
70register_kernel!(SumKernelAdapter(DecimalVTable).lift());