rustpython_common/lock/
thread_mutex.rs

1use lock_api::{GetThreadId, GuardNoSend, RawMutex};
2use std::{
3    cell::UnsafeCell,
4    fmt,
5    marker::PhantomData,
6    ops::{Deref, DerefMut},
7    ptr::NonNull,
8    sync::atomic::{AtomicUsize, Ordering},
9};
10
11// based off ReentrantMutex from lock_api
12
13/// A mutex type that knows when it would deadlock
14pub struct RawThreadMutex<R: RawMutex, G: GetThreadId> {
15    owner: AtomicUsize,
16    mutex: R,
17    get_thread_id: G,
18}
19
20impl<R: RawMutex, G: GetThreadId> RawThreadMutex<R, G> {
21    #[allow(clippy::declare_interior_mutable_const)]
22    pub const INIT: Self = RawThreadMutex {
23        owner: AtomicUsize::new(0),
24        mutex: R::INIT,
25        get_thread_id: G::INIT,
26    };
27
28    #[inline]
29    fn lock_internal<F: FnOnce() -> bool>(&self, try_lock: F) -> Option<bool> {
30        let id = self.get_thread_id.nonzero_thread_id().get();
31        if self.owner.load(Ordering::Relaxed) == id {
32            return None;
33        } else {
34            if !try_lock() {
35                return Some(false);
36            }
37            self.owner.store(id, Ordering::Relaxed);
38        }
39        Some(true)
40    }
41
42    /// Blocks for the mutex to be available, and returns true if the mutex isn't already
43    /// locked on the current thread.
44    pub fn lock(&self) -> bool {
45        self.lock_internal(|| {
46            self.mutex.lock();
47            true
48        })
49        .is_some()
50    }
51
52    /// Returns `Some(true)` if able to successfully lock without blocking, `Some(false)`
53    /// otherwise, and `None` when the mutex is already locked on the current thread.
54    pub fn try_lock(&self) -> Option<bool> {
55        self.lock_internal(|| self.mutex.try_lock())
56    }
57
58    /// Unlocks this mutex. The inner mutex may not be unlocked if
59    /// this mutex was acquired previously in the current thread.
60    ///
61    /// # Safety
62    ///
63    /// This method may only be called if the mutex is held by the current thread.
64    pub unsafe fn unlock(&self) {
65        self.owner.store(0, Ordering::Relaxed);
66        self.mutex.unlock();
67    }
68}
69
70unsafe impl<R: RawMutex + Send, G: GetThreadId + Send> Send for RawThreadMutex<R, G> {}
71unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync> Sync for RawThreadMutex<R, G> {}
72
73pub struct ThreadMutex<R: RawMutex, G: GetThreadId, T: ?Sized> {
74    raw: RawThreadMutex<R, G>,
75    data: UnsafeCell<T>,
76}
77
78impl<R: RawMutex, G: GetThreadId, T> ThreadMutex<R, G, T> {
79    pub fn new(val: T) -> Self {
80        ThreadMutex {
81            raw: RawThreadMutex::INIT,
82            data: UnsafeCell::new(val),
83        }
84    }
85
86    pub fn into_inner(self) -> T {
87        self.data.into_inner()
88    }
89}
90impl<R: RawMutex, G: GetThreadId, T: Default> Default for ThreadMutex<R, G, T> {
91    fn default() -> Self {
92        Self::new(T::default())
93    }
94}
95impl<R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutex<R, G, T> {
96    pub fn lock(&self) -> Option<ThreadMutexGuard<R, G, T>> {
97        if self.raw.lock() {
98            Some(ThreadMutexGuard {
99                mu: self,
100                marker: PhantomData,
101            })
102        } else {
103            None
104        }
105    }
106    pub fn try_lock(&self) -> Result<ThreadMutexGuard<R, G, T>, TryLockThreadError> {
107        match self.raw.try_lock() {
108            Some(true) => Ok(ThreadMutexGuard {
109                mu: self,
110                marker: PhantomData,
111            }),
112            Some(false) => Err(TryLockThreadError::Other),
113            None => Err(TryLockThreadError::Current),
114        }
115    }
116}
117// Whether ThreadMutex::try_lock failed because the mutex was already locked on another thread or
118// on the current thread
119pub enum TryLockThreadError {
120    Other,
121    Current,
122}
123
124struct LockedPlaceholder(&'static str);
125impl fmt::Debug for LockedPlaceholder {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.write_str(self.0)
128    }
129}
130impl<R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug for ThreadMutex<R, G, T> {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        match self.try_lock() {
133            Ok(guard) => f
134                .debug_struct("ThreadMutex")
135                .field("data", &&*guard)
136                .finish(),
137            Err(e) => {
138                let msg = match e {
139                    TryLockThreadError::Other => "<locked on other thread>",
140                    TryLockThreadError::Current => "<locked on current thread>",
141                };
142                f.debug_struct("ThreadMutex")
143                    .field("data", &LockedPlaceholder(msg))
144                    .finish()
145            }
146        }
147    }
148}
149
150unsafe impl<R: RawMutex + Send, G: GetThreadId + Send, T: ?Sized + Send> Send
151    for ThreadMutex<R, G, T>
152{
153}
154unsafe impl<R: RawMutex + Sync, G: GetThreadId + Sync, T: ?Sized + Send> Sync
155    for ThreadMutex<R, G, T>
156{
157}
158
159pub struct ThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
160    mu: &'a ThreadMutex<R, G, T>,
161    marker: PhantomData<(&'a mut T, GuardNoSend)>,
162}
163impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutexGuard<'a, R, G, T> {
164    pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
165        mut s: Self,
166        f: F,
167    ) -> MappedThreadMutexGuard<'a, R, G, U> {
168        let data = f(&mut s).into();
169        let mu = &s.mu.raw;
170        std::mem::forget(s);
171        MappedThreadMutexGuard {
172            mu,
173            data,
174            marker: PhantomData,
175        }
176    }
177    pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
178        mut s: Self,
179        f: F,
180    ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
181        if let Some(data) = f(&mut s) {
182            let data = data.into();
183            let mu = &s.mu.raw;
184            std::mem::forget(s);
185            Ok(MappedThreadMutexGuard {
186                mu,
187                data,
188                marker: PhantomData,
189            })
190        } else {
191            Err(s)
192        }
193    }
194}
195impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for ThreadMutexGuard<'a, R, G, T> {
196    type Target = T;
197    fn deref(&self) -> &T {
198        unsafe { &*self.mu.data.get() }
199    }
200}
201impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for ThreadMutexGuard<'a, R, G, T> {
202    fn deref_mut(&mut self) -> &mut T {
203        unsafe { &mut *self.mu.data.get() }
204    }
205}
206impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for ThreadMutexGuard<'a, R, G, T> {
207    fn drop(&mut self) {
208        unsafe { self.mu.raw.unlock() }
209    }
210}
211impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
212    for ThreadMutexGuard<'a, R, G, T>
213{
214    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
215        fmt::Display::fmt(&**self, f)
216    }
217}
218impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
219    for ThreadMutexGuard<'a, R, G, T>
220{
221    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222        fmt::Debug::fmt(&**self, f)
223    }
224}
225pub struct MappedThreadMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
226    mu: &'a RawThreadMutex<R, G>,
227    data: NonNull<T>,
228    marker: PhantomData<(&'a mut T, GuardNoSend)>,
229}
230impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedThreadMutexGuard<'a, R, G, T> {
231    pub fn map<U, F: FnOnce(&mut T) -> &mut U>(
232        mut s: Self,
233        f: F,
234    ) -> MappedThreadMutexGuard<'a, R, G, U> {
235        let data = f(&mut s).into();
236        let mu = s.mu;
237        std::mem::forget(s);
238        MappedThreadMutexGuard {
239            mu,
240            data,
241            marker: PhantomData,
242        }
243    }
244    pub fn try_map<U, F: FnOnce(&mut T) -> Option<&mut U>>(
245        mut s: Self,
246        f: F,
247    ) -> Result<MappedThreadMutexGuard<'a, R, G, U>, Self> {
248        if let Some(data) = f(&mut s) {
249            let data = data.into();
250            let mu = s.mu;
251            std::mem::forget(s);
252            Ok(MappedThreadMutexGuard {
253                mu,
254                data,
255                marker: PhantomData,
256            })
257        } else {
258            Err(s)
259        }
260    }
261}
262impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for MappedThreadMutexGuard<'a, R, G, T> {
263    type Target = T;
264    fn deref(&self) -> &T {
265        unsafe { self.data.as_ref() }
266    }
267}
268impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for MappedThreadMutexGuard<'a, R, G, T> {
269    fn deref_mut(&mut self) -> &mut T {
270        unsafe { self.data.as_mut() }
271    }
272}
273impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for MappedThreadMutexGuard<'a, R, G, T> {
274    fn drop(&mut self) {
275        unsafe { self.mu.unlock() }
276    }
277}
278impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Display> fmt::Display
279    for MappedThreadMutexGuard<'a, R, G, T>
280{
281    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
282        fmt::Display::fmt(&**self, f)
283    }
284}
285impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + fmt::Debug> fmt::Debug
286    for MappedThreadMutexGuard<'a, R, G, T>
287{
288    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289        fmt::Debug::fmt(&**self, f)
290    }
291}