vortex_array/aggregate_fn/fns/mean/
mod.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_session::registry::CachedId;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::aggregate_fn::Accumulator;
11use crate::aggregate_fn::AggregateFnId;
12use crate::aggregate_fn::AggregateFnVTable;
13use crate::aggregate_fn::DynAccumulator;
14use crate::aggregate_fn::NumericalAggregateOpts;
15use crate::aggregate_fn::combined::BinaryCombined;
16use crate::aggregate_fn::combined::Combined;
17use crate::aggregate_fn::combined::CombinedOptions;
18use crate::aggregate_fn::combined::PairOptions;
19use crate::aggregate_fn::fns::count::Count;
20use crate::aggregate_fn::fns::sum::Sum;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::dtype::PType;
25use crate::scalar::Scalar;
26use crate::scalar_fn::fns::operators::Operator;
27
28pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
32 let mut acc = Accumulator::try_new(
33 Mean::combined(),
34 PairOptions(
35 NumericalAggregateOpts::default(),
36 NumericalAggregateOpts::default(),
37 ),
38 array.dtype().clone(),
39 )?;
40 acc.accumulate(array, ctx)?;
41 acc.finish()
42}
43
44#[derive(Clone, Debug)]
53pub struct Mean;
54
55impl Mean {
56 pub fn combined() -> Combined<Self> {
57 Combined(Mean)
58 }
59}
60
61impl BinaryCombined for Mean {
62 type Left = Sum;
63 type Right = Count;
64
65 fn id(&self) -> AggregateFnId {
66 static ID: CachedId = CachedId::new("vortex.mean");
67 *ID
68 }
69
70 fn left(&self) -> Sum {
71 Sum
72 }
73
74 fn right(&self) -> Count {
75 Count
76 }
77
78 fn left_name(&self) -> &'static str {
79 "sum"
80 }
81
82 fn right_name(&self) -> &'static str {
83 "count"
84 }
85
86 fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
87 Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable))
88 }
89
90 fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult<ArrayRef> {
91 let target = match sum.dtype() {
92 DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable),
93 _ => DType::Primitive(PType::F64, Nullability::Nullable),
94 };
95 let sum_cast = sum.cast(target.clone())?;
96 let count_cast = count.cast(target)?;
97 sum_cast.binary(count_cast, Operator::Div)
98 }
99
100 fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
101 if let DType::Decimal(..) = left_scalar.dtype() {
102 vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs");
103 }
104
105 let target = DType::Primitive(PType::F64, Nullability::Nullable);
106 let sum_cast = left_scalar.cast(&target)?;
107 let count_cast = right_scalar.cast(&target)?;
108
109 let sum = sum_cast.as_primitive().typed_value::<f64>();
110 let count = count_cast.as_primitive().typed_value::<f64>();
111 let value = match (sum, count) {
112 (None, _) | (_, None) => return Ok(Scalar::null(target)), (Some(s), Some(c)) => s / c,
116 };
117 Ok(Scalar::primitive(value, Nullability::Nullable))
118 }
119
120 fn serialize(&self, _options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
121 unimplemented!("mean is not yet serializable");
122 }
123
124 fn coerce_args(
125 &self,
126 _options: &PairOptions<
127 <Sum as AggregateFnVTable>::Options,
128 <Count as AggregateFnVTable>::Options,
129 >,
130 input_dtype: &DType,
131 ) -> VortexResult<DType> {
132 Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone()))
135 }
136}
137
138fn coerced_input_dtype(input_dtype: &DType) -> Option<DType> {
144 match input_dtype {
145 DType::Bool(_) => Some(input_dtype.clone()),
146 DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)),
147 DType::Decimal(..) => {
148 unimplemented!("mean is not implemented for decimals yet")
149 }
150 _ => None,
151 }
152}
153
154fn mean_output_dtype(input_dtype: &DType) -> Option<DType> {
155 match input_dtype {
156 DType::Bool(_) | DType::Primitive(..) => {
157 Some(DType::Primitive(PType::F64, Nullability::Nullable))
158 }
159 DType::Decimal(..) => {
160 unimplemented!("mean for decimals is not yet implemented");
161 }
162 _ => None,
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use vortex_buffer::buffer;
169 use vortex_error::VortexResult;
170
171 use super::*;
172 use crate::IntoArray;
173 use crate::VortexSessionExecute;
174 use crate::array_session;
175 use crate::arrays::BoolArray;
176 use crate::arrays::ChunkedArray;
177 use crate::arrays::ConstantArray;
178 use crate::arrays::PrimitiveArray;
179 use crate::validity::Validity;
180
181 #[test]
182 fn mean_all_valid() -> VortexResult<()> {
183 let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable)
184 .into_array();
185 let mut ctx = array_session().create_execution_ctx();
186 let result = mean(&array, &mut ctx)?;
187 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
188 Ok(())
189 }
190
191 #[test]
192 fn mean_with_nulls() -> VortexResult<()> {
193 let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array();
194 let mut ctx = array_session().create_execution_ctx();
195 let result = mean(&array, &mut ctx)?;
196 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
197 Ok(())
198 }
199
200 #[test]
201 fn mean_integers() -> VortexResult<()> {
202 let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array();
203 let mut ctx = array_session().create_execution_ctx();
204 let result = mean(&array, &mut ctx)?;
205 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.0));
206 Ok(())
207 }
208
209 #[test]
210 fn mean_bool() -> VortexResult<()> {
211 let array: BoolArray = [true, false, true, true].into_iter().collect();
212 let mut ctx = array_session().create_execution_ctx();
213 let result = mean(&array.into_array(), &mut ctx)?;
214 assert_eq!(result.as_primitive().as_::<f64>(), Some(0.75));
215 Ok(())
216 }
217
218 #[test]
219 fn mean_constant_non_null() -> VortexResult<()> {
220 let array = ConstantArray::new(5.0f64, 4);
221 let mut ctx = array_session().create_execution_ctx();
222 let result = mean(&array.into_array(), &mut ctx)?;
223 assert_eq!(result.as_primitive().as_::<f64>(), Some(5.0));
224 Ok(())
225 }
226
227 #[test]
228 fn mean_chunked() -> VortexResult<()> {
229 let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]);
230 let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]);
231 let dtype = chunk1.dtype().clone();
232 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
233 let mut ctx = array_session().create_execution_ctx();
234 let result = mean(&chunked.into_array(), &mut ctx)?;
235 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
236 Ok(())
237 }
238
239 #[test]
240 fn mean_skips_nans_by_default() -> VortexResult<()> {
241 let array =
243 PrimitiveArray::new(buffer![1.0f64, f64::NAN, 3.0], Validity::NonNullable).into_array();
244 let mut ctx = array_session().create_execution_ctx();
245 let result = mean(&array, &mut ctx)?;
246 assert_eq!(result.as_primitive().as_::<f64>(), Some(2.0));
247 Ok(())
248 }
249
250 #[test]
251 fn mean_with_nan_not_skipping() -> VortexResult<()> {
252 let array =
253 PrimitiveArray::new(buffer![1.0f64, f64::NAN, 3.0], Validity::NonNullable).into_array();
254 let mut ctx = array_session().create_execution_ctx();
255 let keep_nans = NumericalAggregateOpts::include_nans();
256 let mut acc = Accumulator::try_new(
257 Mean::combined(),
258 PairOptions(keep_nans, keep_nans),
259 array.dtype().clone(),
260 )?;
261 acc.accumulate(&array, &mut ctx)?;
262 let result = acc.finish()?;
263 assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
264 Ok(())
265 }
266
267 #[test]
268 fn mean_all_null_returns_nan() -> VortexResult<()> {
269 let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
270 let mut ctx = array_session().create_execution_ctx();
271 let result = mean(&array, &mut ctx)?;
272 assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
273 Ok(())
274 }
275
276 #[test]
277 fn mean_multi_batch() -> VortexResult<()> {
278 let mut ctx = array_session().create_execution_ctx();
279 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
280 let mut acc = Accumulator::try_new(
281 Mean::combined(),
282 PairOptions(
283 NumericalAggregateOpts::default(),
284 NumericalAggregateOpts::default(),
285 ),
286 dtype,
287 )?;
288
289 let batch1 =
290 PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
291 acc.accumulate(&batch1, &mut ctx)?;
292
293 let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array();
294 acc.accumulate(&batch2, &mut ctx)?;
295
296 let result = acc.finish()?;
297 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
298 Ok(())
299 }
300}