vortex_array/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_err, vortex_panic};
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProvider};
14use crate::vtable::VTable;
15
16static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
18    for kernel in inventory::iter::<SumKernelRef> {
19        compute.register_kernel(kernel.0.clone());
20    }
21    compute
22});
23
24/// Sum an array.
25///
26/// If the sum overflows, a null scalar will be returned.
27/// If the sum is not supported for the array's dtype, an error will be raised.
28/// If the array is all-invalid, the sum will be zero.
29pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
30    SUM_FN
31        .invoke(&InvocationArgs {
32            inputs: &[array.into()],
33            options: &(),
34        })?
35        .unwrap_scalar()
36}
37
38struct Sum;
39
40impl ComputeFnVTable for Sum {
41    fn invoke(
42        &self,
43        args: &InvocationArgs,
44        kernels: &[ArcRef<dyn Kernel>],
45    ) -> VortexResult<Output> {
46        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
47
48        // Compute the expected dtype of the sum.
49        let sum_dtype = self.return_dtype(args)?;
50
51        // Short-circuit using array statistics.
52        if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
53            return Ok(sum.into());
54        }
55
56        let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
57
58        // Update the statistics with the computed sum.
59        array
60            .statistics()
61            .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
62
63        Ok(sum_scalar.into())
64    }
65
66    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
67        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
68        Stat::Sum
69            .dtype(array.dtype())
70            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
71    }
72
73    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
74        // The sum function always returns a single scalar value.
75        Ok(1)
76    }
77
78    fn is_elementwise(&self) -> bool {
79        false
80    }
81}
82
83pub struct SumKernelRef(ArcRef<dyn Kernel>);
84inventory::collect!(SumKernelRef);
85
86pub trait SumKernel: VTable {
87    /// # Preconditions
88    ///
89    /// * The array's DType is summable
90    /// * The array is not all-null
91    fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
92}
93
94#[derive(Debug)]
95pub struct SumKernelAdapter<V: VTable>(pub V);
96
97impl<V: VTable + SumKernel> SumKernelAdapter<V> {
98    pub const fn lift(&'static self) -> SumKernelRef {
99        SumKernelRef(ArcRef::new_ref(self))
100    }
101}
102
103impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
104    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
105        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
106        let Some(array) = array.as_opt::<V>() else {
107            return Ok(None);
108        };
109        Ok(Some(V::sum(&self.0, array)?.into()))
110    }
111}
112
113/// Sum an array.
114///
115/// If the sum overflows, a null scalar will be returned.
116/// If the sum is not supported for the array's dtype, an error will be raised.
117/// If the array is all-invalid, the sum will be zero.
118pub fn sum_impl(
119    array: &dyn Array,
120    sum_dtype: DType,
121    kernels: &[ArcRef<dyn Kernel>],
122) -> VortexResult<Scalar> {
123    if array.is_empty() {
124        return if sum_dtype.is_float() {
125            Ok(Scalar::new(sum_dtype, 0.0.into()))
126        } else {
127            Ok(Scalar::new(sum_dtype, 0.into()))
128        };
129    }
130
131    // Sum of all null is null.
132    if array.all_invalid()? {
133        return Ok(Scalar::null(sum_dtype));
134    }
135
136    // Try to find a sum kernel
137    let args = InvocationArgs {
138        inputs: &[array.into()],
139        options: &(),
140    };
141    for kernel in kernels {
142        if let Some(output) = kernel.invoke(&args)? {
143            return output.unwrap_scalar();
144        }
145    }
146    if let Some(output) = array.invoke(&SUM_FN, &args)? {
147        return output.unwrap_scalar();
148    }
149
150    // Otherwise, canonicalize and try again.
151    log::debug!("No sum implementation found for {}", array.encoding_id());
152    if array.is_canonical() {
153        // Panic to avoid recursion, but it should never be hit.
154        vortex_panic!(
155            "No sum implementation found for canonical array: {}",
156            array.encoding_id()
157        );
158    }
159    sum(array.to_canonical()?.as_ref())
160}
161
162#[cfg(test)]
163mod test {
164    use vortex_dtype::{DType, Nullability, PType};
165    use vortex_scalar::Scalar;
166
167    use crate::arrays::{BoolArray, PrimitiveArray};
168    use crate::compute::sum;
169
170    #[test]
171    fn sum_all_invalid() {
172        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
173        let result = sum(array.as_ref()).unwrap();
174        assert_eq!(
175            result,
176            Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
177        );
178    }
179
180    #[test]
181    fn sum_all_invalid_float() {
182        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
183        let result = sum(array.as_ref()).unwrap();
184        assert_eq!(
185            result,
186            Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
187        );
188    }
189
190    #[test]
191    fn sum_constant() {
192        let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
193        let result = sum(array.as_ref()).unwrap();
194        assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
195    }
196
197    #[test]
198    fn sum_constant_float() {
199        let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
200        let result = sum(array.as_ref()).unwrap();
201        assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
202    }
203
204    #[test]
205    fn sum_boolean() {
206        let array = BoolArray::from_iter([true, false, false, true]);
207        let result = sum(array.as_ref()).unwrap();
208        assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
209    }
210}