vortex_array/aggregate_fn/fns/mean/
mod.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::aggregate_fn::Accumulator;
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::DynAccumulator;
12use crate::aggregate_fn::EmptyOptions;
13use crate::aggregate_fn::combined::BinaryCombined;
14use crate::aggregate_fn::combined::Combined;
15use crate::aggregate_fn::combined::CombinedOptions;
16use crate::aggregate_fn::combined::PairOptions;
17use crate::aggregate_fn::fns::count::Count;
18use crate::aggregate_fn::fns::sum::Sum;
19use crate::builtins::ArrayBuiltins;
20use crate::dtype::DType;
21use crate::dtype::Nullability;
22use crate::dtype::PType;
23use crate::scalar::Scalar;
24use crate::scalar_fn::fns::operators::Operator;
25
26pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
30 let mut acc = Accumulator::try_new(
31 Mean::combined(),
32 PairOptions(EmptyOptions, EmptyOptions),
33 array.dtype().clone(),
34 )?;
35 acc.accumulate(array, ctx)?;
36 acc.finish()
37}
38
39#[derive(Clone, Debug)]
46pub struct Mean;
47
48impl Mean {
49 pub fn combined() -> Combined<Self> {
50 Combined(Mean)
51 }
52}
53
54impl BinaryCombined for Mean {
55 type Left = Sum;
56 type Right = Count;
57
58 fn id(&self) -> AggregateFnId {
59 AggregateFnId::new("vortex.mean")
60 }
61
62 fn left(&self) -> Sum {
63 Sum
64 }
65
66 fn right(&self) -> Count {
67 Count
68 }
69
70 fn left_name(&self) -> &'static str {
71 "sum"
72 }
73
74 fn right_name(&self) -> &'static str {
75 "count"
76 }
77
78 fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
79 Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable))
80 }
81
82 fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult<ArrayRef> {
83 let target = match sum.dtype() {
84 DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable),
85 _ => DType::Primitive(PType::F64, Nullability::Nullable),
86 };
87 let sum_cast = sum.cast(target.clone())?;
88 let count_cast = count.cast(target)?;
89 sum_cast.binary(count_cast, Operator::Div)
90 }
91
92 fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
93 if let DType::Decimal(..) = left_scalar.dtype() {
94 vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs");
95 }
96
97 let target = DType::Primitive(PType::F64, Nullability::Nullable);
98 let sum_cast = left_scalar.cast(&target)?;
99 let count_cast = right_scalar.cast(&target)?;
100
101 let sum = sum_cast.as_primitive().typed_value::<f64>();
102 let count = count_cast.as_primitive().typed_value::<f64>();
103 let value = match (sum, count) {
104 (None, _) | (_, None) | (_, Some(0.0)) => return Ok(Scalar::null(target)), (Some(s), Some(c)) => s / c,
106 };
107 Ok(Scalar::primitive(value, Nullability::Nullable))
108 }
109
110 fn serialize(&self, _options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
111 unimplemented!("mean is not yet serializable");
112 }
113}
114
115fn mean_output_dtype(input_dtype: &DType) -> Option<DType> {
116 match input_dtype {
117 DType::Bool(_) | DType::Primitive(..) => {
118 Some(DType::Primitive(PType::F64, Nullability::Nullable))
119 }
120 DType::Decimal(..) => {
121 unimplemented!("mean for decimals is not yet implemented");
122 }
123 _ => None,
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use vortex_buffer::buffer;
130 use vortex_error::VortexResult;
131
132 use super::*;
133 use crate::IntoArray;
134 use crate::LEGACY_SESSION;
135 use crate::VortexSessionExecute;
136 use crate::arrays::BoolArray;
137 use crate::arrays::ChunkedArray;
138 use crate::arrays::ConstantArray;
139 use crate::arrays::PrimitiveArray;
140 use crate::validity::Validity;
141
142 #[test]
143 fn mean_all_valid() -> VortexResult<()> {
144 let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable)
145 .into_array();
146 let mut ctx = LEGACY_SESSION.create_execution_ctx();
147 let result = mean(&array, &mut ctx)?;
148 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
149 Ok(())
150 }
151
152 #[test]
153 fn mean_with_nulls() -> VortexResult<()> {
154 let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array();
155 let mut ctx = LEGACY_SESSION.create_execution_ctx();
156 let result = mean(&array, &mut ctx)?;
157 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
158 Ok(())
159 }
160
161 #[test]
162 fn mean_integers() -> VortexResult<()> {
163 let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array();
164 let mut ctx = LEGACY_SESSION.create_execution_ctx();
165 let result = mean(&array, &mut ctx)?;
166 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.0));
167 Ok(())
168 }
169
170 #[test]
171 fn mean_bool() -> VortexResult<()> {
172 let array: BoolArray = [true, false, true, true].into_iter().collect();
173 let mut ctx = LEGACY_SESSION.create_execution_ctx();
174 let result = mean(&array.into_array(), &mut ctx)?;
175 assert_eq!(result.as_primitive().as_::<f64>(), Some(0.75));
176 Ok(())
177 }
178
179 #[test]
180 fn mean_constant_non_null() -> VortexResult<()> {
181 let array = ConstantArray::new(5.0f64, 4);
182 let mut ctx = LEGACY_SESSION.create_execution_ctx();
183 let result = mean(&array.into_array(), &mut ctx)?;
184 assert_eq!(result.as_primitive().as_::<f64>(), Some(5.0));
185 Ok(())
186 }
187
188 #[test]
189 fn mean_chunked() -> VortexResult<()> {
190 let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]);
191 let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]);
192 let dtype = chunk1.dtype().clone();
193 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
194 let mut ctx = LEGACY_SESSION.create_execution_ctx();
195 let result = mean(&chunked.into_array(), &mut ctx)?;
196 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
197 Ok(())
198 }
199
200 #[test]
201 fn mean_all_null_returns_null() -> VortexResult<()> {
202 let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
203 let mut ctx = LEGACY_SESSION.create_execution_ctx();
204 let result = mean(&array, &mut ctx)?;
205 assert_eq!(result.as_primitive().as_::<f64>(), None);
206 Ok(())
207 }
208
209 #[test]
210 fn mean_multi_batch() -> VortexResult<()> {
211 let mut ctx = LEGACY_SESSION.create_execution_ctx();
212 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
213 let mut acc = Accumulator::try_new(
214 Mean::combined(),
215 PairOptions(EmptyOptions, EmptyOptions),
216 dtype,
217 )?;
218
219 let batch1 =
220 PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
221 acc.accumulate(&batch1, &mut ctx)?;
222
223 let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array();
224 acc.accumulate(&batch2, &mut ctx)?;
225
226 let result = acc.finish()?;
227 assert_eq!(result.as_primitive().as_::<f64>(), Some(3.0));
228 Ok(())
229 }
230}