vortex_array/arrays/primitive/compute/
sum.rs1use itertools::Itertools;
5use num_traits::CheckedAdd;
6use num_traits::Float;
7use num_traits::ToPrimitive;
8use vortex_buffer::BitBuffer;
9use vortex_dtype::NativePType;
10use vortex_dtype::Nullability;
11use vortex_dtype::match_each_native_ptype;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_mask::AllOr;
15
16use crate::arrays::PrimitiveArray;
17use crate::arrays::PrimitiveVTable;
18use crate::compute::SumKernel;
19use crate::compute::SumKernelAdapter;
20use crate::register_kernel;
21use crate::scalar::Scalar;
22
23impl SumKernel for PrimitiveVTable {
24 fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult<Scalar> {
25 let array_sum_scalar = match array.validity_mask()?.bit_buffer() {
26 AllOr::All => {
27 match_each_native_ptype!(
29 array.ptype(),
30 unsigned: |T| {
31 Scalar::from(sum_integer::<_, u64>(
32 array.as_slice::<T>(),
33 accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null"),
34 ))
35 },
36 signed: |T| {
37 Scalar::from(sum_integer::<_, i64>(
38 array.as_slice::<T>(),
39 accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null"),
40 ))
41 },
42 floating: |T| {
43 Scalar::primitive(
44 sum_float(
45 array.as_slice::<T>(),
46 accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"),
47 ),
48 Nullability::Nullable,
49 )
50 }
51 )
52 }
53 AllOr::None => {
54 return Ok(accumulator.clone());
56 }
57 AllOr::Some(validity_mask) => {
58 match_each_native_ptype!(
60 array.ptype(),
61 unsigned: |T| {
62 Scalar::from(sum_integer_with_validity::<_, u64>(
63 array.as_slice::<T>(),
64 validity_mask,
65 accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null"),
66 ))
67 },
68 signed: |T| {
69 Scalar::from(sum_integer_with_validity::<_, i64>(
70 array.as_slice::<T>(),
71 validity_mask,
72 accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null"),
73 ))
74 },
75 floating: |T| {
76 Scalar::primitive(
77 sum_float_with_validity(
78 array.as_slice::<T>(),
79 validity_mask,
80 accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"),
81 ),
82 Nullability::Nullable,
83 )
84 }
85 )
86 }
87 };
88
89 Ok(array_sum_scalar)
90 }
91}
92
93register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
94
95fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
96 values: &[T],
97 accumulator: R,
98) -> Option<R> {
99 let mut sum = accumulator;
100 for &x in values {
101 sum = sum.checked_add(&R::from(x)?)?;
102 }
103 Some(sum)
104}
105
106fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
107 values: &[T],
108 validity: &BitBuffer,
109 accumulator: R,
110) -> Option<R> {
111 let mut sum: R = accumulator;
112 for (&x, valid) in values.iter().zip_eq(validity.iter()) {
113 if valid {
114 sum = sum.checked_add(&R::from(x)?)?;
115 }
116 }
117 Some(sum)
118}
119
120fn sum_float<T: NativePType + Float>(values: &[T], accumulator: f64) -> f64 {
121 let mut sum = accumulator;
122 for &x in values {
123 sum += x.to_f64().vortex_expect("Failed to cast value to f64");
124 }
125 sum
126}
127
128fn sum_float_with_validity<T: NativePType + Float>(
129 array: &[T],
130 validity: &BitBuffer,
131 accumulator: f64,
132) -> f64 {
133 let mut sum = accumulator;
134 for (&x, valid) in array.iter().zip_eq(validity.iter()) {
135 if valid {
136 sum += x.to_f64().vortex_expect("Failed to cast value to f64");
137 }
138 }
139 sum
140}