use std::{ops::Deref, sync::Arc};
use crate::WispError;
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use futures::lock::Mutex;
#[derive(Debug)]
pub enum Payload<'a> {
Borrowed(&'a [u8]),
Bytes(BytesMut),
}
impl From<BytesMut> for Payload<'static> {
fn from(value: BytesMut) -> Self {
Self::Bytes(value)
}
}
impl<'a> From<&'a [u8]> for Payload<'a> {
fn from(value: &'a [u8]) -> Self {
Self::Borrowed(value)
}
}
impl Payload<'_> {
pub fn into_owned(self) -> Self {
match self {
Self::Bytes(x) => Self::Bytes(x),
Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)),
}
}
}
impl From<Payload<'_>> for BytesMut {
fn from(value: Payload<'_>) -> Self {
match value {
Payload::Bytes(x) => x,
Payload::Borrowed(x) => x.into(),
}
}
}
impl Deref for Payload<'_> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
Self::Bytes(x) => x.deref(),
Self::Borrowed(x) => x,
}
}
}
impl Clone for Payload<'_> {
fn clone(&self) -> Self {
match self {
Self::Bytes(x) => Self::Bytes(x.clone()),
Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)),
}
}
}
impl Buf for Payload<'_> {
fn remaining(&self) -> usize {
match self {
Self::Bytes(x) => x.remaining(),
Self::Borrowed(x) => x.remaining(),
}
}
fn chunk(&self) -> &[u8] {
match self {
Self::Bytes(x) => x.chunk(),
Self::Borrowed(x) => x.chunk(),
}
}
fn advance(&mut self, cnt: usize) {
match self {
Self::Bytes(x) => x.advance(cnt),
Self::Borrowed(x) => x.advance(cnt),
}
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum OpCode {
Text,
Binary,
Close,
Ping,
Pong,
}
#[derive(Debug, Clone)]
pub struct Frame<'a> {
pub finished: bool,
pub opcode: OpCode,
pub payload: Payload<'a>,
}
impl<'a> Frame<'a> {
pub fn text(payload: Payload<'a>) -> Self {
Self {
finished: true,
opcode: OpCode::Text,
payload,
}
}
pub fn binary(payload: Payload<'a>) -> Self {
Self {
finished: true,
opcode: OpCode::Binary,
payload,
}
}
pub fn close(payload: Payload<'a>) -> Self {
Self {
finished: true,
opcode: OpCode::Close,
payload,
}
}
}
#[async_trait]
pub trait WebSocketRead {
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError>;
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
self.wisp_read_frame(tx).await.map(|x| (x, None))
}
}
#[async_trait]
pub trait WebSocketWrite {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
async fn wisp_close(&mut self) -> Result<(), WispError>;
async fn wisp_write_split(
&mut self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
let mut payload = BytesMut::from(header.payload);
payload.extend_from_slice(&body.payload);
self.wisp_write_frame(Frame::binary(Payload::Bytes(payload)))
.await
}
}
#[derive(Clone)]
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
impl LockedWebSocketWrite {
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
Self(Mutex::new(ws).into())
}
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
self.0.lock().await.wisp_write_frame(frame).await
}
pub(crate) async fn write_split(
&self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
self.0.lock().await.wisp_write_split(header, body).await
}
pub async fn close(&self) -> Result<(), WispError> {
self.0.lock().await.wisp_close().await
}
}
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
where
R: WebSocketRead + Send;
#[async_trait]
impl<R> WebSocketRead for AppendingWebSocketRead<R>
where
R: WebSocketRead + Send,
{
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
if let Some(x) = self.0.take() {
return Ok(x);
}
self.1.wisp_read_frame(tx).await
}
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
if let Some(x) = self.0.take() {
return Ok((x, None));
}
self.1.wisp_read_split(tx).await
}
}