spacetimedb_table/
fixed_bit_set.rs

1use core::{
2    cmp, fmt,
3    ops::{BitAnd, BitAndAssign, BitOr, Not, Shl},
4    slice::Iter,
5};
6pub use internal_unsafe::FixedBitSet;
7use internal_unsafe::Len;
8use spacetimedb_sats::memory_usage::MemoryUsage;
9
10/// A type used to represent blocks in a bit set.
11/// A smaller type, compared to usize,
12/// means taking less advantage of native operations.
13/// A larger type means we might over-allocate more.
14pub trait BitBlock:
15    Copy
16    + Eq
17    + Not<Output = Self>
18    + BitAnd<Self, Output = Self>
19    + BitAndAssign
20    + BitOr<Self, Output = Self>
21    + Shl<usize, Output = Self>
22{
23    /// The number of bits that [`Self`] can represent.
24    const BITS: u32;
25
26    /// The first bit is set.
27    const ONE: Self;
28
29    /// No bits are set.
30    const ZERO: Self;
31
32    fn wrapping_sub(self, rhs: Self) -> Self;
33    fn trailing_zeros(self) -> u32;
34}
35
36type DefaultBitBlock = u64;
37
38impl BitBlock for DefaultBitBlock {
39    const BITS: u32 = Self::BITS;
40    const ONE: Self = 1;
41    const ZERO: Self = 0;
42
43    #[inline]
44    fn wrapping_sub(self, rhs: Self) -> Self {
45        self.wrapping_sub(rhs)
46    }
47
48    #[inline]
49    fn trailing_zeros(self) -> u32 {
50        self.trailing_zeros()
51    }
52}
53
54/// The internals of `FixedBitSet`.
55/// Separated from the higher level APIs to contain the safety boundary.
56mod internal_unsafe {
57    use spacetimedb_lib::{de::Deserialize, ser::Serialize};
58    use spacetimedb_sats::{impl_deserialize, impl_serialize};
59
60    use super::{BitBlock, DefaultBitBlock};
61    use crate::{static_assert_align, static_assert_size};
62    use core::{
63        mem,
64        ptr::NonNull,
65        slice::{from_raw_parts, from_raw_parts_mut},
66    };
67
68    /// The type used to represent the number of bits the set can hold.
69    ///
70    /// Currently `u16` to keep `mem::size_of::<FixedBitSet>()` small.
71    pub(super) type Len = u16;
72
73    /// A bit set that, once created, has a fixed size over time.
74    ///
75    /// The set can store at most `u16::MAX` number of bits.
76    #[repr(C, packed)]
77    pub struct FixedBitSet<B = DefaultBitBlock> {
78        /// The size of the heap allocation in number of elements.
79        len: Len,
80        /// A pointer to a heap allocation of `[B]` of `self.len`.
81        ptr: NonNull<B>,
82    }
83
84    static_assert_align!(FixedBitSet, 1);
85    static_assert_size!(FixedBitSet, mem::size_of::<usize>() + mem::size_of::<Len>());
86
87    // SAFETY: `FixedBitSet` owns its data.
88    unsafe impl<B> Send for FixedBitSet<B> {}
89    // SAFETY: `FixedBitSet` owns its data.
90    unsafe impl<B> Sync for FixedBitSet<B> {}
91
92    impl<B> Drop for FixedBitSet<B> {
93        fn drop(&mut self) {
94            let blocks = self.storage_mut();
95            // SAFETY: We own the memory region pointed to by `blocks`,
96            // and as we have `&'0 mut self`, we also have exclusive access to it.
97            // So, and since we are in `drop`,
98            // we can deallocate the memory as we are the last referent to it.
99            // Moreover, the memory was allocated in `Self::new(..)` using `vec![..]`,
100            // which will allocate using `Global`, so we can convert it back to a `Box`.
101            let _ = unsafe { Box::from_raw(blocks) };
102        }
103    }
104
105    // We need to be able to serialize and deserialize `FixedBitSet` because they appear in the `PageHeader`.
106    impl_serialize!([B: BitBlock + Serialize] FixedBitSet<B>, (self, ser) => self.storage().serialize(ser));
107    impl_deserialize!([B: BitBlock + Deserialize<'de>] FixedBitSet<B>, de => {
108        let storage = Box::<[B]>::deserialize(de)?;
109        Ok(Self::from_boxed_slice(storage))
110    });
111
112    impl<B: BitBlock> FixedBitSet<B> {
113        pub(super) fn from_boxed_slice(storage: Box<[B]>) -> Self {
114            // SAFETY: required for the soundness of `Drop` as
115            // `dealloc` must receive the same layout as it was `alloc`ated with.
116            assert!(storage.len() <= Len::MAX as usize);
117            let len = storage.len() as Len;
118            let ptr = NonNull::from(Box::leak(storage)).cast();
119            Self { ptr, len }
120        }
121    }
122
123    impl<B> FixedBitSet<B> {
124        /// Returns the capacity of the bitset.
125        #[inline]
126        pub(super) const fn blocks(&self) -> usize {
127            self.len as usize
128        }
129
130        /// Returns the backing `[B]` slice for shared access.
131        pub(crate) const fn storage(&self) -> &[B] {
132            let ptr = self.ptr.as_ptr();
133            let len = self.blocks();
134            // SAFETY:
135            // - `self.ptr` is a `NonNull` so `ptr` cannot be null.
136            // - `self.ptr` is properly aligned for `BitBlock`s.
137            // - `self.ptr` is valid for reads as we have `&self` and we own the memory
138            //   which we know is `blocks` elements long.
139            // - As we have `&'0 self`, elsewhere cannot mutate the memory during `'0`
140            //   except through an `UnsafeCell`.
141            unsafe { from_raw_parts(ptr, len) }
142        }
143
144        /// Returns the backing `[B]` slice for mutation.
145        pub(super) fn storage_mut(&mut self) -> &mut [B] {
146            let ptr = self.ptr.as_ptr();
147            let len = self.blocks();
148            // SAFETY:
149            // - `self.ptr` is a `NonNull` so `ptr` cannot be null.
150            // - `self.ptr` is properly aligned for `BitBlock`s.
151            // - `self.ptr` is valid for reads and writes as we have `&mut self` and we own the memory
152            //   which we know is `blocks` elements long.
153            // - As we have `&'0 mut self`, we have exclusive access for `'0`
154            //   so the memory cannot be accessed elsewhere during `'0`.
155            unsafe { from_raw_parts_mut(ptr, len) }
156        }
157    }
158}
159
160impl<B: BitBlock> cmp::Eq for FixedBitSet<B> {}
161impl<B: BitBlock> cmp::PartialEq for FixedBitSet<B> {
162    fn eq(&self, other: &Self) -> bool {
163        self.storage() == other.storage()
164    }
165}
166
167/// Computes how many blocks are needed to store that many bits.
168fn blocks_for_bits<B: BitBlock>(bits: usize) -> usize {
169    // Must round e.g., 31 / 32 to 1 and 32 / 32 to 1 as well.
170    bits.div_ceil(B::BITS as usize)
171}
172
173impl<B: BitBlock> FixedBitSet<B> {
174    /// Allocates a new bit set capable of holding `bits` number of bits.
175    pub fn new(bits: usize) -> Self {
176        Self::new_for_blocks(blocks_for_bits::<B>(bits))
177    }
178
179    /// Allocates a new bit set that will have a capacity of `nblocks` blocks.
180    #[inline]
181    fn new_for_blocks(nblocks: usize) -> Self {
182        // Allocate the blocks and extract the pointer to the heap region.
183        let blocks: Box<[B]> = vec![B::ZERO; nblocks].into_boxed_slice();
184
185        Self::from_boxed_slice(blocks)
186    }
187
188    /// Converts `idx` to its block index and the index within the block.
189    const fn idx_to_pos(idx: usize) -> (usize, usize) {
190        let bits = B::BITS as usize;
191        (idx / bits, idx % bits)
192    }
193
194    /// Returns whether `idx` is set or not.
195    pub fn get(&self, idx: usize) -> bool {
196        let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
197        let block = self.storage()[block_idx];
198        (block & (B::ONE << pos_in_block)) != B::ZERO
199    }
200
201    /// Sets bit at position `idx` to `val`.
202    pub fn set(&mut self, idx: usize, val: bool) {
203        let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
204        let block = &mut self.storage_mut()[block_idx];
205
206        // Update the block.
207        let flag = B::ONE << pos_in_block;
208        *block = if val { *block | flag } else { *block & !flag };
209    }
210
211    /// Clears every bit in the vec.
212    pub fn clear(&mut self) {
213        self.storage_mut().fill(B::ZERO);
214    }
215
216    /// Resets the bit set so that it's capable of holding `bits` number of bits.
217    ///
218    /// Every bit in the set will be zero after this.
219    pub fn reset_for(&mut self, bits: usize) {
220        // Compute the number of blocks needed.
221        let nblocks = blocks_for_bits::<B>(bits);
222
223        // Either clear the existing set, reusing it, or make a new one.
224        if nblocks == self.blocks() {
225            self.clear();
226        } else {
227            *self = Self::new_for_blocks(nblocks)
228        }
229    }
230
231    /// Returns the capacity of the bitset in bits.
232    #[inline]
233    pub(crate) const fn bits(&self) -> usize {
234        self.blocks() * B::BITS as usize
235    }
236
237    /// Returns all the set indices.
238    pub fn iter_set(&self) -> IterSet<'_, B> {
239        let mut inner = self.storage().iter();
240
241        // Fetch the first block; if it isn't there, use an all-zero one.
242        // This will cause the iterator to terminate immediately.
243        let curr = inner.next().copied().unwrap_or(B::ZERO);
244
245        IterSet {
246            inner,
247            curr,
248            block_idx: 0,
249        }
250    }
251
252    /// Returns all the set indices from `start_idx` inclusive.
253    pub fn iter_set_from(&self, start_idx: usize) -> IterSet<'_, B> {
254        // Translate the index to its block and position within it.
255        let (block_idx, pos_in_block) = Self::idx_to_pos(start_idx);
256
257        // We want our iteration to start from the block that includes `start_idx`.
258        let mut inner = self.storage()[block_idx..].iter();
259
260        // Fetch the first block; if it isn't there, use an all-zero one.
261        // This will cause the iterator to terminate immediately.
262        let curr = inner.next().copied().unwrap_or(B::ZERO);
263
264        // Our `start_idx` might be in the middle of the `curr` block.
265        // To resolve this, we must zero out any preceding bits.
266        // So e.g., for `B = u8`,
267        // we must transform `0000_1011` to `0000_1000` for `start_idx = 3`.
268        let zero_preceding_mask = B::ZERO.wrapping_sub(B::ONE << pos_in_block);
269        let curr = curr & zero_preceding_mask;
270
271        IterSet {
272            inner,
273            curr,
274            block_idx: block_idx as Len,
275        }
276    }
277}
278
279impl<B> MemoryUsage for FixedBitSet<B> {
280    fn heap_usage(&self) -> usize {
281        std::mem::size_of_val(self.storage())
282    }
283}
284
285impl<B: BitBlock> fmt::Debug for FixedBitSet<B> {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        f.debug_set().entries(self.iter_set()).finish()
288    }
289}
290
291/// An iterator that yields the set indices of a [`FixedBitSet`].
292pub struct IterSet<'a, B = DefaultBitBlock> {
293    /// The block iterator.
294    inner: Iter<'a, B>,
295    /// The current block being processed, taken from `self.inner`.
296    curr: B,
297    /// What the index of `self.curr` is.
298    block_idx: Len,
299}
300
301impl<B: BitBlock> Iterator for IterSet<'_, B> {
302    type Item = usize;
303
304    fn next(&mut self) -> Option<Self::Item> {
305        loop {
306            let tz = self.curr.trailing_zeros();
307            if tz < B::BITS {
308                // Some bit was set; so yield the index of that
309                // and zero the bit out so we don't yield it again.
310                self.curr &= self.curr.wrapping_sub(B::ONE);
311                let idx = self.block_idx as u32 * B::BITS + tz;
312                return Some(idx as usize);
313            } else {
314                // No bit is set; advance to the next block, or quit if none left.
315                self.curr = *self.inner.next()?;
316                self.block_idx += 1;
317            }
318        }
319    }
320}
321
322#[cfg(test)]
323pub(crate) mod test {
324    use super::*;
325    use proptest::bits::bitset::between;
326    use proptest::prelude::*;
327    use spacetimedb_data_structures::map::HashSet;
328
329    #[test]
330    #[should_panic]
331    fn zero_sized_is_ok() {
332        let mut set = FixedBitSet::<DefaultBitBlock>::new(0);
333        set.clear();
334        set.iter_set_from(0).count();
335        set.iter_set().count();
336        set.get(0);
337    }
338
339    const MAX_NBITS: usize = 1000;
340
341    proptest! {
342        #![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
343
344        #[test]
345        fn after_new_there_are_no_bits_set(nbits in 0..MAX_NBITS) {
346            let set = FixedBitSet::<DefaultBitBlock>::new(nbits);
347            for idx in 0..nbits {
348                prop_assert!(!set.get(idx));
349            }
350        }
351
352        #[test]
353        fn after_clear_there_are_no_bits_set(choices in between(0, MAX_NBITS)) {
354            let nbits = choices.get_ref().len();
355
356            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
357
358            // Set all the bits chosen.
359            for idx in &choices {
360                prop_assert!(!set.get(idx));
361                set.set(idx, true);
362                prop_assert!(set.get(idx));
363            }
364
365            // Clear!
366            set.clear();
367
368            // After clearing, all bits should be unset.
369            for idx in 0..nbits {
370                prop_assert!(!set.get(idx));
371            }
372        }
373
374        #[test]
375        fn get_set_consistency(choices in between(0, MAX_NBITS)) {
376            let nbits = choices.get_ref().len();
377            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
378
379            // Set all the bits chosen.
380            for idx in &choices {
381                prop_assert!(!set.get(idx));
382
383                // After setting, it's true.
384                set.set(idx, true);
385                prop_assert!(set.get(idx));
386                // And this is idempotent.
387                set.set(idx, true);
388                prop_assert!(set.get(idx));
389            }
390
391            // Build the "complement" of `choices`.
392            let choices: HashSet<_> = choices.into_iter().collect();
393            let universe: HashSet<_> = (0..nbits).collect();
394            for idx in universe.difference(&choices) {
395                prop_assert!(!set.get(*idx));
396            }
397
398            // Unset all the bits chosen.
399            for idx in &choices {
400                // After unsetting, it's false.
401                set.set(*idx, false);
402                prop_assert!(!set.get(*idx));
403                // And this is idempotent.
404                set.set(*idx, false);
405                prop_assert!(!set.get(*idx));
406            }
407        }
408
409        #[test]
410        fn iter_set_preserves_order_of_original_choices(choices in between(0, MAX_NBITS)) {
411            let nbits = choices.get_ref().len();
412
413            // Set all the bits chosen.
414            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
415            for idx in &choices {
416                set.set(idx, true);
417            }
418
419            // `iter_set` produces the same list `choices`.
420            let collected = set.iter_set().collect::<Vec<_>>();
421            let original = choices.iter().collect::<Vec<_>>();
422            prop_assert_eq!(&original, &collected);
423
424            if let [_, second, ..] = &*original {
425                // Starting from the second yields the same list as `choices[1..]`.
426                let collected = set.iter_set_from(*second).collect::<Vec<_>>();
427                prop_assert_eq!(&original[1..], &collected);
428            }
429
430            // `iter_set_from` and `iter_set` produce the same list.
431            prop_assert_eq!(collected, set.iter_set_from(0).collect::<Vec<_>>());
432        }
433
434        #[test]
435        fn serde_round_trip(choices in between(0, MAX_NBITS)) {
436            let nbits = choices.get_ref().len();
437
438            // Set all the bits chosen.
439            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
440            for idx in &choices {
441                set.set(idx, true);
442            }
443
444            let ser = spacetimedb_lib::bsatn::to_vec(&set)?;
445            let de = spacetimedb_lib::bsatn::from_slice::<FixedBitSet<DefaultBitBlock>>(&ser)?;
446
447            assert!(set == de);
448        }
449    }
450}