spacetimedb_table/
fixed_bit_set.rs

1use core::{
2    cmp, fmt,
3    ops::{BitAnd, BitAndAssign, BitOr, Not, Shl},
4    slice::Iter,
5};
6pub use internal_unsafe::FixedBitSet;
7use internal_unsafe::Len;
8
9use crate::MemoryUsage;
10
11/// A type used to represent blocks in a bit set.
12/// A smaller type, compared to usize,
13/// means taking less advantage of native operations.
14/// A larger type means we might over-allocate more.
15pub trait BitBlock:
16    Copy
17    + Eq
18    + Not<Output = Self>
19    + BitAnd<Self, Output = Self>
20    + BitAndAssign
21    + BitOr<Self, Output = Self>
22    + Shl<usize, Output = Self>
23{
24    /// The number of bits that [`Self`] can represent.
25    const BITS: u32;
26
27    /// The first bit is set.
28    const ONE: Self;
29
30    /// No bits are set.
31    const ZERO: Self;
32
33    fn wrapping_sub(self, rhs: Self) -> Self;
34    fn trailing_zeros(self) -> u32;
35}
36
37type DefaultBitBlock = u64;
38
39impl BitBlock for DefaultBitBlock {
40    const BITS: u32 = Self::BITS;
41    const ONE: Self = 1;
42    const ZERO: Self = 0;
43
44    #[inline]
45    fn wrapping_sub(self, rhs: Self) -> Self {
46        self.wrapping_sub(rhs)
47    }
48
49    #[inline]
50    fn trailing_zeros(self) -> u32 {
51        self.trailing_zeros()
52    }
53}
54
55/// The internals of `FixedBitSet`.
56/// Separated from the higher level APIs to contain the safety boundary.
57mod internal_unsafe {
58    use spacetimedb_lib::{de::Deserialize, ser::Serialize};
59    use spacetimedb_sats::{impl_deserialize, impl_serialize};
60
61    use super::{BitBlock, DefaultBitBlock};
62    use crate::{static_assert_align, static_assert_size};
63    use core::{
64        mem,
65        ptr::NonNull,
66        slice::{from_raw_parts, from_raw_parts_mut},
67    };
68
69    /// The type used to represent the number of bits the set can hold.
70    ///
71    /// Currently `u16` to keep `mem::size_of::<FixedBitSet>()` small.
72    pub(super) type Len = u16;
73
74    /// A bit set that, once created, has a fixed size over time.
75    ///
76    /// The set can store at most `u16::MAX` number of bits.
77    #[repr(C, packed)]
78    pub struct FixedBitSet<B = DefaultBitBlock> {
79        /// The size of the heap allocation in number of elements.
80        len: Len,
81        /// A pointer to a heap allocation of `[B]` of `self.len`.
82        ptr: NonNull<B>,
83    }
84
85    static_assert_align!(FixedBitSet, 1);
86    static_assert_size!(FixedBitSet, mem::size_of::<usize>() + mem::size_of::<Len>());
87
88    // SAFETY: `FixedBitSet` owns its data.
89    unsafe impl<B> Send for FixedBitSet<B> {}
90    // SAFETY: `FixedBitSet` owns its data.
91    unsafe impl<B> Sync for FixedBitSet<B> {}
92
93    impl<B> Drop for FixedBitSet<B> {
94        fn drop(&mut self) {
95            let blocks = self.storage_mut();
96            // SAFETY: We own the memory region pointed to by `blocks`,
97            // and as we have `&'0 mut self`, we also have exclusive access to it.
98            // So, and since we are in `drop`,
99            // we can deallocate the memory as we are the last referent to it.
100            // Moreover, the memory was allocated in `Self::new(..)` using `vec![..]`,
101            // which will allocate using `Global`, so we can convert it back to a `Box`.
102            let _ = unsafe { Box::from_raw(blocks) };
103        }
104    }
105
106    // We need to be able to serialize and deserialize `FixedBitSet` because they appear in the `PageHeader`.
107    impl_serialize!([B: BitBlock + Serialize] FixedBitSet<B>, (self, ser) => self.storage().serialize(ser));
108    impl_deserialize!([B: BitBlock + Deserialize<'de>] FixedBitSet<B>, de => {
109        let storage = Box::<[B]>::deserialize(de)?;
110        Ok(Self::from_boxed_slice(storage))
111    });
112
113    impl<B: BitBlock> FixedBitSet<B> {
114        pub(super) fn from_boxed_slice(storage: Box<[B]>) -> Self {
115            // SAFETY: required for the soundness of `Drop` as
116            // `dealloc` must receive the same layout as it was `alloc`ated with.
117            assert!(storage.len() <= Len::MAX as usize);
118            let len = storage.len() as Len;
119            let ptr = NonNull::from(Box::leak(storage)).cast();
120            Self { ptr, len }
121        }
122    }
123
124    impl<B> FixedBitSet<B> {
125        /// Returns the backing `[B]` slice for shared access.
126        pub(crate) const fn storage(&self) -> &[B] {
127            let ptr = self.ptr.as_ptr();
128            let len = self.len as usize;
129            // SAFETY:
130            // - `self.ptr` is a `NonNull` so `ptr` cannot be null.
131            // - `self.ptr` is properly aligned for `BitBlock`s.
132            // - `self.ptr` is valid for reads as we have `&self` and we own the memory
133            //   which we know is `blocks` elements long.
134            // - As we have `&'0 self`, elsewhere cannot mutate the memory during `'0`
135            //   except through an `UnsafeCell`.
136            unsafe { from_raw_parts(ptr, len) }
137        }
138
139        /// Returns the backing `[B]` slice for mutation.
140        pub(super) fn storage_mut(&mut self) -> &mut [B] {
141            let ptr = self.ptr.as_ptr();
142            let len = self.len as usize;
143            // SAFETY:
144            // - `self.ptr` is a `NonNull` so `ptr` cannot be null.
145            // - `self.ptr` is properly aligned for `BitBlock`s.
146            // - `self.ptr` is valid for reads and writes as we have `&mut self` and we own the memory
147            //   which we know is `blocks` elements long.
148            // - As we have `&'0 mut self`, we have exclusive access for `'0`
149            //   so the memory cannot be accessed elsewhere during `'0`.
150            unsafe { from_raw_parts_mut(ptr, len) }
151        }
152    }
153}
154
155impl<B: BitBlock> cmp::Eq for FixedBitSet<B> {}
156impl<B: BitBlock> cmp::PartialEq for FixedBitSet<B> {
157    fn eq(&self, other: &Self) -> bool {
158        self.storage() == other.storage()
159    }
160}
161
162/// Computes how many blocks are needed to store that many bits.
163fn blocks_for_bits<B: BitBlock>(bits: usize) -> usize {
164    // Must round e.g., 31 / 32 to 1 and 32 / 32 to 1 as well.
165    bits.div_ceil(B::BITS as usize)
166}
167
168impl<B: BitBlock> FixedBitSet<B> {
169    /// Allocates a new bit set capable of holding `bits` number of bits.
170    pub fn new(bits: usize) -> Self {
171        // Compute the number of blocks needed.
172        let nblocks = blocks_for_bits::<B>(bits);
173
174        // Allocate the blocks and extract the pointer to the heap region.
175        let blocks: Box<[B]> = vec![B::ZERO; nblocks].into_boxed_slice();
176
177        Self::from_boxed_slice(blocks)
178    }
179
180    /// Converts `idx` to its block index and the index within the block.
181    const fn idx_to_pos(idx: usize) -> (usize, usize) {
182        let bits = B::BITS as usize;
183        (idx / bits, idx % bits)
184    }
185
186    /// Returns whether `idx` is set or not.
187    pub fn get(&self, idx: usize) -> bool {
188        let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
189        let block = self.storage()[block_idx];
190        (block & (B::ONE << pos_in_block)) != B::ZERO
191    }
192
193    /// Sets bit at position `idx` to `val`.
194    pub fn set(&mut self, idx: usize, val: bool) {
195        let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
196        let block = &mut self.storage_mut()[block_idx];
197
198        // Update the block.
199        let flag = B::ONE << pos_in_block;
200        *block = if val { *block | flag } else { *block & !flag };
201    }
202
203    /// Clears every bit in the vec.
204    pub fn clear(&mut self) {
205        self.storage_mut().fill(B::ZERO);
206    }
207
208    /// Returns all the set indices.
209    pub fn iter_set(&self) -> IterSet<'_, B> {
210        let mut inner = self.storage().iter();
211
212        // Fetch the first block; if it isn't there, use an all-zero one.
213        // This will cause the iterator to terminate immediately.
214        let curr = inner.next().copied().unwrap_or(B::ZERO);
215
216        IterSet {
217            inner,
218            curr,
219            block_idx: 0,
220        }
221    }
222
223    /// Returns all the set indices from `start_idx` inclusive.
224    pub fn iter_set_from(&self, start_idx: usize) -> IterSet<'_, B> {
225        // Translate the index to its block and position within it.
226        let (block_idx, pos_in_block) = Self::idx_to_pos(start_idx);
227
228        // We want our iteration to start from the block that includes `start_idx`.
229        let mut inner = self.storage()[block_idx..].iter();
230
231        // Fetch the first block; if it isn't there, use an all-zero one.
232        // This will cause the iterator to terminate immediately.
233        let curr = inner.next().copied().unwrap_or(B::ZERO);
234
235        // Our `start_idx` might be in the middle of the `curr` block.
236        // To resolve this, we must zero out any preceding bits.
237        // So e.g., for `B = u8`,
238        // we must transform `0000_1011` to `0000_1000` for `start_idx = 3`.
239        let zero_preceding_mask = B::ZERO.wrapping_sub(B::ONE << pos_in_block);
240        let curr = curr & zero_preceding_mask;
241
242        IterSet {
243            inner,
244            curr,
245            block_idx: block_idx as Len,
246        }
247    }
248}
249
250impl<B> MemoryUsage for FixedBitSet<B> {
251    fn heap_usage(&self) -> usize {
252        std::mem::size_of_val(self.storage())
253    }
254}
255
256impl<B: BitBlock> fmt::Debug for FixedBitSet<B> {
257    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258        f.debug_set().entries(self.iter_set()).finish()
259    }
260}
261
262/// An iterator that yields the set indices of a [`FixedBitSet`].
263pub struct IterSet<'a, B = DefaultBitBlock> {
264    /// The block iterator.
265    inner: Iter<'a, B>,
266    /// The current block being processed, taken from `self.inner`.
267    curr: B,
268    /// What the index of `self.curr` is.
269    block_idx: Len,
270}
271
272impl<B: BitBlock> Iterator for IterSet<'_, B> {
273    type Item = usize;
274
275    fn next(&mut self) -> Option<Self::Item> {
276        loop {
277            let tz = self.curr.trailing_zeros();
278            if tz < B::BITS {
279                // Some bit was set; so yield the index of that
280                // and zero the bit out so we don't yield it again.
281                self.curr &= self.curr.wrapping_sub(B::ONE);
282                let idx = self.block_idx as u32 * B::BITS + tz;
283                return Some(idx as usize);
284            } else {
285                // No bit is set; advance to the next block, or quit if none left.
286                self.curr = *self.inner.next()?;
287                self.block_idx += 1;
288            }
289        }
290    }
291}
292
293#[cfg(test)]
294pub(crate) mod test {
295    use super::*;
296    use proptest::bits::bitset::between;
297    use proptest::prelude::*;
298    use spacetimedb_data_structures::map::HashSet;
299
300    #[test]
301    #[should_panic]
302    fn zero_sized_is_ok() {
303        let mut set = FixedBitSet::<DefaultBitBlock>::new(0);
304        set.clear();
305        set.iter_set_from(0).count();
306        set.iter_set().count();
307        set.get(0);
308    }
309
310    const MAX_NBITS: usize = 1000;
311
312    proptest! {
313        #![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
314
315        #[test]
316        fn after_new_there_are_no_bits_set(nbits in 0..MAX_NBITS) {
317            let set = FixedBitSet::<DefaultBitBlock>::new(nbits);
318            for idx in 0..nbits {
319                prop_assert!(!set.get(idx));
320            }
321        }
322
323        #[test]
324        fn after_clear_there_are_no_bits_set(choices in between(0, MAX_NBITS)) {
325            let nbits = choices.get_ref().len();
326
327            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
328
329            // Set all the bits chosen.
330            for idx in &choices {
331                prop_assert!(!set.get(idx));
332                set.set(idx, true);
333                prop_assert!(set.get(idx));
334            }
335
336            // Clear!
337            set.clear();
338
339            // After clearing, all bits should be unset.
340            for idx in 0..nbits {
341                prop_assert!(!set.get(idx));
342            }
343        }
344
345        #[test]
346        fn get_set_consistency(choices in between(0, MAX_NBITS)) {
347            let nbits = choices.get_ref().len();
348            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
349
350            // Set all the bits chosen.
351            for idx in &choices {
352                prop_assert!(!set.get(idx));
353
354                // After setting, it's true.
355                set.set(idx, true);
356                prop_assert!(set.get(idx));
357                // And this is idempotent.
358                set.set(idx, true);
359                prop_assert!(set.get(idx));
360            }
361
362            // Build the "complement" of `choices`.
363            let choices: HashSet<_> = choices.into_iter().collect();
364            let universe: HashSet<_> = (0..nbits).collect();
365            for idx in universe.difference(&choices) {
366                prop_assert!(!set.get(*idx));
367            }
368
369            // Unset all the bits chosen.
370            for idx in &choices {
371                // After unsetting, it's false.
372                set.set(*idx, false);
373                prop_assert!(!set.get(*idx));
374                // And this is idempotent.
375                set.set(*idx, false);
376                prop_assert!(!set.get(*idx));
377            }
378        }
379
380        #[test]
381        fn iter_set_preserves_order_of_original_choices(choices in between(0, MAX_NBITS)) {
382            let nbits = choices.get_ref().len();
383
384            // Set all the bits chosen.
385            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
386            for idx in &choices {
387                set.set(idx, true);
388            }
389
390            // `iter_set` produces the same list `choices`.
391            let collected = set.iter_set().collect::<Vec<_>>();
392            let original = choices.iter().collect::<Vec<_>>();
393            prop_assert_eq!(&original, &collected);
394
395            if let [_, second, ..] = &*original {
396                // Starting from the second yields the same list as `choices[1..]`.
397                let collected = set.iter_set_from(*second).collect::<Vec<_>>();
398                prop_assert_eq!(&original[1..], &collected);
399            }
400
401            // `iter_set_from` and `iter_set` produce the same list.
402            prop_assert_eq!(collected, set.iter_set_from(0).collect::<Vec<_>>());
403        }
404
405        #[test]
406        fn serde_round_trip(choices in between(0, MAX_NBITS)) {
407            let nbits = choices.get_ref().len();
408
409            // Set all the bits chosen.
410            let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
411            for idx in &choices {
412                set.set(idx, true);
413            }
414
415            let ser = spacetimedb_lib::bsatn::to_vec(&set)?;
416            let de = spacetimedb_lib::bsatn::from_slice::<FixedBitSet<DefaultBitBlock>>(&ser)?;
417
418            assert!(set == de);
419        }
420    }
421}