1use lock_api::{GetThreadId, GuardNoSend, RawMutex};
2use std::{
3 cell::UnsafeCell,
4 fmt,
5 marker::PhantomData,
6 ops::{Deref, DerefMut},
7 ptr::NonNull,
8 sync::atomic::{AtomicUsize, Ordering},
9};
10
11pub struct RawThreadMutex<R: RawMutex, G: GetThreadId> {
15 owner: AtomicUsize,
16 mutex: R,
17 get_thread_id: G,
18}
19
20impl<R: RawMutex, G: GetThreadId> RawThreadMutex<R, G> {
21 #[allow(clippy::declare_interior_mutable_const)]
22 pub const INIT: Self = RawThreadMutex {
23 owner: AtomicUsize::new(0),
24 mutex: R::INIT,
25 get_thread_id: G::INIT,
26 };
27
28 #[inline]
29 fn lock_internal<F: FnOnce() -> bool>(&self, try_lock: F) -> Option<bool> {
30 let id = self.get_thread_id.nonzero_thread_id().get();
31 if self.owner.load(Ordering::Relaxed) == id {
32 return None;
33 } else {
34 if !try_lock() {
35 return Some(false);
36 }
37 self.owner.store(id, Ordering::Relaxed);
38 }
39 Some(true)
40 }
41
42 pub fn lock(&self) -> bool {
45 self.lock_internal(|| {
46 self.mutex.lock();
47 true
48 })
49 .is_some()
50 }
51
52 pub fn try_lock(&self) -> Option<bool> {
55 self.lock_internal(|| self.mutex.try_lock())
56 }
57
58 pub unsafe fn unlock(&self) {
65 self.owner.store(0, Ordering::Relaxed);
66 self.mutex.unlock();
67 }
68}
69
70unsafe impl<R: RawMutex + Send, G: GetThreadId + Send> Send for RawThreadMutex<R, G> {}
71unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync> Sync for RawThreadMutex<R, G> {}
72
73pub struct ThreadMutex<R: RawMutex, G: GetThreadId, T: ?Sized> {
74 raw: RawThreadMutex<R, G>,
75 data: UnsafeCell<T>,
76}
77
78impl<R: RawMutex, G: GetThreadId, T> ThreadMutex<R, G, T> {
79 pub fn new(val: T) -> Self {
80 ThreadMutex {
81 raw: RawThreadMutex::INIT,
82 data: UnsafeCell::new(val),
83 }
84 }
85
86 pub fn into_inner(self) -> T {
87 self.data.into_inner()
88 }
89}
90impl<R: RawMutex, G: GetThreadId, T: Default> Default for ThreadMutex<R, G, T> {
91 fn default() -> Self {
92 Self::new(T::default())
93 }
94}
95impl<R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutex<R, G, T> {
96 pub fn lock(&self) -> Option<ThreadMutexGuard<R, G, T>> {
97 if self.raw.lock() {
98 Some(ThreadMutexGuard {
99 mu: self,
100 marker: PhantomData,
101 })
102 } else {
103 None
104 }
105 }
106 pub fn try_lock(&self) -> Result<ThreadMutexGuard<R, G, T>, TryLockThreadError> {
107 match self.raw.try_lock() {
108 Some(true) => Ok(ThreadMutexGuard {
109 mu: self,
110 marker: PhantomData,
111 }),
112 Some(false) => Err(TryLockThreadError::Other),
113 None => Err(TryLockThreadError::Current),
114 }
115 }
116}
117pub enum TryLockThreadError {
120 Other,
121 Current,
122}
123
124struct LockedPlaceholder(&'static str);
125impl fmt::Debug for LockedPlaceholder {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.write_str(self.0)
128 }
129}
130impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug for ThreadMutex<R, G, T> {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 match self.try_lock() {
133 Ok(guard) => f
134 .debug_struct("ThreadMutex")
135 .field("data", &&*guard)
136 .finish(),
137 Err(e) => {
138 let msg = match e {
139 TryLockThreadError::Other => "<locked on other thread>",
140 TryLockThreadError::Current => "<locked on current thread>",
141 };
142 f.debug_struct("ThreadMutex")
143 .field("data", &LockedPlaceholder(msg))
144 .finish()
145 }
146 }
147 }
148}
149
150unsafe impl<R: RawMutex + Send, G: GetThreadId + Send, T: ?Sized + Send> Send
151 for ThreadMutex<R, G, T>
152{
153}
154unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync, T: ?Sized + Send> Sync
155 for ThreadMutex<R, G, T>
156{
157}
158
159pub struct ThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
160 mu: &'a ThreadMutex<R, G, T>,
161 marker: PhantomData<(&'a mut T, GuardNoSend)>,
162}
163impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutexGuard<'a, R, G, T> {
164 pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
165 mut s: Self,
166 f: F,
167 ) -> MappedThreadMutexGuard<'a, R, G, U> {
168 let data = f(&mut s).into();
169 let mu = &s.mu.raw;
170 std::mem::forget(s);
171 MappedThreadMutexGuard {
172 mu,
173 data,
174 marker: PhantomData,
175 }
176 }
177 pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
178 mut s: Self,
179 f: F,
180 ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
181 if let Some(data) = f(&mut s) {
182 let data = data.into();
183 let mu = &s.mu.raw;
184 std::mem::forget(s);
185 Ok(MappedThreadMutexGuard {
186 mu,
187 data,
188 marker: PhantomData,
189 })
190 } else {
191 Err(s)
192 }
193 }
194}
195impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for ThreadMutexGuard<'a, R, G, T> {
196 type Target = T;
197 fn deref(&self) -> &T {
198 unsafe { &*self.mu.data.get() }
199 }
200}
201impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for ThreadMutexGuard<'a, R, G, T> {
202 fn deref_mut(&mut self) -> &mut T {
203 unsafe { &mut *self.mu.data.get() }
204 }
205}
206impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for ThreadMutexGuard<'a, R, G, T> {
207 fn drop(&mut self) {
208 unsafe { self.mu.raw.unlock() }
209 }
210}
211impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
212 for ThreadMutexGuard<'a, R, G, T>
213{
214 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
215 fmt::Display::fmt(&**self, f)
216 }
217}
218impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
219 for ThreadMutexGuard<'a, R, G, T>
220{
221 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222 fmt::Debug::fmt(&**self, f)
223 }
224}
225pub struct MappedThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
226 mu: &'a RawThreadMutex<R, G>,
227 data: NonNull<T>,
228 marker: PhantomData<(&'a mut T, GuardNoSend)>,
229}
230impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedThreadMutexGuard<'a, R, G, T> {
231 pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
232 mut s: Self,
233 f: F,
234 ) -> MappedThreadMutexGuard<'a, R, G, U> {
235 let data = f(&mut s).into();
236 let mu = s.mu;
237 std::mem::forget(s);
238 MappedThreadMutexGuard {
239 mu,
240 data,
241 marker: PhantomData,
242 }
243 }
244 pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
245 mut s: Self,
246 f: F,
247 ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
248 if let Some(data) = f(&mut s) {
249 let data = data.into();
250 let mu = s.mu;
251 std::mem::forget(s);
252 Ok(MappedThreadMutexGuard {
253 mu,
254 data,
255 marker: PhantomData,
256 })
257 } else {
258 Err(s)
259 }
260 }
261}
262impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for MappedThreadMutexGuard<'a, R, G, T> {
263 type Target = T;
264 fn deref(&self) -> &T {
265 unsafe { self.data.as_ref() }
266 }
267}
268impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for MappedThreadMutexGuard<'a, R, G, T> {
269 fn deref_mut(&mut self) -> &mut T {
270 unsafe { self.data.as_mut() }
271 }
272}
273impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for MappedThreadMutexGuard<'a, R, G, T> {
274 fn drop(&mut self) {
275 unsafe { self.mu.unlock() }
276 }
277}
278impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
279 for MappedThreadMutexGuard<'a, R, G, T>
280{
281 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
282 fmt::Display::fmt(&**self, f)
283 }
284}
285impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
286 for MappedThreadMutexGuard<'a, R, G, T>
287{
288 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289 fmt::Debug::fmt(&**self, f)
290 }
291}