vortex_scan/
selection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::{Not, Range};
5
6use vortex_buffer::Buffer;
7use vortex_error::VortexExpect;
8use vortex_mask::Mask;
9
10use crate::row_mask::RowMask;
11
12/// A selection identifies a set of rows to include in the scan (in addition to applying any
13/// filter predicates).
14#[derive(Default, Clone)]
15pub enum Selection {
16    /// No selection, all rows are included.
17    #[default]
18    All,
19    /// A selection of rows to include by index.
20    IncludeByIndex(Buffer<u64>),
21    /// A selection of rows to exclude by index.
22    ExcludeByIndex(Buffer<u64>),
23    /// A selection of rows to include using a [`roaring::RoaringTreemap`].
24    #[cfg(feature = "roaring")]
25    IncludeRoaring(roaring::RoaringTreemap),
26    /// A selection of rows to exclude using a [`roaring::RoaringTreemap`].
27    #[cfg(feature = "roaring")]
28    ExcludeRoaring(roaring::RoaringTreemap),
29}
30
31impl Selection {
32    /// Extract the [`RowMask`] for the given range from this selection.
33    pub(crate) fn row_mask(&self, range: &Range<u64>) -> RowMask {
34        let range_len = usize::try_from(range.end - range.start)
35            .vortex_expect("Range length does not fit into a usize");
36
37        match self {
38            Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
39            Selection::IncludeByIndex(include) => {
40                let mask = indices_range(range, include)
41                    .map(|idx_range| {
42                        Mask::from_indices(
43                            range_len,
44                            include
45                                .slice(idx_range)
46                                .iter()
47                                .map(|idx| *idx - range.start)
48                                .map(|idx| {
49                                    usize::try_from(idx)
50                                        .vortex_expect("Index does not fit into a usize")
51                                })
52                                .collect(),
53                        )
54                    })
55                    .unwrap_or_else(|| Mask::new_false(range_len));
56
57                RowMask::new(range.start, mask)
58            }
59            Selection::ExcludeByIndex(exclude) => {
60                let mask = Selection::IncludeByIndex(exclude.clone())
61                    .row_mask(range)
62                    .mask()
63                    .clone();
64                RowMask::new(range.start, mask.not())
65            }
66            #[cfg(feature = "roaring")]
67            Selection::IncludeRoaring(roaring) => {
68                use std::ops::BitAnd;
69
70                // First we perform a cheap is_disjoint check
71                let mut range_treemap = roaring::RoaringTreemap::new();
72                range_treemap.insert_range(range.clone());
73
74                if roaring.is_disjoint(&range_treemap) {
75                    return RowMask::new(range.start, Mask::new_false(range_len));
76                }
77
78                // Otherwise, intersect with the selected range and shift to relativize.
79                let roaring = roaring.bitand(range_treemap);
80                let mask = Mask::from_indices(
81                    range_len,
82                    roaring
83                        .iter()
84                        .map(|idx| idx - range.start)
85                        .map(|idx| {
86                            usize::try_from(idx).vortex_expect("Index does not fit into a usize")
87                        })
88                        .collect(),
89                );
90
91                RowMask::new(range.start, mask)
92            }
93            #[cfg(feature = "roaring")]
94            Selection::ExcludeRoaring(roaring) => {
95                use std::ops::BitAnd;
96
97                let mut range_treemap = roaring::RoaringTreemap::new();
98                range_treemap.insert_range(range.clone());
99
100                // If there are no deletions in the intersection, then we have an all true mask.
101                if roaring.intersection_len(&range_treemap) == range_len as u64 {
102                    return RowMask::new(range.start, Mask::new_true(range_len));
103                }
104
105                // Otherwise, intersect with the selected range and shift to relativize.
106                let roaring = roaring.bitand(range_treemap);
107                let mask = Mask::from_excluded_indices(
108                    range_len,
109                    roaring.iter().map(|idx| idx - range.start).map(|idx| {
110                        usize::try_from(idx).vortex_expect("Index does not fit into a usize")
111                    }),
112                );
113
114                RowMask::new(range.start, mask)
115            }
116        }
117    }
118}
119
120/// Find the positional range within row_indices that covers all rows in the given range.
121fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
122    if row_indices.first().is_some_and(|&first| first >= range.end)
123        || row_indices.last().is_some_and(|&last| range.start > last)
124    {
125        return None;
126    }
127
128    // For the given row range, find the indices that are within the row_indices.
129    let start_idx = row_indices
130        .binary_search(&range.start)
131        .unwrap_or_else(|x| x);
132    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
133
134    (start_idx != end_idx).then_some(start_idx..end_idx)
135}
136
137#[cfg(test)]
138mod tests {
139    use vortex_buffer::Buffer;
140
141    #[test]
142    fn test_row_mask_all() {
143        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
144        let range = 1..8;
145        let row_mask = selection.row_mask(&range);
146
147        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
148    }
149
150    #[test]
151    fn test_row_mask_slice() {
152        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
153        let range = 3..6;
154        let row_mask = selection.row_mask(&range);
155
156        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
157    }
158
159    #[test]
160    fn test_row_mask_exclusive() {
161        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
162        let range = 3..5;
163        let row_mask = selection.row_mask(&range);
164
165        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
166    }
167
168    #[test]
169    fn test_row_mask_all_false() {
170        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
171        let range = 8..10;
172        let row_mask = selection.row_mask(&range);
173
174        assert!(row_mask.mask().all_false());
175    }
176
177    #[test]
178    fn test_row_mask_all_true() {
179        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
180        let range = 3..7;
181        let row_mask = selection.row_mask(&range);
182
183        assert!(row_mask.mask().all_true());
184    }
185
186    #[test]
187    fn test_row_mask_zero() {
188        let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
189        let range = 0..5;
190        let row_mask = selection.row_mask(&range);
191
192        assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
193    }
194}