1use itertools::Itertools;
5use num_traits::AsPrimitive;
6use num_traits::CheckedAdd;
7use vortex_buffer::BitBuffer;
8use vortex_buffer::Buffer;
9use vortex_dtype::DecimalType;
10use vortex_dtype::Nullability::Nullable;
11use vortex_dtype::match_each_decimal_value_type;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_mask::Mask;
16use vortex_scalar::DecimalScalar;
17use vortex_scalar::DecimalValue;
18use vortex_scalar::Scalar;
19
20use crate::arrays::DecimalArray;
21use crate::arrays::DecimalVTable;
22use crate::compute::SumKernel;
23use crate::compute::SumKernelAdapter;
24use crate::expr::stats::Stat;
25use crate::register_kernel;
26
27impl SumKernel for DecimalVTable {
28 #[expect(
29 clippy::cognitive_complexity,
30 reason = "complexity from nested match_each_* macros"
31 )]
32 fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
33 let return_dtype = Stat::Sum
34 .dtype(array.dtype())
35 .vortex_expect("sum for decimals exists");
36 let return_decimal_dtype = return_dtype
37 .as_decimal_opt()
38 .vortex_expect("must be decimal");
39
40 let initial_decimal = DecimalScalar::try_from(accumulator)
42 .vortex_expect("must be a decimal")
43 .decimal_value()
44 .vortex_expect("cannot be null");
45
46 match array.validity_mask() {
47 Mask::AllFalse(_) => {
48 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
49 }
50 Mask::AllTrue(_) => {
51 let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
52 match_each_decimal_value_type!(array.values_type(), |I| {
53 match_each_decimal_value_type!(values_type, |O| {
54 let initial_val: O = initial_decimal
55 .cast()
56 .vortex_expect("cannot fail to cast initial value");
57 if let Some(sum) = sum_decimal(array.buffer::<I>(), initial_val) {
58 Ok(Scalar::decimal(
59 DecimalValue::from(sum),
60 *return_decimal_dtype,
61 Nullable,
62 ))
63 } else {
64 Ok(Scalar::null(return_dtype))
65 }
66 })
67 })
68 }
69 Mask::Values(mask_values) => {
70 let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
71 match_each_decimal_value_type!(array.values_type(), |I| {
72 match_each_decimal_value_type!(values_type, |O| {
73 let initial_val: O = initial_decimal
74 .cast()
75 .vortex_expect("cannot fail to cast initial value");
76
77 if let Some(sum) = sum_decimal_with_validity(
78 array.buffer::<I>(),
79 mask_values.bit_buffer(),
80 initial_val,
81 ) {
82 Ok(Scalar::decimal(
83 DecimalValue::from(sum),
84 *return_decimal_dtype,
85 Nullable,
86 ))
87 } else {
88 Ok(Scalar::null(return_dtype))
89 }
90 })
91 })
92 }
93 }
94 }
95}
96
97fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
98 values: Buffer<T>,
99 initial: I,
100) -> Option<I> {
101 let mut sum = initial;
102 for v in values.iter() {
103 let v: I = v.as_();
104 sum = CheckedAdd::checked_add(&sum, &v)?;
105 }
106 Some(sum)
107}
108
109fn sum_decimal_with_validity<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
110 values: Buffer<T>,
111 validity: &BitBuffer,
112 initial: I,
113) -> Option<I> {
114 let mut sum = initial;
115 for (v, valid) in values.iter().zip_eq(validity) {
116 if valid {
117 let v: I = v.as_();
118 sum = CheckedAdd::checked_add(&sum, &v)?;
119 }
120 }
121 Some(sum)
122}
123
124register_kernel!(SumKernelAdapter(DecimalVTable).lift());
125
126#[cfg(test)]
127mod tests {
128 use vortex_buffer::buffer;
129 use vortex_dtype::DType;
130 use vortex_dtype::DecimalDType;
131 use vortex_dtype::Nullability;
132 use vortex_error::VortexUnwrap;
133 use vortex_scalar::DecimalValue;
134 use vortex_scalar::Scalar;
135 use vortex_scalar::ScalarValue;
136 use vortex_scalar::i256;
137
138 use crate::arrays::DecimalArray;
139 use crate::compute::sum;
140 use crate::validity::Validity;
141
142 #[test]
143 fn test_sum_basic() {
144 let decimal = DecimalArray::new(
145 buffer![100i32, 200i32, 300i32],
146 DecimalDType::new(4, 2),
147 Validity::AllValid,
148 );
149
150 let result = sum(decimal.as_ref()).unwrap();
151
152 let expected = Scalar::new(
153 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
154 ScalarValue::from(DecimalValue::from(600i32)),
155 );
156
157 assert_eq!(result, expected);
158 }
159
160 #[test]
161 fn test_sum_with_nulls() {
162 let decimal = DecimalArray::new(
163 buffer![100i32, 200i32, 300i32, 400i32],
164 DecimalDType::new(4, 2),
165 Validity::from_iter([true, false, true, true]),
166 );
167
168 let result = sum(decimal.as_ref()).unwrap();
169
170 let expected = Scalar::new(
171 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
172 ScalarValue::from(DecimalValue::from(800i32)),
173 );
174
175 assert_eq!(result, expected);
176 }
177
178 #[test]
179 fn test_sum_negative_values() {
180 let decimal = DecimalArray::new(
181 buffer![100i32, -200i32, 300i32, -50i32],
182 DecimalDType::new(4, 2),
183 Validity::AllValid,
184 );
185
186 let result = sum(decimal.as_ref()).unwrap();
187
188 let expected = Scalar::new(
189 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
190 ScalarValue::from(DecimalValue::from(150i32)),
191 );
192
193 assert_eq!(result, expected);
194 }
195
196 #[test]
197 fn test_sum_near_i32_max() {
198 let near_max = i32::MAX - 1000;
200 let decimal = DecimalArray::new(
201 buffer![near_max, 500i32, 400i32],
202 DecimalDType::new(10, 2),
203 Validity::AllValid,
204 );
205
206 let result = sum(decimal.as_ref()).unwrap();
207
208 let expected_sum = near_max as i64 + 500 + 400;
210 let expected = Scalar::new(
211 DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
212 ScalarValue::from(DecimalValue::from(expected_sum)),
213 );
214
215 assert_eq!(result, expected);
216 }
217
218 #[test]
219 fn test_sum_large_i64_values() {
220 let large_val = i64::MAX / 4;
222 let decimal = DecimalArray::new(
223 buffer![large_val, large_val, large_val, large_val + 1],
224 DecimalDType::new(19, 0),
225 Validity::AllValid,
226 );
227
228 let result = sum(decimal.as_ref()).unwrap();
229
230 let expected_sum = (large_val as i128) * 4 + 1;
231 let expected = Scalar::new(
232 DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
233 ScalarValue::from(DecimalValue::from(expected_sum)),
234 );
235
236 assert_eq!(result, expected);
237 }
238
239 #[test]
240 fn test_sum_overflow_detection() {
241 use vortex_scalar::i256;
242
243 let max_val = i128::MAX / 2;
246 let decimal = DecimalArray::new(
247 buffer![max_val, max_val, max_val],
248 DecimalDType::new(38, 0),
249 Validity::AllValid,
250 );
251
252 let result = sum(decimal.as_ref()).unwrap();
253
254 let expected_sum =
256 i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
257 let expected = Scalar::new(
258 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
259 ScalarValue::from(DecimalValue::from(expected_sum)),
260 );
261
262 assert_eq!(result, expected);
263 }
264
265 #[test]
266 fn test_sum_mixed_signs_near_overflow() {
267 let large_pos = i64::MAX / 2;
269 let large_neg = -(i64::MAX / 2);
270 let decimal = DecimalArray::new(
271 buffer![large_pos, large_neg, large_pos, 1000i64],
272 DecimalDType::new(19, 3),
273 Validity::AllValid,
274 );
275
276 let result = sum(decimal.as_ref()).unwrap();
277
278 let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
279 let expected = Scalar::new(
280 DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
281 ScalarValue::from(DecimalValue::from(expected_sum)),
282 );
283
284 assert_eq!(result, expected);
285 }
286
287 #[test]
288 fn test_sum_preserves_scale() {
289 let decimal = DecimalArray::new(
290 buffer![12345i32, 67890i32, 11111i32],
291 DecimalDType::new(6, 4),
292 Validity::AllValid,
293 );
294
295 let result = sum(decimal.as_ref()).unwrap();
296
297 let expected = Scalar::new(
299 DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
300 ScalarValue::from(DecimalValue::from(91346i32)),
301 );
302
303 assert_eq!(result, expected);
304 }
305
306 #[test]
307 fn test_sum_single_value() {
308 let decimal =
309 DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
310
311 let result = sum(decimal.as_ref()).unwrap();
312
313 let expected = Scalar::new(
314 DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
315 ScalarValue::from(DecimalValue::from(42i32)),
316 );
317
318 assert_eq!(result, expected);
319 }
320
321 #[test]
322 fn test_sum_with_all_nulls_except_one() {
323 let decimal = DecimalArray::new(
324 buffer![100i32, 200i32, 300i32, 400i32],
325 DecimalDType::new(4, 2),
326 Validity::from_iter([false, false, true, false]),
327 );
328
329 let result = sum(decimal.as_ref()).unwrap();
330
331 let expected = Scalar::new(
332 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
333 ScalarValue::from(DecimalValue::from(300i32)),
334 );
335
336 assert_eq!(result, expected);
337 }
338
339 #[test]
340 fn test_sum_i128_to_i256_boundary() {
341 let large_i128 = i128::MAX / 10;
343 let decimal = DecimalArray::new(
344 buffer![
345 large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
346 large_i128, large_i128
347 ],
348 DecimalDType::new(38, 0),
349 Validity::AllValid,
350 );
351
352 let result = sum(decimal.as_ref()).unwrap();
353
354 let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
356 let expected = Scalar::new(
357 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
358 ScalarValue::from(DecimalValue::from(expected_sum)),
359 );
360
361 assert_eq!(result, expected);
362 }
363
364 #[test]
365 fn test_i256_overflow() {
366 let decimal_dtype = DecimalDType::new(76, 0);
367 let decimal = DecimalArray::new(
368 buffer![i256::MAX, i256::MAX, i256::MAX],
369 decimal_dtype,
370 Validity::AllValid,
371 );
372
373 assert_eq!(
374 sum(decimal.as_ref()).vortex_unwrap(),
375 Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
376 );
377 }
378}