sqlx_build_trust_core/rt/
mod.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7#[cfg(feature = "_rt-async-std")]
8pub mod rt_async_std;
9
10#[cfg(feature = "_rt-tokio")]
11pub mod rt_tokio;
12
13#[derive(Debug, thiserror::Error)]
14#[error("operation timed out")]
15pub struct TimeoutError(());
16
17pub enum JoinHandle<T> {
18    #[cfg(feature = "_rt-async-std")]
19    AsyncStd(async_std::task::JoinHandle<T>),
20    #[cfg(feature = "_rt-tokio")]
21    Tokio(tokio::task::JoinHandle<T>),
22    // `PhantomData<T>` requires `T: Unpin`
23    _Phantom(PhantomData<fn() -> T>),
24}
25
26pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
27    #[cfg(feature = "_rt-tokio")]
28    if rt_tokio::available() {
29        return tokio::time::timeout(duration, f)
30            .await
31            .map_err(|_| TimeoutError(()));
32    }
33
34    #[cfg(feature = "_rt-async-std")]
35    {
36        return async_std::future::timeout(duration, f)
37            .await
38            .map_err(|_| TimeoutError(()));
39    }
40
41    #[cfg(not(feature = "_rt-async-std"))]
42    missing_rt((duration, f))
43}
44
45pub async fn sleep(duration: Duration) {
46    #[cfg(feature = "_rt-tokio")]
47    if rt_tokio::available() {
48        return tokio::time::sleep(duration).await;
49    }
50
51    #[cfg(feature = "_rt-async-std")]
52    {
53        return async_std::task::sleep(duration).await;
54    }
55
56    #[cfg(not(feature = "_rt-async-std"))]
57    missing_rt(duration)
58}
59
60#[track_caller]
61pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
62where
63    F: Future + Send + 'static,
64    F::Output: Send + 'static,
65{
66    #[cfg(feature = "_rt-tokio")]
67    if let Ok(handle) = tokio::runtime::Handle::try_current() {
68        return JoinHandle::Tokio(handle.spawn(fut));
69    }
70
71    #[cfg(feature = "_rt-async-std")]
72    {
73        return JoinHandle::AsyncStd(async_std::task::spawn(fut));
74    }
75
76    #[cfg(not(feature = "_rt-async-std"))]
77    missing_rt(fut)
78}
79
80#[track_caller]
81pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
82where
83    F: FnOnce() -> R + Send + 'static,
84    R: Send + 'static,
85{
86    #[cfg(feature = "_rt-tokio")]
87    if let Ok(handle) = tokio::runtime::Handle::try_current() {
88        return JoinHandle::Tokio(handle.spawn_blocking(f));
89    }
90
91    #[cfg(feature = "_rt-async-std")]
92    {
93        return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f));
94    }
95
96    #[cfg(not(feature = "_rt-async-std"))]
97    missing_rt(f)
98}
99
100pub async fn yield_now() {
101    #[cfg(feature = "_rt-tokio")]
102    if rt_tokio::available() {
103        return tokio::task::yield_now().await;
104    }
105
106    #[cfg(feature = "_rt-async-std")]
107    {
108        return async_std::task::yield_now().await;
109    }
110
111    #[cfg(not(feature = "_rt-async-std"))]
112    missing_rt(())
113}
114
115#[track_caller]
116pub fn test_block_on<F: Future>(f: F) -> F::Output {
117    #[cfg(feature = "_rt-tokio")]
118    {
119        return tokio::runtime::Builder::new_current_thread()
120            .enable_all()
121            .build()
122            .expect("failed to start Tokio runtime")
123            .block_on(f);
124    }
125
126    #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
127    {
128        return async_std::task::block_on(f);
129    }
130
131    #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))]
132    {
133        drop(f);
134        panic!("at least one of the `runtime-*` features must be enabled")
135    }
136}
137
138#[track_caller]
139pub fn missing_rt<T>(_unused: T) -> ! {
140    if cfg!(feature = "_rt-tokio") {
141        panic!("this functionality requires a Tokio context")
142    }
143
144    panic!("either the `runtime-async-std` or `runtime-tokio` feature must be enabled")
145}
146
147impl<T: Send + 'static> Future for JoinHandle<T> {
148    type Output = T;
149
150    #[track_caller]
151    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152        match &mut *self {
153            #[cfg(feature = "_rt-async-std")]
154            Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
155            #[cfg(feature = "_rt-tokio")]
156            Self::Tokio(handle) => Pin::new(handle)
157                .poll(cx)
158                .map(|res| res.expect("spawned task panicked")),
159            Self::_Phantom(_) => {
160                let _ = cx;
161                unreachable!("runtime should have been checked on spawn")
162            }
163        }
164    }
165}