vortex_array/compute/
sum.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_err, vortex_panic};
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProvider};
14use crate::vtable::VTable;
15
16static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
18 for kernel in inventory::iter::<SumKernelRef> {
19 compute.register_kernel(kernel.0.clone());
20 }
21 compute
22});
23
24pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
30 SUM_FN
31 .invoke(&InvocationArgs {
32 inputs: &[array.into()],
33 options: &(),
34 })?
35 .unwrap_scalar()
36}
37
38struct Sum;
39
40impl ComputeFnVTable for Sum {
41 fn invoke(
42 &self,
43 args: &InvocationArgs,
44 kernels: &[ArcRef<dyn Kernel>],
45 ) -> VortexResult<Output> {
46 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
47
48 let sum_dtype = self.return_dtype(args)?;
50
51 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
53 return Ok(sum.into());
54 }
55
56 let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
57
58 array
60 .statistics()
61 .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
62
63 Ok(sum_scalar.into())
64 }
65
66 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
67 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
68 Stat::Sum
69 .dtype(array.dtype())
70 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
71 }
72
73 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
74 Ok(1)
76 }
77
78 fn is_elementwise(&self) -> bool {
79 false
80 }
81}
82
83pub struct SumKernelRef(ArcRef<dyn Kernel>);
84inventory::collect!(SumKernelRef);
85
86pub trait SumKernel: VTable {
87 fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
92}
93
94#[derive(Debug)]
95pub struct SumKernelAdapter<V: VTable>(pub V);
96
97impl<V: VTable + SumKernel> SumKernelAdapter<V> {
98 pub const fn lift(&'static self) -> SumKernelRef {
99 SumKernelRef(ArcRef::new_ref(self))
100 }
101}
102
103impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
104 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
105 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
106 let Some(array) = array.as_opt::<V>() else {
107 return Ok(None);
108 };
109 Ok(Some(V::sum(&self.0, array)?.into()))
110 }
111}
112
113pub fn sum_impl(
119 array: &dyn Array,
120 sum_dtype: DType,
121 kernels: &[ArcRef<dyn Kernel>],
122) -> VortexResult<Scalar> {
123 if array.is_empty() {
124 return if sum_dtype.is_float() {
125 Ok(Scalar::new(sum_dtype, 0.0.into()))
126 } else {
127 Ok(Scalar::new(sum_dtype, 0.into()))
128 };
129 }
130
131 if array.all_invalid()? {
133 return Ok(Scalar::null(sum_dtype));
134 }
135
136 let args = InvocationArgs {
138 inputs: &[array.into()],
139 options: &(),
140 };
141 for kernel in kernels {
142 if let Some(output) = kernel.invoke(&args)? {
143 return output.unwrap_scalar();
144 }
145 }
146 if let Some(output) = array.invoke(&SUM_FN, &args)? {
147 return output.unwrap_scalar();
148 }
149
150 log::debug!("No sum implementation found for {}", array.encoding_id());
152 if array.is_canonical() {
153 vortex_panic!(
155 "No sum implementation found for canonical array: {}",
156 array.encoding_id()
157 );
158 }
159 sum(array.to_canonical()?.as_ref())
160}
161
162#[cfg(test)]
163mod test {
164 use vortex_dtype::{DType, Nullability, PType};
165 use vortex_scalar::Scalar;
166
167 use crate::arrays::{BoolArray, PrimitiveArray};
168 use crate::compute::sum;
169
170 #[test]
171 fn sum_all_invalid() {
172 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
173 let result = sum(array.as_ref()).unwrap();
174 assert_eq!(
175 result,
176 Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
177 );
178 }
179
180 #[test]
181 fn sum_all_invalid_float() {
182 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
183 let result = sum(array.as_ref()).unwrap();
184 assert_eq!(
185 result,
186 Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
187 );
188 }
189
190 #[test]
191 fn sum_constant() {
192 let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
193 let result = sum(array.as_ref()).unwrap();
194 assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
195 }
196
197 #[test]
198 fn sum_constant_float() {
199 let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
200 let result = sum(array.as_ref()).unwrap();
201 assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
202 }
203
204 #[test]
205 fn sum_boolean() {
206 let array = BoolArray::from_iter([true, false, false, true]);
207 let result = sum(array.as_ref()).unwrap();
208 assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
209 }
210}