Skip to main content

vortex_array/
search_sorted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::cmp::Ordering::Equal;
6use std::cmp::Ordering::Greater;
7use std::cmp::Ordering::Less;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hint;
12
13use vortex_error::VortexResult;
14
15use crate::Array;
16use crate::scalar::Scalar;
17
18#[derive(Debug, Copy, Clone, Eq, PartialEq)]
19pub enum SearchSortedSide {
20    Left,
21    Right,
22}
23
24impl Display for SearchSortedSide {
25    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26        match self {
27            SearchSortedSide::Left => write!(f, "left"),
28            SearchSortedSide::Right => write!(f, "right"),
29        }
30    }
31}
32
33/// Result of performing search_sorted on an Array
34#[derive(Debug, Copy, Clone, PartialEq, Eq)]
35pub enum SearchResult {
36    /// Result for a found element was found in the array and another one could be inserted at the given position
37    /// in the sorted order
38    Found(usize),
39
40    /// Result for an element not found, but that could be inserted at the given position
41    /// in the sorted order.
42    NotFound(usize),
43}
44
45impl SearchResult {
46    /// Convert search result to an index only if the value have been found
47    pub fn to_found(self) -> Option<usize> {
48        match self {
49            Self::Found(i) => Some(i),
50            Self::NotFound(_) => None,
51        }
52    }
53
54    /// Extract index out of search result regardless of whether the value have been found or not
55    pub fn to_index(self) -> usize {
56        match self {
57            Self::Found(i) => i,
58            Self::NotFound(i) => i,
59        }
60    }
61
62    /// Convert search result into an index suitable for searching array of offset indices, i.e. first element starts at 0.
63    ///
64    /// For example for a ChunkedArray with chunk offsets array [0, 3, 8, 10] you can use this method to
65    /// obtain index suitable for indexing into it after performing a search
66    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
67        match self {
68            SearchResult::Found(i) => {
69                if side == SearchSortedSide::Right || i == len {
70                    i.saturating_sub(1)
71                } else {
72                    i
73                }
74            }
75            SearchResult::NotFound(i) => i.saturating_sub(1),
76        }
77    }
78
79    /// Convert search result into an index suitable for searching array of end indices without 0 offset,
80    /// i.e. first element implicitly covers 0..0th-element range.
81    ///
82    /// For example for a RunEndArray with ends array [3, 8, 10], you can use this method to obtain index suitable for
83    /// indexing into it after performing a search
84    pub fn to_ends_index(self, len: usize) -> usize {
85        let idx = self.to_index();
86        if idx == len { idx - 1 } else { idx }
87    }
88}
89
90impl Display for SearchResult {
91    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92        match self {
93            SearchResult::Found(i) => write!(f, "Found({i})"),
94            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
95        }
96    }
97}
98
99pub trait IndexOrd<V> {
100    /// PartialOrd of the value at index `idx` with `elem`.
101    /// For example, if self\[idx\] > elem, return Some(Greater).
102    fn index_cmp(&self, idx: usize, elem: &V) -> VortexResult<Option<Ordering>>;
103
104    fn index_lt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
105        Ok(matches!(self.index_cmp(idx, elem)?, Some(Less)))
106    }
107
108    fn index_le(&self, idx: usize, elem: &V) -> VortexResult<bool> {
109        Ok(matches!(self.index_cmp(idx, elem)?, Some(Less | Equal)))
110    }
111
112    fn index_gt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
113        Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater)))
114    }
115
116    fn index_ge(&self, idx: usize, elem: &V) -> VortexResult<bool> {
117        Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater | Equal)))
118    }
119
120    /// Get the length of the underlying ordered collection
121    fn index_len(&self) -> usize;
122}
123
124/// Searches for value assuming the array is sorted.
125///
126/// Returned indices satisfy following condition if the search for value was to be inserted into the array at found positions
127///
128/// |side |result satisfies|
129/// |-----|----------------|
130/// |left |array\[i-1\] < value <= array\[i\]|
131/// |right|array\[i-1\] <= value < array\[i\]|
132pub trait SearchSorted<T> {
133    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> VortexResult<SearchResult>
134    where
135        Self: IndexOrd<T>,
136    {
137        match side {
138            SearchSortedSide::Left => self.search_sorted_by(
139                |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
140                |idx| {
141                    Ok(if self.index_lt(idx, value)? {
142                        Less
143                    } else {
144                        Greater
145                    })
146                },
147                side,
148            ),
149            SearchSortedSide::Right => self.search_sorted_by(
150                |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
151                |idx| {
152                    Ok(if self.index_le(idx, value)? {
153                        Less
154                    } else {
155                        Greater
156                    })
157                },
158                side,
159            ),
160        }
161    }
162
163    /// find function is used to find the element if it exists, if element exists side_find will be
164    /// used to find desired index amongst equal values
165    fn search_sorted_by<
166        F: FnMut(usize) -> VortexResult<Ordering>,
167        N: FnMut(usize) -> VortexResult<Ordering>,
168    >(
169        &self,
170        find: F,
171        side_find: N,
172        side: SearchSortedSide,
173    ) -> VortexResult<SearchResult>;
174}
175
176// Default implementation for types that implement IndexOrd.
177impl<S, T> SearchSorted<T> for S
178where
179    S: IndexOrd<T> + ?Sized,
180{
181    fn search_sorted_by<
182        F: FnMut(usize) -> VortexResult<Ordering>,
183        N: FnMut(usize) -> VortexResult<Ordering>,
184    >(
185        &self,
186        find: F,
187        side_find: N,
188        side: SearchSortedSide,
189    ) -> VortexResult<SearchResult> {
190        match search_sorted_side_idx(find, 0, self.index_len())? {
191            SearchResult::Found(found) => {
192                let idx_search = match side {
193                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found)?,
194                    SearchSortedSide::Right => {
195                        search_sorted_side_idx(side_find, found, self.index_len())?
196                    }
197                };
198                match idx_search {
199                    SearchResult::NotFound(i) => Ok(SearchResult::Found(i)),
200                    _ => unreachable!(
201                        "searching amongst equal values should never return Found result"
202                    ),
203                }
204            }
205            s => Ok(s),
206        }
207    }
208}
209
210// Code adapted from Rust standard library slice::binary_search_by
211fn search_sorted_side_idx<F: FnMut(usize) -> VortexResult<Ordering>>(
212    mut find: F,
213    from: usize,
214    to: usize,
215) -> VortexResult<SearchResult> {
216    let mut size = to - from;
217    if size == 0 {
218        return Ok(SearchResult::NotFound(0));
219    }
220    let mut base = from;
221
222    // This loop intentionally doesn't have an early exit if the comparison
223    // returns Equal. We want the number of loop iterations to depend *only*
224    // on the size of the input slice so that the CPU can reliably predict
225    // the loop count.
226    while size > 1 {
227        let half = size / 2;
228        let mid = base + half;
229
230        // SAFETY: the call is made safe by the following inconstants:
231        // - `mid >= 0`: by definition
232        // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
233        let cmp = find(mid)?;
234
235        // Binary search interacts poorly with branch prediction, so force
236        // the compiler to use conditional moves if supported by the target
237        // architecture.
238        base = if cmp == Greater { base } else { mid };
239
240        // This is imprecise in the case where `size` is odd and the
241        // comparison returns Greater: the mid element still gets included
242        // by `size` even though it's known to be larger than the element
243        // being searched for.
244        // This is fine though: we gain more performance by keeping the
245        // loop iteration count invariant (and thus predictable) than we
246        // lose from considering one additional element.
247        size -= half;
248    }
249
250    // SAFETY: base is always in [0, size) because base <= mid.
251    let cmp = find(base)?;
252    if cmp == Equal {
253        // SAFETY: same as the call to `find` above.
254        unsafe { hint::assert_unchecked(base < to) };
255        Ok(SearchResult::Found(base))
256    } else {
257        let result = base + (cmp == Less) as usize;
258        // SAFETY: same as the call to `find` above.
259        // Note that this is `<=`, unlike the assert in the `Found` path.
260        unsafe { hint::assert_unchecked(result <= to) };
261        Ok(SearchResult::NotFound(result))
262    }
263}
264
265impl IndexOrd<Scalar> for dyn Array + '_ {
266    fn index_cmp(&self, idx: usize, elem: &Scalar) -> VortexResult<Option<Ordering>> {
267        let scalar_a = self.scalar_at(idx)?;
268        Ok(scalar_a.partial_cmp(elem))
269    }
270
271    fn index_len(&self) -> usize {
272        Self::len(self)
273    }
274}
275
276impl<T: PartialOrd> IndexOrd<T> for [T] {
277    fn index_cmp(&self, idx: usize, elem: &T) -> VortexResult<Option<Ordering>> {
278        // SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
279        Ok(unsafe { self.get_unchecked(idx) }.partial_cmp(elem))
280    }
281
282    fn index_len(&self) -> usize {
283        self.len()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use vortex_error::VortexResult;
290
291    use crate::search_sorted::SearchResult;
292    use crate::search_sorted::SearchSorted;
293    use crate::search_sorted::SearchSortedSide;
294
295    #[test]
296    fn left_side_equal() -> VortexResult<()> {
297        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
298        let res = arr.search_sorted(&2, SearchSortedSide::Left)?;
299        assert_eq!(arr[res.to_index()], 2);
300        assert_eq!(res, SearchResult::Found(2));
301        Ok(())
302    }
303
304    #[test]
305    fn right_side_equal() -> VortexResult<()> {
306        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
307        let res = arr.search_sorted(&2, SearchSortedSide::Right)?;
308        assert_eq!(arr[res.to_index() - 1], 2);
309        assert_eq!(res, SearchResult::Found(6));
310        Ok(())
311    }
312
313    #[test]
314    fn left_side_equal_beginning() -> VortexResult<()> {
315        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
316        let res = arr.search_sorted(&0, SearchSortedSide::Left)?;
317        assert_eq!(arr[res.to_index()], 0);
318        assert_eq!(res, SearchResult::Found(0));
319        Ok(())
320    }
321
322    #[test]
323    fn right_side_equal_beginning() -> VortexResult<()> {
324        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
325        let res = arr.search_sorted(&0, SearchSortedSide::Right)?;
326        assert_eq!(arr[res.to_index() - 1], 0);
327        assert_eq!(res, SearchResult::Found(4));
328        Ok(())
329    }
330
331    #[test]
332    fn left_side_equal_end() -> VortexResult<()> {
333        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
334        let res = arr.search_sorted(&9, SearchSortedSide::Left)?;
335        assert_eq!(arr[res.to_index()], 9);
336        assert_eq!(res, SearchResult::Found(9));
337        Ok(())
338    }
339
340    #[test]
341    fn right_side_equal_end() -> VortexResult<()> {
342        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
343        let res = arr.search_sorted(&9, SearchSortedSide::Right)?;
344        assert_eq!(arr[res.to_index() - 1], 9);
345        assert_eq!(res, SearchResult::Found(13));
346        Ok(())
347    }
348}