vortex_scan/
selection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::{Not, Range};
5
6use vortex_buffer::Buffer;
7use vortex_error::vortex_panic;
8use vortex_mask::Mask;
9
10use crate::row_mask::RowMask;
11
12/// A selection identifies a set of rows to include in the scan (in addition to applying any
13/// filter predicates).
14#[derive(Default, Clone)]
15pub enum Selection {
16    /// No selection, all rows are included.
17    #[default]
18    All,
19    /// A selection of rows to include by index.
20    IncludeByIndex(Buffer<u64>),
21    /// A selection of rows to exclude by index.
22    ExcludeByIndex(Buffer<u64>),
23    /// A selection of rows to include using a [`roaring::RoaringTreemap`].
24    #[cfg(feature = "roaring")]
25    IncludeRoaring(roaring::RoaringTreemap),
26    /// A selection of rows to exclude using a [`roaring::RoaringTreemap`].
27    #[cfg(feature = "roaring")]
28    ExcludeRoaring(roaring::RoaringTreemap),
29}
30
31impl Selection {
32    /// Extract the [`RowMask`] for the given range from this selection.
33    pub(crate) fn row_mask(&self, range: &Range<u64>) -> RowMask {
34        // Saturating subtraction to prevent underflow, though range should be valid
35        let range_diff = range.end.saturating_sub(range.start);
36        let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
37            // If the range is too large for usize, cap it at usize::MAX
38            // This is a defensive measure; in practice, ranges should be reasonable
39            log::warn!(
40                "Range length {} exceeds usize::MAX, capping at usize::MAX",
41                range_diff
42            );
43            usize::MAX
44        });
45
46        match self {
47            Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
48            Selection::IncludeByIndex(include) => {
49                let mask = indices_range(range, include)
50                    .map(|idx_range| {
51                        Mask::from_indices(
52                            range_len,
53                            include
54                                .slice(idx_range)
55                                .iter()
56                                .map(|idx| {
57                                    idx.checked_sub(range.start).unwrap_or_else(|| {
58                                        vortex_panic!(
59                                            "index underflow, range: {:?}, idx: {:?}",
60                                            range,
61                                            idx
62                                        )
63                                    })
64                                })
65                                .filter_map(|idx| {
66                                    // Only include indices that fit in usize
67                                    usize::try_from(idx).ok()
68                                })
69                                .collect(),
70                        )
71                    })
72                    .unwrap_or_else(|| Mask::new_false(range_len));
73
74                RowMask::new(range.start, mask)
75            }
76            Selection::ExcludeByIndex(exclude) => {
77                let mask = Selection::IncludeByIndex(exclude.clone())
78                    .row_mask(range)
79                    .mask()
80                    .clone();
81                RowMask::new(range.start, mask.not())
82            }
83            #[cfg(feature = "roaring")]
84            Selection::IncludeRoaring(roaring) => {
85                use std::ops::BitAnd;
86
87                // First we perform a cheap is_disjoint check
88                let mut range_treemap = roaring::RoaringTreemap::new();
89                range_treemap.insert_range(range.clone());
90
91                if roaring.is_disjoint(&range_treemap) {
92                    return RowMask::new(range.start, Mask::new_false(range_len));
93                }
94
95                // Otherwise, intersect with the selected range and shift to relativize.
96                let roaring = roaring.bitand(range_treemap);
97                let mask = Mask::from_indices(
98                    range_len,
99                    roaring
100                        .iter()
101                        .map(|idx| {
102                            idx.checked_sub(range.start).unwrap_or_else(|| {
103                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
104                            })
105                        })
106                        .filter_map(|idx| {
107                            // Only include indices that fit in usize
108                            usize::try_from(idx).ok()
109                        })
110                        .collect(),
111                );
112
113                RowMask::new(range.start, mask)
114            }
115            #[cfg(feature = "roaring")]
116            Selection::ExcludeRoaring(roaring) => {
117                use std::ops::BitAnd;
118
119                let mut range_treemap = roaring::RoaringTreemap::new();
120                range_treemap.insert_range(range.clone());
121
122                // If all indices in range are excluded, return all false mask
123                if roaring.intersection_len(&range_treemap) == range_len as u64 {
124                    return RowMask::new(range.start, Mask::new_false(range_len));
125                }
126
127                // Otherwise, intersect with the selected range and shift to relativize.
128                let roaring = roaring.bitand(range_treemap);
129                let mask = Mask::from_excluded_indices(
130                    range_len,
131                    roaring
132                        .iter()
133                        .map(|idx| {
134                            idx.checked_sub(range.start).unwrap_or_else(|| {
135                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
136                            })
137                        })
138                        .filter_map(|idx| usize::try_from(idx).ok()),
139                );
140
141                RowMask::new(range.start, mask)
142            }
143        }
144    }
145}
146
147/// Find the positional range within row_indices that covers all rows in the given range.
148fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
149    if row_indices.first().is_some_and(|&first| first >= range.end)
150        || row_indices.last().is_some_and(|&last| range.start > last)
151    {
152        return None;
153    }
154
155    // For the given row range, find the indices that are within the row_indices.
156    let start_idx = row_indices
157        .binary_search(&range.start)
158        .unwrap_or_else(|x| x);
159    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
160
161    (start_idx != end_idx).then_some(start_idx..end_idx)
162}
163
164#[cfg(test)]
165mod tests {
166    use vortex_buffer::Buffer;
167
168    #[test]
169    fn test_row_mask_all() {
170        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
171        let range = 1..8;
172        let row_mask = selection.row_mask(&range);
173
174        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
175    }
176
177    #[test]
178    fn test_row_mask_slice() {
179        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
180        let range = 3..6;
181        let row_mask = selection.row_mask(&range);
182
183        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
184    }
185
186    #[test]
187    fn test_row_mask_exclusive() {
188        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
189        let range = 3..5;
190        let row_mask = selection.row_mask(&range);
191
192        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
193    }
194
195    #[test]
196    fn test_row_mask_all_false() {
197        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
198        let range = 8..10;
199        let row_mask = selection.row_mask(&range);
200
201        assert!(row_mask.mask().all_false());
202    }
203
204    #[test]
205    fn test_row_mask_all_true() {
206        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
207        let range = 3..7;
208        let row_mask = selection.row_mask(&range);
209
210        assert!(row_mask.mask().all_true());
211    }
212
213    #[test]
214    fn test_row_mask_zero() {
215        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
216        let range = 0..5;
217        let row_mask = selection.row_mask(&range);
218
219        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
220    }
221
222    #[cfg(feature = "roaring")]
223    mod roaring_tests {
224        use roaring::RoaringTreemap;
225
226        use super::*;
227
228        #[test]
229        fn test_roaring_include_basic() {
230            let mut roaring = RoaringTreemap::new();
231            roaring.insert(1);
232            roaring.insert(3);
233            roaring.insert(5);
234            roaring.insert(7);
235
236            let selection = super::super::Selection::IncludeRoaring(roaring);
237            let range = 1..8;
238            let row_mask = selection.row_mask(&range);
239
240            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
241        }
242
243        #[test]
244        fn test_roaring_include_slice() {
245            let mut roaring = RoaringTreemap::new();
246            roaring.insert(1);
247            roaring.insert(3);
248            roaring.insert(5);
249            roaring.insert(7);
250
251            let selection = super::super::Selection::IncludeRoaring(roaring);
252            let range = 3..6;
253            let row_mask = selection.row_mask(&range);
254
255            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
256        }
257
258        #[test]
259        fn test_roaring_include_disjoint() {
260            let mut roaring = RoaringTreemap::new();
261            roaring.insert(1);
262            roaring.insert(3);
263            roaring.insert(5);
264            roaring.insert(7);
265
266            let selection = super::super::Selection::IncludeRoaring(roaring);
267            let range = 8..10;
268            let row_mask = selection.row_mask(&range);
269
270            assert!(row_mask.mask().all_false());
271        }
272
273        #[test]
274        fn test_roaring_include_large_range() {
275            let mut roaring = RoaringTreemap::new();
276            // Insert a large number of indices
277            for i in (0..1000000).step_by(2) {
278                roaring.insert(i);
279            }
280
281            let selection = super::super::Selection::IncludeRoaring(roaring);
282            let range = 1000..2000;
283            let row_mask = selection.row_mask(&range);
284
285            // Should have 500 selected indices (every even number)
286            assert_eq!(row_mask.mask().true_count(), 500);
287        }
288
289        #[test]
290        fn test_roaring_exclude_basic() {
291            let mut roaring = RoaringTreemap::new();
292            roaring.insert(1);
293            roaring.insert(3);
294            roaring.insert(5);
295
296            let selection = super::super::Selection::ExcludeRoaring(roaring);
297            let range = 0..7;
298            let row_mask = selection.row_mask(&range);
299
300            // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6
301            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
302        }
303
304        #[test]
305        fn test_roaring_exclude_all() {
306            let mut roaring = RoaringTreemap::new();
307            // Exclude all indices in range
308            for i in 10..20 {
309                roaring.insert(i);
310            }
311
312            let selection = super::super::Selection::ExcludeRoaring(roaring);
313            let range = 10..20;
314            let row_mask = selection.row_mask(&range);
315
316            assert!(row_mask.mask().all_false());
317        }
318
319        #[test]
320        fn test_roaring_exclude_none() {
321            let mut roaring = RoaringTreemap::new();
322            roaring.insert(100);
323            roaring.insert(101);
324
325            let selection = super::super::Selection::ExcludeRoaring(roaring);
326            let range = 0..10;
327            let row_mask = selection.row_mask(&range);
328
329            // Nothing to exclude in this range
330            assert!(row_mask.mask().all_true());
331        }
332
333        #[test]
334        fn test_roaring_exclude_partial() {
335            let mut roaring = RoaringTreemap::new();
336            roaring.insert(5);
337            roaring.insert(6);
338            roaring.insert(7);
339            roaring.insert(15); // Outside range
340
341            let selection = super::super::Selection::ExcludeRoaring(roaring);
342            let range = 5..10;
343            let row_mask = selection.row_mask(&range);
344
345            // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4)
346            assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
347        }
348
349        #[test]
350        fn test_roaring_include_empty() {
351            let roaring = RoaringTreemap::new();
352            let selection = super::super::Selection::IncludeRoaring(roaring);
353            let range = 0..100;
354            let row_mask = selection.row_mask(&range);
355
356            assert!(row_mask.mask().all_false());
357        }
358
359        #[test]
360        fn test_roaring_exclude_empty() {
361            let roaring = RoaringTreemap::new();
362            let selection = super::super::Selection::ExcludeRoaring(roaring);
363            let range = 0..100;
364            let row_mask = selection.row_mask(&range);
365
366            assert!(row_mask.mask().all_true());
367        }
368
369        #[test]
370        fn test_roaring_include_boundary() {
371            let mut roaring = RoaringTreemap::new();
372            roaring.insert(0);
373            roaring.insert(99);
374
375            let selection = super::super::Selection::IncludeRoaring(roaring);
376            let range = 0..100;
377            let row_mask = selection.row_mask(&range);
378
379            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
380        }
381
382        #[test]
383        fn test_roaring_include_range_insertion() {
384            let mut roaring = RoaringTreemap::new();
385            // Use insert_range for efficiency
386            roaring.insert_range(10..20);
387            roaring.insert_range(30..40);
388
389            let selection = super::super::Selection::IncludeRoaring(roaring);
390            let range = 15..35;
391            let row_mask = selection.row_mask(&range);
392
393            // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19)
394            let expected: Vec<usize> = (0..5).chain(15..20).collect();
395            assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
396        }
397
398        #[test]
399        fn test_roaring_overflow_protection() {
400            let mut roaring = RoaringTreemap::new();
401            // Insert very large indices
402            roaring.insert(u64::MAX - 1);
403            roaring.insert(u64::MAX);
404
405            let selection = super::super::Selection::IncludeRoaring(roaring);
406            let range = u64::MAX - 10..u64::MAX;
407            let row_mask = selection.row_mask(&range);
408
409            // Should handle overflow gracefully
410            assert_eq!(row_mask.mask().true_count(), 1); // Only u64::MAX - 1 is in range
411        }
412
413        #[test]
414        fn test_roaring_exclude_overflow_protection() {
415            let mut roaring = RoaringTreemap::new();
416            roaring.insert(u64::MAX - 1);
417
418            let selection = super::super::Selection::ExcludeRoaring(roaring);
419            let range = u64::MAX - 10..u64::MAX;
420            let row_mask = selection.row_mask(&range);
421
422            // Should handle overflow gracefully, excluding index u64::MAX - 1
423            assert_eq!(row_mask.mask().true_count(), 9); // All except one
424        }
425
426        #[test]
427        fn test_roaring_include_vs_buffer_equivalence() {
428            // Test that RoaringTreemap and Buffer produce same results
429            let indices = vec![1, 3, 5, 7, 9];
430
431            let buffer_selection =
432                super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
433
434            let mut roaring = RoaringTreemap::new();
435            for idx in &indices {
436                roaring.insert(*idx);
437            }
438            let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
439
440            let range = 0..12;
441            let buffer_mask = buffer_selection.row_mask(&range);
442            let roaring_mask = roaring_selection.row_mask(&range);
443
444            assert_eq!(
445                buffer_mask.mask().values().unwrap().indices(),
446                roaring_mask.mask().values().unwrap().indices()
447            );
448        }
449
450        #[test]
451        fn test_roaring_exclude_vs_buffer_equivalence() {
452            // Test that ExcludeRoaring and ExcludeByIndex produce same results
453            let indices = vec![2, 4, 6, 8];
454
455            let buffer_selection =
456                super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
457
458            let mut roaring = RoaringTreemap::new();
459            for idx in &indices {
460                roaring.insert(*idx);
461            }
462            let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
463
464            let range = 0..10;
465            let buffer_mask = buffer_selection.row_mask(&range);
466            let roaring_mask = roaring_selection.row_mask(&range);
467
468            assert_eq!(
469                buffer_mask.mask().values().unwrap().indices(),
470                roaring_mask.mask().values().unwrap().indices()
471            );
472        }
473    }
474}