1extern crate libc;
28extern crate errno;
29
30use std::{mem, ptr, fmt};
31use std::result::Result;
32use std::net::TcpStream;
33use std::os::unix::io::AsRawFd;
34
35use errno::errno;
36use libc::{size_t, c_void, c_int, ssize_t};
37
38use util::*;
39pub mod util;
40
41
42extern "system" {
43 fn read(fd: c_int, buffer: *mut c_void, count: size_t) -> ssize_t;
44 fn write(fd: c_int, buffer: *const c_void, cout: size_t) -> ssize_t;
45}
46
47pub type NewResult = Result<WebsocketStream, SetFdError>;
49
50pub type SetFdResult = Result<(), SetFdError>;
52
53pub type ReadResult = Result<(OpCode, Vec<u8>), ReadError>;
55
56pub type WriteResult = Result<u64, WriteError>;
58
59type SysReadResult = Result<(), ReadError>;
61
62type SysWriteResult = Result<u64, WriteError>;
64
65type OpCodeResult = Result<OpCode, ReadError>;
68
69type PayloadKeyResult = Result<u8, ReadError>;
72
73type PayloadLenResult = Result<u64, ReadError>;
76
77
78pub struct WebsocketStream {
80 mode: Mode,
81 stream: TcpStream,
82 state: State,
83 msg: Message,
84 buffer: Buffer
85}
86
87#[derive(Clone)]
89struct Buffer {
90 remaining: usize,
91 buf: Vec<u8>
92}
93
94#[derive(Clone)]
96struct Message {
97 op_code: OpCode,
98 payload_key: u8,
99 payload_len: u64,
100 masking_key: [u8; 4],
101 payload: Vec<u8>
102}
103
104#[derive(PartialEq, Clone)]
106pub enum Mode {
107 Block,
109 NonBlock
111}
112
113#[derive(PartialEq, Clone)]
115pub enum State {
116 OpCode,
117 PayloadKey,
118 PayloadLength,
119 MaskingKey,
120 Payload
121}
122
123impl WebsocketStream {
124
125 pub fn new(stream: TcpStream, mode: Mode) -> NewResult {
127 match mode {
128 Mode::Block => {
129 Ok(WebsocketStream {
130 stream: stream,
131 mode: Mode::Block,
132 state: State::OpCode,
133 msg: Message {
134 op_code: OpCode::Text,
135 payload_key: 0u8,
136 payload_len: 0u64,
137 masking_key: [0u8; 4],
138 payload: Vec::new()
139 },
140 buffer: Buffer {
141 remaining: 1,
142 buf: Vec::new()
143 }
144 })
145 }
146 Mode::NonBlock => {
147 match WebsocketStream::set_non_block(&stream) {
148 Ok(()) => Ok(WebsocketStream {
149 stream: stream,
150 mode: Mode::NonBlock,
151 state: State::OpCode,
152 msg: Message {
153 op_code: OpCode::Text,
154 payload_key: 0u8,
155 payload_len: 0u64,
156 masking_key: [0u8; 4],
157 payload: Vec::new()
158 },
159 buffer: Buffer {
160 remaining: 1,
161 buf: Vec::new()
162 }
163 }),
164 Err(e) => Err(e)
165 }
166 }
167 }
168 }
169
170 pub fn set_mode(&mut self, mode: Mode) -> SetFdResult {
172 if self.mode == mode {
174 return Ok(())
175 }
176
177 match mode {
178 Mode::Block => {
179 WebsocketStream::set_block(&self.stream)
180 }
181 Mode::NonBlock => {
182 WebsocketStream::set_non_block(&self.stream)
183 }
184 }
185 }
186
187 fn set_block(stream: &TcpStream) -> SetFdResult {
188 let fd = stream.as_raw_fd();
189
190 let flags;
192 unsafe {
193 flags = libc::fcntl(fd, libc::F_GETFL);
194 }
195
196 if flags < 0 {
198 let errno = errno().0 as i32;
199 return match errno {
200 libc::EACCES => Err(SetFdError::EACCES),
201 libc::EAGAIN => Err(SetFdError::EAGAIN),
202 libc::EBADF => Err(SetFdError::EBADF),
203 libc::EDEADLK => Err(SetFdError::EDEADLK),
204 libc::EFAULT => Err(SetFdError::EFAULT),
205 libc::EINTR => Err(SetFdError::EINTR),
206 libc::EINVAL => Err(SetFdError::EINVAL),
207 libc::EMFILE => Err(SetFdError::EMFILE),
208 libc::ENOLCK => Err(SetFdError::ENOLCK),
209 libc::EPERM => Err(SetFdError::EPERM),
210 _ => panic!("Unexpected errno: {}", errno)
211 };
212 }
213
214 let response;
216 unsafe {
217 response = libc::fcntl(
218 fd,
219 libc::F_SETFL,
220 flags & !libc::O_NONBLOCK);
221 }
222
223 if response < 0 {
225 let errno = errno().0 as i32;
226 return match errno {
227 libc::EACCES => Err(SetFdError::EACCES),
228 libc::EAGAIN => Err(SetFdError::EAGAIN),
229 libc::EBADF => Err(SetFdError::EBADF),
230 libc::EDEADLK => Err(SetFdError::EDEADLK),
231 libc::EFAULT => Err(SetFdError::EFAULT),
232 libc::EINTR => Err(SetFdError::EINTR),
233 libc::EINVAL => Err(SetFdError::EINVAL),
234 libc::EMFILE => Err(SetFdError::EMFILE),
235 libc::ENOLCK => Err(SetFdError::ENOLCK),
236 libc::EPERM => Err(SetFdError::EPERM),
237 _ => panic!("Unexpected errno: {}", errno)
238 };
239 } else {
240 Ok(())
241 }
242 }
243
244 fn set_non_block(stream: &TcpStream) -> SetFdResult {
246 let fd = stream.as_raw_fd();
247 let response;
248 unsafe {
249 response = libc::fcntl(
250 fd,
251 libc::F_SETFL,
252 libc::O_NONBLOCK);
253 }
254
255 if response < 0 {
256 let errno = errno().0 as i32;
257 return match errno {
258 libc::EACCES => Err(SetFdError::EACCES),
259 libc::EAGAIN => Err(SetFdError::EAGAIN),
260 libc::EBADF => Err(SetFdError::EBADF),
261 libc::EDEADLK => Err(SetFdError::EDEADLK),
262 libc::EFAULT => Err(SetFdError::EFAULT),
263 libc::EINTR => Err(SetFdError::EINTR),
264 libc::EINVAL => Err(SetFdError::EINVAL),
265 libc::EMFILE => Err(SetFdError::EMFILE),
266 libc::ENOLCK => Err(SetFdError::ENOLCK),
267 libc::EPERM => Err(SetFdError::EPERM),
268 _ => panic!("Unexpected errno: {}", errno)
269 };
270 } else {
271 Ok(())
272 }
273 }
274
275 pub fn read(&mut self) -> ReadResult {
286 if self.state == State::OpCode {
288 if self.buffer.remaining == 0 {
289 self.buffer.remaining = 1;
290 self.buffer.buf = Vec::<u8>::with_capacity(1);
291 }
292
293 let result = self.read_op_code();
294 if !result.is_ok() {
295 return Err(result.unwrap_err());
296 }
297 self.msg.op_code = result.unwrap();
298
299 self.state = State::PayloadKey;
301 self.buffer.remaining = 1;
302 self.buffer.buf = Vec::<u8>::with_capacity(1);
303 }
304
305 if self.state == State::PayloadKey {
307 let result = self.read_payload_key();
308 if !result.is_ok() {
309 return Err(result.unwrap_err());
310 }
311 self.msg.payload_key = result.unwrap();
312
313 self.state = State::PayloadLength;
315 self.buffer.remaining = match self.msg.payload_key {
316 127 => 8,
317 126 => 2,
318 _ => {
319 self.msg.payload_len = self.msg.payload_key as u64;
320 0
321 }
322 };
323 self.buffer.buf = Vec::<u8>::with_capacity(self.buffer.remaining);
324 }
325
326 if self.state == State::PayloadLength {
328 if self.buffer.remaining > 0 {
329 let result = self.read_payload_length();
330 if !result.is_ok() {
331 let err = result.unwrap_err();
332 match err {
333 ReadError::EAGAIN => {
334 self.buffer.remaining = (self.msg.payload_len -
336 self.buffer.buf.len() as u64) as usize;
337 }
338 _ => { }
339 }
340 return Err(err);
341 }
342
343 self.msg.payload_len = result.unwrap();
345
346 let bytes_needed = match self.msg.payload_key {
348 127 => 8,
349 126 => 2,
350 _ => 0
351 };
352 self.buffer.remaining = (bytes_needed -
353 self.buffer.buf.len() as u64) as usize;
354 } else {
355 self.state = State::MaskingKey;
357 self.buffer.remaining = 4;
358 self.buffer.buf = Vec::<u8>::with_capacity(4);
359 }
360 }
361
362 if self.state == State::MaskingKey {
364 let result = self.read_masking_key();
365 if !result.is_ok() {
366 let err = result.unwrap_err();
367 match err {
368 ReadError::EAGAIN => {
369 self.buffer.remaining = 4 - self.buffer.buf.len();
371 }
372 _ => { }
373 }
374 return Err(err);
375 }
376
377 self.msg.masking_key[0] = self.buffer.buf[0];
379 self.msg.masking_key[1] = self.buffer.buf[1];
380 self.msg.masking_key[2] = self.buffer.buf[2];
381 self.msg.masking_key[3] = self.buffer.buf[3];
382
383 self.state = State::Payload;
385 self.buffer.remaining = self.msg.payload_len as usize;
386 self.buffer.buf = Vec::<u8>::with_capacity(
387 self.msg.payload_len as usize);
388 }
389
390 if self.state == State::Payload {
392 let result = self.read_payload();
393 if !result.is_ok() {
394 let err = result.unwrap_err();
395 match err {
396 ReadError::EAGAIN => {
397 self.buffer.remaining = (self.msg.payload_len -
399 self.buffer.buf.len() as u64) as usize;
400 }
401 _ => { }
402 }
403 return Err(err);
404 }
405
406 self.msg.payload = Vec::<u8>::with_capacity(
408 self.msg.payload_len as usize);
409 for x in 0..self.buffer.buf.len() {
410 self.msg.payload.push(
411 self.buffer.buf[x] ^ self.msg.masking_key[x % 4]);
412 }
413
414 self.state = State::OpCode;
415 self.buffer.remaining = 1;
416 self.buffer.buf = Vec::<u8>::with_capacity(1);
417
418 return Ok((self.msg.op_code.clone(), self.msg.payload.clone()))
420 }
421
422 Err(ReadError::EAGAIN)
424 }
425
426 fn read_op_code(&mut self) -> OpCodeResult {
428 match self.read_num_bytes(1) {
429 Ok(()) => { }
430 Err(e) => return Err(e)
431 };
432
433 let op_code = self.buffer.buf[0] & OP_CODE_UN_MASK;
435 let valid_op = match op_code {
436 OP_CONTINUATION => true,
437 OP_TEXT => true,
438 OP_BINARY => true,
439 OP_CLOSE => true,
440 OP_PING => true,
441 OP_PONG => true,
442 _ => false
443 };
444 if !valid_op {
445 return Err(ReadError::OpCode);
446 }
447
448 let op = match op_code {
450 OP_CONTINUATION => OpCode::Continuation,
451 OP_TEXT => OpCode::Text,
452 OP_BINARY => OpCode::Binary,
453 OP_CLOSE => OpCode::Close,
454 OP_PING => OpCode::Ping,
455 OP_PONG => OpCode::Pong,
456 _ => unimplemented!()
457 };
458 Ok(op)
459 }
460
461 fn read_payload_key(&mut self) -> PayloadKeyResult {
463 match self.read_num_bytes(1) {
464 Ok(()) => Ok(self.buffer.buf[0] & PAYLOAD_KEY_UN_MASK),
465 Err(e) => Err(e)
466 }
467 }
468
469 fn read_payload_length(&mut self) -> PayloadLenResult {
471 let count = self.buffer.remaining;
472 match self.read_num_bytes(count) {
473 Ok(()) => {
474 if self.msg.payload_key == 126 {
475 let mut len = (self.buffer.buf[0] as u16) << 8;
476 len = len | (self.buffer.buf[1] as u16);
477 Ok(len as u64)
478 } else {
479 let mut len = (self.buffer.buf[0] as u64) << 56;
480 len = len | ((self.buffer.buf[1] as u64) << 48);
481 len = len | ((self.buffer.buf[2] as u64) << 40);
482 len = len | ((self.buffer.buf[3] as u64) << 32);
483 len = len | ((self.buffer.buf[4] as u64) << 24);
484 len = len | ((self.buffer.buf[5] as u64) << 16);
485 len = len | ((self.buffer.buf[6] as u64) << 8);
486 len = len | (self.buffer.buf[7] as u64);
487 Ok(len)
488 }
489 }
490 Err(e) => Err(e)
491 }
492 }
493
494 fn read_masking_key(&mut self) -> SysReadResult {
496 let count = self.buffer.remaining;
497 match self.read_num_bytes(count) {
498 Ok(()) => Ok(()),
499 Err(e) => Err(e)
500 }
501 }
502
503 fn read_payload(&mut self) -> SysReadResult {
505 let count = self.buffer.remaining;
506 match self.read_num_bytes(count) {
507 Ok(()) => Ok(()),
508 Err(e) => Err(e)
509 }
510 }
511
512 fn read_num_bytes(&mut self, count: usize) -> SysReadResult {
514 let fd = self.stream.as_raw_fd();
515
516 let buffer;
518 unsafe {
519 buffer = libc::calloc(count as size_t,
520 mem::size_of::<u8>() as size_t);
521 }
522
523 if buffer.is_null() {
525 return Err(ReadError::ENOMEM)
526 }
527
528 let num_read;
530 unsafe {
531 num_read = read(fd, buffer, count as size_t);
532 }
533
534 if num_read < 0 {
536 unsafe { libc::free(buffer); }
537 let errno = errno().0 as i32;
538 return match errno {
539 libc::EBADF => Err(ReadError::EBADF),
540 libc::EFAULT => Err(ReadError::EFAULT),
541 libc::EINTR => Err(ReadError::EINTR),
542 libc::EINVAL => Err(ReadError::EINVAL),
543 libc::EIO => Err(ReadError::EIO),
544 libc::EISDIR => Err(ReadError::EISDIR),
545 libc::EAGAIN => Err(ReadError::EAGAIN),
546 _ => panic!("Unexpected errno during read: {}", errno)
547 };
548 }
549
550 if num_read == 0 {
552 unsafe { libc::free(buffer); }
553 return Err(ReadError::EAGAIN);
554 }
555
556 for x in 0..num_read as isize {
558 unsafe {
559 self.buffer.buf.push(ptr::read(buffer.offset(x)) as u8);
560 }
561 }
562
563 unsafe { libc::free(buffer); }
565 Ok(())
566 }
567
568 pub fn write(&mut self, op: OpCode, payload: &mut Vec<u8>) -> WriteResult {
570 let mut out_buf: Vec<u8> = Vec::with_capacity(payload.len() + 9);
571
572 self.set_op_code(&op, &mut out_buf);
573 self.set_payload_info(payload.len(), &mut out_buf);
574
575 for byte in payload.iter() {
578 out_buf.push(*byte);
579 }
580
581 self.write_bytes(&out_buf)
582 }
583
584 fn set_op_code(&self, op: &OpCode, buf: &mut Vec<u8>) {
585 let op_code = match *op {
586 OpCode::Continuation => OP_CONTINUATION,
587 OpCode::Text => OP_TEXT,
588 OpCode::Binary => OP_BINARY,
589 OpCode::Close => OP_CLOSE,
590 OpCode::Ping => OP_PING,
591 OpCode::Pong => OP_PONG
592 };
593 buf.push(op_code | OP_CODE_MASK);
594 }
595
596 fn set_payload_info(&self, len: usize, buf: &mut Vec<u8>) {
597 if len <= 125 {
598 buf.push(len as u8);
599 } else if len <= 65535 {
600 let mut len_buf = [0u8; 2];
601 len_buf[0] = ((len as u16) >> 8) as u8;
602 len_buf[1] = len as u8;
603
604 buf.push(126u8); buf.push(len_buf[0]);
606 buf.push(len_buf[1]);
607 } else {
608 let mut len_buf = [0u8; 8];
609 len_buf[0] = ((len as u64) >> 56) as u8;
610 len_buf[1] = ((len as u64) >> 48) as u8;
611 len_buf[2] = ((len as u64) >> 40) as u8;
612 len_buf[3] = ((len as u64) >> 32) as u8;
613 len_buf[4] = ((len as u64) >> 24) as u8;
614 len_buf[5] = ((len as u64) >> 16) as u8;
615 len_buf[6] = ((len as u64) >> 8) as u8;
616 len_buf[7] = len as u8;
617
618 buf.push(127u8); buf.push(len_buf[0]);
620 buf.push(len_buf[1]);
621 buf.push(len_buf[2]);
622 buf.push(len_buf[3]);
623 buf.push(len_buf[4]);
624 buf.push(len_buf[5]);
625 buf.push(len_buf[6]);
626 buf.push(len_buf[7]);
627 }
628 }
629
630 fn write_bytes(&mut self, buf: &Vec<u8>) -> SysWriteResult {
631 let buffer = &buf[..];
632 let fd = self.stream.as_raw_fd();
633 let count = buf.len() as size_t;
634
635 let num_written;
636 unsafe {
637 let buff_ptr = buffer.as_ptr();
638 let void_buff_ptr: *const c_void = mem::transmute(buff_ptr);
639 num_written = write(fd, void_buff_ptr, count);
640 }
641
642 if num_written < 0 {
643 let errno = errno().0 as i32;
644 return match errno {
645 libc::EAGAIN => Err(WriteError::EAGAIN),
646 libc::EBADF => Err(WriteError::EBADF),
647 libc::EFAULT => Err(WriteError::EFAULT),
648 libc::EFBIG => Err(WriteError::EFBIG),
649 libc::EINTR => Err(WriteError::EINTR),
650 libc::EINVAL => Err(WriteError::EINVAL),
651 libc::EIO => Err(WriteError::EIO),
652 libc::ENOSPC => Err(WriteError::ENOSPC),
653 libc::EPIPE => Err(WriteError::EPIPE),
654 _ => panic!("Unknown errno during write: {}", errno),
655 }
656 }
657 Ok(num_written as u64)
658 }
659}
660
661impl Clone for WebsocketStream {
662 fn clone(&self) -> WebsocketStream {
663 WebsocketStream {
664 mode: self.mode.clone(),
665 stream: self.stream.try_clone().unwrap(),
666 state: self.state.clone(),
667 msg: self.msg.clone(),
668 buffer: self.buffer.clone()
669 }
670 }
671}
672
673impl fmt::Display for State {
674 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
675 match *self {
676 State::OpCode => "OpCode".fmt(f),
677 State::PayloadKey => "PayloadKey".fmt(f),
678 State::PayloadLength => "PayloadLength".fmt(f),
679 State::MaskingKey => "MaskingKey".fmt(f),
680 State::Payload => "Payload".fmt(f)
681 }
682 }
683}