vortex_scan/
selection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5use std::ops::Range;
6
7use vortex_buffer::Buffer;
8use vortex_error::vortex_panic;
9use vortex_mask::Mask;
10
11use crate::row_mask::RowMask;
12
13/// A selection identifies a set of rows to include in the scan (in addition to applying any
14/// filter predicates).
15#[derive(Default, Clone)]
16pub enum Selection {
17    /// No selection, all rows are included.
18    #[default]
19    All,
20    /// A selection of sorted rows to include by index.
21    IncludeByIndex(Buffer<u64>),
22    /// A selection of sorted rows to exclude by index.
23    ExcludeByIndex(Buffer<u64>),
24    /// A selection of rows to include using a [`roaring::RoaringTreemap`].
25    #[cfg(feature = "roaring")]
26    IncludeRoaring(roaring::RoaringTreemap),
27    /// A selection of rows to exclude using a [`roaring::RoaringTreemap`].
28    #[cfg(feature = "roaring")]
29    ExcludeRoaring(roaring::RoaringTreemap),
30}
31
32impl Selection {
33    /// Extract the [`RowMask`] for the given range from this selection.
34    pub(crate) fn row_mask(&self, range: &Range<u64>) -> RowMask {
35        // Saturating subtraction to prevent underflow, though range should be valid
36        let range_diff = range.end.saturating_sub(range.start);
37        let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
38            // If the range is too large for usize, cap it at usize::MAX
39            // This is a defensive measure; in practice, ranges should be reasonable
40            tracing::warn!(
41                "Range length {} exceeds usize::MAX, capping at usize::MAX",
42                range_diff
43            );
44            usize::MAX
45        });
46
47        match self {
48            Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
49            Selection::IncludeByIndex(include) => {
50                let mask = indices_range(range, include)
51                    .map(|idx_range| {
52                        Mask::from_indices(
53                            range_len,
54                            include
55                                .slice(idx_range)
56                                .iter()
57                                .map(|idx| {
58                                    idx.checked_sub(range.start).unwrap_or_else(|| {
59                                        vortex_panic!(
60                                            "index underflow, range: {:?}, idx: {:?}",
61                                            range,
62                                            idx
63                                        )
64                                    })
65                                })
66                                .filter_map(|idx| {
67                                    // Only include indices that fit in usize
68                                    usize::try_from(idx).ok()
69                                })
70                                .collect(),
71                        )
72                    })
73                    .unwrap_or_else(|| Mask::new_false(range_len));
74
75                RowMask::new(range.start, mask)
76            }
77            Selection::ExcludeByIndex(exclude) => {
78                let mask = Selection::IncludeByIndex(exclude.clone())
79                    .row_mask(range)
80                    .mask()
81                    .clone();
82                RowMask::new(range.start, mask.not())
83            }
84            #[cfg(feature = "roaring")]
85            Selection::IncludeRoaring(roaring) => {
86                use std::ops::BitAnd;
87
88                // First we perform a cheap is_disjoint check
89                let mut range_treemap = roaring::RoaringTreemap::new();
90                range_treemap.insert_range(range.clone());
91
92                if roaring.is_disjoint(&range_treemap) {
93                    return RowMask::new(range.start, Mask::new_false(range_len));
94                }
95
96                // Otherwise, intersect with the selected range and shift to relativize.
97                let roaring = roaring.bitand(range_treemap);
98                let mask = Mask::from_indices(
99                    range_len,
100                    roaring
101                        .iter()
102                        .map(|idx| {
103                            idx.checked_sub(range.start).unwrap_or_else(|| {
104                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
105                            })
106                        })
107                        .filter_map(|idx| {
108                            // Only include indices that fit in usize
109                            usize::try_from(idx).ok()
110                        })
111                        .collect(),
112                );
113
114                RowMask::new(range.start, mask)
115            }
116            #[cfg(feature = "roaring")]
117            Selection::ExcludeRoaring(roaring) => {
118                use std::ops::BitAnd;
119
120                let mut range_treemap = roaring::RoaringTreemap::new();
121                range_treemap.insert_range(range.clone());
122
123                // If all indices in range are excluded, return all false mask
124                if roaring.intersection_len(&range_treemap) == range_len as u64 {
125                    return RowMask::new(range.start, Mask::new_false(range_len));
126                }
127
128                // Otherwise, intersect with the selected range and shift to relativize.
129                let roaring = roaring.bitand(range_treemap);
130                let mask = Mask::from_excluded_indices(
131                    range_len,
132                    roaring
133                        .iter()
134                        .map(|idx| {
135                            idx.checked_sub(range.start).unwrap_or_else(|| {
136                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
137                            })
138                        })
139                        .filter_map(|idx| usize::try_from(idx).ok()),
140                );
141
142                RowMask::new(range.start, mask)
143            }
144        }
145    }
146}
147
148/// Find the positional range within row_indices that covers all rows in the given range.
149fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
150    if row_indices.first().is_some_and(|&first| first >= range.end)
151        || row_indices.last().is_some_and(|&last| range.start > last)
152    {
153        return None;
154    }
155
156    // For the given row range, find the indices that are within the row_indices.
157    let start_idx = row_indices
158        .binary_search(&range.start)
159        .unwrap_or_else(|x| x);
160    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
161
162    (start_idx != end_idx).then_some(start_idx..end_idx)
163}
164
165#[cfg(test)]
166mod tests {
167    use vortex_buffer::Buffer;
168
169    #[test]
170    fn test_row_mask_all() {
171        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
172        let range = 1..8;
173        let row_mask = selection.row_mask(&range);
174
175        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
176    }
177
178    #[test]
179    fn test_row_mask_slice() {
180        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
181        let range = 3..6;
182        let row_mask = selection.row_mask(&range);
183
184        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
185    }
186
187    #[test]
188    fn test_row_mask_exclusive() {
189        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
190        let range = 3..5;
191        let row_mask = selection.row_mask(&range);
192
193        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
194    }
195
196    #[test]
197    fn test_row_mask_all_false() {
198        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
199        let range = 8..10;
200        let row_mask = selection.row_mask(&range);
201
202        assert!(row_mask.mask().all_false());
203    }
204
205    #[test]
206    fn test_row_mask_all_true() {
207        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
208        let range = 3..7;
209        let row_mask = selection.row_mask(&range);
210
211        assert!(row_mask.mask().all_true());
212    }
213
214    #[test]
215    fn test_row_mask_zero() {
216        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
217        let range = 0..5;
218        let row_mask = selection.row_mask(&range);
219
220        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
221    }
222
223    #[cfg(feature = "roaring")]
224    mod roaring_tests {
225        use roaring::RoaringTreemap;
226
227        use super::*;
228
229        #[test]
230        fn test_roaring_include_basic() {
231            let mut roaring = RoaringTreemap::new();
232            roaring.insert(1);
233            roaring.insert(3);
234            roaring.insert(5);
235            roaring.insert(7);
236
237            let selection = super::super::Selection::IncludeRoaring(roaring);
238            let range = 1..8;
239            let row_mask = selection.row_mask(&range);
240
241            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
242        }
243
244        #[test]
245        fn test_roaring_include_slice() {
246            let mut roaring = RoaringTreemap::new();
247            roaring.insert(1);
248            roaring.insert(3);
249            roaring.insert(5);
250            roaring.insert(7);
251
252            let selection = super::super::Selection::IncludeRoaring(roaring);
253            let range = 3..6;
254            let row_mask = selection.row_mask(&range);
255
256            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
257        }
258
259        #[test]
260        fn test_roaring_include_disjoint() {
261            let mut roaring = RoaringTreemap::new();
262            roaring.insert(1);
263            roaring.insert(3);
264            roaring.insert(5);
265            roaring.insert(7);
266
267            let selection = super::super::Selection::IncludeRoaring(roaring);
268            let range = 8..10;
269            let row_mask = selection.row_mask(&range);
270
271            assert!(row_mask.mask().all_false());
272        }
273
274        #[test]
275        fn test_roaring_include_large_range() {
276            let mut roaring = RoaringTreemap::new();
277            // Insert a large number of indices
278            for i in (0..1000000).step_by(2) {
279                roaring.insert(i);
280            }
281
282            let selection = super::super::Selection::IncludeRoaring(roaring);
283            let range = 1000..2000;
284            let row_mask = selection.row_mask(&range);
285
286            // Should have 500 selected indices (every even number)
287            assert_eq!(row_mask.mask().true_count(), 500);
288        }
289
290        #[test]
291        fn test_roaring_exclude_basic() {
292            let mut roaring = RoaringTreemap::new();
293            roaring.insert(1);
294            roaring.insert(3);
295            roaring.insert(5);
296
297            let selection = super::super::Selection::ExcludeRoaring(roaring);
298            let range = 0..7;
299            let row_mask = selection.row_mask(&range);
300
301            // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6
302            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
303        }
304
305        #[test]
306        fn test_roaring_exclude_all() {
307            let mut roaring = RoaringTreemap::new();
308            // Exclude all indices in range
309            for i in 10..20 {
310                roaring.insert(i);
311            }
312
313            let selection = super::super::Selection::ExcludeRoaring(roaring);
314            let range = 10..20;
315            let row_mask = selection.row_mask(&range);
316
317            assert!(row_mask.mask().all_false());
318        }
319
320        #[test]
321        fn test_roaring_exclude_none() {
322            let mut roaring = RoaringTreemap::new();
323            roaring.insert(100);
324            roaring.insert(101);
325
326            let selection = super::super::Selection::ExcludeRoaring(roaring);
327            let range = 0..10;
328            let row_mask = selection.row_mask(&range);
329
330            // Nothing to exclude in this range
331            assert!(row_mask.mask().all_true());
332        }
333
334        #[test]
335        fn test_roaring_exclude_partial() {
336            let mut roaring = RoaringTreemap::new();
337            roaring.insert(5);
338            roaring.insert(6);
339            roaring.insert(7);
340            roaring.insert(15); // Outside range
341
342            let selection = super::super::Selection::ExcludeRoaring(roaring);
343            let range = 5..10;
344            let row_mask = selection.row_mask(&range);
345
346            // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4)
347            assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
348        }
349
350        #[test]
351        fn test_roaring_include_empty() {
352            let roaring = RoaringTreemap::new();
353            let selection = super::super::Selection::IncludeRoaring(roaring);
354            let range = 0..100;
355            let row_mask = selection.row_mask(&range);
356
357            assert!(row_mask.mask().all_false());
358        }
359
360        #[test]
361        fn test_roaring_exclude_empty() {
362            let roaring = RoaringTreemap::new();
363            let selection = super::super::Selection::ExcludeRoaring(roaring);
364            let range = 0..100;
365            let row_mask = selection.row_mask(&range);
366
367            assert!(row_mask.mask().all_true());
368        }
369
370        #[test]
371        fn test_roaring_include_boundary() {
372            let mut roaring = RoaringTreemap::new();
373            roaring.insert(0);
374            roaring.insert(99);
375
376            let selection = super::super::Selection::IncludeRoaring(roaring);
377            let range = 0..100;
378            let row_mask = selection.row_mask(&range);
379
380            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
381        }
382
383        #[test]
384        fn test_roaring_include_range_insertion() {
385            let mut roaring = RoaringTreemap::new();
386            // Use insert_range for efficiency
387            roaring.insert_range(10..20);
388            roaring.insert_range(30..40);
389
390            let selection = super::super::Selection::IncludeRoaring(roaring);
391            let range = 15..35;
392            let row_mask = selection.row_mask(&range);
393
394            // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19)
395            let expected: Vec<usize> = (0..5).chain(15..20).collect();
396            assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
397        }
398
399        #[test]
400        fn test_roaring_overflow_protection() {
401            let mut roaring = RoaringTreemap::new();
402            // Insert very large indices
403            roaring.insert(u64::MAX - 1);
404            roaring.insert(u64::MAX);
405
406            let selection = super::super::Selection::IncludeRoaring(roaring);
407            let range = u64::MAX - 10..u64::MAX;
408            let row_mask = selection.row_mask(&range);
409
410            // Should handle overflow gracefully
411            assert_eq!(row_mask.mask().true_count(), 1); // Only u64::MAX - 1 is in range
412        }
413
414        #[test]
415        fn test_roaring_exclude_overflow_protection() {
416            let mut roaring = RoaringTreemap::new();
417            roaring.insert(u64::MAX - 1);
418
419            let selection = super::super::Selection::ExcludeRoaring(roaring);
420            let range = u64::MAX - 10..u64::MAX;
421            let row_mask = selection.row_mask(&range);
422
423            // Should handle overflow gracefully, excluding index u64::MAX - 1
424            assert_eq!(row_mask.mask().true_count(), 9); // All except one
425        }
426
427        #[test]
428        fn test_roaring_include_vs_buffer_equivalence() {
429            // Test that RoaringTreemap and Buffer produce same results
430            let indices = vec![1, 3, 5, 7, 9];
431
432            let buffer_selection =
433                super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
434
435            let mut roaring = RoaringTreemap::new();
436            for idx in &indices {
437                roaring.insert(*idx);
438            }
439            let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
440
441            let range = 0..12;
442            let buffer_mask = buffer_selection.row_mask(&range);
443            let roaring_mask = roaring_selection.row_mask(&range);
444
445            assert_eq!(
446                buffer_mask.mask().values().unwrap().indices(),
447                roaring_mask.mask().values().unwrap().indices()
448            );
449        }
450
451        #[test]
452        fn test_roaring_exclude_vs_buffer_equivalence() {
453            // Test that ExcludeRoaring and ExcludeByIndex produce same results
454            let indices = vec![2, 4, 6, 8];
455
456            let buffer_selection =
457                super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
458
459            let mut roaring = RoaringTreemap::new();
460            for idx in &indices {
461                roaring.insert(*idx);
462            }
463            let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
464
465            let range = 0..10;
466            let buffer_mask = buffer_selection.row_mask(&range);
467            let roaring_mask = roaring_selection.row_mask(&range);
468
469            assert_eq!(
470                buffer_mask.mask().values().unwrap().indices(),
471                roaring_mask.mask().values().unwrap().indices()
472            );
473        }
474    }
475}