1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use cfg_if::cfg_if;
8
9#[cfg(feature = "_rt-async-io")]
10pub mod rt_async_io;
11
12#[cfg(feature = "_rt-tokio")]
13pub mod rt_tokio;
14
15#[derive(Debug, thiserror::Error)]
16#[error("operation timed out")]
17pub struct TimeoutError;
18
19pub enum JoinHandle<T> {
20    #[cfg(feature = "_rt-async-std")]
21    AsyncStd(async_std::task::JoinHandle<T>),
22
23    #[cfg(feature = "_rt-tokio")]
24    Tokio(tokio::task::JoinHandle<T>),
25
26    #[cfg(feature = "_rt-async-task")]
28    AsyncTask(Option<async_task::Task<T>>),
29
30    _Phantom(PhantomData<fn() -> T>),
32}
33
34pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
35    #[cfg(debug_assertions)]
36    let f = Box::pin(f);
37
38    #[cfg(feature = "_rt-tokio")]
39    if rt_tokio::available() {
40        return tokio::time::timeout(duration, f)
41            .await
42            .map_err(|_| TimeoutError);
43    }
44
45    cfg_if! {
46        if #[cfg(feature = "_rt-async-io")] {
47            rt_async_io::timeout(duration, f).await
48        } else {
49            missing_rt((duration, f))
50        }
51    }
52}
53
54pub async fn sleep(duration: Duration) {
55    #[cfg(feature = "_rt-tokio")]
56    if rt_tokio::available() {
57        return tokio::time::sleep(duration).await;
58    }
59
60    cfg_if! {
61        if #[cfg(feature = "_rt-async-io")] {
62            rt_async_io::sleep(duration).await
63        } else {
64            missing_rt(duration)
65        }
66    }
67}
68
69#[track_caller]
70pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
71where
72    F: Future + Send + 'static,
73    F::Output: Send + 'static,
74{
75    #[cfg(feature = "_rt-tokio")]
76    if let Ok(handle) = tokio::runtime::Handle::try_current() {
77        return JoinHandle::Tokio(handle.spawn(fut));
78    }
79
80    cfg_if! {
81        if #[cfg(feature = "_rt-async-global-executor")] {
82            JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))
83        } else if #[cfg(feature = "_rt-smol")] {
84            JoinHandle::AsyncTask(Some(smol::spawn(fut)))
85        } else if #[cfg(feature = "_rt-async-std")] {
86            JoinHandle::AsyncStd(async_std::task::spawn(fut))
87        } else {
88            missing_rt(fut)
89        }
90    }
91}
92
93#[track_caller]
94pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
95where
96    F: FnOnce() -> R + Send + 'static,
97    R: Send + 'static,
98{
99    #[cfg(feature = "_rt-tokio")]
100    if let Ok(handle) = tokio::runtime::Handle::try_current() {
101        return JoinHandle::Tokio(handle.spawn_blocking(f));
102    }
103
104    cfg_if! {
105        if #[cfg(feature = "_rt-async-global-executor")] {
106            JoinHandle::AsyncTask(Some(async_global_executor::spawn_blocking(f)))
107        } else if #[cfg(feature = "_rt-smol")] {
108            JoinHandle::AsyncTask(Some(smol::unblock(f)))
109        } else if #[cfg(feature = "_rt-async-std")] {
110            JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
111        } else {
112            missing_rt(f)
113        }
114    }
115}
116
117pub async fn yield_now() {
118    #[cfg(feature = "_rt-tokio")]
119    if rt_tokio::available() {
120        return tokio::task::yield_now().await;
121    }
122
123    let mut yielded = false;
133
134    std::future::poll_fn(|cx| {
135        if !yielded {
136            yielded = true;
137            cx.waker().wake_by_ref();
138            Poll::Pending
139        } else {
140            Poll::Ready(())
141        }
142    })
143    .await
144}
145
146#[track_caller]
147pub fn test_block_on<F: Future>(f: F) -> F::Output {
148    cfg_if! {
149        if #[cfg(feature = "_rt-async-io")] {
150            async_io::block_on(f)
151        } else if #[cfg(feature = "_rt-tokio")] {
152            tokio::runtime::Builder::new_current_thread()
153                .enable_all()
154                .build()
155                .expect("failed to start Tokio runtime")
156                .block_on(f)
157        } else {
158            missing_rt(f)
159        }
160    }
161}
162
163#[track_caller]
164pub const fn missing_rt<T>(_unused: T) -> ! {
165    if cfg!(feature = "_rt-tokio") {
166        panic!("this functionality requires a Tokio context")
167    }
168
169    panic!("one of the `runtime` features of SQLx must be enabled")
170}
171
172impl<T: Send + 'static> Future for JoinHandle<T> {
173    type Output = T;
174
175    #[track_caller]
176    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177        match &mut *self {
178            #[cfg(feature = "_rt-async-std")]
179            Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
180
181            #[cfg(feature = "_rt-async-task")]
182            Self::AsyncTask(task) => Pin::new(task)
183                .as_pin_mut()
184                .expect("BUG: task taken")
185                .poll(cx),
186
187            #[cfg(feature = "_rt-tokio")]
188            Self::Tokio(handle) => Pin::new(handle)
189                .poll(cx)
190                .map(|res| res.expect("spawned task panicked")),
191
192            Self::_Phantom(_) => {
193                let _ = cx;
194                unreachable!("runtime should have been checked on spawn")
195            }
196        }
197    }
198}
199
200impl<T> Drop for JoinHandle<T> {
201    fn drop(&mut self) {
202        match self {
203            #[cfg(feature = "_rt-async-task")]
206            Self::AsyncTask(task) => {
207                if let Some(task) = task.take() {
208                    task.detach();
209                }
210            }
211            _ => (),
212        }
213    }
214}