websocket_codec/
message.rs

1use std::convert::TryFrom;
2use std::result;
3use std::str::{self, Utf8Error};
4use std::usize;
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use tokio_util::codec::{Decoder, Encoder};
8
9use crate::frame::FrameHeader;
10use crate::mask::{self, Mask};
11use crate::opcode::Opcode;
12use crate::{Error, Result};
13
14/// A text string, a block of binary data or a WebSocket control frame.
15#[derive(Clone, Debug, PartialEq)]
16pub struct Message {
17    opcode: Opcode,
18    data: Bytes,
19}
20
21impl Message {
22    /// Creates a message from a `Bytes` object.
23    ///
24    /// The message can be tagged as text or binary. When the `opcode` parameter is [`Opcode::Text`](enum.Opcode.html)
25    /// this function validates the bytes in `data` and returns `Err` if they do not contain valid UTF-8 text.
26    pub fn new<B: Into<Bytes>>(opcode: Opcode, data: B) -> result::Result<Self, Utf8Error> {
27        let data = data.into();
28
29        if opcode.is_text() {
30            str::from_utf8(&data)?;
31        }
32
33        Ok(Message { opcode, data })
34    }
35
36    /// Creates a text message from a `String`.
37    pub fn text<S: Into<String>>(data: S) -> Self {
38        Message {
39            opcode: Opcode::Text,
40            data: data.into().into(),
41        }
42    }
43
44    /// Creates a binary message from any type that can be converted to `Bytes`, such as `&[u8]` or `Vec<u8>`.
45    pub fn binary<B: Into<Bytes>>(data: B) -> Self {
46        Message {
47            opcode: Opcode::Binary,
48            data: data.into(),
49        }
50    }
51
52    pub(crate) fn header(&self, mask: Option<Mask>) -> FrameHeader {
53        FrameHeader {
54            fin: true,
55            rsv: 0,
56            opcode: self.opcode.into(),
57            mask,
58            data_len: self.data.len().into(),
59        }
60    }
61
62    /// Creates a message that indicates the connection is about to be closed.
63    ///
64    /// The `reason` parameter is an optional numerical status code and text description. Valid reasons
65    /// may be defined by a particular WebSocket server.
66    pub fn close(reason: Option<(u16, String)>) -> Self {
67        let data = if let Some((code, reason)) = reason {
68            let reason: Bytes = reason.into();
69            let mut buf = BytesMut::new();
70            buf.reserve(2 + reason.len());
71            buf.put_u16(code);
72            buf.put(reason);
73            buf.freeze()
74        } else {
75            Bytes::new()
76        };
77
78        Message {
79            opcode: Opcode::Close,
80            data,
81        }
82    }
83
84    /// Creates a message requesting a pong response.
85    ///
86    /// The client can send one of these to request a pong response from the server.
87    pub fn ping<B: Into<Bytes>>(data: B) -> Self {
88        Message {
89            opcode: Opcode::Ping,
90            data: data.into(),
91        }
92    }
93
94    /// Creates a response to a ping message.
95    ///
96    /// The client can send one of these in response to a ping from the server.
97    pub fn pong<B: Into<Bytes>>(data: B) -> Self {
98        Message {
99            opcode: Opcode::Pong,
100            data: data.into(),
101        }
102    }
103
104    /// Returns this message's WebSocket opcode.
105    pub fn opcode(&self) -> Opcode {
106        self.opcode
107    }
108
109    /// Returns a reference to the data held in this message.
110    pub fn data(&self) -> &Bytes {
111        &self.data
112    }
113
114    /// Consumes the message, returning its data.
115    pub fn into_data(self) -> Bytes {
116        self.data
117    }
118
119    /// For messages with opcode [`Opcode::Text`](enum.Opcode.html), returns a reference to the text.
120    /// Returns `None` otherwise.
121    pub fn as_text(&self) -> Option<&str> {
122        if self.opcode.is_text() {
123            Some(unsafe { str::from_utf8_unchecked(&self.data) })
124        } else {
125            None
126        }
127    }
128}
129
130/// Tokio codec for WebSocket messages. This codec can send and receive [`Message`](struct.Message.html) structs.
131#[derive(Clone)]
132pub struct MessageCodec {
133    interrupted_message: Option<(Opcode, BytesMut)>,
134    use_mask: bool,
135}
136
137impl MessageCodec {
138    /// Creates a `MessageCodec` for a client.
139    ///
140    /// Encoded messages are masked.
141    pub fn client() -> Self {
142        Self::with_masked_encode(true)
143    }
144
145    /// Creates a `MessageCodec` for a server.
146    ///
147    /// Encoded messages are not masked.
148    pub fn server() -> Self {
149        Self::with_masked_encode(false)
150    }
151
152    /// Creates a `MessageCodec` while specifying whether to use message masking while encoding.
153    pub fn with_masked_encode(use_mask: bool) -> Self {
154        Self {
155            use_mask,
156            interrupted_message: None,
157        }
158    }
159}
160
161impl Decoder for MessageCodec {
162    type Item = Message;
163    type Error = Error;
164
165    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>> {
166        let mut state = self.interrupted_message.take();
167        let (opcode, data) = loop {
168            let (header, header_len) = if let Some(tuple) = FrameHeader::parse_slice(src) {
169                tuple
170            } else {
171                // The buffer isn't big enough for the frame header. Reserve additional space for a frame header,
172                // plus reasonable extensions.
173                src.reserve(512);
174                self.interrupted_message = state;
175                return Ok(None);
176            };
177
178            let data_len = usize::try_from(header.data_len)?;
179            let frame_len = header_len + data_len;
180            if frame_len > src.remaining() {
181                // The buffer contains the frame header but it's not big enough for the data. Reserve additional
182                // space for the frame data, plus the next frame header.
183                // Note that we guard against bad data that indicates an unreasonable frame length.
184
185                // If we reserved buffer space for the entire frame data in a single call, would the buffer exceed
186                // usize::MAX bytes in size?
187                // On a 64-bit platform we should not reach here as the usize::try_from line above enforces the
188                // max payload length detailed in the RFC of 2^63 bytes.
189                if frame_len > usize::MAX - src.remaining() {
190                    return Err(format!("frame is too long: {0} bytes ({0:x})", frame_len).into());
191                }
192
193                // We don't really reserve space for the entire frame data in a single call. If somebody is sending
194                // more than a gigabyte of data in a single frame then we'll still try to receive it, we'll just
195                // reserve in 1GB chunks.
196                src.reserve(frame_len.min(0x4000_0000) + 512);
197
198                self.interrupted_message = state;
199                return Ok(None);
200            }
201
202            // The buffer contains the frame header and all of the data. We can parse it and return Ok(Some(...)).
203            let mut data = src.split_to(frame_len);
204            data.advance(header_len);
205
206            let FrameHeader {
207                fin,
208                rsv,
209                opcode,
210                mask,
211                data_len: _data_len,
212            } = header;
213
214            if rsv != 0 {
215                return Err(format!("reserved bits are not supported: 0x{:x}", rsv).into());
216            }
217
218            if let Some(mask) = mask {
219                // Note: clients never need decode masked messages because masking is only used for client -> server frames.
220                // However this code is used to test round tripping of masked messages.
221                mask::mask_slice(&mut data, mask)
222            };
223
224            let opcode = if opcode == 0 {
225                None
226            } else {
227                let opcode = Opcode::try_from(opcode).ok_or_else(|| format!("opcode {} is not supported", opcode))?;
228                if opcode.is_control() && data_len >= 126 {
229                    return Err(format!(
230                        "control frames must be shorter than 126 bytes ({} bytes is too long)",
231                        data_len
232                    )
233                    .into());
234                }
235
236                Some(opcode)
237            };
238
239            state = if let Some((partial_opcode, mut partial_data)) = state {
240                if let Some(opcode) = opcode {
241                    if fin && opcode.is_control() {
242                        self.interrupted_message = Some((partial_opcode, partial_data));
243                        break (opcode, data);
244                    }
245
246                    return Err(format!("continuation frame must have continuation opcode, not {:?}", opcode).into());
247                } else {
248                    partial_data.extend_from_slice(&data);
249
250                    if fin {
251                        break (partial_opcode, partial_data);
252                    }
253
254                    Some((partial_opcode, partial_data))
255                }
256            } else if let Some(opcode) = opcode {
257                if fin {
258                    break (opcode, data);
259                }
260                if opcode.is_control() {
261                    return Err("control frames must not be fragmented".into());
262                }
263                Some((opcode, data))
264            } else {
265                return Err("continuation must not be first frame".into());
266            }
267        };
268
269        Ok(Some(Message::new(opcode, data.freeze())?))
270    }
271}
272
273impl Encoder<Message> for MessageCodec {
274    type Error = Error;
275
276    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<()> {
277        self.encode(&item, dst)
278    }
279}
280
281impl<'a> Encoder<&'a Message> for MessageCodec {
282    type Error = Error;
283
284    fn encode(&mut self, item: &Message, dst: &mut BytesMut) -> Result<()> {
285        let mask = if self.use_mask { Some(Mask::new()) } else { None };
286        let header = item.header(mask);
287        header.write_to_bytes(dst);
288
289        if let Some(mask) = mask {
290            let offset = dst.len();
291            dst.reserve(item.data.len());
292
293            unsafe {
294                dst.set_len(offset + item.data.len());
295            }
296
297            mask::mask_slice_copy(&mut dst[offset..], &item.data, mask);
298        } else {
299            dst.put_slice(&item.data);
300        }
301
302        Ok(())
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use assert_allocations::assert_allocated_bytes;
309    use bytes::{BufMut, BytesMut};
310    use tokio_util::codec::{Decoder, Encoder};
311
312    use crate::frame::{FrameHeader, FrameHeaderCodec};
313    use crate::mask::{self, Mask};
314    use crate::message::{Message, MessageCodec};
315
316    #[quickcheck]
317    fn round_trips(is_text: bool, data: String) {
318        let data_len = data.len();
319
320        let message = assert_allocated_bytes(0, || {
321            if is_text {
322                Message::text(data)
323            } else {
324                Message::binary(data.into_bytes())
325            }
326        });
327
328        // thread_rng performs a one-off memory allocation the first time it is used on a given thread.
329        // We make that allocation here, instead of inside the assert_allocated_bytes block below.
330        rand::thread_rng();
331
332        let header = message.header(Some(Mask::from(0)));
333        let frame_len = header.header_len() + data_len;
334        let mut bytes = BytesMut::new();
335        assert_allocated_bytes(frame_len.max(8), {
336            || {
337                MessageCodec::client()
338                    .encode(&message, &mut bytes)
339                    .expect("didn't expect MessageCodec::encode to return an error")
340            }
341        });
342
343        // We eagerly promote the BytesMut to KIND_ARC. This ensures we make a call to Box::new here,
344        // instead of inside the assert_allocated_bytes(0) block below.
345        let mut src = bytes.split();
346
347        let message2 = assert_allocated_bytes(0, || {
348            MessageCodec::client()
349                .decode(&mut src)
350                .expect("didn't expect MessageCodec::decode to return an error")
351                .expect("expected buffer to contain the full frame")
352        });
353
354        assert_eq!(message, message2);
355    }
356
357    #[quickcheck]
358    fn round_trips_via_frame_header(is_text: bool, mask: Option<u32>, data: String) {
359        let header = assert_allocated_bytes(0, || {
360            FrameHeader {
361                fin: true, // TODO test messages split across frames
362                rsv: 0,
363                opcode: if is_text { 1 } else { 2 },
364                mask: mask.map(|n| n.into()),
365                data_len: data.len().into(),
366            }
367        });
368
369        let mut bytes = BytesMut::with_capacity(header.header_len() + data.len());
370        assert_allocated_bytes(0, || {
371            FrameHeaderCodec.encode(&header, &mut bytes).unwrap();
372
373            if let Some(mask) = header.mask {
374                let offset = bytes.len();
375                bytes.resize(offset + data.len(), 0);
376                mask::mask_slice_copy(&mut bytes[offset..], data.as_bytes(), mask);
377            } else {
378                bytes.put(data.as_bytes());
379            }
380        });
381
382        // We eagerly promote the BytesMut to KIND_ARC. This ensures we make a call to Box::new here,
383        // instead of inside the assert_allocated_bytes(0) block below.
384        let mut src = bytes.split();
385
386        assert_allocated_bytes(0, || {
387            let message2 = MessageCodec::client()
388                .decode(&mut src)
389                .expect("didn't expect MessageCodec::decode to return an error")
390                .expect("expected buffer to contain the full frame");
391
392            assert_eq!(is_text, message2.as_text().is_some());
393            assert_eq!(data.as_bytes(), message2.data());
394        });
395    }
396
397    #[quickcheck]
398    fn reserves_buffer(is_text: bool, data: String) {
399        let message = if is_text {
400            Message::text(data)
401        } else {
402            Message::binary(data.into_bytes())
403        };
404
405        let mut bytes = BytesMut::new();
406        MessageCodec::client()
407            .encode(&message, &mut bytes)
408            .expect("didn't expect MessageCodec::encode to return an error");
409
410        // We don't check allocations around the MessageCodec::decode call below. We're deliberately
411        // supplying a minimal number of source bytes each time, so we expect lots of small
412        // allocations as decoder_buf is resized multiple times.
413
414        let mut src = &bytes[..];
415        let mut decoder = MessageCodec::client();
416        let mut decoder_buf = BytesMut::new();
417        let message2 = loop {
418            if let Some(result) = decoder
419                .decode(&mut decoder_buf)
420                .expect("didn't expect MessageCodec::decode to return an error")
421            {
422                assert_eq!(0, decoder_buf.len(), "expected decoder to consume the whole buffer");
423                break result;
424            }
425
426            let n = decoder_buf.remaining_mut().min(src.len());
427            assert!(n > 0, "expected decoder to reserve at least one byte");
428            decoder_buf.put_slice(&src[..n]);
429            src = &src[n..];
430        };
431
432        assert_eq!(message, message2);
433    }
434
435    #[test]
436    fn frame_bigger_than_2_64_does_not_panic() {
437        // A frame with data longer than 2^64 bytes is bigger than the entire address space,
438        // when the header is included.
439        let data: &[u8] = &[0, 127, 255, 255, 255, 255, 255, 255, 255, 255];
440        let mut data = BytesMut::from(data);
441        data.resize(4096, 0);
442
443        let message = MessageCodec::client()
444            .decode(&mut data)
445            .expect_err("expected decoder to return an error given a frame bigger than 2^64 bytes");
446
447        assert_eq!(
448            message.to_string(),
449            "frame is too long: 18446744073709551615 bytes (ffffffffffffffff)"
450        );
451    }
452
453    #[test]
454    fn frame_bigger_than_2_40_does_not_panic() {
455        // A frame longer than 2^40 bytes causes Vec::extend to trigger an error in
456        // the AddressSanitizer.
457        let data: &[u8] = &[0, 255, 255, 255, 255, 255, 0, 0, 0, 255, 0, 0, 0, 0];
458        let mut data = BytesMut::from(data);
459        data.resize(4096, 0);
460
461        let message = MessageCodec::client()
462            .decode(&mut data)
463            .expect_err("expected decoder to return an error given a frame bigger than 2^40 bytes");
464
465        assert_eq!(
466            message.to_string(),
467            "frame is too long: 18446744069414584575 bytes (ffffffff000000ff)"
468        );
469    }
470
471    #[test]
472    fn roundtrips_multiple_messages() {
473        // According to https://docs.rs/tokio-util/0.7.3/tokio_util/codec/index.html#the-encoder-trait
474        // the buffer given to the Encoder may already contain data.
475        // Therefore, we check whether writing two messages into the same buffer roundtrips correctly.
476        let mut buf = BytesMut::new();
477        let mut codec = MessageCodec::server();
478        codec.encode(Message::text("A"), &mut buf).unwrap();
479        codec.encode(Message::text("B"), &mut buf).unwrap();
480        assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), Message::text("A"));
481        assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), Message::text("B"));
482    }
483}