vortex_array/arrays/decimal/compute/
min_max.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_dtype::Nullability::NonNullable;
6use vortex_dtype::{DecimalDType, NativeDecimalType, match_each_decimal_value_type};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9use vortex_scalar::{DecimalValue, Scalar};
10
11use crate::arrays::{DecimalArray, DecimalVTable};
12use crate::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
13use crate::register_kernel;
14
15impl MinMaxKernel for DecimalVTable {
16    fn min_max(&self, array: &DecimalArray) -> VortexResult<Option<MinMaxResult>> {
17        match_each_decimal_value_type!(array.values_type(), |T| {
18            compute_min_max_with_validity::<T>(array)
19        })
20    }
21}
22
23register_kernel!(MinMaxKernelAdapter(DecimalVTable).lift());
24
25#[inline]
26fn compute_min_max_with_validity<D>(array: &DecimalArray) -> VortexResult<Option<MinMaxResult>>
27where
28    D: Into<DecimalValue> + NativeDecimalType,
29{
30    Ok(match array.validity_mask() {
31        Mask::AllTrue(_) => compute_min_max(array.buffer::<D>().iter(), array.decimal_dtype()),
32        Mask::AllFalse(_) => None,
33        Mask::Values(v) => compute_min_max(
34            array
35                .buffer::<D>()
36                .iter()
37                .zip(v.bit_buffer().iter())
38                .filter_map(|(v, m)| m.then_some(v)),
39            array.decimal_dtype(),
40        ),
41    })
42}
43
44fn compute_min_max<'a, T>(
45    iter: impl Iterator<Item = &'a T>,
46    decimal_dtype: DecimalDType,
47) -> Option<MinMaxResult>
48where
49    T: Into<DecimalValue> + NativeDecimalType + Ord + Copy + 'a,
50{
51    match iter.minmax_by(|a, b| a.cmp(b)) {
52        itertools::MinMaxResult::NoElements => None,
53        itertools::MinMaxResult::OneElement(&x) => {
54            let scalar = Scalar::decimal(x.into(), decimal_dtype, NonNullable);
55            Some(MinMaxResult {
56                min: scalar.clone(),
57                max: scalar,
58            })
59        }
60        itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
61            min: Scalar::decimal(min.into(), decimal_dtype, NonNullable),
62            max: Scalar::decimal(max.into(), decimal_dtype, NonNullable),
63        }),
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use vortex_buffer::buffer;
70    use vortex_dtype::DecimalDType;
71    use vortex_scalar::{DecimalValue, Scalar, ScalarValue};
72
73    use crate::arrays::DecimalArray;
74    use crate::compute::{MinMaxResult, min_max};
75    use crate::validity::Validity;
76
77    #[test]
78    fn min_max_test() {
79        let decimal = DecimalArray::new(
80            buffer![100i32, 2000i32, 200i32],
81            DecimalDType::new(4, 2),
82            Validity::from_iter([true, false, true]),
83        );
84
85        let min_max = min_max(decimal.as_ref()).unwrap();
86
87        let non_nullable_dtype = decimal.dtype().as_nonnullable();
88        let expected = MinMaxResult {
89            min: Scalar::new(
90                non_nullable_dtype.clone(),
91                ScalarValue::from(DecimalValue::from(100i32)),
92            ),
93            max: Scalar::new(
94                non_nullable_dtype,
95                ScalarValue::from(DecimalValue::from(200i32)),
96            ),
97        };
98
99        assert_eq!(Some(expected), min_max)
100    }
101}