1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use num_traits::CheckedAdd;
8use num_traits::CheckedSub;
9use vortex_dtype::DType;
10use vortex_error::VortexError;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_error::vortex_err;
15use vortex_error::vortex_panic;
16use vortex_scalar::NumericOperator;
17use vortex_scalar::Scalar;
18
19use crate::Array;
20use crate::compute::ComputeFn;
21use crate::compute::ComputeFnVTable;
22use crate::compute::InvocationArgs;
23use crate::compute::Kernel;
24use crate::compute::Output;
25use crate::expr::stats::Precision;
26use crate::expr::stats::Stat;
27use crate::expr::stats::StatsProvider;
28use crate::vtable::VTable;
29
30static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
32 for kernel in inventory::iter::<SumKernelRef> {
33 compute.register_kernel(kernel.0.clone());
34 }
35 compute
36});
37
38pub(crate) fn warm_up_vtable() -> usize {
39 SUM_FN.kernels().len()
40}
41
42pub(crate) fn sum_with_accumulator(
49 array: &dyn Array,
50 accumulator: &Scalar,
51) -> VortexResult<Scalar> {
52 SUM_FN
53 .invoke(&InvocationArgs {
54 inputs: &[array.into(), accumulator.into()],
55 options: &(),
56 })?
57 .unwrap_scalar()
58}
59
60pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
66 let sum_dtype = Stat::Sum
67 .dtype(array.dtype())
68 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
69 let zero = Scalar::zero_value(sum_dtype);
70 sum_with_accumulator(array, &zero)
71}
72
73pub struct SumArgs<'a> {
75 pub array: &'a dyn Array,
76 pub accumulator: &'a Scalar,
77}
78
79impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
80 type Error = VortexError;
81
82 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
83 if value.inputs.len() != 2 {
84 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
85 }
86 let array = value.inputs[0]
87 .array()
88 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
89 let accumulator = value.inputs[1]
90 .scalar()
91 .ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?;
92 Ok(SumArgs { array, accumulator })
93 }
94}
95
96struct Sum;
97
98impl ComputeFnVTable for Sum {
99 fn invoke(
100 &self,
101 args: &InvocationArgs,
102 kernels: &[ArcRef<dyn Kernel>],
103 ) -> VortexResult<Output> {
104 let SumArgs { array, accumulator } = args.try_into()?;
105
106 let sum_dtype = self.return_dtype(args)?;
108
109 vortex_ensure!(
110 &sum_dtype == accumulator.dtype(),
111 "sum_dtype {sum_dtype} must match accumulator dtype {}",
112 accumulator.dtype()
113 );
114
115 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
117 match sum_dtype {
119 DType::Primitive(p, _) => {
120 if p.is_float() && accumulator.is_zero() {
121 return Ok(sum.into());
122 } else if p.is_int() {
123 let sum_from_stat = accumulator
124 .as_primitive()
125 .checked_add(&sum.as_primitive())
126 .map(Scalar::from);
127 return Ok(sum_from_stat
128 .unwrap_or_else(|| Scalar::null(sum_dtype))
129 .into());
130 }
131 }
132 DType::Decimal(..) => {
133 let sum_from_stat = accumulator
134 .as_decimal()
135 .checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add)
136 .map(Scalar::from);
137 return Ok(sum_from_stat
138 .unwrap_or_else(|| Scalar::null(sum_dtype))
139 .into());
140 }
141 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
142 }
143 }
144
145 let sum_scalar = sum_impl(array, accumulator, kernels)?;
146
147 match sum_dtype {
149 DType::Primitive(p, _) => {
150 if p.is_float() && accumulator.is_zero() {
151 array
152 .statistics()
153 .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
154 } else if p.is_int()
155 && let Some(less_accumulator) = sum_scalar
156 .as_primitive()
157 .checked_sub(&accumulator.as_primitive())
158 {
159 array.statistics().set(
160 Stat::Sum,
161 Precision::Exact(Scalar::from(less_accumulator).value().clone()),
162 );
163 }
164 }
165 DType::Decimal(..) => {
166 if let Some(less_accumulator) = sum_scalar
167 .as_decimal()
168 .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
169 {
170 array.statistics().set(
171 Stat::Sum,
172 Precision::Exact(Scalar::from(less_accumulator).value().clone()),
173 )
174 }
175 }
176 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
177 }
178
179 Ok(sum_scalar.into())
180 }
181
182 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
183 let SumArgs { array, .. } = args.try_into()?;
184 Stat::Sum
185 .dtype(array.dtype())
186 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
187 }
188
189 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
190 Ok(1)
192 }
193
194 fn is_elementwise(&self) -> bool {
195 false
196 }
197}
198
199pub struct SumKernelRef(ArcRef<dyn Kernel>);
200inventory::collect!(SumKernelRef);
201
202pub trait SumKernel: VTable {
203 fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult<Scalar>;
209}
210
211#[derive(Debug)]
212pub struct SumKernelAdapter<V: VTable>(pub V);
213
214impl<V: VTable + SumKernel> SumKernelAdapter<V> {
215 pub const fn lift(&'static self) -> SumKernelRef {
216 SumKernelRef(ArcRef::new_ref(self))
217 }
218}
219
220impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
221 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
222 let SumArgs { array, accumulator } = args.try_into()?;
223 let Some(array) = array.as_opt::<V>() else {
224 return Ok(None);
225 };
226 Ok(Some(V::sum(&self.0, array, accumulator)?.into()))
227 }
228}
229
230pub fn sum_impl(
236 array: &dyn Array,
237 accumulator: &Scalar,
238 kernels: &[ArcRef<dyn Kernel>],
239) -> VortexResult<Scalar> {
240 if array.is_empty() || array.all_invalid() || accumulator.is_null() {
241 return Ok(accumulator.clone());
242 }
243
244 let args = InvocationArgs {
246 inputs: &[array.into(), accumulator.into()],
247 options: &(),
248 };
249 for kernel in kernels {
250 if let Some(output) = kernel.invoke(&args)? {
251 return output.unwrap_scalar();
252 }
253 }
254 if let Some(output) = array.invoke(&SUM_FN, &args)? {
255 return output.unwrap_scalar();
256 }
257
258 log::debug!("No sum implementation found for {}", array.encoding_id());
260 if array.is_canonical() {
261 vortex_panic!(
263 "No sum implementation found for canonical array: {}",
264 array.encoding_id()
265 );
266 }
267 sum_with_accumulator(array.to_canonical().as_ref(), accumulator)
268}
269
270#[cfg(test)]
271mod test {
272 use vortex_buffer::buffer;
273 use vortex_dtype::DType;
274 use vortex_dtype::Nullability;
275 use vortex_dtype::PType;
276 use vortex_error::VortexUnwrap;
277 use vortex_scalar::Scalar;
278
279 use crate::IntoArray as _;
280 use crate::arrays::BoolArray;
281 use crate::arrays::ChunkedArray;
282 use crate::arrays::PrimitiveArray;
283 use crate::compute::sum;
284 use crate::compute::sum_with_accumulator;
285
286 #[test]
287 fn sum_all_invalid() {
288 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
289 let result = sum(array.as_ref()).unwrap();
290 assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
291 }
292
293 #[test]
294 fn sum_all_invalid_float() {
295 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
296 let result = sum(array.as_ref()).unwrap();
297 assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
298 }
299
300 #[test]
301 fn sum_constant() {
302 let array = buffer![1, 1, 1, 1].into_array();
303 let result = sum(array.as_ref()).unwrap();
304 assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
305 }
306
307 #[test]
308 fn sum_constant_float() {
309 let array = buffer![1., 1., 1., 1.].into_array();
310 let result = sum(array.as_ref()).unwrap();
311 assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
312 }
313
314 #[test]
315 fn sum_boolean() {
316 let array = BoolArray::from_iter([true, false, false, true]);
317 let result = sum(array.as_ref()).unwrap();
318 assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
319 }
320
321 #[test]
322 fn sum_stats() {
323 let array = ChunkedArray::try_new(
324 vec![
325 PrimitiveArray::from_iter([1, 1, 1]).into_array(),
326 PrimitiveArray::from_iter([2, 2, 2]).into_array(),
327 ],
328 DType::Primitive(PType::I32, Nullability::NonNullable),
329 )
330 .vortex_unwrap();
331 sum_with_accumulator(
333 array.as_ref(),
334 &Scalar::primitive(2i64, Nullability::Nullable),
335 )
336 .unwrap();
337
338 let sum_without_acc = sum(array.as_ref()).unwrap();
339 assert_eq!(
340 sum_without_acc,
341 Scalar::primitive(9i64, Nullability::Nullable)
342 );
343 }
344}