Skip to main content

vortex_array/arrays/constant/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::CheckedAdd;
6use num_traits::CheckedMul;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11
12use crate::arrays::ConstantArray;
13use crate::arrays::ConstantVTable;
14use crate::compute::SumKernel;
15use crate::compute::SumKernelAdapter;
16use crate::dtype::DType;
17use crate::dtype::DecimalDType;
18use crate::dtype::NativePType;
19use crate::dtype::Nullability;
20use crate::dtype::i256;
21use crate::expr::stats::Stat;
22use crate::match_each_native_ptype;
23use crate::register_kernel;
24use crate::scalar::DecimalScalar;
25use crate::scalar::DecimalValue;
26use crate::scalar::PrimitiveScalar;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30impl SumKernel for ConstantVTable {
31    fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
32        // Compute the expected dtype of the sum.
33        let sum_dtype = Stat::Sum
34            .dtype(array.dtype())
35            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
36
37        let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
38        Scalar::try_new(sum_dtype, sum_value)
39    }
40}
41
42fn sum_scalar(
43    scalar: &Scalar,
44    len: usize,
45    accumulator: &Scalar,
46) -> VortexResult<Option<ScalarValue>> {
47    match scalar.dtype() {
48        DType::Bool(_) => {
49            let count = match scalar.as_bool().value() {
50                None => unreachable!("Handled before reaching this point"),
51                Some(false) => 0u64,
52                Some(true) => len as u64,
53            };
54            let accumulator = accumulator
55                .as_primitive()
56                .as_::<u64>()
57                .vortex_expect("cannot be null");
58            Ok(accumulator
59                .checked_add(count)
60                .map(|v| ScalarValue::Primitive(v.into())))
61        }
62        DType::Primitive(ptype, _) => {
63            #[expect(dead_code, reason = "TODO(connor): good question")]
64            let result = match_each_native_ptype!(
65                ptype,
66                unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) },
67                signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) },
68                floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }
69            );
70            Ok(result)
71        }
72        DType::Decimal(decimal_dtype, _) => {
73            sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
74        }
75        DType::Extension(_) => {
76            sum_scalar(&scalar.as_extension().to_storage_scalar(), len, accumulator)
77        }
78        dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
79    }
80}
81
82fn sum_decimal(
83    decimal_scalar: DecimalScalar,
84    array_len: usize,
85    decimal_dtype: DecimalDType,
86    accumulator: &Scalar,
87) -> VortexResult<Option<ScalarValue>> {
88    let result_dtype = Stat::Sum
89        .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
90        .vortex_expect("decimal supports sum");
91    let result_decimal_type = result_dtype
92        .as_decimal_opt()
93        .vortex_expect("must be decimal");
94
95    let Some(value) = decimal_scalar.decimal_value() else {
96        // Null value: return null
97        return Ok(None);
98    };
99
100    // Convert array_len to DecimalValue for multiplication.
101    let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
102
103    let Some(array_sum) = value
104        .checked_mul(&len_value)
105        .filter(|d| d.fits_in_precision(*result_decimal_type))
106    else {
107        return Ok(None);
108    };
109
110    // Add accumulator to array_sum.
111    let initial_decimal = accumulator.as_decimal();
112    let initial_dec_value = initial_decimal
113        .decimal_value()
114        .unwrap_or(DecimalValue::I256(i256::ZERO));
115
116    let total = array_sum
117        .checked_add(&initial_dec_value)
118        .and_then(|result| {
119            result
120                .fits_in_precision(*result_decimal_type)
121                .then_some(result)
122        });
123    match total {
124        Some(result_value) => Ok(Some(ScalarValue::from(result_value))),
125        None => Ok(None), // Overflow
126    }
127}
128
129fn sum_integral<T>(
130    primitive_scalar: PrimitiveScalar<'_>,
131    array_len: usize,
132    accumulator: &Scalar,
133) -> VortexResult<Option<T>>
134where
135    T: NativePType + CheckedMul + CheckedAdd,
136{
137    let v = primitive_scalar.as_::<T>();
138    let array_len =
139        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
140    let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
141        return Ok(None);
142    };
143
144    let initial = accumulator
145        .as_primitive()
146        .as_::<T>()
147        .vortex_expect("cannot be null");
148    Ok(initial.checked_add(&array_sum))
149}
150
151fn sum_float(
152    primitive_scalar: PrimitiveScalar<'_>,
153    array_len: usize,
154    accumulator: &Scalar,
155) -> VortexResult<Option<f64>> {
156    let initial = accumulator
157        .as_primitive()
158        .as_::<f64>()
159        .vortex_expect("cannot be null");
160    let v = primitive_scalar
161        .as_::<f64>()
162        .vortex_expect("cannot be null");
163    let len_f64: f64 = array_len.as_();
164
165    Ok(Some(initial + v * len_f64))
166}
167
168register_kernel!(SumKernelAdapter(ConstantVTable).lift());
169
170#[cfg(test)]
171mod tests {
172    use vortex_error::VortexExpect;
173
174    use crate::Array;
175    use crate::IntoArray;
176    use crate::arrays::ConstantArray;
177    use crate::compute::sum;
178    use crate::compute::sum_with_accumulator;
179    use crate::dtype::DType;
180    use crate::dtype::DecimalDType;
181    use crate::dtype::Nullability;
182    use crate::dtype::Nullability::Nullable;
183    use crate::dtype::PType;
184    use crate::dtype::i256;
185    use crate::expr::stats::Stat;
186    use crate::scalar::DecimalValue;
187    use crate::scalar::Scalar;
188
189    #[test]
190    fn test_sum_unsigned() {
191        let array = ConstantArray::new(5u64, 10).into_array();
192        let result = sum(&array).unwrap();
193        assert_eq!(result, 50u64.into());
194    }
195
196    #[test]
197    fn test_sum_signed() {
198        let array = ConstantArray::new(-5i64, 10).into_array();
199        let result = sum(&array).unwrap();
200        assert_eq!(result, (-50i64).into());
201    }
202
203    #[test]
204    fn test_sum_nullable_value() {
205        let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
206            .into_array();
207        let result = sum(&array).unwrap();
208        assert_eq!(result, Scalar::primitive(0u64, Nullable));
209    }
210
211    #[test]
212    fn test_sum_bool_false() {
213        let array = ConstantArray::new(false, 10).into_array();
214        let result = sum(&array).unwrap();
215        assert_eq!(result, 0u64.into());
216    }
217
218    #[test]
219    fn test_sum_bool_true() {
220        let array = ConstantArray::new(true, 10).into_array();
221        let result = sum(&array).unwrap();
222        assert_eq!(result, 10u64.into());
223    }
224
225    #[test]
226    fn test_sum_bool_null() {
227        let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
228        let result = sum(&array).unwrap();
229        assert_eq!(result, Scalar::primitive(0u64, Nullable));
230    }
231
232    #[test]
233    fn test_sum_decimal() {
234        let decimal_dtype = DecimalDType::new(10, 2);
235        let array = ConstantArray::new(
236            Scalar::decimal(
237                DecimalValue::I64(100),
238                decimal_dtype,
239                Nullability::NonNullable,
240            ),
241            5,
242        )
243        .into_array();
244
245        let result = sum(&array).unwrap();
246
247        assert_eq!(
248            result.as_decimal().decimal_value(),
249            Some(DecimalValue::I256(i256::from_i128(500)))
250        );
251        assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
252    }
253
254    #[test]
255    fn test_sum_decimal_null() {
256        let decimal_dtype = DecimalDType::new(10, 2);
257        let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
258            .into_array();
259
260        let result = sum(&array).unwrap();
261        assert_eq!(
262            result,
263            Scalar::decimal(
264                DecimalValue::I256(i256::ZERO),
265                DecimalDType::new(20, 2),
266                Nullable
267            )
268        );
269    }
270
271    #[test]
272    fn test_sum_decimal_large_value() {
273        let decimal_dtype = DecimalDType::new(10, 2);
274        let array = ConstantArray::new(
275            Scalar::decimal(
276                DecimalValue::I64(999_999_999),
277                decimal_dtype,
278                Nullability::NonNullable,
279            ),
280            100,
281        )
282        .into_array();
283
284        let result = sum(&array).unwrap();
285        assert_eq!(
286            result.as_decimal().decimal_value(),
287            Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
288        );
289    }
290
291    #[test]
292    fn test_sum_float_non_multiply() {
293        let acc = -2048669276050936500000000000f64;
294        let array = ConstantArray::new(6.1811675e16f64, 25);
295        let sum = sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable))
296            .vortex_expect("operation should succeed in test");
297        assert_eq!(
298            f64::try_from(&sum).vortex_expect("operation should succeed in test"),
299            -2048669274505644600000000000f64
300        );
301    }
302}