vortex_array/compute/
sum.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_err, vortex_panic};
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProvider};
14use crate::vtable::VTable;
15
16pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
22 SUM_FN
23 .invoke(&InvocationArgs {
24 inputs: &[array.into()],
25 options: &(),
26 })?
27 .unwrap_scalar()
28}
29
30struct Sum;
31
32impl ComputeFnVTable for Sum {
33 fn invoke(
34 &self,
35 args: &InvocationArgs,
36 kernels: &[ArcRef<dyn Kernel>],
37 ) -> VortexResult<Output> {
38 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
39
40 let sum_dtype = self.return_dtype(args)?;
42
43 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
45 return Ok(Scalar::new(sum_dtype, sum).into());
46 }
47
48 let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
49
50 array
52 .statistics()
53 .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
54
55 Ok(sum_scalar.into())
56 }
57
58 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
59 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
60 Stat::Sum
61 .dtype(array.dtype())
62 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
63 }
64
65 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
66 Ok(1)
68 }
69
70 fn is_elementwise(&self) -> bool {
71 false
72 }
73}
74
75pub static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
76 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
77 for kernel in inventory::iter::<SumKernelRef> {
78 compute.register_kernel(kernel.0.clone());
79 }
80 compute
81});
82
83pub struct SumKernelRef(ArcRef<dyn Kernel>);
84inventory::collect!(SumKernelRef);
85
86pub trait SumKernel: VTable {
87 fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
92}
93
94#[derive(Debug)]
95pub struct SumKernelAdapter<V: VTable>(pub V);
96
97impl<V: VTable + SumKernel> SumKernelAdapter<V> {
98 pub const fn lift(&'static self) -> SumKernelRef {
99 SumKernelRef(ArcRef::new_ref(self))
100 }
101}
102
103impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
104 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
105 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
106 let Some(array) = array.as_opt::<V>() else {
107 return Ok(None);
108 };
109 Ok(Some(V::sum(&self.0, array)?.into()))
110 }
111}
112
113pub fn sum_impl(
119 array: &dyn Array,
120 sum_dtype: DType,
121 kernels: &[ArcRef<dyn Kernel>],
122) -> VortexResult<Scalar> {
123 if array.is_empty() {
124 return if sum_dtype.is_float() {
125 Ok(Scalar::new(sum_dtype, 0.0.into()))
126 } else {
127 Ok(Scalar::new(sum_dtype, 0.into()))
128 };
129 }
130
131 if let Some(mut constant) = array.as_constant() {
133 if constant.is_null() {
134 return if sum_dtype.is_float() {
136 Ok(Scalar::new(sum_dtype, 0.0.into()))
137 } else {
138 Ok(Scalar::new(sum_dtype, 0.into()))
139 };
140 }
141
142 if let Some(extension) = constant.as_extension_opt() {
146 constant = extension.storage();
147 }
148
149 if let Some(bool) = constant.as_bool_opt() {
151 return if bool.value().vortex_expect("already checked for null value") {
152 Ok(Scalar::new(sum_dtype, array.len().into()))
154 } else {
155 Ok(Scalar::new(sum_dtype, 0.into()))
157 };
158 }
159
160 if let Some(primitive) = constant.as_primitive_opt() {
162 match primitive.ptype() {
163 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
164 let value = primitive
165 .pvalue()
166 .vortex_expect("already checked for null value")
167 .as_u64()
168 .vortex_expect("Failed to cast constant value to u64");
169
170 let sum = value.checked_mul(array.len() as u64);
172
173 return Ok(Scalar::new(sum_dtype, sum.into()));
174 }
175 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
176 let value = primitive
177 .pvalue()
178 .vortex_expect("already checked for null value")
179 .as_i64()
180 .vortex_expect("Failed to cast constant value to i64");
181
182 let sum = value.checked_mul(array.len() as i64);
184
185 return Ok(Scalar::new(sum_dtype, sum.into()));
186 }
187 PType::F16 | PType::F32 | PType::F64 => {
188 let value = primitive
189 .pvalue()
190 .vortex_expect("already checked for null value")
191 .as_f64()
192 .vortex_expect("Failed to cast constant value to f64");
193
194 let sum = value * (array.len() as f64);
195
196 return Ok(Scalar::new(sum_dtype, sum.into()));
197 }
198 }
199 }
200
201 unreachable!("Unsupported sum constant: {}", constant.dtype());
203 }
204
205 let args = InvocationArgs {
207 inputs: &[array.into()],
208 options: &(),
209 };
210 for kernel in kernels {
211 if let Some(output) = kernel.invoke(&args)? {
212 return output.unwrap_scalar();
213 }
214 }
215 if let Some(output) = array.invoke(&SUM_FN, &args)? {
216 return output.unwrap_scalar();
217 }
218
219 log::debug!("No sum implementation found for {}", array.encoding_id());
221 if array.is_canonical() {
222 vortex_panic!(
224 "No sum implementation found for canonical array: {}",
225 array.encoding_id()
226 );
227 }
228 sum(array.to_canonical()?.as_ref())
229}
230
231#[cfg(test)]
232mod test {
233 use crate::arrays::{BoolArray, PrimitiveArray};
234 use crate::compute::sum;
235
236 #[test]
237 fn sum_all_invalid() {
238 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
239 let result = sum(array.as_ref()).unwrap();
240 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
241 }
242
243 #[test]
244 fn sum_all_invalid_float() {
245 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
246 let result = sum(array.as_ref()).unwrap();
247 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
248 }
249
250 #[test]
251 fn sum_constant() {
252 let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
253 let result = sum(array.as_ref()).unwrap();
254 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
255 }
256
257 #[test]
258 fn sum_constant_float() {
259 let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
260 let result = sum(array.as_ref()).unwrap();
261 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
262 }
263
264 #[test]
265 fn sum_boolean() {
266 let array = BoolArray::from_iter([true, false, false, true]);
267 let result = sum(array.as_ref()).unwrap();
268 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
269 }
270}