vortex_array/arrays/chunked/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::PrimInt;
5use vortex_dtype::{NativePType, PType, match_each_native_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_err};
7use vortex_scalar::{FromPrimitiveOrF16, Scalar};
8
9use crate::arrays::{ChunkedArray, ChunkedVTable};
10use crate::compute::{SumKernel, SumKernelAdapter, sum};
11use crate::stats::Stat;
12use crate::{ArrayRef, register_kernel};
13
14impl SumKernel for ChunkedVTable {
15    fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
16        let sum_dtype = Stat::Sum
17            .dtype(array.dtype())
18            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
19        let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
20
21        let scalar_value = match_each_native_ptype!(
22            sum_ptype,
23            unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
24            signed: |T| { sum_int::<i64>(array.chunks())?.into() },
25            floating: |T| { sum_float(array.chunks())?.into() }
26        );
27
28        Ok(Scalar::new(sum_dtype, scalar_value))
29    }
30}
31
32register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
33
34fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
35    chunks: &[ArrayRef],
36) -> VortexResult<Option<T>> {
37    let mut result = T::zero();
38    for chunk in chunks {
39        let chunk_sum = sum(chunk)?;
40
41        let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>() else {
42            // Bail out on overflow
43            return Ok(None);
44        };
45
46        let Some(chunk_result) = result.checked_add(&chunk_sum) else {
47            // Bail out on overflow
48            return Ok(None);
49        };
50
51        result = chunk_result;
52    }
53    Ok(Some(result))
54}
55
56fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
57    let mut result = 0f64;
58    for chunk in chunks {
59        if let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() {
60            result += chunk_sum;
61        };
62    }
63    Ok(result)
64}
65
66#[cfg(test)]
67mod tests {
68    use vortex_dtype::Nullability;
69    use vortex_scalar::Scalar;
70
71    use crate::array::IntoArray;
72    use crate::arrays::{ChunkedArray, ConstantArray, PrimitiveArray};
73    use crate::compute::sum;
74
75    #[test]
76    fn test_sum_chunked_floats_with_nulls() {
77        // Create chunks with floats including nulls
78        let chunk1 =
79            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
80
81        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
82
83        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
84
85        // Create chunked array from the chunks
86        let dtype = chunk1.dtype().clone();
87        let chunked = ChunkedArray::try_new(
88            vec![
89                chunk1.into_array(),
90                chunk2.into_array(),
91                chunk3.into_array(),
92            ],
93            dtype,
94        )
95        .unwrap();
96
97        // Compute sum
98        let result = sum(chunked.as_ref()).unwrap();
99
100        // Expected sum: 1.5 + 3.2 + 4.8 + 2.1 + 5.7 + 1.0 + 2.5 = 20.8
101        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
102    }
103
104    #[test]
105    fn test_sum_chunked_floats_all_nulls() {
106        // Create chunks with all nulls
107        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
108        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
109
110        let dtype = chunk1.dtype().clone();
111        let chunked =
112            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
113
114        // Compute sum - should return null for all nulls
115        let result = sum(chunked.as_ref()).unwrap();
116        assert!(result.as_primitive().as_::<f64>().is_none());
117    }
118
119    #[test]
120    fn test_sum_chunked_floats_empty_chunks() {
121        // Test with some empty chunks mixed with non-empty
122        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
123        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullability::Nullable), 0);
124        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
125
126        let dtype = chunk1.dtype().clone();
127        let chunked = ChunkedArray::try_new(
128            vec![
129                chunk1.into_array(),
130                chunk2.into_array(),
131                chunk3.into_array(),
132            ],
133            dtype,
134        )
135        .unwrap();
136
137        // Compute sum: 10.5 + 20.3 + 5.2 = 36.0
138        let result = sum(chunked.as_ref()).unwrap();
139        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
140    }
141}