vortex_array/arrays/primitive/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use num_traits::{CheckedAdd, Float, ToPrimitive};
6use vortex_buffer::BitBuffer;
7use vortex_dtype::{NativePType, match_each_native_ptype};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_mask::AllOr;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{PrimitiveArray, PrimitiveVTable};
13use crate::compute::{SumKernel, SumKernelAdapter};
14use crate::register_kernel;
15
16impl SumKernel for PrimitiveVTable {
17    fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult<Scalar> {
18        let array_sum_scalar = match array.validity_mask().bit_buffer() {
19            AllOr::All => {
20                // All-valid
21                match_each_native_ptype!(
22                    array.ptype(),
23                    unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into() },
24                    signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into() },
25                    floating: |T| { Some(sum_float(array.as_slice::<T>(), accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into() }
26                )
27            }
28            AllOr::None => {
29                // All-invalid, return accumulator
30                return Ok(accumulator.clone());
31            }
32            AllOr::Some(validity_mask) => {
33                // Some-valid
34                match_each_native_ptype!(
35                    array.ptype(),
36                    unsigned: |T| {
37                        sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into()
38                    },
39                    signed: |T| {
40                        sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into()
41                    },
42                    floating: |T| {
43                        Some(sum_float_with_validity(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into()
44                    }
45                )
46            }
47        };
48
49        Ok(array_sum_scalar)
50    }
51}
52
53register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
54
55fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
56    values: &[T],
57    accumulator: R,
58) -> Option<R> {
59    let mut sum = accumulator;
60    for &x in values {
61        sum = sum.checked_add(&R::from(x)?)?;
62    }
63    Some(sum)
64}
65
66fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
67    values: &[T],
68    validity: &BitBuffer,
69    accumulator: R,
70) -> Option<R> {
71    let mut sum: R = accumulator;
72    for (&x, valid) in values.iter().zip_eq(validity.iter()) {
73        if valid {
74            sum = sum.checked_add(&R::from(x)?)?;
75        }
76    }
77    Some(sum)
78}
79
80fn sum_float<T: NativePType + Float>(values: &[T], accumulator: f64) -> f64 {
81    let mut sum = accumulator;
82    for &x in values {
83        sum += x.to_f64().vortex_expect("Failed to cast value to f64");
84    }
85    sum
86}
87
88fn sum_float_with_validity<T: NativePType + Float>(
89    array: &[T],
90    validity: &BitBuffer,
91    accumulator: f64,
92) -> f64 {
93    let mut sum = accumulator;
94    for (&x, valid) in array.iter().zip_eq(validity.iter()) {
95        if valid {
96            sum += x.to_f64().vortex_expect("Failed to cast value to f64");
97        }
98    }
99    sum
100}