vortex_array/arrays/dict/compute/
min_max.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::match_each_unsigned_integer_ptype;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use super::{DictArray, DictVTable};
10use crate::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult, mask, min_max};
11use crate::{Array as _, ToCanonical, register_kernel};
12
13impl MinMaxKernel for DictVTable {
14    fn min_max(&self, array: &DictArray) -> VortexResult<Option<MinMaxResult>> {
15        let codes_validity = array.codes().validity_mask();
16        if codes_validity.all_false() {
17            return Ok(None);
18        }
19
20        let codes_primitive = array.codes().to_primitive();
21        let values_len = array.values().len();
22        match_each_unsigned_integer_ptype!(codes_primitive.ptype(), |P| {
23            codes_validity.iter_bools(|validity_iter| {
24                // mask() sets values to null where the mask is true, so we start
25                // with a fully-set bool buffer.
26                let mut unreferenced = vec![true; values_len];
27                #[allow(clippy::cast_possible_truncation)]
28                for (&code, is_valid) in codes_primitive.as_slice::<P>().iter().zip(validity_iter) {
29                    if is_valid {
30                        unreferenced[code as usize] = false;
31                    }
32                }
33
34                let unreferenced_mask =
35                    Mask::from_buffer(BitBuffer::collect_bool(values_len, |i| unreferenced[i]));
36                min_max(&mask(array.values(), &unreferenced_mask)?)
37            })
38        })
39    }
40}
41
42register_kernel!(MinMaxKernelAdapter(DictVTable).lift());
43
44#[cfg(test)]
45mod tests {
46    use rstest::rstest;
47    use vortex_buffer::buffer;
48
49    use super::DictArray;
50    use crate::arrays::PrimitiveArray;
51    use crate::builders::dict::dict_encode;
52    use crate::compute::min_max;
53    use crate::{Array, IntoArray};
54
55    fn assert_min_max(array: &dyn Array, expected: Option<(i32, i32)>) {
56        match (min_max(array).unwrap(), expected) {
57            (Some(result), Some((expected_min, expected_max))) => {
58                assert_eq!(i32::try_from(result.min).unwrap(), expected_min);
59                assert_eq!(i32::try_from(result.max).unwrap(), expected_max);
60            }
61            (None, None) => {}
62            (got, expected) => panic!(
63                "min_max mismatch: expected {:?}, got {:?}",
64                expected,
65                got.as_ref().map(|r| (
66                    i32::try_from(r.min.clone()).ok(),
67                    i32::try_from(r.max.clone()).ok()
68                ))
69            ),
70        }
71    }
72
73    #[rstest]
74    #[case::covering(
75        DictArray::try_new(
76            buffer![0u32, 1, 2, 3, 0, 1].into_array(),
77            buffer![10i32, 20, 30, 40].into_array(),
78        ).unwrap(),
79        (10, 40)
80    )]
81    #[case::non_covering_duplicates(
82        DictArray::try_new(
83            buffer![1u32, 1, 1, 3, 3].into_array(),
84            buffer![1i32, 2, 3, 4, 5].into_array(),
85        ).unwrap(),
86        (2, 4)
87    )]
88    // Non-covering: codes with gaps
89    #[case::non_covering_gaps(
90        DictArray::try_new(
91            buffer![0u32, 2, 4].into_array(),
92            buffer![1i32, 2, 3, 4, 5].into_array(),
93        ).unwrap(),
94        (1, 5)
95    )]
96    #[case::single(dict_encode(&buffer![42i32].into_array()).unwrap(), (42, 42))]
97    #[case::nullable_codes(
98        DictArray::try_new(
99            PrimitiveArray::from_option_iter([Some(0u32), None, Some(1), Some(2)]).into_array(),
100            buffer![10i32, 20, 30].into_array(),
101        ).unwrap(),
102        (10, 30)
103    )]
104    #[case::nullable_values(
105        dict_encode(
106            PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
107        ).unwrap(),
108        (1, 2)
109    )]
110    fn test_min_max(#[case] dict: DictArray, #[case] expected: (i32, i32)) {
111        assert_min_max(dict.as_ref(), Some(expected));
112    }
113
114    #[test]
115    fn test_sliced_dict() {
116        let reference = PrimitiveArray::from_iter([1, 5, 10, 50, 100]);
117        let dict = dict_encode(reference.as_ref()).unwrap();
118        let sliced = dict.slice(1..3);
119        assert_min_max(&sliced, Some((5, 10)));
120    }
121
122    #[rstest]
123    #[case::empty(
124        DictArray::try_new(
125            PrimitiveArray::from_iter(Vec::<u32>::new()).into_array(),
126            buffer![10i32, 20, 30].into_array(),
127        ).unwrap()
128    )]
129    #[case::all_null_codes(
130        DictArray::try_new(
131            PrimitiveArray::from_option_iter([Option::<u32>::None, None, None]).into_array(),
132            buffer![10i32, 20, 30].into_array(),
133        ).unwrap()
134    )]
135    fn test_min_max_none(#[case] dict: DictArray) {
136        assert_min_max(dict.as_ref(), None);
137    }
138}