1use arrow_schema::DECIMAL256_MAX_PRECISION;
5use num_traits::AsPrimitive;
6use vortex_dtype::DecimalDType;
7use vortex_error::{VortexResult, vortex_bail};
8use vortex_mask::Mask;
9use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
10
11use crate::arrays::{DecimalArray, DecimalVTable, smallest_decimal_value_type};
12use crate::compute::{SumKernel, SumKernelAdapter};
13use crate::register_kernel;
14
15macro_rules! sum_decimal {
17 ($ty:ty, $values:expr) => {{
18 let mut sum: $ty = <$ty>::default();
19 for v in $values.iter() {
20 let v: $ty = (*v).as_();
21 sum += v;
22 }
23 sum
24 }};
25 ($ty:ty, $values:expr, $validity:expr) => {{
26 use itertools::Itertools;
27
28 let mut sum: $ty = <$ty>::default();
29 for (v, valid) in $values.iter().zip_eq($validity.iter()) {
30 if valid {
31 let v: $ty = (*v).as_();
32 sum += v;
33 }
34 }
35 sum
36 }};
37}
38
39impl SumKernel for DecimalVTable {
40 #[allow(clippy::cognitive_complexity)]
41 fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
42 let decimal_dtype = array.decimal_dtype();
43 let nullability = array.dtype().nullability();
44
45 let new_precision = u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
49 let new_scale = decimal_dtype.scale();
50 let return_dtype = DecimalDType::new(new_precision, new_scale);
51
52 match array.validity_mask() {
53 Mask::AllFalse(_) => {
54 vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
55 }
56 Mask::AllTrue(_) => {
57 let values_type = smallest_decimal_value_type(&return_dtype);
58 match_each_decimal_value_type!(array.values_type(), |I| {
59 match_each_decimal_value_type!(values_type, |O| {
60 Ok(Scalar::decimal(
61 DecimalValue::from(sum_decimal!(O, array.buffer::<I>())),
62 return_dtype,
63 nullability,
64 ))
65 })
66 })
67 }
68 Mask::Values(mask_values) => {
69 let values_type = smallest_decimal_value_type(&return_dtype);
70 match_each_decimal_value_type!(array.values_type(), |I| {
71 match_each_decimal_value_type!(values_type, |O| {
72 Ok(Scalar::decimal(
73 DecimalValue::from(sum_decimal!(
74 O,
75 array.buffer::<I>(),
76 mask_values.boolean_buffer()
77 )),
78 return_dtype,
79 nullability,
80 ))
81 })
82 })
83 }
84 }
85 }
86}
87
88register_kernel!(SumKernelAdapter(DecimalVTable).lift());
89
90#[cfg(test)]
91mod tests {
92 use vortex_buffer::buffer;
93 use vortex_dtype::{DType, DecimalDType, Nullability};
94 use vortex_scalar::{DecimalValue, Scalar, ScalarValue};
95
96 use crate::arrays::DecimalArray;
97 use crate::compute::sum;
98 use crate::validity::Validity;
99
100 #[test]
101 fn test_sum_basic() {
102 let decimal = DecimalArray::new(
103 buffer![100i32, 200i32, 300i32],
104 DecimalDType::new(4, 2),
105 Validity::AllValid,
106 );
107
108 let result = sum(decimal.as_ref()).unwrap();
109
110 let expected = Scalar::new(
111 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
112 ScalarValue::from(DecimalValue::from(600i32)),
113 );
114
115 assert_eq!(result, expected);
116 }
117
118 #[test]
119 fn test_sum_with_nulls() {
120 let decimal = DecimalArray::new(
121 buffer![100i32, 200i32, 300i32, 400i32],
122 DecimalDType::new(4, 2),
123 Validity::from_iter([true, false, true, true]),
124 );
125
126 let result = sum(decimal.as_ref()).unwrap();
127
128 let expected = Scalar::new(
129 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
130 ScalarValue::from(DecimalValue::from(800i32)),
131 );
132
133 assert_eq!(result, expected);
134 }
135
136 #[test]
137 fn test_sum_negative_values() {
138 let decimal = DecimalArray::new(
139 buffer![100i32, -200i32, 300i32, -50i32],
140 DecimalDType::new(4, 2),
141 Validity::AllValid,
142 );
143
144 let result = sum(decimal.as_ref()).unwrap();
145
146 let expected = Scalar::new(
147 DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable),
148 ScalarValue::from(DecimalValue::from(150i32)),
149 );
150
151 assert_eq!(result, expected);
152 }
153
154 #[test]
155 fn test_sum_near_i32_max() {
156 let near_max = i32::MAX - 1000;
158 let decimal = DecimalArray::new(
159 buffer![near_max, 500i32, 400i32],
160 DecimalDType::new(10, 2),
161 Validity::AllValid,
162 );
163
164 let result = sum(decimal.as_ref()).unwrap();
165
166 let expected_sum = near_max as i64 + 500 + 400;
168 let expected = Scalar::new(
169 DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable),
170 ScalarValue::from(DecimalValue::from(expected_sum)),
171 );
172
173 assert_eq!(result, expected);
174 }
175
176 #[test]
177 fn test_sum_large_i64_values() {
178 let large_val = i64::MAX / 4;
180 let decimal = DecimalArray::new(
181 buffer![large_val, large_val, large_val, large_val + 1],
182 DecimalDType::new(19, 0),
183 Validity::AllValid,
184 );
185
186 let result = sum(decimal.as_ref()).unwrap();
187
188 let expected_sum = (large_val as i128) * 4 + 1;
189 let expected = Scalar::new(
190 DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable),
191 ScalarValue::from(DecimalValue::from(expected_sum)),
192 );
193
194 assert_eq!(result, expected);
195 }
196
197 #[test]
198 fn test_sum_overflow_detection() {
199 use vortex_scalar::i256;
200
201 let max_val = i128::MAX / 2;
204 let decimal = DecimalArray::new(
205 buffer![max_val, max_val, max_val],
206 DecimalDType::new(38, 0),
207 Validity::AllValid,
208 );
209
210 let result = sum(decimal.as_ref()).unwrap();
211
212 let expected_sum =
214 i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val);
215 let expected = Scalar::new(
216 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
217 ScalarValue::from(DecimalValue::from(expected_sum)),
218 );
219
220 assert_eq!(result, expected);
221 }
222
223 #[test]
224 fn test_sum_mixed_signs_near_overflow() {
225 let large_pos = i64::MAX / 2;
227 let large_neg = -(i64::MAX / 2);
228 let decimal = DecimalArray::new(
229 buffer![large_pos, large_neg, large_pos, 1000i64],
230 DecimalDType::new(19, 3),
231 Validity::AllValid,
232 );
233
234 let result = sum(decimal.as_ref()).unwrap();
235
236 let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000;
237 let expected = Scalar::new(
238 DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable),
239 ScalarValue::from(DecimalValue::from(expected_sum)),
240 );
241
242 assert_eq!(result, expected);
243 }
244
245 #[test]
246 fn test_sum_preserves_scale() {
247 let decimal = DecimalArray::new(
248 buffer![12345i32, 67890i32, 11111i32],
249 DecimalDType::new(6, 4),
250 Validity::AllValid,
251 );
252
253 let result = sum(decimal.as_ref()).unwrap();
254
255 let expected = Scalar::new(
257 DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable),
258 ScalarValue::from(DecimalValue::from(91346i32)),
259 );
260
261 assert_eq!(result, expected);
262 }
263
264 #[test]
265 fn test_sum_single_value() {
266 let decimal =
267 DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid);
268
269 let result = sum(decimal.as_ref()).unwrap();
270
271 let expected = Scalar::new(
272 DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable),
273 ScalarValue::from(DecimalValue::from(42i32)),
274 );
275
276 assert_eq!(result, expected);
277 }
278
279 #[test]
280 fn test_sum_with_all_nulls_except_one() {
281 let decimal = DecimalArray::new(
282 buffer![100i32, 200i32, 300i32, 400i32],
283 DecimalDType::new(4, 2),
284 Validity::from_iter([false, false, true, false]),
285 );
286
287 let result = sum(decimal.as_ref()).unwrap();
288
289 let expected = Scalar::new(
290 DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable),
291 ScalarValue::from(DecimalValue::from(300i32)),
292 );
293
294 assert_eq!(result, expected);
295 }
296
297 #[test]
298 fn test_sum_i128_to_i256_boundary() {
299 use vortex_scalar::i256;
300
301 let large_i128 = i128::MAX / 10;
303 let decimal = DecimalArray::new(
304 buffer![
305 large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128,
306 large_i128, large_i128
307 ],
308 DecimalDType::new(38, 0),
309 Validity::AllValid,
310 );
311
312 let result = sum(decimal.as_ref()).unwrap();
313
314 let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9);
316 let expected = Scalar::new(
317 DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable),
318 ScalarValue::from(DecimalValue::from(expected_sum)),
319 );
320
321 assert_eq!(result, expected);
322 }
323}