Skip to main content

vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use num_traits::AsPrimitive;
6use num_traits::CheckedAdd;
7use num_traits::NumOps;
8use vortex_buffer::BitBuffer;
9use vortex_buffer::Buffer;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_mask::Mask;
14
15use crate::arrays::DecimalArray;
16use crate::arrays::DecimalVTable;
17use crate::compute::SumKernel;
18use crate::compute::SumKernelAdapter;
19use crate::dtype::DType;
20use crate::dtype::DecimalDType;
21use crate::dtype::DecimalType;
22use crate::dtype::Nullability::Nullable;
23use crate::expr::stats::Stat;
24use crate::match_each_decimal_value_type;
25use crate::register_kernel;
26use crate::scalar::DecimalValue;
27use crate::scalar::Scalar;
28
29impl SumKernel for DecimalVTable {
30    fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
31        let return_dtype = Stat::Sum
32            .dtype(array.dtype())
33            .vortex_expect("sum for decimals exists");
34        let return_decimal_dtype = *return_dtype
35            .as_decimal_opt()
36            .vortex_expect("must be decimal");
37
38        // Extract the initial value as a `DecimalValue`.
39        let initial_decimal = accumulator
40            .as_decimal()
41            .decimal_value()
42            .vortex_expect("cannot be null");
43
44        let mask = array.validity_mask()?;
45        let validity = match &mask {
46            Mask::AllTrue(_) => None,
47            Mask::Values(mask_values) => Some(mask_values.bit_buffer()),
48            Mask::AllFalse(_) => {
49                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
50            }
51        };
52
53        let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype);
54        match_each_decimal_value_type!(array.values_type(), |I| {
55            match_each_decimal_value_type!(values_type, |O| {
56                let initial_val: O = initial_decimal
57                    .cast()
58                    .vortex_expect("cannot fail to cast initial value");
59
60                Ok(sum_to_scalar(
61                    array.buffer::<I>(),
62                    validity,
63                    initial_val,
64                    return_decimal_dtype,
65                    &return_dtype,
66                ))
67            })
68        })
69    }
70}
71
72/// Compute the checked sum and convert the result to a [`Scalar`].
73///
74/// Returns a null scalar if the sum overflows the underlying integer type or if the result
75/// exceeds the declared decimal precision.
76fn sum_to_scalar<T, O>(
77    values: Buffer<T>,
78    validity: Option<&BitBuffer>,
79    initial: O,
80    return_decimal_dtype: DecimalDType,
81    return_dtype: &DType,
82) -> Scalar
83where
84    T: AsPrimitive<O>,
85    O: CheckedAdd + NumOps + Into<DecimalValue> + Copy + 'static,
86    bool: AsPrimitive<O>,
87{
88    let raw_sum = match validity {
89        Some(v) => sum_decimal_with_validity(values, v, initial),
90        None => sum_decimal(values, initial),
91    };
92
93    raw_sum
94        .map(Into::<DecimalValue>::into)
95        // We have to make sure that the decimal value fits the precision of the decimal dtype.
96        .filter(|v| v.fits_in_precision(return_decimal_dtype))
97        .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable))
98        // If an overflow occurs during summation, or final value does not fit, then return a null.
99        .unwrap_or_else(|| Scalar::null(return_dtype.clone()))
100}
101
102fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
103    values: Buffer<T>,
104    initial: I,
105) -> Option<I> {
106    let mut sum = initial;
107    for v in values.iter() {
108        let v: I = v.as_();
109        sum = CheckedAdd::checked_add(&sum, &v)?;
110    }
111    Some(sum)
112}
113
114fn sum_decimal_with_validity<T, I>(values: Buffer<T>, validity: &BitBuffer, initial: I) -> Option<I>
115where
116    T: AsPrimitive<I>,
117    I: NumOps + CheckedAdd + Copy + 'static,
118    bool: AsPrimitive<I>,
119{
120    let mut sum = initial;
121    for (v, valid) in values.iter().zip_eq(validity) {
122        let v: I = v.as_() * valid.as_();
123
124        sum = CheckedAdd::checked_add(&sum, &v)?;
125    }
126    Some(sum)
127}
128
129register_kernel!(SumKernelAdapter(DecimalVTable).lift());
130
131#[cfg(test)]
132mod tests {
133    use vortex_buffer::buffer;
134    use vortex_error::VortexExpect;
135
136    use crate::IntoArray;
137    use crate::arrays::DecimalArray;
138    use crate::compute::sum;
139    use crate::dtype::DType;
140    use crate::dtype::DecimalDType;
141    use crate::dtype::Nullability;
142    use crate::dtype::i256;
143    use crate::scalar::DecimalValue;
144    use crate::scalar::Scalar;
145    use crate::scalar::ScalarValue;
146    use crate::validity::Validity;
147
148    #[test]
149    fn test_sum_basic() {
150        let decimal = DecimalArray::new(
151            buffer![100i32, 200i32, 300i32],
152            DecimalDType::new(4, 2),
153            Validity::AllValid,
154        );
155
156        let result = sum(&decimal.into_array()).unwrap();
157
158        let expected = Scalar::try_new(
159            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
160            Some(ScalarValue::from(DecimalValue::from(600i32))),
161        )
162        .unwrap();
163
164        assert_eq!(result, expected);
165    }
166
167    #[test]
168    fn test_sum_with_nulls() {
169        let decimal = DecimalArray::new(
170            buffer![100i32, 200i32, 300i32, 400i32],
171            DecimalDType::new(4, 2),
172            Validity::from_iter([true, false, true, true]),
173        );
174
175        let result = sum(&decimal.into_array()).unwrap();
176
177        let expected = Scalar::try_new(
178            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
179            Some(ScalarValue::from(DecimalValue::from(800i32))),
180        )
181        .unwrap();
182
183        assert_eq!(result, expected);
184    }
185
186    #[test]
187    fn test_sum_negative_values() {
188        let decimal = DecimalArray::new(
189            buffer![100i32, -200i32, 300i32, -50i32],
190            DecimalDType::new(4, 2),
191            Validity::AllValid,
192        );
193
194        let result = sum(&decimal.into_array()).unwrap();
195
196        let expected = Scalar::try_new(
197            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
198            Some(ScalarValue::from(DecimalValue::from(150i32))),
199        )
200        .unwrap();
201
202        assert_eq!(result, expected);
203    }
204
205    #[test]
206    fn test_sum_near_i32_max() {
207        // Test values close to i32::MAX to ensure proper handling
208        let near_max = i32::MAX - 1000;
209        let decimal = DecimalArray::new(
210            buffer![near_max, 500i32, 400i32],
211            DecimalDType::new(10, 2),
212            Validity::AllValid,
213        );
214
215        let result = sum(&decimal.into_array()).unwrap();
216
217        // Should use i64 for accumulation since precision increases
218        let expected_sum = near_max as i64 + 500 + 400;
219        let expected = Scalar::try_new(
220            DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
221            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
222        )
223        .unwrap();
224
225        assert_eq!(result, expected);
226    }
227
228    #[test]
229    fn test_sum_large_i64_values() {
230        // Test with large i64 values that require i128 accumulation
231        let large_val = i64::MAX / 4;
232        let decimal = DecimalArray::new(
233            buffer![large_val, large_val, large_val, large_val + 1],
234            DecimalDType::new(19, 0),
235            Validity::AllValid,
236        );
237
238        let result = sum(&decimal.into_array()).unwrap();
239
240        let expected_sum = (large_val as i128) * 4 + 1;
241        let expected = Scalar::try_new(
242            DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
243            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
244        )
245        .unwrap();
246
247        assert_eq!(result, expected);
248    }
249
250    #[test]
251    fn test_sum_overflow_detection() {
252        use crate::dtype::i256;
253
254        // Create values that will overflow when summed
255        // Use maximum i128 values that will overflow when added
256        let max_val = i128::MAX / 2;
257        let decimal = DecimalArray::new(
258            buffer![max_val, max_val, max_val],
259            DecimalDType::new(38, 0),
260            Validity::AllValid,
261        );
262
263        let result = sum(&decimal.into_array()).unwrap();
264
265        // Should use i256 for accumulation
266        let expected_sum =
267            i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
268        let expected = Scalar::try_new(
269            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
270            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
271        )
272        .unwrap();
273
274        assert_eq!(result, expected);
275    }
276
277    #[test]
278    fn test_sum_mixed_signs_near_overflow() {
279        // Test that mixed signs work correctly near overflow boundaries
280        let large_pos = i64::MAX / 2;
281        let large_neg = -(i64::MAX / 2);
282        let decimal = DecimalArray::new(
283            buffer![large_pos, large_neg, large_pos, 1000i64],
284            DecimalDType::new(19, 3),
285            Validity::AllValid,
286        );
287
288        let result = sum(&decimal.into_array()).unwrap();
289
290        let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
291        let expected = Scalar::try_new(
292            DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
293            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
294        )
295        .unwrap();
296
297        assert_eq!(result, expected);
298    }
299
300    #[test]
301    fn test_sum_preserves_scale() {
302        let decimal = DecimalArray::new(
303            buffer![12345i32, 67890i32, 11111i32],
304            DecimalDType::new(6, 4),
305            Validity::AllValid,
306        );
307
308        let result = sum(&decimal.into_array()).unwrap();
309
310        // Scale should be preserved, precision increased by 10
311        let expected = Scalar::try_new(
312            DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
313            Some(ScalarValue::from(DecimalValue::from(91346i32))),
314        )
315        .unwrap();
316
317        assert_eq!(result, expected);
318    }
319
320    #[test]
321    fn test_sum_single_value() {
322        let decimal =
323            DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
324
325        let result = sum(&decimal.into_array()).unwrap();
326
327        let expected = Scalar::try_new(
328            DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
329            Some(ScalarValue::from(DecimalValue::from(42i32))),
330        )
331        .unwrap();
332
333        assert_eq!(result, expected);
334    }
335
336    #[test]
337    fn test_sum_with_all_nulls_except_one() {
338        let decimal = DecimalArray::new(
339            buffer![100i32, 200i32, 300i32, 400i32],
340            DecimalDType::new(4, 2),
341            Validity::from_iter([false, false, true, false]),
342        );
343
344        let result = sum(&decimal.into_array()).unwrap();
345
346        let expected = Scalar::try_new(
347            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
348            Some(ScalarValue::from(DecimalValue::from(300i32))),
349        )
350        .unwrap();
351
352        assert_eq!(result, expected);
353    }
354
355    #[test]
356    fn test_sum_i128_to_i256_boundary() {
357        // Test the boundary between i128 and i256 accumulation
358        let large_i128 = i128::MAX / 10;
359        let decimal = DecimalArray::new(
360            buffer![
361                large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
362                large_i128, large_i128
363            ],
364            DecimalDType::new(38, 0),
365            Validity::AllValid,
366        );
367
368        let result = sum(&decimal.into_array()).unwrap();
369
370        // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision
371        let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
372        let expected = Scalar::try_new(
373            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
374            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
375        )
376        .unwrap();
377
378        assert_eq!(result, expected);
379    }
380
381    #[test]
382    fn test_sum_precision_overflow_without_i256_overflow() {
383        // Construct values that individually fit in precision 76 but whose sum exceeds it,
384        // while still fitting in `i256`. This ensures we return null for precision overflow
385        // and not just for arithmetic overflow.
386        let ten_to_38 = i256::from_i128(10i128.pow(38));
387        let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37));
388        // 6 * 10^75 is a 76-digit number, which fits in precision 76.
389        let val = ten_to_75 * i256::from_i128(6);
390
391        let decimal_dtype = DecimalDType::new(76, 0);
392        let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid);
393
394        // Sum = 12 * 10^75 = 1.2 * 10^76, which exceeds precision 76 but fits in `i256`.
395        let result = sum(&decimal.into_array()).unwrap();
396        assert_eq!(
397            result,
398            Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
399        );
400    }
401
402    #[test]
403    fn test_i256_overflow() {
404        let decimal_dtype = DecimalDType::new(76, 0);
405        let decimal = DecimalArray::new(
406            buffer![i256::MAX, i256::MAX, i256::MAX],
407            decimal_dtype,
408            Validity::AllValid,
409        );
410
411        assert_eq!(
412            sum(&decimal.into_array()).vortex_expect("operation should succeed in test"),
413            Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
414        );
415    }
416}