1use std::cmp::Ordering;
2use std::cmp::Ordering::{Equal, Greater, Less};
3use std::fmt::{Debug, Display, Formatter};
4use std::hint;
5
6use itertools::Itertools;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::Array;
11use crate::compute::scalar_at;
12use crate::encoding::Encoding;
13
14#[derive(Debug, Copy, Clone)]
15pub enum SearchSortedSide {
16 Left,
17 Right,
18}
19
20impl Display for SearchSortedSide {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 match self {
23 SearchSortedSide::Left => write!(f, "left"),
24 SearchSortedSide::Right => write!(f, "right"),
25 }
26 }
27}
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq)]
31pub enum SearchResult {
32 Found(usize),
34
35 NotFound(usize),
38}
39
40impl SearchResult {
41 pub fn to_found(self) -> Option<usize> {
43 match self {
44 Self::Found(i) => Some(i),
45 Self::NotFound(_) => None,
46 }
47 }
48
49 pub fn to_index(self) -> usize {
51 match self {
52 Self::Found(i) => i,
53 Self::NotFound(i) => i,
54 }
55 }
56
57 pub fn to_offsets_index(self, len: usize) -> usize {
62 match self {
63 SearchResult::Found(i) => {
64 if i == len {
65 i - 1
66 } else {
67 i
68 }
69 }
70 SearchResult::NotFound(i) => i.saturating_sub(1),
71 }
72 }
73
74 pub fn to_ends_index(self, len: usize) -> usize {
80 let idx = self.to_index();
81 if idx == len { idx - 1 } else { idx }
82 }
83
84 #[inline]
86 pub fn map<F>(self, f: F) -> SearchResult
87 where
88 F: FnOnce(usize) -> usize,
89 {
90 match self {
91 SearchResult::Found(i) => SearchResult::Found(f(i)),
92 SearchResult::NotFound(i) => SearchResult::NotFound(f(i)),
93 }
94 }
95}
96
97impl Display for SearchResult {
98 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99 match self {
100 SearchResult::Found(i) => write!(f, "Found({i})"),
101 SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
102 }
103 }
104}
105
106pub trait SearchSortedFn<A: Copy> {
110 fn search_sorted(
111 &self,
112 array: A,
113 value: &Scalar,
114 side: SearchSortedSide,
115 ) -> VortexResult<SearchResult>;
116
117 fn search_sorted_many(
119 &self,
120 array: A,
121 values: &[Scalar],
122 side: SearchSortedSide,
123 ) -> VortexResult<Vec<SearchResult>> {
124 values
125 .iter()
126 .map(|value| self.search_sorted(array, value, side))
127 .try_collect()
128 }
129}
130
131pub trait SearchSortedUsizeFn<A: Copy> {
132 fn search_sorted_usize(
133 &self,
134 array: A,
135 value: usize,
136 side: SearchSortedSide,
137 ) -> VortexResult<SearchResult>;
138
139 fn search_sorted_usize_many(
140 &self,
141 array: A,
142 values: &[usize],
143 side: SearchSortedSide,
144 ) -> VortexResult<Vec<SearchResult>> {
145 values
146 .iter()
147 .map(|&value| self.search_sorted_usize(array, value, side))
148 .try_collect()
149 }
150}
151
152impl<E: Encoding> SearchSortedFn<&dyn Array> for E
153where
154 E: for<'a> SearchSortedFn<&'a E::Array>,
155{
156 fn search_sorted(
157 &self,
158 array: &dyn Array,
159 value: &Scalar,
160 side: SearchSortedSide,
161 ) -> VortexResult<SearchResult> {
162 let array_ref = array
163 .as_any()
164 .downcast_ref::<E::Array>()
165 .vortex_expect("Failed to downcast array");
166 SearchSortedFn::search_sorted(self, array_ref, value, side)
167 }
168
169 fn search_sorted_many(
170 &self,
171 array: &dyn Array,
172 values: &[Scalar],
173 side: SearchSortedSide,
174 ) -> VortexResult<Vec<SearchResult>> {
175 let array_ref = array
176 .as_any()
177 .downcast_ref::<E::Array>()
178 .vortex_expect("Failed to downcast array");
179 SearchSortedFn::search_sorted_many(self, array_ref, values, side)
180 }
181}
182
183impl<E: Encoding> SearchSortedUsizeFn<&dyn Array> for E
184where
185 E: for<'a> SearchSortedUsizeFn<&'a E::Array>,
186{
187 fn search_sorted_usize(
188 &self,
189 array: &dyn Array,
190 value: usize,
191 side: SearchSortedSide,
192 ) -> VortexResult<SearchResult> {
193 let array_ref = array
194 .as_any()
195 .downcast_ref::<E::Array>()
196 .vortex_expect("Failed to downcast array");
197 SearchSortedUsizeFn::search_sorted_usize(self, array_ref, value, side)
198 }
199
200 fn search_sorted_usize_many(
201 &self,
202 array: &dyn Array,
203 values: &[usize],
204 side: SearchSortedSide,
205 ) -> VortexResult<Vec<SearchResult>> {
206 let array_ref = array
207 .as_any()
208 .downcast_ref::<E::Array>()
209 .vortex_expect("Failed to downcast array");
210 SearchSortedUsizeFn::search_sorted_usize_many(self, array_ref, values, side)
211 }
212}
213
214pub fn search_sorted<T: Into<Scalar>>(
215 array: &dyn Array,
216 target: T,
217 side: SearchSortedSide,
218) -> VortexResult<SearchResult> {
219 let Ok(scalar) = target.into().cast(array.dtype()) else {
220 return Ok(SearchResult::NotFound(array.len()));
223 };
224
225 if scalar.is_null() {
226 vortex_bail!("Search sorted with null value is not supported");
227 }
228
229 if let Some(f) = array.vtable().search_sorted_fn() {
230 return f.search_sorted(array, &scalar, side);
231 }
232
233 if array.vtable().scalar_at_fn().is_some() {
235 return Ok(SearchSorted::search_sorted(array, &scalar, side));
236 }
237
238 vortex_bail!(
239 NotImplemented: "search_sorted",
240 array.encoding()
241 )
242}
243
244pub fn search_sorted_usize(
245 array: &dyn Array,
246 target: usize,
247 side: SearchSortedSide,
248) -> VortexResult<SearchResult> {
249 if let Some(f) = array.vtable().search_sorted_usize_fn() {
250 return f.search_sorted_usize(array, target, side);
251 }
252
253 let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
255 return Ok(SearchResult::NotFound(array.len()));
256 };
257
258 if let Some(f) = array.vtable().search_sorted_fn() {
260 return f.search_sorted(array, &target, side);
261 }
262
263 if array.vtable().scalar_at_fn().is_some() {
265 let Ok(target) = target.cast(array.dtype()) else {
268 return Ok(SearchResult::NotFound(array.len()));
269 };
270 return Ok(SearchSorted::search_sorted(array, &target, side));
271 }
272
273 vortex_bail!(
274 NotImplemented: "search_sorted_usize",
275 array.encoding()
276 )
277}
278
279pub fn search_sorted_many<T: Into<Scalar> + Clone>(
281 array: &dyn Array,
282 targets: &[T],
283 side: SearchSortedSide,
284) -> VortexResult<Vec<SearchResult>> {
285 if let Some(f) = array.vtable().search_sorted_fn() {
286 let mut too_big_cast_idxs = Vec::new();
287 let values = targets
288 .iter()
289 .cloned()
290 .enumerate()
291 .filter_map(|(i, t)| {
292 let Ok(c) = t.into().cast(array.dtype()) else {
293 too_big_cast_idxs.push(i);
294 return None;
295 };
296 Some(c)
297 })
298 .collect::<Vec<_>>();
299
300 let mut results = f.search_sorted_many(array, &values, side)?;
301 for too_big_idx in too_big_cast_idxs {
302 results.insert(too_big_idx, SearchResult::NotFound(array.len()));
303 }
304 return Ok(results);
305 }
306
307 targets
309 .iter()
310 .map(|target| search_sorted(array, target.clone(), side))
311 .try_collect()
312}
313
314pub fn search_sorted_usize_many(
316 array: &dyn Array,
317 targets: &[usize],
318 side: SearchSortedSide,
319) -> VortexResult<Vec<SearchResult>> {
320 if let Some(f) = array.vtable().search_sorted_usize_fn() {
321 return f.search_sorted_usize_many(array, targets, side);
322 }
323
324 targets
326 .iter()
327 .map(|&target| search_sorted_usize(array, target, side))
328 .try_collect()
329}
330
331pub trait IndexOrd<V> {
332 fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
335
336 fn index_lt(&self, idx: usize, elem: &V) -> bool {
337 matches!(self.index_cmp(idx, elem), Some(Less))
338 }
339
340 fn index_le(&self, idx: usize, elem: &V) -> bool {
341 matches!(self.index_cmp(idx, elem), Some(Less | Equal))
342 }
343
344 fn index_gt(&self, idx: usize, elem: &V) -> bool {
345 matches!(self.index_cmp(idx, elem), Some(Greater))
346 }
347
348 fn index_ge(&self, idx: usize, elem: &V) -> bool {
349 matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
350 }
351}
352
353#[allow(clippy::len_without_is_empty)]
354pub trait Len {
355 fn len(&self) -> usize;
356}
357
358pub trait SearchSorted<T> {
359 fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
360 where
361 Self: IndexOrd<T>,
362 {
363 match side {
364 SearchSortedSide::Left => self.search_sorted_by(
365 |idx| self.index_cmp(idx, value).unwrap_or(Less),
366 |idx| {
367 if self.index_lt(idx, value) {
368 Less
369 } else {
370 Greater
371 }
372 },
373 side,
374 ),
375 SearchSortedSide::Right => self.search_sorted_by(
376 |idx| self.index_cmp(idx, value).unwrap_or(Less),
377 |idx| {
378 if self.index_le(idx, value) {
379 Less
380 } else {
381 Greater
382 }
383 },
384 side,
385 ),
386 }
387 }
388
389 fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
392 &self,
393 find: F,
394 side_find: N,
395 side: SearchSortedSide,
396 ) -> SearchResult;
397}
398
399impl<S, T> SearchSorted<T> for S
401where
402 S: IndexOrd<T> + Len + ?Sized,
403{
404 fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
405 &self,
406 find: F,
407 side_find: N,
408 side: SearchSortedSide,
409 ) -> SearchResult {
410 match search_sorted_side_idx(find, 0, self.len()) {
411 SearchResult::Found(found) => {
412 let idx_search = match side {
413 SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
414 SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
415 };
416 match idx_search {
417 SearchResult::NotFound(i) => SearchResult::Found(i),
418 _ => unreachable!(
419 "searching amongst equal values should never return Found result"
420 ),
421 }
422 }
423 s => s,
424 }
425 }
426}
427
428fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
430 mut find: F,
431 from: usize,
432 to: usize,
433) -> SearchResult {
434 let mut size = to - from;
435 if size == 0 {
436 return SearchResult::NotFound(0);
437 }
438 let mut base = from;
439
440 while size > 1 {
445 let half = size / 2;
446 let mid = base + half;
447
448 let cmp = find(mid);
452
453 base = if cmp == Greater { base } else { mid };
457
458 size -= half;
467 }
468
469 let cmp = find(base);
471 if cmp == Equal {
472 unsafe { hint::assert_unchecked(base < to) };
474 SearchResult::Found(base)
475 } else {
476 let result = base + (cmp == Less) as usize;
477 unsafe { hint::assert_unchecked(result <= to) };
480 SearchResult::NotFound(result)
481 }
482}
483
484impl IndexOrd<Scalar> for dyn Array + '_ {
485 fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
486 let scalar_a = scalar_at(self, idx).ok()?;
487 scalar_a.partial_cmp(elem)
488 }
489}
490
491impl<T: PartialOrd> IndexOrd<T> for [T] {
492 fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
493 unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
495 }
496}
497
498impl Len for dyn Array + '_ {
499 #[allow(clippy::same_name_method)]
500 fn len(&self) -> usize {
501 Self::len(self)
502 }
503}
504
505impl<T> Len for [T] {
506 fn len(&self) -> usize {
507 self.len()
508 }
509}
510
511#[cfg(test)]
512mod test {
513 use vortex_buffer::buffer;
514
515 use crate::IntoArray;
516 use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
517 use crate::compute::{search_sorted, search_sorted_many};
518
519 #[test]
520 fn left_side_equal() {
521 let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
522 let res = arr.search_sorted(&2, SearchSortedSide::Left);
523 assert_eq!(arr[res.to_index()], 2);
524 assert_eq!(res, SearchResult::Found(2));
525 }
526
527 #[test]
528 fn right_side_equal() {
529 let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
530 let res = arr.search_sorted(&2, SearchSortedSide::Right);
531 assert_eq!(arr[res.to_index() - 1], 2);
532 assert_eq!(res, SearchResult::Found(6));
533 }
534
535 #[test]
536 fn left_side_equal_beginning() {
537 let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
538 let res = arr.search_sorted(&0, SearchSortedSide::Left);
539 assert_eq!(arr[res.to_index()], 0);
540 assert_eq!(res, SearchResult::Found(0));
541 }
542
543 #[test]
544 fn right_side_equal_beginning() {
545 let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
546 let res = arr.search_sorted(&0, SearchSortedSide::Right);
547 assert_eq!(arr[res.to_index() - 1], 0);
548 assert_eq!(res, SearchResult::Found(4));
549 }
550
551 #[test]
552 fn left_side_equal_end() {
553 let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
554 let res = arr.search_sorted(&9, SearchSortedSide::Left);
555 assert_eq!(arr[res.to_index()], 9);
556 assert_eq!(res, SearchResult::Found(9));
557 }
558
559 #[test]
560 fn right_side_equal_end() {
561 let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
562 let res = arr.search_sorted(&9, SearchSortedSide::Right);
563 assert_eq!(arr[res.to_index() - 1], 9);
564 assert_eq!(res, SearchResult::Found(13));
565 }
566
567 #[test]
568 fn failed_cast() {
569 let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
570 let res = search_sorted(&arr, 256, SearchSortedSide::Left).unwrap();
571 assert_eq!(res, SearchResult::NotFound(arr.len()));
572 }
573
574 #[test]
575 fn search_sorted_many_failed_cast() {
576 let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
577 let res = search_sorted_many(&arr, &[256], SearchSortedSide::Left).unwrap();
578 assert_eq!(res, vec![SearchResult::NotFound(arr.len())]);
579 }
580}