vortex_array/compute/
search_sorted.rs

1use std::cmp::Ordering;
2use std::cmp::Ordering::{Equal, Greater, Less};
3use std::fmt::{Debug, Display, Formatter};
4use std::hint;
5
6use itertools::Itertools;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::Array;
11use crate::compute::scalar_at;
12use crate::encoding::Encoding;
13
14#[derive(Debug, Copy, Clone)]
15pub enum SearchSortedSide {
16    Left,
17    Right,
18}
19
20impl Display for SearchSortedSide {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        match self {
23            SearchSortedSide::Left => write!(f, "left"),
24            SearchSortedSide::Right => write!(f, "right"),
25        }
26    }
27}
28
29/// Result of performing search_sorted on an Array
30#[derive(Debug, Copy, Clone, PartialEq, Eq)]
31pub enum SearchResult {
32    /// Result for a found element was found at the given index in the sorted array
33    Found(usize),
34
35    /// Result for an element not found, but that could be inserted at the given position
36    /// in the sorted order.
37    NotFound(usize),
38}
39
40impl SearchResult {
41    /// Convert search result to an index only if the value have been found
42    pub fn to_found(self) -> Option<usize> {
43        match self {
44            Self::Found(i) => Some(i),
45            Self::NotFound(_) => None,
46        }
47    }
48
49    /// Extract index out of search result regardless of whether the value have been found or not
50    pub fn to_index(self) -> usize {
51        match self {
52            Self::Found(i) => i,
53            Self::NotFound(i) => i,
54        }
55    }
56
57    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
58    ///
59    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
60    /// obtain index suitable for indexing into it after performing a search
61    pub fn to_offsets_index(self, len: usize) -> usize {
62        match self {
63            SearchResult::Found(i) => {
64                if i == len {
65                    i - 1
66                } else {
67                    i
68                }
69            }
70            SearchResult::NotFound(i) => i.saturating_sub(1),
71        }
72    }
73
74    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
75    /// i.e. first element implicitly covers 0..0th-element range.
76    ///
77    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
78    /// indexing into it after performing a search
79    pub fn to_ends_index(self, len: usize) -> usize {
80        let idx = self.to_index();
81        if idx == len { idx - 1 } else { idx }
82    }
83
84    /// Apply a transformation to the Found or NotFound index.
85    #[inline]
86    pub fn map<F>(self, f: F) -> SearchResult
87    where
88        F: FnOnce(usize) -> usize,
89    {
90        match self {
91            SearchResult::Found(i) => SearchResult::Found(f(i)),
92            SearchResult::NotFound(i) => SearchResult::NotFound(f(i)),
93        }
94    }
95}
96
97impl Display for SearchResult {
98    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99        match self {
100            SearchResult::Found(i) => write!(f, "Found({i})"),
101            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
102        }
103    }
104}
105
106/// Searches for value assuming the array is sorted.
107///
108/// For nullable arrays we assume that the nulls are sorted last, i.e. they're the greatest value
109pub trait SearchSortedFn<A: Copy> {
110    fn search_sorted(
111        &self,
112        array: A,
113        value: &Scalar,
114        side: SearchSortedSide,
115    ) -> VortexResult<SearchResult>;
116
117    /// Bulk search for many values.
118    fn search_sorted_many(
119        &self,
120        array: A,
121        values: &[Scalar],
122        side: SearchSortedSide,
123    ) -> VortexResult<Vec<SearchResult>> {
124        values
125            .iter()
126            .map(|value| self.search_sorted(array, value, side))
127            .try_collect()
128    }
129}
130
131pub trait SearchSortedUsizeFn<A: Copy> {
132    fn search_sorted_usize(
133        &self,
134        array: A,
135        value: usize,
136        side: SearchSortedSide,
137    ) -> VortexResult<SearchResult>;
138
139    fn search_sorted_usize_many(
140        &self,
141        array: A,
142        values: &[usize],
143        side: SearchSortedSide,
144    ) -> VortexResult<Vec<SearchResult>> {
145        values
146            .iter()
147            .map(|&value| self.search_sorted_usize(array, value, side))
148            .try_collect()
149    }
150}
151
152impl<E: Encoding> SearchSortedFn<&dyn Array> for E
153where
154    E: for<'a> SearchSortedFn<&'a E::Array>,
155{
156    fn search_sorted(
157        &self,
158        array: &dyn Array,
159        value: &Scalar,
160        side: SearchSortedSide,
161    ) -> VortexResult<SearchResult> {
162        let array_ref = array
163            .as_any()
164            .downcast_ref::<E::Array>()
165            .vortex_expect("Failed to downcast array");
166        SearchSortedFn::search_sorted(self, array_ref, value, side)
167    }
168
169    fn search_sorted_many(
170        &self,
171        array: &dyn Array,
172        values: &[Scalar],
173        side: SearchSortedSide,
174    ) -> VortexResult<Vec<SearchResult>> {
175        let array_ref = array
176            .as_any()
177            .downcast_ref::<E::Array>()
178            .vortex_expect("Failed to downcast array");
179        SearchSortedFn::search_sorted_many(self, array_ref, values, side)
180    }
181}
182
183impl<E: Encoding> SearchSortedUsizeFn<&dyn Array> for E
184where
185    E: for<'a> SearchSortedUsizeFn<&'a E::Array>,
186{
187    fn search_sorted_usize(
188        &self,
189        array: &dyn Array,
190        value: usize,
191        side: SearchSortedSide,
192    ) -> VortexResult<SearchResult> {
193        let array_ref = array
194            .as_any()
195            .downcast_ref::<E::Array>()
196            .vortex_expect("Failed to downcast array");
197        SearchSortedUsizeFn::search_sorted_usize(self, array_ref, value, side)
198    }
199
200    fn search_sorted_usize_many(
201        &self,
202        array: &dyn Array,
203        values: &[usize],
204        side: SearchSortedSide,
205    ) -> VortexResult<Vec<SearchResult>> {
206        let array_ref = array
207            .as_any()
208            .downcast_ref::<E::Array>()
209            .vortex_expect("Failed to downcast array");
210        SearchSortedUsizeFn::search_sorted_usize_many(self, array_ref, values, side)
211    }
212}
213
214pub fn search_sorted<T: Into<Scalar>>(
215    array: &dyn Array,
216    target: T,
217    side: SearchSortedSide,
218) -> VortexResult<SearchResult> {
219    let Ok(scalar) = target.into().cast(array.dtype()) else {
220        // Try to downcast the usize ot the array type, if the downcast fails, then we know the
221        // usize is too large and the value is greater than the highest value in the array.
222        return Ok(SearchResult::NotFound(array.len()));
223    };
224
225    if scalar.is_null() {
226        vortex_bail!("Search sorted with null value is not supported");
227    }
228
229    if let Some(f) = array.vtable().search_sorted_fn() {
230        return f.search_sorted(array, &scalar, side);
231    }
232
233    // Fallback to a generic search_sorted using scalar_at
234    if array.vtable().scalar_at_fn().is_some() {
235        return Ok(SearchSorted::search_sorted(array, &scalar, side));
236    }
237
238    vortex_bail!(
239        NotImplemented: "search_sorted",
240        array.encoding()
241    )
242}
243
244pub fn search_sorted_usize(
245    array: &dyn Array,
246    target: usize,
247    side: SearchSortedSide,
248) -> VortexResult<SearchResult> {
249    if let Some(f) = array.vtable().search_sorted_usize_fn() {
250        return f.search_sorted_usize(array, target, side);
251    }
252
253    // Otherwise, convert the target into a scalar to try the search_sorted_fn
254    let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
255        return Ok(SearchResult::NotFound(array.len()));
256    };
257
258    // Try the non-usize search sorted
259    if let Some(f) = array.vtable().search_sorted_fn() {
260        return f.search_sorted(array, &target, side);
261    }
262
263    // Or fallback all the way to a generic search_sorted using scalar_at
264    if array.vtable().scalar_at_fn().is_some() {
265        // Try to downcast the usize to the array type, if the downcast fails, then we know the
266        // usize is too large and the value is greater than the highest value in the array.
267        let Ok(target) = target.cast(array.dtype()) else {
268            return Ok(SearchResult::NotFound(array.len()));
269        };
270        return Ok(SearchSorted::search_sorted(array, &target, side));
271    }
272
273    vortex_bail!(
274    NotImplemented: "search_sorted_usize",
275        array.encoding()
276    )
277}
278
279/// Search for many elements in the array.
280pub fn search_sorted_many<T: Into<Scalar> + Clone>(
281    array: &dyn Array,
282    targets: &[T],
283    side: SearchSortedSide,
284) -> VortexResult<Vec<SearchResult>> {
285    if let Some(f) = array.vtable().search_sorted_fn() {
286        let mut too_big_cast_idxs = Vec::new();
287        let values = targets
288            .iter()
289            .cloned()
290            .enumerate()
291            .filter_map(|(i, t)| {
292                let Ok(c) = t.into().cast(array.dtype()) else {
293                    too_big_cast_idxs.push(i);
294                    return None;
295                };
296                Some(c)
297            })
298            .collect::<Vec<_>>();
299
300        let mut results = f.search_sorted_many(array, &values, side)?;
301        for too_big_idx in too_big_cast_idxs {
302            results.insert(too_big_idx, SearchResult::NotFound(array.len()));
303        }
304        return Ok(results);
305    }
306
307    // Call in loop and collect
308    targets
309        .iter()
310        .map(|target| search_sorted(array, target.clone(), side))
311        .try_collect()
312}
313
314// Native functions for each of the values, cast up to u64 or down to something lower.
315pub fn search_sorted_usize_many(
316    array: &dyn Array,
317    targets: &[usize],
318    side: SearchSortedSide,
319) -> VortexResult<Vec<SearchResult>> {
320    if let Some(f) = array.vtable().search_sorted_usize_fn() {
321        return f.search_sorted_usize_many(array, targets, side);
322    }
323
324    // Call in loop and collect
325    targets
326        .iter()
327        .map(|&target| search_sorted_usize(array, target, side))
328        .try_collect()
329}
330
331pub trait IndexOrd<V> {
332    /// PartialOrd of the value at index `idx` with `elem`.
333    /// For example, if self\[idx\] > elem, return Some(Greater).
334    fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
335
336    fn index_lt(&self, idx: usize, elem: &V) -> bool {
337        matches!(self.index_cmp(idx, elem), Some(Less))
338    }
339
340    fn index_le(&self, idx: usize, elem: &V) -> bool {
341        matches!(self.index_cmp(idx, elem), Some(Less | Equal))
342    }
343
344    fn index_gt(&self, idx: usize, elem: &V) -> bool {
345        matches!(self.index_cmp(idx, elem), Some(Greater))
346    }
347
348    fn index_ge(&self, idx: usize, elem: &V) -> bool {
349        matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
350    }
351}
352
353#[allow(clippy::len_without_is_empty)]
354pub trait Len {
355    fn len(&self) -> usize;
356}
357
358pub trait SearchSorted<T> {
359    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
360    where
361        Self: IndexOrd<T>,
362    {
363        match side {
364            SearchSortedSide::Left => self.search_sorted_by(
365                |idx| self.index_cmp(idx, value).unwrap_or(Less),
366                |idx| {
367                    if self.index_lt(idx, value) {
368                        Less
369                    } else {
370                        Greater
371                    }
372                },
373                side,
374            ),
375            SearchSortedSide::Right => self.search_sorted_by(
376                |idx| self.index_cmp(idx, value).unwrap_or(Less),
377                |idx| {
378                    if self.index_le(idx, value) {
379                        Less
380                    } else {
381                        Greater
382                    }
383                },
384                side,
385            ),
386        }
387    }
388
389    /// find function is used to find the element if it exists, if element exists side_find will be
390    /// used to find desired index amongst equal values
391    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
392        &self,
393        find: F,
394        side_find: N,
395        side: SearchSortedSide,
396    ) -> SearchResult;
397}
398
399// Default implementation for types that implement IndexOrd.
400impl<S, T> SearchSorted<T> for S
401where
402    S: IndexOrd<T> + Len + ?Sized,
403{
404    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
405        &self,
406        find: F,
407        side_find: N,
408        side: SearchSortedSide,
409    ) -> SearchResult {
410        match search_sorted_side_idx(find, 0, self.len()) {
411            SearchResult::Found(found) => {
412                let idx_search = match side {
413                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
414                    SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
415                };
416                match idx_search {
417                    SearchResult::NotFound(i) => SearchResult::Found(i),
418                    _ => unreachable!(
419                        "searching amongst equal values should never return Found result"
420                    ),
421                }
422            }
423            s => s,
424        }
425    }
426}
427
428// Code adapted from Rust standard library slice::binary_search_by
429fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
430    mut find: F,
431    from: usize,
432    to: usize,
433) -> SearchResult {
434    let mut size = to - from;
435    if size == 0 {
436        return SearchResult::NotFound(0);
437    }
438    let mut base = from;
439
440    // This loop intentionally doesn't have an early exit if the comparison
441    // returns Equal. We want the number of loop iterations to depend *only*
442    // on the size of the input slice so that the CPU can reliably predict
443    // the loop count.
444    while size > 1 {
445        let half = size / 2;
446        let mid = base + half;
447
448        // SAFETY: the call is made safe by the following inconstants:
449        // - `mid >= 0`: by definition
450        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
451        let cmp = find(mid);
452
453        // Binary search interacts poorly with branch prediction, so force
454        // the compiler to use conditional moves if supported by the target
455        // architecture.
456        base = if cmp == Greater { base } else { mid };
457
458        // This is imprecise in the case where `size` is odd and the
459        // comparison returns Greater: the mid element still gets included
460        // by `size` even though it's known to be larger than the element
461        // being searched for.
462        //
463        // This is fine though: we gain more performance by keeping the
464        // loop iteration count invariant (and thus predictable) than we
465        // lose from considering one additional element.
466        size -= half;
467    }
468
469    // SAFETY: base is always in [0, size) because base <= mid.
470    let cmp = find(base);
471    if cmp == Equal {
472        // SAFETY: same as the call to `find` above.
473        unsafe { hint::assert_unchecked(base < to) };
474        SearchResult::Found(base)
475    } else {
476        let result = base + (cmp == Less) as usize;
477        // SAFETY: same as the call to `find` above.
478        // Note that this is `<=`, unlike the assert in the `Found` path.
479        unsafe { hint::assert_unchecked(result <= to) };
480        SearchResult::NotFound(result)
481    }
482}
483
484impl IndexOrd<Scalar> for dyn Array + '_ {
485    fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
486        let scalar_a = scalar_at(self, idx).ok()?;
487        scalar_a.partial_cmp(elem)
488    }
489}
490
491impl<T: PartialOrd> IndexOrd<T> for [T] {
492    fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
493        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
494        unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
495    }
496}
497
498impl Len for dyn Array + '_ {
499    #[allow(clippy::same_name_method)]
500    fn len(&self) -> usize {
501        Self::len(self)
502    }
503}
504
505impl<T> Len for [T] {
506    fn len(&self) -> usize {
507        self.len()
508    }
509}
510
511#[cfg(test)]
512mod test {
513    use vortex_buffer::buffer;
514
515    use crate::IntoArray;
516    use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
517    use crate::compute::{search_sorted, search_sorted_many};
518
519    #[test]
520    fn left_side_equal() {
521        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
522        let res = arr.search_sorted(&2, SearchSortedSide::Left);
523        assert_eq!(arr[res.to_index()], 2);
524        assert_eq!(res, SearchResult::Found(2));
525    }
526
527    #[test]
528    fn right_side_equal() {
529        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
530        let res = arr.search_sorted(&2, SearchSortedSide::Right);
531        assert_eq!(arr[res.to_index() - 1], 2);
532        assert_eq!(res, SearchResult::Found(6));
533    }
534
535    #[test]
536    fn left_side_equal_beginning() {
537        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
538        let res = arr.search_sorted(&0, SearchSortedSide::Left);
539        assert_eq!(arr[res.to_index()], 0);
540        assert_eq!(res, SearchResult::Found(0));
541    }
542
543    #[test]
544    fn right_side_equal_beginning() {
545        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
546        let res = arr.search_sorted(&0, SearchSortedSide::Right);
547        assert_eq!(arr[res.to_index() - 1], 0);
548        assert_eq!(res, SearchResult::Found(4));
549    }
550
551    #[test]
552    fn left_side_equal_end() {
553        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
554        let res = arr.search_sorted(&9, SearchSortedSide::Left);
555        assert_eq!(arr[res.to_index()], 9);
556        assert_eq!(res, SearchResult::Found(9));
557    }
558
559    #[test]
560    fn right_side_equal_end() {
561        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
562        let res = arr.search_sorted(&9, SearchSortedSide::Right);
563        assert_eq!(arr[res.to_index() - 1], 9);
564        assert_eq!(res, SearchResult::Found(13));
565    }
566
567    #[test]
568    fn failed_cast() {
569        let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
570        let res = search_sorted(&arr, 256, SearchSortedSide::Left).unwrap();
571        assert_eq!(res, SearchResult::NotFound(arr.len()));
572    }
573
574    #[test]
575    fn search_sorted_many_failed_cast() {
576        let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
577        let res = search_sorted_many(&arr, &[256], SearchSortedSide::Left).unwrap();
578        assert_eq!(res, vec![SearchResult::NotFound(arr.len())]);
579    }
580}