vortex_array/arrays/constant/compute/
sum.rs1use num_traits::{CheckedMul, ToPrimitive};
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue};
8
9use crate::arrays::{ConstantArray, ConstantVTable};
10use crate::compute::{SumKernel, SumKernelAdapter};
11use crate::register_kernel;
12use crate::stats::Stat;
13
14impl SumKernel for ConstantVTable {
15 fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
16 let sum_dtype = Stat::Sum
18 .dtype(array.dtype())
19 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
20
21 let sum_value = sum_scalar(array.scalar(), array.len())?;
22 Ok(Scalar::new(sum_dtype, sum_value))
23 }
24}
25
26fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult<ScalarValue> {
27 match scalar.dtype() {
28 DType::Bool(_) => Ok(ScalarValue::from(match scalar.as_bool().value() {
29 None => unreachable!("Handled before reaching this point"),
30 Some(false) => 0u64,
31 Some(true) => len as u64,
32 })),
33 DType::Primitive(ptype, _) => Ok(match_each_native_ptype!(
34 ptype,
35 unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len)?.into() },
36 signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len)?.into() },
37 floating: |T| { sum_float(scalar.as_primitive(), len)?.into() }
38 )),
39 DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len),
40 _ => vortex_bail!("Unsupported dtype for sum: {}", scalar.dtype()),
41 }
42}
43
44fn sum_integral<T>(
45 primitive_scalar: PrimitiveScalar<'_>,
46 array_len: usize,
47) -> VortexResult<Option<T>>
48where
49 T: FromPrimitiveOrF16 + NativePType + CheckedMul,
50 Scalar: From<Option<T>>,
51{
52 let v = primitive_scalar.as_::<T>();
53 let array_len =
54 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
55 let sum = v.and_then(|v| v.checked_mul(&array_len));
56
57 Ok(sum)
58}
59
60fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
61 let v = primitive_scalar.as_::<f64>();
62 let array_len = array_len
63 .to_f64()
64 .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
65
66 Ok(v.map(|v| v * array_len))
67}
68
69register_kernel!(SumKernelAdapter(ConstantVTable).lift());
70
71#[cfg(test)]
72mod tests {
73 use vortex_dtype::{DType, Nullability, PType};
74 use vortex_scalar::Scalar;
75
76 use crate::IntoArray;
77 use crate::arrays::ConstantArray;
78 use crate::compute::sum;
79
80 #[test]
81 fn test_sum_unsigned() {
82 let array = ConstantArray::new(5u64, 10).into_array();
83 let result = sum(&array).unwrap();
84 assert_eq!(result, 50u64.into());
85 }
86
87 #[test]
88 fn test_sum_signed() {
89 let array = ConstantArray::new(-5i64, 10).into_array();
90 let result = sum(&array).unwrap();
91 assert_eq!(result, (-50i64).into());
92 }
93
94 #[test]
95 fn test_sum_nullable_value() {
96 let array = ConstantArray::new(
97 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
98 10,
99 )
100 .into_array();
101 let result = sum(&array).unwrap();
102 assert!(result.is_null());
103 }
104
105 #[test]
106 fn test_sum_bool_false() {
107 let array = ConstantArray::new(false, 10).into_array();
108 let result = sum(&array).unwrap();
109 assert_eq!(result, 0u64.into());
110 }
111
112 #[test]
113 fn test_sum_bool_true() {
114 let array = ConstantArray::new(true, 10).into_array();
115 let result = sum(&array).unwrap();
116 assert_eq!(result, 10u64.into());
117 }
118
119 #[test]
120 fn test_sum_bool_null() {
121 let array =
122 ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), 10).into_array();
123 let result = sum(&array).unwrap();
124 assert!(result.is_null());
125 }
126}