vortex_array/
search_sorted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::cmp::Ordering::{Equal, Greater, Less};
6use std::fmt::{Debug, Display, Formatter};
7use std::hint;
8
9use vortex_scalar::Scalar;
10
11use crate::Array;
12
13#[derive(Debug, Copy, Clone, Eq, PartialEq)]
14pub enum SearchSortedSide {
15    Left,
16    Right,
17}
18
19impl Display for SearchSortedSide {
20    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21        match self {
22            SearchSortedSide::Left => write!(f, "left"),
23            SearchSortedSide::Right => write!(f, "right"),
24        }
25    }
26}
27
28/// Result of performing search_sorted on an Array
29#[derive(Debug, Copy, Clone, PartialEq, Eq)]
30pub enum SearchResult {
31    /// Result for a found element was found in the array and another one could be inserted at the given position
32    /// in the sorted order
33    Found(usize),
34
35    /// Result for an element not found, but that could be inserted at the given position
36    /// in the sorted order.
37    NotFound(usize),
38}
39
40impl SearchResult {
41    /// Convert search result to an index only if the value have been found
42    pub fn to_found(self) -> Option<usize> {
43        match self {
44            Self::Found(i) => Some(i),
45            Self::NotFound(_) => None,
46        }
47    }
48
49    /// Extract index out of search result regardless of whether the value have been found or not
50    pub fn to_index(self) -> usize {
51        match self {
52            Self::Found(i) => i,
53            Self::NotFound(i) => i,
54        }
55    }
56
57    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
58    ///
59    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
60    /// obtain index suitable for indexing into it after performing a search
61    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
62        match self {
63            SearchResult::Found(i) => {
64                if side == SearchSortedSide::Right || i == len {
65                    i.saturating_sub(1)
66                } else {
67                    i
68                }
69            }
70            SearchResult::NotFound(i) => i.saturating_sub(1),
71        }
72    }
73
74    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
75    /// i.e. first element implicitly covers 0..0th-element range.
76    ///
77    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
78    /// indexing into it after performing a search
79    pub fn to_ends_index(self, len: usize) -> usize {
80        let idx = self.to_index();
81        if idx == len { idx - 1 } else { idx }
82    }
83}
84
85impl Display for SearchResult {
86    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
87        match self {
88            SearchResult::Found(i) => write!(f, "Found({i})"),
89            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
90        }
91    }
92}
93
94pub trait IndexOrd<V> {
95    /// PartialOrd of the value at index `idx` with `elem`.
96    /// For example, if self\[idx\] > elem, return Some(Greater).
97    fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
98
99    fn index_lt(&self, idx: usize, elem: &V) -> bool {
100        matches!(self.index_cmp(idx, elem), Some(Less))
101    }
102
103    fn index_le(&self, idx: usize, elem: &V) -> bool {
104        matches!(self.index_cmp(idx, elem), Some(Less | Equal))
105    }
106
107    fn index_gt(&self, idx: usize, elem: &V) -> bool {
108        matches!(self.index_cmp(idx, elem), Some(Greater))
109    }
110
111    fn index_ge(&self, idx: usize, elem: &V) -> bool {
112        matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
113    }
114
115    /// Get the length of the underlying ordered collection
116    fn index_len(&self) -> usize;
117}
118
119/// Searches for value assuming the array is sorted.
120///
121/// Returned indices satisfy following condition if the search for value was to be inserted into the array at found positions
122///
123/// |side |result satisfies|
124/// |-----|----------------|
125/// |left |array\[i-1\] < value <= array\[i\]|
126/// |right|array\[i-1\] <= value < array\[i\]|
127pub trait SearchSorted<T> {
128    fn search_sorted_many<I: IntoIterator<Item = T>>(
129        &self,
130        values: I,
131        side: SearchSortedSide,
132    ) -> impl Iterator<Item = SearchResult>
133    where
134        Self: IndexOrd<T>,
135    {
136        values
137            .into_iter()
138            .map(move |value| self.search_sorted(&value, side))
139    }
140
141    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
142    where
143        Self: IndexOrd<T>,
144    {
145        match side {
146            SearchSortedSide::Left => self.search_sorted_by(
147                |idx| self.index_cmp(idx, value).unwrap_or(Less),
148                |idx| {
149                    if self.index_lt(idx, value) {
150                        Less
151                    } else {
152                        Greater
153                    }
154                },
155                side,
156            ),
157            SearchSortedSide::Right => self.search_sorted_by(
158                |idx| self.index_cmp(idx, value).unwrap_or(Less),
159                |idx| {
160                    if self.index_le(idx, value) {
161                        Less
162                    } else {
163                        Greater
164                    }
165                },
166                side,
167            ),
168        }
169    }
170
171    /// find function is used to find the element if it exists, if element exists side_find will be
172    /// used to find desired index amongst equal values
173    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
174        &self,
175        find: F,
176        side_find: N,
177        side: SearchSortedSide,
178    ) -> SearchResult;
179}
180
181// Default implementation for types that implement IndexOrd.
182impl<S, T> SearchSorted<T> for S
183where
184    S: IndexOrd<T> + ?Sized,
185{
186    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
187        &self,
188        find: F,
189        side_find: N,
190        side: SearchSortedSide,
191    ) -> SearchResult {
192        match search_sorted_side_idx(find, 0, self.index_len()) {
193            SearchResult::Found(found) => {
194                let idx_search = match side {
195                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
196                    SearchSortedSide::Right => {
197                        search_sorted_side_idx(side_find, found, self.index_len())
198                    }
199                };
200                match idx_search {
201                    SearchResult::NotFound(i) => SearchResult::Found(i),
202                    _ => unreachable!(
203                        "searching amongst equal values should never return Found result"
204                    ),
205                }
206            }
207            s => s,
208        }
209    }
210}
211
212// Code adapted from Rust standard library slice::binary_search_by
213fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
214    mut find: F,
215    from: usize,
216    to: usize,
217) -> SearchResult {
218    let mut size = to - from;
219    if size == 0 {
220        return SearchResult::NotFound(0);
221    }
222    let mut base = from;
223
224    // This loop intentionally doesn't have an early exit if the comparison
225    // returns Equal. We want the number of loop iterations to depend *only*
226    // on the size of the input slice so that the CPU can reliably predict
227    // the loop count.
228    while size > 1 {
229        let half = size / 2;
230        let mid = base + half;
231
232        // SAFETY: the call is made safe by the following inconstants:
233        // - `mid >= 0`: by definition
234        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
235        let cmp = find(mid);
236
237        // Binary search interacts poorly with branch prediction, so force
238        // the compiler to use conditional moves if supported by the target
239        // architecture.
240        base = if cmp == Greater { base } else { mid };
241
242        // This is imprecise in the case where `size` is odd and the
243        // comparison returns Greater: the mid element still gets included
244        // by `size` even though it's known to be larger than the element
245        // being searched for.
246        // This is fine though: we gain more performance by keeping the
247        // loop iteration count invariant (and thus predictable) than we
248        // lose from considering one additional element.
249        size -= half;
250    }
251
252    // SAFETY: base is always in [0, size) because base <= mid.
253    let cmp = find(base);
254    if cmp == Equal {
255        // SAFETY: same as the call to `find` above.
256        unsafe { hint::assert_unchecked(base < to) };
257        SearchResult::Found(base)
258    } else {
259        let result = base + (cmp == Less) as usize;
260        // SAFETY: same as the call to `find` above.
261        // Note that this is `<=`, unlike the assert in the `Found` path.
262        unsafe { hint::assert_unchecked(result <= to) };
263        SearchResult::NotFound(result)
264    }
265}
266
267impl IndexOrd<Scalar> for dyn Array + '_ {
268    fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
269        let scalar_a = self.scalar_at(idx);
270        scalar_a.partial_cmp(elem)
271    }
272
273    fn index_len(&self) -> usize {
274        Self::len(self)
275    }
276}
277
278impl<T: PartialOrd> IndexOrd<T> for [T] {
279    fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
280        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
281        unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
282    }
283
284    fn index_len(&self) -> usize {
285        self.len()
286    }
287}
288
289#[cfg(test)]
290mod test {
291    use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
292
293    #[test]
294    fn left_side_equal() {
295        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
296        let res = arr.search_sorted(&2, SearchSortedSide::Left);
297        assert_eq!(arr[res.to_index()], 2);
298        assert_eq!(res, SearchResult::Found(2));
299    }
300
301    #[test]
302    fn right_side_equal() {
303        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
304        let res = arr.search_sorted(&2, SearchSortedSide::Right);
305        assert_eq!(arr[res.to_index() - 1], 2);
306        assert_eq!(res, SearchResult::Found(6));
307    }
308
309    #[test]
310    fn left_side_equal_beginning() {
311        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
312        let res = arr.search_sorted(&0, SearchSortedSide::Left);
313        assert_eq!(arr[res.to_index()], 0);
314        assert_eq!(res, SearchResult::Found(0));
315    }
316
317    #[test]
318    fn right_side_equal_beginning() {
319        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
320        let res = arr.search_sorted(&0, SearchSortedSide::Right);
321        assert_eq!(arr[res.to_index() - 1], 0);
322        assert_eq!(res, SearchResult::Found(4));
323    }
324
325    #[test]
326    fn left_side_equal_end() {
327        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
328        let res = arr.search_sorted(&9, SearchSortedSide::Left);
329        assert_eq!(arr[res.to_index()], 9);
330        assert_eq!(res, SearchResult::Found(9));
331    }
332
333    #[test]
334    fn right_side_equal_end() {
335        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
336        let res = arr.search_sorted(&9, SearchSortedSide::Right);
337        assert_eq!(arr[res.to_index() - 1], 9);
338        assert_eq!(res, SearchResult::Found(13));
339    }
340}