Skip to main content

vortex_array/arrays/constant/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::CheckedAdd;
6use num_traits::CheckedMul;
7use vortex_dtype::DType;
8use vortex_dtype::DecimalDType;
9use vortex_dtype::NativePType;
10use vortex_dtype::Nullability;
11use vortex_dtype::i256;
12use vortex_dtype::match_each_native_ptype;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17
18use crate::arrays::ConstantArray;
19use crate::arrays::ConstantVTable;
20use crate::compute::SumKernel;
21use crate::compute::SumKernelAdapter;
22use crate::expr::stats::Stat;
23use crate::register_kernel;
24use crate::scalar::DecimalScalar;
25use crate::scalar::DecimalValue;
26use crate::scalar::PrimitiveScalar;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30impl SumKernel for ConstantVTable {
31    fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
32        // Compute the expected dtype of the sum.
33        let sum_dtype = Stat::Sum
34            .dtype(array.dtype())
35            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
36
37        let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
38        Scalar::try_new(sum_dtype, sum_value)
39    }
40}
41
42fn sum_scalar(
43    scalar: &Scalar,
44    len: usize,
45    accumulator: &Scalar,
46) -> VortexResult<Option<ScalarValue>> {
47    match scalar.dtype() {
48        DType::Bool(_) => {
49            let count = match scalar.as_bool().value() {
50                None => unreachable!("Handled before reaching this point"),
51                Some(false) => 0u64,
52                Some(true) => len as u64,
53            };
54            let accumulator = accumulator
55                .as_primitive()
56                .as_::<u64>()
57                .vortex_expect("cannot be null");
58            Ok(accumulator
59                .checked_add(count)
60                .map(|v| ScalarValue::Primitive(v.into())))
61        }
62        DType::Primitive(ptype, _) => {
63            let result = match_each_native_ptype!(
64                ptype,
65                unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) },
66                signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) },
67                floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }
68            );
69            Ok(result)
70        }
71        DType::Decimal(decimal_dtype, _) => {
72            sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
73        }
74        DType::Extension(_) => {
75            sum_scalar(&scalar.as_extension().to_storage_scalar(), len, accumulator)
76        }
77        dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
78    }
79}
80
81fn sum_decimal(
82    decimal_scalar: DecimalScalar,
83    array_len: usize,
84    decimal_dtype: DecimalDType,
85    accumulator: &Scalar,
86) -> VortexResult<Option<ScalarValue>> {
87    let result_dtype = Stat::Sum
88        .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
89        .vortex_expect("decimal supports sum");
90    let result_decimal_type = result_dtype
91        .as_decimal_opt()
92        .vortex_expect("must be decimal");
93
94    let Some(value) = decimal_scalar.decimal_value() else {
95        // Null value: return null
96        return Ok(None);
97    };
98
99    // Convert array_len to DecimalValue for multiplication.
100    let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
101
102    let Some(array_sum) = value
103        .checked_mul(&len_value)
104        .filter(|d| d.fits_in_precision(*result_decimal_type))
105    else {
106        return Ok(None);
107    };
108
109    // Add accumulator to array_sum.
110    let initial_decimal = accumulator.as_decimal();
111    let initial_dec_value = initial_decimal
112        .decimal_value()
113        .unwrap_or(DecimalValue::I256(i256::ZERO));
114
115    let total = array_sum
116        .checked_add(&initial_dec_value)
117        .and_then(|result| {
118            result
119                .fits_in_precision(*result_decimal_type)
120                .then_some(result)
121        });
122    match total {
123        Some(result_value) => Ok(Some(ScalarValue::from(result_value))),
124        None => Ok(None), // Overflow
125    }
126}
127
128fn sum_integral<T>(
129    primitive_scalar: PrimitiveScalar<'_>,
130    array_len: usize,
131    accumulator: &Scalar,
132) -> VortexResult<Option<T>>
133where
134    T: NativePType + CheckedMul + CheckedAdd,
135{
136    let v = primitive_scalar.as_::<T>();
137    let array_len =
138        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
139    let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
140        return Ok(None);
141    };
142
143    let initial = accumulator
144        .as_primitive()
145        .as_::<T>()
146        .vortex_expect("cannot be null");
147    Ok(initial.checked_add(&array_sum))
148}
149
150fn sum_float(
151    primitive_scalar: PrimitiveScalar<'_>,
152    array_len: usize,
153    accumulator: &Scalar,
154) -> VortexResult<Option<f64>> {
155    let initial = accumulator
156        .as_primitive()
157        .as_::<f64>()
158        .vortex_expect("cannot be null");
159    let v = primitive_scalar
160        .as_::<f64>()
161        .vortex_expect("cannot be null");
162    let len_f64: f64 = array_len.as_();
163
164    Ok(Some(initial + v * len_f64))
165}
166
167register_kernel!(SumKernelAdapter(ConstantVTable).lift());
168
169#[cfg(test)]
170mod tests {
171    use vortex_dtype::DType;
172    use vortex_dtype::DecimalDType;
173    use vortex_dtype::Nullability;
174    use vortex_dtype::Nullability::Nullable;
175    use vortex_dtype::PType;
176    use vortex_dtype::i256;
177    use vortex_error::VortexExpect;
178
179    use crate::Array;
180    use crate::IntoArray;
181    use crate::arrays::ConstantArray;
182    use crate::compute::sum;
183    use crate::compute::sum_with_accumulator;
184    use crate::expr::stats::Stat;
185    use crate::scalar::DecimalValue;
186    use crate::scalar::Scalar;
187
188    #[test]
189    fn test_sum_unsigned() {
190        let array = ConstantArray::new(5u64, 10).into_array();
191        let result = sum(&array).unwrap();
192        assert_eq!(result, 50u64.into());
193    }
194
195    #[test]
196    fn test_sum_signed() {
197        let array = ConstantArray::new(-5i64, 10).into_array();
198        let result = sum(&array).unwrap();
199        assert_eq!(result, (-50i64).into());
200    }
201
202    #[test]
203    fn test_sum_nullable_value() {
204        let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
205            .into_array();
206        let result = sum(&array).unwrap();
207        assert_eq!(result, Scalar::primitive(0u64, Nullable));
208    }
209
210    #[test]
211    fn test_sum_bool_false() {
212        let array = ConstantArray::new(false, 10).into_array();
213        let result = sum(&array).unwrap();
214        assert_eq!(result, 0u64.into());
215    }
216
217    #[test]
218    fn test_sum_bool_true() {
219        let array = ConstantArray::new(true, 10).into_array();
220        let result = sum(&array).unwrap();
221        assert_eq!(result, 10u64.into());
222    }
223
224    #[test]
225    fn test_sum_bool_null() {
226        let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
227        let result = sum(&array).unwrap();
228        assert_eq!(result, Scalar::primitive(0u64, Nullable));
229    }
230
231    #[test]
232    fn test_sum_decimal() {
233        let decimal_dtype = DecimalDType::new(10, 2);
234        let array = ConstantArray::new(
235            Scalar::decimal(
236                DecimalValue::I64(100),
237                decimal_dtype,
238                Nullability::NonNullable,
239            ),
240            5,
241        )
242        .into_array();
243
244        let result = sum(&array).unwrap();
245
246        assert_eq!(
247            result.as_decimal().decimal_value(),
248            Some(DecimalValue::I256(i256::from_i128(500)))
249        );
250        assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
251    }
252
253    #[test]
254    fn test_sum_decimal_null() {
255        let decimal_dtype = DecimalDType::new(10, 2);
256        let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
257            .into_array();
258
259        let result = sum(&array).unwrap();
260        assert_eq!(
261            result,
262            Scalar::decimal(
263                DecimalValue::I256(i256::ZERO),
264                DecimalDType::new(20, 2),
265                Nullable
266            )
267        );
268    }
269
270    #[test]
271    fn test_sum_decimal_large_value() {
272        let decimal_dtype = DecimalDType::new(10, 2);
273        let array = ConstantArray::new(
274            Scalar::decimal(
275                DecimalValue::I64(999_999_999),
276                decimal_dtype,
277                Nullability::NonNullable,
278            ),
279            100,
280        )
281        .into_array();
282
283        let result = sum(&array).unwrap();
284        assert_eq!(
285            result.as_decimal().decimal_value(),
286            Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
287        );
288    }
289
290    #[test]
291    fn test_sum_float_non_multiply() {
292        let acc = -2048669276050936500000000000f64;
293        let array = ConstantArray::new(6.1811675e16f64, 25);
294        let sum = sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable))
295            .vortex_expect("operation should succeed in test");
296        assert_eq!(
297            f64::try_from(&sum).vortex_expect("operation should succeed in test"),
298            -2048669274505644600000000000f64
299        );
300    }
301}