websocket_simple/
frame.rs

1use std::fmt;
2use std::mem::transmute;
3use std::io::{Cursor, Read, Write};
4use std::default::Default;
5use std::iter::FromIterator;
6
7use rand;
8
9use result::{Result, Error, Kind};
10use protocol::{OpCode, CloseCode};
11
12fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
13    let iter = buf.iter_mut().zip(mask.iter().cycle());
14    for (byte, &key) in iter {
15        *byte ^= key
16    }
17}
18
19#[inline]
20fn generate_mask() -> [u8; 4] {
21    unsafe { transmute(rand::random::<u32>()) }
22}
23
24/// A struct representing a WebSocket frame.
25#[derive(Debug, Clone)]
26pub struct Frame {
27    finished: bool,
28    rsv1: bool,
29    rsv2: bool,
30    rsv3: bool,
31    opcode: OpCode,
32
33    mask: Option<[u8; 4]>,
34
35    payload: Vec<u8>,
36}
37
38impl Frame {
39
40    /// Get the length of the frame.
41    /// This is the length of the header + the length of the payload.
42    #[inline]
43    pub fn len(&self) -> usize {
44        let mut header_length = 2;
45        let payload_len = self.payload().len();
46        if payload_len > 125 {
47            if payload_len <= u16::max_value() as usize {
48                header_length += 2;
49            } else {
50                header_length += 8;
51            }
52        }
53
54        if self.is_masked() {
55            header_length += 4;
56        }
57
58        header_length + payload_len
59    }
60
61    /// Test whether the frame is a final frame.
62    #[inline]
63    pub fn is_final(&self) -> bool {
64        self.finished
65    }
66
67    /// Test whether the first reserved bit is set.
68    #[inline]
69    pub fn has_rsv1(&self) -> bool {
70        self.rsv1
71    }
72
73    /// Test whether the second reserved bit is set.
74    #[inline]
75    pub fn has_rsv2(&self) -> bool {
76        self.rsv2
77    }
78
79    /// Test whether the third reserved bit is set.
80    #[inline]
81    pub fn has_rsv3(&self) -> bool {
82        self.rsv3
83    }
84
85    /// Get the OpCode of the frame.
86    #[inline]
87    pub fn opcode(&self) -> OpCode {
88        self.opcode
89    }
90
91    /// Test whether this is a control frame.
92    #[inline]
93    pub fn is_control(&self) -> bool {
94        self.opcode.is_control()
95    }
96
97    /// Get a reference to the frame's payload.
98    #[inline]
99    pub fn payload(&self) -> &Vec<u8> {
100        &self.payload
101    }
102
103    // Test whether the frame is masked.
104    #[doc(hidden)]
105    #[inline]
106    pub fn is_masked(&self) -> bool {
107        self.mask.is_some()
108    }
109
110    // Get an optional reference to the frame's mask.
111    #[doc(hidden)]
112    #[allow(dead_code)]
113    #[inline]
114    pub fn mask(&self) -> Option<&[u8; 4]> {
115        self.mask.as_ref()
116    }
117
118    /// Make this frame a final frame.
119    #[allow(dead_code)]
120    #[inline]
121    pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
122        self.finished = is_final;
123        self
124    }
125
126    /// Set the first reserved bit.
127    #[inline]
128    pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
129        self.rsv1 = has_rsv1;
130        self
131    }
132
133    /// Set the second reserved bit.
134    #[inline]
135    pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
136        self.rsv2 = has_rsv2;
137        self
138    }
139
140    /// Set the third reserved bit.
141    #[inline]
142    pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
143        self.rsv3 = has_rsv3;
144        self
145    }
146
147    /// Set the OpCode.
148    #[allow(dead_code)]
149    #[inline]
150    pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
151        self.opcode = opcode;
152        self
153    }
154
155    /// Edit the frame's payload.
156    #[allow(dead_code)]
157    #[inline]
158    pub fn payload_mut(&mut self) -> &mut Vec<u8> {
159        &mut self.payload
160    }
161
162    // Generate a new mask for this frame.
163    //
164    // This method simply generates and stores the mask. It does not change the payload data.
165    // Instead, the payload data will be masked with the generated mask when the frame is sent
166    // to the other endpoint.
167    #[doc(hidden)]
168    #[inline]
169    pub fn set_mask(&mut self) -> &mut Frame {
170        self.mask = Some(generate_mask());
171        self
172    }
173
174    // This method unmasks the payload and should only be called on frames that are actually
175    // masked. In other words, those frames that have just been received from a client endpoint.
176    #[doc(hidden)]
177    #[inline]
178    pub fn remove_mask(&mut self) -> &mut Frame {
179        self.mask.and_then(|mask| {
180            Some(apply_mask(&mut self.payload, &mask))
181        });
182        self.mask = None;
183        self
184    }
185
186    /// Consume the frame into its payload.
187    pub fn into_data(self) -> Vec<u8> {
188        self.payload
189    }
190
191    /// Create a new data frame.
192    #[inline]
193    pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
194        debug_assert!(match code {
195            OpCode::Text | OpCode::Binary | OpCode::Continue => true,
196            _ => false,
197        }, "Invalid opcode for data frame.");
198
199        Frame {
200            finished: finished,
201            opcode: code,
202            payload: data,
203            .. Frame::default()
204        }
205    }
206
207    /// Create a new Pong control frame.
208    #[inline]
209    pub fn pong(data: Vec<u8>) -> Frame {
210        Frame {
211            opcode: OpCode::Pong,
212            payload: data,
213            .. Frame::default()
214        }
215    }
216
217    /// Create a new Ping control frame.
218    #[inline]
219    pub fn ping(data: Vec<u8>) -> Frame {
220        Frame {
221            opcode: OpCode::Ping,
222            payload: data,
223            .. Frame::default()
224        }
225    }
226
227    /// Create a new Close control frame.
228    #[inline]
229    pub fn close(code: CloseCode, reason: &str) -> Frame {
230        let raw: [u8; 2] = unsafe {
231            let u: u16 = code.into();
232            transmute(u.to_be())
233        };
234
235        let payload = if let CloseCode::Empty = code {
236            Vec::new()
237        } else {
238            Vec::from_iter(
239                raw[..].iter()
240                       .chain(reason.as_bytes().iter())
241                       .map(|&b| b))
242        };
243
244        Frame {
245            payload: payload,
246            .. Frame::default()
247        }
248    }
249
250    /// Parse the input stream into a frame.
251    pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> {
252        let size = cursor.get_ref().len() as u64 - cursor.position();
253        let initial = cursor.position();
254        trace!("Position in buffer {}", initial);
255
256        let mut head = [0u8; 2];
257        if try!(cursor.read(&mut head)) != 2 {
258            cursor.set_position(initial);
259            return Ok(None)
260        }
261
262        trace!("Parsed headers {:?}", head);
263
264        let first = head[0];
265        let second = head[1];
266        trace!("First: {:b}", first);
267        trace!("Second: {:b}", second);
268
269        let finished = first & 0x80 != 0;
270
271        let rsv1 = first & 0x40 != 0;
272        let rsv2 = first & 0x20 != 0;
273        let rsv3 = first & 0x10 != 0;
274
275        let opcode = OpCode::from(first & 0x0F);
276        trace!("Opcode: {:?}", opcode);
277
278        let masked = second & 0x80 != 0;
279        trace!("Masked: {:?}", masked);
280
281        let mut header_length = 2;
282
283        let mut length = (second & 0x7F) as u64;
284
285        if length == 126 {
286            let mut length_bytes = [0u8; 2];
287            if try!(cursor.read(&mut length_bytes)) != 2 {
288                cursor.set_position(initial);
289                return Ok(None)
290            }
291
292            length = unsafe {
293                let mut wide: u16 = transmute(length_bytes);
294                wide = u16::from_be(wide);
295                wide
296            } as u64;
297            header_length += 2;
298        } else if length == 127 {
299            let mut length_bytes = [0u8; 8];
300            if try!(cursor.read(&mut length_bytes)) != 8 {
301                cursor.set_position(initial);
302                return Ok(None)
303            }
304
305            unsafe { length = transmute(length_bytes); }
306            length = u64::from_be(length);
307            header_length += 8;
308        }
309        trace!("Payload length: {}", length);
310
311        let mask = if masked {
312            let mut mask_bytes = [0u8; 4];
313            if try!(cursor.read(&mut mask_bytes)) != 4 {
314                cursor.set_position(initial);
315                return Ok(None)
316            } else {
317                header_length += 4;
318                Some(mask_bytes)
319            }
320        } else {
321            None
322        };
323
324        if size < length + header_length {
325            cursor.set_position(initial);
326            return Ok(None)
327        }
328
329        let mut data = vec![0; length as usize];
330        try!(cursor.read(&mut data[..]));
331
332        // Disallow bad opcode
333        if let OpCode::Bad = opcode {
334            return Err(Error::new(Kind::Protocol, format!("Encountered invalid opcode: {}", first & 0x0F)))
335        }
336
337        // control frames must have length <= 125
338        match opcode {
339            OpCode::Ping | OpCode::Pong if length > 125 => {
340                return Err(Error::new(Kind::Protocol, format!("Rejected WebSocket handshake.Received control frame with length: {}.", length)))
341            }
342            OpCode::Close if length > 125 => {
343                debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
344                return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125.")))
345            }
346            _ => ()
347        }
348
349        let frame = Frame {
350            finished: finished,
351            rsv1: rsv1,
352            rsv2: rsv2,
353            rsv3: rsv3,
354            opcode: opcode,
355            mask: mask,
356            payload: data,
357        };
358
359
360        Ok(Some(frame))
361    }
362
363    /// Write a frame out to a buffer
364    pub fn format<W>(&mut self, w: &mut W) -> Result<()>
365        where W: Write
366    {
367        let mut one = 0u8;
368        let code: u8 = self.opcode.into();
369        if self.is_final() {
370            one |= 0x80;
371        }
372        if self.has_rsv1() {
373            one |= 0x40;
374        }
375        if self.has_rsv2() {
376            one |= 0x20;
377        }
378        if self.has_rsv3() {
379            one |= 0x10;
380        }
381        one |= code;
382
383        let mut two = 0u8;
384
385        if self.is_masked() {
386            two |= 0x80;
387        }
388
389        if self.payload.len() < 126 {
390            two |= self.payload.len() as u8;
391            let headers = [one, two];
392            try!(w.write(&headers));
393        } else if self.payload.len() <= 65535 {
394            two |= 126;
395            let length_bytes: [u8; 2] = unsafe {
396                let short = self.payload.len() as u16;
397                transmute(short.to_be())
398            };
399            let headers = [one, two, length_bytes[0], length_bytes[1]];
400            try!(w.write(&headers));
401        } else {
402            two |= 127;
403            let length_bytes: [u8; 8] = unsafe {
404                let long = self.payload.len() as u64;
405                transmute(long.to_be())
406            };
407            let headers = [
408                one,
409                two,
410                length_bytes[0],
411                length_bytes[1],
412                length_bytes[2],
413                length_bytes[3],
414                length_bytes[4],
415                length_bytes[5],
416                length_bytes[6],
417                length_bytes[7],
418            ];
419            try!(w.write(&headers));
420        }
421
422        if self.is_masked() {
423            let mask = self.mask.take().unwrap();
424            apply_mask(&mut self.payload, &mask);
425            try!(w.write(&mask));
426        }
427
428        try!(w.write(&self.payload));
429        Ok(())
430    }
431}
432
433impl Default for Frame {
434    fn default() -> Frame {
435        Frame {
436            finished: true,
437            rsv1: false,
438            rsv2: false,
439            rsv3: false,
440            opcode: OpCode::Close,
441            mask: None,
442            payload: Vec::new(),
443        }
444    }
445}
446
447impl fmt::Display for Frame {
448    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
449        write!(f,
450            "
451<FRAME>
452final: {}
453reserved: {} {} {}
454opcode: {}
455length: {}
456payload length: {}
457payload: 0x{}
458            ",
459            self.finished,
460            self.rsv1,
461            self.rsv2,
462            self.rsv3,
463            self.opcode,
464            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
465            self.len(),
466            self.payload.len(),
467            self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>())
468    }
469}
470
471mod test {
472    #![allow(unused_imports, unused_variables, dead_code)]
473    use super::*;
474    use protocol::OpCode;
475
476    #[test]
477    fn display_frame() {
478        let f = Frame::message("hi there".into(), OpCode::Text, true);
479        let view = format!("{}", f);
480        view.contains("payload:");
481    }
482}