vortex_array/arrays/chunked/compute/
sum.rs1use num_traits::PrimInt;
2use vortex_dtype::{NativePType, PType, match_each_native_ptype};
3use vortex_error::{VortexExpect, VortexResult, vortex_err};
4use vortex_scalar::{FromPrimitiveOrF16, Scalar};
5
6use crate::arrays::{ChunkedArray, ChunkedVTable};
7use crate::compute::{SumKernel, SumKernelAdapter, sum};
8use crate::stats::Stat;
9use crate::{ArrayRef, register_kernel};
10
11impl SumKernel for ChunkedVTable {
12    fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
13        let sum_dtype = Stat::Sum
14            .dtype(array.dtype())
15            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
16        let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
17
18        let scalar_value = match_each_native_ptype!(
19            sum_ptype,
20            unsigned: |$T| { sum_int::<u64>(array.chunks())?.into() }
21            signed: |$T| { sum_int::<i64>(array.chunks())?.into() }
22            floating: |$T| { sum_float(array.chunks())?.into() }
23        );
24
25        Ok(Scalar::new(sum_dtype, scalar_value))
26    }
27}
28
29register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
30
31fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
32    chunks: &[ArrayRef],
33) -> VortexResult<Option<T>> {
34    let mut result = T::zero();
35    for chunk in chunks {
36        let chunk_sum = sum(chunk)?;
37
38        let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>()? else {
39            return Ok(None);
41        };
42
43        let Some(chunk_result) = result.checked_add(&chunk_sum) else {
44            return Ok(None);
46        };
47
48        result = chunk_result;
49    }
50    Ok(Some(result))
51}
52
53fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
54    let mut result = 0f64;
55    for chunk in chunks {
56        let chunk_sum = sum(chunk)?;
57        let chunk_sum = chunk_sum
58            .as_primitive()
59            .as_::<f64>()?
60            .vortex_expect("Float sum should never be null");
61        result += chunk_sum;
62    }
63    Ok(result)
64}