websocket_codec/
frame.rs

1use std::convert::TryFrom;
2use std::mem;
3use std::usize;
4
5use byteorder::{BigEndian, ByteOrder, NativeEndian};
6use bytes::BytesMut;
7use tokio_util::codec::{Decoder, Encoder};
8
9use crate::mask::Mask;
10use crate::{Error, Result};
11
12/// Describes the length of the payload data within an individual WebSocket frame.
13#[derive(Copy, Clone, Debug, PartialEq)]
14pub enum DataLength {
15    /// Holds the length of a payload of 125 bytes or shorter.
16    Small(u8),
17    /// Holds the length of a payload between 126 and 65535 bytes.
18    Medium(u16),
19    /// Holds the length of a payload between 65536 and 2^63 bytes.
20    Large(u64),
21}
22
23impl From<u64> for DataLength {
24    fn from(n: u64) -> Self {
25        if n <= 125 {
26            Self::Small(n as u8)
27        } else if n <= 65535 {
28            Self::Medium(n as u16)
29        } else {
30            Self::Large(n)
31        }
32    }
33}
34
35impl TryFrom<DataLength> for u64 {
36    type Error = Error;
37
38    fn try_from(len: DataLength) -> Result<Self> {
39        match len {
40            DataLength::Small(n) => Ok(n as u64),
41            DataLength::Medium(n) => {
42                if n <= 125 {
43                    return Err(format!("payload length {} should not be represented using 16 bits", n).into());
44                }
45
46                Ok(n as u64)
47            }
48            DataLength::Large(n) => {
49                if n <= 65535 {
50                    return Err(format!("payload length {} should not be represented using 64 bits", n).into());
51                }
52
53                if n >= 0x8000_0000_0000_0000 {
54                    return Err(format!("frame is too long: {} bytes ({:x})", n, n).into());
55                }
56
57                Ok(n as u64)
58            }
59        }
60    }
61}
62
63impl From<usize> for DataLength {
64    fn from(n: usize) -> Self {
65        Self::from(n as u64)
66    }
67}
68
69impl TryFrom<DataLength> for usize {
70    type Error = Error;
71
72    fn try_from(len: DataLength) -> Result<Self> {
73        let len = u64::try_from(len)?;
74        if len > usize::MAX as u64 {
75            return Err(format!(
76                "frame of {} bytes can't be parsed on a {}-bit platform",
77                len,
78                mem::size_of::<usize>() / 8
79            )
80            .into());
81        }
82
83        Ok(len as usize)
84    }
85}
86
87/// Describes an individual frame within a WebSocket message at a low level.
88///
89/// The frame header is a lower level detail of the WebSocket protocol. At the application level,
90/// use [`Message`](struct.Message.html) structs and the [`MessageCodec`](struct.MessageCodec.html).
91#[derive(Clone, Debug, PartialEq)]
92pub struct FrameHeader {
93    pub(crate) fin: bool,
94    pub(crate) rsv: u8,
95    pub(crate) opcode: u8,
96    pub(crate) mask: Option<Mask>,
97    pub(crate) data_len: DataLength,
98}
99
100impl FrameHeader {
101    /// Returns a `FrameHeader` struct.
102    pub fn new(fin: bool, rsv: u8, opcode: u8, mask: Option<Mask>, data_len: DataLength) -> Self {
103        Self {
104            fin,
105            rsv,
106            opcode,
107            mask,
108            data_len,
109        }
110    }
111
112    /// Returns the WebSocket FIN bit, which indicates that this is the last frame in the message.
113    pub fn fin(&self) -> bool {
114        self.fin
115    }
116
117    /// Returns the WebSocket RSV1, RSV2 and RSV3 bits.
118    ///
119    /// The RSV bits may be used by extensions to the WebSocket protocol not exposed by this crate.
120    pub fn rsv(&self) -> u8 {
121        self.rsv
122    }
123
124    /// Returns the WebSocket opcode, which defines the interpretation of the frame payload data.
125    pub fn opcode(&self) -> u8 {
126        self.opcode
127    }
128
129    /// Returns the frame's mask.
130    pub fn mask(&self) -> Option<Mask> {
131        self.mask
132    }
133
134    /// Returns the length of the payload data that follows this header.
135    pub fn data_len(&self) -> DataLength {
136        self.data_len
137    }
138
139    /// Returns the total length of the frame header.
140    ///
141    /// The frame header is between 2 bytes and 10 bytes in length, depending on the presence of a mask
142    /// and the length of the payload data.
143    pub fn header_len(&self) -> usize {
144        let mut len = 1 /* fin|opcode */ + 1 /* mask|len1 */;
145        len += match self.data_len {
146            DataLength::Small(_) => 0,
147            DataLength::Medium(_) => 2,
148            DataLength::Large(_) => 8,
149        };
150
151        if self.mask.is_some() {
152            len += 4;
153        }
154
155        len
156    }
157
158    pub(crate) fn parse_slice(buf: &[u8]) -> Option<(Self, usize)> {
159        if buf.len() < 2 {
160            return None;
161        }
162
163        let fin_opcode = buf[0];
164        let mask_data_len = buf[1];
165        let mut header_len = 2;
166        let fin = (fin_opcode & 0x80) != 0;
167        let rsv = (fin_opcode & 0xf0) & !0x80;
168        let opcode = fin_opcode & 0x0f;
169
170        let (buf, data_len) = match mask_data_len & 0x7f {
171            127 => {
172                if buf.len() < 10 {
173                    return None;
174                }
175
176                header_len += 8;
177
178                (&buf[10..], DataLength::Large(BigEndian::read_u64(&buf[2..10])))
179            }
180            126 => {
181                if buf.len() < 4 {
182                    return None;
183                }
184
185                header_len += 2;
186
187                (&buf[4..], DataLength::Medium(BigEndian::read_u16(&buf[2..4])))
188            }
189            n => {
190                assert!(n < 126);
191                (&buf[2..], DataLength::Small(n))
192            }
193        };
194
195        let mask = if mask_data_len & 0x80 == 0 {
196            None
197        } else {
198            if buf.len() < 4 {
199                return None;
200            }
201
202            header_len += 4;
203            Some(NativeEndian::read_u32(buf).into())
204        };
205
206        let header = Self {
207            fin,
208            rsv,
209            opcode,
210            mask,
211            data_len,
212        };
213
214        debug_assert_eq!(header.header_len(), header_len);
215        Some((header, header_len))
216    }
217
218    pub(crate) fn write_to_slice(&self, dst: &mut [u8]) {
219        let FrameHeader {
220            fin,
221            rsv,
222            opcode,
223            mask,
224            data_len,
225        } = *self;
226
227        let mut fin_opcode = rsv | opcode;
228        if fin {
229            fin_opcode |= 0x80
230        };
231
232        dst[0] = fin_opcode;
233
234        let mask_bit = if mask.is_some() { 0x80 } else { 0 };
235
236        let dst = match data_len {
237            DataLength::Small(n) => {
238                dst[1] = mask_bit | n;
239                &mut dst[2..]
240            }
241            DataLength::Medium(n) => {
242                let (dst, rest) = dst.split_at_mut(4);
243                dst[1] = mask_bit | 126;
244                BigEndian::write_u16(&mut dst[2..4], n);
245                rest
246            }
247            DataLength::Large(n) => {
248                let (dst, rest) = dst.split_at_mut(10);
249                dst[1] = mask_bit | 127;
250                BigEndian::write_u64(&mut dst[2..10], n);
251                rest
252            }
253        };
254
255        if let Some(mask) = mask {
256            NativeEndian::write_u32(dst, mask.into());
257        }
258    }
259
260    pub(crate) fn write_to_bytes(&self, dst: &mut BytesMut) {
261        let data_len = match self.data_len {
262            DataLength::Small(n) => n as usize,
263            DataLength::Medium(n) => n as usize,
264            DataLength::Large(n) => n as usize,
265        };
266
267        let initial_len = dst.len();
268        let header_len = self.header_len();
269        dst.reserve(header_len + data_len);
270
271        unsafe {
272            dst.set_len(initial_len + header_len);
273        }
274
275        let dst_slice = &mut dst[initial_len..(initial_len + header_len)];
276        self.write_to_slice(dst_slice);
277    }
278}
279
280/// Tokio codec for the low-level header portion of WebSocket frames.
281/// This codec can send and receive [`FrameHeader`](struct.FrameHeader.html) structs.
282///
283/// The frame header is a lower level detail of the WebSocket protocol. At the application level,
284/// use [`Message`](struct.Message.html) structs and the [`MessageCodec`](struct.MessageCodec.html).
285pub struct FrameHeaderCodec;
286
287impl Decoder for FrameHeaderCodec {
288    type Item = FrameHeader;
289    type Error = Error;
290
291    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrameHeader>> {
292        use bytes::Buf;
293
294        Ok(FrameHeader::parse_slice(src.chunk()).map(|(header, header_len)| {
295            src.advance(header_len);
296            header
297        }))
298    }
299}
300
301impl Encoder<FrameHeader> for FrameHeaderCodec {
302    type Error = Error;
303
304    fn encode(&mut self, item: FrameHeader, dst: &mut BytesMut) -> Result<()> {
305        self.encode(&item, dst)
306    }
307}
308
309impl<'a> Encoder<&'a FrameHeader> for FrameHeaderCodec {
310    type Error = Error;
311
312    fn encode(&mut self, item: &'a FrameHeader, dst: &mut BytesMut) -> Result<()> {
313        item.write_to_bytes(dst);
314        Ok(())
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use assert_allocations::assert_allocated_bytes;
321    use bytes::BytesMut;
322    use tokio_util::codec::{Decoder, Encoder};
323
324    use crate::frame::{FrameHeader, FrameHeaderCodec};
325
326    #[quickcheck]
327    fn round_trips(fin: bool, is_text: bool, mask: Option<u32>, data_len: u16) {
328        let header = assert_allocated_bytes(0, || FrameHeader {
329            fin,
330            rsv: 0,
331            opcode: if is_text { 1 } else { 2 },
332            mask: mask.map(|n| n.into()),
333            data_len: (data_len as u64).into(),
334        });
335
336        assert_allocated_bytes((header.header_len() + data_len as usize).max(8), || {
337            let mut codec = FrameHeaderCodec;
338            let mut bytes = BytesMut::new();
339            codec.encode(&header, &mut bytes).unwrap();
340            let header_len = header.header_len();
341            assert_eq!(bytes.len(), header_len);
342
343            let header2 = codec.decode(&mut bytes).unwrap().unwrap();
344            assert_eq!(header2.header_len(), header_len);
345            assert_eq!(bytes.len(), 0);
346            assert_eq!(header, header2)
347        })
348    }
349}