utils_atomics/
take.rs

1use crate::{InnerAtomicFlag, FALSE, TRUE};
2use core::{
3    cell::UnsafeCell,
4    mem::{needs_drop, MaybeUninit},
5    sync::atomic::Ordering,
6};
7
8/// Inverse of a `OnceCell`. It initializes with a value, which then can be raced by other threads to take.
9///
10/// Once the value is taken, it can never be taken again.
11pub struct TakeCell<T> {
12    taken: InnerAtomicFlag,
13    v: UnsafeCell<MaybeUninit<T>>,
14}
15
16impl<T> TakeCell<T> {
17    /// Creates a new [`TakeCell`]
18    #[inline]
19    pub const fn new(v: T) -> Self {
20        Self {
21            taken: InnerAtomicFlag::new(FALSE),
22            v: UnsafeCell::new(MaybeUninit::new(v)),
23        }
24    }
25
26    /// Creates a [`TakeCell`] that has already been taken
27    #[inline]
28    pub const fn new_taken() -> Self {
29        Self {
30            taken: InnerAtomicFlag::new(TRUE),
31            v: UnsafeCell::new(MaybeUninit::uninit()),
32        }
33    }
34
35    /// Checks if the cell has alredy been taken
36    #[inline]
37    pub fn is_taken(&self) -> bool {
38        self.taken.load(Ordering::Relaxed) == TRUE
39    }
40
41    /// Attempts to take the value from the cell, returning `None` if the value has already been taken
42    #[inline]
43    pub fn try_take(&self) -> Option<T> {
44        if self
45            .taken
46            .compare_exchange(FALSE, TRUE, Ordering::AcqRel, Ordering::Acquire)
47            .is_ok()
48        {
49            unsafe {
50                let v = &*self.v.get();
51                return Some(v.assume_init_read());
52            }
53        }
54        None
55    }
56
57    /// Attempts to take the value from the cell through non-atomic operations, returning `None` if the value has already been taken
58    ///
59    /// # Safety
60    /// This method is safe because the mutable reference indicates we are the only thread with access to the cell,
61    /// so atomic operations aren't required.
62    #[inline]
63    pub fn try_take_mut(&mut self) -> Option<T> {
64        let taken = self.taken.get_mut();
65        if *taken == FALSE {
66            *taken = TRUE;
67
68            unsafe { return Some(self.v.get_mut().assume_init_read()) }
69        }
70        None
71    }
72}
73
74impl<T> Drop for TakeCell<T> {
75    #[inline]
76    fn drop(&mut self) {
77        if needs_drop::<T>() && *self.taken.get_mut() == FALSE {
78            unsafe { self.v.get_mut().assume_init_drop() }
79        }
80    }
81}
82
83unsafe impl<T: Send> Send for TakeCell<T> {}
84unsafe impl<T: Sync> Sync for TakeCell<T> {}
85
86// Thanks ChatGPT!
87#[cfg(test)]
88mod tests {
89    use super::TakeCell;
90
91    #[test]
92    fn test_normal_conditions() {
93        let cell = TakeCell::new(42);
94        assert_eq!(cell.is_taken(), false);
95        assert_eq!(cell.try_take(), Some(42));
96        assert_eq!(cell.is_taken(), true);
97        assert_eq!(cell.try_take(), None);
98
99        let mut cell = TakeCell::new(42);
100        assert_eq!(cell.try_take_mut(), Some(42));
101        assert_eq!(cell.try_take_mut(), None);
102    }
103
104    #[cfg(feature = "std")]
105    #[test]
106    fn test_stressed_conditions() {
107        use alloc::vec::Vec;
108        use std::{
109            sync::{Arc, Barrier},
110            thread,
111        };
112
113        let cell = Arc::new(TakeCell::new(42));
114        let barrier = Arc::new(Barrier::new(10));
115
116        let mut handles = Vec::new();
117
118        for _ in 0..10 {
119            let c = Arc::clone(&cell);
120            let b = Arc::clone(&barrier);
121            handles.push(thread::spawn(move || {
122                b.wait();
123                c.try_take()
124            }));
125        }
126
127        for handle in handles {
128            handle.join().unwrap();
129        }
130
131        assert_eq!(cell.is_taken(), true);
132        assert_eq!(cell.try_take(), None);
133    }
134}