vortex_array/compute/
sum.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use num_traits::CheckedAdd;
8use num_traits::CheckedSub;
9use vortex_error::VortexError;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray as _;
19use crate::compute::ComputeFn;
20use crate::compute::ComputeFnVTable;
21use crate::compute::InvocationArgs;
22use crate::compute::Kernel;
23use crate::compute::Output;
24use crate::dtype::DType;
25use crate::expr::stats::Precision;
26use crate::expr::stats::Stat;
27use crate::expr::stats::StatsProvider;
28use crate::scalar::NumericOperator;
29use crate::scalar::Scalar;
30use crate::vtable::VTable;
31
32static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
33 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
34 for kernel in inventory::iter::<SumKernelRef> {
35 compute.register_kernel(kernel.0.clone());
36 }
37 compute
38});
39
40pub(crate) fn warm_up_vtable() -> usize {
41 SUM_FN.kernels().len()
42}
43
44pub(crate) fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
51 SUM_FN
52 .invoke(&InvocationArgs {
53 inputs: &[array.into(), accumulator.into()],
54 options: &(),
55 })?
56 .unwrap_scalar()
57}
58
59pub fn sum(array: &ArrayRef) -> VortexResult<Scalar> {
65 let sum_dtype = Stat::Sum
66 .dtype(array.dtype())
67 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
68 let zero = Scalar::zero_value(&sum_dtype);
69 sum_with_accumulator(array, &zero)
70}
71
72pub struct SumArgs<'a> {
74 pub array: &'a dyn Array,
75 pub accumulator: &'a Scalar,
76}
77
78impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
79 type Error = VortexError;
80
81 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
82 if value.inputs.len() != 2 {
83 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
84 }
85 let array = value.inputs[0]
86 .array()
87 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
88 let accumulator = value.inputs[1]
89 .scalar()
90 .ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?;
91 Ok(SumArgs { array, accumulator })
92 }
93}
94
95struct Sum;
96
97impl ComputeFnVTable for Sum {
98 fn invoke(
99 &self,
100 args: &InvocationArgs,
101 kernels: &[ArcRef<dyn Kernel>],
102 ) -> VortexResult<Output> {
103 let SumArgs { array, accumulator } = args.try_into()?;
104 let array = array.to_array();
105
106 let sum_dtype = self.return_dtype(args)?;
108
109 vortex_ensure!(
110 &sum_dtype == accumulator.dtype(),
111 "sum_dtype {sum_dtype} must match accumulator dtype {}",
112 accumulator.dtype()
113 );
114
115 if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
117 match &sum_dtype {
120 DType::Primitive(p, _) => {
121 if p.is_float() && accumulator.is_zero() == Some(true) {
122 return Ok(sum_scalar.into());
123 } else if p.is_int() {
124 let sum_from_stat = accumulator
125 .as_primitive()
126 .checked_add(&sum_scalar.as_primitive())
127 .map(Scalar::from);
128 return Ok(sum_from_stat
129 .unwrap_or_else(|| Scalar::null(sum_dtype))
130 .into());
131 }
132 }
133 DType::Decimal(..) => {
134 let sum_from_stat = accumulator
135 .as_decimal()
136 .checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add)
137 .map(Scalar::from);
138 return Ok(sum_from_stat
139 .unwrap_or_else(|| Scalar::null(sum_dtype))
140 .into());
141 }
142 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
143 }
144 }
145
146 let sum_scalar = sum_impl(&array, accumulator, kernels)?;
147
148 match sum_dtype {
150 DType::Primitive(p, _) => {
151 if p.is_float()
152 && accumulator.is_zero() == Some(true)
153 && let Some(sum_value) = sum_scalar.value().cloned()
154 {
155 array
156 .statistics()
157 .set(Stat::Sum, Precision::Exact(sum_value));
158 } else if p.is_int()
159 && let Some(less_accumulator) = sum_scalar
160 .as_primitive()
161 .checked_sub(&accumulator.as_primitive())
162 && let Some(val) = Scalar::from(less_accumulator).into_value()
163 {
164 array.statistics().set(Stat::Sum, Precision::Exact(val));
165 }
166 }
167 DType::Decimal(..) => {
168 if let Some(less_accumulator) = sum_scalar
169 .as_decimal()
170 .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
171 && let Some(val) = Scalar::from(less_accumulator).into_value()
172 {
173 array.statistics().set(Stat::Sum, Precision::Exact(val));
174 }
175 }
176 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
177 }
178
179 Ok(sum_scalar.into())
180 }
181
182 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
183 let SumArgs { array, .. } = args.try_into()?;
184 Stat::Sum
185 .dtype(array.dtype())
186 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
187 }
188
189 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
190 Ok(1)
192 }
193
194 fn is_elementwise(&self) -> bool {
195 false
196 }
197}
198
199pub struct SumKernelRef(ArcRef<dyn Kernel>);
200inventory::collect!(SumKernelRef);
201
202pub trait SumKernel: VTable {
203 fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult<Scalar>;
209}
210
211#[derive(Debug)]
212pub struct SumKernelAdapter<V: VTable>(pub V);
213
214impl<V: VTable + SumKernel> SumKernelAdapter<V> {
215 pub const fn lift(&'static self) -> SumKernelRef {
216 SumKernelRef(ArcRef::new_ref(self))
217 }
218}
219
220impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
221 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
222 let SumArgs { array, accumulator } = args.try_into()?;
223 let Some(array) = array.as_opt::<V>() else {
224 return Ok(None);
225 };
226 Ok(Some(V::sum(&self.0, array, accumulator)?.into()))
227 }
228}
229
230pub fn sum_impl(
236 array: &ArrayRef,
237 accumulator: &Scalar,
238 kernels: &[ArcRef<dyn Kernel>],
239) -> VortexResult<Scalar> {
240 if array.is_empty() || array.all_invalid()? || accumulator.is_null() {
241 return Ok(accumulator.clone());
242 }
243
244 let args = InvocationArgs {
246 inputs: &[array.into(), accumulator.into()],
247 options: &(),
248 };
249 for kernel in kernels {
250 if let Some(output) = kernel.invoke(&args)? {
251 return output.unwrap_scalar();
252 }
253 }
254
255 tracing::debug!("No sum implementation found for {}", array.encoding_id());
257 if array.is_canonical() {
258 vortex_panic!(
260 "No sum implementation found for canonical array: {}",
261 array.encoding_id()
262 );
263 }
264 let canonical = array.to_canonical()?.into_array();
265 sum_with_accumulator(&canonical, accumulator)
266}
267
268#[cfg(test)]
269mod test {
270 use vortex_buffer::buffer;
271 use vortex_error::VortexExpect;
272
273 use crate::IntoArray as _;
274 use crate::arrays::BoolArray;
275 use crate::arrays::ChunkedArray;
276 use crate::arrays::PrimitiveArray;
277 use crate::compute::sum;
278 use crate::compute::sum_with_accumulator;
279 use crate::dtype::DType;
280 use crate::dtype::Nullability;
281 use crate::dtype::PType;
282 use crate::scalar::Scalar;
283
284 #[test]
285 fn sum_all_invalid() {
286 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
287 let result = sum(&array).unwrap();
288 assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
289 }
290
291 #[test]
292 fn sum_all_invalid_float() {
293 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]).into_array();
294 let result = sum(&array).unwrap();
295 assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
296 }
297
298 #[test]
299 fn sum_constant() {
300 let array = buffer![1, 1, 1, 1].into_array();
301 let result = sum(&array).unwrap();
302 assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
303 }
304
305 #[test]
306 fn sum_constant_float() {
307 let array = buffer![1., 1., 1., 1.].into_array();
308 let result = sum(&array).unwrap();
309 assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
310 }
311
312 #[test]
313 fn sum_boolean() {
314 let array = BoolArray::from_iter([true, false, false, true]).into_array();
315 let result = sum(&array).unwrap();
316 assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
317 }
318
319 #[test]
320 fn sum_stats() {
321 let array = ChunkedArray::try_new(
322 vec![
323 PrimitiveArray::from_iter([1, 1, 1]).into_array(),
324 PrimitiveArray::from_iter([2, 2, 2]).into_array(),
325 ],
326 DType::Primitive(PType::I32, Nullability::NonNullable),
327 )
328 .vortex_expect("operation should succeed in test");
329 let array = array.into_array();
330 sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullability::Nullable)).unwrap();
332
333 let sum_without_acc = sum(&array).unwrap();
334 assert_eq!(
335 sum_without_acc,
336 Scalar::primitive(9i64, Nullability::Nullable)
337 );
338 }
339}