1use std::ops::Not;
7use std::ops::Range;
8
9use vortex_buffer::Buffer;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::row_mask::RowMask;
14
15#[derive(Default, Clone, Debug)]
18pub enum Selection {
19 #[default]
21 All,
22 IncludeByIndex(Buffer<u64>),
24 ExcludeByIndex(Buffer<u64>),
26 IncludeRoaring(roaring::RoaringTreemap),
28 ExcludeRoaring(roaring::RoaringTreemap),
30}
31
32impl Selection {
33 pub fn row_count(&self, total_rows: u64) -> u64 {
35 match self {
36 Selection::All => total_rows,
37 Selection::IncludeByIndex(include) => include.len() as u64,
38 Selection::ExcludeByIndex(exclude) => total_rows.saturating_sub(exclude.len() as u64),
39 Selection::IncludeRoaring(roaring) => roaring.len(),
40 Selection::ExcludeRoaring(roaring) => total_rows.saturating_sub(roaring.len()),
41 }
42 }
43
44 pub fn row_mask(&self, range: &Range<u64>) -> RowMask {
46 if range.start >= range.end {
47 return RowMask::new(0, Mask::AllFalse(0));
48 }
49
50 let range_diff = range.end.saturating_sub(range.start);
52 let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
53 tracing::warn!(
56 "Range length {} exceeds usize::MAX, capping at usize::MAX",
57 range_diff
58 );
59 usize::MAX
60 });
61
62 match self {
63 Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
64 Selection::IncludeByIndex(include) => {
65 let mask = indices_range(range, include)
66 .map(|idx_range| {
67 Mask::from_indices(
68 range_len,
69 include
70 .slice(idx_range)
71 .iter()
72 .map(|idx| {
73 idx.checked_sub(range.start).unwrap_or_else(|| {
74 vortex_panic!(
75 "index underflow, range: {:?}, idx: {:?}",
76 range,
77 idx
78 )
79 })
80 })
81 .filter_map(|idx| {
82 usize::try_from(idx).ok()
84 }),
85 )
86 })
87 .unwrap_or_else(|| Mask::new_false(range_len));
88
89 RowMask::new(range.start, mask)
90 }
91 Selection::ExcludeByIndex(exclude) => {
92 let mask = Selection::IncludeByIndex(exclude.clone())
93 .row_mask(range)
94 .mask()
95 .clone();
96 RowMask::new(range.start, mask.not())
97 }
98 Selection::IncludeRoaring(roaring) => {
99 use std::ops::BitAnd;
100
101 let mut range_treemap = roaring::RoaringTreemap::new();
103 range_treemap.insert_range(range.clone());
104
105 if roaring.is_disjoint(&range_treemap) {
106 return RowMask::new(range.start, Mask::new_false(range_len));
107 }
108
109 let roaring = roaring.bitand(range_treemap);
111 let mask = Mask::from_indices(
112 range_len,
113 roaring
114 .iter()
115 .map(|idx| {
116 idx.checked_sub(range.start).unwrap_or_else(|| {
117 vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
118 })
119 })
120 .filter_map(|idx| {
121 usize::try_from(idx).ok()
123 }),
124 );
125
126 RowMask::new(range.start, mask)
127 }
128 Selection::ExcludeRoaring(roaring) => {
129 use std::ops::BitAnd;
130
131 let mut range_treemap = roaring::RoaringTreemap::new();
132 range_treemap.insert_range(range.clone());
133
134 if roaring.intersection_len(&range_treemap) == range_len as u64 {
136 return RowMask::new(range.start, Mask::new_false(range_len));
137 }
138
139 let roaring = roaring.bitand(range_treemap);
141 let mask = Mask::from_excluded_indices(
142 range_len,
143 roaring
144 .iter()
145 .map(|idx| {
146 idx.checked_sub(range.start).unwrap_or_else(|| {
147 vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
148 })
149 })
150 .filter_map(|idx| usize::try_from(idx).ok()),
151 );
152
153 RowMask::new(range.start, mask)
154 }
155 }
156 }
157}
158
159fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
161 if row_indices.first().is_some_and(|&first| first >= range.end)
162 || row_indices.last().is_some_and(|&last| range.start > last)
163 {
164 return None;
165 }
166
167 let start_idx = row_indices
169 .binary_search(&range.start)
170 .unwrap_or_else(|x| x);
171 let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
172
173 (start_idx != end_idx).then_some(start_idx..end_idx)
174}
175
176#[cfg(test)]
177mod tests {
178 use vortex_buffer::Buffer;
179
180 #[test]
181 fn test_row_mask_all() {
182 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
183 let range = 1..8;
184 let row_mask = selection.row_mask(&range);
185
186 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
187 }
188
189 #[test]
190 fn test_row_mask_slice() {
191 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
192 let range = 3..6;
193 let row_mask = selection.row_mask(&range);
194
195 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
196 }
197
198 #[test]
199 fn test_row_mask_exclusive() {
200 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
201 let range = 3..5;
202 let row_mask = selection.row_mask(&range);
203
204 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
205 }
206
207 #[test]
208 fn test_row_mask_all_false() {
209 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
210 let range = 8..10;
211 let row_mask = selection.row_mask(&range);
212
213 assert!(row_mask.mask().all_false());
214 }
215
216 #[test]
217 fn test_row_mask_all_true() {
218 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
219 let range = 3..7;
220 let row_mask = selection.row_mask(&range);
221
222 assert!(row_mask.mask().all_true());
223 }
224
225 #[test]
226 fn test_row_mask_zero() {
227 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
228 let range = 0..5;
229 let row_mask = selection.row_mask(&range);
230
231 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
232 }
233
234 mod roaring_tests {
235 use roaring::RoaringTreemap;
236
237 use super::*;
238
239 #[test]
240 fn test_roaring_include_basic() {
241 let mut roaring = RoaringTreemap::new();
242 roaring.insert(1);
243 roaring.insert(3);
244 roaring.insert(5);
245 roaring.insert(7);
246
247 let selection = super::super::Selection::IncludeRoaring(roaring);
248 let range = 1..8;
249 let row_mask = selection.row_mask(&range);
250
251 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
252 }
253
254 #[test]
255 fn test_roaring_include_slice() {
256 let mut roaring = RoaringTreemap::new();
257 roaring.insert(1);
258 roaring.insert(3);
259 roaring.insert(5);
260 roaring.insert(7);
261
262 let selection = super::super::Selection::IncludeRoaring(roaring);
263 let range = 3..6;
264 let row_mask = selection.row_mask(&range);
265
266 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
267 }
268
269 #[test]
270 fn test_roaring_include_disjoint() {
271 let mut roaring = RoaringTreemap::new();
272 roaring.insert(1);
273 roaring.insert(3);
274 roaring.insert(5);
275 roaring.insert(7);
276
277 let selection = super::super::Selection::IncludeRoaring(roaring);
278 let range = 8..10;
279 let row_mask = selection.row_mask(&range);
280
281 assert!(row_mask.mask().all_false());
282 }
283
284 #[test]
285 fn test_roaring_include_large_range() {
286 let mut roaring = RoaringTreemap::new();
287 for i in (0..1000000).step_by(2) {
289 roaring.insert(i);
290 }
291
292 let selection = super::super::Selection::IncludeRoaring(roaring);
293 let range = 1000..2000;
294 let row_mask = selection.row_mask(&range);
295
296 assert_eq!(row_mask.mask().true_count(), 500);
298 }
299
300 #[test]
301 fn test_roaring_exclude_basic() {
302 let mut roaring = RoaringTreemap::new();
303 roaring.insert(1);
304 roaring.insert(3);
305 roaring.insert(5);
306
307 let selection = super::super::Selection::ExcludeRoaring(roaring);
308 let range = 0..7;
309 let row_mask = selection.row_mask(&range);
310
311 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
313 }
314
315 #[test]
316 fn test_roaring_exclude_all() {
317 let mut roaring = RoaringTreemap::new();
318 for i in 10..20 {
320 roaring.insert(i);
321 }
322
323 let selection = super::super::Selection::ExcludeRoaring(roaring);
324 let range = 10..20;
325 let row_mask = selection.row_mask(&range);
326
327 assert!(row_mask.mask().all_false());
328 }
329
330 #[test]
331 fn test_roaring_exclude_none() {
332 let mut roaring = RoaringTreemap::new();
333 roaring.insert(100);
334 roaring.insert(101);
335
336 let selection = super::super::Selection::ExcludeRoaring(roaring);
337 let range = 0..10;
338 let row_mask = selection.row_mask(&range);
339
340 assert!(row_mask.mask().all_true());
342 }
343
344 #[test]
345 fn test_roaring_exclude_partial() {
346 let mut roaring = RoaringTreemap::new();
347 roaring.insert(5);
348 roaring.insert(6);
349 roaring.insert(7);
350 roaring.insert(15); let selection = super::super::Selection::ExcludeRoaring(roaring);
353 let range = 5..10;
354 let row_mask = selection.row_mask(&range);
355
356 assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
358 }
359
360 #[test]
361 fn test_roaring_include_empty() {
362 let roaring = RoaringTreemap::new();
363 let selection = super::super::Selection::IncludeRoaring(roaring);
364 let range = 0..100;
365 let row_mask = selection.row_mask(&range);
366
367 assert!(row_mask.mask().all_false());
368 }
369
370 #[test]
371 fn test_roaring_exclude_empty() {
372 let roaring = RoaringTreemap::new();
373 let selection = super::super::Selection::ExcludeRoaring(roaring);
374 let range = 0..100;
375 let row_mask = selection.row_mask(&range);
376
377 assert!(row_mask.mask().all_true());
378 }
379
380 #[test]
381 fn test_roaring_include_boundary() {
382 let mut roaring = RoaringTreemap::new();
383 roaring.insert(0);
384 roaring.insert(99);
385
386 let selection = super::super::Selection::IncludeRoaring(roaring);
387 let range = 0..100;
388 let row_mask = selection.row_mask(&range);
389
390 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
391 }
392
393 #[test]
394 fn test_roaring_include_range_insertion() {
395 let mut roaring = RoaringTreemap::new();
396 roaring.insert_range(10..20);
398 roaring.insert_range(30..40);
399
400 let selection = super::super::Selection::IncludeRoaring(roaring);
401 let range = 15..35;
402 let row_mask = selection.row_mask(&range);
403
404 let expected: Vec<usize> = (0..5).chain(15..20).collect();
406 assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
407 }
408
409 #[test]
410 fn test_roaring_overflow_protection() {
411 let mut roaring = RoaringTreemap::new();
412 roaring.insert(u64::MAX - 1);
414 roaring.insert(u64::MAX);
415
416 let selection = super::super::Selection::IncludeRoaring(roaring);
417 let range = u64::MAX - 10..u64::MAX;
418 let row_mask = selection.row_mask(&range);
419
420 assert_eq!(row_mask.mask().true_count(), 1); }
423
424 #[test]
425 fn test_roaring_exclude_overflow_protection() {
426 let mut roaring = RoaringTreemap::new();
427 roaring.insert(u64::MAX - 1);
428
429 let selection = super::super::Selection::ExcludeRoaring(roaring);
430 let range = u64::MAX - 10..u64::MAX;
431 let row_mask = selection.row_mask(&range);
432
433 assert_eq!(row_mask.mask().true_count(), 9); }
436
437 #[test]
438 fn test_roaring_include_vs_buffer_equivalence() {
439 let indices = vec![1, 3, 5, 7, 9];
441
442 let buffer_selection =
443 super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
444
445 let mut roaring = RoaringTreemap::new();
446 for idx in &indices {
447 roaring.insert(*idx);
448 }
449 let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
450
451 let range = 0..12;
452 let buffer_mask = buffer_selection.row_mask(&range);
453 let roaring_mask = roaring_selection.row_mask(&range);
454
455 assert_eq!(
456 buffer_mask.mask().values().unwrap().indices(),
457 roaring_mask.mask().values().unwrap().indices()
458 );
459 }
460
461 #[test]
462 fn test_roaring_exclude_vs_buffer_equivalence() {
463 let indices = vec![2, 4, 6, 8];
465
466 let buffer_selection =
467 super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
468
469 let mut roaring = RoaringTreemap::new();
470 for idx in &indices {
471 roaring.insert(*idx);
472 }
473 let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
474
475 let range = 0..10;
476 let buffer_mask = buffer_selection.row_mask(&range);
477 let roaring_mask = roaring_selection.row_mask(&range);
478
479 assert_eq!(
480 buffer_mask.mask().values().unwrap().indices(),
481 roaring_mask.mask().values().unwrap().indices()
482 );
483 }
484 }
485}