vortex_array/compute/
sum.rs

1use std::sync::LazyLock;
2
3use vortex_dtype::{DType, PType};
4use vortex_error::{
5    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
6};
7use vortex_scalar::Scalar;
8
9use crate::Array;
10use crate::arcref::ArcRef;
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
12use crate::encoding::Encoding;
13use crate::stats::{Precision, Stat, StatsProvider};
14
15/// Sum an array.
16///
17/// If the sum overflows, a null scalar will be returned.
18/// If the sum is not supported for the array's dtype, an error will be raised.
19/// If the array is all-invalid, the sum will be zero.
20pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
21    SUM_FN
22        .invoke(&InvocationArgs {
23            inputs: &[array.into()],
24            options: &(),
25        })?
26        .unwrap_scalar()
27}
28
29struct Sum;
30
31impl ComputeFnVTable for Sum {
32    fn invoke(
33        &self,
34        args: &InvocationArgs,
35        kernels: &[ArcRef<dyn Kernel>],
36    ) -> VortexResult<Output> {
37        let SumArgs { array } = SumArgs::try_from(args)?;
38
39        // Compute the expected dtype of the sum.
40        let sum_dtype = self.return_dtype(args)?;
41
42        // Short-circuit using array statistics.
43        if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
44            return Ok(Scalar::new(sum_dtype, sum).into());
45        }
46
47        let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
48
49        // Update the statistics with the computed sum.
50        array
51            .statistics()
52            .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
53
54        Ok(sum_scalar.into())
55    }
56
57    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
58        let SumArgs { array } = SumArgs::try_from(args)?;
59        Stat::Sum
60            .dtype(array.dtype())
61            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
62    }
63
64    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
65        // The sum function always returns a single scalar value.
66        Ok(1)
67    }
68
69    fn is_elementwise(&self) -> bool {
70        false
71    }
72}
73
74struct SumArgs<'a> {
75    array: &'a dyn Array,
76}
77
78impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
79    type Error = VortexError;
80
81    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
82        if value.inputs.len() != 1 {
83            vortex_bail!(
84                "Sum function requires exactly one argument, got {}",
85                value.inputs.len()
86            );
87        }
88        let array = value.inputs[0]
89            .array()
90            .ok_or_else(|| vortex_err!("Invalid argument type for sum function"))?;
91
92        Ok(SumArgs { array })
93    }
94}
95
96pub static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
97    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
98    for kernel in inventory::iter::<SumKernelRef> {
99        compute.register_kernel(kernel.0.clone());
100    }
101    compute
102});
103
104pub struct SumKernelRef(ArcRef<dyn Kernel>);
105inventory::collect!(SumKernelRef);
106
107pub trait SumKernel: Encoding {
108    /// # Preconditions
109    ///
110    /// * The array's DType is summable
111    /// * The array is not all-null
112    fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
113}
114
115#[derive(Debug)]
116pub struct SumKernelAdapter<E: Encoding>(pub E);
117
118impl<E: Encoding + SumKernel> SumKernelAdapter<E> {
119    pub const fn lift(&'static self) -> SumKernelRef {
120        SumKernelRef(ArcRef::new_ref(self))
121    }
122}
123
124impl<E: Encoding + SumKernel> Kernel for SumKernelAdapter<E> {
125    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
126        let SumArgs { array } = SumArgs::try_from(args)?;
127        let Some(array) = array.as_any().downcast_ref::<E::Array>() else {
128            return Ok(None);
129        };
130        Ok(Some(E::sum(&self.0, array)?.into()))
131    }
132}
133
134/// Sum an array.
135///
136/// If the sum overflows, a null scalar will be returned.
137/// If the sum is not supported for the array's dtype, an error will be raised.
138/// If the array is all-invalid, the sum will be zero.
139pub fn sum_impl(
140    array: &dyn Array,
141    sum_dtype: DType,
142    kernels: &[ArcRef<dyn Kernel>],
143) -> VortexResult<Scalar> {
144    if array.is_empty() {
145        return if sum_dtype.is_float() {
146            Ok(Scalar::new(sum_dtype, 0.0.into()))
147        } else {
148            Ok(Scalar::new(sum_dtype, 0.into()))
149        };
150    }
151
152    // If the array is constant, we can compute the sum directly.
153    if let Some(mut constant) = array.as_constant() {
154        if constant.is_null() {
155            // An all-null constant array has a sum of 0.
156            return if sum_dtype.is_float() {
157                Ok(Scalar::new(sum_dtype, 0.0.into()))
158            } else {
159                Ok(Scalar::new(sum_dtype, 0.into()))
160            };
161        }
162
163        // TODO(ngates): I think we should delegate these to kernels, rather than hard-code.
164
165        // If it's an extension array, then unwrap it into the storage scalar.
166        if let Some(extension) = constant.as_extension_opt() {
167            constant = extension.storage();
168        }
169
170        // If it's a boolean array, then the true count is the sum, which is the length.
171        if let Some(bool) = constant.as_bool_opt() {
172            return if bool.value().vortex_expect("already checked for null value") {
173                // Constant true
174                Ok(Scalar::new(sum_dtype, array.len().into()))
175            } else {
176                // Constant false
177                Ok(Scalar::new(sum_dtype, 0.into()))
178            };
179        }
180
181        // If it's a primitive array, then the sum is the constant value times the length.
182        if let Some(primitive) = constant.as_primitive_opt() {
183            match primitive.ptype() {
184                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
185                    let value = primitive
186                        .pvalue()
187                        .vortex_expect("already checked for null value")
188                        .as_u64()
189                        .vortex_expect("Failed to cast constant value to u64");
190
191                    // Overflow results in a null sum.
192                    let sum = value.checked_mul(array.len() as u64);
193
194                    return Ok(Scalar::new(sum_dtype, sum.into()));
195                }
196                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
197                    let value = primitive
198                        .pvalue()
199                        .vortex_expect("already checked for null value")
200                        .as_i64()
201                        .vortex_expect("Failed to cast constant value to i64");
202
203                    // Overflow results in a null sum.
204                    let sum = value.checked_mul(array.len() as i64);
205
206                    return Ok(Scalar::new(sum_dtype, sum.into()));
207                }
208                PType::F16 | PType::F32 | PType::F64 => {
209                    let value = primitive
210                        .pvalue()
211                        .vortex_expect("already checked for null value")
212                        .as_f64()
213                        .vortex_expect("Failed to cast constant value to f64");
214
215                    let sum = value * (array.len() as f64);
216
217                    return Ok(Scalar::new(sum_dtype, sum.into()));
218                }
219            }
220        }
221
222        // For the unsupported types, we should have exited earlier.
223        unreachable!("Unsupported sum constant: {}", constant.dtype());
224    }
225
226    // Try to find a sum kernel
227    let args = InvocationArgs {
228        inputs: &[array.into()],
229        options: &(),
230    };
231    for kernel in kernels {
232        if let Some(output) = kernel.invoke(&args)? {
233            return output.unwrap_scalar();
234        }
235    }
236    if let Some(output) = array.invoke(&SUM_FN, &args)? {
237        return output.unwrap_scalar();
238    }
239
240    // Otherwise, canonicalize and try again.
241    log::debug!("No sum implementation found for {}", array.encoding());
242    if array.is_canonical() {
243        // Panic to avoid recursion, but it should never be hit.
244        vortex_panic!(
245            "No sum implementation found for canonical array: {}",
246            array.encoding()
247        );
248    }
249    sum(array.to_canonical()?.as_ref())
250}
251
252#[cfg(test)]
253mod test {
254    use crate::arrays::{BoolArray, PrimitiveArray};
255    use crate::compute::sum;
256
257    #[test]
258    fn sum_all_invalid() {
259        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
260        let result = sum(&array).unwrap();
261        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
262    }
263
264    #[test]
265    fn sum_all_invalid_float() {
266        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
267        let result = sum(&array).unwrap();
268        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
269    }
270
271    #[test]
272    fn sum_constant() {
273        let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
274        let result = sum(&array).unwrap();
275        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
276    }
277
278    #[test]
279    fn sum_constant_float() {
280        let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
281        let result = sum(&array).unwrap();
282        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
283    }
284
285    #[test]
286    fn sum_boolean() {
287        let array = BoolArray::from_iter([true, false, false, true]);
288        let result = sum(&array).unwrap();
289        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
290    }
291}