use crate::{ import::*, tung_websocket::TungWebSocket, WsEvent, WsErr };
pub struct WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
inner: IoStream< TungWebSocket<S>, Vec<u8> >,
buffer_size: usize,
}
impl<S> WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
pub fn new( inner: ATungSocket<S> ) -> Self
{
let c = inner.get_config();
let buffer_size = std::cmp::min( c.max_write_buffer_size, c.max_message_size.unwrap_or(usize::MAX) );
Self
{
buffer_size,
inner : IoStream::new( TungWebSocket::new( inner ) ),
}
}
}
impl<S> fmt::Debug for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
fn fmt( &self, f: &mut fmt::Formatter<'_> ) -> fmt::Result
{
write!( f, "WsStream over Tungstenite" )
}
}
impl<S> AsyncWrite for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8] ) -> Poll< io::Result<usize> >
{
let buffer_size = std::cmp::min(self.buffer_size, buf.len());
AsyncWrite::poll_write( Pin::new( &mut self.inner ), cx, &buf[..buffer_size] )
}
fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[ IoSlice<'_> ] ) -> Poll< io::Result<usize> >
{
let mut take_size = 0;
let mut seen_size = 0;
let mut next = 1;
for (i, buf) in bufs.iter().enumerate()
{
let len = buf.len();
seen_size += len;
if take_size + len > self.buffer_size { break; }
take_size += len;
next = i+1;
}
if seen_size == 0 { return Poll::Ready(Ok(0)); }
if take_size == 0
{
return AsyncWrite::poll_write( self, cx, bufs[next-1].get(0..).expect("index 0 not to be out of bounds") );
}
AsyncWrite::poll_write_vectored( Pin::new( &mut self.inner ), cx, &bufs[0..next] )
}
fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll< io::Result<()> >
{
AsyncWrite::poll_flush( Pin::new( &mut self.inner ), cx )
}
fn poll_close( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll< io::Result<()> >
{
Pin::new( &mut self.inner ).poll_close( cx )
}
}
#[ cfg( feature = "tokio_io" ) ]
#[ cfg_attr( nightly, doc(cfg( feature = "tokio_io" )) ) ]
impl<S> TokAsyncWrite for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8] ) -> Poll< io::Result<usize> >
{
TokAsyncWrite::poll_write( Pin::new( &mut self.inner ), cx, buf )
}
fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll< io::Result<()> >
{
TokAsyncWrite::poll_flush( Pin::new( &mut self.inner ), cx )
}
fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll< io::Result<()> >
{
Pin::new( &mut self.inner ).poll_close( cx )
}
}
impl<S> AsyncRead for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8] ) -> Poll< io::Result<usize> >
{
AsyncRead::poll_read( Pin::new( &mut self.inner), cx, buf )
}
fn poll_read_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>] ) -> Poll< io::Result<usize> >
{
AsyncRead::poll_read_vectored( Pin::new( &mut self.inner), cx, bufs )
}
}
#[ cfg( feature = "tokio_io" ) ]
#[ cfg_attr( nightly, doc(cfg( feature = "tokio_io" )) ) ]
impl<S> TokAsyncRead for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_> ) -> Poll< io::Result<()> >
{
TokAsyncRead::poll_read( Pin::new( &mut self.inner), cx, buf )
}
}
impl<S> AsyncBufRead for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
fn poll_fill_buf( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll< io::Result<&[u8]> >
{
Pin::new( &mut self.get_mut().inner ).poll_fill_buf( cx )
}
fn consume( mut self: Pin<&mut Self>, amount: usize )
{
Pin::new( &mut self.inner ).consume( amount )
}
}
impl<S> Observable< WsEvent > for WsStream<S> where S: AsyncRead + AsyncWrite + Send + Unpin
{
type Error = WsErr;
fn observe( &mut self, options: ObserveConfig< WsEvent > ) -> Observe< '_, WsEvent, Self::Error >
{
async move
{
self.inner.observe( options ).await.map_err( Into::into )
}.boxed()
}
}