Skip to main content

vortex_array/
search_sorted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::cmp::Ordering::Equal;
6use std::cmp::Ordering::Greater;
7use std::cmp::Ordering::Less;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hint;
12
13use vortex_error::VortexResult;
14
15use crate::ArrayRef;
16use crate::LEGACY_SESSION;
17use crate::VortexSessionExecute;
18use crate::scalar::Scalar;
19
20#[derive(Debug, Copy, Clone, Eq, PartialEq)]
21pub enum SearchSortedSide {
22    Left,
23    Right,
24}
25
26impl Display for SearchSortedSide {
27    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28        match self {
29            SearchSortedSide::Left => write!(f, "left"),
30            SearchSortedSide::Right => write!(f, "right"),
31        }
32    }
33}
34
35/// Result of performing search_sorted on an Array
36#[derive(Debug, Copy, Clone, PartialEq, Eq)]
37pub enum SearchResult {
38    /// Result for a found element was found in the array and another one could be inserted at the given position
39    /// in the sorted order
40    Found(usize),
41
42    /// Result for an element not found, but that could be inserted at the given position
43    /// in the sorted order.
44    NotFound(usize),
45}
46
47impl SearchResult {
48    /// Convert search result to an index only if the value have been found
49    pub fn to_found(self) -> Option<usize> {
50        match self {
51            Self::Found(i) => Some(i),
52            Self::NotFound(_) => None,
53        }
54    }
55
56    /// Extract index out of search result regardless of whether the value have been found or not
57    pub fn to_index(self) -> usize {
58        match self {
59            Self::Found(i) => i,
60            Self::NotFound(i) => i,
61        }
62    }
63
64    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
65    ///
66    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
67    /// obtain index suitable for indexing into it after performing a search
68    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
69        match self {
70            SearchResult::Found(i) => {
71                if side == SearchSortedSide::Right || i == len {
72                    i.saturating_sub(1)
73                } else {
74                    i
75                }
76            }
77            SearchResult::NotFound(i) => i.saturating_sub(1),
78        }
79    }
80
81    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
82    /// i.e. first element implicitly covers 0..0th-element range.
83    ///
84    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
85    /// indexing into it after performing a search
86    pub fn to_ends_index(self, len: usize) -> usize {
87        let idx = self.to_index();
88        if idx == len { idx - 1 } else { idx }
89    }
90}
91
92impl Display for SearchResult {
93    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94        match self {
95            SearchResult::Found(i) => write!(f, "Found({i})"),
96            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
97        }
98    }
99}
100
101pub trait IndexOrd<V> {
102    /// PartialOrd of the value at index `idx` with `elem`.
103    /// For example, if self\[idx\] > elem, return Some(Greater).
104    fn index_cmp(&self, idx: usize, elem: &V) -> VortexResult<Option<Ordering>>;
105
106    #[inline]
107    fn index_lt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
108        Ok(matches!(self.index_cmp(idx, elem)?, Some(Less)))
109    }
110
111    #[inline]
112    fn index_le(&self, idx: usize, elem: &V) -> VortexResult<bool> {
113        Ok(matches!(self.index_cmp(idx, elem)?, Some(Less | Equal)))
114    }
115
116    fn index_gt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
117        Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater)))
118    }
119
120    fn index_ge(&self, idx: usize, elem: &V) -> VortexResult<bool> {
121        Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater | Equal)))
122    }
123
124    /// Get the length of the underlying ordered collection
125    fn index_len(&self) -> usize;
126}
127
128/// Searches for value assuming the array is sorted.
129///
130/// Returned indices satisfy following condition if the search for value was to be inserted into the array at found positions
131///
132/// |side |result satisfies|
133/// |-----|----------------|
134/// |left |array\[i-1\] < value <= array\[i\]|
135/// |right|array\[i-1\] <= value < array\[i\]|
136pub trait SearchSorted<T> {
137    #[inline]
138    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> VortexResult<SearchResult>
139    where
140        Self: IndexOrd<T>,
141    {
142        match side {
143            SearchSortedSide::Left => self.search_sorted_by(
144                |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
145                |idx| {
146                    Ok(if self.index_lt(idx, value)? {
147                        Less
148                    } else {
149                        Greater
150                    })
151                },
152                side,
153            ),
154            SearchSortedSide::Right => self.search_sorted_by(
155                |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
156                |idx| {
157                    Ok(if self.index_le(idx, value)? {
158                        Less
159                    } else {
160                        Greater
161                    })
162                },
163                side,
164            ),
165        }
166    }
167
168    /// find function is used to find the element if it exists, if element exists side_find will be
169    /// used to find desired index amongst equal values
170    fn search_sorted_by<
171        F: FnMut(usize) -> VortexResult<Ordering>,
172        N: FnMut(usize) -> VortexResult<Ordering>,
173    >(
174        &self,
175        find: F,
176        side_find: N,
177        side: SearchSortedSide,
178    ) -> VortexResult<SearchResult>;
179}
180
181// Default implementation for types that implement IndexOrd.
182impl<S, T> SearchSorted<T> for S
183where
184    S: IndexOrd<T> + ?Sized,
185{
186    #[inline]
187    fn search_sorted_by<
188        F: FnMut(usize) -> VortexResult<Ordering>,
189        N: FnMut(usize) -> VortexResult<Ordering>,
190    >(
191        &self,
192        find: F,
193        side_find: N,
194        side: SearchSortedSide,
195    ) -> VortexResult<SearchResult> {
196        match search_sorted_side_idx(find, 0, self.index_len())? {
197            SearchResult::Found(found) => {
198                let idx_search = match side {
199                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found)?,
200                    SearchSortedSide::Right => {
201                        search_sorted_side_idx(side_find, found, self.index_len())?
202                    }
203                };
204                match idx_search {
205                    SearchResult::NotFound(i) => Ok(SearchResult::Found(i)),
206                    _ => unreachable!(
207                        "searching amongst equal values should never return Found result"
208                    ),
209                }
210            }
211            s => Ok(s),
212        }
213    }
214}
215
216// Code adapted from Rust standard library slice::binary_search_by
217#[inline]
218fn search_sorted_side_idx<F: FnMut(usize) -> VortexResult<Ordering>>(
219    mut find: F,
220    from: usize,
221    to: usize,
222) -> VortexResult<SearchResult> {
223    let mut size = to - from;
224    if size == 0 {
225        return Ok(SearchResult::NotFound(0));
226    }
227    let mut base = from;
228
229    // This loop intentionally doesn't have an early exit if the comparison
230    // returns Equal. We want the number of loop iterations to depend *only*
231    // on the size of the input slice so that the CPU can reliably predict
232    // the loop count.
233    while size > 1 {
234        let half = size / 2;
235        let mid = base + half;
236
237        // SAFETY: the call is made safe by the following inconstants:
238        // - `mid >= 0`: by definition
239        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
240        let cmp = find(mid)?;
241
242        // Binary search interacts poorly with branch prediction, so force
243        // the compiler to use conditional moves if supported by the target
244        // architecture.
245        base = hint::select_unpredictable(cmp == Greater, base, mid);
246
247        // This is imprecise in the case where `size` is odd and the
248        // comparison returns Greater: the mid element still gets included
249        // by `size` even though it's known to be larger than the element
250        // being searched for.
251        // This is fine though: we gain more performance by keeping the
252        // loop iteration count invariant (and thus predictable) than we
253        // lose from considering one additional element.
254        size -= half;
255    }
256
257    // SAFETY: base is always in [0, size) because base <= mid.
258    let cmp = find(base)?;
259    if cmp == Equal {
260        // SAFETY: same as the call to `find` above.
261        unsafe { hint::assert_unchecked(base < to) };
262        Ok(SearchResult::Found(base))
263    } else {
264        let result = base + (cmp == Less) as usize;
265        // SAFETY: same as the call to `find` above.
266        // Note that this is `<=`, unlike the assert in the `Found` path.
267        unsafe { hint::assert_unchecked(result <= to) };
268        Ok(SearchResult::NotFound(result))
269    }
270}
271
272impl IndexOrd<Scalar> for ArrayRef {
273    fn index_cmp(&self, idx: usize, elem: &Scalar) -> VortexResult<Option<Ordering>> {
274        let scalar_a = self.execute_scalar(idx, &mut LEGACY_SESSION.create_execution_ctx())?;
275        Ok(scalar_a.partial_cmp(elem))
276    }
277
278    fn index_len(&self) -> usize {
279        Self::len(self)
280    }
281}
282
283impl<T: PartialOrd> IndexOrd<T> for [T] {
284    #[inline]
285    fn index_cmp(&self, idx: usize, elem: &T) -> VortexResult<Option<Ordering>> {
286        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
287        Ok(unsafe { self.get_unchecked(idx) }.partial_cmp(elem))
288    }
289
290    #[inline]
291    fn index_len(&self) -> usize {
292        self.len()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use vortex_error::VortexResult;
299
300    use crate::search_sorted::SearchResult;
301    use crate::search_sorted::SearchSorted;
302    use crate::search_sorted::SearchSortedSide;
303
304    #[test]
305    fn left_side_equal() -> VortexResult<()> {
306        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
307        let res = arr.search_sorted(&2, SearchSortedSide::Left)?;
308        assert_eq!(arr[res.to_index()], 2);
309        assert_eq!(res, SearchResult::Found(2));
310        Ok(())
311    }
312
313    #[test]
314    fn right_side_equal() -> VortexResult<()> {
315        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
316        let res = arr.search_sorted(&2, SearchSortedSide::Right)?;
317        assert_eq!(arr[res.to_index() - 1], 2);
318        assert_eq!(res, SearchResult::Found(6));
319        Ok(())
320    }
321
322    #[test]
323    fn left_side_equal_beginning() -> VortexResult<()> {
324        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
325        let res = arr.search_sorted(&0, SearchSortedSide::Left)?;
326        assert_eq!(arr[res.to_index()], 0);
327        assert_eq!(res, SearchResult::Found(0));
328        Ok(())
329    }
330
331    #[test]
332    fn right_side_equal_beginning() -> VortexResult<()> {
333        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
334        let res = arr.search_sorted(&0, SearchSortedSide::Right)?;
335        assert_eq!(arr[res.to_index() - 1], 0);
336        assert_eq!(res, SearchResult::Found(4));
337        Ok(())
338    }
339
340    #[test]
341    fn left_side_equal_end() -> VortexResult<()> {
342        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
343        let res = arr.search_sorted(&9, SearchSortedSide::Left)?;
344        assert_eq!(arr[res.to_index()], 9);
345        assert_eq!(res, SearchResult::Found(9));
346        Ok(())
347    }
348
349    #[test]
350    fn right_side_equal_end() -> VortexResult<()> {
351        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
352        let res = arr.search_sorted(&9, SearchSortedSide::Right)?;
353        assert_eq!(arr[res.to_index() - 1], 9);
354        assert_eq!(res, SearchResult::Found(13));
355        Ok(())
356    }
357}