1#![deny(missing_docs)]
6
7use spargio::{RuntimeError, RuntimeHandle};
8use std::io;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Copy, Default)]
12pub struct BlockingOptions {
14 timeout: Option<Duration>,
15}
16
17impl BlockingOptions {
18 pub fn with_timeout(mut self, timeout: Duration) -> Self {
20 self.timeout = Some(timeout);
21 self
22 }
23
24 pub fn timeout(self) -> Option<Duration> {
26 self.timeout
27 }
28}
29
30pub async fn tls_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
32where
33 T: Send + 'static,
34 F: FnOnce() -> io::Result<T> + Send + 'static,
35{
36 tls_blocking_with_options(handle, BlockingOptions::default(), f).await
37}
38
39pub async fn tls_blocking_with_options<T, F>(
41 handle: &RuntimeHandle,
42 options: BlockingOptions,
43 f: F,
44) -> io::Result<T>
45where
46 T: Send + 'static,
47 F: FnOnce() -> io::Result<T> + Send + 'static,
48{
49 run_blocking(
50 handle,
51 options,
52 f,
53 "tls blocking task canceled",
54 "tls blocking task timed out",
55 )
56 .await
57}
58
59pub async fn ws_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
61where
62 T: Send + 'static,
63 F: FnOnce() -> io::Result<T> + Send + 'static,
64{
65 ws_blocking_with_options(handle, BlockingOptions::default(), f).await
66}
67
68pub async fn ws_blocking_with_options<T, F>(
70 handle: &RuntimeHandle,
71 options: BlockingOptions,
72 f: F,
73) -> io::Result<T>
74where
75 T: Send + 'static,
76 F: FnOnce() -> io::Result<T> + Send + 'static,
77{
78 run_blocking(
79 handle,
80 options,
81 f,
82 "ws blocking task canceled",
83 "ws blocking task timed out",
84 )
85 .await
86}
87
88pub async fn quic_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
90where
91 T: Send + 'static,
92 F: FnOnce() -> io::Result<T> + Send + 'static,
93{
94 quic_blocking_with_options(handle, BlockingOptions::default(), f).await
95}
96
97pub async fn quic_blocking_with_options<T, F>(
99 handle: &RuntimeHandle,
100 options: BlockingOptions,
101 f: F,
102) -> io::Result<T>
103where
104 T: Send + 'static,
105 F: FnOnce() -> io::Result<T> + Send + 'static,
106{
107 run_blocking(
108 handle,
109 options,
110 f,
111 "quic blocking task canceled",
112 "quic blocking task timed out",
113 )
114 .await
115}
116
117async fn run_blocking<T, F>(
118 handle: &RuntimeHandle,
119 options: BlockingOptions,
120 f: F,
121 canceled_msg: &'static str,
122 timeout_msg: &'static str,
123) -> io::Result<T>
124where
125 T: Send + 'static,
126 F: FnOnce() -> io::Result<T> + Send + 'static,
127{
128 let join = handle
129 .spawn_blocking(f)
130 .map_err(runtime_error_to_io_for_blocking)?;
131 let joined = match options.timeout() {
132 Some(duration) => match spargio::timeout(duration, join).await {
133 Ok(result) => result,
134 Err(_) => return Err(io::Error::new(io::ErrorKind::TimedOut, timeout_msg)),
135 },
136 None => join.await,
137 };
138 joined.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, canceled_msg))?
139}
140
141fn runtime_error_to_io_for_blocking(err: RuntimeError) -> io::Error {
142 match err {
143 RuntimeError::InvalidConfig(msg) => io::Error::new(io::ErrorKind::InvalidInput, msg),
144 RuntimeError::ThreadSpawn(io) => io,
145 RuntimeError::InvalidShard(shard) => {
146 io::Error::new(io::ErrorKind::NotFound, format!("invalid shard {shard}"))
147 }
148 RuntimeError::Closed => io::Error::new(io::ErrorKind::BrokenPipe, "runtime closed"),
149 RuntimeError::Overloaded => io::Error::new(io::ErrorKind::WouldBlock, "runtime overloaded"),
150 RuntimeError::UnsupportedBackend(msg) => io::Error::new(io::ErrorKind::Unsupported, msg),
151 RuntimeError::IoUringInit(io) => io,
152 }
153}
154
155#[cfg(all(feature = "uring-native", target_os = "linux"))]
156pub mod io_compat {
158 use futures::io::{AsyncRead, AsyncWrite};
159 use spargio::net::TcpStream;
160 use std::future::Future;
161 use std::io;
162 use std::pin::Pin;
163 use std::task::{Context, Poll};
164
165 type ReadOp = Pin<Box<dyn Future<Output = io::Result<(usize, Vec<u8>)>> + Send + 'static>>;
166 type WriteOp = Pin<Box<dyn Future<Output = io::Result<usize>> + Send + 'static>>;
167
168 pub struct FuturesTcpStream {
170 inner: TcpStream,
171 read_op: Option<ReadOp>,
172 write_op: Option<WriteOp>,
173 }
174
175 impl std::fmt::Debug for FuturesTcpStream {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("FuturesTcpStream")
178 .field("fd", &self.inner.as_raw_fd())
179 .field("session_shard", &self.inner.session_shard())
180 .finish()
181 }
182 }
183
184 impl FuturesTcpStream {
185 pub fn new(inner: TcpStream) -> Self {
187 Self {
188 inner,
189 read_op: None,
190 write_op: None,
191 }
192 }
193
194 pub fn get_ref(&self) -> &TcpStream {
196 &self.inner
197 }
198
199 pub fn into_inner(self) -> TcpStream {
201 self.inner
202 }
203 }
204
205 impl Unpin for FuturesTcpStream {}
206
207 impl AsyncRead for FuturesTcpStream {
208 fn poll_read(
209 mut self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 buf: &mut [u8],
212 ) -> Poll<io::Result<usize>> {
213 if buf.is_empty() {
214 return Poll::Ready(Ok(0));
215 }
216
217 if self.read_op.is_none() {
218 let inner = self.inner.clone();
219 let want = buf.len().max(1);
220 self.read_op = Some(Box::pin(
221 async move { inner.recv_owned(vec![0u8; want]).await },
222 ));
223 }
224
225 match self
226 .read_op
227 .as_mut()
228 .expect("read op set")
229 .as_mut()
230 .poll(cx)
231 {
232 Poll::Pending => Poll::Pending,
233 Poll::Ready(result) => {
234 self.read_op = None;
235 let (got, payload) = result?;
236 let got = got.min(payload.len()).min(buf.len());
237 buf[..got].copy_from_slice(&payload[..got]);
238 Poll::Ready(Ok(got))
239 }
240 }
241 }
242 }
243
244 impl AsyncWrite for FuturesTcpStream {
245 fn poll_write(
246 mut self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 buf: &[u8],
249 ) -> Poll<io::Result<usize>> {
250 if buf.is_empty() {
251 return Poll::Ready(Ok(0));
252 }
253
254 if self.write_op.is_none() {
255 let inner = self.inner.clone();
256 let payload = buf.to_vec();
257 let payload_len = payload.len();
258 self.write_op = Some(Box::pin(async move {
259 let (written, _) = inner.send_owned(payload).await?;
260 Ok(written.min(payload_len))
261 }));
262 }
263
264 match self
265 .write_op
266 .as_mut()
267 .expect("write op set")
268 .as_mut()
269 .poll(cx)
270 {
271 Poll::Pending => Poll::Pending,
272 Poll::Ready(result) => {
273 self.write_op = None;
274 Poll::Ready(result)
275 }
276 }
277 }
278
279 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280 Poll::Ready(Ok(()))
281 }
282
283 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
284 Poll::Ready(Ok(()))
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use futures::executor::block_on;
293 use std::time::Duration;
294
295 #[test]
296 fn protocol_blocking_helpers_execute_closure() {
297 let rt = spargio::Runtime::builder()
298 .shards(1)
299 .build()
300 .expect("runtime");
301 let handle = rt.handle();
302
303 let tls = block_on(async { tls_blocking(&handle, || Ok::<_, io::Error>(11usize)).await })
304 .expect("tls");
305 let ws = block_on(async { ws_blocking(&handle, || Ok::<_, io::Error>(22usize)).await })
306 .expect("ws");
307 let quic = block_on(async { quic_blocking(&handle, || Ok::<_, io::Error>(33usize)).await })
308 .expect("quic");
309
310 assert_eq!(tls + ws + quic, 66);
311 }
312
313 #[test]
314 fn blocking_timeout_returns_timed_out() {
315 let rt = spargio::Runtime::builder()
316 .shards(1)
317 .build()
318 .expect("runtime");
319 let err = block_on(async {
320 tls_blocking_with_options(
321 &rt.handle(),
322 BlockingOptions::default().with_timeout(Duration::from_millis(5)),
323 || {
324 std::thread::sleep(Duration::from_millis(30));
325 Ok::<(), io::Error>(())
326 },
327 )
328 .await
329 .expect_err("timeout")
330 });
331 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
332 }
333}