rustrails_support/
runtime.rs1use std::{cell::RefCell, future::Future};
2
3use tokio::runtime::{Builder, Handle, Runtime};
4
5const RUNTIME_NOT_INITIALIZED: &str = "rustrails_support::runtime::init_runtime() must be called on this thread before using runtime helpers";
6
7thread_local! {
8 static RT_HANDLE: RefCell<Option<Handle>> = const { RefCell::new(None) };
9}
10
11fn with_handle<R>(f: impl FnOnce(&Handle) -> R) -> R {
12 RT_HANDLE.with(|cell| {
13 let borrow = cell.borrow();
14 let handle = borrow
15 .as_ref()
16 .unwrap_or_else(|| panic!("{RUNTIME_NOT_INITIALIZED}"));
17 f(handle)
18 })
19}
20
21pub fn init_runtime() -> Runtime {
26 let runtime = match Builder::new_multi_thread().enable_all().build() {
27 Ok(runtime) => runtime,
28 Err(error) => panic!("failed to build Tokio runtime: {error}"),
29 };
30 let handle = runtime.handle().clone();
31 RT_HANDLE.with(|cell| {
32 *cell.borrow_mut() = Some(handle);
33 });
34 runtime
35}
36
37pub fn block_on<F: Future>(future: F) -> F::Output {
44 with_handle(|handle| match Handle::try_current() {
45 Ok(current) if current.id() == handle.id() => {
46 tokio::task::block_in_place(|| handle.block_on(future))
47 }
48 _ => handle.block_on(future),
49 })
50}
51
52pub fn spawn<F>(future: F) -> tokio::task::JoinHandle<F::Output>
56where
57 F: Future + Send + 'static,
58 F::Output: Send + 'static,
59{
60 with_handle(|handle| handle.spawn(future))
61}
62
63pub fn is_initialized() -> bool {
65 RT_HANDLE.with(|cell| cell.borrow().is_some())
66}
67
68#[cfg(test)]
69mod tests {
70 use std::{
71 any::Any,
72 sync::mpsc,
73 thread,
74 time::{Duration, Instant},
75 };
76
77 use super::{block_on, init_runtime, is_initialized, spawn};
78
79 fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
80 where
81 R: Send + 'static,
82 {
83 match thread::spawn(test).join() {
84 Ok(result) => result,
85 Err(payload) => std::panic::resume_unwind(payload),
86 }
87 }
88
89 fn panic_message(payload: Box<dyn Any + Send>) -> String {
90 if let Some(message) = payload.downcast_ref::<String>() {
91 message.clone()
92 } else if let Some(message) = payload.downcast_ref::<&str>() {
93 (*message).to_owned()
94 } else {
95 "non-string panic payload".to_owned()
96 }
97 }
98
99 #[test]
100 fn init_runtime_sets_initialized_to_true() {
101 run_isolated(|| {
102 assert!(!is_initialized());
103 let _runtime = init_runtime();
104 assert!(is_initialized());
105 });
106 }
107
108 #[test]
109 fn is_initialized_is_false_before_init() {
110 run_isolated(|| {
111 assert!(!is_initialized());
112 });
113 }
114
115 #[test]
116 fn block_on_executes_simple_future() {
117 run_isolated(|| {
118 let _runtime = init_runtime();
119 assert_eq!(block_on(async { 42 }), 42);
120 });
121 }
122
123 #[test]
124 fn block_on_propagates_result_errors() {
125 run_isolated(|| {
126 let _runtime = init_runtime();
127 let result = block_on(async { Result::<(), &'static str>::Err("boom") });
128 assert_eq!(result, Err("boom"));
129 });
130 }
131
132 #[test]
133 fn block_on_panics_with_clear_message_before_init() {
134 let message = run_isolated(|| {
135 let panic = std::panic::catch_unwind(|| {
136 let _: i32 = block_on(async { 42 });
137 })
138 .expect_err("block_on should panic before init_runtime");
139 panic_message(panic)
140 });
141
142 assert!(message.contains("init_runtime() must be called on this thread"));
143 }
144
145 #[test]
146 fn spawn_runs_task_to_completion() {
147 run_isolated(|| {
148 let _runtime = init_runtime();
149 let join = spawn(async { 7_i32 * 6 });
150 let value = block_on(async { join.await.expect("task should complete") });
151 assert_eq!(value, 42);
152 });
153 }
154
155 #[test]
156 fn multiple_sequential_block_on_calls_work() {
157 run_isolated(|| {
158 let _runtime = init_runtime();
159 assert_eq!(block_on(async { 1 }), 1);
160 assert_eq!(block_on(async { 2 }), 2);
161 assert_eq!(block_on(async { 3 }), 3);
162 });
163 }
164
165 #[test]
166 fn block_on_supports_sleeping_futures() {
167 run_isolated(|| {
168 let _runtime = init_runtime();
169 let start = Instant::now();
170 block_on(async {
171 tokio::time::sleep(Duration::from_millis(10)).await;
172 });
173 assert!(start.elapsed() >= Duration::from_millis(10));
174 });
175 }
176
177 #[test]
178 fn block_on_reenters_the_same_runtime_inside_async_context() {
179 run_isolated(|| {
180 let runtime = init_runtime();
181 let value = runtime.block_on(async { block_on(async { 21 * 2 }) });
182 assert_eq!(value, 42);
183 });
184 }
185
186 #[test]
187 fn spawn_can_signal_back_to_sync_code() {
188 run_isolated(|| {
189 let _runtime = init_runtime();
190 let (sender, receiver) = mpsc::channel();
191 let join = spawn(async move {
192 sender.send("done").expect("channel send should succeed");
193 });
194 block_on(async { join.await.expect("task should complete") });
195 assert_eq!(
196 receiver.recv().expect("channel receive should succeed"),
197 "done"
198 );
199 });
200 }
201}