Skip to main content

vortex_scan/
selection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5use std::ops::Range;
6
7use vortex_buffer::Buffer;
8use vortex_error::vortex_panic;
9use vortex_mask::Mask;
10
11use crate::row_mask::RowMask;
12
13/// A selection identifies a set of rows to include in the scan (in addition to applying any
14/// filter predicates).
15#[derive(Default, Clone, Debug)]
16pub enum Selection {
17    /// No selection, all rows are included.
18    #[default]
19    All,
20    /// A selection of sorted rows to include by index.
21    IncludeByIndex(Buffer<u64>),
22    /// A selection of sorted rows to exclude by index.
23    ExcludeByIndex(Buffer<u64>),
24    /// A selection of rows to include using a [`roaring::RoaringTreemap`].
25    IncludeRoaring(roaring::RoaringTreemap),
26    /// A selection of rows to exclude using a [`roaring::RoaringTreemap`].
27    ExcludeRoaring(roaring::RoaringTreemap),
28}
29
30impl Selection {
31    /// Return the row count for this selection.
32    pub fn row_count(&self, total_rows: u64) -> u64 {
33        match self {
34            Selection::All => total_rows,
35            Selection::IncludeByIndex(include) => include.len() as u64,
36            Selection::ExcludeByIndex(exclude) => total_rows.saturating_sub(exclude.len() as u64),
37            Selection::IncludeRoaring(roaring) => roaring.len(),
38            Selection::ExcludeRoaring(roaring) => total_rows.saturating_sub(roaring.len()),
39        }
40    }
41
42    /// Extract the [`RowMask`] for the given range from this selection.
43    pub(crate) fn row_mask(&self, range: &Range<u64>) -> RowMask {
44        // Saturating subtraction to prevent underflow, though range should be valid
45        let range_diff = range.end.saturating_sub(range.start);
46        let range_len = usize::try_from(range_diff).unwrap_or_else(|_| {
47            // If the range is too large for usize, cap it at usize::MAX
48            // This is a defensive measure; in practice, ranges should be reasonable
49            tracing::warn!(
50                "Range length {} exceeds usize::MAX, capping at usize::MAX",
51                range_diff
52            );
53            usize::MAX
54        });
55
56        match self {
57            Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
58            Selection::IncludeByIndex(include) => {
59                let mask = indices_range(range, include)
60                    .map(|idx_range| {
61                        Mask::from_indices(
62                            range_len,
63                            include
64                                .slice(idx_range)
65                                .iter()
66                                .map(|idx| {
67                                    idx.checked_sub(range.start).unwrap_or_else(|| {
68                                        vortex_panic!(
69                                            "index underflow, range: {:?}, idx: {:?}",
70                                            range,
71                                            idx
72                                        )
73                                    })
74                                })
75                                .filter_map(|idx| {
76                                    // Only include indices that fit in usize
77                                    usize::try_from(idx).ok()
78                                })
79                                .collect(),
80                        )
81                    })
82                    .unwrap_or_else(|| Mask::new_false(range_len));
83
84                RowMask::new(range.start, mask)
85            }
86            Selection::ExcludeByIndex(exclude) => {
87                let mask = Selection::IncludeByIndex(exclude.clone())
88                    .row_mask(range)
89                    .mask()
90                    .clone();
91                RowMask::new(range.start, mask.not())
92            }
93            Selection::IncludeRoaring(roaring) => {
94                use std::ops::BitAnd;
95
96                // First we perform a cheap is_disjoint check
97                let mut range_treemap = roaring::RoaringTreemap::new();
98                range_treemap.insert_range(range.clone());
99
100                if roaring.is_disjoint(&range_treemap) {
101                    return RowMask::new(range.start, Mask::new_false(range_len));
102                }
103
104                // Otherwise, intersect with the selected range and shift to relativize.
105                let roaring = roaring.bitand(range_treemap);
106                let mask = Mask::from_indices(
107                    range_len,
108                    roaring
109                        .iter()
110                        .map(|idx| {
111                            idx.checked_sub(range.start).unwrap_or_else(|| {
112                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
113                            })
114                        })
115                        .filter_map(|idx| {
116                            // Only include indices that fit in usize
117                            usize::try_from(idx).ok()
118                        })
119                        .collect(),
120                );
121
122                RowMask::new(range.start, mask)
123            }
124            Selection::ExcludeRoaring(roaring) => {
125                use std::ops::BitAnd;
126
127                let mut range_treemap = roaring::RoaringTreemap::new();
128                range_treemap.insert_range(range.clone());
129
130                // If all indices in range are excluded, return all false mask
131                if roaring.intersection_len(&range_treemap) == range_len as u64 {
132                    return RowMask::new(range.start, Mask::new_false(range_len));
133                }
134
135                // Otherwise, intersect with the selected range and shift to relativize.
136                let roaring = roaring.bitand(range_treemap);
137                let mask = Mask::from_excluded_indices(
138                    range_len,
139                    roaring
140                        .iter()
141                        .map(|idx| {
142                            idx.checked_sub(range.start).unwrap_or_else(|| {
143                                vortex_panic!("index underflow, range: {:?}, idx: {:?}", range, idx)
144                            })
145                        })
146                        .filter_map(|idx| usize::try_from(idx).ok()),
147                );
148
149                RowMask::new(range.start, mask)
150            }
151        }
152    }
153}
154
155/// Find the positional range within row_indices that covers all rows in the given range.
156fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
157    if row_indices.first().is_some_and(|&first| first >= range.end)
158        || row_indices.last().is_some_and(|&last| range.start > last)
159    {
160        return None;
161    }
162
163    // For the given row range, find the indices that are within the row_indices.
164    let start_idx = row_indices
165        .binary_search(&range.start)
166        .unwrap_or_else(|x| x);
167    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
168
169    (start_idx != end_idx).then_some(start_idx..end_idx)
170}
171
172#[cfg(test)]
173mod tests {
174    use vortex_buffer::Buffer;
175
176    #[test]
177    fn test_row_mask_all() {
178        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
179        let range = 1..8;
180        let row_mask = selection.row_mask(&range);
181
182        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
183    }
184
185    #[test]
186    fn test_row_mask_slice() {
187        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
188        let range = 3..6;
189        let row_mask = selection.row_mask(&range);
190
191        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
192    }
193
194    #[test]
195    fn test_row_mask_exclusive() {
196        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
197        let range = 3..5;
198        let row_mask = selection.row_mask(&range);
199
200        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
201    }
202
203    #[test]
204    fn test_row_mask_all_false() {
205        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
206        let range = 8..10;
207        let row_mask = selection.row_mask(&range);
208
209        assert!(row_mask.mask().all_false());
210    }
211
212    #[test]
213    fn test_row_mask_all_true() {
214        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
215        let range = 3..7;
216        let row_mask = selection.row_mask(&range);
217
218        assert!(row_mask.mask().all_true());
219    }
220
221    #[test]
222    fn test_row_mask_zero() {
223        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
224        let range = 0..5;
225        let row_mask = selection.row_mask(&range);
226
227        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
228    }
229
230    mod roaring_tests {
231        use roaring::RoaringTreemap;
232
233        use super::*;
234
235        #[test]
236        fn test_roaring_include_basic() {
237            let mut roaring = RoaringTreemap::new();
238            roaring.insert(1);
239            roaring.insert(3);
240            roaring.insert(5);
241            roaring.insert(7);
242
243            let selection = super::super::Selection::IncludeRoaring(roaring);
244            let range = 1..8;
245            let row_mask = selection.row_mask(&range);
246
247            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
248        }
249
250        #[test]
251        fn test_roaring_include_slice() {
252            let mut roaring = RoaringTreemap::new();
253            roaring.insert(1);
254            roaring.insert(3);
255            roaring.insert(5);
256            roaring.insert(7);
257
258            let selection = super::super::Selection::IncludeRoaring(roaring);
259            let range = 3..6;
260            let row_mask = selection.row_mask(&range);
261
262            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
263        }
264
265        #[test]
266        fn test_roaring_include_disjoint() {
267            let mut roaring = RoaringTreemap::new();
268            roaring.insert(1);
269            roaring.insert(3);
270            roaring.insert(5);
271            roaring.insert(7);
272
273            let selection = super::super::Selection::IncludeRoaring(roaring);
274            let range = 8..10;
275            let row_mask = selection.row_mask(&range);
276
277            assert!(row_mask.mask().all_false());
278        }
279
280        #[test]
281        fn test_roaring_include_large_range() {
282            let mut roaring = RoaringTreemap::new();
283            // Insert a large number of indices
284            for i in (0..1000000).step_by(2) {
285                roaring.insert(i);
286            }
287
288            let selection = super::super::Selection::IncludeRoaring(roaring);
289            let range = 1000..2000;
290            let row_mask = selection.row_mask(&range);
291
292            // Should have 500 selected indices (every even number)
293            assert_eq!(row_mask.mask().true_count(), 500);
294        }
295
296        #[test]
297        fn test_roaring_exclude_basic() {
298            let mut roaring = RoaringTreemap::new();
299            roaring.insert(1);
300            roaring.insert(3);
301            roaring.insert(5);
302
303            let selection = super::super::Selection::ExcludeRoaring(roaring);
304            let range = 0..7;
305            let row_mask = selection.row_mask(&range);
306
307            // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6
308            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
309        }
310
311        #[test]
312        fn test_roaring_exclude_all() {
313            let mut roaring = RoaringTreemap::new();
314            // Exclude all indices in range
315            for i in 10..20 {
316                roaring.insert(i);
317            }
318
319            let selection = super::super::Selection::ExcludeRoaring(roaring);
320            let range = 10..20;
321            let row_mask = selection.row_mask(&range);
322
323            assert!(row_mask.mask().all_false());
324        }
325
326        #[test]
327        fn test_roaring_exclude_none() {
328            let mut roaring = RoaringTreemap::new();
329            roaring.insert(100);
330            roaring.insert(101);
331
332            let selection = super::super::Selection::ExcludeRoaring(roaring);
333            let range = 0..10;
334            let row_mask = selection.row_mask(&range);
335
336            // Nothing to exclude in this range
337            assert!(row_mask.mask().all_true());
338        }
339
340        #[test]
341        fn test_roaring_exclude_partial() {
342            let mut roaring = RoaringTreemap::new();
343            roaring.insert(5);
344            roaring.insert(6);
345            roaring.insert(7);
346            roaring.insert(15); // Outside range
347
348            let selection = super::super::Selection::ExcludeRoaring(roaring);
349            let range = 5..10;
350            let row_mask = selection.row_mask(&range);
351
352            // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4)
353            assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]);
354        }
355
356        #[test]
357        fn test_roaring_include_empty() {
358            let roaring = RoaringTreemap::new();
359            let selection = super::super::Selection::IncludeRoaring(roaring);
360            let range = 0..100;
361            let row_mask = selection.row_mask(&range);
362
363            assert!(row_mask.mask().all_false());
364        }
365
366        #[test]
367        fn test_roaring_exclude_empty() {
368            let roaring = RoaringTreemap::new();
369            let selection = super::super::Selection::ExcludeRoaring(roaring);
370            let range = 0..100;
371            let row_mask = selection.row_mask(&range);
372
373            assert!(row_mask.mask().all_true());
374        }
375
376        #[test]
377        fn test_roaring_include_boundary() {
378            let mut roaring = RoaringTreemap::new();
379            roaring.insert(0);
380            roaring.insert(99);
381
382            let selection = super::super::Selection::IncludeRoaring(roaring);
383            let range = 0..100;
384            let row_mask = selection.row_mask(&range);
385
386            assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]);
387        }
388
389        #[test]
390        fn test_roaring_include_range_insertion() {
391            let mut roaring = RoaringTreemap::new();
392            // Use insert_range for efficiency
393            roaring.insert_range(10..20);
394            roaring.insert_range(30..40);
395
396            let selection = super::super::Selection::IncludeRoaring(roaring);
397            let range = 15..35;
398            let row_mask = selection.row_mask(&range);
399
400            // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19)
401            let expected: Vec<usize> = (0..5).chain(15..20).collect();
402            assert_eq!(row_mask.mask().values().unwrap().indices(), &expected);
403        }
404
405        #[test]
406        fn test_roaring_overflow_protection() {
407            let mut roaring = RoaringTreemap::new();
408            // Insert very large indices
409            roaring.insert(u64::MAX - 1);
410            roaring.insert(u64::MAX);
411
412            let selection = super::super::Selection::IncludeRoaring(roaring);
413            let range = u64::MAX - 10..u64::MAX;
414            let row_mask = selection.row_mask(&range);
415
416            // Should handle overflow gracefully
417            assert_eq!(row_mask.mask().true_count(), 1); // Only u64::MAX - 1 is in range
418        }
419
420        #[test]
421        fn test_roaring_exclude_overflow_protection() {
422            let mut roaring = RoaringTreemap::new();
423            roaring.insert(u64::MAX - 1);
424
425            let selection = super::super::Selection::ExcludeRoaring(roaring);
426            let range = u64::MAX - 10..u64::MAX;
427            let row_mask = selection.row_mask(&range);
428
429            // Should handle overflow gracefully, excluding index u64::MAX - 1
430            assert_eq!(row_mask.mask().true_count(), 9); // All except one
431        }
432
433        #[test]
434        fn test_roaring_include_vs_buffer_equivalence() {
435            // Test that RoaringTreemap and Buffer produce same results
436            let indices = vec![1, 3, 5, 7, 9];
437
438            let buffer_selection =
439                super::super::Selection::IncludeByIndex(Buffer::from_iter(indices.clone()));
440
441            let mut roaring = RoaringTreemap::new();
442            for idx in &indices {
443                roaring.insert(*idx);
444            }
445            let roaring_selection = super::super::Selection::IncludeRoaring(roaring);
446
447            let range = 0..12;
448            let buffer_mask = buffer_selection.row_mask(&range);
449            let roaring_mask = roaring_selection.row_mask(&range);
450
451            assert_eq!(
452                buffer_mask.mask().values().unwrap().indices(),
453                roaring_mask.mask().values().unwrap().indices()
454            );
455        }
456
457        #[test]
458        fn test_roaring_exclude_vs_buffer_equivalence() {
459            // Test that ExcludeRoaring and ExcludeByIndex produce same results
460            let indices = vec![2, 4, 6, 8];
461
462            let buffer_selection =
463                super::super::Selection::ExcludeByIndex(Buffer::from_iter(indices.clone()));
464
465            let mut roaring = RoaringTreemap::new();
466            for idx in &indices {
467                roaring.insert(*idx);
468            }
469            let roaring_selection = super::super::Selection::ExcludeRoaring(roaring);
470
471            let range = 0..10;
472            let buffer_mask = buffer_selection.row_mask(&range);
473            let roaring_mask = roaring_selection.row_mask(&range);
474
475            assert_eq!(
476                buffer_mask.mask().values().unwrap().indices(),
477                roaring_mask.mask().values().unwrap().indices()
478            );
479        }
480    }
481}