Skip to main content

vortex_array/aggregate_fn/fns/mean/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::aggregate_fn::Accumulator;
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::DynAccumulator;
12use crate::aggregate_fn::EmptyOptions;
13use crate::aggregate_fn::combined::BinaryCombined;
14use crate::aggregate_fn::combined::Combined;
15use crate::aggregate_fn::combined::CombinedOptions;
16use crate::aggregate_fn::combined::PairOptions;
17use crate::aggregate_fn::fns::count::Count;
18use crate::aggregate_fn::fns::sum::Sum;
19use crate::builtins::ArrayBuiltins;
20use crate::dtype::DType;
21use crate::dtype::Nullability;
22use crate::dtype::PType;
23use crate::scalar::Scalar;
24use crate::scalar_fn::fns::operators::Operator;
25
26/// Compute the arithmetic mean of an array.
27///
28/// See [`Mean`] for details.
29pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
30    let mut acc = Accumulator::try_new(
31        Mean::combined(),
32        PairOptions(EmptyOptions, EmptyOptions),
33        array.dtype().clone(),
34    )?;
35    acc.accumulate(array, ctx)?;
36    acc.finish()
37}
38
39/// Compute the arithmetic mean of an array.
40///
41/// Implemented as `Sum / Count` via [`BinaryCombined`].
42///
43/// Booleans and primitive numeric types produce nullable `f64` results.
44/// Decimals are kept as decimals but not implemented currently.
45#[derive(Clone, Debug)]
46pub struct Mean;
47
48impl Mean {
49    pub fn combined() -> Combined<Self> {
50        Combined(Mean)
51    }
52}
53
54impl BinaryCombined for Mean {
55    type Left = Sum;
56    type Right = Count;
57
58    fn id(&self) -> AggregateFnId {
59        AggregateFnId::new("vortex.mean")
60    }
61
62    fn left(&self) -> Sum {
63        Sum
64    }
65
66    fn right(&self) -> Count {
67        Count
68    }
69
70    fn left_name(&self) -> &'static str {
71        "sum"
72    }
73
74    fn right_name(&self) -> &'static str {
75        "count"
76    }
77
78    fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
79        Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable))
80    }
81
82    fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult<ArrayRef> {
83        let target = match sum.dtype() {
84            DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable),
85            _ => DType::Primitive(PType::F64, Nullability::Nullable),
86        };
87        let sum_cast = sum.cast(target.clone())?;
88        let count_cast = count.cast(target)?;
89        sum_cast.binary(count_cast, Operator::Div)
90    }
91
92    fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
93        if let DType::Decimal(..) = left_scalar.dtype() {
94            vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs");
95        }
96
97        let target = DType::Primitive(PType::F64, Nullability::Nullable);
98        let sum_cast = left_scalar.cast(&target)?;
99        let count_cast = right_scalar.cast(&target)?;
100
101        let sum = sum_cast.as_primitive().typed_value::<f64>();
102        let count = count_cast.as_primitive().typed_value::<f64>();
103        let value = match (sum, count) {
104            (None, _) | (_, None) | (_, Some(0.0)) => return Ok(Scalar::null(target)), // Sum overflowed
105            (Some(s), Some(c)) => s / c,
106        };
107        Ok(Scalar::primitive(value, Nullability::Nullable))
108    }
109
110    fn serialize(&self, _options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
111        unimplemented!("mean is not yet serializable");
112    }
113}
114
115fn mean_output_dtype(input_dtype: &DType) -> Option<DType> {
116    match input_dtype {
117        DType::Bool(_) | DType::Primitive(..) => {
118            Some(DType::Primitive(PType::F64, Nullability::Nullable))
119        }
120        DType::Decimal(..) => {
121            unimplemented!("mean for decimals is not yet implemented");
122        }
123        _ => None,
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use vortex_buffer::buffer;
130    use vortex_error::VortexResult;
131
132    use super::*;
133    use crate::IntoArray;
134    use crate::LEGACY_SESSION;
135    use crate::VortexSessionExecute;
136    use crate::arrays::BoolArray;
137    use crate::arrays::ChunkedArray;
138    use crate::arrays::ConstantArray;
139    use crate::arrays::PrimitiveArray;
140    use crate::validity::Validity;
141
142    #[test]
143    fn mean_all_valid() -> VortexResult<()> {
144        let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable)
145            .into_array();
146        let mut ctx = LEGACY_SESSION.create_execution_ctx();
147        let result = mean(&array, &mut ctx)?;
148        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
149        Ok(())
150    }
151
152    #[test]
153    fn mean_with_nulls() -> VortexResult<()> {
154        let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array();
155        let mut ctx = LEGACY_SESSION.create_execution_ctx();
156        let result = mean(&array, &mut ctx)?;
157        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
158        Ok(())
159    }
160
161    #[test]
162    fn mean_integers() -> VortexResult<()> {
163        let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array();
164        let mut ctx = LEGACY_SESSION.create_execution_ctx();
165        let result = mean(&array, &mut ctx)?;
166        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.0));
167        Ok(())
168    }
169
170    #[test]
171    fn mean_bool() -> VortexResult<()> {
172        let array: BoolArray = [true, false, true, true].into_iter().collect();
173        let mut ctx = LEGACY_SESSION.create_execution_ctx();
174        let result = mean(&array.into_array(), &mut ctx)?;
175        assert_eq!(result.as_primitive().as_::<f64>(), Some(0.75));
176        Ok(())
177    }
178
179    #[test]
180    fn mean_constant_non_null() -> VortexResult<()> {
181        let array = ConstantArray::new(5.0f64, 4);
182        let mut ctx = LEGACY_SESSION.create_execution_ctx();
183        let result = mean(&array.into_array(), &mut ctx)?;
184        assert_eq!(result.as_primitive().as_::<f64>(), Some(5.0));
185        Ok(())
186    }
187
188    #[test]
189    fn mean_chunked() -> VortexResult<()> {
190        let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]);
191        let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]);
192        let dtype = chunk1.dtype().clone();
193        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
194        let mut ctx = LEGACY_SESSION.create_execution_ctx();
195        let result = mean(&chunked.into_array(), &mut ctx)?;
196        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
197        Ok(())
198    }
199
200    #[test]
201    fn mean_all_null_returns_null() -> VortexResult<()> {
202        let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
203        let mut ctx = LEGACY_SESSION.create_execution_ctx();
204        let result = mean(&array, &mut ctx)?;
205        assert_eq!(result.as_primitive().as_::<f64>(), None);
206        Ok(())
207    }
208
209    #[test]
210    fn mean_multi_batch() -> VortexResult<()> {
211        let mut ctx = LEGACY_SESSION.create_execution_ctx();
212        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
213        let mut acc = Accumulator::try_new(
214            Mean::combined(),
215            PairOptions(EmptyOptions, EmptyOptions),
216            dtype,
217        )?;
218
219        let batch1 =
220            PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
221        acc.accumulate(&batch1, &mut ctx)?;
222
223        let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array();
224        acc.accumulate(&batch2, &mut ctx)?;
225
226        let result = acc.finish()?;
227        assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
228        Ok(())
229    }
230}