thread_safe/
lib.rs

1// MIT/Apache2 License
2
3//! Let's say you have some thread-unsafe data. For whatever reason, it can't be used outside of the thread it
4//! originated in. This thread-unsafe data is a component of a larger data struct that does need to be sent
5//! around between other threads.
6//!
7//! The `ThreadSafe` contains data that can only be utilized in the thread it was created in. When a reference
8//! is attempted to be acquired to the interior data, it checks for the current thread it comes from.
9//!
10//! # [`ThreadKey`]
11//!
12//! The `ThreadKey` is a wrapper around `ThreadId`, but `!Send`. This allows one to certify that the current
13//! thread has the given `ThreadId`, without having to go through `thread::current().id()`.
14//!
15//! # Example
16//!
17//! ```
18//! use std::{cell::Cell, sync::{Arc, atomic}, thread};
19//! use thread_safe::ThreadSafe;
20//!
21//! #[derive(Debug)]
22//! struct InnerData {
23//!     counter: atomic::AtomicUsize,
24//!     other_counter: ThreadSafe<Cell<usize>>,
25//! }
26//!
27//! fn increment_data(data: &InnerData) {
28//!     data.counter.fetch_add(1, atomic::Ordering::SeqCst);
29//!     if let Ok(counter) = data.other_counter.try_get_ref() {
30//!         counter.set(counter.get() + 1);
31//!     }
32//! }
33//!
34//! let data = Arc::new(InnerData {
35//!     counter: atomic::AtomicUsize::new(0),
36//!     other_counter: ThreadSafe::new(Cell::new(0)),
37//! });
38//!
39//! let mut handles = vec![];
40//!
41//! for _ in 0..10 {
42//!     let data = data.clone();
43//!     handles.push(thread::spawn(move || increment_data(&data)));
44//! }
45//!
46//! increment_data(&data);
47//!
48//! for handle in handles {
49//!     handle.join().unwrap();
50//! }
51//!
52//! let data = Arc::try_unwrap(data).unwrap();
53//! assert_eq!(data.counter.load(atomic::Ordering::Relaxed), 11);
54//! assert_eq!(data.other_counter.get_ref().get(), 1);
55//! ```
56
57use std::{
58    error::Error,
59    fmt,
60    marker::PhantomData,
61    mem::{self, ManuallyDrop},
62    rc::Rc,
63    thread::{self, ThreadId},
64    thread_local,
65};
66
67/// The whole point.
68///
69/// This structure wraps around thread-unsafe data and only allows access if it comes from the thread that the
70/// data originated from. This allows thread-unsafe data to be used in thread-safe structures, as long as
71/// the data is only used from the originating thread.
72///
73/// # Panics
74///
75/// If the `ThreadSafe` is dropped in a foreign thread, it will panic. This is because running the drop handle
76/// for the inner data is considered to be using it in a thread-unsafe context.
77pub struct ThreadSafe<T: ?Sized> {
78    // thread that we originated in
79    origin_thread: ThreadId,
80    // whether or not we need to elide the drop check
81    handle_drop: bool,
82    // inner object
83    inner: ManuallyDrop<T>,
84}
85
86impl<T: Default> Default for ThreadSafe<T> {
87    #[inline]
88    fn default() -> Self {
89        Self {
90            inner: ManuallyDrop::new(T::default()),
91            handle_drop: mem::needs_drop::<T>(),
92            origin_thread: thread::current().id(),
93        }
94    }
95}
96
97impl<T: fmt::Debug + ?Sized> fmt::Debug for ThreadSafe<T> {
98    #[inline]
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        if self.origin_thread == thread::current().id() {
101            // SAFETY: self.inner can be accessed since we are on the origin thread
102            fmt::Debug::fmt(&self.inner, f)
103        } else {
104            f.write_str("<not in origin thread>")
105        }
106    }
107}
108
109// SAFETY: we check each and every use of "inner" in the below functions. Using "inner" is considered unsafe.
110unsafe impl<T> Send for ThreadSafe<T> {}
111unsafe impl<T> Sync for ThreadSafe<T> {}
112
113impl<T> ThreadSafe<T> {
114    /// Create a new instance of a `ThreadSafe`.
115    ///
116    /// # Example
117    ///
118    /// ```
119    /// use thread_safe::ThreadSafe;
120    /// let t = ThreadSafe::new(0i32);
121    /// ```
122    #[inline]
123    pub fn new(inner: T) -> ThreadSafe<T> {
124        ThreadSafe {
125            origin_thread: thread::current().id(),
126            handle_drop: mem::needs_drop::<T>(),
127            inner: ManuallyDrop::new(inner),
128        }
129    }
130
131    /// Attempt to convert to the inner type. This errors if it is not in the origin thread.
132    ///
133    /// # Example
134    ///
135    /// ```
136    /// use std::{thread, sync::Arc};
137    /// use thread_safe::ThreadSafe;
138    ///
139    /// let t = ThreadSafe::new(0i32);
140    ///
141    /// let t = thread::spawn(move || match t.try_into_inner() {
142    ///     Ok(_) => panic!(),
143    ///     Err(t) => t,
144    /// }).join().unwrap();
145    ///
146    /// t.try_into_inner().unwrap();
147    /// ```
148    #[inline]
149    pub fn try_into_inner(self) -> Result<T, ThreadSafe<T>> {
150        self.try_into_inner_with_key(ThreadKey::get())
151    }
152
153    /// Attempt to convert to the inner type, using a thread key.
154    #[inline]
155    pub fn try_into_inner_with_key(mut self, key: ThreadKey) -> Result<T, ThreadSafe<T>> {
156        if self.origin_thread == key.id() {
157            // SAFETY: "inner" can be used since we are in the origin thread
158            //         we can take() because we delete the original right after
159            let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
160            // SAFETY: suppress the dropper on this object
161            mem::forget(self);
162            Ok(inner)
163        } else {
164            Err(self)
165        }
166    }
167
168    /// Attempt to convert to the inner type. This panics if it is not in the origin thread.
169    #[inline]
170    pub fn into_inner(self) -> T {
171        match self.try_into_inner() {
172            Ok(i) => i,
173            Err(_) => panic!("Attempted to use a ThreadSafe outside of its origin thread"),
174        }
175    }
176
177    /// Attempt to convert to the inner type, using a thread key.
178    #[inline]
179    pub fn into_inner_with_key(self, key: ThreadKey) -> T {
180        match self.try_into_inner_with_key(key) {
181            Ok(i) => i,
182            Err(_) => panic!("Attempted to use a ThreadSafe outside of its origin thread"),
183        }
184    }
185
186    /// Get the inner object.
187    ///
188    /// # Safety
189    ///
190    /// Behavior is undefined if this is not called in the object's origin thread and the object is `!Send`.
191    #[inline]
192    pub unsafe fn into_inner_unchecked(mut self) -> T {
193        let inner = ManuallyDrop::take(&mut self.inner);
194        mem::forget(self);
195        inner
196    }
197}
198
199impl<T: ?Sized> ThreadSafe<T> {
200    /// Try to get a reference to the inner type. This errors if it is not in the origin thread.
201    #[inline]
202    pub fn try_get_ref(&self) -> Result<&T, NotInOriginThread> {
203        self.try_get_ref_with_key(ThreadKey::get())
204    }
205
206    /// Try to get a reference to the inner type, using a thread key.
207    #[inline]
208    pub fn try_get_ref_with_key(&self, key: ThreadKey) -> Result<&T, NotInOriginThread> {
209        if self.origin_thread == key.id() {
210            // SAFETY: "inner" can be used since we are in the origin thread
211            //         it is unlikely that &T can be sent to another thread
212            Ok(&self.inner)
213        } else {
214            Err(NotInOriginThread)
215        }
216    }
217
218    /// Get a reference to the inner type. This panics if it is not called in the origin thread.
219    #[inline]
220    pub fn get_ref(&self) -> &T {
221        match self.try_get_ref() {
222            Ok(i) => i,
223            Err(NotInOriginThread) => {
224                panic!("Attempted to use a ThreadSafe outside of its origin thread")
225            }
226        }
227    }
228
229    /// Get a reference to the inner type, using a thread key.
230    #[inline]
231    pub fn get_ref_with_key(&self, key: ThreadKey) -> &T {
232        match self.try_get_ref_with_key(key) {
233            Ok(i) => i,
234            Err(NotInOriginThread) => {
235                panic!("Attempted to use a ThreadSafe outside of its origin thread")
236            }
237        }
238    }
239
240    /// Get a reference to the inner type without checking for thread safety.
241    ///
242    /// # Safety
243    ///
244    /// Behavior is undefined if this is not called in the origin thread and if `T` is `!Sync`.
245    #[inline]
246    pub unsafe fn get_ref_unchecked(&self) -> &T {
247        &self.inner
248    }
249
250    /// Try to get a mutable reference to the inner type. This errors if it is not in the origin thread.
251    #[inline]
252    pub fn try_get_mut(&mut self) -> Result<&mut T, NotInOriginThread> {
253        self.try_get_mut_with_key(ThreadKey::get())
254    }
255
256    /// Try to get a mutable reference to the inner type, using a thread key.
257    #[inline]
258    pub fn try_get_mut_with_key(&mut self, key: ThreadKey) -> Result<&mut T, NotInOriginThread> {
259        if self.origin_thread == key.id() {
260            // SAFETY: "inner" can be used since we are in the origin thread
261            //         it is unlikely that &mut T can be sent to another thread
262            Ok(&mut self.inner)
263        } else {
264            Err(NotInOriginThread)
265        }
266    }
267
268    /// Get a mutable reference to the inner type. This panics if it is not called in the origin thread.
269    #[inline]
270    pub fn get_mut(&mut self) -> &mut T {
271        match self.try_get_mut() {
272            Ok(i) => i,
273            Err(NotInOriginThread) => {
274                panic!("Attempted to use a ThreadSafe outside of its origin thread")
275            }
276        }
277    }
278
279    /// Get a mutable reference to the inner type, using a thread key.
280    #[inline]
281    pub fn get_mut_with_key(&mut self, key: ThreadKey) -> &mut T {
282        match self.try_get_mut_with_key(key) {
283            Ok(i) => i,
284            Err(NotInOriginThread) => {
285                panic!("Attempted to use a ThreadSafe outside of its origin thread")
286            }
287        }
288    }
289
290    /// Get a mutable reference to the inner type without checking for thread safety.
291    ///
292    /// # Safety
293    ///
294    /// Behavior is undefined if this is not called in the origin thread and if `T` is `!Send`.
295    #[inline]
296    pub unsafe fn get_mut_unchecked(&mut self) -> &mut T {
297        &mut self.inner
298    }
299}
300
301impl<T: Clone> ThreadSafe<T> {
302    /// Try to clone this value. This errors if we are not in the origin thread.
303    #[inline]
304    pub fn try_clone(&self) -> Result<ThreadSafe<T>, NotInOriginThread> {
305        self.try_clone_with_key(ThreadKey::get())
306    }
307
308    /// Try to clone this value, using a thread key.
309    #[inline]
310    pub fn try_clone_with_key(&self, key: ThreadKey) -> Result<ThreadSafe<T>, NotInOriginThread> {
311        match self.try_get_ref_with_key(key) {
312            Ok(r) => Ok(ThreadSafe {
313                inner: ManuallyDrop::new(r.clone()),
314                handle_drop: self.handle_drop,
315                origin_thread: self.origin_thread,
316            }),
317            Err(NotInOriginThread) => Err(NotInOriginThread),
318        }
319    }
320
321    /// Clone this value, using a thread key.
322    #[inline]
323    pub fn clone_with_key(&self, key: ThreadKey) -> ThreadSafe<T> {
324        ThreadSafe {
325            inner: ManuallyDrop::new(self.get_ref_with_key(key).clone()),
326            handle_drop: self.handle_drop,
327            origin_thread: self.origin_thread,
328        }
329    }
330}
331
332impl<T: Clone> Clone for ThreadSafe<T> {
333    /// Clone this value. This panics if it takes place outside of the origin thread.
334    #[inline]
335    fn clone(&self) -> ThreadSafe<T> {
336        self.clone_with_key(ThreadKey::get())
337    }
338}
339
340impl<T: ?Sized> Drop for ThreadSafe<T> {
341    #[inline]
342    fn drop(&mut self) {
343        // SAFETY: handle_drop is only turned on if the internal type is needs_drop() in some way
344        if self.handle_drop && self.origin_thread != thread::current().id() {
345            // SAFETY: we cannot allow the type to be dropped, as this is thread unsafe
346            panic!("Attempted to drop ThreadSafe<_> outside of its origin thread");
347        } else {
348            // SAFETY: since we are dropping the outer struct, and we're in the origin thread, we can drop the
349            //         inner object
350            unsafe { ManuallyDrop::drop(&mut self.inner) };
351        }
352    }
353}
354
355/// A `ThreadId` that is guaranteed to refer to the current thread, since this is `!Send`.
356#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
357pub struct ThreadKey {
358    id: ThreadId,
359    // ensure this is !Send and !Sync
360    _phantom: PhantomData<Rc<ThreadId>>,
361}
362
363impl Default for ThreadKey {
364    #[inline]
365    fn default() -> Self {
366        Self::get()
367    }
368}
369
370impl ThreadKey {
371    /// Create a new `ThreadKey` based on the current thread.
372    #[inline]
373    pub fn get() -> Self {
374        thread_local! {
375            static ID: ThreadId = thread::current().id();
376        }
377
378        Self {
379            id: ID
380                .try_with(|&id| id)
381                .unwrap_or_else(|_| thread::current().id()),
382            _phantom: PhantomData,
383        }
384    }
385
386    /// Create a new `ThreadKey` using a `ThreadId`.
387    ///
388    /// # Safety
389    ///
390    /// If this `ThreadKey` is ever used, it can only be used in the thread that the thread id refers to.
391    #[inline]
392    pub unsafe fn new(id: ThreadId) -> Self {
393        Self {
394            id,
395            _phantom: PhantomData,
396        }
397    }
398
399    /// Get the `ThreadId` for this `ThreadKey`.
400    #[inline]
401    pub fn id(self) -> ThreadId {
402        self.id
403    }
404}
405
406impl From<ThreadKey> for ThreadId {
407    #[inline]
408    fn from(k: ThreadKey) -> ThreadId {
409        k.id
410    }
411}
412
413/// Error type for "we are not in the current thread".
414#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
415pub struct NotInOriginThread;
416
417impl fmt::Display for NotInOriginThread {
418    #[inline]
419    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420        f.write_str("Attempted to use ThreadSafe<_> outside of its origin thread")
421    }
422}
423
424impl Error for NotInOriginThread {}