1#![feature(trusted_len)]
2#![deny(missing_docs)]
4mod bitops;
5mod eq;
6mod intersect_by_rank;
7mod iter_bools;
8
9use std::cmp::Ordering;
10use std::fmt::{Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
14use itertools::Itertools;
15
16pub enum AllOr<T> {
18 All,
20 None,
22 Some(T),
24}
25
26impl<T> AllOr<T> {
27 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
29 where
30 F: FnOnce() -> T,
31 G: FnOnce() -> T,
32 {
33 match self {
34 Self::Some(v) => v,
35 AllOr::All => all_true(),
36 AllOr::None => all_false(),
37 }
38 }
39}
40
41impl<T> AllOr<&T> {
42 pub fn cloned(self) -> AllOr<T>
44 where
45 T: Clone,
46 {
47 match self {
48 Self::All => AllOr::All,
49 Self::None => AllOr::None,
50 Self::Some(v) => AllOr::Some(v.clone()),
51 }
52 }
53}
54
55impl<T> Debug for AllOr<T>
56where
57 T: Debug,
58{
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::All => f.write_str("All"),
62 Self::None => f.write_str("None"),
63 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
64 }
65 }
66}
67
68impl<T> PartialEq for AllOr<T>
69where
70 T: PartialEq,
71{
72 fn eq(&self, other: &Self) -> bool {
73 match (self, other) {
74 (Self::All, Self::All) => true,
75 (Self::None, Self::None) => true,
76 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
77 _ => false,
78 }
79 }
80}
81
82impl<T> Eq for AllOr<T> where T: Eq {}
83
84#[derive(Clone, Debug)]
89pub enum Mask {
90 AllTrue(usize),
92 AllFalse(usize),
94 Values(Arc<MaskValues>),
96}
97
98#[derive(Debug)]
100pub struct MaskValues {
101 buffer: BooleanBuffer,
102
103 indices: OnceLock<Vec<usize>>,
106 slices: OnceLock<Vec<(usize, usize)>>,
107
108 true_count: usize,
110 density: f64,
112}
113
114impl MaskValues {
115 #[inline]
117 #[allow(clippy::len_without_is_empty)]
118 pub fn len(&self) -> usize {
119 self.buffer.len()
120 }
121
122 pub fn true_count(&self) -> usize {
124 self.true_count
125 }
126
127 pub fn boolean_buffer(&self) -> &BooleanBuffer {
129 &self.buffer
130 }
131
132 pub fn value(&self, index: usize) -> bool {
134 self.buffer.value(index)
135 }
136
137 pub fn indices(&self) -> &[usize] {
139 self.indices.get_or_init(|| {
140 if self.true_count == 0 {
141 return vec![];
142 }
143
144 if self.true_count == self.len() {
145 return (0..self.len()).collect();
146 }
147
148 if let Some(slices) = self.slices.get() {
149 let mut indices = Vec::with_capacity(self.true_count);
150 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
151 debug_assert!(indices.is_sorted());
152 assert_eq!(indices.len(), self.true_count);
153 return indices;
154 }
155
156 let mut indices = Vec::with_capacity(self.true_count);
157 indices.extend(self.buffer.set_indices());
158 debug_assert!(indices.is_sorted());
159 assert_eq!(indices.len(), self.true_count);
160 indices
161 })
162 }
163
164 #[allow(clippy::cast_possible_truncation)]
166 pub fn slices(&self) -> &[(usize, usize)] {
167 self.slices.get_or_init(|| {
168 if self.true_count == self.len() {
169 return vec![(0, self.len())];
170 }
171
172 self.buffer.set_slices().collect()
173 })
174 }
175
176 pub fn threshold_iter(&self, threshold: f64) -> MaskIter {
178 if self.density >= threshold {
179 MaskIter::Slices(self.slices())
180 } else {
181 MaskIter::Indices(self.indices())
182 }
183 }
184}
185
186impl Mask {
187 pub fn new_true(length: usize) -> Self {
189 Self::AllTrue(length)
190 }
191
192 pub fn new_false(length: usize) -> Self {
194 Self::AllFalse(length)
195 }
196
197 pub fn from_buffer(buffer: BooleanBuffer) -> Self {
199 let len = buffer.len();
200 let true_count = buffer.count_set_bits();
201
202 if true_count == 0 {
203 return Self::AllFalse(len);
204 }
205 if true_count == len {
206 return Self::AllTrue(len);
207 }
208
209 Self::Values(Arc::new(MaskValues {
210 buffer,
211 indices: Default::default(),
212 slices: Default::default(),
213 true_count,
214 density: true_count as f64 / len as f64,
215 }))
216 }
217
218 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
221 let true_count = indices.len();
222 assert!(indices.is_sorted(), "Mask indices must be sorted");
223 assert!(
224 indices.last().is_none_or(|&idx| idx < len),
225 "Mask indices must be in bounds (len={len})"
226 );
227
228 if true_count == 0 {
229 return Self::AllFalse(len);
230 }
231 if true_count == len {
232 return Self::AllTrue(len);
233 }
234
235 let mut buf = BooleanBufferBuilder::new(len);
236 buf.append_n(len, false);
238 indices.iter().for_each(|idx| buf.set_bit(*idx, true));
239 debug_assert_eq!(buf.len(), len);
240
241 Self::Values(Arc::new(MaskValues {
242 buffer: buf.finish(),
243 indices: OnceLock::from(indices),
244 slices: Default::default(),
245 true_count,
246 density: true_count as f64 / len as f64,
247 }))
248 }
249
250 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
252 let mut buf = BooleanBufferBuilder::new(len);
253 buf.append_n(len, true);
254
255 let mut false_count: usize = 0;
256 indices.into_iter().for_each(|idx| {
257 buf.set_bit(idx, false);
258 false_count += 1;
259 });
260 debug_assert_eq!(buf.len(), len);
261 let true_count = len - false_count;
262
263 Self::Values(Arc::new(MaskValues {
264 buffer: buf.finish(),
265 indices: Default::default(),
266 slices: Default::default(),
267 true_count,
268 density: true_count as f64 / len as f64,
269 }))
270 }
271
272 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
275 Self::check_slices(len, &vec);
276 Self::from_slices_unchecked(len, vec)
277 }
278
279 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
280 #[cfg(debug_assertions)]
281 Self::check_slices(len, &slices);
282
283 let true_count = slices.iter().map(|(b, e)| e - b).sum();
284 if true_count == 0 {
285 return Self::AllFalse(len);
286 }
287 if true_count == len {
288 return Self::AllTrue(len);
289 }
290
291 let mut buf = BooleanBufferBuilder::new(len);
292 for (start, end) in slices.iter().copied() {
293 buf.append_n(start - buf.len(), false);
294 buf.append_n(end - start, true);
295 }
296 if let Some((_, end)) = slices.last() {
297 buf.append_n(len - end, false);
298 }
299 debug_assert_eq!(buf.len(), len);
300
301 Self::Values(Arc::new(MaskValues {
302 buffer: buf.finish(),
303 indices: Default::default(),
304 slices: OnceLock::from(slices),
305 true_count,
306 density: true_count as f64 / len as f64,
307 }))
308 }
309
310 #[inline(always)]
311 fn check_slices(len: usize, vec: &[(usize, usize)]) {
312 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
313 for (first, second) in vec.iter().tuple_windows() {
314 assert!(
315 first.0 < second.0,
316 "Slices must be sorted, got {:?} and {:?}",
317 first,
318 second
319 );
320 assert!(
321 first.1 <= second.0,
322 "Slices must be non-overlapping, got {:?} and {:?}",
323 first,
324 second
325 );
326 }
327 }
328
329 pub fn from_intersection_indices(
331 len: usize,
332 lhs: impl Iterator<Item = usize>,
333 rhs: impl Iterator<Item = usize>,
334 ) -> Self {
335 let mut intersection = Vec::with_capacity(len);
336 let mut lhs = lhs.peekable();
337 let mut rhs = rhs.peekable();
338 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
339 match l.cmp(&r) {
340 Ordering::Less => {
341 lhs.next();
342 }
343 Ordering::Greater => {
344 rhs.next();
345 }
346 Ordering::Equal => {
347 intersection.push(l);
348 lhs.next();
349 rhs.next();
350 }
351 }
352 }
353 Self::from_indices(len, intersection)
354 }
355
356 #[inline]
358 #[allow(clippy::len_without_is_empty)]
360 pub fn len(&self) -> usize {
361 match &self {
362 Self::AllTrue(len) => *len,
363 Self::AllFalse(len) => *len,
364 Self::Values(values) => values.buffer.len(),
365 }
366 }
367
368 #[inline]
370 pub fn true_count(&self) -> usize {
371 match &self {
372 Self::AllTrue(len) => *len,
373 Self::AllFalse(_) => 0,
374 Self::Values(values) => values.true_count,
375 }
376 }
377
378 #[inline]
380 pub fn false_count(&self) -> usize {
381 match &self {
382 Self::AllTrue(_) => 0,
383 Self::AllFalse(len) => *len,
384 Self::Values(values) => values.buffer.len() - values.true_count,
385 }
386 }
387
388 #[inline]
390 pub fn all_true(&self) -> bool {
391 match &self {
392 Self::AllTrue(_) => true,
393 Self::AllFalse(_) => false,
394 Self::Values(values) => values.buffer.len() == values.true_count,
395 }
396 }
397
398 #[inline]
400 pub fn all_false(&self) -> bool {
401 self.true_count() == 0
402 }
403
404 #[inline]
406 pub fn density(&self) -> f64 {
407 match &self {
408 Self::AllTrue(_) => 1.0,
409 Self::AllFalse(_) => 0.0,
410 Self::Values(values) => values.density,
411 }
412 }
413
414 pub fn value(&self, idx: usize) -> bool {
420 match self {
421 Mask::AllTrue(_) => true,
422 Mask::AllFalse(_) => false,
423 Mask::Values(values) => values.buffer.value(idx),
424 }
425 }
426
427 pub fn first(&self) -> Option<usize> {
429 match &self {
430 Self::AllTrue(len) => (*len > 0).then_some(0),
431 Self::AllFalse(_) => None,
432 Self::Values(values) => {
433 if let Some(indices) = values.indices.get() {
434 return indices.first().copied();
435 }
436 if let Some(slices) = values.slices.get() {
437 return slices.first().map(|(start, _)| *start);
438 }
439 values.buffer.set_indices().next()
440 }
441 }
442 }
443
444 pub fn slice(&self, offset: usize, length: usize) -> Self {
446 assert!(offset + length <= self.len());
447 match &self {
448 Self::AllTrue(_) => Self::new_true(length),
449 Self::AllFalse(_) => Self::new_false(length),
450 Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
451 }
452 }
453
454 pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
456 match &self {
457 Self::AllTrue(_) => AllOr::All,
458 Self::AllFalse(_) => AllOr::None,
459 Self::Values(values) => AllOr::Some(&values.buffer),
460 }
461 }
462
463 pub fn to_boolean_buffer(&self) -> BooleanBuffer {
466 match self {
467 Self::AllTrue(l) => BooleanBuffer::new_set(*l),
468 Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
469 Self::Values(values) => values.boolean_buffer().clone(),
470 }
471 }
472
473 pub fn to_null_buffer(&self) -> Option<NullBuffer> {
475 match self {
476 Mask::AllTrue(_) => None,
477 Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
478 Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
479 }
480 }
481
482 pub fn indices(&self) -> AllOr<&[usize]> {
484 match &self {
485 Self::AllTrue(_) => AllOr::All,
486 Self::AllFalse(_) => AllOr::None,
487 Self::Values(values) => AllOr::Some(values.indices()),
488 }
489 }
490
491 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
493 match &self {
494 Self::AllTrue(_) => AllOr::All,
495 Self::AllFalse(_) => AllOr::None,
496 Self::Values(values) => AllOr::Some(values.slices()),
497 }
498 }
499
500 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter> {
502 match &self {
503 Self::AllTrue(_) => AllOr::All,
504 Self::AllFalse(_) => AllOr::None,
505 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
506 }
507 }
508
509 pub fn values(&self) -> Option<&MaskValues> {
511 match self {
512 Self::Values(values) => Some(values),
513 _ => None,
514 }
515 }
516}
517
518pub enum MaskIter<'a> {
520 Indices(&'a [usize]),
522 Slices(&'a [(usize, usize)]),
524}
525
526impl From<BooleanBuffer> for Mask {
527 fn from(value: BooleanBuffer) -> Self {
528 Self::from_buffer(value)
529 }
530}
531
532impl FromIterator<bool> for Mask {
533 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
534 Self::from_buffer(BooleanBuffer::from_iter(iter))
535 }
536}
537
538impl FromIterator<Mask> for Mask {
539 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
540 let masks = iter.into_iter().collect::<Vec<_>>();
541 let total_length = masks.iter().map(|v| v.len()).sum();
542
543 if masks.iter().all(|v| v.all_true()) {
545 return Self::AllTrue(total_length);
546 }
547 if masks.iter().all(|v| v.all_false()) {
549 return Self::AllFalse(total_length);
550 }
551
552 let mut buffer = BooleanBufferBuilder::new(total_length);
554 for mask in masks {
555 match mask {
556 Mask::AllTrue(count) => buffer.append_n(count, true),
557 Mask::AllFalse(count) => buffer.append_n(count, false),
558 Mask::Values(values) => {
559 buffer.append_buffer(values.boolean_buffer());
560 }
561 };
562 }
563 Self::from_buffer(buffer.finish())
564 }
565}
566
567#[cfg(test)]
568mod test {
569 use super::*;
570
571 #[test]
572 fn mask_all_true() {
573 let mask = Mask::new_true(5);
574 assert_eq!(mask.len(), 5);
575 assert_eq!(mask.true_count(), 5);
576 assert_eq!(mask.density(), 1.0);
577 assert_eq!(mask.indices(), AllOr::All);
578 assert_eq!(mask.slices(), AllOr::All);
579 assert_eq!(mask.boolean_buffer(), AllOr::All,);
580 }
581
582 #[test]
583 fn mask_all_false() {
584 let mask = Mask::new_false(5);
585 assert_eq!(mask.len(), 5);
586 assert_eq!(mask.true_count(), 0);
587 assert_eq!(mask.density(), 0.0);
588 assert_eq!(mask.indices(), AllOr::None);
589 assert_eq!(mask.slices(), AllOr::None);
590 assert_eq!(mask.boolean_buffer(), AllOr::None,);
591 }
592
593 #[test]
594 fn mask_from() {
595 let masks = [
596 Mask::from_indices(5, vec![0, 2, 3]),
597 Mask::from_slices(5, vec![(0, 1), (2, 4)]),
598 Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
599 ];
600
601 for mask in &masks {
602 assert_eq!(mask.len(), 5);
603 assert_eq!(mask.true_count(), 3);
604 assert_eq!(mask.density(), 0.6);
605 assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
606 assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
607 assert_eq!(
608 mask.boolean_buffer(),
609 AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
610 );
611 }
612 }
613}