Skip to main content

rustrails_support/
runtime.rs

1use std::{cell::RefCell, future::Future};
2
3use tokio::runtime::{Builder, Handle, Runtime};
4
5const RUNTIME_NOT_INITIALIZED: &str = "rustrails_support::runtime::init_runtime() must be called on this thread before using runtime helpers";
6
7thread_local! {
8    static RT_HANDLE: RefCell<Option<Handle>> = const { RefCell::new(None) };
9}
10
11fn with_handle<R>(f: impl FnOnce(&Handle) -> R) -> R {
12    RT_HANDLE.with(|cell| {
13        let borrow = cell.borrow();
14        let handle = borrow
15            .as_ref()
16            .unwrap_or_else(|| panic!("{RUNTIME_NOT_INITIALIZED}"));
17        f(handle)
18    })
19}
20
21/// Initializes the thread-local Tokio runtime handle for the current thread.
22///
23/// The returned runtime must be kept alive by the caller for as long as the
24/// thread-local helpers are used on this thread.
25pub fn init_runtime() -> Runtime {
26    let runtime = match Builder::new_multi_thread().enable_all().build() {
27        Ok(runtime) => runtime,
28        Err(error) => panic!("failed to build Tokio runtime: {error}"),
29    };
30    let handle = runtime.handle().clone();
31    RT_HANDLE.with(|cell| {
32        *cell.borrow_mut() = Some(handle);
33    });
34    runtime
35}
36
37/// Runs a future to completion on the thread-local Tokio runtime.
38///
39/// Panics when the current thread has not been initialized with
40/// [`init_runtime`]. When called from within the same Tokio runtime, this
41/// temporarily yields the worker thread with `block_in_place` before re-entering
42/// the async context.
43pub fn block_on<F: Future>(future: F) -> F::Output {
44    with_handle(|handle| match Handle::try_current() {
45        Ok(current) if current.id() == handle.id() => {
46            tokio::task::block_in_place(|| handle.block_on(future))
47        }
48        _ => handle.block_on(future),
49    })
50}
51
52/// Spawns a task onto the thread-local Tokio runtime.
53///
54/// Panics when the current thread has not been initialized with [`init_runtime`].
55pub fn spawn<F>(future: F) -> tokio::task::JoinHandle<F::Output>
56where
57    F: Future + Send + 'static,
58    F::Output: Send + 'static,
59{
60    with_handle(|handle| handle.spawn(future))
61}
62
63/// Returns `true` when the current thread has an initialized Tokio runtime handle.
64pub fn is_initialized() -> bool {
65    RT_HANDLE.with(|cell| cell.borrow().is_some())
66}
67
68#[cfg(test)]
69mod tests {
70    use std::{
71        any::Any,
72        sync::mpsc,
73        thread,
74        time::{Duration, Instant},
75    };
76
77    use super::{block_on, init_runtime, is_initialized, spawn};
78
79    fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
80    where
81        R: Send + 'static,
82    {
83        match thread::spawn(test).join() {
84            Ok(result) => result,
85            Err(payload) => std::panic::resume_unwind(payload),
86        }
87    }
88
89    fn panic_message(payload: Box<dyn Any + Send>) -> String {
90        if let Some(message) = payload.downcast_ref::<String>() {
91            message.clone()
92        } else if let Some(message) = payload.downcast_ref::<&str>() {
93            (*message).to_owned()
94        } else {
95            "non-string panic payload".to_owned()
96        }
97    }
98
99    #[test]
100    fn init_runtime_sets_initialized_to_true() {
101        run_isolated(|| {
102            assert!(!is_initialized());
103            let _runtime = init_runtime();
104            assert!(is_initialized());
105        });
106    }
107
108    #[test]
109    fn is_initialized_is_false_before_init() {
110        run_isolated(|| {
111            assert!(!is_initialized());
112        });
113    }
114
115    #[test]
116    fn block_on_executes_simple_future() {
117        run_isolated(|| {
118            let _runtime = init_runtime();
119            assert_eq!(block_on(async { 42 }), 42);
120        });
121    }
122
123    #[test]
124    fn block_on_propagates_result_errors() {
125        run_isolated(|| {
126            let _runtime = init_runtime();
127            let result = block_on(async { Result::<(), &'static str>::Err("boom") });
128            assert_eq!(result, Err("boom"));
129        });
130    }
131
132    #[test]
133    fn block_on_panics_with_clear_message_before_init() {
134        let message = run_isolated(|| {
135            let panic = std::panic::catch_unwind(|| {
136                let _: i32 = block_on(async { 42 });
137            })
138            .expect_err("block_on should panic before init_runtime");
139            panic_message(panic)
140        });
141
142        assert!(message.contains("init_runtime() must be called on this thread"));
143    }
144
145    #[test]
146    fn spawn_runs_task_to_completion() {
147        run_isolated(|| {
148            let _runtime = init_runtime();
149            let join = spawn(async { 7_i32 * 6 });
150            let value = block_on(async { join.await.expect("task should complete") });
151            assert_eq!(value, 42);
152        });
153    }
154
155    #[test]
156    fn multiple_sequential_block_on_calls_work() {
157        run_isolated(|| {
158            let _runtime = init_runtime();
159            assert_eq!(block_on(async { 1 }), 1);
160            assert_eq!(block_on(async { 2 }), 2);
161            assert_eq!(block_on(async { 3 }), 3);
162        });
163    }
164
165    #[test]
166    fn block_on_supports_sleeping_futures() {
167        run_isolated(|| {
168            let _runtime = init_runtime();
169            let start = Instant::now();
170            block_on(async {
171                tokio::time::sleep(Duration::from_millis(10)).await;
172            });
173            assert!(start.elapsed() >= Duration::from_millis(10));
174        });
175    }
176
177    #[test]
178    fn block_on_reenters_the_same_runtime_inside_async_context() {
179        run_isolated(|| {
180            let runtime = init_runtime();
181            let value = runtime.block_on(async { block_on(async { 21 * 2 }) });
182            assert_eq!(value, 42);
183        });
184    }
185
186    #[test]
187    fn spawn_can_signal_back_to_sync_code() {
188        run_isolated(|| {
189            let _runtime = init_runtime();
190            let (sender, receiver) = mpsc::channel();
191            let join = spawn(async move {
192                sender.send("done").expect("channel send should succeed");
193            });
194            block_on(async { join.await.expect("task should complete") });
195            assert_eq!(
196                receiver.recv().expect("channel receive should succeed"),
197                "done"
198            );
199        });
200    }
201}