1use std::cmp::Ordering;
2use std::cmp::Ordering::{Equal, Greater, Less};
3use std::fmt::{Debug, Display, Formatter};
4use std::hint;
5
6use itertools::Itertools;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::Array;
11use crate::compute::scalar_at;
12use crate::encoding::Encoding;
13
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum SearchSortedSide {
16    Left,
17    Right,
18}
19
20impl Display for SearchSortedSide {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        match self {
23            SearchSortedSide::Left => write!(f, "left"),
24            SearchSortedSide::Right => write!(f, "right"),
25        }
26    }
27}
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33pub enum SearchResult {
34    Found(usize),
37
38    NotFound(usize),
41}
42
43impl SearchResult {
44    pub fn to_found(self) -> Option<usize> {
46        match self {
47            Self::Found(i) => Some(i),
48            Self::NotFound(_) => None,
49        }
50    }
51
52    pub fn to_index(self) -> usize {
54        match self {
55            Self::Found(i) => i,
56            Self::NotFound(i) => i,
57        }
58    }
59
60    pub fn to_offsets_index(self, len: usize, side: SearchSortedSide) -> usize {
65        match self {
66            SearchResult::Found(i) => {
67                if side == SearchSortedSide::Right || i == len {
68                    i.saturating_sub(1)
69                } else {
70                    i
71                }
72            }
73            SearchResult::NotFound(i) => i.saturating_sub(1),
74        }
75    }
76
77    pub fn to_ends_index(self, len: usize) -> usize {
83        let idx = self.to_index();
84        if idx == len { idx - 1 } else { idx }
85    }
86}
87
88impl Display for SearchResult {
89    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
90        match self {
91            SearchResult::Found(i) => write!(f, "Found({i})"),
92            SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
93        }
94    }
95}
96
97pub trait SearchSortedFn<A: Copy> {
106    fn search_sorted(
107        &self,
108        array: A,
109        value: &Scalar,
110        side: SearchSortedSide,
111    ) -> VortexResult<SearchResult>;
112
113    fn search_sorted_many(
115        &self,
116        array: A,
117        values: &[Scalar],
118        side: SearchSortedSide,
119    ) -> VortexResult<Vec<SearchResult>> {
120        values
121            .iter()
122            .map(|value| self.search_sorted(array, value, side))
123            .try_collect()
124    }
125}
126
127pub trait SearchSortedUsizeFn<A: Copy> {
128    fn search_sorted_usize(
129        &self,
130        array: A,
131        value: usize,
132        side: SearchSortedSide,
133    ) -> VortexResult<SearchResult>;
134
135    fn search_sorted_usize_many(
136        &self,
137        array: A,
138        values: &[usize],
139        side: SearchSortedSide,
140    ) -> VortexResult<Vec<SearchResult>> {
141        values
142            .iter()
143            .map(|&value| self.search_sorted_usize(array, value, side))
144            .try_collect()
145    }
146}
147
148impl<E: Encoding> SearchSortedFn<&dyn Array> for E
149where
150    E: for<'a> SearchSortedFn<&'a E::Array>,
151{
152    fn search_sorted(
153        &self,
154        array: &dyn Array,
155        value: &Scalar,
156        side: SearchSortedSide,
157    ) -> VortexResult<SearchResult> {
158        let array_ref = array
159            .as_any()
160            .downcast_ref::<E::Array>()
161            .vortex_expect("Failed to downcast array");
162        SearchSortedFn::search_sorted(self, array_ref, value, side)
163    }
164
165    fn search_sorted_many(
166        &self,
167        array: &dyn Array,
168        values: &[Scalar],
169        side: SearchSortedSide,
170    ) -> VortexResult<Vec<SearchResult>> {
171        let array_ref = array
172            .as_any()
173            .downcast_ref::<E::Array>()
174            .vortex_expect("Failed to downcast array");
175        SearchSortedFn::search_sorted_many(self, array_ref, values, side)
176    }
177}
178
179impl<E: Encoding> SearchSortedUsizeFn<&dyn Array> for E
180where
181    E: for<'a> SearchSortedUsizeFn<&'a E::Array>,
182{
183    fn search_sorted_usize(
184        &self,
185        array: &dyn Array,
186        value: usize,
187        side: SearchSortedSide,
188    ) -> VortexResult<SearchResult> {
189        let array_ref = array
190            .as_any()
191            .downcast_ref::<E::Array>()
192            .vortex_expect("Failed to downcast array");
193        SearchSortedUsizeFn::search_sorted_usize(self, array_ref, value, side)
194    }
195
196    fn search_sorted_usize_many(
197        &self,
198        array: &dyn Array,
199        values: &[usize],
200        side: SearchSortedSide,
201    ) -> VortexResult<Vec<SearchResult>> {
202        let array_ref = array
203            .as_any()
204            .downcast_ref::<E::Array>()
205            .vortex_expect("Failed to downcast array");
206        SearchSortedUsizeFn::search_sorted_usize_many(self, array_ref, values, side)
207    }
208}
209
210pub fn search_sorted<T: Into<Scalar>>(
211    array: &dyn Array,
212    target: T,
213    side: SearchSortedSide,
214) -> VortexResult<SearchResult> {
215    let Ok(scalar) = target.into().cast(array.dtype()) else {
216        return Ok(SearchResult::NotFound(array.len()));
219    };
220
221    if scalar.is_null() {
222        vortex_bail!("Search sorted with null value is not supported");
223    }
224
225    if let Some(f) = array.vtable().search_sorted_fn() {
226        return f.search_sorted(array, &scalar, side);
227    }
228
229    if array.vtable().scalar_at_fn().is_some() {
231        return Ok(SearchSorted::search_sorted(array, &scalar, side));
232    }
233
234    vortex_bail!(
235        NotImplemented: "search_sorted",
236        array.encoding()
237    )
238}
239
240pub fn search_sorted_usize(
241    array: &dyn Array,
242    target: usize,
243    side: SearchSortedSide,
244) -> VortexResult<SearchResult> {
245    if let Some(f) = array.vtable().search_sorted_usize_fn() {
246        return f.search_sorted_usize(array, target, side);
247    }
248
249    let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
251        return Ok(SearchResult::NotFound(array.len()));
252    };
253
254    if let Some(f) = array.vtable().search_sorted_fn() {
256        return f.search_sorted(array, &target, side);
257    }
258
259    if array.vtable().scalar_at_fn().is_some() {
261        let Ok(target) = target.cast(array.dtype()) else {
264            return Ok(SearchResult::NotFound(array.len()));
265        };
266        return Ok(SearchSorted::search_sorted(array, &target, side));
267    }
268
269    vortex_bail!(
270    NotImplemented: "search_sorted_usize",
271        array.encoding()
272    )
273}
274
275pub fn search_sorted_many<T: Into<Scalar> + Clone>(
277    array: &dyn Array,
278    targets: &[T],
279    side: SearchSortedSide,
280) -> VortexResult<Vec<SearchResult>> {
281    if let Some(f) = array.vtable().search_sorted_fn() {
282        let mut too_big_cast_idxs = Vec::new();
283        let values = targets
284            .iter()
285            .cloned()
286            .enumerate()
287            .filter_map(|(i, t)| {
288                let Ok(c) = t.into().cast(array.dtype()) else {
289                    too_big_cast_idxs.push(i);
290                    return None;
291                };
292                Some(c)
293            })
294            .collect::<Vec<_>>();
295
296        let mut results = f.search_sorted_many(array, &values, side)?;
297        for too_big_idx in too_big_cast_idxs {
298            results.insert(too_big_idx, SearchResult::NotFound(array.len()));
299        }
300        return Ok(results);
301    }
302
303    targets
305        .iter()
306        .map(|target| search_sorted(array, target.clone(), side))
307        .try_collect()
308}
309
310pub fn search_sorted_usize_many(
312    array: &dyn Array,
313    targets: &[usize],
314    side: SearchSortedSide,
315) -> VortexResult<Vec<SearchResult>> {
316    if let Some(f) = array.vtable().search_sorted_usize_fn() {
317        return f.search_sorted_usize_many(array, targets, side);
318    }
319
320    targets
322        .iter()
323        .map(|&target| search_sorted_usize(array, target, side))
324        .try_collect()
325}
326
327#[allow(clippy::len_without_is_empty)]
328pub trait IndexOrd<V> {
329    fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
332
333    fn index_lt(&self, idx: usize, elem: &V) -> bool {
334        matches!(self.index_cmp(idx, elem), Some(Less))
335    }
336
337    fn index_le(&self, idx: usize, elem: &V) -> bool {
338        matches!(self.index_cmp(idx, elem), Some(Less | Equal))
339    }
340
341    fn index_gt(&self, idx: usize, elem: &V) -> bool {
342        matches!(self.index_cmp(idx, elem), Some(Greater))
343    }
344
345    fn index_ge(&self, idx: usize, elem: &V) -> bool {
346        matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
347    }
348
349    fn len(&self) -> usize;
351}
352
353pub trait SearchSorted<T> {
354    fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
355    where
356        Self: IndexOrd<T>,
357    {
358        match side {
359            SearchSortedSide::Left => self.search_sorted_by(
360                |idx| self.index_cmp(idx, value).unwrap_or(Less),
361                |idx| {
362                    if self.index_lt(idx, value) {
363                        Less
364                    } else {
365                        Greater
366                    }
367                },
368                side,
369            ),
370            SearchSortedSide::Right => self.search_sorted_by(
371                |idx| self.index_cmp(idx, value).unwrap_or(Less),
372                |idx| {
373                    if self.index_le(idx, value) {
374                        Less
375                    } else {
376                        Greater
377                    }
378                },
379                side,
380            ),
381        }
382    }
383
384    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
387        &self,
388        find: F,
389        side_find: N,
390        side: SearchSortedSide,
391    ) -> SearchResult;
392}
393
394impl<S, T> SearchSorted<T> for S
396where
397    S: IndexOrd<T> + ?Sized,
398{
399    fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
400        &self,
401        find: F,
402        side_find: N,
403        side: SearchSortedSide,
404    ) -> SearchResult {
405        match search_sorted_side_idx(find, 0, self.len()) {
406            SearchResult::Found(found) => {
407                let idx_search = match side {
408                    SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
409                    SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
410                };
411                match idx_search {
412                    SearchResult::NotFound(i) => SearchResult::Found(i),
413                    _ => unreachable!(
414                        "searching amongst equal values should never return Found result"
415                    ),
416                }
417            }
418            s => s,
419        }
420    }
421}
422
423fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
425    mut find: F,
426    from: usize,
427    to: usize,
428) -> SearchResult {
429    let mut size = to - from;
430    if size == 0 {
431        return SearchResult::NotFound(0);
432    }
433    let mut base = from;
434
435    while size > 1 {
440        let half = size / 2;
441        let mid = base + half;
442
443        let cmp = find(mid);
447
448        base = if cmp == Greater { base } else { mid };
452
453        size -= half;
462    }
463
464    let cmp = find(base);
466    if cmp == Equal {
467        unsafe { hint::assert_unchecked(base < to) };
469        SearchResult::Found(base)
470    } else {
471        let result = base + (cmp == Less) as usize;
472        unsafe { hint::assert_unchecked(result <= to) };
475        SearchResult::NotFound(result)
476    }
477}
478
479impl IndexOrd<Scalar> for dyn Array + '_ {
480    fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
481        let scalar_a = scalar_at(self, idx).ok()?;
482        scalar_a.partial_cmp(elem)
483    }
484
485    fn len(&self) -> usize {
486        Self::len(self)
487    }
488}
489
490impl<T: PartialOrd> IndexOrd<T> for [T] {
491    fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
492        unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
494    }
495
496    fn len(&self) -> usize {
497        self.len()
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use vortex_buffer::buffer;
504
505    use crate::IntoArray;
506    use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
507    use crate::compute::{search_sorted, search_sorted_many};
508
509    #[test]
510    fn left_side_equal() {
511        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
512        let res = arr.search_sorted(&2, SearchSortedSide::Left);
513        assert_eq!(arr[res.to_index()], 2);
514        assert_eq!(res, SearchResult::Found(2));
515    }
516
517    #[test]
518    fn right_side_equal() {
519        let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
520        let res = arr.search_sorted(&2, SearchSortedSide::Right);
521        assert_eq!(arr[res.to_index() - 1], 2);
522        assert_eq!(res, SearchResult::Found(6));
523    }
524
525    #[test]
526    fn left_side_equal_beginning() {
527        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
528        let res = arr.search_sorted(&0, SearchSortedSide::Left);
529        assert_eq!(arr[res.to_index()], 0);
530        assert_eq!(res, SearchResult::Found(0));
531    }
532
533    #[test]
534    fn right_side_equal_beginning() {
535        let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
536        let res = arr.search_sorted(&0, SearchSortedSide::Right);
537        assert_eq!(arr[res.to_index() - 1], 0);
538        assert_eq!(res, SearchResult::Found(4));
539    }
540
541    #[test]
542    fn left_side_equal_end() {
543        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
544        let res = arr.search_sorted(&9, SearchSortedSide::Left);
545        assert_eq!(arr[res.to_index()], 9);
546        assert_eq!(res, SearchResult::Found(9));
547    }
548
549    #[test]
550    fn right_side_equal_end() {
551        let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
552        let res = arr.search_sorted(&9, SearchSortedSide::Right);
553        assert_eq!(arr[res.to_index() - 1], 9);
554        assert_eq!(res, SearchResult::Found(13));
555    }
556
557    #[test]
558    fn failed_cast() {
559        let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
560        let res = search_sorted(&arr, 256, SearchSortedSide::Left).unwrap();
561        assert_eq!(res, SearchResult::NotFound(arr.len()));
562    }
563
564    #[test]
565    fn search_sorted_many_failed_cast() {
566        let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
567        let res = search_sorted_many(&arr, &[256], SearchSortedSide::Left).unwrap();
568        assert_eq!(res, vec![SearchResult::NotFound(arr.len())]);
569    }
570}