thread_safe/lib.rs
1// MIT/Apache2 License
2
3//! Let's say you have some thread-unsafe data. For whatever reason, it can't be used outside of the thread it
4//! originated in. This thread-unsafe data is a component of a larger data struct that does need to be sent
5//! around between other threads.
6//!
7//! The `ThreadSafe` contains data that can only be utilized in the thread it was created in. When a reference
8//! is attempted to be acquired to the interior data, it checks for the current thread it comes from.
9//!
10//! # [`ThreadKey`]
11//!
12//! The `ThreadKey` is a wrapper around `ThreadId`, but `!Send`. This allows one to certify that the current
13//! thread has the given `ThreadId`, without having to go through `thread::current().id()`.
14//!
15//! # Example
16//!
17//! ```
18//! use std::{cell::Cell, sync::{Arc, atomic}, thread};
19//! use thread_safe::ThreadSafe;
20//!
21//! #[derive(Debug)]
22//! struct InnerData {
23//! counter: atomic::AtomicUsize,
24//! other_counter: ThreadSafe<Cell<usize>>,
25//! }
26//!
27//! fn increment_data(data: &InnerData) {
28//! data.counter.fetch_add(1, atomic::Ordering::SeqCst);
29//! if let Ok(counter) = data.other_counter.try_get_ref() {
30//! counter.set(counter.get() + 1);
31//! }
32//! }
33//!
34//! let data = Arc::new(InnerData {
35//! counter: atomic::AtomicUsize::new(0),
36//! other_counter: ThreadSafe::new(Cell::new(0)),
37//! });
38//!
39//! let mut handles = vec![];
40//!
41//! for _ in 0..10 {
42//! let data = data.clone();
43//! handles.push(thread::spawn(move || increment_data(&data)));
44//! }
45//!
46//! increment_data(&data);
47//!
48//! for handle in handles {
49//! handle.join().unwrap();
50//! }
51//!
52//! let data = Arc::try_unwrap(data).unwrap();
53//! assert_eq!(data.counter.load(atomic::Ordering::Relaxed), 11);
54//! assert_eq!(data.other_counter.get_ref().get(), 1);
55//! ```
56
57use std::{
58 error::Error,
59 fmt,
60 marker::PhantomData,
61 mem::{self, ManuallyDrop},
62 rc::Rc,
63 thread::{self, ThreadId},
64 thread_local,
65};
66
67/// The whole point.
68///
69/// This structure wraps around thread-unsafe data and only allows access if it comes from the thread that the
70/// data originated from. This allows thread-unsafe data to be used in thread-safe structures, as long as
71/// the data is only used from the originating thread.
72///
73/// # Panics
74///
75/// If the `ThreadSafe` is dropped in a foreign thread, it will panic. This is because running the drop handle
76/// for the inner data is considered to be using it in a thread-unsafe context.
77pub struct ThreadSafe<T: ?Sized> {
78 // thread that we originated in
79 origin_thread: ThreadId,
80 // whether or not we need to elide the drop check
81 handle_drop: bool,
82 // inner object
83 inner: ManuallyDrop<T>,
84}
85
86impl<T: Default> Default for ThreadSafe<T> {
87 #[inline]
88 fn default() -> Self {
89 Self {
90 inner: ManuallyDrop::new(T::default()),
91 handle_drop: mem::needs_drop::<T>(),
92 origin_thread: thread::current().id(),
93 }
94 }
95}
96
97impl<T: fmt::Debug + ?Sized> fmt::Debug for ThreadSafe<T> {
98 #[inline]
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 if self.origin_thread == thread::current().id() {
101 // SAFETY: self.inner can be accessed since we are on the origin thread
102 fmt::Debug::fmt(&self.inner, f)
103 } else {
104 f.write_str("<not in origin thread>")
105 }
106 }
107}
108
109// SAFETY: we check each and every use of "inner" in the below functions. Using "inner" is considered unsafe.
110unsafe impl<T> Send for ThreadSafe<T> {}
111unsafe impl<T> Sync for ThreadSafe<T> {}
112
113impl<T> ThreadSafe<T> {
114 /// Create a new instance of a `ThreadSafe`.
115 ///
116 /// # Example
117 ///
118 /// ```
119 /// use thread_safe::ThreadSafe;
120 /// let t = ThreadSafe::new(0i32);
121 /// ```
122 #[inline]
123 pub fn new(inner: T) -> ThreadSafe<T> {
124 ThreadSafe {
125 origin_thread: thread::current().id(),
126 handle_drop: mem::needs_drop::<T>(),
127 inner: ManuallyDrop::new(inner),
128 }
129 }
130
131 /// Attempt to convert to the inner type. This errors if it is not in the origin thread.
132 ///
133 /// # Example
134 ///
135 /// ```
136 /// use std::{thread, sync::Arc};
137 /// use thread_safe::ThreadSafe;
138 ///
139 /// let t = ThreadSafe::new(0i32);
140 ///
141 /// let t = thread::spawn(move || match t.try_into_inner() {
142 /// Ok(_) => panic!(),
143 /// Err(t) => t,
144 /// }).join().unwrap();
145 ///
146 /// t.try_into_inner().unwrap();
147 /// ```
148 #[inline]
149 pub fn try_into_inner(self) -> Result<T, ThreadSafe<T>> {
150 self.try_into_inner_with_key(ThreadKey::get())
151 }
152
153 /// Attempt to convert to the inner type, using a thread key.
154 #[inline]
155 pub fn try_into_inner_with_key(mut self, key: ThreadKey) -> Result<T, ThreadSafe<T>> {
156 if self.origin_thread == key.id() {
157 // SAFETY: "inner" can be used since we are in the origin thread
158 // we can take() because we delete the original right after
159 let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
160 // SAFETY: suppress the dropper on this object
161 mem::forget(self);
162 Ok(inner)
163 } else {
164 Err(self)
165 }
166 }
167
168 /// Attempt to convert to the inner type. This panics if it is not in the origin thread.
169 #[inline]
170 pub fn into_inner(self) -> T {
171 match self.try_into_inner() {
172 Ok(i) => i,
173 Err(_) => panic!("Attempted to use a ThreadSafe outside of its origin thread"),
174 }
175 }
176
177 /// Attempt to convert to the inner type, using a thread key.
178 #[inline]
179 pub fn into_inner_with_key(self, key: ThreadKey) -> T {
180 match self.try_into_inner_with_key(key) {
181 Ok(i) => i,
182 Err(_) => panic!("Attempted to use a ThreadSafe outside of its origin thread"),
183 }
184 }
185
186 /// Get the inner object.
187 ///
188 /// # Safety
189 ///
190 /// Behavior is undefined if this is not called in the object's origin thread and the object is `!Send`.
191 #[inline]
192 pub unsafe fn into_inner_unchecked(mut self) -> T {
193 let inner = ManuallyDrop::take(&mut self.inner);
194 mem::forget(self);
195 inner
196 }
197}
198
199impl<T: ?Sized> ThreadSafe<T> {
200 /// Try to get a reference to the inner type. This errors if it is not in the origin thread.
201 #[inline]
202 pub fn try_get_ref(&self) -> Result<&T, NotInOriginThread> {
203 self.try_get_ref_with_key(ThreadKey::get())
204 }
205
206 /// Try to get a reference to the inner type, using a thread key.
207 #[inline]
208 pub fn try_get_ref_with_key(&self, key: ThreadKey) -> Result<&T, NotInOriginThread> {
209 if self.origin_thread == key.id() {
210 // SAFETY: "inner" can be used since we are in the origin thread
211 // it is unlikely that &T can be sent to another thread
212 Ok(&self.inner)
213 } else {
214 Err(NotInOriginThread)
215 }
216 }
217
218 /// Get a reference to the inner type. This panics if it is not called in the origin thread.
219 #[inline]
220 pub fn get_ref(&self) -> &T {
221 match self.try_get_ref() {
222 Ok(i) => i,
223 Err(NotInOriginThread) => {
224 panic!("Attempted to use a ThreadSafe outside of its origin thread")
225 }
226 }
227 }
228
229 /// Get a reference to the inner type, using a thread key.
230 #[inline]
231 pub fn get_ref_with_key(&self, key: ThreadKey) -> &T {
232 match self.try_get_ref_with_key(key) {
233 Ok(i) => i,
234 Err(NotInOriginThread) => {
235 panic!("Attempted to use a ThreadSafe outside of its origin thread")
236 }
237 }
238 }
239
240 /// Get a reference to the inner type without checking for thread safety.
241 ///
242 /// # Safety
243 ///
244 /// Behavior is undefined if this is not called in the origin thread and if `T` is `!Sync`.
245 #[inline]
246 pub unsafe fn get_ref_unchecked(&self) -> &T {
247 &self.inner
248 }
249
250 /// Try to get a mutable reference to the inner type. This errors if it is not in the origin thread.
251 #[inline]
252 pub fn try_get_mut(&mut self) -> Result<&mut T, NotInOriginThread> {
253 self.try_get_mut_with_key(ThreadKey::get())
254 }
255
256 /// Try to get a mutable reference to the inner type, using a thread key.
257 #[inline]
258 pub fn try_get_mut_with_key(&mut self, key: ThreadKey) -> Result<&mut T, NotInOriginThread> {
259 if self.origin_thread == key.id() {
260 // SAFETY: "inner" can be used since we are in the origin thread
261 // it is unlikely that &mut T can be sent to another thread
262 Ok(&mut self.inner)
263 } else {
264 Err(NotInOriginThread)
265 }
266 }
267
268 /// Get a mutable reference to the inner type. This panics if it is not called in the origin thread.
269 #[inline]
270 pub fn get_mut(&mut self) -> &mut T {
271 match self.try_get_mut() {
272 Ok(i) => i,
273 Err(NotInOriginThread) => {
274 panic!("Attempted to use a ThreadSafe outside of its origin thread")
275 }
276 }
277 }
278
279 /// Get a mutable reference to the inner type, using a thread key.
280 #[inline]
281 pub fn get_mut_with_key(&mut self, key: ThreadKey) -> &mut T {
282 match self.try_get_mut_with_key(key) {
283 Ok(i) => i,
284 Err(NotInOriginThread) => {
285 panic!("Attempted to use a ThreadSafe outside of its origin thread")
286 }
287 }
288 }
289
290 /// Get a mutable reference to the inner type without checking for thread safety.
291 ///
292 /// # Safety
293 ///
294 /// Behavior is undefined if this is not called in the origin thread and if `T` is `!Send`.
295 #[inline]
296 pub unsafe fn get_mut_unchecked(&mut self) -> &mut T {
297 &mut self.inner
298 }
299}
300
301impl<T: Clone> ThreadSafe<T> {
302 /// Try to clone this value. This errors if we are not in the origin thread.
303 #[inline]
304 pub fn try_clone(&self) -> Result<ThreadSafe<T>, NotInOriginThread> {
305 self.try_clone_with_key(ThreadKey::get())
306 }
307
308 /// Try to clone this value, using a thread key.
309 #[inline]
310 pub fn try_clone_with_key(&self, key: ThreadKey) -> Result<ThreadSafe<T>, NotInOriginThread> {
311 match self.try_get_ref_with_key(key) {
312 Ok(r) => Ok(ThreadSafe {
313 inner: ManuallyDrop::new(r.clone()),
314 handle_drop: self.handle_drop,
315 origin_thread: self.origin_thread,
316 }),
317 Err(NotInOriginThread) => Err(NotInOriginThread),
318 }
319 }
320
321 /// Clone this value, using a thread key.
322 #[inline]
323 pub fn clone_with_key(&self, key: ThreadKey) -> ThreadSafe<T> {
324 ThreadSafe {
325 inner: ManuallyDrop::new(self.get_ref_with_key(key).clone()),
326 handle_drop: self.handle_drop,
327 origin_thread: self.origin_thread,
328 }
329 }
330}
331
332impl<T: Clone> Clone for ThreadSafe<T> {
333 /// Clone this value. This panics if it takes place outside of the origin thread.
334 #[inline]
335 fn clone(&self) -> ThreadSafe<T> {
336 self.clone_with_key(ThreadKey::get())
337 }
338}
339
340impl<T: ?Sized> Drop for ThreadSafe<T> {
341 #[inline]
342 fn drop(&mut self) {
343 // SAFETY: handle_drop is only turned on if the internal type is needs_drop() in some way
344 if self.handle_drop && self.origin_thread != thread::current().id() {
345 // SAFETY: we cannot allow the type to be dropped, as this is thread unsafe
346 panic!("Attempted to drop ThreadSafe<_> outside of its origin thread");
347 } else {
348 // SAFETY: since we are dropping the outer struct, and we're in the origin thread, we can drop the
349 // inner object
350 unsafe { ManuallyDrop::drop(&mut self.inner) };
351 }
352 }
353}
354
355/// A `ThreadId` that is guaranteed to refer to the current thread, since this is `!Send`.
356#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
357pub struct ThreadKey {
358 id: ThreadId,
359 // ensure this is !Send and !Sync
360 _phantom: PhantomData<Rc<ThreadId>>,
361}
362
363impl Default for ThreadKey {
364 #[inline]
365 fn default() -> Self {
366 Self::get()
367 }
368}
369
370impl ThreadKey {
371 /// Create a new `ThreadKey` based on the current thread.
372 #[inline]
373 pub fn get() -> Self {
374 thread_local! {
375 static ID: ThreadId = thread::current().id();
376 }
377
378 Self {
379 id: ID
380 .try_with(|&id| id)
381 .unwrap_or_else(|_| thread::current().id()),
382 _phantom: PhantomData,
383 }
384 }
385
386 /// Create a new `ThreadKey` using a `ThreadId`.
387 ///
388 /// # Safety
389 ///
390 /// If this `ThreadKey` is ever used, it can only be used in the thread that the thread id refers to.
391 #[inline]
392 pub unsafe fn new(id: ThreadId) -> Self {
393 Self {
394 id,
395 _phantom: PhantomData,
396 }
397 }
398
399 /// Get the `ThreadId` for this `ThreadKey`.
400 #[inline]
401 pub fn id(self) -> ThreadId {
402 self.id
403 }
404}
405
406impl From<ThreadKey> for ThreadId {
407 #[inline]
408 fn from(k: ThreadKey) -> ThreadId {
409 k.id
410 }
411}
412
413/// Error type for "we are not in the current thread".
414#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
415pub struct NotInOriginThread;
416
417impl fmt::Display for NotInOriginThread {
418 #[inline]
419 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420 f.write_str("Attempted to use ThreadSafe<_> outside of its origin thread")
421 }
422}
423
424impl Error for NotInOriginThread {}