vortex_array/compute/
is_sorted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::{DType, Nullability};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::Array;
13use crate::arrays::{ConstantVTable, NullVTable};
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
15use crate::stats::{Precision, Stat, StatsProviderExt};
16use crate::vtable::VTable;
17
18static IS_SORTED_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
19    let compute = ComputeFn::new("is_sorted".into(), ArcRef::new_ref(&IsSorted));
20    for kernel in inventory::iter::<IsSortedKernelRef> {
21        compute.register_kernel(kernel.0.clone());
22    }
23    compute
24});
25
26pub(crate) fn warm_up_vtable() -> usize {
27    IS_SORTED_FN.kernels().len()
28}
29
30pub fn is_sorted(array: &dyn Array) -> VortexResult<Option<bool>> {
31    is_sorted_opts(array, false)
32}
33
34pub fn is_strict_sorted(array: &dyn Array) -> VortexResult<Option<bool>> {
35    is_sorted_opts(array, true)
36}
37
38pub fn is_sorted_opts(array: &dyn Array, strict: bool) -> VortexResult<Option<bool>> {
39    Ok(IS_SORTED_FN
40        .invoke(&InvocationArgs {
41            inputs: &[array.into()],
42            options: &IsSortedOptions { strict },
43        })?
44        .unwrap_scalar()?
45        .as_bool()
46        .value())
47}
48
49struct IsSorted;
50impl ComputeFnVTable for IsSorted {
51    fn invoke(
52        &self,
53        args: &InvocationArgs,
54        kernels: &[ArcRef<dyn Kernel>],
55    ) -> VortexResult<Output> {
56        let IsSortedArgs { array, strict } = IsSortedArgs::try_from(args)?;
57
58        // We currently don't support sorting struct arrays.
59        if array.dtype().is_struct() {
60            return Ok(Scalar::from(Some(false)).into());
61        }
62
63        let is_sorted = if strict {
64            if let Some(Precision::Exact(value)) =
65                array.statistics().get_as::<bool>(Stat::IsStrictSorted)
66            {
67                return Ok(Scalar::from(Some(value)).into());
68            }
69
70            let is_strict_sorted = is_sorted_impl(array, kernels, true)?;
71            let array_stats = array.statistics();
72
73            if is_strict_sorted.is_some() {
74                if is_strict_sorted.unwrap_or(false) {
75                    array_stats.set(Stat::IsSorted, Precision::Exact(true.into()));
76                    array_stats.set(Stat::IsStrictSorted, Precision::Exact(true.into()));
77                } else {
78                    array_stats.set(Stat::IsStrictSorted, Precision::Exact(false.into()));
79                }
80            }
81
82            is_strict_sorted
83        } else {
84            if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsSorted)
85            {
86                return Ok(Scalar::from(Some(value)).into());
87            }
88
89            let is_sorted = is_sorted_impl(array, kernels, false)?;
90            let array_stats = array.statistics();
91
92            if is_sorted.is_some() {
93                if is_sorted.unwrap_or(false) {
94                    array_stats.set(Stat::IsSorted, Precision::Exact(true.into()));
95                } else {
96                    array_stats.set(Stat::IsSorted, Precision::Exact(false.into()));
97                    array_stats.set(Stat::IsStrictSorted, Precision::Exact(false.into()));
98                }
99            }
100
101            is_sorted
102        };
103
104        Ok(Scalar::from(is_sorted).into())
105    }
106
107    fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
108        // We always return a nullable boolean where `null` indicates we couldn't determine
109        // whether the array is constant.
110        Ok(DType::Bool(Nullability::Nullable))
111    }
112
113    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
114        Ok(1)
115    }
116
117    fn is_elementwise(&self) -> bool {
118        true
119    }
120}
121
122struct IsSortedArgs<'a> {
123    array: &'a dyn Array,
124    strict: bool,
125}
126
127impl<'a> TryFrom<&InvocationArgs<'a>> for IsSortedArgs<'a> {
128    type Error = VortexError;
129
130    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
131        if value.inputs.len() != 1 {
132            vortex_bail!(
133                "IsSorted function requires exactly one argument, got {}",
134                value.inputs.len()
135            );
136        }
137        let array = value.inputs[0]
138            .array()
139            .ok_or_else(|| vortex_err!("Invalid argument type for is sorted function"))?;
140        let options = *value
141            .options
142            .as_any()
143            .downcast_ref::<IsSortedOptions>()
144            .ok_or_else(|| vortex_err!("Invalid options type for is sorted function"))?;
145
146        Ok(IsSortedArgs {
147            array,
148            strict: options.strict,
149        })
150    }
151}
152
153#[derive(Clone, Copy)]
154struct IsSortedOptions {
155    strict: bool,
156}
157
158impl Options for IsSortedOptions {
159    fn as_any(&self) -> &dyn Any {
160        self
161    }
162}
163
164pub struct IsSortedKernelRef(ArcRef<dyn Kernel>);
165inventory::collect!(IsSortedKernelRef);
166
167#[derive(Debug)]
168pub struct IsSortedKernelAdapter<V: VTable>(pub V);
169
170impl<V: VTable + IsSortedKernel> IsSortedKernelAdapter<V> {
171    pub const fn lift(&'static self) -> IsSortedKernelRef {
172        IsSortedKernelRef(ArcRef::new_ref(self))
173    }
174}
175
176impl<V: VTable + IsSortedKernel> Kernel for IsSortedKernelAdapter<V> {
177    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
178        let IsSortedArgs { array, strict } = IsSortedArgs::try_from(args)?;
179        let Some(array) = array.as_opt::<V>() else {
180            return Ok(None);
181        };
182
183        let is_sorted = if strict {
184            V::is_strict_sorted(&self.0, array)?
185        } else {
186            V::is_sorted(&self.0, array)?
187        };
188
189        Ok(Some(Scalar::from(is_sorted).into()))
190    }
191}
192
193pub trait IsSortedKernel: VTable {
194    /// # Preconditions
195    /// - The array's length is > 1.
196    /// - The array is not encoded as `NullArray` or `ConstantArray`.
197    /// - If doing a `strict` check, if the array is nullable, it'll have at most 1 null element
198    ///   as the first item in the array.
199    fn is_sorted(&self, array: &Self::Array) -> VortexResult<Option<bool>>;
200
201    fn is_strict_sorted(&self, array: &Self::Array) -> VortexResult<Option<bool>>;
202}
203
204#[allow(clippy::wrong_self_convention)]
205/// Helper methods to check sortedness with strictness
206pub trait IsSortedIteratorExt: Iterator
207where
208    <Self as Iterator>::Item: PartialOrd,
209{
210    fn is_strict_sorted(self) -> bool
211    where
212        Self: Sized,
213        Self::Item: PartialOrd,
214    {
215        self.is_sorted_by(|a, b| a < b)
216    }
217}
218
219impl<T> IsSortedIteratorExt for T
220where
221    T: Iterator + ?Sized,
222    T::Item: PartialOrd,
223{
224}
225
226fn is_sorted_impl(
227    array: &dyn Array,
228    kernels: &[ArcRef<dyn Kernel>],
229    strict: bool,
230) -> VortexResult<Option<bool>> {
231    // Arrays with 0 or 1 elements are strict sorted.
232    if array.len() <= 1 {
233        return Ok(Some(true));
234    }
235
236    // Constant and null arrays are always sorted, but not strict sorted.
237    if array.is::<ConstantVTable>() || array.is::<NullVTable>() {
238        return Ok(Some(!strict));
239    }
240
241    // Enforce strictness before we even try to check if the array is sorted.
242    if strict {
243        let invalid_count = array.invalid_count();
244        match invalid_count {
245            // We can keep going
246            0 => {}
247            // If we have a potential null value - it has to be the first one.
248            1 => {
249                if !array.is_invalid(0) {
250                    return Ok(Some(false));
251                }
252            }
253            _ => return Ok(Some(false)),
254        }
255    }
256
257    let args = InvocationArgs {
258        inputs: &[array.into()],
259        options: &IsSortedOptions { strict },
260    };
261
262    for kernel in kernels {
263        if let Some(output) = kernel.invoke(&args)? {
264            return Ok(output.unwrap_scalar()?.as_bool().value());
265        }
266    }
267    if let Some(output) = array.invoke(&IS_SORTED_FN, &args)? {
268        return Ok(output.unwrap_scalar()?.as_bool().value());
269    }
270
271    if !array.is_canonical() {
272        log::debug!(
273            "No is_sorted implementation found for {}",
274            array.encoding_id()
275        );
276
277        // Recurse to canonical implementation
278        let array = array.to_canonical();
279
280        return if strict {
281            is_strict_sorted(array.as_ref())
282        } else {
283            is_sorted(array.as_ref())
284        };
285    }
286
287    vortex_bail!(
288        "No is_sorted function for canonical array: {}",
289        array.encoding_id(),
290    )
291}
292
293#[cfg(test)]
294mod tests {
295    use vortex_buffer::buffer;
296
297    use crate::IntoArray;
298    use crate::arrays::{BoolArray, PrimitiveArray};
299    use crate::compute::{is_sorted, is_strict_sorted};
300    use crate::validity::Validity;
301    #[test]
302    fn test_is_sorted() {
303        assert!(
304            is_sorted(PrimitiveArray::new(buffer!(0, 1, 2, 3), Validity::AllValid).as_ref())
305                .unwrap()
306                .unwrap()
307        );
308        assert!(
309            is_sorted(
310                PrimitiveArray::new(
311                    buffer!(0, 1, 2, 3),
312                    Validity::Array(BoolArray::from_iter([false, true, true, true]).into_array())
313                )
314                .as_ref()
315            )
316            .unwrap()
317            .unwrap()
318        );
319        assert!(
320            !is_sorted(
321                PrimitiveArray::new(
322                    buffer!(0, 1, 2, 3),
323                    Validity::Array(BoolArray::from_iter([true, false, true, true]).into_array())
324                )
325                .as_ref()
326            )
327            .unwrap()
328            .unwrap()
329        );
330
331        assert!(
332            !is_sorted(PrimitiveArray::new(buffer!(0, 1, 3, 2), Validity::AllValid).as_ref())
333                .unwrap()
334                .unwrap()
335        );
336        assert!(
337            !is_sorted(
338                PrimitiveArray::new(
339                    buffer!(0, 1, 3, 2),
340                    Validity::Array(BoolArray::from_iter([false, true, true, true]).into_array()),
341                )
342                .as_ref()
343            )
344            .unwrap()
345            .unwrap()
346        );
347    }
348
349    #[test]
350    fn test_is_strict_sorted() {
351        assert!(
352            is_strict_sorted(PrimitiveArray::new(buffer!(0, 1, 2, 3), Validity::AllValid).as_ref())
353                .unwrap()
354                .unwrap()
355        );
356        assert!(
357            is_strict_sorted(
358                PrimitiveArray::new(
359                    buffer!(0, 1, 2, 3),
360                    Validity::Array(BoolArray::from_iter([false, true, true, true]).into_array())
361                )
362                .as_ref()
363            )
364            .unwrap()
365            .unwrap()
366        );
367        assert!(
368            !is_strict_sorted(
369                PrimitiveArray::new(
370                    buffer!(0, 1, 2, 3),
371                    Validity::Array(BoolArray::from_iter([true, false, true, true]).into_array()),
372                )
373                .as_ref()
374            )
375            .unwrap()
376            .unwrap()
377        );
378
379        assert!(
380            !is_strict_sorted(
381                PrimitiveArray::new(
382                    buffer!(0, 1, 3, 2),
383                    Validity::Array(BoolArray::from_iter([false, true, true, true]).into_array()),
384                )
385                .as_ref()
386            )
387            .unwrap()
388            .unwrap()
389        );
390    }
391}