thread_checked_lock/
mutex.rs1use std::{
2 fmt::{Display, Formatter, Result as FmtResult},
3 ops::{Deref, DerefMut},
4 sync::{Mutex, MutexGuard, PoisonError, TryLockError as StdTryLockError},
5};
6
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10use crate::{locked_mutexes, mutex_id};
11use crate::mutex_id::MutexID;
12use crate::error::{AccessResult, LockError, LockResult, TryLockError, TryLockResult};
13
14
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23#[derive(Debug)]
24pub struct ThreadCheckedMutex<T: ?Sized> {
25 mutex_id: MutexID,
26 mutex: Mutex<T>,
27}
28
29impl<T> ThreadCheckedMutex<T> {
30 #[inline]
32 #[must_use]
33 pub fn new(t: T) -> Self {
34 Self {
35 mutex_id: mutex_id::next_id(),
36 mutex: Mutex::new(t),
37 }
38 }
39}
40
41impl<T: ?Sized> ThreadCheckedMutex<T> {
42 #[inline]
44 const fn new_guard<'a>(&self, guard: MutexGuard<'a, T>) -> ThreadCheckedMutexGuard<'a, T> {
45 ThreadCheckedMutexGuard {
46 mutex_id: self.mutex_id,
47 guard,
48 }
49 }
50
51 #[inline]
54 fn poisoned_guard<'a>(
55 &self,
56 poison: PoisonError<MutexGuard<'a, T>>,
57 ) -> PoisonError<ThreadCheckedMutexGuard<'a, T>> {
58 PoisonError::new(self.new_guard(poison.into_inner()))
59 }
60}
61
62impl<T: ?Sized> ThreadCheckedMutex<T> {
63 pub fn lock(&self) -> LockResult<ThreadCheckedMutexGuard<'_, T>> {
83 if locked_mutexes::register_locked(self.mutex_id) {
84 match self.mutex.lock() {
85 Ok(guard) => Ok(self.new_guard(guard)),
86 Err(poison) => {
87 let poison = self.poisoned_guard(poison);
88 Err(LockError::Poisoned(poison))
89 }
90 }
91 } else {
92 Err(LockError::LockedByCurrentThread)
93 }
94 }
95
96 pub fn try_lock(&self) -> TryLockResult<ThreadCheckedMutexGuard<'_, T>> {
117 if self.locked_by_current_thread() {
118 return Err(TryLockError::LockedByCurrentThread);
119 }
120
121 match self.mutex.try_lock() {
122 Ok(guard) => {
123 #[expect(
124 clippy::let_underscore_must_use,
125 clippy::redundant_type_annotations,
126 reason = "We already checked that the current thread hasn't locked the mutex, \
127 so this always returns true.",
128 )]
129 let _: bool = locked_mutexes::register_locked(self.mutex_id);
130 Ok(self.new_guard(guard))
131 }
132 Err(StdTryLockError::Poisoned(poison)) => {
133 #[expect(
134 clippy::let_underscore_must_use,
135 clippy::redundant_type_annotations,
136 reason = "We already checked that the current thread hasn't locked the mutex, \
137 so this always returns true.",
138 )]
139 let _: bool = locked_mutexes::register_locked(self.mutex_id);
140 let poison = self.poisoned_guard(poison);
141 Err(TryLockError::Poisoned(poison))
142 }
143 Err(StdTryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
144 }
145 }
146
147 #[inline]
149 #[must_use]
150 pub fn locked_by_current_thread(&self) -> bool {
151 locked_mutexes::locked_by_current_thread(self.mutex_id)
152 }
153
154 #[inline]
162 #[must_use]
163 pub fn is_poisoned(&self) -> bool {
164 self.mutex.is_poisoned()
165 }
166
167 #[inline]
174 pub fn clear_poison(&self) {
175 self.mutex.clear_poison();
176 }
177
178 #[inline]
186 pub fn into_inner(self) -> AccessResult<T>
187 where
188 T: Sized,
189 {
190 self.mutex.into_inner().map_err(Into::into)
191 }
192
193 #[inline]
201 pub fn get_mut(&mut self) -> AccessResult<&mut T> {
202 self.mutex.get_mut().map_err(Into::into)
203 }
204}
205
206impl<T: Default> Default for ThreadCheckedMutex<T> {
207 #[inline]
208 fn default() -> Self {
209 Self::new(T::default())
210 }
211}
212
213#[must_use = "if unused the ThreadCheckedMutex will immediately unlock"]
224#[clippy::has_significant_drop]
225#[derive(Debug)]
226pub struct ThreadCheckedMutexGuard<'a, T: ?Sized> {
227 mutex_id: MutexID,
228 guard: MutexGuard<'a, T>,
229}
230
231impl<T: ?Sized> Drop for ThreadCheckedMutexGuard<'_, T> {
232 #[inline]
233 fn drop(&mut self) {
234 let was_locked = locked_mutexes::register_unlocked(self.mutex_id);
235
236 debug_assert!(
238 was_locked,
239 "a ThreadCheckedMutexGuard was dropped in a thread which it was not locked in",
240 );
241 }
242}
243
244impl<T: ?Sized> Deref for ThreadCheckedMutexGuard<'_, T> {
245 type Target = T;
246
247 #[inline]
248 fn deref(&self) -> &Self::Target {
249 &self.guard
250 }
251}
252
253impl<T: ?Sized> DerefMut for ThreadCheckedMutexGuard<'_, T> {
254 #[inline]
255 fn deref_mut(&mut self) -> &mut Self::Target {
256 &mut self.guard
257 }
258}
259
260impl<T: ?Sized + Display> Display for ThreadCheckedMutexGuard<'_, T> {
261 #[inline]
262 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
263 Display::fmt(&*self.guard, f)
264 }
265}
266
267
268#[cfg(test)]
269mod tests {
270 #![expect(clippy::unwrap_used, reason = "these are tests")]
271
272 use std::{sync::mpsc, thread};
273 use std::{sync::Arc, time::Duration};
274
275 use crate::mutex_id::run_this_before_each_test_that_creates_a_mutex_id;
276 use super::*;
277
278
279 #[test]
280 fn lock_then_is_locked() {
281 run_this_before_each_test_that_creates_a_mutex_id();
282
283 let mutex = ThreadCheckedMutex::new(0_u8);
284
285 assert!(!mutex.locked_by_current_thread());
286
287 let _guard = mutex.lock().unwrap();
288
289 assert!(mutex.locked_by_current_thread());
290 }
291
292 #[test]
293 fn lock_unlock_isnt_locked() {
294 run_this_before_each_test_that_creates_a_mutex_id();
295
296 let mutex = ThreadCheckedMutex::new(0_u8);
297
298 let guard = mutex.lock().unwrap();
299
300 assert!(mutex.locked_by_current_thread());
301
302 drop(guard);
303
304 assert!(!mutex.locked_by_current_thread());
305 }
306
307 #[test]
308 fn lock_unlock_lock() {
309 run_this_before_each_test_that_creates_a_mutex_id();
310
311 let mutex = ThreadCheckedMutex::new(0_u8);
312
313 {
314 let _guard = mutex.lock().unwrap();
315 }
316
317 assert!(!mutex.locked_by_current_thread());
318
319 let _guard = mutex.lock().unwrap();
320
321 assert!(mutex.locked_by_current_thread());
322 }
323
324 #[test]
325 fn lock_lock_unlock_lock() {
326 run_this_before_each_test_that_creates_a_mutex_id();
327
328 let mutex = ThreadCheckedMutex::new(0_u8);
329
330 let guard = mutex.lock().unwrap();
331
332 assert!(matches!(
334 mutex.lock(),
335 Err(LockError::LockedByCurrentThread),
336 ));
337
338 drop(guard);
339
340 let _guard = mutex.lock().unwrap();
342 }
343
344 #[test]
345 fn locked_by_current_thread() {
346 run_this_before_each_test_that_creates_a_mutex_id();
347
348 let mutex = Arc::new(ThreadCheckedMutex::new(()));
349 let (sender, receiver) = mpsc::channel();
350
351 let mutex_clone = Arc::clone(&mutex);
352
353 thread::spawn(move || {
354 let guard = mutex_clone.try_lock().unwrap();
355 drop(guard);
356 sender.send(()).unwrap();
357 });
358
359 receiver.recv().unwrap();
361
362 let _guard = mutex.try_lock().unwrap();
364
365 assert!(matches!(
367 mutex.try_lock(),
368 Err(TryLockError::LockedByCurrentThread),
369 ));
370 }
371
372 #[test]
373 fn would_block() {
374 run_this_before_each_test_that_creates_a_mutex_id();
375
376 let mutex = Arc::new(ThreadCheckedMutex::new(()));
377 let (locking_sender, locking_receiver) = mpsc::channel();
378 let (unlocking_sender, unlocking_receiver) = mpsc::channel();
379
380 let mutex_clone = Arc::clone(&mutex);
381
382 thread::spawn(move || {
383 let guard = mutex_clone.try_lock().unwrap();
384
385 locking_sender.send(()).unwrap();
386
387 unlocking_receiver.recv().unwrap();
389
390 thread::sleep(Duration::from_millis(50));
392
393 drop(guard);
394 });
395
396 locking_receiver.recv().unwrap();
398
399 assert!(matches!(
403 mutex.try_lock(),
404 Err(TryLockError::WouldBlock),
405 ));
406
407 unlocking_sender.send(()).unwrap();
408
409 let _guard = mutex.lock().unwrap();
411 }
412}