1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use cfg_if::cfg_if;
8
9#[cfg(feature = "_rt-async-io")]
10pub mod rt_async_io;
11
12#[cfg(feature = "_rt-tokio")]
13pub mod rt_tokio;
14
15#[derive(Debug, thiserror::Error)]
16#[error("operation timed out")]
17pub struct TimeoutError;
18
19pub enum JoinHandle<T> {
20 #[cfg(feature = "_rt-async-std")]
21 AsyncStd(async_std::task::JoinHandle<T>),
22
23 #[cfg(feature = "_rt-tokio")]
24 Tokio(tokio::task::JoinHandle<T>),
25
26 #[cfg(feature = "_rt-async-task")]
28 AsyncTask(Option<async_task::Task<T>>),
29
30 _Phantom(PhantomData<fn() -> T>),
32}
33
34pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
35 #[cfg(debug_assertions)]
36 let f = Box::pin(f);
37
38 #[cfg(feature = "_rt-tokio")]
39 if rt_tokio::available() {
40 return tokio::time::timeout(duration, f)
41 .await
42 .map_err(|_| TimeoutError);
43 }
44
45 cfg_if! {
46 if #[cfg(feature = "_rt-async-io")] {
47 rt_async_io::timeout(duration, f).await
48 } else {
49 missing_rt((duration, f))
50 }
51 }
52}
53
54pub async fn sleep(duration: Duration) {
55 #[cfg(feature = "_rt-tokio")]
56 if rt_tokio::available() {
57 return tokio::time::sleep(duration).await;
58 }
59
60 cfg_if! {
61 if #[cfg(feature = "_rt-async-io")] {
62 rt_async_io::sleep(duration).await
63 } else {
64 missing_rt(duration)
65 }
66 }
67}
68
69#[track_caller]
70pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
71where
72 F: Future + Send + 'static,
73 F::Output: Send + 'static,
74{
75 #[cfg(feature = "_rt-tokio")]
76 if let Ok(handle) = tokio::runtime::Handle::try_current() {
77 return JoinHandle::Tokio(handle.spawn(fut));
78 }
79
80 cfg_if! {
81 if #[cfg(feature = "_rt-async-global-executor")] {
82 JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))
83 } else if #[cfg(feature = "_rt-smol")] {
84 JoinHandle::AsyncTask(Some(smol::spawn(fut)))
85 } else if #[cfg(feature = "_rt-async-std")] {
86 JoinHandle::AsyncStd(async_std::task::spawn(fut))
87 } else {
88 missing_rt(fut)
89 }
90 }
91}
92
93#[track_caller]
94pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
95where
96 F: FnOnce() -> R + Send + 'static,
97 R: Send + 'static,
98{
99 #[cfg(feature = "_rt-tokio")]
100 if let Ok(handle) = tokio::runtime::Handle::try_current() {
101 return JoinHandle::Tokio(handle.spawn_blocking(f));
102 }
103
104 cfg_if! {
105 if #[cfg(feature = "_rt-async-global-executor")] {
106 JoinHandle::AsyncTask(Some(async_global_executor::spawn_blocking(f)))
107 } else if #[cfg(feature = "_rt-smol")] {
108 JoinHandle::AsyncTask(Some(smol::unblock(f)))
109 } else if #[cfg(feature = "_rt-async-std")] {
110 JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
111 } else {
112 missing_rt(f)
113 }
114 }
115}
116
117pub async fn yield_now() {
118 #[cfg(feature = "_rt-tokio")]
119 if rt_tokio::available() {
120 return tokio::task::yield_now().await;
121 }
122
123 let mut yielded = false;
133
134 std::future::poll_fn(|cx| {
135 if !yielded {
136 yielded = true;
137 cx.waker().wake_by_ref();
138 Poll::Pending
139 } else {
140 Poll::Ready(())
141 }
142 })
143 .await
144}
145
146#[track_caller]
147pub fn test_block_on<F: Future>(f: F) -> F::Output {
148 cfg_if! {
149 if #[cfg(feature = "_rt-async-io")] {
150 async_io::block_on(f)
151 } else if #[cfg(feature = "_rt-tokio")] {
152 tokio::runtime::Builder::new_current_thread()
153 .enable_all()
154 .build()
155 .expect("failed to start Tokio runtime")
156 .block_on(f)
157 } else {
158 missing_rt(f)
159 }
160 }
161}
162
163#[track_caller]
164pub const fn missing_rt<T>(_unused: T) -> ! {
165 if cfg!(feature = "_rt-tokio") {
166 panic!("this functionality requires a Tokio context")
167 }
168
169 panic!("one of the `runtime` features of SQLx must be enabled")
170}
171
172impl<T: Send + 'static> Future for JoinHandle<T> {
173 type Output = T;
174
175 #[track_caller]
176 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177 match &mut *self {
178 #[cfg(feature = "_rt-async-std")]
179 Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
180
181 #[cfg(feature = "_rt-async-task")]
182 Self::AsyncTask(task) => Pin::new(task)
183 .as_pin_mut()
184 .expect("BUG: task taken")
185 .poll(cx),
186
187 #[cfg(feature = "_rt-tokio")]
188 Self::Tokio(handle) => Pin::new(handle)
189 .poll(cx)
190 .map(|res| res.expect("spawned task panicked")),
191
192 Self::_Phantom(_) => {
193 let _ = cx;
194 unreachable!("runtime should have been checked on spawn")
195 }
196 }
197 }
198}
199
200impl<T> Drop for JoinHandle<T> {
201 fn drop(&mut self) {
202 match self {
203 #[cfg(feature = "_rt-async-task")]
206 Self::AsyncTask(task) => {
207 if let Some(task) = task.take() {
208 task.detach();
209 }
210 }
211 _ => (),
212 }
213 }
214}