1use std::cmp::Ordering;
2
3use vortex_array::Array;
4use vortex_array::compute::{
5 SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn, scalar_at,
6};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_scalar::Scalar;
9
10use crate::{SparseArray, SparseEncoding};
11
12impl SearchSortedFn<&SparseArray> for SparseEncoding {
13 fn search_sorted(
14 &self,
15 array: &SparseArray,
16 value: &Scalar,
17 side: SearchSortedSide,
18 ) -> VortexResult<SearchResult> {
19 let patches_result = array.patches().search_sorted(value.clone(), side)?;
21 match patches_result {
22 SearchResult::Found(i) => {
23 if value == array.fill_scalar() {
24 let fill_index = fill_position(array, side)?;
26 match side {
27 SearchSortedSide::Left => Ok(SearchResult::Found(i.min(fill_index))),
28 SearchSortedSide::Right => Ok(SearchResult::Found(i.max(fill_index))),
29 }
30 } else {
31 Ok(SearchResult::Found(i))
32 }
33 }
34 SearchResult::NotFound(i) => {
35 let fill_index = fill_position(array, side)?;
37
38 match value
40 .partial_cmp(array.fill_scalar())
41 .vortex_expect("value and fill scalar must have same dtype")
42 {
43 Ordering::Less => Ok(SearchResult::NotFound(i.min(fill_index))),
44 Ordering::Equal => match side {
45 SearchSortedSide::Left => Ok(SearchResult::Found(i.min(fill_index))),
46 SearchSortedSide::Right => Ok(SearchResult::Found(i.max(fill_index))),
47 },
48 Ordering::Greater => Ok(SearchResult::NotFound(i.max(fill_index))),
49 }
50 }
51 }
52 }
53}
54
55fn fill_position(array: &SparseArray, side: SearchSortedSide) -> VortexResult<usize> {
58 let fill_result = if array.fill_scalar().is_null() {
59 SearchResult::NotFound(array.patches().min_index()?)
61 } else {
62 array
63 .patches()
64 .search_sorted(array.fill_scalar().clone(), side)?
65 };
66 let fill_result_index = fill_result.to_index();
67 Ok(if fill_result_index <= array.patches().min_index()? {
69 0
71 } else if fill_result_index > array.patches().max_index()? {
72 array.len()
74 } else {
75 let fill_index = array.patches().search_index(fill_result_index)?.to_index();
77 match fill_result {
78 SearchResult::Found(_) => match side {
80 SearchSortedSide::Left => {
81 usize::try_from(&scalar_at(array.patches().indices(), fill_index - 1)?)? + 1
82 }
83 SearchSortedSide::Right => {
84 if fill_index < array.patches().num_patches() {
85 usize::try_from(&scalar_at(array.patches().indices(), fill_index)?)?
86 } else {
87 fill_result_index
88 }
89 }
90 },
91 SearchResult::NotFound(_) => {
94 if fill_index < array.patches().num_patches() {
95 usize::try_from(&scalar_at(array.patches().indices(), fill_index)?)?
96 } else {
97 fill_result_index
98 }
99 }
100 }
101 })
102}
103
104impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
105 fn search_sorted_usize(
106 &self,
107 array: &SparseArray,
108 value: usize,
109 side: SearchSortedSide,
110 ) -> VortexResult<SearchResult> {
111 let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
112 return Ok(SearchResult::NotFound(array.len()));
114 };
115 SearchSortedFn::search_sorted(self, array, &target, side)
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use vortex_array::compute::conformance::search_sorted::rstest_reuse::apply;
122 use vortex_array::compute::conformance::search_sorted::{search_sorted_conformance, *};
123 use vortex_array::compute::{SearchResult, SearchSortedSide, search_sorted};
124 use vortex_array::{Array, ArrayRef, IntoArray};
125 use vortex_buffer::buffer;
126 use vortex_dtype::Nullability;
127 use vortex_error::VortexUnwrap;
128 use vortex_scalar::Scalar;
129
130 use crate::SparseArray;
131
132 #[apply(search_sorted_conformance)]
133 fn sparse_search_sorted(
134 #[case] array: ArrayRef,
135 #[case] value: i32,
136 #[case] side: SearchSortedSide,
137 #[case] expected: SearchResult,
138 ) {
139 let sparse_array = SparseArray::encode(&array, None).vortex_unwrap();
140 let res = search_sorted(&sparse_array, value, side).unwrap();
141 assert_eq!(res, expected);
142 }
143
144 fn high_fill_in_patches() -> ArrayRef {
145 SparseArray::try_new(
146 buffer![17u64, 18, 19].into_array(),
147 buffer![33_i32, 44, 55].into_array(),
148 20,
149 Scalar::primitive(33, Nullability::NonNullable),
150 )
151 .unwrap()
152 .into_array()
153 }
154
155 fn low_fill_in_patches() -> ArrayRef {
156 SparseArray::try_new(
157 buffer![0u64, 1, 2].into_array(),
158 buffer![33_i32, 44, 55].into_array(),
159 20,
160 Scalar::primitive(55, Nullability::NonNullable),
161 )
162 .unwrap()
163 .into_array()
164 }
165
166 fn low_high_fill_in_patches_low() -> ArrayRef {
167 SparseArray::try_new(
168 buffer![0u64, 1, 17, 18, 19].into_array(),
169 buffer![11i32, 22, 33, 44, 55].into_array(),
170 20,
171 Scalar::primitive(22, Nullability::NonNullable),
172 )
173 .unwrap()
174 .into_array()
175 }
176
177 fn low_high_fill_in_patches_high() -> ArrayRef {
178 SparseArray::try_new(
179 buffer![0u64, 1, 17, 18, 19].into_array(),
180 buffer![11i32, 22, 33, 44, 55].into_array(),
181 20,
182 Scalar::primitive(33, Nullability::NonNullable),
183 )
184 .unwrap()
185 .into_array()
186 }
187
188 #[rstest]
189 #[case(
190 high_fill_in_patches(),
191 33,
192 SearchSortedSide::Left,
193 SearchResult::Found(0)
194 )]
195 #[case(
196 low_fill_in_patches(),
197 55,
198 SearchSortedSide::Left,
199 SearchResult::Found(2)
200 )]
201 #[case(
202 low_high_fill_in_patches_low(),
203 22,
204 SearchSortedSide::Left,
205 SearchResult::Found(1)
206 )]
207 #[case(
208 low_high_fill_in_patches_high(),
209 33,
210 SearchSortedSide::Left,
211 SearchResult::Found(2)
212 )]
213 #[case(
214 high_fill_in_patches(),
215 33,
216 SearchSortedSide::Right,
217 SearchResult::Found(18)
218 )]
219 #[case(
220 low_fill_in_patches(),
221 55,
222 SearchSortedSide::Right,
223 SearchResult::Found(20)
224 )]
225 #[case(
226 low_high_fill_in_patches_low(),
227 22,
228 SearchSortedSide::Right,
229 SearchResult::Found(17)
230 )]
231 #[case(
232 low_high_fill_in_patches_high(),
233 33,
234 SearchSortedSide::Right,
235 SearchResult::Found(18)
236 )]
237 fn search_fill(
238 #[case] array: ArrayRef,
239 #[case] search: i32,
240 #[case] side: SearchSortedSide,
241 #[case] expected: SearchResult,
242 ) {
243 let res = search_sorted(&array, search, side).vortex_unwrap();
244 assert_eq!(res, expected);
245 }
246}