websocket_rs/
lib.rs

1#![feature(test)]
2
3use std::marker::Unpin;
4
5use base64::encode;
6use futures::sink::SinkExt;
7use sha1::Sha1;
8use tokio::io::{ AsyncRead, AsyncWrite, AsyncWriteExt };
9use tokio::io::{ split, ReadHalf, WriteHalf };
10use tokio::stream::StreamExt;
11use tokio::sync::mpsc;
12use tokio_util::codec::{ FramedRead, FramedWrite };
13use tracing::trace;
14
15pub mod error;
16pub mod opcode;
17pub mod codec;
18pub mod message;
19
20pub use error::{ WebsocketError, WebsocketResult };
21pub use codec::WebsocketCodec;
22pub use opcode::Opcode;
23pub use message::Message;
24
25#[derive(Debug)]
26pub struct Websocket<S> {
27    reader: FramedRead<ReadHalf<S>, WebsocketCodec>,
28    tx: mpsc::Sender<Message>,
29    key: Option<String>,
30}
31
32impl<S: 'static> Websocket<S>
33where
34    S: AsyncRead + AsyncWrite + Unpin + Send,
35{
36    /// Create a new websocket instance, given any type that implements
37    /// `AsyncRead + AsyncWrite` like `tokio::net::TcpStream` or `tokio_native_tls::TlsStream`
38    pub fn new(stream: S) -> Self {
39        Self::_create(stream, None)
40    }
41
42    /// Same as `Websocket::new` except that it also accept a key that represent
43    /// the value of `Sec-Websocket-Key` for client that requires a valid
44    /// `Sec-Websocket-Accept` in response headers.
45    pub fn new_with_key(stream: S, key: String) -> Self {
46        Self::_create(stream, Some(key))
47    }
48
49    fn _create(stream: S, key: Option<String>) -> Self {
50        let (reader, mut writer) = split(stream);
51        let reader = FramedRead::new(reader, WebsocketCodec::default());
52        let (tx, mut rx) = mpsc::channel::<Message>(100);
53
54        tokio::spawn(async move {
55            Self::_send_handshake(&mut writer, key).await;
56
57            let mut writer = FramedWrite::new(writer, WebsocketCodec::default());
58
59            while let Some(message) = rx.recv().await {
60                if message.opcode == Opcode::Close { break }
61                writer.send(message).await.unwrap();
62            }
63
64            trace!("Client disconnected");
65        });
66
67        Self { reader, tx, key: None }
68    }
69
70    /// Send handshake.
71    async fn _send_handshake(writer: &mut WriteHalf<S>, key: Option<String>) {
72        let mut handshake = vec![
73            "HTTP/1.1 101 Switching Protocols".to_string(),
74            "Upgrade: websocket".to_string(),
75            "Connection: Upgrade".to_string(),
76        ];
77
78        if let Some(key) = key {
79            let guid = [key.as_bytes(), b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"].concat();
80            let sha1 = Sha1::from(guid).digest().bytes();
81            let key = format!("Sec-Websocket-Accept: {}", encode(sha1));
82
83            handshake.push(key);
84        }
85
86        handshake.push("\r\n".to_string());
87
88        writer.write_all(handshake.join("\r\n").as_bytes()).await.unwrap();
89    }
90
91    /// Wait for next frame to come. (support incoming fragmented frames)
92    pub async fn next(&mut self) -> Option<Message> {
93        while let Some(Ok(msg)) = self.reader.next().await {
94            if msg.opcode == Opcode::Text {
95                return Some(msg)
96            }
97
98            if msg.opcode == Opcode::Close {
99                self.tx.send(Message::close()).await.unwrap();
100                return None
101            }
102
103            else if msg.opcode == Opcode::Ping {
104                self.tx.send(Message::pong()).await.unwrap();
105            }
106        }
107
108        None
109    }
110
111    /// Send a text frame
112    pub async fn send_text<T: AsRef<str>>(&mut self, t: T) -> WebsocketResult<()> {
113        Ok(self.tx.send(Message::text(t.as_ref())).await?)
114    }
115}