vortex_array/arrays/primitive/compute/
sum.rs

1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools;
3use num_traits::{CheckedAdd, Float, ToPrimitive};
4use vortex_dtype::{NativePType, match_each_native_ptype};
5use vortex_error::{VortexExpect, VortexResult};
6use vortex_mask::AllOr;
7use vortex_scalar::Scalar;
8
9use crate::Array;
10use crate::arrays::{PrimitiveArray, PrimitiveEncoding};
11use crate::compute::SumFn;
12use crate::stats::Stat;
13use crate::variants::PrimitiveArrayTrait;
14
15impl SumFn<&PrimitiveArray> for PrimitiveEncoding {
16    fn sum(&self, array: &PrimitiveArray) -> VortexResult<Scalar> {
17        Ok(match array.validity_mask()?.boolean_buffer() {
18            AllOr::All => {
19                // All-valid
20                match_each_native_ptype!(
21                    array.ptype(),
22                    unsigned: |$T| { sum_integer::<_, u64>(array.as_slice::<$T>()).into() }
23                    signed: |$T| { sum_integer::<_, i64>(array.as_slice::<$T>()).into() }
24                    floating: |$T| { sum_float(array.as_slice::<$T>()).into() }
25                )
26            }
27            AllOr::None => {
28                // All-invalid
29                return Ok(Scalar::null(
30                    Stat::Sum
31                        .dtype(array.dtype())
32                        .vortex_expect("Sum dtype must be defined for primitive type"),
33                ));
34            }
35            AllOr::Some(validity_mask) => {
36                // Some-valid
37                match_each_native_ptype!(
38                    array.ptype(),
39                    unsigned: |$T| {
40                        sum_integer_with_validity::<_, u64>(array.as_slice::<$T>(), validity_mask)
41                            .into()
42                    }
43                    signed: |$T| {
44                        sum_integer_with_validity::<_, i64>(array.as_slice::<$T>(), validity_mask)
45                            .into()
46                    }
47                    floating: |$T| {
48                        sum_float_with_validity(array.as_slice::<$T>(), validity_mask).into()
49                    }
50                )
51            }
52        })
53    }
54}
55
56fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
57    values: &[T],
58) -> Option<R> {
59    let mut sum = R::zero();
60    for &x in values {
61        sum = sum.checked_add(&R::from(x)?)?;
62    }
63    Some(sum)
64}
65
66fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
67    values: &[T],
68    validity: &BooleanBuffer,
69) -> Option<R> {
70    let mut sum = R::zero();
71    for (&x, valid) in values.iter().zip_eq(validity.iter()) {
72        if valid {
73            sum = sum.checked_add(&R::from(x)?)?;
74        }
75    }
76    Some(sum)
77}
78
79fn sum_float<T: NativePType + Float>(values: &[T]) -> f64 {
80    let mut sum = 0.0;
81    for &x in values {
82        sum += x.to_f64().vortex_expect("Failed to cast value to f64");
83    }
84    sum
85}
86
87fn sum_float_with_validity<T: NativePType + Float>(array: &[T], validity: &BooleanBuffer) -> f64 {
88    let mut sum = 0.0;
89    for (&x, valid) in array.iter().zip_eq(validity.iter()) {
90        if valid {
91            sum += x.to_f64().vortex_expect("Failed to cast value to f64");
92        }
93    }
94    sum
95}