Skip to main content

vortex_array/arrays/decimal/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use num_traits::AsPrimitive;
6use num_traits::CheckedAdd;
7use vortex_buffer::BitBuffer;
8use vortex_buffer::Buffer;
9use vortex_dtype::DType;
10use vortex_dtype::DecimalDType;
11use vortex_dtype::DecimalType;
12use vortex_dtype::Nullability::Nullable;
13use vortex_dtype::match_each_decimal_value_type;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_mask::Mask;
18
19use crate::arrays::DecimalArray;
20use crate::arrays::DecimalVTable;
21use crate::compute::SumKernel;
22use crate::compute::SumKernelAdapter;
23use crate::expr::stats::Stat;
24use crate::register_kernel;
25use crate::scalar::DecimalValue;
26use crate::scalar::Scalar;
27
28impl SumKernel for DecimalVTable {
29    fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
30        let return_dtype = Stat::Sum
31            .dtype(array.dtype())
32            .vortex_expect("sum for decimals exists");
33        let return_decimal_dtype = *return_dtype
34            .as_decimal_opt()
35            .vortex_expect("must be decimal");
36
37        // Extract the initial value as a `DecimalValue`.
38        let initial_decimal = accumulator
39            .as_decimal()
40            .decimal_value()
41            .vortex_expect("cannot be null");
42
43        let mask = array.validity_mask()?;
44        let validity = match &mask {
45            Mask::AllTrue(_) => None,
46            Mask::Values(mask_values) => Some(mask_values.bit_buffer()),
47            Mask::AllFalse(_) => {
48                vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
49            }
50        };
51
52        let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype);
53        match_each_decimal_value_type!(array.values_type(), |I| {
54            match_each_decimal_value_type!(values_type, |O| {
55                let initial_val: O = initial_decimal
56                    .cast()
57                    .vortex_expect("cannot fail to cast initial value");
58
59                Ok(sum_to_scalar(
60                    array.buffer::<I>(),
61                    validity,
62                    initial_val,
63                    return_decimal_dtype,
64                    &return_dtype,
65                ))
66            })
67        })
68    }
69}
70
71/// Compute the checked sum and convert the result to a [`Scalar`].
72///
73/// Returns a null scalar if the sum overflows the underlying integer type or if the result
74/// exceeds the declared decimal precision.
75fn sum_to_scalar<T, O>(
76    values: Buffer<T>,
77    validity: Option<&BitBuffer>,
78    initial: O,
79    return_decimal_dtype: DecimalDType,
80    return_dtype: &DType,
81) -> Scalar
82where
83    T: AsPrimitive<O>,
84    O: Copy + CheckedAdd + Into<DecimalValue> + 'static,
85{
86    let raw_sum = match validity {
87        Some(v) => sum_decimal_with_validity(values, v, initial),
88        None => sum_decimal(values, initial),
89    };
90
91    raw_sum
92        .map(Into::<DecimalValue>::into)
93        // We have to make sure that the decimal value fits the precision of the decimal dtype.
94        .filter(|v| v.fits_in_precision(return_decimal_dtype))
95        .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable))
96        // If an overflow occurs during summation, or final value does not fit, then return a null.
97        .unwrap_or_else(|| Scalar::null(return_dtype.clone()))
98}
99
100fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
101    values: Buffer<T>,
102    initial: I,
103) -> Option<I> {
104    let mut sum = initial;
105    for v in values.iter() {
106        let v: I = v.as_();
107        sum = CheckedAdd::checked_add(&sum, &v)?;
108    }
109    Some(sum)
110}
111
112fn sum_decimal_with_validity<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
113    values: Buffer<T>,
114    validity: &BitBuffer,
115    initial: I,
116) -> Option<I> {
117    let mut sum = initial;
118    for (v, valid) in values.iter().zip_eq(validity) {
119        if valid {
120            let v: I = v.as_();
121            sum = CheckedAdd::checked_add(&sum, &v)?;
122        }
123    }
124    Some(sum)
125}
126
127register_kernel!(SumKernelAdapter(DecimalVTable).lift());
128
129#[cfg(test)]
130mod tests {
131    use vortex_buffer::buffer;
132    use vortex_dtype::DType;
133    use vortex_dtype::DecimalDType;
134    use vortex_dtype::Nullability;
135    use vortex_dtype::i256;
136    use vortex_error::VortexExpect;
137
138    use crate::arrays::DecimalArray;
139    use crate::compute::sum;
140    use crate::scalar::DecimalValue;
141    use crate::scalar::Scalar;
142    use crate::scalar::ScalarValue;
143    use crate::validity::Validity;
144
145    #[test]
146    fn test_sum_basic() {
147        let decimal = DecimalArray::new(
148            buffer![100i32, 200i32, 300i32],
149            DecimalDType::new(4, 2),
150            Validity::AllValid,
151        );
152
153        let result = sum(decimal.as_ref()).unwrap();
154
155        let expected = Scalar::try_new(
156            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
157            Some(ScalarValue::from(DecimalValue::from(600i32))),
158        )
159        .unwrap();
160
161        assert_eq!(result, expected);
162    }
163
164    #[test]
165    fn test_sum_with_nulls() {
166        let decimal = DecimalArray::new(
167            buffer![100i32, 200i32, 300i32, 400i32],
168            DecimalDType::new(4, 2),
169            Validity::from_iter([true, false, true, true]),
170        );
171
172        let result = sum(decimal.as_ref()).unwrap();
173
174        let expected = Scalar::try_new(
175            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
176            Some(ScalarValue::from(DecimalValue::from(800i32))),
177        )
178        .unwrap();
179
180        assert_eq!(result, expected);
181    }
182
183    #[test]
184    fn test_sum_negative_values() {
185        let decimal = DecimalArray::new(
186            buffer![100i32, -200i32, 300i32, -50i32],
187            DecimalDType::new(4, 2),
188            Validity::AllValid,
189        );
190
191        let result = sum(decimal.as_ref()).unwrap();
192
193        let expected = Scalar::try_new(
194            DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
195            Some(ScalarValue::from(DecimalValue::from(150i32))),
196        )
197        .unwrap();
198
199        assert_eq!(result, expected);
200    }
201
202    #[test]
203    fn test_sum_near_i32_max() {
204        // Test values close to i32::MAX to ensure proper handling
205        let near_max = i32::MAX - 1000;
206        let decimal = DecimalArray::new(
207            buffer![near_max, 500i32, 400i32],
208            DecimalDType::new(10, 2),
209            Validity::AllValid,
210        );
211
212        let result = sum(decimal.as_ref()).unwrap();
213
214        // Should use i64 for accumulation since precision increases
215        let expected_sum = near_max as i64 + 500 + 400;
216        let expected = Scalar::try_new(
217            DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
218            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
219        )
220        .unwrap();
221
222        assert_eq!(result, expected);
223    }
224
225    #[test]
226    fn test_sum_large_i64_values() {
227        // Test with large i64 values that require i128 accumulation
228        let large_val = i64::MAX / 4;
229        let decimal = DecimalArray::new(
230            buffer![large_val, large_val, large_val, large_val + 1],
231            DecimalDType::new(19, 0),
232            Validity::AllValid,
233        );
234
235        let result = sum(decimal.as_ref()).unwrap();
236
237        let expected_sum = (large_val as i128) * 4 + 1;
238        let expected = Scalar::try_new(
239            DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
240            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
241        )
242        .unwrap();
243
244        assert_eq!(result, expected);
245    }
246
247    #[test]
248    fn test_sum_overflow_detection() {
249        use vortex_dtype::i256;
250
251        // Create values that will overflow when summed
252        // Use maximum i128 values that will overflow when added
253        let max_val = i128::MAX / 2;
254        let decimal = DecimalArray::new(
255            buffer![max_val, max_val, max_val],
256            DecimalDType::new(38, 0),
257            Validity::AllValid,
258        );
259
260        let result = sum(decimal.as_ref()).unwrap();
261
262        // Should use i256 for accumulation
263        let expected_sum =
264            i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
265        let expected = Scalar::try_new(
266            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
267            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
268        )
269        .unwrap();
270
271        assert_eq!(result, expected);
272    }
273
274    #[test]
275    fn test_sum_mixed_signs_near_overflow() {
276        // Test that mixed signs work correctly near overflow boundaries
277        let large_pos = i64::MAX / 2;
278        let large_neg = -(i64::MAX / 2);
279        let decimal = DecimalArray::new(
280            buffer![large_pos, large_neg, large_pos, 1000i64],
281            DecimalDType::new(19, 3),
282            Validity::AllValid,
283        );
284
285        let result = sum(decimal.as_ref()).unwrap();
286
287        let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
288        let expected = Scalar::try_new(
289            DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
290            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
291        )
292        .unwrap();
293
294        assert_eq!(result, expected);
295    }
296
297    #[test]
298    fn test_sum_preserves_scale() {
299        let decimal = DecimalArray::new(
300            buffer![12345i32, 67890i32, 11111i32],
301            DecimalDType::new(6, 4),
302            Validity::AllValid,
303        );
304
305        let result = sum(decimal.as_ref()).unwrap();
306
307        // Scale should be preserved, precision increased by 10
308        let expected = Scalar::try_new(
309            DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
310            Some(ScalarValue::from(DecimalValue::from(91346i32))),
311        )
312        .unwrap();
313
314        assert_eq!(result, expected);
315    }
316
317    #[test]
318    fn test_sum_single_value() {
319        let decimal =
320            DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
321
322        let result = sum(decimal.as_ref()).unwrap();
323
324        let expected = Scalar::try_new(
325            DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
326            Some(ScalarValue::from(DecimalValue::from(42i32))),
327        )
328        .unwrap();
329
330        assert_eq!(result, expected);
331    }
332
333    #[test]
334    fn test_sum_with_all_nulls_except_one() {
335        let decimal = DecimalArray::new(
336            buffer![100i32, 200i32, 300i32, 400i32],
337            DecimalDType::new(4, 2),
338            Validity::from_iter([false, false, true, false]),
339        );
340
341        let result = sum(decimal.as_ref()).unwrap();
342
343        let expected = Scalar::try_new(
344            DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
345            Some(ScalarValue::from(DecimalValue::from(300i32))),
346        )
347        .unwrap();
348
349        assert_eq!(result, expected);
350    }
351
352    #[test]
353    fn test_sum_i128_to_i256_boundary() {
354        // Test the boundary between i128 and i256 accumulation
355        let large_i128 = i128::MAX / 10;
356        let decimal = DecimalArray::new(
357            buffer![
358                large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
359                large_i128, large_i128
360            ],
361            DecimalDType::new(38, 0),
362            Validity::AllValid,
363        );
364
365        let result = sum(decimal.as_ref()).unwrap();
366
367        // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision
368        let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
369        let expected = Scalar::try_new(
370            DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
371            Some(ScalarValue::from(DecimalValue::from(expected_sum))),
372        )
373        .unwrap();
374
375        assert_eq!(result, expected);
376    }
377
378    #[test]
379    fn test_sum_precision_overflow_without_i256_overflow() {
380        // Construct values that individually fit in precision 76 but whose sum exceeds it,
381        // while still fitting in `i256`. This ensures we return null for precision overflow
382        // and not just for arithmetic overflow.
383        let ten_to_38 = i256::from_i128(10i128.pow(38));
384        let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37));
385        // 6 * 10^75 is a 76-digit number, which fits in precision 76.
386        let val = ten_to_75 * i256::from_i128(6);
387
388        let decimal_dtype = DecimalDType::new(76, 0);
389        let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid);
390
391        // Sum = 12 * 10^75 = 1.2 * 10^76, which exceeds precision 76 but fits in `i256`.
392        let result = sum(decimal.as_ref()).unwrap();
393        assert_eq!(
394            result,
395            Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
396        );
397    }
398
399    #[test]
400    fn test_i256_overflow() {
401        let decimal_dtype = DecimalDType::new(76, 0);
402        let decimal = DecimalArray::new(
403            buffer![i256::MAX, i256::MAX, i256::MAX],
404            decimal_dtype,
405            Validity::AllValid,
406        );
407
408        assert_eq!(
409            sum(decimal.as_ref()).vortex_expect("operation should succeed in test"),
410            Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
411        );
412    }
413}