1#![feature(test)]
2
3use std::marker::Unpin;
4
5use base64::encode;
6use futures::sink::SinkExt;
7use sha1::Sha1;
8use tokio::io::{ AsyncRead, AsyncWrite, AsyncWriteExt };
9use tokio::io::{ split, ReadHalf, WriteHalf };
10use tokio::stream::StreamExt;
11use tokio::sync::mpsc;
12use tokio_util::codec::{ FramedRead, FramedWrite };
13use tracing::trace;
14
15pub mod error;
16pub mod opcode;
17pub mod codec;
18pub mod message;
19
20pub use error::{ WebsocketError, WebsocketResult };
21pub use codec::WebsocketCodec;
22pub use opcode::Opcode;
23pub use message::Message;
24
25#[derive(Debug)]
26pub struct Websocket<S> {
27 reader: FramedRead<ReadHalf<S>, WebsocketCodec>,
28 tx: mpsc::Sender<Message>,
29 key: Option<String>,
30}
31
32impl<S: 'static> Websocket<S>
33where
34 S: AsyncRead + AsyncWrite + Unpin + Send,
35{
36 pub fn new(stream: S) -> Self {
39 Self::_create(stream, None)
40 }
41
42 pub fn new_with_key(stream: S, key: String) -> Self {
46 Self::_create(stream, Some(key))
47 }
48
49 fn _create(stream: S, key: Option<String>) -> Self {
50 let (reader, mut writer) = split(stream);
51 let reader = FramedRead::new(reader, WebsocketCodec::default());
52 let (tx, mut rx) = mpsc::channel::<Message>(100);
53
54 tokio::spawn(async move {
55 Self::_send_handshake(&mut writer, key).await;
56
57 let mut writer = FramedWrite::new(writer, WebsocketCodec::default());
58
59 while let Some(message) = rx.recv().await {
60 if message.opcode == Opcode::Close { break }
61 writer.send(message).await.unwrap();
62 }
63
64 trace!("Client disconnected");
65 });
66
67 Self { reader, tx, key: None }
68 }
69
70 async fn _send_handshake(writer: &mut WriteHalf<S>, key: Option<String>) {
72 let mut handshake = vec![
73 "HTTP/1.1 101 Switching Protocols".to_string(),
74 "Upgrade: websocket".to_string(),
75 "Connection: Upgrade".to_string(),
76 ];
77
78 if let Some(key) = key {
79 let guid = [key.as_bytes(), b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"].concat();
80 let sha1 = Sha1::from(guid).digest().bytes();
81 let key = format!("Sec-Websocket-Accept: {}", encode(sha1));
82
83 handshake.push(key);
84 }
85
86 handshake.push("\r\n".to_string());
87
88 writer.write_all(handshake.join("\r\n").as_bytes()).await.unwrap();
89 }
90
91 pub async fn next(&mut self) -> Option<Message> {
93 while let Some(Ok(msg)) = self.reader.next().await {
94 if msg.opcode == Opcode::Text {
95 return Some(msg)
96 }
97
98 if msg.opcode == Opcode::Close {
99 self.tx.send(Message::close()).await.unwrap();
100 return None
101 }
102
103 else if msg.opcode == Opcode::Ping {
104 self.tx.send(Message::pong()).await.unwrap();
105 }
106 }
107
108 None
109 }
110
111 pub async fn send_text<T: AsRef<str>>(&mut self, t: T) -> WebsocketResult<()> {
113 Ok(self.tx.send(Message::text(t.as_ref())).await?)
114 }
115}