Skip to main content

vortex_array/
search_sorted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::cmp::Ordering::Equal;
6use std::cmp::Ordering::Greater;
7use std::cmp::Ordering::Less;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hint;
12
13use vortex_error::VortexResult;
14
15use crate::ArrayRef;
16use crate::LEGACY_SESSION;
17use crate::VortexSessionExecute;
18use crate::scalar::Scalar;
19
20#[derive(Debug, Copy, Clone, Eq, PartialEq)]
21pub enum SearchSortedSide {
22    Left,
23    Right,
24}
25
26impl Display for SearchSortedSide {
27    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28        match self {
29            SearchSortedSide::Left => write!(f, "left"),
30            SearchSortedSide::Right => write!(f, "right"),
31        }
32    }
33}
34
35/// Result of performing search_sorted on an Array
36#[derive(Debug, Copy, Clone, PartialEq, Eq)]
37pub enum SearchResult {
38    /// Result for a found element was found in the array and another one could be inserted at the given position
39    /// in the sorted order
40    Found(usize),
41
42    /// Result for an element not found, but that could be inserted at the given position
43    /// in the sorted order.
44    NotFound(usize),
45}
46
47impl SearchResult {
48    /// Convert search result to an index only if the value have been found
49    pub fn to_found(self) -> Option<usize> {
50        match self {
51            Self::Found(i) => Some(i),
52            Self::NotFound(_) => None,
53        }
54    }
55
56    /// Extract index out of search result regardless of whether the value have been found or not
57    pub fn to_index(self) -> usize {
58        match self {
59            Self::Found(i) => i,
60            Self::NotFound(i) => i,
61        }
62    }
63
64    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
65    ///
66    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
67    /// obtain index suitable for indexing into it after performing a search
68    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
69        match self {
70            SearchResult::Found(i) => {
71                if side == SearchSortedSide::Right || i == len {
72                    i.saturating_sub(1)
73                } else {
74                    i
75                }
76            }
77            SearchResult::NotFound(i) => i.saturating_sub(1),
78        }
79    }
80
81    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
82    /// i.e. first element implicitly covers 0..0th-element range.
83    ///
84    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
85    /// indexing into it after performing a search
86    pub fn to_ends_index(self, len: usize) -> usize {
87        let idx = self.to_index();
88        if idx == len { idx - 1 } else { idx }
89    }
90}
91
92impl Display for SearchResult {
93    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94        match self {
95            SearchResult::Found(i) => write!(f, "Found({i})"),
96            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
97        }
98    }
99}
100
101pub trait IndexOrd<V> {
102    /// PartialOrd of the value at index `idx` with `elem`.
103    /// For example, if self\[idx\] > elem, return Some(Greater).
104    fn index_cmp(&self, idx: usize, elem: &V) -> VortexResult<Option<Ordering>>;
105
106    fn index_lt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
107        Ok(matches!(self.index_cmp(idx, elem)?, Some(Less)))
108    }
109
110    fn index_le(&self, idx: usize, elem: &V) -> VortexResult<bool> {
111        Ok(matches!(self.index_cmp(idx, elem)?, Some(Less | Equal)))
112    }
113
114    fn index_gt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
115        Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater)))
116    }
117
118    fn index_ge(&self, idx: usize, elem: &V) -> VortexResult<bool> {
119        Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater | Equal)))
120    }
121
122    /// Get the length of the underlying ordered collection
123    fn index_len(&self) -> usize;
124}
125
126/// Searches for value assuming the array is sorted.
127///
128/// Returned indices satisfy following condition if the search for value was to be inserted into the array at found positions
129///
130/// |side |result satisfies|
131/// |-----|----------------|
132/// |left |array\[i-1\] < value <= array\[i\]|
133/// |right|array\[i-1\] <= value < array\[i\]|
134pub trait SearchSorted<T> {
135    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> VortexResult<SearchResult>
136    where
137        Self: IndexOrd<T>,
138    {
139        match side {
140            SearchSortedSide::Left => self.search_sorted_by(
141                |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
142                |idx| {
143                    Ok(if self.index_lt(idx, value)? {
144                        Less
145                    } else {
146                        Greater
147                    })
148                },
149                side,
150            ),
151            SearchSortedSide::Right => self.search_sorted_by(
152                |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
153                |idx| {
154                    Ok(if self.index_le(idx, value)? {
155                        Less
156                    } else {
157                        Greater
158                    })
159                },
160                side,
161            ),
162        }
163    }
164
165    /// find function is used to find the element if it exists, if element exists side_find will be
166    /// used to find desired index amongst equal values
167    fn search_sorted_by<
168        F: FnMut(usize) -> VortexResult<Ordering>,
169        N: FnMut(usize) -> VortexResult<Ordering>,
170    >(
171        &self,
172        find: F,
173        side_find: N,
174        side: SearchSortedSide,
175    ) -> VortexResult<SearchResult>;
176}
177
178// Default implementation for types that implement IndexOrd.
179impl<S, T> SearchSorted<T> for S
180where
181    S: IndexOrd<T> + ?Sized,
182{
183    fn search_sorted_by<
184        F: FnMut(usize) -> VortexResult<Ordering>,
185        N: FnMut(usize) -> VortexResult<Ordering>,
186    >(
187        &self,
188        find: F,
189        side_find: N,
190        side: SearchSortedSide,
191    ) -> VortexResult<SearchResult> {
192        match search_sorted_side_idx(find, 0, self.index_len())? {
193            SearchResult::Found(found) => {
194                let idx_search = match side {
195                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found)?,
196                    SearchSortedSide::Right => {
197                        search_sorted_side_idx(side_find, found, self.index_len())?
198                    }
199                };
200                match idx_search {
201                    SearchResult::NotFound(i) => Ok(SearchResult::Found(i)),
202                    _ => unreachable!(
203                        "searching amongst equal values should never return Found result"
204                    ),
205                }
206            }
207            s => Ok(s),
208        }
209    }
210}
211
212// Code adapted from Rust standard library slice::binary_search_by
213fn search_sorted_side_idx<F: FnMut(usize) -> VortexResult<Ordering>>(
214    mut find: F,
215    from: usize,
216    to: usize,
217) -> VortexResult<SearchResult> {
218    let mut size = to - from;
219    if size == 0 {
220        return Ok(SearchResult::NotFound(0));
221    }
222    let mut base = from;
223
224    // This loop intentionally doesn't have an early exit if the comparison
225    // returns Equal. We want the number of loop iterations to depend *only*
226    // on the size of the input slice so that the CPU can reliably predict
227    // the loop count.
228    while size > 1 {
229        let half = size / 2;
230        let mid = base + half;
231
232        // SAFETY: the call is made safe by the following inconstants:
233        // - `mid >= 0`: by definition
234        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
235        let cmp = find(mid)?;
236
237        // Binary search interacts poorly with branch prediction, so force
238        // the compiler to use conditional moves if supported by the target
239        // architecture.
240        base = if cmp == Greater { base } else { mid };
241
242        // This is imprecise in the case where `size` is odd and the
243        // comparison returns Greater: the mid element still gets included
244        // by `size` even though it's known to be larger than the element
245        // being searched for.
246        // This is fine though: we gain more performance by keeping the
247        // loop iteration count invariant (and thus predictable) than we
248        // lose from considering one additional element.
249        size -= half;
250    }
251
252    // SAFETY: base is always in [0, size) because base <= mid.
253    let cmp = find(base)?;
254    if cmp == Equal {
255        // SAFETY: same as the call to `find` above.
256        unsafe { hint::assert_unchecked(base < to) };
257        Ok(SearchResult::Found(base))
258    } else {
259        let result = base + (cmp == Less) as usize;
260        // SAFETY: same as the call to `find` above.
261        // Note that this is `<=`, unlike the assert in the `Found` path.
262        unsafe { hint::assert_unchecked(result <= to) };
263        Ok(SearchResult::NotFound(result))
264    }
265}
266
267impl IndexOrd<Scalar> for ArrayRef {
268    fn index_cmp(&self, idx: usize, elem: &Scalar) -> VortexResult<Option<Ordering>> {
269        let scalar_a = self.execute_scalar(idx, &mut LEGACY_SESSION.create_execution_ctx())?;
270        Ok(scalar_a.partial_cmp(elem))
271    }
272
273    fn index_len(&self) -> usize {
274        Self::len(self)
275    }
276}
277
278impl<T: PartialOrd> IndexOrd<T> for [T] {
279    fn index_cmp(&self, idx: usize, elem: &T) -> VortexResult<Option<Ordering>> {
280        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
281        Ok(unsafe { self.get_unchecked(idx) }.partial_cmp(elem))
282    }
283
284    fn index_len(&self) -> usize {
285        self.len()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use vortex_error::VortexResult;
292
293    use crate::search_sorted::SearchResult;
294    use crate::search_sorted::SearchSorted;
295    use crate::search_sorted::SearchSortedSide;
296
297    #[test]
298    fn left_side_equal() -> VortexResult<()> {
299        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
300        let res = arr.search_sorted(&2, SearchSortedSide::Left)?;
301        assert_eq!(arr[res.to_index()], 2);
302        assert_eq!(res, SearchResult::Found(2));
303        Ok(())
304    }
305
306    #[test]
307    fn right_side_equal() -> VortexResult<()> {
308        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
309        let res = arr.search_sorted(&2, SearchSortedSide::Right)?;
310        assert_eq!(arr[res.to_index() - 1], 2);
311        assert_eq!(res, SearchResult::Found(6));
312        Ok(())
313    }
314
315    #[test]
316    fn left_side_equal_beginning() -> VortexResult<()> {
317        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
318        let res = arr.search_sorted(&0, SearchSortedSide::Left)?;
319        assert_eq!(arr[res.to_index()], 0);
320        assert_eq!(res, SearchResult::Found(0));
321        Ok(())
322    }
323
324    #[test]
325    fn right_side_equal_beginning() -> VortexResult<()> {
326        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
327        let res = arr.search_sorted(&0, SearchSortedSide::Right)?;
328        assert_eq!(arr[res.to_index() - 1], 0);
329        assert_eq!(res, SearchResult::Found(4));
330        Ok(())
331    }
332
333    #[test]
334    fn left_side_equal_end() -> VortexResult<()> {
335        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
336        let res = arr.search_sorted(&9, SearchSortedSide::Left)?;
337        assert_eq!(arr[res.to_index()], 9);
338        assert_eq!(res, SearchResult::Found(9));
339        Ok(())
340    }
341
342    #[test]
343    fn right_side_equal_end() -> VortexResult<()> {
344        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
345        let res = arr.search_sorted(&9, SearchSortedSide::Right)?;
346        assert_eq!(arr[res.to_index() - 1], 9);
347        assert_eq!(res, SearchResult::Found(13));
348        Ok(())
349    }
350}