#![deny(missing_docs, warnings)]
#![cfg_attr(docsrs, feature(doc_cfg))]
pub mod extensions;
#[cfg(feature = "fastwebsockets")]
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
mod fastwebsockets;
mod packet;
mod sink_unfold;
mod stream;
pub mod ws;
pub use crate::{packet::*, stream::*};
use bytes::Bytes;
use dashmap::DashMap;
use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use flume as mpsc;
use futures::{channel::oneshot, select, Future, FutureExt};
use futures_timer::Delay;
use std::{
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
time::Duration,
};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum Role {
Client,
Server,
}
#[derive(Debug)]
pub enum WispError {
PacketTooSmall,
InvalidPacketType,
InvalidStreamId,
InvalidCloseReason,
InvalidUri,
UriHasNoHost,
UriHasNoPort,
MaxStreamCountReached,
IncompatibleProtocolVersion,
StreamAlreadyClosed,
WsFrameInvalidType,
WsFrameNotFinished,
WsImplError(Box<dyn std::error::Error + Sync + Send>),
WsImplSocketClosed,
WsImplNotSupported,
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
ExtensionImplNotSupported,
ExtensionsNotSupported(Vec<u8>),
Utf8Error(std::str::Utf8Error),
TryFromIntError(std::num::TryFromIntError),
Other(Box<dyn std::error::Error + Sync + Send>),
MuxMessageFailedToSend,
MuxMessageFailedToRecv,
MuxTaskEnded,
}
impl From<std::str::Utf8Error> for WispError {
fn from(err: std::str::Utf8Error) -> Self {
Self::Utf8Error(err)
}
}
impl From<std::num::TryFromIntError> for WispError {
fn from(value: std::num::TryFromIntError) -> Self {
Self::TryFromIntError(value)
}
}
impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Self::PacketTooSmall => write!(f, "Packet too small"),
Self::InvalidPacketType => write!(f, "Invalid packet type"),
Self::InvalidStreamId => write!(f, "Invalid stream id"),
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
Self::InvalidUri => write!(f, "Invalid URI"),
Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
Self::WsImplSocketClosed => {
write!(f, "Websocket implementation error: websocket closed")
}
Self::WsImplNotSupported => {
write!(f, "Websocket implementation error: unsupported feature")
}
Self::ExtensionImplError(err) => {
write!(f, "Protocol extension implementation error: {}", err)
}
Self::ExtensionImplNotSupported => {
write!(
f,
"Protocol extension implementation error: unsupported feature"
)
}
Self::ExtensionsNotSupported(list) => {
write!(f, "Protocol extensions {:?} not supported", list)
}
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err),
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
}
}
}
impl std::error::Error for WispError {}
struct MuxMapValue {
stream: mpsc::Sender<Bytes>,
stream_type: StreamType,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
}
struct MuxInner {
tx: ws::LockedWebSocketWrite,
stream_map: DashMap<u32, MuxMapValue>,
buffer_size: u32,
fut_exited: Arc<AtomicBool>
}
impl MuxInner {
pub async fn server_into_future<R>(
self,
rx: R,
extensions: Vec<AnyProtocolExtension>,
close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
self.as_future(
close_rx,
close_tx.clone(),
self.server_loop(rx, extensions, muxstream_sender, close_tx),
)
.await
}
pub async fn client_into_future<R>(
self,
rx: R,
extensions: Vec<AnyProtocolExtension>,
close_rx: mpsc::Receiver<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
self.as_future(close_rx, close_tx, self.client_loop(rx, extensions))
.await
}
async fn as_future(
&self,
close_rx: mpsc::Receiver<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
wisp_fut: impl Future<Output = Result<(), WispError>>,
) -> Result<(), WispError> {
let ret = futures::select! {
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
x = wisp_fut.fuse() => x,
};
self.fut_exited.store(true, Ordering::Release);
for x in self.stream_map.iter_mut() {
x.is_closed.store(true, Ordering::Release);
x.is_closed_event.notify(usize::MAX);
}
self.stream_map.clear();
let _ = self.tx.close().await;
ret
}
async fn create_new_stream(
&self,
stream_id: u32,
stream_type: StreamType,
role: Role,
stream_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
target_buffer_size: u32,
) -> Result<(MuxMapValue, MuxStream), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
let is_closed_event: Arc<Event> = Event::new().into();
Ok((
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(),
},
MuxStream::new(
stream_id,
role,
stream_type,
ch_rx,
stream_tx,
tx,
is_closed,
is_closed_event,
flow_control,
flow_control_event,
target_buffer_size,
),
))
}
async fn stream_loop(
&self,
stream_rx: mpsc::Receiver<WsEvent>,
stream_tx: mpsc::Sender<WsEvent>,
) {
let mut next_free_stream_id: u32 = 1;
while let Ok(msg) = stream_rx.recv_async().await {
match msg {
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?;
let (map_value, stream) = self
.create_new_stream(
stream_id,
stream_type,
Role::Client,
stream_tx.clone(),
self.tx.clone(),
0,
)
.await?;
self.tx
.write_frame(
Packet::new_connect(stream_id, stream_type, port, host).into(),
)
.await?;
self.stream_map.insert(stream_id, map_value);
next_free_stream_id = next_stream_id;
Ok(stream)
}
.await;
let _ = channel.send(ret);
}
WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
let _ = channel.send(self.tx.write_frame(packet.into()).await);
drop(stream.stream)
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut(x) => {
if let Some(reason) = x {
let _ = self
.tx
.write_frame(Packet::new_close(0, reason).into())
.await;
}
break;
}
}
}
}
fn close_stream(&self, packet: Packet) {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.is_closed_event.notify(usize::MAX);
stream.flow_control.store(u32::MAX, Ordering::Release);
stream.flow_control_event.notify(usize::MAX);
drop(stream.stream)
}
}
async fn server_loop<R>(
&self,
mut rx: R,
mut extensions: Vec<AnyProtocolExtension>,
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
stream_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{
use PacketType::*;
match packet.packet_type {
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
Connect(inner_packet) => {
let (map_value, stream) = self
.create_new_stream(
packet.stream_id,
inner_packet.stream_type,
Role::Server,
stream_tx.clone(),
self.tx.clone(),
target_buffer_size,
)
.await?;
muxstream_sender
.send_async((inner_packet, stream))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
self.stream_map.insert(packet.stream_id, map_value);
}
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.try_send(data);
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
}
}
}
Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
self.close_stream(packet)
}
}
}
}
}
async fn client_loop<R>(
&self,
mut rx: R,
mut extensions: Vec<AnyProtocolExtension>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{
use PacketType::*;
match packet.packet_type {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.send_async(data).await;
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
self.close_stream(packet)
}
}
}
}
}
}
async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame>, bool), WispError>
where
R: ws::WebSocketRead + Send,
{
let mut supported_extensions = Vec::new();
let mut extra_packet = None;
let mut downgraded = true;
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.replace(packet.into());
}
}
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(read, write).await?;
}
Ok((supported_extensions, extra_packet, downgraded))
}
pub struct ServerMux {
pub downgraded: bool,
pub supported_extension_ids: Vec<u8>,
close_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: ws::LockedWebSocketWrite,
fut_exited: Arc<AtomicBool>,
}
impl ServerMux {
pub async fn create<R, W>(
mut read: R,
write: W,
buffer_size: u32,
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(Box::new(write));
let fut_exited = Arc::new(AtomicBool::new(false));
write
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) =
if let Some(builders) = extension_builders {
write
.write_frame(
Packet::new_info(
builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect(),
)
.into(),
)
.await?;
maybe_wisp_v2(&mut read, &write, builders).await?
} else {
(Vec::new(), None, true)
};
Ok(ServerMuxResult(
Self {
muxstream_recv: rx,
close_tx: close_tx.clone(),
downgraded,
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
tx: write.clone(),
fut_exited: fut_exited.clone(),
},
MuxInner {
tx: write,
stream_map: DashMap::new(),
buffer_size,
fut_exited
}
.server_into_future(
AppendingWebSocketRead(extra_packet, read),
supported_extensions,
close_rx,
tx,
close_tx,
),
))
}
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
if self.fut_exited.load(Ordering::Acquire) {
return None;
}
self.muxstream_recv.recv_async().await.ok()
}
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.fut_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.close_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
}
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
is_closed: self.fut_exited.clone(),
}
}
}
impl Drop for ServerMux {
fn drop(&mut self) {
let _ = self.close_tx.send(WsEvent::EndFut(None));
}
}
pub struct ServerMuxResult<F>(ServerMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ServerMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
(self.0, self.1)
}
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extension_ids.contains(extension) {
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}
pub struct ClientMux {
pub downgraded: bool,
pub supported_extension_ids: Vec<u8>,
stream_tx: mpsc::Sender<WsEvent>,
tx: ws::LockedWebSocketWrite,
fut_exited: Arc<AtomicBool>,
}
impl ClientMux {
pub async fn create<R, W>(
mut read: R,
write: W,
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let write = ws::LockedWebSocketWrite::new(Box::new(write));
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
let fut_exited = Arc::new(AtomicBool::new(false));
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) =
if let Some(builders) = extension_builders {
let x = maybe_wisp_v2(&mut read, &write, builders).await?;
write
.write_frame(
Packet::new_info(
builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect(),
)
.into(),
)
.await?;
x
} else {
(Vec::new(), None, true)
};
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
Ok(ClientMuxResult(
Self {
stream_tx: tx.clone(),
downgraded,
supported_extension_ids: supported_extensions
.iter()
.map(|x| x.get_id())
.collect(),
tx: write.clone(),
fut_exited: fut_exited.clone(),
},
MuxInner {
tx: write,
stream_map: DashMap::new(),
buffer_size: packet.buffer_remaining,
fut_exited
}
.client_into_future(
AppendingWebSocketRead(extra_packet, read),
supported_extensions,
rx,
tx,
),
))
} else {
Err(WispError::InvalidPacketType)
}
}
pub async fn client_new_stream(
&self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream, WispError> {
if self.fut_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
if stream_type == StreamType::Udp
&& !self
.supported_extension_ids
.iter()
.any(|x| *x == UdpProtocolExtension::ID)
{
return Err(WispError::ExtensionsNotSupported(vec![
UdpProtocolExtension::ID,
]));
}
let (tx, rx) = oneshot::channel();
self.stream_tx
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.fut_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.stream_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
}
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
is_closed: self.fut_exited.clone(),
}
}
}
impl Drop for ClientMux {
fn drop(&mut self) {
let _ = self.stream_tx.send(WsEvent::EndFut(None));
}
}
pub struct ClientMuxResult<F>(ClientMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ClientMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
(self.0, self.1)
}
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extension_ids.contains(extension) {
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}