vortex_array/arrays/primitive/compute/
sum.rs1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools;
3use num_traits::{CheckedAdd, Float, ToPrimitive};
4use vortex_dtype::{NativePType, match_each_native_ptype};
5use vortex_error::{VortexExpect, VortexResult};
6use vortex_mask::AllOr;
7use vortex_scalar::Scalar;
8
9use crate::arrays::{PrimitiveArray, PrimitiveEncoding};
10use crate::compute::{SumKernel, SumKernelAdapter};
11use crate::stats::Stat;
12use crate::variants::PrimitiveArrayTrait;
13use crate::{Array, register_kernel};
14
15impl SumKernel for PrimitiveEncoding {
16 fn sum(&self, array: &PrimitiveArray) -> VortexResult<Scalar> {
17 Ok(match array.validity_mask()?.boolean_buffer() {
18 AllOr::All => {
19 match_each_native_ptype!(
21 array.ptype(),
22 unsigned: |$T| { sum_integer::<_, u64>(array.as_slice::<$T>()).into() }
23 signed: |$T| { sum_integer::<_, i64>(array.as_slice::<$T>()).into() }
24 floating: |$T| { sum_float(array.as_slice::<$T>()).into() }
25 )
26 }
27 AllOr::None => {
28 return Ok(Scalar::null(
30 Stat::Sum
31 .dtype(array.dtype())
32 .vortex_expect("Sum dtype must be defined for primitive type"),
33 ));
34 }
35 AllOr::Some(validity_mask) => {
36 match_each_native_ptype!(
38 array.ptype(),
39 unsigned: |$T| {
40 sum_integer_with_validity::<_, u64>(array.as_slice::<$T>(), validity_mask)
41 .into()
42 }
43 signed: |$T| {
44 sum_integer_with_validity::<_, i64>(array.as_slice::<$T>(), validity_mask)
45 .into()
46 }
47 floating: |$T| {
48 sum_float_with_validity(array.as_slice::<$T>(), validity_mask).into()
49 }
50 )
51 }
52 })
53 }
54}
55
56register_kernel!(SumKernelAdapter(PrimitiveEncoding).lift());
57
58fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
59 values: &[T],
60) -> Option<R> {
61 let mut sum = R::zero();
62 for &x in values {
63 sum = sum.checked_add(&R::from(x)?)?;
64 }
65 Some(sum)
66}
67
68fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
69 values: &[T],
70 validity: &BooleanBuffer,
71) -> Option<R> {
72 let mut sum = R::zero();
73 for (&x, valid) in values.iter().zip_eq(validity.iter()) {
74 if valid {
75 sum = sum.checked_add(&R::from(x)?)?;
76 }
77 }
78 Some(sum)
79}
80
81fn sum_float<T: NativePType + Float>(values: &[T]) -> f64 {
82 let mut sum = 0.0;
83 for &x in values {
84 sum += x.to_f64().vortex_expect("Failed to cast value to f64");
85 }
86 sum
87}
88
89fn sum_float_with_validity<T: NativePType + Float>(array: &[T], validity: &BooleanBuffer) -> f64 {
90 let mut sum = 0.0;
91 for (&x, valid) in array.iter().zip_eq(validity.iter()) {
92 if valid {
93 sum += x.to_f64().vortex_expect("Failed to cast value to f64");
94 }
95 }
96 sum
97}