simple_stream/frame/
websocket.rs

1// Copyright 2015 Nathan Sizemore <nathanrsizemore@gmail.com>
2//
3// This Source Code Form is subject to the terms of the
4// Mozilla Public License, v. 2.0. If a copy of the MPL was not
5// distributed with this file, You can obtain one at
6// http://mozilla.org/MPL/2.0/.
7
8//! The `frame::websocket` module provides [RFC-6465][rfc-6455] support for websocket based
9//! streams. This module provides no support for the handshake part of the protocol, or any
10//! smarts about handling fragmentation messages. It simply encodes/decodes complete websocket
11//! frames.
12//!
13//! [rfc-6455]: https://tools.ietf.org/html/rfc6455
14
15use std::{fmt, mem};
16
17use super::{Frame, FrameBuilder};
18
19bitflags! {
20    #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
21    struct OpCode: u8 {
22        const CONTINUATION  = 0b0000_0000;
23        const TEXT          = 0b0000_0001;
24        const BINARY        = 0b0000_0010;
25        const CLOSE         = 0b0000_1000;
26        const PING          = 0b0000_1001;
27        const PONG          = 0b0000_1010;
28    }
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum FrameType {
33    Control,
34    Data,
35}
36
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum OpType {
39    Continuation,
40    Text,
41    Binary,
42    Close,
43    Ping,
44    Pong,
45}
46
47#[derive(Clone)]
48struct Header {
49    op_code: OpCode,
50    mask: bool,
51    payload_len: u64,
52    masking_key: [u8; 4],
53}
54
55#[derive(Clone)]
56struct Payload {
57    data: Vec<u8>,
58}
59
60#[derive(Clone)]
61pub struct WebSocketFrame {
62    frame_type: FrameType,
63    header: Header,
64    payload: Payload,
65}
66
67#[derive(Clone)]
68pub struct WebSocketFrameBuilder;
69
70impl FrameBuilder for WebSocketFrameBuilder {
71    fn from_bytes(buf: &mut Vec<u8>) -> Option<Box<dyn Frame>> {
72        if buf.len() < 5 {
73            return None;
74        }
75
76        let mut frame: WebSocketFrame = Default::default();
77
78        // OpCode and FrameType
79        const FIN_CLEAR_MASK: u8 = 0b0000_1111;
80        let op_byte = buf[0] & FIN_CLEAR_MASK;
81        match OpCode::from_bits(op_byte) {
82            Some(op_code) => {
83                if op_code == OpCode::CONTINUATION {
84                    frame.frame_type = FrameType::Data;
85                } else if op_code == OpCode::TEXT {
86                    frame.frame_type = FrameType::Data;
87                } else if op_code == OpCode::BINARY {
88                    frame.frame_type = FrameType::Data;
89                } else if op_code == OpCode::CLOSE {
90                    frame.frame_type = FrameType::Control;
91                } else if op_code == OpCode::PING {
92                    frame.frame_type = FrameType::Control;
93                } else if op_code == OpCode::PONG {
94                    frame.frame_type = FrameType::Control;
95                } else {
96                    unreachable!();
97                }
98
99                frame.header.op_code = op_code;
100            }
101            None => {
102                error!("Invalid OpCode bits: {:#b}", buf[0]);
103                return None;
104            }
105        }
106
107        trace!("{}", frame.op_type());
108
109        // Payload masked (If from client, must always be true)
110        let mask_bit = 0b1000_0000 & buf[1];
111        frame.header.mask = mask_bit > 0;
112
113        trace!("Frame masked: {}", frame.header.mask);
114
115        // Payload data length
116        let payload_len = 0b0111_1111 & buf[1];
117        let mut next_offset: usize = 2;
118        if payload_len <= 125 {
119            frame.header.payload_len = payload_len as u64;
120        } else if payload_len == 126 {
121            let mut len = (buf[2] as u16) << 8;
122            len |= buf[3] as u16;
123            frame.header.payload_len = len as u64;
124            next_offset = 4;
125        } else {
126            // We don't want to cause a panic
127            if buf.len() < 10 {
128                return None;
129            }
130
131            let mut len = (buf[2] as u64) << 56;
132            len |= (buf[3] as u64) << 48;
133            len |= (buf[4] as u64) << 40;
134            len |= (buf[5] as u64) << 32;
135            len |= (buf[6] as u64) << 24;
136            len |= (buf[7] as u64) << 16;
137            len |= (buf[8] as u64) << 8;
138            len |= buf[9] as u64;
139            frame.header.payload_len = len;
140            next_offset = 10;
141        }
142
143        trace!("Payload length: {}", frame.header.payload_len);
144
145        // Optional masking key
146        if frame.header.mask {
147            if buf.len() <= next_offset + 4 {
148                return None;
149            }
150            frame.header.masking_key[0] = buf[next_offset];
151            frame.header.masking_key[1] = buf[next_offset + 1];
152            frame.header.masking_key[2] = buf[next_offset + 2];
153            frame.header.masking_key[3] = buf[next_offset + 3];
154            next_offset += 4;
155        }
156
157        if buf.len() < next_offset + frame.header.payload_len as usize {
158            return None;
159        }
160
161        // Payload data
162        let len = frame.header.payload_len as usize;
163        frame
164            .payload
165            .data
166            .extend_from_slice(&buf[next_offset..(len + next_offset)]);
167
168        // Remove from buffer
169        let mut remainder = Vec::<u8>::with_capacity(buf.len() - frame.len_as_vec());
170        remainder.extend_from_slice(&buf[frame.len_as_vec()..buf.len()]);
171        mem::swap(buf, &mut remainder);
172
173        return Some(Box::new(frame));
174    }
175}
176
177impl WebSocketFrame {
178    pub fn new(buf: &[u8], frame_type: FrameType, op_type: OpType) -> WebSocketFrame {
179        WebSocketFrame {
180            frame_type,
181            header: Header {
182                op_code: match op_type {
183                    OpType::Continuation => OpCode::CONTINUATION,
184                    OpType::Text => OpCode::TEXT,
185                    OpType::Binary => OpCode::BINARY,
186                    OpType::Close => OpCode::CLOSE,
187                    OpType::Ping => OpCode::PING,
188                    OpType::Pong => OpCode::PONG,
189                },
190                mask: false,
191                payload_len: buf.len() as u64,
192                masking_key: [0u8; 4],
193            },
194            payload: Payload { data: buf.to_vec() },
195        }
196    }
197
198    pub fn op_type(&self) -> OpType {
199        match self.header.op_code {
200            OpCode::CONTINUATION => OpType::Continuation,
201            OpCode::TEXT => OpType::Text,
202            OpCode::BINARY => OpType::Binary,
203            OpCode::CLOSE => OpType::Close,
204            OpCode::PING => OpType::Ping,
205            OpCode::PONG => OpType::Pong,
206            _ => unreachable!(),
207        }
208    }
209
210    pub fn frame_type(&self) -> FrameType {
211        self.frame_type.clone()
212    }
213
214    pub fn is_masked(&self) -> bool {
215        self.header.mask
216    }
217
218    pub fn payload_unmasked(&self) -> Vec<u8> {
219        let len = self.payload.data.len();
220        let mut buf = Vec::<u8>::with_capacity(len);
221        for x in 0..len {
222            buf.push(self.payload.data[x] ^ self.header.masking_key[x % 4]);
223        }
224
225        buf
226    }
227}
228
229impl Frame for WebSocketFrame {
230    fn payload(&self) -> Vec<u8> {
231        if self.header.mask {
232            self.payload_unmasked()
233        } else {
234            self.payload.data.clone()
235        }
236    }
237
238    fn to_bytes(&self) -> Vec<u8> {
239        let mut buf = Vec::<u8>::with_capacity(self.len_as_vec());
240
241        // OpCode
242        const FIN: u8 = 0b1000_0000;
243        let op_code_with_fin = FIN | self.header.op_code.bits();
244        buf.push(op_code_with_fin);
245
246        // Mask and Payload len
247        let mask_bit: u8 = if self.header.mask {
248            0b1000_0000
249        } else {
250            0b0000_0000
251        };
252        let next_7_bits: u8 = if self.header.payload_len <= 125 {
253            self.header.payload_len as u8
254        } else if self.header.payload_len <= u16::MAX as u64 {
255            126u8
256        } else {
257            127u8
258        };
259        let next_byte: u8 = mask_bit | next_7_bits;
260        buf.push(next_byte);
261
262        // Optional payload len
263        if next_byte == 126 {
264            buf.push(((self.header.payload_len as u16) >> 8) as u8);
265            buf.push(self.header.payload_len as u8);
266        } else if next_byte == 127 {
267            buf.push((self.header.payload_len >> 56) as u8);
268            buf.push((self.header.payload_len >> 48) as u8);
269            buf.push((self.header.payload_len >> 40) as u8);
270            buf.push((self.header.payload_len >> 32) as u8);
271            buf.push((self.header.payload_len >> 24) as u8);
272            buf.push((self.header.payload_len >> 16) as u8);
273            buf.push((self.header.payload_len >> 8) as u8);
274            buf.push(self.header.payload_len as u8);
275        }
276
277        // Optional masking key
278        if self.header.mask {
279            buf.push(self.header.masking_key[0]);
280            buf.push(self.header.masking_key[1]);
281            buf.push(self.header.masking_key[2]);
282            buf.push(self.header.masking_key[3]);
283        }
284
285        // Payload data
286        buf.extend_from_slice(&self.payload.data[..]);
287
288        buf
289    }
290
291    fn len_as_vec(&self) -> usize {
292        let mut len = 0usize;
293
294        // OpCode
295        len += 1;
296
297        // Mask and paylaod length
298        len += 1;
299
300        // Extended Payload length
301        if self.header.payload_len > 125 && self.header.payload_len < u16::MAX as u64 {
302            len += 2;
303        } else if self.header.payload_len > u16::MAX as u64 {
304            len += 8;
305        }
306
307        // Optional masking key
308        if self.header.mask {
309            len += 4;
310        }
311
312        // Payload data
313        len += self.header.payload_len as usize;
314
315        len
316    }
317
318    fn as_mut_raw_erased(&self) -> *mut () {
319        let dup = Box::new(self.clone());
320        return Box::into_raw(dup) as *mut _ as *mut ();
321    }
322}
323
324impl Default for WebSocketFrame {
325    fn default() -> WebSocketFrame {
326        WebSocketFrame {
327            frame_type: FrameType::Control,
328            header: Header {
329                op_code: OpCode::CONTINUATION,
330                mask: false,
331                payload_len: 0u64,
332                masking_key: [0u8; 4],
333            },
334            payload: Payload {
335                data: Vec::<u8>::new(),
336            },
337        }
338    }
339}
340
341impl fmt::Display for FrameType {
342    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
343        match *self {
344            FrameType::Control => write!(f, "FrameType::Control"),
345            FrameType::Data => write!(f, "FrameType::Data"),
346        }
347    }
348}
349
350impl fmt::Display for OpType {
351    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
352        match *self {
353            OpType::Continuation => write!(f, "OpType::Continuation"),
354            OpType::Text => write!(f, "OpType::Text"),
355            OpType::Binary => write!(f, "OpType::Binary"),
356            OpType::Close => write!(f, "OpType::Close"),
357            OpType::Ping => write!(f, "OpType::Ping"),
358            OpType::Pong => write!(f, "OpType::Pong"),
359        }
360    }
361}