use crate::{
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
Role, WispError, WISP_VERSION,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum StreamType {
Tcp,
Udp,
Unknown(u8),
}
impl From<u8> for StreamType {
fn from(value: u8) -> Self {
use StreamType as S;
match value {
0x01 => S::Tcp,
0x02 => S::Udp,
x => S::Unknown(x),
}
}
}
impl From<StreamType> for u8 {
fn from(value: StreamType) -> Self {
use StreamType as S;
match value {
S::Tcp => 0x01,
S::Udp => 0x02,
S::Unknown(x) => x,
}
}
}
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum CloseReason {
Unknown = 0x01,
Voluntary = 0x02,
Unexpected = 0x03,
IncompatibleExtensions = 0x04,
ServerStreamInvalidInfo = 0x41,
ServerStreamUnreachable = 0x42,
ServerStreamConnectionTimedOut = 0x43,
ServerStreamConnectionRefused = 0x44,
ServerStreamTimedOut = 0x47,
ServerStreamBlockedAddress = 0x48,
ServerStreamThrottled = 0x49,
ClientUnexpected = 0x81,
}
impl TryFrom<u8> for CloseReason {
type Error = WispError;
fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
use CloseReason as R;
match close_reason {
0x01 => Ok(R::Unknown),
0x02 => Ok(R::Voluntary),
0x03 => Ok(R::Unexpected),
0x04 => Ok(R::IncompatibleExtensions),
0x41 => Ok(R::ServerStreamInvalidInfo),
0x42 => Ok(R::ServerStreamUnreachable),
0x43 => Ok(R::ServerStreamConnectionTimedOut),
0x44 => Ok(R::ServerStreamConnectionRefused),
0x47 => Ok(R::ServerStreamTimedOut),
0x48 => Ok(R::ServerStreamBlockedAddress),
0x49 => Ok(R::ServerStreamThrottled),
0x81 => Ok(R::ClientUnexpected),
_ => Err(Self::Error::InvalidCloseReason),
}
}
}
trait Encode {
fn encode(self, bytes: &mut BytesMut);
}
#[derive(Debug, Clone)]
pub struct ConnectPacket {
pub stream_type: StreamType,
pub destination_port: u16,
pub destination_hostname: String,
}
impl ConnectPacket {
pub fn new(
stream_type: StreamType,
destination_port: u16,
destination_hostname: String,
) -> Self {
Self {
stream_type,
destination_port,
destination_hostname,
}
}
}
impl TryFrom<Payload<'_>> for ConnectPacket {
type Error = WispError;
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < (1 + 2) {
return Err(Self::Error::PacketTooSmall);
}
Ok(Self {
stream_type: bytes.get_u8().into(),
destination_port: bytes.get_u16_le(),
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
})
}
}
impl Encode for ConnectPacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.stream_type.into());
bytes.put_u16_le(self.destination_port);
bytes.extend(self.destination_hostname.bytes());
}
}
#[derive(Debug, Copy, Clone)]
pub struct ContinuePacket {
pub buffer_remaining: u32,
}
impl ContinuePacket {
pub fn new(buffer_remaining: u32) -> Self {
Self { buffer_remaining }
}
}
impl TryFrom<Payload<'_>> for ContinuePacket {
type Error = WispError;
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < 4 {
return Err(Self::Error::PacketTooSmall);
}
Ok(Self {
buffer_remaining: bytes.get_u32_le(),
})
}
}
impl Encode for ContinuePacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u32_le(self.buffer_remaining);
}
}
#[derive(Debug, Copy, Clone)]
pub struct ClosePacket {
pub reason: CloseReason,
}
impl ClosePacket {
pub fn new(reason: CloseReason) -> Self {
Self { reason }
}
}
impl TryFrom<Payload<'_>> for ClosePacket {
type Error = WispError;
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall);
}
Ok(Self {
reason: bytes.get_u8().try_into()?,
})
}
}
impl Encode for ClosePacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.reason as u8);
}
}
#[derive(Debug, Clone)]
pub struct WispVersion {
pub major: u8,
pub minor: u8,
}
#[derive(Debug, Clone)]
pub struct InfoPacket {
pub version: WispVersion,
pub extensions: Vec<AnyProtocolExtension>,
}
impl Encode for InfoPacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.version.major);
bytes.put_u8(self.version.minor);
for extension in self.extensions {
bytes.extend_from_slice(&Bytes::from(extension));
}
}
}
#[derive(Debug, Clone)]
pub enum PacketType<'a> {
Connect(ConnectPacket),
Data(Payload<'a>),
Continue(ContinuePacket),
Close(ClosePacket),
Info(InfoPacket),
}
impl PacketType<'_> {
pub fn as_u8(&self) -> u8 {
use PacketType as P;
match self {
P::Connect(_) => 0x01,
P::Data(_) => 0x02,
P::Continue(_) => 0x03,
P::Close(_) => 0x04,
P::Info(_) => 0x05,
}
}
pub(crate) fn get_packet_size(&self) -> usize {
use PacketType as P;
match self {
P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
P::Data(p) => p.len(),
P::Continue(_) => 4,
P::Close(_) => 1,
P::Info(_) => 2,
}
}
}
impl Encode for PacketType<'_> {
fn encode(self, bytes: &mut BytesMut) {
use PacketType as P;
match self {
P::Connect(x) => x.encode(bytes),
P::Data(x) => bytes.extend_from_slice(&x),
P::Continue(x) => x.encode(bytes),
P::Close(x) => x.encode(bytes),
P::Info(x) => x.encode(bytes),
};
}
}
#[derive(Debug, Clone)]
pub struct Packet<'a> {
pub stream_id: u32,
pub packet_type: PacketType<'a>,
}
impl<'a> Packet<'a> {
pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
Self {
stream_id,
packet_type: packet,
}
}
pub fn new_connect(
stream_id: u32,
stream_type: StreamType,
destination_port: u16,
destination_hostname: String,
) -> Self {
Self {
stream_id,
packet_type: PacketType::Connect(ConnectPacket::new(
stream_type,
destination_port,
destination_hostname,
)),
}
}
pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
Self {
stream_id,
packet_type: PacketType::Data(data),
}
}
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
Self {
stream_id,
packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
}
}
pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
Self {
stream_id,
packet_type: PacketType::Close(ClosePacket::new(reason)),
}
}
pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version: WISP_VERSION,
extensions,
}),
}
}
fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
use PacketType as P;
Ok(Self {
stream_id: bytes.get_u32_le(),
packet_type: match packet_type {
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
0x02 => P::Data(bytes),
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
0x04 => P::Close(ClosePacket::try_from(bytes)?),
_ => return Err(WispError::InvalidPacketType),
},
})
}
pub(crate) fn maybe_parse_info(
frame: Frame<'a>,
role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> {
if !frame.finished {
return Err(WispError::WsFrameNotFinished);
}
if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType);
}
let mut bytes = frame.payload;
if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall);
}
let packet_type = bytes.get_u8();
if packet_type == 0x05 {
Self::parse_info(bytes, role, extension_builders)
} else {
Self::parse_packet(packet_type, bytes)
}
}
pub(crate) async fn maybe_handle_extension(
frame: Frame<'a>,
extensions: &mut [AnyProtocolExtension],
read: &mut (dyn WebSocketRead + Send),
write: &LockedWebSocketWrite,
) -> Result<Option<Self>, WispError> {
if !frame.finished {
return Err(WispError::WsFrameNotFinished);
}
if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType);
}
let mut bytes = frame.payload;
if bytes.remaining() < 5 {
return Err(WispError::PacketTooSmall);
}
let packet_type = bytes.get_u8();
match packet_type {
0x01 => Ok(Some(Self {
stream_id: bytes.get_u32_le(),
packet_type: PacketType::Connect(bytes.try_into()?),
})),
0x02 => Ok(Some(Self {
stream_id: bytes.get_u32_le(),
packet_type: PacketType::Data(bytes),
})),
0x03 => Ok(Some(Self {
stream_id: bytes.get_u32_le(),
packet_type: PacketType::Continue(bytes.try_into()?),
})),
0x04 => Ok(Some(Self {
stream_id: bytes.get_u32_le(),
packet_type: PacketType::Close(bytes.try_into()?),
})),
0x05 => Ok(None),
packet_type => {
if let Some(extension) = extensions
.iter_mut()
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
{
extension
.handle_packet(BytesMut::from(bytes).freeze(), read, write)
.await?;
Ok(None)
} else {
Err(WispError::InvalidPacketType)
}
}
}
}
fn parse_info(
mut bytes: Payload<'a>,
role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> {
if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall);
}
if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let version = WispVersion {
major: bytes.get_u8(),
minor: bytes.get_u8(),
};
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new();
while bytes.remaining() > 4 {
let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length {
return Err(WispError::PacketTooSmall);
}
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
extensions.push(extension)
}
} else {
bytes.advance(length)
}
}
Ok(Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version,
extensions,
}),
})
}
}
impl Encode for Packet<'_> {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.packet_type.as_u8());
bytes.put_u32_le(self.stream_id);
self.packet_type.encode(bytes);
}
}
impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
type Error = WispError;
fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall);
}
let packet_type = bytes.get_u8();
Self::parse_packet(packet_type, bytes)
}
}
impl From<Packet<'_>> for BytesMut {
fn from(packet: Packet) -> Self {
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
packet.encode(&mut encoded);
encoded
}
}
impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
type Error = WispError;
fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
if !frame.finished {
return Err(Self::Error::WsFrameNotFinished);
}
if frame.opcode != ws::OpCode::Binary {
return Err(Self::Error::WsFrameInvalidType);
}
Packet::try_from(frame.payload)
}
}
impl From<Packet<'_>> for ws::Frame<'static> {
fn from(packet: Packet) -> Self {
Self::binary(Payload::Bytes(BytesMut::from(packet)))
}
}