Skip to main content

vortex_scan/
selection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Defines a selection mask over a scan.
5
6use std::ops::Not;
7use std::ops::Range;
8
9use vortex_buffer::Buffer;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::row_mask::RowMask;
14
15/// A selection identifies a set of rows to include in the scan (in addition to applying any
16/// filter predicates).
17#[derive(Default, Clone, Debug)]
18pub enum Selection {
19    /// No selection, all rows are included.
20    #[default]
21    All,
22    /// A selection of sorted rows to include by index.
23    IncludeByIndex(Buffer<u64>),
24    /// A selection of sorted rows to exclude by index.
25    ExcludeByIndex(Buffer<u64>),
26    /// A selection of rows to include using a [`roaring::RoaringTreemap`].
27    IncludeRoaring(roaring::RoaringTreemap),
28    /// A selection of rows to exclude using a [`roaring::RoaringTreemap`].
29    ExcludeRoaring(roaring::RoaringTreemap),
30}
31
32impl Selection {
33    /// Return the row count for this selection.
34    pub fn row_count(&self, total_rows: u64) -> u64 {
35        match self {
36            Selection::All => total_rows,
37            Selection::IncludeByIndex(include) => include.len() as u64,
38            Selection::ExcludeByIndex(exclude) => total_rows.saturating_sub(exclude.len() as u64),
39            Selection::IncludeRoaring(roaring) => roaring.len(),
40            Selection::ExcludeRoaring(roaring) => total_rows.saturating_sub(roaring.len()),
41        }
42    }
43
44    /// Extract the [`RowMask`] for the given range from this selection.
45    pub fn row_mask(&self, range: &Range<u64>) -> RowMask {
46        if range.start >= range.end {
47            return RowMask::new(0, Mask::AllFalse(0));
48        }
49
50        // Saturating subtraction to prevent underflow, though range should be valid
51        let range_diff = range.end.saturating_sub(range.start);
52        let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
53            // If the range is too large for usize, cap it at usize::MAX
54            // This is a defensive measure; in practice, ranges should be reasonable
55            tracing::warn!(
56                "Range length {} exceeds usize::MAX, capping at usize::MAX",
57                range_diff
58            );
59            usize::MAX
60        });
61
62        match self {
63            Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
64            Selection::IncludeByIndex(include) => {
65                let mask = indices_range(range, include)
66                    .map(|idx_range| {
67                        Mask::from_indices(
68                            range_len,
69                            include
70                                .slice(idx_range)
71                                .iter()
72                                .map(|idx| {
73                                    idx.checked_sub(range.start).unwrap_or_else(|| {
74                                        vortex_panic!(
75                                            "index underflow, range: {:?}, idx: {:?}",
76                                            range,
77                                            idx
78                                        )
79                                    })
80                                })
81                                .filter_map(|idx| {
82                                    // Only include indices that fit in usize
83                                    usize::try_from(idx).ok()
84                                }),
85                        )
86                    })
87                    .unwrap_or_else(|| Mask::new_false(range_len));
88
89                RowMask::new(range.start, mask)
90            }
91            Selection::ExcludeByIndex(exclude) => {
92                let mask = Selection::IncludeByIndex(exclude.clone())
93                    .row_mask(range)
94                    .mask()
95                    .clone();
96                RowMask::new(range.start, mask.not())
97            }
98            Selection::IncludeRoaring(roaring) => {
99                use std::ops::BitAnd;
100
101                // First we perform a cheap is_disjoint check
102                let mut range_treemap = roaring::RoaringTreemap::new();
103                range_treemap.insert_range(range.clone());
104
105                if roaring.is_disjoint(&range_treemap) {
106                    return RowMask::new(range.start, Mask::new_false(range_len));
107                }
108
109                // Otherwise, intersect with the selected range and shift to relativize.
110                let roaring = roaring.bitand(range_treemap);
111                let mask = Mask::from_indices(
112                    range_len,
113                    roaring
114                        .iter()
115                        .map(|idx| {
116                            idx.checked_sub(range.start).unwrap_or_else(|| {
117                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
118                            })
119                        })
120                        .filter_map(|idx| {
121                            // Only include indices that fit in usize
122                            usize::try_from(idx).ok()
123                        }),
124                );
125
126                RowMask::new(range.start, mask)
127            }
128            Selection::ExcludeRoaring(roaring) => {
129                use std::ops::BitAnd;
130
131                let mut range_treemap = roaring::RoaringTreemap::new();
132                range_treemap.insert_range(range.clone());
133
134                // If all indices in range are excluded, return all false mask
135                if roaring.intersection_len(&range_treemap) == range_len as u64 {
136                    return RowMask::new(range.start, Mask::new_false(range_len));
137                }
138
139                // Otherwise, intersect with the selected range and shift to relativize.
140                let roaring = roaring.bitand(range_treemap);
141                let mask = Mask::from_excluded_indices(
142                    range_len,
143                    roaring
144                        .iter()
145                        .map(|idx| {
146                            idx.checked_sub(range.start).unwrap_or_else(|| {
147                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
148                            })
149                        })
150                        .filter_map(|idx| usize::try_from(idx).ok()),
151                );
152
153                RowMask::new(range.start, mask)
154            }
155        }
156    }
157}
158
159/// Find the positional range within row_indices that covers all rows in the given range.
160fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
161    if row_indices.first().is_some_and(|&first| first >= range.end)
162        || row_indices.last().is_some_and(|&last| range.start > last)
163    {
164        return None;
165    }
166
167    // For the given row range, find the indices that are within the row_indices.
168    let start_idx = row_indices
169        .binary_search(&range.start)
170        .unwrap_or_else(|x| x);
171    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
172
173    (start_idx != end_idx).then_some(start_idx..end_idx)
174}
175
176#[cfg(test)]
177mod tests {
178    use vortex_buffer::Buffer;
179
180    #[test]
181    fn test_row_mask_all() {
182        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
183        let range = 1..8;
184        let row_mask = selection.row_mask(&range);
185
186        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
187    }
188
189    #[test]
190    fn test_row_mask_slice() {
191        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
192        let range = 3..6;
193        let row_mask = selection.row_mask(&range);
194
195        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
196    }
197
198    #[test]
199    fn test_row_mask_exclusive() {
200        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
201        let range = 3..5;
202        let row_mask = selection.row_mask(&range);
203
204        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
205    }
206
207    #[test]
208    fn test_row_mask_all_false() {
209        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
210        let range = 8..10;
211        let row_mask = selection.row_mask(&range);
212
213        assert!(row_mask.mask().all_false());
214    }
215
216    #[test]
217    fn test_row_mask_all_true() {
218        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
219        let range = 3..7;
220        let row_mask = selection.row_mask(&range);
221
222        assert!(row_mask.mask().all_true());
223    }
224
225    #[test]
226    fn test_row_mask_zero() {
227        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
228        let range = 0..5;
229        let row_mask = selection.row_mask(&range);
230
231        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
232    }
233
234    mod roaring_tests {
235        use roaring::RoaringTreemap;
236
237        use super::*;
238
239        #[test]
240        fn test_roaring_include_basic() {
241            let mut roaring = RoaringTreemap::new();
242            roaring.insert(1);
243            roaring.insert(3);
244            roaring.insert(5);
245            roaring.insert(7);
246
247            let selection = super::super::Selection::IncludeRoaring(roaring);
248            let range = 1..8;
249            let row_mask = selection.row_mask(&range);
250
251            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
252        }
253
254        #[test]
255        fn test_roaring_include_slice() {
256            let mut roaring = RoaringTreemap::new();
257            roaring.insert(1);
258            roaring.insert(3);
259            roaring.insert(5);
260            roaring.insert(7);
261
262            let selection = super::super::Selection::IncludeRoaring(roaring);
263            let range = 3..6;
264            let row_mask = selection.row_mask(&range);
265
266            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
267        }
268
269        #[test]
270        fn test_roaring_include_disjoint() {
271            let mut roaring = RoaringTreemap::new();
272            roaring.insert(1);
273            roaring.insert(3);
274            roaring.insert(5);
275            roaring.insert(7);
276
277            let selection = super::super::Selection::IncludeRoaring(roaring);
278            let range = 8..10;
279            let row_mask = selection.row_mask(&range);
280
281            assert!(row_mask.mask().all_false());
282        }
283
284        #[test]
285        fn test_roaring_include_large_range() {
286            let mut roaring = RoaringTreemap::new();
287            // Insert a large number of indices
288            for i in (0..1000000).step_by(2) {
289                roaring.insert(i);
290            }
291
292            let selection = super::super::Selection::IncludeRoaring(roaring);
293            let range = 1000..2000;
294            let row_mask = selection.row_mask(&range);
295
296            // Should have 500 selected indices (every even number)
297            assert_eq!(row_mask.mask().true_count(), 500);
298        }
299
300        #[test]
301        fn test_roaring_exclude_basic() {
302            let mut roaring = RoaringTreemap::new();
303            roaring.insert(1);
304            roaring.insert(3);
305            roaring.insert(5);
306
307            let selection = super::super::Selection::ExcludeRoaring(roaring);
308            let range = 0..7;
309            let row_mask = selection.row_mask(&range);
310
311            // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6
312            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
313        }
314
315        #[test]
316        fn test_roaring_exclude_all() {
317            let mut roaring = RoaringTreemap::new();
318            // Exclude all indices in range
319            for i in 10..20 {
320                roaring.insert(i);
321            }
322
323            let selection = super::super::Selection::ExcludeRoaring(roaring);
324            let range = 10..20;
325            let row_mask = selection.row_mask(&range);
326
327            assert!(row_mask.mask().all_false());
328        }
329
330        #[test]
331        fn test_roaring_exclude_none() {
332            let mut roaring = RoaringTreemap::new();
333            roaring.insert(100);
334            roaring.insert(101);
335
336            let selection = super::super::Selection::ExcludeRoaring(roaring);
337            let range = 0..10;
338            let row_mask = selection.row_mask(&range);
339
340            // Nothing to exclude in this range
341            assert!(row_mask.mask().all_true());
342        }
343
344        #[test]
345        fn test_roaring_exclude_partial() {
346            let mut roaring = RoaringTreemap::new();
347            roaring.insert(5);
348            roaring.insert(6);
349            roaring.insert(7);
350            roaring.insert(15); // Outside range
351
352            let selection = super::super::Selection::ExcludeRoaring(roaring);
353            let range = 5..10;
354            let row_mask = selection.row_mask(&range);
355
356            // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4)
357            assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
358        }
359
360        #[test]
361        fn test_roaring_include_empty() {
362            let roaring = RoaringTreemap::new();
363            let selection = super::super::Selection::IncludeRoaring(roaring);
364            let range = 0..100;
365            let row_mask = selection.row_mask(&range);
366
367            assert!(row_mask.mask().all_false());
368        }
369
370        #[test]
371        fn test_roaring_exclude_empty() {
372            let roaring = RoaringTreemap::new();
373            let selection = super::super::Selection::ExcludeRoaring(roaring);
374            let range = 0..100;
375            let row_mask = selection.row_mask(&range);
376
377            assert!(row_mask.mask().all_true());
378        }
379
380        #[test]
381        fn test_roaring_include_boundary() {
382            let mut roaring = RoaringTreemap::new();
383            roaring.insert(0);
384            roaring.insert(99);
385
386            let selection = super::super::Selection::IncludeRoaring(roaring);
387            let range = 0..100;
388            let row_mask = selection.row_mask(&range);
389
390            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
391        }
392
393        #[test]
394        fn test_roaring_include_range_insertion() {
395            let mut roaring = RoaringTreemap::new();
396            // Use insert_range for efficiency
397            roaring.insert_range(10..20);
398            roaring.insert_range(30..40);
399
400            let selection = super::super::Selection::IncludeRoaring(roaring);
401            let range = 15..35;
402            let row_mask = selection.row_mask(&range);
403
404            // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19)
405            let expected: Vec<usize> = (0..5).chain(15..20).collect();
406            assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
407        }
408
409        #[test]
410        fn test_roaring_overflow_protection() {
411            let mut roaring = RoaringTreemap::new();
412            // Insert very large indices
413            roaring.insert(u64::MAX - 1);
414            roaring.insert(u64::MAX);
415
416            let selection = super::super::Selection::IncludeRoaring(roaring);
417            let range = u64::MAX - 10..u64::MAX;
418            let row_mask = selection.row_mask(&range);
419
420            // Should handle overflow gracefully
421            assert_eq!(row_mask.mask().true_count(), 1); // Only u64::MAX - 1 is in range
422        }
423
424        #[test]
425        fn test_roaring_exclude_overflow_protection() {
426            let mut roaring = RoaringTreemap::new();
427            roaring.insert(u64::MAX - 1);
428
429            let selection = super::super::Selection::ExcludeRoaring(roaring);
430            let range = u64::MAX - 10..u64::MAX;
431            let row_mask = selection.row_mask(&range);
432
433            // Should handle overflow gracefully, excluding index u64::MAX - 1
434            assert_eq!(row_mask.mask().true_count(), 9); // All except one
435        }
436
437        #[test]
438        fn test_roaring_include_vs_buffer_equivalence() {
439            // Test that RoaringTreemap and Buffer produce same results
440            let indices = vec![1, 3, 5, 7, 9];
441
442            let buffer_selection =
443                super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
444
445            let mut roaring = RoaringTreemap::new();
446            for idx in &indices {
447                roaring.insert(*idx);
448            }
449            let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
450
451            let range = 0..12;
452            let buffer_mask = buffer_selection.row_mask(&range);
453            let roaring_mask = roaring_selection.row_mask(&range);
454
455            assert_eq!(
456                buffer_mask.mask().values().unwrap().indices(),
457                roaring_mask.mask().values().unwrap().indices()
458            );
459        }
460
461        #[test]
462        fn test_roaring_exclude_vs_buffer_equivalence() {
463            // Test that ExcludeRoaring and ExcludeByIndex produce same results
464            let indices = vec![2, 4, 6, 8];
465
466            let buffer_selection =
467                super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
468
469            let mut roaring = RoaringTreemap::new();
470            for idx in &indices {
471                roaring.insert(*idx);
472            }
473            let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
474
475            let range = 0..10;
476            let buffer_mask = buffer_selection.row_mask(&range);
477            let roaring_mask = roaring_selection.row_mask(&range);
478
479            assert_eq!(
480                buffer_mask.mask().values().unwrap().indices(),
481                roaring_mask.mask().values().unwrap().indices()
482            );
483        }
484    }
485}