Skip to main content

spargio_protocols/
lib.rs

1//! Protocol integration companion APIs for spargio runtimes.
2//!
3//! These helpers provide explicit blocking bridges intended for TLS/WS/QUIC
4//! ecosystem integrations that do not natively target spargio executors.
5#![deny(missing_docs)]
6
7use spargio::{RuntimeError, RuntimeHandle};
8use std::io;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Copy, Default)]
12/// Shared options for protocol blocking bridge helpers.
13pub struct BlockingOptions {
14    timeout: Option<Duration>,
15}
16
17impl BlockingOptions {
18    /// Sets a timeout for the blocking operation.
19    pub fn with_timeout(mut self, timeout: Duration) -> Self {
20        self.timeout = Some(timeout);
21        self
22    }
23
24    /// Returns the configured timeout.
25    pub fn timeout(self) -> Option<Duration> {
26        self.timeout
27    }
28}
29
30/// Executes a TLS-related blocking closure on Spargio's blocking lane.
31pub async fn tls_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
32where
33    T: Send + 'static,
34    F: FnOnce() -> io::Result<T> + Send + 'static,
35{
36    tls_blocking_with_options(handle, BlockingOptions::default(), f).await
37}
38
39/// Executes a TLS-related blocking closure with explicit options.
40pub async fn tls_blocking_with_options<T, F>(
41    handle: &RuntimeHandle,
42    options: BlockingOptions,
43    f: F,
44) -> io::Result<T>
45where
46    T: Send + 'static,
47    F: FnOnce() -> io::Result<T> + Send + 'static,
48{
49    run_blocking(
50        handle,
51        options,
52        f,
53        "tls blocking task canceled",
54        "tls blocking task timed out",
55    )
56    .await
57}
58
59/// Executes a WebSocket-related blocking closure on Spargio's blocking lane.
60pub async fn ws_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
61where
62    T: Send + 'static,
63    F: FnOnce() -> io::Result<T> + Send + 'static,
64{
65    ws_blocking_with_options(handle, BlockingOptions::default(), f).await
66}
67
68/// Executes a WebSocket-related blocking closure with explicit options.
69pub async fn ws_blocking_with_options<T, F>(
70    handle: &RuntimeHandle,
71    options: BlockingOptions,
72    f: F,
73) -> io::Result<T>
74where
75    T: Send + 'static,
76    F: FnOnce() -> io::Result<T> + Send + 'static,
77{
78    run_blocking(
79        handle,
80        options,
81        f,
82        "ws blocking task canceled",
83        "ws blocking task timed out",
84    )
85    .await
86}
87
88/// Executes a QUIC-related blocking closure on Spargio's blocking lane.
89pub async fn quic_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
90where
91    T: Send + 'static,
92    F: FnOnce() -> io::Result<T> + Send + 'static,
93{
94    quic_blocking_with_options(handle, BlockingOptions::default(), f).await
95}
96
97/// Executes a QUIC-related blocking closure with explicit options.
98pub async fn quic_blocking_with_options<T, F>(
99    handle: &RuntimeHandle,
100    options: BlockingOptions,
101    f: F,
102) -> io::Result<T>
103where
104    T: Send + 'static,
105    F: FnOnce() -> io::Result<T> + Send + 'static,
106{
107    run_blocking(
108        handle,
109        options,
110        f,
111        "quic blocking task canceled",
112        "quic blocking task timed out",
113    )
114    .await
115}
116
117async fn run_blocking<T, F>(
118    handle: &RuntimeHandle,
119    options: BlockingOptions,
120    f: F,
121    canceled_msg: &'static str,
122    timeout_msg: &'static str,
123) -> io::Result<T>
124where
125    T: Send + 'static,
126    F: FnOnce() -> io::Result<T> + Send + 'static,
127{
128    let join = handle
129        .spawn_blocking(f)
130        .map_err(runtime_error_to_io_for_blocking)?;
131    let joined = match options.timeout() {
132        Some(duration) => match spargio::timeout(duration, join).await {
133            Ok(result) => result,
134            Err(_) => return Err(io::Error::new(io::ErrorKind::TimedOut, timeout_msg)),
135        },
136        None => join.await,
137    };
138    joined.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, canceled_msg))?
139}
140
141fn runtime_error_to_io_for_blocking(err: RuntimeError) -> io::Error {
142    match err {
143        RuntimeError::InvalidConfig(msg) => io::Error::new(io::ErrorKind::InvalidInput, msg),
144        RuntimeError::ThreadSpawn(io) => io,
145        RuntimeError::InvalidShard(shard) => {
146            io::Error::new(io::ErrorKind::NotFound, format!("invalid shard {shard}"))
147        }
148        RuntimeError::Closed => io::Error::new(io::ErrorKind::BrokenPipe, "runtime closed"),
149        RuntimeError::Overloaded => io::Error::new(io::ErrorKind::WouldBlock, "runtime overloaded"),
150        RuntimeError::UnsupportedBackend(msg) => io::Error::new(io::ErrorKind::Unsupported, msg),
151        RuntimeError::IoUringInit(io) => io,
152    }
153}
154
155#[cfg(all(feature = "uring-native", target_os = "linux"))]
156/// Compatibility adapters between Spargio sockets and `futures::io` traits.
157pub mod io_compat {
158    use futures::io::{AsyncRead, AsyncWrite};
159    use spargio::net::TcpStream;
160    use std::future::Future;
161    use std::io;
162    use std::pin::Pin;
163    use std::task::{Context, Poll};
164
165    type ReadOp = Pin<Box<dyn Future<Output = io::Result<(usize, Vec<u8>)>> + Send + 'static>>;
166    type WriteOp = Pin<Box<dyn Future<Output = io::Result<usize>> + Send + 'static>>;
167
168    /// `futures::io` compatible wrapper over [`spargio::net::TcpStream`].
169    pub struct FuturesTcpStream {
170        inner: TcpStream,
171        read_op: Option<ReadOp>,
172        write_op: Option<WriteOp>,
173    }
174
175    impl std::fmt::Debug for FuturesTcpStream {
176        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177            f.debug_struct("FuturesTcpStream")
178                .field("fd", &self.inner.as_raw_fd())
179                .field("session_shard", &self.inner.session_shard())
180                .finish()
181        }
182    }
183
184    impl FuturesTcpStream {
185        /// Wraps a Spargio TCP stream.
186        pub fn new(inner: TcpStream) -> Self {
187            Self {
188                inner,
189                read_op: None,
190                write_op: None,
191            }
192        }
193
194        /// Returns a shared reference to the underlying Spargio stream.
195        pub fn get_ref(&self) -> &TcpStream {
196            &self.inner
197        }
198
199        /// Unwraps this adapter into the underlying Spargio stream.
200        pub fn into_inner(self) -> TcpStream {
201            self.inner
202        }
203    }
204
205    impl Unpin for FuturesTcpStream {}
206
207    impl AsyncRead for FuturesTcpStream {
208        fn poll_read(
209            mut self: Pin<&mut Self>,
210            cx: &mut Context<'_>,
211            buf: &mut [u8],
212        ) -> Poll<io::Result<usize>> {
213            if buf.is_empty() {
214                return Poll::Ready(Ok(0));
215            }
216
217            if self.read_op.is_none() {
218                let inner = self.inner.clone();
219                let want = buf.len().max(1);
220                self.read_op = Some(Box::pin(
221                    async move { inner.recv_owned(vec![0u8; want]).await },
222                ));
223            }
224
225            match self
226                .read_op
227                .as_mut()
228                .expect("read op set")
229                .as_mut()
230                .poll(cx)
231            {
232                Poll::Pending => Poll::Pending,
233                Poll::Ready(result) => {
234                    self.read_op = None;
235                    let (got, payload) = result?;
236                    let got = got.min(payload.len()).min(buf.len());
237                    buf[..got].copy_from_slice(&payload[..got]);
238                    Poll::Ready(Ok(got))
239                }
240            }
241        }
242    }
243
244    impl AsyncWrite for FuturesTcpStream {
245        fn poll_write(
246            mut self: Pin<&mut Self>,
247            cx: &mut Context<'_>,
248            buf: &[u8],
249        ) -> Poll<io::Result<usize>> {
250            if buf.is_empty() {
251                return Poll::Ready(Ok(0));
252            }
253
254            if self.write_op.is_none() {
255                let inner = self.inner.clone();
256                let payload = buf.to_vec();
257                let payload_len = payload.len();
258                self.write_op = Some(Box::pin(async move {
259                    let (written, _) = inner.send_owned(payload).await?;
260                    Ok(written.min(payload_len))
261                }));
262            }
263
264            match self
265                .write_op
266                .as_mut()
267                .expect("write op set")
268                .as_mut()
269                .poll(cx)
270            {
271                Poll::Pending => Poll::Pending,
272                Poll::Ready(result) => {
273                    self.write_op = None;
274                    Poll::Ready(result)
275                }
276            }
277        }
278
279        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280            Poll::Ready(Ok(()))
281        }
282
283        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
284            Poll::Ready(Ok(()))
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use futures::executor::block_on;
293    use std::time::Duration;
294
295    #[test]
296    fn protocol_blocking_helpers_execute_closure() {
297        let rt = spargio::Runtime::builder()
298            .shards(1)
299            .build()
300            .expect("runtime");
301        let handle = rt.handle();
302
303        let tls = block_on(async { tls_blocking(&handle, || Ok::<_, io::Error>(11usize)).await })
304            .expect("tls");
305        let ws = block_on(async { ws_blocking(&handle, || Ok::<_, io::Error>(22usize)).await })
306            .expect("ws");
307        let quic = block_on(async { quic_blocking(&handle, || Ok::<_, io::Error>(33usize)).await })
308            .expect("quic");
309
310        assert_eq!(tls + ws + quic, 66);
311    }
312
313    #[test]
314    fn blocking_timeout_returns_timed_out() {
315        let rt = spargio::Runtime::builder()
316            .shards(1)
317            .build()
318            .expect("runtime");
319        let err = block_on(async {
320            tls_blocking_with_options(
321                &rt.handle(),
322                BlockingOptions::default().with_timeout(Duration::from_millis(5)),
323                || {
324                    std::thread::sleep(Duration::from_millis(30));
325                    Ok::<(), io::Error>(())
326                },
327            )
328            .await
329            .expect_err("timeout")
330        });
331        assert_eq!(err.kind(), io::ErrorKind::TimedOut);
332    }
333}