vortex_array/compute/
sum.rs1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::{DType, PType};
5use vortex_error::{VortexExpect, VortexResult, vortex_err, vortex_panic};
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
10use crate::stats::{Precision, Stat, StatsProvider};
11use crate::vtable::VTable;
12
13pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
19 SUM_FN
20 .invoke(&InvocationArgs {
21 inputs: &[array.into()],
22 options: &(),
23 })?
24 .unwrap_scalar()
25}
26
27struct Sum;
28
29impl ComputeFnVTable for Sum {
30 fn invoke(
31 &self,
32 args: &InvocationArgs,
33 kernels: &[ArcRef<dyn Kernel>],
34 ) -> VortexResult<Output> {
35 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
36
37 let sum_dtype = self.return_dtype(args)?;
39
40 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
42 return Ok(Scalar::new(sum_dtype, sum).into());
43 }
44
45 let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
46
47 array
49 .statistics()
50 .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
51
52 Ok(sum_scalar.into())
53 }
54
55 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
56 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
57 Stat::Sum
58 .dtype(array.dtype())
59 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
60 }
61
62 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
63 Ok(1)
65 }
66
67 fn is_elementwise(&self) -> bool {
68 false
69 }
70}
71
72pub static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
73 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
74 for kernel in inventory::iter::<SumKernelRef> {
75 compute.register_kernel(kernel.0.clone());
76 }
77 compute
78});
79
80pub struct SumKernelRef(ArcRef<dyn Kernel>);
81inventory::collect!(SumKernelRef);
82
83pub trait SumKernel: VTable {
84 fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
89}
90
91#[derive(Debug)]
92pub struct SumKernelAdapter<V: VTable>(pub V);
93
94impl<V: VTable + SumKernel> SumKernelAdapter<V> {
95 pub const fn lift(&'static self) -> SumKernelRef {
96 SumKernelRef(ArcRef::new_ref(self))
97 }
98}
99
100impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
101 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
102 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
103 let Some(array) = array.as_opt::<V>() else {
104 return Ok(None);
105 };
106 Ok(Some(V::sum(&self.0, array)?.into()))
107 }
108}
109
110pub fn sum_impl(
116 array: &dyn Array,
117 sum_dtype: DType,
118 kernels: &[ArcRef<dyn Kernel>],
119) -> VortexResult<Scalar> {
120 if array.is_empty() {
121 return if sum_dtype.is_float() {
122 Ok(Scalar::new(sum_dtype, 0.0.into()))
123 } else {
124 Ok(Scalar::new(sum_dtype, 0.into()))
125 };
126 }
127
128 if let Some(mut constant) = array.as_constant() {
130 if constant.is_null() {
131 return if sum_dtype.is_float() {
133 Ok(Scalar::new(sum_dtype, 0.0.into()))
134 } else {
135 Ok(Scalar::new(sum_dtype, 0.into()))
136 };
137 }
138
139 if let Some(extension) = constant.as_extension_opt() {
143 constant = extension.storage();
144 }
145
146 if let Some(bool) = constant.as_bool_opt() {
148 return if bool.value().vortex_expect("already checked for null value") {
149 Ok(Scalar::new(sum_dtype, array.len().into()))
151 } else {
152 Ok(Scalar::new(sum_dtype, 0.into()))
154 };
155 }
156
157 if let Some(primitive) = constant.as_primitive_opt() {
159 match primitive.ptype() {
160 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
161 let value = primitive
162 .pvalue()
163 .vortex_expect("already checked for null value")
164 .as_u64()
165 .vortex_expect("Failed to cast constant value to u64");
166
167 let sum = value.checked_mul(array.len() as u64);
169
170 return Ok(Scalar::new(sum_dtype, sum.into()));
171 }
172 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
173 let value = primitive
174 .pvalue()
175 .vortex_expect("already checked for null value")
176 .as_i64()
177 .vortex_expect("Failed to cast constant value to i64");
178
179 let sum = value.checked_mul(array.len() as i64);
181
182 return Ok(Scalar::new(sum_dtype, sum.into()));
183 }
184 PType::F16 | PType::F32 | PType::F64 => {
185 let value = primitive
186 .pvalue()
187 .vortex_expect("already checked for null value")
188 .as_f64()
189 .vortex_expect("Failed to cast constant value to f64");
190
191 let sum = value * (array.len() as f64);
192
193 return Ok(Scalar::new(sum_dtype, sum.into()));
194 }
195 }
196 }
197
198 unreachable!("Unsupported sum constant: {}", constant.dtype());
200 }
201
202 let args = InvocationArgs {
204 inputs: &[array.into()],
205 options: &(),
206 };
207 for kernel in kernels {
208 if let Some(output) = kernel.invoke(&args)? {
209 return output.unwrap_scalar();
210 }
211 }
212 if let Some(output) = array.invoke(&SUM_FN, &args)? {
213 return output.unwrap_scalar();
214 }
215
216 log::debug!("No sum implementation found for {}", array.encoding_id());
218 if array.is_canonical() {
219 vortex_panic!(
221 "No sum implementation found for canonical array: {}",
222 array.encoding_id()
223 );
224 }
225 sum(array.to_canonical()?.as_ref())
226}
227
228#[cfg(test)]
229mod test {
230 use crate::arrays::{BoolArray, PrimitiveArray};
231 use crate::compute::sum;
232
233 #[test]
234 fn sum_all_invalid() {
235 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
236 let result = sum(array.as_ref()).unwrap();
237 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
238 }
239
240 #[test]
241 fn sum_all_invalid_float() {
242 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
243 let result = sum(array.as_ref()).unwrap();
244 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
245 }
246
247 #[test]
248 fn sum_constant() {
249 let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
250 let result = sum(array.as_ref()).unwrap();
251 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
252 }
253
254 #[test]
255 fn sum_constant_float() {
256 let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
257 let result = sum(array.as_ref()).unwrap();
258 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
259 }
260
261 #[test]
262 fn sum_boolean() {
263 let array = BoolArray::from_iter([true, false, false, true]);
264 let result = sum(array.as_ref()).unwrap();
265 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
266 }
267}