1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use std::cmp::Ordering;
use std::cmp::Ordering::Greater;

use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
use crate::compute::{IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide};
use crate::validity::Validity;

impl SearchSortedFn for PrimitiveArray {
    fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
        match_each_native_ptype!(self.ptype(), |$T| {
            match self.validity() {
                Validity::NonNullable | Validity::AllValid => {
                    let pvalue: $T = value.try_into()?;
                    Ok(self.maybe_null_slice::<$T>().search_sorted(&pvalue, side))
                }
                Validity::AllInvalid => Ok(SearchResult::NotFound(0)),
                Validity::Array(_) => {
                    let pvalue: $T = value.try_into()?;
                    Ok(SearchSortedNullsLast::new(self).search_sorted(&pvalue, side))
                }
            }
        })
    }
}

struct SearchSortedNullsLast<'a, T> {
    values: &'a [T],
    validity: Validity,
}

impl<'a, T: NativePType> SearchSortedNullsLast<'a, T> {
    pub fn new(array: &'a PrimitiveArray) -> Self {
        Self {
            values: array.maybe_null_slice(),
            validity: array.validity(),
        }
    }
}

impl<'a, T: NativePType> IndexOrd<T> for SearchSortedNullsLast<'a, T> {
    fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
        if self.validity.is_null(idx) {
            return Some(Greater);
        }

        self.values.index_cmp(idx, elem)
    }
}

impl<'a, T> Len for SearchSortedNullsLast<'a, T> {
    fn len(&self) -> usize {
        self.values.len()
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::compute::search_sorted;
    use crate::IntoArray;

    #[test]
    fn test_searchsorted_primitive() {
        let values = vec![1u16, 2, 3].into_array();

        assert_eq!(
            search_sorted(&values, 0, SearchSortedSide::Left).unwrap(),
            SearchResult::NotFound(0)
        );
        assert_eq!(
            search_sorted(&values, 1, SearchSortedSide::Left).unwrap(),
            SearchResult::Found(0)
        );
        assert_eq!(
            search_sorted(&values, 1, SearchSortedSide::Right).unwrap(),
            SearchResult::Found(1)
        );
        assert_eq!(
            search_sorted(&values, 4, SearchSortedSide::Left).unwrap(),
            SearchResult::NotFound(3)
        );
    }
}