vortex_array/arrays/constant/compute/
sum.rs1use num_traits::{CheckedMul, ToPrimitive};
2use vortex_dtype::{NativePType, PType, match_each_native_ptype};
3use vortex_error::{VortexExpect, VortexResult, vortex_err};
4use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
5
6use crate::arrays::{ConstantArray, ConstantVTable};
7use crate::compute::{SumKernel, SumKernelAdapter};
8use crate::register_kernel;
9use crate::stats::Stat;
10
11impl SumKernel for ConstantVTable {
12 fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
13 let sum_dtype = Stat::Sum
14 .dtype(array.dtype())
15 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
16 let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
17
18 let scalar = array.scalar();
19
20 let scalar_value = match_each_native_ptype!(
21 sum_ptype,
22 unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() },
23 signed: |T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() },
24 floating: |T| { sum_float(scalar.as_primitive(), array.len())?.into() }
25 );
26
27 Ok(Scalar::new(sum_dtype, scalar_value))
28 }
29}
30
31fn sum_integral<T>(
32 primitive_scalar: PrimitiveScalar<'_>,
33 array_len: usize,
34) -> VortexResult<Option<T>>
35where
36 T: FromPrimitiveOrF16 + NativePType + CheckedMul,
37 Scalar: From<Option<T>>,
38{
39 let v = primitive_scalar.as_::<T>()?;
40 let array_len =
41 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
42 let sum = v.and_then(|v| v.checked_mul(&array_len));
43
44 Ok(sum)
45}
46
47fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
48 let v = primitive_scalar.as_::<f64>()?;
49 let array_len = array_len
50 .to_f64()
51 .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
52
53 Ok(v.map(|v| v * array_len))
54}
55
56register_kernel!(SumKernelAdapter(ConstantVTable).lift());