vortex_array/compute/
sum.rs1use vortex_dtype::PType;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic};
3use vortex_scalar::Scalar;
4
5use crate::Array;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProvider};
8
9pub trait SumFn<A> {
10 fn sum(&self, array: A) -> VortexResult<Scalar>;
15}
16
17impl<E: Encoding> SumFn<&dyn Array> for E
18where
19 E: for<'a> SumFn<&'a E::Array>,
20{
21 fn sum(&self, array: &dyn Array) -> VortexResult<Scalar> {
22 let array_ref = array
23 .as_any()
24 .downcast_ref::<E::Array>()
25 .vortex_expect("Failed to downcast array");
26 SumFn::sum(self, array_ref)
27 }
28}
29
30pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
36 let sum_dtype = Stat::Sum
38 .dtype(array.dtype())
39 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
40
41 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
43 return Ok(Scalar::new(sum_dtype, sum));
44 }
45
46 if array.is_empty() {
47 return if sum_dtype.is_float() {
48 Ok(Scalar::new(sum_dtype, 0.0.into()))
49 } else {
50 Ok(Scalar::new(sum_dtype, 0.into()))
51 };
52 }
53
54 if let Some(mut constant) = array.as_constant() {
56 if constant.is_null() {
57 return if sum_dtype.is_float() {
59 Ok(Scalar::new(sum_dtype, 0.0.into()))
60 } else {
61 Ok(Scalar::new(sum_dtype, 0.into()))
62 };
63 }
64
65 if let Some(extension) = constant.as_extension_opt() {
67 constant = extension.storage();
68 }
69
70 if let Some(bool) = constant.as_bool_opt() {
72 return if bool.value().vortex_expect("already checked for null value") {
73 Ok(Scalar::new(sum_dtype, array.len().into()))
75 } else {
76 Ok(Scalar::new(sum_dtype, 0.into()))
78 };
79 }
80
81 if let Some(primitive) = constant.as_primitive_opt() {
83 match primitive.ptype() {
84 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
85 let value = primitive
86 .pvalue()
87 .vortex_expect("already checked for null value")
88 .as_u64()
89 .vortex_expect("Failed to cast constant value to u64");
90
91 let sum = value.checked_mul(array.len() as u64);
93
94 return Ok(Scalar::new(sum_dtype, sum.into()));
95 }
96 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
97 let value = primitive
98 .pvalue()
99 .vortex_expect("already checked for null value")
100 .as_i64()
101 .vortex_expect("Failed to cast constant value to i64");
102
103 let sum = value.checked_mul(array.len() as i64);
105
106 return Ok(Scalar::new(sum_dtype, sum.into()));
107 }
108 PType::F16 | PType::F32 | PType::F64 => {
109 let value = primitive
110 .pvalue()
111 .vortex_expect("already checked for null value")
112 .as_f64()
113 .vortex_expect("Failed to cast constant value to f64");
114
115 let sum = value * (array.len() as f64);
116
117 return Ok(Scalar::new(sum_dtype, sum.into()));
118 }
119 }
120 }
121
122 unreachable!("Unsupported sum constant: {}", constant.dtype());
124 }
125
126 let sum = if let Some(f) = array.vtable().sum_fn() {
128 f.sum(array)?
129 } else {
130 log::debug!("No sum implementation found for {}", array.encoding());
132
133 let array = array.to_canonical()?;
134 if let Some(f) = array.as_ref().vtable().sum_fn() {
135 f.sum(array.as_ref())?
136 } else {
137 vortex_bail!(
138 "No sum function for canonical array: {}",
139 array.as_ref().encoding(),
140 )
141 }
142 };
143
144 if sum.dtype() != &sum_dtype {
145 vortex_panic!(
146 "Sum function of {} returned scalar with wrong dtype: {:?}",
147 array.encoding(),
148 sum.dtype()
149 );
150 }
151
152 array
154 .statistics()
155 .set(Stat::Sum, Precision::Exact(sum.value().clone()));
156
157 Ok(sum)
158}
159
160#[cfg(test)]
161mod test {
162 use crate::arrays::{BoolArray, PrimitiveArray};
163 use crate::compute::sum;
164
165 #[test]
166 fn sum_all_invalid() {
167 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
168 let result = sum(&array).unwrap();
169 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
170 }
171
172 #[test]
173 fn sum_all_invalid_float() {
174 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
175 let result = sum(&array).unwrap();
176 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
177 }
178
179 #[test]
180 fn sum_constant() {
181 let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
182 let result = sum(&array).unwrap();
183 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
184 }
185
186 #[test]
187 fn sum_constant_float() {
188 let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
189 let result = sum(&array).unwrap();
190 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
191 }
192
193 #[test]
194 fn sum_boolean() {
195 let array = BoolArray::from_iter([true, false, false, true]);
196 let result = sum(&array).unwrap();
197 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
198 }
199}