spacetimedb_table/
fixed_bit_set.rs1use core::{
2 cmp, fmt,
3 ops::{BitAnd, BitAndAssign, BitOr, Not, Shl},
4 slice::Iter,
5};
6pub use internal_unsafe::FixedBitSet;
7use internal_unsafe::Len;
8use spacetimedb_sats::memory_usage::MemoryUsage;
9
10pub trait BitBlock:
15 Copy
16 + Eq
17 + Not<Output = Self>
18 + BitAnd<Self, Output = Self>
19 + BitAndAssign
20 + BitOr<Self, Output = Self>
21 + Shl<usize, Output = Self>
22{
23 const BITS: u32;
25
26 const ONE: Self;
28
29 const ZERO: Self;
31
32 fn wrapping_sub(self, rhs: Self) -> Self;
33 fn trailing_zeros(self) -> u32;
34}
35
36type DefaultBitBlock = u64;
37
38impl BitBlock for DefaultBitBlock {
39 const BITS: u32 = Self::BITS;
40 const ONE: Self = 1;
41 const ZERO: Self = 0;
42
43 #[inline]
44 fn wrapping_sub(self, rhs: Self) -> Self {
45 self.wrapping_sub(rhs)
46 }
47
48 #[inline]
49 fn trailing_zeros(self) -> u32 {
50 self.trailing_zeros()
51 }
52}
53
54mod internal_unsafe {
57 use spacetimedb_lib::{de::Deserialize, ser::Serialize};
58 use spacetimedb_sats::{impl_deserialize, impl_serialize};
59
60 use super::{BitBlock, DefaultBitBlock};
61 use crate::{static_assert_align, static_assert_size};
62 use core::{
63 mem,
64 ptr::NonNull,
65 slice::{from_raw_parts, from_raw_parts_mut},
66 };
67
68 pub(super) type Len = u16;
72
73 #[repr(C, packed)]
77 pub struct FixedBitSet<B = DefaultBitBlock> {
78 len: Len,
80 ptr: NonNull<B>,
82 }
83
84 static_assert_align!(FixedBitSet, 1);
85 static_assert_size!(FixedBitSet, mem::size_of::<usize>() + mem::size_of::<Len>());
86
87 unsafe impl<B> Send for FixedBitSet<B> {}
89 unsafe impl<B> Sync for FixedBitSet<B> {}
91
92 impl<B> Drop for FixedBitSet<B> {
93 fn drop(&mut self) {
94 let blocks = self.storage_mut();
95 let _ = unsafe { Box::from_raw(blocks) };
102 }
103 }
104
105 impl_serialize!([B: BitBlock + Serialize] FixedBitSet<B>, (self, ser) => self.storage().serialize(ser));
107 impl_deserialize!([B: BitBlock + Deserialize<'de>] FixedBitSet<B>, de => {
108 let storage = Box::<[B]>::deserialize(de)?;
109 Ok(Self::from_boxed_slice(storage))
110 });
111
112 impl<B: BitBlock> FixedBitSet<B> {
113 pub(super) fn from_boxed_slice(storage: Box<[B]>) -> Self {
114 assert!(storage.len() <= Len::MAX as usize);
117 let len = storage.len() as Len;
118 let ptr = NonNull::from(Box::leak(storage)).cast();
119 Self { ptr, len }
120 }
121 }
122
123 impl<B> FixedBitSet<B> {
124 #[inline]
126 pub(super) const fn blocks(&self) -> usize {
127 self.len as usize
128 }
129
130 pub(crate) const fn storage(&self) -> &[B] {
132 let ptr = self.ptr.as_ptr();
133 let len = self.blocks();
134 unsafe { from_raw_parts(ptr, len) }
142 }
143
144 pub(super) fn storage_mut(&mut self) -> &mut [B] {
146 let ptr = self.ptr.as_ptr();
147 let len = self.blocks();
148 unsafe { from_raw_parts_mut(ptr, len) }
156 }
157 }
158}
159
160impl<B: BitBlock> cmp::Eq for FixedBitSet<B> {}
161impl<B: BitBlock> cmp::PartialEq for FixedBitSet<B> {
162 fn eq(&self, other: &Self) -> bool {
163 self.storage() == other.storage()
164 }
165}
166
167fn blocks_for_bits<B: BitBlock>(bits: usize) -> usize {
169 bits.div_ceil(B::BITS as usize)
171}
172
173impl<B: BitBlock> FixedBitSet<B> {
174 pub fn new(bits: usize) -> Self {
176 Self::new_for_blocks(blocks_for_bits::<B>(bits))
177 }
178
179 #[inline]
181 fn new_for_blocks(nblocks: usize) -> Self {
182 let blocks: Box<[B]> = vec![B::ZERO; nblocks].into_boxed_slice();
184
185 Self::from_boxed_slice(blocks)
186 }
187
188 const fn idx_to_pos(idx: usize) -> (usize, usize) {
190 let bits = B::BITS as usize;
191 (idx / bits, idx % bits)
192 }
193
194 pub fn get(&self, idx: usize) -> bool {
196 let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
197 let block = self.storage()[block_idx];
198 (block & (B::ONE << pos_in_block)) != B::ZERO
199 }
200
201 pub fn set(&mut self, idx: usize, val: bool) {
203 let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
204 let block = &mut self.storage_mut()[block_idx];
205
206 let flag = B::ONE << pos_in_block;
208 *block = if val { *block | flag } else { *block & !flag };
209 }
210
211 pub fn clear(&mut self) {
213 self.storage_mut().fill(B::ZERO);
214 }
215
216 pub fn reset_for(&mut self, bits: usize) {
220 let nblocks = blocks_for_bits::<B>(bits);
222
223 if nblocks == self.blocks() {
225 self.clear();
226 } else {
227 *self = Self::new_for_blocks(nblocks)
228 }
229 }
230
231 #[inline]
233 pub(crate) const fn bits(&self) -> usize {
234 self.blocks() * B::BITS as usize
235 }
236
237 pub fn iter_set(&self) -> IterSet<'_, B> {
239 let mut inner = self.storage().iter();
240
241 let curr = inner.next().copied().unwrap_or(B::ZERO);
244
245 IterSet {
246 inner,
247 curr,
248 block_idx: 0,
249 }
250 }
251
252 pub fn iter_set_from(&self, start_idx: usize) -> IterSet<'_, B> {
254 let (block_idx, pos_in_block) = Self::idx_to_pos(start_idx);
256
257 let mut inner = self.storage()[block_idx..].iter();
259
260 let curr = inner.next().copied().unwrap_or(B::ZERO);
263
264 let zero_preceding_mask = B::ZERO.wrapping_sub(B::ONE << pos_in_block);
269 let curr = curr & zero_preceding_mask;
270
271 IterSet {
272 inner,
273 curr,
274 block_idx: block_idx as Len,
275 }
276 }
277}
278
279impl<B> MemoryUsage for FixedBitSet<B> {
280 fn heap_usage(&self) -> usize {
281 std::mem::size_of_val(self.storage())
282 }
283}
284
285impl<B: BitBlock> fmt::Debug for FixedBitSet<B> {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 f.debug_set().entries(self.iter_set()).finish()
288 }
289}
290
291pub struct IterSet<'a, B = DefaultBitBlock> {
293 inner: Iter<'a, B>,
295 curr: B,
297 block_idx: Len,
299}
300
301impl<B: BitBlock> Iterator for IterSet<'_, B> {
302 type Item = usize;
303
304 fn next(&mut self) -> Option<Self::Item> {
305 loop {
306 let tz = self.curr.trailing_zeros();
307 if tz < B::BITS {
308 self.curr &= self.curr.wrapping_sub(B::ONE);
311 let idx = self.block_idx as u32 * B::BITS + tz;
312 return Some(idx as usize);
313 } else {
314 self.curr = *self.inner.next()?;
316 self.block_idx += 1;
317 }
318 }
319 }
320}
321
322#[cfg(test)]
323pub(crate) mod test {
324 use super::*;
325 use proptest::bits::bitset::between;
326 use proptest::prelude::*;
327 use spacetimedb_data_structures::map::HashSet;
328
329 #[test]
330 #[should_panic]
331 fn zero_sized_is_ok() {
332 let mut set = FixedBitSet::<DefaultBitBlock>::new(0);
333 set.clear();
334 set.iter_set_from(0).count();
335 set.iter_set().count();
336 set.get(0);
337 }
338
339 const MAX_NBITS: usize = 1000;
340
341 proptest! {
342 #![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
343
344 #[test]
345 fn after_new_there_are_no_bits_set(nbits in 0..MAX_NBITS) {
346 let set = FixedBitSet::<DefaultBitBlock>::new(nbits);
347 for idx in 0..nbits {
348 prop_assert!(!set.get(idx));
349 }
350 }
351
352 #[test]
353 fn after_clear_there_are_no_bits_set(choices in between(0, MAX_NBITS)) {
354 let nbits = choices.get_ref().len();
355
356 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
357
358 for idx in &choices {
360 prop_assert!(!set.get(idx));
361 set.set(idx, true);
362 prop_assert!(set.get(idx));
363 }
364
365 set.clear();
367
368 for idx in 0..nbits {
370 prop_assert!(!set.get(idx));
371 }
372 }
373
374 #[test]
375 fn get_set_consistency(choices in between(0, MAX_NBITS)) {
376 let nbits = choices.get_ref().len();
377 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
378
379 for idx in &choices {
381 prop_assert!(!set.get(idx));
382
383 set.set(idx, true);
385 prop_assert!(set.get(idx));
386 set.set(idx, true);
388 prop_assert!(set.get(idx));
389 }
390
391 let choices: HashSet<_> = choices.into_iter().collect();
393 let universe: HashSet<_> = (0..nbits).collect();
394 for idx in universe.difference(&choices) {
395 prop_assert!(!set.get(*idx));
396 }
397
398 for idx in &choices {
400 set.set(*idx, false);
402 prop_assert!(!set.get(*idx));
403 set.set(*idx, false);
405 prop_assert!(!set.get(*idx));
406 }
407 }
408
409 #[test]
410 fn iter_set_preserves_order_of_original_choices(choices in between(0, MAX_NBITS)) {
411 let nbits = choices.get_ref().len();
412
413 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
415 for idx in &choices {
416 set.set(idx, true);
417 }
418
419 let collected = set.iter_set().collect::<Vec<_>>();
421 let original = choices.iter().collect::<Vec<_>>();
422 prop_assert_eq!(&original, &collected);
423
424 if let [_, second, ..] = &*original {
425 let collected = set.iter_set_from(*second).collect::<Vec<_>>();
427 prop_assert_eq!(&original[1..], &collected);
428 }
429
430 prop_assert_eq!(collected, set.iter_set_from(0).collect::<Vec<_>>());
432 }
433
434 #[test]
435 fn serde_round_trip(choices in between(0, MAX_NBITS)) {
436 let nbits = choices.get_ref().len();
437
438 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
440 for idx in &choices {
441 set.set(idx, true);
442 }
443
444 let ser = spacetimedb_lib::bsatn::to_vec(&set)?;
445 let de = spacetimedb_lib::bsatn::from_slice::<FixedBitSet<DefaultBitBlock>>(&ser)?;
446
447 assert!(set == de);
448 }
449 }
450}