sync_utils/
blocking_async.rs

1use std::{
2    cell::UnsafeCell,
3    future::Future,
4    mem::transmute,
5    sync::{Arc, Condvar, Mutex},
6};
7
8use tokio::runtime::Runtime;
9
10struct BlockingFutureInner<R>
11where
12    R: Sync + Send + 'static,
13{
14    res: UnsafeCell<Option<R>>,
15    cond: Condvar,
16    done: Mutex<bool>,
17}
18
19impl<R> BlockingFutureInner<R>
20where
21    R: Sync + Send + 'static,
22{
23    #[inline(always)]
24    fn done(&self, r: R) {
25        let _res: &mut Option<R> = unsafe { transmute(self.res.get()) };
26        _res.replace(r);
27        let mut guard = self.done.lock().unwrap();
28        *guard = true;
29        self.cond.notify_one();
30    }
31
32    #[inline(always)]
33    fn take_res(&self) -> R {
34        let _res: &mut Option<R> = unsafe { transmute(self.res.get()) };
35        _res.take().unwrap()
36    }
37}
38
39unsafe impl<R> Send for BlockingFutureInner<R> where R: Sync + Send + Clone + 'static {}
40
41unsafe impl<R> Sync for BlockingFutureInner<R> where R: Sync + Send + Clone + 'static {}
42
43/// For the use in blocking context,
44/// spawn a future into given tokio runtime and wait for result.
45///
46/// ## example:
47///
48/// ``` rust
49/// use tokio::time::*;
50/// use sync_utils::blocking_async::BlockingFuture;
51/// let rt = tokio::runtime::Builder::new_multi_thread()
52///     .enable_all()
53///     .worker_threads(1)
54///     .build()
55///     .unwrap();
56/// let res = BlockingFuture::new().block_on(&rt, async move {
57///     println!("exec future");
58///     sleep(Duration::from_secs(1)).await;
59///     return "hello world".to_string();
60/// });
61/// ```
62pub struct BlockingFuture<R: Sync + Send + 'static>(Arc<BlockingFutureInner<R>>);
63
64impl<R> BlockingFuture<R>
65where
66    R: Sync + Send + Clone + 'static,
67{
68    #[inline(always)]
69    pub fn new() -> Self {
70        Self(Arc::new(BlockingFutureInner {
71            res: UnsafeCell::new(None),
72            cond: Condvar::new(),
73            done: Mutex::new(false),
74        }))
75    }
76
77    pub fn block_on<F>(&mut self, rt: &Runtime, f: F) -> R
78    where
79        F: Future<Output = R> + Send + Sync + 'static,
80    {
81        let _self = self.0.clone();
82        let _ = rt.spawn(async move {
83            let res = f.await;
84            _self.done(res);
85        });
86        let _self = self.0.as_ref();
87        let mut guard = _self.done.lock().unwrap();
88        loop {
89            if *guard {
90                return _self.take_res();
91            }
92            guard = _self.cond.wait(guard).unwrap();
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99
100    use std::time::Duration;
101
102    use tokio::time::sleep;
103
104    use super::*;
105
106    #[test]
107    fn test_spawn() {
108        let rt = tokio::runtime::Builder::new_multi_thread()
109            .enable_all()
110            .worker_threads(1)
111            .build()
112            .unwrap();
113
114        let mut bf = BlockingFuture::new();
115        let res = bf.block_on(&rt, async move {
116            sleep(Duration::from_secs(1)).await;
117            println!("exec future");
118            sleep(Duration::from_secs(1)).await;
119            return "hello world".to_string();
120        });
121        println!("got res {}", res);
122    }
123}