Skip to main content

vortex_array/aggregate_fn/fns/bounded_max/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::num::NonZeroUsize;
7use std::sync::LazyLock;
8
9use vortex_buffer::BufferString;
10use vortex_buffer::ByteBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::aggregate_fn::AggregateFnId;
23use crate::aggregate_fn::AggregateFnRef;
24use crate::aggregate_fn::AggregateFnSatisfaction;
25use crate::aggregate_fn::AggregateFnVTable;
26use crate::aggregate_fn::NumericalAggregateOpts;
27use crate::aggregate_fn::fns::max::Max;
28use crate::aggregate_fn::fns::min_max::MinMax;
29use crate::aggregate_fn::fns::min_max::min_max;
30use crate::builtins::ArrayBuiltins;
31use crate::dtype::DType;
32use crate::dtype::FieldNames;
33use crate::dtype::Nullability;
34use crate::dtype::StructFields;
35use crate::partial_ord::partial_max;
36use crate::scalar::Scalar;
37use crate::scalar::ScalarTruncation;
38use crate::scalar::upper_bound;
39
40/// Field name for the bounded maximum upper-bound value in the partial state.
41pub const BOUNDED_MAX_BOUND: &str = "bound";
42/// Field name for whether the partial state represents an unknown upper bound.
43pub const BOUNDED_MAX_UNKNOWN: &str = "unknown";
44
45static NAMES: LazyLock<FieldNames> =
46    LazyLock::new(|| FieldNames::from([BOUNDED_MAX_BOUND, BOUNDED_MAX_UNKNOWN]));
47
48/// Options for [`BoundedMax`].
49#[derive(Clone, Debug, PartialEq, Eq, Hash)]
50pub struct BoundedMaxOptions {
51    /// Maximum byte length for UTF8/Binary bounds.
52    pub max_bytes: NonZeroUsize,
53}
54
55impl Display for BoundedMaxOptions {
56    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57        write!(f, "{}", self.max_bytes.get())
58    }
59}
60
61/// Compute a byte-bounded upper bound for the maximum non-null value of a UTF8/Binary array.
62#[derive(Clone, Debug)]
63pub struct BoundedMax;
64
65enum BoundedMaxState {
66    Empty,
67    Value(Scalar),
68    Unknown,
69}
70
71/// Partial accumulator state for the bounded maximum aggregate.
72pub struct BoundedMaxPartial {
73    state: BoundedMaxState,
74    element_dtype: DType,
75    max_bytes: NonZeroUsize,
76}
77
78impl BoundedMaxPartial {
79    fn merge_bound(&mut self, max: Scalar) {
80        if max.is_null() {
81            return;
82        }
83
84        self.state = match std::mem::replace(&mut self.state, BoundedMaxState::Empty) {
85            BoundedMaxState::Empty => BoundedMaxState::Value(max),
86            BoundedMaxState::Value(current) => BoundedMaxState::Value(
87                partial_max(max, current).vortex_expect("incomparable bounded max scalars"),
88            ),
89            BoundedMaxState::Unknown => BoundedMaxState::Unknown,
90        };
91    }
92
93    fn unknown(&mut self) {
94        self.state = BoundedMaxState::Unknown;
95    }
96
97    fn final_scalar(&self) -> VortexResult<Scalar> {
98        let dtype = self.element_dtype.as_nullable();
99        match &self.state {
100            BoundedMaxState::Value(max) => max.cast(&dtype),
101            BoundedMaxState::Empty | BoundedMaxState::Unknown => Ok(Scalar::null(dtype)),
102        }
103    }
104}
105
106/// Return the serialized partial-state dtype for [`BoundedMax`].
107///
108/// A null struct means the partial is empty. A non-null struct with a null `bound` and
109/// `unknown = true` means the input has a non-null maximum but no finite upper bound could be
110/// represented within the configured byte limit.
111pub fn make_bounded_max_partial_dtype(element_dtype: &DType) -> DType {
112    DType::Struct(
113        StructFields::new(
114            NAMES.clone(),
115            vec![
116                element_dtype.as_nullable(),
117                DType::Bool(Nullability::NonNullable),
118            ],
119        ),
120        Nullability::Nullable,
121    )
122}
123
124impl AggregateFnVTable for BoundedMax {
125    type Options = BoundedMaxOptions;
126    type Partial = BoundedMaxPartial;
127
128    fn id(&self) -> AggregateFnId {
129        static ID: CachedId = CachedId::new("vortex.bounded_max");
130        *ID
131    }
132
133    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
134        let max_bytes = u64::try_from(options.max_bytes.get())?;
135        Ok(Some(max_bytes.to_le_bytes().to_vec()))
136    }
137
138    fn deserialize(
139        &self,
140        metadata: &[u8],
141        _session: &VortexSession,
142    ) -> VortexResult<Self::Options> {
143        vortex_ensure!(
144            metadata.len() == size_of::<u64>(),
145            "BoundedMax options expected {} bytes, got {}",
146            size_of::<u64>(),
147            metadata.len()
148        );
149        let mut bytes = [0u8; size_of::<u64>()];
150        bytes.copy_from_slice(metadata);
151        let max_bytes = usize::try_from(u64::from_le_bytes(bytes))?;
152        vortex_ensure!(max_bytes > 0, "BoundedMax requires max_bytes > 0");
153        Ok(BoundedMaxOptions {
154            max_bytes: NonZeroUsize::new(max_bytes).vortex_expect("checked non-zero max_bytes"),
155        })
156    }
157
158    fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
159        supported_dtype(options, input_dtype).map(DType::as_nullable)
160    }
161
162    fn can_satisfy(
163        &self,
164        options: &Self::Options,
165        requested: &AggregateFnRef,
166    ) -> AggregateFnSatisfaction {
167        if let Some(other) = requested.as_opt::<Self>() {
168            return if other == options {
169                AggregateFnSatisfaction::Exact
170            } else if options.max_bytes >= other.max_bytes {
171                AggregateFnSatisfaction::Approximate
172            } else {
173                AggregateFnSatisfaction::No
174            };
175        }
176
177        // The stored bound skips NaNs, so it cannot stand in for a NaN-including maximum.
178        if requested
179            .as_opt::<Max>()
180            .is_some_and(|options| options.skip_nans)
181        {
182            AggregateFnSatisfaction::Approximate
183        } else {
184            AggregateFnSatisfaction::No
185        }
186    }
187
188    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
189        supported_dtype(options, input_dtype).map(make_bounded_max_partial_dtype)
190    }
191
192    fn empty_partial(
193        &self,
194        options: &Self::Options,
195        input_dtype: &DType,
196    ) -> VortexResult<Self::Partial> {
197        Ok(BoundedMaxPartial {
198            state: BoundedMaxState::Empty,
199            element_dtype: input_dtype.clone(),
200            max_bytes: options.max_bytes,
201        })
202    }
203
204    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
205        if other.is_null() {
206            return Ok(());
207        }
208
209        let Some(other) = other.as_struct_opt() else {
210            vortex_bail!("BoundedMax partial must be a struct, got {}", other.dtype());
211        };
212        let Some(bound) = other.field_by_idx(0) else {
213            vortex_bail!("BoundedMax partial is missing its bound field");
214        };
215        let Some(unknown) = other
216            .field_by_idx(1)
217            .and_then(|unknown| unknown.as_bool().value())
218        else {
219            vortex_bail!("BoundedMax partial is missing its non-null unknown field");
220        };
221
222        if unknown {
223            partial.unknown();
224        } else {
225            partial.merge_bound(bound);
226        }
227        Ok(())
228    }
229
230    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
231        let dtype = make_bounded_max_partial_dtype(&partial.element_dtype);
232        let bound_dtype = partial.element_dtype.as_nullable();
233        match &partial.state {
234            BoundedMaxState::Empty => Ok(Scalar::null(dtype)),
235            BoundedMaxState::Value(max) => Ok(Scalar::struct_(
236                dtype,
237                vec![
238                    max.cast(&bound_dtype)?,
239                    Scalar::bool(false, Nullability::NonNullable),
240                ],
241            )),
242            BoundedMaxState::Unknown => Ok(Scalar::struct_(
243                dtype,
244                vec![
245                    Scalar::null(bound_dtype),
246                    Scalar::bool(true, Nullability::NonNullable),
247                ],
248            )),
249        }
250    }
251
252    fn reset(&self, partial: &mut Self::Partial) {
253        partial.state = BoundedMaxState::Empty;
254    }
255
256    fn is_saturated(&self, partial: &Self::Partial) -> bool {
257        matches!(partial.state, BoundedMaxState::Unknown)
258    }
259
260    fn accumulate(
261        &self,
262        partial: &mut Self::Partial,
263        batch: &Columnar,
264        ctx: &mut ExecutionCtx,
265    ) -> VortexResult<()> {
266        // Delegate to the existing min_max implementation for now. A dedicated bounded-max
267        // aggregate would avoid computing min when only max is needed.
268        let array = match batch {
269            Columnar::Canonical(canonical) => canonical.clone().into_array(),
270            Columnar::Constant(constant) => constant.clone().into_array(),
271        };
272        let Some(result) = min_max(&array, ctx, NumericalAggregateOpts::default())? else {
273            return Ok(());
274        };
275        match truncate_max(result.max, partial.max_bytes.get())? {
276            Some(bound) => partial.merge_bound(bound),
277            None => partial.unknown(),
278        }
279        Ok(())
280    }
281
282    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
283        partials.get_item(BOUNDED_MAX_BOUND)
284    }
285
286    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
287        partial.final_scalar()
288    }
289}
290
291fn supported_dtype<'a>(_options: &BoundedMaxOptions, input_dtype: &'a DType) -> Option<&'a DType> {
292    MinMax
293        .return_dtype(&NumericalAggregateOpts::default(), input_dtype)
294        .map(|_| input_dtype)
295}
296
297fn truncate_max(value: Scalar, max_bytes: usize) -> VortexResult<Option<Scalar>> {
298    let nullability = value.dtype().nullability();
299    match value.dtype() {
300        DType::Utf8(_) => {
301            Ok(
302                upper_bound(BufferString::from_scalar(value)?, max_bytes, nullability)
303                    .map(|(bound, _)| bound),
304            )
305        }
306        DType::Binary(_) => {
307            Ok(
308                upper_bound(ByteBuffer::from_scalar(value)?, max_bytes, nullability)
309                    .map(|(bound, _)| bound),
310            )
311        }
312        _ => Ok(Some(value)),
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use std::num::NonZeroUsize;
319
320    use vortex_buffer::buffer;
321    use vortex_error::VortexExpect;
322    use vortex_error::VortexResult;
323    use vortex_session::VortexSession;
324
325    use crate::IntoArray;
326    use crate::VortexSessionExecute;
327    use crate::aggregate_fn::Accumulator;
328    use crate::aggregate_fn::AggregateFnSatisfaction;
329    use crate::aggregate_fn::AggregateFnVTable;
330    use crate::aggregate_fn::AggregateFnVTableExt;
331    use crate::aggregate_fn::DynAccumulator;
332    use crate::aggregate_fn::NumericalAggregateOpts;
333    use crate::aggregate_fn::fns::bounded_max::BoundedMax;
334    use crate::aggregate_fn::fns::bounded_max::BoundedMaxOptions;
335    use crate::aggregate_fn::fns::bounded_max::make_bounded_max_partial_dtype;
336    use crate::aggregate_fn::fns::max::Max;
337    use crate::aggregate_fn::fns::min::Min;
338    use crate::array_session;
339    use crate::arrays::PrimitiveArray;
340    use crate::arrays::VarBinViewArray;
341    use crate::dtype::Nullability;
342    use crate::scalar::Scalar;
343    use crate::validity::Validity;
344
345    fn max_bytes(value: usize) -> NonZeroUsize {
346        NonZeroUsize::new(value).vortex_expect("non-zero max_bytes")
347    }
348
349    fn fresh_session() -> VortexSession {
350        array_session()
351    }
352
353    #[test]
354    fn bounded_max_truncates_utf8_to_upper_bound() -> VortexResult<()> {
355        let mut ctx = array_session().create_execution_ctx();
356        let array = VarBinViewArray::from_iter_str(["aardvark", "char🪩"]).into_array();
357        let mut acc = Accumulator::try_new(
358            BoundedMax,
359            BoundedMaxOptions {
360                max_bytes: max_bytes(5),
361            },
362            array.dtype().clone(),
363        )?;
364
365        acc.accumulate(&array, &mut ctx)?;
366
367        assert_eq!(acc.finish()?, Scalar::utf8("chas", Nullability::Nullable));
368        Ok(())
369    }
370
371    #[test]
372    fn bounded_max_unknown_upper_bound_returns_null() -> VortexResult<()> {
373        let mut ctx = array_session().create_execution_ctx();
374        let array = VarBinViewArray::from_iter_bin([&[255u8, 255, 255][..]]).into_array();
375        let mut acc = Accumulator::try_new(
376            BoundedMax,
377            BoundedMaxOptions {
378                max_bytes: max_bytes(2),
379            },
380            array.dtype().clone(),
381        )?;
382
383        acc.accumulate(&array, &mut ctx)?;
384
385        assert_eq!(acc.finish()?, Scalar::null(array.dtype().as_nullable()));
386        Ok(())
387    }
388
389    #[test]
390    fn bounded_max_empty_does_not_poison_later_values() -> VortexResult<()> {
391        let mut ctx = array_session().create_execution_ctx();
392        let empty = VarBinViewArray::from_iter_bin(Vec::<&[u8]>::new()).into_array();
393        let values = VarBinViewArray::from_iter_bin([&[1u8][..]]).into_array();
394        let mut acc = Accumulator::try_new(
395            BoundedMax,
396            BoundedMaxOptions {
397                max_bytes: max_bytes(2),
398            },
399            empty.dtype().clone(),
400        )?;
401
402        acc.accumulate(&empty, &mut ctx)?;
403        acc.accumulate(&values, &mut ctx)?;
404
405        assert_eq!(
406            acc.finish()?,
407            Scalar::binary(buffer![1u8], Nullability::Nullable)
408        );
409        Ok(())
410    }
411
412    #[test]
413    fn bounded_max_unknown_poisons_later_values() -> VortexResult<()> {
414        let mut ctx = array_session().create_execution_ctx();
415        let unknown = VarBinViewArray::from_iter_bin([&[255u8, 255, 255][..]]).into_array();
416        let values = VarBinViewArray::from_iter_bin([&[1u8][..]]).into_array();
417        let mut acc = Accumulator::try_new(
418            BoundedMax,
419            BoundedMaxOptions {
420                max_bytes: max_bytes(2),
421            },
422            unknown.dtype().clone(),
423        )?;
424
425        acc.accumulate(&unknown, &mut ctx)?;
426        acc.accumulate(&values, &mut ctx)?;
427
428        assert_eq!(acc.finish()?, Scalar::null(unknown.dtype().as_nullable()));
429        Ok(())
430    }
431
432    #[test]
433    fn bounded_max_empty_partial_does_not_poison_existing_bound() -> VortexResult<()> {
434        let mut ctx = fresh_session().create_execution_ctx();
435        let values = VarBinViewArray::from_iter_bin([&[1u8][..]]).into_array();
436        let mut acc = Accumulator::try_new(
437            BoundedMax,
438            BoundedMaxOptions {
439                max_bytes: max_bytes(2),
440            },
441            values.dtype().clone(),
442        )?;
443
444        acc.accumulate(&values, &mut ctx)?;
445        acc.combine_partials(Scalar::null(make_bounded_max_partial_dtype(values.dtype())))?;
446
447        assert_eq!(
448            acc.finish()?,
449            Scalar::binary(buffer![1u8], Nullability::Nullable)
450        );
451        Ok(())
452    }
453
454    #[test]
455    fn bounded_max_unknown_partial_poisons_existing_bound() -> VortexResult<()> {
456        let mut ctx = fresh_session().create_execution_ctx();
457        let values = VarBinViewArray::from_iter_bin([&[1u8][..]]).into_array();
458        let mut acc = Accumulator::try_new(
459            BoundedMax,
460            BoundedMaxOptions {
461                max_bytes: max_bytes(2),
462            },
463            values.dtype().clone(),
464        )?;
465
466        let partial_dtype = make_bounded_max_partial_dtype(values.dtype());
467        let unknown = Scalar::struct_(
468            partial_dtype,
469            vec![
470                Scalar::null(values.dtype().as_nullable()),
471                Scalar::bool(true, Nullability::NonNullable),
472            ],
473        );
474
475        acc.accumulate(&values, &mut ctx)?;
476        acc.combine_partials(unknown)?;
477
478        assert_eq!(acc.finish()?, Scalar::null(values.dtype().as_nullable()));
479        Ok(())
480    }
481
482    #[test]
483    fn bounded_max_keeps_fixed_width_values_exact() -> VortexResult<()> {
484        let mut ctx = array_session().create_execution_ctx();
485        let array = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array();
486        let mut acc = Accumulator::try_new(
487            BoundedMax,
488            BoundedMaxOptions {
489                max_bytes: max_bytes(9),
490            },
491            array.dtype().clone(),
492        )?;
493
494        acc.accumulate(&array, &mut ctx)?;
495
496        assert_eq!(
497            acc.finish()?,
498            Scalar::primitive(20i32, Nullability::Nullable)
499        );
500        Ok(())
501    }
502
503    #[test]
504    fn bounded_max_satisfies_max_bounds() {
505        let stored = BoundedMax.bind(BoundedMaxOptions {
506            max_bytes: max_bytes(5),
507        });
508        let same = BoundedMax.bind(BoundedMaxOptions {
509            max_bytes: max_bytes(5),
510        });
511        let looser_bounded = BoundedMax.bind(BoundedMaxOptions {
512            max_bytes: max_bytes(4),
513        });
514        let tighter_bounded = BoundedMax.bind(BoundedMaxOptions {
515            max_bytes: max_bytes(6),
516        });
517
518        assert_eq!(stored.can_satisfy(&same), AggregateFnSatisfaction::Exact);
519        assert_eq!(
520            stored.can_satisfy(&looser_bounded),
521            AggregateFnSatisfaction::Approximate
522        );
523        assert_eq!(
524            stored.can_satisfy(&tighter_bounded),
525            AggregateFnSatisfaction::No
526        );
527        assert_eq!(
528            stored.can_satisfy(&Max.bind(NumericalAggregateOpts::default())),
529            AggregateFnSatisfaction::Approximate
530        );
531        assert_eq!(
532            stored.can_satisfy(&Max.bind(NumericalAggregateOpts::include_nans())),
533            AggregateFnSatisfaction::No
534        );
535        assert_eq!(
536            Max.bind(NumericalAggregateOpts::include_nans())
537                .can_satisfy(&stored),
538            AggregateFnSatisfaction::No
539        );
540        assert_eq!(
541            Max.bind(NumericalAggregateOpts::default())
542                .can_satisfy(&stored),
543            AggregateFnSatisfaction::Approximate
544        );
545        assert_eq!(
546            stored.can_satisfy(&Min.bind(NumericalAggregateOpts::default())),
547            AggregateFnSatisfaction::No
548        );
549    }
550
551    #[test]
552    fn bounded_max_options_round_trip() -> VortexResult<()> {
553        let options = BoundedMaxOptions {
554            max_bytes: max_bytes(64),
555        };
556        let metadata = BoundedMax
557            .serialize(&options)?
558            .expect("serializable options");
559        let roundtrip = BoundedMax.deserialize(&metadata, &VortexSession::empty())?;
560
561        assert_eq!(roundtrip, options);
562        Ok(())
563    }
564}