1use std::cmp::Ordering;
5use std::cmp::Ordering::Equal;
6use std::cmp::Ordering::Greater;
7use std::cmp::Ordering::Less;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hint;
12
13use vortex_error::VortexResult;
14
15use crate::Array;
16use crate::scalar::Scalar;
17
18#[derive(Debug, Copy, Clone, Eq, PartialEq)]
19pub enum SearchSortedSide {
20 Left,
21 Right,
22}
23
24impl Display for SearchSortedSide {
25 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26 match self {
27 SearchSortedSide::Left => write!(f, "left"),
28 SearchSortedSide::Right => write!(f, "right"),
29 }
30 }
31}
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq)]
35pub enum SearchResult {
36 Found(usize),
39
40 NotFound(usize),
43}
44
45impl SearchResult {
46 pub fn to_found(self) -> Option<usize> {
48 match self {
49 Self::Found(i) => Some(i),
50 Self::NotFound(_) => None,
51 }
52 }
53
54 pub fn to_index(self) -> usize {
56 match self {
57 Self::Found(i) => i,
58 Self::NotFound(i) => i,
59 }
60 }
61
62 pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
67 match self {
68 SearchResult::Found(i) => {
69 if side == SearchSortedSide::Right || i == len {
70 i.saturating_sub(1)
71 } else {
72 i
73 }
74 }
75 SearchResult::NotFound(i) => i.saturating_sub(1),
76 }
77 }
78
79 pub fn to_ends_index(self, len: usize) -> usize {
85 let idx = self.to_index();
86 if idx == len { idx - 1 } else { idx }
87 }
88}
89
90impl Display for SearchResult {
91 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92 match self {
93 SearchResult::Found(i) => write!(f, "Found({i})"),
94 SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
95 }
96 }
97}
98
99pub trait IndexOrd<V> {
100 fn index_cmp(&self, idx: usize, elem: &V) -> VortexResult<Option<Ordering>>;
103
104 fn index_lt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
105 Ok(matches!(self.index_cmp(idx, elem)?, Some(Less)))
106 }
107
108 fn index_le(&self, idx: usize, elem: &V) -> VortexResult<bool> {
109 Ok(matches!(self.index_cmp(idx, elem)?, Some(Less | Equal)))
110 }
111
112 fn index_gt(&self, idx: usize, elem: &V) -> VortexResult<bool> {
113 Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater)))
114 }
115
116 fn index_ge(&self, idx: usize, elem: &V) -> VortexResult<bool> {
117 Ok(matches!(self.index_cmp(idx, elem)?, Some(Greater | Equal)))
118 }
119
120 fn index_len(&self) -> usize;
122}
123
124pub trait SearchSorted<T> {
133 fn search_sorted(&self, value: &T, side: SearchSortedSide) -> VortexResult<SearchResult>
134 where
135 Self: IndexOrd<T>,
136 {
137 match side {
138 SearchSortedSide::Left => self.search_sorted_by(
139 |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
140 |idx| {
141 Ok(if self.index_lt(idx, value)? {
142 Less
143 } else {
144 Greater
145 })
146 },
147 side,
148 ),
149 SearchSortedSide::Right => self.search_sorted_by(
150 |idx| Ok(self.index_cmp(idx, value)?.unwrap_or(Less)),
151 |idx| {
152 Ok(if self.index_le(idx, value)? {
153 Less
154 } else {
155 Greater
156 })
157 },
158 side,
159 ),
160 }
161 }
162
163 fn search_sorted_by<
166 F: FnMut(usize) -> VortexResult<Ordering>,
167 N: FnMut(usize) -> VortexResult<Ordering>,
168 >(
169 &self,
170 find: F,
171 side_find: N,
172 side: SearchSortedSide,
173 ) -> VortexResult<SearchResult>;
174}
175
176impl<S, T> SearchSorted<T> for S
178where
179 S: IndexOrd<T> + ?Sized,
180{
181 fn search_sorted_by<
182 F: FnMut(usize) -> VortexResult<Ordering>,
183 N: FnMut(usize) -> VortexResult<Ordering>,
184 >(
185 &self,
186 find: F,
187 side_find: N,
188 side: SearchSortedSide,
189 ) -> VortexResult<SearchResult> {
190 match search_sorted_side_idx(find, 0, self.index_len())? {
191 SearchResult::Found(found) => {
192 let idx_search = match side {
193 SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found)?,
194 SearchSortedSide::Right => {
195 search_sorted_side_idx(side_find, found, self.index_len())?
196 }
197 };
198 match idx_search {
199 SearchResult::NotFound(i) => Ok(SearchResult::Found(i)),
200 _ => unreachable!(
201 "searching amongst equal values should never return Found result"
202 ),
203 }
204 }
205 s => Ok(s),
206 }
207 }
208}
209
210fn search_sorted_side_idx<F: FnMut(usize) -> VortexResult<Ordering>>(
212 mut find: F,
213 from: usize,
214 to: usize,
215) -> VortexResult<SearchResult> {
216 let mut size = to - from;
217 if size == 0 {
218 return Ok(SearchResult::NotFound(0));
219 }
220 let mut base = from;
221
222 while size > 1 {
227 let half = size / 2;
228 let mid = base + half;
229
230 let cmp = find(mid)?;
234
235 base = if cmp == Greater { base } else { mid };
239
240 size -= half;
248 }
249
250 let cmp = find(base)?;
252 if cmp == Equal {
253 unsafe { hint::assert_unchecked(base < to) };
255 Ok(SearchResult::Found(base))
256 } else {
257 let result = base + (cmp == Less) as usize;
258 unsafe { hint::assert_unchecked(result <= to) };
261 Ok(SearchResult::NotFound(result))
262 }
263}
264
265impl IndexOrd<Scalar> for dyn Array + '_ {
266 fn index_cmp(&self, idx: usize, elem: &Scalar) -> VortexResult<Option<Ordering>> {
267 let scalar_a = self.scalar_at(idx)?;
268 Ok(scalar_a.partial_cmp(elem))
269 }
270
271 fn index_len(&self) -> usize {
272 Self::len(self)
273 }
274}
275
276impl<T: PartialOrd> IndexOrd<T> for [T] {
277 fn index_cmp(&self, idx: usize, elem: &T) -> VortexResult<Option<Ordering>> {
278 Ok(unsafe { self.get_unchecked(idx) }.partial_cmp(elem))
280 }
281
282 fn index_len(&self) -> usize {
283 self.len()
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use vortex_error::VortexResult;
290
291 use crate::search_sorted::SearchResult;
292 use crate::search_sorted::SearchSorted;
293 use crate::search_sorted::SearchSortedSide;
294
295 #[test]
296 fn left_side_equal() -> VortexResult<()> {
297 let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
298 let res = arr.search_sorted(&2, SearchSortedSide::Left)?;
299 assert_eq!(arr[res.to_index()], 2);
300 assert_eq!(res, SearchResult::Found(2));
301 Ok(())
302 }
303
304 #[test]
305 fn right_side_equal() -> VortexResult<()> {
306 let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
307 let res = arr.search_sorted(&2, SearchSortedSide::Right)?;
308 assert_eq!(arr[res.to_index() - 1], 2);
309 assert_eq!(res, SearchResult::Found(6));
310 Ok(())
311 }
312
313 #[test]
314 fn left_side_equal_beginning() -> VortexResult<()> {
315 let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
316 let res = arr.search_sorted(&0, SearchSortedSide::Left)?;
317 assert_eq!(arr[res.to_index()], 0);
318 assert_eq!(res, SearchResult::Found(0));
319 Ok(())
320 }
321
322 #[test]
323 fn right_side_equal_beginning() -> VortexResult<()> {
324 let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
325 let res = arr.search_sorted(&0, SearchSortedSide::Right)?;
326 assert_eq!(arr[res.to_index() - 1], 0);
327 assert_eq!(res, SearchResult::Found(4));
328 Ok(())
329 }
330
331 #[test]
332 fn left_side_equal_end() -> VortexResult<()> {
333 let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
334 let res = arr.search_sorted(&9, SearchSortedSide::Left)?;
335 assert_eq!(arr[res.to_index()], 9);
336 assert_eq!(res, SearchResult::Found(9));
337 Ok(())
338 }
339
340 #[test]
341 fn right_side_equal_end() -> VortexResult<()> {
342 let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
343 let res = arr.search_sorted(&9, SearchSortedSide::Right)?;
344 assert_eq!(arr[res.to_index() - 1], 9);
345 assert_eq!(res, SearchResult::Found(13));
346 Ok(())
347 }
348}