vortex_sparse/compute/
search_sorted.rs

1use std::cmp::Ordering;
2
3use vortex_array::Array;
4use vortex_array::compute::{
5    SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn, scalar_at,
6};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_scalar::Scalar;
9
10use crate::{SparseArray, SparseEncoding};
11
12impl SearchSortedFn<&SparseArray> for SparseEncoding {
13    fn search_sorted(
14        &self,
15        array: &SparseArray,
16        value: &Scalar,
17        side: SearchSortedSide,
18    ) -> VortexResult<SearchResult> {
19        // first search result in patches
20        let patches_result = array.patches().search_sorted(value.clone(), side)?;
21        match patches_result {
22            SearchResult::Found(i) => {
23                if value == array.fill_scalar() {
24                    // Find the relevant position of the fill value in the patches
25                    let fill_index = fill_position(array, side)?;
26                    match side {
27                        SearchSortedSide::Left => Ok(SearchResult::Found(i.min(fill_index))),
28                        SearchSortedSide::Right => Ok(SearchResult::Found(i.max(fill_index))),
29                    }
30                } else {
31                    Ok(SearchResult::Found(i))
32                }
33            }
34            SearchResult::NotFound(i) => {
35                // Find the relevant position of the fill value in the patches
36                let fill_index = fill_position(array, side)?;
37
38                // Adjust the position of the search value relative to the position of the fill value
39                match value
40                    .partial_cmp(array.fill_scalar())
41                    .vortex_expect("value and fill scalar must have same dtype")
42                {
43                    Ordering::Less => Ok(SearchResult::NotFound(i.min(fill_index))),
44                    Ordering::Equal => match side {
45                        SearchSortedSide::Left => Ok(SearchResult::Found(i.min(fill_index))),
46                        SearchSortedSide::Right => Ok(SearchResult::Found(i.max(fill_index))),
47                    },
48                    Ordering::Greater => Ok(SearchResult::NotFound(i.max(fill_index))),
49                }
50            }
51        }
52    }
53}
54
55// Find the fill position relative to patches, in case of fill being in between patches we want to find the right most
56// index of the fill relative to patches.
57fn fill_position(array: &SparseArray, side: SearchSortedSide) -> VortexResult<usize> {
58    let fill_result = if array.fill_scalar().is_null() {
59        // For null fill the patches can only ever be after the fill
60        SearchResult::NotFound(array.patches().min_index()?)
61    } else {
62        array
63            .patches()
64            .search_sorted(array.fill_scalar().clone(), side)?
65    };
66    let fill_result_index = fill_result.to_index();
67    // Find the relevant position of the fill value in the patches
68    Ok(if fill_result_index <= array.patches().min_index()? {
69        // [fill, ..., patch]
70        0
71    } else if fill_result_index > array.patches().max_index()? {
72        // [patch, ..., fill]
73        array.len()
74    } else {
75        // [patch, fill, ..., fill, patch]
76        let fill_index = array.patches().search_index(fill_result_index)?.to_index();
77        match fill_result {
78            // If fill value is present in patches this would be the index of the next or previous value after the fill value depending on the side
79            SearchResult::Found(_) => match side {
80                SearchSortedSide::Left => {
81                    usize::try_from(&scalar_at(array.patches().indices(), fill_index - 1)?)? + 1
82                }
83                SearchSortedSide::Right => {
84                    if fill_index < array.patches().num_patches() {
85                        usize::try_from(&scalar_at(array.patches().indices(), fill_index)?)?
86                    } else {
87                        fill_result_index
88                    }
89                }
90            },
91            // If the fill value is not in patches but falls in between two patch values we want to take the right most index of that will match the fill value
92            // This will then be min/maxed with result of searching for value in patches
93            SearchResult::NotFound(_) => {
94                if fill_index < array.patches().num_patches() {
95                    usize::try_from(&scalar_at(array.patches().indices(), fill_index)?)?
96                } else {
97                    fill_result_index
98                }
99            }
100        }
101    })
102}
103
104impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
105    fn search_sorted_usize(
106        &self,
107        array: &SparseArray,
108        value: usize,
109        side: SearchSortedSide,
110    ) -> VortexResult<SearchResult> {
111        let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
112            // If the downcast fails, then the target is too large for the dtype.
113            return Ok(SearchResult::NotFound(array.len()));
114        };
115        SearchSortedFn::search_sorted(self, array, &target, side)
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use vortex_array::compute::conformance::search_sorted::rstest_reuse::apply;
122    use vortex_array::compute::conformance::search_sorted::{search_sorted_conformance, *};
123    use vortex_array::compute::{SearchResult, SearchSortedSide, search_sorted};
124    use vortex_array::{Array, ArrayRef, IntoArray};
125    use vortex_buffer::buffer;
126    use vortex_dtype::Nullability;
127    use vortex_error::VortexUnwrap;
128    use vortex_scalar::Scalar;
129
130    use crate::SparseArray;
131
132    #[apply(search_sorted_conformance)]
133    fn sparse_search_sorted(
134        #[case] array: ArrayRef,
135        #[case] value: i32,
136        #[case] side: SearchSortedSide,
137        #[case] expected: SearchResult,
138    ) {
139        let sparse_array = SparseArray::encode(&array, None).vortex_unwrap();
140        let res = search_sorted(&sparse_array, value, side).unwrap();
141        assert_eq!(res, expected);
142    }
143
144    fn high_fill_in_patches() -> ArrayRef {
145        SparseArray::try_new(
146            buffer![17u64, 18, 19].into_array(),
147            buffer![33_i32, 44, 55].into_array(),
148            20,
149            Scalar::primitive(33, Nullability::NonNullable),
150        )
151        .unwrap()
152        .into_array()
153    }
154
155    fn low_fill_in_patches() -> ArrayRef {
156        SparseArray::try_new(
157            buffer![0u64, 1, 2].into_array(),
158            buffer![33_i32, 44, 55].into_array(),
159            20,
160            Scalar::primitive(55, Nullability::NonNullable),
161        )
162        .unwrap()
163        .into_array()
164    }
165
166    fn low_high_fill_in_patches_low() -> ArrayRef {
167        SparseArray::try_new(
168            buffer![0u64, 1, 17, 18, 19].into_array(),
169            buffer![11i32, 22, 33, 44, 55].into_array(),
170            20,
171            Scalar::primitive(22, Nullability::NonNullable),
172        )
173        .unwrap()
174        .into_array()
175    }
176
177    fn low_high_fill_in_patches_high() -> ArrayRef {
178        SparseArray::try_new(
179            buffer![0u64, 1, 17, 18, 19].into_array(),
180            buffer![11i32, 22, 33, 44, 55].into_array(),
181            20,
182            Scalar::primitive(33, Nullability::NonNullable),
183        )
184        .unwrap()
185        .into_array()
186    }
187
188    #[rstest]
189    #[case(
190        high_fill_in_patches(),
191        33,
192        SearchSortedSide::Left,
193        SearchResult::Found(0)
194    )]
195    #[case(
196        low_fill_in_patches(),
197        55,
198        SearchSortedSide::Left,
199        SearchResult::Found(2)
200    )]
201    #[case(
202        low_high_fill_in_patches_low(),
203        22,
204        SearchSortedSide::Left,
205        SearchResult::Found(1)
206    )]
207    #[case(
208        low_high_fill_in_patches_high(),
209        33,
210        SearchSortedSide::Left,
211        SearchResult::Found(2)
212    )]
213    #[case(
214        high_fill_in_patches(),
215        33,
216        SearchSortedSide::Right,
217        SearchResult::Found(18)
218    )]
219    #[case(
220        low_fill_in_patches(),
221        55,
222        SearchSortedSide::Right,
223        SearchResult::Found(20)
224    )]
225    #[case(
226        low_high_fill_in_patches_low(),
227        22,
228        SearchSortedSide::Right,
229        SearchResult::Found(17)
230    )]
231    #[case(
232        low_high_fill_in_patches_high(),
233        33,
234        SearchSortedSide::Right,
235        SearchResult::Found(18)
236    )]
237    fn search_fill(
238        #[case] array: ArrayRef,
239        #[case] search: i32,
240        #[case] side: SearchSortedSide,
241        #[case] expected: SearchResult,
242    ) {
243        let res = search_sorted(&array, search, side).vortex_unwrap();
244        assert_eq!(res, expected);
245    }
246}