sqlx_build_trust_core/rt/
mod.rs1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7#[cfg(feature = "_rt-async-std")]
8pub mod rt_async_std;
9
10#[cfg(feature = "_rt-tokio")]
11pub mod rt_tokio;
12
13#[derive(Debug, thiserror::Error)]
14#[error("operation timed out")]
15pub struct TimeoutError(());
16
17pub enum JoinHandle<T> {
18 #[cfg(feature = "_rt-async-std")]
19 AsyncStd(async_std::task::JoinHandle<T>),
20 #[cfg(feature = "_rt-tokio")]
21 Tokio(tokio::task::JoinHandle<T>),
22 _Phantom(PhantomData<fn() -> T>),
24}
25
26pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
27 #[cfg(feature = "_rt-tokio")]
28 if rt_tokio::available() {
29 return tokio::time::timeout(duration, f)
30 .await
31 .map_err(|_| TimeoutError(()));
32 }
33
34 #[cfg(feature = "_rt-async-std")]
35 {
36 return async_std::future::timeout(duration, f)
37 .await
38 .map_err(|_| TimeoutError(()));
39 }
40
41 #[cfg(not(feature = "_rt-async-std"))]
42 missing_rt((duration, f))
43}
44
45pub async fn sleep(duration: Duration) {
46 #[cfg(feature = "_rt-tokio")]
47 if rt_tokio::available() {
48 return tokio::time::sleep(duration).await;
49 }
50
51 #[cfg(feature = "_rt-async-std")]
52 {
53 return async_std::task::sleep(duration).await;
54 }
55
56 #[cfg(not(feature = "_rt-async-std"))]
57 missing_rt(duration)
58}
59
60#[track_caller]
61pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
62where
63 F: Future + Send + 'static,
64 F::Output: Send + 'static,
65{
66 #[cfg(feature = "_rt-tokio")]
67 if let Ok(handle) = tokio::runtime::Handle::try_current() {
68 return JoinHandle::Tokio(handle.spawn(fut));
69 }
70
71 #[cfg(feature = "_rt-async-std")]
72 {
73 return JoinHandle::AsyncStd(async_std::task::spawn(fut));
74 }
75
76 #[cfg(not(feature = "_rt-async-std"))]
77 missing_rt(fut)
78}
79
80#[track_caller]
81pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
82where
83 F: FnOnce() -> R + Send + 'static,
84 R: Send + 'static,
85{
86 #[cfg(feature = "_rt-tokio")]
87 if let Ok(handle) = tokio::runtime::Handle::try_current() {
88 return JoinHandle::Tokio(handle.spawn_blocking(f));
89 }
90
91 #[cfg(feature = "_rt-async-std")]
92 {
93 return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f));
94 }
95
96 #[cfg(not(feature = "_rt-async-std"))]
97 missing_rt(f)
98}
99
100pub async fn yield_now() {
101 #[cfg(feature = "_rt-tokio")]
102 if rt_tokio::available() {
103 return tokio::task::yield_now().await;
104 }
105
106 #[cfg(feature = "_rt-async-std")]
107 {
108 return async_std::task::yield_now().await;
109 }
110
111 #[cfg(not(feature = "_rt-async-std"))]
112 missing_rt(())
113}
114
115#[track_caller]
116pub fn test_block_on<F: Future>(f: F) -> F::Output {
117 #[cfg(feature = "_rt-tokio")]
118 {
119 return tokio::runtime::Builder::new_current_thread()
120 .enable_all()
121 .build()
122 .expect("failed to start Tokio runtime")
123 .block_on(f);
124 }
125
126 #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
127 {
128 return async_std::task::block_on(f);
129 }
130
131 #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))]
132 {
133 drop(f);
134 panic!("at least one of the `runtime-*` features must be enabled")
135 }
136}
137
138#[track_caller]
139pub fn missing_rt<T>(_unused: T) -> ! {
140 if cfg!(feature = "_rt-tokio") {
141 panic!("this functionality requires a Tokio context")
142 }
143
144 panic!("either the `runtime-async-std` or `runtime-tokio` feature must be enabled")
145}
146
147impl<T: Send + 'static> Future for JoinHandle<T> {
148 type Output = T;
149
150 #[track_caller]
151 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152 match &mut *self {
153 #[cfg(feature = "_rt-async-std")]
154 Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
155 #[cfg(feature = "_rt-tokio")]
156 Self::Tokio(handle) => Pin::new(handle)
157 .poll(cx)
158 .map(|res| res.expect("spawned task panicked")),
159 Self::_Phantom(_) => {
160 let _ = cx;
161 unreachable!("runtime should have been checked on spawn")
162 }
163 }
164 }
165}