1use std::ops::Not;
5use std::ops::Range;
6
7use vortex_buffer::Buffer;
8use vortex_error::vortex_panic;
9use vortex_mask::Mask;
10
11use crate::row_mask::RowMask;
12
13#[derive(Default, Clone, Debug)]
16pub enum Selection {
17 #[default]
19 All,
20 IncludeByIndex(Buffer<u64>),
22 ExcludeByIndex(Buffer<u64>),
24 IncludeRoaring(roaring::RoaringTreemap),
26 ExcludeRoaring(roaring::RoaringTreemap),
28}
29
30impl Selection {
31 pub fn row_count(&self, total_rows: u64) -> u64 {
33 match self {
34 Selection::All => total_rows,
35 Selection::IncludeByIndex(include) => include.len() as u64,
36 Selection::ExcludeByIndex(exclude) => total_rows.saturating_sub(exclude.len() as u64),
37 Selection::IncludeRoaring(roaring) => roaring.len(),
38 Selection::ExcludeRoaring(roaring) => total_rows.saturating_sub(roaring.len()),
39 }
40 }
41
42 pub(crate) fn row_mask(&self, range: &Range<u64>) -> RowMask {
44 let range_diff = range.end.saturating_sub(range.start);
46 let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
47 tracing::warn!(
50 "Range length {} exceeds usize::MAX, capping at usize::MAX",
51 range_diff
52 );
53 usize::MAX
54 });
55
56 match self {
57 Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
58 Selection::IncludeByIndex(include) => {
59 let mask = indices_range(range, include)
60 .map(|idx_range| {
61 Mask::from_indices(
62 range_len,
63 include
64 .slice(idx_range)
65 .iter()
66 .map(|idx| {
67 idx.checked_sub(range.start).unwrap_or_else(|| {
68 vortex_panic!(
69 "index underflow, range: {:?}, idx: {:?}",
70 range,
71 idx
72 )
73 })
74 })
75 .filter_map(|idx| {
76 usize::try_from(idx).ok()
78 })
79 .collect(),
80 )
81 })
82 .unwrap_or_else(|| Mask::new_false(range_len));
83
84 RowMask::new(range.start, mask)
85 }
86 Selection::ExcludeByIndex(exclude) => {
87 let mask = Selection::IncludeByIndex(exclude.clone())
88 .row_mask(range)
89 .mask()
90 .clone();
91 RowMask::new(range.start, mask.not())
92 }
93 Selection::IncludeRoaring(roaring) => {
94 use std::ops::BitAnd;
95
96 let mut range_treemap = roaring::RoaringTreemap::new();
98 range_treemap.insert_range(range.clone());
99
100 if roaring.is_disjoint(&range_treemap) {
101 return RowMask::new(range.start, Mask::new_false(range_len));
102 }
103
104 let roaring = roaring.bitand(range_treemap);
106 let mask = Mask::from_indices(
107 range_len,
108 roaring
109 .iter()
110 .map(|idx| {
111 idx.checked_sub(range.start).unwrap_or_else(|| {
112 vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
113 })
114 })
115 .filter_map(|idx| {
116 usize::try_from(idx).ok()
118 })
119 .collect(),
120 );
121
122 RowMask::new(range.start, mask)
123 }
124 Selection::ExcludeRoaring(roaring) => {
125 use std::ops::BitAnd;
126
127 let mut range_treemap = roaring::RoaringTreemap::new();
128 range_treemap.insert_range(range.clone());
129
130 if roaring.intersection_len(&range_treemap) == range_len as u64 {
132 return RowMask::new(range.start, Mask::new_false(range_len));
133 }
134
135 let roaring = roaring.bitand(range_treemap);
137 let mask = Mask::from_excluded_indices(
138 range_len,
139 roaring
140 .iter()
141 .map(|idx| {
142 idx.checked_sub(range.start).unwrap_or_else(|| {
143 vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
144 })
145 })
146 .filter_map(|idx| usize::try_from(idx).ok()),
147 );
148
149 RowMask::new(range.start, mask)
150 }
151 }
152 }
153}
154
155fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
157 if row_indices.first().is_some_and(|&first| first >= range.end)
158 || row_indices.last().is_some_and(|&last| range.start > last)
159 {
160 return None;
161 }
162
163 let start_idx = row_indices
165 .binary_search(&range.start)
166 .unwrap_or_else(|x| x);
167 let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
168
169 (start_idx != end_idx).then_some(start_idx..end_idx)
170}
171
172#[cfg(test)]
173mod tests {
174 use vortex_buffer::Buffer;
175
176 #[test]
177 fn test_row_mask_all() {
178 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
179 let range = 1..8;
180 let row_mask = selection.row_mask(&range);
181
182 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
183 }
184
185 #[test]
186 fn test_row_mask_slice() {
187 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
188 let range = 3..6;
189 let row_mask = selection.row_mask(&range);
190
191 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
192 }
193
194 #[test]
195 fn test_row_mask_exclusive() {
196 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
197 let range = 3..5;
198 let row_mask = selection.row_mask(&range);
199
200 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
201 }
202
203 #[test]
204 fn test_row_mask_all_false() {
205 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
206 let range = 8..10;
207 let row_mask = selection.row_mask(&range);
208
209 assert!(row_mask.mask().all_false());
210 }
211
212 #[test]
213 fn test_row_mask_all_true() {
214 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
215 let range = 3..7;
216 let row_mask = selection.row_mask(&range);
217
218 assert!(row_mask.mask().all_true());
219 }
220
221 #[test]
222 fn test_row_mask_zero() {
223 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
224 let range = 0..5;
225 let row_mask = selection.row_mask(&range);
226
227 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
228 }
229
230 mod roaring_tests {
231 use roaring::RoaringTreemap;
232
233 use super::*;
234
235 #[test]
236 fn test_roaring_include_basic() {
237 let mut roaring = RoaringTreemap::new();
238 roaring.insert(1);
239 roaring.insert(3);
240 roaring.insert(5);
241 roaring.insert(7);
242
243 let selection = super::super::Selection::IncludeRoaring(roaring);
244 let range = 1..8;
245 let row_mask = selection.row_mask(&range);
246
247 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
248 }
249
250 #[test]
251 fn test_roaring_include_slice() {
252 let mut roaring = RoaringTreemap::new();
253 roaring.insert(1);
254 roaring.insert(3);
255 roaring.insert(5);
256 roaring.insert(7);
257
258 let selection = super::super::Selection::IncludeRoaring(roaring);
259 let range = 3..6;
260 let row_mask = selection.row_mask(&range);
261
262 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
263 }
264
265 #[test]
266 fn test_roaring_include_disjoint() {
267 let mut roaring = RoaringTreemap::new();
268 roaring.insert(1);
269 roaring.insert(3);
270 roaring.insert(5);
271 roaring.insert(7);
272
273 let selection = super::super::Selection::IncludeRoaring(roaring);
274 let range = 8..10;
275 let row_mask = selection.row_mask(&range);
276
277 assert!(row_mask.mask().all_false());
278 }
279
280 #[test]
281 fn test_roaring_include_large_range() {
282 let mut roaring = RoaringTreemap::new();
283 for i in (0..1000000).step_by(2) {
285 roaring.insert(i);
286 }
287
288 let selection = super::super::Selection::IncludeRoaring(roaring);
289 let range = 1000..2000;
290 let row_mask = selection.row_mask(&range);
291
292 assert_eq!(row_mask.mask().true_count(), 500);
294 }
295
296 #[test]
297 fn test_roaring_exclude_basic() {
298 let mut roaring = RoaringTreemap::new();
299 roaring.insert(1);
300 roaring.insert(3);
301 roaring.insert(5);
302
303 let selection = super::super::Selection::ExcludeRoaring(roaring);
304 let range = 0..7;
305 let row_mask = selection.row_mask(&range);
306
307 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
309 }
310
311 #[test]
312 fn test_roaring_exclude_all() {
313 let mut roaring = RoaringTreemap::new();
314 for i in 10..20 {
316 roaring.insert(i);
317 }
318
319 let selection = super::super::Selection::ExcludeRoaring(roaring);
320 let range = 10..20;
321 let row_mask = selection.row_mask(&range);
322
323 assert!(row_mask.mask().all_false());
324 }
325
326 #[test]
327 fn test_roaring_exclude_none() {
328 let mut roaring = RoaringTreemap::new();
329 roaring.insert(100);
330 roaring.insert(101);
331
332 let selection = super::super::Selection::ExcludeRoaring(roaring);
333 let range = 0..10;
334 let row_mask = selection.row_mask(&range);
335
336 assert!(row_mask.mask().all_true());
338 }
339
340 #[test]
341 fn test_roaring_exclude_partial() {
342 let mut roaring = RoaringTreemap::new();
343 roaring.insert(5);
344 roaring.insert(6);
345 roaring.insert(7);
346 roaring.insert(15); let selection = super::super::Selection::ExcludeRoaring(roaring);
349 let range = 5..10;
350 let row_mask = selection.row_mask(&range);
351
352 assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
354 }
355
356 #[test]
357 fn test_roaring_include_empty() {
358 let roaring = RoaringTreemap::new();
359 let selection = super::super::Selection::IncludeRoaring(roaring);
360 let range = 0..100;
361 let row_mask = selection.row_mask(&range);
362
363 assert!(row_mask.mask().all_false());
364 }
365
366 #[test]
367 fn test_roaring_exclude_empty() {
368 let roaring = RoaringTreemap::new();
369 let selection = super::super::Selection::ExcludeRoaring(roaring);
370 let range = 0..100;
371 let row_mask = selection.row_mask(&range);
372
373 assert!(row_mask.mask().all_true());
374 }
375
376 #[test]
377 fn test_roaring_include_boundary() {
378 let mut roaring = RoaringTreemap::new();
379 roaring.insert(0);
380 roaring.insert(99);
381
382 let selection = super::super::Selection::IncludeRoaring(roaring);
383 let range = 0..100;
384 let row_mask = selection.row_mask(&range);
385
386 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
387 }
388
389 #[test]
390 fn test_roaring_include_range_insertion() {
391 let mut roaring = RoaringTreemap::new();
392 roaring.insert_range(10..20);
394 roaring.insert_range(30..40);
395
396 let selection = super::super::Selection::IncludeRoaring(roaring);
397 let range = 15..35;
398 let row_mask = selection.row_mask(&range);
399
400 let expected: Vec<usize> = (0..5).chain(15..20).collect();
402 assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
403 }
404
405 #[test]
406 fn test_roaring_overflow_protection() {
407 let mut roaring = RoaringTreemap::new();
408 roaring.insert(u64::MAX - 1);
410 roaring.insert(u64::MAX);
411
412 let selection = super::super::Selection::IncludeRoaring(roaring);
413 let range = u64::MAX - 10..u64::MAX;
414 let row_mask = selection.row_mask(&range);
415
416 assert_eq!(row_mask.mask().true_count(), 1); }
419
420 #[test]
421 fn test_roaring_exclude_overflow_protection() {
422 let mut roaring = RoaringTreemap::new();
423 roaring.insert(u64::MAX - 1);
424
425 let selection = super::super::Selection::ExcludeRoaring(roaring);
426 let range = u64::MAX - 10..u64::MAX;
427 let row_mask = selection.row_mask(&range);
428
429 assert_eq!(row_mask.mask().true_count(), 9); }
432
433 #[test]
434 fn test_roaring_include_vs_buffer_equivalence() {
435 let indices = vec![1, 3, 5, 7, 9];
437
438 let buffer_selection =
439 super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
440
441 let mut roaring = RoaringTreemap::new();
442 for idx in &indices {
443 roaring.insert(*idx);
444 }
445 let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
446
447 let range = 0..12;
448 let buffer_mask = buffer_selection.row_mask(&range);
449 let roaring_mask = roaring_selection.row_mask(&range);
450
451 assert_eq!(
452 buffer_mask.mask().values().unwrap().indices(),
453 roaring_mask.mask().values().unwrap().indices()
454 );
455 }
456
457 #[test]
458 fn test_roaring_exclude_vs_buffer_equivalence() {
459 let indices = vec![2, 4, 6, 8];
461
462 let buffer_selection =
463 super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
464
465 let mut roaring = RoaringTreemap::new();
466 for idx in &indices {
467 roaring.insert(*idx);
468 }
469 let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
470
471 let range = 0..10;
472 let buffer_mask = buffer_selection.row_mask(&range);
473 let roaring_mask = roaring_selection.row_mask(&range);
474
475 assert_eq!(
476 buffer_mask.mask().values().unwrap().indices(),
477 roaring_mask.mask().values().unwrap().indices()
478 );
479 }
480 }
481}