vortex_array/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_err, vortex_panic};
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProvider};
14use crate::vtable::VTable;
15
16/// Sum an array.
17///
18/// If the sum overflows, a null scalar will be returned.
19/// If the sum is not supported for the array's dtype, an error will be raised.
20/// If the array is all-invalid, the sum will be zero.
21pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
22    SUM_FN
23        .invoke(&InvocationArgs {
24            inputs: &[array.into()],
25            options: &(),
26        })?
27        .unwrap_scalar()
28}
29
30struct Sum;
31
32impl ComputeFnVTable for Sum {
33    fn invoke(
34        &self,
35        args: &InvocationArgs,
36        kernels: &[ArcRef<dyn Kernel>],
37    ) -> VortexResult<Output> {
38        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
39
40        // Compute the expected dtype of the sum.
41        let sum_dtype = self.return_dtype(args)?;
42
43        // Short-circuit using array statistics.
44        if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
45            return Ok(Scalar::new(sum_dtype, sum).into());
46        }
47
48        let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
49
50        // Update the statistics with the computed sum.
51        array
52            .statistics()
53            .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
54
55        Ok(sum_scalar.into())
56    }
57
58    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
59        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
60        Stat::Sum
61            .dtype(array.dtype())
62            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
63    }
64
65    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
66        // The sum function always returns a single scalar value.
67        Ok(1)
68    }
69
70    fn is_elementwise(&self) -> bool {
71        false
72    }
73}
74
75pub static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
76    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
77    for kernel in inventory::iter::<SumKernelRef> {
78        compute.register_kernel(kernel.0.clone());
79    }
80    compute
81});
82
83pub struct SumKernelRef(ArcRef<dyn Kernel>);
84inventory::collect!(SumKernelRef);
85
86pub trait SumKernel: VTable {
87    /// # Preconditions
88    ///
89    /// * The array's DType is summable
90    /// * The array is not all-null
91    fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
92}
93
94#[derive(Debug)]
95pub struct SumKernelAdapter<V: VTable>(pub V);
96
97impl<V: VTable + SumKernel> SumKernelAdapter<V> {
98    pub const fn lift(&'static self) -> SumKernelRef {
99        SumKernelRef(ArcRef::new_ref(self))
100    }
101}
102
103impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
104    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
105        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
106        let Some(array) = array.as_opt::<V>() else {
107            return Ok(None);
108        };
109        Ok(Some(V::sum(&self.0, array)?.into()))
110    }
111}
112
113/// Sum an array.
114///
115/// If the sum overflows, a null scalar will be returned.
116/// If the sum is not supported for the array's dtype, an error will be raised.
117/// If the array is all-invalid, the sum will be zero.
118pub fn sum_impl(
119    array: &dyn Array,
120    sum_dtype: DType,
121    kernels: &[ArcRef<dyn Kernel>],
122) -> VortexResult<Scalar> {
123    if array.is_empty() {
124        return if sum_dtype.is_float() {
125            Ok(Scalar::new(sum_dtype, 0.0.into()))
126        } else {
127            Ok(Scalar::new(sum_dtype, 0.into()))
128        };
129    }
130
131    // If the array is constant, we can compute the sum directly.
132    if let Some(mut constant) = array.as_constant() {
133        if constant.is_null() {
134            // An all-null constant array has a sum of 0.
135            return if sum_dtype.is_float() {
136                Ok(Scalar::new(sum_dtype, 0.0.into()))
137            } else {
138                Ok(Scalar::new(sum_dtype, 0.into()))
139            };
140        }
141
142        // TODO(ngates): I think we should delegate these to kernels, rather than hard-code.
143
144        // If it's an extension array, then unwrap it into the storage scalar.
145        if let Some(extension) = constant.as_extension_opt() {
146            constant = extension.storage();
147        }
148
149        // If it's a boolean array, then the true count is the sum, which is the length.
150        if let Some(bool) = constant.as_bool_opt() {
151            return if bool.value().vortex_expect("already checked for null value") {
152                // Constant true
153                Ok(Scalar::new(sum_dtype, array.len().into()))
154            } else {
155                // Constant false
156                Ok(Scalar::new(sum_dtype, 0.into()))
157            };
158        }
159
160        // If it's a primitive array, then the sum is the constant value times the length.
161        if let Some(primitive) = constant.as_primitive_opt() {
162            match primitive.ptype() {
163                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
164                    let value = primitive
165                        .pvalue()
166                        .vortex_expect("already checked for null value")
167                        .as_u64()
168                        .vortex_expect("Failed to cast constant value to u64");
169
170                    // Overflow results in a null sum.
171                    let sum = value.checked_mul(array.len() as u64);
172
173                    return Ok(Scalar::new(sum_dtype, sum.into()));
174                }
175                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
176                    let value = primitive
177                        .pvalue()
178                        .vortex_expect("already checked for null value")
179                        .as_i64()
180                        .vortex_expect("Failed to cast constant value to i64");
181
182                    // Overflow results in a null sum.
183                    let sum = value.checked_mul(array.len() as i64);
184
185                    return Ok(Scalar::new(sum_dtype, sum.into()));
186                }
187                PType::F16 | PType::F32 | PType::F64 => {
188                    let value = primitive
189                        .pvalue()
190                        .vortex_expect("already checked for null value")
191                        .as_f64()
192                        .vortex_expect("Failed to cast constant value to f64");
193
194                    let sum = value * (array.len() as f64);
195
196                    return Ok(Scalar::new(sum_dtype, sum.into()));
197                }
198            }
199        }
200
201        // For the unsupported types, we should have exited earlier.
202        unreachable!("Unsupported sum constant: {}", constant.dtype());
203    }
204
205    // Try to find a sum kernel
206    let args = InvocationArgs {
207        inputs: &[array.into()],
208        options: &(),
209    };
210    for kernel in kernels {
211        if let Some(output) = kernel.invoke(&args)? {
212            return output.unwrap_scalar();
213        }
214    }
215    if let Some(output) = array.invoke(&SUM_FN, &args)? {
216        return output.unwrap_scalar();
217    }
218
219    // Otherwise, canonicalize and try again.
220    log::debug!("No sum implementation found for {}", array.encoding_id());
221    if array.is_canonical() {
222        // Panic to avoid recursion, but it should never be hit.
223        vortex_panic!(
224            "No sum implementation found for canonical array: {}",
225            array.encoding_id()
226        );
227    }
228    sum(array.to_canonical()?.as_ref())
229}
230
231#[cfg(test)]
232mod test {
233    use crate::arrays::{BoolArray, PrimitiveArray};
234    use crate::compute::sum;
235
236    #[test]
237    fn sum_all_invalid() {
238        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
239        let result = sum(array.as_ref()).unwrap();
240        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
241    }
242
243    #[test]
244    fn sum_all_invalid_float() {
245        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
246        let result = sum(array.as_ref()).unwrap();
247        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
248    }
249
250    #[test]
251    fn sum_constant() {
252        let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
253        let result = sum(array.as_ref()).unwrap();
254        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
255    }
256
257    #[test]
258    fn sum_constant_float() {
259        let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
260        let result = sum(array.as_ref()).unwrap();
261        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
262    }
263
264    #[test]
265    fn sum_boolean() {
266        let array = BoolArray::from_iter([true, false, false, true]);
267        let result = sum(array.as_ref()).unwrap();
268        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
269    }
270}