spacetimedb_table/
fixed_bit_set.rs1use core::{
2 cmp, fmt,
3 ops::{BitAnd, BitAndAssign, BitOr, Not, Shl},
4 slice::Iter,
5};
6pub use internal_unsafe::FixedBitSet;
7use internal_unsafe::Len;
8
9use crate::MemoryUsage;
10
11pub trait BitBlock:
16 Copy
17 + Eq
18 + Not<Output = Self>
19 + BitAnd<Self, Output = Self>
20 + BitAndAssign
21 + BitOr<Self, Output = Self>
22 + Shl<usize, Output = Self>
23{
24 const BITS: u32;
26
27 const ONE: Self;
29
30 const ZERO: Self;
32
33 fn wrapping_sub(self, rhs: Self) -> Self;
34 fn trailing_zeros(self) -> u32;
35}
36
37type DefaultBitBlock = u64;
38
39impl BitBlock for DefaultBitBlock {
40 const BITS: u32 = Self::BITS;
41 const ONE: Self = 1;
42 const ZERO: Self = 0;
43
44 #[inline]
45 fn wrapping_sub(self, rhs: Self) -> Self {
46 self.wrapping_sub(rhs)
47 }
48
49 #[inline]
50 fn trailing_zeros(self) -> u32 {
51 self.trailing_zeros()
52 }
53}
54
55mod internal_unsafe {
58 use spacetimedb_lib::{de::Deserialize, ser::Serialize};
59 use spacetimedb_sats::{impl_deserialize, impl_serialize};
60
61 use super::{BitBlock, DefaultBitBlock};
62 use crate::{static_assert_align, static_assert_size};
63 use core::{
64 mem,
65 ptr::NonNull,
66 slice::{from_raw_parts, from_raw_parts_mut},
67 };
68
69 pub(super) type Len = u16;
73
74 #[repr(C, packed)]
78 pub struct FixedBitSet<B = DefaultBitBlock> {
79 len: Len,
81 ptr: NonNull<B>,
83 }
84
85 static_assert_align!(FixedBitSet, 1);
86 static_assert_size!(FixedBitSet, mem::size_of::<usize>() + mem::size_of::<Len>());
87
88 unsafe impl<B> Send for FixedBitSet<B> {}
90 unsafe impl<B> Sync for FixedBitSet<B> {}
92
93 impl<B> Drop for FixedBitSet<B> {
94 fn drop(&mut self) {
95 let blocks = self.storage_mut();
96 let _ = unsafe { Box::from_raw(blocks) };
103 }
104 }
105
106 impl_serialize!([B: BitBlock + Serialize] FixedBitSet<B>, (self, ser) => self.storage().serialize(ser));
108 impl_deserialize!([B: BitBlock + Deserialize<'de>] FixedBitSet<B>, de => {
109 let storage = Box::<[B]>::deserialize(de)?;
110 Ok(Self::from_boxed_slice(storage))
111 });
112
113 impl<B: BitBlock> FixedBitSet<B> {
114 pub(super) fn from_boxed_slice(storage: Box<[B]>) -> Self {
115 assert!(storage.len() <= Len::MAX as usize);
118 let len = storage.len() as Len;
119 let ptr = NonNull::from(Box::leak(storage)).cast();
120 Self { ptr, len }
121 }
122 }
123
124 impl<B> FixedBitSet<B> {
125 pub(crate) const fn storage(&self) -> &[B] {
127 let ptr = self.ptr.as_ptr();
128 let len = self.len as usize;
129 unsafe { from_raw_parts(ptr, len) }
137 }
138
139 pub(super) fn storage_mut(&mut self) -> &mut [B] {
141 let ptr = self.ptr.as_ptr();
142 let len = self.len as usize;
143 unsafe { from_raw_parts_mut(ptr, len) }
151 }
152 }
153}
154
155impl<B: BitBlock> cmp::Eq for FixedBitSet<B> {}
156impl<B: BitBlock> cmp::PartialEq for FixedBitSet<B> {
157 fn eq(&self, other: &Self) -> bool {
158 self.storage() == other.storage()
159 }
160}
161
162fn blocks_for_bits<B: BitBlock>(bits: usize) -> usize {
164 bits.div_ceil(B::BITS as usize)
166}
167
168impl<B: BitBlock> FixedBitSet<B> {
169 pub fn new(bits: usize) -> Self {
171 let nblocks = blocks_for_bits::<B>(bits);
173
174 let blocks: Box<[B]> = vec![B::ZERO; nblocks].into_boxed_slice();
176
177 Self::from_boxed_slice(blocks)
178 }
179
180 const fn idx_to_pos(idx: usize) -> (usize, usize) {
182 let bits = B::BITS as usize;
183 (idx / bits, idx % bits)
184 }
185
186 pub fn get(&self, idx: usize) -> bool {
188 let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
189 let block = self.storage()[block_idx];
190 (block & (B::ONE << pos_in_block)) != B::ZERO
191 }
192
193 pub fn set(&mut self, idx: usize, val: bool) {
195 let (block_idx, pos_in_block) = Self::idx_to_pos(idx);
196 let block = &mut self.storage_mut()[block_idx];
197
198 let flag = B::ONE << pos_in_block;
200 *block = if val { *block | flag } else { *block & !flag };
201 }
202
203 pub fn clear(&mut self) {
205 self.storage_mut().fill(B::ZERO);
206 }
207
208 pub fn iter_set(&self) -> IterSet<'_, B> {
210 let mut inner = self.storage().iter();
211
212 let curr = inner.next().copied().unwrap_or(B::ZERO);
215
216 IterSet {
217 inner,
218 curr,
219 block_idx: 0,
220 }
221 }
222
223 pub fn iter_set_from(&self, start_idx: usize) -> IterSet<'_, B> {
225 let (block_idx, pos_in_block) = Self::idx_to_pos(start_idx);
227
228 let mut inner = self.storage()[block_idx..].iter();
230
231 let curr = inner.next().copied().unwrap_or(B::ZERO);
234
235 let zero_preceding_mask = B::ZERO.wrapping_sub(B::ONE << pos_in_block);
240 let curr = curr & zero_preceding_mask;
241
242 IterSet {
243 inner,
244 curr,
245 block_idx: block_idx as Len,
246 }
247 }
248}
249
250impl<B> MemoryUsage for FixedBitSet<B> {
251 fn heap_usage(&self) -> usize {
252 std::mem::size_of_val(self.storage())
253 }
254}
255
256impl<B: BitBlock> fmt::Debug for FixedBitSet<B> {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 f.debug_set().entries(self.iter_set()).finish()
259 }
260}
261
262pub struct IterSet<'a, B = DefaultBitBlock> {
264 inner: Iter<'a, B>,
266 curr: B,
268 block_idx: Len,
270}
271
272impl<B: BitBlock> Iterator for IterSet<'_, B> {
273 type Item = usize;
274
275 fn next(&mut self) -> Option<Self::Item> {
276 loop {
277 let tz = self.curr.trailing_zeros();
278 if tz < B::BITS {
279 self.curr &= self.curr.wrapping_sub(B::ONE);
282 let idx = self.block_idx as u32 * B::BITS + tz;
283 return Some(idx as usize);
284 } else {
285 self.curr = *self.inner.next()?;
287 self.block_idx += 1;
288 }
289 }
290 }
291}
292
293#[cfg(test)]
294pub(crate) mod test {
295 use super::*;
296 use proptest::bits::bitset::between;
297 use proptest::prelude::*;
298 use spacetimedb_data_structures::map::HashSet;
299
300 #[test]
301 #[should_panic]
302 fn zero_sized_is_ok() {
303 let mut set = FixedBitSet::<DefaultBitBlock>::new(0);
304 set.clear();
305 set.iter_set_from(0).count();
306 set.iter_set().count();
307 set.get(0);
308 }
309
310 const MAX_NBITS: usize = 1000;
311
312 proptest! {
313 #![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
314
315 #[test]
316 fn after_new_there_are_no_bits_set(nbits in 0..MAX_NBITS) {
317 let set = FixedBitSet::<DefaultBitBlock>::new(nbits);
318 for idx in 0..nbits {
319 prop_assert!(!set.get(idx));
320 }
321 }
322
323 #[test]
324 fn after_clear_there_are_no_bits_set(choices in between(0, MAX_NBITS)) {
325 let nbits = choices.get_ref().len();
326
327 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
328
329 for idx in &choices {
331 prop_assert!(!set.get(idx));
332 set.set(idx, true);
333 prop_assert!(set.get(idx));
334 }
335
336 set.clear();
338
339 for idx in 0..nbits {
341 prop_assert!(!set.get(idx));
342 }
343 }
344
345 #[test]
346 fn get_set_consistency(choices in between(0, MAX_NBITS)) {
347 let nbits = choices.get_ref().len();
348 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
349
350 for idx in &choices {
352 prop_assert!(!set.get(idx));
353
354 set.set(idx, true);
356 prop_assert!(set.get(idx));
357 set.set(idx, true);
359 prop_assert!(set.get(idx));
360 }
361
362 let choices: HashSet<_> = choices.into_iter().collect();
364 let universe: HashSet<_> = (0..nbits).collect();
365 for idx in universe.difference(&choices) {
366 prop_assert!(!set.get(*idx));
367 }
368
369 for idx in &choices {
371 set.set(*idx, false);
373 prop_assert!(!set.get(*idx));
374 set.set(*idx, false);
376 prop_assert!(!set.get(*idx));
377 }
378 }
379
380 #[test]
381 fn iter_set_preserves_order_of_original_choices(choices in between(0, MAX_NBITS)) {
382 let nbits = choices.get_ref().len();
383
384 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
386 for idx in &choices {
387 set.set(idx, true);
388 }
389
390 let collected = set.iter_set().collect::<Vec<_>>();
392 let original = choices.iter().collect::<Vec<_>>();
393 prop_assert_eq!(&original, &collected);
394
395 if let [_, second, ..] = &*original {
396 let collected = set.iter_set_from(*second).collect::<Vec<_>>();
398 prop_assert_eq!(&original[1..], &collected);
399 }
400
401 prop_assert_eq!(collected, set.iter_set_from(0).collect::<Vec<_>>());
403 }
404
405 #[test]
406 fn serde_round_trip(choices in between(0, MAX_NBITS)) {
407 let nbits = choices.get_ref().len();
408
409 let mut set = FixedBitSet::<DefaultBitBlock>::new(nbits);
411 for idx in &choices {
412 set.set(idx, true);
413 }
414
415 let ser = spacetimedb_lib::bsatn::to_vec(&set)?;
416 let de = spacetimedb_lib::bsatn::from_slice::<FixedBitSet<DefaultBitBlock>>(&ser)?;
417
418 assert!(set == de);
419 }
420 }
421}