web_socket/ws.rs
1#![allow(clippy::unusual_byte_groupings)]
2use crate::*;
3use std::io::{IoSlice, Result};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5
6/// WebSocket implementation for both client and server
7#[derive(Debug)]
8pub struct WebSocket<Stream> {
9 /// it is a low-level abstraction that represents the underlying byte stream over which WebSocket messages are exchanged.
10 pub stream: Stream,
11
12 /// Maximum allowed payload length in bytes.
13 ///
14 /// Default: 16 MB
15 pub max_payload_len: usize,
16
17 role: Role,
18 is_closed: bool,
19 fragment: Option<MessageType>,
20}
21
22impl<IO> WebSocket<IO> {
23 /// Create a new websocket client instance.
24 #[inline]
25 pub fn client(stream: IO) -> Self {
26 Self::from((stream, Role::Client))
27 }
28 /// Create a websocket server instance.
29 #[inline]
30 pub fn server(stream: IO) -> Self {
31 Self::from((stream, Role::Server))
32 }
33}
34
35impl<W> WebSocket<W>
36where
37 W: Unpin + AsyncWrite,
38{
39 #[doc(hidden)]
40 pub async fn send_raw(&mut self, frame: Frame<'_>) -> Result<()> {
41 let buf = match self.role {
42 Role::Server => {
43 if self.stream.is_write_vectored() {
44 let mut head = [0; 10];
45 let head_len = unsafe { frame.encode_header_unchecked(head.as_mut_ptr(), 0) };
46 let total_len = head_len + frame.data.len();
47
48 let mut bufs = [IoSlice::new(&head[..head_len]), IoSlice::new(frame.data)];
49 let mut amt = self.stream.write_vectored(&bufs).await?;
50 if amt == total_len {
51 return Ok(());
52 }
53 while amt < head_len {
54 bufs[0] = IoSlice::new(&head[amt..head_len]);
55 amt += self.stream.write_vectored(&bufs).await?;
56 }
57 if amt < total_len {
58 self.stream.write_all(&frame.data[amt - head_len..]).await?;
59 }
60 return Ok(());
61 }
62 frame.encode_without_mask()
63 }
64 Role::Client => frame.encode_with_mask(),
65 };
66 self.stream.write_all(&buf).await
67 }
68
69 /// Send message to a endpoint.
70 pub async fn send(&mut self, data: impl Into<Frame<'_>>) -> Result<()> {
71 self.send_raw(data.into()).await
72 }
73
74 /// - The Close frame MAY contain a body that indicates a reason for closing.
75 pub async fn close<T>(mut self, reason: T) -> Result<()>
76 where
77 T: CloseReason,
78 T::Bytes: AsRef<[u8]>,
79 {
80 self.send_raw(Frame {
81 fin: true,
82 opcode: 8,
83 data: reason.to_bytes().as_ref(),
84 })
85 .await?;
86 self.stream.flush().await
87 }
88
89 /// A Ping frame may serve either as a keepalive or as a means to verify that the remote endpoint is still responsive.
90 ///
91 /// It is used to send ping frame.
92 ///
93 /// ### Example
94 ///
95 /// ```no_run
96 /// # use web_socket::*;
97 /// # async {
98 /// let writer = Vec::new();
99 /// let mut ws = WebSocket::client(writer);
100 /// ws.send_ping("Hello!").await;
101 /// # };
102 /// ```
103 pub async fn send_ping(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
104 self.send_raw(Frame {
105 fin: true,
106 opcode: 9,
107 data: data.as_ref(),
108 })
109 .await
110 }
111
112 /// A Pong frame sent in response to a Ping frame must have identical
113 /// "Application data" as found in the message body of the Ping frame being replied to.
114 ///
115 /// A Pong frame MAY be sent unsolicited. This serves as a unidirectional heartbeat. A response to an unsolicited Pong frame is not expected.
116 pub async fn send_pong(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
117 self.send_raw(Frame {
118 fin: true,
119 opcode: 10,
120 data: data.as_ref(),
121 })
122 .await
123 }
124
125 /// Flushes this output stream, ensuring that all intermediately buffered contents reach their destination.
126 pub async fn flash(&mut self) -> Result<()> {
127 self.stream.flush().await
128 }
129}
130
131// ------------------------------------------------------------------------
132
133macro_rules! err { [$msg: expr] => { return Ok(Event::Error($msg)) }; }
134
135#[inline]
136pub async fn read_buf<const N: usize, R>(stream: &mut R) -> Result<[u8; N]>
137where
138 R: Unpin + AsyncRead,
139{
140 let mut buf = [0; N];
141 stream.read_exact(&mut buf).await?;
142 Ok(buf)
143}
144
145impl<R> WebSocket<R>
146where
147 R: Unpin + AsyncRead,
148{
149 /// reads [Event] from websocket stream.
150 pub async fn recv(&mut self) -> Result<Event> {
151 if self.is_closed {
152 return Err(std::io::Error::new(
153 std::io::ErrorKind::NotConnected,
154 "read after close",
155 ));
156 }
157 let event = self.recv_event().await;
158 if let Ok(Event::Close { .. } | Event::Error(..)) | Err(..) = event {
159 self.is_closed = true;
160 }
161 event
162 }
163
164 // ### WebSocket Frame Header
165 //
166 // ```txt
167 // 0 1 2 3
168 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
169 // +-+-+-+-+-------+-+-------------+-------------------------------+
170 // |F|R|R|R| opcode|M| Payload len | Extended payload length |
171 // |I|S|S|S| (4) |A| (7) | (16/64) |
172 // |N|V|V|V| |S| | (if payload len==126/127) |
173 // | |1|2|3| |K| | |
174 // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
175 // | Extended payload length continued, if payload len == 127 |
176 // + - - - - - - - - - - - - - - - +-------------------------------+
177 // | |Masking-key, if MASK set to 1 |
178 // +-------------------------------+-------------------------------+
179 // | Masking-key (continued) | Payload Data |
180 // +-------------------------------- - - - - - - - - - - - - - - - +
181 // : Payload Data continued ... :
182 // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
183 // | Payload Data continued ... |
184 // +---------------------------------------------------------------+
185 // ```
186 /// reads [Event] from websocket stream.
187 pub async fn recv_event(&mut self) -> Result<Event> {
188 let [b1, b2] = read_buf(&mut self.stream).await?;
189
190 let fin = b1 & 0b_1000_0000 != 0;
191 let rsv = b1 & 0b_111_0000;
192 let opcode = b1 & 0b_1111;
193 let len = (b2 & 0b_111_1111) as usize;
194
195 // Defines whether the "Payload data" is masked. If set to 1, a
196 // masking key is present in masking-key, and this is used to unmask
197 // the "Payload data" as per [Section 5.3](https://datatracker.ietf.org/doc/html/rfc6455#section-5.3). All frames sent from
198 // client to server have this bit set to 1.
199 let is_masked = b2 & 0b_1000_0000 != 0;
200
201 if rsv != 0 {
202 // MUST be `0` unless an extension is negotiated that defines meanings
203 // for non-zero values. If a nonzero value is received and none of
204 // the negotiated extensions defines the meaning of such a nonzero
205 // value, the receiving endpoint MUST _Fail the WebSocket Connection_.
206 err!("reserve bit must be `0`");
207 }
208
209 // A client MUST mask all frames that it sends to the server. (Note
210 // that masking is done whether or not the WebSocket Protocol is running
211 // over TLS.) The server MUST close the connection upon receiving a
212 // frame that is not masked.
213 //
214 // A server MUST NOT mask any frames that it sends to the client.
215 if let Role::Server = self.role {
216 if !is_masked {
217 err!("expected masked frame");
218 }
219 } else if is_masked {
220 err!("expected unmasked frame");
221 }
222
223 // 3-7 are reserved for further non-control frames.
224 if opcode >= 8 {
225 if !fin {
226 err!("control frame must not be fragmented");
227 }
228 if len > 125 {
229 err!("control frame must have a payload length of 125 bytes or less");
230 }
231 let msg = self.read_payload(len).await?;
232 match opcode {
233 8 => Ok(on_close(&msg)),
234 9 => Ok(Event::Ping(msg)),
235 10 => Ok(Event::Pong(msg)),
236 // 11-15 are reserved for further control frames
237 _ => err!("unknown opcode"),
238 }
239 } else {
240 let ty = match (opcode, fin, self.fragment) {
241 (2, true, None) => DataType::Complete(MessageType::Binary),
242 (1, true, None) => DataType::Complete(MessageType::Text),
243 (2, false, None) => {
244 self.fragment = Some(MessageType::Binary);
245 DataType::Stream(Stream::Start(MessageType::Binary))
246 }
247 (1, false, None) => {
248 self.fragment = Some(MessageType::Text);
249 DataType::Stream(Stream::Start(MessageType::Text))
250 }
251 (0, false, Some(ty)) => DataType::Stream(Stream::Next(ty)),
252 (0, true, Some(ty)) => {
253 self.fragment = None;
254 DataType::Stream(Stream::End(ty))
255 }
256 _ => err!("invalid data frame"),
257 };
258 let len = match len {
259 126 => u16::from_be_bytes(read_buf(&mut self.stream).await?) as usize,
260 127 => u64::from_be_bytes(read_buf(&mut self.stream).await?) as usize,
261 len => len,
262 };
263 if len > self.max_payload_len {
264 err!("payload too large");
265 }
266 let data = self.read_payload(len).await?;
267 Ok(Event::Data { ty, data })
268 }
269 }
270
271 async fn read_payload(&mut self, len: usize) -> Result<Box<[u8]>> {
272 let mut data = vec![0; len].into_boxed_slice();
273 match self.role {
274 Role::Server => {
275 let mask: [u8; 4] = read_buf(&mut self.stream).await?;
276 self.stream.read_exact(&mut data).await?;
277 // TODO: Use SIMD wherever possible for best performance
278 for i in 0..data.len() {
279 data[i] ^= mask[i & 3];
280 }
281 }
282 Role::Client => {
283 self.stream.read_exact(&mut data).await?;
284 }
285 }
286 Ok(data)
287 }
288}
289
290/// - If there is a body, the first two bytes of the body MUST be a 2-byte unsigned integer (in network byte order: Big Endian)
291/// representing a status code with value /code/ defined in [Section 7.4](https:///datatracker.ietf.org/doc/html/rfc6455#section-7.4).
292/// Following the 2-byte integer,
293///
294/// - The application MUST NOT send any more data frames after sending a `Close` frame.
295///
296/// - If an endpoint receives a Close frame and did not previously send a
297/// Close frame, the endpoint MUST send a Close frame in response. (When
298/// sending a Close frame in response, the endpoint typically echos the
299/// status code it received.) It SHOULD do so as soon as practical. An
300/// endpoint MAY delay sending a Close frame until its current message is
301/// sent
302///
303/// - After both sending and receiving a Close message, an endpoint
304/// considers the WebSocket connection closed and MUST close the
305/// underlying TCP connection.
306fn on_close(msg: &[u8]) -> Event {
307 let code = msg
308 .get(..2)
309 .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]]))
310 .unwrap_or(1000);
311
312 match code {
313 1000..=1003 | 1007..=1011 | 1015 | 3000..=3999 | 4000..=4999 => {
314 match msg.get(2..).map(|data| String::from_utf8(data.to_vec())) {
315 Some(Ok(msg)) => Event::Close {
316 code,
317 reason: msg.into_boxed_str(),
318 },
319 None => Event::Close {
320 code,
321 reason: "".into(),
322 },
323 Some(Err(_)) => Event::Error("invalid utf-8 payload"),
324 }
325 }
326 _ => Event::Error("invalid close code"),
327 }
328}
329
330impl<IO> From<(IO, Role)> for WebSocket<IO> {
331 #[inline]
332 fn from((stream, role): (IO, Role)) -> Self {
333 Self {
334 stream,
335 max_payload_len: 16 * 1024 * 1024,
336 role,
337 is_closed: false,
338 fragment: None,
339 }
340 }
341}