1use std::cmp::Ordering;
2use std::cmp::Ordering::{Equal, Greater, Less};
3use std::fmt::{Debug, Display, Formatter};
4use std::hint;
5
6use itertools::Itertools;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::Array;
11use crate::compute::scalar_at;
12use crate::encoding::Encoding;
13
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum SearchSortedSide {
16 Left,
17 Right,
18}
19
20impl Display for SearchSortedSide {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 match self {
23 SearchSortedSide::Left => write!(f, "left"),
24 SearchSortedSide::Right => write!(f, "right"),
25 }
26 }
27}
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33pub enum SearchResult {
34 Found(usize),
37
38 NotFound(usize),
41}
42
43impl SearchResult {
44 pub fn to_found(self) -> Option<usize> {
46 match self {
47 Self::Found(i) => Some(i),
48 Self::NotFound(_) => None,
49 }
50 }
51
52 pub fn to_index(self) -> usize {
54 match self {
55 Self::Found(i) => i,
56 Self::NotFound(i) => i,
57 }
58 }
59
60 pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
65 match self {
66 SearchResult::Found(i) => {
67 if side == SearchSortedSide::Right || i == len {
68 i.saturating_sub(1)
69 } else {
70 i
71 }
72 }
73 SearchResult::NotFound(i) => i.saturating_sub(1),
74 }
75 }
76
77 pub fn to_ends_index(self, len: usize) -> usize {
83 let idx = self.to_index();
84 if idx == len { idx - 1 } else { idx }
85 }
86}
87
88impl Display for SearchResult {
89 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
90 match self {
91 SearchResult::Found(i) => write!(f, "Found({i})"),
92 SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
93 }
94 }
95}
96
97pub trait SearchSortedFn<A: Copy> {
106 fn search_sorted(
107 &self,
108 array: A,
109 value: &Scalar,
110 side: SearchSortedSide,
111 ) -> VortexResult<SearchResult>;
112
113 fn search_sorted_many(
115 &self,
116 array: A,
117 values: &[Scalar],
118 side: SearchSortedSide,
119 ) -> VortexResult<Vec<SearchResult>> {
120 values
121 .iter()
122 .map(|value| self.search_sorted(array, value, side))
123 .try_collect()
124 }
125}
126
127pub trait SearchSortedUsizeFn<A: Copy> {
128 fn search_sorted_usize(
129 &self,
130 array: A,
131 value: usize,
132 side: SearchSortedSide,
133 ) -> VortexResult<SearchResult>;
134
135 fn search_sorted_usize_many(
136 &self,
137 array: A,
138 values: &[usize],
139 side: SearchSortedSide,
140 ) -> VortexResult<Vec<SearchResult>> {
141 values
142 .iter()
143 .map(|&value| self.search_sorted_usize(array, value, side))
144 .try_collect()
145 }
146}
147
148impl<E: Encoding> SearchSortedFn<&dyn Array> for E
149where
150 E: for<'a> SearchSortedFn<&'a E::Array>,
151{
152 fn search_sorted(
153 &self,
154 array: &dyn Array,
155 value: &Scalar,
156 side: SearchSortedSide,
157 ) -> VortexResult<SearchResult> {
158 let array_ref = array
159 .as_any()
160 .downcast_ref::<E::Array>()
161 .vortex_expect("Failed to downcast array");
162 SearchSortedFn::search_sorted(self, array_ref, value, side)
163 }
164
165 fn search_sorted_many(
166 &self,
167 array: &dyn Array,
168 values: &[Scalar],
169 side: SearchSortedSide,
170 ) -> VortexResult<Vec<SearchResult>> {
171 let array_ref = array
172 .as_any()
173 .downcast_ref::<E::Array>()
174 .vortex_expect("Failed to downcast array");
175 SearchSortedFn::search_sorted_many(self, array_ref, values, side)
176 }
177}
178
179impl<E: Encoding> SearchSortedUsizeFn<&dyn Array> for E
180where
181 E: for<'a> SearchSortedUsizeFn<&'a E::Array>,
182{
183 fn search_sorted_usize(
184 &self,
185 array: &dyn Array,
186 value: usize,
187 side: SearchSortedSide,
188 ) -> VortexResult<SearchResult> {
189 let array_ref = array
190 .as_any()
191 .downcast_ref::<E::Array>()
192 .vortex_expect("Failed to downcast array");
193 SearchSortedUsizeFn::search_sorted_usize(self, array_ref, value, side)
194 }
195
196 fn search_sorted_usize_many(
197 &self,
198 array: &dyn Array,
199 values: &[usize],
200 side: SearchSortedSide,
201 ) -> VortexResult<Vec<SearchResult>> {
202 let array_ref = array
203 .as_any()
204 .downcast_ref::<E::Array>()
205 .vortex_expect("Failed to downcast array");
206 SearchSortedUsizeFn::search_sorted_usize_many(self, array_ref, values, side)
207 }
208}
209
210pub fn search_sorted<T: Into<Scalar>>(
211 array: &dyn Array,
212 target: T,
213 side: SearchSortedSide,
214) -> VortexResult<SearchResult> {
215 let Ok(scalar) = target.into().cast(array.dtype()) else {
216 return Ok(SearchResult::NotFound(array.len()));
219 };
220
221 if scalar.is_null() {
222 vortex_bail!("Search sorted with null value is not supported");
223 }
224
225 if let Some(f) = array.vtable().search_sorted_fn() {
226 return f.search_sorted(array, &scalar, side);
227 }
228
229 if array.vtable().scalar_at_fn().is_some() {
231 return Ok(SearchSorted::search_sorted(array, &scalar, side));
232 }
233
234 vortex_bail!(
235 NotImplemented: "search_sorted",
236 array.encoding()
237 )
238}
239
240pub fn search_sorted_usize(
241 array: &dyn Array,
242 target: usize,
243 side: SearchSortedSide,
244) -> VortexResult<SearchResult> {
245 if let Some(f) = array.vtable().search_sorted_usize_fn() {
246 return f.search_sorted_usize(array, target, side);
247 }
248
249 let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
251 return Ok(SearchResult::NotFound(array.len()));
252 };
253
254 if let Some(f) = array.vtable().search_sorted_fn() {
256 return f.search_sorted(array, &target, side);
257 }
258
259 if array.vtable().scalar_at_fn().is_some() {
261 let Ok(target) = target.cast(array.dtype()) else {
264 return Ok(SearchResult::NotFound(array.len()));
265 };
266 return Ok(SearchSorted::search_sorted(array, &target, side));
267 }
268
269 vortex_bail!(
270 NotImplemented: "search_sorted_usize",
271 array.encoding()
272 )
273}
274
275pub fn search_sorted_many<T: Into<Scalar> + Clone>(
277 array: &dyn Array,
278 targets: &[T],
279 side: SearchSortedSide,
280) -> VortexResult<Vec<SearchResult>> {
281 if let Some(f) = array.vtable().search_sorted_fn() {
282 let mut too_big_cast_idxs = Vec::new();
283 let values = targets
284 .iter()
285 .cloned()
286 .enumerate()
287 .filter_map(|(i, t)| {
288 let Ok(c) = t.into().cast(array.dtype()) else {
289 too_big_cast_idxs.push(i);
290 return None;
291 };
292 Some(c)
293 })
294 .collect::<Vec<_>>();
295
296 let mut results = f.search_sorted_many(array, &values, side)?;
297 for too_big_idx in too_big_cast_idxs {
298 results.insert(too_big_idx, SearchResult::NotFound(array.len()));
299 }
300 return Ok(results);
301 }
302
303 targets
305 .iter()
306 .map(|target| search_sorted(array, target.clone(), side))
307 .try_collect()
308}
309
310pub fn search_sorted_usize_many(
312 array: &dyn Array,
313 targets: &[usize],
314 side: SearchSortedSide,
315) -> VortexResult<Vec<SearchResult>> {
316 if let Some(f) = array.vtable().search_sorted_usize_fn() {
317 return f.search_sorted_usize_many(array, targets, side);
318 }
319
320 targets
322 .iter()
323 .map(|&target| search_sorted_usize(array, target, side))
324 .try_collect()
325}
326
327#[allow(clippy::len_without_is_empty)]
328pub trait IndexOrd<V> {
329 fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
332
333 fn index_lt(&self, idx: usize, elem: &V) -> bool {
334 matches!(self.index_cmp(idx, elem), Some(Less))
335 }
336
337 fn index_le(&self, idx: usize, elem: &V) -> bool {
338 matches!(self.index_cmp(idx, elem), Some(Less | Equal))
339 }
340
341 fn index_gt(&self, idx: usize, elem: &V) -> bool {
342 matches!(self.index_cmp(idx, elem), Some(Greater))
343 }
344
345 fn index_ge(&self, idx: usize, elem: &V) -> bool {
346 matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
347 }
348
349 fn len(&self) -> usize;
351}
352
353pub trait SearchSorted<T> {
354 fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
355 where
356 Self: IndexOrd<T>,
357 {
358 match side {
359 SearchSortedSide::Left => self.search_sorted_by(
360 |idx| self.index_cmp(idx, value).unwrap_or(Less),
361 |idx| {
362 if self.index_lt(idx, value) {
363 Less
364 } else {
365 Greater
366 }
367 },
368 side,
369 ),
370 SearchSortedSide::Right => self.search_sorted_by(
371 |idx| self.index_cmp(idx, value).unwrap_or(Less),
372 |idx| {
373 if self.index_le(idx, value) {
374 Less
375 } else {
376 Greater
377 }
378 },
379 side,
380 ),
381 }
382 }
383
384 fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
387 &self,
388 find: F,
389 side_find: N,
390 side: SearchSortedSide,
391 ) -> SearchResult;
392}
393
394impl<S, T> SearchSorted<T> for S
396where
397 S: IndexOrd<T> + ?Sized,
398{
399 fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
400 &self,
401 find: F,
402 side_find: N,
403 side: SearchSortedSide,
404 ) -> SearchResult {
405 match search_sorted_side_idx(find, 0, self.len()) {
406 SearchResult::Found(found) => {
407 let idx_search = match side {
408 SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
409 SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
410 };
411 match idx_search {
412 SearchResult::NotFound(i) => SearchResult::Found(i),
413 _ => unreachable!(
414 "searching amongst equal values should never return Found result"
415 ),
416 }
417 }
418 s => s,
419 }
420 }
421}
422
423fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
425 mut find: F,
426 from: usize,
427 to: usize,
428) -> SearchResult {
429 let mut size = to - from;
430 if size == 0 {
431 return SearchResult::NotFound(0);
432 }
433 let mut base = from;
434
435 while size > 1 {
440 let half = size / 2;
441 let mid = base + half;
442
443 let cmp = find(mid);
447
448 base = if cmp == Greater { base } else { mid };
452
453 size -= half;
462 }
463
464 let cmp = find(base);
466 if cmp == Equal {
467 unsafe { hint::assert_unchecked(base < to) };
469 SearchResult::Found(base)
470 } else {
471 let result = base + (cmp == Less) as usize;
472 unsafe { hint::assert_unchecked(result <= to) };
475 SearchResult::NotFound(result)
476 }
477}
478
479impl IndexOrd<Scalar> for dyn Array + '_ {
480 fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
481 let scalar_a = scalar_at(self, idx).ok()?;
482 scalar_a.partial_cmp(elem)
483 }
484
485 fn len(&self) -> usize {
486 Self::len(self)
487 }
488}
489
490impl<T: PartialOrd> IndexOrd<T> for [T] {
491 fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
492 unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
494 }
495
496 fn len(&self) -> usize {
497 self.len()
498 }
499}
500
501#[cfg(test)]
502mod test {
503 use vortex_buffer::buffer;
504
505 use crate::IntoArray;
506 use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
507 use crate::compute::{search_sorted, search_sorted_many};
508
509 #[test]
510 fn left_side_equal() {
511 let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
512 let res = arr.search_sorted(&2, SearchSortedSide::Left);
513 assert_eq!(arr[res.to_index()], 2);
514 assert_eq!(res, SearchResult::Found(2));
515 }
516
517 #[test]
518 fn right_side_equal() {
519 let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
520 let res = arr.search_sorted(&2, SearchSortedSide::Right);
521 assert_eq!(arr[res.to_index() - 1], 2);
522 assert_eq!(res, SearchResult::Found(6));
523 }
524
525 #[test]
526 fn left_side_equal_beginning() {
527 let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
528 let res = arr.search_sorted(&0, SearchSortedSide::Left);
529 assert_eq!(arr[res.to_index()], 0);
530 assert_eq!(res, SearchResult::Found(0));
531 }
532
533 #[test]
534 fn right_side_equal_beginning() {
535 let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
536 let res = arr.search_sorted(&0, SearchSortedSide::Right);
537 assert_eq!(arr[res.to_index() - 1], 0);
538 assert_eq!(res, SearchResult::Found(4));
539 }
540
541 #[test]
542 fn left_side_equal_end() {
543 let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
544 let res = arr.search_sorted(&9, SearchSortedSide::Left);
545 assert_eq!(arr[res.to_index()], 9);
546 assert_eq!(res, SearchResult::Found(9));
547 }
548
549 #[test]
550 fn right_side_equal_end() {
551 let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
552 let res = arr.search_sorted(&9, SearchSortedSide::Right);
553 assert_eq!(arr[res.to_index() - 1], 9);
554 assert_eq!(res, SearchResult::Found(13));
555 }
556
557 #[test]
558 fn failed_cast() {
559 let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
560 let res = search_sorted(&arr, 256, SearchSortedSide::Left).unwrap();
561 assert_eq!(res, SearchResult::NotFound(arr.len()));
562 }
563
564 #[test]
565 fn search_sorted_many_failed_cast() {
566 let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
567 let res = search_sorted_many(&arr, &[256], SearchSortedSide::Left).unwrap();
568 assert_eq!(res, vec![SearchResult::NotFound(arr.len())]);
569 }
570}