vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::{VortexResult, vortex_bail};
5use vortex_mask::Mask;
6use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
7
8use crate::arrays::{DecimalArray, DecimalVTable};
9use crate::compute::{SumKernel, SumKernelAdapter};
10use crate::register_kernel;
11
12macro_rules! sum_decimal {
13    ($ty:ty, $values:expr) => {{
14        let mut sum: $ty = <$ty>::default();
15        for v in $values {
16            sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
17        }
18        sum
19    }};
20    ($ty:ty, $values:expr, $validity:expr) => {{
21        use itertools::Itertools;
22
23        let mut sum: $ty = <$ty>::default();
24        for (v, valid) in $values.iter().zip_eq($validity.iter()) {
25            if valid {
26                sum = num_traits::CheckedAdd::checked_add(&sum, &v).expect("overflow");
27            }
28        }
29        sum
30    }};
31}
32
33impl SumKernel for DecimalVTable {
34    fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
35        let decimal_dtype = array.decimal_dtype();
36        let nullability = array.dtype.nullability();
37
38        match array.validity_mask()? {
39            Mask::AllFalse(_) => {
40                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
41            }
42            Mask::AllTrue(_) => {
43                match_each_decimal_value_type!(array.values_type(), |D| {
44                    Ok(Scalar::decimal(
45                        DecimalValue::from(sum_decimal!(D, array.buffer::<D>())),
46                        decimal_dtype,
47                        nullability,
48                    ))
49                })
50            }
51            Mask::Values(mask_values) => {
52                match_each_decimal_value_type!(array.values_type(), |D| {
53                    Ok(Scalar::decimal(
54                        DecimalValue::from(sum_decimal!(
55                            D,
56                            array.buffer::<D>(),
57                            mask_values.boolean_buffer()
58                        )),
59                        decimal_dtype,
60                        nullability,
61                    ))
62                })
63            }
64        }
65    }
66}
67
68register_kernel!(SumKernelAdapter(DecimalVTable).lift());