1use itertools::Itertools;
5use num_traits::AsPrimitive;
6use num_traits::CheckedAdd;
7use num_traits::NumOps;
8use vortex_buffer::BitBuffer;
9use vortex_buffer::Buffer;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_mask::Mask;
14
15use crate::arrays::DecimalArray;
16use crate::arrays::DecimalVTable;
17use crate::compute::SumKernel;
18use crate::compute::SumKernelAdapter;
19use crate::dtype::DType;
20use crate::dtype::DecimalDType;
21use crate::dtype::DecimalType;
22use crate::dtype::Nullability::Nullable;
23use crate::expr::stats::Stat;
24use crate::match_each_decimal_value_type;
25use crate::register_kernel;
26use crate::scalar::DecimalValue;
27use crate::scalar::Scalar;
28
29impl SumKernel for DecimalVTable {
30 fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
31 let return_dtype = Stat::Sum
32 .dtype(array.dtype())
33 .vortex_expect("sum for decimals exists");
34 let return_decimal_dtype = *return_dtype
35 .as_decimal_opt()
36 .vortex_expect("must be decimal");
37
38 let initial_decimal = accumulator
40 .as_decimal()
41 .decimal_value()
42 .vortex_expect("cannot be null");
43
44 let mask = array.validity_mask()?;
45 let validity = match &mask {
46 Mask::AllTrue(_) => None,
47 Mask::Values(mask_values) => Some(mask_values.bit_buffer()),
48 Mask::AllFalse(_) => {
49 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
50 }
51 };
52
53 let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype);
54 match_each_decimal_value_type!(array.values_type(), |I| {
55 match_each_decimal_value_type!(values_type, |O| {
56 let initial_val: O = initial_decimal
57 .cast()
58 .vortex_expect("cannot fail to cast initial value");
59
60 Ok(sum_to_scalar(
61 array.buffer::<I>(),
62 validity,
63 initial_val,
64 return_decimal_dtype,
65 &return_dtype,
66 ))
67 })
68 })
69 }
70}
71
72fn sum_to_scalar<T, O>(
77 values: Buffer<T>,
78 validity: Option<&BitBuffer>,
79 initial: O,
80 return_decimal_dtype: DecimalDType,
81 return_dtype: &DType,
82) -> Scalar
83where
84 T: AsPrimitive<O>,
85 O: CheckedAdd + NumOps + Into<DecimalValue> + Copy + 'static,
86 bool: AsPrimitive<O>,
87{
88 let raw_sum = match validity {
89 Some(v) => sum_decimal_with_validity(values, v, initial),
90 None => sum_decimal(values, initial),
91 };
92
93 raw_sum
94 .map(Into::<DecimalValue>::into)
95 .filter(|v| v.fits_in_precision(return_decimal_dtype))
97 .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable))
98 .unwrap_or_else(|| Scalar::null(return_dtype.clone()))
100}
101
102fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
103 values: Buffer<T>,
104 initial: I,
105) -> Option<I> {
106 let mut sum = initial;
107 for v in values.iter() {
108 let v: I = v.as_();
109 sum = CheckedAdd::checked_add(&sum, &v)?;
110 }
111 Some(sum)
112}
113
114fn sum_decimal_with_validity<T, I>(values: Buffer<T>, validity: &BitBuffer, initial: I) -> Option<I>
115where
116 T: AsPrimitive<I>,
117 I: NumOps + CheckedAdd + Copy + 'static,
118 bool: AsPrimitive<I>,
119{
120 let mut sum = initial;
121 for (v, valid) in values.iter().zip_eq(validity) {
122 let v: I = v.as_() * valid.as_();
123
124 sum = CheckedAdd::checked_add(&sum, &v)?;
125 }
126 Some(sum)
127}
128
129register_kernel!(SumKernelAdapter(DecimalVTable).lift());
130
131#[cfg(test)]
132mod tests {
133 use vortex_buffer::buffer;
134 use vortex_error::VortexExpect;
135
136 use crate::IntoArray;
137 use crate::arrays::DecimalArray;
138 use crate::compute::sum;
139 use crate::dtype::DType;
140 use crate::dtype::DecimalDType;
141 use crate::dtype::Nullability;
142 use crate::dtype::i256;
143 use crate::scalar::DecimalValue;
144 use crate::scalar::Scalar;
145 use crate::scalar::ScalarValue;
146 use crate::validity::Validity;
147
148 #[test]
149 fn test_sum_basic() {
150 let decimal = DecimalArray::new(
151 buffer![100i32, 200i32, 300i32],
152 DecimalDType::new(4, 2),
153 Validity::AllValid,
154 );
155
156 let result = sum(&decimal.into_array()).unwrap();
157
158 let expected = Scalar::try_new(
159 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
160 Some(ScalarValue::from(DecimalValue::from(600i32))),
161 )
162 .unwrap();
163
164 assert_eq!(result, expected);
165 }
166
167 #[test]
168 fn test_sum_with_nulls() {
169 let decimal = DecimalArray::new(
170 buffer![100i32, 200i32, 300i32, 400i32],
171 DecimalDType::new(4, 2),
172 Validity::from_iter([true, false, true, true]),
173 );
174
175 let result = sum(&decimal.into_array()).unwrap();
176
177 let expected = Scalar::try_new(
178 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
179 Some(ScalarValue::from(DecimalValue::from(800i32))),
180 )
181 .unwrap();
182
183 assert_eq!(result, expected);
184 }
185
186 #[test]
187 fn test_sum_negative_values() {
188 let decimal = DecimalArray::new(
189 buffer![100i32, -200i32, 300i32, -50i32],
190 DecimalDType::new(4, 2),
191 Validity::AllValid,
192 );
193
194 let result = sum(&decimal.into_array()).unwrap();
195
196 let expected = Scalar::try_new(
197 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
198 Some(ScalarValue::from(DecimalValue::from(150i32))),
199 )
200 .unwrap();
201
202 assert_eq!(result, expected);
203 }
204
205 #[test]
206 fn test_sum_near_i32_max() {
207 let near_max = i32::MAX - 1000;
209 let decimal = DecimalArray::new(
210 buffer![near_max, 500i32, 400i32],
211 DecimalDType::new(10, 2),
212 Validity::AllValid,
213 );
214
215 let result = sum(&decimal.into_array()).unwrap();
216
217 let expected_sum = near_max as i64 + 500 + 400;
219 let expected = Scalar::try_new(
220 DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
221 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
222 )
223 .unwrap();
224
225 assert_eq!(result, expected);
226 }
227
228 #[test]
229 fn test_sum_large_i64_values() {
230 let large_val = i64::MAX / 4;
232 let decimal = DecimalArray::new(
233 buffer![large_val, large_val, large_val, large_val + 1],
234 DecimalDType::new(19, 0),
235 Validity::AllValid,
236 );
237
238 let result = sum(&decimal.into_array()).unwrap();
239
240 let expected_sum = (large_val as i128) * 4 + 1;
241 let expected = Scalar::try_new(
242 DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
243 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
244 )
245 .unwrap();
246
247 assert_eq!(result, expected);
248 }
249
250 #[test]
251 fn test_sum_overflow_detection() {
252 use crate::dtype::i256;
253
254 let max_val = i128::MAX / 2;
257 let decimal = DecimalArray::new(
258 buffer![max_val, max_val, max_val],
259 DecimalDType::new(38, 0),
260 Validity::AllValid,
261 );
262
263 let result = sum(&decimal.into_array()).unwrap();
264
265 let expected_sum =
267 i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
268 let expected = Scalar::try_new(
269 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
270 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
271 )
272 .unwrap();
273
274 assert_eq!(result, expected);
275 }
276
277 #[test]
278 fn test_sum_mixed_signs_near_overflow() {
279 let large_pos = i64::MAX / 2;
281 let large_neg = -(i64::MAX / 2);
282 let decimal = DecimalArray::new(
283 buffer![large_pos, large_neg, large_pos, 1000i64],
284 DecimalDType::new(19, 3),
285 Validity::AllValid,
286 );
287
288 let result = sum(&decimal.into_array()).unwrap();
289
290 let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
291 let expected = Scalar::try_new(
292 DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
293 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
294 )
295 .unwrap();
296
297 assert_eq!(result, expected);
298 }
299
300 #[test]
301 fn test_sum_preserves_scale() {
302 let decimal = DecimalArray::new(
303 buffer![12345i32, 67890i32, 11111i32],
304 DecimalDType::new(6, 4),
305 Validity::AllValid,
306 );
307
308 let result = sum(&decimal.into_array()).unwrap();
309
310 let expected = Scalar::try_new(
312 DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
313 Some(ScalarValue::from(DecimalValue::from(91346i32))),
314 )
315 .unwrap();
316
317 assert_eq!(result, expected);
318 }
319
320 #[test]
321 fn test_sum_single_value() {
322 let decimal =
323 DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
324
325 let result = sum(&decimal.into_array()).unwrap();
326
327 let expected = Scalar::try_new(
328 DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
329 Some(ScalarValue::from(DecimalValue::from(42i32))),
330 )
331 .unwrap();
332
333 assert_eq!(result, expected);
334 }
335
336 #[test]
337 fn test_sum_with_all_nulls_except_one() {
338 let decimal = DecimalArray::new(
339 buffer![100i32, 200i32, 300i32, 400i32],
340 DecimalDType::new(4, 2),
341 Validity::from_iter([false, false, true, false]),
342 );
343
344 let result = sum(&decimal.into_array()).unwrap();
345
346 let expected = Scalar::try_new(
347 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
348 Some(ScalarValue::from(DecimalValue::from(300i32))),
349 )
350 .unwrap();
351
352 assert_eq!(result, expected);
353 }
354
355 #[test]
356 fn test_sum_i128_to_i256_boundary() {
357 let large_i128 = i128::MAX / 10;
359 let decimal = DecimalArray::new(
360 buffer![
361 large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
362 large_i128, large_i128
363 ],
364 DecimalDType::new(38, 0),
365 Validity::AllValid,
366 );
367
368 let result = sum(&decimal.into_array()).unwrap();
369
370 let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
372 let expected = Scalar::try_new(
373 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
374 Some(ScalarValue::from(DecimalValue::from(expected_sum))),
375 )
376 .unwrap();
377
378 assert_eq!(result, expected);
379 }
380
381 #[test]
382 fn test_sum_precision_overflow_without_i256_overflow() {
383 let ten_to_38 = i256::from_i128(10i128.pow(38));
387 let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37));
388 let val = ten_to_75 * i256::from_i128(6);
390
391 let decimal_dtype = DecimalDType::new(76, 0);
392 let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid);
393
394 let result = sum(&decimal.into_array()).unwrap();
396 assert_eq!(
397 result,
398 Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
399 );
400 }
401
402 #[test]
403 fn test_i256_overflow() {
404 let decimal_dtype = DecimalDType::new(76, 0);
405 let decimal = DecimalArray::new(
406 buffer![i256::MAX, i256::MAX, i256::MAX],
407 decimal_dtype,
408 Validity::AllValid,
409 );
410
411 assert_eq!(
412 sum(&decimal.into_array()).vortex_expect("operation should succeed in test"),
413 Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
414 );
415 }
416}