vortex_array/arrays/primitive/compute/
sum.rs1use itertools::Itertools;
5use num_traits::CheckedAdd;
6use num_traits::Float;
7use num_traits::ToPrimitive;
8use vortex_buffer::BitBuffer;
9use vortex_dtype::NativePType;
10use vortex_dtype::match_each_native_ptype;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_mask::AllOr;
14use vortex_scalar::Scalar;
15
16use crate::arrays::PrimitiveArray;
17use crate::arrays::PrimitiveVTable;
18use crate::compute::SumKernel;
19use crate::compute::SumKernelAdapter;
20use crate::register_kernel;
21
22impl SumKernel for PrimitiveVTable {
23 fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult<Scalar> {
24 let array_sum_scalar = match array.validity_mask().bit_buffer() {
25 AllOr::All => {
26 match_each_native_ptype!(
28 array.ptype(),
29 unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into() },
30 signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into() },
31 floating: |T| { Some(sum_float(array.as_slice::<T>(), accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into() }
32 )
33 }
34 AllOr::None => {
35 return Ok(accumulator.clone());
37 }
38 AllOr::Some(validity_mask) => {
39 match_each_native_ptype!(
41 array.ptype(),
42 unsigned: |T| {
43 sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into()
44 },
45 signed: |T| {
46 sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into()
47 },
48 floating: |T| {
49 Some(sum_float_with_validity(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into()
50 }
51 )
52 }
53 };
54
55 Ok(array_sum_scalar)
56 }
57}
58
59register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
60
61fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
62 values: &[T],
63 accumulator: R,
64) -> Option<R> {
65 let mut sum = accumulator;
66 for &x in values {
67 sum = sum.checked_add(&R::from(x)?)?;
68 }
69 Some(sum)
70}
71
72fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
73 values: &[T],
74 validity: &BitBuffer,
75 accumulator: R,
76) -> Option<R> {
77 let mut sum: R = accumulator;
78 for (&x, valid) in values.iter().zip_eq(validity.iter()) {
79 if valid {
80 sum = sum.checked_add(&R::from(x)?)?;
81 }
82 }
83 Some(sum)
84}
85
86fn sum_float<T: NativePType + Float>(values: &[T], accumulator: f64) -> f64 {
87 let mut sum = accumulator;
88 for &x in values {
89 sum += x.to_f64().vortex_expect("Failed to cast value to f64");
90 }
91 sum
92}
93
94fn sum_float_with_validity<T: NativePType + Float>(
95 array: &[T],
96 validity: &BitBuffer,
97 accumulator: f64,
98) -> f64 {
99 let mut sum = accumulator;
100 for (&x, valid) in array.iter().zip_eq(validity.iter()) {
101 if valid {
102 sum += x.to_f64().vortex_expect("Failed to cast value to f64");
103 }
104 }
105 sum
106}