vortex_array/arrays/chunked/compute/
sum.rs1use num_traits::PrimInt;
5use vortex_dtype::{NativePType, PType, match_each_native_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_err};
7use vortex_scalar::{FromPrimitiveOrF16, Scalar};
8
9use crate::arrays::{ChunkedArray, ChunkedVTable};
10use crate::compute::{SumKernel, SumKernelAdapter, sum};
11use crate::stats::Stat;
12use crate::{ArrayRef, register_kernel};
13
14impl SumKernel for ChunkedVTable {
15 fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
16 let sum_dtype = Stat::Sum
17 .dtype(array.dtype())
18 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
19 let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
20
21 let scalar_value = match_each_native_ptype!(
22 sum_ptype,
23 unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
24 signed: |T| { sum_int::<i64>(array.chunks())?.into() },
25 floating: |T| { sum_float(array.chunks())?.into() }
26 );
27
28 Ok(Scalar::new(sum_dtype, scalar_value))
29 }
30}
31
32register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
33
34fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
35 chunks: &[ArrayRef],
36) -> VortexResult<Option<T>> {
37 let mut result = T::zero();
38 for chunk in chunks {
39 let chunk_sum = sum(chunk)?;
40
41 let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>() else {
42 return Ok(None);
44 };
45
46 let Some(chunk_result) = result.checked_add(&chunk_sum) else {
47 return Ok(None);
49 };
50
51 result = chunk_result;
52 }
53 Ok(Some(result))
54}
55
56fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
57 let mut result = 0f64;
58 for chunk in chunks {
59 if let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() {
60 result += chunk_sum;
61 };
62 }
63 Ok(result)
64}
65
66#[cfg(test)]
67mod tests {
68 use vortex_dtype::Nullability;
69 use vortex_scalar::Scalar;
70
71 use crate::array::IntoArray;
72 use crate::arrays::{ChunkedArray, ConstantArray, PrimitiveArray};
73 use crate::compute::sum;
74
75 #[test]
76 fn test_sum_chunked_floats_with_nulls() {
77 let chunk1 =
79 PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
80
81 let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
82
83 let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
84
85 let dtype = chunk1.dtype().clone();
87 let chunked = ChunkedArray::try_new(
88 vec![
89 chunk1.into_array(),
90 chunk2.into_array(),
91 chunk3.into_array(),
92 ],
93 dtype,
94 )
95 .unwrap();
96
97 let result = sum(chunked.as_ref()).unwrap();
99
100 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
102 }
103
104 #[test]
105 fn test_sum_chunked_floats_all_nulls() {
106 let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
108 let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
109
110 let dtype = chunk1.dtype().clone();
111 let chunked =
112 ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
113
114 let result = sum(chunked.as_ref()).unwrap();
116 assert!(result.as_primitive().as_::<f64>().is_none());
117 }
118
119 #[test]
120 fn test_sum_chunked_floats_empty_chunks() {
121 let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
123 let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullability::Nullable), 0);
124 let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
125
126 let dtype = chunk1.dtype().clone();
127 let chunked = ChunkedArray::try_new(
128 vec![
129 chunk1.into_array(),
130 chunk2.into_array(),
131 chunk3.into_array(),
132 ],
133 dtype,
134 )
135 .unwrap();
136
137 let result = sum(chunked.as_ref()).unwrap();
139 assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
140 }
141}