vortex_array/arrays/chunked/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_scalar::Scalar;
6
7use crate::arrays::ChunkedArray;
8use crate::arrays::ChunkedVTable;
9use crate::compute::SumKernel;
10use crate::compute::SumKernelAdapter;
11use crate::compute::sum_with_accumulator;
12use crate::register_kernel;
13
14impl SumKernel for ChunkedVTable {
15    fn sum(&self, array: &ChunkedArray, accumulator: &Scalar) -> VortexResult<Scalar> {
16        array
17            .chunks
18            .iter()
19            .try_fold(accumulator.clone(), |result, chunk| {
20                sum_with_accumulator(chunk, &result)
21            })
22    }
23}
24
25register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
26
27#[cfg(test)]
28mod tests {
29    use vortex_buffer::buffer;
30    use vortex_dtype::DType;
31    use vortex_dtype::DecimalDType;
32    use vortex_dtype::Nullability;
33    use vortex_scalar::DecimalValue;
34    use vortex_scalar::Scalar;
35    use vortex_scalar::i256;
36
37    use crate::array::IntoArray;
38    use crate::arrays::ChunkedArray;
39    use crate::arrays::ConstantArray;
40    use crate::arrays::DecimalArray;
41    use crate::arrays::PrimitiveArray;
42    use crate::compute::sum;
43    use crate::validity::Validity;
44
45    #[test]
46    fn test_sum_chunked_floats_with_nulls() {
47        // Create chunks with floats including nulls
48        let chunk1 =
49            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
50
51        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
52
53        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
54
55        // Create chunked array from the chunks
56        let dtype = chunk1.dtype().clone();
57        let chunked = ChunkedArray::try_new(
58            vec![
59                chunk1.into_array(),
60                chunk2.into_array(),
61                chunk3.into_array(),
62            ],
63            dtype,
64        )
65        .unwrap();
66
67        // Compute sum
68        let result = sum(chunked.as_ref()).unwrap();
69
70        // Expected sum: 1.5 + 3.2 + 4.8 + 2.1 + 5.7 + 1.0 + 2.5 = 20.8
71        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
72    }
73
74    #[test]
75    fn test_sum_chunked_floats_all_nulls_is_zero() {
76        // Create chunks with all nulls
77        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
78        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
79
80        let dtype = chunk1.dtype().clone();
81        let chunked =
82            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
83        // Compute sum - should return null for all nulls
84        let result = sum(chunked.as_ref()).unwrap();
85        assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
86    }
87
88    #[test]
89    fn test_sum_chunked_floats_empty_chunks() {
90        // Test with some empty chunks mixed with non-empty
91        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
92        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullability::Nullable), 0);
93        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
94
95        let dtype = chunk1.dtype().clone();
96        let chunked = ChunkedArray::try_new(
97            vec![
98                chunk1.into_array(),
99                chunk2.into_array(),
100                chunk3.into_array(),
101            ],
102            dtype,
103        )
104        .unwrap();
105
106        // Compute sum: 10.5 + 20.3 + 5.2 = 36.0
107        let result = sum(chunked.as_ref()).unwrap();
108        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
109    }
110
111    #[test]
112    fn test_sum_chunked_int_almost_all_null_chunks() {
113        let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
114        let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
115
116        let dtype = chunk1.dtype().clone();
117        let chunked =
118            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
119
120        let result = sum(chunked.as_ref()).unwrap();
121        assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
122    }
123
124    #[test]
125    fn test_sum_chunked_decimals() {
126        // Create decimal chunks with precision=10, scale=2
127        let decimal_dtype = DecimalDType::new(10, 2);
128        let chunk1 = DecimalArray::new(
129            buffer![100i32, 100i32, 100i32, 100i32, 100i32],
130            decimal_dtype,
131            Validity::AllValid,
132        );
133        let chunk2 = DecimalArray::new(
134            buffer![200i32, 200i32, 200i32],
135            decimal_dtype,
136            Validity::AllValid,
137        );
138        let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
139
140        let dtype = chunk1.dtype().clone();
141        let chunked = ChunkedArray::try_new(
142            vec![
143                chunk1.into_array(),
144                chunk2.into_array(),
145                chunk3.into_array(),
146            ],
147            dtype,
148        )
149        .unwrap();
150
151        // Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00)
152        let result = sum(chunked.as_ref()).unwrap();
153        let decimal_result = result.as_decimal();
154        assert_eq!(
155            decimal_result.decimal_value(),
156            Some(DecimalValue::I256(i256::from_i128(1700)))
157        );
158    }
159
160    #[test]
161    fn test_sum_chunked_decimals_with_nulls() {
162        let decimal_dtype = DecimalDType::new(10, 2);
163
164        // Create chunks with some nulls - all must have same nullability
165        let chunk1 = DecimalArray::new(
166            buffer![100i32, 100i32, 100i32],
167            decimal_dtype,
168            Validity::AllValid,
169        );
170        let chunk2 = DecimalArray::new(
171            buffer![0i32, 0i32],
172            decimal_dtype,
173            Validity::from_iter([false, false]),
174        );
175        let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
176
177        let dtype = chunk1.dtype().clone();
178        let chunked = ChunkedArray::try_new(
179            vec![
180                chunk1.into_array(),
181                chunk2.into_array(),
182                chunk3.into_array(),
183            ],
184            dtype,
185        )
186        .unwrap();
187
188        // Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored)
189        let result = sum(chunked.as_ref()).unwrap();
190        let decimal_result = result.as_decimal();
191        assert_eq!(
192            decimal_result.decimal_value(),
193            Some(DecimalValue::I256(i256::from_i128(700)))
194        );
195    }
196
197    #[test]
198    fn test_sum_chunked_decimals_large() {
199        // Create decimals with precision 3 (max value 999)
200        // Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10)
201        let decimal_dtype = DecimalDType::new(3, 0);
202        let chunk1 = ConstantArray::new(
203            Scalar::decimal(
204                DecimalValue::I16(500),
205                decimal_dtype,
206                Nullability::NonNullable,
207            ),
208            1,
209        );
210        let chunk2 = ConstantArray::new(
211            Scalar::decimal(
212                DecimalValue::I16(600),
213                decimal_dtype,
214                Nullability::NonNullable,
215            ),
216            1,
217        );
218
219        let dtype = chunk1.dtype().clone();
220        let chunked =
221            ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
222
223        // Compute sum: 500 + 600 = 1100
224        // Result should have precision 13 (3+10), scale 0
225        let result = sum(chunked.as_ref()).unwrap();
226        let decimal_result = result.as_decimal();
227        assert_eq!(
228            decimal_result.decimal_value(),
229            Some(DecimalValue::I256(i256::from_i128(1100)))
230        );
231        assert_eq!(
232            result.dtype(),
233            &DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable)
234        );
235    }
236}