1use std::convert::TryFrom;
2use std::result;
3use std::str::{self, Utf8Error};
4use std::usize;
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use tokio_util::codec::{Decoder, Encoder};
8
9use crate::frame::FrameHeader;
10use crate::mask::{self, Mask};
11use crate::opcode::Opcode;
12use crate::{Error, Result};
13
14#[derive(Clone, Debug, PartialEq)]
16pub struct Message {
17 opcode: Opcode,
18 data: Bytes,
19}
20
21impl Message {
22 pub fn new<B: Into<Bytes>>(opcode: Opcode, data: B) -> result::Result<Self, Utf8Error> {
27 let data = data.into();
28
29 if opcode.is_text() {
30 str::from_utf8(&data)?;
31 }
32
33 Ok(Message { opcode, data })
34 }
35
36 pub fn text<S: Into<String>>(data: S) -> Self {
38 Message {
39 opcode: Opcode::Text,
40 data: data.into().into(),
41 }
42 }
43
44 pub fn binary<B: Into<Bytes>>(data: B) -> Self {
46 Message {
47 opcode: Opcode::Binary,
48 data: data.into(),
49 }
50 }
51
52 pub(crate) fn header(&self, mask: Option<Mask>) -> FrameHeader {
53 FrameHeader {
54 fin: true,
55 rsv: 0,
56 opcode: self.opcode.into(),
57 mask,
58 data_len: self.data.len().into(),
59 }
60 }
61
62 pub fn close(reason: Option<(u16, String)>) -> Self {
67 let data = if let Some((code, reason)) = reason {
68 let reason: Bytes = reason.into();
69 let mut buf = BytesMut::new();
70 buf.reserve(2 + reason.len());
71 buf.put_u16(code);
72 buf.put(reason);
73 buf.freeze()
74 } else {
75 Bytes::new()
76 };
77
78 Message {
79 opcode: Opcode::Close,
80 data,
81 }
82 }
83
84 pub fn ping<B: Into<Bytes>>(data: B) -> Self {
88 Message {
89 opcode: Opcode::Ping,
90 data: data.into(),
91 }
92 }
93
94 pub fn pong<B: Into<Bytes>>(data: B) -> Self {
98 Message {
99 opcode: Opcode::Pong,
100 data: data.into(),
101 }
102 }
103
104 pub fn opcode(&self) -> Opcode {
106 self.opcode
107 }
108
109 pub fn data(&self) -> &Bytes {
111 &self.data
112 }
113
114 pub fn into_data(self) -> Bytes {
116 self.data
117 }
118
119 pub fn as_text(&self) -> Option<&str> {
122 if self.opcode.is_text() {
123 Some(unsafe { str::from_utf8_unchecked(&self.data) })
124 } else {
125 None
126 }
127 }
128}
129
130#[derive(Clone)]
132pub struct MessageCodec {
133 interrupted_message: Option<(Opcode, BytesMut)>,
134 use_mask: bool,
135}
136
137impl MessageCodec {
138 pub fn client() -> Self {
142 Self::with_masked_encode(true)
143 }
144
145 pub fn server() -> Self {
149 Self::with_masked_encode(false)
150 }
151
152 pub fn with_masked_encode(use_mask: bool) -> Self {
154 Self {
155 use_mask,
156 interrupted_message: None,
157 }
158 }
159}
160
161impl Decoder for MessageCodec {
162 type Item = Message;
163 type Error = Error;
164
165 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>> {
166 let mut state = self.interrupted_message.take();
167 let (opcode, data) = loop {
168 let (header, header_len) = if let Some(tuple) = FrameHeader::parse_slice(src) {
169 tuple
170 } else {
171 src.reserve(512);
174 self.interrupted_message = state;
175 return Ok(None);
176 };
177
178 let data_len = usize::try_from(header.data_len)?;
179 let frame_len = header_len + data_len;
180 if frame_len > src.remaining() {
181 if frame_len > usize::MAX - src.remaining() {
190 return Err(format!("frame is too long: {0} bytes ({0:x})", frame_len).into());
191 }
192
193 src.reserve(frame_len.min(0x4000_0000) + 512);
197
198 self.interrupted_message = state;
199 return Ok(None);
200 }
201
202 let mut data = src.split_to(frame_len);
204 data.advance(header_len);
205
206 let FrameHeader {
207 fin,
208 rsv,
209 opcode,
210 mask,
211 data_len: _data_len,
212 } = header;
213
214 if rsv != 0 {
215 return Err(format!("reserved bits are not supported: 0x{:x}", rsv).into());
216 }
217
218 if let Some(mask) = mask {
219 mask::mask_slice(&mut data, mask)
222 };
223
224 let opcode = if opcode == 0 {
225 None
226 } else {
227 let opcode = Opcode::try_from(opcode).ok_or_else(|| format!("opcode {} is not supported", opcode))?;
228 if opcode.is_control() && data_len >= 126 {
229 return Err(format!(
230 "control frames must be shorter than 126 bytes ({} bytes is too long)",
231 data_len
232 )
233 .into());
234 }
235
236 Some(opcode)
237 };
238
239 state = if let Some((partial_opcode, mut partial_data)) = state {
240 if let Some(opcode) = opcode {
241 if fin && opcode.is_control() {
242 self.interrupted_message = Some((partial_opcode, partial_data));
243 break (opcode, data);
244 }
245
246 return Err(format!("continuation frame must have continuation opcode, not {:?}", opcode).into());
247 } else {
248 partial_data.extend_from_slice(&data);
249
250 if fin {
251 break (partial_opcode, partial_data);
252 }
253
254 Some((partial_opcode, partial_data))
255 }
256 } else if let Some(opcode) = opcode {
257 if fin {
258 break (opcode, data);
259 }
260 if opcode.is_control() {
261 return Err("control frames must not be fragmented".into());
262 }
263 Some((opcode, data))
264 } else {
265 return Err("continuation must not be first frame".into());
266 }
267 };
268
269 Ok(Some(Message::new(opcode, data.freeze())?))
270 }
271}
272
273impl Encoder<Message> for MessageCodec {
274 type Error = Error;
275
276 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<()> {
277 self.encode(&item, dst)
278 }
279}
280
281impl<'a> Encoder<&'a Message> for MessageCodec {
282 type Error = Error;
283
284 fn encode(&mut self, item: &Message, dst: &mut BytesMut) -> Result<()> {
285 let mask = if self.use_mask { Some(Mask::new()) } else { None };
286 let header = item.header(mask);
287 header.write_to_bytes(dst);
288
289 if let Some(mask) = mask {
290 let offset = dst.len();
291 dst.reserve(item.data.len());
292
293 unsafe {
294 dst.set_len(offset + item.data.len());
295 }
296
297 mask::mask_slice_copy(&mut dst[offset..], &item.data, mask);
298 } else {
299 dst.put_slice(&item.data);
300 }
301
302 Ok(())
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use assert_allocations::assert_allocated_bytes;
309 use bytes::{BufMut, BytesMut};
310 use tokio_util::codec::{Decoder, Encoder};
311
312 use crate::frame::{FrameHeader, FrameHeaderCodec};
313 use crate::mask::{self, Mask};
314 use crate::message::{Message, MessageCodec};
315
316 #[quickcheck]
317 fn round_trips(is_text: bool, data: String) {
318 let data_len = data.len();
319
320 let message = assert_allocated_bytes(0, || {
321 if is_text {
322 Message::text(data)
323 } else {
324 Message::binary(data.into_bytes())
325 }
326 });
327
328 rand::thread_rng();
331
332 let header = message.header(Some(Mask::from(0)));
333 let frame_len = header.header_len() + data_len;
334 let mut bytes = BytesMut::new();
335 assert_allocated_bytes(frame_len.max(8), {
336 || {
337 MessageCodec::client()
338 .encode(&message, &mut bytes)
339 .expect("didn't expect MessageCodec::encode to return an error")
340 }
341 });
342
343 let mut src = bytes.split();
346
347 let message2 = assert_allocated_bytes(0, || {
348 MessageCodec::client()
349 .decode(&mut src)
350 .expect("didn't expect MessageCodec::decode to return an error")
351 .expect("expected buffer to contain the full frame")
352 });
353
354 assert_eq!(message, message2);
355 }
356
357 #[quickcheck]
358 fn round_trips_via_frame_header(is_text: bool, mask: Option<u32>, data: String) {
359 let header = assert_allocated_bytes(0, || {
360 FrameHeader {
361 fin: true, rsv: 0,
363 opcode: if is_text { 1 } else { 2 },
364 mask: mask.map(|n| n.into()),
365 data_len: data.len().into(),
366 }
367 });
368
369 let mut bytes = BytesMut::with_capacity(header.header_len() + data.len());
370 assert_allocated_bytes(0, || {
371 FrameHeaderCodec.encode(&header, &mut bytes).unwrap();
372
373 if let Some(mask) = header.mask {
374 let offset = bytes.len();
375 bytes.resize(offset + data.len(), 0);
376 mask::mask_slice_copy(&mut bytes[offset..], data.as_bytes(), mask);
377 } else {
378 bytes.put(data.as_bytes());
379 }
380 });
381
382 let mut src = bytes.split();
385
386 assert_allocated_bytes(0, || {
387 let message2 = MessageCodec::client()
388 .decode(&mut src)
389 .expect("didn't expect MessageCodec::decode to return an error")
390 .expect("expected buffer to contain the full frame");
391
392 assert_eq!(is_text, message2.as_text().is_some());
393 assert_eq!(data.as_bytes(), message2.data());
394 });
395 }
396
397 #[quickcheck]
398 fn reserves_buffer(is_text: bool, data: String) {
399 let message = if is_text {
400 Message::text(data)
401 } else {
402 Message::binary(data.into_bytes())
403 };
404
405 let mut bytes = BytesMut::new();
406 MessageCodec::client()
407 .encode(&message, &mut bytes)
408 .expect("didn't expect MessageCodec::encode to return an error");
409
410 let mut src = &bytes[..];
415 let mut decoder = MessageCodec::client();
416 let mut decoder_buf = BytesMut::new();
417 let message2 = loop {
418 if let Some(result) = decoder
419 .decode(&mut decoder_buf)
420 .expect("didn't expect MessageCodec::decode to return an error")
421 {
422 assert_eq!(0, decoder_buf.len(), "expected decoder to consume the whole buffer");
423 break result;
424 }
425
426 let n = decoder_buf.remaining_mut().min(src.len());
427 assert!(n > 0, "expected decoder to reserve at least one byte");
428 decoder_buf.put_slice(&src[..n]);
429 src = &src[n..];
430 };
431
432 assert_eq!(message, message2);
433 }
434
435 #[test]
436 fn frame_bigger_than_2_64_does_not_panic() {
437 let data: &[u8] = &[0, 127, 255, 255, 255, 255, 255, 255, 255, 255];
440 let mut data = BytesMut::from(data);
441 data.resize(4096, 0);
442
443 let message = MessageCodec::client()
444 .decode(&mut data)
445 .expect_err("expected decoder to return an error given a frame bigger than 2^64 bytes");
446
447 assert_eq!(
448 message.to_string(),
449 "frame is too long: 18446744073709551615 bytes (ffffffffffffffff)"
450 );
451 }
452
453 #[test]
454 fn frame_bigger_than_2_40_does_not_panic() {
455 let data: &[u8] = &[0, 255, 255, 255, 255, 255, 0, 0, 0, 255, 0, 0, 0, 0];
458 let mut data = BytesMut::from(data);
459 data.resize(4096, 0);
460
461 let message = MessageCodec::client()
462 .decode(&mut data)
463 .expect_err("expected decoder to return an error given a frame bigger than 2^40 bytes");
464
465 assert_eq!(
466 message.to_string(),
467 "frame is too long: 18446744069414584575 bytes (ffffffff000000ff)"
468 );
469 }
470
471 #[test]
472 fn roundtrips_multiple_messages() {
473 let mut buf = BytesMut::new();
477 let mut codec = MessageCodec::server();
478 codec.encode(Message::text("A"), &mut buf).unwrap();
479 codec.encode(Message::text("B"), &mut buf).unwrap();
480 assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), Message::text("A"));
481 assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), Message::text("B"));
482 }
483}