web_socket/
ws.rs

1#![allow(clippy::unusual_byte_groupings)]
2use crate::*;
3use std::io::{IoSlice, Result};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5
6/// WebSocket implementation for both client and server
7#[derive(Debug)]
8pub struct WebSocket<Stream> {
9    /// it is a low-level abstraction that represents the underlying byte stream over which WebSocket messages are exchanged.
10    pub stream: Stream,
11
12    /// Maximum allowed payload length in bytes.
13    ///
14    /// Default: 16 MB
15    pub max_payload_len: usize,
16
17    role: Role,
18    is_closed: bool,
19    fragment: Option<MessageType>,
20}
21
22impl<IO> WebSocket<IO> {
23    /// Create a new websocket client instance.
24    #[inline]
25    pub fn client(stream: IO) -> Self {
26        Self::from((stream, Role::Client))
27    }
28    /// Create a websocket server instance.
29    #[inline]
30    pub fn server(stream: IO) -> Self {
31        Self::from((stream, Role::Server))
32    }
33}
34
35impl<W> WebSocket<W>
36where
37    W: Unpin + AsyncWrite,
38{
39    #[doc(hidden)]
40    pub async fn send_raw(&mut self, frame: Frame<'_>) -> Result<()> {
41        let buf = match self.role {
42            Role::Server => {
43                if self.stream.is_write_vectored() {
44                    let mut head = [0; 10];
45                    let head_len = unsafe { frame.encode_header_unchecked(head.as_mut_ptr(), 0) };
46                    let total_len = head_len + frame.data.len();
47
48                    let mut bufs = [IoSlice::new(&head[..head_len]), IoSlice::new(frame.data)];
49                    let mut amt = self.stream.write_vectored(&bufs).await?;
50                    if amt == total_len {
51                        return Ok(());
52                    }
53                    while amt < head_len {
54                        bufs[0] = IoSlice::new(&head[amt..head_len]);
55                        amt += self.stream.write_vectored(&bufs).await?;
56                    }
57                    if amt < total_len {
58                        self.stream.write_all(&frame.data[amt - head_len..]).await?;
59                    }
60                    return Ok(());
61                }
62                frame.encode_without_mask()
63            }
64            Role::Client => frame.encode_with_mask(),
65        };
66        self.stream.write_all(&buf).await
67    }
68
69    /// Send message to a endpoint.
70    pub async fn send(&mut self, data: impl Into<Frame<'_>>) -> Result<()> {
71        self.send_raw(data.into()).await
72    }
73
74    /// - The Close frame MAY contain a body that indicates a reason for closing.
75    pub async fn close<T>(mut self, reason: T) -> Result<()>
76    where
77        T: CloseReason,
78        T::Bytes: AsRef<[u8]>,
79    {
80        self.send_raw(Frame {
81            fin: true,
82            opcode: 8,
83            data: reason.to_bytes().as_ref(),
84        })
85        .await?;
86        self.stream.flush().await
87    }
88
89    /// A Ping frame may serve either as a keepalive or as a means to verify that the remote endpoint is still responsive.
90    ///
91    /// It is used to send ping frame.
92    ///
93    /// ### Example
94    ///
95    /// ```no_run
96    /// # use web_socket::*;
97    /// # async {
98    /// let writer = Vec::new();
99    /// let mut ws = WebSocket::client(writer);
100    /// ws.send_ping("Hello!").await;
101    /// # };
102    /// ```
103    pub async fn send_ping(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
104        self.send_raw(Frame {
105            fin: true,
106            opcode: 9,
107            data: data.as_ref(),
108        })
109        .await
110    }
111
112    /// A Pong frame sent in response to a Ping frame must have identical
113    /// "Application data" as found in the message body of the Ping frame being replied to.
114    ///
115    /// A Pong frame MAY be sent unsolicited.  This serves as a unidirectional heartbeat.  A response to an unsolicited Pong frame is not expected.
116    pub async fn send_pong(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
117        self.send_raw(Frame {
118            fin: true,
119            opcode: 10,
120            data: data.as_ref(),
121        })
122        .await
123    }
124
125    /// Flushes this output stream, ensuring that all intermediately buffered contents reach their destination.
126    pub async fn flash(&mut self) -> Result<()> {
127        self.stream.flush().await
128    }
129}
130
131// ------------------------------------------------------------------------
132
133macro_rules! err { [$msg: expr] => { return Ok(Event::Error($msg)) }; }
134
135#[inline]
136pub async fn read_buf<const N: usize, R>(stream: &mut R) -> Result<[u8; N]>
137where
138    R: Unpin + AsyncRead,
139{
140    let mut buf = [0; N];
141    stream.read_exact(&mut buf).await?;
142    Ok(buf)
143}
144
145impl<R> WebSocket<R>
146where
147    R: Unpin + AsyncRead,
148{
149    /// reads [Event] from websocket stream.
150    pub async fn recv(&mut self) -> Result<Event> {
151        if self.is_closed {
152            return Err(std::io::Error::new(
153                std::io::ErrorKind::NotConnected,
154                "read after close",
155            ));
156        }
157        let event = self.recv_event().await;
158        if let Ok(Event::Close { .. } | Event::Error(..)) | Err(..) = event {
159            self.is_closed = true;
160        }
161        event
162    }
163
164    // ### WebSocket Frame Header
165    //
166    // ```txt
167    //  0                   1                   2                   3
168    //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
169    // +-+-+-+-+-------+-+-------------+-------------------------------+
170    // |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
171    // |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
172    // |N|V|V|V|       |S|             |   (if payload len==126/127)   |
173    // | |1|2|3|       |K|             |                               |
174    // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
175    // |     Extended payload length continued, if payload len == 127  |
176    // + - - - - - - - - - - - - - - - +-------------------------------+
177    // |                               |Masking-key, if MASK set to 1  |
178    // +-------------------------------+-------------------------------+
179    // | Masking-key (continued)       |          Payload Data         |
180    // +-------------------------------- - - - - - - - - - - - - - - - +
181    // :                     Payload Data continued ...                :
182    // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
183    // |                     Payload Data continued ...                |
184    // +---------------------------------------------------------------+
185    // ```
186    /// reads [Event] from websocket stream.
187    pub async fn recv_event(&mut self) -> Result<Event> {
188        let [b1, b2] = read_buf(&mut self.stream).await?;
189
190        let fin = b1 & 0b_1000_0000 != 0;
191        let rsv = b1 & 0b_111_0000;
192        let opcode = b1 & 0b_1111;
193        let len = (b2 & 0b_111_1111) as usize;
194
195        // Defines whether the "Payload data" is masked.  If set to 1, a
196        // masking key is present in masking-key, and this is used to unmask
197        // the "Payload data" as per [Section 5.3](https://datatracker.ietf.org/doc/html/rfc6455#section-5.3).  All frames sent from
198        // client to server have this bit set to 1.
199        let is_masked = b2 & 0b_1000_0000 != 0;
200
201        if rsv != 0 {
202            // MUST be `0` unless an extension is negotiated that defines meanings
203            // for non-zero values.  If a nonzero value is received and none of
204            // the negotiated extensions defines the meaning of such a nonzero
205            // value, the receiving endpoint MUST _Fail the WebSocket Connection_.
206            err!("reserve bit must be `0`");
207        }
208
209        // A client MUST mask all frames that it sends to the server. (Note
210        // that masking is done whether or not the WebSocket Protocol is running
211        // over TLS.)  The server MUST close the connection upon receiving a
212        // frame that is not masked.
213        //
214        // A server MUST NOT mask any frames that it sends to the client.
215        if let Role::Server = self.role {
216            if !is_masked {
217                err!("expected masked frame");
218            }
219        } else if is_masked {
220            err!("expected unmasked frame");
221        }
222
223        // 3-7 are reserved for further non-control frames.
224        if opcode >= 8 {
225            if !fin {
226                err!("control frame must not be fragmented");
227            }
228            if len > 125 {
229                err!("control frame must have a payload length of 125 bytes or less");
230            }
231            let msg = self.read_payload(len).await?;
232            match opcode {
233                8 => Ok(on_close(&msg)),
234                9 => Ok(Event::Ping(msg)),
235                10 => Ok(Event::Pong(msg)),
236                // 11-15 are reserved for further control frames
237                _ => err!("unknown opcode"),
238            }
239        } else {
240            let ty = match (opcode, fin, self.fragment) {
241                (2, true, None) => DataType::Complete(MessageType::Binary),
242                (1, true, None) => DataType::Complete(MessageType::Text),
243                (2, false, None) => {
244                    self.fragment = Some(MessageType::Binary);
245                    DataType::Stream(Stream::Start(MessageType::Binary))
246                }
247                (1, false, None) => {
248                    self.fragment = Some(MessageType::Text);
249                    DataType::Stream(Stream::Start(MessageType::Text))
250                }
251                (0, false, Some(ty)) => DataType::Stream(Stream::Next(ty)),
252                (0, true, Some(ty)) => {
253                    self.fragment = None;
254                    DataType::Stream(Stream::End(ty))
255                }
256                _ => err!("invalid data frame"),
257            };
258            let len = match len {
259                126 => u16::from_be_bytes(read_buf(&mut self.stream).await?) as usize,
260                127 => u64::from_be_bytes(read_buf(&mut self.stream).await?) as usize,
261                len => len,
262            };
263            if len > self.max_payload_len {
264                err!("payload too large");
265            }
266            let data = self.read_payload(len).await?;
267            Ok(Event::Data { ty, data })
268        }
269    }
270
271    async fn read_payload(&mut self, len: usize) -> Result<Box<[u8]>> {
272        let mut data = vec![0; len].into_boxed_slice();
273        match self.role {
274            Role::Server => {
275                let mask: [u8; 4] = read_buf(&mut self.stream).await?;
276                self.stream.read_exact(&mut data).await?;
277                // TODO: Use SIMD wherever possible for best performance
278                for i in 0..data.len() {
279                    data[i] ^= mask[i & 3];
280                }
281            }
282            Role::Client => {
283                self.stream.read_exact(&mut data).await?;
284            }
285        }
286        Ok(data)
287    }
288}
289
290/// - If there is a body, the first two bytes of the body MUST be a 2-byte unsigned integer (in network byte order: Big Endian)
291///   representing a status code with value /code/ defined in [Section 7.4](https:///datatracker.ietf.org/doc/html/rfc6455#section-7.4).
292///   Following the 2-byte integer,
293///
294/// - The application MUST NOT send any more data frames after sending a `Close` frame.
295///
296/// - If an endpoint receives a Close frame and did not previously send a
297///   Close frame, the endpoint MUST send a Close frame in response.  (When
298///   sending a Close frame in response, the endpoint typically echos the
299///   status code it received.)  It SHOULD do so as soon as practical.  An
300///   endpoint MAY delay sending a Close frame until its current message is
301///   sent
302///
303/// - After both sending and receiving a Close message, an endpoint
304///   considers the WebSocket connection closed and MUST close the
305///   underlying TCP connection.
306fn on_close(msg: &[u8]) -> Event {
307    let code = msg
308        .get(..2)
309        .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]]))
310        .unwrap_or(1000);
311
312    match code {
313        1000..=1003 | 1007..=1011 | 1015 | 3000..=3999 | 4000..=4999 => {
314            match msg.get(2..).map(|data| String::from_utf8(data.to_vec())) {
315                Some(Ok(msg)) => Event::Close {
316                    code,
317                    reason: msg.into_boxed_str(),
318                },
319                None => Event::Close {
320                    code,
321                    reason: "".into(),
322                },
323                Some(Err(_)) => Event::Error("invalid utf-8 payload"),
324            }
325        }
326        _ => Event::Error("invalid close code"),
327    }
328}
329
330impl<IO> From<(IO, Role)> for WebSocket<IO> {
331    #[inline]
332    fn from((stream, role): (IO, Role)) -> Self {
333        Self {
334            stream,
335            max_payload_len: 16 * 1024 * 1024,
336            role,
337            is_closed: false,
338            fragment: None,
339        }
340    }
341}