vortex_array/arrays/constant/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::CheckedAdd;
6use num_traits::CheckedMul;
7use vortex_dtype::DType;
8use vortex_dtype::DecimalDType;
9use vortex_dtype::NativePType;
10use vortex_dtype::Nullability;
11use vortex_dtype::i256;
12use vortex_dtype::match_each_native_ptype;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_scalar::DecimalScalar;
18use vortex_scalar::DecimalValue;
19use vortex_scalar::PrimitiveScalar;
20use vortex_scalar::Scalar;
21use vortex_scalar::ScalarValue;
22
23use crate::arrays::ConstantArray;
24use crate::arrays::ConstantVTable;
25use crate::compute::SumKernel;
26use crate::compute::SumKernelAdapter;
27use crate::expr::stats::Stat;
28use crate::register_kernel;
29
30impl SumKernel for ConstantVTable {
31    fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
32        // Compute the expected dtype of the sum.
33        let sum_dtype = Stat::Sum
34            .dtype(array.dtype())
35            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
36
37        let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
38        Ok(Scalar::new(sum_dtype, sum_value))
39    }
40}
41
42fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult<ScalarValue> {
43    match scalar.dtype() {
44        DType::Bool(_) => {
45            let count = match scalar.as_bool().value() {
46                None => unreachable!("Handled before reaching this point"),
47                Some(false) => 0u64,
48                Some(true) => len as u64,
49            };
50            let accumulator = accumulator
51                .as_primitive()
52                .as_::<u64>()
53                .vortex_expect("cannot be null");
54            Ok(ScalarValue::from(accumulator.checked_add(count)))
55        }
56        DType::Primitive(ptype, _) => {
57            let result = match_each_native_ptype!(
58                ptype,
59                unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.into() },
60                signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.into() },
61                floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() }
62            );
63            Ok(result)
64        }
65        DType::Decimal(decimal_dtype, _) => {
66            sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
67        }
68        DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len, accumulator),
69        dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
70    }
71}
72
73fn sum_decimal(
74    decimal_scalar: DecimalScalar,
75    array_len: usize,
76    decimal_dtype: DecimalDType,
77    accumulator: &Scalar,
78) -> VortexResult<ScalarValue> {
79    let result_dtype = Stat::Sum
80        .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
81        .vortex_expect("decimal supports sum");
82    let result_decimal_type = result_dtype
83        .as_decimal_opt()
84        .vortex_expect("must be decimal");
85
86    let Some(value) = decimal_scalar.decimal_value() else {
87        // Null value: return null
88        return Ok(ScalarValue::null());
89    };
90
91    // Convert array_len to DecimalValue for multiplication
92    let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
93
94    // Multiply value * len
95    let array_sum = value.checked_mul(&len_value).and_then(|result| {
96        // Check if result fits in the precision
97        result
98            .fits_in_precision(*result_decimal_type)
99            .unwrap_or(false)
100            .then_some(result)
101    });
102
103    // Add accumulator to array_sum
104    let initial_decimal = DecimalScalar::try_from(accumulator)?;
105    let initial_dec_value = initial_decimal
106        .decimal_value()
107        .unwrap_or(DecimalValue::I256(i256::ZERO));
108
109    match array_sum {
110        Some(array_sum_value) => {
111            let total = array_sum_value
112                .checked_add(&initial_dec_value)
113                .and_then(|result| {
114                    result
115                        .fits_in_precision(*result_decimal_type)
116                        .unwrap_or(false)
117                        .then_some(result)
118                });
119            match total {
120                Some(result_value) => Ok(ScalarValue::from(result_value)),
121                None => Ok(ScalarValue::null()), // Overflow
122            }
123        }
124        None => Ok(ScalarValue::null()), // Overflow
125    }
126}
127
128fn sum_integral<T>(
129    primitive_scalar: PrimitiveScalar<'_>,
130    array_len: usize,
131    accumulator: &Scalar,
132) -> VortexResult<Option<T>>
133where
134    T: NativePType + CheckedMul + CheckedAdd,
135    Scalar: From<Option<T>>,
136{
137    let v = primitive_scalar.as_::<T>();
138    let array_len =
139        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
140    let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
141        return Ok(None);
142    };
143
144    let initial = accumulator
145        .as_primitive()
146        .as_::<T>()
147        .vortex_expect("cannot be null");
148    Ok(initial.checked_add(&array_sum))
149}
150
151fn sum_float(
152    primitive_scalar: PrimitiveScalar<'_>,
153    array_len: usize,
154    accumulator: &Scalar,
155) -> VortexResult<Option<f64>> {
156    let initial = accumulator
157        .as_primitive()
158        .as_::<f64>()
159        .vortex_expect("cannot be null");
160    let v = primitive_scalar
161        .as_::<f64>()
162        .vortex_expect("cannot be null");
163    let len_f64: f64 = array_len.as_();
164
165    Ok(Some(initial + v * len_f64))
166}
167
168register_kernel!(SumKernelAdapter(ConstantVTable).lift());
169
170#[cfg(test)]
171mod tests {
172    use vortex_dtype::DType;
173    use vortex_dtype::DecimalDType;
174    use vortex_dtype::Nullability;
175    use vortex_dtype::Nullability::Nullable;
176    use vortex_dtype::PType;
177    use vortex_dtype::i256;
178    use vortex_error::VortexUnwrap;
179    use vortex_scalar::DecimalValue;
180    use vortex_scalar::Scalar;
181
182    use crate::Array;
183    use crate::IntoArray;
184    use crate::arrays::ConstantArray;
185    use crate::compute::sum;
186    use crate::compute::sum_with_accumulator;
187    use crate::expr::stats::Stat;
188
189    #[test]
190    fn test_sum_unsigned() {
191        let array = ConstantArray::new(5u64, 10).into_array();
192        let result = sum(&array).unwrap();
193        assert_eq!(result, 50u64.into());
194    }
195
196    #[test]
197    fn test_sum_signed() {
198        let array = ConstantArray::new(-5i64, 10).into_array();
199        let result = sum(&array).unwrap();
200        assert_eq!(result, (-50i64).into());
201    }
202
203    #[test]
204    fn test_sum_nullable_value() {
205        let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
206            .into_array();
207        let result = sum(&array).unwrap();
208        assert_eq!(result, Scalar::primitive(0u64, Nullable));
209    }
210
211    #[test]
212    fn test_sum_bool_false() {
213        let array = ConstantArray::new(false, 10).into_array();
214        let result = sum(&array).unwrap();
215        assert_eq!(result, 0u64.into());
216    }
217
218    #[test]
219    fn test_sum_bool_true() {
220        let array = ConstantArray::new(true, 10).into_array();
221        let result = sum(&array).unwrap();
222        assert_eq!(result, 10u64.into());
223    }
224
225    #[test]
226    fn test_sum_bool_null() {
227        let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
228        let result = sum(&array).unwrap();
229        assert_eq!(result, Scalar::primitive(0u64, Nullable));
230    }
231
232    #[test]
233    fn test_sum_decimal() {
234        let decimal_dtype = DecimalDType::new(10, 2);
235        let array = ConstantArray::new(
236            Scalar::decimal(
237                DecimalValue::I64(100),
238                decimal_dtype,
239                Nullability::NonNullable,
240            ),
241            5,
242        )
243        .into_array();
244
245        let result = sum(&array).unwrap();
246
247        assert_eq!(
248            result.as_decimal().decimal_value(),
249            Some(DecimalValue::I256(i256::from_i128(500)))
250        );
251        assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
252    }
253
254    #[test]
255    fn test_sum_decimal_null() {
256        let decimal_dtype = DecimalDType::new(10, 2);
257        let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
258            .into_array();
259
260        let result = sum(&array).unwrap();
261        assert_eq!(
262            result,
263            Scalar::decimal(
264                DecimalValue::I256(i256::ZERO),
265                DecimalDType::new(20, 2),
266                Nullable
267            )
268        );
269    }
270
271    #[test]
272    fn test_sum_decimal_large_value() {
273        let decimal_dtype = DecimalDType::new(10, 2);
274        let array = ConstantArray::new(
275            Scalar::decimal(
276                DecimalValue::I64(999_999_999),
277                decimal_dtype,
278                Nullability::NonNullable,
279            ),
280            100,
281        )
282        .into_array();
283
284        let result = sum(&array).unwrap();
285        assert_eq!(
286            result.as_decimal().decimal_value(),
287            Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
288        );
289    }
290
291    #[test]
292    fn test_sum_float_non_multiply() {
293        let acc = -2048669276050936500000000000f64;
294        let array = ConstantArray::new(6.1811675e16f64, 25);
295        let sum =
296            sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable)).vortex_unwrap();
297        assert_eq!(
298            f64::try_from(sum).vortex_unwrap(),
299            -2048669274505644600000000000f64
300        );
301    }
302}