utils_atomics/
bitfield.rs

1use crate::traits::{Atomic, AtomicBitAnd, AtomicBitOr, HasAtomicInt};
2use crate::AllocError;
3use crate::{div_ceil, InnerFlag};
4use alloc::boxed::Box;
5use bytemuck::Zeroable;
6use core::{
7    ops::{BitAnd, Not, Shl, Shr},
8    sync::atomic::Ordering,
9};
10use num_traits::Num;
11#[cfg(feature = "alloc_api")]
12use {alloc::alloc::Global, core::alloc::*};
13
14/// An atomic bitfield with a static size, stored in a boxed slice.
15///
16/// This struct provides methods for working with atomic bitfields, allowing
17/// concurrent access and manipulation of individual bits. It is particularly
18/// useful when you need to store a large number of boolean flags and want to
19/// minimize memory usage.
20///
21/// # Example
22///
23/// ```
24/// use utils_atomics::{AtomicBitBox};
25/// use core::sync::atomic::Ordering;
26///
27/// let bit_box = AtomicBitBox::<u8>::new(10);
28/// assert_eq!(bit_box.get(3, Ordering::Relaxed), Some(false));
29/// bit_box.set(3, Ordering::Relaxed);
30/// assert_eq!(bit_box.get(3, Ordering::Relaxed), Some(true));
31/// ```
32#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
33pub struct AtomicBitBox<
34    T: HasAtomicInt = InnerFlag,
35    #[cfg(feature = "alloc_api")] A: Allocator = Global,
36> {
37    #[cfg(feature = "alloc_api")]
38    bits: Box<[T::AtomicInt], A>,
39    #[cfg(not(feature = "alloc_api"))]
40    bits: Box<[T::AtomicInt]>,
41    len: usize,
42}
43
44impl<T: HasAtomicInt> AtomicBitBox<T>
45where
46    T: BitFieldAble,
47{
48    /// Allocates a new bitfield. All values are initialized to `false`.
49    ///
50    /// # Panics
51    /// This method panics if the memory allocation fails
52    #[inline]
53    pub fn new(len: usize) -> Self {
54        Self::try_new(len).unwrap()
55    }
56
57    /// Allocates a new bitfield. All values are initialized to `false`.
58    ///
59    /// # Errors
60    /// This method returns an error if the memory allocation fails
61    #[inline]
62    pub fn try_new(len: usize) -> Result<Self, AllocError> {
63        let count = div_ceil(len, Self::BIT_SIZE);
64
65        let bits;
66        unsafe {
67            cfg_if::cfg_if! {
68                if #[cfg(feature = "nightly")] {
69                    let uninit = Box::<[T::AtomicInt]>::new_zeroed_slice(count);
70                    bits = uninit.assume_init()
71                } else {
72                    let mut tmp = alloc::vec::Vec::with_capacity(count);
73                    core::ptr::write_bytes(tmp.as_mut_ptr(), 0, count);
74                    tmp.set_len(count);
75                    bits = tmp.into_boxed_slice();
76                }
77            };
78        }
79
80        Ok(Self { bits, len })
81    }
82}
83
84cfg_if::cfg_if! {
85    if #[cfg(feature = "alloc_api")] {
86        impl<T: HasAtomicInt, A: Allocator> AtomicBitBox<T, A> where T: BitFieldAble {
87            const BIT_SIZE: usize = 8 * core::mem::size_of::<T>();
88
89            /// Allocates a new bitfield. All values are initialized to `false`.
90            ///
91            /// # Panics
92            /// This method panics if the memory allocation fails
93            #[inline]
94            pub fn new_in (len: usize, alloc: A) -> Self {
95                Self::try_new_in(len, alloc).unwrap()
96            }
97
98            /// Allocates a new bitfield. All values are initialized to `false`.
99            ///
100            /// # Errors
101            /// This method returns an error if the memory allocation fails
102            #[inline]
103            pub fn try_new_in (len: usize, alloc: A) -> Result<Self, AllocError> {
104                let bytes = len.div_ceil(Self::BIT_SIZE);
105                let bits = unsafe {
106                    let uninit = Box::<[T::AtomicInt], _>::new_zeroed_slice_in(bytes, alloc);
107                    uninit.assume_init()
108                };
109
110                Ok(Self { bits, len })
111            }
112
113            /// Returns the value of the bit at the specified index, or `None` if the index is out of bounds.
114            ///
115            /// `order` defines the memory ordering for this operation.
116            pub fn get(&self, idx: usize, order: Ordering) -> Option<bool> {
117                let byte = idx / Self::BIT_SIZE;
118                let idx = idx % Self::BIT_SIZE;
119
120                if !self.check_bounds(byte, idx) {
121                    return None
122                }
123
124                let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
125                let v = byte.load(order);
126                let mask = T::one() << idx;
127                return Some((v & mask) != T::zero())
128            }
129
130            /// Sets the value of the bit at the specified index and returns the previous value, or `None` if the index is out of bounds.
131            ///
132            /// `order` defines the memory ordering for this operation.
133            #[inline]
134            pub fn set_value (&self, v: bool, idx: usize, order: Ordering) -> Option<bool> {
135                if v { return self.set(idx, order) }
136                self.clear(idx, order)
137            }
138
139            /// Sets the bit at the specified index to `true` and returns the previous value, or `None` if the index is out of bounds.
140            ///
141            /// `order` defines the memory ordering for this operation.
142            #[inline]
143            pub fn set (&self, idx: usize, order: Ordering) -> Option<bool> {
144                let byte = idx / Self::BIT_SIZE;
145                let idx = idx % Self::BIT_SIZE;
146
147                if !self.check_bounds(byte, idx) {
148                    return None
149                }
150
151                let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
152                let mask = T::one() << idx;
153                let prev = byte.fetch_or(mask, order);
154                return Some((prev & mask) != T::zero())
155            }
156
157            /// Sets the bit at the specified index to `false` and returns the previous value, or `None` if the index is out of bounds.
158            ///
159            /// `order` defines the memory ordering for this operation.
160            #[inline]
161            pub fn clear (&self, idx: usize, order: Ordering) -> Option<bool> {
162                let byte = idx / Self::BIT_SIZE;
163                let idx = idx % Self::BIT_SIZE;
164
165                if !self.check_bounds(byte, idx) {
166                    return None
167                }
168
169                let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
170                let mask = T::one() << idx;
171                let prev = byte.fetch_and(!mask, order);
172                return Some((prev & mask) != T::zero())
173            }
174
175            #[inline]
176            fn check_bounds (&self, major: usize, minor: usize) -> bool {
177                if major < self.bits.len() - 1 {
178                    return minor < Self::BIT_SIZE
179                }
180                return minor < self.len % Self::BIT_SIZE
181            }
182        }
183    } else {
184        impl<T: HasAtomicInt> AtomicBitBox<T> where T: BitFieldAble {
185            const BIT_SIZE: usize = 8 * core::mem::size_of::<T>();
186
187            /// Returns the value of the bit at the specified index, or `None` if the index is out of bounds.
188            ///
189            /// `order` defines the memory ordering for this operation.
190            pub fn get(&self, idx: usize, order: Ordering) -> Option<bool> {
191                let byte = idx / Self::BIT_SIZE;
192                let idx = idx % Self::BIT_SIZE;
193
194                if !self.check_bounds(byte, idx) {
195                    return None
196                }
197
198                let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
199                let v = byte.load(order);
200                let mask = T::one() << idx;
201                return Some((v & mask) != T::zero())
202            }
203
204            /// Sets the value of the bit at the specified index and returns the previous value, or `None` if the index is out of bounds.
205            ///
206            /// `order` defines the memory ordering for this operation.
207            #[inline]
208            pub fn set_value (&self, v: bool, idx: usize, order: Ordering) -> Option<bool> {
209                if v { return self.set(idx, order) }
210                self.clear(idx, order)
211            }
212
213            /// Sets the bit at the specified index to `true` and returns the previous value, or `None` if the index is out of bounds.
214            ///
215            /// `order` defines the memory ordering for this operation.
216            #[inline]
217            pub fn set (&self, idx: usize, order: Ordering) -> Option<bool> {
218                let byte = idx / Self::BIT_SIZE;
219                let idx = idx % Self::BIT_SIZE;
220
221                if !self.check_bounds(byte, idx) {
222                    return None
223                }
224
225                let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
226                let mask = T::one() << idx;
227                let prev = byte.fetch_or(mask, order);
228                return Some((prev & mask) != T::zero())
229            }
230
231            /// Sets the bit at the specified index to `false` and returns the previous value, or `None` if the index is out of bounds.
232            ///
233            /// `order` defines the memory ordering for this operation.
234            #[inline]
235            pub fn clear (&self, idx: usize, order: Ordering) -> Option<bool> {
236                let byte = idx / Self::BIT_SIZE;
237                let idx = idx % Self::BIT_SIZE;
238
239                if !self.check_bounds(byte, idx) {
240                    return None
241                }
242
243                let byte = unsafe { <[T::AtomicInt]>::get_unchecked(&self.bits, byte) };
244                let mask = T::one() << idx;
245                let prev = byte.fetch_and(!mask, order);
246                return Some((prev & mask) != T::zero())
247            }
248
249            #[inline]
250            fn check_bounds (&self, major: usize, minor: usize) -> bool {
251                if major < self.bits.len() - 1 {
252                    return minor < Self::BIT_SIZE
253                }
254                return minor < self.len % Self::BIT_SIZE
255            }
256        }
257    }
258}
259
260pub trait BitFieldAble:
261    Num
262    + Copy
263    + Zeroable
264    + Eq
265    + BitAnd<Output = Self>
266    + Shl<usize, Output = Self>
267    + Shr<usize, Output = Self>
268    + Not<Output = Self>
269{
270}
271impl<T> BitFieldAble for T where
272    T: Num
273        + Copy
274        + Zeroable
275        + Eq
276        + BitAnd<Output = Self>
277        + Shl<usize, Output = Self>
278        + Shr<usize, Output = Self>
279        + Not<Output = Self>
280{
281}
282
283// Thanks ChatGPT!
284#[cfg(test)]
285mod tests {
286    use core::sync::atomic::Ordering;
287
288    pub type AtomicBitBox = super::AtomicBitBox<u16>;
289
290    #[test]
291    fn new_bitbox() {
292        let bitbox = AtomicBitBox::new(10);
293        for i in 0..10 {
294            assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(false));
295        }
296    }
297
298    #[test]
299    fn set_and_get() {
300        let bitbox = AtomicBitBox::new(10);
301
302        bitbox.set(2, Ordering::SeqCst);
303        bitbox.set(7, Ordering::SeqCst);
304
305        for i in 0..10 {
306            let expected = (i == 2) || (i == 7);
307            assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
308        }
309    }
310
311    #[test]
312    fn set_false_and_get() {
313        let bitbox = AtomicBitBox::new(10);
314
315        bitbox.set(2, Ordering::SeqCst);
316        bitbox.set(7, Ordering::SeqCst);
317
318        bitbox.clear(2, Ordering::SeqCst);
319
320        for i in 0..10 {
321            let expected = i == 7;
322            assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
323        }
324    }
325
326    #[test]
327    fn out_of_bounds() {
328        let bitbox = AtomicBitBox::new(10);
329        assert_eq!(bitbox.get(11, Ordering::SeqCst), None);
330        assert_eq!(bitbox.set(11, Ordering::SeqCst), None);
331        assert_eq!(bitbox.clear(11, Ordering::SeqCst), None);
332    }
333
334    #[cfg(feature = "alloc_api")]
335    mod custom_allocator {
336        use core::sync::atomic::Ordering;
337        use std::alloc::System;
338
339        pub type AtomicBitBox = super::super::AtomicBitBox<u16, System>;
340
341        #[test]
342        fn new_bitbox() {
343            let bitbox = AtomicBitBox::new_in(10, System);
344            for i in 0..10 {
345                assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(false));
346            }
347        }
348
349        #[test]
350        fn set_and_get() {
351            let bitbox = AtomicBitBox::new_in(10, System);
352
353            bitbox.set(2, Ordering::SeqCst);
354            bitbox.set(7, Ordering::SeqCst);
355
356            for i in 0..10 {
357                let expected = (i == 2) || (i == 7);
358                assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
359            }
360        }
361
362        #[test]
363        fn set_false_and_get() {
364            let bitbox = AtomicBitBox::new_in(10, System);
365
366            bitbox.set(2, Ordering::SeqCst);
367            bitbox.set(7, Ordering::SeqCst);
368
369            bitbox.clear(2, Ordering::SeqCst);
370
371            for i in 0..10 {
372                let expected = i == 7;
373                assert_eq!(bitbox.get(i, Ordering::SeqCst), Some(expected));
374            }
375        }
376
377        #[test]
378        fn out_of_bounds() {
379            let bitbox = AtomicBitBox::new_in(10, System);
380            assert_eq!(bitbox.get(11, Ordering::SeqCst), None);
381            assert_eq!(bitbox.set(11, Ordering::SeqCst), None);
382            assert_eq!(bitbox.clear(11, Ordering::SeqCst), None);
383        }
384    }
385}