tls_api/
async_as_sync.rs

1//! Utility used in different implementations of TLS API.
2//!
3//! Not to be used by regular users of the library.
4
5use std::fmt;
6use std::io;
7use std::io::Read;
8use std::io::Write;
9use std::marker::PhantomData;
10use std::pin::Pin;
11use std::task::Context;
12use std::task::Poll;
13
14use crate::runtime::AsyncRead;
15use crate::runtime::AsyncWrite;
16use crate::spi::restore_context;
17use crate::spi::save_context;
18use crate::spi::TlsStreamWithUpcastDyn;
19use crate::AsyncSocket;
20use crate::ImplInfo;
21use crate::TlsStreamDyn;
22use crate::TlsStreamWithSocketDyn;
23
24/// Async IO object as sync IO.
25///
26/// Used in API implementations.
27#[derive(Debug)]
28pub struct AsyncIoAsSyncIo<S: Unpin> {
29    inner: S,
30}
31
32unsafe impl<S: Unpin + Send> Send for AsyncIoAsSyncIo<S> {}
33
34impl<S: Unpin> AsyncIoAsSyncIo<S> {
35    /// Get a mutable reference to a wrapped stream
36    pub fn get_inner_mut(&mut self) -> &mut S {
37        &mut self.inner
38    }
39
40    /// And a reference to a wrapped stream
41    pub fn get_inner_ref(&self) -> &S {
42        &self.inner
43    }
44
45    /// Wrap sync object in this wrapper.
46    pub fn new(inner: S) -> AsyncIoAsSyncIo<S> {
47        AsyncIoAsSyncIo { inner }
48    }
49
50    fn get_inner_pin(&mut self) -> Pin<&mut S> {
51        Pin::new(&mut self.inner)
52    }
53}
54
55impl<S: AsyncRead + Unpin> Read for AsyncIoAsSyncIo<S> {
56    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
57        restore_context_poll_to_result(|cx| {
58            #[cfg(feature = "runtime-tokio")]
59            {
60                let mut read_buf = tokio::io::ReadBuf::new(buf);
61                let p = self.get_inner_pin().poll_read(cx, &mut read_buf);
62                p.map_ok(|()| read_buf.filled().len())
63            }
64            #[cfg(feature = "runtime-async-std")]
65            {
66                self.get_inner_pin().poll_read(cx, buf)
67            }
68        })
69    }
70}
71
72impl<S: AsyncWrite + Unpin> Write for AsyncIoAsSyncIo<S> {
73    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
74        restore_context_poll_to_result(|cx| self.get_inner_pin().poll_write(cx, buf))
75    }
76
77    fn flush(&mut self) -> io::Result<()> {
78        restore_context_poll_to_result(|cx| self.get_inner_pin().poll_flush(cx))
79    }
80}
81
82/// Convert blocking API result to async result
83fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
84    match r {
85        Ok(v) => Poll::Ready(Ok(v)),
86        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
87        Err(e) => Poll::Ready(Err(e)),
88    }
89}
90
91#[derive(Debug, thiserror::Error)]
92#[error("should not return WouldBlock from async API: {}", _0)]
93struct ShouldNotReturnWouldBlockFromAsync(io::Error);
94
95/// Convert nonblocking API to sync result
96fn poll_to_result<T>(r: Poll<io::Result<T>>) -> io::Result<T> {
97    match r {
98        Poll::Ready(Ok(r)) => Ok(r),
99        Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => Err(io::Error::new(
100            io::ErrorKind::Other,
101            ShouldNotReturnWouldBlockFromAsync(e),
102        )),
103        Poll::Ready(Err(e)) => Err(e),
104        Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
105    }
106}
107
108fn restore_context_poll_to_result<R>(
109    f: impl FnOnce(&mut Context<'_>) -> Poll<io::Result<R>>,
110) -> io::Result<R> {
111    restore_context(|cx| poll_to_result(f(cx)))
112}
113
114/// Used by API implementors.
115pub trait AsyncWrapperOps<A>: fmt::Debug + Unpin + Send + 'static
116where
117    A: Unpin,
118{
119    /// API-implementation of wrapper stream.
120    ///
121    /// Wrapped object is always [`AsyncIoAsSyncIo`].
122    type SyncWrapper: Read + Write + WriteShutdown + Unpin + Send + 'static;
123
124    /// Which crates imlpements this?
125    fn impl_info() -> ImplInfo;
126
127    /// Cast the wrapper to [`fmt::Debug`] or provide substitute debug.
128    /// This is work around not all wrappers implementing [`fmt::Debug`].
129    fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug;
130
131    /// Unwrap the wrapper.
132    fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo<A>;
133    /// Unwrap the wrapper.
134    fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo<A>;
135
136    /// Get negotiated ALPN protocol.
137    fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result<Option<Vec<u8>>>;
138}
139
140/// Notify the writer that there will be no more data written.
141/// In context of TLS providers, this is great time to send notify_close message.
142pub trait WriteShutdown: Write {
143    /// Initiates or attempts to shut down this writer, returning when
144    /// the I/O connection has completely shut down.
145    ///
146    /// For example this is suitable for implementing shutdown of a
147    /// TLS connection or calling `TcpStream::shutdown` on a proxied connection.
148    /// Protocols sometimes need to flush out final pieces of data or otherwise
149    /// perform a graceful shutdown handshake, reading/writing more data as
150    /// appropriate. This method is the hook for such protocols to implement the
151    /// graceful shutdown logic.
152    ///
153    /// This `shutdown` method is required by implementers of the
154    /// `AsyncWrite` trait. Wrappers typically just want to proxy this call
155    /// through to the wrapped type, and base types will typically implement
156    /// shutdown logic here or just return `Ok(().into())`. Note that if you're
157    /// wrapping an underlying `AsyncWrite` a call to `shutdown` implies that
158    /// transitively the entire stream has been shut down. After your wrapper's
159    /// shutdown logic has been executed you should shut down the underlying
160    /// stream.
161    ///
162    /// Invocation of a `shutdown` implies an invocation of `flush`. Once this
163    /// method returns it implies that a flush successfully happened
164    /// before the shutdown happened. That is, callers don't need to call
165    /// `flush` before calling `shutdown`. They can rely that by calling
166    /// `shutdown` any pending buffered data will be written out.
167    ///
168    /// # Errors
169    ///
170    /// This function can return normal I/O errors through `Err`, described
171    /// above. Additionally this method may also render the underlying
172    /// `Write::write` method no longer usable (e.g. will return errors in the
173    /// future). It's recommended that once `shutdown` is called the
174    /// `write` method is no longer called.
175    fn shutdown(&mut self) -> Result<(), io::Error> {
176        self.flush()?;
177        Ok(())
178    }
179}
180
181/// Implementation of `TlsStreamImpl` for APIs using synchronous I/O.
182pub struct TlsStreamOverSyncIo<A, O>
183where
184    A: Unpin,
185    O: AsyncWrapperOps<A>,
186{
187    /// TLS-implementation.
188    pub stream: O::SyncWrapper,
189    _phantom: PhantomData<(A, O)>,
190}
191
192impl<A, O> fmt::Debug for TlsStreamOverSyncIo<A, O>
193where
194    A: Unpin,
195    O: AsyncWrapperOps<A>,
196{
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        f.debug_tuple("TlsStreamOverSyncIo")
199            .field(O::debug(&self.stream))
200            .finish()
201    }
202}
203
204impl<A, O> TlsStreamOverSyncIo<A, O>
205where
206    A: Unpin,
207    O: AsyncWrapperOps<A>,
208{
209    /// Constructor.
210    pub fn new(stream: O::SyncWrapper) -> TlsStreamOverSyncIo<A, O> {
211        TlsStreamOverSyncIo {
212            stream,
213            _phantom: PhantomData,
214        }
215    }
216
217    fn with_context_sync_to_async<F, R>(
218        &mut self,
219        cx: &mut Context<'_>,
220        f: F,
221    ) -> Poll<io::Result<R>>
222    where
223        F: FnOnce(&mut Self) -> io::Result<R>,
224    {
225        result_to_poll(save_context(cx, || f(self)))
226    }
227
228    #[cfg(feature = "runtime-tokio")]
229    fn with_context_sync_to_async_tokio<F>(
230        &mut self,
231        cx: &mut Context<'_>,
232        buf: &mut tokio::io::ReadBuf,
233        f: F,
234    ) -> Poll<io::Result<()>>
235    where
236        F: FnOnce(&mut Self, &mut [u8]) -> io::Result<usize>,
237    {
238        self.with_context_sync_to_async(cx, |s| {
239            let unfilled = buf.initialize_unfilled();
240            let read = f(s, unfilled)?;
241            buf.advance(read);
242            Ok(())
243        })
244    }
245}
246
247impl<A, O> AsyncRead for TlsStreamOverSyncIo<A, O>
248where
249    A: Unpin,
250    O: AsyncWrapperOps<A>,
251{
252    #[cfg(feature = "runtime-tokio")]
253    fn poll_read(
254        self: Pin<&mut Self>,
255        cx: &mut Context<'_>,
256        buf: &mut tokio::io::ReadBuf,
257    ) -> Poll<io::Result<()>> {
258        self.get_mut()
259            .with_context_sync_to_async_tokio(cx, buf, |s, buf| {
260                let result = s.stream.read(buf);
261                match result {
262                    Ok(r) => Ok(r),
263                    Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
264                        // rustls returns `ConnectionAborted` on EOF
265                        Ok(0)
266                    }
267                    Err(e) => Err(e),
268                }
269            })
270    }
271
272    #[cfg(feature = "runtime-async-std")]
273    fn poll_read(
274        self: Pin<&mut Self>,
275        cx: &mut Context<'_>,
276        buf: &mut [u8],
277    ) -> Poll<io::Result<usize>> {
278        self.get_mut().with_context_sync_to_async(cx, |s| {
279            let result = s.stream.read(buf);
280            match result {
281                Ok(r) => Ok(r),
282                Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
283                    // rustls returns `ConnectionAborted` on EOF
284                    Ok(0)
285                }
286                Err(e) => Err(e),
287            }
288        })
289    }
290}
291
292impl<A, O> AsyncWrite for TlsStreamOverSyncIo<A, O>
293where
294    A: Unpin,
295    O: AsyncWrapperOps<A>,
296{
297    fn poll_write(
298        self: Pin<&mut Self>,
299        cx: &mut Context<'_>,
300        buf: &[u8],
301    ) -> Poll<io::Result<usize>> {
302        self.get_mut()
303            .with_context_sync_to_async(cx, |stream| stream.stream.write(buf))
304    }
305
306    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
307        self.get_mut()
308            .with_context_sync_to_async(cx, |stream| stream.stream.flush())
309    }
310
311    #[cfg(feature = "runtime-tokio")]
312    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
313        self.get_mut()
314            .with_context_sync_to_async(cx, |stream| stream.stream.shutdown())
315    }
316
317    #[cfg(feature = "runtime-async-std")]
318    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
319        self.get_mut()
320            .with_context_sync_to_async(cx, |stream| stream.stream.shutdown())
321    }
322}
323
324impl<A, O> TlsStreamDyn for TlsStreamOverSyncIo<A, O>
325where
326    A: AsyncSocket,
327    O: AsyncWrapperOps<A>,
328{
329    fn impl_info(&self) -> ImplInfo {
330        O::impl_info()
331    }
332
333    fn get_alpn_protocol(&self) -> anyhow::Result<Option<Vec<u8>>> {
334        O::get_alpn_protocol(&self.stream)
335    }
336
337    fn get_socket_dyn_mut(&mut self) -> &mut dyn AsyncSocket {
338        O::get_mut(&mut self.stream).get_inner_mut()
339    }
340
341    fn get_socket_dyn_ref(&self) -> &dyn AsyncSocket {
342        O::get_ref(&self.stream).get_inner_ref()
343    }
344}
345
346impl<A, O> TlsStreamWithSocketDyn<A> for TlsStreamOverSyncIo<A, O>
347where
348    A: AsyncSocket,
349    O: AsyncWrapperOps<A>,
350{
351    fn get_socket_mut(&mut self) -> &mut A {
352        O::get_mut(&mut self.stream).get_inner_mut()
353    }
354
355    fn get_socket_ref(&self) -> &A {
356        O::get_ref(&self.stream).get_inner_ref()
357    }
358}
359
360impl<A, O> TlsStreamWithUpcastDyn<A> for TlsStreamOverSyncIo<A, O>
361where
362    A: AsyncSocket,
363    O: AsyncWrapperOps<A>,
364{
365    fn upcast_box(self: Box<Self>) -> Box<dyn TlsStreamDyn> {
366        self
367    }
368}
369
370/// Implement wrapper for [`TlsStreamOverSyncIo`].
371#[macro_export]
372macro_rules! spi_tls_stream_over_sync_io_wrapper {
373    ( $t:ident, $n:ident ) => {
374        #[derive(Debug)]
375        pub struct TlsStream<A: AsyncSocket>(
376            pub(crate) TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>>,
377        );
378
379        impl<A: AsyncSocket> TlsStream<A> {
380            pub(crate) fn new(stream: $n<AsyncIoAsSyncIo<A>>) -> TlsStream<A> {
381                TlsStream(TlsStreamOverSyncIo::new(stream))
382            }
383
384            fn deref_pin_mut_for_impl_socket(
385                self: std::pin::Pin<&mut Self>,
386            ) -> std::pin::Pin<
387                &mut TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>>,
388            > {
389                std::pin::Pin::new(&mut self.get_mut().0)
390            }
391
392            fn deref_for_impl_socket(
393                &self,
394            ) -> &TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>> {
395                &self.0
396            }
397        }
398
399        spi_async_socket_impl_delegate!($t<S>);
400
401        impl<A: tls_api::AsyncSocket> tls_api::TlsStreamDyn for $t<A> {
402            fn get_alpn_protocol(&self) -> anyhow::Result<Option<Vec<u8>>> {
403                self.0.get_alpn_protocol()
404            }
405
406            fn impl_info(&self) -> ImplInfo {
407                self.0.impl_info()
408            }
409
410            fn get_socket_dyn_mut(&mut self) -> &mut dyn AsyncSocket {
411                self.0.get_socket_dyn_mut()
412            }
413
414            fn get_socket_dyn_ref(&self) -> &dyn AsyncSocket {
415                self.0.get_socket_dyn_ref()
416            }
417        }
418
419        impl<A: tls_api::AsyncSocket> tls_api::TlsStreamWithSocketDyn<A> for $t<A> {
420            fn get_socket_mut(&mut self) -> &mut A {
421                self.0.get_socket_mut()
422            }
423
424            fn get_socket_ref(&self) -> &A {
425                self.0.get_socket_ref()
426            }
427        }
428
429        impl<A: tls_api::AsyncSocket> tls_api::spi::TlsStreamWithUpcastDyn<A> for $t<A> {
430            fn upcast_box(self: Box<Self>) -> Box<dyn tls_api::TlsStreamDyn> {
431                self
432            }
433        }
434    };
435}