vortex_array/arrays/primitive/compute/
sum.rs

1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools;
3use num_traits::{CheckedAdd, Float, ToPrimitive};
4use vortex_dtype::{NativePType, match_each_native_ptype};
5use vortex_error::{VortexExpect, VortexResult};
6use vortex_mask::AllOr;
7use vortex_scalar::Scalar;
8
9use crate::arrays::{PrimitiveArray, PrimitiveVTable};
10use crate::compute::{SumKernel, SumKernelAdapter};
11use crate::register_kernel;
12use crate::stats::Stat;
13
14impl SumKernel for PrimitiveVTable {
15    fn sum(&self, array: &PrimitiveArray) -> VortexResult<Scalar> {
16        Ok(match array.validity_mask()?.boolean_buffer() {
17            AllOr::All => {
18                // All-valid
19                match_each_native_ptype!(
20                    array.ptype(),
21                    unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>()).into() },
22                    signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>()).into() },
23                    floating: |T| { sum_float(array.as_slice::<T>()).into() }
24                )
25            }
26            AllOr::None => {
27                // All-invalid
28                return Ok(Scalar::null(
29                    Stat::Sum
30                        .dtype(array.dtype())
31                        .vortex_expect("Sum dtype must be defined for primitive type"),
32                ));
33            }
34            AllOr::Some(validity_mask) => {
35                // Some-valid
36                match_each_native_ptype!(
37                    array.ptype(),
38                    unsigned: |T| {
39                        sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask)
40                            .into()
41                    },
42                    signed: |T| {
43                        sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask)
44                            .into()
45                    },
46                    floating: |T| {
47                        sum_float_with_validity(array.as_slice::<T>(), validity_mask).into()
48                    }
49                )
50            }
51        })
52    }
53}
54
55register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
56
57fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
58    values: &[T],
59) -> Option<R> {
60    let mut sum = R::zero();
61    for &x in values {
62        sum = sum.checked_add(&R::from(x)?)?;
63    }
64    Some(sum)
65}
66
67fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
68    values: &[T],
69    validity: &BooleanBuffer,
70) -> Option<R> {
71    let mut sum = R::zero();
72    for (&x, valid) in values.iter().zip_eq(validity.iter()) {
73        if valid {
74            sum = sum.checked_add(&R::from(x)?)?;
75        }
76    }
77    Some(sum)
78}
79
80fn sum_float<T: NativePType + Float>(values: &[T]) -> f64 {
81    let mut sum = 0.0;
82    for &x in values {
83        sum += x.to_f64().vortex_expect("Failed to cast value to f64");
84    }
85    sum
86}
87
88fn sum_float_with_validity<T: NativePType + Float>(array: &[T], validity: &BooleanBuffer) -> f64 {
89    let mut sum = 0.0;
90    for (&x, valid) in array.iter().zip_eq(validity.iter()) {
91        if valid {
92            sum += x.to_f64().vortex_expect("Failed to cast value to f64");
93        }
94    }
95    sum
96}