vortex_array/arrays/chunked/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::PrimInt;
5use vortex_dtype::{NativePType, PType, match_each_native_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_err};
7use vortex_scalar::{FromPrimitiveOrF16, Scalar};
8
9use crate::arrays::{ChunkedArray, ChunkedVTable};
10use crate::compute::{SumKernel, SumKernelAdapter, sum};
11use crate::stats::Stat;
12use crate::{ArrayRef, register_kernel};
13
14impl SumKernel for ChunkedVTable {
15    fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
16        let sum_dtype = Stat::Sum
17            .dtype(array.dtype())
18            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
19        let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
20
21        let scalar_value = match_each_native_ptype!(
22            sum_ptype,
23            unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
24            signed: |T| { sum_int::<i64>(array.chunks())?.into() },
25            floating: |T| { sum_float(array.chunks())?.into() }
26        );
27
28        Ok(Scalar::new(sum_dtype, scalar_value))
29    }
30}
31
32register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
33
34fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
35    chunks: &[ArrayRef],
36) -> VortexResult<Option<T>> {
37    let mut result = T::zero();
38    for chunk in chunks {
39        let chunk_sum = sum(chunk)?;
40
41        let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>()? else {
42            // Bail out on overflow
43            return Ok(None);
44        };
45
46        let Some(chunk_result) = result.checked_add(&chunk_sum) else {
47            // Bail out on overflow
48            return Ok(None);
49        };
50
51        result = chunk_result;
52    }
53    Ok(Some(result))
54}
55
56fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
57    let mut result = 0f64;
58    for chunk in chunks {
59        let chunk_sum = sum(chunk)?;
60        let chunk_sum = chunk_sum
61            .as_primitive()
62            .as_::<f64>()?
63            .vortex_expect("Float sum should never be null");
64        result += chunk_sum;
65    }
66    Ok(result)
67}