1use std::sync::Arc;
5
6use vortex_buffer::BitBufferMut;
7use vortex_error::vortex_panic;
8
9use crate::Mask;
10
11#[derive(Debug, Clone)]
13pub struct MaskMut(Inner);
14
15impl Default for MaskMut {
16 fn default() -> Self {
17 Self::empty()
18 }
19}
20
21#[derive(Debug, Clone)]
22enum Inner {
23 Empty { capacity: usize },
25 Constant {
27 value: bool,
28 len: usize,
29 capacity: usize,
30 },
31 Builder(BitBufferMut),
34}
35
36impl MaskMut {
37 pub fn empty() -> Self {
39 Self::with_capacity(0)
40 }
41
42 pub fn with_capacity(capacity: usize) -> Self {
44 Self(Inner::Empty { capacity })
45 }
46
47 pub fn new(len: usize, value: bool) -> Self {
49 Self(Inner::Constant {
50 value,
51 len,
52 capacity: len,
53 })
54 }
55
56 pub fn new_true(len: usize) -> Self {
58 Self(Inner::Constant {
59 value: true,
60 len,
61 capacity: len,
62 })
63 }
64
65 pub fn new_false(len: usize) -> Self {
67 Self(Inner::Constant {
68 value: false,
69 len,
70 capacity: len,
71 })
72 }
73
74 pub fn from_buffer(bit_buffer: BitBufferMut) -> Self {
76 Self(Inner::Builder(bit_buffer))
77 }
78
79 pub fn value(&self, index: usize) -> bool {
85 match &self.0 {
86 Inner::Empty { .. } => {
87 vortex_panic!("index out of bounds: the length is 0 but the index is {index}")
88 }
89 Inner::Constant { value, len, .. } => {
90 assert!(
91 index < *len,
92 "index out of bounds: the length is {} but the index is {index}",
93 *len
94 );
95
96 *value
97 }
98 Inner::Builder(bit_buffer) => bit_buffer.value(index),
99 }
100 }
101
102 pub fn reserve(&mut self, additional: usize) {
104 match &mut self.0 {
105 Inner::Empty { capacity } => {
106 *capacity += additional;
107 }
108 Inner::Constant { capacity, .. } => {
109 *capacity += additional;
110 }
111 Inner::Builder(bits) => {
112 bits.reserve(additional);
113 }
114 }
115 }
116
117 pub unsafe fn set_len(&mut self, new_len: usize) {
126 debug_assert!(new_len < self.capacity());
127 match &mut self.0 {
128 Inner::Empty { capacity, .. } => {
129 self.0 = Inner::Constant {
130 value: false, len: new_len,
132 capacity: *capacity,
133 }
134 }
135 Inner::Constant { len, .. } => {
136 *len = new_len;
137 }
138 Inner::Builder(bits) => {
139 unsafe { bits.set_len(new_len) };
140 }
141 }
142 }
143
144 pub fn capacity(&self) -> usize {
146 match &self.0 {
147 Inner::Empty { capacity } => *capacity,
148 Inner::Constant { capacity, .. } => *capacity,
149 Inner::Builder(bits) => bits.capacity(),
150 }
151 }
152
153 pub fn clear(&mut self) {
157 match &mut self.0 {
158 Inner::Empty { .. } => {}
159 Inner::Constant { capacity, .. } => {
160 self.0 = Inner::Empty {
161 capacity: *capacity,
162 }
163 }
164 Inner::Builder(bit_buffer) => bit_buffer.clear(),
165 };
166 }
167
168 pub fn truncate(&mut self, len: usize) {
174 let truncated_len = len;
175 if truncated_len > self.len() {
176 return;
177 }
178
179 match &mut self.0 {
180 Inner::Empty { .. } => {}
181 Inner::Constant { len, .. } => *len = truncated_len.min(*len),
182 Inner::Builder(bit_buffer) => bit_buffer.truncate(truncated_len),
183 };
184 }
185
186 pub fn append_n(&mut self, new_value: bool, n: usize) {
188 match &mut self.0 {
189 Inner::Empty { capacity } => {
190 self.0 = Inner::Constant {
191 value: new_value,
192 len: n,
193 capacity: (*capacity).max(n),
194 }
195 }
196 Inner::Constant {
197 value,
198 len,
199 capacity,
200 } => {
201 if *value == new_value {
202 self.0 = Inner::Constant {
204 value: *value,
205 len: *len + n,
206 capacity: (*capacity).max(*len + n),
207 }
208 } else {
209 let bits = self.materialize();
212 bits.append_n(new_value, n);
213 }
214 }
215 Inner::Builder(bits) => {
216 bits.append_n(new_value, n);
217 }
218 }
219 }
220
221 pub fn append_mask(&mut self, other: &Mask) {
223 match other {
224 Mask::AllTrue(len) => self.append_n(true, *len),
225 Mask::AllFalse(len) => self.append_n(false, *len),
226 Mask::Values(values) => {
227 let bitbuffer = values.buffer.clone();
228 self.materialize().append_buffer(&bitbuffer);
229 }
230 }
231 }
232
233 fn materialize(&mut self) -> &mut BitBufferMut {
235 let needs_materialization = !matches!(self.0, Inner::Builder(_));
236
237 if needs_materialization {
238 let new_builder = match &self.0 {
239 Inner::Empty { capacity } => BitBufferMut::with_capacity(*capacity),
240 Inner::Constant {
241 value,
242 len,
243 capacity,
244 } => {
245 let required_capacity = (*capacity).max(*len);
246 let mut bits = BitBufferMut::with_capacity(required_capacity);
247 bits.append_n(*value, *len);
248 bits
249 }
250 Inner::Builder(_) => unreachable!(),
251 };
252 self.0 = Inner::Builder(new_builder);
253 }
254
255 match &mut self.0 {
256 Inner::Builder(bits) => bits,
257 _ => unreachable!(),
258 }
259 }
260
261 pub fn split_off(&mut self, at: usize) -> Self {
265 assert!(at <= self.capacity(), "split_off index out of bounds");
266 match &mut self.0 {
267 Inner::Empty { capacity } => {
268 let new_capacity = *capacity - at;
269 *capacity = at;
270 Self(Inner::Empty {
271 capacity: new_capacity,
272 })
273 }
274 Inner::Constant {
275 value,
276 len,
277 capacity,
278 } => {
279 let new_len = len.saturating_sub(at);
281 let new_capacity = *capacity - at;
282 *len = (*len).min(at);
283 *capacity = at;
284
285 Self(Inner::Constant {
286 value: *value,
287 len: new_len,
288 capacity: new_capacity,
289 })
290 }
291 Inner::Builder(bits) => {
292 let new_bits = bits.split_off(at);
293 Self(Inner::Builder(new_bits))
294 }
295 }
296 }
297
298 pub fn unsplit(&mut self, other: Self) {
300 match other.0 {
301 Inner::Empty { .. } => {
302 }
304 Inner::Constant { value, len, .. } => {
305 self.append_n(value, len);
306 }
307 Inner::Builder(bits) => {
308 self.materialize().unsplit(bits);
309 }
310 }
311 }
312
313 pub fn freeze(self) -> Mask {
315 match self.0 {
316 Inner::Empty { .. } => Mask::new_true(0),
317 Inner::Constant { value, len, .. } => {
318 if value {
319 Mask::new_true(len)
320 } else {
321 Mask::new_false(len)
322 }
323 }
324 Inner::Builder(bits) => Mask::from_buffer(bits.freeze()),
325 }
326 }
327
328 pub fn len(&self) -> usize {
330 match &self.0 {
331 Inner::Empty { .. } => 0,
332 Inner::Constant { len, .. } => *len,
333 Inner::Builder(bits) => bits.len(),
334 }
335 }
336
337 pub fn is_empty(&self) -> bool {
339 self.len() == 0
340 }
341
342 pub fn all_true(&self) -> bool {
344 match &self.0 {
345 Inner::Empty { .. } => true,
346 Inner::Constant { value, .. } => *value,
347 Inner::Builder(bits) => bits.true_count() == bits.len(),
348 }
349 }
350
351 pub fn all_false(&self) -> bool {
353 match &self.0 {
354 Inner::Empty { .. } => true,
355 Inner::Constant { value, .. } => !*value,
356 Inner::Builder(bits) => !bits.is_empty() && bits.true_count() == 0,
357 }
358 }
359
360 pub fn as_bit_buffer_mut(&mut self) -> Option<&mut BitBufferMut> {
362 match &mut self.0 {
363 Inner::Builder(bits) => Some(bits),
364 _ => None,
365 }
366 }
367
368 pub fn set(&mut self, index: usize) {
374 self.set_to(index, true);
375 }
376
377 pub fn unset(&mut self, index: usize) {
383 self.set_to(index, false);
384 }
385
386 pub fn set_to(&mut self, index: usize, value: bool) {
392 match &mut self.0 {
393 Inner::Empty { .. } => {
394 vortex_panic!("index out of bounds: the length is 0 but the index is {index}")
395 }
396 Inner::Constant {
397 value: current_value,
398 len,
399 ..
400 } => {
401 assert!(
402 index < *len,
403 "index out of bounds: the length is {} but the index is {index}",
404 *len
405 );
406
407 if *current_value != value {
408 self.materialize().set_to(index, value);
410 }
411 }
413 Inner::Builder(bit_buffer) => {
414 bit_buffer.set_to(index, value);
415 }
416 }
417 }
418
419 pub unsafe fn set_unchecked(&mut self, index: usize) {
425 unsafe { self.set_to_unchecked(index, true) }
426 }
427
428 pub unsafe fn unset_unchecked(&mut self, index: usize) {
434 unsafe { self.set_to_unchecked(index, false) }
435 }
436
437 pub unsafe fn set_to_unchecked(&mut self, index: usize, value: bool) {
443 unsafe {
444 match &mut self.0 {
445 Inner::Empty { .. } => {
446 debug_assert!(false, "cannot set value in empty mask");
448 }
449 Inner::Constant {
450 value: current_value,
451 len,
452 ..
453 } => {
454 debug_assert!(
455 index < *len,
456 "index out of bounds: the length is {} but the index is {index}",
457 *len
458 );
459
460 if *current_value != value {
461 self.materialize().set_to_unchecked(index, value);
463 }
464 }
466 Inner::Builder(bit_buffer) => {
467 bit_buffer.set_to_unchecked(index, value);
468 }
469 }
470 }
471 }
472}
473
474impl Mask {
475 pub fn try_into_mut(self) -> Result<MaskMut, Self> {
478 match self {
479 Mask::AllTrue(len) => Ok(MaskMut::new_true(len)),
480 Mask::AllFalse(len) => Ok(MaskMut::new_false(len)),
481 Mask::Values(values) => {
482 let owned_values = Arc::try_unwrap(values).map_err(Mask::Values)?;
485 let bit_buffer = owned_values.into_buffer();
486 let mut_buffer = bit_buffer.try_into_mut().map_err(Mask::from_buffer)?;
487
488 Ok(MaskMut(Inner::Builder(mut_buffer)))
489 }
490 }
491 }
492
493 pub fn into_mut(self) -> MaskMut {
496 match self {
497 Mask::AllTrue(len) => MaskMut::new_true(len),
498 Mask::AllFalse(len) => MaskMut::new_false(len),
499 Mask::Values(values) => {
500 let bit_buffer_mut = match Arc::try_unwrap(values) {
501 Ok(mask_values) => {
502 let bit_buffer = mask_values.into_buffer();
503 bit_buffer.into_mut()
504 }
505 Err(arc_mask_values) => {
506 let bit_buffer = arc_mask_values.bit_buffer();
507 BitBufferMut::copy_from(bit_buffer)
508 }
509 };
510
511 MaskMut(Inner::Builder(bit_buffer_mut))
512 }
513 }
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_split_off_empty() {
523 let mut mask = MaskMut::with_capacity(10);
524 assert_eq!(mask.len(), 0);
525
526 let other = mask.split_off(0);
527 assert_eq!(mask.len(), 0);
528 assert_eq!(other.len(), 0);
529 }
530
531 #[test]
532 fn test_split_off_constant_true_at_zero() {
533 let mut mask = MaskMut::new_true(10);
534 let other = mask.split_off(0);
535
536 assert_eq!(mask.len(), 0);
537 assert_eq!(other.len(), 10);
538
539 let frozen = other.freeze();
540 assert_eq!(frozen.true_count(), 10);
541 }
542
543 #[test]
544 fn test_split_off_constant_true_at_end() {
545 let mut mask = MaskMut::new_true(10);
546 let other = mask.split_off(10);
547
548 assert_eq!(mask.len(), 10);
549 assert_eq!(other.len(), 0);
550
551 let frozen = mask.freeze();
552 assert_eq!(frozen.true_count(), 10);
553 }
554
555 #[test]
556 fn test_split_off_constant_true_in_middle() {
557 let mut mask = MaskMut::new_true(10);
558 let other = mask.split_off(6);
559
560 assert_eq!(mask.len(), 6);
561 assert_eq!(other.len(), 4);
562
563 let frozen_first = mask.freeze();
564 assert_eq!(frozen_first.true_count(), 6);
565
566 let frozen_second = other.freeze();
567 assert_eq!(frozen_second.true_count(), 4);
568 }
569
570 #[test]
571 fn test_split_off_constant_false() {
572 let mut mask = MaskMut::new_false(20);
573 let other = mask.split_off(12);
574
575 assert_eq!(mask.len(), 12);
576 assert_eq!(other.len(), 8);
577
578 let frozen_first = mask.freeze();
579 assert_eq!(frozen_first.true_count(), 0);
580
581 let frozen_second = other.freeze();
582 assert_eq!(frozen_second.true_count(), 0);
583 }
584
585 #[test]
588 fn test_split_off_builder_at_byte_boundary() {
589 let mut mask = MaskMut::with_capacity(16);
590 mask.append_n(true, 8);
592 mask.append_n(false, 8);
593
594 let mask_ptr = match &mask.0 {
595 Inner::Builder(bits) => bits.as_slice().as_ptr(),
596 _ => unreachable!(),
597 };
598
599 let other = mask.split_off(8);
600
601 assert_eq!(mask.len(), 8);
602 assert_eq!(other.len(), 8);
603
604 mask.unsplit(other);
606 let new_mask_ptr = match &mask.0 {
607 Inner::Builder(bits) => bits.as_slice().as_ptr(),
608 _ => unreachable!(),
609 };
610 assert_eq!(mask_ptr, new_mask_ptr);
611 }
612
613 #[test]
614 fn test_split_off_builder_not_byte_aligned() {
615 let mut mask = MaskMut::with_capacity(20);
616 mask.append_n(true, 10);
618 mask.append_n(false, 10);
619
620 let other = mask.split_off(10);
621
622 assert_eq!(mask.len(), 10);
623 assert_eq!(other.len(), 10);
624
625 let frozen_first = mask.freeze();
626 assert_eq!(frozen_first.true_count(), 10);
627
628 let frozen_second = other.freeze();
629 assert_eq!(frozen_second.true_count(), 0);
630 }
631
632 #[test]
633 fn test_split_off_builder_mixed_pattern() {
634 let mut mask = MaskMut::with_capacity(15);
635 for i in 0..15 {
637 mask.append_n(i % 2 == 0, 1);
638 }
639
640 let other = mask.split_off(7);
641
642 assert_eq!(mask.len(), 7);
643 assert_eq!(other.len(), 8);
644
645 let frozen_first = mask.freeze();
646 assert_eq!(frozen_first.true_count(), 4); let frozen_second = other.freeze();
649 assert_eq!(frozen_second.true_count(), 4); }
651
652 #[test]
653 fn test_unsplit_empty_with_empty() {
654 let mut mask = MaskMut::with_capacity(10);
655 let other = MaskMut::with_capacity(10);
656
657 mask.unsplit(other);
658 assert_eq!(mask.len(), 0);
659 }
660
661 #[test]
662 fn test_unsplit_empty_with_constant() {
663 let mut mask = MaskMut::with_capacity(10);
664 let other = MaskMut::new_true(5);
665
666 mask.unsplit(other);
667 assert_eq!(mask.len(), 5);
668
669 let frozen = mask.freeze();
670 assert_eq!(frozen.true_count(), 5);
671 }
672
673 #[test]
674 fn test_unsplit_constant_with_constant_same() {
675 let mut mask = MaskMut::new_true(5);
676 let other = MaskMut::new_true(5);
677
678 mask.unsplit(other);
679 assert_eq!(mask.len(), 10);
680
681 let frozen = mask.freeze();
682 assert_eq!(frozen.true_count(), 10);
683 }
684
685 #[test]
686 fn test_unsplit_constant_with_constant_different() {
687 let mut mask = MaskMut::new_true(5);
688 let other = MaskMut::new_false(5);
689
690 mask.unsplit(other);
691 assert_eq!(mask.len(), 10);
692
693 let frozen = mask.freeze();
694 assert_eq!(frozen.true_count(), 5);
695 }
696
697 #[test]
698 fn test_unsplit_constant_with_builder() {
699 let mut mask = MaskMut::new_true(5);
700
701 let mut other = MaskMut::with_capacity(10);
702 other.append_n(true, 3);
703 other.append_n(false, 2);
704
705 mask.unsplit(other);
706 assert_eq!(mask.len(), 10);
707
708 let frozen = mask.freeze();
709 assert_eq!(frozen.true_count(), 8); }
711
712 #[test]
713 fn test_unsplit_builder_with_constant() {
714 let mut mask = MaskMut::with_capacity(10);
715 mask.append_n(true, 3);
716 mask.append_n(false, 2);
717
718 let other = MaskMut::new_true(5);
719
720 mask.unsplit(other);
721 assert_eq!(mask.len(), 10);
722
723 let frozen = mask.freeze();
724 assert_eq!(frozen.true_count(), 8); }
726
727 #[test]
728 fn test_unsplit_builder_with_builder() {
729 let mut mask = MaskMut::with_capacity(10);
730 mask.append_n(true, 3);
731 mask.append_n(false, 2);
732
733 let mut other = MaskMut::with_capacity(10);
734 other.append_n(false, 3);
735 other.append_n(true, 2);
736
737 mask.unsplit(other);
738 assert_eq!(mask.len(), 10);
739
740 let frozen = mask.freeze();
741 assert_eq!(frozen.true_count(), 5); }
743
744 #[test]
745 fn test_round_trip_split_unsplit() {
746 let mut original = MaskMut::with_capacity(20);
747 original.append_n(true, 10);
749 original.append_n(false, 10);
750
751 let original_frozen = original.freeze();
752 let original_true_count = original_frozen.true_count();
753
754 let mut mask = original_frozen.try_into_mut().unwrap();
756
757 let other = mask.split_off(10);
759
760 mask.unsplit(other);
762
763 assert_eq!(mask.len(), 20);
764 let frozen = mask.freeze();
765 assert_eq!(frozen.true_count(), original_true_count);
766 }
767
768 #[test]
769 #[should_panic(expected = "split_off index out of bounds")]
770 fn test_split_off_out_of_bounds() {
771 let mut mask = MaskMut::new_true(10);
772 mask.split_off(11);
773 }
774
775 #[test]
776 fn test_split_off_builder_at_bit_1() {
777 let mut mask = MaskMut::with_capacity(16);
778 mask.append_n(true, 16);
779
780 let other = mask.split_off(1);
781
782 assert_eq!(mask.len(), 1);
783 assert_eq!(other.len(), 15);
784
785 let frozen_first = mask.freeze();
786 assert_eq!(frozen_first.true_count(), 1);
787
788 let frozen_second = other.freeze();
789 assert_eq!(frozen_second.true_count(), 15);
790 }
791
792 #[test]
793 fn test_multiple_split_unsplit() {
794 let mut mask = MaskMut::new_true(30);
795
796 let third = mask.split_off(20); let second = mask.split_off(10); assert_eq!(mask.len(), 10);
802 assert_eq!(second.len(), 10);
803 assert_eq!(third.len(), 10);
804
805 mask.unsplit(second);
807 mask.unsplit(third);
808
809 assert_eq!(mask.len(), 30);
810 let frozen = mask.freeze();
811 assert_eq!(frozen.true_count(), 30);
812 }
813
814 #[test]
815 fn test_try_into_mut_all_variants() {
816 let mask_true = Mask::new_true(100);
818 let mut_mask_true = mask_true.try_into_mut().unwrap();
819 assert_eq!(mut_mask_true.len(), 100);
820 assert_eq!(mut_mask_true.freeze().true_count(), 100);
821
822 let mask_false = Mask::new_false(50);
823 let mut_mask_false = mask_false.try_into_mut().unwrap();
824 assert_eq!(mut_mask_false.len(), 50);
825 assert_eq!(mut_mask_false.freeze().true_count(), 0);
826 }
827
828 #[test]
829 fn test_try_into_mut_with_references() {
830 let mut mask_mut = MaskMut::with_capacity(10);
832 mask_mut.append_n(true, 5);
833 mask_mut.append_n(false, 5);
834 let mask = mask_mut.freeze();
835
836 let mask2 = {
838 let mut mask_mut2 = MaskMut::with_capacity(10);
839 mask_mut2.append_n(true, 5);
840 mask_mut2.append_n(false, 5);
841 mask_mut2.freeze()
842 };
843 let result = mask2.try_into_mut();
844 assert!(result.is_ok());
845 assert_eq!(result.unwrap().len(), 10);
846
847 let _cloned = mask.clone();
849 let result = mask.try_into_mut();
850 assert!(result.is_err());
851 if let Err(returned_mask) = result {
852 assert_eq!(returned_mask.len(), 10);
853 assert_eq!(returned_mask.true_count(), 5);
854 }
855 }
856
857 #[test]
858 fn test_try_into_mut_round_trip() {
859 let mut original = MaskMut::with_capacity(20);
861 original.append_n(true, 10);
862 original.append_n(false, 10);
863
864 let frozen = original.freeze();
865 assert_eq!(frozen.true_count(), 10);
866
867 let mut mut_mask = frozen.try_into_mut().unwrap();
868 mut_mask.append_n(true, 5);
869 assert_eq!(mut_mask.len(), 25);
870
871 let frozen_again = mut_mask.freeze();
872 assert_eq!(frozen_again.true_count(), 15);
873 }
874}