tetsy_ws/
frame.rs

1use std::default::Default;
2use std::fmt;
3use std::io::{Cursor, ErrorKind, Read, Write};
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use rand;
7
8use capped_buffer::CappedBuffer;
9use protocol::{CloseCode, OpCode};
10use result::{Error, Kind, Result};
11use stream::TryReadBuf;
12
13fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
14    let iter = buf.iter_mut().zip(mask.iter().cycle());
15    for (byte, &key) in iter {
16        *byte ^= key
17    }
18}
19
20/// A struct representing a WebSocket frame.
21#[derive(Debug, Clone)]
22pub struct Frame {
23    finished: bool,
24    rsv1: bool,
25    rsv2: bool,
26    rsv3: bool,
27    opcode: OpCode,
28
29    mask: Option<[u8; 4]>,
30
31    payload: Vec<u8>,
32}
33
34impl Frame {
35    /// Get the length of the frame.
36    /// This is the length of the header + the length of the payload.
37    #[inline]
38    pub fn len(&self) -> usize {
39        let mut header_length = 2;
40        let payload_len = self.payload().len();
41        if payload_len > 125 {
42            if payload_len <= u16::max_value() as usize {
43                header_length += 2;
44            } else {
45                header_length += 8;
46            }
47        }
48
49        if self.is_masked() {
50            header_length += 4;
51        }
52
53        header_length + payload_len
54    }
55
56    /// Return `false`: a frame is never empty since it has a header.
57    #[inline]
58    pub fn is_empty(&self) -> bool {
59        false
60    }
61
62    /// Test whether the frame is a final frame.
63    #[inline]
64    pub fn is_final(&self) -> bool {
65        self.finished
66    }
67
68    /// Test whether the first reserved bit is set.
69    #[inline]
70    pub fn has_rsv1(&self) -> bool {
71        self.rsv1
72    }
73
74    /// Test whether the second reserved bit is set.
75    #[inline]
76    pub fn has_rsv2(&self) -> bool {
77        self.rsv2
78    }
79
80    /// Test whether the third reserved bit is set.
81    #[inline]
82    pub fn has_rsv3(&self) -> bool {
83        self.rsv3
84    }
85
86    /// Get the OpCode of the frame.
87    #[inline]
88    pub fn opcode(&self) -> OpCode {
89        self.opcode
90    }
91
92    /// Test whether this is a control frame.
93    #[inline]
94    pub fn is_control(&self) -> bool {
95        self.opcode.is_control()
96    }
97
98    /// Get a reference to the frame's payload.
99    #[inline]
100    pub fn payload(&self) -> &Vec<u8> {
101        &self.payload
102    }
103
104    // Test whether the frame is masked.
105    #[doc(hidden)]
106    #[inline]
107    pub fn is_masked(&self) -> bool {
108        self.mask.is_some()
109    }
110
111    // Get an optional reference to the frame's mask.
112    #[doc(hidden)]
113    #[allow(dead_code)]
114    #[inline]
115    pub fn mask(&self) -> Option<&[u8; 4]> {
116        self.mask.as_ref()
117    }
118
119    /// Make this frame a final frame.
120    #[allow(dead_code)]
121    #[inline]
122    pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
123        self.finished = is_final;
124        self
125    }
126
127    /// Set the first reserved bit.
128    #[inline]
129    pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
130        self.rsv1 = has_rsv1;
131        self
132    }
133
134    /// Set the second reserved bit.
135    #[inline]
136    pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
137        self.rsv2 = has_rsv2;
138        self
139    }
140
141    /// Set the third reserved bit.
142    #[inline]
143    pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
144        self.rsv3 = has_rsv3;
145        self
146    }
147
148    /// Set the OpCode.
149    #[allow(dead_code)]
150    #[inline]
151    pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
152        self.opcode = opcode;
153        self
154    }
155
156    /// Edit the frame's payload.
157    #[allow(dead_code)]
158    #[inline]
159    pub fn payload_mut(&mut self) -> &mut Vec<u8> {
160        &mut self.payload
161    }
162
163    // Generate a new mask for this frame.
164    //
165    // This method simply generates and stores the mask. It does not change the payload data.
166    // Instead, the payload data will be masked with the generated mask when the frame is sent
167    // to the other endpoint.
168    #[doc(hidden)]
169    #[inline]
170    pub fn set_mask(&mut self) -> &mut Frame {
171        self.mask = Some(rand::random());
172        self
173    }
174
175    // This method unmasks the payload and should only be called on frames that are actually
176    // masked. In other words, those frames that have just been received from a client endpoint.
177    #[doc(hidden)]
178    #[inline]
179    pub fn remove_mask(&mut self) -> &mut Frame {
180        self.mask
181            .take()
182            .map(|mask| apply_mask(&mut self.payload, &mask));
183        self
184    }
185
186    /// Consume the frame into its payload.
187    pub fn into_data(self) -> Vec<u8> {
188        self.payload
189    }
190
191    /// Create a new data frame.
192    #[inline]
193    pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
194        debug_assert!(
195            match code {
196                OpCode::Text | OpCode::Binary | OpCode::Continue => true,
197                _ => false,
198            },
199            "Invalid opcode for data frame."
200        );
201
202        Frame {
203            finished,
204            opcode: code,
205            payload: data,
206            ..Frame::default()
207        }
208    }
209
210    /// Create a new Pong control frame.
211    #[inline]
212    pub fn pong(data: Vec<u8>) -> Frame {
213        Frame {
214            opcode: OpCode::Pong,
215            payload: data,
216            ..Frame::default()
217        }
218    }
219
220    /// Create a new Ping control frame.
221    #[inline]
222    pub fn ping(data: Vec<u8>) -> Frame {
223        Frame {
224            opcode: OpCode::Ping,
225            payload: data,
226            ..Frame::default()
227        }
228    }
229
230    /// Create a new Close control frame.
231    #[inline]
232    pub fn close(code: CloseCode, reason: &str) -> Frame {
233        let payload = if let CloseCode::Empty = code {
234            Vec::new()
235        } else {
236            let u: u16 = code.into();
237            let raw = [(u >> 8) as u8, u as u8];
238            [&raw, reason.as_bytes()].concat()
239        };
240
241        Frame {
242            payload,
243            ..Frame::default()
244        }
245    }
246
247    /// Parse the input stream into a frame.
248    pub fn parse(cursor: &mut Cursor<CappedBuffer>, max_payload_length: u64) -> Result<Option<Frame>> {
249        let size = cursor.get_ref().len() as u64 - cursor.position();
250        let initial = cursor.position();
251        trace!("Position in buffer {}", initial);
252
253        let mut head = [0u8; 2];
254        if cursor.read(&mut head)? != 2 {
255            cursor.set_position(initial);
256            return Ok(None);
257        }
258
259        trace!("Parsed headers {:?}", head);
260
261        let first = head[0];
262        let second = head[1];
263        trace!("First: {:b}", first);
264        trace!("Second: {:b}", second);
265
266        let finished = first & 0x80 != 0;
267
268        let rsv1 = first & 0x40 != 0;
269        let rsv2 = first & 0x20 != 0;
270        let rsv3 = first & 0x10 != 0;
271
272        let opcode = OpCode::from(first & 0x0F);
273        trace!("Opcode: {:?}", opcode);
274
275        let masked = second & 0x80 != 0;
276        trace!("Masked: {:?}", masked);
277
278        let mut header_length = 2;
279
280        let mut length = u64::from(second & 0x7F);
281
282        if let Some(length_nbytes) = match length {
283            126 => Some(2),
284            127 => Some(8),
285            _ => None,
286        } {
287            match cursor.read_uint::<BigEndian>(length_nbytes) {
288                Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
289                    cursor.set_position(initial);
290                    return Ok(None);
291                }
292                Err(err) => {
293                    return Err(Error::from(err));
294                }
295                Ok(read) => {
296                    length = read;
297                }
298            };
299            header_length += length_nbytes as u64;
300        }
301        trace!("Payload length: {}", length);
302
303        if length > max_payload_length {
304            return Err(Error::new(
305                Kind::Protocol,
306                format!(
307                    "Rejected frame with payload length exceeding defined max: {}.",
308                    max_payload_length
309                ),
310            ));
311        }
312
313        let mask = if masked {
314            let mut mask_bytes = [0u8; 4];
315            if cursor.read(&mut mask_bytes)? != 4 {
316                cursor.set_position(initial);
317                return Ok(None);
318            } else {
319                header_length += 4;
320                Some(mask_bytes)
321            }
322        } else {
323            None
324        };
325
326        match length.checked_add(header_length) {
327            Some(l) if size < l => {
328                cursor.set_position(initial);
329                return Ok(None);
330            }
331            Some(_) => (),
332            None => return Ok(None),
333        };
334
335        let mut data = Vec::with_capacity(length as usize);
336        if length > 0 {
337            if let Some(read) = cursor.try_read_buf(&mut data)? {
338                debug_assert!(read == length as usize, "Read incorrect payload length!");
339            }
340        }
341
342        // Disallow bad opcode
343        if let OpCode::Bad = opcode {
344            return Err(Error::new(
345                Kind::Protocol,
346                format!("Encountered invalid opcode: {}", first & 0x0F),
347            ));
348        }
349
350        // control frames must have length <= 125
351        match opcode {
352            OpCode::Ping | OpCode::Pong if length > 125 => {
353                return Err(Error::new(
354                    Kind::Protocol,
355                    format!(
356                        "Rejected WebSocket handshake.Received control frame with length: {}.",
357                        length
358                    ),
359                ))
360            }
361            OpCode::Close if length > 125 => {
362                debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
363                return Ok(Some(Frame::close(
364                    CloseCode::Protocol,
365                    "Received close frame with payload length exceeding 125.",
366                )));
367            }
368            _ => (),
369        }
370
371        let frame = Frame {
372            finished,
373            rsv1,
374            rsv2,
375            rsv3,
376            opcode,
377            mask,
378            payload: data,
379        };
380
381        Ok(Some(frame))
382    }
383
384    /// Write a frame out to a buffer
385    pub fn format<W>(&mut self, w: &mut W) -> Result<()>
386    where
387        W: Write,
388    {
389        let mut one = 0u8;
390        let code: u8 = self.opcode.into();
391        if self.is_final() {
392            one |= 0x80;
393        }
394        if self.has_rsv1() {
395            one |= 0x40;
396        }
397        if self.has_rsv2() {
398            one |= 0x20;
399        }
400        if self.has_rsv3() {
401            one |= 0x10;
402        }
403        one |= code;
404
405        let mut two = 0u8;
406        if self.is_masked() {
407            two |= 0x80;
408        }
409
410        match self.payload.len() {
411            len if len < 126 => {
412                two |= len as u8;
413            }
414            len if len <= 65535 => {
415                two |= 126;
416            }
417            _ => {
418                two |= 127;
419            }
420        }
421        w.write_all(&[one, two])?;
422
423        if let Some(length_bytes) = match self.payload.len() {
424            len if len < 126 => None,
425            len if len <= 65535 => Some(2),
426            _ => Some(8),
427        } {
428            w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?;
429        }
430
431        if self.is_masked() {
432            let mask = self.mask.take().unwrap();
433            apply_mask(&mut self.payload, &mask);
434            w.write_all(&mask)?;
435        }
436
437        w.write_all(&self.payload)?;
438        Ok(())
439    }
440}
441
442impl Default for Frame {
443    fn default() -> Frame {
444        Frame {
445            finished: true,
446            rsv1: false,
447            rsv2: false,
448            rsv3: false,
449            opcode: OpCode::Close,
450            mask: None,
451            payload: Vec::new(),
452        }
453    }
454}
455
456impl fmt::Display for Frame {
457    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
458        write!(
459            f,
460            "
461<FRAME>
462final: {}
463reserved: {} {} {}
464opcode: {}
465length: {}
466payload length: {}
467payload: 0x{}
468            ",
469            self.finished,
470            self.rsv1,
471            self.rsv2,
472            self.rsv3,
473            self.opcode,
474            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
475            self.len(),
476            self.payload.len(),
477            self.payload
478                .iter()
479                .map(|byte| format!("{:x}", byte))
480                .collect::<String>()
481        )
482    }
483}
484
485mod test {
486    #![allow(unused_imports, unused_variables, dead_code)]
487    use super::*;
488    use protocol::OpCode;
489
490    #[test]
491    fn display_frame() {
492        let f = Frame::message("hi there".into(), OpCode::Text, true);
493        let view = format!("{}", f);
494        view.contains("payload:");
495    }
496}