spacetimedb_table/
fixed_bit_set.rs

1use core::{
2    cmp, fmt,
3    ops::{BitAnd, BitAndAssign, BitOr, Not, Shl},
4    slice::Iter,
5};
6pub use internal_unsafe::FixedBitSet;
7use internal_unsafe::Len;
8
9use crate::MemoryUsage;
10
11/// A type used to represent blocks in a bit set.
12/// A smaller type, compared to usize,
13/// means taking less advantage of native operations.
14/// A larger type means we might over-allocate more.
15pub trait BitBlock:
16    Copy
17    + Eq
18    + Not<Output = Self>
19    + BitAnd<Self, Output = Self>
20    + BitAndAssign
21    + BitOr<Self, Output = Self>
22    + Shl<usize, Output = Self>
23{
24    /// The number of bits that [`Self`] can represent.
25    const BITS: u32;
26
27    /// The first bit is set.
28    const ONE: Self;
29
30    /// No bits are set.
31    const ZERO: Self;
32
33    fn wrapping_sub(self, rhs: Self) -> Self;
34    fn trailing_zeros(self) -> u32;
35}
36
37type DefaultBitBlock = u64;
38
39impl BitBlock for DefaultBitBlock {
40    const BITS: u32 = Self::BITS;
41    const ONE: Self = 1;
42    const ZERO: Self = 0;
43
44    #[inline]
45    fn wrapping_sub(self, rhs: Self) -> Self {
46        self.wrapping_sub(rhs)
47    }
48
49    #[inline]
50    fn trailing_zeros(self) -> u32 {
51        self.trailing_zeros()
52    }
53}
54
55/// The internals of `FixedBitSet`.
56/// Separated from the higher level APIs to contain the safety boundary.
57mod internal_unsafe {
58    use spacetimedb_lib::{de::Deserialize, ser::Serialize};
59    use spacetimedb_sats::{impl_deserialize, impl_serialize};
60
61    use super::{BitBlock, DefaultBitBlock};
62    use crate::{static_assert_align, static_assert_size};
63    use core::{
64        mem,
65        ptr::NonNull,
66        slice::{from_raw_parts, from_raw_parts_mut},
67    };
68
69    /// The type used to represent the number of bits the set can hold.
70    ///
71    /// Currently `u16` to keep `mem::size_of::<FixedBitSet>()` small.
72    pub(super) type Len = u16;
73
74    /// A bit set that, once created, has a fixed size over time.
75    ///
76    /// The set can store at most `u16::MAX` number of bits.
77    #[repr(C, packed)]
78    pub struct FixedBitSet<B = DefaultBitBlock> {
79        /// The size of the heap allocation in number of elements.
80        len: Len,
81        /// A pointer to a heap allocation of `[B]` of `self.len`.
82        ptr: NonNull<B>,
83    }
84
85    static_assert_align!(FixedBitSet, 1);
86    static_assert_size!(FixedBitSet, mem::size_of::<usize>() + mem::size_of::<Len>());
87
88    // SAFETY: `FixedBitSet` owns its data.
89    unsafe impl<B> Send for FixedBitSet<B> {}
90    // SAFETY: `FixedBitSet` owns its data.
91    unsafe impl<B> Sync for FixedBitSet<B> {}
92
93    impl<B> Drop for FixedBitSet<B> {
94        fn drop(&mut self) {
95            let blocks = self.storage_mut();
96            // SAFETY: We own the memory region pointed to by `blocks`,
97            // and as we have `&'0 mut self`, we also have exclusive access to it.
98            // So, and since we are in `drop`,
99            // we can deallocate the memory as we are the last referent to it.
100            // Moreover, the memory was allocated in `Self::new(..)` using `vec![..]`,
101            // which will allocate using `Global`, so we can convert it back to a `Box`.
102            let _ = unsafe { Box::from_raw(blocks) };
103        }
104    }
105
106    // We need to be able to serialize and deserialize `FixedBitSet` because they appear in the `PageHeader`.
107    impl_serialize!([B: BitBlock + Serialize] FixedBitSet<B>, (self, ser) => self.storage().serialize(ser));
108    impl_deserialize!([B: BitBlock + Deserialize<'de>] FixedBitSet<B>, de => {
109        let storage = Box::<[B]>::deserialize(de)?;
110        Ok(Self::from_boxed_slice(storage))
111    });
112
113    impl<B: BitBlock> FixedBitSet<B> {
114        pub(super) fn from_boxed_slice(storage: Box<[B]>) -> Self {
115            // SAFETY: required for the soundness of `Drop` as
116            // `dealloc` must receive the same layout as it was `alloc`ated with.
117            assert!(storage.len() <= Len::MAX as usize);
118            let len = storage.len() as Len;
119            let ptr = NonNull::from(Box::leak(storage)).cast();
120            Self { ptr, len }
121        }
122    }
123
124    impl<B> FixedBitSet<B> {
125        /// Returns the capacity of the bitset.
126        #[inline]
127        pub(super) const fn blocks(&self) -> usize {
128            self.len as usize
129        }
130
131        /// Returns the backing `[B]` slice for shared access.
132        pub(crate) const fn storage(&self) -> &[B] {
133            let ptr = self.ptr.as_ptr();
134            let len = self.blocks();
135            // SAFETY:
136            // - `self.ptr` is a `NonNull` so `ptr` cannot be null.
137            // - `self.ptr` is properly aligned for `BitBlock`s.
138            // - `self.ptr` is valid for reads as we have `&self` and we own the memory
139            //   which we know is `blocks` elements long.
140            // - As we have `&'0 self`, elsewhere cannot mutate the memory during `'0`
141            //   except through an `UnsafeCell`.
142            unsafe { from_raw_parts(ptr, len) }
143        }
144
145        /// Returns the backing `[B]` slice for mutation.
146        pub(super) fn storage_mut(&mut self) -> &mut [B] {
147            let ptr = self.ptr.as_ptr();
148            let len = self.blocks();
149            // SAFETY:
150            // - `self.ptr` is a `NonNull` so `ptr` cannot be null.
151            // - `self.ptr` is properly aligned for `BitBlock`s.
152            // - `self.ptr` is valid for reads and writes as we have `&mut self` and we own the memory
153            //   which we know is `blocks` elements long.
154            // - As we have `&'0 mut self`, we have exclusive access for `'0`
155            //   so the memory cannot be accessed elsewhere during `'0`.
156            unsafe { from_raw_parts_mut(ptr, len) }
157        }
158    }
159}
160
161impl<B: BitBlock> cmp::Eq for FixedBitSet<B> {}
162impl<B: BitBlock> cmp::PartialEq for FixedBitSet<B> {
163    fn eq(&self, other: &Self) -> bool {
164        self.storage() == other.storage()
165    }
166}
167
168/// Computes how many blocks are needed to store that many bits.
169fn blocks_for_bits<B: BitBlock>(bits: usize) -> usize {
170    // Must round e.g., 31 / 32 to 1 and 32 / 32 to 1 as well.
171    bits.div_ceil(B::BITS as usize)
172}
173
174impl<B: BitBlock> FixedBitSet<B> {
175    /// Allocates a new bit set capable of holding `bits` number of bits.
176    pub fn new(bits: usize) -> Self {
177        Self::new_for_blocks(blocks_for_bits::<B>(bits))
178    }
179
180    /// Allocates a new bit set that will have a capacity of `nblocks` blocks.
181    #[inline]
182    fn new_for_blocks(nblocks: usize) -> Self {
183        // Allocate the blocks and extract the pointer to the heap region.
184        let blocks: Box<[B]> = vec![B::ZERO; nblocks].into_boxed_slice();
185
186        Self::from_boxed_slice(blocks)
187    }
188
189    /// Converts `idx` to its block index and the index within the block.
190    const fn idx_to_pos(idx: usize) -> (usize, usize) {
191        let bits = B::BITS as usize;
192        (idx / bits, idx % bits)
193    }
194
195    /// Returns whether `idx` is set or not.
196    pub fn get(&self, idx: usize) -> bool {
197        let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
198        let block = self.storage()[block_idx];
199        (block & (B::ONE << pos_in_block)) != B::ZERO
200    }
201
202    /// Sets bit at position `idx` to `val`.
203    pub fn set(&mut self, idx: usize, val: bool) {
204        let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
205        let block = &mut self.storage_mut()[block_idx];
206
207        // Update the block.
208        let flag = B::ONE << pos_in_block;
209        *block = if val { *block | flag } else { *block & !flag };
210    }
211
212    /// Clears every bit in the vec.
213    pub fn clear(&mut self) {
214        self.storage_mut().fill(B::ZERO);
215    }
216
217    /// Resets the bit set so that it's capable of holding `bits` number of bits.
218    ///
219    /// Every bit in the set will be zero after this.
220    pub fn reset_for(&mut self, bits: usize) {
221        // Compute the number of blocks needed.
222        let nblocks = blocks_for_bits::<B>(bits);
223
224        // Either clear the existing set, reusing it, or make a new one.
225        if nblocks == self.blocks() {
226            self.clear();
227        } else {
228            *self = Self::new_for_blocks(nblocks)
229        }
230    }
231
232    /// Returns the capacity of the bitset in bits.
233    #[inline]
234    pub(crate) const fn bits(&self) -> usize {
235        self.blocks() * B::BITS as usize
236    }
237
238    /// Returns all the set indices.
239    pub fn iter_set(&self) -> IterSet<'_, B> {
240        let mut inner = self.storage().iter();
241
242        // Fetch the first block; if it isn't there, use an all-zero one.
243        // This will cause the iterator to terminate immediately.
244        let curr = inner.next().copied().unwrap_or(B::ZERO);
245
246        IterSet {
247            inner,
248            curr,
249            block_idx: 0,
250        }
251    }
252
253    /// Returns all the set indices from `start_idx` inclusive.
254    pub fn iter_set_from(&self, start_idx: usize) -> IterSet<'_, B> {
255        // Translate the index to its block and position within it.
256        let (block_idx, pos_in_block) = Self::idx_to_pos(start_idx);
257
258        // We want our iteration to start from the block that includes `start_idx`.
259        let mut inner = self.storage()[block_idx..].iter();
260
261        // Fetch the first block; if it isn't there, use an all-zero one.
262        // This will cause the iterator to terminate immediately.
263        let curr = inner.next().copied().unwrap_or(B::ZERO);
264
265        // Our `start_idx` might be in the middle of the `curr` block.
266        // To resolve this, we must zero out any preceding bits.
267        // So e.g., for `B = u8`,
268        // we must transform `0000_1011` to `0000_1000` for `start_idx = 3`.
269        let zero_preceding_mask = B::ZERO.wrapping_sub(B::ONE << pos_in_block);
270        let curr = curr & zero_preceding_mask;
271
272        IterSet {
273            inner,
274            curr,
275            block_idx: block_idx as Len,
276        }
277    }
278}
279
280impl<B> MemoryUsage for FixedBitSet<B> {
281    fn heap_usage(&self) -> usize {
282        std::mem::size_of_val(self.storage())
283    }
284}
285
286impl<B: BitBlock> fmt::Debug for FixedBitSet<B> {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        f.debug_set().entries(self.iter_set()).finish()
289    }
290}
291
292/// An iterator that yields the set indices of a [`FixedBitSet`].
293pub struct IterSet<'a, B = DefaultBitBlock> {
294    /// The block iterator.
295    inner: Iter<'a, B>,
296    /// The current block being processed, taken from `self.inner`.
297    curr: B,
298    /// What the index of `self.curr` is.
299    block_idx: Len,
300}
301
302impl<B: BitBlock> Iterator for IterSet<'_, B> {
303    type Item = usize;
304
305    fn next(&mut self) -> Option<Self::Item> {
306        loop {
307            let tz = self.curr.trailing_zeros();
308            if tz < B::BITS {
309                // Some bit was set; so yield the index of that
310                // and zero the bit out so we don't yield it again.
311                self.curr &= self.curr.wrapping_sub(B::ONE);
312                let idx = self.block_idx as u32 * B::BITS + tz;
313                return Some(idx as usize);
314            } else {
315                // No bit is set; advance to the next block, or quit if none left.
316                self.curr = *self.inner.next()?;
317                self.block_idx += 1;
318            }
319        }
320    }
321}
322
323#[cfg(test)]
324pub(crate) mod test {
325    use super::*;
326    use proptest::bits::bitset::between;
327    use proptest::prelude::*;
328    use spacetimedb_data_structures::map::HashSet;
329
330    #[test]
331    #[should_panic]
332    fn zero_sized_is_ok() {
333        let mut set = FixedBitSet::<DefaultBitBlock>::new(0);
334        set.clear();
335        set.iter_set_from(0).count();
336        set.iter_set().count();
337        set.get(0);
338    }
339
340    const MAX_NBITS: usize = 1000;
341
342    proptest! {
343        #![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
344
345        #[test]
346        fn after_new_there_are_no_bits_set(nbits in 0..MAX_NBITS) {
347            let set = FixedBitSet::<DefaultBitBlock>::new(nbits);
348            for idx in 0..nbits {
349                prop_assert!(!set.get(idx));
350            }
351        }
352
353        #[test]
354        fn after_clear_there_are_no_bits_set(choices in between(0, MAX_NBITS)) {
355            let nbits = choices.get_ref().len();
356
357            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
358
359            // Set all the bits chosen.
360            for idx in &choices {
361                prop_assert!(!set.get(idx));
362                set.set(idx, true);
363                prop_assert!(set.get(idx));
364            }
365
366            // Clear!
367            set.clear();
368
369            // After clearing, all bits should be unset.
370            for idx in 0..nbits {
371                prop_assert!(!set.get(idx));
372            }
373        }
374
375        #[test]
376        fn get_set_consistency(choices in between(0, MAX_NBITS)) {
377            let nbits = choices.get_ref().len();
378            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
379
380            // Set all the bits chosen.
381            for idx in &choices {
382                prop_assert!(!set.get(idx));
383
384                // After setting, it's true.
385                set.set(idx, true);
386                prop_assert!(set.get(idx));
387                // And this is idempotent.
388                set.set(idx, true);
389                prop_assert!(set.get(idx));
390            }
391
392            // Build the "complement" of `choices`.
393            let choices: HashSet<_> = choices.into_iter().collect();
394            let universe: HashSet<_> = (0..nbits).collect();
395            for idx in universe.difference(&choices) {
396                prop_assert!(!set.get(*idx));
397            }
398
399            // Unset all the bits chosen.
400            for idx in &choices {
401                // After unsetting, it's false.
402                set.set(*idx, false);
403                prop_assert!(!set.get(*idx));
404                // And this is idempotent.
405                set.set(*idx, false);
406                prop_assert!(!set.get(*idx));
407            }
408        }
409
410        #[test]
411        fn iter_set_preserves_order_of_original_choices(choices in between(0, MAX_NBITS)) {
412            let nbits = choices.get_ref().len();
413
414            // Set all the bits chosen.
415            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
416            for idx in &choices {
417                set.set(idx, true);
418            }
419
420            // `iter_set` produces the same list `choices`.
421            let collected = set.iter_set().collect::<Vec<_>>();
422            let original = choices.iter().collect::<Vec<_>>();
423            prop_assert_eq!(&original, &collected);
424
425            if let [_, second, ..] = &*original {
426                // Starting from the second yields the same list as `choices[1..]`.
427                let collected = set.iter_set_from(*second).collect::<Vec<_>>();
428                prop_assert_eq!(&original[1..], &collected);
429            }
430
431            // `iter_set_from` and `iter_set` produce the same list.
432            prop_assert_eq!(collected, set.iter_set_from(0).collect::<Vec<_>>());
433        }
434
435        #[test]
436        fn serde_round_trip(choices in between(0, MAX_NBITS)) {
437            let nbits = choices.get_ref().len();
438
439            // Set all the bits chosen.
440            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
441            for idx in &choices {
442                set.set(idx, true);
443            }
444
445            let ser = spacetimedb_lib::bsatn::to_vec(&set)?;
446            let de = spacetimedb_lib::bsatn::from_slice::<FixedBitSet<DefaultBitBlock>>(&ser)?;
447
448            assert!(set == de);
449        }
450    }
451}