vortex_array/arrays/primitive/compute/
min_max.rs1use itertools::Itertools;
2use vortex_dtype::{DType, NativePType, match_each_native_ptype};
3use vortex_error::VortexResult;
4use vortex_mask::Mask;
5use vortex_scalar::{Scalar, ScalarValue};
6
7use crate::arrays::{PrimitiveArray, PrimitiveVTable};
8use crate::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
9use crate::register_kernel;
10
11impl MinMaxKernel for PrimitiveVTable {
12 fn min_max(&self, array: &PrimitiveArray) -> VortexResult<Option<MinMaxResult>> {
13 match_each_native_ptype!(array.ptype(), |T| {
14 compute_min_max_with_validity::<T>(array)
15 })
16 }
17}
18
19register_kernel!(MinMaxKernelAdapter(PrimitiveVTable).lift());
20
21#[inline]
22fn compute_min_max_with_validity<T>(array: &PrimitiveArray) -> VortexResult<Option<MinMaxResult>>
23where
24 T: Into<ScalarValue> + NativePType,
25{
26 Ok(match array.validity_mask()? {
27 Mask::AllTrue(_) => compute_min_max(array.as_slice::<T>().iter(), array.dtype()),
28 Mask::AllFalse(_) => None,
29 Mask::Values(v) => compute_min_max(
30 array
31 .as_slice::<T>()
32 .iter()
33 .zip(v.boolean_buffer().iter())
34 .filter_map(|(v, m)| m.then_some(v)),
35 array.dtype(),
36 ),
37 })
38}
39
40fn compute_min_max<'a, T>(iter: impl Iterator<Item = &'a T>, dtype: &DType) -> Option<MinMaxResult>
41where
42 T: Into<ScalarValue> + NativePType,
43{
44 match iter
47 .filter(|v| !v.is_nan())
48 .minmax_by(|a, b| a.total_compare(**b))
49 {
50 itertools::MinMaxResult::NoElements => None,
51 itertools::MinMaxResult::OneElement(&x) => {
52 let scalar = Scalar::new(dtype.clone(), x.into());
53 Some(MinMaxResult {
54 min: scalar.clone(),
55 max: scalar,
56 })
57 }
58 itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
59 min: Scalar::new(dtype.clone(), min.into()),
60 max: Scalar::new(dtype.clone(), max.into()),
61 }),
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use vortex_buffer::buffer;
68
69 use crate::arrays::PrimitiveArray;
70 use crate::compute::min_max;
71 use crate::validity::Validity;
72
73 #[test]
74 fn min_max_nan() {
75 let array = PrimitiveArray::new(
76 buffer![f32::NAN, -f32::NAN, -1.0, 1.0],
77 Validity::NonNullable,
78 );
79 let min_max = min_max(array.as_ref()).unwrap().unwrap();
80 assert_eq!(f32::try_from(min_max.min).unwrap(), -1.0);
81 assert_eq!(f32::try_from(min_max.max).unwrap(), 1.0);
82 }
83
84 #[test]
85 fn min_max_inf() {
86 let array = PrimitiveArray::new(
87 buffer![f32::INFINITY, f32::NEG_INFINITY, -1.0, 1.0],
88 Validity::NonNullable,
89 );
90 let min_max = min_max(array.as_ref()).unwrap().unwrap();
91 assert_eq!(f32::try_from(min_max.min).unwrap(), f32::NEG_INFINITY);
92 assert_eq!(f32::try_from(min_max.max).unwrap(), f32::INFINITY);
93 }
94}