use futures_util::{Sink, Stream, StreamExt};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_tungstenite::{
self as tg,
tungstenite::{
error::*,
protocol::{frame::coding::Data, CloseFrame},
Message, Result,
},
MaybeTlsStream,
};
pub async fn connect(url: &str) -> crate::Result<WebSocketStream> {
let (inner, _response) = tg::connect_async(url).await?;
let inner = inner.filter_map(msg_conv as fn(_) -> _);
Ok(WebSocketStream { inner })
}
type Ws = tg::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type MsgConvFut = futures_util::future::Ready<Option<crate::Result<crate::Message>>>;
pub struct WebSocketStream {
inner: futures_util::stream::FilterMap<
Ws,
MsgConvFut,
fn(tg::tungstenite::Result<Message>) -> MsgConvFut,
>,
}
impl Stream for WebSocketStream {
type Item = crate::Result<crate::Message>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl Sink<crate::Message> for WebSocketStream {
type Error = crate::Error;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
}
fn start_send(
mut self: Pin<&mut Self>,
item: crate::Message,
) -> std::result::Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.into())
.map_err(Into::into)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
}
}
fn msg_conv(msg: Result<Message>) -> MsgConvFut {
fn inner(msg: Result<Message>) -> Option<crate::Result<crate::Message>> {
let msg = match msg {
Ok(msg) => match msg {
Message::Text(inner) => Ok(crate::Message::Text(inner)),
Message::Binary(inner) => Ok(crate::Message::Binary(inner)),
Message::Close(inner) => Ok(crate::Message::Close(inner.map(Into::into))),
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => return None,
},
Err(err) => Err(crate::Error::from(err)),
};
Some(msg)
}
futures_util::future::ready(inner(msg))
}
impl<'a> From<CloseFrame<'a>> for crate::message::CloseFrame<'a> {
fn from(close_frame: CloseFrame<'a>) -> Self {
crate::message::CloseFrame {
code: u16::from(close_frame.code).into(),
reason: close_frame.reason,
}
}
}
impl<'a> From<crate::message::CloseFrame<'a>> for CloseFrame<'a> {
fn from(close_frame: crate::message::CloseFrame<'a>) -> Self {
CloseFrame {
code: u16::from(close_frame.code).into(),
reason: close_frame.reason,
}
}
}
impl From<Message> for crate::Message {
fn from(msg: Message) -> Self {
match msg {
Message::Text(inner) => crate::Message::Text(inner),
Message::Binary(inner) => crate::Message::Binary(inner),
Message::Close(inner) => crate::Message::Close(inner.map(Into::into)),
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
unreachable!("Unsendable via interface.")
}
}
}
}
impl From<crate::Message> for Message {
fn from(msg: crate::Message) -> Self {
match msg {
crate::Message::Text(inner) => Message::Text(inner),
crate::Message::Binary(inner) => Message::Binary(inner),
crate::Message::Close(inner) => Message::Close(inner.map(Into::into)),
}
}
}
impl From<Error> for crate::Error {
fn from(err: Error) -> Self {
match err {
Error::ConnectionClosed => crate::Error::ConnectionClosed,
Error::AlreadyClosed => crate::Error::AlreadyClosed,
Error::Io(inner) => crate::Error::Io(inner),
Error::Tls(inner) => crate::Error::Tls(inner.into()),
Error::Capacity(inner) => crate::Error::Capacity(inner.into()),
Error::Protocol(inner) => crate::Error::Protocol(inner.into()),
Error::WriteBufferFull(inner) => crate::Error::WriteBufferFull(inner.into()),
Error::Utf8 => crate::Error::Utf8,
Error::AttackAttempt => crate::Error::AttackAttempt,
Error::Url(inner) => crate::Error::Url(inner.into()),
Error::Http(inner) => crate::Error::Http(inner),
Error::HttpFormat(inner) => crate::Error::HttpFormat(inner),
}
}
}
impl From<CapacityError> for crate::error::CapacityError {
fn from(err: CapacityError) -> Self {
match err {
CapacityError::TooManyHeaders => crate::error::CapacityError::TooManyHeaders,
CapacityError::MessageTooLong { size, max_size } => {
crate::error::CapacityError::MessageTooLong { size, max_size }
}
}
}
}
impl From<UrlError> for crate::error::UrlError {
fn from(err: UrlError) -> Self {
match err {
UrlError::TlsFeatureNotEnabled => crate::error::UrlError::TlsFeatureNotEnabled,
UrlError::NoHostName => crate::error::UrlError::NoHostName,
UrlError::UnableToConnect(inner) => crate::error::UrlError::UnableToConnect(inner),
UrlError::UnsupportedUrlScheme => crate::error::UrlError::UnsupportedUrlScheme,
UrlError::EmptyHostName => crate::error::UrlError::EmptyHostName,
UrlError::NoPathOrQuery => crate::error::UrlError::NoPathOrQuery,
}
}
}
impl From<ProtocolError> for crate::error::ProtocolError {
fn from(err: ProtocolError) -> Self {
match err {
ProtocolError::WrongHttpMethod => crate::error::ProtocolError::WrongHttpMethod,
ProtocolError::WrongHttpVersion => crate::error::ProtocolError::WrongHttpVersion,
ProtocolError::MissingConnectionUpgradeHeader => {
crate::error::ProtocolError::MissingConnectionUpgradeHeader
}
ProtocolError::MissingUpgradeWebSocketHeader => {
crate::error::ProtocolError::MissingUpgradeWebSocketHeader
}
ProtocolError::MissingSecWebSocketVersionHeader => {
crate::error::ProtocolError::MissingSecWebSocketVersionHeader
}
ProtocolError::MissingSecWebSocketKey => {
crate::error::ProtocolError::MissingSecWebSocketKey
}
ProtocolError::SecWebSocketAcceptKeyMismatch => {
crate::error::ProtocolError::SecWebSocketAcceptKeyMismatch
}
ProtocolError::JunkAfterRequest => crate::error::ProtocolError::JunkAfterRequest,
ProtocolError::CustomResponseSuccessful => {
crate::error::ProtocolError::CustomResponseSuccessful
}
ProtocolError::InvalidHeader(header_name) => {
crate::error::ProtocolError::InvalidHeader(header_name)
}
ProtocolError::HandshakeIncomplete => crate::error::ProtocolError::HandshakeIncomplete,
ProtocolError::HttparseError(inner) => {
crate::error::ProtocolError::HttparseError(inner)
}
ProtocolError::SendAfterClosing => crate::error::ProtocolError::SendAfterClosing,
ProtocolError::ReceivedAfterClosing => {
crate::error::ProtocolError::ReceivedAfterClosing
}
ProtocolError::NonZeroReservedBits => crate::error::ProtocolError::NonZeroReservedBits,
ProtocolError::UnmaskedFrameFromClient => {
crate::error::ProtocolError::UnmaskedFrameFromClient
}
ProtocolError::MaskedFrameFromServer => {
crate::error::ProtocolError::MaskedFrameFromServer
}
ProtocolError::FragmentedControlFrame => {
crate::error::ProtocolError::FragmentedControlFrame
}
ProtocolError::ControlFrameTooBig => crate::error::ProtocolError::ControlFrameTooBig,
ProtocolError::UnknownControlFrameType(inner) => {
crate::error::ProtocolError::UnknownControlFrameType(inner)
}
ProtocolError::UnknownDataFrameType(inner) => {
crate::error::ProtocolError::UnknownDataFrameType(inner)
}
ProtocolError::UnexpectedContinueFrame => {
crate::error::ProtocolError::UnexpectedContinueFrame
}
ProtocolError::ExpectedFragment(inner) => {
crate::error::ProtocolError::ExpectedFragment(inner.into())
}
ProtocolError::ResetWithoutClosingHandshake => {
crate::error::ProtocolError::ResetWithoutClosingHandshake
}
ProtocolError::InvalidOpcode(inner) => {
crate::error::ProtocolError::InvalidOpcode(inner)
}
ProtocolError::InvalidCloseSequence => {
crate::error::ProtocolError::InvalidCloseSequence
}
}
}
}
impl From<TlsError> for crate::error::TlsError {
fn from(err: TlsError) -> Self {
match err {
#[cfg(feature = "native-tls")]
TlsError::Native(inner) => crate::error::TlsError::Native(inner),
#[cfg(feature = "__rustls-tls")]
TlsError::Rustls(inner) => crate::error::TlsError::Rustls(inner),
#[cfg(feature = "__rustls-tls")]
TlsError::Webpki(inner) => crate::error::TlsError::Webpki(inner),
#[cfg(feature = "__rustls-tls")]
TlsError::InvalidDnsName => crate::error::TlsError::InvalidDnsName,
_ => crate::error::TlsError::Unknown,
}
}
}
impl From<Data> for crate::error::Data {
fn from(data: Data) -> Self {
match data {
Data::Continue => crate::error::Data::Continue,
Data::Text => crate::error::Data::Text,
Data::Binary => crate::error::Data::Binary,
Data::Reserved(inner) => crate::error::Data::Reserved(inner),
}
}
}