vortex_array/arrays/chunked/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_scalar::Scalar;
6
7use crate::arrays::{ChunkedArray, ChunkedVTable};
8use crate::compute::{SumKernel, SumKernelAdapter, sum_with_accumulator};
9use crate::register_kernel;
10
11impl SumKernel for ChunkedVTable {
12    fn sum(&self, array: &ChunkedArray, accumulator: &Scalar) -> VortexResult<Scalar> {
13        array
14            .chunks
15            .iter()
16            .try_fold(accumulator.clone(), |result, chunk| {
17                sum_with_accumulator(chunk, &result)
18            })
19    }
20}
21
22register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
23
24#[cfg(test)]
25mod tests {
26    use vortex_buffer::buffer;
27    use vortex_dtype::{DType, DecimalDType, Nullability};
28    use vortex_scalar::{DecimalValue, Scalar, i256};
29
30    use crate::array::IntoArray;
31    use crate::arrays::{ChunkedArray, ConstantArray, DecimalArray, PrimitiveArray};
32    use crate::compute::sum;
33    use crate::validity::Validity;
34
35    #[test]
36    fn test_sum_chunked_floats_with_nulls() {
37        // Create chunks with floats including nulls
38        let chunk1 =
39            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
40
41        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
42
43        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
44
45        // Create chunked array from the chunks
46        let dtype = chunk1.dtype().clone();
47        let chunked = ChunkedArray::try_new(
48            vec![
49                chunk1.into_array(),
50                chunk2.into_array(),
51                chunk3.into_array(),
52            ],
53            dtype,
54        )
55        .unwrap();
56
57        // Compute sum
58        let result = sum(chunked.as_ref()).unwrap();
59
60        // Expected sum: 1.5 + 3.2 + 4.8 + 2.1 + 5.7 + 1.0 + 2.5 = 20.8
61        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
62    }
63
64    #[test]
65    fn test_sum_chunked_floats_all_nulls_is_zero() {
66        // Create chunks with all nulls
67        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
68        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
69
70        let dtype = chunk1.dtype().clone();
71        let chunked =
72            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
73        // Compute sum - should return null for all nulls
74        let result = sum(chunked.as_ref()).unwrap();
75        assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
76    }
77
78    #[test]
79    fn test_sum_chunked_floats_empty_chunks() {
80        // Test with some empty chunks mixed with non-empty
81        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
82        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullability::Nullable), 0);
83        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
84
85        let dtype = chunk1.dtype().clone();
86        let chunked = ChunkedArray::try_new(
87            vec![
88                chunk1.into_array(),
89                chunk2.into_array(),
90                chunk3.into_array(),
91            ],
92            dtype,
93        )
94        .unwrap();
95
96        // Compute sum: 10.5 + 20.3 + 5.2 = 36.0
97        let result = sum(chunked.as_ref()).unwrap();
98        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
99    }
100
101    #[test]
102    fn test_sum_chunked_int_almost_all_null_chunks() {
103        let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
104        let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
105
106        let dtype = chunk1.dtype().clone();
107        let chunked =
108            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
109
110        let result = sum(chunked.as_ref()).unwrap();
111        assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
112    }
113
114    #[test]
115    fn test_sum_chunked_decimals() {
116        // Create decimal chunks with precision=10, scale=2
117        let decimal_dtype = DecimalDType::new(10, 2);
118        let chunk1 = DecimalArray::new(
119            buffer![100i32, 100i32, 100i32, 100i32, 100i32],
120            decimal_dtype,
121            Validity::AllValid,
122        );
123        let chunk2 = DecimalArray::new(
124            buffer![200i32, 200i32, 200i32],
125            decimal_dtype,
126            Validity::AllValid,
127        );
128        let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
129
130        let dtype = chunk1.dtype().clone();
131        let chunked = ChunkedArray::try_new(
132            vec![
133                chunk1.into_array(),
134                chunk2.into_array(),
135                chunk3.into_array(),
136            ],
137            dtype,
138        )
139        .unwrap();
140
141        // Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00)
142        let result = sum(chunked.as_ref()).unwrap();
143        let decimal_result = result.as_decimal();
144        assert_eq!(
145            decimal_result.decimal_value(),
146            Some(DecimalValue::I256(i256::from_i128(1700)))
147        );
148    }
149
150    #[test]
151    fn test_sum_chunked_decimals_with_nulls() {
152        let decimal_dtype = DecimalDType::new(10, 2);
153
154        // Create chunks with some nulls - all must have same nullability
155        let chunk1 = DecimalArray::new(
156            buffer![100i32, 100i32, 100i32],
157            decimal_dtype,
158            Validity::AllValid,
159        );
160        let chunk2 = DecimalArray::new(
161            buffer![0i32, 0i32],
162            decimal_dtype,
163            Validity::from_iter([false, false]),
164        );
165        let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
166
167        let dtype = chunk1.dtype().clone();
168        let chunked = ChunkedArray::try_new(
169            vec![
170                chunk1.into_array(),
171                chunk2.into_array(),
172                chunk3.into_array(),
173            ],
174            dtype,
175        )
176        .unwrap();
177
178        // Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored)
179        let result = sum(chunked.as_ref()).unwrap();
180        let decimal_result = result.as_decimal();
181        assert_eq!(
182            decimal_result.decimal_value(),
183            Some(DecimalValue::I256(i256::from_i128(700)))
184        );
185    }
186
187    #[test]
188    fn test_sum_chunked_decimals_large() {
189        // Create decimals with precision 3 (max value 999)
190        // Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10)
191        let decimal_dtype = DecimalDType::new(3, 0);
192        let chunk1 = ConstantArray::new(
193            Scalar::decimal(
194                DecimalValue::I16(500),
195                decimal_dtype,
196                Nullability::NonNullable,
197            ),
198            1,
199        );
200        let chunk2 = ConstantArray::new(
201            Scalar::decimal(
202                DecimalValue::I16(600),
203                decimal_dtype,
204                Nullability::NonNullable,
205            ),
206            1,
207        );
208
209        let dtype = chunk1.dtype().clone();
210        let chunked =
211            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
212
213        // Compute sum: 500 + 600 = 1100
214        // Result should have precision 13 (3+10), scale 0
215        let result = sum(chunked.as_ref()).unwrap();
216        let decimal_result = result.as_decimal();
217        assert_eq!(
218            decimal_result.decimal_value(),
219            Some(DecimalValue::I256(i256::from_i128(1100)))
220        );
221        assert_eq!(
222            result.dtype(),
223            &DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable)
224        );
225    }
226}