vortex_array/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_err, vortex_panic};
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProvider};
14use crate::vtable::VTable;
15
16static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
18    for kernel in inventory::iter::<SumKernelRef> {
19        compute.register_kernel(kernel.0.clone());
20    }
21    compute
22});
23
24pub(crate) fn warm_up_vtable() -> usize {
25    SUM_FN.kernels().len()
26}
27
28/// Sum an array.
29///
30/// If the sum overflows, a null scalar will be returned.
31/// If the sum is not supported for the array's dtype, an error will be raised.
32/// If the array is all-invalid, the sum will be zero.
33pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
34    SUM_FN
35        .invoke(&InvocationArgs {
36            inputs: &[array.into()],
37            options: &(),
38        })?
39        .unwrap_scalar()
40}
41
42struct Sum;
43
44impl ComputeFnVTable for Sum {
45    fn invoke(
46        &self,
47        args: &InvocationArgs,
48        kernels: &[ArcRef<dyn Kernel>],
49    ) -> VortexResult<Output> {
50        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
51
52        // Compute the expected dtype of the sum.
53        let sum_dtype = self.return_dtype(args)?;
54
55        // Short-circuit using array statistics.
56        if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
57            return Ok(sum.into());
58        }
59
60        let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
61
62        // Update the statistics with the computed sum.
63        array
64            .statistics()
65            .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
66
67        Ok(sum_scalar.into())
68    }
69
70    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
71        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
72        Stat::Sum
73            .dtype(array.dtype())
74            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
75    }
76
77    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
78        // The sum function always returns a single scalar value.
79        Ok(1)
80    }
81
82    fn is_elementwise(&self) -> bool {
83        false
84    }
85}
86
87pub struct SumKernelRef(ArcRef<dyn Kernel>);
88inventory::collect!(SumKernelRef);
89
90pub trait SumKernel: VTable {
91    /// # Preconditions
92    ///
93    /// * The array's DType is summable
94    /// * The array is not all-null
95    fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
96}
97
98#[derive(Debug)]
99pub struct SumKernelAdapter<V: VTable>(pub V);
100
101impl<V: VTable + SumKernel> SumKernelAdapter<V> {
102    pub const fn lift(&'static self) -> SumKernelRef {
103        SumKernelRef(ArcRef::new_ref(self))
104    }
105}
106
107impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
108    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
109        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
110        let Some(array) = array.as_opt::<V>() else {
111            return Ok(None);
112        };
113        Ok(Some(V::sum(&self.0, array)?.into()))
114    }
115}
116
117/// Sum an array.
118///
119/// If the sum overflows, a null scalar will be returned.
120/// If the sum is not supported for the array's dtype, an error will be raised.
121/// If the array is all-invalid, the sum will be zero.
122pub fn sum_impl(
123    array: &dyn Array,
124    sum_dtype: DType,
125    kernels: &[ArcRef<dyn Kernel>],
126) -> VortexResult<Scalar> {
127    if array.is_empty() {
128        return if sum_dtype.is_float() {
129            Ok(Scalar::new(sum_dtype, 0.0.into()))
130        } else {
131            Ok(Scalar::new(sum_dtype, 0.into()))
132        };
133    }
134
135    // Sum of all null is null.
136    if array.all_invalid() {
137        return Ok(Scalar::null(sum_dtype));
138    }
139
140    // Try to find a sum kernel
141    let args = InvocationArgs {
142        inputs: &[array.into()],
143        options: &(),
144    };
145    for kernel in kernels {
146        if let Some(output) = kernel.invoke(&args)? {
147            return output.unwrap_scalar();
148        }
149    }
150    if let Some(output) = array.invoke(&SUM_FN, &args)? {
151        return output.unwrap_scalar();
152    }
153
154    // Otherwise, canonicalize and try again.
155    log::debug!("No sum implementation found for {}", array.encoding_id());
156    if array.is_canonical() {
157        // Panic to avoid recursion, but it should never be hit.
158        vortex_panic!(
159            "No sum implementation found for canonical array: {}",
160            array.encoding_id()
161        );
162    }
163    sum(array.to_canonical().as_ref())
164}
165
166#[cfg(test)]
167mod test {
168    use vortex_buffer::buffer;
169    use vortex_dtype::{DType, Nullability, PType};
170    use vortex_scalar::Scalar;
171
172    use crate::IntoArray as _;
173    use crate::arrays::{BoolArray, PrimitiveArray};
174    use crate::compute::sum;
175
176    #[test]
177    fn sum_all_invalid() {
178        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
179        let result = sum(array.as_ref()).unwrap();
180        assert_eq!(
181            result,
182            Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
183        );
184    }
185
186    #[test]
187    fn sum_all_invalid_float() {
188        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
189        let result = sum(array.as_ref()).unwrap();
190        assert_eq!(
191            result,
192            Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
193        );
194    }
195
196    #[test]
197    fn sum_constant() {
198        let array = buffer![1, 1, 1, 1].into_array();
199        let result = sum(array.as_ref()).unwrap();
200        assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
201    }
202
203    #[test]
204    fn sum_constant_float() {
205        let array = buffer![1., 1., 1., 1.].into_array();
206        let result = sum(array.as_ref()).unwrap();
207        assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
208    }
209
210    #[test]
211    fn sum_boolean() {
212        let array = BoolArray::from_iter([true, false, false, true]);
213        let result = sum(array.as_ref()).unwrap();
214        assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
215    }
216}