Skip to main content

vortex_array/arrays/primitive/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use num_traits::CheckedAdd;
6use num_traits::Float;
7use num_traits::ToPrimitive;
8use vortex_buffer::BitBuffer;
9use vortex_dtype::NativePType;
10use vortex_dtype::Nullability;
11use vortex_dtype::match_each_native_ptype;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_mask::AllOr;
15
16use crate::arrays::PrimitiveArray;
17use crate::arrays::PrimitiveVTable;
18use crate::compute::SumKernel;
19use crate::compute::SumKernelAdapter;
20use crate::register_kernel;
21use crate::scalar::Scalar;
22
23impl SumKernel for PrimitiveVTable {
24    fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult<Scalar> {
25        let array_sum_scalar = match array.validity_mask()?.bit_buffer() {
26            AllOr::All => {
27                // All-valid
28                match_each_native_ptype!(
29                    array.ptype(),
30                    unsigned: |T| {
31                        Scalar::from(sum_integer::<_, u64>(
32                            array.as_slice::<T>(),
33                            accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null"),
34                        ))
35                    },
36                    signed: |T| {
37                        Scalar::from(sum_integer::<_, i64>(
38                            array.as_slice::<T>(),
39                            accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null"),
40                        ))
41                    },
42                    floating: |T| {
43                        Scalar::primitive(
44                            sum_float(
45                                array.as_slice::<T>(),
46                                accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"),
47                            ),
48                            Nullability::Nullable,
49                        )
50                    }
51                )
52            }
53            AllOr::None => {
54                // All-invalid, return accumulator
55                return Ok(accumulator.clone());
56            }
57            AllOr::Some(validity_mask) => {
58                // Some-valid
59                match_each_native_ptype!(
60                    array.ptype(),
61                    unsigned: |T| {
62                        Scalar::from(sum_integer_with_validity::<_, u64>(
63                            array.as_slice::<T>(),
64                            validity_mask,
65                            accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null"),
66                        ))
67                    },
68                    signed: |T| {
69                        Scalar::from(sum_integer_with_validity::<_, i64>(
70                            array.as_slice::<T>(),
71                            validity_mask,
72                            accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null"),
73                        ))
74                    },
75                    floating: |T| {
76                        Scalar::primitive(
77                            sum_float_with_validity(
78                                array.as_slice::<T>(),
79                                validity_mask,
80                                accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"),
81                            ),
82                            Nullability::Nullable,
83                        )
84                    }
85                )
86            }
87        };
88
89        Ok(array_sum_scalar)
90    }
91}
92
93register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
94
95fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
96    values: &[T],
97    accumulator: R,
98) -> Option<R> {
99    let mut sum = accumulator;
100    for &x in values {
101        sum = sum.checked_add(&R::from(x)?)?;
102    }
103    Some(sum)
104}
105
106fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
107    values: &[T],
108    validity: &BitBuffer,
109    accumulator: R,
110) -> Option<R> {
111    let mut sum: R = accumulator;
112    for (&x, valid) in values.iter().zip_eq(validity.iter()) {
113        if valid {
114            sum = sum.checked_add(&R::from(x)?)?;
115        }
116    }
117    Some(sum)
118}
119
120fn sum_float<T: NativePType + Float>(values: &[T], accumulator: f64) -> f64 {
121    let mut sum = accumulator;
122    for &x in values {
123        sum += x.to_f64().vortex_expect("Failed to cast value to f64");
124    }
125    sum
126}
127
128fn sum_float_with_validity<T: NativePType + Float>(
129    array: &[T],
130    validity: &BitBuffer,
131    accumulator: f64,
132) -> f64 {
133    let mut sum = accumulator;
134    for (&x, valid) in array.iter().zip_eq(validity.iter()) {
135        if valid {
136            sum += x.to_f64().vortex_expect("Failed to cast value to f64");
137        }
138    }
139    sum
140}