vortex_array/arrays/constant/compute/
sum.rs1use num_traits::AsPrimitive;
5use num_traits::CheckedAdd;
6use num_traits::CheckedMul;
7use vortex_dtype::DType;
8use vortex_dtype::DecimalDType;
9use vortex_dtype::NativePType;
10use vortex_dtype::Nullability;
11use vortex_dtype::i256;
12use vortex_dtype::match_each_native_ptype;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_scalar::DecimalScalar;
18use vortex_scalar::DecimalValue;
19use vortex_scalar::PrimitiveScalar;
20use vortex_scalar::Scalar;
21use vortex_scalar::ScalarValue;
22
23use crate::arrays::ConstantArray;
24use crate::arrays::ConstantVTable;
25use crate::compute::SumKernel;
26use crate::compute::SumKernelAdapter;
27use crate::expr::stats::Stat;
28use crate::register_kernel;
29
30impl SumKernel for ConstantVTable {
31 fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
32 let sum_dtype = Stat::Sum
34 .dtype(array.dtype())
35 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
36
37 let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
38 Ok(Scalar::new(sum_dtype, sum_value))
39 }
40}
41
42fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult<ScalarValue> {
43 match scalar.dtype() {
44 DType::Bool(_) => {
45 let count = match scalar.as_bool().value() {
46 None => unreachable!("Handled before reaching this point"),
47 Some(false) => 0u64,
48 Some(true) => len as u64,
49 };
50 let accumulator = accumulator
51 .as_primitive()
52 .as_::<u64>()
53 .vortex_expect("cannot be null");
54 Ok(ScalarValue::from(accumulator.checked_add(count)))
55 }
56 DType::Primitive(ptype, _) => {
57 let result = match_each_native_ptype!(
58 ptype,
59 unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.into() },
60 signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.into() },
61 floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() }
62 );
63 Ok(result)
64 }
65 DType::Decimal(decimal_dtype, _) => {
66 sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
67 }
68 DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len, accumulator),
69 dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
70 }
71}
72
73fn sum_decimal(
74 decimal_scalar: DecimalScalar,
75 array_len: usize,
76 decimal_dtype: DecimalDType,
77 accumulator: &Scalar,
78) -> VortexResult<ScalarValue> {
79 let result_dtype = Stat::Sum
80 .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
81 .vortex_expect("decimal supports sum");
82 let result_decimal_type = result_dtype
83 .as_decimal_opt()
84 .vortex_expect("must be decimal");
85
86 let Some(value) = decimal_scalar.decimal_value() else {
87 return Ok(ScalarValue::null());
89 };
90
91 let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
93
94 let array_sum = value.checked_mul(&len_value).and_then(|result| {
96 result
98 .fits_in_precision(*result_decimal_type)
99 .unwrap_or(false)
100 .then_some(result)
101 });
102
103 let initial_decimal = DecimalScalar::try_from(accumulator)?;
105 let initial_dec_value = initial_decimal
106 .decimal_value()
107 .unwrap_or(DecimalValue::I256(i256::ZERO));
108
109 match array_sum {
110 Some(array_sum_value) => {
111 let total = array_sum_value
112 .checked_add(&initial_dec_value)
113 .and_then(|result| {
114 result
115 .fits_in_precision(*result_decimal_type)
116 .unwrap_or(false)
117 .then_some(result)
118 });
119 match total {
120 Some(result_value) => Ok(ScalarValue::from(result_value)),
121 None => Ok(ScalarValue::null()), }
123 }
124 None => Ok(ScalarValue::null()), }
126}
127
128fn sum_integral<T>(
129 primitive_scalar: PrimitiveScalar<'_>,
130 array_len: usize,
131 accumulator: &Scalar,
132) -> VortexResult<Option<T>>
133where
134 T: NativePType + CheckedMul + CheckedAdd,
135 Scalar: From<Option<T>>,
136{
137 let v = primitive_scalar.as_::<T>();
138 let array_len =
139 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
140 let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
141 return Ok(None);
142 };
143
144 let initial = accumulator
145 .as_primitive()
146 .as_::<T>()
147 .vortex_expect("cannot be null");
148 Ok(initial.checked_add(&array_sum))
149}
150
151fn sum_float(
152 primitive_scalar: PrimitiveScalar<'_>,
153 array_len: usize,
154 accumulator: &Scalar,
155) -> VortexResult<Option<f64>> {
156 let initial = accumulator
157 .as_primitive()
158 .as_::<f64>()
159 .vortex_expect("cannot be null");
160 let v = primitive_scalar
161 .as_::<f64>()
162 .vortex_expect("cannot be null");
163 let len_f64: f64 = array_len.as_();
164
165 Ok(Some(initial + v * len_f64))
166}
167
168register_kernel!(SumKernelAdapter(ConstantVTable).lift());
169
170#[cfg(test)]
171mod tests {
172 use vortex_dtype::DType;
173 use vortex_dtype::DecimalDType;
174 use vortex_dtype::Nullability;
175 use vortex_dtype::Nullability::Nullable;
176 use vortex_dtype::PType;
177 use vortex_dtype::i256;
178 use vortex_error::VortexUnwrap;
179 use vortex_scalar::DecimalValue;
180 use vortex_scalar::Scalar;
181
182 use crate::Array;
183 use crate::IntoArray;
184 use crate::arrays::ConstantArray;
185 use crate::compute::sum;
186 use crate::compute::sum_with_accumulator;
187 use crate::expr::stats::Stat;
188
189 #[test]
190 fn test_sum_unsigned() {
191 let array = ConstantArray::new(5u64, 10).into_array();
192 let result = sum(&array).unwrap();
193 assert_eq!(result, 50u64.into());
194 }
195
196 #[test]
197 fn test_sum_signed() {
198 let array = ConstantArray::new(-5i64, 10).into_array();
199 let result = sum(&array).unwrap();
200 assert_eq!(result, (-50i64).into());
201 }
202
203 #[test]
204 fn test_sum_nullable_value() {
205 let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
206 .into_array();
207 let result = sum(&array).unwrap();
208 assert_eq!(result, Scalar::primitive(0u64, Nullable));
209 }
210
211 #[test]
212 fn test_sum_bool_false() {
213 let array = ConstantArray::new(false, 10).into_array();
214 let result = sum(&array).unwrap();
215 assert_eq!(result, 0u64.into());
216 }
217
218 #[test]
219 fn test_sum_bool_true() {
220 let array = ConstantArray::new(true, 10).into_array();
221 let result = sum(&array).unwrap();
222 assert_eq!(result, 10u64.into());
223 }
224
225 #[test]
226 fn test_sum_bool_null() {
227 let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
228 let result = sum(&array).unwrap();
229 assert_eq!(result, Scalar::primitive(0u64, Nullable));
230 }
231
232 #[test]
233 fn test_sum_decimal() {
234 let decimal_dtype = DecimalDType::new(10, 2);
235 let array = ConstantArray::new(
236 Scalar::decimal(
237 DecimalValue::I64(100),
238 decimal_dtype,
239 Nullability::NonNullable,
240 ),
241 5,
242 )
243 .into_array();
244
245 let result = sum(&array).unwrap();
246
247 assert_eq!(
248 result.as_decimal().decimal_value(),
249 Some(DecimalValue::I256(i256::from_i128(500)))
250 );
251 assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
252 }
253
254 #[test]
255 fn test_sum_decimal_null() {
256 let decimal_dtype = DecimalDType::new(10, 2);
257 let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
258 .into_array();
259
260 let result = sum(&array).unwrap();
261 assert_eq!(
262 result,
263 Scalar::decimal(
264 DecimalValue::I256(i256::ZERO),
265 DecimalDType::new(20, 2),
266 Nullable
267 )
268 );
269 }
270
271 #[test]
272 fn test_sum_decimal_large_value() {
273 let decimal_dtype = DecimalDType::new(10, 2);
274 let array = ConstantArray::new(
275 Scalar::decimal(
276 DecimalValue::I64(999_999_999),
277 decimal_dtype,
278 Nullability::NonNullable,
279 ),
280 100,
281 )
282 .into_array();
283
284 let result = sum(&array).unwrap();
285 assert_eq!(
286 result.as_decimal().decimal_value(),
287 Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
288 );
289 }
290
291 #[test]
292 fn test_sum_float_non_multiply() {
293 let acc = -2048669276050936500000000000f64;
294 let array = ConstantArray::new(6.1811675e16f64, 25);
295 let sum =
296 sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable)).vortex_unwrap();
297 assert_eq!(
298 f64::try_from(sum).vortex_unwrap(),
299 -2048669274505644600000000000f64
300 );
301 }
302}