1use std::ops::Not;
7use std::ops::Range;
8
9use vortex_buffer::Buffer;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::row_mask::RowMask;
14
15#[derive(Default, Clone, Debug)]
18pub enum Selection {
19 #[default]
21 All,
22 IncludeByIndex(Buffer<u64>),
24 ExcludeByIndex(Buffer<u64>),
26 IncludeRoaring(roaring::RoaringTreemap),
28 ExcludeRoaring(roaring::RoaringTreemap),
30}
31
32impl Selection {
33 pub fn row_count(&self, total_rows: u64) -> u64 {
35 match self {
36 Selection::All => total_rows,
37 Selection::IncludeByIndex(include) => include.len() as u64,
38 Selection::ExcludeByIndex(exclude) => total_rows.saturating_sub(exclude.len() as u64),
39 Selection::IncludeRoaring(roaring) => roaring.len(),
40 Selection::ExcludeRoaring(roaring) => total_rows.saturating_sub(roaring.len()),
41 }
42 }
43
44 pub fn row_mask(&self, range: &Range<u64>) -> RowMask {
46 let range_diff = range.end.saturating_sub(range.start);
48 let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
49 tracing::warn!(
52 "Range length {} exceeds usize::MAX, capping at usize::MAX",
53 range_diff
54 );
55 usize::MAX
56 });
57
58 match self {
59 Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
60 Selection::IncludeByIndex(include) => {
61 let mask = indices_range(range, include)
62 .map(|idx_range| {
63 Mask::from_indices(
64 range_len,
65 include
66 .slice(idx_range)
67 .iter()
68 .map(|idx| {
69 idx.checked_sub(range.start).unwrap_or_else(|| {
70 vortex_panic!(
71 "index underflow, range: {:?}, idx: {:?}",
72 range,
73 idx
74 )
75 })
76 })
77 .filter_map(|idx| {
78 usize::try_from(idx).ok()
80 })
81 .collect(),
82 )
83 })
84 .unwrap_or_else(|| Mask::new_false(range_len));
85
86 RowMask::new(range.start, mask)
87 }
88 Selection::ExcludeByIndex(exclude) => {
89 let mask = Selection::IncludeByIndex(exclude.clone())
90 .row_mask(range)
91 .mask()
92 .clone();
93 RowMask::new(range.start, mask.not())
94 }
95 Selection::IncludeRoaring(roaring) => {
96 use std::ops::BitAnd;
97
98 let mut range_treemap = roaring::RoaringTreemap::new();
100 range_treemap.insert_range(range.clone());
101
102 if roaring.is_disjoint(&range_treemap) {
103 return RowMask::new(range.start, Mask::new_false(range_len));
104 }
105
106 let roaring = roaring.bitand(range_treemap);
108 let mask = Mask::from_indices(
109 range_len,
110 roaring
111 .iter()
112 .map(|idx| {
113 idx.checked_sub(range.start).unwrap_or_else(|| {
114 vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
115 })
116 })
117 .filter_map(|idx| {
118 usize::try_from(idx).ok()
120 })
121 .collect(),
122 );
123
124 RowMask::new(range.start, mask)
125 }
126 Selection::ExcludeRoaring(roaring) => {
127 use std::ops::BitAnd;
128
129 let mut range_treemap = roaring::RoaringTreemap::new();
130 range_treemap.insert_range(range.clone());
131
132 if roaring.intersection_len(&range_treemap) == range_len as u64 {
134 return RowMask::new(range.start, Mask::new_false(range_len));
135 }
136
137 let roaring = roaring.bitand(range_treemap);
139 let mask = Mask::from_excluded_indices(
140 range_len,
141 roaring
142 .iter()
143 .map(|idx| {
144 idx.checked_sub(range.start).unwrap_or_else(|| {
145 vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
146 })
147 })
148 .filter_map(|idx| usize::try_from(idx).ok()),
149 );
150
151 RowMask::new(range.start, mask)
152 }
153 }
154 }
155}
156
157fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
159 if row_indices.first().is_some_and(|&first| first >= range.end)
160 || row_indices.last().is_some_and(|&last| range.start > last)
161 {
162 return None;
163 }
164
165 let start_idx = row_indices
167 .binary_search(&range.start)
168 .unwrap_or_else(|x| x);
169 let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
170
171 (start_idx != end_idx).then_some(start_idx..end_idx)
172}
173
174#[cfg(test)]
175mod tests {
176 use vortex_buffer::Buffer;
177
178 #[test]
179 fn test_row_mask_all() {
180 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
181 let range = 1..8;
182 let row_mask = selection.row_mask(&range);
183
184 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
185 }
186
187 #[test]
188 fn test_row_mask_slice() {
189 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
190 let range = 3..6;
191 let row_mask = selection.row_mask(&range);
192
193 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
194 }
195
196 #[test]
197 fn test_row_mask_exclusive() {
198 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
199 let range = 3..5;
200 let row_mask = selection.row_mask(&range);
201
202 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
203 }
204
205 #[test]
206 fn test_row_mask_all_false() {
207 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
208 let range = 8..10;
209 let row_mask = selection.row_mask(&range);
210
211 assert!(row_mask.mask().all_false());
212 }
213
214 #[test]
215 fn test_row_mask_all_true() {
216 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
217 let range = 3..7;
218 let row_mask = selection.row_mask(&range);
219
220 assert!(row_mask.mask().all_true());
221 }
222
223 #[test]
224 fn test_row_mask_zero() {
225 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
226 let range = 0..5;
227 let row_mask = selection.row_mask(&range);
228
229 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
230 }
231
232 mod roaring_tests {
233 use roaring::RoaringTreemap;
234
235 use super::*;
236
237 #[test]
238 fn test_roaring_include_basic() {
239 let mut roaring = RoaringTreemap::new();
240 roaring.insert(1);
241 roaring.insert(3);
242 roaring.insert(5);
243 roaring.insert(7);
244
245 let selection = super::super::Selection::IncludeRoaring(roaring);
246 let range = 1..8;
247 let row_mask = selection.row_mask(&range);
248
249 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
250 }
251
252 #[test]
253 fn test_roaring_include_slice() {
254 let mut roaring = RoaringTreemap::new();
255 roaring.insert(1);
256 roaring.insert(3);
257 roaring.insert(5);
258 roaring.insert(7);
259
260 let selection = super::super::Selection::IncludeRoaring(roaring);
261 let range = 3..6;
262 let row_mask = selection.row_mask(&range);
263
264 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
265 }
266
267 #[test]
268 fn test_roaring_include_disjoint() {
269 let mut roaring = RoaringTreemap::new();
270 roaring.insert(1);
271 roaring.insert(3);
272 roaring.insert(5);
273 roaring.insert(7);
274
275 let selection = super::super::Selection::IncludeRoaring(roaring);
276 let range = 8..10;
277 let row_mask = selection.row_mask(&range);
278
279 assert!(row_mask.mask().all_false());
280 }
281
282 #[test]
283 fn test_roaring_include_large_range() {
284 let mut roaring = RoaringTreemap::new();
285 for i in (0..1000000).step_by(2) {
287 roaring.insert(i);
288 }
289
290 let selection = super::super::Selection::IncludeRoaring(roaring);
291 let range = 1000..2000;
292 let row_mask = selection.row_mask(&range);
293
294 assert_eq!(row_mask.mask().true_count(), 500);
296 }
297
298 #[test]
299 fn test_roaring_exclude_basic() {
300 let mut roaring = RoaringTreemap::new();
301 roaring.insert(1);
302 roaring.insert(3);
303 roaring.insert(5);
304
305 let selection = super::super::Selection::ExcludeRoaring(roaring);
306 let range = 0..7;
307 let row_mask = selection.row_mask(&range);
308
309 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
311 }
312
313 #[test]
314 fn test_roaring_exclude_all() {
315 let mut roaring = RoaringTreemap::new();
316 for i in 10..20 {
318 roaring.insert(i);
319 }
320
321 let selection = super::super::Selection::ExcludeRoaring(roaring);
322 let range = 10..20;
323 let row_mask = selection.row_mask(&range);
324
325 assert!(row_mask.mask().all_false());
326 }
327
328 #[test]
329 fn test_roaring_exclude_none() {
330 let mut roaring = RoaringTreemap::new();
331 roaring.insert(100);
332 roaring.insert(101);
333
334 let selection = super::super::Selection::ExcludeRoaring(roaring);
335 let range = 0..10;
336 let row_mask = selection.row_mask(&range);
337
338 assert!(row_mask.mask().all_true());
340 }
341
342 #[test]
343 fn test_roaring_exclude_partial() {
344 let mut roaring = RoaringTreemap::new();
345 roaring.insert(5);
346 roaring.insert(6);
347 roaring.insert(7);
348 roaring.insert(15); let selection = super::super::Selection::ExcludeRoaring(roaring);
351 let range = 5..10;
352 let row_mask = selection.row_mask(&range);
353
354 assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
356 }
357
358 #[test]
359 fn test_roaring_include_empty() {
360 let roaring = RoaringTreemap::new();
361 let selection = super::super::Selection::IncludeRoaring(roaring);
362 let range = 0..100;
363 let row_mask = selection.row_mask(&range);
364
365 assert!(row_mask.mask().all_false());
366 }
367
368 #[test]
369 fn test_roaring_exclude_empty() {
370 let roaring = RoaringTreemap::new();
371 let selection = super::super::Selection::ExcludeRoaring(roaring);
372 let range = 0..100;
373 let row_mask = selection.row_mask(&range);
374
375 assert!(row_mask.mask().all_true());
376 }
377
378 #[test]
379 fn test_roaring_include_boundary() {
380 let mut roaring = RoaringTreemap::new();
381 roaring.insert(0);
382 roaring.insert(99);
383
384 let selection = super::super::Selection::IncludeRoaring(roaring);
385 let range = 0..100;
386 let row_mask = selection.row_mask(&range);
387
388 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
389 }
390
391 #[test]
392 fn test_roaring_include_range_insertion() {
393 let mut roaring = RoaringTreemap::new();
394 roaring.insert_range(10..20);
396 roaring.insert_range(30..40);
397
398 let selection = super::super::Selection::IncludeRoaring(roaring);
399 let range = 15..35;
400 let row_mask = selection.row_mask(&range);
401
402 let expected: Vec<usize> = (0..5).chain(15..20).collect();
404 assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
405 }
406
407 #[test]
408 fn test_roaring_overflow_protection() {
409 let mut roaring = RoaringTreemap::new();
410 roaring.insert(u64::MAX - 1);
412 roaring.insert(u64::MAX);
413
414 let selection = super::super::Selection::IncludeRoaring(roaring);
415 let range = u64::MAX - 10..u64::MAX;
416 let row_mask = selection.row_mask(&range);
417
418 assert_eq!(row_mask.mask().true_count(), 1); }
421
422 #[test]
423 fn test_roaring_exclude_overflow_protection() {
424 let mut roaring = RoaringTreemap::new();
425 roaring.insert(u64::MAX - 1);
426
427 let selection = super::super::Selection::ExcludeRoaring(roaring);
428 let range = u64::MAX - 10..u64::MAX;
429 let row_mask = selection.row_mask(&range);
430
431 assert_eq!(row_mask.mask().true_count(), 9); }
434
435 #[test]
436 fn test_roaring_include_vs_buffer_equivalence() {
437 let indices = vec![1, 3, 5, 7, 9];
439
440 let buffer_selection =
441 super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
442
443 let mut roaring = RoaringTreemap::new();
444 for idx in &indices {
445 roaring.insert(*idx);
446 }
447 let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
448
449 let range = 0..12;
450 let buffer_mask = buffer_selection.row_mask(&range);
451 let roaring_mask = roaring_selection.row_mask(&range);
452
453 assert_eq!(
454 buffer_mask.mask().values().unwrap().indices(),
455 roaring_mask.mask().values().unwrap().indices()
456 );
457 }
458
459 #[test]
460 fn test_roaring_exclude_vs_buffer_equivalence() {
461 let indices = vec![2, 4, 6, 8];
463
464 let buffer_selection =
465 super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
466
467 let mut roaring = RoaringTreemap::new();
468 for idx in &indices {
469 roaring.insert(*idx);
470 }
471 let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
472
473 let range = 0..10;
474 let buffer_mask = buffer_selection.row_mask(&range);
475 let roaring_mask = roaring_selection.row_mask(&range);
476
477 assert_eq!(
478 buffer_mask.mask().values().unwrap().indices(),
479 roaring_mask.mask().values().unwrap().indices()
480 );
481 }
482 }
483}