vortex_array/compute/
sum.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_err, vortex_panic};
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProvider};
14use crate::vtable::VTable;
15
16static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17 let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
18 for kernel in inventory::iter::<SumKernelRef> {
19 compute.register_kernel(kernel.0.clone());
20 }
21 compute
22});
23
24pub(crate) fn warm_up_vtable() -> usize {
25 SUM_FN.kernels().len()
26}
27
28pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
34 SUM_FN
35 .invoke(&InvocationArgs {
36 inputs: &[array.into()],
37 options: &(),
38 })?
39 .unwrap_scalar()
40}
41
42struct Sum;
43
44impl ComputeFnVTable for Sum {
45 fn invoke(
46 &self,
47 args: &InvocationArgs,
48 kernels: &[ArcRef<dyn Kernel>],
49 ) -> VortexResult<Output> {
50 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
51
52 let sum_dtype = self.return_dtype(args)?;
54
55 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
57 return Ok(sum.into());
58 }
59
60 let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
61
62 array
64 .statistics()
65 .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
66
67 Ok(sum_scalar.into())
68 }
69
70 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
71 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
72 Stat::Sum
73 .dtype(array.dtype())
74 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
75 }
76
77 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
78 Ok(1)
80 }
81
82 fn is_elementwise(&self) -> bool {
83 false
84 }
85}
86
87pub struct SumKernelRef(ArcRef<dyn Kernel>);
88inventory::collect!(SumKernelRef);
89
90pub trait SumKernel: VTable {
91 fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
96}
97
98#[derive(Debug)]
99pub struct SumKernelAdapter<V: VTable>(pub V);
100
101impl<V: VTable + SumKernel> SumKernelAdapter<V> {
102 pub const fn lift(&'static self) -> SumKernelRef {
103 SumKernelRef(ArcRef::new_ref(self))
104 }
105}
106
107impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
108 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
109 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
110 let Some(array) = array.as_opt::<V>() else {
111 return Ok(None);
112 };
113 Ok(Some(V::sum(&self.0, array)?.into()))
114 }
115}
116
117pub fn sum_impl(
123 array: &dyn Array,
124 sum_dtype: DType,
125 kernels: &[ArcRef<dyn Kernel>],
126) -> VortexResult<Scalar> {
127 if array.is_empty() {
128 return if sum_dtype.is_float() {
129 Ok(Scalar::new(sum_dtype, 0.0.into()))
130 } else {
131 Ok(Scalar::new(sum_dtype, 0.into()))
132 };
133 }
134
135 if array.all_invalid() {
137 return Ok(Scalar::null(sum_dtype));
138 }
139
140 let args = InvocationArgs {
142 inputs: &[array.into()],
143 options: &(),
144 };
145 for kernel in kernels {
146 if let Some(output) = kernel.invoke(&args)? {
147 return output.unwrap_scalar();
148 }
149 }
150 if let Some(output) = array.invoke(&SUM_FN, &args)? {
151 return output.unwrap_scalar();
152 }
153
154 log::debug!("No sum implementation found for {}", array.encoding_id());
156 if array.is_canonical() {
157 vortex_panic!(
159 "No sum implementation found for canonical array: {}",
160 array.encoding_id()
161 );
162 }
163 sum(array.to_canonical().as_ref())
164}
165
166#[cfg(test)]
167mod test {
168 use vortex_buffer::buffer;
169 use vortex_dtype::{DType, Nullability, PType};
170 use vortex_scalar::Scalar;
171
172 use crate::IntoArray as _;
173 use crate::arrays::{BoolArray, PrimitiveArray};
174 use crate::compute::sum;
175
176 #[test]
177 fn sum_all_invalid() {
178 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
179 let result = sum(array.as_ref()).unwrap();
180 assert_eq!(
181 result,
182 Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
183 );
184 }
185
186 #[test]
187 fn sum_all_invalid_float() {
188 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
189 let result = sum(array.as_ref()).unwrap();
190 assert_eq!(
191 result,
192 Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
193 );
194 }
195
196 #[test]
197 fn sum_constant() {
198 let array = buffer![1, 1, 1, 1].into_array();
199 let result = sum(array.as_ref()).unwrap();
200 assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
201 }
202
203 #[test]
204 fn sum_constant_float() {
205 let array = buffer![1., 1., 1., 1.].into_array();
206 let result = sum(array.as_ref()).unwrap();
207 assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
208 }
209
210 #[test]
211 fn sum_boolean() {
212 let array = BoolArray::from_iter([true, false, false, true]);
213 let result = sum(array.as_ref()).unwrap();
214 assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
215 }
216}