vortex_array/compute/
sum.rs1use std::sync::LazyLock;
2
3use vortex_dtype::{DType, PType};
4use vortex_error::{
5 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
6};
7use vortex_scalar::Scalar;
8
9use crate::Array;
10use crate::arcref::ArcRef;
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
12use crate::encoding::Encoding;
13use crate::stats::{Precision, Stat, StatsProvider};
14
15pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
21 SUM_FN
22 .invoke(&InvocationArgs {
23 inputs: &[array.into()],
24 options: &(),
25 })?
26 .unwrap_scalar()
27}
28
29struct Sum;
30
31impl ComputeFnVTable for Sum {
32 fn invoke(
33 &self,
34 args: &InvocationArgs,
35 kernels: &[ArcRef<dyn Kernel>],
36 ) -> VortexResult<Output> {
37 let SumArgs { array } = SumArgs::try_from(args)?;
38
39 let sum_dtype = self.return_dtype(args)?;
41
42 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
44 return Ok(Scalar::new(sum_dtype, sum).into());
45 }
46
47 let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
48
49 array
51 .statistics()
52 .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
53
54 Ok(sum_scalar.into())
55 }
56
57 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
58 let SumArgs { array } = SumArgs::try_from(args)?;
59 Stat::Sum
60 .dtype(array.dtype())
61 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
62 }
63
64 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
65 Ok(1)
67 }
68
69 fn is_elementwise(&self) -> bool {
70 false
71 }
72}
73
74struct SumArgs<'a> {
75 array: &'a dyn Array,
76}
77
78impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
79 type Error = VortexError;
80
81 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
82 if value.inputs.len() != 1 {
83 vortex_bail!(
84 "Sum function requires exactly one argument, got {}",
85 value.inputs.len()
86 );
87 }
88 let array = value.inputs[0]
89 .array()
90 .ok_or_else(|| vortex_err!("Invalid argument type for sum function"))?;
91
92 Ok(SumArgs { array })
93 }
94}
95
96pub static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
97 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
98 for kernel in inventory::iter::<SumKernelRef> {
99 compute.register_kernel(kernel.0.clone());
100 }
101 compute
102});
103
104pub struct SumKernelRef(ArcRef<dyn Kernel>);
105inventory::collect!(SumKernelRef);
106
107pub trait SumKernel: Encoding {
108 fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
113}
114
115#[derive(Debug)]
116pub struct SumKernelAdapter<E: Encoding>(pub E);
117
118impl<E: Encoding + SumKernel> SumKernelAdapter<E> {
119 pub const fn lift(&'static self) -> SumKernelRef {
120 SumKernelRef(ArcRef::new_ref(self))
121 }
122}
123
124impl<E: Encoding + SumKernel> Kernel for SumKernelAdapter<E> {
125 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
126 let SumArgs { array } = SumArgs::try_from(args)?;
127 let Some(array) = array.as_any().downcast_ref::<E::Array>() else {
128 return Ok(None);
129 };
130 Ok(Some(E::sum(&self.0, array)?.into()))
131 }
132}
133
134pub fn sum_impl(
140 array: &dyn Array,
141 sum_dtype: DType,
142 kernels: &[ArcRef<dyn Kernel>],
143) -> VortexResult<Scalar> {
144 if array.is_empty() {
145 return if sum_dtype.is_float() {
146 Ok(Scalar::new(sum_dtype, 0.0.into()))
147 } else {
148 Ok(Scalar::new(sum_dtype, 0.into()))
149 };
150 }
151
152 if let Some(mut constant) = array.as_constant() {
154 if constant.is_null() {
155 return if sum_dtype.is_float() {
157 Ok(Scalar::new(sum_dtype, 0.0.into()))
158 } else {
159 Ok(Scalar::new(sum_dtype, 0.into()))
160 };
161 }
162
163 if let Some(extension) = constant.as_extension_opt() {
167 constant = extension.storage();
168 }
169
170 if let Some(bool) = constant.as_bool_opt() {
172 return if bool.value().vortex_expect("already checked for null value") {
173 Ok(Scalar::new(sum_dtype, array.len().into()))
175 } else {
176 Ok(Scalar::new(sum_dtype, 0.into()))
178 };
179 }
180
181 if let Some(primitive) = constant.as_primitive_opt() {
183 match primitive.ptype() {
184 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
185 let value = primitive
186 .pvalue()
187 .vortex_expect("already checked for null value")
188 .as_u64()
189 .vortex_expect("Failed to cast constant value to u64");
190
191 let sum = value.checked_mul(array.len() as u64);
193
194 return Ok(Scalar::new(sum_dtype, sum.into()));
195 }
196 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
197 let value = primitive
198 .pvalue()
199 .vortex_expect("already checked for null value")
200 .as_i64()
201 .vortex_expect("Failed to cast constant value to i64");
202
203 let sum = value.checked_mul(array.len() as i64);
205
206 return Ok(Scalar::new(sum_dtype, sum.into()));
207 }
208 PType::F16 | PType::F32 | PType::F64 => {
209 let value = primitive
210 .pvalue()
211 .vortex_expect("already checked for null value")
212 .as_f64()
213 .vortex_expect("Failed to cast constant value to f64");
214
215 let sum = value * (array.len() as f64);
216
217 return Ok(Scalar::new(sum_dtype, sum.into()));
218 }
219 }
220 }
221
222 unreachable!("Unsupported sum constant: {}", constant.dtype());
224 }
225
226 let args = InvocationArgs {
228 inputs: &[array.into()],
229 options: &(),
230 };
231 for kernel in kernels {
232 if let Some(output) = kernel.invoke(&args)? {
233 return output.unwrap_scalar();
234 }
235 }
236 if let Some(output) = array.invoke(&SUM_FN, &args)? {
237 return output.unwrap_scalar();
238 }
239
240 log::debug!("No sum implementation found for {}", array.encoding());
242 if array.is_canonical() {
243 vortex_panic!(
245 "No sum implementation found for canonical array: {}",
246 array.encoding()
247 );
248 }
249 sum(array.to_canonical()?.as_ref())
250}
251
252#[cfg(test)]
253mod test {
254 use crate::arrays::{BoolArray, PrimitiveArray};
255 use crate::compute::sum;
256
257 #[test]
258 fn sum_all_invalid() {
259 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
260 let result = sum(&array).unwrap();
261 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
262 }
263
264 #[test]
265 fn sum_all_invalid_float() {
266 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
267 let result = sum(&array).unwrap();
268 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
269 }
270
271 #[test]
272 fn sum_constant() {
273 let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
274 let result = sum(&array).unwrap();
275 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
276 }
277
278 #[test]
279 fn sum_constant_float() {
280 let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
281 let result = sum(&array).unwrap();
282 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
283 }
284
285 #[test]
286 fn sum_boolean() {
287 let array = BoolArray::from_iter([true, false, false, true]);
288 let result = sum(&array).unwrap();
289 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
290 }
291}