vortex_array/arrays/primitive/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use num_traits::CheckedAdd;
6use num_traits::Float;
7use num_traits::ToPrimitive;
8use vortex_buffer::BitBuffer;
9use vortex_dtype::NativePType;
10use vortex_dtype::match_each_native_ptype;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_mask::AllOr;
14use vortex_scalar::Scalar;
15
16use crate::arrays::PrimitiveArray;
17use crate::arrays::PrimitiveVTable;
18use crate::compute::SumKernel;
19use crate::compute::SumKernelAdapter;
20use crate::register_kernel;
21
22impl SumKernel for PrimitiveVTable {
23    fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult<Scalar> {
24        let array_sum_scalar = match array.validity_mask().bit_buffer() {
25            AllOr::All => {
26                // All-valid
27                match_each_native_ptype!(
28                    array.ptype(),
29                    unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into() },
30                    signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into() },
31                    floating: |T| { Some(sum_float(array.as_slice::<T>(), accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into() }
32                )
33            }
34            AllOr::None => {
35                // All-invalid, return accumulator
36                return Ok(accumulator.clone());
37            }
38            AllOr::Some(validity_mask) => {
39                // Some-valid
40                match_each_native_ptype!(
41                    array.ptype(),
42                    unsigned: |T| {
43                        sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into()
44                    },
45                    signed: |T| {
46                        sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into()
47                    },
48                    floating: |T| {
49                        Some(sum_float_with_validity(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into()
50                    }
51                )
52            }
53        };
54
55        Ok(array_sum_scalar)
56    }
57}
58
59register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
60
61fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
62    values: &[T],
63    accumulator: R,
64) -> Option<R> {
65    let mut sum = accumulator;
66    for &x in values {
67        sum = sum.checked_add(&R::from(x)?)?;
68    }
69    Some(sum)
70}
71
72fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
73    values: &[T],
74    validity: &BitBuffer,
75    accumulator: R,
76) -> Option<R> {
77    let mut sum: R = accumulator;
78    for (&x, valid) in values.iter().zip_eq(validity.iter()) {
79        if valid {
80            sum = sum.checked_add(&R::from(x)?)?;
81        }
82    }
83    Some(sum)
84}
85
86fn sum_float<T: NativePType + Float>(values: &[T], accumulator: f64) -> f64 {
87    let mut sum = accumulator;
88    for &x in values {
89        sum += x.to_f64().vortex_expect("Failed to cast value to f64");
90    }
91    sum
92}
93
94fn sum_float_with_validity<T: NativePType + Float>(
95    array: &[T],
96    validity: &BitBuffer,
97    accumulator: f64,
98) -> f64 {
99    let mut sum = accumulator;
100    for (&x, valid) in array.iter().zip_eq(validity.iter()) {
101        if valid {
102            sum += x.to_f64().vortex_expect("Failed to cast value to f64");
103        }
104    }
105    sum
106}