vortex_array/compute/
sum.rs

1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::{DType, PType};
5use vortex_error::{VortexExpect, VortexResult, vortex_err, vortex_panic};
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
10use crate::stats::{Precision, Stat, StatsProvider};
11use crate::vtable::VTable;
12
13/// Sum an array.
14///
15/// If the sum overflows, a null scalar will be returned.
16/// If the sum is not supported for the array's dtype, an error will be raised.
17/// If the array is all-invalid, the sum will be zero.
18pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
19    SUM_FN
20        .invoke(&InvocationArgs {
21            inputs: &[array.into()],
22            options: &(),
23        })?
24        .unwrap_scalar()
25}
26
27struct Sum;
28
29impl ComputeFnVTable for Sum {
30    fn invoke(
31        &self,
32        args: &InvocationArgs,
33        kernels: &[ArcRef<dyn Kernel>],
34    ) -> VortexResult<Output> {
35        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
36
37        // Compute the expected dtype of the sum.
38        let sum_dtype = self.return_dtype(args)?;
39
40        // Short-circuit using array statistics.
41        if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
42            return Ok(Scalar::new(sum_dtype, sum).into());
43        }
44
45        let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
46
47        // Update the statistics with the computed sum.
48        array
49            .statistics()
50            .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
51
52        Ok(sum_scalar.into())
53    }
54
55    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
56        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
57        Stat::Sum
58            .dtype(array.dtype())
59            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
60    }
61
62    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
63        // The sum function always returns a single scalar value.
64        Ok(1)
65    }
66
67    fn is_elementwise(&self) -> bool {
68        false
69    }
70}
71
72pub static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
73    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
74    for kernel in inventory::iter::<SumKernelRef> {
75        compute.register_kernel(kernel.0.clone());
76    }
77    compute
78});
79
80pub struct SumKernelRef(ArcRef<dyn Kernel>);
81inventory::collect!(SumKernelRef);
82
83pub trait SumKernel: VTable {
84    /// # Preconditions
85    ///
86    /// * The array's DType is summable
87    /// * The array is not all-null
88    fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
89}
90
91#[derive(Debug)]
92pub struct SumKernelAdapter<V: VTable>(pub V);
93
94impl<V: VTable + SumKernel> SumKernelAdapter<V> {
95    pub const fn lift(&'static self) -> SumKernelRef {
96        SumKernelRef(ArcRef::new_ref(self))
97    }
98}
99
100impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
101    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
102        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
103        let Some(array) = array.as_opt::<V>() else {
104            return Ok(None);
105        };
106        Ok(Some(V::sum(&self.0, array)?.into()))
107    }
108}
109
110/// Sum an array.
111///
112/// If the sum overflows, a null scalar will be returned.
113/// If the sum is not supported for the array's dtype, an error will be raised.
114/// If the array is all-invalid, the sum will be zero.
115pub fn sum_impl(
116    array: &dyn Array,
117    sum_dtype: DType,
118    kernels: &[ArcRef<dyn Kernel>],
119) -> VortexResult<Scalar> {
120    if array.is_empty() {
121        return if sum_dtype.is_float() {
122            Ok(Scalar::new(sum_dtype, 0.0.into()))
123        } else {
124            Ok(Scalar::new(sum_dtype, 0.into()))
125        };
126    }
127
128    // If the array is constant, we can compute the sum directly.
129    if let Some(mut constant) = array.as_constant() {
130        if constant.is_null() {
131            // An all-null constant array has a sum of 0.
132            return if sum_dtype.is_float() {
133                Ok(Scalar::new(sum_dtype, 0.0.into()))
134            } else {
135                Ok(Scalar::new(sum_dtype, 0.into()))
136            };
137        }
138
139        // TODO(ngates): I think we should delegate these to kernels, rather than hard-code.
140
141        // If it's an extension array, then unwrap it into the storage scalar.
142        if let Some(extension) = constant.as_extension_opt() {
143            constant = extension.storage();
144        }
145
146        // If it's a boolean array, then the true count is the sum, which is the length.
147        if let Some(bool) = constant.as_bool_opt() {
148            return if bool.value().vortex_expect("already checked for null value") {
149                // Constant true
150                Ok(Scalar::new(sum_dtype, array.len().into()))
151            } else {
152                // Constant false
153                Ok(Scalar::new(sum_dtype, 0.into()))
154            };
155        }
156
157        // If it's a primitive array, then the sum is the constant value times the length.
158        if let Some(primitive) = constant.as_primitive_opt() {
159            match primitive.ptype() {
160                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
161                    let value = primitive
162                        .pvalue()
163                        .vortex_expect("already checked for null value")
164                        .as_u64()
165                        .vortex_expect("Failed to cast constant value to u64");
166
167                    // Overflow results in a null sum.
168                    let sum = value.checked_mul(array.len() as u64);
169
170                    return Ok(Scalar::new(sum_dtype, sum.into()));
171                }
172                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
173                    let value = primitive
174                        .pvalue()
175                        .vortex_expect("already checked for null value")
176                        .as_i64()
177                        .vortex_expect("Failed to cast constant value to i64");
178
179                    // Overflow results in a null sum.
180                    let sum = value.checked_mul(array.len() as i64);
181
182                    return Ok(Scalar::new(sum_dtype, sum.into()));
183                }
184                PType::F16 | PType::F32 | PType::F64 => {
185                    let value = primitive
186                        .pvalue()
187                        .vortex_expect("already checked for null value")
188                        .as_f64()
189                        .vortex_expect("Failed to cast constant value to f64");
190
191                    let sum = value * (array.len() as f64);
192
193                    return Ok(Scalar::new(sum_dtype, sum.into()));
194                }
195            }
196        }
197
198        // For the unsupported types, we should have exited earlier.
199        unreachable!("Unsupported sum constant: {}", constant.dtype());
200    }
201
202    // Try to find a sum kernel
203    let args = InvocationArgs {
204        inputs: &[array.into()],
205        options: &(),
206    };
207    for kernel in kernels {
208        if let Some(output) = kernel.invoke(&args)? {
209            return output.unwrap_scalar();
210        }
211    }
212    if let Some(output) = array.invoke(&SUM_FN, &args)? {
213        return output.unwrap_scalar();
214    }
215
216    // Otherwise, canonicalize and try again.
217    log::debug!("No sum implementation found for {}", array.encoding_id());
218    if array.is_canonical() {
219        // Panic to avoid recursion, but it should never be hit.
220        vortex_panic!(
221            "No sum implementation found for canonical array: {}",
222            array.encoding_id()
223        );
224    }
225    sum(array.to_canonical()?.as_ref())
226}
227
228#[cfg(test)]
229mod test {
230    use crate::arrays::{BoolArray, PrimitiveArray};
231    use crate::compute::sum;
232
233    #[test]
234    fn sum_all_invalid() {
235        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
236        let result = sum(array.as_ref()).unwrap();
237        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
238    }
239
240    #[test]
241    fn sum_all_invalid_float() {
242        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
243        let result = sum(array.as_ref()).unwrap();
244        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
245    }
246
247    #[test]
248    fn sum_constant() {
249        let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
250        let result = sum(array.as_ref()).unwrap();
251        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
252    }
253
254    #[test]
255    fn sum_constant_float() {
256        let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
257        let result = sum(array.as_ref()).unwrap();
258        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
259    }
260
261    #[test]
262    fn sum_boolean() {
263        let array = BoolArray::from_iter([true, false, false, true]);
264        let result = sum(array.as_ref()).unwrap();
265        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
266    }
267}