spacetimedb_table/
fixed_bit_set.rs1use core::{
2 cmp, fmt,
3 ops::{BitAnd, BitAndAssign, BitOr, Not, Shl},
4 slice::Iter,
5};
6pub use internal_unsafe::FixedBitSet;
7use internal_unsafe::Len;
8
9use crate::MemoryUsage;
10
11pub trait BitBlock:
16 Copy
17 + Eq
18 + Not<Output = Self>
19 + BitAnd<Self, Output = Self>
20 + BitAndAssign
21 + BitOr<Self, Output = Self>
22 + Shl<usize, Output = Self>
23{
24 const BITS: u32;
26
27 const ONE: Self;
29
30 const ZERO: Self;
32
33 fn wrapping_sub(self, rhs: Self) -> Self;
34 fn trailing_zeros(self) -> u32;
35}
36
37type DefaultBitBlock = u64;
38
39impl BitBlock for DefaultBitBlock {
40 const BITS: u32 = Self::BITS;
41 const ONE: Self = 1;
42 const ZERO: Self = 0;
43
44 #[inline]
45 fn wrapping_sub(self, rhs: Self) -> Self {
46 self.wrapping_sub(rhs)
47 }
48
49 #[inline]
50 fn trailing_zeros(self) -> u32 {
51 self.trailing_zeros()
52 }
53}
54
55mod internal_unsafe {
58 use spacetimedb_lib::{de::Deserialize, ser::Serialize};
59 use spacetimedb_sats::{impl_deserialize, impl_serialize};
60
61 use super::{BitBlock, DefaultBitBlock};
62 use crate::{static_assert_align, static_assert_size};
63 use core::{
64 mem,
65 ptr::NonNull,
66 slice::{from_raw_parts, from_raw_parts_mut},
67 };
68
69 pub(super) type Len = u16;
73
74 #[repr(C, packed)]
78 pub struct FixedBitSet<B = DefaultBitBlock> {
79 len: Len,
81 ptr: NonNull<B>,
83 }
84
85 static_assert_align!(FixedBitSet, 1);
86 static_assert_size!(FixedBitSet, mem::size_of::<usize>() + mem::size_of::<Len>());
87
88 unsafe impl<B> Send for FixedBitSet<B> {}
90 unsafe impl<B> Sync for FixedBitSet<B> {}
92
93 impl<B> Drop for FixedBitSet<B> {
94 fn drop(&mut self) {
95 let blocks = self.storage_mut();
96 let _ = unsafe { Box::from_raw(blocks) };
103 }
104 }
105
106 impl_serialize!([B: BitBlock + Serialize] FixedBitSet<B>, (self, ser) => self.storage().serialize(ser));
108 impl_deserialize!([B: BitBlock + Deserialize<'de>] FixedBitSet<B>, de => {
109 let storage = Box::<[B]>::deserialize(de)?;
110 Ok(Self::from_boxed_slice(storage))
111 });
112
113 impl<B: BitBlock> FixedBitSet<B> {
114 pub(super) fn from_boxed_slice(storage: Box<[B]>) -> Self {
115 assert!(storage.len() <= Len::MAX as usize);
118 let len = storage.len() as Len;
119 let ptr = NonNull::from(Box::leak(storage)).cast();
120 Self { ptr, len }
121 }
122 }
123
124 impl<B> FixedBitSet<B> {
125 #[inline]
127 pub(super) const fn blocks(&self) -> usize {
128 self.len as usize
129 }
130
131 pub(crate) const fn storage(&self) -> &[B] {
133 let ptr = self.ptr.as_ptr();
134 let len = self.blocks();
135 unsafe { from_raw_parts(ptr, len) }
143 }
144
145 pub(super) fn storage_mut(&mut self) -> &mut [B] {
147 let ptr = self.ptr.as_ptr();
148 let len = self.blocks();
149 unsafe { from_raw_parts_mut(ptr, len) }
157 }
158 }
159}
160
161impl<B: BitBlock> cmp::Eq for FixedBitSet<B> {}
162impl<B: BitBlock> cmp::PartialEq for FixedBitSet<B> {
163 fn eq(&self, other: &Self) -> bool {
164 self.storage() == other.storage()
165 }
166}
167
168fn blocks_for_bits<B: BitBlock>(bits: usize) -> usize {
170 bits.div_ceil(B::BITS as usize)
172}
173
174impl<B: BitBlock> FixedBitSet<B> {
175 pub fn new(bits: usize) -> Self {
177 Self::new_for_blocks(blocks_for_bits::<B>(bits))
178 }
179
180 #[inline]
182 fn new_for_blocks(nblocks: usize) -> Self {
183 let blocks: Box<[B]> = vec![B::ZERO; nblocks].into_boxed_slice();
185
186 Self::from_boxed_slice(blocks)
187 }
188
189 const fn idx_to_pos(idx: usize) -> (usize, usize) {
191 let bits = B::BITS as usize;
192 (idx / bits, idx % bits)
193 }
194
195 pub fn get(&self, idx: usize) -> bool {
197 let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
198 let block = self.storage()[block_idx];
199 (block & (B::ONE << pos_in_block)) != B::ZERO
200 }
201
202 pub fn set(&mut self, idx: usize, val: bool) {
204 let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
205 let block = &mut self.storage_mut()[block_idx];
206
207 let flag = B::ONE << pos_in_block;
209 *block = if val { *block | flag } else { *block & !flag };
210 }
211
212 pub fn clear(&mut self) {
214 self.storage_mut().fill(B::ZERO);
215 }
216
217 pub fn reset_for(&mut self, bits: usize) {
221 let nblocks = blocks_for_bits::<B>(bits);
223
224 if nblocks == self.blocks() {
226 self.clear();
227 } else {
228 *self = Self::new_for_blocks(nblocks)
229 }
230 }
231
232 #[inline]
234 pub(crate) const fn bits(&self) -> usize {
235 self.blocks() * B::BITS as usize
236 }
237
238 pub fn iter_set(&self) -> IterSet<'_, B> {
240 let mut inner = self.storage().iter();
241
242 let curr = inner.next().copied().unwrap_or(B::ZERO);
245
246 IterSet {
247 inner,
248 curr,
249 block_idx: 0,
250 }
251 }
252
253 pub fn iter_set_from(&self, start_idx: usize) -> IterSet<'_, B> {
255 let (block_idx, pos_in_block) = Self::idx_to_pos(start_idx);
257
258 let mut inner = self.storage()[block_idx..].iter();
260
261 let curr = inner.next().copied().unwrap_or(B::ZERO);
264
265 let zero_preceding_mask = B::ZERO.wrapping_sub(B::ONE << pos_in_block);
270 let curr = curr & zero_preceding_mask;
271
272 IterSet {
273 inner,
274 curr,
275 block_idx: block_idx as Len,
276 }
277 }
278}
279
280impl<B> MemoryUsage for FixedBitSet<B> {
281 fn heap_usage(&self) -> usize {
282 std::mem::size_of_val(self.storage())
283 }
284}
285
286impl<B: BitBlock> fmt::Debug for FixedBitSet<B> {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 f.debug_set().entries(self.iter_set()).finish()
289 }
290}
291
292pub struct IterSet<'a, B = DefaultBitBlock> {
294 inner: Iter<'a, B>,
296 curr: B,
298 block_idx: Len,
300}
301
302impl<B: BitBlock> Iterator for IterSet<'_, B> {
303 type Item = usize;
304
305 fn next(&mut self) -> Option<Self::Item> {
306 loop {
307 let tz = self.curr.trailing_zeros();
308 if tz < B::BITS {
309 self.curr &= self.curr.wrapping_sub(B::ONE);
312 let idx = self.block_idx as u32 * B::BITS + tz;
313 return Some(idx as usize);
314 } else {
315 self.curr = *self.inner.next()?;
317 self.block_idx += 1;
318 }
319 }
320 }
321}
322
323#[cfg(test)]
324pub(crate) mod test {
325 use super::*;
326 use proptest::bits::bitset::between;
327 use proptest::prelude::*;
328 use spacetimedb_data_structures::map::HashSet;
329
330 #[test]
331 #[should_panic]
332 fn zero_sized_is_ok() {
333 let mut set = FixedBitSet::<DefaultBitBlock>::new(0);
334 set.clear();
335 set.iter_set_from(0).count();
336 set.iter_set().count();
337 set.get(0);
338 }
339
340 const MAX_NBITS: usize = 1000;
341
342 proptest! {
343 #![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
344
345 #[test]
346 fn after_new_there_are_no_bits_set(nbits in 0..MAX_NBITS) {
347 let set = FixedBitSet::<DefaultBitBlock>::new(nbits);
348 for idx in 0..nbits {
349 prop_assert!(!set.get(idx));
350 }
351 }
352
353 #[test]
354 fn after_clear_there_are_no_bits_set(choices in between(0, MAX_NBITS)) {
355 let nbits = choices.get_ref().len();
356
357 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
358
359 for idx in &choices {
361 prop_assert!(!set.get(idx));
362 set.set(idx, true);
363 prop_assert!(set.get(idx));
364 }
365
366 set.clear();
368
369 for idx in 0..nbits {
371 prop_assert!(!set.get(idx));
372 }
373 }
374
375 #[test]
376 fn get_set_consistency(choices in between(0, MAX_NBITS)) {
377 let nbits = choices.get_ref().len();
378 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
379
380 for idx in &choices {
382 prop_assert!(!set.get(idx));
383
384 set.set(idx, true);
386 prop_assert!(set.get(idx));
387 set.set(idx, true);
389 prop_assert!(set.get(idx));
390 }
391
392 let choices: HashSet<_> = choices.into_iter().collect();
394 let universe: HashSet<_> = (0..nbits).collect();
395 for idx in universe.difference(&choices) {
396 prop_assert!(!set.get(*idx));
397 }
398
399 for idx in &choices {
401 set.set(*idx, false);
403 prop_assert!(!set.get(*idx));
404 set.set(*idx, false);
406 prop_assert!(!set.get(*idx));
407 }
408 }
409
410 #[test]
411 fn iter_set_preserves_order_of_original_choices(choices in between(0, MAX_NBITS)) {
412 let nbits = choices.get_ref().len();
413
414 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
416 for idx in &choices {
417 set.set(idx, true);
418 }
419
420 let collected = set.iter_set().collect::<Vec<_>>();
422 let original = choices.iter().collect::<Vec<_>>();
423 prop_assert_eq!(&original, &collected);
424
425 if let [_, second, ..] = &*original {
426 let collected = set.iter_set_from(*second).collect::<Vec<_>>();
428 prop_assert_eq!(&original[1..], &collected);
429 }
430
431 prop_assert_eq!(collected, set.iter_set_from(0).collect::<Vec<_>>());
433 }
434
435 #[test]
436 fn serde_round_trip(choices in between(0, MAX_NBITS)) {
437 let nbits = choices.get_ref().len();
438
439 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
441 for idx in &choices {
442 set.set(idx, true);
443 }
444
445 let ser = spacetimedb_lib::bsatn::to_vec(&set)?;
446 let de = spacetimedb_lib::bsatn::from_slice::<FixedBitSet<DefaultBitBlock>>(&ser)?;
447
448 assert!(set == de);
449 }
450 }
451}