websocket_stream/
lib.rs

1// The MIT License (MIT)
2//
3// Copyright (c) 2015 Nathan Sizemore <nathanrsizemore@gmail.com>
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23
24//! WebsocketStream crate
25
26
27extern crate libc;
28extern crate errno;
29
30use std::{mem, ptr, fmt};
31use std::result::Result;
32use std::net::TcpStream;
33use std::os::unix::io::AsRawFd;
34
35use errno::errno;
36use libc::{size_t, c_void, c_int, ssize_t};
37
38use util::*;
39pub mod util;
40
41
42extern "system" {
43    fn read(fd: c_int, buffer: *mut c_void, count: size_t) -> ssize_t;
44    fn write(fd: c_int, buffer: *const c_void, cout: size_t) -> ssize_t;
45}
46
47/// Represents the result of trying to create a WebsocketStream
48pub type NewResult = Result<WebsocketStream, SetFdError>;
49
50/// Represents the result of setting a flag to file descriptor through syscalls
51pub type SetFdResult = Result<(), SetFdError>;
52
53/// Represents the result of attempting a syscall read on the file descriptor
54pub type ReadResult = Result<(OpCode, Vec<u8>), ReadError>;
55
56/// Represents the result of attempting a syscall write on the file descriptor
57pub type WriteResult = Result<u64, WriteError>;
58
59/// Internal result of read syscall on file descriptor
60type SysReadResult = Result<(), ReadError>;
61
62/// Internal result of write syscall on file descriptor
63type SysWriteResult = Result<u64, WriteError>;
64
65/// Internal result of attempting to retrieve the OpCode
66/// https://tools.ietf.org/html/rfc6455#page-29
67type OpCodeResult = Result<OpCode, ReadError>;
68
69/// Internal result of attempting to retrieve the payload key
70/// https://tools.ietf.org/html/rfc6455#page-29
71type PayloadKeyResult = Result<u8, ReadError>;
72
73/// Internal result of attempting to read the payload length
74/// https://tools.ietf.org/html/rfc6455#page-29
75type PayloadLenResult = Result<u64, ReadError>;
76
77
78/// RFC-6455 Protocol stream
79pub struct WebsocketStream {
80    mode: Mode,
81    stream: TcpStream,
82    state: State,
83    msg: Message,
84    buffer: Buffer
85}
86
87/// Generic buffer used for reading/writing
88#[derive(Clone)]
89struct Buffer {
90    remaining: usize,
91    buf: Vec<u8>
92}
93
94/// Websocket Frame
95#[derive(Clone)]
96struct Message {
97    op_code: OpCode,
98    payload_key: u8,
99    payload_len: u64,
100    masking_key: [u8; 4],
101    payload: Vec<u8>
102}
103
104/// Stream I/O mode
105#[derive(PartialEq, Clone)]
106pub enum Mode {
107    /// Blocking I/O
108    Block,
109    /// Non-blocking I/O
110    NonBlock
111}
112
113/// Stream state
114#[derive(PartialEq, Clone)]
115pub enum State {
116    OpCode,
117    PayloadKey,
118    PayloadLength,
119    MaskingKey,
120    Payload
121}
122
123impl WebsocketStream {
124
125    /// Attempts to create a new stream in specified mode
126    pub fn new(stream: TcpStream, mode: Mode) -> NewResult {
127        match mode {
128            Mode::Block => {
129                Ok(WebsocketStream {
130                    stream: stream,
131                    mode: Mode::Block,
132                    state: State::OpCode,
133                    msg: Message {
134                        op_code: OpCode::Text,
135                        payload_key: 0u8,
136                        payload_len: 0u64,
137                        masking_key: [0u8; 4],
138                        payload: Vec::new()
139                    },
140                    buffer: Buffer {
141                        remaining: 1,
142                        buf: Vec::new()
143                    }
144                })
145            }
146            Mode::NonBlock => {
147                match WebsocketStream::set_non_block(&stream) {
148                    Ok(()) => Ok(WebsocketStream {
149                        stream: stream,
150                        mode: Mode::NonBlock,
151                        state: State::OpCode,
152                        msg: Message {
153                            op_code: OpCode::Text,
154                            payload_key: 0u8,
155                            payload_len: 0u64,
156                            masking_key: [0u8; 4],
157                            payload: Vec::new()
158                        },
159                        buffer: Buffer {
160                            remaining: 1,
161                            buf: Vec::new()
162                        }
163                    }),
164                    Err(e) => Err(e)
165                }
166            }
167        }
168    }
169
170    /// Sets the socket to the specified mode
171    pub fn set_mode(&mut self, mode: Mode) -> SetFdResult {
172        // If we're already in the mode, no need to for a syscall
173        if self.mode == mode {
174            return Ok(())
175        }
176
177        match mode {
178            Mode::Block => {
179                WebsocketStream::set_block(&self.stream)
180            }
181            Mode::NonBlock => {
182                WebsocketStream::set_non_block(&self.stream)
183            }
184        }
185    }
186
187    fn set_block(stream: &TcpStream) -> SetFdResult {
188        let fd = stream.as_raw_fd();
189
190        // Get the flags currently set on the fd
191        let flags;
192        unsafe {
193            flags = libc::fcntl(fd, libc::F_GETFL);
194        }
195
196        // Ensure we were able to get the current set flags
197        if flags < 0 {
198            let errno = errno().0 as i32;
199            return match errno {
200                libc::EACCES     => Err(SetFdError::EACCES),
201                libc::EAGAIN     => Err(SetFdError::EAGAIN),
202                libc::EBADF      => Err(SetFdError::EBADF),
203                libc::EDEADLK    => Err(SetFdError::EDEADLK),
204                libc::EFAULT     => Err(SetFdError::EFAULT),
205                libc::EINTR      => Err(SetFdError::EINTR),
206                libc::EINVAL     => Err(SetFdError::EINVAL),
207                libc::EMFILE     => Err(SetFdError::EMFILE),
208                libc::ENOLCK     => Err(SetFdError::ENOLCK),
209                libc::EPERM      => Err(SetFdError::EPERM),
210                _ => panic!("Unexpected errno: {}", errno)
211            };
212        }
213
214        // Remove non-blocking flag
215        let response;
216        unsafe {
217            response = libc::fcntl(
218                fd,
219                libc::F_SETFL,
220                flags & !libc::O_NONBLOCK);
221        }
222
223        // Ensure removal was successful
224        if response < 0 {
225            let errno = errno().0 as i32;
226            return match errno {
227                libc::EACCES     => Err(SetFdError::EACCES),
228                libc::EAGAIN     => Err(SetFdError::EAGAIN),
229                libc::EBADF      => Err(SetFdError::EBADF),
230                libc::EDEADLK    => Err(SetFdError::EDEADLK),
231                libc::EFAULT     => Err(SetFdError::EFAULT),
232                libc::EINTR      => Err(SetFdError::EINTR),
233                libc::EINVAL     => Err(SetFdError::EINVAL),
234                libc::EMFILE     => Err(SetFdError::EMFILE),
235                libc::ENOLCK     => Err(SetFdError::ENOLCK),
236                libc::EPERM      => Err(SetFdError::EPERM),
237                _ => panic!("Unexpected errno: {}", errno)
238            };
239        } else {
240            Ok(())
241        }
242    }
243
244    /// Sets the stream to non-blocking mode
245    fn set_non_block(stream: &TcpStream) -> SetFdResult {
246        let fd = stream.as_raw_fd();
247        let response;
248        unsafe {
249            response = libc::fcntl(
250                fd,
251                libc::F_SETFL,
252                libc::O_NONBLOCK);
253        }
254
255        if response < 0 {
256            let errno = errno().0 as i32;
257            return match errno {
258                libc::EACCES     => Err(SetFdError::EACCES),
259                libc::EAGAIN     => Err(SetFdError::EAGAIN),
260                libc::EBADF      => Err(SetFdError::EBADF),
261                libc::EDEADLK    => Err(SetFdError::EDEADLK),
262                libc::EFAULT     => Err(SetFdError::EFAULT),
263                libc::EINTR      => Err(SetFdError::EINTR),
264                libc::EINVAL     => Err(SetFdError::EINVAL),
265                libc::EMFILE     => Err(SetFdError::EMFILE),
266                libc::ENOLCK     => Err(SetFdError::ENOLCK),
267                libc::EPERM      => Err(SetFdError::EPERM),
268                _ => panic!("Unexpected errno: {}", errno)
269            };
270        } else {
271            Ok(())
272        }
273    }
274
275    /// Attempts to read data from the socket.
276    ///
277    /// If stream is in Mode::Block, this will block forever until
278    /// data is received
279    ///
280    /// If socket is in Mode::NonBlock and data is available,
281    /// it will read until a complete message is received.  If the buffer
282    /// has run out, and it is still waiting on the remaining payload, it
283    /// will adjust the remaining needed in it's buffer and will adjust on
284    /// the next call to this function.
285    pub fn read(&mut self) -> ReadResult {
286        // Read the OpCode
287        if self.state == State::OpCode {
288            if self.buffer.remaining == 0 {
289                self.buffer.remaining = 1;
290                self.buffer.buf = Vec::<u8>::with_capacity(1);
291            }
292
293            let result = self.read_op_code();
294            if !result.is_ok() {
295                return Err(result.unwrap_err());
296            }
297            self.msg.op_code = result.unwrap();
298
299            // Set state to next stage
300            self.state = State::PayloadKey;
301            self.buffer.remaining = 1;
302            self.buffer.buf = Vec::<u8>::with_capacity(1);
303        }
304
305        // Read the Payload Key
306        if self.state == State::PayloadKey {
307            let result = self.read_payload_key();
308            if !result.is_ok() {
309                return Err(result.unwrap_err());
310            }
311            self.msg.payload_key = result.unwrap();
312
313            // Set next state
314            self.state = State::PayloadLength;
315            self.buffer.remaining = match self.msg.payload_key {
316                127 => 8,
317                126 => 2,
318                _ => {
319                    self.msg.payload_len = self.msg.payload_key as u64;
320                    0
321                }
322            };
323            self.buffer.buf = Vec::<u8>::with_capacity(self.buffer.remaining);
324        }
325
326        // Read the payload length, if needed
327        if self.state == State::PayloadLength {
328            if self.buffer.remaining > 0 {
329                let result = self.read_payload_length();
330                if !result.is_ok() {
331                    let err = result.unwrap_err();
332                    match err {
333                        ReadError::EAGAIN => {
334                            // Update bytes remaining
335                            self.buffer.remaining = (self.msg.payload_len -
336                                self.buffer.buf.len() as u64) as usize;
337                        }
338                        _ => { }
339                    }
340                    return Err(err);
341                }
342
343                // Grab result
344                self.msg.payload_len = result.unwrap();
345
346                // Update bytes remaining
347                let bytes_needed = match self.msg.payload_key {
348                    127 => 8,
349                    126 => 2,
350                    _ => 0
351                };
352                self.buffer.remaining = (bytes_needed -
353                    self.buffer.buf.len() as u64) as usize;
354            } else {
355                // If buffer.remaining == 0, len was the key
356                self.state = State::MaskingKey;
357                self.buffer.remaining = 4;
358                self.buffer.buf = Vec::<u8>::with_capacity(4);
359            }
360        }
361
362        // Read the masking key
363        if self.state == State::MaskingKey {
364            let result = self.read_masking_key();
365            if !result.is_ok() {
366                let err = result.unwrap_err();
367                match err {
368                    ReadError::EAGAIN => {
369                        // Update bytes remaining
370                        self.buffer.remaining = 4 - self.buffer.buf.len();
371                    }
372                    _ => { }
373                }
374                return Err(err);
375            }
376
377            // Copy the masking key
378            self.msg.masking_key[0] = self.buffer.buf[0];
379            self.msg.masking_key[1] = self.buffer.buf[1];
380            self.msg.masking_key[2] = self.buffer.buf[2];
381            self.msg.masking_key[3] = self.buffer.buf[3];
382
383            // Change state and update buffer
384            self.state = State::Payload;
385            self.buffer.remaining = self.msg.payload_len as usize;
386            self.buffer.buf = Vec::<u8>::with_capacity(
387                self.msg.payload_len as usize);
388        }
389
390        // Read the payload
391        if self.state == State::Payload {
392            let result = self.read_payload();
393            if !result.is_ok() {
394                let err = result.unwrap_err();
395                match err {
396                    ReadError::EAGAIN => {
397                        // Update bytes remaining
398                        self.buffer.remaining = (self.msg.payload_len -
399                            self.buffer.buf.len() as u64) as usize;
400                    }
401                    _ => { }
402                }
403                return Err(err);
404            }
405
406            // Unmask the payload
407            self.msg.payload = Vec::<u8>::with_capacity(
408                self.msg.payload_len as usize);
409            for x in 0..self.buffer.buf.len() {
410                self.msg.payload.push(
411                    self.buffer.buf[x] ^ self.msg.masking_key[x % 4]);
412            }
413
414            self.state = State::OpCode;
415            self.buffer.remaining = 1;
416            self.buffer.buf = Vec::<u8>::with_capacity(1);
417
418            // Return the OpCode and Payload
419            return Ok((self.msg.op_code.clone(), self.msg.payload.clone()))
420        }
421
422        // Default return value
423        Err(ReadError::EAGAIN)
424    }
425
426    /// Attempts to read and unmask the OpCode frame
427    fn read_op_code(&mut self) -> OpCodeResult {
428        match self.read_num_bytes(1) {
429            Ok(()) => { }
430            Err(e) => return Err(e)
431        };
432
433        // Ensure opcode is valid
434        let op_code = self.buffer.buf[0] & OP_CODE_UN_MASK;
435        let valid_op = match op_code {
436            OP_CONTINUATION => true,
437            OP_TEXT         => true,
438            OP_BINARY       => true,
439            OP_CLOSE        => true,
440            OP_PING         => true,
441            OP_PONG         => true,
442            _ => false
443        };
444        if !valid_op {
445            return Err(ReadError::OpCode);
446        }
447
448        // Assign OpCode
449        let op = match op_code {
450            OP_CONTINUATION => OpCode::Continuation,
451            OP_TEXT         => OpCode::Text,
452            OP_BINARY       => OpCode::Binary,
453            OP_CLOSE        => OpCode::Close,
454            OP_PING         => OpCode::Ping,
455            OP_PONG         => OpCode::Pong,
456            _ => unimplemented!()
457        };
458        Ok(op)
459    }
460
461    /// Attempts to read the payload key from the socket
462    fn read_payload_key(&mut self) -> PayloadKeyResult {
463        match self.read_num_bytes(1) {
464            Ok(()) => Ok(self.buffer.buf[0] & PAYLOAD_KEY_UN_MASK),
465            Err(e) => Err(e)
466        }
467    }
468
469    /// Attempts to read the payload length from the socket
470    fn read_payload_length(&mut self) -> PayloadLenResult {
471        let count = self.buffer.remaining;
472        match self.read_num_bytes(count) {
473            Ok(()) => {
474                if self.msg.payload_key == 126 {
475                    let mut len = (self.buffer.buf[0] as u16) << 8;
476                    len = len | (self.buffer.buf[1] as u16);
477                    Ok(len as u64)
478                } else {
479                    let mut len = (self.buffer.buf[0] as u64) << 56;
480                    len = len | ((self.buffer.buf[1] as u64) << 48);
481                    len = len | ((self.buffer.buf[2] as u64) << 40);
482                    len = len | ((self.buffer.buf[3] as u64) << 32);
483                    len = len | ((self.buffer.buf[4] as u64) << 24);
484                    len = len | ((self.buffer.buf[5] as u64) << 16);
485                    len = len | ((self.buffer.buf[6] as u64) << 8);
486                    len = len | (self.buffer.buf[7] as u64);
487                    Ok(len)
488                }
489            }
490            Err(e) => Err(e)
491        }
492    }
493
494    /// Attempts to read the masking key from the stream
495    fn read_masking_key(&mut self) -> SysReadResult {
496        let count = self.buffer.remaining;
497        match self.read_num_bytes(count) {
498            Ok(()) => Ok(()),
499            Err(e) => Err(e)
500        }
501    }
502
503    /// Attempts to read the payload from the stream
504    fn read_payload(&mut self) -> SysReadResult {
505        let count = self.buffer.remaining;
506        match self.read_num_bytes(count) {
507            Ok(()) => Ok(()),
508            Err(e) => Err(e)
509        }
510    }
511
512    /// Attempts to read count bytes from the stream
513    fn read_num_bytes(&mut self, count: usize) -> SysReadResult {
514        let fd = self.stream.as_raw_fd();
515
516        // Create a buffer for the total of bytes still needed
517        let buffer;
518        unsafe {
519            buffer = libc::calloc(count as size_t,
520                mem::size_of::<u8>() as size_t);
521        }
522
523        // Ensure system gave up the mem
524        if buffer.is_null() {
525            return Err(ReadError::ENOMEM)
526        }
527
528        // Read data into buffer
529        let num_read;
530        unsafe {
531            num_read = read(fd, buffer, count as size_t);
532        }
533
534        // Report and exit on any thrown errors
535        if num_read < 0 {
536            unsafe { libc::free(buffer); }
537            let errno = errno().0 as i32;
538            return match errno {
539                libc::EBADF      => Err(ReadError::EBADF),
540                libc::EFAULT     => Err(ReadError::EFAULT),
541                libc::EINTR      => Err(ReadError::EINTR),
542                libc::EINVAL     => Err(ReadError::EINVAL),
543                libc::EIO        => Err(ReadError::EIO),
544                libc::EISDIR     => Err(ReadError::EISDIR),
545                libc::EAGAIN     => Err(ReadError::EAGAIN),
546                _ => panic!("Unexpected errno during read: {}", errno)
547            };
548        }
549
550        // Check for EOF
551        if num_read == 0 {
552            unsafe { libc::free(buffer); }
553            return Err(ReadError::EAGAIN);
554        }
555
556        // Add bytes to msg buffer
557        for x in 0..num_read as isize {
558            unsafe {
559                self.buffer.buf.push(ptr::read(buffer.offset(x)) as u8);
560            }
561        }
562
563        // Free buffer and return Ok
564        unsafe { libc::free(buffer); }
565        Ok(())
566    }
567
568    /// Attempts to write data to the socket
569    pub fn write(&mut self, op: OpCode, payload: &mut Vec<u8>) -> WriteResult {
570        let mut out_buf: Vec<u8> = Vec::with_capacity(payload.len() + 9);
571
572        self.set_op_code(&op, &mut out_buf);
573        self.set_payload_info(payload.len(), &mut out_buf);
574
575        // TODO - Fix with Vec.append() once stable
576        // out_buf.append(payload);
577        for byte in payload.iter() {
578            out_buf.push(*byte);
579        }
580
581        self.write_bytes(&out_buf)
582    }
583
584    fn set_op_code(&self, op: &OpCode, buf: &mut Vec<u8>) {
585        let op_code = match *op {
586            OpCode::Continuation    => OP_CONTINUATION,
587            OpCode::Text            => OP_TEXT,
588            OpCode::Binary          => OP_BINARY,
589            OpCode::Close           => OP_CLOSE,
590            OpCode::Ping            => OP_PING,
591            OpCode::Pong            => OP_PONG
592        };
593        buf.push(op_code | OP_CODE_MASK);
594    }
595
596    fn set_payload_info(&self, len: usize, buf: &mut Vec<u8>) {
597        if len <= 125 {
598            buf.push(len as u8);
599        } else if len <= 65535 {
600            let mut len_buf = [0u8; 2];
601            len_buf[0] = ((len as u16) >> 8) as u8;
602            len_buf[1] = len as u8;
603
604            buf.push(126u8); // 16 bit prelude
605            buf.push(len_buf[0]);
606            buf.push(len_buf[1]);
607        } else {
608            let mut len_buf = [0u8; 8];
609            len_buf[0] = ((len as u64) >> 56) as u8;
610            len_buf[1] = ((len as u64) >> 48) as u8;
611            len_buf[2] = ((len as u64) >> 40) as u8;
612            len_buf[3] = ((len as u64) >> 32) as u8;
613            len_buf[4] = ((len as u64) >> 24) as u8;
614            len_buf[5] = ((len as u64) >> 16) as u8;
615            len_buf[6] = ((len as u64) >> 8) as u8;
616            len_buf[7] = len as u8;
617
618            buf.push(127u8); // 64 bit prelude
619            buf.push(len_buf[0]);
620            buf.push(len_buf[1]);
621            buf.push(len_buf[2]);
622            buf.push(len_buf[3]);
623            buf.push(len_buf[4]);
624            buf.push(len_buf[5]);
625            buf.push(len_buf[6]);
626            buf.push(len_buf[7]);
627        }
628    }
629
630    fn write_bytes(&mut self, buf: &Vec<u8>) -> SysWriteResult {
631        let buffer = &buf[..];
632        let fd = self.stream.as_raw_fd();
633        let count = buf.len() as size_t;
634
635        let num_written;
636        unsafe {
637            let buff_ptr = buffer.as_ptr();
638            let void_buff_ptr: *const c_void = mem::transmute(buff_ptr);
639            num_written = write(fd, void_buff_ptr, count);
640        }
641
642        if num_written < 0 {
643            let errno = errno().0 as i32;
644            return match errno {
645                libc::EAGAIN     => Err(WriteError::EAGAIN),
646                libc::EBADF      => Err(WriteError::EBADF),
647                libc::EFAULT     => Err(WriteError::EFAULT),
648                libc::EFBIG      => Err(WriteError::EFBIG),
649                libc::EINTR      => Err(WriteError::EINTR),
650                libc::EINVAL     => Err(WriteError::EINVAL),
651                libc::EIO        => Err(WriteError::EIO),
652                libc::ENOSPC     => Err(WriteError::ENOSPC),
653                libc::EPIPE      => Err(WriteError::EPIPE),
654                _ => panic!("Unknown errno during write: {}", errno),
655            }
656        }
657        Ok(num_written as u64)
658    }
659}
660
661impl Clone for WebsocketStream {
662    fn clone(&self) -> WebsocketStream {
663        WebsocketStream {
664            mode: self.mode.clone(),
665            stream: self.stream.try_clone().unwrap(),
666            state: self.state.clone(),
667            msg: self.msg.clone(),
668            buffer: self.buffer.clone()
669        }
670    }
671}
672
673impl fmt::Display for State {
674    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
675        match *self {
676            State::OpCode => "OpCode".fmt(f),
677            State::PayloadKey => "PayloadKey".fmt(f),
678            State::PayloadLength => "PayloadLength".fmt(f),
679            State::MaskingKey => "MaskingKey".fmt(f),
680            State::Payload => "Payload".fmt(f)
681        }
682    }
683}