1#![doc = include_str!("../README.md")]
2
3use std::thread;
4
5enum ThreadCellMessage<T> {
7 Run(Box<dyn FnOnce(&mut T) + Send>),
8 GetSessionSync(crossbeam::channel::Sender<ThreadCellSession<T>>),
9 #[cfg(feature = "tokio")]
10 GetSessionAsync(tokio::sync::oneshot::Sender<ThreadCellSession<T>>),
11}
12
13type SessionMsg<T> = Box<dyn FnOnce(&mut T) + Send>;
15
16static SESSION_ERROR_MESSAGE: &str = "Session thread has panicked or resource was dropped";
17
18pub struct ThreadCellSession<T> {
22 sender: crossbeam::channel::Sender<SessionMsg<T>>,
23}
24
25impl<T> ThreadCellSession<T> {
26 pub fn run_blocking<F, R>(&self, f: F) -> R
27 where
28 F: FnOnce(&mut T) -> R + Send + 'static,
29 R: Send + 'static,
30 {
31 let (tx, rx) = crossbeam::channel::bounded(1);
32 self.sender
33 .send(Box::new(move |resource| {
34 let res = f(resource);
35 let _ = tx.send(res);
36 }))
37 .expect(SESSION_ERROR_MESSAGE);
38 rx.recv().expect(SESSION_ERROR_MESSAGE)
39 }
40
41 #[cfg(feature = "tokio")]
42 pub async fn run<F, R>(&self, f: F) -> R
43 where
44 F: FnOnce(&mut T) -> R + Send + 'static,
45 R: Send + 'static,
46 {
47 let (tx, rx) = tokio::sync::oneshot::channel();
48 self.sender
49 .send(Box::new(move |resource| {
50 let res = f(resource);
51 let _ = tx.send(res);
52 }))
53 .expect(SESSION_ERROR_MESSAGE);
54 rx.await.expect(SESSION_ERROR_MESSAGE)
55 }
56}
57
58static MANAGER_ERROR_MESSAGE: &str = "Manager thread has panicked";
59
60pub struct ThreadCell<T: 'static> {
65 sender: crossbeam::channel::Sender<ThreadCellMessage<T>>,
66}
67
68impl<T: 'static> Clone for ThreadCell<T> {
69 fn clone(&self) -> Self {
70 Self {
71 sender: self.sender.clone(),
72 }
73 }
74}
75
76impl<T: Send> ThreadCell<T> {
77 pub fn new(mut resource: T) -> Self {
79 let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
80
81 thread::spawn(move || {
82 while let Ok(msg) = rx.recv() {
83 match msg {
84 ThreadCellMessage::Run(f) => f(&mut resource),
85 ThreadCellMessage::GetSessionSync(responder) => {
86 let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
87 let _ = responder.send(ThreadCellSession { sender: stx });
88 while let Ok(f) = srx.recv() {
89 f(&mut resource);
90 }
91 }
92 #[cfg(feature = "tokio")]
93 ThreadCellMessage::GetSessionAsync(sender) => {
94 let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
95 let _ = sender.send(ThreadCellSession { sender: stx });
96 while let Ok(f) = srx.recv() {
97 f(&mut resource);
98 }
99 }
100 }
101 }
102 });
103
104 Self { sender: tx }
105 }
106}
107
108impl<T> ThreadCell<T> {
109 pub fn new_with<F: FnOnce() -> T + Send + 'static>(resource_fn: F) -> Self {
111 let (tx, rx) = crossbeam::channel::unbounded::<ThreadCellMessage<T>>();
112
113 thread::spawn(move || {
114 let mut resource = resource_fn();
115 while let Ok(msg) = rx.recv() {
116 match msg {
117 ThreadCellMessage::Run(f) => f(&mut resource),
118 ThreadCellMessage::GetSessionSync(responder) => {
119 let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
120 let _ = responder.send(ThreadCellSession { sender: stx });
121 while let Ok(f) = srx.recv() {
122 f(&mut resource);
123 }
124 }
125 #[cfg(feature = "tokio")]
126 ThreadCellMessage::GetSessionAsync(sender) => {
127 let (stx, srx) = crossbeam::channel::unbounded::<SessionMsg<T>>();
128 let _ = sender.send(ThreadCellSession { sender: stx });
129 while let Ok(f) = srx.recv() {
130 f(&mut resource);
131 }
132 }
133 }
134 }
135 });
136
137 Self { sender: tx }
138 }
139
140 pub fn run_blocking<F, R>(&self, f: F) -> R
141 where
142 F: FnOnce(&mut T) -> R + Send + 'static,
143 R: Send + 'static,
144 {
145 let (tx, rx) = crossbeam::channel::bounded(1);
146 self.sender
147 .send(ThreadCellMessage::Run(Box::new(move |resource| {
148 let res = f(resource);
149 let _ = tx.send(res);
150 })))
151 .expect(MANAGER_ERROR_MESSAGE);
152 rx.recv().expect(MANAGER_ERROR_MESSAGE)
153 }
154
155 #[cfg(feature = "tokio")]
156 pub async fn run<F, R>(&self, f: F) -> R
157 where
158 F: FnOnce(&mut T) -> R + Send + 'static,
159 R: Send + 'static,
160 {
161 let (tx, rx) = tokio::sync::oneshot::channel();
162 self.sender
163 .send(ThreadCellMessage::Run(Box::new(move |resource| {
164 let res = f(resource);
165 let _ = tx.send(res);
166 })))
167 .expect(MANAGER_ERROR_MESSAGE);
168 rx.await.expect(MANAGER_ERROR_MESSAGE)
169 }
170
171 pub fn session_blocking(&self) -> ThreadCellSession<T> {
172 let (tx, rx) = crossbeam::channel::bounded(1);
173 self.sender
174 .send(ThreadCellMessage::GetSessionSync(tx))
175 .expect(MANAGER_ERROR_MESSAGE);
176 rx.recv().expect(MANAGER_ERROR_MESSAGE)
177 }
178
179 #[cfg(feature = "tokio")]
180 pub async fn session(&self) -> ThreadCellSession<T> {
181 let (tx, rx) = tokio::sync::oneshot::channel();
182 self.sender
183 .send(ThreadCellMessage::GetSessionAsync(tx))
184 .expect(MANAGER_ERROR_MESSAGE);
185 rx.await.expect(MANAGER_ERROR_MESSAGE)
186 }
187}
188
189impl<T: Send> ThreadCell<T> {
190 pub fn set_blocking(&self, new_value: T) {
192 self.run_blocking(|res| *res = new_value);
193 }
194
195 #[cfg(feature = "tokio")]
197 pub async fn set(&self, new_value: T) {
198 self.run(|res| *res = new_value).await;
199 }
200
201 pub fn replace_blocking(&self, new_value: T) -> T {
203 self.run_blocking(|res| std::mem::replace(res, new_value))
204 }
205
206 #[cfg(feature = "tokio")]
208 pub async fn replace(&self, new_value: T) -> T {
209 self.run(|res| std::mem::replace(res, new_value)).await
210 }
211}
212
213impl<T: Send + Default> ThreadCell<T> {
214 pub fn take_blocking(&self) -> T {
215 self.run_blocking(|res| std::mem::take(res))
216 }
217
218 #[cfg(feature = "tokio")]
219 pub async fn take(&self) -> T {
220 self.run(|res| std::mem::take(res)).await
221 }
222}
223
224impl<T: Send + Clone> ThreadCell<T> {
225 pub fn get_blocking(&self) -> T {
227 self.run_blocking(|res| res.clone())
228 }
229
230 #[cfg(feature = "tokio")]
232 pub async fn get(&self) -> T {
233 self.run(|res| res.clone()).await
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use std::rc::Rc;
241 use std::sync::Arc;
242 use std::sync::atomic::{AtomicUsize, Ordering};
243
244 #[derive(Default)]
245 struct TestResource {
246 counter: usize,
247 }
248
249 impl TestResource {
250 fn increment(&mut self) -> usize {
251 self.counter += 1;
252 self.counter
253 }
254 }
255
256 #[test]
257 fn basic_run_blocking_works() {
258 let cell = ThreadCell::new(TestResource::default());
259 let value = cell.run_blocking(|res| {
260 res.increment();
261 res.increment()
262 });
263 assert_eq!(value, 2);
264
265 let value = cell.run_blocking(|res| res.increment());
266 assert_eq!(value, 3);
267 }
268
269 #[test]
270 fn can_be_sent_to_another_thread() {
271 let cell = ThreadCell::new(TestResource::default());
272 let handle = std::thread::spawn(move || cell.run_blocking(|res| res.increment()));
273 let result = handle.join().unwrap();
274 assert_eq!(result, 1);
275 }
276
277 #[cfg(feature = "tokio")]
278 #[tokio::test(flavor = "current_thread")]
279 async fn async_run_works() {
280 let cell = ThreadCell::new(TestResource::default());
281 let result = cell.run(|res| res.increment()).await;
282 assert_eq!(result, 1);
283 }
284
285 #[test]
286 fn session_blocking_gives_mutable_access() {
287 let cell = ThreadCell::new(TestResource::default());
288 let lock = cell.session_blocking();
289 let value = lock.run_blocking(|res| {
290 res.increment();
291 res.increment()
292 });
293 assert_eq!(value, 2);
294 }
295
296 #[cfg(feature = "tokio")]
297 #[tokio::test(flavor = "current_thread")]
298 async fn async_session_works() {
299 let cell = ThreadCell::new(TestResource::default());
300 let lock = cell.session().await;
301 let value = lock.run(|res| res.increment()).await;
302 assert_eq!(value, 1);
303 }
304
305 #[test]
306 fn can_hold_non_send_type() {
307 #[derive(Default)]
308 struct NotSend(Rc<()>); let cell = ThreadCell::new_with(|| NotSend(Rc::new(())));
310 let count = cell.run_blocking(|res| Rc::strong_count(&res.0));
311 assert_eq!(count, 1);
312 }
313
314 #[test]
315 fn concurrent_run_blocking_requests_are_serialized() {
316 let cell = ThreadCell::new(TestResource::default());
317 let counter = Arc::new(AtomicUsize::new(0));
318
319 let mut handles = Vec::new();
320 for _ in 0..10 {
321 let cell = cell.clone();
322 let counter = counter.clone();
323 handles.push(std::thread::spawn(move || {
324 cell.run_blocking(move |res| {
325 let val = res.increment();
326 counter.fetch_add(val, Ordering::SeqCst);
327 });
328 }));
329 }
330
331 for h in handles {
332 h.join().unwrap();
333 }
334
335 assert_eq!(counter.load(Ordering::SeqCst), 55);
337 }
338
339 #[test]
340 fn dropping_cell_does_not_panic() {
341 let cell = ThreadCell::new(TestResource::default());
342 drop(cell);
343 }
345}