wisp_mux/
ws.rs

1//! Abstraction over WebSocket implementations.
2//!
3//! Use the [`fastwebsockets`] implementation of these traits as an example for implementing them
4//! for other WebSocket implementations.
5//!
6//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
7use std::{ops::Deref, sync::Arc};
8
9use crate::WispError;
10use async_trait::async_trait;
11use bytes::{Buf, BytesMut};
12use futures::lock::Mutex;
13
14/// Payload of the websocket frame.
15#[derive(Debug)]
16pub enum Payload<'a> {
17	/// Borrowed payload. Currently used when writing data.
18	Borrowed(&'a [u8]),
19	/// BytesMut payload. Currently used when reading data.
20	Bytes(BytesMut),
21}
22
23impl From<BytesMut> for Payload<'static> {
24	fn from(value: BytesMut) -> Self {
25		Self::Bytes(value)
26	}
27}
28
29impl<'a> From<&'a [u8]> for Payload<'a> {
30	fn from(value: &'a [u8]) -> Self {
31		Self::Borrowed(value)
32	}
33}
34
35impl Payload<'_> {
36	/// Turn a Payload<'a> into a Payload<'static> by copying the data.
37	pub fn into_owned(self) -> Self {
38		match self {
39			Self::Bytes(x) => Self::Bytes(x),
40			Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)),
41		}
42	}
43}
44
45impl From<Payload<'_>> for BytesMut {
46	fn from(value: Payload<'_>) -> Self {
47		match value {
48			Payload::Bytes(x) => x,
49			Payload::Borrowed(x) => x.into(),
50		}
51	}
52}
53
54impl Deref for Payload<'_> {
55	type Target = [u8];
56	fn deref(&self) -> &Self::Target {
57		match self {
58			Self::Bytes(x) => x.deref(),
59			Self::Borrowed(x) => x,
60		}
61	}
62}
63
64impl Clone for Payload<'_> {
65	fn clone(&self) -> Self {
66		match self {
67			Self::Bytes(x) => Self::Bytes(x.clone()),
68			Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)),
69		}
70	}
71}
72
73impl Buf for Payload<'_> {
74	fn remaining(&self) -> usize {
75		match self {
76			Self::Bytes(x) => x.remaining(),
77			Self::Borrowed(x) => x.remaining(),
78		}
79	}
80
81	fn chunk(&self) -> &[u8] {
82		match self {
83			Self::Bytes(x) => x.chunk(),
84			Self::Borrowed(x) => x.chunk(),
85		}
86	}
87
88	fn advance(&mut self, cnt: usize) {
89		match self {
90			Self::Bytes(x) => x.advance(cnt),
91			Self::Borrowed(x) => x.advance(cnt),
92		}
93	}
94}
95
96/// Opcode of the WebSocket frame.
97#[derive(Debug, PartialEq, Clone, Copy)]
98pub enum OpCode {
99	/// Text frame.
100	Text,
101	/// Binary frame.
102	Binary,
103	/// Close frame.
104	Close,
105	/// Ping frame.
106	Ping,
107	/// Pong frame.
108	Pong,
109}
110
111/// WebSocket frame.
112#[derive(Debug, Clone)]
113pub struct Frame<'a> {
114	/// Whether the frame is finished or not.
115	pub finished: bool,
116	/// Opcode of the WebSocket frame.
117	pub opcode: OpCode,
118	/// Payload of the WebSocket frame.
119	pub payload: Payload<'a>,
120}
121
122impl<'a> Frame<'a> {
123	/// Create a new text frame.
124	pub fn text(payload: Payload<'a>) -> Self {
125		Self {
126			finished: true,
127			opcode: OpCode::Text,
128			payload,
129		}
130	}
131
132	/// Create a new binary frame.
133	pub fn binary(payload: Payload<'a>) -> Self {
134		Self {
135			finished: true,
136			opcode: OpCode::Binary,
137			payload,
138		}
139	}
140
141	/// Create a new close frame.
142	pub fn close(payload: Payload<'a>) -> Self {
143		Self {
144			finished: true,
145			opcode: OpCode::Close,
146			payload,
147		}
148	}
149}
150
151/// Generic WebSocket read trait.
152#[async_trait]
153pub trait WebSocketRead {
154	/// Read a frame from the socket.
155	async fn wisp_read_frame(
156		&mut self,
157		tx: &LockedWebSocketWrite,
158	) -> Result<Frame<'static>, WispError>;
159
160	/// Read a split frame from the socket.
161	async fn wisp_read_split(
162		&mut self,
163		tx: &LockedWebSocketWrite,
164	) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
165		self.wisp_read_frame(tx).await.map(|x| (x, None))
166	}
167}
168
169/// Generic WebSocket write trait.
170#[async_trait]
171pub trait WebSocketWrite {
172	/// Write a frame to the socket.
173	async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
174
175	/// Close the socket.
176	async fn wisp_close(&mut self) -> Result<(), WispError>;
177
178	/// Write a split frame to the socket.
179	async fn wisp_write_split(
180		&mut self,
181		header: Frame<'_>,
182		body: Frame<'_>,
183	) -> Result<(), WispError> {
184		let mut payload = BytesMut::from(header.payload);
185		payload.extend_from_slice(&body.payload);
186		self.wisp_write_frame(Frame::binary(Payload::Bytes(payload)))
187			.await
188	}
189}
190
191/// Locked WebSocket.
192#[derive(Clone)]
193pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
194
195impl LockedWebSocketWrite {
196	/// Create a new locked websocket.
197	pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
198		Self(Mutex::new(ws).into())
199	}
200
201	/// Write a frame to the websocket.
202	pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
203		self.0.lock().await.wisp_write_frame(frame).await
204	}
205
206	pub(crate) async fn write_split(
207		&self,
208		header: Frame<'_>,
209		body: Frame<'_>,
210	) -> Result<(), WispError> {
211		self.0.lock().await.wisp_write_split(header, body).await
212	}
213
214	/// Close the websocket.
215	pub async fn close(&self) -> Result<(), WispError> {
216		self.0.lock().await.wisp_close().await
217	}
218}
219
220pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
221where
222	R: WebSocketRead + Send;
223
224#[async_trait]
225impl<R> WebSocketRead for AppendingWebSocketRead<R>
226where
227	R: WebSocketRead + Send,
228{
229	async fn wisp_read_frame(
230		&mut self,
231		tx: &LockedWebSocketWrite,
232	) -> Result<Frame<'static>, WispError> {
233		if let Some(x) = self.0.take() {
234			return Ok(x);
235		}
236		self.1.wisp_read_frame(tx).await
237	}
238
239	async fn wisp_read_split(
240		&mut self,
241		tx: &LockedWebSocketWrite,
242	) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
243		if let Some(x) = self.0.take() {
244			return Ok((x, None));
245		}
246		self.1.wisp_read_split(tx).await
247	}
248}