vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_schema::DECIMAL256_MAX_PRECISION;
5use num_traits::AsPrimitive;
6use vortex_dtype::DecimalDType;
7use vortex_error::{VortexResult, vortex_bail};
8use vortex_mask::Mask;
9use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
10
11use crate::arrays::{DecimalArray, DecimalVTable, smallest_decimal_value_type};
12use crate::compute::{SumKernel, SumKernelAdapter};
13use crate::register_kernel;
14
15// Its safe to use `AsPrimitive` here because we always cast up.
16macro_rules! sum_decimal {
17    ($ty:ty, $values:expr) => {{
18        let mut sum: $ty = <$ty>::default();
19        for v in $values.iter() {
20            let v: $ty = (*v).as_();
21            sum += v;
22        }
23        sum
24    }};
25    ($ty:ty, $values:expr, $validity:expr) => {{
26        use itertools::Itertools;
27
28        let mut sum: $ty = <$ty>::default();
29        for (v, valid) in $values.iter().zip_eq($validity.iter()) {
30            if valid {
31                let v: $ty = (*v).as_();
32                sum += v;
33            }
34        }
35        sum
36    }};
37}
38
39impl SumKernel for DecimalVTable {
40    #[allow(clippy::cognitive_complexity)]
41    fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
42        let decimal_dtype = array.decimal_dtype();
43        let nullability = array.dtype().nullability();
44
45        // Both Spark and DataFusion use this heuristic.
46        // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
47        // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
48        let new_precision = u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
49        let new_scale = decimal_dtype.scale();
50        let return_dtype = DecimalDType::new(new_precision, new_scale);
51
52        match array.validity_mask() {
53            Mask::AllFalse(_) => {
54                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
55            }
56            Mask::AllTrue(_) => {
57                let values_type = smallest_decimal_value_type(&return_dtype);
58                match_each_decimal_value_type!(array.values_type(), |I| {
59                    match_each_decimal_value_type!(values_type, |O| {
60                        Ok(Scalar::decimal(
61                            DecimalValue::from(sum_decimal!(O, array.buffer::<I>())),
62                            return_dtype,
63                            nullability,
64                        ))
65                    })
66                })
67            }
68            Mask::Values(mask_values) => {
69                let values_type = smallest_decimal_value_type(&return_dtype);
70                match_each_decimal_value_type!(array.values_type(), |I| {
71                    match_each_decimal_value_type!(values_type, |O| {
72                        Ok(Scalar::decimal(
73                            DecimalValue::from(sum_decimal!(
74                                O,
75                                array.buffer::<I>(),
76                                mask_values.boolean_buffer()
77                            )),
78                            return_dtype,
79                            nullability,
80                        ))
81                    })
82                })
83            }
84        }
85    }
86}
87
88register_kernel!(SumKernelAdapter(DecimalVTable).lift());
89
90#[cfg(test)]
91mod tests {
92    use vortex_buffer::buffer;
93    use vortex_dtype::{DType, DecimalDType, Nullability};
94    use vortex_scalar::{DecimalValue, Scalar, ScalarValue};
95
96    use crate::arrays::DecimalArray;
97    use crate::compute::sum;
98    use crate::validity::Validity;
99
100    #[test]
101    fn test_sum_basic() {
102        let decimal = DecimalArray::new(
103            buffer![100i32, 200i32, 300i32],
104            DecimalDType::new(4, 2),
105            Validity::AllValid,
106        );
107
108        let result = sum(decimal.as_ref()).unwrap();
109
110        let expected = Scalar::new(
111            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
112            ScalarValue::from(DecimalValue::from(600i32)),
113        );
114
115        assert_eq!(result, expected);
116    }
117
118    #[test]
119    fn test_sum_with_nulls() {
120        let decimal = DecimalArray::new(
121            buffer![100i32, 200i32, 300i32, 400i32],
122            DecimalDType::new(4, 2),
123            Validity::from_iter([true, false, true, true]),
124        );
125
126        let result = sum(decimal.as_ref()).unwrap();
127
128        let expected = Scalar::new(
129            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
130            ScalarValue::from(DecimalValue::from(800i32)),
131        );
132
133        assert_eq!(result, expected);
134    }
135
136    #[test]
137    fn test_sum_negative_values() {
138        let decimal = DecimalArray::new(
139            buffer![100i32, -200i32, 300i32, -50i32],
140            DecimalDType::new(4, 2),
141            Validity::AllValid,
142        );
143
144        let result = sum(decimal.as_ref()).unwrap();
145
146        let expected = Scalar::new(
147            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
148            ScalarValue::from(DecimalValue::from(150i32)),
149        );
150
151        assert_eq!(result, expected);
152    }
153
154    #[test]
155    fn test_sum_near_i32_max() {
156        // Test values close to i32::MAX to ensure proper handling
157        let near_max = i32::MAX - 1000;
158        let decimal = DecimalArray::new(
159            buffer![near_max, 500i32, 400i32],
160            DecimalDType::new(10, 2),
161            Validity::AllValid,
162        );
163
164        let result = sum(decimal.as_ref()).unwrap();
165
166        // Should use i64 for accumulation since precision increases
167        let expected_sum = near_max as i64 + 500 + 400;
168        let expected = Scalar::new(
169            DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
170            ScalarValue::from(DecimalValue::from(expected_sum)),
171        );
172
173        assert_eq!(result, expected);
174    }
175
176    #[test]
177    fn test_sum_large_i64_values() {
178        // Test with large i64 values that require i128 accumulation
179        let large_val = i64::MAX / 4;
180        let decimal = DecimalArray::new(
181            buffer![large_val, large_val, large_val, large_val + 1],
182            DecimalDType::new(19, 0),
183            Validity::AllValid,
184        );
185
186        let result = sum(decimal.as_ref()).unwrap();
187
188        let expected_sum = (large_val as i128) * 4 + 1;
189        let expected = Scalar::new(
190            DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
191            ScalarValue::from(DecimalValue::from(expected_sum)),
192        );
193
194        assert_eq!(result, expected);
195    }
196
197    #[test]
198    fn test_sum_overflow_detection() {
199        use vortex_scalar::i256;
200
201        // Create values that will overflow when summed
202        // Use maximum i128 values that will overflow when added
203        let max_val = i128::MAX / 2;
204        let decimal = DecimalArray::new(
205            buffer![max_val, max_val, max_val],
206            DecimalDType::new(38, 0),
207            Validity::AllValid,
208        );
209
210        let result = sum(decimal.as_ref()).unwrap();
211
212        // Should use i256 for accumulation
213        let expected_sum =
214            i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
215        let expected = Scalar::new(
216            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
217            ScalarValue::from(DecimalValue::from(expected_sum)),
218        );
219
220        assert_eq!(result, expected);
221    }
222
223    #[test]
224    fn test_sum_mixed_signs_near_overflow() {
225        // Test that mixed signs work correctly near overflow boundaries
226        let large_pos = i64::MAX / 2;
227        let large_neg = -(i64::MAX / 2);
228        let decimal = DecimalArray::new(
229            buffer![large_pos, large_neg, large_pos, 1000i64],
230            DecimalDType::new(19, 3),
231            Validity::AllValid,
232        );
233
234        let result = sum(decimal.as_ref()).unwrap();
235
236        let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
237        let expected = Scalar::new(
238            DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
239            ScalarValue::from(DecimalValue::from(expected_sum)),
240        );
241
242        assert_eq!(result, expected);
243    }
244
245    #[test]
246    fn test_sum_preserves_scale() {
247        let decimal = DecimalArray::new(
248            buffer![12345i32, 67890i32, 11111i32],
249            DecimalDType::new(6, 4),
250            Validity::AllValid,
251        );
252
253        let result = sum(decimal.as_ref()).unwrap();
254
255        // Scale should be preserved, precision increased by 10
256        let expected = Scalar::new(
257            DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
258            ScalarValue::from(DecimalValue::from(91346i32)),
259        );
260
261        assert_eq!(result, expected);
262    }
263
264    #[test]
265    fn test_sum_single_value() {
266        let decimal =
267            DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
268
269        let result = sum(decimal.as_ref()).unwrap();
270
271        let expected = Scalar::new(
272            DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
273            ScalarValue::from(DecimalValue::from(42i32)),
274        );
275
276        assert_eq!(result, expected);
277    }
278
279    #[test]
280    fn test_sum_with_all_nulls_except_one() {
281        let decimal = DecimalArray::new(
282            buffer![100i32, 200i32, 300i32, 400i32],
283            DecimalDType::new(4, 2),
284            Validity::from_iter([false, false, true, false]),
285        );
286
287        let result = sum(decimal.as_ref()).unwrap();
288
289        let expected = Scalar::new(
290            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
291            ScalarValue::from(DecimalValue::from(300i32)),
292        );
293
294        assert_eq!(result, expected);
295    }
296
297    #[test]
298    fn test_sum_i128_to_i256_boundary() {
299        use vortex_scalar::i256;
300
301        // Test the boundary between i128 and i256 accumulation
302        let large_i128 = i128::MAX / 10;
303        let decimal = DecimalArray::new(
304            buffer![
305                large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
306                large_i128, large_i128
307            ],
308            DecimalDType::new(38, 0),
309            Validity::AllValid,
310        );
311
312        let result = sum(decimal.as_ref()).unwrap();
313
314        // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision
315        let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
316        let expected = Scalar::new(
317            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
318            ScalarValue::from(DecimalValue::from(expected_sum)),
319        );
320
321        assert_eq!(result, expected);
322    }
323}