1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10
11#[cfg(test)]
12mod tests;
13
14use std::cmp::Ordering;
15use std::fmt::Debug;
16use std::fmt::Formatter;
17use std::ops::Bound;
18use std::ops::RangeBounds;
19use std::sync::Arc;
20use std::sync::OnceLock;
21
22use itertools::Itertools;
23use vortex_buffer::BitBuffer;
24use vortex_buffer::BitBufferMut;
25use vortex_error::VortexResult;
26use vortex_error::vortex_panic;
27
28pub enum AllOr<T> {
30 All,
32 None,
34 Some(T),
36}
37
38impl<T> AllOr<T> {
39 #[inline]
41 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
42 where
43 F: FnOnce() -> T,
44 G: FnOnce() -> T,
45 {
46 match self {
47 Self::Some(v) => v,
48 AllOr::All => all_true(),
49 AllOr::None => all_false(),
50 }
51 }
52}
53
54impl<T> AllOr<&T> {
55 #[inline]
57 pub fn cloned(self) -> AllOr<T>
58 where
59 T: Clone,
60 {
61 match self {
62 Self::All => AllOr::All,
63 Self::None => AllOr::None,
64 Self::Some(v) => AllOr::Some(v.clone()),
65 }
66 }
67}
68
69impl<T> Debug for AllOr<T>
70where
71 T: Debug,
72{
73 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
74 match self {
75 Self::All => f.write_str("All"),
76 Self::None => f.write_str("None"),
77 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
78 }
79 }
80}
81
82impl<T> PartialEq for AllOr<T>
83where
84 T: PartialEq,
85{
86 fn eq(&self, other: &Self) -> bool {
87 match (self, other) {
88 (Self::All, Self::All) => true,
89 (Self::None, Self::None) => true,
90 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
91 _ => false,
92 }
93 }
94}
95
96impl<T> Eq for AllOr<T> where T: Eq {}
97
98#[derive(Clone)]
104#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
105pub enum Mask {
106 AllTrue(usize),
108 AllFalse(usize),
110 Values(Arc<MaskValues>),
112}
113
114impl Debug for Mask {
115 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 match self {
117 Self::AllTrue(len) => write!(f, "All true({len})"),
118 Self::AllFalse(len) => write!(f, "All false({len})"),
119 Self::Values(mask) => write!(f, "{mask:?}"),
120 }
121 }
122}
123
124impl Default for Mask {
125 fn default() -> Self {
126 Self::new_true(0)
127 }
128}
129
130#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
132pub struct MaskValues {
133 buffer: BitBuffer,
134
135 #[cfg_attr(feature = "serde", serde(skip))]
138 indices: OnceLock<Vec<usize>>,
139 #[cfg_attr(feature = "serde", serde(skip))]
140 slices: OnceLock<Vec<(usize, usize)>>,
141
142 true_count: usize,
144 density: f64,
146}
147
148impl Debug for MaskValues {
149 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150 write!(f, "true_count={}, ", self.true_count)?;
151 write!(f, "density={}, ", self.density)?;
152 if let Some(v) = self.indices.get() {
153 write!(f, "indices={v:?}, ")?;
154 }
155 if let Some(v) = self.slices.get() {
156 write!(f, "slices={v:?}, ")?;
157 }
158 if f.alternate() {
159 f.write_str("\n")?;
160 }
161 write!(f, "{}", self.buffer)
162 }
163}
164
165impl Mask {
166 pub fn new(length: usize, value: bool) -> Self {
168 if value {
169 Self::AllTrue(length)
170 } else {
171 Self::AllFalse(length)
172 }
173 }
174
175 #[inline]
177 pub fn new_true(length: usize) -> Self {
178 Self::AllTrue(length)
179 }
180
181 #[inline]
183 pub fn new_false(length: usize) -> Self {
184 Self::AllFalse(length)
185 }
186
187 pub fn from_buffer(buffer: BitBuffer) -> Self {
189 let len = buffer.len();
190 let true_count = buffer.true_count();
191
192 if true_count == 0 {
193 return Self::AllFalse(len);
194 }
195 if true_count == len {
196 return Self::AllTrue(len);
197 }
198
199 Self::Values(Arc::new(MaskValues {
200 buffer,
201 indices: Default::default(),
202 slices: Default::default(),
203 true_count,
204 density: true_count as f64 / len as f64,
205 }))
206 }
207
208 pub fn from_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
210 let indices = indices.into_iter().collect::<Vec<_>>();
211 assert!(indices.is_sorted(), "Mask indices must be sorted");
212 assert!(
213 indices.windows(2).all(|w| w[0] != w[1]),
214 "Mask indices must be unique"
215 );
216 let buffer = BitBuffer::from_indices(len, indices.iter().copied());
217 debug_assert_eq!(buffer.len(), len);
218 let true_count = buffer.true_count();
219
220 if true_count == 0 {
221 return Self::AllFalse(len);
222 }
223 if true_count == len {
224 return Self::AllTrue(len);
225 }
226
227 Self::Values(Arc::new(MaskValues {
228 buffer,
229 indices: OnceLock::from(indices),
230 slices: Default::default(),
231 true_count,
232 density: true_count as f64 / len as f64,
233 }))
234 }
235
236 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
238 let mut buf = BitBufferMut::new_set(len);
239
240 let mut false_count: usize = 0;
241 indices.into_iter().for_each(|idx| {
242 buf.unset(idx);
243 false_count += 1;
244 });
245 debug_assert_eq!(buf.len(), len);
246 let true_count = len - false_count;
247
248 if false_count == 0 {
250 return Self::AllTrue(len);
251 }
252 if false_count == len {
253 return Self::AllFalse(len);
254 }
255
256 Self::Values(Arc::new(MaskValues {
257 buffer: buf.freeze(),
258 indices: Default::default(),
259 slices: Default::default(),
260 true_count,
261 density: true_count as f64 / len as f64,
262 }))
263 }
264
265 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
268 Self::check_slices(len, &vec);
269 Self::from_slices_unchecked(len, vec)
270 }
271
272 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
273 #[cfg(debug_assertions)]
274 Self::check_slices(len, &slices);
275
276 let true_count = slices.iter().map(|(b, e)| e - b).sum();
277 if true_count == 0 {
278 return Self::AllFalse(len);
279 }
280 if true_count == len {
281 return Self::AllTrue(len);
282 }
283
284 let mut buf = BitBufferMut::with_capacity(len);
285 let mut cursor = 0;
286 for (start, end) in slices.iter().copied() {
287 buf.append_n(false, start - cursor);
288 buf.append_n(true, end - start);
289 cursor = end;
290 }
291 buf.append_n(false, len - cursor);
292 debug_assert_eq!(buf.len(), len);
293
294 Self::Values(Arc::new(MaskValues {
295 buffer: buf.freeze(),
296 indices: Default::default(),
297 slices: OnceLock::from(slices),
298 true_count,
299 density: true_count as f64 / len as f64,
300 }))
301 }
302
303 #[inline(always)]
304 fn check_slices(len: usize, vec: &[(usize, usize)]) {
305 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
306 for (first, second) in vec.iter().tuple_windows() {
307 assert!(
308 first.0 < second.0,
309 "Slices must be sorted, got {first:?} and {second:?}"
310 );
311 assert!(
312 first.1 <= second.0,
313 "Slices must be non-overlapping, got {first:?} and {second:?}"
314 );
315 }
316 }
317
318 pub fn from_intersection_indices(
320 len: usize,
321 lhs: impl Iterator<Item = usize>,
322 rhs: impl Iterator<Item = usize>,
323 ) -> Self {
324 let mut intersection = Vec::with_capacity(len);
325 let mut lhs = lhs.peekable();
326 let mut rhs = rhs.peekable();
327 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
328 match l.cmp(&r) {
329 Ordering::Less => {
330 lhs.next();
331 }
332 Ordering::Greater => {
333 rhs.next();
334 }
335 Ordering::Equal => {
336 intersection.push(l);
337 lhs.next();
338 rhs.next();
339 }
340 }
341 }
342 Self::from_indices(len, intersection)
343 }
344
345 pub fn clear(&mut self) {
347 *self = Self::new_false(0);
348 }
349
350 #[inline]
352 pub fn len(&self) -> usize {
353 match self {
354 Self::AllTrue(len) => *len,
355 Self::AllFalse(len) => *len,
356 Self::Values(values) => values.len(),
357 }
358 }
359
360 #[inline]
362 pub fn is_empty(&self) -> bool {
363 match self {
364 Self::AllTrue(len) => *len == 0,
365 Self::AllFalse(len) => *len == 0,
366 Self::Values(values) => values.is_empty(),
367 }
368 }
369
370 #[inline]
372 pub fn true_count(&self) -> usize {
373 match &self {
374 Self::AllTrue(len) => *len,
375 Self::AllFalse(_) => 0,
376 Self::Values(values) => values.true_count,
377 }
378 }
379
380 #[inline]
382 pub fn false_count(&self) -> usize {
383 match &self {
384 Self::AllTrue(_) => 0,
385 Self::AllFalse(len) => *len,
386 Self::Values(values) => values.buffer.len() - values.true_count,
387 }
388 }
389
390 #[inline]
392 pub fn all_true(&self) -> bool {
393 match &self {
394 Self::AllTrue(_) => true,
395 Self::AllFalse(0) => true,
396 Self::AllFalse(_) => false,
397 Self::Values(values) => values.buffer.len() == values.true_count,
398 }
399 }
400
401 #[inline]
403 pub fn all_false(&self) -> bool {
404 self.true_count() == 0
405 }
406
407 #[inline]
409 pub fn density(&self) -> f64 {
410 match &self {
411 Self::AllTrue(_) => 1.0,
412 Self::AllFalse(_) => 0.0,
413 Self::Values(values) => values.density,
414 }
415 }
416
417 #[inline]
423 pub fn value(&self, idx: usize) -> bool {
424 match self {
425 Mask::AllTrue(_) => true,
426 Mask::AllFalse(_) => false,
427 Mask::Values(values) => values.buffer.value(idx),
428 }
429 }
430
431 pub fn first(&self) -> Option<usize> {
433 match &self {
434 Self::AllTrue(len) => (*len > 0).then_some(0),
435 Self::AllFalse(_) => None,
436 Self::Values(values) => {
437 if let Some(indices) = values.indices.get() {
438 return indices.first().copied();
439 }
440 if let Some(slices) = values.slices.get() {
441 return slices.first().map(|(start, _)| *start);
442 }
443 values.buffer.set_indices().next()
444 }
445 }
446 }
447
448 pub fn last(&self) -> Option<usize> {
450 match &self {
451 Self::AllTrue(len) => (*len > 0).then_some(*len - 1),
452 Self::AllFalse(_) => None,
453 Self::Values(values) => {
454 if let Some(indices) = values.indices.get() {
455 return indices.last().copied();
456 }
457 if let Some(slices) = values.slices.get() {
458 return slices.last().map(|(_, end)| end - 1);
459 }
460 values.buffer.set_slices().last().map(|(_, end)| end - 1)
461 }
462 }
463 }
464
465 pub fn rank(&self, n: usize) -> usize {
467 if n >= self.true_count() {
468 vortex_panic!(
469 "Rank {n} out of bounds for mask with true count {}",
470 self.true_count()
471 );
472 }
473 match &self {
474 Self::AllTrue(_) => n,
475 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
476 Self::Values(values) => values.indices()[n],
478 }
479 }
480
481 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
483 let start = match range.start_bound() {
484 Bound::Included(&s) => s,
485 Bound::Excluded(&s) => s + 1,
486 Bound::Unbounded => 0,
487 };
488 let end = match range.end_bound() {
489 Bound::Included(&e) => e + 1,
490 Bound::Excluded(&e) => e,
491 Bound::Unbounded => self.len(),
492 };
493
494 assert!(start <= end);
495 assert!(start <= self.len());
496 assert!(end <= self.len());
497 let len = end - start;
498
499 match &self {
500 Self::AllTrue(_) => Self::new_true(len),
501 Self::AllFalse(_) => Self::new_false(len),
502 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
503 }
504 }
505
506 #[inline]
508 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
509 match &self {
510 Self::AllTrue(_) => AllOr::All,
511 Self::AllFalse(_) => AllOr::None,
512 Self::Values(values) => AllOr::Some(&values.buffer),
513 }
514 }
515
516 #[inline]
519 pub fn to_bit_buffer(&self) -> BitBuffer {
520 match self {
521 Self::AllTrue(l) => BitBuffer::new_set(*l),
522 Self::AllFalse(l) => BitBuffer::new_unset(*l),
523 Self::Values(values) => values.bit_buffer().clone(),
524 }
525 }
526
527 #[inline]
530 pub fn into_bit_buffer(self) -> BitBuffer {
531 match self {
532 Self::AllTrue(l) => BitBuffer::new_set(l),
533 Self::AllFalse(l) => BitBuffer::new_unset(l),
534 Self::Values(values) => Arc::try_unwrap(values)
535 .map(|v| v.into_bit_buffer())
536 .unwrap_or_else(|v| v.bit_buffer().clone()),
537 }
538 }
539
540 #[inline]
542 pub fn indices(&self) -> AllOr<&[usize]> {
543 match &self {
544 Self::AllTrue(_) => AllOr::All,
545 Self::AllFalse(_) => AllOr::None,
546 Self::Values(values) => AllOr::Some(values.indices()),
547 }
548 }
549
550 #[inline]
552 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
553 match &self {
554 Self::AllTrue(_) => AllOr::All,
555 Self::AllFalse(_) => AllOr::None,
556 Self::Values(values) => AllOr::Some(values.slices()),
557 }
558 }
559
560 #[inline]
562 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
563 match &self {
564 Self::AllTrue(_) => AllOr::All,
565 Self::AllFalse(_) => AllOr::None,
566 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
567 }
568 }
569
570 #[inline]
572 pub fn values(&self) -> Option<&MaskValues> {
573 if let Self::Values(values) = self {
574 Some(values)
575 } else {
576 None
577 }
578 }
579
580 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
585 match self {
586 Self::AllTrue(_) => indices.to_vec(),
587 Self::AllFalse(_) => vec![0; indices.len()],
588 Self::Values(values) => {
589 let mut bool_iter = values.bit_buffer().iter();
590 let mut valid_counts = Vec::with_capacity(indices.len());
591 let mut valid_count = 0;
592 let mut idx = 0;
593 for &next_idx in indices {
594 while idx < next_idx {
595 idx += 1;
596 valid_count += bool_iter
597 .next()
598 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
599 as usize;
600 }
601 valid_counts.push(valid_count);
602 }
603
604 valid_counts
605 }
606 }
607 }
608
609 pub fn limit(self, limit: usize) -> Self {
611 if self.len() <= limit {
615 return self;
616 }
617
618 match &self {
619 Mask::AllTrue(len) => {
620 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
621 }
622 Mask::AllFalse(_) => self,
623 Mask::Values(mask_values) => {
624 if limit >= mask_values.true_count() {
625 return self;
626 }
627
628 let existing_buffer = mask_values.bit_buffer();
629
630 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
631 debug_assert!(limit < mask_values.len());
632
633 for index in existing_buffer.set_indices().take(limit) {
634 unsafe { new_buffer_builder.set_unchecked(index) }
637 }
638
639 Self::from(new_buffer_builder.freeze())
640 }
641 }
642 }
643
644 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
646 let masks: Vec<_> = masks.collect();
647 let len = masks.iter().map(|t| t.len()).sum();
648
649 if masks.iter().all(|t| t.all_true()) {
650 return Ok(Mask::AllTrue(len));
651 }
652
653 if masks.iter().all(|t| t.all_false()) {
654 return Ok(Mask::AllFalse(len));
655 }
656
657 let mut builder = BitBufferMut::with_capacity(len);
658
659 for mask in masks {
660 match mask {
661 Mask::AllTrue(n) => builder.append_n(true, *n),
662 Mask::AllFalse(n) => builder.append_n(false, *n),
663 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
664 }
665 }
666
667 Ok(Mask::from_buffer(builder.freeze()))
668 }
669}
670
671impl MaskValues {
672 #[inline]
674 pub fn len(&self) -> usize {
675 self.buffer.len()
676 }
677
678 #[inline]
680 pub fn is_empty(&self) -> bool {
681 self.buffer.is_empty()
682 }
683
684 #[inline]
686 pub fn density(&self) -> f64 {
687 self.density
688 }
689
690 #[inline]
692 pub fn true_count(&self) -> usize {
693 self.true_count
694 }
695
696 #[inline]
698 pub fn bit_buffer(&self) -> &BitBuffer {
699 &self.buffer
700 }
701
702 #[inline]
704 pub fn into_bit_buffer(self) -> BitBuffer {
705 self.buffer
706 }
707
708 #[inline]
710 pub fn value(&self, index: usize) -> bool {
711 self.buffer.value(index)
712 }
713
714 pub fn indices(&self) -> &[usize] {
716 self.indices.get_or_init(|| {
717 if self.true_count == 0 {
718 return vec![];
719 }
720
721 if self.true_count == self.len() {
722 return (0..self.len()).collect();
723 }
724
725 if let Some(slices) = self.slices.get() {
726 let mut indices = Vec::with_capacity(self.true_count);
727 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
728 debug_assert!(indices.is_sorted());
729 assert_eq!(indices.len(), self.true_count);
730 return indices;
731 }
732
733 let mut indices = Vec::with_capacity(self.true_count);
734 indices.extend(self.buffer.set_indices());
735 debug_assert!(indices.is_sorted());
736 assert_eq!(indices.len(), self.true_count);
737 indices
738 })
739 }
740
741 #[inline]
743 pub fn slices(&self) -> &[(usize, usize)] {
744 self.slices.get_or_init(|| {
745 if self.true_count == self.len() {
746 return vec![(0, self.len())];
747 }
748
749 self.buffer.set_slices().collect()
750 })
751 }
752
753 #[inline]
755 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
756 if self.density >= threshold {
757 MaskIter::Slices(self.slices())
758 } else {
759 MaskIter::Indices(self.indices())
760 }
761 }
762}
763
764pub enum MaskIter<'a> {
766 Indices(&'a [usize]),
768 Slices(&'a [(usize, usize)]),
770}
771
772impl From<BitBuffer> for Mask {
773 fn from(value: BitBuffer) -> Self {
774 Self::from_buffer(value)
775 }
776}
777
778impl FromIterator<bool> for Mask {
779 #[inline]
780 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
781 Self::from_buffer(BitBuffer::from_iter(iter))
782 }
783}
784
785impl FromIterator<Mask> for Mask {
786 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
787 let masks = iter
788 .into_iter()
789 .filter(|m| !m.is_empty())
790 .collect::<Vec<_>>();
791 let total_length = masks.iter().map(|v| v.len()).sum();
792
793 if masks.iter().all(|v| v.all_true()) {
795 return Self::AllTrue(total_length);
796 }
797 if masks.iter().all(|v| v.all_false()) {
799 return Self::AllFalse(total_length);
800 }
801
802 let mut buffer = BitBufferMut::with_capacity(total_length);
804 for mask in masks {
805 match mask {
806 Mask::AllTrue(count) => buffer.append_n(true, count),
807 Mask::AllFalse(count) => buffer.append_n(false, count),
808 Mask::Values(values) => {
809 buffer.append_buffer(values.bit_buffer());
810 }
811 };
812 }
813 Self::from_buffer(buffer.freeze())
814 }
815}