1use std::sync::{Arc, LazyLock};
2
3use arcref::ArcRef;
4use vortex_dtype::{DType, Nullability, StructDType};
5use vortex_error::{VortexResult, vortex_bail};
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
10use crate::stats::{Precision, Stat, StatsProviderExt};
11use crate::vtable::VTable;
12
13pub fn min_max(array: &dyn Array) -> VortexResult<Option<MinMaxResult>> {
20    let scalar = MIN_MAX_FN
21        .invoke(&InvocationArgs {
22            inputs: &[array.into()],
23            options: &(),
24        })?
25        .unwrap_scalar()?;
26    MinMaxResult::from_scalar(scalar)
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct MinMaxResult {
31    pub min: Scalar,
32    pub max: Scalar,
33}
34
35impl MinMaxResult {
36    pub fn from_scalar(scalar: Scalar) -> VortexResult<Option<Self>> {
37        if scalar.is_null() {
38            Ok(None)
39        } else {
40            let min = scalar.as_struct().field_by_idx(0)?;
41            let max = scalar.as_struct().field_by_idx(1)?;
42            Ok(Some(MinMaxResult { min, max }))
43        }
44    }
45}
46
47pub struct MinMax;
48
49impl ComputeFnVTable for MinMax {
50    fn invoke(
51        &self,
52        args: &InvocationArgs,
53        kernels: &[ArcRef<dyn Kernel>],
54    ) -> VortexResult<Output> {
55        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
56
57        let return_dtype = self.return_dtype(args)?;
58
59        match min_max_impl(array, kernels)? {
60            None => Ok(Scalar::null(return_dtype).into()),
61            Some(MinMaxResult { min, max }) => {
62                assert!(
63                    min <= max,
64                    "min > max: min={} max={} encoding={}",
65                    min,
66                    max,
67                    array.encoding_id()
68                );
69
70                array
72                    .statistics()
73                    .set(Stat::Min, Precision::Exact(min.value().clone()));
74                array
75                    .statistics()
76                    .set(Stat::Max, Precision::Exact(max.value().clone()));
77
78                Ok(Scalar::struct_(return_dtype, vec![min, max]).into())
80            }
81        }
82    }
83
84    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
85        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
86
87        Ok(DType::Struct(
90            Arc::new(StructDType::new(
91                ["min".into(), "max".into()].into(),
92                vec![array.dtype().clone(), array.dtype().clone()],
93            )),
94            Nullability::Nullable,
95        ))
96    }
97
98    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
99        Ok(1)
100    }
101
102    fn is_elementwise(&self) -> bool {
103        false
104    }
105}
106
107fn min_max_impl(
108    array: &dyn Array,
109    kernels: &[ArcRef<dyn Kernel>],
110) -> VortexResult<Option<MinMaxResult>> {
111    if array.is_empty() || array.valid_count()? == 0 {
112        return Ok(None);
113    }
114
115    let min = array
116        .statistics()
117        .get_scalar(Stat::Min, array.dtype())
118        .and_then(Precision::as_exact);
119    let max = array
120        .statistics()
121        .get_scalar(Stat::Max, array.dtype())
122        .and_then(Precision::as_exact);
123
124    if let Some((min, max)) = min.zip(max) {
125        return Ok(Some(MinMaxResult { min, max }));
126    }
127
128    let args = InvocationArgs {
129        inputs: &[array.into()],
130        options: &(),
131    };
132    for kernel in kernels {
133        if let Some(output) = kernel.invoke(&args)? {
134            return MinMaxResult::from_scalar(output.unwrap_scalar()?);
135        }
136    }
137    if let Some(output) = array.invoke(&MIN_MAX_FN, &args)? {
138        return MinMaxResult::from_scalar(output.unwrap_scalar()?);
139    }
140
141    if !array.is_canonical() {
142        let array = array.to_canonical()?;
143        return min_max(array.as_ref());
144    }
145
146    vortex_bail!(NotImplemented: "min_max", array.encoding_id());
147}
148
149pub trait MinMaxKernel: VTable {
152    fn min_max(&self, array: &Self::Array) -> VortexResult<Option<MinMaxResult>>;
153}
154
155pub struct MinMaxKernelRef(ArcRef<dyn Kernel>);
156inventory::collect!(MinMaxKernelRef);
157
158#[derive(Debug)]
159pub struct MinMaxKernelAdapter<V: VTable>(pub V);
160
161impl<V: VTable + MinMaxKernel> MinMaxKernelAdapter<V> {
162    pub const fn lift(&'static self) -> MinMaxKernelRef {
163        MinMaxKernelRef(ArcRef::new_ref(self))
164    }
165}
166
167impl<V: VTable + MinMaxKernel> Kernel for MinMaxKernelAdapter<V> {
168    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
169        let inputs = UnaryArgs::<()>::try_from(args)?;
170        let Some(array) = inputs.array.as_opt::<V>() else {
171            return Ok(None);
172        };
173        let dtype = DType::Struct(
174            Arc::new(StructDType::new(
175                ["min".into(), "max".into()].into(),
176                vec![array.dtype().clone(), array.dtype().clone()],
177            )),
178            Nullability::Nullable,
179        );
180        Ok(Some(match V::min_max(&self.0, array)? {
181            None => Scalar::null(dtype).into(),
182            Some(MinMaxResult { min, max }) => Scalar::struct_(dtype, vec![min, max]).into(),
183        }))
184    }
185}
186
187pub static MIN_MAX_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
188    let compute = ComputeFn::new("min_max".into(), ArcRef::new_ref(&MinMax));
189    for kernel in inventory::iter::<MinMaxKernelRef> {
190        compute.register_kernel(kernel.0.clone());
191    }
192    compute
193});
194
195#[cfg(test)]
196mod tests {
197    use arrow_buffer::BooleanBuffer;
198    use vortex_buffer::buffer;
199
200    use crate::arrays::{BoolArray, NullArray, PrimitiveArray};
201    use crate::compute::{MinMaxResult, min_max};
202    use crate::validity::Validity;
203
204    #[test]
205    fn test_prim_max() {
206        let p = PrimitiveArray::new(buffer![1, 2, 3], Validity::NonNullable);
207        assert_eq!(
208            min_max(p.as_ref()).unwrap(),
209            Some(MinMaxResult {
210                min: 1.into(),
211                max: 3.into()
212            })
213        );
214    }
215
216    #[test]
217    fn test_bool_max() {
218        let p = BoolArray::new(
219            BooleanBuffer::from([true, true, true].as_slice()),
220            Validity::NonNullable,
221        );
222        assert_eq!(
223            min_max(p.as_ref()).unwrap(),
224            Some(MinMaxResult {
225                min: true.into(),
226                max: true.into()
227            })
228        );
229
230        let p = BoolArray::new(
231            BooleanBuffer::from([false, false, false].as_slice()),
232            Validity::NonNullable,
233        );
234        assert_eq!(
235            min_max(p.as_ref()).unwrap(),
236            Some(MinMaxResult {
237                min: false.into(),
238                max: false.into()
239            })
240        );
241
242        let p = BoolArray::new(
243            BooleanBuffer::from([false, true, false].as_slice()),
244            Validity::NonNullable,
245        );
246        assert_eq!(
247            min_max(p.as_ref()).unwrap(),
248            Some(MinMaxResult {
249                min: false.into(),
250                max: true.into()
251            })
252        );
253    }
254
255    #[test]
256    fn test_null() {
257        let p = NullArray::new(1);
258        assert_eq!(min_max(p.as_ref()).unwrap(), None);
259    }
260}