use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream};
use crate::udp::UdpListener;
use super::addr::{each_addr, ToSocketAddrs};
use super::udp::{UdpStream, UdpStreamReadHalf, UdpStreamWriteHalf};
type Result<T, E = std::io::Error> = std::result::Result<T, E>;
pub trait NetworkStream: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static {
type ReaderRef<'a>: AsyncReadExt + Send + Unpin + Send
where
Self: 'a;
type WriterRef<'a>: AsyncWriteExt + Send + Unpin + Send
where
Self: 'a;
type InnerStream: AsyncReadExt + AsyncWriteExt + Unpin + Send;
fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>);
fn into_inner_stream(self) -> Self::InnerStream;
fn local_addr(&self) -> Result<SocketAddr>;
fn peer_addr(&self) -> Result<SocketAddr>;
}
macro_rules! gen_stream_impl {
($struct_name:ident, $inner_ty:ty,$doc_string:literal) => {
#[doc = $doc_string]
pub struct $struct_name($inner_ty);
impl $struct_name {
pub fn new(stream: $inner_ty) -> Self {
Self(stream)
}
}
impl AsyncRead for $struct_name {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl AsyncWrite for $struct_name {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
}
}
};
}
gen_stream_impl!(
TcpStreamImpl,
TcpStream,
"Implementing NetworkStream for TcpStream"
);
gen_stream_impl!(
UdpStreamImpl,
UdpStream,
"Implementing NetworkStream for UdpStream"
);
impl NetworkStream for TcpStreamImpl {
type ReaderRef<'a> = ReadHalf<'a>
where
Self: 'a;
type WriterRef<'a> = WriteHalf<'a>
where
Self: 'a;
type InnerStream = TcpStream;
fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
self.0.split()
}
fn into_inner_stream(self) -> Self::InnerStream {
self.0
}
fn local_addr(&self) -> Result<SocketAddr> {
self.0.local_addr()
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.0.peer_addr()
}
}
impl NetworkStream for UdpStreamImpl {
type ReaderRef<'a> = UdpStreamReadHalf<'static>;
type WriterRef<'a> = UdpStreamWriteHalf<'a>
where
Self: 'a;
type InnerStream = UdpStream;
fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
self.0.split()
}
fn into_inner_stream(self) -> Self::InnerStream {
self.0
}
fn local_addr(&self) -> Result<SocketAddr> {
self.0.local_addr()
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.0.peer_addr()
}
}
pub trait StreamProvider {
type Item: NetworkStream;
fn connect<A: ToSocketAddrs + Send>(
addr: A,
) -> impl std::future::Future<Output = Result<Self::Item>> + Send;
}
pub struct TcpStreamProvider;
impl StreamProvider for TcpStreamProvider {
type Item = TcpStreamImpl;
async fn connect<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
Ok(TcpStreamImpl(each_addr(addr, TcpStream::connect).await?))
}
}
pub struct UdpStreamProvider;
impl StreamProvider for UdpStreamProvider {
type Item = UdpStreamImpl;
async fn connect<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
Ok(UdpStreamImpl(UdpStream::connect(addr).await?))
}
}
pub trait ListenerProvider {
type Listener: StreamAccept + 'static;
fn bind<A: ToSocketAddrs + Send>(
addr: A,
) -> impl std::future::Future<Output = Result<Self::Listener>> + Send;
}
pub trait StreamAccept {
type Item: NetworkStream;
fn accept(&self) -> impl std::future::Future<Output = Result<(Self::Item, SocketAddr)>> + Send;
}
pub struct TcpListenerProvider;
impl ListenerProvider for TcpListenerProvider {
type Listener = TcpListenerImpl;
async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
Ok(TcpListenerImpl(each_addr(addr, TcpListener::bind).await?))
}
}
pub struct TcpListenerImpl(TcpListener);
impl StreamAccept for TcpListenerImpl {
type Item = TcpStreamImpl;
async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
let (stream, addr) = self.0.accept().await?;
Ok((TcpStreamImpl::new(stream), addr))
}
}
pub struct UdpListenerProvider;
impl ListenerProvider for UdpListenerProvider {
type Listener = UdpListenerImpl;
async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
Ok(UdpListenerImpl(UdpListener::bind(addr).await?))
}
}
pub struct UdpListenerImpl(UdpListener);
impl StreamAccept for UdpListenerImpl {
type Item = UdpStreamImpl;
async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
let (stream, addr) = self.0.accept().await?;
Ok((UdpStreamImpl::new(stream), addr))
}
}