1use crate::{InnerAtomicFlag, FALSE, TRUE};
2use core::{
3 cell::UnsafeCell,
4 mem::{needs_drop, MaybeUninit},
5 sync::atomic::Ordering,
6};
7
8pub struct TakeCell<T> {
12 taken: InnerAtomicFlag,
13 v: UnsafeCell<MaybeUninit<T>>,
14}
15
16impl<T> TakeCell<T> {
17 #[inline]
19 pub const fn new(v: T) -> Self {
20 Self {
21 taken: InnerAtomicFlag::new(FALSE),
22 v: UnsafeCell::new(MaybeUninit::new(v)),
23 }
24 }
25
26 #[inline]
28 pub const fn new_taken() -> Self {
29 Self {
30 taken: InnerAtomicFlag::new(TRUE),
31 v: UnsafeCell::new(MaybeUninit::uninit()),
32 }
33 }
34
35 #[inline]
37 pub fn is_taken(&self) -> bool {
38 self.taken.load(Ordering::Relaxed) == TRUE
39 }
40
41 #[inline]
43 pub fn try_take(&self) -> Option<T> {
44 if self
45 .taken
46 .compare_exchange(FALSE, TRUE, Ordering::AcqRel, Ordering::Acquire)
47 .is_ok()
48 {
49 unsafe {
50 let v = &*self.v.get();
51 return Some(v.assume_init_read());
52 }
53 }
54 None
55 }
56
57 #[inline]
63 pub fn try_take_mut(&mut self) -> Option<T> {
64 let taken = self.taken.get_mut();
65 if *taken == FALSE {
66 *taken = TRUE;
67
68 unsafe { return Some(self.v.get_mut().assume_init_read()) }
69 }
70 None
71 }
72}
73
74impl<T> Drop for TakeCell<T> {
75 #[inline]
76 fn drop(&mut self) {
77 if needs_drop::<T>() && *self.taken.get_mut() == FALSE {
78 unsafe { self.v.get_mut().assume_init_drop() }
79 }
80 }
81}
82
83unsafe impl<T: Send> Send for TakeCell<T> {}
84unsafe impl<T: Sync> Sync for TakeCell<T> {}
85
86#[cfg(test)]
88mod tests {
89 use super::TakeCell;
90
91 #[test]
92 fn test_normal_conditions() {
93 let cell = TakeCell::new(42);
94 assert_eq!(cell.is_taken(), false);
95 assert_eq!(cell.try_take(), Some(42));
96 assert_eq!(cell.is_taken(), true);
97 assert_eq!(cell.try_take(), None);
98
99 let mut cell = TakeCell::new(42);
100 assert_eq!(cell.try_take_mut(), Some(42));
101 assert_eq!(cell.try_take_mut(), None);
102 }
103
104 #[cfg(feature = "std")]
105 #[test]
106 fn test_stressed_conditions() {
107 use alloc::vec::Vec;
108 use std::{
109 sync::{Arc, Barrier},
110 thread,
111 };
112
113 let cell = Arc::new(TakeCell::new(42));
114 let barrier = Arc::new(Barrier::new(10));
115
116 let mut handles = Vec::new();
117
118 for _ in 0..10 {
119 let c = Arc::clone(&cell);
120 let b = Arc::clone(&barrier);
121 handles.push(thread::spawn(move || {
122 b.wait();
123 c.try_take()
124 }));
125 }
126
127 for handle in handles {
128 handle.join().unwrap();
129 }
130
131 assert_eq!(cell.is_taken(), true);
132 assert_eq!(cell.try_take(), None);
133 }
134}