1use itertools::Itertools;
5use num_traits::AsPrimitive;
6use num_traits::CheckedAdd;
7use vortex_buffer::BitBuffer;
8use vortex_buffer::Buffer;
9use vortex_dtype::DType;
10use vortex_dtype::DecimalDType;
11use vortex_dtype::DecimalType;
12use vortex_dtype::Nullability::Nullable;
13use vortex_dtype::match_each_decimal_value_type;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_mask::Mask;
18
19use crate::arrays::DecimalArray;
20use crate::arrays::DecimalVTable;
21use crate::compute::SumKernel;
22use crate::compute::SumKernelAdapter;
23use crate::expr::stats::Stat;
24use crate::register_kernel;
25use crate::scalar::DecimalValue;
26use crate::scalar::Scalar;
27
28impl SumKernel for DecimalVTable {
29 fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
30 let return_dtype = Stat::Sum
31 .dtype(array.dtype())
32 .vortex_expect("sum for decimals exists");
33 let return_decimal_dtype = *return_dtype
34 .as_decimal_opt()
35 .vortex_expect("must be decimal");
36
37 let initial_decimal = accumulator
39 .as_decimal()
40 .decimal_value()
41 .vortex_expect("cannot be null");
42
43 let mask = array.validity_mask()?;
44 let validity = match &mask {
45 Mask::AllTrue(_) => None,
46 Mask::Values(mask_values) => Some(mask_values.bit_buffer()),
47 Mask::AllFalse(_) => {
48 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
49 }
50 };
51
52 let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype);
53 match_each_decimal_value_type!(array.values_type(), |I| {
54 match_each_decimal_value_type!(values_type, |O| {
55 let initial_val: O = initial_decimal
56 .cast()
57 .vortex_expect("cannot fail to cast initial value");
58
59 Ok(sum_to_scalar(
60 array.buffer::<I>(),
61 validity,
62 initial_val,
63 return_decimal_dtype,
64 &return_dtype,
65 ))
66 })
67 })
68 }
69}
70
71fn sum_to_scalar<T, O>(
76 values: Buffer<T>,
77 validity: Option<&BitBuffer>,
78 initial: O,
79 return_decimal_dtype: DecimalDType,
80 return_dtype: &DType,
81) -> Scalar
82where
83 T: AsPrimitive<O>,
84 O: Copy + CheckedAdd + Into<DecimalValue> + 'static,
85{
86 let raw_sum = match validity {
87 Some(v) => sum_decimal_with_validity(values, v, initial),
88 None => sum_decimal(values, initial),
89 };
90
91 raw_sum
92 .map(Into::<DecimalValue>::into)
93 .filter(|v| v.fits_in_precision(return_decimal_dtype))
95 .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable))
96 .unwrap_or_else(|| Scalar::null(return_dtype.clone()))
98}
99
100fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
101 values: Buffer<T>,
102 initial: I,
103) -> Option<I> {
104 let mut sum = initial;
105 for v in values.iter() {
106 let v: I = v.as_();
107 sum = CheckedAdd::checked_add(&sum, &v)?;
108 }
109 Some(sum)
110}
111
112fn sum_decimal_with_validity<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
113 values: Buffer<T>,
114 validity: &BitBuffer,
115 initial: I,
116) -> Option<I> {
117 let mut sum = initial;
118 for (v, valid) in values.iter().zip_eq(validity) {
119 if valid {
120 let v: I = v.as_();
121 sum = CheckedAdd::checked_add(&sum, &v)?;
122 }
123 }
124 Some(sum)
125}
126
127register_kernel!(SumKernelAdapter(DecimalVTable).lift());
128
129#[cfg(test)]
130mod tests {
131 use vortex_buffer::buffer;
132 use vortex_dtype::DType;
133 use vortex_dtype::DecimalDType;
134 use vortex_dtype::Nullability;
135 use vortex_dtype::i256;
136 use vortex_error::VortexExpect;
137
138 use crate::arrays::DecimalArray;
139 use crate::compute::sum;
140 use crate::scalar::DecimalValue;
141 use crate::scalar::Scalar;
142 use crate::scalar::ScalarValue;
143 use crate::validity::Validity;
144
145 #[test]
146 fn test_sum_basic() {
147 let decimal = DecimalArray::new(
148 buffer![100i32, 200i32, 300i32],
149 DecimalDType::new(4, 2),
150 Validity::AllValid,
151 );
152
153 let result = sum(decimal.as_ref()).unwrap();
154
155 let expected = Scalar::try_new(
156 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
157 Some(ScalarValue::from(DecimalValue::from(600i32))),
158 )
159 .unwrap();
160
161 assert_eq!(result, expected);
162 }
163
164 #[test]
165 fn test_sum_with_nulls() {
166 let decimal = DecimalArray::new(
167 buffer![100i32, 200i32, 300i32, 400i32],
168 DecimalDType::new(4, 2),
169 Validity::from_iter([true, false, true, true]),
170 );
171
172 let result = sum(decimal.as_ref()).unwrap();
173
174 let expected = Scalar::try_new(
175 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
176 Some(ScalarValue::from(DecimalValue::from(800i32))),
177 )
178 .unwrap();
179
180 assert_eq!(result, expected);
181 }
182
183 #[test]
184 fn test_sum_negative_values() {
185 let decimal = DecimalArray::new(
186 buffer![100i32, -200i32, 300i32, -50i32],
187 DecimalDType::new(4, 2),
188 Validity::AllValid,
189 );
190
191 let result = sum(decimal.as_ref()).unwrap();
192
193 let expected = Scalar::try_new(
194 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
195 Some(ScalarValue::from(DecimalValue::from(150i32))),
196 )
197 .unwrap();
198
199 assert_eq!(result, expected);
200 }
201
202 #[test]
203 fn test_sum_near_i32_max() {
204 let near_max = i32::MAX - 1000;
206 let decimal = DecimalArray::new(
207 buffer![near_max, 500i32, 400i32],
208 DecimalDType::new(10, 2),
209 Validity::AllValid,
210 );
211
212 let result = sum(decimal.as_ref()).unwrap();
213
214 let expected_sum = near_max as i64 + 500 + 400;
216 let expected = Scalar::try_new(
217 DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
218 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
219 )
220 .unwrap();
221
222 assert_eq!(result, expected);
223 }
224
225 #[test]
226 fn test_sum_large_i64_values() {
227 let large_val = i64::MAX / 4;
229 let decimal = DecimalArray::new(
230 buffer![large_val, large_val, large_val, large_val + 1],
231 DecimalDType::new(19, 0),
232 Validity::AllValid,
233 );
234
235 let result = sum(decimal.as_ref()).unwrap();
236
237 let expected_sum = (large_val as i128) * 4 + 1;
238 let expected = Scalar::try_new(
239 DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
240 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
241 )
242 .unwrap();
243
244 assert_eq!(result, expected);
245 }
246
247 #[test]
248 fn test_sum_overflow_detection() {
249 use vortex_dtype::i256;
250
251 let max_val = i128::MAX / 2;
254 let decimal = DecimalArray::new(
255 buffer![max_val, max_val, max_val],
256 DecimalDType::new(38, 0),
257 Validity::AllValid,
258 );
259
260 let result = sum(decimal.as_ref()).unwrap();
261
262 let expected_sum =
264 i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
265 let expected = Scalar::try_new(
266 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
267 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
268 )
269 .unwrap();
270
271 assert_eq!(result, expected);
272 }
273
274 #[test]
275 fn test_sum_mixed_signs_near_overflow() {
276 let large_pos = i64::MAX / 2;
278 let large_neg = -(i64::MAX / 2);
279 let decimal = DecimalArray::new(
280 buffer![large_pos, large_neg, large_pos, 1000i64],
281 DecimalDType::new(19, 3),
282 Validity::AllValid,
283 );
284
285 let result = sum(decimal.as_ref()).unwrap();
286
287 let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
288 let expected = Scalar::try_new(
289 DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
290 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
291 )
292 .unwrap();
293
294 assert_eq!(result, expected);
295 }
296
297 #[test]
298 fn test_sum_preserves_scale() {
299 let decimal = DecimalArray::new(
300 buffer![12345i32, 67890i32, 11111i32],
301 DecimalDType::new(6, 4),
302 Validity::AllValid,
303 );
304
305 let result = sum(decimal.as_ref()).unwrap();
306
307 let expected = Scalar::try_new(
309 DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
310 Some(ScalarValue::from(DecimalValue::from(91346i32))),
311 )
312 .unwrap();
313
314 assert_eq!(result, expected);
315 }
316
317 #[test]
318 fn test_sum_single_value() {
319 let decimal =
320 DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
321
322 let result = sum(decimal.as_ref()).unwrap();
323
324 let expected = Scalar::try_new(
325 DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
326 Some(ScalarValue::from(DecimalValue::from(42i32))),
327 )
328 .unwrap();
329
330 assert_eq!(result, expected);
331 }
332
333 #[test]
334 fn test_sum_with_all_nulls_except_one() {
335 let decimal = DecimalArray::new(
336 buffer![100i32, 200i32, 300i32, 400i32],
337 DecimalDType::new(4, 2),
338 Validity::from_iter([false, false, true, false]),
339 );
340
341 let result = sum(decimal.as_ref()).unwrap();
342
343 let expected = Scalar::try_new(
344 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
345 Some(ScalarValue::from(DecimalValue::from(300i32))),
346 )
347 .unwrap();
348
349 assert_eq!(result, expected);
350 }
351
352 #[test]
353 fn test_sum_i128_to_i256_boundary() {
354 let large_i128 = i128::MAX / 10;
356 let decimal = DecimalArray::new(
357 buffer![
358 large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
359 large_i128, large_i128
360 ],
361 DecimalDType::new(38, 0),
362 Validity::AllValid,
363 );
364
365 let result = sum(decimal.as_ref()).unwrap();
366
367 let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
369 let expected = Scalar::try_new(
370 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
371 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
372 )
373 .unwrap();
374
375 assert_eq!(result, expected);
376 }
377
378 #[test]
379 fn test_sum_precision_overflow_without_i256_overflow() {
380 let ten_to_38 = i256::from_i128(10i128.pow(38));
384 let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37));
385 let val = ten_to_75 * i256::from_i128(6);
387
388 let decimal_dtype = DecimalDType::new(76, 0);
389 let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid);
390
391 let result = sum(decimal.as_ref()).unwrap();
393 assert_eq!(
394 result,
395 Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
396 );
397 }
398
399 #[test]
400 fn test_i256_overflow() {
401 let decimal_dtype = DecimalDType::new(76, 0);
402 let decimal = DecimalArray::new(
403 buffer![i256::MAX, i256::MAX, i256::MAX],
404 decimal_dtype,
405 Validity::AllValid,
406 );
407
408 assert_eq!(
409 sum(decimal.as_ref()).vortex_expect("operation should succeed in test"),
410 Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
411 );
412 }
413}