Skip to main content

vortex_array/aggregate_fn/fns/max/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::Columnar;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::aggregate_fn::AggregateFnId;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnSatisfaction;
14use crate::aggregate_fn::AggregateFnVTable;
15use crate::aggregate_fn::EmptyOptions;
16use crate::aggregate_fn::fns::bounded_max::BoundedMax;
17use crate::aggregate_fn::fns::min_max::MinMax;
18use crate::aggregate_fn::fns::min_max::min_max;
19use crate::dtype::DType;
20use crate::partial_ord::partial_max;
21use crate::scalar::Scalar;
22
23/// Compute the maximum non-null value of an array.
24#[derive(Clone, Debug)]
25pub struct Max;
26
27/// Partial accumulator state for the maximum aggregate.
28pub struct MaxPartial {
29    max: Option<Scalar>,
30    element_dtype: DType,
31}
32
33impl MaxPartial {
34    fn merge(&mut self, max: Scalar) {
35        if max.is_null() {
36            return;
37        }
38
39        self.max = Some(match self.max.take() {
40            Some(current) => partial_max(max, current).vortex_expect("incomparable max scalars"),
41            None => max,
42        });
43    }
44}
45
46impl AggregateFnVTable for Max {
47    type Options = EmptyOptions;
48    type Partial = MaxPartial;
49
50    fn id(&self) -> AggregateFnId {
51        AggregateFnId::new("vortex.max")
52    }
53
54    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
55        Ok(None)
56    }
57
58    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
59        MinMax
60            .return_dtype(&EmptyOptions, input_dtype)
61            .map(|_| input_dtype.as_nullable())
62    }
63
64    fn can_satisfy(
65        &self,
66        _options: &Self::Options,
67        requested: &AggregateFnRef,
68    ) -> AggregateFnSatisfaction {
69        if requested.is::<Self>() {
70            AggregateFnSatisfaction::Exact
71        } else if requested.is::<BoundedMax>() {
72            AggregateFnSatisfaction::Approximate
73        } else {
74            AggregateFnSatisfaction::No
75        }
76    }
77
78    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
79        self.return_dtype(options, input_dtype)
80    }
81
82    fn empty_partial(
83        &self,
84        _options: &Self::Options,
85        input_dtype: &DType,
86    ) -> VortexResult<Self::Partial> {
87        Ok(MaxPartial {
88            max: None,
89            element_dtype: input_dtype.clone(),
90        })
91    }
92
93    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
94        partial.merge(other);
95        Ok(())
96    }
97
98    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
99        let dtype = partial.element_dtype.as_nullable();
100        match &partial.max {
101            Some(max) => max.cast(&dtype),
102            None => Ok(Scalar::null(dtype)),
103        }
104    }
105
106    fn reset(&self, partial: &mut Self::Partial) {
107        partial.max = None;
108    }
109
110    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
111        false
112    }
113
114    fn accumulate(
115        &self,
116        partial: &mut Self::Partial,
117        batch: &Columnar,
118        ctx: &mut ExecutionCtx,
119    ) -> VortexResult<()> {
120        // Delegate to the existing min_max implementation for now. A dedicated max aggregate
121        // would avoid computing min when only max is needed.
122        let array = match batch {
123            Columnar::Canonical(canonical) => canonical.clone().into_array(),
124            Columnar::Constant(constant) => constant.clone().into_array(),
125        };
126        if let Some(result) = min_max(&array, ctx)? {
127            partial.merge(result.max);
128        }
129        Ok(())
130    }
131
132    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
133        Ok(partials)
134    }
135
136    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
137        self.to_scalar(partial)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use vortex_buffer::buffer;
144    use vortex_error::VortexResult;
145
146    use crate::IntoArray as _;
147    use crate::LEGACY_SESSION;
148    use crate::VortexSessionExecute;
149    use crate::aggregate_fn::Accumulator;
150    use crate::aggregate_fn::DynAccumulator;
151    use crate::aggregate_fn::EmptyOptions;
152    use crate::aggregate_fn::fns::max::Max;
153    use crate::arrays::PrimitiveArray;
154    use crate::dtype::DType;
155    use crate::dtype::Nullability;
156    use crate::dtype::PType;
157    use crate::expr::stats::Precision;
158    use crate::expr::stats::Stat;
159    use crate::scalar::Scalar;
160    use crate::scalar::ScalarValue;
161    use crate::validity::Validity;
162
163    #[test]
164    fn max_aggregate_fn() -> VortexResult<()> {
165        let mut ctx = LEGACY_SESSION.create_execution_ctx();
166        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
167        let mut acc = Accumulator::try_new(Max, EmptyOptions, dtype)?;
168
169        let batch1 = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array();
170        acc.accumulate(&batch1, &mut ctx)?;
171
172        let batch2 = PrimitiveArray::new(buffer![3i32, 25], Validity::NonNullable).into_array();
173        acc.accumulate(&batch2, &mut ctx)?;
174
175        assert_eq!(
176            acc.finish()?,
177            Scalar::primitive(25i32, Nullability::Nullable)
178        );
179        Ok(())
180    }
181
182    #[test]
183    fn max_empty_group_returns_null() -> VortexResult<()> {
184        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
185        let mut acc = Accumulator::try_new(Max, EmptyOptions, dtype)?;
186
187        assert_eq!(
188            acc.finish()?,
189            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
190        );
191        Ok(())
192    }
193
194    #[test]
195    fn max_casts_nonnullable_legacy_stat_to_nullable_partial() -> VortexResult<()> {
196        let mut ctx = LEGACY_SESSION.create_execution_ctx();
197        let batch = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
198        batch
199            .statistics()
200            .set(Stat::Max, Precision::Exact(ScalarValue::from(25i32)));
201        let mut acc = Accumulator::try_new(Max, EmptyOptions, batch.dtype().clone())?;
202
203        acc.accumulate(&batch, &mut ctx)?;
204
205        assert_eq!(
206            acc.finish()?,
207            Scalar::primitive(25i32, Nullability::Nullable)
208        );
209        Ok(())
210    }
211}