1use crate::traits::{Atomic, AtomicBitAnd, AtomicBitOr, HasAtomicInt};
2use crate::AllocError;
3use crate::{div_ceil, InnerFlag};
4use alloc::boxed::Box;
5use bytemuck::Zeroable;
6use core::{
7 ops::{BitAnd, Not, Shl, Shr},
8 sync::atomic::Ordering,
9};
10use num_traits::Num;
11#[cfg(feature = "alloc_api")]
12use {alloc::alloc::Global, core::alloc::*};
13
14#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
33pub struct AtomicBitBox<
34 T: HasAtomicInt = InnerFlag,
35 #[cfg(feature = "alloc_api")] A: Allocator = Global,
36> {
37 #[cfg(feature = "alloc_api")]
38 bits: Box<[T::AtomicInt], A>,
39 #[cfg(not(feature = "alloc_api"))]
40 bits: Box<[T::AtomicInt]>,
41 len: usize,
42}
43
44impl<T: HasAtomicInt> AtomicBitBox<T>
45where
46 T: BitFieldAble,
47{
48 #[inline]
53 pub fn new(len: usize) -> Self {
54 Self::try_new(len).unwrap()
55 }
56
57 #[inline]
62 pub fn try_new(len: usize) -> Result<Self, AllocError> {
63 let count = div_ceil(len, Self::BIT_SIZE);
64
65 let bits;
66 unsafe {
67 cfg_if::cfg_if! {
68 if #[cfg(feature = "nightly")] {
69 let uninit = Box::<[T::AtomicInt]>::new_zeroed_slice(count);
70 bits = uninit.assume_init()
71 } else {
72 let mut tmp = alloc::vec::Vec::with_capacity(count);
73 core::ptr::write_bytes(tmp.as_mut_ptr(), 0, count);
74 tmp.set_len(count);
75 bits = tmp.into_boxed_slice();
76 }
77 };
78 }
79
80 Ok(Self { bits, len })
81 }
82}
83
84cfg_if::cfg_if! {
85 if #[cfg(feature = "alloc_api")] {
86 impl<T: HasAtomicInt, A: Allocator> AtomicBitBox<T, A> where T: BitFieldAble {
87 const BIT_SIZE: usize = 8 * core::mem::size_of::<T>();
88
89 #[inline]
94 pub fn new_in (len: usize, alloc: A) -> Self {
95 Self::try_new_in(len, alloc).unwrap()
96 }
97
98 #[inline]
103 pub fn try_new_in (len: usize, alloc: A) -> Result<Self, AllocError> {
104 let bytes = len.div_ceil(Self::BIT_SIZE);
105 let bits = unsafe {
106 let uninit = Box::<[T::AtomicInt], _>::new_zeroed_slice_in(bytes, alloc);
107 uninit.assume_init()
108 };
109
110 Ok(Self { bits, len })
111 }
112
113 pub fn get(&self, idx: usize, order: Ordering) -> Option<bool> {
117 let byte = idx / Self::BIT_SIZE;
118 let idx = idx % Self::BIT_SIZE;
119
120 if !self.check_bounds(byte, idx) {
121 return None
122 }
123
124 let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
125 let v = byte.load(order);
126 let mask = T::one() << idx;
127 return Some((v & mask) != T::zero())
128 }
129
130 #[inline]
134 pub fn set_value (&self, v: bool, idx: usize, order: Ordering) -> Option<bool> {
135 if v { return self.set(idx, order) }
136 self.clear(idx, order)
137 }
138
139 #[inline]
143 pub fn set (&self, idx: usize, order: Ordering) -> Option<bool> {
144 let byte = idx / Self::BIT_SIZE;
145 let idx = idx % Self::BIT_SIZE;
146
147 if !self.check_bounds(byte, idx) {
148 return None
149 }
150
151 let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
152 let mask = T::one() << idx;
153 let prev = byte.fetch_or(mask, order);
154 return Some((prev & mask) != T::zero())
155 }
156
157 #[inline]
161 pub fn clear (&self, idx: usize, order: Ordering) -> Option<bool> {
162 let byte = idx / Self::BIT_SIZE;
163 let idx = idx % Self::BIT_SIZE;
164
165 if !self.check_bounds(byte, idx) {
166 return None
167 }
168
169 let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
170 let mask = T::one() << idx;
171 let prev = byte.fetch_and(!mask, order);
172 return Some((prev & mask) != T::zero())
173 }
174
175 #[inline]
176 fn check_bounds (&self, major: usize, minor: usize) -> bool {
177 if major < self.bits.len() - 1 {
178 return minor < Self::BIT_SIZE
179 }
180 return minor < self.len % Self::BIT_SIZE
181 }
182 }
183 } else {
184 impl<T: HasAtomicInt> AtomicBitBox<T> where T: BitFieldAble {
185 const BIT_SIZE: usize = 8 * core::mem::size_of::<T>();
186
187 pub fn get(&self, idx: usize, order: Ordering) -> Option<bool> {
191 let byte = idx / Self::BIT_SIZE;
192 let idx = idx % Self::BIT_SIZE;
193
194 if !self.check_bounds(byte, idx) {
195 return None
196 }
197
198 let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
199 let v = byte.load(order);
200 let mask = T::one() << idx;
201 return Some((v & mask) != T::zero())
202 }
203
204 #[inline]
208 pub fn set_value (&self, v: bool, idx: usize, order: Ordering) -> Option<bool> {
209 if v { return self.set(idx, order) }
210 self.clear(idx, order)
211 }
212
213 #[inline]
217 pub fn set (&self, idx: usize, order: Ordering) -> Option<bool> {
218 let byte = idx / Self::BIT_SIZE;
219 let idx = idx % Self::BIT_SIZE;
220
221 if !self.check_bounds(byte, idx) {
222 return None
223 }
224
225 let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
226 let mask = T::one() << idx;
227 let prev = byte.fetch_or(mask, order);
228 return Some((prev & mask) != T::zero())
229 }
230
231 #[inline]
235 pub fn clear (&self, idx: usize, order: Ordering) -> Option<bool> {
236 let byte = idx / Self::BIT_SIZE;
237 let idx = idx % Self::BIT_SIZE;
238
239 if !self.check_bounds(byte, idx) {
240 return None
241 }
242
243 let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
244 let mask = T::one() << idx;
245 let prev = byte.fetch_and(!mask, order);
246 return Some((prev & mask) != T::zero())
247 }
248
249 #[inline]
250 fn check_bounds (&self, major: usize, minor: usize) -> bool {
251 if major < self.bits.len() - 1 {
252 return minor < Self::BIT_SIZE
253 }
254 return minor < self.len % Self::BIT_SIZE
255 }
256 }
257 }
258}
259
260pub trait BitFieldAble:
261 Num
262 + Copy
263 + Zeroable
264 + Eq
265 + BitAnd<Output = Self>
266 + Shl<usize, Output = Self>
267 + Shr<usize, Output = Self>
268 + Not<Output = Self>
269{
270}
271impl<T> BitFieldAble for T where
272 T: Num
273 + Copy
274 + Zeroable
275 + Eq
276 + BitAnd<Output = Self>
277 + Shl<usize, Output = Self>
278 + Shr<usize, Output = Self>
279 + Not<Output = Self>
280{
281}
282
283#[cfg(test)]
285mod tests {
286 use core::sync::atomic::Ordering;
287
288 pub type AtomicBitBox = super::AtomicBitBox<u16>;
289
290 #[test]
291 fn new_bitbox() {
292 let bitbox = AtomicBitBox::new(10);
293 for i in 0..10 {
294 assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(false));
295 }
296 }
297
298 #[test]
299 fn set_and_get() {
300 let bitbox = AtomicBitBox::new(10);
301
302 bitbox.set(2, Ordering::SeqCst);
303 bitbox.set(7, Ordering::SeqCst);
304
305 for i in 0..10 {
306 let expected = (i == 2) || (i == 7);
307 assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
308 }
309 }
310
311 #[test]
312 fn set_false_and_get() {
313 let bitbox = AtomicBitBox::new(10);
314
315 bitbox.set(2, Ordering::SeqCst);
316 bitbox.set(7, Ordering::SeqCst);
317
318 bitbox.clear(2, Ordering::SeqCst);
319
320 for i in 0..10 {
321 let expected = i == 7;
322 assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
323 }
324 }
325
326 #[test]
327 fn out_of_bounds() {
328 let bitbox = AtomicBitBox::new(10);
329 assert_eq!(bitbox.get(11, Ordering::SeqCst), None);
330 assert_eq!(bitbox.set(11, Ordering::SeqCst), None);
331 assert_eq!(bitbox.clear(11, Ordering::SeqCst), None);
332 }
333
334 #[cfg(feature = "alloc_api")]
335 mod custom_allocator {
336 use core::sync::atomic::Ordering;
337 use std::alloc::System;
338
339 pub type AtomicBitBox = super::super::AtomicBitBox<u16, System>;
340
341 #[test]
342 fn new_bitbox() {
343 let bitbox = AtomicBitBox::new_in(10, System);
344 for i in 0..10 {
345 assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(false));
346 }
347 }
348
349 #[test]
350 fn set_and_get() {
351 let bitbox = AtomicBitBox::new_in(10, System);
352
353 bitbox.set(2, Ordering::SeqCst);
354 bitbox.set(7, Ordering::SeqCst);
355
356 for i in 0..10 {
357 let expected = (i == 2) || (i == 7);
358 assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
359 }
360 }
361
362 #[test]
363 fn set_false_and_get() {
364 let bitbox = AtomicBitBox::new_in(10, System);
365
366 bitbox.set(2, Ordering::SeqCst);
367 bitbox.set(7, Ordering::SeqCst);
368
369 bitbox.clear(2, Ordering::SeqCst);
370
371 for i in 0..10 {
372 let expected = i == 7;
373 assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
374 }
375 }
376
377 #[test]
378 fn out_of_bounds() {
379 let bitbox = AtomicBitBox::new_in(10, System);
380 assert_eq!(bitbox.get(11, Ordering::SeqCst), None);
381 assert_eq!(bitbox.set(11, Ordering::SeqCst), None);
382 assert_eq!(bitbox.clear(11, Ordering::SeqCst), None);
383 }
384 }
385}