vortex_array/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{
9    VortexError, VortexResult, vortex_bail, vortex_ensure, vortex_err, vortex_panic,
10};
11use vortex_scalar::Scalar;
12
13use crate::Array;
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
15use crate::stats::{Precision, Stat, StatsProvider};
16use crate::vtable::VTable;
17
18static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
19    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
20    for kernel in inventory::iter::<SumKernelRef> {
21        compute.register_kernel(kernel.0.clone());
22    }
23    compute
24});
25
26pub(crate) fn warm_up_vtable() -> usize {
27    SUM_FN.kernels().len()
28}
29
30/// Sum an array with an initial value.
31///
32/// If the sum overflows, a null scalar will be returned.
33/// If the sum is not supported for the array's dtype, an error will be raised.
34/// If the array is all-invalid, the sum will be the accumulator.
35/// The accumulator must have a dtype compatible with the sum result dtype.
36pub(crate) fn sum_with_accumulator(
37    array: &dyn Array,
38    accumulator: &Scalar,
39) -> VortexResult<Scalar> {
40    SUM_FN
41        .invoke(&InvocationArgs {
42            inputs: &[array.into(), accumulator.into()],
43            options: &(),
44        })?
45        .unwrap_scalar()
46}
47
48/// Sum an array, starting from zero.
49///
50/// If the sum overflows, a null scalar will be returned.
51/// If the sum is not supported for the array's dtype, an error will be raised.
52/// If the array is all-invalid, the sum will be zero.
53pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
54    let sum_dtype = Stat::Sum
55        .dtype(array.dtype())
56        .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
57    let zero = Scalar::zero_value(sum_dtype);
58    sum_with_accumulator(array, &zero)
59}
60
61/// For unary compute functions, it's useful to just have this short-cut.
62pub struct SumArgs<'a> {
63    pub array: &'a dyn Array,
64    pub accumulator: &'a Scalar,
65}
66
67impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
68    type Error = VortexError;
69
70    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
71        if value.inputs.len() != 2 {
72            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
73        }
74        let array = value.inputs[0]
75            .array()
76            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
77        let accumulator = value.inputs[1]
78            .scalar()
79            .ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?;
80        Ok(SumArgs { array, accumulator })
81    }
82}
83
84struct Sum;
85
86impl ComputeFnVTable for Sum {
87    fn invoke(
88        &self,
89        args: &InvocationArgs,
90        kernels: &[ArcRef<dyn Kernel>],
91    ) -> VortexResult<Output> {
92        let SumArgs { array, accumulator } = args.try_into()?;
93
94        // Compute the expected dtype of the sum.
95        let sum_dtype = self.return_dtype(args)?;
96
97        vortex_ensure!(
98            &sum_dtype == accumulator.dtype(),
99            "sum_dtype {sum_dtype} must match accumulator dtype {}",
100            accumulator.dtype()
101        );
102
103        // Short-circuit using array statistics.
104        if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
105            return Ok(sum.into());
106        }
107
108        let sum_scalar = sum_impl(array, accumulator, kernels)?;
109
110        // Update the statistics with the computed sum.
111        array
112            .statistics()
113            .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
114
115        Ok(sum_scalar.into())
116    }
117
118    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
119        let SumArgs { array, .. } = args.try_into()?;
120        Stat::Sum
121            .dtype(array.dtype())
122            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
123    }
124
125    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
126        // The sum function always returns a single scalar value.
127        Ok(1)
128    }
129
130    fn is_elementwise(&self) -> bool {
131        false
132    }
133}
134
135pub struct SumKernelRef(ArcRef<dyn Kernel>);
136inventory::collect!(SumKernelRef);
137
138pub trait SumKernel: VTable {
139    /// # Preconditions
140    ///
141    /// * The array's DType is summable
142    /// * The array is not all-null
143    /// * The accumulator must have a dtype compatible with the sum result dtype
144    fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult<Scalar>;
145}
146
147#[derive(Debug)]
148pub struct SumKernelAdapter<V: VTable>(pub V);
149
150impl<V: VTable + SumKernel> SumKernelAdapter<V> {
151    pub const fn lift(&'static self) -> SumKernelRef {
152        SumKernelRef(ArcRef::new_ref(self))
153    }
154}
155
156impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
157    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
158        let SumArgs { array, accumulator } = args.try_into()?;
159        let Some(array) = array.as_opt::<V>() else {
160            return Ok(None);
161        };
162        Ok(Some(V::sum(&self.0, array, accumulator)?.into()))
163    }
164}
165
166/// Sum an array.
167///
168/// If the sum overflows, a null scalar will be returned.
169/// If the sum is not supported for the array's dtype, an error will be raised.
170/// If the array is all-invalid, the sum will be the accumulator.
171pub fn sum_impl(
172    array: &dyn Array,
173    accumulator: &Scalar,
174    kernels: &[ArcRef<dyn Kernel>],
175) -> VortexResult<Scalar> {
176    if array.is_empty() || array.all_invalid() || accumulator.is_null() {
177        return Ok(accumulator.clone());
178    }
179
180    // Try to find a sum kernel
181    let args = InvocationArgs {
182        inputs: &[array.into(), accumulator.into()],
183        options: &(),
184    };
185    for kernel in kernels {
186        if let Some(output) = kernel.invoke(&args)? {
187            return output.unwrap_scalar();
188        }
189    }
190    if let Some(output) = array.invoke(&SUM_FN, &args)? {
191        return output.unwrap_scalar();
192    }
193
194    // Otherwise, canonicalize and try again.
195    log::debug!("No sum implementation found for {}", array.encoding_id());
196    if array.is_canonical() {
197        // Panic to avoid recursion, but it should never be hit.
198        vortex_panic!(
199            "No sum implementation found for canonical array: {}",
200            array.encoding_id()
201        );
202    }
203    sum_with_accumulator(array.to_canonical().as_ref(), accumulator)
204}
205
206#[cfg(test)]
207mod test {
208    use vortex_buffer::buffer;
209    use vortex_dtype::Nullability;
210    use vortex_scalar::Scalar;
211
212    use crate::IntoArray as _;
213    use crate::arrays::{BoolArray, PrimitiveArray};
214    use crate::compute::sum;
215
216    #[test]
217    fn sum_all_invalid() {
218        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
219        let result = sum(array.as_ref()).unwrap();
220        assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
221    }
222
223    #[test]
224    fn sum_all_invalid_float() {
225        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
226        let result = sum(array.as_ref()).unwrap();
227        assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
228    }
229
230    #[test]
231    fn sum_constant() {
232        let array = buffer![1, 1, 1, 1].into_array();
233        let result = sum(array.as_ref()).unwrap();
234        assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
235    }
236
237    #[test]
238    fn sum_constant_float() {
239        let array = buffer![1., 1., 1., 1.].into_array();
240        let result = sum(array.as_ref()).unwrap();
241        assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
242    }
243
244    #[test]
245    fn sum_boolean() {
246        let array = BoolArray::from_iter([true, false, false, true]);
247        let result = sum(array.as_ref()).unwrap();
248        assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
249    }
250}