1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10
11#[cfg(test)]
12mod tests;
13
14use std::cmp::Ordering;
15use std::fmt::Debug;
16use std::fmt::Formatter;
17use std::ops::Bound;
18use std::ops::RangeBounds;
19use std::sync::Arc;
20use std::sync::OnceLock;
21
22use itertools::Itertools;
23use vortex_buffer::BitBuffer;
24use vortex_buffer::BitBufferMut;
25use vortex_error::VortexResult;
26use vortex_error::vortex_panic;
27
28pub enum AllOr<T> {
30 All,
32 None,
34 Some(T),
36}
37
38impl<T> AllOr<T> {
39 #[inline]
41 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
42 where
43 F: FnOnce() -> T,
44 G: FnOnce() -> T,
45 {
46 match self {
47 Self::Some(v) => v,
48 AllOr::All => all_true(),
49 AllOr::None => all_false(),
50 }
51 }
52}
53
54impl<T> AllOr<&T> {
55 #[inline]
57 pub fn cloned(self) -> AllOr<T>
58 where
59 T: Clone,
60 {
61 match self {
62 Self::All => AllOr::All,
63 Self::None => AllOr::None,
64 Self::Some(v) => AllOr::Some(v.clone()),
65 }
66 }
67}
68
69impl<T> Debug for AllOr<T>
70where
71 T: Debug,
72{
73 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
74 match self {
75 Self::All => f.write_str("All"),
76 Self::None => f.write_str("None"),
77 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
78 }
79 }
80}
81
82impl<T> PartialEq for AllOr<T>
83where
84 T: PartialEq,
85{
86 fn eq(&self, other: &Self) -> bool {
87 match (self, other) {
88 (Self::All, Self::All) => true,
89 (Self::None, Self::None) => true,
90 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
91 _ => false,
92 }
93 }
94}
95
96impl<T> Eq for AllOr<T> where T: Eq {}
97
98#[derive(Clone)]
104#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
105pub enum Mask {
106 AllTrue(usize),
108 AllFalse(usize),
110 Values(Arc<MaskValues>),
112}
113
114impl Debug for Mask {
115 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 match self {
117 Self::AllTrue(len) => write!(f, "All true({len})"),
118 Self::AllFalse(len) => write!(f, "All false({len})"),
119 Self::Values(mask) => write!(f, "{mask:?}"),
120 }
121 }
122}
123
124impl Default for Mask {
125 fn default() -> Self {
126 Self::new_true(0)
127 }
128}
129
130#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
132pub struct MaskValues {
133 buffer: BitBuffer,
134
135 #[cfg_attr(feature = "serde", serde(skip))]
138 indices: OnceLock<Vec<usize>>,
139 #[cfg_attr(feature = "serde", serde(skip))]
140 slices: OnceLock<Vec<(usize, usize)>>,
141
142 true_count: usize,
144 density: f64,
146}
147
148impl Debug for MaskValues {
149 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150 write!(f, "true_count={}, ", self.true_count)?;
151 write!(f, "density={}, ", self.density)?;
152 if let Some(v) = self.indices.get() {
153 write!(f, "indices={v:?}, ")?;
154 }
155 if let Some(v) = self.slices.get() {
156 write!(f, "slices={v:?}, ")?;
157 }
158 if f.alternate() {
159 f.write_str("\n")?;
160 }
161 write!(f, "{}", self.buffer)
162 }
163}
164
165impl Mask {
166 pub fn new(length: usize, value: bool) -> Self {
168 if value {
169 Self::AllTrue(length)
170 } else {
171 Self::AllFalse(length)
172 }
173 }
174
175 #[inline]
177 pub fn new_true(length: usize) -> Self {
178 Self::AllTrue(length)
179 }
180
181 #[inline]
183 pub fn new_false(length: usize) -> Self {
184 Self::AllFalse(length)
185 }
186
187 pub fn from_buffer(buffer: BitBuffer) -> Self {
189 let len = buffer.len();
190 let true_count = buffer.true_count();
191
192 if true_count == 0 {
193 return Self::AllFalse(len);
194 }
195 if true_count == len {
196 return Self::AllTrue(len);
197 }
198
199 Self::Values(Arc::new(MaskValues {
200 buffer,
201 indices: Default::default(),
202 slices: Default::default(),
203 true_count,
204 density: true_count as f64 / len as f64,
205 }))
206 }
207
208 pub fn from_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
210 let indices = indices.into_iter().collect::<Vec<_>>();
211 assert!(indices.is_sorted(), "Mask indices must be sorted");
212 assert!(
213 indices.windows(2).all(|w| w[0] != w[1]),
214 "Mask indices must be unique"
215 );
216 let buffer = BitBuffer::from_indices(len, indices.iter().copied());
217 debug_assert_eq!(buffer.len(), len);
218 let true_count = buffer.true_count();
219
220 if true_count == 0 {
221 return Self::AllFalse(len);
222 }
223 if true_count == len {
224 return Self::AllTrue(len);
225 }
226
227 Self::Values(Arc::new(MaskValues {
228 buffer,
229 indices: OnceLock::from(indices),
230 slices: Default::default(),
231 true_count,
232 density: true_count as f64 / len as f64,
233 }))
234 }
235
236 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
238 let mut buf = BitBufferMut::new_set(len);
239
240 let mut false_count: usize = 0;
241 indices.into_iter().for_each(|idx| {
242 buf.unset(idx);
243 false_count += 1;
244 });
245 debug_assert_eq!(buf.len(), len);
246 let true_count = len - false_count;
247
248 if false_count == 0 {
250 return Self::AllTrue(len);
251 }
252 if false_count == len {
253 return Self::AllFalse(len);
254 }
255
256 Self::Values(Arc::new(MaskValues {
257 buffer: buf.freeze(),
258 indices: Default::default(),
259 slices: Default::default(),
260 true_count,
261 density: true_count as f64 / len as f64,
262 }))
263 }
264
265 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
268 Self::check_slices(len, &vec);
269 Self::from_slices_unchecked(len, vec)
270 }
271
272 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
273 #[cfg(debug_assertions)]
274 Self::check_slices(len, &slices);
275
276 let true_count = slices.iter().map(|(b, e)| e - b).sum();
277 if true_count == 0 {
278 return Self::AllFalse(len);
279 }
280 if true_count == len {
281 return Self::AllTrue(len);
282 }
283
284 let mut buf = BitBufferMut::with_capacity(len);
285 let mut cursor = 0;
286 for (start, end) in slices.iter().copied() {
287 buf.append_n(false, start - cursor);
288 buf.append_n(true, end - start);
289 cursor = end;
290 }
291 buf.append_n(false, len - cursor);
292 debug_assert_eq!(buf.len(), len);
293
294 Self::Values(Arc::new(MaskValues {
295 buffer: buf.freeze(),
296 indices: Default::default(),
297 slices: OnceLock::from(slices),
298 true_count,
299 density: true_count as f64 / len as f64,
300 }))
301 }
302
303 #[inline(always)]
304 fn check_slices(len: usize, vec: &[(usize, usize)]) {
305 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
306 for (first, second) in vec.iter().tuple_windows() {
307 assert!(
308 first.0 < second.0,
309 "Slices must be sorted, got {first:?} and {second:?}"
310 );
311 assert!(
312 first.1 <= second.0,
313 "Slices must be non-overlapping, got {first:?} and {second:?}"
314 );
315 }
316 }
317
318 pub fn from_intersection_indices(
320 len: usize,
321 lhs: impl Iterator<Item = usize>,
322 rhs: impl Iterator<Item = usize>,
323 ) -> Self {
324 let mut intersection = Vec::with_capacity(len);
325 let mut lhs = lhs.peekable();
326 let mut rhs = rhs.peekable();
327 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
328 match l.cmp(&r) {
329 Ordering::Less => {
330 lhs.next();
331 }
332 Ordering::Greater => {
333 rhs.next();
334 }
335 Ordering::Equal => {
336 intersection.push(l);
337 lhs.next();
338 rhs.next();
339 }
340 }
341 }
342 Self::from_indices(len, intersection)
343 }
344
345 pub fn clear(&mut self) {
347 *self = Self::new_false(0);
348 }
349
350 #[inline]
352 pub fn len(&self) -> usize {
353 match self {
354 Self::AllTrue(len) => *len,
355 Self::AllFalse(len) => *len,
356 Self::Values(values) => values.len(),
357 }
358 }
359
360 #[inline]
362 pub fn is_empty(&self) -> bool {
363 match self {
364 Self::AllTrue(len) => *len == 0,
365 Self::AllFalse(len) => *len == 0,
366 Self::Values(values) => values.is_empty(),
367 }
368 }
369
370 #[inline]
372 pub fn true_count(&self) -> usize {
373 match &self {
374 Self::AllTrue(len) => *len,
375 Self::AllFalse(_) => 0,
376 Self::Values(values) => values.true_count,
377 }
378 }
379
380 #[inline]
382 pub fn false_count(&self) -> usize {
383 match &self {
384 Self::AllTrue(_) => 0,
385 Self::AllFalse(len) => *len,
386 Self::Values(values) => values.buffer.len() - values.true_count,
387 }
388 }
389
390 #[inline]
392 pub fn all_true(&self) -> bool {
393 match &self {
394 Self::AllTrue(_) => true,
395 Self::AllFalse(0) => true,
396 Self::AllFalse(_) => false,
397 Self::Values(values) => values.buffer.len() == values.true_count,
398 }
399 }
400
401 #[inline]
403 pub fn all_false(&self) -> bool {
404 self.true_count() == 0
405 }
406
407 #[inline]
409 pub fn density(&self) -> f64 {
410 match &self {
411 Self::AllTrue(_) => 1.0,
412 Self::AllFalse(_) => 0.0,
413 Self::Values(values) => values.density,
414 }
415 }
416
417 #[inline]
423 pub fn value(&self, idx: usize) -> bool {
424 match self {
425 Mask::AllTrue(_) => true,
426 Mask::AllFalse(_) => false,
427 Mask::Values(values) => values.buffer.value(idx),
428 }
429 }
430
431 pub fn first(&self) -> Option<usize> {
433 match &self {
434 Self::AllTrue(len) => (*len > 0).then_some(0),
435 Self::AllFalse(_) => None,
436 Self::Values(values) => {
437 if let Some(indices) = values.indices.get() {
438 return indices.first().copied();
439 }
440 if let Some(slices) = values.slices.get() {
441 return slices.first().map(|(start, _)| *start);
442 }
443 values.buffer.set_indices().next()
444 }
445 }
446 }
447
448 pub fn last(&self) -> Option<usize> {
450 match &self {
451 Self::AllTrue(len) => (*len > 0).then_some(*len - 1),
452 Self::AllFalse(_) => None,
453 Self::Values(values) => {
454 if let Some(indices) = values.indices.get() {
455 return indices.last().copied();
456 }
457 if let Some(slices) = values.slices.get() {
458 return slices.last().map(|(_, end)| end - 1);
459 }
460
461 if values.true_count == 0 {
462 return None;
463 }
464
465 Some(
466 values
467 .buffer
468 .select(values.true_count - 1)
469 .unwrap_or_else(|| {
470 vortex_panic!(
471 "Rank {} out of bounds for mask with true count {}",
472 values.true_count - 1,
473 values.true_count
474 )
475 }),
476 )
477 }
478 }
479 }
480
481 pub fn rank(&self, n: usize) -> usize {
483 if n >= self.true_count() {
484 vortex_panic!(
485 "Rank {n} out of bounds for mask with true count {}",
486 self.true_count()
487 );
488 }
489 match &self {
490 Self::AllTrue(_) => n,
491 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
492 Self::Values(values) => {
493 if let Some(indices) = values.indices.get() {
494 return indices[n];
495 }
496
497 values.buffer.select(n).unwrap_or_else(|| {
498 vortex_panic!(
499 "Rank {} out of bounds for mask with true count {}",
500 values.true_count - 1,
501 values.true_count
502 )
503 })
504 }
505 }
506 }
507
508 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
510 let start = match range.start_bound() {
511 Bound::Included(&s) => s,
512 Bound::Excluded(&s) => s + 1,
513 Bound::Unbounded => 0,
514 };
515 let end = match range.end_bound() {
516 Bound::Included(&e) => e + 1,
517 Bound::Excluded(&e) => e,
518 Bound::Unbounded => self.len(),
519 };
520
521 assert!(start <= end);
522 assert!(start <= self.len());
523 assert!(end <= self.len());
524 let len = end - start;
525
526 match &self {
527 Self::AllTrue(_) => Self::new_true(len),
528 Self::AllFalse(_) => Self::new_false(len),
529 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
530 }
531 }
532
533 #[inline]
535 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
536 match &self {
537 Self::AllTrue(_) => AllOr::All,
538 Self::AllFalse(_) => AllOr::None,
539 Self::Values(values) => AllOr::Some(&values.buffer),
540 }
541 }
542
543 #[inline]
546 pub fn to_bit_buffer(&self) -> BitBuffer {
547 match self {
548 Self::AllTrue(l) => BitBuffer::new_set(*l),
549 Self::AllFalse(l) => BitBuffer::new_unset(*l),
550 Self::Values(values) => values.bit_buffer().clone(),
551 }
552 }
553
554 #[inline]
557 pub fn into_bit_buffer(self) -> BitBuffer {
558 match self {
559 Self::AllTrue(l) => BitBuffer::new_set(l),
560 Self::AllFalse(l) => BitBuffer::new_unset(l),
561 Self::Values(values) => Arc::try_unwrap(values)
562 .map(|v| v.into_bit_buffer())
563 .unwrap_or_else(|v| v.bit_buffer().clone()),
564 }
565 }
566
567 #[inline]
569 pub fn indices(&self) -> AllOr<&[usize]> {
570 match &self {
571 Self::AllTrue(_) => AllOr::All,
572 Self::AllFalse(_) => AllOr::None,
573 Self::Values(values) => AllOr::Some(values.indices()),
574 }
575 }
576
577 #[inline]
579 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
580 match &self {
581 Self::AllTrue(_) => AllOr::All,
582 Self::AllFalse(_) => AllOr::None,
583 Self::Values(values) => AllOr::Some(values.slices()),
584 }
585 }
586
587 #[inline]
589 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
590 match &self {
591 Self::AllTrue(_) => AllOr::All,
592 Self::AllFalse(_) => AllOr::None,
593 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
594 }
595 }
596
597 #[inline]
599 pub fn values(&self) -> Option<&MaskValues> {
600 if let Self::Values(values) = self {
601 Some(values)
602 } else {
603 None
604 }
605 }
606
607 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
612 match self {
613 Self::AllTrue(_) => indices.to_vec(),
614 Self::AllFalse(_) => vec![0; indices.len()],
615 Self::Values(values) => {
616 let mut bool_iter = values.bit_buffer().iter();
617 let mut valid_counts = Vec::with_capacity(indices.len());
618 let mut valid_count = 0;
619 let mut idx = 0;
620 for &next_idx in indices {
621 while idx < next_idx {
622 idx += 1;
623 valid_count += bool_iter
624 .next()
625 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
626 as usize;
627 }
628 valid_counts.push(valid_count);
629 }
630
631 valid_counts
632 }
633 }
634 }
635
636 pub fn limit(self, limit: usize) -> Self {
638 if self.len() <= limit {
642 return self;
643 }
644
645 match &self {
646 Mask::AllTrue(len) => {
647 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
648 }
649 Mask::AllFalse(_) => self,
650 Mask::Values(mask_values) => {
651 if limit >= mask_values.true_count() {
652 return self;
653 }
654
655 let existing_buffer = mask_values.bit_buffer();
656
657 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
658 debug_assert!(limit < mask_values.len());
659
660 for index in existing_buffer.set_indices().take(limit) {
661 unsafe { new_buffer_builder.set_unchecked(index) }
664 }
665
666 Self::from(new_buffer_builder.freeze())
667 }
668 }
669 }
670
671 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
673 let masks: Vec<_> = masks.collect();
674 let len = masks.iter().map(|t| t.len()).sum();
675
676 if masks.iter().all(|t| t.all_true()) {
677 return Ok(Mask::AllTrue(len));
678 }
679
680 if masks.iter().all(|t| t.all_false()) {
681 return Ok(Mask::AllFalse(len));
682 }
683
684 let mut builder = BitBufferMut::with_capacity(len);
685
686 for mask in masks {
687 match mask {
688 Mask::AllTrue(n) => builder.append_n(true, *n),
689 Mask::AllFalse(n) => builder.append_n(false, *n),
690 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
691 }
692 }
693
694 Ok(Mask::from_buffer(builder.freeze()))
695 }
696}
697
698impl MaskValues {
699 #[inline]
701 pub fn len(&self) -> usize {
702 self.buffer.len()
703 }
704
705 #[inline]
707 pub fn is_empty(&self) -> bool {
708 self.buffer.is_empty()
709 }
710
711 #[inline]
713 pub fn density(&self) -> f64 {
714 self.density
715 }
716
717 #[inline]
719 pub fn true_count(&self) -> usize {
720 self.true_count
721 }
722
723 #[inline]
725 pub fn bit_buffer(&self) -> &BitBuffer {
726 &self.buffer
727 }
728
729 #[inline]
731 pub fn into_bit_buffer(self) -> BitBuffer {
732 self.buffer
733 }
734
735 #[inline]
737 pub fn value(&self, index: usize) -> bool {
738 self.buffer.value(index)
739 }
740
741 pub fn indices(&self) -> &[usize] {
743 self.indices.get_or_init(|| {
744 if self.true_count == 0 {
745 return vec![];
746 }
747
748 if self.true_count == self.len() {
749 return (0..self.len()).collect();
750 }
751
752 if let Some(slices) = self.slices.get() {
753 let mut indices = Vec::with_capacity(self.true_count);
754 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
755 debug_assert!(indices.is_sorted());
756 assert_eq!(indices.len(), self.true_count);
757 return indices;
758 }
759
760 let mut indices = Vec::with_capacity(self.true_count);
761 indices.extend(self.buffer.set_indices());
762 debug_assert!(indices.is_sorted());
763 assert_eq!(indices.len(), self.true_count);
764 indices
765 })
766 }
767
768 #[inline]
773 pub fn cached_indices(&self) -> Option<&[usize]> {
774 self.indices.get().map(Vec::as_slice)
775 }
776
777 #[inline]
779 pub fn slices(&self) -> &[(usize, usize)] {
780 self.slices.get_or_init(|| {
781 if self.true_count == self.len() {
782 return vec![(0, self.len())];
783 }
784
785 self.buffer.set_slices().collect()
786 })
787 }
788
789 #[inline]
794 pub fn cached_slices(&self) -> Option<&[(usize, usize)]> {
795 self.slices.get().map(Vec::as_slice)
796 }
797
798 #[inline]
800 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
801 if self.density >= threshold {
802 MaskIter::Slices(self.slices())
803 } else {
804 MaskIter::Indices(self.indices())
805 }
806 }
807}
808
809pub enum MaskIter<'a> {
811 Indices(&'a [usize]),
813 Slices(&'a [(usize, usize)]),
815}
816
817impl From<BitBuffer> for Mask {
818 fn from(value: BitBuffer) -> Self {
819 Self::from_buffer(value)
820 }
821}
822
823impl FromIterator<bool> for Mask {
824 #[inline]
825 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
826 Self::from_buffer(BitBuffer::from_iter(iter))
827 }
828}
829
830impl FromIterator<Mask> for Mask {
831 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
832 let masks = iter
833 .into_iter()
834 .filter(|m| !m.is_empty())
835 .collect::<Vec<_>>();
836 let total_length = masks.iter().map(|v| v.len()).sum();
837
838 if masks.iter().all(|v| v.all_true()) {
840 return Self::AllTrue(total_length);
841 }
842 if masks.iter().all(|v| v.all_false()) {
844 return Self::AllFalse(total_length);
845 }
846
847 let mut buffer = BitBufferMut::with_capacity(total_length);
849 for mask in masks {
850 match mask {
851 Mask::AllTrue(count) => buffer.append_n(true, count),
852 Mask::AllFalse(count) => buffer.append_n(false, count),
853 Mask::Values(values) => {
854 buffer.append_buffer(values.bit_buffer());
855 }
856 };
857 }
858 Self::from_buffer(buffer.freeze())
859 }
860}