1#![feature(trusted_len)]
2#![deny(missing_docs)]
4mod bitops;
5mod eq;
6mod intersect_by_rank;
7mod iter_bools;
8
9use std::cmp::Ordering;
10use std::fmt::{Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
14use itertools::Itertools;
15
16pub enum AllOr<T> {
18 All,
20 None,
22 Some(T),
24}
25
26impl<T> AllOr<T> {
27 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
29 where
30 F: FnOnce() -> T,
31 G: FnOnce() -> T,
32 {
33 match self {
34 Self::Some(v) => v,
35 AllOr::All => all_true(),
36 AllOr::None => all_false(),
37 }
38 }
39}
40
41impl<T> AllOr<&T> {
42 pub fn cloned(self) -> AllOr<T>
44 where
45 T: Clone,
46 {
47 match self {
48 Self::All => AllOr::All,
49 Self::None => AllOr::None,
50 Self::Some(v) => AllOr::Some(v.clone()),
51 }
52 }
53}
54
55impl<T> Debug for AllOr<T>
56where
57 T: Debug,
58{
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::All => f.write_str("All"),
62 Self::None => f.write_str("None"),
63 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
64 }
65 }
66}
67
68impl<T> PartialEq for AllOr<T>
69where
70 T: PartialEq,
71{
72 fn eq(&self, other: &Self) -> bool {
73 match (self, other) {
74 (Self::All, Self::All) => true,
75 (Self::None, Self::None) => true,
76 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
77 _ => false,
78 }
79 }
80}
81
82impl<T> Eq for AllOr<T> where T: Eq {}
83
84#[derive(Clone, Debug)]
89pub enum Mask {
90 AllTrue(usize),
92 AllFalse(usize),
94 Values(Arc<MaskValues>),
96}
97
98#[derive(Debug)]
100pub struct MaskValues {
101 buffer: BooleanBuffer,
102
103 indices: OnceLock<Vec<usize>>,
106 slices: OnceLock<Vec<(usize, usize)>>,
107
108 true_count: usize,
110 density: f64,
112}
113
114impl MaskValues {
115 #[inline]
117 #[allow(clippy::len_without_is_empty)]
118 pub fn len(&self) -> usize {
119 self.buffer.len()
120 }
121
122 pub fn true_count(&self) -> usize {
124 self.true_count
125 }
126
127 pub fn boolean_buffer(&self) -> &BooleanBuffer {
129 &self.buffer
130 }
131
132 pub fn value(&self, index: usize) -> bool {
134 self.buffer.value(index)
135 }
136
137 pub fn indices(&self) -> &[usize] {
139 self.indices.get_or_init(|| {
140 if self.true_count == 0 {
141 return vec![];
142 }
143
144 if self.true_count == self.len() {
145 return (0..self.len()).collect();
146 }
147
148 if let Some(slices) = self.slices.get() {
149 let mut indices = Vec::with_capacity(self.true_count);
150 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
151 debug_assert!(indices.is_sorted());
152 assert_eq!(indices.len(), self.true_count);
153 return indices;
154 }
155
156 let mut indices = Vec::with_capacity(self.true_count);
157 indices.extend(self.buffer.set_indices());
158 debug_assert!(indices.is_sorted());
159 assert_eq!(indices.len(), self.true_count);
160 indices
161 })
162 }
163
164 #[allow(clippy::cast_possible_truncation)]
166 pub fn slices(&self) -> &[(usize, usize)] {
167 self.slices.get_or_init(|| {
168 if self.true_count == self.len() {
169 return vec![(0, self.len())];
170 }
171
172 self.buffer.set_slices().collect()
173 })
174 }
175
176 pub fn threshold_iter(&self, threshold: f64) -> MaskIter {
178 if self.density >= threshold {
179 MaskIter::Slices(self.slices())
180 } else {
181 MaskIter::Indices(self.indices())
182 }
183 }
184}
185
186impl Mask {
187 pub fn new_true(length: usize) -> Self {
189 Self::AllTrue(length)
190 }
191
192 pub fn new_false(length: usize) -> Self {
194 Self::AllFalse(length)
195 }
196
197 pub fn from_buffer(buffer: BooleanBuffer) -> Self {
199 let len = buffer.len();
200 let true_count = buffer.count_set_bits();
201
202 if true_count == 0 {
203 return Self::AllFalse(len);
204 }
205 if true_count == len {
206 return Self::AllTrue(len);
207 }
208
209 Self::Values(Arc::new(MaskValues {
210 buffer,
211 indices: Default::default(),
212 slices: Default::default(),
213 true_count,
214 density: true_count as f64 / len as f64,
215 }))
216 }
217
218 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
221 let true_count = indices.len();
222 assert!(indices.is_sorted(), "Mask indices must be sorted");
223 assert!(
224 indices.last().is_none_or(|&idx| idx < len),
225 "Mask indices must be in bounds (len={len})"
226 );
227
228 if true_count == 0 {
229 return Self::AllFalse(len);
230 }
231 if true_count == len {
232 return Self::AllTrue(len);
233 }
234
235 let mut buf = BooleanBufferBuilder::new(len);
236 buf.append_n(len, false);
238 indices.iter().for_each(|idx| buf.set_bit(*idx, true));
239 debug_assert_eq!(buf.len(), len);
240
241 Self::Values(Arc::new(MaskValues {
242 buffer: buf.finish(),
243 indices: OnceLock::from(indices),
244 slices: Default::default(),
245 true_count,
246 density: true_count as f64 / len as f64,
247 }))
248 }
249
250 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
252 let mut buf = BooleanBufferBuilder::new(len);
253 buf.append_n(len, true);
254
255 let mut false_count: usize = 0;
256 indices.into_iter().for_each(|idx| {
257 buf.set_bit(idx, false);
258 false_count += 1;
259 });
260 debug_assert_eq!(buf.len(), len);
261 let true_count = len - false_count;
262
263 Self::Values(Arc::new(MaskValues {
264 buffer: buf.finish(),
265 indices: Default::default(),
266 slices: Default::default(),
267 true_count,
268 density: true_count as f64 / len as f64,
269 }))
270 }
271
272 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
275 Self::check_slices(len, &vec);
276 Self::from_slices_unchecked(len, vec)
277 }
278
279 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
280 #[cfg(debug_assertions)]
281 Self::check_slices(len, &slices);
282
283 let true_count = slices.iter().map(|(b, e)| e - b).sum();
284 if true_count == 0 {
285 return Self::AllFalse(len);
286 }
287 if true_count == len {
288 return Self::AllTrue(len);
289 }
290
291 let mut buf = BooleanBufferBuilder::new(len);
292 for (start, end) in slices.iter().copied() {
293 buf.append_n(start - buf.len(), false);
294 buf.append_n(end - start, true);
295 }
296 if let Some((_, end)) = slices.last() {
297 buf.append_n(len - end, false);
298 }
299 debug_assert_eq!(buf.len(), len);
300
301 Self::Values(Arc::new(MaskValues {
302 buffer: buf.finish(),
303 indices: Default::default(),
304 slices: OnceLock::from(slices),
305 true_count,
306 density: true_count as f64 / len as f64,
307 }))
308 }
309
310 #[inline(always)]
311 fn check_slices(len: usize, vec: &[(usize, usize)]) {
312 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
313 for (first, second) in vec.iter().tuple_windows() {
314 assert!(
315 first.0 < second.0,
316 "Slices must be sorted, got {first:?} and {second:?}"
317 );
318 assert!(
319 first.1 <= second.0,
320 "Slices must be non-overlapping, got {first:?} and {second:?}"
321 );
322 }
323 }
324
325 pub fn from_intersection_indices(
327 len: usize,
328 lhs: impl Iterator<Item = usize>,
329 rhs: impl Iterator<Item = usize>,
330 ) -> Self {
331 let mut intersection = Vec::with_capacity(len);
332 let mut lhs = lhs.peekable();
333 let mut rhs = rhs.peekable();
334 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
335 match l.cmp(&r) {
336 Ordering::Less => {
337 lhs.next();
338 }
339 Ordering::Greater => {
340 rhs.next();
341 }
342 Ordering::Equal => {
343 intersection.push(l);
344 lhs.next();
345 rhs.next();
346 }
347 }
348 }
349 Self::from_indices(len, intersection)
350 }
351
352 #[inline]
354 #[allow(clippy::len_without_is_empty)]
356 pub fn len(&self) -> usize {
357 match &self {
358 Self::AllTrue(len) => *len,
359 Self::AllFalse(len) => *len,
360 Self::Values(values) => values.buffer.len(),
361 }
362 }
363
364 #[inline]
366 pub fn true_count(&self) -> usize {
367 match &self {
368 Self::AllTrue(len) => *len,
369 Self::AllFalse(_) => 0,
370 Self::Values(values) => values.true_count,
371 }
372 }
373
374 #[inline]
376 pub fn false_count(&self) -> usize {
377 match &self {
378 Self::AllTrue(_) => 0,
379 Self::AllFalse(len) => *len,
380 Self::Values(values) => values.buffer.len() - values.true_count,
381 }
382 }
383
384 #[inline]
386 pub fn all_true(&self) -> bool {
387 match &self {
388 Self::AllTrue(_) => true,
389 Self::AllFalse(_) => false,
390 Self::Values(values) => values.buffer.len() == values.true_count,
391 }
392 }
393
394 #[inline]
396 pub fn all_false(&self) -> bool {
397 self.true_count() == 0
398 }
399
400 #[inline]
402 pub fn density(&self) -> f64 {
403 match &self {
404 Self::AllTrue(_) => 1.0,
405 Self::AllFalse(_) => 0.0,
406 Self::Values(values) => values.density,
407 }
408 }
409
410 pub fn value(&self, idx: usize) -> bool {
416 match self {
417 Mask::AllTrue(_) => true,
418 Mask::AllFalse(_) => false,
419 Mask::Values(values) => values.buffer.value(idx),
420 }
421 }
422
423 pub fn first(&self) -> Option<usize> {
425 match &self {
426 Self::AllTrue(len) => (*len > 0).then_some(0),
427 Self::AllFalse(_) => None,
428 Self::Values(values) => {
429 if let Some(indices) = values.indices.get() {
430 return indices.first().copied();
431 }
432 if let Some(slices) = values.slices.get() {
433 return slices.first().map(|(start, _)| *start);
434 }
435 values.buffer.set_indices().next()
436 }
437 }
438 }
439
440 pub fn slice(&self, offset: usize, length: usize) -> Self {
442 assert!(offset + length <= self.len());
443 match &self {
444 Self::AllTrue(_) => Self::new_true(length),
445 Self::AllFalse(_) => Self::new_false(length),
446 Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
447 }
448 }
449
450 pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
452 match &self {
453 Self::AllTrue(_) => AllOr::All,
454 Self::AllFalse(_) => AllOr::None,
455 Self::Values(values) => AllOr::Some(&values.buffer),
456 }
457 }
458
459 pub fn to_boolean_buffer(&self) -> BooleanBuffer {
462 match self {
463 Self::AllTrue(l) => BooleanBuffer::new_set(*l),
464 Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
465 Self::Values(values) => values.boolean_buffer().clone(),
466 }
467 }
468
469 pub fn to_null_buffer(&self) -> Option<NullBuffer> {
471 match self {
472 Mask::AllTrue(_) => None,
473 Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
474 Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
475 }
476 }
477
478 pub fn indices(&self) -> AllOr<&[usize]> {
480 match &self {
481 Self::AllTrue(_) => AllOr::All,
482 Self::AllFalse(_) => AllOr::None,
483 Self::Values(values) => AllOr::Some(values.indices()),
484 }
485 }
486
487 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
489 match &self {
490 Self::AllTrue(_) => AllOr::All,
491 Self::AllFalse(_) => AllOr::None,
492 Self::Values(values) => AllOr::Some(values.slices()),
493 }
494 }
495
496 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter> {
498 match &self {
499 Self::AllTrue(_) => AllOr::All,
500 Self::AllFalse(_) => AllOr::None,
501 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
502 }
503 }
504
505 pub fn values(&self) -> Option<&MaskValues> {
507 match self {
508 Self::Values(values) => Some(values),
509 _ => None,
510 }
511 }
512}
513
514pub enum MaskIter<'a> {
516 Indices(&'a [usize]),
518 Slices(&'a [(usize, usize)]),
520}
521
522impl From<BooleanBuffer> for Mask {
523 fn from(value: BooleanBuffer) -> Self {
524 Self::from_buffer(value)
525 }
526}
527
528impl FromIterator<bool> for Mask {
529 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
530 Self::from_buffer(BooleanBuffer::from_iter(iter))
531 }
532}
533
534impl FromIterator<Mask> for Mask {
535 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
536 let masks = iter.into_iter().collect::<Vec<_>>();
537 let total_length = masks.iter().map(|v| v.len()).sum();
538
539 if masks.iter().all(|v| v.all_true()) {
541 return Self::AllTrue(total_length);
542 }
543 if masks.iter().all(|v| v.all_false()) {
545 return Self::AllFalse(total_length);
546 }
547
548 let mut buffer = BooleanBufferBuilder::new(total_length);
550 for mask in masks {
551 match mask {
552 Mask::AllTrue(count) => buffer.append_n(count, true),
553 Mask::AllFalse(count) => buffer.append_n(count, false),
554 Mask::Values(values) => {
555 buffer.append_buffer(values.boolean_buffer());
556 }
557 };
558 }
559 Self::from_buffer(buffer.finish())
560 }
561}
562
563#[cfg(test)]
564mod test {
565 use super::*;
566
567 #[test]
568 fn mask_all_true() {
569 let mask = Mask::new_true(5);
570 assert_eq!(mask.len(), 5);
571 assert_eq!(mask.true_count(), 5);
572 assert_eq!(mask.density(), 1.0);
573 assert_eq!(mask.indices(), AllOr::All);
574 assert_eq!(mask.slices(), AllOr::All);
575 assert_eq!(mask.boolean_buffer(), AllOr::All,);
576 }
577
578 #[test]
579 fn mask_all_false() {
580 let mask = Mask::new_false(5);
581 assert_eq!(mask.len(), 5);
582 assert_eq!(mask.true_count(), 0);
583 assert_eq!(mask.density(), 0.0);
584 assert_eq!(mask.indices(), AllOr::None);
585 assert_eq!(mask.slices(), AllOr::None);
586 assert_eq!(mask.boolean_buffer(), AllOr::None,);
587 }
588
589 #[test]
590 fn mask_from() {
591 let masks = [
592 Mask::from_indices(5, vec![0, 2, 3]),
593 Mask::from_slices(5, vec![(0, 1), (2, 4)]),
594 Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
595 ];
596
597 for mask in &masks {
598 assert_eq!(mask.len(), 5);
599 assert_eq!(mask.true_count(), 3);
600 assert_eq!(mask.density(), 0.6);
601 assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
602 assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
603 assert_eq!(
604 mask.boolean_buffer(),
605 AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
606 );
607 }
608 }
609}