Skip to main content

rustpython_common/lock/
thread_mutex.rs

1#![allow(clippy::needless_lifetimes)]
2
3use alloc::fmt;
4use core::{
5    cell::UnsafeCell,
6    marker::PhantomData,
7    ops::{Deref, DerefMut},
8    ptr::NonNull,
9    sync::atomic::{AtomicUsize, Ordering},
10};
11use lock_api::{GetThreadId, GuardNoSend, RawMutex};
12
13// based off ReentrantMutex from lock_api
14
15/// A mutex type that knows when it would deadlock
16pub struct RawThreadMutex<R: RawMutex, G: GetThreadId> {
17    owner: AtomicUsize,
18    mutex: R,
19    get_thread_id: G,
20}
21
22impl<R: RawMutex, G: GetThreadId> RawThreadMutex<R, G> {
23    #[allow(
24        clippy::declare_interior_mutable_const,
25        reason = "const initializer for lock primitive contains atomics by design"
26    )]
27    pub const INIT: Self = Self {
28        owner: AtomicUsize::new(0),
29        mutex: R::INIT,
30        get_thread_id: G::INIT,
31    };
32
33    #[inline]
34    fn lock_internal<F: FnOnce() -> bool>(&self, try_lock: F) -> Option<bool> {
35        let id = self.get_thread_id.nonzero_thread_id().get();
36        if self.owner.load(Ordering::Relaxed) == id {
37            return None;
38        } else {
39            if !try_lock() {
40                return Some(false);
41            }
42            self.owner.store(id, Ordering::Relaxed);
43        }
44        Some(true)
45    }
46
47    /// Blocks for the mutex to be available, and returns true if the mutex isn't already
48    /// locked on the current thread.
49    pub fn lock(&self) -> bool {
50        self.lock_internal(|| {
51            self.mutex.lock();
52            true
53        })
54        .is_some()
55    }
56
57    /// Like `lock()` but wraps the blocking wait in `wrap_fn`.
58    /// The caller can use this to detach thread state while waiting.
59    pub fn lock_wrapped<F: FnOnce(&dyn Fn())>(&self, wrap_fn: F) -> bool {
60        let id = self.get_thread_id.nonzero_thread_id().get();
61        if self.owner.load(Ordering::Relaxed) == id {
62            return false;
63        }
64        wrap_fn(&|| self.mutex.lock());
65        self.owner.store(id, Ordering::Relaxed);
66        true
67    }
68
69    /// Returns `Some(true)` if able to successfully lock without blocking, `Some(false)`
70    /// otherwise, and `None` when the mutex is already locked on the current thread.
71    pub fn try_lock(&self) -> Option<bool> {
72        self.lock_internal(|| self.mutex.try_lock())
73    }
74
75    /// Unlocks this mutex. The inner mutex may not be unlocked if
76    /// this mutex was acquired previously in the current thread.
77    ///
78    /// # Safety
79    ///
80    /// This method may only be called if the mutex is held by the current thread.
81    pub unsafe fn unlock(&self) {
82        self.owner.store(0, Ordering::Relaxed);
83        unsafe { self.mutex.unlock() };
84    }
85}
86
87impl<R: RawMutex, G: GetThreadId> RawThreadMutex<R, G> {
88    /// Reset this mutex to its initial (unlocked, unowned) state after `fork()`.
89    ///
90    /// # Safety
91    ///
92    /// Must only be called from the single-threaded child process immediately
93    /// after `fork()`, before any other thread is created.
94    #[cfg(unix)]
95    pub unsafe fn reinit_after_fork(&self) {
96        self.owner.store(0, Ordering::Relaxed);
97        unsafe {
98            let mutex_ptr = &self.mutex as *const R as *mut u8;
99            core::ptr::write_bytes(mutex_ptr, 0, core::mem::size_of::<R>());
100        }
101    }
102}
103
104unsafe impl<R: RawMutex + Send, G: GetThreadId + Send> Send for RawThreadMutex<R, G> {}
105unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync> Sync for RawThreadMutex<R, G> {}
106
107pub struct ThreadMutex<R: RawMutex, G: GetThreadId, T: ?Sized> {
108    raw: RawThreadMutex<R, G>,
109    data: UnsafeCell<T>,
110}
111
112impl<R: RawMutex, G: GetThreadId, T> ThreadMutex<R, G, T> {
113    pub const fn new(val: T) -> Self {
114        Self {
115            raw: RawThreadMutex::INIT,
116            data: UnsafeCell::new(val),
117        }
118    }
119
120    pub fn into_inner(self) -> T {
121        self.data.into_inner()
122    }
123}
124impl<R: RawMutex, G: GetThreadId, T: Default> Default for ThreadMutex<R, G, T> {
125    fn default() -> Self {
126        Self::new(T::default())
127    }
128}
129impl<R: RawMutex, G: GetThreadId, T> From<T> for ThreadMutex<R, G, T> {
130    fn from(val: T) -> Self {
131        Self::new(val)
132    }
133}
134impl<R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutex<R, G, T> {
135    /// Access the underlying raw thread mutex.
136    pub fn raw(&self) -> &RawThreadMutex<R, G> {
137        &self.raw
138    }
139
140    pub fn lock(&self) -> Option<ThreadMutexGuard<'_, R, G, T>> {
141        if self.raw.lock() {
142            Some(ThreadMutexGuard {
143                mu: self,
144                marker: PhantomData,
145            })
146        } else {
147            None
148        }
149    }
150
151    /// Like `lock()` but wraps the blocking wait in `wrap_fn`.
152    /// The caller can use this to detach thread state while waiting.
153    pub fn lock_wrapped<F: FnOnce(&dyn Fn())>(
154        &self,
155        wrap_fn: F,
156    ) -> Option<ThreadMutexGuard<'_, R, G, T>> {
157        if self.raw.lock_wrapped(wrap_fn) {
158            Some(ThreadMutexGuard {
159                mu: self,
160                marker: PhantomData,
161            })
162        } else {
163            None
164        }
165    }
166
167    pub fn try_lock(&self) -> Result<ThreadMutexGuard<'_, R, G, T>, TryLockThreadError> {
168        match self.raw.try_lock() {
169            Some(true) => Ok(ThreadMutexGuard {
170                mu: self,
171                marker: PhantomData,
172            }),
173            Some(false) => Err(TryLockThreadError::Other),
174            None => Err(TryLockThreadError::Current),
175        }
176    }
177}
178
179#[derive(Clone, Copy)]
180pub enum TryLockThreadError {
181    /// Failed to lock because mutex was already locked on another thread.
182    Other,
183    /// Failed to lock because mutex was already locked on current thread.
184    Current,
185}
186
187struct LockedPlaceholder(&'static str);
188
189impl fmt::Debug for LockedPlaceholder {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        f.write_str(self.0)
192    }
193}
194
195impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug for ThreadMutex<R, G, T> {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        match self.try_lock() {
198            Ok(guard) => f
199                .debug_struct("ThreadMutex")
200                .field("data", &&*guard)
201                .finish(),
202            Err(e) => {
203                let msg = match e {
204                    TryLockThreadError::Other => "<locked on other thread>",
205                    TryLockThreadError::Current => "<locked on current thread>",
206                };
207                f.debug_struct("ThreadMutex")
208                    .field("data", &LockedPlaceholder(msg))
209                    .finish()
210            }
211        }
212    }
213}
214
215unsafe impl<R: RawMutex + Send, G: GetThreadId + Send, T: ?Sized + Send> Send
216    for ThreadMutex<R, G, T>
217{
218}
219unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync, T: ?Sized + Send> Sync
220    for ThreadMutex<R, G, T>
221{
222}
223
224pub struct ThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
225    mu: &'a ThreadMutex<R, G, T>,
226    marker: PhantomData<(&'a mut T, GuardNoSend)>,
227}
228impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutexGuard<'a, R, G, T> {
229    pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
230        mut s: Self,
231        f: F,
232    ) -> MappedThreadMutexGuard<'a, R, G, U> {
233        let data = f(&mut s).into();
234        let mu = &s.mu.raw;
235        core::mem::forget(s);
236        MappedThreadMutexGuard {
237            mu,
238            data,
239            marker: PhantomData,
240        }
241    }
242    pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
243        mut s: Self,
244        f: F,
245    ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
246        if let Some(data) = f(&mut s) {
247            let data = data.into();
248            let mu = &s.mu.raw;
249            core::mem::forget(s);
250            Ok(MappedThreadMutexGuard {
251                mu,
252                data,
253                marker: PhantomData,
254            })
255        } else {
256            Err(s)
257        }
258    }
259}
260impl<R: RawMutex, G: GetThreadId, T: ?Sized> Deref for ThreadMutexGuard<'_, R, G, T> {
261    type Target = T;
262    fn deref(&self) -> &T {
263        unsafe { &*self.mu.data.get() }
264    }
265}
266impl<R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for ThreadMutexGuard<'_, R, G, T> {
267    fn deref_mut(&mut self) -> &mut T {
268        unsafe { &mut *self.mu.data.get() }
269    }
270}
271impl<R: RawMutex, G: GetThreadId, T: ?Sized> Drop for ThreadMutexGuard<'_, R, G, T> {
272    fn drop(&mut self) {
273        unsafe { self.mu.raw.unlock() }
274    }
275}
276impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
277    for ThreadMutexGuard<'_, R, G, T>
278{
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        fmt::Display::fmt(&**self, f)
281    }
282}
283impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
284    for ThreadMutexGuard<'_, R, G, T>
285{
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        fmt::Debug::fmt(&**self, f)
288    }
289}
290pub struct MappedThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
291    mu: &'a RawThreadMutex<R, G>,
292    data: NonNull<T>,
293    marker: PhantomData<(&'a mut T, GuardNoSend)>,
294}
295impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedThreadMutexGuard<'a, R, G, T> {
296    pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
297        mut s: Self,
298        f: F,
299    ) -> MappedThreadMutexGuard<'a, R, G, U> {
300        let data = f(&mut s).into();
301        let mu = s.mu;
302        core::mem::forget(s);
303        MappedThreadMutexGuard {
304            mu,
305            data,
306            marker: PhantomData,
307        }
308    }
309    pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
310        mut s: Self,
311        f: F,
312    ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
313        if let Some(data) = f(&mut s) {
314            let data = data.into();
315            let mu = s.mu;
316            core::mem::forget(s);
317            Ok(MappedThreadMutexGuard {
318                mu,
319                data,
320                marker: PhantomData,
321            })
322        } else {
323            Err(s)
324        }
325    }
326}
327impl<R: RawMutex, G: GetThreadId, T: ?Sized> Deref for MappedThreadMutexGuard<'_, R, G, T> {
328    type Target = T;
329    fn deref(&self) -> &T {
330        unsafe { self.data.as_ref() }
331    }
332}
333impl<R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for MappedThreadMutexGuard<'_, R, G, T> {
334    fn deref_mut(&mut self) -> &mut T {
335        unsafe { self.data.as_mut() }
336    }
337}
338impl<R: RawMutex, G: GetThreadId, T: ?Sized> Drop for MappedThreadMutexGuard<'_, R, G, T> {
339    fn drop(&mut self) {
340        unsafe { self.mu.unlock() }
341    }
342}
343impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
344    for MappedThreadMutexGuard<'_, R, G, T>
345{
346    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347        fmt::Display::fmt(&**self, f)
348    }
349}
350impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
351    for MappedThreadMutexGuard<'_, R, G, T>
352{
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        fmt::Debug::fmt(&**self, f)
355    }
356}