vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_schema::DECIMAL256_MAX_PRECISION;
5use num_traits::AsPrimitive;
6use vortex_dtype::Nullability::Nullable;
7use vortex_dtype::{DecimalDType, DecimalType, match_each_decimal_value_type};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
11
12use crate::arrays::{DecimalArray, DecimalVTable};
13use crate::compute::{SumKernel, SumKernelAdapter};
14use crate::register_kernel;
15
16// Its safe to use `AsPrimitive` here because we always cast up.
17macro_rules! sum_decimal {
18    ($ty:ty, $values:expr, $initial:expr) => {{
19        let mut sum: $ty = $initial;
20        for v in $values.iter() {
21            let v: $ty = (*v).as_();
22            sum = num_traits::CheckedAdd::checked_add(&sum, &v)
23                .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
24        }
25        sum
26    }};
27    ($ty:ty, $values:expr, $validity:expr, $initial:expr) => {{
28        use itertools::Itertools;
29
30        let mut sum: $ty = $initial;
31        for (v, valid) in $values.iter().zip_eq($validity) {
32            if valid {
33                let v: $ty = (*v).as_();
34                sum = num_traits::CheckedAdd::checked_add(&sum, &v)
35                    .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
36            }
37        }
38        sum
39    }};
40}
41
42impl SumKernel for DecimalVTable {
43    #[allow(clippy::cognitive_complexity)]
44    fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
45        let decimal_dtype = array.decimal_dtype();
46
47        // Both Spark and DataFusion use this heuristic.
48        // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
49        // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
50        let new_precision = u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
51        let new_scale = decimal_dtype.scale();
52        let return_dtype = DecimalDType::new(new_precision, new_scale);
53
54        // Extract the initial value as a DecimalValue
55        let initial_decimal = DecimalScalar::try_from(accumulator)
56            .vortex_expect("must be a decimal")
57            .decimal_value()
58            .vortex_expect("cannot be null");
59
60        match array.validity_mask() {
61            Mask::AllFalse(_) => {
62                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
63            }
64            Mask::AllTrue(_) => {
65                let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
66                match_each_decimal_value_type!(array.values_type(), |I| {
67                    match_each_decimal_value_type!(values_type, |O| {
68                        let initial_val: O = initial_decimal
69                            .cast()
70                            .vortex_expect("cannot fail to cast initial value");
71                        Ok(Scalar::decimal(
72                            DecimalValue::from(sum_decimal!(O, array.buffer::<I>(), initial_val)),
73                            return_dtype,
74                            Nullable,
75                        ))
76                    })
77                })
78            }
79            Mask::Values(mask_values) => {
80                let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
81                match_each_decimal_value_type!(array.values_type(), |I| {
82                    match_each_decimal_value_type!(values_type, |O| {
83                        let initial_val: O = initial_decimal
84                            .cast()
85                            .vortex_expect("cannot fail to cast initial value");
86                        Ok(Scalar::decimal(
87                            DecimalValue::from(sum_decimal!(
88                                O,
89                                array.buffer::<I>(),
90                                mask_values.bit_buffer(),
91                                initial_val
92                            )),
93                            return_dtype,
94                            Nullable,
95                        ))
96                    })
97                })
98            }
99        }
100    }
101}
102
103register_kernel!(SumKernelAdapter(DecimalVTable).lift());
104
105#[cfg(test)]
106mod tests {
107    use vortex_buffer::buffer;
108    use vortex_dtype::{DType, DecimalDType, Nullability};
109    use vortex_scalar::{DecimalValue, Scalar, ScalarValue};
110
111    use crate::arrays::DecimalArray;
112    use crate::compute::sum;
113    use crate::validity::Validity;
114
115    #[test]
116    fn test_sum_basic() {
117        let decimal = DecimalArray::new(
118            buffer![100i32, 200i32, 300i32],
119            DecimalDType::new(4, 2),
120            Validity::AllValid,
121        );
122
123        let result = sum(decimal.as_ref()).unwrap();
124
125        let expected = Scalar::new(
126            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
127            ScalarValue::from(DecimalValue::from(600i32)),
128        );
129
130        assert_eq!(result, expected);
131    }
132
133    #[test]
134    fn test_sum_with_nulls() {
135        let decimal = DecimalArray::new(
136            buffer![100i32, 200i32, 300i32, 400i32],
137            DecimalDType::new(4, 2),
138            Validity::from_iter([true, false, true, true]),
139        );
140
141        let result = sum(decimal.as_ref()).unwrap();
142
143        let expected = Scalar::new(
144            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
145            ScalarValue::from(DecimalValue::from(800i32)),
146        );
147
148        assert_eq!(result, expected);
149    }
150
151    #[test]
152    fn test_sum_negative_values() {
153        let decimal = DecimalArray::new(
154            buffer![100i32, -200i32, 300i32, -50i32],
155            DecimalDType::new(4, 2),
156            Validity::AllValid,
157        );
158
159        let result = sum(decimal.as_ref()).unwrap();
160
161        let expected = Scalar::new(
162            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
163            ScalarValue::from(DecimalValue::from(150i32)),
164        );
165
166        assert_eq!(result, expected);
167    }
168
169    #[test]
170    fn test_sum_near_i32_max() {
171        // Test values close to i32::MAX to ensure proper handling
172        let near_max = i32::MAX - 1000;
173        let decimal = DecimalArray::new(
174            buffer![near_max, 500i32, 400i32],
175            DecimalDType::new(10, 2),
176            Validity::AllValid,
177        );
178
179        let result = sum(decimal.as_ref()).unwrap();
180
181        // Should use i64 for accumulation since precision increases
182        let expected_sum = near_max as i64 + 500 + 400;
183        let expected = Scalar::new(
184            DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
185            ScalarValue::from(DecimalValue::from(expected_sum)),
186        );
187
188        assert_eq!(result, expected);
189    }
190
191    #[test]
192    fn test_sum_large_i64_values() {
193        // Test with large i64 values that require i128 accumulation
194        let large_val = i64::MAX / 4;
195        let decimal = DecimalArray::new(
196            buffer![large_val, large_val, large_val, large_val + 1],
197            DecimalDType::new(19, 0),
198            Validity::AllValid,
199        );
200
201        let result = sum(decimal.as_ref()).unwrap();
202
203        let expected_sum = (large_val as i128) * 4 + 1;
204        let expected = Scalar::new(
205            DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
206            ScalarValue::from(DecimalValue::from(expected_sum)),
207        );
208
209        assert_eq!(result, expected);
210    }
211
212    #[test]
213    fn test_sum_overflow_detection() {
214        use vortex_scalar::i256;
215
216        // Create values that will overflow when summed
217        // Use maximum i128 values that will overflow when added
218        let max_val = i128::MAX / 2;
219        let decimal = DecimalArray::new(
220            buffer![max_val, max_val, max_val],
221            DecimalDType::new(38, 0),
222            Validity::AllValid,
223        );
224
225        let result = sum(decimal.as_ref()).unwrap();
226
227        // Should use i256 for accumulation
228        let expected_sum =
229            i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
230        let expected = Scalar::new(
231            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
232            ScalarValue::from(DecimalValue::from(expected_sum)),
233        );
234
235        assert_eq!(result, expected);
236    }
237
238    #[test]
239    fn test_sum_mixed_signs_near_overflow() {
240        // Test that mixed signs work correctly near overflow boundaries
241        let large_pos = i64::MAX / 2;
242        let large_neg = -(i64::MAX / 2);
243        let decimal = DecimalArray::new(
244            buffer![large_pos, large_neg, large_pos, 1000i64],
245            DecimalDType::new(19, 3),
246            Validity::AllValid,
247        );
248
249        let result = sum(decimal.as_ref()).unwrap();
250
251        let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
252        let expected = Scalar::new(
253            DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
254            ScalarValue::from(DecimalValue::from(expected_sum)),
255        );
256
257        assert_eq!(result, expected);
258    }
259
260    #[test]
261    fn test_sum_preserves_scale() {
262        let decimal = DecimalArray::new(
263            buffer![12345i32, 67890i32, 11111i32],
264            DecimalDType::new(6, 4),
265            Validity::AllValid,
266        );
267
268        let result = sum(decimal.as_ref()).unwrap();
269
270        // Scale should be preserved, precision increased by 10
271        let expected = Scalar::new(
272            DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
273            ScalarValue::from(DecimalValue::from(91346i32)),
274        );
275
276        assert_eq!(result, expected);
277    }
278
279    #[test]
280    fn test_sum_single_value() {
281        let decimal =
282            DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
283
284        let result = sum(decimal.as_ref()).unwrap();
285
286        let expected = Scalar::new(
287            DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
288            ScalarValue::from(DecimalValue::from(42i32)),
289        );
290
291        assert_eq!(result, expected);
292    }
293
294    #[test]
295    fn test_sum_with_all_nulls_except_one() {
296        let decimal = DecimalArray::new(
297            buffer![100i32, 200i32, 300i32, 400i32],
298            DecimalDType::new(4, 2),
299            Validity::from_iter([false, false, true, false]),
300        );
301
302        let result = sum(decimal.as_ref()).unwrap();
303
304        let expected = Scalar::new(
305            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
306            ScalarValue::from(DecimalValue::from(300i32)),
307        );
308
309        assert_eq!(result, expected);
310    }
311
312    #[test]
313    fn test_sum_i128_to_i256_boundary() {
314        use vortex_scalar::i256;
315
316        // Test the boundary between i128 and i256 accumulation
317        let large_i128 = i128::MAX / 10;
318        let decimal = DecimalArray::new(
319            buffer![
320                large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
321                large_i128, large_i128
322            ],
323            DecimalDType::new(38, 0),
324            Validity::AllValid,
325        );
326
327        let result = sum(decimal.as_ref()).unwrap();
328
329        // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision
330        let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
331        let expected = Scalar::new(
332            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
333            ScalarValue::from(DecimalValue::from(expected_sum)),
334        );
335
336        assert_eq!(result, expected);
337    }
338}