Skip to main content

vortex_array/aggregate_fn/fns/mean/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::aggregate_fn::Accumulator;
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnVTable;
12use crate::aggregate_fn::DynAccumulator;
13use crate::aggregate_fn::EmptyOptions;
14use crate::aggregate_fn::combined::BinaryCombined;
15use crate::aggregate_fn::combined::Combined;
16use crate::aggregate_fn::combined::CombinedOptions;
17use crate::aggregate_fn::combined::PairOptions;
18use crate::aggregate_fn::fns::count::Count;
19use crate::aggregate_fn::fns::sum::Sum;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::dtype::Nullability;
23use crate::dtype::PType;
24use crate::scalar::Scalar;
25use crate::scalar_fn::fns::operators::Operator;
26
27/// Compute the arithmetic mean of an array.
28///
29/// See [`Mean`] for details.
30pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
31    let mut acc = Accumulator::try_new(
32        Mean::combined(),
33        PairOptions(EmptyOptions, EmptyOptions),
34        array.dtype().clone(),
35    )?;
36    acc.accumulate(array, ctx)?;
37    acc.finish()
38}
39
40/// Compute the arithmetic mean of an array.
41///
42/// Implemented as `Sum / Count` via [`BinaryCombined`].
43///
44/// Coercion / return type:
45/// - Booleans and primitive numeric types are coerced to `f64` and the result
46///   is a nullable `f64`.
47/// - Decimals are kept as decimals but not implemented currently
48#[derive(Clone, Debug)]
49pub struct Mean;
50
51impl Mean {
52    pub fn combined() -> Combined<Self> {
53        Combined(Mean)
54    }
55}
56
57impl BinaryCombined for Mean {
58    type Left = Sum;
59    type Right = Count;
60
61    fn id(&self) -> AggregateFnId {
62        AggregateFnId::new("vortex.mean")
63    }
64
65    fn left(&self) -> Sum {
66        Sum
67    }
68
69    fn right(&self) -> Count {
70        Count
71    }
72
73    fn left_name(&self) -> &'static str {
74        "sum"
75    }
76
77    fn right_name(&self) -> &'static str {
78        "count"
79    }
80
81    fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
82        Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable))
83    }
84
85    fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult<ArrayRef> {
86        let target = match sum.dtype() {
87            DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable),
88            _ => DType::Primitive(PType::F64, Nullability::Nullable),
89        };
90        let sum_cast = sum.cast(target.clone())?;
91        let count_cast = count.cast(target)?;
92        sum_cast.binary(count_cast, Operator::Div)
93    }
94
95    fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
96        if let DType::Decimal(..) = left_scalar.dtype() {
97            vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs");
98        }
99
100        let target = DType::Primitive(PType::F64, Nullability::Nullable);
101        let sum_cast = left_scalar.cast(&target)?;
102        let count_cast = right_scalar.cast(&target)?;
103
104        let sum = sum_cast.as_primitive().typed_value::<f64>();
105        let count = count_cast.as_primitive().typed_value::<f64>();
106        let value = match (sum, count) {
107            (None, _) | (_, None) => return Ok(Scalar::null(target)), // Sum overflowed
108            // A count of zero yields 0/0 = NaN, matching the array `finalize` path: nulls are
109            // skipped during accumulation, so an all-null input is an empty mean, not null.
110            (Some(s), Some(c)) => s / c,
111        };
112        Ok(Scalar::primitive(value, Nullability::Nullable))
113    }
114
115    fn serialize(&self, _options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
116        unimplemented!("mean is not yet serializable");
117    }
118
119    fn coerce_args(
120        &self,
121        _options: &PairOptions<
122            <Sum as AggregateFnVTable>::Options,
123            <Count as AggregateFnVTable>::Options,
124        >,
125        input_dtype: &DType,
126    ) -> VortexResult<DType> {
127        // Advisory hint for query planners: where possible, cast input to the
128        // type we're going to compute the mean in.
129        Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone()))
130    }
131}
132
133/// Hint for callers: what to cast the input to before accumulation.
134///
135/// - Bool stays as bool — `Sum` has a native bool path and bool → f64 isn't
136///   currently a direct cast in vortex.
137/// - Primitive numerics → `f64` so the sum and finalize work without overflow.
138fn coerced_input_dtype(input_dtype: &DType) -> Option<DType> {
139    match input_dtype {
140        DType::Bool(_) => Some(input_dtype.clone()),
141        DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)),
142        DType::Decimal(..) => {
143            unimplemented!("mean is not implemented for decimals yet")
144        }
145        _ => None,
146    }
147}
148
149fn mean_output_dtype(input_dtype: &DType) -> Option<DType> {
150    match input_dtype {
151        DType::Bool(_) | DType::Primitive(..) => {
152            Some(DType::Primitive(PType::F64, Nullability::Nullable))
153        }
154        DType::Decimal(..) => {
155            unimplemented!("mean for decimals is not yet implemented");
156        }
157        _ => None,
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use vortex_buffer::buffer;
164    use vortex_error::VortexResult;
165
166    use super::*;
167    use crate::IntoArray;
168    use crate::LEGACY_SESSION;
169    use crate::VortexSessionExecute;
170    use crate::arrays::BoolArray;
171    use crate::arrays::ChunkedArray;
172    use crate::arrays::ConstantArray;
173    use crate::arrays::PrimitiveArray;
174    use crate::validity::Validity;
175
176    #[test]
177    fn mean_all_valid() -> VortexResult<()> {
178        let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable)
179            .into_array();
180        let mut ctx = LEGACY_SESSION.create_execution_ctx();
181        let result = mean(&array, &mut ctx)?;
182        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
183        Ok(())
184    }
185
186    #[test]
187    fn mean_with_nulls() -> VortexResult<()> {
188        let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array();
189        let mut ctx = LEGACY_SESSION.create_execution_ctx();
190        let result = mean(&array, &mut ctx)?;
191        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
192        Ok(())
193    }
194
195    #[test]
196    fn mean_integers() -> VortexResult<()> {
197        let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array();
198        let mut ctx = LEGACY_SESSION.create_execution_ctx();
199        let result = mean(&array, &mut ctx)?;
200        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.0));
201        Ok(())
202    }
203
204    #[test]
205    fn mean_bool() -> VortexResult<()> {
206        let array: BoolArray = [true, false, true, true].into_iter().collect();
207        let mut ctx = LEGACY_SESSION.create_execution_ctx();
208        let result = mean(&array.into_array(), &mut ctx)?;
209        assert_eq!(result.as_primitive().as_::<f64>(), Some(0.75));
210        Ok(())
211    }
212
213    #[test]
214    fn mean_constant_non_null() -> VortexResult<()> {
215        let array = ConstantArray::new(5.0f64, 4);
216        let mut ctx = LEGACY_SESSION.create_execution_ctx();
217        let result = mean(&array.into_array(), &mut ctx)?;
218        assert_eq!(result.as_primitive().as_::<f64>(), Some(5.0));
219        Ok(())
220    }
221
222    #[test]
223    fn mean_chunked() -> VortexResult<()> {
224        let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]);
225        let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]);
226        let dtype = chunk1.dtype().clone();
227        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
228        let mut ctx = LEGACY_SESSION.create_execution_ctx();
229        let result = mean(&chunked.into_array(), &mut ctx)?;
230        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
231        Ok(())
232    }
233
234    #[test]
235    fn mean_all_null_returns_nan() -> VortexResult<()> {
236        let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
237        let mut ctx = LEGACY_SESSION.create_execution_ctx();
238        let result = mean(&array, &mut ctx)?;
239        assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
240        Ok(())
241    }
242
243    #[test]
244    fn mean_multi_batch() -> VortexResult<()> {
245        let mut ctx = LEGACY_SESSION.create_execution_ctx();
246        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
247        let mut acc = Accumulator::try_new(
248            Mean::combined(),
249            PairOptions(EmptyOptions, EmptyOptions),
250            dtype,
251        )?;
252
253        let batch1 =
254            PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
255        acc.accumulate(&batch1, &mut ctx)?;
256
257        let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array();
258        acc.accumulate(&batch2, &mut ctx)?;
259
260        let result = acc.finish()?;
261        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
262        Ok(())
263    }
264}