1use std::ops::{Not, Range};
5
6use vortex_buffer::Buffer;
7use vortex_error::VortexExpect;
8use vortex_mask::Mask;
9
10use crate::row_mask::RowMask;
11
12#[derive(Default, Clone)]
15pub enum Selection {
16 #[default]
18 All,
19 IncludeByIndex(Buffer<u64>),
21 ExcludeByIndex(Buffer<u64>),
23 #[cfg(feature = "roaring")]
25 IncludeRoaring(roaring::RoaringTreemap),
26 #[cfg(feature = "roaring")]
28 ExcludeRoaring(roaring::RoaringTreemap),
29}
30
31impl Selection {
32 pub(crate) fn row_mask(&self, range: &Range<u64>) -> RowMask {
34 let range_len = usize::try_from(range.end - range.start)
35 .vortex_expect("Range length does not fit into a usize");
36
37 match self {
38 Selection::All => RowMask::new(range.start, Mask::new_true(range_len)),
39 Selection::IncludeByIndex(include) => {
40 let mask = indices_range(range, include)
41 .map(|idx_range| {
42 Mask::from_indices(
43 range_len,
44 include
45 .slice(idx_range)
46 .iter()
47 .map(|idx| *idx - range.start)
48 .map(|idx| {
49 usize::try_from(idx)
50 .vortex_expect("Index does not fit into a usize")
51 })
52 .collect(),
53 )
54 })
55 .unwrap_or_else(|| Mask::new_false(range_len));
56
57 RowMask::new(range.start, mask)
58 }
59 Selection::ExcludeByIndex(exclude) => {
60 let mask = Selection::IncludeByIndex(exclude.clone())
61 .row_mask(range)
62 .mask()
63 .clone();
64 RowMask::new(range.start, mask.not())
65 }
66 #[cfg(feature = "roaring")]
67 Selection::IncludeRoaring(roaring) => {
68 use std::ops::BitAnd;
69
70 let mut range_treemap = roaring::RoaringTreemap::new();
72 range_treemap.insert_range(range.clone());
73
74 if roaring.is_disjoint(&range_treemap) {
75 return RowMask::new(range.start, Mask::new_false(range_len));
76 }
77
78 let roaring = roaring.bitand(range_treemap);
80 let mask = Mask::from_indices(
81 range_len,
82 roaring
83 .iter()
84 .map(|idx| idx - range.start)
85 .map(|idx| {
86 usize::try_from(idx).vortex_expect("Index does not fit into a usize")
87 })
88 .collect(),
89 );
90
91 RowMask::new(range.start, mask)
92 }
93 #[cfg(feature = "roaring")]
94 Selection::ExcludeRoaring(roaring) => {
95 use std::ops::BitAnd;
96
97 let mut range_treemap = roaring::RoaringTreemap::new();
98 range_treemap.insert_range(range.clone());
99
100 if roaring.intersection_len(&range_treemap) == range_len as u64 {
102 return RowMask::new(range.start, Mask::new_true(range_len));
103 }
104
105 let roaring = roaring.bitand(range_treemap);
107 let mask = Mask::from_excluded_indices(
108 range_len,
109 roaring.iter().map(|idx| idx - range.start).map(|idx| {
110 usize::try_from(idx).vortex_expect("Index does not fit into a usize")
111 }),
112 );
113
114 RowMask::new(range.start, mask)
115 }
116 }
117 }
118}
119
120fn indices_range(range: &Range<u64>, row_indices: &[u64]) -> Option<Range<usize>> {
122 if row_indices.first().is_some_and(|&first| first >= range.end)
123 || row_indices.last().is_some_and(|&last| range.start > last)
124 {
125 return None;
126 }
127
128 let start_idx = row_indices
130 .binary_search(&range.start)
131 .unwrap_or_else(|x| x);
132 let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
133
134 (start_idx != end_idx).then_some(start_idx..end_idx)
135}
136
137#[cfg(test)]
138mod tests {
139 use vortex_buffer::Buffer;
140
141 #[test]
142 fn test_row_mask_all() {
143 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
144 let range = 1..8;
145 let row_mask = selection.row_mask(&range);
146
147 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]);
148 }
149
150 #[test]
151 fn test_row_mask_slice() {
152 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
153 let range = 3..6;
154 let row_mask = selection.row_mask(&range);
155
156 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]);
157 }
158
159 #[test]
160 fn test_row_mask_exclusive() {
161 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
162 let range = 3..5;
163 let row_mask = selection.row_mask(&range);
164
165 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
166 }
167
168 #[test]
169 fn test_row_mask_all_false() {
170 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7]));
171 let range = 8..10;
172 let row_mask = selection.row_mask(&range);
173
174 assert!(row_mask.mask().all_false());
175 }
176
177 #[test]
178 fn test_row_mask_all_true() {
179 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 4, 5, 6]));
180 let range = 3..7;
181 let row_mask = selection.row_mask(&range);
182
183 assert!(row_mask.mask().all_true());
184 }
185
186 #[test]
187 fn test_row_mask_zero() {
188 let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![0]));
189 let range = 0..5;
190 let row_mask = selection.row_mask(&range);
191
192 assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]);
193 }
194}