vortex_array/compute/
sum.rs1use vortex_dtype::PType;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic};
3use vortex_scalar::Scalar;
4
5use crate::Array;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProvider};
8
9pub trait SumFn<A> {
10 fn sum(&self, array: A) -> VortexResult<Scalar>;
15}
16
17impl<E: Encoding> SumFn<&dyn Array> for E
18where
19 E: for<'a> SumFn<&'a E::Array>,
20{
21 fn sum(&self, array: &dyn Array) -> VortexResult<Scalar> {
22 let array_ref = array
23 .as_any()
24 .downcast_ref::<E::Array>()
25 .vortex_expect("Failed to downcast array");
26 SumFn::sum(self, array_ref)
27 }
28}
29
30pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
36 let sum_dtype = Stat::Sum
38 .dtype(array.dtype())
39 .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
40
41 if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
43 return Ok(Scalar::new(sum_dtype, sum));
44 }
45
46 if let Some(mut constant) = array.as_constant() {
48 if constant.is_null() {
49 return if PType::try_from(&sum_dtype)
51 .vortex_expect("must be primitive")
52 .is_float()
53 {
54 Ok(Scalar::new(sum_dtype, 0.0.into()))
55 } else {
56 Ok(Scalar::new(sum_dtype, 0.into()))
57 };
58 }
59
60 if let Some(extension) = constant.as_extension_opt() {
62 constant = extension.storage();
63 }
64
65 if let Some(bool) = constant.as_bool_opt() {
67 return if bool.value().vortex_expect("already checked for null value") {
68 Ok(Scalar::new(sum_dtype, array.len().into()))
70 } else {
71 Ok(Scalar::new(sum_dtype, 0.into()))
73 };
74 }
75
76 if let Some(primitive) = constant.as_primitive_opt() {
78 match primitive.ptype() {
79 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
80 let value = primitive
81 .pvalue()
82 .vortex_expect("already checked for null value")
83 .as_u64()
84 .vortex_expect("Failed to cast constant value to u64");
85
86 let sum = value.checked_mul(array.len() as u64);
88
89 return Ok(Scalar::new(sum_dtype, sum.into()));
90 }
91 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
92 let value = primitive
93 .pvalue()
94 .vortex_expect("already checked for null value")
95 .as_i64()
96 .vortex_expect("Failed to cast constant value to i64");
97
98 let sum = value.checked_mul(array.len() as i64);
100
101 return Ok(Scalar::new(sum_dtype, sum.into()));
102 }
103 PType::F16 | PType::F32 | PType::F64 => {
104 let value = primitive
105 .pvalue()
106 .vortex_expect("already checked for null value")
107 .as_f64()
108 .vortex_expect("Failed to cast constant value to f64");
109
110 let sum = value * (array.len() as f64);
111
112 return Ok(Scalar::new(sum_dtype, sum.into()));
113 }
114 }
115 }
116
117 unreachable!("Unsupported sum constant: {}", constant.dtype());
119 }
120
121 let sum = if let Some(f) = array.vtable().sum_fn() {
123 f.sum(array)?
124 } else {
125 log::debug!("No sum implementation found for {}", array.encoding());
127
128 let array = array.to_canonical()?;
129 if let Some(f) = array.as_ref().vtable().sum_fn() {
130 f.sum(array.as_ref())?
131 } else {
132 vortex_bail!(
133 "No sum function for canonical array: {}",
134 array.as_ref().encoding(),
135 )
136 }
137 };
138
139 if sum.dtype() != &sum_dtype {
140 vortex_panic!(
141 "Sum function of {} returned scalar with wrong dtype: {:?}",
142 array.encoding(),
143 sum.dtype()
144 );
145 }
146
147 array
149 .statistics()
150 .set(Stat::Sum, Precision::Exact(sum.value().clone()));
151
152 Ok(sum)
153}
154
155#[cfg(test)]
156mod test {
157 use crate::arrays::{BoolArray, PrimitiveArray};
158 use crate::compute::sum;
159
160 #[test]
161 fn sum_all_invalid() {
162 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
163 let result = sum(&array).unwrap();
164 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
165 }
166
167 #[test]
168 fn sum_all_invalid_float() {
169 let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
170 let result = sum(&array).unwrap();
171 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
172 }
173
174 #[test]
175 fn sum_constant() {
176 let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
177 let result = sum(&array).unwrap();
178 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
179 }
180
181 #[test]
182 fn sum_constant_float() {
183 let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
184 let result = sum(&array).unwrap();
185 assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
186 }
187
188 #[test]
189 fn sum_boolean() {
190 let array = BoolArray::from_iter([true, false, false, true]);
191 let result = sum(&array).unwrap();
192 assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
193 }
194}