1#![allow(clippy::needless_lifetimes)]
2
3use alloc::fmt;
4use core::{
5 cell::UnsafeCell,
6 marker::PhantomData,
7 ops::{Deref, DerefMut},
8 ptr::NonNull,
9 sync::atomic::{AtomicUsize, Ordering},
10};
11use lock_api::{GetThreadId, GuardNoSend, RawMutex};
12
13pub struct RawThreadMutex<R: RawMutex, G: GetThreadId> {
17 owner: AtomicUsize,
18 mutex: R,
19 get_thread_id: G,
20}
21
22impl<R: RawMutex, G: GetThreadId> RawThreadMutex<R, G> {
23 #[allow(
24 clippy::declare_interior_mutable_const,
25 reason = "const initializer for lock primitive contains atomics by design"
26 )]
27 pub const INIT: Self = Self {
28 owner: AtomicUsize::new(0),
29 mutex: R::INIT,
30 get_thread_id: G::INIT,
31 };
32
33 #[inline]
34 fn lock_internal<F: FnOnce() -> bool>(&self, try_lock: F) -> Option<bool> {
35 let id = self.get_thread_id.nonzero_thread_id().get();
36 if self.owner.load(Ordering::Relaxed) == id {
37 return None;
38 } else {
39 if !try_lock() {
40 return Some(false);
41 }
42 self.owner.store(id, Ordering::Relaxed);
43 }
44 Some(true)
45 }
46
47 pub fn lock(&self) -> bool {
50 self.lock_internal(|| {
51 self.mutex.lock();
52 true
53 })
54 .is_some()
55 }
56
57 pub fn lock_wrapped<F: FnOnce(&dyn Fn())>(&self, wrap_fn: F) -> bool {
60 let id = self.get_thread_id.nonzero_thread_id().get();
61 if self.owner.load(Ordering::Relaxed) == id {
62 return false;
63 }
64 wrap_fn(&|| self.mutex.lock());
65 self.owner.store(id, Ordering::Relaxed);
66 true
67 }
68
69 pub fn try_lock(&self) -> Option<bool> {
72 self.lock_internal(|| self.mutex.try_lock())
73 }
74
75 pub unsafe fn unlock(&self) {
82 self.owner.store(0, Ordering::Relaxed);
83 unsafe { self.mutex.unlock() };
84 }
85}
86
87impl<R: RawMutex, G: GetThreadId> RawThreadMutex<R, G> {
88 #[cfg(unix)]
95 pub unsafe fn reinit_after_fork(&self) {
96 self.owner.store(0, Ordering::Relaxed);
97 unsafe {
98 let mutex_ptr = &self.mutex as *const R as *mut u8;
99 core::ptr::write_bytes(mutex_ptr, 0, core::mem::size_of::<R>());
100 }
101 }
102}
103
104unsafe impl<R: RawMutex + Send, G: GetThreadId + Send> Send for RawThreadMutex<R, G> {}
105unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync> Sync for RawThreadMutex<R, G> {}
106
107pub struct ThreadMutex<R: RawMutex, G: GetThreadId, T: ?Sized> {
108 raw: RawThreadMutex<R, G>,
109 data: UnsafeCell<T>,
110}
111
112impl<R: RawMutex, G: GetThreadId, T> ThreadMutex<R, G, T> {
113 pub const fn new(val: T) -> Self {
114 Self {
115 raw: RawThreadMutex::INIT,
116 data: UnsafeCell::new(val),
117 }
118 }
119
120 pub fn into_inner(self) -> T {
121 self.data.into_inner()
122 }
123}
124impl<R: RawMutex, G: GetThreadId, T: Default> Default for ThreadMutex<R, G, T> {
125 fn default() -> Self {
126 Self::new(T::default())
127 }
128}
129impl<R: RawMutex, G: GetThreadId, T> From<T> for ThreadMutex<R, G, T> {
130 fn from(val: T) -> Self {
131 Self::new(val)
132 }
133}
134impl<R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutex<R, G, T> {
135 pub fn raw(&self) -> &RawThreadMutex<R, G> {
137 &self.raw
138 }
139
140 pub fn lock(&self) -> Option<ThreadMutexGuard<'_, R, G, T>> {
141 if self.raw.lock() {
142 Some(ThreadMutexGuard {
143 mu: self,
144 marker: PhantomData,
145 })
146 } else {
147 None
148 }
149 }
150
151 pub fn lock_wrapped<F: FnOnce(&dyn Fn())>(
154 &self,
155 wrap_fn: F,
156 ) -> Option<ThreadMutexGuard<'_, R, G, T>> {
157 if self.raw.lock_wrapped(wrap_fn) {
158 Some(ThreadMutexGuard {
159 mu: self,
160 marker: PhantomData,
161 })
162 } else {
163 None
164 }
165 }
166
167 pub fn try_lock(&self) -> Result<ThreadMutexGuard<'_, R, G, T>, TryLockThreadError> {
168 match self.raw.try_lock() {
169 Some(true) => Ok(ThreadMutexGuard {
170 mu: self,
171 marker: PhantomData,
172 }),
173 Some(false) => Err(TryLockThreadError::Other),
174 None => Err(TryLockThreadError::Current),
175 }
176 }
177}
178
179#[derive(Clone, Copy)]
180pub enum TryLockThreadError {
181 Other,
183 Current,
185}
186
187struct LockedPlaceholder(&'static str);
188
189impl fmt::Debug for LockedPlaceholder {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.write_str(self.0)
192 }
193}
194
195impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug for ThreadMutex<R, G, T> {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 match self.try_lock() {
198 Ok(guard) => f
199 .debug_struct("ThreadMutex")
200 .field("data", &&*guard)
201 .finish(),
202 Err(e) => {
203 let msg = match e {
204 TryLockThreadError::Other => "<locked on other thread>",
205 TryLockThreadError::Current => "<locked on current thread>",
206 };
207 f.debug_struct("ThreadMutex")
208 .field("data", &LockedPlaceholder(msg))
209 .finish()
210 }
211 }
212 }
213}
214
215unsafe impl<R: RawMutex + Send, G: GetThreadId + Send, T: ?Sized + Send> Send
216 for ThreadMutex<R, G, T>
217{
218}
219unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync, T: ?Sized + Send> Sync
220 for ThreadMutex<R, G, T>
221{
222}
223
224pub struct ThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
225 mu: &'a ThreadMutex<R, G, T>,
226 marker: PhantomData<(&'a mut T, GuardNoSend)>,
227}
228impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutexGuard<'a, R, G, T> {
229 pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
230 mut s: Self,
231 f: F,
232 ) -> MappedThreadMutexGuard<'a, R, G, U> {
233 let data = f(&mut s).into();
234 let mu = &s.mu.raw;
235 core::mem::forget(s);
236 MappedThreadMutexGuard {
237 mu,
238 data,
239 marker: PhantomData,
240 }
241 }
242 pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
243 mut s: Self,
244 f: F,
245 ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
246 if let Some(data) = f(&mut s) {
247 let data = data.into();
248 let mu = &s.mu.raw;
249 core::mem::forget(s);
250 Ok(MappedThreadMutexGuard {
251 mu,
252 data,
253 marker: PhantomData,
254 })
255 } else {
256 Err(s)
257 }
258 }
259}
260impl<R: RawMutex, G: GetThreadId, T: ?Sized> Deref for ThreadMutexGuard<'_, R, G, T> {
261 type Target = T;
262 fn deref(&self) -> &T {
263 unsafe { &*self.mu.data.get() }
264 }
265}
266impl<R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for ThreadMutexGuard<'_, R, G, T> {
267 fn deref_mut(&mut self) -> &mut T {
268 unsafe { &mut *self.mu.data.get() }
269 }
270}
271impl<R: RawMutex, G: GetThreadId, T: ?Sized> Drop for ThreadMutexGuard<'_, R, G, T> {
272 fn drop(&mut self) {
273 unsafe { self.mu.raw.unlock() }
274 }
275}
276impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
277 for ThreadMutexGuard<'_, R, G, T>
278{
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 fmt::Display::fmt(&**self, f)
281 }
282}
283impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
284 for ThreadMutexGuard<'_, R, G, T>
285{
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 fmt::Debug::fmt(&**self, f)
288 }
289}
290pub struct MappedThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
291 mu: &'a RawThreadMutex<R, G>,
292 data: NonNull<T>,
293 marker: PhantomData<(&'a mut T, GuardNoSend)>,
294}
295impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedThreadMutexGuard<'a, R, G, T> {
296 pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
297 mut s: Self,
298 f: F,
299 ) -> MappedThreadMutexGuard<'a, R, G, U> {
300 let data = f(&mut s).into();
301 let mu = s.mu;
302 core::mem::forget(s);
303 MappedThreadMutexGuard {
304 mu,
305 data,
306 marker: PhantomData,
307 }
308 }
309 pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
310 mut s: Self,
311 f: F,
312 ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
313 if let Some(data) = f(&mut s) {
314 let data = data.into();
315 let mu = s.mu;
316 core::mem::forget(s);
317 Ok(MappedThreadMutexGuard {
318 mu,
319 data,
320 marker: PhantomData,
321 })
322 } else {
323 Err(s)
324 }
325 }
326}
327impl<R: RawMutex, G: GetThreadId, T: ?Sized> Deref for MappedThreadMutexGuard<'_, R, G, T> {
328 type Target = T;
329 fn deref(&self) -> &T {
330 unsafe { self.data.as_ref() }
331 }
332}
333impl<R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for MappedThreadMutexGuard<'_, R, G, T> {
334 fn deref_mut(&mut self) -> &mut T {
335 unsafe { self.data.as_mut() }
336 }
337}
338impl<R: RawMutex, G: GetThreadId, T: ?Sized> Drop for MappedThreadMutexGuard<'_, R, G, T> {
339 fn drop(&mut self) {
340 unsafe { self.mu.unlock() }
341 }
342}
343impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
344 for MappedThreadMutexGuard<'_, R, G, T>
345{
346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 fmt::Display::fmt(&**self, f)
348 }
349}
350impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
351 for MappedThreadMutexGuard<'_, R, G, T>
352{
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 fmt::Debug::fmt(&**self, f)
355 }
356}