vortex_array/arrays/decimal/compute/
min_max.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_dtype::DecimalDType;
6use vortex_dtype::NativeDecimalType;
7use vortex_dtype::Nullability::NonNullable;
8use vortex_dtype::match_each_decimal_value_type;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11use vortex_scalar::DecimalValue;
12use vortex_scalar::Scalar;
13
14use crate::arrays::DecimalArray;
15use crate::arrays::DecimalVTable;
16use crate::compute::MinMaxKernel;
17use crate::compute::MinMaxKernelAdapter;
18use crate::compute::MinMaxResult;
19use crate::register_kernel;
20
21impl MinMaxKernel for DecimalVTable {
22    fn min_max(&self, array: &DecimalArray) -> VortexResult<Option<MinMaxResult>> {
23        match_each_decimal_value_type!(array.values_type(), |T| {
24            compute_min_max_with_validity::<T>(array)
25        })
26    }
27}
28
29register_kernel!(MinMaxKernelAdapter(DecimalVTable).lift());
30
31#[inline]
32fn compute_min_max_with_validity<D>(array: &DecimalArray) -> VortexResult<Option<MinMaxResult>>
33where
34    D: Into<DecimalValue> + NativeDecimalType,
35{
36    Ok(match array.validity_mask() {
37        Mask::AllTrue(_) => compute_min_max(array.buffer::<D>().iter(), array.decimal_dtype()),
38        Mask::AllFalse(_) => None,
39        Mask::Values(v) => compute_min_max(
40            array
41                .buffer::<D>()
42                .iter()
43                .zip(v.bit_buffer().iter())
44                .filter_map(|(v, m)| m.then_some(v)),
45            array.decimal_dtype(),
46        ),
47    })
48}
49
50fn compute_min_max<'a, T>(
51    iter: impl Iterator<Item = &'a T>,
52    decimal_dtype: DecimalDType,
53) -> Option<MinMaxResult>
54where
55    T: Into<DecimalValue> + NativeDecimalType + Ord + Copy + 'a,
56{
57    match iter.minmax_by(|a, b| a.cmp(b)) {
58        itertools::MinMaxResult::NoElements => None,
59        itertools::MinMaxResult::OneElement(&x) => {
60            let scalar = Scalar::decimal(x.into(), decimal_dtype, NonNullable);
61            Some(MinMaxResult {
62                min: scalar.clone(),
63                max: scalar,
64            })
65        }
66        itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
67            min: Scalar::decimal(min.into(), decimal_dtype, NonNullable),
68            max: Scalar::decimal(max.into(), decimal_dtype, NonNullable),
69        }),
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use vortex_buffer::buffer;
76    use vortex_dtype::DecimalDType;
77    use vortex_scalar::DecimalValue;
78    use vortex_scalar::Scalar;
79    use vortex_scalar::ScalarValue;
80
81    use crate::arrays::DecimalArray;
82    use crate::compute::MinMaxResult;
83    use crate::compute::min_max;
84    use crate::validity::Validity;
85
86    #[test]
87    fn min_max_test() {
88        let decimal = DecimalArray::new(
89            buffer![100i32, 2000i32, 200i32],
90            DecimalDType::new(4, 2),
91            Validity::from_iter([true, false, true]),
92        );
93
94        let min_max = min_max(decimal.as_ref()).unwrap();
95
96        let non_nullable_dtype = decimal.dtype().as_nonnullable();
97        let expected = MinMaxResult {
98            min: Scalar::new(
99                non_nullable_dtype.clone(),
100                ScalarValue::from(DecimalValue::from(100i32)),
101            ),
102            max: Scalar::new(
103                non_nullable_dtype,
104                ScalarValue::from(DecimalValue::from(200i32)),
105            ),
106        };
107
108        assert_eq!(Some(expected), min_max)
109    }
110}