vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use num_traits::AsPrimitive;
6use num_traits::CheckedAdd;
7use vortex_buffer::BitBuffer;
8use vortex_buffer::Buffer;
9use vortex_dtype::DecimalType;
10use vortex_dtype::Nullability::Nullable;
11use vortex_dtype::match_each_decimal_value_type;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_mask::Mask;
16use vortex_scalar::DecimalScalar;
17use vortex_scalar::DecimalValue;
18use vortex_scalar::Scalar;
19
20use crate::arrays::DecimalArray;
21use crate::arrays::DecimalVTable;
22use crate::compute::SumKernel;
23use crate::compute::SumKernelAdapter;
24use crate::expr::stats::Stat;
25use crate::register_kernel;
26
27impl SumKernel for DecimalVTable {
28    #[expect(
29        clippy::cognitive_complexity,
30        reason = "complexity from nested match_each_* macros"
31    )]
32    fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
33        let return_dtype = Stat::Sum
34            .dtype(array.dtype())
35            .vortex_expect("sum for decimals exists");
36        let return_decimal_dtype = return_dtype
37            .as_decimal_opt()
38            .vortex_expect("must be decimal");
39
40        // Extract the initial value as a DecimalValue
41        let initial_decimal = DecimalScalar::try_from(accumulator)
42            .vortex_expect("must be a decimal")
43            .decimal_value()
44            .vortex_expect("cannot be null");
45
46        match array.validity_mask() {
47            Mask::AllFalse(_) => {
48                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
49            }
50            Mask::AllTrue(_) => {
51                let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
52                match_each_decimal_value_type!(array.values_type(), |I| {
53                    match_each_decimal_value_type!(values_type, |O| {
54                        let initial_val: O = initial_decimal
55                            .cast()
56                            .vortex_expect("cannot fail to cast initial value");
57                        if let Some(sum) = sum_decimal(array.buffer::<I>(), initial_val) {
58                            Ok(Scalar::decimal(
59                                DecimalValue::from(sum),
60                                *return_decimal_dtype,
61                                Nullable,
62                            ))
63                        } else {
64                            Ok(Scalar::null(return_dtype))
65                        }
66                    })
67                })
68            }
69            Mask::Values(mask_values) => {
70                let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
71                match_each_decimal_value_type!(array.values_type(), |I| {
72                    match_each_decimal_value_type!(values_type, |O| {
73                        let initial_val: O = initial_decimal
74                            .cast()
75                            .vortex_expect("cannot fail to cast initial value");
76
77                        if let Some(sum) = sum_decimal_with_validity(
78                            array.buffer::<I>(),
79                            mask_values.bit_buffer(),
80                            initial_val,
81                        ) {
82                            Ok(Scalar::decimal(
83                                DecimalValue::from(sum),
84                                *return_decimal_dtype,
85                                Nullable,
86                            ))
87                        } else {
88                            Ok(Scalar::null(return_dtype))
89                        }
90                    })
91                })
92            }
93        }
94    }
95}
96
97fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
98    values: Buffer<T>,
99    initial: I,
100) -> Option<I> {
101    let mut sum = initial;
102    for v in values.iter() {
103        let v: I = v.as_();
104        sum = CheckedAdd::checked_add(&sum, &v)?;
105    }
106    Some(sum)
107}
108
109fn sum_decimal_with_validity<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
110    values: Buffer<T>,
111    validity: &BitBuffer,
112    initial: I,
113) -> Option<I> {
114    let mut sum = initial;
115    for (v, valid) in values.iter().zip_eq(validity) {
116        if valid {
117            let v: I = v.as_();
118            sum = CheckedAdd::checked_add(&sum, &v)?;
119        }
120    }
121    Some(sum)
122}
123
124register_kernel!(SumKernelAdapter(DecimalVTable).lift());
125
126#[cfg(test)]
127mod tests {
128    use vortex_buffer::buffer;
129    use vortex_dtype::DType;
130    use vortex_dtype::DecimalDType;
131    use vortex_dtype::Nullability;
132    use vortex_error::VortexUnwrap;
133    use vortex_scalar::DecimalValue;
134    use vortex_scalar::Scalar;
135    use vortex_scalar::ScalarValue;
136    use vortex_scalar::i256;
137
138    use crate::arrays::DecimalArray;
139    use crate::compute::sum;
140    use crate::validity::Validity;
141
142    #[test]
143    fn test_sum_basic() {
144        let decimal = DecimalArray::new(
145            buffer![100i32, 200i32, 300i32],
146            DecimalDType::new(4, 2),
147            Validity::AllValid,
148        );
149
150        let result = sum(decimal.as_ref()).unwrap();
151
152        let expected = Scalar::new(
153            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
154            ScalarValue::from(DecimalValue::from(600i32)),
155        );
156
157        assert_eq!(result, expected);
158    }
159
160    #[test]
161    fn test_sum_with_nulls() {
162        let decimal = DecimalArray::new(
163            buffer![100i32, 200i32, 300i32, 400i32],
164            DecimalDType::new(4, 2),
165            Validity::from_iter([true, false, true, true]),
166        );
167
168        let result = sum(decimal.as_ref()).unwrap();
169
170        let expected = Scalar::new(
171            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
172            ScalarValue::from(DecimalValue::from(800i32)),
173        );
174
175        assert_eq!(result, expected);
176    }
177
178    #[test]
179    fn test_sum_negative_values() {
180        let decimal = DecimalArray::new(
181            buffer![100i32, -200i32, 300i32, -50i32],
182            DecimalDType::new(4, 2),
183            Validity::AllValid,
184        );
185
186        let result = sum(decimal.as_ref()).unwrap();
187
188        let expected = Scalar::new(
189            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
190            ScalarValue::from(DecimalValue::from(150i32)),
191        );
192
193        assert_eq!(result, expected);
194    }
195
196    #[test]
197    fn test_sum_near_i32_max() {
198        // Test values close to i32::MAX to ensure proper handling
199        let near_max = i32::MAX - 1000;
200        let decimal = DecimalArray::new(
201            buffer![near_max, 500i32, 400i32],
202            DecimalDType::new(10, 2),
203            Validity::AllValid,
204        );
205
206        let result = sum(decimal.as_ref()).unwrap();
207
208        // Should use i64 for accumulation since precision increases
209        let expected_sum = near_max as i64 + 500 + 400;
210        let expected = Scalar::new(
211            DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
212            ScalarValue::from(DecimalValue::from(expected_sum)),
213        );
214
215        assert_eq!(result, expected);
216    }
217
218    #[test]
219    fn test_sum_large_i64_values() {
220        // Test with large i64 values that require i128 accumulation
221        let large_val = i64::MAX / 4;
222        let decimal = DecimalArray::new(
223            buffer![large_val, large_val, large_val, large_val + 1],
224            DecimalDType::new(19, 0),
225            Validity::AllValid,
226        );
227
228        let result = sum(decimal.as_ref()).unwrap();
229
230        let expected_sum = (large_val as i128) * 4 + 1;
231        let expected = Scalar::new(
232            DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
233            ScalarValue::from(DecimalValue::from(expected_sum)),
234        );
235
236        assert_eq!(result, expected);
237    }
238
239    #[test]
240    fn test_sum_overflow_detection() {
241        use vortex_scalar::i256;
242
243        // Create values that will overflow when summed
244        // Use maximum i128 values that will overflow when added
245        let max_val = i128::MAX / 2;
246        let decimal = DecimalArray::new(
247            buffer![max_val, max_val, max_val],
248            DecimalDType::new(38, 0),
249            Validity::AllValid,
250        );
251
252        let result = sum(decimal.as_ref()).unwrap();
253
254        // Should use i256 for accumulation
255        let expected_sum =
256            i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
257        let expected = Scalar::new(
258            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
259            ScalarValue::from(DecimalValue::from(expected_sum)),
260        );
261
262        assert_eq!(result, expected);
263    }
264
265    #[test]
266    fn test_sum_mixed_signs_near_overflow() {
267        // Test that mixed signs work correctly near overflow boundaries
268        let large_pos = i64::MAX / 2;
269        let large_neg = -(i64::MAX / 2);
270        let decimal = DecimalArray::new(
271            buffer![large_pos, large_neg, large_pos, 1000i64],
272            DecimalDType::new(19, 3),
273            Validity::AllValid,
274        );
275
276        let result = sum(decimal.as_ref()).unwrap();
277
278        let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
279        let expected = Scalar::new(
280            DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
281            ScalarValue::from(DecimalValue::from(expected_sum)),
282        );
283
284        assert_eq!(result, expected);
285    }
286
287    #[test]
288    fn test_sum_preserves_scale() {
289        let decimal = DecimalArray::new(
290            buffer![12345i32, 67890i32, 11111i32],
291            DecimalDType::new(6, 4),
292            Validity::AllValid,
293        );
294
295        let result = sum(decimal.as_ref()).unwrap();
296
297        // Scale should be preserved, precision increased by 10
298        let expected = Scalar::new(
299            DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
300            ScalarValue::from(DecimalValue::from(91346i32)),
301        );
302
303        assert_eq!(result, expected);
304    }
305
306    #[test]
307    fn test_sum_single_value() {
308        let decimal =
309            DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
310
311        let result = sum(decimal.as_ref()).unwrap();
312
313        let expected = Scalar::new(
314            DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
315            ScalarValue::from(DecimalValue::from(42i32)),
316        );
317
318        assert_eq!(result, expected);
319    }
320
321    #[test]
322    fn test_sum_with_all_nulls_except_one() {
323        let decimal = DecimalArray::new(
324            buffer![100i32, 200i32, 300i32, 400i32],
325            DecimalDType::new(4, 2),
326            Validity::from_iter([false, false, true, false]),
327        );
328
329        let result = sum(decimal.as_ref()).unwrap();
330
331        let expected = Scalar::new(
332            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
333            ScalarValue::from(DecimalValue::from(300i32)),
334        );
335
336        assert_eq!(result, expected);
337    }
338
339    #[test]
340    fn test_sum_i128_to_i256_boundary() {
341        // Test the boundary between i128 and i256 accumulation
342        let large_i128 = i128::MAX / 10;
343        let decimal = DecimalArray::new(
344            buffer![
345                large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
346                large_i128, large_i128
347            ],
348            DecimalDType::new(38, 0),
349            Validity::AllValid,
350        );
351
352        let result = sum(decimal.as_ref()).unwrap();
353
354        // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision
355        let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
356        let expected = Scalar::new(
357            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
358            ScalarValue::from(DecimalValue::from(expected_sum)),
359        );
360
361        assert_eq!(result, expected);
362    }
363
364    #[test]
365    fn test_i256_overflow() {
366        let decimal_dtype = DecimalDType::new(76, 0);
367        let decimal = DecimalArray::new(
368            buffer![i256::MAX, i256::MAX, i256::MAX],
369            decimal_dtype,
370            Validity::AllValid,
371        );
372
373        assert_eq!(
374            sum(decimal.as_ref()).vortex_unwrap(),
375            Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
376        );
377    }
378}