vortex_array/arrays/constant/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::{CheckedMul, ToPrimitive};
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue};
8
9use crate::arrays::{ConstantArray, ConstantVTable};
10use crate::compute::{SumKernel, SumKernelAdapter};
11use crate::register_kernel;
12use crate::stats::Stat;
13
14impl SumKernel for ConstantVTable {
15    fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
16        // Compute the expected dtype of the sum.
17        let sum_dtype = Stat::Sum
18            .dtype(array.dtype())
19            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
20
21        let sum_value = sum_scalar(array.scalar(), array.len())?;
22        Ok(Scalar::new(sum_dtype, sum_value))
23    }
24}
25
26fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult<ScalarValue> {
27    match scalar.dtype() {
28        DType::Bool(_) => Ok(ScalarValue::from(match scalar.as_bool().value() {
29            None => unreachable!("Handled before reaching this point"),
30            Some(false) => 0u64,
31            Some(true) => len as u64,
32        })),
33        DType::Primitive(ptype, _) => Ok(match_each_native_ptype!(
34            ptype,
35            unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len)?.into() },
36            signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len)?.into() },
37            floating: |T| { sum_float(scalar.as_primitive(), len)?.into() }
38        )),
39        DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len),
40        _ => vortex_bail!("Unsupported dtype for sum: {}", scalar.dtype()),
41    }
42}
43
44fn sum_integral<T>(
45    primitive_scalar: PrimitiveScalar<'_>,
46    array_len: usize,
47) -> VortexResult<Option<T>>
48where
49    T: FromPrimitiveOrF16 + NativePType + CheckedMul,
50    Scalar: From<Option<T>>,
51{
52    let v = primitive_scalar.as_::<T>();
53    let array_len =
54        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
55    let sum = v.and_then(|v| v.checked_mul(&array_len));
56
57    Ok(sum)
58}
59
60fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
61    let v = primitive_scalar.as_::<f64>();
62    let array_len = array_len
63        .to_f64()
64        .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
65
66    Ok(v.map(|v| v * array_len))
67}
68
69register_kernel!(SumKernelAdapter(ConstantVTable).lift());
70
71#[cfg(test)]
72mod tests {
73    use vortex_dtype::{DType, Nullability, PType};
74    use vortex_scalar::Scalar;
75
76    use crate::IntoArray;
77    use crate::arrays::ConstantArray;
78    use crate::compute::sum;
79
80    #[test]
81    fn test_sum_unsigned() {
82        let array = ConstantArray::new(5u64, 10).into_array();
83        let result = sum(&array).unwrap();
84        assert_eq!(result, 50u64.into());
85    }
86
87    #[test]
88    fn test_sum_signed() {
89        let array = ConstantArray::new(-5i64, 10).into_array();
90        let result = sum(&array).unwrap();
91        assert_eq!(result, (-50i64).into());
92    }
93
94    #[test]
95    fn test_sum_nullable_value() {
96        let array = ConstantArray::new(
97            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
98            10,
99        )
100        .into_array();
101        let result = sum(&array).unwrap();
102        assert!(result.is_null());
103    }
104
105    #[test]
106    fn test_sum_bool_false() {
107        let array = ConstantArray::new(false, 10).into_array();
108        let result = sum(&array).unwrap();
109        assert_eq!(result, 0u64.into());
110    }
111
112    #[test]
113    fn test_sum_bool_true() {
114        let array = ConstantArray::new(true, 10).into_array();
115        let result = sum(&array).unwrap();
116        assert_eq!(result, 10u64.into());
117    }
118
119    #[test]
120    fn test_sum_bool_null() {
121        let array =
122            ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), 10).into_array();
123        let result = sum(&array).unwrap();
124        assert!(result.is_null());
125    }
126}