1#[cfg(feature = "no-std")]
62#[no_std]
63#[cfg(not(feature = "no-std"))]
64use std::thread::ThreadId;
65
66pub struct ThreadOwnedLock<T: ?Sized, P: ThreadIdProvider> {
78 thread_id: P::Id,
79 guard: DoubleLockGuard,
80 data: core::cell::UnsafeCell<T>,
81}
82
83unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Send for ThreadOwnedLock<T, P> {}
84unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Sync for ThreadOwnedLock<T, P> {}
85
86#[must_use = "if unused the ThreadOwnedLock will immediately unlock"]
98pub struct ThreadOwnedLockGuard<'l, T: ?Sized + 'l, P: ThreadIdProvider> {
99 lock: &'l ThreadOwnedLock<T, P>,
100 p: core::marker::PhantomData<*mut ()>, }
102
103#[derive(Debug)]
104pub enum ThreadOwnedMutexError {
105 InvalidThread,
107 AlreadyLocked,
109}
110
111impl<T, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
112 #[inline]
114 pub fn new(value: T) -> Self {
115 Self {
116 data: core::cell::UnsafeCell::new(value),
117 thread_id: P::current_thread_id(),
118 guard: DoubleLockGuard::new(),
119 }
120 }
121
122 pub fn rebind(mut self) -> Self {
135 self.thread_id = P::current_thread_id();
136 self
137 }
138}
139
140impl<T: ?Sized, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
141 #[inline]
147 pub fn lock(&self) -> ThreadOwnedLockGuard<'_, T, P> {
148 match self.try_lock() {
149 Ok(v) => v,
150 Err(e) => panic!("{}", e),
151 }
152 }
153
154 #[inline]
161 pub fn try_lock(&self) -> Result<ThreadOwnedLockGuard<'_, T, P>, ThreadOwnedMutexError> {
162 let current_thread_id = P::current_thread_id();
163 if current_thread_id != self.thread_id {
164 return Err(ThreadOwnedMutexError::InvalidThread);
165 }
166 if self.guard.try_enter() {
167 return Err(ThreadOwnedMutexError::AlreadyLocked);
168 }
169 Ok(ThreadOwnedLockGuard {
170 lock: self,
171 p: core::marker::PhantomData,
172 })
173 }
174}
175
176pub trait ThreadIdProvider {
178 type Id: PartialEq + Eq + Copy;
179
180 fn current_thread_id() -> Self::Id;
182}
183
184#[cfg(not(feature = "no-std"))]
186pub struct StdThreadIdProvider {}
187
188#[cfg(not(feature = "no-std"))]
189impl ThreadIdProvider for StdThreadIdProvider {
190 type Id = std::thread::ThreadId;
191
192 fn current_thread_id() -> Self::Id {
193 std::thread::current().id()
194 }
195}
196#[cfg(not(feature = "no-std"))]
197pub type StdThreadOwnedLock<T> = ThreadOwnedLock<T, StdThreadIdProvider>;
198
199impl<T, P: ThreadIdProvider> From<T> for ThreadOwnedLock<T, P> {
200 fn from(value: T) -> Self {
201 Self::new(value)
202 }
203}
204impl<T: Default, P: ThreadIdProvider> Default for ThreadOwnedLock<T, P> {
205 fn default() -> Self {
206 Self::new(T::default())
207 }
208}
209
210impl<T: ?Sized, P: ThreadIdProvider> Drop for ThreadOwnedLockGuard<'_, T, P> {
211 fn drop(&mut self) {
212 self.lock.guard.exit();
213 }
214}
215
216impl<T: ?Sized, P: ThreadIdProvider> core::ops::Deref for ThreadOwnedLockGuard<'_, T, P> {
217 type Target = T;
218 fn deref(&self) -> &Self::Target {
219 unsafe { &*self.lock.data.get() }
220 }
221}
222
223impl<T: ?Sized, P: ThreadIdProvider> core::ops::DerefMut for ThreadOwnedLockGuard<'_, T, P> {
224 fn deref_mut(&mut self) -> &mut Self::Target {
225 unsafe { &mut *self.lock.data.get() }
226 }
227}
228
229impl<T: ?Sized + core::fmt::Debug, P: ThreadIdProvider> core::fmt::Debug
230 for ThreadOwnedLockGuard<'_, T, P>
231{
232 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
233 core::fmt::Debug::fmt(&**self, f)
234 }
235}
236
237impl<T: ?Sized + core::fmt::Display, P: ThreadIdProvider> core::fmt::Display
238 for ThreadOwnedLockGuard<'_, T, P>
239{
240 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241 (**self).fmt(f)
242 }
243}
244
245impl core::fmt::Display for ThreadOwnedMutexError {
246 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
247 match self {
248 ThreadOwnedMutexError::InvalidThread => {
249 f.write_str("Current thread does not own this lock")
250 }
251 ThreadOwnedMutexError::AlreadyLocked => f.write_str("Already locked"),
252 }
253 }
254}
255
256impl std::error::Error for ThreadOwnedMutexError {}
257
258#[doc(hidden)]
259struct DoubleLockGuard(core::cell::UnsafeCell<bool>);
260
261impl DoubleLockGuard {
262 fn new() -> Self {
263 Self(core::cell::UnsafeCell::new(false))
264 }
265
266 #[inline]
267 fn try_enter(&self) -> bool {
268 unsafe {
271 let old = *self.0.get();
272 *self.0.get() = true;
273 old
274 }
275 }
276
277 #[inline]
278 fn exit(&self) {
279 unsafe {
282 *self.0.get() = false;
283 }
284 }
285}
286
287#[cfg(all(test, not(feature = "no-std")))]
288mod test {
289 use super::*;
290
291 #[test]
292 fn test_lock() {
293 let lock = StdThreadOwnedLock::new(20);
294 {
295 let guard = lock.try_lock().expect("failed to acquire lock");
296 assert_eq!(*guard, 20);
297 }
298 let h = std::thread::spawn(move || {
299 let err = lock.try_lock().expect_err("Should fail");
300 assert!(matches!(err, ThreadOwnedMutexError::InvalidThread));
301 });
302 h.join().unwrap();
303 }
304
305 #[test]
306 fn test_double_lock_fails() {
307 let lock = StdThreadOwnedLock::new(20);
308 {
309 let _guard = lock.try_lock().expect("failed to acquire lock");
310 let err = lock.try_lock().expect_err("Should fail");
311 assert!(matches!(err, ThreadOwnedMutexError::AlreadyLocked));
312 }
313 let _guard = lock.try_lock().expect("failed to acquire lock");
314 }
315
316 #[test]
317 fn test_lock_rebind() {
318 let lock = StdThreadOwnedLock::new(20);
319 assert_eq!(lock.thread_id, std::thread::current().id());
320 let h = std::thread::spawn(move || {
321 let lock = lock.rebind();
322 assert_eq!(lock.thread_id, std::thread::current().id());
323 });
324 h.join().unwrap();
325 }
326}