vortex_array/
search_sorted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::cmp::Ordering::Equal;
6use std::cmp::Ordering::Greater;
7use std::cmp::Ordering::Less;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hint;
12
13use vortex_scalar::Scalar;
14
15use crate::Array;
16
17#[derive(Debug, Copy, Clone, Eq, PartialEq)]
18pub enum SearchSortedSide {
19    Left,
20    Right,
21}
22
23impl Display for SearchSortedSide {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        match self {
26            SearchSortedSide::Left => write!(f, "left"),
27            SearchSortedSide::Right => write!(f, "right"),
28        }
29    }
30}
31
32/// Result of performing search_sorted on an Array
33#[derive(Debug, Copy, Clone, PartialEq, Eq)]
34pub enum SearchResult {
35    /// Result for a found element was found in the array and another one could be inserted at the given position
36    /// in the sorted order
37    Found(usize),
38
39    /// Result for an element not found, but that could be inserted at the given position
40    /// in the sorted order.
41    NotFound(usize),
42}
43
44impl SearchResult {
45    /// Convert search result to an index only if the value have been found
46    pub fn to_found(self) -> Option<usize> {
47        match self {
48            Self::Found(i) => Some(i),
49            Self::NotFound(_) => None,
50        }
51    }
52
53    /// Extract index out of search result regardless of whether the value have been found or not
54    pub fn to_index(self) -> usize {
55        match self {
56            Self::Found(i) => i,
57            Self::NotFound(i) => i,
58        }
59    }
60
61    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
62    ///
63    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
64    /// obtain index suitable for indexing into it after performing a search
65    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
66        match self {
67            SearchResult::Found(i) => {
68                if side == SearchSortedSide::Right || i == len {
69                    i.saturating_sub(1)
70                } else {
71                    i
72                }
73            }
74            SearchResult::NotFound(i) => i.saturating_sub(1),
75        }
76    }
77
78    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
79    /// i.e. first element implicitly covers 0..0th-element range.
80    ///
81    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
82    /// indexing into it after performing a search
83    pub fn to_ends_index(self, len: usize) -> usize {
84        let idx = self.to_index();
85        if idx == len { idx - 1 } else { idx }
86    }
87}
88
89impl Display for SearchResult {
90    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91        match self {
92            SearchResult::Found(i) => write!(f, "Found({i})"),
93            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
94        }
95    }
96}
97
98pub trait IndexOrd<V> {
99    /// PartialOrd of the value at index `idx` with `elem`.
100    /// For example, if self\[idx\] > elem, return Some(Greater).
101    fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
102
103    fn index_lt(&self, idx: usize, elem: &V) -> bool {
104        matches!(self.index_cmp(idx, elem), Some(Less))
105    }
106
107    fn index_le(&self, idx: usize, elem: &V) -> bool {
108        matches!(self.index_cmp(idx, elem), Some(Less | Equal))
109    }
110
111    fn index_gt(&self, idx: usize, elem: &V) -> bool {
112        matches!(self.index_cmp(idx, elem), Some(Greater))
113    }
114
115    fn index_ge(&self, idx: usize, elem: &V) -> bool {
116        matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
117    }
118
119    /// Get the length of the underlying ordered collection
120    fn index_len(&self) -> usize;
121}
122
123/// Searches for value assuming the array is sorted.
124///
125/// Returned indices satisfy following condition if the search for value was to be inserted into the array at found positions
126///
127/// |side |result satisfies|
128/// |-----|----------------|
129/// |left |array\[i-1\] < value <= array\[i\]|
130/// |right|array\[i-1\] <= value < array\[i\]|
131pub trait SearchSorted<T> {
132    fn search_sorted_many<I: IntoIterator<Item = T>>(
133        &self,
134        values: I,
135        side: SearchSortedSide,
136    ) -> impl Iterator<Item = SearchResult>
137    where
138        Self: IndexOrd<T>,
139    {
140        values
141            .into_iter()
142            .map(move |value| self.search_sorted(&value, side))
143    }
144
145    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
146    where
147        Self: IndexOrd<T>,
148    {
149        match side {
150            SearchSortedSide::Left => self.search_sorted_by(
151                |idx| self.index_cmp(idx, value).unwrap_or(Less),
152                |idx| {
153                    if self.index_lt(idx, value) {
154                        Less
155                    } else {
156                        Greater
157                    }
158                },
159                side,
160            ),
161            SearchSortedSide::Right => self.search_sorted_by(
162                |idx| self.index_cmp(idx, value).unwrap_or(Less),
163                |idx| {
164                    if self.index_le(idx, value) {
165                        Less
166                    } else {
167                        Greater
168                    }
169                },
170                side,
171            ),
172        }
173    }
174
175    /// find function is used to find the element if it exists, if element exists side_find will be
176    /// used to find desired index amongst equal values
177    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
178        &self,
179        find: F,
180        side_find: N,
181        side: SearchSortedSide,
182    ) -> SearchResult;
183}
184
185// Default implementation for types that implement IndexOrd.
186impl<S, T> SearchSorted<T> for S
187where
188    S: IndexOrd<T> + ?Sized,
189{
190    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
191        &self,
192        find: F,
193        side_find: N,
194        side: SearchSortedSide,
195    ) -> SearchResult {
196        match search_sorted_side_idx(find, 0, self.index_len()) {
197            SearchResult::Found(found) => {
198                let idx_search = match side {
199                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
200                    SearchSortedSide::Right => {
201                        search_sorted_side_idx(side_find, found, self.index_len())
202                    }
203                };
204                match idx_search {
205                    SearchResult::NotFound(i) => SearchResult::Found(i),
206                    _ => unreachable!(
207                        "searching amongst equal values should never return Found result"
208                    ),
209                }
210            }
211            s => s,
212        }
213    }
214}
215
216// Code adapted from Rust standard library slice::binary_search_by
217fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
218    mut find: F,
219    from: usize,
220    to: usize,
221) -> SearchResult {
222    let mut size = to - from;
223    if size == 0 {
224        return SearchResult::NotFound(0);
225    }
226    let mut base = from;
227
228    // This loop intentionally doesn't have an early exit if the comparison
229    // returns Equal. We want the number of loop iterations to depend *only*
230    // on the size of the input slice so that the CPU can reliably predict
231    // the loop count.
232    while size > 1 {
233        let half = size / 2;
234        let mid = base + half;
235
236        // SAFETY: the call is made safe by the following inconstants:
237        // - `mid >= 0`: by definition
238        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
239        let cmp = find(mid);
240
241        // Binary search interacts poorly with branch prediction, so force
242        // the compiler to use conditional moves if supported by the target
243        // architecture.
244        base = if cmp == Greater { base } else { mid };
245
246        // This is imprecise in the case where `size` is odd and the
247        // comparison returns Greater: the mid element still gets included
248        // by `size` even though it's known to be larger than the element
249        // being searched for.
250        // This is fine though: we gain more performance by keeping the
251        // loop iteration count invariant (and thus predictable) than we
252        // lose from considering one additional element.
253        size -= half;
254    }
255
256    // SAFETY: base is always in [0, size) because base <= mid.
257    let cmp = find(base);
258    if cmp == Equal {
259        // SAFETY: same as the call to `find` above.
260        unsafe { hint::assert_unchecked(base < to) };
261        SearchResult::Found(base)
262    } else {
263        let result = base + (cmp == Less) as usize;
264        // SAFETY: same as the call to `find` above.
265        // Note that this is `<=`, unlike the assert in the `Found` path.
266        unsafe { hint::assert_unchecked(result <= to) };
267        SearchResult::NotFound(result)
268    }
269}
270
271impl IndexOrd<Scalar> for dyn Array + '_ {
272    fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
273        let scalar_a = self.scalar_at(idx);
274        scalar_a.partial_cmp(elem)
275    }
276
277    fn index_len(&self) -> usize {
278        Self::len(self)
279    }
280}
281
282impl<T: PartialOrd> IndexOrd<T> for [T] {
283    fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
284        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
285        unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
286    }
287
288    fn index_len(&self) -> usize {
289        self.len()
290    }
291}
292
293#[cfg(test)]
294mod test {
295    use crate::search_sorted::SearchResult;
296    use crate::search_sorted::SearchSorted;
297    use crate::search_sorted::SearchSortedSide;
298
299    #[test]
300    fn left_side_equal() {
301        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
302        let res = arr.search_sorted(&2, SearchSortedSide::Left);
303        assert_eq!(arr[res.to_index()], 2);
304        assert_eq!(res, SearchResult::Found(2));
305    }
306
307    #[test]
308    fn right_side_equal() {
309        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
310        let res = arr.search_sorted(&2, SearchSortedSide::Right);
311        assert_eq!(arr[res.to_index() - 1], 2);
312        assert_eq!(res, SearchResult::Found(6));
313    }
314
315    #[test]
316    fn left_side_equal_beginning() {
317        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
318        let res = arr.search_sorted(&0, SearchSortedSide::Left);
319        assert_eq!(arr[res.to_index()], 0);
320        assert_eq!(res, SearchResult::Found(0));
321    }
322
323    #[test]
324    fn right_side_equal_beginning() {
325        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
326        let res = arr.search_sorted(&0, SearchSortedSide::Right);
327        assert_eq!(arr[res.to_index() - 1], 0);
328        assert_eq!(res, SearchResult::Found(4));
329    }
330
331    #[test]
332    fn left_side_equal_end() {
333        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
334        let res = arr.search_sorted(&9, SearchSortedSide::Left);
335        assert_eq!(arr[res.to_index()], 9);
336        assert_eq!(res, SearchResult::Found(9));
337    }
338
339    #[test]
340    fn right_side_equal_end() {
341        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
342        let res = arr.search_sorted(&9, SearchSortedSide::Right);
343        assert_eq!(arr[res.to_index() - 1], 9);
344        assert_eq!(res, SearchResult::Found(13));
345    }
346}