1use arrow_schema::DECIMAL256_MAX_PRECISION;
5use num_traits::AsPrimitive;
6use vortex_dtype::Nullability::Nullable;
7use vortex_dtype::{DecimalDType, DecimalType, match_each_decimal_value_type};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
11
12use crate::arrays::{DecimalArray, DecimalVTable};
13use crate::compute::{SumKernel, SumKernelAdapter};
14use crate::register_kernel;
15
16macro_rules! sum_decimal {
18 ($ty:ty, $values:expr, $initial:expr) => {{
19 let mut sum: $ty = $initial;
20 for v in $values.iter() {
21 let v: $ty = (*v).as_();
22 sum = num_traits::CheckedAdd::checked_add(&sum, &v)
23 .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
24 }
25 sum
26 }};
27 ($ty:ty, $values:expr, $validity:expr, $initial:expr) => {{
28 use itertools::Itertools;
29
30 let mut sum: $ty = $initial;
31 for (v, valid) in $values.iter().zip_eq($validity) {
32 if valid {
33 let v: $ty = (*v).as_();
34 sum = num_traits::CheckedAdd::checked_add(&sum, &v)
35 .ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
36 }
37 }
38 sum
39 }};
40}
41
42impl SumKernel for DecimalVTable {
43 #[allow(clippy::cognitive_complexity)]
44 fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
45 let decimal_dtype = array.decimal_dtype();
46
47 let new_precision = u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
51 let new_scale = decimal_dtype.scale();
52 let return_dtype = DecimalDType::new(new_precision, new_scale);
53
54 let initial_decimal = DecimalScalar::try_from(accumulator)
56 .vortex_expect("must be a decimal")
57 .decimal_value()
58 .vortex_expect("cannot be null");
59
60 match array.validity_mask() {
61 Mask::AllFalse(_) => {
62 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
63 }
64 Mask::AllTrue(_) => {
65 let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
66 match_each_decimal_value_type!(array.values_type(), |I| {
67 match_each_decimal_value_type!(values_type, |O| {
68 let initial_val: O = initial_decimal
69 .cast()
70 .vortex_expect("cannot fail to cast initial value");
71 Ok(Scalar::decimal(
72 DecimalValue::from(sum_decimal!(O, array.buffer::<I>(), initial_val)),
73 return_dtype,
74 Nullable,
75 ))
76 })
77 })
78 }
79 Mask::Values(mask_values) => {
80 let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
81 match_each_decimal_value_type!(array.values_type(), |I| {
82 match_each_decimal_value_type!(values_type, |O| {
83 let initial_val: O = initial_decimal
84 .cast()
85 .vortex_expect("cannot fail to cast initial value");
86 Ok(Scalar::decimal(
87 DecimalValue::from(sum_decimal!(
88 O,
89 array.buffer::<I>(),
90 mask_values.bit_buffer(),
91 initial_val
92 )),
93 return_dtype,
94 Nullable,
95 ))
96 })
97 })
98 }
99 }
100 }
101}
102
103register_kernel!(SumKernelAdapter(DecimalVTable).lift());
104
105#[cfg(test)]
106mod tests {
107 use vortex_buffer::buffer;
108 use vortex_dtype::{DType, DecimalDType, Nullability};
109 use vortex_scalar::{DecimalValue, Scalar, ScalarValue};
110
111 use crate::arrays::DecimalArray;
112 use crate::compute::sum;
113 use crate::validity::Validity;
114
115 #[test]
116 fn test_sum_basic() {
117 let decimal = DecimalArray::new(
118 buffer![100i32, 200i32, 300i32],
119 DecimalDType::new(4, 2),
120 Validity::AllValid,
121 );
122
123 let result = sum(decimal.as_ref()).unwrap();
124
125 let expected = Scalar::new(
126 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
127 ScalarValue::from(DecimalValue::from(600i32)),
128 );
129
130 assert_eq!(result, expected);
131 }
132
133 #[test]
134 fn test_sum_with_nulls() {
135 let decimal = DecimalArray::new(
136 buffer![100i32, 200i32, 300i32, 400i32],
137 DecimalDType::new(4, 2),
138 Validity::from_iter([true, false, true, true]),
139 );
140
141 let result = sum(decimal.as_ref()).unwrap();
142
143 let expected = Scalar::new(
144 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
145 ScalarValue::from(DecimalValue::from(800i32)),
146 );
147
148 assert_eq!(result, expected);
149 }
150
151 #[test]
152 fn test_sum_negative_values() {
153 let decimal = DecimalArray::new(
154 buffer![100i32, -200i32, 300i32, -50i32],
155 DecimalDType::new(4, 2),
156 Validity::AllValid,
157 );
158
159 let result = sum(decimal.as_ref()).unwrap();
160
161 let expected = Scalar::new(
162 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
163 ScalarValue::from(DecimalValue::from(150i32)),
164 );
165
166 assert_eq!(result, expected);
167 }
168
169 #[test]
170 fn test_sum_near_i32_max() {
171 let near_max = i32::MAX - 1000;
173 let decimal = DecimalArray::new(
174 buffer![near_max, 500i32, 400i32],
175 DecimalDType::new(10, 2),
176 Validity::AllValid,
177 );
178
179 let result = sum(decimal.as_ref()).unwrap();
180
181 let expected_sum = near_max as i64 + 500 + 400;
183 let expected = Scalar::new(
184 DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
185 ScalarValue::from(DecimalValue::from(expected_sum)),
186 );
187
188 assert_eq!(result, expected);
189 }
190
191 #[test]
192 fn test_sum_large_i64_values() {
193 let large_val = i64::MAX / 4;
195 let decimal = DecimalArray::new(
196 buffer![large_val, large_val, large_val, large_val + 1],
197 DecimalDType::new(19, 0),
198 Validity::AllValid,
199 );
200
201 let result = sum(decimal.as_ref()).unwrap();
202
203 let expected_sum = (large_val as i128) * 4 + 1;
204 let expected = Scalar::new(
205 DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
206 ScalarValue::from(DecimalValue::from(expected_sum)),
207 );
208
209 assert_eq!(result, expected);
210 }
211
212 #[test]
213 fn test_sum_overflow_detection() {
214 use vortex_scalar::i256;
215
216 let max_val = i128::MAX / 2;
219 let decimal = DecimalArray::new(
220 buffer![max_val, max_val, max_val],
221 DecimalDType::new(38, 0),
222 Validity::AllValid,
223 );
224
225 let result = sum(decimal.as_ref()).unwrap();
226
227 let expected_sum =
229 i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
230 let expected = Scalar::new(
231 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
232 ScalarValue::from(DecimalValue::from(expected_sum)),
233 );
234
235 assert_eq!(result, expected);
236 }
237
238 #[test]
239 fn test_sum_mixed_signs_near_overflow() {
240 let large_pos = i64::MAX / 2;
242 let large_neg = -(i64::MAX / 2);
243 let decimal = DecimalArray::new(
244 buffer![large_pos, large_neg, large_pos, 1000i64],
245 DecimalDType::new(19, 3),
246 Validity::AllValid,
247 );
248
249 let result = sum(decimal.as_ref()).unwrap();
250
251 let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
252 let expected = Scalar::new(
253 DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
254 ScalarValue::from(DecimalValue::from(expected_sum)),
255 );
256
257 assert_eq!(result, expected);
258 }
259
260 #[test]
261 fn test_sum_preserves_scale() {
262 let decimal = DecimalArray::new(
263 buffer![12345i32, 67890i32, 11111i32],
264 DecimalDType::new(6, 4),
265 Validity::AllValid,
266 );
267
268 let result = sum(decimal.as_ref()).unwrap();
269
270 let expected = Scalar::new(
272 DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
273 ScalarValue::from(DecimalValue::from(91346i32)),
274 );
275
276 assert_eq!(result, expected);
277 }
278
279 #[test]
280 fn test_sum_single_value() {
281 let decimal =
282 DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
283
284 let result = sum(decimal.as_ref()).unwrap();
285
286 let expected = Scalar::new(
287 DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
288 ScalarValue::from(DecimalValue::from(42i32)),
289 );
290
291 assert_eq!(result, expected);
292 }
293
294 #[test]
295 fn test_sum_with_all_nulls_except_one() {
296 let decimal = DecimalArray::new(
297 buffer![100i32, 200i32, 300i32, 400i32],
298 DecimalDType::new(4, 2),
299 Validity::from_iter([false, false, true, false]),
300 );
301
302 let result = sum(decimal.as_ref()).unwrap();
303
304 let expected = Scalar::new(
305 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
306 ScalarValue::from(DecimalValue::from(300i32)),
307 );
308
309 assert_eq!(result, expected);
310 }
311
312 #[test]
313 fn test_sum_i128_to_i256_boundary() {
314 use vortex_scalar::i256;
315
316 let large_i128 = i128::MAX / 10;
318 let decimal = DecimalArray::new(
319 buffer![
320 large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
321 large_i128, large_i128
322 ],
323 DecimalDType::new(38, 0),
324 Validity::AllValid,
325 );
326
327 let result = sum(decimal.as_ref()).unwrap();
328
329 let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
331 let expected = Scalar::new(
332 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
333 ScalarValue::from(DecimalValue::from(expected_sum)),
334 );
335
336 assert_eq!(result, expected);
337 }
338}