vortex_array/arrays/constant/compute/
sum.rs1use arrow_array::ArrowNativeTypeOp;
5use num_traits::{CheckedAdd, CheckedMul, ToPrimitive};
6use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
8use vortex_scalar::{DecimalScalar, DecimalValue, PrimitiveScalar, Scalar, ScalarValue};
9
10use crate::arrays::{ConstantArray, ConstantVTable};
11use crate::compute::{SumKernel, SumKernelAdapter};
12use crate::register_kernel;
13use crate::stats::Stat;
14
15impl SumKernel for ConstantVTable {
16 fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
17 let sum_dtype = Stat::Sum
19 .dtype(array.dtype())
20 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
21
22 let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
23 Ok(Scalar::new(sum_dtype, sum_value))
24 }
25}
26
27fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult<ScalarValue> {
28 match scalar.dtype() {
29 DType::Bool(_) => {
30 let count = match scalar.as_bool().value() {
31 None => unreachable!("Handled before reaching this point"),
32 Some(false) => 0u64,
33 Some(true) => len as u64,
34 };
35 let accumulator = accumulator
36 .as_primitive()
37 .as_::<u64>()
38 .vortex_expect("cannot be null");
39 Ok(ScalarValue::from(accumulator.checked_add(count)))
40 }
41 DType::Primitive(ptype, _) => {
42 let result = match_each_native_ptype!(
43 ptype,
44 unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.into() },
45 signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.into() },
46 floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() }
47 );
48 Ok(result)
49 }
50 DType::Decimal(decimal_dtype, _) => {
51 sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
52 }
53 DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len, accumulator),
54 dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
55 }
56}
57
58fn sum_decimal(
59 decimal_scalar: DecimalScalar,
60 array_len: usize,
61 decimal_dtype: DecimalDType,
62 accumulator: &Scalar,
63) -> VortexResult<ScalarValue> {
64 let result_dtype = Stat::Sum
65 .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
66 .vortex_expect("decimal supports sum");
67 let result_decimal_type = result_dtype
68 .as_decimal_opt()
69 .vortex_expect("must be decimal");
70
71 let Some(value) = decimal_scalar.decimal_value() else {
72 return Ok(ScalarValue::null());
74 };
75
76 let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
78
79 let array_sum = value.checked_mul(&len_value).and_then(|result| {
81 result
83 .fits_in_precision(*result_decimal_type)
84 .unwrap_or(false)
85 .then_some(result)
86 });
87
88 let initial_decimal = DecimalScalar::try_from(accumulator)?;
90 let initial_dec_value = initial_decimal
91 .decimal_value()
92 .unwrap_or(DecimalValue::I256(i256::ZERO));
93
94 match array_sum {
95 Some(array_sum_value) => {
96 let total = array_sum_value
97 .checked_add(&initial_dec_value)
98 .and_then(|result| {
99 result
100 .fits_in_precision(*result_decimal_type)
101 .unwrap_or(false)
102 .then_some(result)
103 });
104 match total {
105 Some(result_value) => Ok(ScalarValue::from(result_value)),
106 None => Ok(ScalarValue::null()), }
108 }
109 None => Ok(ScalarValue::null()), }
111}
112
113fn sum_integral<T>(
114 primitive_scalar: PrimitiveScalar<'_>,
115 array_len: usize,
116 accumulator: &Scalar,
117) -> VortexResult<Option<T>>
118where
119 T: NativePType + CheckedMul + CheckedAdd,
120 Scalar: From<Option<T>>,
121{
122 let v = primitive_scalar.as_::<T>();
123 let array_len =
124 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
125 let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
126 return Ok(None);
127 };
128
129 let initial = accumulator
130 .as_primitive()
131 .as_::<T>()
132 .vortex_expect("cannot be null");
133 Ok(initial.checked_add(&array_sum))
134}
135
136fn sum_float(
137 primitive_scalar: PrimitiveScalar<'_>,
138 array_len: usize,
139 accumulator: &Scalar,
140) -> VortexResult<Option<f64>> {
141 let v = primitive_scalar
142 .as_::<f64>()
143 .vortex_expect("cannot be null");
144 let array_len = array_len
145 .to_f64()
146 .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
147
148 let Ok(array_sum) = v.mul_checked(array_len) else {
149 return Ok(None);
150 };
151 let initial = accumulator
152 .as_primitive()
153 .as_::<f64>()
154 .vortex_expect("cannot be null");
155 Ok(Some(initial + array_sum))
156}
157
158register_kernel!(SumKernelAdapter(ConstantVTable).lift());
159
160#[cfg(test)]
161mod tests {
162 use vortex_dtype::Nullability::Nullable;
163 use vortex_dtype::{DType, DecimalDType, Nullability, PType, i256};
164 use vortex_scalar::{DecimalValue, Scalar};
165
166 use crate::arrays::ConstantArray;
167 use crate::compute::sum;
168 use crate::stats::Stat;
169 use crate::{Array, IntoArray};
170
171 #[test]
172 fn test_sum_unsigned() {
173 let array = ConstantArray::new(5u64, 10).into_array();
174 let result = sum(&array).unwrap();
175 assert_eq!(result, 50u64.into());
176 }
177
178 #[test]
179 fn test_sum_signed() {
180 let array = ConstantArray::new(-5i64, 10).into_array();
181 let result = sum(&array).unwrap();
182 assert_eq!(result, (-50i64).into());
183 }
184
185 #[test]
186 fn test_sum_nullable_value() {
187 let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
188 .into_array();
189 let result = sum(&array).unwrap();
190 assert_eq!(result, Scalar::primitive(0u64, Nullable));
191 }
192
193 #[test]
194 fn test_sum_bool_false() {
195 let array = ConstantArray::new(false, 10).into_array();
196 let result = sum(&array).unwrap();
197 assert_eq!(result, 0u64.into());
198 }
199
200 #[test]
201 fn test_sum_bool_true() {
202 let array = ConstantArray::new(true, 10).into_array();
203 let result = sum(&array).unwrap();
204 assert_eq!(result, 10u64.into());
205 }
206
207 #[test]
208 fn test_sum_bool_null() {
209 let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
210 let result = sum(&array).unwrap();
211 assert_eq!(result, Scalar::primitive(0u64, Nullable));
212 }
213
214 #[test]
215 fn test_sum_decimal() {
216 let decimal_dtype = DecimalDType::new(10, 2);
217 let array = ConstantArray::new(
218 Scalar::decimal(
219 DecimalValue::I64(100),
220 decimal_dtype,
221 Nullability::NonNullable,
222 ),
223 5,
224 )
225 .into_array();
226
227 let result = sum(&array).unwrap();
228
229 assert_eq!(
230 result.as_decimal().decimal_value(),
231 Some(DecimalValue::I256(i256::from_i128(500)))
232 );
233 assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
234 }
235
236 #[test]
237 fn test_sum_decimal_null() {
238 let decimal_dtype = DecimalDType::new(10, 2);
239 let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
240 .into_array();
241
242 let result = sum(&array).unwrap();
243 assert_eq!(
244 result,
245 Scalar::decimal(
246 DecimalValue::I256(i256::ZERO),
247 DecimalDType::new(20, 2),
248 Nullable
249 )
250 );
251 }
252
253 #[test]
254 fn test_sum_decimal_large_value() {
255 let decimal_dtype = DecimalDType::new(10, 2);
256 let array = ConstantArray::new(
257 Scalar::decimal(
258 DecimalValue::I64(999_999_999),
259 decimal_dtype,
260 Nullability::NonNullable,
261 ),
262 100,
263 )
264 .into_array();
265
266 let result = sum(&array).unwrap();
267 assert_eq!(
268 result.as_decimal().decimal_value(),
269 Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
270 );
271 }
272}