Skip to main content

vortex_array/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use num_traits::CheckedAdd;
8use num_traits::CheckedSub;
9use vortex_error::VortexError;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray as _;
19use crate::compute::ComputeFn;
20use crate::compute::ComputeFnVTable;
21use crate::compute::InvocationArgs;
22use crate::compute::Kernel;
23use crate::compute::Output;
24use crate::dtype::DType;
25use crate::expr::stats::Precision;
26use crate::expr::stats::Stat;
27use crate::expr::stats::StatsProvider;
28use crate::scalar::NumericOperator;
29use crate::scalar::Scalar;
30use crate::vtable::VTable;
31
32static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
33    let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
34    for kernel in inventory::iter::<SumKernelRef> {
35        compute.register_kernel(kernel.0.clone());
36    }
37    compute
38});
39
40pub(crate) fn warm_up_vtable() -> usize {
41    SUM_FN.kernels().len()
42}
43
44/// Sum an array with an initial value.
45///
46/// If the sum overflows, a null scalar will be returned.
47/// If the sum is not supported for the array's dtype, an error will be raised.
48/// If the array is all-invalid, the sum will be the accumulator.
49/// The accumulator must have a dtype compatible with the sum result dtype.
50pub(crate) fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
51    SUM_FN
52        .invoke(&InvocationArgs {
53            inputs: &[array.into(), accumulator.into()],
54            options: &(),
55        })?
56        .unwrap_scalar()
57}
58
59/// Sum an array, starting from zero.
60///
61/// If the sum overflows, a null scalar will be returned.
62/// If the sum is not supported for the array's dtype, an error will be raised.
63/// If the array is all-invalid, the sum will be zero.
64pub fn sum(array: &ArrayRef) -> VortexResult<Scalar> {
65    let sum_dtype = Stat::Sum
66        .dtype(array.dtype())
67        .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
68    let zero = Scalar::zero_value(&sum_dtype);
69    sum_with_accumulator(array, &zero)
70}
71
72/// For unary compute functions, it's useful to just have this short-cut.
73pub struct SumArgs<'a> {
74    pub array: &'a dyn Array,
75    pub accumulator: &'a Scalar,
76}
77
78impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
79    type Error = VortexError;
80
81    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
82        if value.inputs.len() != 2 {
83            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
84        }
85        let array = value.inputs[0]
86            .array()
87            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
88        let accumulator = value.inputs[1]
89            .scalar()
90            .ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?;
91        Ok(SumArgs { array, accumulator })
92    }
93}
94
95struct Sum;
96
97impl ComputeFnVTable for Sum {
98    fn invoke(
99        &self,
100        args: &InvocationArgs,
101        kernels: &[ArcRef<dyn Kernel>],
102    ) -> VortexResult<Output> {
103        let SumArgs { array, accumulator } = args.try_into()?;
104        let array = array.to_array();
105
106        // Compute the expected dtype of the sum.
107        let sum_dtype = self.return_dtype(args)?;
108
109        vortex_ensure!(
110            &sum_dtype == accumulator.dtype(),
111            "sum_dtype {sum_dtype} must match accumulator dtype {}",
112            accumulator.dtype()
113        );
114
115        // Short-circuit using array statistics.
116        if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
117            // For floats only use stats if accumulator is zero. otherwise we might have numerical
118            // stability issues.
119            match &sum_dtype {
120                DType::Primitive(p, _) => {
121                    if p.is_float() && accumulator.is_zero() == Some(true) {
122                        return Ok(sum_scalar.into());
123                    } else if p.is_int() {
124                        let sum_from_stat = accumulator
125                            .as_primitive()
126                            .checked_add(&sum_scalar.as_primitive())
127                            .map(Scalar::from);
128                        return Ok(sum_from_stat
129                            .unwrap_or_else(|| Scalar::null(sum_dtype))
130                            .into());
131                    }
132                }
133                DType::Decimal(..) => {
134                    let sum_from_stat = accumulator
135                        .as_decimal()
136                        .checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add)
137                        .map(Scalar::from);
138                    return Ok(sum_from_stat
139                        .unwrap_or_else(|| Scalar::null(sum_dtype))
140                        .into());
141                }
142                _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
143            }
144        }
145
146        let sum_scalar = sum_impl(&array, accumulator, kernels)?;
147
148        // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
149        match sum_dtype {
150            DType::Primitive(p, _) => {
151                if p.is_float()
152                    && accumulator.is_zero() == Some(true)
153                    && let Some(sum_value) = sum_scalar.value().cloned()
154                {
155                    array
156                        .statistics()
157                        .set(Stat::Sum, Precision::Exact(sum_value));
158                } else if p.is_int()
159                    && let Some(less_accumulator) = sum_scalar
160                        .as_primitive()
161                        .checked_sub(&accumulator.as_primitive())
162                    && let Some(val) = Scalar::from(less_accumulator).into_value()
163                {
164                    array.statistics().set(Stat::Sum, Precision::Exact(val));
165                }
166            }
167            DType::Decimal(..) => {
168                if let Some(less_accumulator) = sum_scalar
169                    .as_decimal()
170                    .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
171                    && let Some(val) = Scalar::from(less_accumulator).into_value()
172                {
173                    array.statistics().set(Stat::Sum, Precision::Exact(val));
174                }
175            }
176            _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
177        }
178
179        Ok(sum_scalar.into())
180    }
181
182    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
183        let SumArgs { array, .. } = args.try_into()?;
184        Stat::Sum
185            .dtype(array.dtype())
186            .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
187    }
188
189    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
190        // The sum function always returns a single scalar value.
191        Ok(1)
192    }
193
194    fn is_elementwise(&self) -> bool {
195        false
196    }
197}
198
199pub struct SumKernelRef(ArcRef<dyn Kernel>);
200inventory::collect!(SumKernelRef);
201
202pub trait SumKernel: VTable {
203    /// # Preconditions
204    ///
205    /// * The array's DType is summable
206    /// * The array is not all-null
207    /// * The accumulator must have a dtype compatible with the sum result dtype
208    fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult<Scalar>;
209}
210
211#[derive(Debug)]
212pub struct SumKernelAdapter<V: VTable>(pub V);
213
214impl<V: VTable + SumKernel> SumKernelAdapter<V> {
215    pub const fn lift(&'static self) -> SumKernelRef {
216        SumKernelRef(ArcRef::new_ref(self))
217    }
218}
219
220impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
221    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
222        let SumArgs { array, accumulator } = args.try_into()?;
223        let Some(array) = array.as_opt::<V>() else {
224            return Ok(None);
225        };
226        Ok(Some(V::sum(&self.0, array, accumulator)?.into()))
227    }
228}
229
230/// Sum an array.
231///
232/// If the sum overflows, a null scalar will be returned.
233/// If the sum is not supported for the array's dtype, an error will be raised.
234/// If the array is all-invalid, the sum will be the accumulator.
235pub fn sum_impl(
236    array: &ArrayRef,
237    accumulator: &Scalar,
238    kernels: &[ArcRef<dyn Kernel>],
239) -> VortexResult<Scalar> {
240    if array.is_empty() || array.all_invalid()? || accumulator.is_null() {
241        return Ok(accumulator.clone());
242    }
243
244    // Try to find a sum kernel
245    let args = InvocationArgs {
246        inputs: &[array.into(), accumulator.into()],
247        options: &(),
248    };
249    for kernel in kernels {
250        if let Some(output) = kernel.invoke(&args)? {
251            return output.unwrap_scalar();
252        }
253    }
254
255    // Otherwise, canonicalize and try again.
256    tracing::debug!("No sum implementation found for {}", array.encoding_id());
257    if array.is_canonical() {
258        // Panic to avoid recursion, but it should never be hit.
259        vortex_panic!(
260            "No sum implementation found for canonical array: {}",
261            array.encoding_id()
262        );
263    }
264    let canonical = array.to_canonical()?.into_array();
265    sum_with_accumulator(&canonical, accumulator)
266}
267
268#[cfg(test)]
269mod test {
270    use vortex_buffer::buffer;
271    use vortex_error::VortexExpect;
272
273    use crate::IntoArray as _;
274    use crate::arrays::BoolArray;
275    use crate::arrays::ChunkedArray;
276    use crate::arrays::PrimitiveArray;
277    use crate::compute::sum;
278    use crate::compute::sum_with_accumulator;
279    use crate::dtype::DType;
280    use crate::dtype::Nullability;
281    use crate::dtype::PType;
282    use crate::scalar::Scalar;
283
284    #[test]
285    fn sum_all_invalid() {
286        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
287        let result = sum(&array).unwrap();
288        assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
289    }
290
291    #[test]
292    fn sum_all_invalid_float() {
293        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]).into_array();
294        let result = sum(&array).unwrap();
295        assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
296    }
297
298    #[test]
299    fn sum_constant() {
300        let array = buffer![1, 1, 1, 1].into_array();
301        let result = sum(&array).unwrap();
302        assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
303    }
304
305    #[test]
306    fn sum_constant_float() {
307        let array = buffer![1., 1., 1., 1.].into_array();
308        let result = sum(&array).unwrap();
309        assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
310    }
311
312    #[test]
313    fn sum_boolean() {
314        let array = BoolArray::from_iter([true, false, false, true]).into_array();
315        let result = sum(&array).unwrap();
316        assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
317    }
318
319    #[test]
320    fn sum_stats() {
321        let array = ChunkedArray::try_new(
322            vec![
323                PrimitiveArray::from_iter([1, 1, 1]).into_array(),
324                PrimitiveArray::from_iter([2, 2, 2]).into_array(),
325            ],
326            DType::Primitive(PType::I32, Nullability::NonNullable),
327        )
328        .vortex_expect("operation should succeed in test");
329        let array = array.into_array();
330        // compute sum with accumulator to populate stats
331        sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullability::Nullable)).unwrap();
332
333        let sum_without_acc = sum(&array).unwrap();
334        assert_eq!(
335            sum_without_acc,
336            Scalar::primitive(9i64, Nullability::Nullable)
337        );
338    }
339}