vortex_array/arrays/constant/compute/
sum.rs

1use num_traits::{CheckedMul, ToPrimitive};
2use vortex_dtype::{NativePType, PType, match_each_native_ptype};
3use vortex_error::{VortexExpect, VortexResult, vortex_err};
4use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
5
6use crate::arrays::{ConstantArray, ConstantVTable};
7use crate::compute::{SumKernel, SumKernelAdapter};
8use crate::register_kernel;
9use crate::stats::Stat;
10
11impl SumKernel for ConstantVTable {
12    fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
13        let sum_dtype = Stat::Sum
14            .dtype(array.dtype())
15            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
16        let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
17
18        let scalar = array.scalar();
19
20        let scalar_value = match_each_native_ptype!(
21            sum_ptype,
22            unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() },
23            signed: |T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() },
24            floating: |T| { sum_float(scalar.as_primitive(), array.len())?.into() }
25        );
26
27        Ok(Scalar::new(sum_dtype, scalar_value))
28    }
29}
30
31fn sum_integral<T>(
32    primitive_scalar: PrimitiveScalar<'_>,
33    array_len: usize,
34) -> VortexResult<Option<T>>
35where
36    T: FromPrimitiveOrF16 + NativePType + CheckedMul,
37    Scalar: From<Option<T>>,
38{
39    let v = primitive_scalar.as_::<T>()?;
40    let array_len =
41        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
42    let sum = v.and_then(|v| v.checked_mul(&array_len));
43
44    Ok(sum)
45}
46
47fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
48    let v = primitive_scalar.as_::<f64>()?;
49    let array_len = array_len
50        .to_f64()
51        .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
52
53    Ok(v.map(|v| v * array_len))
54}
55
56register_kernel!(SumKernelAdapter(ConstantVTable).lift());