Skip to main content

vortex_array/aggregate_fn/fns/mean/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_session::registry::CachedId;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::aggregate_fn::Accumulator;
11use crate::aggregate_fn::AggregateFnId;
12use crate::aggregate_fn::AggregateFnVTable;
13use crate::aggregate_fn::DynAccumulator;
14use crate::aggregate_fn::NumericalAggregateOpts;
15use crate::aggregate_fn::combined::BinaryCombined;
16use crate::aggregate_fn::combined::Combined;
17use crate::aggregate_fn::combined::CombinedOptions;
18use crate::aggregate_fn::combined::PairOptions;
19use crate::aggregate_fn::fns::count::Count;
20use crate::aggregate_fn::fns::sum::Sum;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::dtype::PType;
25use crate::scalar::Scalar;
26use crate::scalar_fn::fns::operators::Operator;
27
28/// Compute the arithmetic mean of an array.
29///
30/// See [`Mean`] for details.
31pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
32    let mut acc = Accumulator::try_new(
33        Mean::combined(),
34        PairOptions(
35            NumericalAggregateOpts::default(),
36            NumericalAggregateOpts::default(),
37        ),
38        array.dtype().clone(),
39    )?;
40    acc.accumulate(array, ctx)?;
41    acc.finish()
42}
43
44/// Compute the arithmetic mean of an array.
45///
46/// Implemented as `Sum / Count` via [`BinaryCombined`].
47///
48/// Coercion / return type:
49/// - Booleans and primitive numeric types are coerced to `f64` and the result
50///   is a nullable `f64`.
51/// - Decimals are kept as decimals but not implemented currently
52#[derive(Clone, Debug)]
53pub struct Mean;
54
55impl Mean {
56    pub fn combined() -> Combined<Self> {
57        Combined(Mean)
58    }
59}
60
61impl BinaryCombined for Mean {
62    type Left = Sum;
63    type Right = Count;
64
65    fn id(&self) -> AggregateFnId {
66        static ID: CachedId = CachedId::new("vortex.mean");
67        *ID
68    }
69
70    fn left(&self) -> Sum {
71        Sum
72    }
73
74    fn right(&self) -> Count {
75        Count
76    }
77
78    fn left_name(&self) -> &'static str {
79        "sum"
80    }
81
82    fn right_name(&self) -> &'static str {
83        "count"
84    }
85
86    fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
87        Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable))
88    }
89
90    fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult<ArrayRef> {
91        let target = match sum.dtype() {
92            DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable),
93            _ => DType::Primitive(PType::F64, Nullability::Nullable),
94        };
95        let sum_cast = sum.cast(target.clone())?;
96        let count_cast = count.cast(target)?;
97        sum_cast.binary(count_cast, Operator::Div)
98    }
99
100    fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
101        if let DType::Decimal(..) = left_scalar.dtype() {
102            vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs");
103        }
104
105        let target = DType::Primitive(PType::F64, Nullability::Nullable);
106        let sum_cast = left_scalar.cast(&target)?;
107        let count_cast = right_scalar.cast(&target)?;
108
109        let sum = sum_cast.as_primitive().typed_value::<f64>();
110        let count = count_cast.as_primitive().typed_value::<f64>();
111        let value = match (sum, count) {
112            (None, _) | (_, None) => return Ok(Scalar::null(target)), // Sum overflowed
113            // A count of zero yields 0/0 = NaN, matching the array `finalize` path: nulls are
114            // skipped during accumulation, so an all-null input is an empty mean, not null.
115            (Some(s), Some(c)) => s / c,
116        };
117        Ok(Scalar::primitive(value, Nullability::Nullable))
118    }
119
120    fn serialize(&self, _options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
121        unimplemented!("mean is not yet serializable");
122    }
123
124    fn coerce_args(
125        &self,
126        _options: &PairOptions<
127            <Sum as AggregateFnVTable>::Options,
128            <Count as AggregateFnVTable>::Options,
129        >,
130        input_dtype: &DType,
131    ) -> VortexResult<DType> {
132        // Advisory hint for query planners: where possible, cast input to the
133        // type we're going to compute the mean in.
134        Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone()))
135    }
136}
137
138/// Hint for callers: what to cast the input to before accumulation.
139///
140/// - Bool stays as bool — `Sum` has a native bool path and bool → f64 isn't
141///   currently a direct cast in vortex.
142/// - Primitive numerics → `f64` so the sum and finalize work without overflow.
143fn coerced_input_dtype(input_dtype: &DType) -> Option<DType> {
144    match input_dtype {
145        DType::Bool(_) => Some(input_dtype.clone()),
146        DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)),
147        DType::Decimal(..) => {
148            unimplemented!("mean is not implemented for decimals yet")
149        }
150        _ => None,
151    }
152}
153
154fn mean_output_dtype(input_dtype: &DType) -> Option<DType> {
155    match input_dtype {
156        DType::Bool(_) | DType::Primitive(..) => {
157            Some(DType::Primitive(PType::F64, Nullability::Nullable))
158        }
159        DType::Decimal(..) => {
160            unimplemented!("mean for decimals is not yet implemented");
161        }
162        _ => None,
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use vortex_buffer::buffer;
169    use vortex_error::VortexResult;
170
171    use super::*;
172    use crate::IntoArray;
173    use crate::VortexSessionExecute;
174    use crate::array_session;
175    use crate::arrays::BoolArray;
176    use crate::arrays::ChunkedArray;
177    use crate::arrays::ConstantArray;
178    use crate::arrays::PrimitiveArray;
179    use crate::validity::Validity;
180
181    #[test]
182    fn mean_all_valid() -> VortexResult<()> {
183        let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable)
184            .into_array();
185        let mut ctx = array_session().create_execution_ctx();
186        let result = mean(&array, &mut ctx)?;
187        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
188        Ok(())
189    }
190
191    #[test]
192    fn mean_with_nulls() -> VortexResult<()> {
193        let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array();
194        let mut ctx = array_session().create_execution_ctx();
195        let result = mean(&array, &mut ctx)?;
196        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
197        Ok(())
198    }
199
200    #[test]
201    fn mean_integers() -> VortexResult<()> {
202        let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array();
203        let mut ctx = array_session().create_execution_ctx();
204        let result = mean(&array, &mut ctx)?;
205        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.0));
206        Ok(())
207    }
208
209    #[test]
210    fn mean_bool() -> VortexResult<()> {
211        let array: BoolArray = [true, false, true, true].into_iter().collect();
212        let mut ctx = array_session().create_execution_ctx();
213        let result = mean(&array.into_array(), &mut ctx)?;
214        assert_eq!(result.as_primitive().as_::<f64>(), Some(0.75));
215        Ok(())
216    }
217
218    #[test]
219    fn mean_constant_non_null() -> VortexResult<()> {
220        let array = ConstantArray::new(5.0f64, 4);
221        let mut ctx = array_session().create_execution_ctx();
222        let result = mean(&array.into_array(), &mut ctx)?;
223        assert_eq!(result.as_primitive().as_::<f64>(), Some(5.0));
224        Ok(())
225    }
226
227    #[test]
228    fn mean_chunked() -> VortexResult<()> {
229        let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]);
230        let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]);
231        let dtype = chunk1.dtype().clone();
232        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
233        let mut ctx = array_session().create_execution_ctx();
234        let result = mean(&chunked.into_array(), &mut ctx)?;
235        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
236        Ok(())
237    }
238
239    #[test]
240    fn mean_skips_nans_by_default() -> VortexResult<()> {
241        // NaNs are excluded from both the sum and the count.
242        let array =
243            PrimitiveArray::new(buffer![1.0f64, f64::NAN, 3.0], Validity::NonNullable).into_array();
244        let mut ctx = array_session().create_execution_ctx();
245        let result = mean(&array, &mut ctx)?;
246        assert_eq!(result.as_primitive().as_::<f64>(), Some(2.0));
247        Ok(())
248    }
249
250    #[test]
251    fn mean_with_nan_not_skipping() -> VortexResult<()> {
252        let array =
253            PrimitiveArray::new(buffer![1.0f64, f64::NAN, 3.0], Validity::NonNullable).into_array();
254        let mut ctx = array_session().create_execution_ctx();
255        let keep_nans = NumericalAggregateOpts::include_nans();
256        let mut acc = Accumulator::try_new(
257            Mean::combined(),
258            PairOptions(keep_nans, keep_nans),
259            array.dtype().clone(),
260        )?;
261        acc.accumulate(&array, &mut ctx)?;
262        let result = acc.finish()?;
263        assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
264        Ok(())
265    }
266
267    #[test]
268    fn mean_all_null_returns_nan() -> VortexResult<()> {
269        let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
270        let mut ctx = array_session().create_execution_ctx();
271        let result = mean(&array, &mut ctx)?;
272        assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
273        Ok(())
274    }
275
276    #[test]
277    fn mean_multi_batch() -> VortexResult<()> {
278        let mut ctx = array_session().create_execution_ctx();
279        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
280        let mut acc = Accumulator::try_new(
281            Mean::combined(),
282            PairOptions(
283                NumericalAggregateOpts::default(),
284                NumericalAggregateOpts::default(),
285            ),
286            dtype,
287        )?;
288
289        let batch1 =
290            PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
291        acc.accumulate(&batch1, &mut ctx)?;
292
293        let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array();
294        acc.accumulate(&batch2, &mut ctx)?;
295
296        let result = acc.finish()?;
297        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
298        Ok(())
299    }
300}