vortex_array/arrays/constant/compute/
sum.rs1use num_traits::{CheckedMul, ToPrimitive};
5use vortex_dtype::{NativePType, PType, match_each_native_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_err};
7use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
8
9use crate::arrays::{ConstantArray, ConstantVTable};
10use crate::compute::{SumKernel, SumKernelAdapter};
11use crate::register_kernel;
12use crate::stats::Stat;
13
14impl SumKernel for ConstantVTable {
15 fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
16 let sum_dtype = Stat::Sum
17 .dtype(array.dtype())
18 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
19 let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
20
21 let scalar = array.scalar();
22
23 let scalar_value = match_each_native_ptype!(
24 sum_ptype,
25 unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() },
26 signed: |T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() },
27 floating: |T| { sum_float(scalar.as_primitive(), array.len())?.into() }
28 );
29
30 Ok(Scalar::new(sum_dtype, scalar_value))
31 }
32}
33
34fn sum_integral<T>(
35 primitive_scalar: PrimitiveScalar<'_>,
36 array_len: usize,
37) -> VortexResult<Option<T>>
38where
39 T: FromPrimitiveOrF16 + NativePType + CheckedMul,
40 Scalar: From<Option<T>>,
41{
42 let v = primitive_scalar.as_::<T>()?;
43 let array_len =
44 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
45 let sum = v.and_then(|v| v.checked_mul(&array_len));
46
47 Ok(sum)
48}
49
50fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
51 let v = primitive_scalar.as_::<f64>()?;
52 let array_len = array_len
53 .to_f64()
54 .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
55
56 Ok(v.map(|v| v * array_len))
57}
58
59register_kernel!(SumKernelAdapter(ConstantVTable).lift());