use std::error;
use std::fmt;
use std::io;
use std::io::Read;
use std::io::Write;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use crate::runtime::AsyncRead;
use crate::runtime::AsyncWrite;
use crate::spi::restore_context;
use crate::spi::save_context;
use crate::spi::TlsStreamWithUpcastDyn;
use crate::AsyncSocket;
use crate::ImplInfo;
use crate::TlsStreamDyn;
use crate::TlsStreamWithSocketDyn;
#[derive(Debug)]
pub struct AsyncIoAsSyncIo<S: Unpin> {
inner: S,
}
unsafe impl<S: Unpin + Send> Send for AsyncIoAsSyncIo<S> {}
impl<S: Unpin> AsyncIoAsSyncIo<S> {
pub fn get_inner_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn get_inner_ref(&self) -> &S {
&self.inner
}
pub fn new(inner: S) -> AsyncIoAsSyncIo<S> {
AsyncIoAsSyncIo { inner }
}
fn get_inner_pin(&mut self) -> Pin<&mut S> {
Pin::new(&mut self.inner)
}
}
impl<S: AsyncRead + Unpin> Read for AsyncIoAsSyncIo<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
restore_context_poll_to_result(|cx| {
#[cfg(feature = "runtime-tokio")]
{
let mut read_buf = tokio::io::ReadBuf::new(buf);
let p = self.get_inner_pin().poll_read(cx, &mut read_buf);
p.map_ok(|()| read_buf.filled().len())
}
#[cfg(feature = "runtime-async-std")]
{
self.get_inner_pin().poll_read(cx, buf)
}
})
}
}
impl<S: AsyncWrite + Unpin> Write for AsyncIoAsSyncIo<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
restore_context_poll_to_result(|cx| self.get_inner_pin().poll_write(cx, buf))
}
fn flush(&mut self) -> io::Result<()> {
restore_context_poll_to_result(|cx| self.get_inner_pin().poll_flush(cx))
}
}
fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
#[derive(Debug)]
struct ShouldNotReturnWouldBlockFromAsync(io::Error);
impl error::Error for ShouldNotReturnWouldBlockFromAsync {}
impl fmt::Display for ShouldNotReturnWouldBlockFromAsync {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "should not return WouldBlock from async API: {}", self.0)
}
}
fn poll_to_result<T>(r: Poll<io::Result<T>>) -> io::Result<T> {
match r {
Poll::Ready(Ok(r)) => Ok(r),
Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => Err(io::Error::new(
io::ErrorKind::Other,
ShouldNotReturnWouldBlockFromAsync(e),
)),
Poll::Ready(Err(e)) => Err(e),
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
fn restore_context_poll_to_result<R>(
f: impl FnOnce(&mut Context<'_>) -> Poll<io::Result<R>>,
) -> io::Result<R> {
restore_context(|cx| poll_to_result(f(cx)))
}
pub trait AsyncWrapperOps<A>: fmt::Debug + Unpin + Send + 'static
where
A: Unpin,
{
type SyncWrapper: Read + Write + Unpin + Send + 'static;
fn impl_info() -> ImplInfo;
fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug;
fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo<A>;
fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo<A>;
fn get_alpn_protocol(w: &Self::SyncWrapper) -> crate::Result<Option<Vec<u8>>>;
}
pub struct TlsStreamOverSyncIo<A, O>
where
A: Unpin,
O: AsyncWrapperOps<A>,
{
pub stream: O::SyncWrapper,
_phantom: PhantomData<(A, O)>,
}
impl<A, O> fmt::Debug for TlsStreamOverSyncIo<A, O>
where
A: Unpin,
O: AsyncWrapperOps<A>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("TlsStreamOverSyncIo")
.field(O::debug(&self.stream))
.finish()
}
}
impl<A, O> TlsStreamOverSyncIo<A, O>
where
A: Unpin,
O: AsyncWrapperOps<A>,
{
pub fn new(stream: O::SyncWrapper) -> TlsStreamOverSyncIo<A, O> {
TlsStreamOverSyncIo {
stream,
_phantom: PhantomData,
}
}
fn with_context_sync_to_async<F, R>(
&mut self,
cx: &mut Context<'_>,
f: F,
) -> Poll<io::Result<R>>
where
F: FnOnce(&mut Self) -> io::Result<R>,
{
result_to_poll(save_context(cx, || f(self)))
}
#[cfg(feature = "runtime-tokio")]
fn with_context_sync_to_async_tokio<F>(
&mut self,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf,
f: F,
) -> Poll<io::Result<()>>
where
F: FnOnce(&mut Self, &mut [u8]) -> io::Result<usize>,
{
self.with_context_sync_to_async(cx, |s| {
let unfilled = buf.initialize_unfilled();
let read = f(s, unfilled)?;
buf.advance(read);
Ok(())
})
}
}
impl<A, O> AsyncRead for TlsStreamOverSyncIo<A, O>
where
A: Unpin,
O: AsyncWrapperOps<A>,
{
#[cfg(feature = "runtime-tokio")]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf,
) -> Poll<io::Result<()>> {
self.get_mut()
.with_context_sync_to_async_tokio(cx, buf, |s, buf| {
let result = s.stream.read(buf);
match result {
Ok(r) => Ok(r),
Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
Ok(0)
}
Err(e) => Err(e),
}
})
}
#[cfg(feature = "runtime-async-std")]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.get_mut().with_context_sync_to_async(cx, |s| {
let result = s.stream.read(buf);
match result {
Ok(r) => Ok(r),
Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
Ok(0)
}
Err(e) => Err(e),
}
})
}
}
impl<A, O> AsyncWrite for TlsStreamOverSyncIo<A, O>
where
A: Unpin,
O: AsyncWrapperOps<A>,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut()
.with_context_sync_to_async(cx, |stream| stream.stream.write(buf))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut()
.with_context_sync_to_async(cx, |stream| stream.stream.flush())
}
#[cfg(feature = "runtime-tokio")]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut()
.with_context_sync_to_async(cx, |stream| stream.stream.flush())
}
#[cfg(feature = "runtime-async-std")]
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut()
.with_context_sync_to_async(cx, |stream| stream.stream.flush())
}
}
impl<A, O> TlsStreamDyn for TlsStreamOverSyncIo<A, O>
where
A: AsyncSocket,
O: AsyncWrapperOps<A>,
{
fn impl_info(&self) -> ImplInfo {
O::impl_info()
}
fn get_alpn_protocol(&self) -> crate::Result<Option<Vec<u8>>> {
O::get_alpn_protocol(&self.stream)
}
fn get_socket_dyn_mut(&mut self) -> &mut dyn AsyncSocket {
O::get_mut(&mut self.stream).get_inner_mut()
}
fn get_socket_dyn_ref(&self) -> &dyn AsyncSocket {
O::get_ref(&self.stream).get_inner_ref()
}
}
impl<A, O> TlsStreamWithSocketDyn<A> for TlsStreamOverSyncIo<A, O>
where
A: AsyncSocket,
O: AsyncWrapperOps<A>,
{
fn get_socket_mut(&mut self) -> &mut A {
O::get_mut(&mut self.stream).get_inner_mut()
}
fn get_socket_ref(&self) -> &A {
O::get_ref(&self.stream).get_inner_ref()
}
}
impl<A, O> TlsStreamWithUpcastDyn<A> for TlsStreamOverSyncIo<A, O>
where
A: AsyncSocket,
O: AsyncWrapperOps<A>,
{
fn upcast_box(self: Box<Self>) -> Box<dyn TlsStreamDyn> {
self
}
}
#[macro_export]
macro_rules! spi_tls_stream_over_sync_io_wrapper {
( $t:ident, $n:ident ) => {
#[derive(Debug)]
pub struct TlsStream<A: AsyncSocket>(
pub(crate) TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>>,
);
impl<A: AsyncSocket> TlsStream<A> {
pub(crate) fn new(stream: $n<AsyncIoAsSyncIo<A>>) -> TlsStream<A> {
TlsStream(TlsStreamOverSyncIo::new(stream))
}
fn deref_pin_mut_for_impl_socket(
self: std::pin::Pin<&mut Self>,
) -> std::pin::Pin<
&mut TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>>,
> {
std::pin::Pin::new(&mut self.get_mut().0)
}
fn deref_for_impl_socket(
&self,
) -> &TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>> {
&self.0
}
}
spi_async_socket_impl_delegate!($t<S>);
impl<A: tls_api::AsyncSocket> tls_api::TlsStreamDyn for $t<A> {
fn get_alpn_protocol(&self) -> $crate::Result<Option<Vec<u8>>> {
self.0.get_alpn_protocol()
}
fn impl_info(&self) -> ImplInfo {
self.0.impl_info()
}
fn get_socket_dyn_mut(&mut self) -> &mut dyn AsyncSocket {
self.0.get_socket_dyn_mut()
}
fn get_socket_dyn_ref(&self) -> &dyn AsyncSocket {
self.0.get_socket_dyn_ref()
}
}
impl<A: tls_api::AsyncSocket> tls_api::TlsStreamWithSocketDyn<A> for $t<A> {
fn get_socket_mut(&mut self) -> &mut A {
self.0.get_socket_mut()
}
fn get_socket_ref(&self) -> &A {
self.0.get_socket_ref()
}
}
impl<A: tls_api::AsyncSocket> tls_api::spi::TlsStreamWithUpcastDyn<A> for $t<A> {
fn upcast_box(self: Box<Self>) -> Box<dyn tls_api::TlsStreamDyn> {
self
}
}
};
}