1use std::sync::Arc;
5
6use vortex_buffer::BitBufferMut;
7use vortex_error::vortex_panic;
8
9use crate::Mask;
10
11#[derive(Debug, Clone)]
13pub struct MaskMut(Inner);
14
15#[derive(Debug, Clone)]
16enum Inner {
17 Empty { capacity: usize },
19 Constant {
21 value: bool,
22 len: usize,
23 capacity: usize,
24 },
25 Builder(BitBufferMut),
28}
29
30impl MaskMut {
31 pub fn empty() -> Self {
33 Self::with_capacity(0)
34 }
35
36 pub fn with_capacity(capacity: usize) -> Self {
38 Self(Inner::Empty { capacity })
39 }
40
41 pub fn new_true(len: usize) -> Self {
43 Self(Inner::Constant {
44 value: true,
45 len,
46 capacity: len,
47 })
48 }
49
50 pub fn new_false(len: usize) -> Self {
52 Self(Inner::Constant {
53 value: false,
54 len,
55 capacity: len,
56 })
57 }
58
59 pub fn value(&self, index: usize) -> bool {
65 match &self.0 {
66 Inner::Empty { .. } => {
67 vortex_panic!("index out of bounds: the length is 0 but the index is {index}")
68 }
69 Inner::Constant { value, len, .. } => {
70 assert!(
71 index < *len,
72 "index out of bounds: the length is {} but the index is {index}",
73 *len
74 );
75
76 *value
77 }
78 Inner::Builder(bit_buffer) => bit_buffer.value(index),
79 }
80 }
81
82 pub fn reserve(&mut self, additional: usize) {
84 match &mut self.0 {
85 Inner::Empty { capacity } => {
86 *capacity += additional;
87 }
88 Inner::Constant { capacity, .. } => {
89 *capacity += additional;
90 }
91 Inner::Builder(bits) => {
92 bits.reserve(additional);
93 }
94 }
95 }
96
97 pub unsafe fn set_len(&mut self, new_len: usize) {
106 debug_assert!(new_len < self.capacity());
107 match &mut self.0 {
108 Inner::Empty { capacity, .. } => {
109 self.0 = Inner::Constant {
110 value: false, len: new_len,
112 capacity: *capacity,
113 }
114 }
115 Inner::Constant { len, .. } => {
116 *len = new_len;
117 }
118 Inner::Builder(bits) => {
119 unsafe { bits.set_len(new_len) };
120 }
121 }
122 }
123
124 pub fn capacity(&self) -> usize {
126 match &self.0 {
127 Inner::Empty { capacity } => *capacity,
128 Inner::Constant { capacity, .. } => *capacity,
129 Inner::Builder(bits) => bits.capacity(),
130 }
131 }
132
133 pub fn clear(&mut self) {
137 match &mut self.0 {
138 Inner::Empty { .. } => {}
139 Inner::Constant { capacity, .. } => {
140 self.0 = Inner::Empty {
141 capacity: *capacity,
142 }
143 }
144 Inner::Builder(bit_buffer) => bit_buffer.clear(),
145 };
146 }
147
148 pub fn truncate(&mut self, len: usize) {
154 let truncated_len = len;
155 if truncated_len > self.len() {
156 return;
157 }
158
159 match &mut self.0 {
160 Inner::Empty { .. } => {}
161 Inner::Constant { len, .. } => *len = truncated_len.min(*len),
162 Inner::Builder(bit_buffer) => bit_buffer.truncate(truncated_len),
163 };
164 }
165
166 pub fn append_n(&mut self, new_value: bool, n: usize) {
168 match &mut self.0 {
169 Inner::Empty { capacity } => {
170 self.0 = Inner::Constant {
171 value: new_value,
172 len: n,
173 capacity: (*capacity).max(n),
174 }
175 }
176 Inner::Constant {
177 value,
178 len,
179 capacity,
180 } => {
181 if *value == new_value {
182 self.0 = Inner::Constant {
184 value: *value,
185 len: *len + n,
186 capacity: (*capacity).max(*len + n),
187 }
188 } else {
189 let bits = self.materialize();
192 bits.append_n(new_value, n);
193 }
194 }
195 Inner::Builder(bits) => {
196 bits.append_n(new_value, n);
197 }
198 }
199 }
200
201 pub fn append_mask(&mut self, other: &Mask) {
203 match other {
204 Mask::AllTrue(len) => self.append_n(true, *len),
205 Mask::AllFalse(len) => self.append_n(false, *len),
206 Mask::Values(values) => {
207 let bitbuffer = values.buffer.clone();
208 self.materialize().append_buffer(&bitbuffer);
209 }
210 }
211 }
212
213 fn materialize(&mut self) -> &mut BitBufferMut {
215 let needs_materialization = !matches!(self.0, Inner::Builder(_));
216
217 if needs_materialization {
218 let new_builder = match &self.0 {
219 Inner::Empty { capacity } => BitBufferMut::with_capacity(*capacity),
220 Inner::Constant {
221 value,
222 len,
223 capacity,
224 } => {
225 let required_capacity = (*capacity).max(*len);
226 let mut bits = BitBufferMut::with_capacity(required_capacity);
227 bits.append_n(*value, *len);
228 bits
229 }
230 Inner::Builder(_) => unreachable!(),
231 };
232 self.0 = Inner::Builder(new_builder);
233 }
234
235 match &mut self.0 {
236 Inner::Builder(bits) => bits,
237 _ => unreachable!(),
238 }
239 }
240
241 pub fn split_off(&mut self, at: usize) -> Self {
245 assert!(at <= self.capacity(), "split_off index out of bounds");
246 match &mut self.0 {
247 Inner::Empty { capacity } => {
248 let new_capacity = *capacity - at;
249 *capacity = at;
250 Self(Inner::Empty {
251 capacity: new_capacity,
252 })
253 }
254 Inner::Constant {
255 value,
256 len,
257 capacity,
258 } => {
259 let new_len = len.saturating_sub(at);
261 let new_capacity = *capacity - at;
262 *len = (*len).min(at);
263 *capacity = at;
264
265 Self(Inner::Constant {
266 value: *value,
267 len: new_len,
268 capacity: new_capacity,
269 })
270 }
271 Inner::Builder(bits) => {
272 let new_bits = bits.split_off(at);
273 Self(Inner::Builder(new_bits))
274 }
275 }
276 }
277
278 pub fn unsplit(&mut self, other: Self) {
280 match other.0 {
281 Inner::Empty { .. } => {
282 }
284 Inner::Constant { value, len, .. } => {
285 self.append_n(value, len);
286 }
287 Inner::Builder(bits) => {
288 self.materialize().unsplit(bits);
289 }
290 }
291 }
292
293 pub fn freeze(self) -> Mask {
295 match self.0 {
296 Inner::Empty { .. } => Mask::new_true(0),
297 Inner::Constant { value, len, .. } => {
298 if value {
299 Mask::new_true(len)
300 } else {
301 Mask::new_false(len)
302 }
303 }
304 Inner::Builder(bits) => Mask::from_buffer(bits.freeze()),
305 }
306 }
307
308 pub fn len(&self) -> usize {
310 match &self.0 {
311 Inner::Empty { .. } => 0,
312 Inner::Constant { len, .. } => *len,
313 Inner::Builder(bits) => bits.len(),
314 }
315 }
316
317 pub fn is_empty(&self) -> bool {
319 self.len() == 0
320 }
321
322 pub fn all_true(&self) -> bool {
324 match &self.0 {
325 Inner::Empty { .. } => true,
326 Inner::Constant { value, .. } => *value,
327 Inner::Builder(bits) => bits.true_count() == bits.len(),
328 }
329 }
330
331 pub fn all_false(&self) -> bool {
333 match &self.0 {
334 Inner::Empty { .. } => true,
335 Inner::Constant { value, .. } => !*value,
336 Inner::Builder(bits) => !bits.is_empty() && bits.true_count() == 0,
337 }
338 }
339}
340
341impl Mask {
342 pub fn try_into_mut(self) -> Result<MaskMut, Self> {
345 match self {
346 Mask::AllTrue(len) => Ok(MaskMut::new_true(len)),
347 Mask::AllFalse(len) => Ok(MaskMut::new_false(len)),
348 Mask::Values(values) => {
349 let owned_values = Arc::try_unwrap(values).map_err(Mask::Values)?;
352 let bit_buffer = owned_values.into_buffer();
353 let mut_buffer = bit_buffer.try_into_mut().map_err(Mask::from_buffer)?;
354
355 Ok(MaskMut(Inner::Builder(mut_buffer)))
356 }
357 }
358 }
359
360 pub fn into_mut(self) -> MaskMut {
363 match self {
364 Mask::AllTrue(len) => MaskMut::new_true(len),
365 Mask::AllFalse(len) => MaskMut::new_false(len),
366 Mask::Values(values) => {
367 let bit_buffer_mut = match Arc::try_unwrap(values) {
368 Ok(mask_values) => {
369 let bit_buffer = mask_values.into_buffer();
370 bit_buffer.into_mut()
371 }
372 Err(arc_mask_values) => {
373 let bit_buffer = arc_mask_values.bit_buffer();
374 BitBufferMut::copy_from(bit_buffer)
375 }
376 };
377
378 MaskMut(Inner::Builder(bit_buffer_mut))
379 }
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_split_off_empty() {
390 let mut mask = MaskMut::with_capacity(10);
391 assert_eq!(mask.len(), 0);
392
393 let other = mask.split_off(0);
394 assert_eq!(mask.len(), 0);
395 assert_eq!(other.len(), 0);
396 }
397
398 #[test]
399 fn test_split_off_constant_true_at_zero() {
400 let mut mask = MaskMut::new_true(10);
401 let other = mask.split_off(0);
402
403 assert_eq!(mask.len(), 0);
404 assert_eq!(other.len(), 10);
405
406 let frozen = other.freeze();
407 assert_eq!(frozen.true_count(), 10);
408 }
409
410 #[test]
411 fn test_split_off_constant_true_at_end() {
412 let mut mask = MaskMut::new_true(10);
413 let other = mask.split_off(10);
414
415 assert_eq!(mask.len(), 10);
416 assert_eq!(other.len(), 0);
417
418 let frozen = mask.freeze();
419 assert_eq!(frozen.true_count(), 10);
420 }
421
422 #[test]
423 fn test_split_off_constant_true_in_middle() {
424 let mut mask = MaskMut::new_true(10);
425 let other = mask.split_off(6);
426
427 assert_eq!(mask.len(), 6);
428 assert_eq!(other.len(), 4);
429
430 let frozen_first = mask.freeze();
431 assert_eq!(frozen_first.true_count(), 6);
432
433 let frozen_second = other.freeze();
434 assert_eq!(frozen_second.true_count(), 4);
435 }
436
437 #[test]
438 fn test_split_off_constant_false() {
439 let mut mask = MaskMut::new_false(20);
440 let other = mask.split_off(12);
441
442 assert_eq!(mask.len(), 12);
443 assert_eq!(other.len(), 8);
444
445 let frozen_first = mask.freeze();
446 assert_eq!(frozen_first.true_count(), 0);
447
448 let frozen_second = other.freeze();
449 assert_eq!(frozen_second.true_count(), 0);
450 }
451
452 #[test]
455 fn test_split_off_builder_at_byte_boundary() {
456 let mut mask = MaskMut::with_capacity(16);
457 mask.append_n(true, 8);
459 mask.append_n(false, 8);
460
461 let mask_ptr = match &mask.0 {
462 Inner::Builder(bits) => bits.as_slice().as_ptr(),
463 _ => unreachable!(),
464 };
465
466 let other = mask.split_off(8);
467
468 assert_eq!(mask.len(), 8);
469 assert_eq!(other.len(), 8);
470
471 mask.unsplit(other);
473 let new_mask_ptr = match &mask.0 {
474 Inner::Builder(bits) => bits.as_slice().as_ptr(),
475 _ => unreachable!(),
476 };
477 assert_eq!(mask_ptr, new_mask_ptr);
478 }
479
480 #[test]
481 fn test_split_off_builder_not_byte_aligned() {
482 let mut mask = MaskMut::with_capacity(20);
483 mask.append_n(true, 10);
485 mask.append_n(false, 10);
486
487 let other = mask.split_off(10);
488
489 assert_eq!(mask.len(), 10);
490 assert_eq!(other.len(), 10);
491
492 let frozen_first = mask.freeze();
493 assert_eq!(frozen_first.true_count(), 10);
494
495 let frozen_second = other.freeze();
496 assert_eq!(frozen_second.true_count(), 0);
497 }
498
499 #[test]
500 fn test_split_off_builder_mixed_pattern() {
501 let mut mask = MaskMut::with_capacity(15);
502 for i in 0..15 {
504 mask.append_n(i % 2 == 0, 1);
505 }
506
507 let other = mask.split_off(7);
508
509 assert_eq!(mask.len(), 7);
510 assert_eq!(other.len(), 8);
511
512 let frozen_first = mask.freeze();
513 assert_eq!(frozen_first.true_count(), 4); let frozen_second = other.freeze();
516 assert_eq!(frozen_second.true_count(), 4); }
518
519 #[test]
520 fn test_unsplit_empty_with_empty() {
521 let mut mask = MaskMut::with_capacity(10);
522 let other = MaskMut::with_capacity(10);
523
524 mask.unsplit(other);
525 assert_eq!(mask.len(), 0);
526 }
527
528 #[test]
529 fn test_unsplit_empty_with_constant() {
530 let mut mask = MaskMut::with_capacity(10);
531 let other = MaskMut::new_true(5);
532
533 mask.unsplit(other);
534 assert_eq!(mask.len(), 5);
535
536 let frozen = mask.freeze();
537 assert_eq!(frozen.true_count(), 5);
538 }
539
540 #[test]
541 fn test_unsplit_constant_with_constant_same() {
542 let mut mask = MaskMut::new_true(5);
543 let other = MaskMut::new_true(5);
544
545 mask.unsplit(other);
546 assert_eq!(mask.len(), 10);
547
548 let frozen = mask.freeze();
549 assert_eq!(frozen.true_count(), 10);
550 }
551
552 #[test]
553 fn test_unsplit_constant_with_constant_different() {
554 let mut mask = MaskMut::new_true(5);
555 let other = MaskMut::new_false(5);
556
557 mask.unsplit(other);
558 assert_eq!(mask.len(), 10);
559
560 let frozen = mask.freeze();
561 assert_eq!(frozen.true_count(), 5);
562 }
563
564 #[test]
565 fn test_unsplit_constant_with_builder() {
566 let mut mask = MaskMut::new_true(5);
567
568 let mut other = MaskMut::with_capacity(10);
569 other.append_n(true, 3);
570 other.append_n(false, 2);
571
572 mask.unsplit(other);
573 assert_eq!(mask.len(), 10);
574
575 let frozen = mask.freeze();
576 assert_eq!(frozen.true_count(), 8); }
578
579 #[test]
580 fn test_unsplit_builder_with_constant() {
581 let mut mask = MaskMut::with_capacity(10);
582 mask.append_n(true, 3);
583 mask.append_n(false, 2);
584
585 let other = MaskMut::new_true(5);
586
587 mask.unsplit(other);
588 assert_eq!(mask.len(), 10);
589
590 let frozen = mask.freeze();
591 assert_eq!(frozen.true_count(), 8); }
593
594 #[test]
595 fn test_unsplit_builder_with_builder() {
596 let mut mask = MaskMut::with_capacity(10);
597 mask.append_n(true, 3);
598 mask.append_n(false, 2);
599
600 let mut other = MaskMut::with_capacity(10);
601 other.append_n(false, 3);
602 other.append_n(true, 2);
603
604 mask.unsplit(other);
605 assert_eq!(mask.len(), 10);
606
607 let frozen = mask.freeze();
608 assert_eq!(frozen.true_count(), 5); }
610
611 #[test]
612 fn test_round_trip_split_unsplit() {
613 let mut original = MaskMut::with_capacity(20);
614 original.append_n(true, 10);
616 original.append_n(false, 10);
617
618 let original_frozen = original.freeze();
619 let original_true_count = original_frozen.true_count();
620
621 let mut mask = original_frozen.try_into_mut().unwrap();
623
624 let other = mask.split_off(10);
626
627 mask.unsplit(other);
629
630 assert_eq!(mask.len(), 20);
631 let frozen = mask.freeze();
632 assert_eq!(frozen.true_count(), original_true_count);
633 }
634
635 #[test]
636 #[should_panic(expected = "split_off index out of bounds")]
637 fn test_split_off_out_of_bounds() {
638 let mut mask = MaskMut::new_true(10);
639 mask.split_off(11);
640 }
641
642 #[test]
643 fn test_split_off_builder_at_bit_1() {
644 let mut mask = MaskMut::with_capacity(16);
645 mask.append_n(true, 16);
646
647 let other = mask.split_off(1);
648
649 assert_eq!(mask.len(), 1);
650 assert_eq!(other.len(), 15);
651
652 let frozen_first = mask.freeze();
653 assert_eq!(frozen_first.true_count(), 1);
654
655 let frozen_second = other.freeze();
656 assert_eq!(frozen_second.true_count(), 15);
657 }
658
659 #[test]
660 fn test_multiple_split_unsplit() {
661 let mut mask = MaskMut::new_true(30);
662
663 let third = mask.split_off(20); let second = mask.split_off(10); assert_eq!(mask.len(), 10);
669 assert_eq!(second.len(), 10);
670 assert_eq!(third.len(), 10);
671
672 mask.unsplit(second);
674 mask.unsplit(third);
675
676 assert_eq!(mask.len(), 30);
677 let frozen = mask.freeze();
678 assert_eq!(frozen.true_count(), 30);
679 }
680
681 #[test]
682 fn test_try_into_mut_all_variants() {
683 let mask_true = Mask::new_true(100);
685 let mut_mask_true = mask_true.try_into_mut().unwrap();
686 assert_eq!(mut_mask_true.len(), 100);
687 assert_eq!(mut_mask_true.freeze().true_count(), 100);
688
689 let mask_false = Mask::new_false(50);
690 let mut_mask_false = mask_false.try_into_mut().unwrap();
691 assert_eq!(mut_mask_false.len(), 50);
692 assert_eq!(mut_mask_false.freeze().true_count(), 0);
693 }
694
695 #[test]
696 fn test_try_into_mut_with_references() {
697 let mut mask_mut = MaskMut::with_capacity(10);
699 mask_mut.append_n(true, 5);
700 mask_mut.append_n(false, 5);
701 let mask = mask_mut.freeze();
702
703 let mask2 = {
705 let mut mask_mut2 = MaskMut::with_capacity(10);
706 mask_mut2.append_n(true, 5);
707 mask_mut2.append_n(false, 5);
708 mask_mut2.freeze()
709 };
710 let result = mask2.try_into_mut();
711 assert!(result.is_ok());
712 assert_eq!(result.unwrap().len(), 10);
713
714 let _cloned = mask.clone();
716 let result = mask.try_into_mut();
717 assert!(result.is_err());
718 if let Err(returned_mask) = result {
719 assert_eq!(returned_mask.len(), 10);
720 assert_eq!(returned_mask.true_count(), 5);
721 }
722 }
723
724 #[test]
725 fn test_try_into_mut_round_trip() {
726 let mut original = MaskMut::with_capacity(20);
728 original.append_n(true, 10);
729 original.append_n(false, 10);
730
731 let frozen = original.freeze();
732 assert_eq!(frozen.true_count(), 10);
733
734 let mut mut_mask = frozen.try_into_mut().unwrap();
735 mut_mask.append_n(true, 5);
736 assert_eq!(mut_mask.len(), 25);
737
738 let frozen_again = mut_mask.freeze();
739 assert_eq!(frozen_again.true_count(), 15);
740 }
741}