vortex_array/compute/
search_sorted.rs

1use std::cmp::Ordering;
2use std::cmp::Ordering::{Equal, Greater, Less};
3use std::fmt::{Debug, Display, Formatter};
4use std::hint;
5
6use itertools::Itertools;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::Array;
11use crate::compute::scalar_at;
12use crate::encoding::Encoding;
13
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum SearchSortedSide {
16    Left,
17    Right,
18}
19
20impl Display for SearchSortedSide {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        match self {
23            SearchSortedSide::Left => write!(f, "left"),
24            SearchSortedSide::Right => write!(f, "right"),
25        }
26    }
27}
28
29/// Result of performing search_sorted on an Array
30///
31/// See [`SearchSortedFn`] documentation for interpretation of the results
32#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33pub enum SearchResult {
34    /// Result for a found element was found in the array and another one could be inserted at the given position
35    /// in the sorted order
36    Found(usize),
37
38    /// Result for an element not found, but that could be inserted at the given position
39    /// in the sorted order.
40    NotFound(usize),
41}
42
43impl SearchResult {
44    /// Convert search result to an index only if the value have been found
45    pub fn to_found(self) -> Option<usize> {
46        match self {
47            Self::Found(i) => Some(i),
48            Self::NotFound(_) => None,
49        }
50    }
51
52    /// Extract index out of search result regardless of whether the value have been found or not
53    pub fn to_index(self) -> usize {
54        match self {
55            Self::Found(i) => i,
56            Self::NotFound(i) => i,
57        }
58    }
59
60    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
61    ///
62    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
63    /// obtain index suitable for indexing into it after performing a search
64    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
65        match self {
66            SearchResult::Found(i) => {
67                if side == SearchSortedSide::Right || i == len {
68                    i.saturating_sub(1)
69                } else {
70                    i
71                }
72            }
73            SearchResult::NotFound(i) => i.saturating_sub(1),
74        }
75    }
76
77    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
78    /// i.e. first element implicitly covers 0..0th-element range.
79    ///
80    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
81    /// indexing into it after performing a search
82    pub fn to_ends_index(self, len: usize) -> usize {
83        let idx = self.to_index();
84        if idx == len { idx - 1 } else { idx }
85    }
86}
87
88impl Display for SearchResult {
89    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
90        match self {
91            SearchResult::Found(i) => write!(f, "Found({i})"),
92            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
93        }
94    }
95}
96
97/// Searches for value assuming the array is sorted.
98///
99/// Returned indices satisfy following condition if the search for value was to be inserted into the array at found positions
100///
101/// |side |result satisfies|
102/// |-----|----------------|
103/// |left |array\[i-1\] < value <= array\[i\]|
104/// |right|array\[i-1\] <= value < array\[i\]|
105pub trait SearchSortedFn<A: Copy> {
106    fn search_sorted(
107        &self,
108        array: A,
109        value: &Scalar,
110        side: SearchSortedSide,
111    ) -> VortexResult<SearchResult>;
112
113    /// Bulk search for many values.
114    fn search_sorted_many(
115        &self,
116        array: A,
117        values: &[Scalar],
118        side: SearchSortedSide,
119    ) -> VortexResult<Vec<SearchResult>> {
120        values
121            .iter()
122            .map(|value| self.search_sorted(array, value, side))
123            .try_collect()
124    }
125}
126
127pub trait SearchSortedUsizeFn<A: Copy> {
128    fn search_sorted_usize(
129        &self,
130        array: A,
131        value: usize,
132        side: SearchSortedSide,
133    ) -> VortexResult<SearchResult>;
134
135    fn search_sorted_usize_many(
136        &self,
137        array: A,
138        values: &[usize],
139        side: SearchSortedSide,
140    ) -> VortexResult<Vec<SearchResult>> {
141        values
142            .iter()
143            .map(|&value| self.search_sorted_usize(array, value, side))
144            .try_collect()
145    }
146}
147
148impl<E: Encoding> SearchSortedFn<&dyn Array> for E
149where
150    E: for<'a> SearchSortedFn<&'a E::Array>,
151{
152    fn search_sorted(
153        &self,
154        array: &dyn Array,
155        value: &Scalar,
156        side: SearchSortedSide,
157    ) -> VortexResult<SearchResult> {
158        let array_ref = array
159            .as_any()
160            .downcast_ref::<E::Array>()
161            .vortex_expect("Failed to downcast array");
162        SearchSortedFn::search_sorted(self, array_ref, value, side)
163    }
164
165    fn search_sorted_many(
166        &self,
167        array: &dyn Array,
168        values: &[Scalar],
169        side: SearchSortedSide,
170    ) -> VortexResult<Vec<SearchResult>> {
171        let array_ref = array
172            .as_any()
173            .downcast_ref::<E::Array>()
174            .vortex_expect("Failed to downcast array");
175        SearchSortedFn::search_sorted_many(self, array_ref, values, side)
176    }
177}
178
179impl<E: Encoding> SearchSortedUsizeFn<&dyn Array> for E
180where
181    E: for<'a> SearchSortedUsizeFn<&'a E::Array>,
182{
183    fn search_sorted_usize(
184        &self,
185        array: &dyn Array,
186        value: usize,
187        side: SearchSortedSide,
188    ) -> VortexResult<SearchResult> {
189        let array_ref = array
190            .as_any()
191            .downcast_ref::<E::Array>()
192            .vortex_expect("Failed to downcast array");
193        SearchSortedUsizeFn::search_sorted_usize(self, array_ref, value, side)
194    }
195
196    fn search_sorted_usize_many(
197        &self,
198        array: &dyn Array,
199        values: &[usize],
200        side: SearchSortedSide,
201    ) -> VortexResult<Vec<SearchResult>> {
202        let array_ref = array
203            .as_any()
204            .downcast_ref::<E::Array>()
205            .vortex_expect("Failed to downcast array");
206        SearchSortedUsizeFn::search_sorted_usize_many(self, array_ref, values, side)
207    }
208}
209
210pub fn search_sorted<T: Into<Scalar>>(
211    array: &dyn Array,
212    target: T,
213    side: SearchSortedSide,
214) -> VortexResult<SearchResult> {
215    let Ok(scalar) = target.into().cast(array.dtype()) else {
216        // Try to downcast the usize ot the array type, if the downcast fails, then we know the
217        // usize is too large and the value is greater than the highest value in the array.
218        return Ok(SearchResult::NotFound(array.len()));
219    };
220
221    if scalar.is_null() {
222        vortex_bail!("Search sorted with null value is not supported");
223    }
224
225    if let Some(f) = array.vtable().search_sorted_fn() {
226        return f.search_sorted(array, &scalar, side);
227    }
228
229    // Fallback to a generic search_sorted using scalar_at
230    if array.vtable().scalar_at_fn().is_some() {
231        return Ok(SearchSorted::search_sorted(array, &scalar, side));
232    }
233
234    vortex_bail!(
235        NotImplemented: "search_sorted",
236        array.encoding()
237    )
238}
239
240pub fn search_sorted_usize(
241    array: &dyn Array,
242    target: usize,
243    side: SearchSortedSide,
244) -> VortexResult<SearchResult> {
245    if let Some(f) = array.vtable().search_sorted_usize_fn() {
246        return f.search_sorted_usize(array, target, side);
247    }
248
249    // Otherwise, convert the target into a scalar to try the search_sorted_fn
250    let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
251        return Ok(SearchResult::NotFound(array.len()));
252    };
253
254    // Try the non-usize search sorted
255    if let Some(f) = array.vtable().search_sorted_fn() {
256        return f.search_sorted(array, &target, side);
257    }
258
259    // Or fallback all the way to a generic search_sorted using scalar_at
260    if array.vtable().scalar_at_fn().is_some() {
261        // Try to downcast the usize to the array type, if the downcast fails, then we know the
262        // usize is too large and the value is greater than the highest value in the array.
263        let Ok(target) = target.cast(array.dtype()) else {
264            return Ok(SearchResult::NotFound(array.len()));
265        };
266        return Ok(SearchSorted::search_sorted(array, &target, side));
267    }
268
269    vortex_bail!(
270    NotImplemented: "search_sorted_usize",
271        array.encoding()
272    )
273}
274
275/// Search for many elements in the array.
276pub fn search_sorted_many<T: Into<Scalar> + Clone>(
277    array: &dyn Array,
278    targets: &[T],
279    side: SearchSortedSide,
280) -> VortexResult<Vec<SearchResult>> {
281    if let Some(f) = array.vtable().search_sorted_fn() {
282        let mut too_big_cast_idxs = Vec::new();
283        let values = targets
284            .iter()
285            .cloned()
286            .enumerate()
287            .filter_map(|(i, t)| {
288                let Ok(c) = t.into().cast(array.dtype()) else {
289                    too_big_cast_idxs.push(i);
290                    return None;
291                };
292                Some(c)
293            })
294            .collect::<Vec<_>>();
295
296        let mut results = f.search_sorted_many(array, &values, side)?;
297        for too_big_idx in too_big_cast_idxs {
298            results.insert(too_big_idx, SearchResult::NotFound(array.len()));
299        }
300        return Ok(results);
301    }
302
303    // Call in loop and collect
304    targets
305        .iter()
306        .map(|target| search_sorted(array, target.clone(), side))
307        .try_collect()
308}
309
310// Native functions for each of the values, cast up to u64 or down to something lower.
311pub fn search_sorted_usize_many(
312    array: &dyn Array,
313    targets: &[usize],
314    side: SearchSortedSide,
315) -> VortexResult<Vec<SearchResult>> {
316    if let Some(f) = array.vtable().search_sorted_usize_fn() {
317        return f.search_sorted_usize_many(array, targets, side);
318    }
319
320    // Call in loop and collect
321    targets
322        .iter()
323        .map(|&target| search_sorted_usize(array, target, side))
324        .try_collect()
325}
326
327#[allow(clippy::len_without_is_empty)]
328pub trait IndexOrd<V> {
329    /// PartialOrd of the value at index `idx` with `elem`.
330    /// For example, if self\[idx\] > elem, return Some(Greater).
331    fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
332
333    fn index_lt(&self, idx: usize, elem: &V) -> bool {
334        matches!(self.index_cmp(idx, elem), Some(Less))
335    }
336
337    fn index_le(&self, idx: usize, elem: &V) -> bool {
338        matches!(self.index_cmp(idx, elem), Some(Less | Equal))
339    }
340
341    fn index_gt(&self, idx: usize, elem: &V) -> bool {
342        matches!(self.index_cmp(idx, elem), Some(Greater))
343    }
344
345    fn index_ge(&self, idx: usize, elem: &V) -> bool {
346        matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
347    }
348
349    /// Get the length of the underlying ordered collection
350    fn len(&self) -> usize;
351}
352
353pub trait SearchSorted<T> {
354    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
355    where
356        Self: IndexOrd<T>,
357    {
358        match side {
359            SearchSortedSide::Left => self.search_sorted_by(
360                |idx| self.index_cmp(idx, value).unwrap_or(Less),
361                |idx| {
362                    if self.index_lt(idx, value) {
363                        Less
364                    } else {
365                        Greater
366                    }
367                },
368                side,
369            ),
370            SearchSortedSide::Right => self.search_sorted_by(
371                |idx| self.index_cmp(idx, value).unwrap_or(Less),
372                |idx| {
373                    if self.index_le(idx, value) {
374                        Less
375                    } else {
376                        Greater
377                    }
378                },
379                side,
380            ),
381        }
382    }
383
384    /// find function is used to find the element if it exists, if element exists side_find will be
385    /// used to find desired index amongst equal values
386    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
387        &self,
388        find: F,
389        side_find: N,
390        side: SearchSortedSide,
391    ) -> SearchResult;
392}
393
394// Default implementation for types that implement IndexOrd.
395impl<S, T> SearchSorted<T> for S
396where
397    S: IndexOrd<T> + ?Sized,
398{
399    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
400        &self,
401        find: F,
402        side_find: N,
403        side: SearchSortedSide,
404    ) -> SearchResult {
405        match search_sorted_side_idx(find, 0, self.len()) {
406            SearchResult::Found(found) => {
407                let idx_search = match side {
408                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
409                    SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
410                };
411                match idx_search {
412                    SearchResult::NotFound(i) => SearchResult::Found(i),
413                    _ => unreachable!(
414                        "searching amongst equal values should never return Found result"
415                    ),
416                }
417            }
418            s => s,
419        }
420    }
421}
422
423// Code adapted from Rust standard library slice::binary_search_by
424fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
425    mut find: F,
426    from: usize,
427    to: usize,
428) -> SearchResult {
429    let mut size = to - from;
430    if size == 0 {
431        return SearchResult::NotFound(0);
432    }
433    let mut base = from;
434
435    // This loop intentionally doesn't have an early exit if the comparison
436    // returns Equal. We want the number of loop iterations to depend *only*
437    // on the size of the input slice so that the CPU can reliably predict
438    // the loop count.
439    while size > 1 {
440        let half = size / 2;
441        let mid = base + half;
442
443        // SAFETY: the call is made safe by the following inconstants:
444        // - `mid >= 0`: by definition
445        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
446        let cmp = find(mid);
447
448        // Binary search interacts poorly with branch prediction, so force
449        // the compiler to use conditional moves if supported by the target
450        // architecture.
451        base = if cmp == Greater { base } else { mid };
452
453        // This is imprecise in the case where `size` is odd and the
454        // comparison returns Greater: the mid element still gets included
455        // by `size` even though it's known to be larger than the element
456        // being searched for.
457        //
458        // This is fine though: we gain more performance by keeping the
459        // loop iteration count invariant (and thus predictable) than we
460        // lose from considering one additional element.
461        size -= half;
462    }
463
464    // SAFETY: base is always in [0, size) because base <= mid.
465    let cmp = find(base);
466    if cmp == Equal {
467        // SAFETY: same as the call to `find` above.
468        unsafe { hint::assert_unchecked(base < to) };
469        SearchResult::Found(base)
470    } else {
471        let result = base + (cmp == Less) as usize;
472        // SAFETY: same as the call to `find` above.
473        // Note that this is `<=`, unlike the assert in the `Found` path.
474        unsafe { hint::assert_unchecked(result <= to) };
475        SearchResult::NotFound(result)
476    }
477}
478
479impl IndexOrd<Scalar> for dyn Array + '_ {
480    fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
481        let scalar_a = scalar_at(self, idx).ok()?;
482        scalar_a.partial_cmp(elem)
483    }
484
485    fn len(&self) -> usize {
486        Self::len(self)
487    }
488}
489
490impl<T: PartialOrd> IndexOrd<T> for [T] {
491    fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
492        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
493        unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
494    }
495
496    fn len(&self) -> usize {
497        self.len()
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use vortex_buffer::buffer;
504
505    use crate::IntoArray;
506    use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
507    use crate::compute::{search_sorted, search_sorted_many};
508
509    #[test]
510    fn left_side_equal() {
511        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
512        let res = arr.search_sorted(&2, SearchSortedSide::Left);
513        assert_eq!(arr[res.to_index()], 2);
514        assert_eq!(res, SearchResult::Found(2));
515    }
516
517    #[test]
518    fn right_side_equal() {
519        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
520        let res = arr.search_sorted(&2, SearchSortedSide::Right);
521        assert_eq!(arr[res.to_index() - 1], 2);
522        assert_eq!(res, SearchResult::Found(6));
523    }
524
525    #[test]
526    fn left_side_equal_beginning() {
527        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
528        let res = arr.search_sorted(&0, SearchSortedSide::Left);
529        assert_eq!(arr[res.to_index()], 0);
530        assert_eq!(res, SearchResult::Found(0));
531    }
532
533    #[test]
534    fn right_side_equal_beginning() {
535        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
536        let res = arr.search_sorted(&0, SearchSortedSide::Right);
537        assert_eq!(arr[res.to_index() - 1], 0);
538        assert_eq!(res, SearchResult::Found(4));
539    }
540
541    #[test]
542    fn left_side_equal_end() {
543        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
544        let res = arr.search_sorted(&9, SearchSortedSide::Left);
545        assert_eq!(arr[res.to_index()], 9);
546        assert_eq!(res, SearchResult::Found(9));
547    }
548
549    #[test]
550    fn right_side_equal_end() {
551        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
552        let res = arr.search_sorted(&9, SearchSortedSide::Right);
553        assert_eq!(arr[res.to_index() - 1], 9);
554        assert_eq!(res, SearchResult::Found(13));
555    }
556
557    #[test]
558    fn failed_cast() {
559        let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
560        let res = search_sorted(&arr, 256, SearchSortedSide::Left).unwrap();
561        assert_eq!(res, SearchResult::NotFound(arr.len()));
562    }
563
564    #[test]
565    fn search_sorted_many_failed_cast() {
566        let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
567        let res = search_sorted_many(&arr, &[256], SearchSortedSide::Left).unwrap();
568        assert_eq!(res, vec![SearchResult::NotFound(arr.len())]);
569    }
570}