vortex_array/arrays/constant/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::ArrowNativeTypeOp;
5use num_traits::{CheckedAdd, CheckedMul, ToPrimitive};
6use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
8use vortex_scalar::{DecimalScalar, DecimalValue, PrimitiveScalar, Scalar, ScalarValue};
9
10use crate::arrays::{ConstantArray, ConstantVTable};
11use crate::compute::{SumKernel, SumKernelAdapter};
12use crate::register_kernel;
13use crate::stats::Stat;
14
15impl SumKernel for ConstantVTable {
16    fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
17        // Compute the expected dtype of the sum.
18        let sum_dtype = Stat::Sum
19            .dtype(array.dtype())
20            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
21
22        let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
23        Ok(Scalar::new(sum_dtype, sum_value))
24    }
25}
26
27fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult<ScalarValue> {
28    match scalar.dtype() {
29        DType::Bool(_) => {
30            let count = match scalar.as_bool().value() {
31                None => unreachable!("Handled before reaching this point"),
32                Some(false) => 0u64,
33                Some(true) => len as u64,
34            };
35            let accumulator = accumulator
36                .as_primitive()
37                .as_::<u64>()
38                .vortex_expect("cannot be null");
39            Ok(ScalarValue::from(accumulator.checked_add(count)))
40        }
41        DType::Primitive(ptype, _) => {
42            let result = match_each_native_ptype!(
43                ptype,
44                unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.into() },
45                signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.into() },
46                floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() }
47            );
48            Ok(result)
49        }
50        DType::Decimal(decimal_dtype, _) => {
51            sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
52        }
53        DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len, accumulator),
54        dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
55    }
56}
57
58fn sum_decimal(
59    decimal_scalar: DecimalScalar,
60    array_len: usize,
61    decimal_dtype: DecimalDType,
62    accumulator: &Scalar,
63) -> VortexResult<ScalarValue> {
64    let result_dtype = Stat::Sum
65        .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
66        .vortex_expect("decimal supports sum");
67    let result_decimal_type = result_dtype
68        .as_decimal_opt()
69        .vortex_expect("must be decimal");
70
71    let Some(value) = decimal_scalar.decimal_value() else {
72        // Null value: return null
73        return Ok(ScalarValue::null());
74    };
75
76    // Convert array_len to DecimalValue for multiplication
77    let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
78
79    // Multiply value * len
80    let array_sum = value.checked_mul(&len_value).and_then(|result| {
81        // Check if result fits in the precision
82        result
83            .fits_in_precision(*result_decimal_type)
84            .unwrap_or(false)
85            .then_some(result)
86    });
87
88    // Add accumulator to array_sum
89    let initial_decimal = DecimalScalar::try_from(accumulator)?;
90    let initial_dec_value = initial_decimal
91        .decimal_value()
92        .unwrap_or(DecimalValue::I256(i256::ZERO));
93
94    match array_sum {
95        Some(array_sum_value) => {
96            let total = array_sum_value
97                .checked_add(&initial_dec_value)
98                .and_then(|result| {
99                    result
100                        .fits_in_precision(*result_decimal_type)
101                        .unwrap_or(false)
102                        .then_some(result)
103                });
104            match total {
105                Some(result_value) => Ok(ScalarValue::from(result_value)),
106                None => Ok(ScalarValue::null()), // Overflow
107            }
108        }
109        None => Ok(ScalarValue::null()), // Overflow
110    }
111}
112
113fn sum_integral<T>(
114    primitive_scalar: PrimitiveScalar<'_>,
115    array_len: usize,
116    accumulator: &Scalar,
117) -> VortexResult<Option<T>>
118where
119    T: NativePType + CheckedMul + CheckedAdd,
120    Scalar: From<Option<T>>,
121{
122    let v = primitive_scalar.as_::<T>();
123    let array_len =
124        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
125    let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
126        return Ok(None);
127    };
128
129    let initial = accumulator
130        .as_primitive()
131        .as_::<T>()
132        .vortex_expect("cannot be null");
133    Ok(initial.checked_add(&array_sum))
134}
135
136fn sum_float(
137    primitive_scalar: PrimitiveScalar<'_>,
138    array_len: usize,
139    accumulator: &Scalar,
140) -> VortexResult<Option<f64>> {
141    let v = primitive_scalar
142        .as_::<f64>()
143        .vortex_expect("cannot be null");
144    let array_len = array_len
145        .to_f64()
146        .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
147
148    let Ok(array_sum) = v.mul_checked(array_len) else {
149        return Ok(None);
150    };
151    let initial = accumulator
152        .as_primitive()
153        .as_::<f64>()
154        .vortex_expect("cannot be null");
155    Ok(Some(initial + array_sum))
156}
157
158register_kernel!(SumKernelAdapter(ConstantVTable).lift());
159
160#[cfg(test)]
161mod tests {
162    use vortex_dtype::Nullability::Nullable;
163    use vortex_dtype::{DType, DecimalDType, Nullability, PType, i256};
164    use vortex_scalar::{DecimalValue, Scalar};
165
166    use crate::arrays::ConstantArray;
167    use crate::compute::sum;
168    use crate::stats::Stat;
169    use crate::{Array, IntoArray};
170
171    #[test]
172    fn test_sum_unsigned() {
173        let array = ConstantArray::new(5u64, 10).into_array();
174        let result = sum(&array).unwrap();
175        assert_eq!(result, 50u64.into());
176    }
177
178    #[test]
179    fn test_sum_signed() {
180        let array = ConstantArray::new(-5i64, 10).into_array();
181        let result = sum(&array).unwrap();
182        assert_eq!(result, (-50i64).into());
183    }
184
185    #[test]
186    fn test_sum_nullable_value() {
187        let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
188            .into_array();
189        let result = sum(&array).unwrap();
190        assert_eq!(result, Scalar::primitive(0u64, Nullable));
191    }
192
193    #[test]
194    fn test_sum_bool_false() {
195        let array = ConstantArray::new(false, 10).into_array();
196        let result = sum(&array).unwrap();
197        assert_eq!(result, 0u64.into());
198    }
199
200    #[test]
201    fn test_sum_bool_true() {
202        let array = ConstantArray::new(true, 10).into_array();
203        let result = sum(&array).unwrap();
204        assert_eq!(result, 10u64.into());
205    }
206
207    #[test]
208    fn test_sum_bool_null() {
209        let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
210        let result = sum(&array).unwrap();
211        assert_eq!(result, Scalar::primitive(0u64, Nullable));
212    }
213
214    #[test]
215    fn test_sum_decimal() {
216        let decimal_dtype = DecimalDType::new(10, 2);
217        let array = ConstantArray::new(
218            Scalar::decimal(
219                DecimalValue::I64(100),
220                decimal_dtype,
221                Nullability::NonNullable,
222            ),
223            5,
224        )
225        .into_array();
226
227        let result = sum(&array).unwrap();
228
229        assert_eq!(
230            result.as_decimal().decimal_value(),
231            Some(DecimalValue::I256(i256::from_i128(500)))
232        );
233        assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
234    }
235
236    #[test]
237    fn test_sum_decimal_null() {
238        let decimal_dtype = DecimalDType::new(10, 2);
239        let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
240            .into_array();
241
242        let result = sum(&array).unwrap();
243        assert_eq!(
244            result,
245            Scalar::decimal(
246                DecimalValue::I256(i256::ZERO),
247                DecimalDType::new(20, 2),
248                Nullable
249            )
250        );
251    }
252
253    #[test]
254    fn test_sum_decimal_large_value() {
255        let decimal_dtype = DecimalDType::new(10, 2);
256        let array = ConstantArray::new(
257            Scalar::decimal(
258                DecimalValue::I64(999_999_999),
259                decimal_dtype,
260                Nullability::NonNullable,
261            ),
262            100,
263        )
264        .into_array();
265
266        let result = sum(&array).unwrap();
267        assert_eq!(
268            result.as_decimal().decimal_value(),
269            Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
270        );
271    }
272}