tokio_websockets/proto/
types.rs

1//! Types required for the WebSocket protocol implementation.
2use std::{fmt, mem::replace, num::NonZeroU16, ops::Deref};
3
4use bytes::{BufMut, Bytes, BytesMut};
5
6use super::error::ProtocolError;
7use crate::utf8;
8
9/// The opcode of a WebSocket frame. It denotes the type of the frame or an
10/// assembled message.
11///
12/// A fully assembled [`Message`] will never have a continuation opcode.
13#[derive(Debug, PartialEq, Eq, Clone, Copy)]
14pub(super) enum OpCode {
15    /// A continuation opcode. This will never be encountered in a full
16    /// [`Message`].
17    Continuation,
18    /// A text opcode.
19    Text,
20    /// A binary opcode.
21    Binary,
22    /// A close opcode.
23    Close,
24    /// A ping opcode.
25    Ping,
26    /// A pong opcode.
27    Pong,
28}
29
30impl OpCode {
31    /// Whether this is a control opcode (i.e. close, ping or pong).
32    pub(super) fn is_control(self) -> bool {
33        matches!(self, Self::Close | Self::Ping | Self::Pong)
34    }
35}
36
37impl TryFrom<u8> for OpCode {
38    type Error = ProtocolError;
39
40    fn try_from(value: u8) -> Result<Self, Self::Error> {
41        match value {
42            0 => Ok(Self::Continuation),
43            1 => Ok(Self::Text),
44            2 => Ok(Self::Binary),
45            8 => Ok(Self::Close),
46            9 => Ok(Self::Ping),
47            10 => Ok(Self::Pong),
48            _ => Err(ProtocolError::InvalidOpcode),
49        }
50    }
51}
52
53impl From<OpCode> for u8 {
54    fn from(value: OpCode) -> Self {
55        match value {
56            OpCode::Continuation => 0,
57            OpCode::Text => 1,
58            OpCode::Binary => 2,
59            OpCode::Close => 8,
60            OpCode::Ping => 9,
61            OpCode::Pong => 10,
62        }
63    }
64}
65
66/// Close status code.
67#[derive(Clone, Copy, Debug, Eq, PartialEq)]
68pub struct CloseCode(NonZeroU16);
69
70// rustfmt reorders these alphabetically
71#[rustfmt::skip]
72impl CloseCode {
73    /// Normal closure, meaning that the purpose for which the connection was
74    /// established has been fulfilled.
75    pub const NORMAL_CLOSURE: Self = Self::constant(1000);
76    /// Endpoint is "going away", such as a server going down or a browser
77    /// having navigated away from a page.
78    pub const GOING_AWAY: Self = Self::constant(1001);
79    /// Endpoint is terminating the connection due to a protocol error.
80    pub const PROTOCOL_ERROR: Self = Self::constant(1002);
81    /// Endpoint is terminating the connection because it has received a type of
82    /// data it cannot accept.
83    pub const UNSUPPORTED_DATA: Self = Self::constant(1003);
84    /// No status code was actually present.
85    pub const NO_STATUS_RECEIVED: Self = Self::constant(1005);
86    /// Endpoint is terminating the connection because it has received data
87    /// within a message that was not consistent with the type of the message.
88    pub const INVALID_FRAME_PAYLOAD_DATA: Self = Self::constant(1007);
89    /// Endpoint is terminating the connection because it has received a message
90    /// that violates its policy.
91    pub const POLICY_VIOLATION: Self = Self::constant(1008);
92    /// Endpoint is terminating the connection because it has received a message
93    /// that is too big for it to process.
94    pub const MESSAGE_TOO_BIG: Self = Self::constant(1009);
95    /// Client is terminating the connection because it has expected the server
96    /// to negotiate one or more extension, but the server didn't return them in
97    /// the response message of the WebSocket handshake.
98    pub const MANDATORY_EXTENSION: Self = Self::constant(1010);
99    /// Server is terminating the connection because it encountered an
100    /// unexpected condition that prevented it from fulfilling the request.
101    pub const INTERNAL_SERVER_ERROR: Self = Self::constant(1011);
102    /// Service is restarted. A client may reconnect, and if it choses to do,
103    /// should reconnect using a randomized delay of 5--30s.
104    pub const SERVICE_RESTART: Self = Self::constant(1012);
105    /// Service is experiencing overload. A client should only connect to a
106    /// different IP (when there are multiple for the target) or reconnect to
107    /// the same IP upon user action.
108    pub const SERVICE_OVERLOAD: Self = Self::constant(1013);
109    /// The server was acting as a gateway or proxy and received an invalid
110    /// response from the upstream server. This is similar to the HTTP 502
111    /// status code.
112    pub const BAD_GATEWAY: Self = Self::constant(1014);
113}
114
115impl CloseCode {
116    /// Try to construct [`CloseCode`] from `u16`
117    ///
118    /// Returns `None` if `code` is not a valid `CloseCode`
119    const fn try_from_u16(code: u16) -> Option<Self> {
120        match code {
121            1000..=1015 | 3000..=4999 => {
122                // FIXME: replace with `Some(Self(NonZeroU16::new(code).unwrap()))`
123                // once MSRV is bumped to 1.83
124                match NonZeroU16::new(code) {
125                    Some(code) => Some(Self(code)),
126                    None => unreachable!(),
127                }
128            }
129            0..=999 | 1016..=2999 | 5000..=u16::MAX => None,
130        }
131    }
132
133    /// Try to construct [`CloseCode`] from `u16`
134    ///
135    /// Panics if `code` is not a valid `CloseCode`
136    const fn constant(code: u16) -> Self {
137        // FIXME: replace with `Self::try_from_u16(code).unwrap()`
138        // once MSRV is bumped to 1.83
139        match Self::try_from_u16(code) {
140            Some(code) => code,
141            None => unreachable!(),
142        }
143    }
144
145    /// Whether the close code is reserved and cannot be sent over the wire.
146    #[must_use]
147    pub fn is_reserved(self) -> bool {
148        match self.0.get() {
149            1004 | 1005 | 1006 | 1015 => true,
150            1000..=4999 => false,
151            // `TryFrom` is the only way to acquire self and it errors for these values
152            0..=999 | 5000..=u16::MAX => {
153                debug_assert!(false, "unexpected CloseCode");
154                false
155            }
156        }
157    }
158}
159
160impl From<CloseCode> for u16 {
161    fn from(value: CloseCode) -> Self {
162        value.0.get()
163    }
164}
165
166impl TryFrom<u16> for CloseCode {
167    type Error = ProtocolError;
168
169    fn try_from(value: u16) -> Result<Self, Self::Error> {
170        Self::try_from_u16(value).ok_or(ProtocolError::InvalidCloseCode)
171    }
172}
173
174/// The websocket message payload storage.
175///
176/// Payloads can be created by using the `From<T>` implementations.
177///
178/// Sending the payloads or calling [`Into<BytesMut>`] is zero-copy, except when
179/// sending a payload created from a static slice or when the payload buffer is
180/// not unique. All conversions to other types are zero-cost.
181///
182/// [`Into<BytesMut>`]: #impl-From<Payload>-for-BytesMut
183#[derive(Clone)]
184pub struct Payload {
185    /// The raw payload data.
186    data: Bytes,
187    /// Whether the payload data was validated to be valid UTF-8.
188    utf8_validated: bool,
189}
190
191impl Payload {
192    /// Creates a new shared `Payload` from a static slice.
193    const fn from_static(bytes: &'static [u8]) -> Self {
194        Self {
195            data: Bytes::from_static(bytes),
196            utf8_validated: false,
197        }
198    }
199
200    /// Marks whether the payload contents were validated to be valid UTF-8.
201    pub(super) fn set_utf8_validated(&mut self, value: bool) {
202        self.utf8_validated = value;
203    }
204
205    /// Shortens the buffer, keeping the first `len` bytes and dropping the
206    /// rest.
207    pub(super) fn truncate(&mut self, len: usize) {
208        self.data.truncate(len);
209    }
210
211    /// Splits the buffer into two at the given index.
212    fn split_to(&mut self, at: usize) -> Self {
213        // This is only used by the outgoing message frame iterator, so we do not care
214        // about the value of utf8_validated. For the sake of correctness (in case we
215        // split a utf8 codepoint), we set it to false.
216        self.utf8_validated = false;
217        Self {
218            data: self.data.split_to(at),
219            utf8_validated: false,
220        }
221    }
222}
223
224impl Deref for Payload {
225    type Target = [u8];
226
227    fn deref(&self) -> &Self::Target {
228        &self.data
229    }
230}
231
232impl fmt::Debug for Payload {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        f.debug_tuple("Payload").field(&self.data).finish()
235    }
236}
237
238impl From<Bytes> for Payload {
239    fn from(value: Bytes) -> Self {
240        Self {
241            data: value,
242            utf8_validated: false,
243        }
244    }
245}
246
247impl From<BytesMut> for Payload {
248    fn from(value: BytesMut) -> Self {
249        Self {
250            data: value.freeze(),
251            utf8_validated: false,
252        }
253    }
254}
255
256impl From<Payload> for Bytes {
257    fn from(value: Payload) -> Self {
258        value.data
259    }
260}
261
262impl From<Payload> for BytesMut {
263    fn from(value: Payload) -> Self {
264        value.data.into()
265    }
266}
267
268impl From<Vec<u8>> for Payload {
269    fn from(value: Vec<u8>) -> Self {
270        // BytesMut::from_iter goes through a specialization in std if the iterator is a
271        // Vec, effectively allowing us to use BytesMut::from_vec which isn't
272        // exposed in bytes. See https://github.com/tokio-rs/bytes/issues/723 for details.
273        Self {
274            data: BytesMut::from_iter(value).freeze(),
275            utf8_validated: false,
276        }
277    }
278}
279
280impl From<String> for Payload {
281    fn from(value: String) -> Self {
282        // See From<Vec<u8>> impl for reasoning behind this.
283        Self {
284            data: BytesMut::from_iter(value.into_bytes()).freeze(),
285            utf8_validated: true,
286        }
287    }
288}
289
290impl From<&'static [u8]> for Payload {
291    fn from(value: &'static [u8]) -> Self {
292        Self {
293            data: Bytes::from_static(value),
294            utf8_validated: false,
295        }
296    }
297}
298
299impl From<&'static str> for Payload {
300    fn from(value: &'static str) -> Self {
301        Self {
302            data: Bytes::from_static(value.as_bytes()),
303            utf8_validated: true,
304        }
305    }
306}
307
308/// A WebSocket message. This is cheaply clonable and uses [`Payload`] as the
309/// payload storage underneath.
310///
311/// Received messages are always validated prior to dealing with them, so all
312/// the type casting methods are either almost or fully zero cost.
313#[derive(Debug, Clone)]
314pub struct Message {
315    /// The [`OpCode`] of the message.
316    pub(super) opcode: OpCode,
317    /// The payload of the message.
318    pub(super) payload: Payload,
319}
320
321impl Message {
322    /// Create a new text message. The payload contents must be valid UTF-8.
323    #[must_use]
324    pub fn text<P: Into<Payload>>(payload: P) -> Self {
325        Self {
326            opcode: OpCode::Text,
327            payload: payload.into(),
328        }
329    }
330
331    /// Create a new binary message.
332    #[must_use]
333    pub fn binary<P: Into<Payload>>(payload: P) -> Self {
334        Self {
335            opcode: OpCode::Binary,
336            payload: payload.into(),
337        }
338    }
339
340    /// Create a new close message. If an non-empty reason is specified, a
341    /// [`CloseCode`] must be specified for it to be included.
342    ///
343    /// # Panics
344    /// - If the `code` is reserved so it cannot be sent.
345    /// - If `code` is present and the `reason` exceeds 123 bytes, the
346    ///   protocol-imposed limit.
347    #[must_use]
348    #[track_caller]
349    pub fn close(code: Option<CloseCode>, reason: &str) -> Self {
350        let mut payload = BytesMut::with_capacity((2 + reason.len()) * usize::from(code.is_some()));
351
352        if let Some(code) = code {
353            assert!(!code.is_reserved());
354            payload.put_u16(code.into());
355
356            assert!(reason.len() <= 123);
357            payload.extend_from_slice(reason.as_bytes());
358        }
359
360        Self {
361            opcode: OpCode::Close,
362            payload: payload.into(),
363        }
364    }
365
366    /// Create a new ping message.
367    ///
368    /// # Panics
369    /// If the payload exceeds 125 bytes, the protocol-imposed limit.
370    #[must_use]
371    #[track_caller]
372    pub fn ping<P: Into<Payload>>(payload: P) -> Self {
373        let payload = payload.into();
374        assert!(payload.len() <= 125);
375        Self {
376            opcode: OpCode::Ping,
377            payload,
378        }
379    }
380
381    /// Create a new pong message.
382    ///
383    /// # Panics
384    /// If the payload exceeds 125 bytes, the protocol-imposed limit.
385    #[must_use]
386    #[track_caller]
387    pub fn pong<P: Into<Payload>>(payload: P) -> Self {
388        let payload = payload.into();
389        assert!(payload.len() <= 125);
390        Self {
391            opcode: OpCode::Pong,
392            payload,
393        }
394    }
395
396    /// Whether the message is a text message.
397    #[must_use]
398    pub fn is_text(&self) -> bool {
399        self.opcode == OpCode::Text
400    }
401
402    /// Whether the message is a binary message.
403    #[must_use]
404    pub fn is_binary(&self) -> bool {
405        self.opcode == OpCode::Binary
406    }
407
408    /// Whether the message is a close message.
409    #[must_use]
410    pub fn is_close(&self) -> bool {
411        self.opcode == OpCode::Close
412    }
413
414    /// Whether the message is a ping message.
415    #[must_use]
416    pub fn is_ping(&self) -> bool {
417        self.opcode == OpCode::Ping
418    }
419
420    /// Whether the message is a pong message.
421    #[must_use]
422    pub fn is_pong(&self) -> bool {
423        self.opcode == OpCode::Pong
424    }
425
426    /// Returns the message payload and consumes the message, regardless of
427    /// type.
428    #[must_use]
429    pub fn into_payload(self) -> Payload {
430        self.payload
431    }
432
433    /// Returns a reference to the message payload, regardless of message type.
434    pub fn as_payload(&self) -> &Payload {
435        &self.payload
436    }
437
438    /// Returns a reference to the message payload as a string if it is a text
439    /// message.
440    ///
441    /// # Panics
442    ///
443    /// This method will panic when the message was created via
444    /// [`Message::text`] with invalid UTF-8.
445    pub fn as_text(&self) -> Option<&str> {
446        // SAFETY: Received messages were validated to be valid UTF-8, otherwise
447        // we check if it is valid UTF-8.
448        (self.opcode == OpCode::Text).then(|| {
449            assert!(
450                self.payload.utf8_validated || utf8::parse_str(&self.payload).is_ok(),
451                "called as_text on message created from payload with invalid utf-8"
452            );
453            unsafe { std::str::from_utf8_unchecked(&self.payload) }
454        })
455    }
456
457    /// Returns the [`CloseCode`] and close reason if the message is a close
458    /// message.
459    pub fn as_close(&self) -> Option<(CloseCode, &str)> {
460        (self.opcode == OpCode::Close).then(|| {
461            let code = if self.payload.is_empty() {
462                CloseCode::NO_STATUS_RECEIVED
463            } else {
464                // SAFETY: Opcode is Close with a non-empty payload so it's at least 2 bytes
465                // long
466                unsafe {
467                    CloseCode::try_from(u16::from_be_bytes(
468                        self.payload
469                            .get_unchecked(0..2)
470                            .try_into()
471                            .unwrap_unchecked(),
472                    ))
473                    .unwrap_unchecked()
474                }
475            };
476
477            // SAFETY: Opcode is Close so the rest of the payload is valid UTF-8
478            let reason =
479                unsafe { std::str::from_utf8_unchecked(self.payload.get(2..).unwrap_or_default()) };
480
481            (code, reason)
482        })
483    }
484
485    /// Returns an iterator over frames of `frame_size` length to split this
486    /// message into.
487    pub(super) fn into_frames(self, frame_size: usize) -> MessageFrames {
488        MessageFrames {
489            frame_size,
490            payload: self.payload,
491            opcode: self.opcode,
492        }
493    }
494}
495
496/// Iterator over frames of a chunked message.
497pub(super) struct MessageFrames {
498    /// Iterator over payload chunks.
499    frame_size: usize,
500    /// The full message payload this iterates over.
501    payload: Payload,
502    /// Opcode for the next frame.
503    opcode: OpCode,
504}
505
506impl Iterator for MessageFrames {
507    type Item = Frame;
508
509    fn next(&mut self) -> Option<Self::Item> {
510        let is_empty = self.payload.is_empty() && self.opcode == OpCode::Continuation;
511
512        (!is_empty).then(|| {
513            let payload = self
514                .payload
515                .split_to(self.frame_size.min(self.payload.len()));
516
517            Frame {
518                opcode: replace(&mut self.opcode, OpCode::Continuation),
519                is_final: self.payload.is_empty(),
520                payload,
521            }
522        })
523    }
524}
525
526/// Configuration for limitations on reading of [`Message`]s from a
527/// [`WebSocketStream`] to prevent high memory usage caused by malicious actors.
528///
529/// [`WebSocketStream`]: super::WebSocketStream
530#[derive(Debug, Clone, Copy)]
531pub struct Limits {
532    /// The maximum allowed payload length. The default is 64 MiB.
533    pub(super) max_payload_len: usize,
534}
535
536impl Limits {
537    /// A limit configuration without any limits.
538    #[must_use]
539    pub fn unlimited() -> Self {
540        Self {
541            max_payload_len: usize::MAX,
542        }
543    }
544
545    /// Sets the maximum allowed payload length. `None` equals no limit.
546    ///
547    /// The default is 64 MiB.
548    #[must_use]
549    pub fn max_payload_len(mut self, size: Option<usize>) -> Self {
550        self.set_max_payload_len(size);
551
552        self
553    }
554
555    /// See [`max_payload_len`](Self::max_payload_len).
556    pub fn set_max_payload_len(&mut self, size: Option<usize>) {
557        self.max_payload_len = size.unwrap_or(usize::MAX);
558    }
559}
560
561impl Default for Limits {
562    fn default() -> Self {
563        Self {
564            max_payload_len: 64 * 1024 * 1024,
565        }
566    }
567}
568
569/// Low-level configuration for a [`WebSocketStream`] that allows configuring
570/// behavior for sending and receiving messages.
571///
572/// [`WebSocketStream`]: super::WebSocketStream
573#[derive(Debug, Clone, Copy)]
574pub struct Config {
575    /// Frame payload size to split outgoing messages into.
576    ///
577    /// Consider decreasing this if the remote imposes a limit on the frame
578    /// payload size. The default is 4MiB.
579    pub(super) frame_size: usize,
580    /// Threshold of queued up bytes after which the underlying I/O is flushed
581    /// before the sink is declared ready. The default is 8 KiB.
582    pub(super) flush_threshold: usize,
583}
584
585impl Config {
586    /// Set the frame payload size to split outgoing messages into.
587    ///
588    /// Consider decreasing this if the remote imposes a limit on the frame
589    /// payload size. The default is 4MiB.
590    ///
591    /// # Panics
592    ///
593    /// If `frame_size` is `0`.
594    #[must_use]
595    pub fn frame_size(mut self, frame_size: usize) -> Self {
596        assert_ne!(frame_size, 0, "frame_size must be non-zero");
597        self.frame_size = frame_size;
598
599        self
600    }
601
602    /// Sets the threshold of queued up bytes after which the underlying I/O is
603    /// flushed before the sink is declared ready. The default is 8 KiB.
604    #[must_use]
605    pub fn flush_threshold(mut self, threshold: usize) -> Self {
606        self.flush_threshold = threshold;
607
608        self
609    }
610}
611
612impl Default for Config {
613    fn default() -> Self {
614        Self {
615            frame_size: 4 * 1024 * 1024,
616            flush_threshold: 8 * 1024,
617        }
618    }
619}
620
621/// Role assumed by the [`WebSocketStream`] in a connection.
622#[derive(Debug, PartialEq, Eq, Clone, Copy)]
623pub(crate) enum Role {
624    /// The client end.
625    Client,
626    /// The server end.
627    Server,
628}
629
630/// The connection state of the stream.
631#[derive(Debug, PartialEq)]
632pub(super) enum StreamState {
633    /// The connection is fully active and no close has been initiated.
634    Active,
635    /// The connection has been closed by the peer, but not yet acknowledged by
636    /// us.
637    ClosedByPeer,
638    /// The connection has been closed by us, but not yet acknowledged.
639    ClosedByUs,
640    /// The close has been acknowledged by the end that did not initiate the
641    /// close.
642    CloseAcknowledged,
643}
644
645/// A frame of a WebSocket [`Message`].
646#[derive(Clone, Debug)]
647pub(super) struct Frame {
648    /// The [`OpCode`] of the frame.
649    pub opcode: OpCode,
650    /// Whether this is the last frame of a message.
651    pub is_final: bool,
652    /// The payload bytes of the frame.
653    pub payload: Payload,
654}
655
656impl Frame {
657    /// Default close frame.
658    #[allow(clippy::declare_interior_mutable_const)]
659    pub const DEFAULT_CLOSE: Self = Self {
660        opcode: OpCode::Close,
661        is_final: true,
662        payload: Payload::from_static(&CloseCode::NORMAL_CLOSURE.0.get().to_be_bytes()),
663    };
664
665    /// Encode the frame head into `out`, returning a subslice where the mask
666    /// should be written to.
667    pub fn encode<'a>(&self, out: &'a mut [u8; 14]) -> &'a mut [u8; 4] {
668        out[0] = (u8::from(self.is_final) << 7) | u8::from(self.opcode);
669        let mask_slice = if u16::try_from(self.payload.len()).is_err() {
670            out[1] = 127;
671            let len = u64::try_from(self.payload.len()).unwrap();
672            out[2..10].copy_from_slice(&len.to_be_bytes());
673            &mut out[10..14]
674        } else if self.payload.len() > 125 {
675            out[1] = 126;
676            let len = u16::try_from(self.payload.len()).expect("checked by previous branch");
677            out[2..4].copy_from_slice(&len.to_be_bytes());
678            &mut out[4..8]
679        } else {
680            out[1] = u8::try_from(self.payload.len()).expect("checked by previous branch");
681            &mut out[2..6]
682        };
683        mask_slice.try_into().unwrap()
684    }
685}
686
687impl From<Message> for Frame {
688    fn from(value: Message) -> Self {
689        Self {
690            opcode: value.opcode,
691            is_final: true,
692            payload: value.payload,
693        }
694    }
695}
696
697impl From<&ProtocolError> for Frame {
698    fn from(val: &ProtocolError) -> Self {
699        match val {
700            ProtocolError::InvalidUtf8 => {
701                Message::close(Some(CloseCode::INVALID_FRAME_PAYLOAD_DATA), "invalid utf8")
702            }
703            _ => Message::close(Some(CloseCode::PROTOCOL_ERROR), val.as_str()),
704        }
705        .into()
706    }
707}