1use std::default::Default;
2use std::fmt;
3use std::io::{Cursor, ErrorKind, Read, Write};
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use rand;
7
8use capped_buffer::CappedBuffer;
9use protocol::{CloseCode, OpCode};
10use result::{Error, Kind, Result};
11use stream::TryReadBuf;
12
13fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
14 let iter = buf.iter_mut().zip(mask.iter().cycle());
15 for (byte, &key) in iter {
16 *byte ^= key
17 }
18}
19
20#[derive(Debug, Clone)]
22pub struct Frame {
23 finished: bool,
24 rsv1: bool,
25 rsv2: bool,
26 rsv3: bool,
27 opcode: OpCode,
28
29 mask: Option<[u8; 4]>,
30
31 payload: Vec<u8>,
32}
33
34impl Frame {
35 #[inline]
38 pub fn len(&self) -> usize {
39 let mut header_length = 2;
40 let payload_len = self.payload().len();
41 if payload_len > 125 {
42 if payload_len <= u16::max_value() as usize {
43 header_length += 2;
44 } else {
45 header_length += 8;
46 }
47 }
48
49 if self.is_masked() {
50 header_length += 4;
51 }
52
53 header_length + payload_len
54 }
55
56 #[inline]
58 pub fn is_empty(&self) -> bool {
59 false
60 }
61
62 #[inline]
64 pub fn is_final(&self) -> bool {
65 self.finished
66 }
67
68 #[inline]
70 pub fn has_rsv1(&self) -> bool {
71 self.rsv1
72 }
73
74 #[inline]
76 pub fn has_rsv2(&self) -> bool {
77 self.rsv2
78 }
79
80 #[inline]
82 pub fn has_rsv3(&self) -> bool {
83 self.rsv3
84 }
85
86 #[inline]
88 pub fn opcode(&self) -> OpCode {
89 self.opcode
90 }
91
92 #[inline]
94 pub fn is_control(&self) -> bool {
95 self.opcode.is_control()
96 }
97
98 #[inline]
100 pub fn payload(&self) -> &Vec<u8> {
101 &self.payload
102 }
103
104 #[doc(hidden)]
106 #[inline]
107 pub fn is_masked(&self) -> bool {
108 self.mask.is_some()
109 }
110
111 #[doc(hidden)]
113 #[allow(dead_code)]
114 #[inline]
115 pub fn mask(&self) -> Option<&[u8; 4]> {
116 self.mask.as_ref()
117 }
118
119 #[allow(dead_code)]
121 #[inline]
122 pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
123 self.finished = is_final;
124 self
125 }
126
127 #[inline]
129 pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
130 self.rsv1 = has_rsv1;
131 self
132 }
133
134 #[inline]
136 pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
137 self.rsv2 = has_rsv2;
138 self
139 }
140
141 #[inline]
143 pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
144 self.rsv3 = has_rsv3;
145 self
146 }
147
148 #[allow(dead_code)]
150 #[inline]
151 pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
152 self.opcode = opcode;
153 self
154 }
155
156 #[allow(dead_code)]
158 #[inline]
159 pub fn payload_mut(&mut self) -> &mut Vec<u8> {
160 &mut self.payload
161 }
162
163 #[doc(hidden)]
169 #[inline]
170 pub fn set_mask(&mut self) -> &mut Frame {
171 self.mask = Some(rand::random());
172 self
173 }
174
175 #[doc(hidden)]
178 #[inline]
179 pub fn remove_mask(&mut self) -> &mut Frame {
180 self.mask
181 .take()
182 .map(|mask| apply_mask(&mut self.payload, &mask));
183 self
184 }
185
186 pub fn into_data(self) -> Vec<u8> {
188 self.payload
189 }
190
191 #[inline]
193 pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
194 debug_assert!(
195 match code {
196 OpCode::Text | OpCode::Binary | OpCode::Continue => true,
197 _ => false,
198 },
199 "Invalid opcode for data frame."
200 );
201
202 Frame {
203 finished,
204 opcode: code,
205 payload: data,
206 ..Frame::default()
207 }
208 }
209
210 #[inline]
212 pub fn pong(data: Vec<u8>) -> Frame {
213 Frame {
214 opcode: OpCode::Pong,
215 payload: data,
216 ..Frame::default()
217 }
218 }
219
220 #[inline]
222 pub fn ping(data: Vec<u8>) -> Frame {
223 Frame {
224 opcode: OpCode::Ping,
225 payload: data,
226 ..Frame::default()
227 }
228 }
229
230 #[inline]
232 pub fn close(code: CloseCode, reason: &str) -> Frame {
233 let payload = if let CloseCode::Empty = code {
234 Vec::new()
235 } else {
236 let u: u16 = code.into();
237 let raw = [(u >> 8) as u8, u as u8];
238 [&raw, reason.as_bytes()].concat()
239 };
240
241 Frame {
242 payload,
243 ..Frame::default()
244 }
245 }
246
247 pub fn parse(cursor: &mut Cursor<CappedBuffer>, max_payload_length: u64) -> Result<Option<Frame>> {
249 let size = cursor.get_ref().len() as u64 - cursor.position();
250 let initial = cursor.position();
251 trace!("Position in buffer {}", initial);
252
253 let mut head = [0u8; 2];
254 if cursor.read(&mut head)? != 2 {
255 cursor.set_position(initial);
256 return Ok(None);
257 }
258
259 trace!("Parsed headers {:?}", head);
260
261 let first = head[0];
262 let second = head[1];
263 trace!("First: {:b}", first);
264 trace!("Second: {:b}", second);
265
266 let finished = first & 0x80 != 0;
267
268 let rsv1 = first & 0x40 != 0;
269 let rsv2 = first & 0x20 != 0;
270 let rsv3 = first & 0x10 != 0;
271
272 let opcode = OpCode::from(first & 0x0F);
273 trace!("Opcode: {:?}", opcode);
274
275 let masked = second & 0x80 != 0;
276 trace!("Masked: {:?}", masked);
277
278 let mut header_length = 2;
279
280 let mut length = u64::from(second & 0x7F);
281
282 if let Some(length_nbytes) = match length {
283 126 => Some(2),
284 127 => Some(8),
285 _ => None,
286 } {
287 match cursor.read_uint::<BigEndian>(length_nbytes) {
288 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
289 cursor.set_position(initial);
290 return Ok(None);
291 }
292 Err(err) => {
293 return Err(Error::from(err));
294 }
295 Ok(read) => {
296 length = read;
297 }
298 };
299 header_length += length_nbytes as u64;
300 }
301 trace!("Payload length: {}", length);
302
303 if length > max_payload_length {
304 return Err(Error::new(
305 Kind::Protocol,
306 format!(
307 "Rejected frame with payload length exceeding defined max: {}.",
308 max_payload_length
309 ),
310 ));
311 }
312
313 let mask = if masked {
314 let mut mask_bytes = [0u8; 4];
315 if cursor.read(&mut mask_bytes)? != 4 {
316 cursor.set_position(initial);
317 return Ok(None);
318 } else {
319 header_length += 4;
320 Some(mask_bytes)
321 }
322 } else {
323 None
324 };
325
326 match length.checked_add(header_length) {
327 Some(l) if size < l => {
328 cursor.set_position(initial);
329 return Ok(None);
330 }
331 Some(_) => (),
332 None => return Ok(None),
333 };
334
335 let mut data = Vec::with_capacity(length as usize);
336 if length > 0 {
337 if let Some(read) = cursor.try_read_buf(&mut data)? {
338 debug_assert!(read == length as usize, "Read incorrect payload length!");
339 }
340 }
341
342 if let OpCode::Bad = opcode {
344 return Err(Error::new(
345 Kind::Protocol,
346 format!("Encountered invalid opcode: {}", first & 0x0F),
347 ));
348 }
349
350 match opcode {
352 OpCode::Ping | OpCode::Pong if length > 125 => {
353 return Err(Error::new(
354 Kind::Protocol,
355 format!(
356 "Rejected WebSocket handshake.Received control frame with length: {}.",
357 length
358 ),
359 ))
360 }
361 OpCode::Close if length > 125 => {
362 debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
363 return Ok(Some(Frame::close(
364 CloseCode::Protocol,
365 "Received close frame with payload length exceeding 125.",
366 )));
367 }
368 _ => (),
369 }
370
371 let frame = Frame {
372 finished,
373 rsv1,
374 rsv2,
375 rsv3,
376 opcode,
377 mask,
378 payload: data,
379 };
380
381 Ok(Some(frame))
382 }
383
384 pub fn format<W>(&mut self, w: &mut W) -> Result<()>
386 where
387 W: Write,
388 {
389 let mut one = 0u8;
390 let code: u8 = self.opcode.into();
391 if self.is_final() {
392 one |= 0x80;
393 }
394 if self.has_rsv1() {
395 one |= 0x40;
396 }
397 if self.has_rsv2() {
398 one |= 0x20;
399 }
400 if self.has_rsv3() {
401 one |= 0x10;
402 }
403 one |= code;
404
405 let mut two = 0u8;
406 if self.is_masked() {
407 two |= 0x80;
408 }
409
410 match self.payload.len() {
411 len if len < 126 => {
412 two |= len as u8;
413 }
414 len if len <= 65535 => {
415 two |= 126;
416 }
417 _ => {
418 two |= 127;
419 }
420 }
421 w.write_all(&[one, two])?;
422
423 if let Some(length_bytes) = match self.payload.len() {
424 len if len < 126 => None,
425 len if len <= 65535 => Some(2),
426 _ => Some(8),
427 } {
428 w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?;
429 }
430
431 if self.is_masked() {
432 let mask = self.mask.take().unwrap();
433 apply_mask(&mut self.payload, &mask);
434 w.write_all(&mask)?;
435 }
436
437 w.write_all(&self.payload)?;
438 Ok(())
439 }
440}
441
442impl Default for Frame {
443 fn default() -> Frame {
444 Frame {
445 finished: true,
446 rsv1: false,
447 rsv2: false,
448 rsv3: false,
449 opcode: OpCode::Close,
450 mask: None,
451 payload: Vec::new(),
452 }
453 }
454}
455
456impl fmt::Display for Frame {
457 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
458 write!(
459 f,
460 "
461<FRAME>
462final: {}
463reserved: {} {} {}
464opcode: {}
465length: {}
466payload length: {}
467payload: 0x{}
468 ",
469 self.finished,
470 self.rsv1,
471 self.rsv2,
472 self.rsv3,
473 self.opcode,
474 self.len(),
476 self.payload.len(),
477 self.payload
478 .iter()
479 .map(|byte| format!("{:x}", byte))
480 .collect::<String>()
481 )
482 }
483}
484
485mod test {
486 #![allow(unused_imports, unused_variables, dead_code)]
487 use super::*;
488 use protocol::OpCode;
489
490 #[test]
491 fn display_frame() {
492 let f = Frame::message("hi there".into(), OpCode::Text, true);
493 let view = format!("{}", f);
494 view.contains("payload:");
495 }
496}