vortex_array/arrays/primitive/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBuffer;
5use itertools::Itertools;
6use num_traits::{CheckedAdd, Float, ToPrimitive};
7use vortex_dtype::{NativePType, match_each_native_ptype};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_mask::AllOr;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{PrimitiveArray, PrimitiveVTable};
13use crate::compute::{SumKernel, SumKernelAdapter};
14use crate::register_kernel;
15use crate::stats::Stat;
16
17impl SumKernel for PrimitiveVTable {
18    fn sum(&self, array: &PrimitiveArray) -> VortexResult<Scalar> {
19        Ok(match array.validity_mask()?.boolean_buffer() {
20            AllOr::All => {
21                // All-valid
22                match_each_native_ptype!(
23                    array.ptype(),
24                    unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>()).into() },
25                    signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>()).into() },
26                    floating: |T| { sum_float(array.as_slice::<T>()).into() }
27                )
28            }
29            AllOr::None => {
30                // All-invalid
31                return Ok(Scalar::null(
32                    Stat::Sum
33                        .dtype(array.dtype())
34                        .vortex_expect("Sum dtype must be defined for primitive type"),
35                ));
36            }
37            AllOr::Some(validity_mask) => {
38                // Some-valid
39                match_each_native_ptype!(
40                    array.ptype(),
41                    unsigned: |T| {
42                        sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask)
43                            .into()
44                    },
45                    signed: |T| {
46                        sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask)
47                            .into()
48                    },
49                    floating: |T| {
50                        sum_float_with_validity(array.as_slice::<T>(), validity_mask).into()
51                    }
52                )
53            }
54        })
55    }
56}
57
58register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
59
60fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
61    values: &[T],
62) -> Option<R> {
63    let mut sum = R::zero();
64    for &x in values {
65        sum = sum.checked_add(&R::from(x)?)?;
66    }
67    Some(sum)
68}
69
70fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
71    values: &[T],
72    validity: &BooleanBuffer,
73) -> Option<R> {
74    let mut sum = R::zero();
75    for (&x, valid) in values.iter().zip_eq(validity.iter()) {
76        if valid {
77            sum = sum.checked_add(&R::from(x)?)?;
78        }
79    }
80    Some(sum)
81}
82
83fn sum_float<T: NativePType + Float>(values: &[T]) -> f64 {
84    let mut sum = 0.0;
85    for &x in values {
86        sum += x.to_f64().vortex_expect("Failed to cast value to f64");
87    }
88    sum
89}
90
91fn sum_float_with_validity<T: NativePType + Float>(array: &[T], validity: &BooleanBuffer) -> f64 {
92    let mut sum = 0.0;
93    for (&x, valid) in array.iter().zip_eq(validity.iter()) {
94        if valid {
95            sum += x.to_f64().vortex_expect("Failed to cast value to f64");
96        }
97    }
98    sum
99}