vortex_mask/
mask_mut.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_buffer::BitBufferMut;
7use vortex_error::vortex_panic;
8
9use crate::Mask;
10
11/// A mutable mask, used for lazily allocating the bit buffer as required.
12#[derive(Debug, Clone)]
13pub struct MaskMut(Inner);
14
15#[derive(Debug, Clone)]
16enum Inner {
17    /// Initially, the mask is empty but may have some capacity.
18    Empty { capacity: usize },
19    /// When the first value is pushed, the mask becomes constant.
20    Constant {
21        value: bool,
22        len: usize,
23        capacity: usize,
24    },
25    /// When the first non-constant value is written, we allocate the bit buffer and switch
26    /// into the builder state.
27    Builder(BitBufferMut),
28}
29
30impl MaskMut {
31    /// Creates a new empty mask.
32    pub fn empty() -> Self {
33        Self::with_capacity(0)
34    }
35
36    /// Creates a new empty mask with the default capacity.
37    pub fn with_capacity(capacity: usize) -> Self {
38        Self(Inner::Empty { capacity })
39    }
40
41    /// Creates a new mask with all values set to `true`.
42    pub fn new_true(len: usize) -> Self {
43        Self(Inner::Constant {
44            value: true,
45            len,
46            capacity: len,
47        })
48    }
49
50    /// Creates a new mask with all values set to `false`.
51    pub fn new_false(len: usize) -> Self {
52        Self(Inner::Constant {
53            value: false,
54            len,
55            capacity: len,
56        })
57    }
58
59    /// Returns the boolean value at a given index.
60    ///
61    /// # Panics
62    ///
63    /// Panics if the index is out of bounds.
64    pub fn value(&self, index: usize) -> bool {
65        match &self.0 {
66            Inner::Empty { .. } => {
67                vortex_panic!("index out of bounds: the length is 0 but the index is {index}")
68            }
69            Inner::Constant { value, len, .. } => {
70                assert!(
71                    index < *len,
72                    "index out of bounds: the length is {} but the index is {index}",
73                    *len
74                );
75
76                *value
77            }
78            Inner::Builder(bit_buffer) => bit_buffer.value(index),
79        }
80    }
81
82    /// Reserve capacity for at least `additional` more values to be appended.
83    pub fn reserve(&mut self, additional: usize) {
84        match &mut self.0 {
85            Inner::Empty { capacity } => {
86                *capacity += additional;
87            }
88            Inner::Constant { capacity, .. } => {
89                *capacity += additional;
90            }
91            Inner::Builder(bits) => {
92                bits.reserve(additional);
93            }
94        }
95    }
96
97    /// Set the length of the mask.
98    ///
99    /// # Safety
100    ///
101    /// - `new_len` must be less than or equal to [`capacity()`].
102    /// - The elements at `old_len..new_len` must be initialized.
103    ///
104    /// [`capacity()`]: Self::capacity
105    pub unsafe fn set_len(&mut self, new_len: usize) {
106        debug_assert!(new_len < self.capacity());
107        match &mut self.0 {
108            Inner::Empty { capacity, .. } => {
109                self.0 = Inner::Constant {
110                    value: false, // Pick any value
111                    len: new_len,
112                    capacity: *capacity,
113                }
114            }
115            Inner::Constant { len, .. } => {
116                *len = new_len;
117            }
118            Inner::Builder(bits) => {
119                unsafe { bits.set_len(new_len) };
120            }
121        }
122    }
123
124    /// Returns the capacity of the mask.
125    pub fn capacity(&self) -> usize {
126        match &self.0 {
127            Inner::Empty { capacity } => *capacity,
128            Inner::Constant { capacity, .. } => *capacity,
129            Inner::Builder(bits) => bits.capacity(),
130        }
131    }
132
133    /// Clears the mask.
134    ///
135    /// Note that this method has no effect on the allocated capacity of the mask.
136    pub fn clear(&mut self) {
137        match &mut self.0 {
138            Inner::Empty { .. } => {}
139            Inner::Constant { capacity, .. } => {
140                self.0 = Inner::Empty {
141                    capacity: *capacity,
142                }
143            }
144            Inner::Builder(bit_buffer) => bit_buffer.clear(),
145        };
146    }
147
148    /// Shortens the mask, keeping the first `len` bits.
149    ///
150    /// If `len` is greater or equal to the vector’s current length, this has no effect.
151    ///
152    /// Note that this method has no effect on the allocated capacity of the mask.
153    pub fn truncate(&mut self, len: usize) {
154        let truncated_len = len;
155        if truncated_len > self.len() {
156            return;
157        }
158
159        match &mut self.0 {
160            Inner::Empty { .. } => {}
161            Inner::Constant { len, .. } => *len = truncated_len.min(*len),
162            Inner::Builder(bit_buffer) => bit_buffer.truncate(truncated_len),
163        };
164    }
165
166    /// Append n values to the mask.
167    pub fn append_n(&mut self, new_value: bool, n: usize) {
168        match &mut self.0 {
169            Inner::Empty { capacity } => {
170                self.0 = Inner::Constant {
171                    value: new_value,
172                    len: n,
173                    capacity: (*capacity).max(n),
174                }
175            }
176            Inner::Constant {
177                value,
178                len,
179                capacity,
180            } => {
181                if *value == new_value {
182                    // Same value, just increase length.
183                    self.0 = Inner::Constant {
184                        value: *value,
185                        len: *len + n,
186                        capacity: (*capacity).max(*len + n),
187                    }
188                } else {
189                    // Different value, need to allocate the bit buffer.
190                    // Note: materialize() already appends the existing constant values
191                    let bits = self.materialize();
192                    bits.append_n(new_value, n);
193                }
194            }
195            Inner::Builder(bits) => {
196                bits.append_n(new_value, n);
197            }
198        }
199    }
200
201    /// Append a [`Mask`] to this mutable mask.
202    pub fn append_mask(&mut self, other: &Mask) {
203        match other {
204            Mask::AllTrue(len) => self.append_n(true, *len),
205            Mask::AllFalse(len) => self.append_n(false, *len),
206            Mask::Values(values) => {
207                let bitbuffer = values.buffer.clone();
208                self.materialize().append_buffer(&bitbuffer);
209            }
210        }
211    }
212
213    /// Ensures that the internal bit buffer is allocated and returns a mutable reference to it.
214    fn materialize(&mut self) -> &mut BitBufferMut {
215        let needs_materialization = !matches!(self.0, Inner::Builder(_));
216
217        if needs_materialization {
218            let new_builder = match &self.0 {
219                Inner::Empty { capacity } => BitBufferMut::with_capacity(*capacity),
220                Inner::Constant {
221                    value,
222                    len,
223                    capacity,
224                } => {
225                    let required_capacity = (*capacity).max(*len);
226                    let mut bits = BitBufferMut::with_capacity(required_capacity);
227                    bits.append_n(*value, *len);
228                    bits
229                }
230                Inner::Builder(_) => unreachable!(),
231            };
232            self.0 = Inner::Builder(new_builder);
233        }
234
235        match &mut self.0 {
236            Inner::Builder(bits) => bits,
237            _ => unreachable!(),
238        }
239    }
240
241    /// Split-off the mask at the given index, returning a new mask with the
242    /// values from `at` to the end, and leaving `self` with the values from
243    /// the start to `at`.
244    pub fn split_off(&mut self, at: usize) -> Self {
245        assert!(at <= self.capacity(), "split_off index out of bounds");
246        match &mut self.0 {
247            Inner::Empty { capacity } => {
248                let new_capacity = *capacity - at;
249                *capacity = at;
250                Self(Inner::Empty {
251                    capacity: new_capacity,
252                })
253            }
254            Inner::Constant {
255                value,
256                len,
257                capacity,
258            } => {
259                // Adjust the lengths, given that length may be < at
260                let new_len = len.saturating_sub(at);
261                let new_capacity = *capacity - at;
262                *len = (*len).min(at);
263                *capacity = at;
264
265                Self(Inner::Constant {
266                    value: *value,
267                    len: new_len,
268                    capacity: new_capacity,
269                })
270            }
271            Inner::Builder(bits) => {
272                let new_bits = bits.split_off(at);
273                Self(Inner::Builder(new_bits))
274            }
275        }
276    }
277
278    /// Absorb another mask into this one, appending its values.
279    pub fn unsplit(&mut self, other: Self) {
280        match other.0 {
281            Inner::Empty { .. } => {
282                // No work to do
283            }
284            Inner::Constant { value, len, .. } => {
285                self.append_n(value, len);
286            }
287            Inner::Builder(bits) => {
288                self.materialize().unsplit(bits);
289            }
290        }
291    }
292
293    /// Freezes the mutable mask into an immutable one.
294    pub fn freeze(self) -> Mask {
295        match self.0 {
296            Inner::Empty { .. } => Mask::new_true(0),
297            Inner::Constant { value, len, .. } => {
298                if value {
299                    Mask::new_true(len)
300                } else {
301                    Mask::new_false(len)
302                }
303            }
304            Inner::Builder(bits) => Mask::from_buffer(bits.freeze()),
305        }
306    }
307
308    /// Returns the logical length of the mask.
309    pub fn len(&self) -> usize {
310        match &self.0 {
311            Inner::Empty { .. } => 0,
312            Inner::Constant { len, .. } => *len,
313            Inner::Builder(bits) => bits.len(),
314        }
315    }
316
317    /// Returns true if the mask is empty.
318    pub fn is_empty(&self) -> bool {
319        self.len() == 0
320    }
321
322    /// Returns true if all values in the mask are true.
323    pub fn all_true(&self) -> bool {
324        match &self.0 {
325            Inner::Empty { .. } => true,
326            Inner::Constant { value, .. } => *value,
327            Inner::Builder(bits) => bits.true_count() == bits.len(),
328        }
329    }
330
331    /// Returns true if all values in the mask are false.
332    pub fn all_false(&self) -> bool {
333        match &self.0 {
334            Inner::Empty { .. } => true,
335            Inner::Constant { value, .. } => !*value,
336            Inner::Builder(bits) => !bits.is_empty() && bits.true_count() == 0,
337        }
338    }
339}
340
341impl Mask {
342    /// Attempts to convert an immutable mask into a mutable one, returning an error of `Self` if
343    /// the underlying [`BitBuffer`](crate::BitBuffer) data if there are any other references.
344    pub fn try_into_mut(self) -> Result<MaskMut, Self> {
345        match self {
346            Mask::AllTrue(len) => Ok(MaskMut::new_true(len)),
347            Mask::AllFalse(len) => Ok(MaskMut::new_false(len)),
348            Mask::Values(values) => {
349                // We need to check for uniqueness twice, first for the `Arc` with `try_unwrap`,
350                // then for the internal `BitBuffer` with `try_into_mut`.
351                let owned_values = Arc::try_unwrap(values).map_err(Mask::Values)?;
352                let bit_buffer = owned_values.into_buffer();
353                let mut_buffer = bit_buffer.try_into_mut().map_err(Mask::from_buffer)?;
354
355                Ok(MaskMut(Inner::Builder(mut_buffer)))
356            }
357        }
358    }
359
360    /// Convert an immutable mask into a mutable one, cloning the underlying
361    /// [`BitBuffer`](crate::BitBuffer) data if there are any other references.
362    pub fn into_mut(self) -> MaskMut {
363        match self {
364            Mask::AllTrue(len) => MaskMut::new_true(len),
365            Mask::AllFalse(len) => MaskMut::new_false(len),
366            Mask::Values(values) => {
367                let bit_buffer_mut = match Arc::try_unwrap(values) {
368                    Ok(mask_values) => {
369                        let bit_buffer = mask_values.into_buffer();
370                        bit_buffer.into_mut()
371                    }
372                    Err(arc_mask_values) => {
373                        let bit_buffer = arc_mask_values.bit_buffer();
374                        BitBufferMut::copy_from(bit_buffer)
375                    }
376                };
377
378                MaskMut(Inner::Builder(bit_buffer_mut))
379            }
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_split_off_empty() {
390        let mut mask = MaskMut::with_capacity(10);
391        assert_eq!(mask.len(), 0);
392
393        let other = mask.split_off(0);
394        assert_eq!(mask.len(), 0);
395        assert_eq!(other.len(), 0);
396    }
397
398    #[test]
399    fn test_split_off_constant_true_at_zero() {
400        let mut mask = MaskMut::new_true(10);
401        let other = mask.split_off(0);
402
403        assert_eq!(mask.len(), 0);
404        assert_eq!(other.len(), 10);
405
406        let frozen = other.freeze();
407        assert_eq!(frozen.true_count(), 10);
408    }
409
410    #[test]
411    fn test_split_off_constant_true_at_end() {
412        let mut mask = MaskMut::new_true(10);
413        let other = mask.split_off(10);
414
415        assert_eq!(mask.len(), 10);
416        assert_eq!(other.len(), 0);
417
418        let frozen = mask.freeze();
419        assert_eq!(frozen.true_count(), 10);
420    }
421
422    #[test]
423    fn test_split_off_constant_true_in_middle() {
424        let mut mask = MaskMut::new_true(10);
425        let other = mask.split_off(6);
426
427        assert_eq!(mask.len(), 6);
428        assert_eq!(other.len(), 4);
429
430        let frozen_first = mask.freeze();
431        assert_eq!(frozen_first.true_count(), 6);
432
433        let frozen_second = other.freeze();
434        assert_eq!(frozen_second.true_count(), 4);
435    }
436
437    #[test]
438    fn test_split_off_constant_false() {
439        let mut mask = MaskMut::new_false(20);
440        let other = mask.split_off(12);
441
442        assert_eq!(mask.len(), 12);
443        assert_eq!(other.len(), 8);
444
445        let frozen_first = mask.freeze();
446        assert_eq!(frozen_first.true_count(), 0);
447
448        let frozen_second = other.freeze();
449        assert_eq!(frozen_second.true_count(), 0);
450    }
451
452    // Note: Tests using BitBuffer operations are marked as ignored under miri
453    // because bitvec uses raw pointer operations that miri cannot verify.
454    #[test]
455    fn test_split_off_builder_at_byte_boundary() {
456        let mut mask = MaskMut::with_capacity(16);
457        // Create a pattern: 8 true, 8 false
458        mask.append_n(true, 8);
459        mask.append_n(false, 8);
460
461        let mask_ptr = match &mask.0 {
462            Inner::Builder(bits) => bits.as_slice().as_ptr(),
463            _ => unreachable!(),
464        };
465
466        let other = mask.split_off(8);
467
468        assert_eq!(mask.len(), 8);
469        assert_eq!(other.len(), 8);
470
471        // Ensure the unsplit was zero-copy.
472        mask.unsplit(other);
473        let new_mask_ptr = match &mask.0 {
474            Inner::Builder(bits) => bits.as_slice().as_ptr(),
475            _ => unreachable!(),
476        };
477        assert_eq!(mask_ptr, new_mask_ptr);
478    }
479
480    #[test]
481    fn test_split_off_builder_not_byte_aligned() {
482        let mut mask = MaskMut::with_capacity(20);
483        // Create a pattern: 10 true, 10 false
484        mask.append_n(true, 10);
485        mask.append_n(false, 10);
486
487        let other = mask.split_off(10);
488
489        assert_eq!(mask.len(), 10);
490        assert_eq!(other.len(), 10);
491
492        let frozen_first = mask.freeze();
493        assert_eq!(frozen_first.true_count(), 10);
494
495        let frozen_second = other.freeze();
496        assert_eq!(frozen_second.true_count(), 0);
497    }
498
499    #[test]
500    fn test_split_off_builder_mixed_pattern() {
501        let mut mask = MaskMut::with_capacity(15);
502        // Create pattern: TFTFTFTFTFTFTFT (alternating)
503        for i in 0..15 {
504            mask.append_n(i % 2 == 0, 1);
505        }
506
507        let other = mask.split_off(7);
508
509        assert_eq!(mask.len(), 7);
510        assert_eq!(other.len(), 8);
511
512        let frozen_first = mask.freeze();
513        assert_eq!(frozen_first.true_count(), 4); // positions 0,2,4,6
514
515        let frozen_second = other.freeze();
516        assert_eq!(frozen_second.true_count(), 4); // positions 7,9,11,13 => 0,2,4,6 in split
517    }
518
519    #[test]
520    fn test_unsplit_empty_with_empty() {
521        let mut mask = MaskMut::with_capacity(10);
522        let other = MaskMut::with_capacity(10);
523
524        mask.unsplit(other);
525        assert_eq!(mask.len(), 0);
526    }
527
528    #[test]
529    fn test_unsplit_empty_with_constant() {
530        let mut mask = MaskMut::with_capacity(10);
531        let other = MaskMut::new_true(5);
532
533        mask.unsplit(other);
534        assert_eq!(mask.len(), 5);
535
536        let frozen = mask.freeze();
537        assert_eq!(frozen.true_count(), 5);
538    }
539
540    #[test]
541    fn test_unsplit_constant_with_constant_same() {
542        let mut mask = MaskMut::new_true(5);
543        let other = MaskMut::new_true(5);
544
545        mask.unsplit(other);
546        assert_eq!(mask.len(), 10);
547
548        let frozen = mask.freeze();
549        assert_eq!(frozen.true_count(), 10);
550    }
551
552    #[test]
553    fn test_unsplit_constant_with_constant_different() {
554        let mut mask = MaskMut::new_true(5);
555        let other = MaskMut::new_false(5);
556
557        mask.unsplit(other);
558        assert_eq!(mask.len(), 10);
559
560        let frozen = mask.freeze();
561        assert_eq!(frozen.true_count(), 5);
562    }
563
564    #[test]
565    fn test_unsplit_constant_with_builder() {
566        let mut mask = MaskMut::new_true(5);
567
568        let mut other = MaskMut::with_capacity(10);
569        other.append_n(true, 3);
570        other.append_n(false, 2);
571
572        mask.unsplit(other);
573        assert_eq!(mask.len(), 10);
574
575        let frozen = mask.freeze();
576        assert_eq!(frozen.true_count(), 8); // 5 from first + 3 from second
577    }
578
579    #[test]
580    fn test_unsplit_builder_with_constant() {
581        let mut mask = MaskMut::with_capacity(10);
582        mask.append_n(true, 3);
583        mask.append_n(false, 2);
584
585        let other = MaskMut::new_true(5);
586
587        mask.unsplit(other);
588        assert_eq!(mask.len(), 10);
589
590        let frozen = mask.freeze();
591        assert_eq!(frozen.true_count(), 8); // 3 from first + 5 from second
592    }
593
594    #[test]
595    fn test_unsplit_builder_with_builder() {
596        let mut mask = MaskMut::with_capacity(10);
597        mask.append_n(true, 3);
598        mask.append_n(false, 2);
599
600        let mut other = MaskMut::with_capacity(10);
601        other.append_n(false, 3);
602        other.append_n(true, 2);
603
604        mask.unsplit(other);
605        assert_eq!(mask.len(), 10);
606
607        let frozen = mask.freeze();
608        assert_eq!(frozen.true_count(), 5); // 3 from first + 2 from second
609    }
610
611    #[test]
612    fn test_round_trip_split_unsplit() {
613        let mut original = MaskMut::with_capacity(20);
614        // Pattern: 10 true, 10 false
615        original.append_n(true, 10);
616        original.append_n(false, 10);
617
618        let original_frozen = original.freeze();
619        let original_true_count = original_frozen.true_count();
620
621        // Convert back to mutable for split
622        let mut mask = original_frozen.try_into_mut().unwrap();
623
624        // Split at 10
625        let other = mask.split_off(10);
626
627        // Unsplit back together
628        mask.unsplit(other);
629
630        assert_eq!(mask.len(), 20);
631        let frozen = mask.freeze();
632        assert_eq!(frozen.true_count(), original_true_count);
633    }
634
635    #[test]
636    #[should_panic(expected = "split_off index out of bounds")]
637    fn test_split_off_out_of_bounds() {
638        let mut mask = MaskMut::new_true(10);
639        mask.split_off(11);
640    }
641
642    #[test]
643    fn test_split_off_builder_at_bit_1() {
644        let mut mask = MaskMut::with_capacity(16);
645        mask.append_n(true, 16);
646
647        let other = mask.split_off(1);
648
649        assert_eq!(mask.len(), 1);
650        assert_eq!(other.len(), 15);
651
652        let frozen_first = mask.freeze();
653        assert_eq!(frozen_first.true_count(), 1);
654
655        let frozen_second = other.freeze();
656        assert_eq!(frozen_second.true_count(), 15);
657    }
658
659    #[test]
660    fn test_multiple_split_unsplit() {
661        let mut mask = MaskMut::new_true(30);
662
663        // Split into 3 parts
664        let third = mask.split_off(20); // 20-30
665        let second = mask.split_off(10); // 10-20
666        // first is 0-10
667
668        assert_eq!(mask.len(), 10);
669        assert_eq!(second.len(), 10);
670        assert_eq!(third.len(), 10);
671
672        // Recombine in order
673        mask.unsplit(second);
674        mask.unsplit(third);
675
676        assert_eq!(mask.len(), 30);
677        let frozen = mask.freeze();
678        assert_eq!(frozen.true_count(), 30);
679    }
680
681    #[test]
682    fn test_try_into_mut_all_variants() {
683        // Test AllTrue and AllFalse variants.
684        let mask_true = Mask::new_true(100);
685        let mut_mask_true = mask_true.try_into_mut().unwrap();
686        assert_eq!(mut_mask_true.len(), 100);
687        assert_eq!(mut_mask_true.freeze().true_count(), 100);
688
689        let mask_false = Mask::new_false(50);
690        let mut_mask_false = mask_false.try_into_mut().unwrap();
691        assert_eq!(mut_mask_false.len(), 50);
692        assert_eq!(mut_mask_false.freeze().true_count(), 0);
693    }
694
695    #[test]
696    fn test_try_into_mut_with_references() {
697        // Create a MaskValues variant.
698        let mut mask_mut = MaskMut::with_capacity(10);
699        mask_mut.append_n(true, 5);
700        mask_mut.append_n(false, 5);
701        let mask = mask_mut.freeze();
702
703        // Should succeed with unique reference (no clones).
704        let mask2 = {
705            let mut mask_mut2 = MaskMut::with_capacity(10);
706            mask_mut2.append_n(true, 5);
707            mask_mut2.append_n(false, 5);
708            mask_mut2.freeze()
709        };
710        let result = mask2.try_into_mut();
711        assert!(result.is_ok());
712        assert_eq!(result.unwrap().len(), 10);
713
714        // Should fail with shared references.
715        let _cloned = mask.clone();
716        let result = mask.try_into_mut();
717        assert!(result.is_err());
718        if let Err(returned_mask) = result {
719            assert_eq!(returned_mask.len(), 10);
720            assert_eq!(returned_mask.true_count(), 5);
721        }
722    }
723
724    #[test]
725    fn test_try_into_mut_round_trip() {
726        // Test freeze -> try_into_mut -> modify -> freeze cycle.
727        let mut original = MaskMut::with_capacity(20);
728        original.append_n(true, 10);
729        original.append_n(false, 10);
730
731        let frozen = original.freeze();
732        assert_eq!(frozen.true_count(), 10);
733
734        let mut mut_mask = frozen.try_into_mut().unwrap();
735        mut_mask.append_n(true, 5);
736        assert_eq!(mut_mask.len(), 25);
737
738        let frozen_again = mut_mask.freeze();
739        assert_eq!(frozen_again.true_count(), 15);
740    }
741}