Skip to main content

vortex_scan/
selection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Defines a selection mask over a scan.
5
6use std::ops::Not;
7use std::ops::Range;
8
9use vortex_buffer::Buffer;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::row_mask::RowMask;
14
15/// A selection identifies a set of rows to include in the scan (in addition to applying any
16/// filter predicates).
17#[derive(Default, Clone, Debug)]
18pub enum Selection {
19    /// No selection, all rows are included.
20    #[default]
21    All,
22    /// A selection of sorted rows to include by index.
23    IncludeByIndex(Buffer<u64>),
24    /// A selection of sorted rows to exclude by index.
25    ExcludeByIndex(Buffer<u64>),
26    /// A selection of rows to include using a [`roaring::RoaringTreemap`].
27    IncludeRoaring(roaring::RoaringTreemap),
28    /// A selection of rows to exclude using a [`roaring::RoaringTreemap`].
29    ExcludeRoaring(roaring::RoaringTreemap),
30}
31
32impl Selection {
33    /// Return the row count for this selection.
34    pub fn row_count(&self, total_rows: u64) -> u64 {
35        match self {
36            Selection::All => total_rows,
37            Selection::IncludeByIndex(include) => include.len() as u64,
38            Selection::ExcludeByIndex(exclude) => total_rows.saturating_sub(exclude.len() as u64),
39            Selection::IncludeRoaring(roaring) => roaring.len(),
40            Selection::ExcludeRoaring(roaring) => total_rows.saturating_sub(roaring.len()),
41        }
42    }
43
44    /// Extract the [`RowMask`] for the given range from this selection.
45    pub fn row_mask(&self, range: &Range<u64>) -> RowMask {
46        // Saturating subtraction to prevent underflow, though range should be valid
47        let range_diff = range.end.saturating_sub(range.start);
48        let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
49            // If the range is too large for usize, cap it at usize::MAX
50            // This is a defensive measure; in practice, ranges should be reasonable
51            tracing::warn!(
52                "Range length {} exceeds usize::MAX, capping at usize::MAX",
53                range_diff
54            );
55            usize::MAX
56        });
57
58        match self {
59            Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
60            Selection::IncludeByIndex(include) => {
61                let mask = indices_range(range, include)
62                    .map(|idx_range| {
63                        Mask::from_indices(
64                            range_len,
65                            include
66                                .slice(idx_range)
67                                .iter()
68                                .map(|idx| {
69                                    idx.checked_sub(range.start).unwrap_or_else(|| {
70                                        vortex_panic!(
71                                            "index underflow, range: {:?}, idx: {:?}",
72                                            range,
73                                            idx
74                                        )
75                                    })
76                                })
77                                .filter_map(|idx| {
78                                    // Only include indices that fit in usize
79                                    usize::try_from(idx).ok()
80                                })
81                                .collect(),
82                        )
83                    })
84                    .unwrap_or_else(|| Mask::new_false(range_len));
85
86                RowMask::new(range.start, mask)
87            }
88            Selection::ExcludeByIndex(exclude) => {
89                let mask = Selection::IncludeByIndex(exclude.clone())
90                    .row_mask(range)
91                    .mask()
92                    .clone();
93                RowMask::new(range.start, mask.not())
94            }
95            Selection::IncludeRoaring(roaring) => {
96                use std::ops::BitAnd;
97
98                // First we perform a cheap is_disjoint check
99                let mut range_treemap = roaring::RoaringTreemap::new();
100                range_treemap.insert_range(range.clone());
101
102                if roaring.is_disjoint(&range_treemap) {
103                    return RowMask::new(range.start, Mask::new_false(range_len));
104                }
105
106                // Otherwise, intersect with the selected range and shift to relativize.
107                let roaring = roaring.bitand(range_treemap);
108                let mask = Mask::from_indices(
109                    range_len,
110                    roaring
111                        .iter()
112                        .map(|idx| {
113                            idx.checked_sub(range.start).unwrap_or_else(|| {
114                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
115                            })
116                        })
117                        .filter_map(|idx| {
118                            // Only include indices that fit in usize
119                            usize::try_from(idx).ok()
120                        })
121                        .collect(),
122                );
123
124                RowMask::new(range.start, mask)
125            }
126            Selection::ExcludeRoaring(roaring) => {
127                use std::ops::BitAnd;
128
129                let mut range_treemap = roaring::RoaringTreemap::new();
130                range_treemap.insert_range(range.clone());
131
132                // If all indices in range are excluded, return all false mask
133                if roaring.intersection_len(&range_treemap) == range_len as u64 {
134                    return RowMask::new(range.start, Mask::new_false(range_len));
135                }
136
137                // Otherwise, intersect with the selected range and shift to relativize.
138                let roaring = roaring.bitand(range_treemap);
139                let mask = Mask::from_excluded_indices(
140                    range_len,
141                    roaring
142                        .iter()
143                        .map(|idx| {
144                            idx.checked_sub(range.start).unwrap_or_else(|| {
145                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
146                            })
147                        })
148                        .filter_map(|idx| usize::try_from(idx).ok()),
149                );
150
151                RowMask::new(range.start, mask)
152            }
153        }
154    }
155}
156
157/// Find the positional range within row_indices that covers all rows in the given range.
158fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
159    if row_indices.first().is_some_and(|&first| first >= range.end)
160        || row_indices.last().is_some_and(|&last| range.start > last)
161    {
162        return None;
163    }
164
165    // For the given row range, find the indices that are within the row_indices.
166    let start_idx = row_indices
167        .binary_search(&range.start)
168        .unwrap_or_else(|x| x);
169    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
170
171    (start_idx != end_idx).then_some(start_idx..end_idx)
172}
173
174#[cfg(test)]
175mod tests {
176    use vortex_buffer::Buffer;
177
178    #[test]
179    fn test_row_mask_all() {
180        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
181        let range = 1..8;
182        let row_mask = selection.row_mask(&range);
183
184        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
185    }
186
187    #[test]
188    fn test_row_mask_slice() {
189        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
190        let range = 3..6;
191        let row_mask = selection.row_mask(&range);
192
193        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
194    }
195
196    #[test]
197    fn test_row_mask_exclusive() {
198        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
199        let range = 3..5;
200        let row_mask = selection.row_mask(&range);
201
202        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
203    }
204
205    #[test]
206    fn test_row_mask_all_false() {
207        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
208        let range = 8..10;
209        let row_mask = selection.row_mask(&range);
210
211        assert!(row_mask.mask().all_false());
212    }
213
214    #[test]
215    fn test_row_mask_all_true() {
216        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
217        let range = 3..7;
218        let row_mask = selection.row_mask(&range);
219
220        assert!(row_mask.mask().all_true());
221    }
222
223    #[test]
224    fn test_row_mask_zero() {
225        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
226        let range = 0..5;
227        let row_mask = selection.row_mask(&range);
228
229        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
230    }
231
232    mod roaring_tests {
233        use roaring::RoaringTreemap;
234
235        use super::*;
236
237        #[test]
238        fn test_roaring_include_basic() {
239            let mut roaring = RoaringTreemap::new();
240            roaring.insert(1);
241            roaring.insert(3);
242            roaring.insert(5);
243            roaring.insert(7);
244
245            let selection = super::super::Selection::IncludeRoaring(roaring);
246            let range = 1..8;
247            let row_mask = selection.row_mask(&range);
248
249            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
250        }
251
252        #[test]
253        fn test_roaring_include_slice() {
254            let mut roaring = RoaringTreemap::new();
255            roaring.insert(1);
256            roaring.insert(3);
257            roaring.insert(5);
258            roaring.insert(7);
259
260            let selection = super::super::Selection::IncludeRoaring(roaring);
261            let range = 3..6;
262            let row_mask = selection.row_mask(&range);
263
264            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
265        }
266
267        #[test]
268        fn test_roaring_include_disjoint() {
269            let mut roaring = RoaringTreemap::new();
270            roaring.insert(1);
271            roaring.insert(3);
272            roaring.insert(5);
273            roaring.insert(7);
274
275            let selection = super::super::Selection::IncludeRoaring(roaring);
276            let range = 8..10;
277            let row_mask = selection.row_mask(&range);
278
279            assert!(row_mask.mask().all_false());
280        }
281
282        #[test]
283        fn test_roaring_include_large_range() {
284            let mut roaring = RoaringTreemap::new();
285            // Insert a large number of indices
286            for i in (0..1000000).step_by(2) {
287                roaring.insert(i);
288            }
289
290            let selection = super::super::Selection::IncludeRoaring(roaring);
291            let range = 1000..2000;
292            let row_mask = selection.row_mask(&range);
293
294            // Should have 500 selected indices (every even number)
295            assert_eq!(row_mask.mask().true_count(), 500);
296        }
297
298        #[test]
299        fn test_roaring_exclude_basic() {
300            let mut roaring = RoaringTreemap::new();
301            roaring.insert(1);
302            roaring.insert(3);
303            roaring.insert(5);
304
305            let selection = super::super::Selection::ExcludeRoaring(roaring);
306            let range = 0..7;
307            let row_mask = selection.row_mask(&range);
308
309            // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6
310            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
311        }
312
313        #[test]
314        fn test_roaring_exclude_all() {
315            let mut roaring = RoaringTreemap::new();
316            // Exclude all indices in range
317            for i in 10..20 {
318                roaring.insert(i);
319            }
320
321            let selection = super::super::Selection::ExcludeRoaring(roaring);
322            let range = 10..20;
323            let row_mask = selection.row_mask(&range);
324
325            assert!(row_mask.mask().all_false());
326        }
327
328        #[test]
329        fn test_roaring_exclude_none() {
330            let mut roaring = RoaringTreemap::new();
331            roaring.insert(100);
332            roaring.insert(101);
333
334            let selection = super::super::Selection::ExcludeRoaring(roaring);
335            let range = 0..10;
336            let row_mask = selection.row_mask(&range);
337
338            // Nothing to exclude in this range
339            assert!(row_mask.mask().all_true());
340        }
341
342        #[test]
343        fn test_roaring_exclude_partial() {
344            let mut roaring = RoaringTreemap::new();
345            roaring.insert(5);
346            roaring.insert(6);
347            roaring.insert(7);
348            roaring.insert(15); // Outside range
349
350            let selection = super::super::Selection::ExcludeRoaring(roaring);
351            let range = 5..10;
352            let row_mask = selection.row_mask(&range);
353
354            // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4)
355            assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
356        }
357
358        #[test]
359        fn test_roaring_include_empty() {
360            let roaring = RoaringTreemap::new();
361            let selection = super::super::Selection::IncludeRoaring(roaring);
362            let range = 0..100;
363            let row_mask = selection.row_mask(&range);
364
365            assert!(row_mask.mask().all_false());
366        }
367
368        #[test]
369        fn test_roaring_exclude_empty() {
370            let roaring = RoaringTreemap::new();
371            let selection = super::super::Selection::ExcludeRoaring(roaring);
372            let range = 0..100;
373            let row_mask = selection.row_mask(&range);
374
375            assert!(row_mask.mask().all_true());
376        }
377
378        #[test]
379        fn test_roaring_include_boundary() {
380            let mut roaring = RoaringTreemap::new();
381            roaring.insert(0);
382            roaring.insert(99);
383
384            let selection = super::super::Selection::IncludeRoaring(roaring);
385            let range = 0..100;
386            let row_mask = selection.row_mask(&range);
387
388            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
389        }
390
391        #[test]
392        fn test_roaring_include_range_insertion() {
393            let mut roaring = RoaringTreemap::new();
394            // Use insert_range for efficiency
395            roaring.insert_range(10..20);
396            roaring.insert_range(30..40);
397
398            let selection = super::super::Selection::IncludeRoaring(roaring);
399            let range = 15..35;
400            let row_mask = selection.row_mask(&range);
401
402            // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19)
403            let expected: Vec<usize> = (0..5).chain(15..20).collect();
404            assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
405        }
406
407        #[test]
408        fn test_roaring_overflow_protection() {
409            let mut roaring = RoaringTreemap::new();
410            // Insert very large indices
411            roaring.insert(u64::MAX - 1);
412            roaring.insert(u64::MAX);
413
414            let selection = super::super::Selection::IncludeRoaring(roaring);
415            let range = u64::MAX - 10..u64::MAX;
416            let row_mask = selection.row_mask(&range);
417
418            // Should handle overflow gracefully
419            assert_eq!(row_mask.mask().true_count(), 1); // Only u64::MAX - 1 is in range
420        }
421
422        #[test]
423        fn test_roaring_exclude_overflow_protection() {
424            let mut roaring = RoaringTreemap::new();
425            roaring.insert(u64::MAX - 1);
426
427            let selection = super::super::Selection::ExcludeRoaring(roaring);
428            let range = u64::MAX - 10..u64::MAX;
429            let row_mask = selection.row_mask(&range);
430
431            // Should handle overflow gracefully, excluding index u64::MAX - 1
432            assert_eq!(row_mask.mask().true_count(), 9); // All except one
433        }
434
435        #[test]
436        fn test_roaring_include_vs_buffer_equivalence() {
437            // Test that RoaringTreemap and Buffer produce same results
438            let indices = vec![1, 3, 5, 7, 9];
439
440            let buffer_selection =
441                super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
442
443            let mut roaring = RoaringTreemap::new();
444            for idx in &indices {
445                roaring.insert(*idx);
446            }
447            let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
448
449            let range = 0..12;
450            let buffer_mask = buffer_selection.row_mask(&range);
451            let roaring_mask = roaring_selection.row_mask(&range);
452
453            assert_eq!(
454                buffer_mask.mask().values().unwrap().indices(),
455                roaring_mask.mask().values().unwrap().indices()
456            );
457        }
458
459        #[test]
460        fn test_roaring_exclude_vs_buffer_equivalence() {
461            // Test that ExcludeRoaring and ExcludeByIndex produce same results
462            let indices = vec![2, 4, 6, 8];
463
464            let buffer_selection =
465                super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
466
467            let mut roaring = RoaringTreemap::new();
468            for idx in &indices {
469                roaring.insert(*idx);
470            }
471            let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
472
473            let range = 0..10;
474            let buffer_mask = buffer_selection.row_mask(&range);
475            let roaring_mask = roaring_selection.row_mask(&range);
476
477            assert_eq!(
478                buffer_mask.mask().values().unwrap().indices(),
479                roaring_mask.mask().values().unwrap().indices()
480            );
481        }
482    }
483}