1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(test)]
14mod tests;
15
16use std::cmp::Ordering;
17use std::fmt::{Debug, Formatter};
18use std::ops::{Bound, RangeBounds};
19use std::sync::{Arc, OnceLock};
20
21use itertools::Itertools;
22pub use mask_mut::*;
23use vortex_buffer::{BitBuffer, BitBufferMut, set_bit_unchecked};
24use vortex_error::{VortexResult, vortex_panic};
25
26pub enum AllOr<T> {
28 All,
30 None,
32 Some(T),
34}
35
36impl<T> AllOr<T> {
37 #[inline]
39 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
40 where
41 F: FnOnce() -> T,
42 G: FnOnce() -> T,
43 {
44 match self {
45 Self::Some(v) => v,
46 AllOr::All => all_true(),
47 AllOr::None => all_false(),
48 }
49 }
50}
51
52impl<T> AllOr<&T> {
53 #[inline]
55 pub fn cloned(self) -> AllOr<T>
56 where
57 T: Clone,
58 {
59 match self {
60 Self::All => AllOr::All,
61 Self::None => AllOr::None,
62 Self::Some(v) => AllOr::Some(v.clone()),
63 }
64 }
65}
66
67impl<T> Debug for AllOr<T>
68where
69 T: Debug,
70{
71 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72 match self {
73 Self::All => f.write_str("All"),
74 Self::None => f.write_str("None"),
75 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
76 }
77 }
78}
79
80impl<T> PartialEq for AllOr<T>
81where
82 T: PartialEq,
83{
84 fn eq(&self, other: &Self) -> bool {
85 match (self, other) {
86 (Self::All, Self::All) => true,
87 (Self::None, Self::None) => true,
88 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
89 _ => false,
90 }
91 }
92}
93
94impl<T> Eq for AllOr<T> where T: Eq {}
95
96#[derive(Debug, Clone)]
101#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
102pub enum Mask {
103 AllTrue(usize),
105 AllFalse(usize),
107 Values(Arc<MaskValues>),
109}
110
111#[derive(Debug)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
114pub struct MaskValues {
115 buffer: BitBuffer,
116
117 #[cfg_attr(feature = "serde", serde(skip))]
120 indices: OnceLock<Vec<usize>>,
121 #[cfg_attr(feature = "serde", serde(skip))]
122 slices: OnceLock<Vec<(usize, usize)>>,
123
124 true_count: usize,
126 density: f64,
128}
129
130impl Mask {
131 #[inline]
133 pub fn new_true(length: usize) -> Self {
134 Self::AllTrue(length)
135 }
136
137 #[inline]
139 pub fn new_false(length: usize) -> Self {
140 Self::AllFalse(length)
141 }
142
143 pub fn from_buffer(buffer: BitBuffer) -> Self {
145 let len = buffer.len();
146 let true_count = buffer.true_count();
147
148 if true_count == 0 {
149 return Self::AllFalse(len);
150 }
151 if true_count == len {
152 return Self::AllTrue(len);
153 }
154
155 Self::Values(Arc::new(MaskValues {
156 buffer,
157 indices: Default::default(),
158 slices: Default::default(),
159 true_count,
160 density: true_count as f64 / len as f64,
161 }))
162 }
163
164 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
167 let true_count = indices.len();
168 assert!(indices.is_sorted(), "Mask indices must be sorted");
169 assert!(
170 indices.last().is_none_or(|&idx| idx < len),
171 "Mask indices must be in bounds (len={len})"
172 );
173
174 if true_count == 0 {
175 return Self::AllFalse(len);
176 }
177 if true_count == len {
178 return Self::AllTrue(len);
179 }
180
181 let mut buf = BitBufferMut::new_unset(len);
182 indices.iter().for_each(|&idx| buf.set(idx));
184 debug_assert_eq!(buf.len(), len);
185
186 Self::Values(Arc::new(MaskValues {
187 buffer: buf.freeze(),
188 indices: OnceLock::from(indices),
189 slices: Default::default(),
190 true_count,
191 density: true_count as f64 / len as f64,
192 }))
193 }
194
195 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
197 let mut buf = BitBufferMut::new_set(len);
198
199 let mut false_count: usize = 0;
200 indices.into_iter().for_each(|idx| {
201 buf.unset(idx);
202 false_count += 1;
203 });
204 debug_assert_eq!(buf.len(), len);
205 let true_count = len - false_count;
206
207 if false_count == 0 {
209 return Self::AllTrue(len);
210 }
211 if false_count == len {
212 return Self::AllFalse(len);
213 }
214
215 Self::Values(Arc::new(MaskValues {
216 buffer: buf.freeze(),
217 indices: Default::default(),
218 slices: Default::default(),
219 true_count,
220 density: true_count as f64 / len as f64,
221 }))
222 }
223
224 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
227 Self::check_slices(len, &vec);
228 Self::from_slices_unchecked(len, vec)
229 }
230
231 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
232 #[cfg(debug_assertions)]
233 Self::check_slices(len, &slices);
234
235 let true_count = slices.iter().map(|(b, e)| e - b).sum();
236 if true_count == 0 {
237 return Self::AllFalse(len);
238 }
239 if true_count == len {
240 return Self::AllTrue(len);
241 }
242
243 let mut buf = BitBufferMut::new_unset(len);
244 for (start, end) in slices.iter().copied() {
245 (start..end).for_each(|idx| buf.set(idx));
246 }
247 debug_assert_eq!(buf.len(), len);
248
249 Self::Values(Arc::new(MaskValues {
250 buffer: buf.freeze(),
251 indices: Default::default(),
252 slices: OnceLock::from(slices),
253 true_count,
254 density: true_count as f64 / len as f64,
255 }))
256 }
257
258 #[inline(always)]
259 fn check_slices(len: usize, vec: &[(usize, usize)]) {
260 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
261 for (first, second) in vec.iter().tuple_windows() {
262 assert!(
263 first.0 < second.0,
264 "Slices must be sorted, got {first:?} and {second:?}"
265 );
266 assert!(
267 first.1 <= second.0,
268 "Slices must be non-overlapping, got {first:?} and {second:?}"
269 );
270 }
271 }
272
273 pub fn from_intersection_indices(
275 len: usize,
276 lhs: impl Iterator<Item = usize>,
277 rhs: impl Iterator<Item = usize>,
278 ) -> Self {
279 let mut intersection = Vec::with_capacity(len);
280 let mut lhs = lhs.peekable();
281 let mut rhs = rhs.peekable();
282 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
283 match l.cmp(&r) {
284 Ordering::Less => {
285 lhs.next();
286 }
287 Ordering::Greater => {
288 rhs.next();
289 }
290 Ordering::Equal => {
291 intersection.push(l);
292 lhs.next();
293 rhs.next();
294 }
295 }
296 }
297 Self::from_indices(len, intersection)
298 }
299
300 #[inline]
302 pub fn len(&self) -> usize {
303 match self {
304 Self::AllTrue(len) => *len,
305 Self::AllFalse(len) => *len,
306 Self::Values(values) => values.len(),
307 }
308 }
309
310 #[inline]
312 pub fn is_empty(&self) -> bool {
313 match self {
314 Self::AllTrue(len) => *len == 0,
315 Self::AllFalse(len) => *len == 0,
316 Self::Values(values) => values.is_empty(),
317 }
318 }
319
320 #[inline]
322 pub fn true_count(&self) -> usize {
323 match &self {
324 Self::AllTrue(len) => *len,
325 Self::AllFalse(_) => 0,
326 Self::Values(values) => values.true_count,
327 }
328 }
329
330 #[inline]
332 pub fn false_count(&self) -> usize {
333 match &self {
334 Self::AllTrue(_) => 0,
335 Self::AllFalse(len) => *len,
336 Self::Values(values) => values.buffer.len() - values.true_count,
337 }
338 }
339
340 #[inline]
342 pub fn all_true(&self) -> bool {
343 match &self {
344 Self::AllTrue(_) => true,
345 Self::AllFalse(0) => true,
346 Self::AllFalse(_) => false,
347 Self::Values(values) => values.buffer.len() == values.true_count,
348 }
349 }
350
351 #[inline]
353 pub fn all_false(&self) -> bool {
354 self.true_count() == 0
355 }
356
357 #[inline]
359 pub fn density(&self) -> f64 {
360 match &self {
361 Self::AllTrue(_) => 1.0,
362 Self::AllFalse(_) => 0.0,
363 Self::Values(values) => values.density,
364 }
365 }
366
367 #[inline]
373 pub fn value(&self, idx: usize) -> bool {
374 match self {
375 Mask::AllTrue(_) => true,
376 Mask::AllFalse(_) => false,
377 Mask::Values(values) => values.buffer.value(idx),
378 }
379 }
380
381 pub fn first(&self) -> Option<usize> {
383 match &self {
384 Self::AllTrue(len) => (*len > 0).then_some(0),
385 Self::AllFalse(_) => None,
386 Self::Values(values) => {
387 if let Some(indices) = values.indices.get() {
388 return indices.first().copied();
389 }
390 if let Some(slices) = values.slices.get() {
391 return slices.first().map(|(start, _)| *start);
392 }
393 values.buffer.set_indices().next()
394 }
395 }
396 }
397
398 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
400 let start = match range.start_bound() {
401 Bound::Included(&s) => s,
402 Bound::Excluded(&s) => s + 1,
403 Bound::Unbounded => 0,
404 };
405 let end = match range.end_bound() {
406 Bound::Included(&e) => e + 1,
407 Bound::Excluded(&e) => e,
408 Bound::Unbounded => self.len(),
409 };
410
411 assert!(start <= end);
412 assert!(start <= self.len());
413 assert!(end <= self.len());
414 let len = end - start;
415
416 match &self {
417 Self::AllTrue(_) => Self::new_true(len),
418 Self::AllFalse(_) => Self::new_false(len),
419 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
420 }
421 }
422
423 #[inline]
425 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
426 match &self {
427 Self::AllTrue(_) => AllOr::All,
428 Self::AllFalse(_) => AllOr::None,
429 Self::Values(values) => AllOr::Some(&values.buffer),
430 }
431 }
432
433 #[inline]
436 pub fn to_bit_buffer(&self) -> BitBuffer {
437 match self {
438 Self::AllTrue(l) => BitBuffer::new_set(*l),
439 Self::AllFalse(l) => BitBuffer::new_unset(*l),
440 Self::Values(values) => values.bit_buffer().clone(),
441 }
442 }
443
444 #[inline]
447 pub fn into_bit_buffer(self) -> BitBuffer {
448 match self {
449 Self::AllTrue(l) => BitBuffer::new_set(l),
450 Self::AllFalse(l) => BitBuffer::new_unset(l),
451 Self::Values(values) => Arc::try_unwrap(values)
452 .map(|v| v.into_bit_buffer())
453 .unwrap_or_else(|v| v.bit_buffer().clone()),
454 }
455 }
456
457 #[inline]
459 pub fn indices(&self) -> AllOr<&[usize]> {
460 match &self {
461 Self::AllTrue(_) => AllOr::All,
462 Self::AllFalse(_) => AllOr::None,
463 Self::Values(values) => AllOr::Some(values.indices()),
464 }
465 }
466
467 #[inline]
469 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
470 match &self {
471 Self::AllTrue(_) => AllOr::All,
472 Self::AllFalse(_) => AllOr::None,
473 Self::Values(values) => AllOr::Some(values.slices()),
474 }
475 }
476
477 #[inline]
479 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
480 match &self {
481 Self::AllTrue(_) => AllOr::All,
482 Self::AllFalse(_) => AllOr::None,
483 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
484 }
485 }
486
487 #[inline]
489 pub fn values(&self) -> Option<&MaskValues> {
490 if let Self::Values(values) = self {
491 Some(values)
492 } else {
493 None
494 }
495 }
496
497 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
502 match self {
503 Self::AllTrue(_) => indices.to_vec(),
504 Self::AllFalse(_) => vec![0; indices.len()],
505 Self::Values(values) => {
506 let mut bool_iter = values.bit_buffer().iter();
507 let mut valid_counts = Vec::with_capacity(indices.len());
508 let mut valid_count = 0;
509 let mut idx = 0;
510 for &next_idx in indices {
511 while idx < next_idx {
512 idx += 1;
513 valid_count += bool_iter
514 .next()
515 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
516 as usize;
517 }
518 valid_counts.push(valid_count);
519 }
520
521 valid_counts
522 }
523 }
524 }
525
526 pub fn limit(self, limit: usize) -> Self {
528 if self.len() <= limit {
532 return self;
533 }
534
535 match self {
536 Mask::AllTrue(len) => {
537 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
538 }
539 Mask::AllFalse(_) => self,
540 Mask::Values(ref mask_values) => {
541 if limit >= mask_values.true_count() {
542 return self;
543 }
544
545 let existing_buffer = mask_values.bit_buffer();
546
547 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
548 debug_assert!(limit < mask_values.len());
549
550 let ptr = new_buffer_builder.as_mut_ptr();
551 for index in existing_buffer.set_indices().take(limit) {
552 unsafe { set_bit_unchecked(ptr, index) }
555 }
556
557 Self::from(new_buffer_builder.freeze())
558 }
559 }
560 }
561
562 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
564 let masks: Vec<_> = masks.collect();
565 let len = masks.iter().map(|t| t.len()).sum();
566
567 if masks.iter().all(|t| t.all_true()) {
568 return Ok(Mask::AllTrue(len));
569 }
570
571 if masks.iter().all(|t| t.all_false()) {
572 return Ok(Mask::AllFalse(len));
573 }
574
575 let mut builder = BitBufferMut::with_capacity(len);
576
577 for mask in masks {
578 match mask {
579 Mask::AllTrue(n) => builder.append_n(true, *n),
580 Mask::AllFalse(n) => builder.append_n(false, *n),
581 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
582 }
583 }
584
585 Ok(Mask::from_buffer(builder.freeze()))
586 }
587}
588
589impl MaskValues {
590 #[inline]
592 pub fn len(&self) -> usize {
593 self.buffer.len()
594 }
595
596 #[inline]
598 pub fn is_empty(&self) -> bool {
599 self.buffer.is_empty()
600 }
601
602 #[inline]
604 pub fn true_count(&self) -> usize {
605 self.true_count
606 }
607
608 #[inline]
610 pub fn bit_buffer(&self) -> &BitBuffer {
611 &self.buffer
612 }
613
614 #[inline]
616 pub fn into_bit_buffer(self) -> BitBuffer {
617 self.buffer
618 }
619
620 #[inline]
622 pub fn value(&self, index: usize) -> bool {
623 self.buffer.value(index)
624 }
625
626 pub fn indices(&self) -> &[usize] {
628 self.indices.get_or_init(|| {
629 if self.true_count == 0 {
630 return vec![];
631 }
632
633 if self.true_count == self.len() {
634 return (0..self.len()).collect();
635 }
636
637 if let Some(slices) = self.slices.get() {
638 let mut indices = Vec::with_capacity(self.true_count);
639 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
640 debug_assert!(indices.is_sorted());
641 assert_eq!(indices.len(), self.true_count);
642 return indices;
643 }
644
645 let mut indices = Vec::with_capacity(self.true_count);
646 indices.extend(self.buffer.set_indices());
647 debug_assert!(indices.is_sorted());
648 assert_eq!(indices.len(), self.true_count);
649 indices
650 })
651 }
652
653 #[allow(clippy::cast_possible_truncation)]
655 #[inline]
656 pub fn slices(&self) -> &[(usize, usize)] {
657 self.slices.get_or_init(|| {
658 if self.true_count == self.len() {
659 return vec![(0, self.len())];
660 }
661
662 self.buffer.set_slices().collect()
663 })
664 }
665
666 #[inline]
668 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
669 if self.density >= threshold {
670 MaskIter::Slices(self.slices())
671 } else {
672 MaskIter::Indices(self.indices())
673 }
674 }
675
676 pub(crate) fn into_buffer(self) -> BitBuffer {
678 self.buffer
679 }
680}
681
682pub enum MaskIter<'a> {
684 Indices(&'a [usize]),
686 Slices(&'a [(usize, usize)]),
688}
689
690impl From<BitBuffer> for Mask {
691 fn from(value: BitBuffer) -> Self {
692 Self::from_buffer(value)
693 }
694}
695
696impl FromIterator<bool> for Mask {
697 #[inline]
698 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
699 Self::from_buffer(BitBuffer::from_iter(iter))
700 }
701}
702
703impl FromIterator<Mask> for Mask {
704 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
705 let masks = iter
706 .into_iter()
707 .filter(|m| !m.is_empty())
708 .collect::<Vec<_>>();
709 let total_length = masks.iter().map(|v| v.len()).sum();
710
711 if masks.iter().all(|v| v.all_true()) {
713 return Self::AllTrue(total_length);
714 }
715 if masks.iter().all(|v| v.all_false()) {
717 return Self::AllFalse(total_length);
718 }
719
720 let mut buffer = BitBufferMut::with_capacity(total_length);
722 for mask in masks {
723 match mask {
724 Mask::AllTrue(count) => buffer.append_n(true, count),
725 Mask::AllFalse(count) => buffer.append_n(false, count),
726 Mask::Values(values) => {
727 buffer.append_buffer(values.bit_buffer());
728 }
729 };
730 }
731 Self::from_buffer(buffer.freeze())
732 }
733}