1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10
11#[cfg(test)]
12mod tests;
13
14use std::cmp::Ordering;
15use std::fmt::Debug;
16use std::fmt::Formatter;
17use std::ops::Bound;
18use std::ops::RangeBounds;
19use std::sync::Arc;
20use std::sync::OnceLock;
21
22use itertools::Itertools;
23use vortex_buffer::BitBuffer;
24use vortex_buffer::BitBufferMut;
25use vortex_error::VortexResult;
26use vortex_error::vortex_panic;
27
28pub enum AllOr<T> {
30 All,
32 None,
34 Some(T),
36}
37
38impl<T> AllOr<T> {
39 #[inline]
41 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
42 where
43 F: FnOnce() -> T,
44 G: FnOnce() -> T,
45 {
46 match self {
47 Self::Some(v) => v,
48 AllOr::All => all_true(),
49 AllOr::None => all_false(),
50 }
51 }
52}
53
54impl<T> AllOr<&T> {
55 #[inline]
57 pub fn cloned(self) -> AllOr<T>
58 where
59 T: Clone,
60 {
61 match self {
62 Self::All => AllOr::All,
63 Self::None => AllOr::None,
64 Self::Some(v) => AllOr::Some(v.clone()),
65 }
66 }
67}
68
69impl<T> Debug for AllOr<T>
70where
71 T: Debug,
72{
73 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
74 match self {
75 Self::All => f.write_str("All"),
76 Self::None => f.write_str("None"),
77 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
78 }
79 }
80}
81
82impl<T> PartialEq for AllOr<T>
83where
84 T: PartialEq,
85{
86 fn eq(&self, other: &Self) -> bool {
87 match (self, other) {
88 (Self::All, Self::All) => true,
89 (Self::None, Self::None) => true,
90 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
91 _ => false,
92 }
93 }
94}
95
96impl<T> Eq for AllOr<T> where T: Eq {}
97
98#[derive(Clone)]
104#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
105pub enum Mask {
106 AllTrue(usize),
108 AllFalse(usize),
110 Values(Arc<MaskValues>),
112}
113
114impl Debug for Mask {
115 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 match self {
117 Self::AllTrue(len) => write!(f, "All true({len})"),
118 Self::AllFalse(len) => write!(f, "All false({len})"),
119 Self::Values(mask) => write!(f, "{mask:?}"),
120 }
121 }
122}
123
124impl Default for Mask {
125 fn default() -> Self {
126 Self::new_true(0)
127 }
128}
129
130#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
132pub struct MaskValues {
133 buffer: BitBuffer,
134
135 #[cfg_attr(feature = "serde", serde(skip))]
138 indices: OnceLock<Vec<usize>>,
139 #[cfg_attr(feature = "serde", serde(skip))]
140 slices: OnceLock<Vec<(usize, usize)>>,
141
142 true_count: usize,
144 density: f64,
146}
147
148impl Debug for MaskValues {
149 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150 write!(f, "true_count={}, ", self.true_count)?;
151 write!(f, "density={}, ", self.density)?;
152 if let Some(v) = self.indices.get() {
153 write!(f, "indices={v:?}, ")?;
154 }
155 if let Some(v) = self.slices.get() {
156 write!(f, "slices={v:?}, ")?;
157 }
158 if f.alternate() {
159 f.write_str("\n")?;
160 }
161 write!(f, "{}", self.buffer)
162 }
163}
164
165impl Mask {
166 pub fn new(length: usize, value: bool) -> Self {
168 if value {
169 Self::AllTrue(length)
170 } else {
171 Self::AllFalse(length)
172 }
173 }
174
175 #[inline]
177 pub fn new_true(length: usize) -> Self {
178 Self::AllTrue(length)
179 }
180
181 #[inline]
183 pub fn new_false(length: usize) -> Self {
184 Self::AllFalse(length)
185 }
186
187 pub fn from_buffer(buffer: BitBuffer) -> Self {
189 let len = buffer.len();
190 let true_count = buffer.true_count();
191
192 if true_count == 0 {
193 return Self::AllFalse(len);
194 }
195 if true_count == len {
196 return Self::AllTrue(len);
197 }
198
199 Self::Values(Arc::new(MaskValues {
200 buffer,
201 indices: Default::default(),
202 slices: Default::default(),
203 true_count,
204 density: true_count as f64 / len as f64,
205 }))
206 }
207
208 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
211 let true_count = indices.len();
212 assert!(indices.is_sorted(), "Mask indices must be sorted");
213 assert!(
214 indices.last().is_none_or(|&idx| idx < len),
215 "Mask indices must be in bounds (len={len})"
216 );
217
218 if true_count == 0 {
219 return Self::AllFalse(len);
220 }
221 if true_count == len {
222 return Self::AllTrue(len);
223 }
224
225 let mut buf = BitBufferMut::new_unset(len);
226 indices.iter().for_each(|&idx| buf.set(idx));
228 debug_assert_eq!(buf.len(), len);
229
230 Self::Values(Arc::new(MaskValues {
231 buffer: buf.freeze(),
232 indices: OnceLock::from(indices),
233 slices: Default::default(),
234 true_count,
235 density: true_count as f64 / len as f64,
236 }))
237 }
238
239 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
241 let mut buf = BitBufferMut::new_set(len);
242
243 let mut false_count: usize = 0;
244 indices.into_iter().for_each(|idx| {
245 buf.unset(idx);
246 false_count += 1;
247 });
248 debug_assert_eq!(buf.len(), len);
249 let true_count = len - false_count;
250
251 if false_count == 0 {
253 return Self::AllTrue(len);
254 }
255 if false_count == len {
256 return Self::AllFalse(len);
257 }
258
259 Self::Values(Arc::new(MaskValues {
260 buffer: buf.freeze(),
261 indices: Default::default(),
262 slices: Default::default(),
263 true_count,
264 density: true_count as f64 / len as f64,
265 }))
266 }
267
268 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
271 Self::check_slices(len, &vec);
272 Self::from_slices_unchecked(len, vec)
273 }
274
275 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
276 #[cfg(debug_assertions)]
277 Self::check_slices(len, &slices);
278
279 let true_count = slices.iter().map(|(b, e)| e - b).sum();
280 if true_count == 0 {
281 return Self::AllFalse(len);
282 }
283 if true_count == len {
284 return Self::AllTrue(len);
285 }
286
287 let mut buf = BitBufferMut::new_unset(len);
288 for (start, end) in slices.iter().copied() {
289 (start..end).for_each(|idx| buf.set(idx));
290 }
291 debug_assert_eq!(buf.len(), len);
292
293 Self::Values(Arc::new(MaskValues {
294 buffer: buf.freeze(),
295 indices: Default::default(),
296 slices: OnceLock::from(slices),
297 true_count,
298 density: true_count as f64 / len as f64,
299 }))
300 }
301
302 #[inline(always)]
303 fn check_slices(len: usize, vec: &[(usize, usize)]) {
304 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
305 for (first, second) in vec.iter().tuple_windows() {
306 assert!(
307 first.0 < second.0,
308 "Slices must be sorted, got {first:?} and {second:?}"
309 );
310 assert!(
311 first.1 <= second.0,
312 "Slices must be non-overlapping, got {first:?} and {second:?}"
313 );
314 }
315 }
316
317 pub fn from_intersection_indices(
319 len: usize,
320 lhs: impl Iterator<Item = usize>,
321 rhs: impl Iterator<Item = usize>,
322 ) -> Self {
323 let mut intersection = Vec::with_capacity(len);
324 let mut lhs = lhs.peekable();
325 let mut rhs = rhs.peekable();
326 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
327 match l.cmp(&r) {
328 Ordering::Less => {
329 lhs.next();
330 }
331 Ordering::Greater => {
332 rhs.next();
333 }
334 Ordering::Equal => {
335 intersection.push(l);
336 lhs.next();
337 rhs.next();
338 }
339 }
340 }
341 Self::from_indices(len, intersection)
342 }
343
344 pub fn clear(&mut self) {
346 *self = Self::new_false(0);
347 }
348
349 #[inline]
351 pub fn len(&self) -> usize {
352 match self {
353 Self::AllTrue(len) => *len,
354 Self::AllFalse(len) => *len,
355 Self::Values(values) => values.len(),
356 }
357 }
358
359 #[inline]
361 pub fn is_empty(&self) -> bool {
362 match self {
363 Self::AllTrue(len) => *len == 0,
364 Self::AllFalse(len) => *len == 0,
365 Self::Values(values) => values.is_empty(),
366 }
367 }
368
369 #[inline]
371 pub fn true_count(&self) -> usize {
372 match &self {
373 Self::AllTrue(len) => *len,
374 Self::AllFalse(_) => 0,
375 Self::Values(values) => values.true_count,
376 }
377 }
378
379 #[inline]
381 pub fn false_count(&self) -> usize {
382 match &self {
383 Self::AllTrue(_) => 0,
384 Self::AllFalse(len) => *len,
385 Self::Values(values) => values.buffer.len() - values.true_count,
386 }
387 }
388
389 #[inline]
391 pub fn all_true(&self) -> bool {
392 match &self {
393 Self::AllTrue(_) => true,
394 Self::AllFalse(0) => true,
395 Self::AllFalse(_) => false,
396 Self::Values(values) => values.buffer.len() == values.true_count,
397 }
398 }
399
400 #[inline]
402 pub fn all_false(&self) -> bool {
403 self.true_count() == 0
404 }
405
406 #[inline]
408 pub fn density(&self) -> f64 {
409 match &self {
410 Self::AllTrue(_) => 1.0,
411 Self::AllFalse(_) => 0.0,
412 Self::Values(values) => values.density,
413 }
414 }
415
416 #[inline]
422 pub fn value(&self, idx: usize) -> bool {
423 match self {
424 Mask::AllTrue(_) => true,
425 Mask::AllFalse(_) => false,
426 Mask::Values(values) => values.buffer.value(idx),
427 }
428 }
429
430 pub fn first(&self) -> Option<usize> {
432 match &self {
433 Self::AllTrue(len) => (*len > 0).then_some(0),
434 Self::AllFalse(_) => None,
435 Self::Values(values) => {
436 if let Some(indices) = values.indices.get() {
437 return indices.first().copied();
438 }
439 if let Some(slices) = values.slices.get() {
440 return slices.first().map(|(start, _)| *start);
441 }
442 values.buffer.set_indices().next()
443 }
444 }
445 }
446
447 pub fn last(&self) -> Option<usize> {
449 match &self {
450 Self::AllTrue(len) => (*len > 0).then_some(*len - 1),
451 Self::AllFalse(_) => None,
452 Self::Values(values) => {
453 if let Some(indices) = values.indices.get() {
454 return indices.last().copied();
455 }
456 if let Some(slices) = values.slices.get() {
457 return slices.last().map(|(_, end)| end - 1);
458 }
459 values.buffer.set_slices().last().map(|(_, end)| end - 1)
460 }
461 }
462 }
463
464 pub fn rank(&self, n: usize) -> usize {
466 if n >= self.true_count() {
467 vortex_panic!(
468 "Rank {n} out of bounds for mask with true count {}",
469 self.true_count()
470 );
471 }
472 match &self {
473 Self::AllTrue(_) => n,
474 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
475 Self::Values(values) => values.indices()[n],
477 }
478 }
479
480 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
482 let start = match range.start_bound() {
483 Bound::Included(&s) => s,
484 Bound::Excluded(&s) => s + 1,
485 Bound::Unbounded => 0,
486 };
487 let end = match range.end_bound() {
488 Bound::Included(&e) => e + 1,
489 Bound::Excluded(&e) => e,
490 Bound::Unbounded => self.len(),
491 };
492
493 assert!(start <= end);
494 assert!(start <= self.len());
495 assert!(end <= self.len());
496 let len = end - start;
497
498 match &self {
499 Self::AllTrue(_) => Self::new_true(len),
500 Self::AllFalse(_) => Self::new_false(len),
501 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
502 }
503 }
504
505 #[inline]
507 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
508 match &self {
509 Self::AllTrue(_) => AllOr::All,
510 Self::AllFalse(_) => AllOr::None,
511 Self::Values(values) => AllOr::Some(&values.buffer),
512 }
513 }
514
515 #[inline]
518 pub fn to_bit_buffer(&self) -> BitBuffer {
519 match self {
520 Self::AllTrue(l) => BitBuffer::new_set(*l),
521 Self::AllFalse(l) => BitBuffer::new_unset(*l),
522 Self::Values(values) => values.bit_buffer().clone(),
523 }
524 }
525
526 #[inline]
529 pub fn into_bit_buffer(self) -> BitBuffer {
530 match self {
531 Self::AllTrue(l) => BitBuffer::new_set(l),
532 Self::AllFalse(l) => BitBuffer::new_unset(l),
533 Self::Values(values) => Arc::try_unwrap(values)
534 .map(|v| v.into_bit_buffer())
535 .unwrap_or_else(|v| v.bit_buffer().clone()),
536 }
537 }
538
539 #[inline]
541 pub fn indices(&self) -> AllOr<&[usize]> {
542 match &self {
543 Self::AllTrue(_) => AllOr::All,
544 Self::AllFalse(_) => AllOr::None,
545 Self::Values(values) => AllOr::Some(values.indices()),
546 }
547 }
548
549 #[inline]
551 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
552 match &self {
553 Self::AllTrue(_) => AllOr::All,
554 Self::AllFalse(_) => AllOr::None,
555 Self::Values(values) => AllOr::Some(values.slices()),
556 }
557 }
558
559 #[inline]
561 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
562 match &self {
563 Self::AllTrue(_) => AllOr::All,
564 Self::AllFalse(_) => AllOr::None,
565 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
566 }
567 }
568
569 #[inline]
571 pub fn values(&self) -> Option<&MaskValues> {
572 if let Self::Values(values) = self {
573 Some(values)
574 } else {
575 None
576 }
577 }
578
579 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
584 match self {
585 Self::AllTrue(_) => indices.to_vec(),
586 Self::AllFalse(_) => vec![0; indices.len()],
587 Self::Values(values) => {
588 let mut bool_iter = values.bit_buffer().iter();
589 let mut valid_counts = Vec::with_capacity(indices.len());
590 let mut valid_count = 0;
591 let mut idx = 0;
592 for &next_idx in indices {
593 while idx < next_idx {
594 idx += 1;
595 valid_count += bool_iter
596 .next()
597 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
598 as usize;
599 }
600 valid_counts.push(valid_count);
601 }
602
603 valid_counts
604 }
605 }
606 }
607
608 pub fn limit(self, limit: usize) -> Self {
610 if self.len() <= limit {
614 return self;
615 }
616
617 match self {
618 Mask::AllTrue(len) => {
619 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
620 }
621 Mask::AllFalse(_) => self,
622 Mask::Values(ref mask_values) => {
623 if limit >= mask_values.true_count() {
624 return self;
625 }
626
627 let existing_buffer = mask_values.bit_buffer();
628
629 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
630 debug_assert!(limit < mask_values.len());
631
632 for index in existing_buffer.set_indices().take(limit) {
633 unsafe { new_buffer_builder.set_unchecked(index) }
636 }
637
638 Self::from(new_buffer_builder.freeze())
639 }
640 }
641 }
642
643 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
645 let masks: Vec<_> = masks.collect();
646 let len = masks.iter().map(|t| t.len()).sum();
647
648 if masks.iter().all(|t| t.all_true()) {
649 return Ok(Mask::AllTrue(len));
650 }
651
652 if masks.iter().all(|t| t.all_false()) {
653 return Ok(Mask::AllFalse(len));
654 }
655
656 let mut builder = BitBufferMut::with_capacity(len);
657
658 for mask in masks {
659 match mask {
660 Mask::AllTrue(n) => builder.append_n(true, *n),
661 Mask::AllFalse(n) => builder.append_n(false, *n),
662 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
663 }
664 }
665
666 Ok(Mask::from_buffer(builder.freeze()))
667 }
668}
669
670impl MaskValues {
671 #[inline]
673 pub fn len(&self) -> usize {
674 self.buffer.len()
675 }
676
677 #[inline]
679 pub fn is_empty(&self) -> bool {
680 self.buffer.is_empty()
681 }
682
683 #[inline]
685 pub fn density(&self) -> f64 {
686 self.density
687 }
688
689 #[inline]
691 pub fn true_count(&self) -> usize {
692 self.true_count
693 }
694
695 #[inline]
697 pub fn bit_buffer(&self) -> &BitBuffer {
698 &self.buffer
699 }
700
701 #[inline]
703 pub fn into_bit_buffer(self) -> BitBuffer {
704 self.buffer
705 }
706
707 #[inline]
709 pub fn value(&self, index: usize) -> bool {
710 self.buffer.value(index)
711 }
712
713 pub fn indices(&self) -> &[usize] {
715 self.indices.get_or_init(|| {
716 if self.true_count == 0 {
717 return vec![];
718 }
719
720 if self.true_count == self.len() {
721 return (0..self.len()).collect();
722 }
723
724 if let Some(slices) = self.slices.get() {
725 let mut indices = Vec::with_capacity(self.true_count);
726 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
727 debug_assert!(indices.is_sorted());
728 assert_eq!(indices.len(), self.true_count);
729 return indices;
730 }
731
732 let mut indices = Vec::with_capacity(self.true_count);
733 indices.extend(self.buffer.set_indices());
734 debug_assert!(indices.is_sorted());
735 assert_eq!(indices.len(), self.true_count);
736 indices
737 })
738 }
739
740 #[inline]
742 pub fn slices(&self) -> &[(usize, usize)] {
743 self.slices.get_or_init(|| {
744 if self.true_count == self.len() {
745 return vec![(0, self.len())];
746 }
747
748 self.buffer.set_slices().collect()
749 })
750 }
751
752 #[inline]
754 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
755 if self.density >= threshold {
756 MaskIter::Slices(self.slices())
757 } else {
758 MaskIter::Indices(self.indices())
759 }
760 }
761}
762
763pub enum MaskIter<'a> {
765 Indices(&'a [usize]),
767 Slices(&'a [(usize, usize)]),
769}
770
771impl From<BitBuffer> for Mask {
772 fn from(value: BitBuffer) -> Self {
773 Self::from_buffer(value)
774 }
775}
776
777impl FromIterator<bool> for Mask {
778 #[inline]
779 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
780 Self::from_buffer(BitBuffer::from_iter(iter))
781 }
782}
783
784impl FromIterator<Mask> for Mask {
785 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
786 let masks = iter
787 .into_iter()
788 .filter(|m| !m.is_empty())
789 .collect::<Vec<_>>();
790 let total_length = masks.iter().map(|v| v.len()).sum();
791
792 if masks.iter().all(|v| v.all_true()) {
794 return Self::AllTrue(total_length);
795 }
796 if masks.iter().all(|v| v.all_false()) {
798 return Self::AllFalse(total_length);
799 }
800
801 let mut buffer = BitBufferMut::with_capacity(total_length);
803 for mask in masks {
804 match mask {
805 Mask::AllTrue(count) => buffer.append_n(true, count),
806 Mask::AllFalse(count) => buffer.append_n(false, count),
807 Mask::Values(values) => {
808 buffer.append_buffer(values.bit_buffer());
809 }
810 };
811 }
812 Self::from_buffer(buffer.freeze())
813 }
814}