vortex_array/arrays/constant/compute/
sum.rs1use num_traits::AsPrimitive;
5use num_traits::CheckedAdd;
6use num_traits::CheckedMul;
7use vortex_dtype::DType;
8use vortex_dtype::DecimalDType;
9use vortex_dtype::NativePType;
10use vortex_dtype::Nullability;
11use vortex_dtype::i256;
12use vortex_dtype::match_each_native_ptype;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17
18use crate::arrays::ConstantArray;
19use crate::arrays::ConstantVTable;
20use crate::compute::SumKernel;
21use crate::compute::SumKernelAdapter;
22use crate::expr::stats::Stat;
23use crate::register_kernel;
24use crate::scalar::DecimalScalar;
25use crate::scalar::DecimalValue;
26use crate::scalar::PrimitiveScalar;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30impl SumKernel for ConstantVTable {
31 fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
32 let sum_dtype = Stat::Sum
34 .dtype(array.dtype())
35 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
36
37 let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
38 Scalar::try_new(sum_dtype, sum_value)
39 }
40}
41
42fn sum_scalar(
43 scalar: &Scalar,
44 len: usize,
45 accumulator: &Scalar,
46) -> VortexResult<Option<ScalarValue>> {
47 match scalar.dtype() {
48 DType::Bool(_) => {
49 let count = match scalar.as_bool().value() {
50 None => unreachable!("Handled before reaching this point"),
51 Some(false) => 0u64,
52 Some(true) => len as u64,
53 };
54 let accumulator = accumulator
55 .as_primitive()
56 .as_::<u64>()
57 .vortex_expect("cannot be null");
58 Ok(accumulator
59 .checked_add(count)
60 .map(|v| ScalarValue::Primitive(v.into())))
61 }
62 DType::Primitive(ptype, _) => {
63 let result = match_each_native_ptype!(
64 ptype,
65 unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) },
66 signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) },
67 floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }
68 );
69 Ok(result)
70 }
71 DType::Decimal(decimal_dtype, _) => {
72 sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
73 }
74 DType::Extension(_) => {
75 sum_scalar(&scalar.as_extension().to_storage_scalar(), len, accumulator)
76 }
77 dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
78 }
79}
80
81fn sum_decimal(
82 decimal_scalar: DecimalScalar,
83 array_len: usize,
84 decimal_dtype: DecimalDType,
85 accumulator: &Scalar,
86) -> VortexResult<Option<ScalarValue>> {
87 let result_dtype = Stat::Sum
88 .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
89 .vortex_expect("decimal supports sum");
90 let result_decimal_type = result_dtype
91 .as_decimal_opt()
92 .vortex_expect("must be decimal");
93
94 let Some(value) = decimal_scalar.decimal_value() else {
95 return Ok(None);
97 };
98
99 let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
101
102 let Some(array_sum) = value
103 .checked_mul(&len_value)
104 .filter(|d| d.fits_in_precision(*result_decimal_type))
105 else {
106 return Ok(None);
107 };
108
109 let initial_decimal = accumulator.as_decimal();
111 let initial_dec_value = initial_decimal
112 .decimal_value()
113 .unwrap_or(DecimalValue::I256(i256::ZERO));
114
115 let total = array_sum
116 .checked_add(&initial_dec_value)
117 .and_then(|result| {
118 result
119 .fits_in_precision(*result_decimal_type)
120 .then_some(result)
121 });
122 match total {
123 Some(result_value) => Ok(Some(ScalarValue::from(result_value))),
124 None => Ok(None), }
126}
127
128fn sum_integral<T>(
129 primitive_scalar: PrimitiveScalar<'_>,
130 array_len: usize,
131 accumulator: &Scalar,
132) -> VortexResult<Option<T>>
133where
134 T: NativePType + CheckedMul + CheckedAdd,
135{
136 let v = primitive_scalar.as_::<T>();
137 let array_len =
138 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
139 let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
140 return Ok(None);
141 };
142
143 let initial = accumulator
144 .as_primitive()
145 .as_::<T>()
146 .vortex_expect("cannot be null");
147 Ok(initial.checked_add(&array_sum))
148}
149
150fn sum_float(
151 primitive_scalar: PrimitiveScalar<'_>,
152 array_len: usize,
153 accumulator: &Scalar,
154) -> VortexResult<Option<f64>> {
155 let initial = accumulator
156 .as_primitive()
157 .as_::<f64>()
158 .vortex_expect("cannot be null");
159 let v = primitive_scalar
160 .as_::<f64>()
161 .vortex_expect("cannot be null");
162 let len_f64: f64 = array_len.as_();
163
164 Ok(Some(initial + v * len_f64))
165}
166
167register_kernel!(SumKernelAdapter(ConstantVTable).lift());
168
169#[cfg(test)]
170mod tests {
171 use vortex_dtype::DType;
172 use vortex_dtype::DecimalDType;
173 use vortex_dtype::Nullability;
174 use vortex_dtype::Nullability::Nullable;
175 use vortex_dtype::PType;
176 use vortex_dtype::i256;
177 use vortex_error::VortexExpect;
178
179 use crate::Array;
180 use crate::IntoArray;
181 use crate::arrays::ConstantArray;
182 use crate::compute::sum;
183 use crate::compute::sum_with_accumulator;
184 use crate::expr::stats::Stat;
185 use crate::scalar::DecimalValue;
186 use crate::scalar::Scalar;
187
188 #[test]
189 fn test_sum_unsigned() {
190 let array = ConstantArray::new(5u64, 10).into_array();
191 let result = sum(&array).unwrap();
192 assert_eq!(result, 50u64.into());
193 }
194
195 #[test]
196 fn test_sum_signed() {
197 let array = ConstantArray::new(-5i64, 10).into_array();
198 let result = sum(&array).unwrap();
199 assert_eq!(result, (-50i64).into());
200 }
201
202 #[test]
203 fn test_sum_nullable_value() {
204 let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
205 .into_array();
206 let result = sum(&array).unwrap();
207 assert_eq!(result, Scalar::primitive(0u64, Nullable));
208 }
209
210 #[test]
211 fn test_sum_bool_false() {
212 let array = ConstantArray::new(false, 10).into_array();
213 let result = sum(&array).unwrap();
214 assert_eq!(result, 0u64.into());
215 }
216
217 #[test]
218 fn test_sum_bool_true() {
219 let array = ConstantArray::new(true, 10).into_array();
220 let result = sum(&array).unwrap();
221 assert_eq!(result, 10u64.into());
222 }
223
224 #[test]
225 fn test_sum_bool_null() {
226 let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
227 let result = sum(&array).unwrap();
228 assert_eq!(result, Scalar::primitive(0u64, Nullable));
229 }
230
231 #[test]
232 fn test_sum_decimal() {
233 let decimal_dtype = DecimalDType::new(10, 2);
234 let array = ConstantArray::new(
235 Scalar::decimal(
236 DecimalValue::I64(100),
237 decimal_dtype,
238 Nullability::NonNullable,
239 ),
240 5,
241 )
242 .into_array();
243
244 let result = sum(&array).unwrap();
245
246 assert_eq!(
247 result.as_decimal().decimal_value(),
248 Some(DecimalValue::I256(i256::from_i128(500)))
249 );
250 assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
251 }
252
253 #[test]
254 fn test_sum_decimal_null() {
255 let decimal_dtype = DecimalDType::new(10, 2);
256 let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
257 .into_array();
258
259 let result = sum(&array).unwrap();
260 assert_eq!(
261 result,
262 Scalar::decimal(
263 DecimalValue::I256(i256::ZERO),
264 DecimalDType::new(20, 2),
265 Nullable
266 )
267 );
268 }
269
270 #[test]
271 fn test_sum_decimal_large_value() {
272 let decimal_dtype = DecimalDType::new(10, 2);
273 let array = ConstantArray::new(
274 Scalar::decimal(
275 DecimalValue::I64(999_999_999),
276 decimal_dtype,
277 Nullability::NonNullable,
278 ),
279 100,
280 )
281 .into_array();
282
283 let result = sum(&array).unwrap();
284 assert_eq!(
285 result.as_decimal().decimal_value(),
286 Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
287 );
288 }
289
290 #[test]
291 fn test_sum_float_non_multiply() {
292 let acc = -2048669276050936500000000000f64;
293 let array = ConstantArray::new(6.1811675e16f64, 25);
294 let sum = sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable))
295 .vortex_expect("operation should succeed in test");
296 assert_eq!(
297 f64::try_from(&sum).vortex_expect("operation should succeed in test"),
298 -2048669274505644600000000000f64
299 );
300 }
301}