1#![feature(trusted_len)]
2#![deny(missing_docs)]
4mod bitops;
5mod eq;
6mod intersect_by_rank;
7mod iter_bools;
8
9use std::cmp::Ordering;
10use std::fmt::{Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
14use itertools::Itertools;
15
16pub enum AllOr<T> {
18 All,
20 None,
22 Some(T),
24}
25
26impl<T> AllOr<T> {
27 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
29 where
30 F: FnOnce() -> T,
31 G: FnOnce() -> T,
32 {
33 match self {
34 Self::Some(v) => v,
35 AllOr::All => all_true(),
36 AllOr::None => all_false(),
37 }
38 }
39}
40
41impl<T> AllOr<&T> {
42 pub fn cloned(self) -> AllOr<T>
44 where
45 T: Clone,
46 {
47 match self {
48 Self::All => AllOr::All,
49 Self::None => AllOr::None,
50 Self::Some(v) => AllOr::Some(v.clone()),
51 }
52 }
53}
54
55impl<T> Debug for AllOr<T>
56where
57 T: Debug,
58{
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::All => f.write_str("All"),
62 Self::None => f.write_str("None"),
63 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
64 }
65 }
66}
67
68impl<T> PartialEq for AllOr<T>
69where
70 T: PartialEq,
71{
72 fn eq(&self, other: &Self) -> bool {
73 match (self, other) {
74 (Self::All, Self::All) => true,
75 (Self::None, Self::None) => true,
76 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
77 _ => false,
78 }
79 }
80}
81
82impl<T> Eq for AllOr<T> where T: Eq {}
83
84#[derive(Clone, Debug)]
89pub enum Mask {
90 AllTrue(usize),
92 AllFalse(usize),
94 Values(Arc<MaskValues>),
96}
97
98#[derive(Debug)]
100pub struct MaskValues {
101 buffer: BooleanBuffer,
102
103 indices: OnceLock<Vec<usize>>,
106 slices: OnceLock<Vec<(usize, usize)>>,
107
108 true_count: usize,
110 density: f64,
112}
113
114impl MaskValues {
115 #[inline]
117 #[allow(clippy::len_without_is_empty)]
118 pub fn len(&self) -> usize {
119 self.buffer.len()
120 }
121
122 pub fn true_count(&self) -> usize {
124 self.true_count
125 }
126
127 pub fn boolean_buffer(&self) -> &BooleanBuffer {
129 &self.buffer
130 }
131
132 pub fn value(&self, index: usize) -> bool {
134 self.buffer.value(index)
135 }
136
137 pub fn indices(&self) -> &[usize] {
139 self.indices.get_or_init(|| {
140 if self.true_count == 0 {
141 return vec![];
142 }
143
144 if self.true_count == self.len() {
145 return (0..self.len()).collect();
146 }
147
148 if let Some(slices) = self.slices.get() {
149 let mut indices = Vec::with_capacity(self.true_count);
150 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
151 debug_assert!(indices.is_sorted());
152 assert_eq!(indices.len(), self.true_count);
153 return indices;
154 }
155
156 let mut indices = Vec::with_capacity(self.true_count);
157 indices.extend(self.buffer.set_indices());
158 debug_assert!(indices.is_sorted());
159 assert_eq!(indices.len(), self.true_count);
160 indices
161 })
162 }
163
164 #[allow(clippy::cast_possible_truncation)]
166 pub fn slices(&self) -> &[(usize, usize)] {
167 self.slices.get_or_init(|| {
168 if self.true_count == self.len() {
169 return vec![(0, self.len())];
170 }
171
172 self.buffer.set_slices().collect()
173 })
174 }
175
176 pub fn threshold_iter(&self, threshold: f64) -> MaskIter {
178 if self.density >= threshold {
179 MaskIter::Slices(self.slices())
180 } else {
181 MaskIter::Indices(self.indices())
182 }
183 }
184}
185
186impl Mask {
187 pub fn new_true(length: usize) -> Self {
189 Self::AllTrue(length)
190 }
191
192 pub fn new_false(length: usize) -> Self {
194 Self::AllFalse(length)
195 }
196
197 pub fn from_buffer(buffer: BooleanBuffer) -> Self {
199 let len = buffer.len();
200 let true_count = buffer.count_set_bits();
201
202 if true_count == 0 {
203 return Self::AllFalse(len);
204 }
205 if true_count == len {
206 return Self::AllTrue(len);
207 }
208
209 Self::Values(Arc::new(MaskValues {
210 buffer,
211 indices: Default::default(),
212 slices: Default::default(),
213 true_count,
214 density: true_count as f64 / len as f64,
215 }))
216 }
217
218 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
220 let true_count = indices.len();
221 assert!(indices.is_sorted(), "Mask indices must be sorted");
222 assert!(
223 indices.last().is_none_or(|&idx| idx < len),
224 "Mask indices must be in bounds (len={len})"
225 );
226
227 if true_count == 0 {
228 return Self::AllFalse(len);
229 }
230 if true_count == len {
231 return Self::AllTrue(len);
232 }
233
234 let mut buf = BooleanBufferBuilder::new(len);
235 buf.append_n(len, false);
237 indices.iter().for_each(|idx| buf.set_bit(*idx, true));
238 debug_assert_eq!(buf.len(), len);
239
240 Self::Values(Arc::new(MaskValues {
241 buffer: buf.finish(),
242 indices: OnceLock::from(indices),
243 slices: Default::default(),
244 true_count,
245 density: true_count as f64 / len as f64,
246 }))
247 }
248
249 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
252 Self::check_slices(len, &vec);
253 Self::from_slices_unchecked(len, vec)
254 }
255
256 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
257 #[cfg(debug_assertions)]
258 Self::check_slices(len, &slices);
259
260 let true_count = slices.iter().map(|(b, e)| e - b).sum();
261 if true_count == 0 {
262 return Self::AllFalse(len);
263 }
264 if true_count == len {
265 return Self::AllTrue(len);
266 }
267
268 let mut buf = BooleanBufferBuilder::new(len);
269 for (start, end) in slices.iter().copied() {
270 buf.append_n(start - buf.len(), false);
271 buf.append_n(end - start, true);
272 }
273 if let Some((_, end)) = slices.last() {
274 buf.append_n(len - end, false);
275 }
276 debug_assert_eq!(buf.len(), len);
277
278 Self::Values(Arc::new(MaskValues {
279 buffer: buf.finish(),
280 indices: Default::default(),
281 slices: OnceLock::from(slices),
282 true_count,
283 density: true_count as f64 / len as f64,
284 }))
285 }
286
287 #[inline(always)]
288 fn check_slices(len: usize, vec: &[(usize, usize)]) {
289 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
290 for (first, second) in vec.iter().tuple_windows() {
291 assert!(
292 first.0 < second.0,
293 "Slices must be sorted, got {:?} and {:?}",
294 first,
295 second
296 );
297 assert!(
298 first.1 <= second.0,
299 "Slices must be non-overlapping, got {:?} and {:?}",
300 first,
301 second
302 );
303 }
304 }
305
306 pub fn from_intersection_indices(
308 len: usize,
309 lhs: impl Iterator<Item = usize>,
310 rhs: impl Iterator<Item = usize>,
311 ) -> Self {
312 let mut intersection = Vec::with_capacity(len);
313 let mut lhs = lhs.peekable();
314 let mut rhs = rhs.peekable();
315 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
316 match l.cmp(&r) {
317 Ordering::Less => {
318 lhs.next();
319 }
320 Ordering::Greater => {
321 rhs.next();
322 }
323 Ordering::Equal => {
324 intersection.push(l);
325 lhs.next();
326 rhs.next();
327 }
328 }
329 }
330 Self::from_indices(len, intersection)
331 }
332
333 #[inline]
335 #[allow(clippy::len_without_is_empty)]
337 pub fn len(&self) -> usize {
338 match &self {
339 Self::AllTrue(len) => *len,
340 Self::AllFalse(len) => *len,
341 Self::Values(values) => values.buffer.len(),
342 }
343 }
344
345 #[inline]
347 pub fn true_count(&self) -> usize {
348 match &self {
349 Self::AllTrue(len) => *len,
350 Self::AllFalse(_) => 0,
351 Self::Values(values) => values.true_count,
352 }
353 }
354
355 #[inline]
357 pub fn false_count(&self) -> usize {
358 match &self {
359 Self::AllTrue(_) => 0,
360 Self::AllFalse(len) => *len,
361 Self::Values(values) => values.buffer.len() - values.true_count,
362 }
363 }
364
365 #[inline]
367 pub fn all_true(&self) -> bool {
368 match &self {
369 Self::AllTrue(_) => true,
370 Self::AllFalse(_) => false,
371 Self::Values(values) => values.buffer.len() == values.true_count,
372 }
373 }
374
375 #[inline]
377 pub fn all_false(&self) -> bool {
378 self.true_count() == 0
379 }
380
381 #[inline]
383 pub fn density(&self) -> f64 {
384 match &self {
385 Self::AllTrue(_) => 1.0,
386 Self::AllFalse(_) => 0.0,
387 Self::Values(values) => values.density,
388 }
389 }
390
391 pub fn value(&self, idx: usize) -> bool {
397 match self {
398 Mask::AllTrue(_) => true,
399 Mask::AllFalse(_) => false,
400 Mask::Values(values) => values.buffer.value(idx),
401 }
402 }
403
404 pub fn first(&self) -> Option<usize> {
406 match &self {
407 Self::AllTrue(len) => (*len > 0).then_some(0),
408 Self::AllFalse(_) => None,
409 Self::Values(values) => {
410 if let Some(indices) = values.indices.get() {
411 return indices.first().copied();
412 }
413 if let Some(slices) = values.slices.get() {
414 return slices.first().map(|(start, _)| *start);
415 }
416 values.buffer.set_indices().next()
417 }
418 }
419 }
420
421 pub fn slice(&self, offset: usize, length: usize) -> Self {
423 assert!(offset + length <= self.len());
424 match &self {
425 Self::AllTrue(_) => Self::new_true(length),
426 Self::AllFalse(_) => Self::new_false(length),
427 Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
428 }
429 }
430
431 pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
433 match &self {
434 Self::AllTrue(_) => AllOr::All,
435 Self::AllFalse(_) => AllOr::None,
436 Self::Values(values) => AllOr::Some(&values.buffer),
437 }
438 }
439
440 pub fn to_boolean_buffer(&self) -> BooleanBuffer {
443 match self {
444 Self::AllTrue(l) => BooleanBuffer::new_set(*l),
445 Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
446 Self::Values(values) => values.boolean_buffer().clone(),
447 }
448 }
449
450 pub fn to_null_buffer(&self) -> Option<NullBuffer> {
452 match self {
453 Mask::AllTrue(_) => None,
454 Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
455 Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
456 }
457 }
458
459 pub fn indices(&self) -> AllOr<&[usize]> {
461 match &self {
462 Self::AllTrue(_) => AllOr::All,
463 Self::AllFalse(_) => AllOr::None,
464 Self::Values(values) => AllOr::Some(values.indices()),
465 }
466 }
467
468 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
470 match &self {
471 Self::AllTrue(_) => AllOr::All,
472 Self::AllFalse(_) => AllOr::None,
473 Self::Values(values) => AllOr::Some(values.slices()),
474 }
475 }
476
477 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter> {
479 match &self {
480 Self::AllTrue(_) => AllOr::All,
481 Self::AllFalse(_) => AllOr::None,
482 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
483 }
484 }
485
486 pub fn values(&self) -> Option<&MaskValues> {
488 match self {
489 Self::Values(values) => Some(values),
490 _ => None,
491 }
492 }
493}
494
495pub enum MaskIter<'a> {
497 Indices(&'a [usize]),
499 Slices(&'a [(usize, usize)]),
501}
502
503impl From<BooleanBuffer> for Mask {
504 fn from(value: BooleanBuffer) -> Self {
505 Self::from_buffer(value)
506 }
507}
508
509impl FromIterator<bool> for Mask {
510 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
511 Self::from_buffer(BooleanBuffer::from_iter(iter))
512 }
513}
514
515#[cfg(test)]
516mod test {
517 use super::*;
518
519 #[test]
520 fn mask_all_true() {
521 let mask = Mask::new_true(5);
522 assert_eq!(mask.len(), 5);
523 assert_eq!(mask.true_count(), 5);
524 assert_eq!(mask.density(), 1.0);
525 assert_eq!(mask.indices(), AllOr::All);
526 assert_eq!(mask.slices(), AllOr::All);
527 assert_eq!(mask.boolean_buffer(), AllOr::All,);
528 }
529
530 #[test]
531 fn mask_all_false() {
532 let mask = Mask::new_false(5);
533 assert_eq!(mask.len(), 5);
534 assert_eq!(mask.true_count(), 0);
535 assert_eq!(mask.density(), 0.0);
536 assert_eq!(mask.indices(), AllOr::None);
537 assert_eq!(mask.slices(), AllOr::None);
538 assert_eq!(mask.boolean_buffer(), AllOr::None,);
539 }
540
541 #[test]
542 fn mask_from() {
543 let masks = [
544 Mask::from_indices(5, vec![0, 2, 3]),
545 Mask::from_slices(5, vec![(0, 1), (2, 4)]),
546 Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
547 ];
548
549 for mask in &masks {
550 assert_eq!(mask.len(), 5);
551 assert_eq!(mask.true_count(), 3);
552 assert_eq!(mask.density(), 0.6);
553 assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
554 assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
555 assert_eq!(
556 mask.boolean_buffer(),
557 AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
558 );
559 }
560 }
561}