Skip to main content

vortex_array/aggregate_fn/fns/max/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_session::VortexSession;
7use vortex_session::registry::CachedId;
8
9use crate::ArrayRef;
10use crate::Columnar;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::aggregate_fn::AggregateFnId;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::AggregateFnSatisfaction;
16use crate::aggregate_fn::AggregateFnVTable;
17use crate::aggregate_fn::NumericalAggregateOpts;
18use crate::aggregate_fn::fns::bounded_max::BoundedMax;
19use crate::aggregate_fn::fns::min_max::MinMax;
20use crate::aggregate_fn::fns::min_max::min_max;
21use crate::aggregate_fn::fns::min_max::nan_scalar;
22use crate::aggregate_fn::fns::min_max::scalar_is_nan;
23use crate::dtype::DType;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::expr::stats::StatsProviderExt;
28use crate::partial_ord::partial_max;
29use crate::scalar::Scalar;
30
31/// Compute the maximum non-null value of an array.
32///
33/// NaN handling for float inputs is controlled by [`NumericalAggregateOpts`]: with `skip_nans` (the
34/// default) NaN values are ignored, otherwise any NaN value poisons the maximum to NaN.
35#[derive(Clone, Debug)]
36pub struct Max;
37
38/// Partial accumulator state for the maximum aggregate.
39pub struct MaxPartial {
40    max: Option<Scalar>,
41    element_dtype: DType,
42    skip_nans: bool,
43}
44
45impl MaxPartial {
46    fn merge(&mut self, max: Scalar) {
47        if max.is_null() {
48            return;
49        }
50
51        // NaN scalars are incomparable under `partial_max`; they poison the maximum when NaNs
52        // participate, and are dropped when they are skipped.
53        if scalar_is_nan(&max) || self.is_poisoned() {
54            if !self.skip_nans {
55                self.poison();
56            }
57            return;
58        }
59
60        self.max = Some(match self.max.take() {
61            Some(current) => partial_max(max, current).vortex_expect("incomparable max scalars"),
62            None => max,
63        });
64    }
65
66    fn poison(&mut self) {
67        self.max = Some(nan_scalar(&self.element_dtype));
68    }
69
70    fn is_poisoned(&self) -> bool {
71        self.element_dtype.is_float() && self.max.as_ref().is_some_and(scalar_is_nan)
72    }
73}
74
75impl AggregateFnVTable for Max {
76    type Options = NumericalAggregateOpts;
77    type Partial = MaxPartial;
78
79    fn id(&self) -> AggregateFnId {
80        static ID: CachedId = CachedId::new("vortex.max");
81        *ID
82    }
83
84    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
85        Ok(Some(options.serialize()))
86    }
87
88    fn deserialize(
89        &self,
90        metadata: &[u8],
91        _session: &VortexSession,
92    ) -> VortexResult<Self::Options> {
93        NumericalAggregateOpts::deserialize(metadata)
94    }
95
96    fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
97        MinMax
98            .return_dtype(options, input_dtype)
99            .map(|_| input_dtype.as_nullable())
100    }
101
102    fn can_satisfy(
103        &self,
104        options: &Self::Options,
105        requested: &AggregateFnRef,
106    ) -> AggregateFnSatisfaction {
107        if requested
108            .as_opt::<Self>()
109            .is_some_and(|other| other == options)
110        {
111            AggregateFnSatisfaction::Exact
112        } else if requested.is::<BoundedMax>() && options.skip_nans {
113            // A NaN-including maximum may be NaN, which is not a usable upper bound.
114            AggregateFnSatisfaction::Approximate
115        } else {
116            AggregateFnSatisfaction::No
117        }
118    }
119
120    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
121        self.return_dtype(options, input_dtype)
122    }
123
124    fn empty_partial(
125        &self,
126        options: &Self::Options,
127        input_dtype: &DType,
128    ) -> VortexResult<Self::Partial> {
129        Ok(MaxPartial {
130            max: None,
131            element_dtype: input_dtype.clone(),
132            skip_nans: options.skip_nans,
133        })
134    }
135
136    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
137        partial.merge(other);
138        Ok(())
139    }
140
141    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
142        let dtype = partial.element_dtype.as_nullable();
143        match &partial.max {
144            Some(max) => max.cast(&dtype),
145            None => Ok(Scalar::null(dtype)),
146        }
147    }
148
149    fn reset(&self, partial: &mut Self::Partial) {
150        partial.max = None;
151    }
152
153    fn is_saturated(&self, partial: &Self::Partial) -> bool {
154        // A poisoned NaN-including maximum is fully determined.
155        partial.is_poisoned()
156    }
157
158    fn try_accumulate(
159        &self,
160        partial: &mut Self::Partial,
161        batch: &ArrayRef,
162        _ctx: &mut ExecutionCtx,
163    ) -> VortexResult<bool> {
164        // NaN-aware shortcircuits only apply to the NaN-including float maximum; everything else
165        // takes the default dispatch path.
166        if partial.skip_nans || !partial.element_dtype.is_float() {
167            return Ok(false);
168        }
169        match batch.statistics().get_as::<u64>(Stat::NaNCount) {
170            Precision::Exact(0) => {
171                // NaN-free batch: the cached NaN-skipping maximum (if any) is valid. `to_scalar`
172                // re-casts to the result dtype, so the cached scalar can merge as-is.
173                if let Some(max) = batch.statistics().get(Stat::Max).as_exact() {
174                    partial.merge(max);
175                    return Ok(true);
176                }
177                Ok(false)
178            }
179            Precision::Exact(_) => {
180                partial.poison();
181                Ok(true)
182            }
183            _ => Ok(false),
184        }
185    }
186
187    fn accumulate(
188        &self,
189        partial: &mut Self::Partial,
190        batch: &Columnar,
191        ctx: &mut ExecutionCtx,
192    ) -> VortexResult<()> {
193        // Delegate to the existing min_max implementation for now. A dedicated max aggregate
194        // would avoid computing min when only max is needed.
195        let array = match batch {
196            Columnar::Canonical(canonical) => canonical.clone().into_array(),
197            Columnar::Constant(constant) => constant.clone().into_array(),
198        };
199        let options = NumericalAggregateOpts {
200            skip_nans: partial.skip_nans,
201        };
202        if let Some(result) = min_max(&array, ctx, options)? {
203            partial.merge(result.max);
204        }
205        Ok(())
206    }
207
208    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
209        Ok(partials)
210    }
211
212    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
213        self.to_scalar(partial)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use vortex_buffer::buffer;
220    use vortex_error::VortexResult;
221
222    use crate::IntoArray as _;
223    use crate::VortexSessionExecute;
224    use crate::aggregate_fn::Accumulator;
225    use crate::aggregate_fn::DynAccumulator;
226    use crate::aggregate_fn::NumericalAggregateOpts;
227    use crate::aggregate_fn::fns::max::Max;
228    use crate::array_session;
229    use crate::arrays::PrimitiveArray;
230    use crate::dtype::DType;
231    use crate::dtype::Nullability;
232    use crate::dtype::PType;
233    use crate::expr::stats::Precision;
234    use crate::expr::stats::Stat;
235    use crate::scalar::Scalar;
236    use crate::scalar::ScalarValue;
237    use crate::validity::Validity;
238
239    #[test]
240    fn max_aggregate_fn() -> VortexResult<()> {
241        let mut ctx = array_session().create_execution_ctx();
242        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
243        let mut acc = Accumulator::try_new(Max, NumericalAggregateOpts::default(), dtype)?;
244
245        let batch1 = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array();
246        acc.accumulate(&batch1, &mut ctx)?;
247
248        let batch2 = PrimitiveArray::new(buffer![3i32, 25], Validity::NonNullable).into_array();
249        acc.accumulate(&batch2, &mut ctx)?;
250
251        assert_eq!(
252            acc.finish()?,
253            Scalar::primitive(25i32, Nullability::Nullable)
254        );
255        Ok(())
256    }
257
258    #[test]
259    fn max_empty_group_returns_null() -> VortexResult<()> {
260        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
261        let mut acc = Accumulator::try_new(Max, NumericalAggregateOpts::default(), dtype)?;
262
263        assert_eq!(
264            acc.finish()?,
265            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
266        );
267        Ok(())
268    }
269
270    #[test]
271    fn max_with_nan_not_skipping() -> VortexResult<()> {
272        let mut ctx = array_session().create_execution_ctx();
273        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
274        let mut acc = Accumulator::try_new(Max, NumericalAggregateOpts::include_nans(), dtype)?;
275
276        let batch = PrimitiveArray::new(buffer![1.0f64, f64::NAN, -5.0], Validity::NonNullable)
277            .into_array();
278        acc.accumulate(&batch, &mut ctx)?;
279        assert!(acc.is_saturated());
280
281        let result = acc.finish()?;
282        assert!(
283            result
284                .as_primitive()
285                .typed_value::<f64>()
286                .is_some_and(f64::is_nan)
287        );
288        Ok(())
289    }
290
291    #[test]
292    fn max_not_skipping_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> {
293        let mut ctx = array_session().create_execution_ctx();
294        // The array has no NaNs; a planted exact NaNCount stat proves the poisoning came from
295        // the stat rather than a scan.
296        let batch = PrimitiveArray::new(buffer![1.0f64, 2.0], Validity::NonNullable).into_array();
297        batch
298            .statistics()
299            .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(1u64)));
300        let mut acc = Accumulator::try_new(
301            Max,
302            NumericalAggregateOpts::include_nans(),
303            batch.dtype().clone(),
304        )?;
305        acc.accumulate(&batch, &mut ctx)?;
306        let result = acc.finish()?;
307        assert!(
308            result
309                .as_primitive()
310                .typed_value::<f64>()
311                .is_some_and(f64::is_nan)
312        );
313        Ok(())
314    }
315
316    #[test]
317    fn max_nan_including_nullable_cached_stat() -> VortexResult<()> {
318        // A nullable float array's cached Max stat is reconstructed as a nullable scalar. The
319        // NaN-including shortcircuit merges it as-is; `to_scalar` re-casts to the result dtype.
320        let mut ctx = array_session().create_execution_ctx();
321        let array =
322            PrimitiveArray::from_option_iter([Some(1.0f64), Some(2.0), Some(3.0)]).into_array();
323        array
324            .statistics()
325            .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(0u64)));
326        array
327            .statistics()
328            .set(Stat::Max, Precision::Exact(ScalarValue::from(3.0f64)));
329        let mut acc = Accumulator::try_new(
330            Max,
331            NumericalAggregateOpts::include_nans(),
332            array.dtype().clone(),
333        )?;
334        acc.accumulate(&array, &mut ctx)?;
335        assert_eq!(
336            acc.finish()?,
337            Scalar::primitive(3.0f64, Nullability::Nullable)
338        );
339        Ok(())
340    }
341
342    #[test]
343    fn max_casts_nonnullable_legacy_stat_to_nullable_partial() -> VortexResult<()> {
344        let mut ctx = array_session().create_execution_ctx();
345        let batch = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
346        batch
347            .statistics()
348            .set(Stat::Max, Precision::Exact(ScalarValue::from(25i32)));
349        let mut acc = Accumulator::try_new(
350            Max,
351            NumericalAggregateOpts::default(),
352            batch.dtype().clone(),
353        )?;
354
355        acc.accumulate(&batch, &mut ctx)?;
356
357        assert_eq!(
358            acc.finish()?,
359            Scalar::primitive(25i32, Nullability::Nullable)
360        );
361        Ok(())
362    }
363}