thread_cell/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::thread;
4
5/// Messages sent to the manager thread
6enum ThreadCellMessage<T> {
7    Run(Box<dyn FnOnce(&mut T) + Send>),
8    GetSessionSync(crossbeam::channel::Sender<ThreadCellSession<T>>),
9    #[cfg(feature = "tokio")]
10    GetSessionAsync(tokio::sync::oneshot::Sender<ThreadCellSession<T>>),
11}
12
13/// A message type for session callbacks
14type SessionMsg<T> = Box<dyn FnOnce(&mut T) + Send>;
15
16static SESSION_ERROR_MESSAGE: &str = "Session thread has panicked or resource was dropped";
17
18/// A session with exclusive access to the resource held by the thread.
19/// While held, this is the only way to access the resource. It is possible to create a "deadlock"
20/// if a `ThreadCellSession` is requested while one is already held.
21pub struct ThreadCellSession<T> {
22    sender: crossbeam::channel::Sender<SessionMsg<T>>,
23}
24
25impl<T> ThreadCellSession<T> {
26    pub fn run_blocking<F, R>(&self, f: F) -> R
27    where
28        F: FnOnce(&mut T) -> R + Send + 'static,
29        R: Send + 'static,
30    {
31        let (tx, rx) = crossbeam::channel::bounded(1);
32        self.sender
33            .send(Box::new(move |resource| {
34                let res = f(resource);
35                let _ = tx.send(res);
36            }))
37            .expect(SESSION_ERROR_MESSAGE);
38        rx.recv().expect(SESSION_ERROR_MESSAGE)
39    }
40
41    #[cfg(feature = "tokio")]
42    pub async fn run<F, R>(&self, f: F) -> R
43    where
44        F: FnOnce(&mut T) -> R + Send + 'static,
45        R: Send + 'static,
46    {
47        let (tx, rx) = tokio::sync::oneshot::channel();
48        self.sender
49            .send(Box::new(move |resource| {
50                let res = f(resource);
51                let _ = tx.send(res);
52            }))
53            .expect(SESSION_ERROR_MESSAGE);
54        rx.await.expect(SESSION_ERROR_MESSAGE)
55    }
56}
57
58static MANAGER_ERROR_MESSAGE: &str = "Manager thread has panicked";
59
60/// A cell that holds a value bound to a single thread. Thus `T` can be non-`Send` and/or non-`Sync`,
61/// but `ThreadCell<T>` is always `Send`/`Sync`. Access is provided through message passing, so no
62/// internal locking is used. But a lock-like `ThreadCellSession` can be acquired to gain exclusive
63/// access to the underlying resource while held.
64pub struct ThreadCell<T: 'static> {
65    sender: crossbeam::channel::Sender<ThreadCellMessage<T>>,
66}
67
68impl<T: 'static> Clone for ThreadCell<T> {
69    fn clone(&self) -> Self {
70        Self {
71            sender: self.sender.clone(),
72        }
73    }
74}
75
76impl<T: Send> ThreadCell<T> {
77    /// Creates new
78    pub fn new(mut resource: T) -> Self {
79        let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
80
81        thread::spawn(move || {
82            while let Ok(msg) = rx.recv() {
83                match msg {
84                    ThreadCellMessage::Run(f) => f(&mut resource),
85                    ThreadCellMessage::GetSessionSync(responder) => {
86                        let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
87                        let _ = responder.send(ThreadCellSession { sender: stx });
88                        while let Ok(f) = srx.recv() {
89                            f(&mut resource);
90                        }
91                    }
92                    #[cfg(feature = "tokio")]
93                    ThreadCellMessage::GetSessionAsync(sender) => {
94                        let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
95                        let _ = sender.send(ThreadCellSession { sender: stx });
96                        while let Ok(f) = srx.recv() {
97                            f(&mut resource);
98                        }
99                    }
100                }
101            }
102        });
103
104        Self { sender: tx }
105    }
106}
107
108impl<T> ThreadCell<T> {
109    /// Creates a new when `T` is not `Send` but a function to create `T` is
110    pub fn new_with<F: FnOnce() -> T + Send + 'static>(resource_fn: F) -> Self {
111        let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
112
113        thread::spawn(move || {
114            let mut resource = resource_fn();
115            while let Ok(msg) = rx.recv() {
116                match msg {
117                    ThreadCellMessage::Run(f) => f(&mut resource),
118                    ThreadCellMessage::GetSessionSync(responder) => {
119                        let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
120                        let _ = responder.send(ThreadCellSession { sender: stx });
121                        while let Ok(f) = srx.recv() {
122                            f(&mut resource);
123                        }
124                    }
125                    #[cfg(feature = "tokio")]
126                    ThreadCellMessage::GetSessionAsync(sender) => {
127                        let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
128                        let _ = sender.send(ThreadCellSession { sender: stx });
129                        while let Ok(f) = srx.recv() {
130                            f(&mut resource);
131                        }
132                    }
133                }
134            }
135        });
136
137        Self { sender: tx }
138    }
139
140    pub fn run_blocking<F, R>(&self, f: F) -> R
141    where
142        F: FnOnce(&mut T) -> R + Send + 'static,
143        R: Send + 'static,
144    {
145        let (tx, rx) = crossbeam::channel::bounded(1);
146        self.sender
147            .send(ThreadCellMessage::Run(Box::new(move |resource| {
148                let res = f(resource);
149                let _ = tx.send(res);
150            })))
151            .expect(MANAGER_ERROR_MESSAGE);
152        rx.recv().expect(MANAGER_ERROR_MESSAGE)
153    }
154
155    #[cfg(feature = "tokio")]
156    pub async fn run<F, R>(&self, f: F) -> R
157    where
158        F: FnOnce(&mut T) -> R + Send + 'static,
159        R: Send + 'static,
160    {
161        let (tx, rx) = tokio::sync::oneshot::channel();
162        self.sender
163            .send(ThreadCellMessage::Run(Box::new(move |resource| {
164                let res = f(resource);
165                let _ = tx.send(res);
166            })))
167            .expect(MANAGER_ERROR_MESSAGE);
168        rx.await.expect(MANAGER_ERROR_MESSAGE)
169    }
170
171    pub fn session_blocking(&self) -> ThreadCellSession<T> {
172        let (tx, rx) = crossbeam::channel::bounded(1);
173        self.sender
174            .send(ThreadCellMessage::GetSessionSync(tx))
175            .expect(MANAGER_ERROR_MESSAGE);
176        rx.recv().expect(MANAGER_ERROR_MESSAGE)
177    }
178
179    #[cfg(feature = "tokio")]
180    pub async fn session(&self) -> ThreadCellSession<T> {
181        let (tx, rx) = tokio::sync::oneshot::channel();
182        self.sender
183            .send(ThreadCellMessage::GetSessionAsync(tx))
184            .expect(MANAGER_ERROR_MESSAGE);
185        rx.await.expect(MANAGER_ERROR_MESSAGE)
186    }
187}
188
189impl<T: Send> ThreadCell<T> {
190    /// Set the resource in a blocking manner
191    pub fn set_blocking(&self, new_value: T) {
192        self.run_blocking(|res| *res = new_value);
193    }
194
195    /// Set the resource in an async manner
196    #[cfg(feature = "tokio")]
197    pub async fn set(&self, new_value: T) {
198        self.run(|res| *res = new_value).await;
199    }
200
201    /// Set the resource in a blocking manner, returning the old value
202    pub fn replace_blocking(&self, new_value: T) -> T {
203        self.run_blocking(|res| std::mem::replace(res, new_value))
204    }
205
206    /// Set the resource in an async manner, returning the old value
207    #[cfg(feature = "tokio")]
208    pub async fn replace(&self, new_value: T) -> T {
209        self.run(|res| std::mem::replace(res, new_value)).await
210    }
211}
212
213impl<T: Send + Default> ThreadCell<T> {
214    pub fn take_blocking(&self) -> T {
215        self.run_blocking(|res| std::mem::take(res))
216    }
217
218    #[cfg(feature = "tokio")]
219    pub async fn take(&self) -> T {
220        self.run(|res| std::mem::take(res)).await
221    }
222}
223
224impl<T: Send + Clone> ThreadCell<T> {
225    /// Get a clone of the resource in a blocking manner
226    pub fn get_blocking(&self) -> T {
227        self.run_blocking(|res| res.clone())
228    }
229
230    /// Get a clone of the resource in an async manner
231    #[cfg(feature = "tokio")]
232    pub async fn get(&self) -> T {
233        self.run(|res| res.clone()).await
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use std::rc::Rc;
241    use std::sync::Arc;
242    use std::sync::atomic::{AtomicUsize, Ordering};
243
244    #[derive(Default)]
245    struct TestResource {
246        counter: usize,
247    }
248
249    impl TestResource {
250        fn increment(&mut self) -> usize {
251            self.counter += 1;
252            self.counter
253        }
254    }
255
256    #[test]
257    fn basic_run_blocking_works() {
258        let cell = ThreadCell::new(TestResource::default());
259        let value = cell.run_blocking(|res| {
260            res.increment();
261            res.increment()
262        });
263        assert_eq!(value, 2);
264
265        let value = cell.run_blocking(|res| res.increment());
266        assert_eq!(value, 3);
267    }
268
269    #[test]
270    fn can_be_sent_to_another_thread() {
271        let cell = ThreadCell::new(TestResource::default());
272        let handle = std::thread::spawn(move || cell.run_blocking(|res| res.increment()));
273        let result = handle.join().unwrap();
274        assert_eq!(result, 1);
275    }
276
277    #[cfg(feature = "tokio")]
278    #[tokio::test(flavor = "current_thread")]
279    async fn async_run_works() {
280        let cell = ThreadCell::new(TestResource::default());
281        let result = cell.run(|res| res.increment()).await;
282        assert_eq!(result, 1);
283    }
284
285    #[test]
286    fn session_blocking_gives_mutable_access() {
287        let cell = ThreadCell::new(TestResource::default());
288        let lock = cell.session_blocking();
289        let value = lock.run_blocking(|res| {
290            res.increment();
291            res.increment()
292        });
293        assert_eq!(value, 2);
294    }
295
296    #[cfg(feature = "tokio")]
297    #[tokio::test(flavor = "current_thread")]
298    async fn async_session_works() {
299        let cell = ThreadCell::new(TestResource::default());
300        let lock = cell.session().await;
301        let value = lock.run(|res| res.increment()).await;
302        assert_eq!(value, 1);
303    }
304
305    #[test]
306    fn can_hold_non_send_type() {
307        #[derive(Default)]
308        struct NotSend(Rc<()>); // Rc is !Send
309        let cell = ThreadCell::new_with(|| NotSend(Rc::new(())));
310        let count = cell.run_blocking(|res| Rc::strong_count(&res.0));
311        assert_eq!(count, 1);
312    }
313
314    #[test]
315    fn concurrent_run_blocking_requests_are_serialized() {
316        let cell = ThreadCell::new(TestResource::default());
317        let counter = Arc::new(AtomicUsize::new(0));
318
319        let mut handles = Vec::new();
320        for _ in 0..10 {
321            let cell = cell.clone();
322            let counter = counter.clone();
323            handles.push(std::thread::spawn(move || {
324                cell.run_blocking(move |res| {
325                    let val = res.increment();
326                    counter.fetch_add(val, Ordering::SeqCst);
327                });
328            }));
329        }
330
331        for h in handles {
332            h.join().unwrap();
333        }
334
335        // The sum of 1..=10 = 55
336        assert_eq!(counter.load(Ordering::SeqCst), 55);
337    }
338
339    #[test]
340    fn dropping_cell_does_not_panic() {
341        let cell = ThreadCell::new(TestResource::default());
342        drop(cell);
343        // no panic = pass
344    }
345}