vortex_array/aggregate_fn/fns/mean/
mod.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::aggregate_fn::Accumulator;
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnVTable;
12use crate::aggregate_fn::DynAccumulator;
13use crate::aggregate_fn::EmptyOptions;
14use crate::aggregate_fn::combined::BinaryCombined;
15use crate::aggregate_fn::combined::Combined;
16use crate::aggregate_fn::combined::CombinedOptions;
17use crate::aggregate_fn::combined::PairOptions;
18use crate::aggregate_fn::fns::count::Count;
19use crate::aggregate_fn::fns::sum::Sum;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::dtype::Nullability;
23use crate::dtype::PType;
24use crate::scalar::Scalar;
25use crate::scalar_fn::fns::operators::Operator;
26
27pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
31 let mut acc = Accumulator::try_new(
32 Mean::combined(),
33 PairOptions(EmptyOptions, EmptyOptions),
34 array.dtype().clone(),
35 )?;
36 acc.accumulate(array, ctx)?;
37 acc.finish()
38}
39
40#[derive(Clone, Debug)]
49pub struct Mean;
50
51impl Mean {
52 pub fn combined() -> Combined<Self> {
53 Combined(Mean)
54 }
55}
56
57impl BinaryCombined for Mean {
58 type Left = Sum;
59 type Right = Count;
60
61 fn id(&self) -> AggregateFnId {
62 AggregateFnId::new("vortex.mean")
63 }
64
65 fn left(&self) -> Sum {
66 Sum
67 }
68
69 fn right(&self) -> Count {
70 Count
71 }
72
73 fn left_name(&self) -> &'static str {
74 "sum"
75 }
76
77 fn right_name(&self) -> &'static str {
78 "count"
79 }
80
81 fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
82 Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable))
83 }
84
85 fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult<ArrayRef> {
86 let target = match sum.dtype() {
87 DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable),
88 _ => DType::Primitive(PType::F64, Nullability::Nullable),
89 };
90 let sum_cast = sum.cast(target.clone())?;
91 let count_cast = count.cast(target)?;
92 sum_cast.binary(count_cast, Operator::Div)
93 }
94
95 fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
96 if let DType::Decimal(..) = left_scalar.dtype() {
97 vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs");
98 }
99
100 let target = DType::Primitive(PType::F64, Nullability::Nullable);
101 let sum_cast = left_scalar.cast(&target)?;
102 let count_cast = right_scalar.cast(&target)?;
103
104 let sum = sum_cast.as_primitive().typed_value::<f64>();
105 let count = count_cast.as_primitive().typed_value::<f64>();
106 let value = match (sum, count) {
107 (None, _) | (_, None) => return Ok(Scalar::null(target)), (Some(s), Some(c)) => s / c,
111 };
112 Ok(Scalar::primitive(value, Nullability::Nullable))
113 }
114
115 fn serialize(&self, _options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
116 unimplemented!("mean is not yet serializable");
117 }
118
119 fn coerce_args(
120 &self,
121 _options: &PairOptions<
122 <Sum as AggregateFnVTable>::Options,
123 <Count as AggregateFnVTable>::Options,
124 >,
125 input_dtype: &DType,
126 ) -> VortexResult<DType> {
127 Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone()))
130 }
131}
132
133fn coerced_input_dtype(input_dtype: &DType) -> Option<DType> {
139 match input_dtype {
140 DType::Bool(_) => Some(input_dtype.clone()),
141 DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)),
142 DType::Decimal(..) => {
143 unimplemented!("mean is not implemented for decimals yet")
144 }
145 _ => None,
146 }
147}
148
149fn mean_output_dtype(input_dtype: &DType) -> Option<DType> {
150 match input_dtype {
151 DType::Bool(_) | DType::Primitive(..) => {
152 Some(DType::Primitive(PType::F64, Nullability::Nullable))
153 }
154 DType::Decimal(..) => {
155 unimplemented!("mean for decimals is not yet implemented");
156 }
157 _ => None,
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use vortex_buffer::buffer;
164 use vortex_error::VortexResult;
165
166 use super::*;
167 use crate::IntoArray;
168 use crate::LEGACY_SESSION;
169 use crate::VortexSessionExecute;
170 use crate::arrays::BoolArray;
171 use crate::arrays::ChunkedArray;
172 use crate::arrays::ConstantArray;
173 use crate::arrays::PrimitiveArray;
174 use crate::validity::Validity;
175
176 #[test]
177 fn mean_all_valid() -> VortexResult<()> {
178 let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable)
179 .into_array();
180 let mut ctx = LEGACY_SESSION.create_execution_ctx();
181 let result = mean(&array, &mut ctx)?;
182 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
183 Ok(())
184 }
185
186 #[test]
187 fn mean_with_nulls() -> VortexResult<()> {
188 let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array();
189 let mut ctx = LEGACY_SESSION.create_execution_ctx();
190 let result = mean(&array, &mut ctx)?;
191 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
192 Ok(())
193 }
194
195 #[test]
196 fn mean_integers() -> VortexResult<()> {
197 let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array();
198 let mut ctx = LEGACY_SESSION.create_execution_ctx();
199 let result = mean(&array, &mut ctx)?;
200 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.0));
201 Ok(())
202 }
203
204 #[test]
205 fn mean_bool() -> VortexResult<()> {
206 let array: BoolArray = [true, false, true, true].into_iter().collect();
207 let mut ctx = LEGACY_SESSION.create_execution_ctx();
208 let result = mean(&array.into_array(), &mut ctx)?;
209 assert_eq!(result.as_primitive().as_::<f64>(), Some(0.75));
210 Ok(())
211 }
212
213 #[test]
214 fn mean_constant_non_null() -> VortexResult<()> {
215 let array = ConstantArray::new(5.0f64, 4);
216 let mut ctx = LEGACY_SESSION.create_execution_ctx();
217 let result = mean(&array.into_array(), &mut ctx)?;
218 assert_eq!(result.as_primitive().as_::<f64>(), Some(5.0));
219 Ok(())
220 }
221
222 #[test]
223 fn mean_chunked() -> VortexResult<()> {
224 let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]);
225 let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]);
226 let dtype = chunk1.dtype().clone();
227 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
228 let mut ctx = LEGACY_SESSION.create_execution_ctx();
229 let result = mean(&chunked.into_array(), &mut ctx)?;
230 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
231 Ok(())
232 }
233
234 #[test]
235 fn mean_all_null_returns_nan() -> VortexResult<()> {
236 let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
237 let mut ctx = LEGACY_SESSION.create_execution_ctx();
238 let result = mean(&array, &mut ctx)?;
239 assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
240 Ok(())
241 }
242
243 #[test]
244 fn mean_multi_batch() -> VortexResult<()> {
245 let mut ctx = LEGACY_SESSION.create_execution_ctx();
246 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
247 let mut acc = Accumulator::try_new(
248 Mean::combined(),
249 PairOptions(EmptyOptions, EmptyOptions),
250 dtype,
251 )?;
252
253 let batch1 =
254 PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
255 acc.accumulate(&batch1, &mut ctx)?;
256
257 let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array();
258 acc.accumulate(&batch2, &mut ctx)?;
259
260 let result = acc.finish()?;
261 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
262 Ok(())
263 }
264}