1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#![feature(test)]
use std::marker::Unpin;
use base64::encode;
use futures::sink::SinkExt;
use http::HeaderValue;
use sha1::Sha1;
use tokio::io::{ AsyncRead, AsyncWrite, AsyncWriteExt };
use tokio::io::{ split, ReadHalf, WriteHalf };
use tokio::stream::StreamExt;
use tokio::sync::mpsc;
use tokio_util::codec::{ FramedRead, FramedWrite };
use tracing::trace;
pub mod error;
pub use error::{ WebsocketError, WebsocketResult };
pub mod codec;
pub use codec::WebsocketCodec;
pub mod opcode;
pub use opcode::Opcode;
pub mod message;
pub use message::Message;
#[derive(Debug)]
pub struct Websocket<S> {
reader: FramedRead<ReadHalf<S>, WebsocketCodec>,
tx: mpsc::Sender<Message>,
}
impl<S: 'static> Websocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
pub async fn new(key: &HeaderValue, stream: S) -> WebsocketResult<Self> {
let (reader, mut writer) = split(stream);
Self::_send_handshake(key, &mut writer).await;
let reader = FramedRead::new(reader, WebsocketCodec::default());
let mut writer = FramedWrite::new(writer, WebsocketCodec::default());
let (tx, mut rx) = mpsc::channel::<Message>(100);
tokio::spawn(async move {
while let Some(message) = rx.recv().await {
if message.opcode == Opcode::Close { break }
writer.send(message).await.unwrap();
}
trace!("Client disconnected");
});
return Ok(Self {
reader,
tx,
})
}
async fn _send_handshake(key: &HeaderValue, writer: &mut WriteHalf<S>) {
let guid = [key.as_bytes(), b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"].concat();
let sha1 = Sha1::from(guid).digest().bytes();
let key = format!("Sec-WebSocket-Accept: {}", encode(sha1));
let handshake = &[
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
key.as_str(),
"\r\n",
];
let handshake = handshake.join("\r\n");
writer.write_all(handshake.as_bytes()).await.unwrap()
}
pub async fn next(&mut self) -> Option<Message> {
while let Some(Ok(msg)) = self.reader.next().await {
if msg.opcode == Opcode::Text {
return Some(msg)
}
if msg.opcode == Opcode::Close {
self.tx.send(Message::close()).await.unwrap();
return None
}
else if msg.opcode == Opcode::Ping {
self.tx.send(Message::pong()).await.unwrap();
}
}
None
}
pub async fn send_text<T: AsRef<str>>(&mut self, t: T) -> WebsocketResult<()> {
Ok(self.tx.send(Message::text(t.as_ref())).await?)
}
}