1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#![feature(test)]
use std::marker::Unpin;
use base64::encode;
use futures::sink::SinkExt;
use sha1::Sha1;
use tokio::io::{ AsyncRead, AsyncWrite, AsyncWriteExt };
use tokio::io::{ split, ReadHalf, WriteHalf };
use tokio::stream::StreamExt;
use tokio::sync::mpsc;
use tokio_util::codec::{ FramedRead, FramedWrite };
use tracing::trace;
pub mod error;
pub mod opcode;
pub mod codec;
pub mod message;
pub use error::{ WebsocketError, WebsocketResult };
pub use codec::WebsocketCodec;
pub use opcode::Opcode;
pub use message::Message;
#[derive(Debug)]
pub struct Websocket<S> {
reader: FramedRead<ReadHalf<S>, WebsocketCodec>,
tx: mpsc::Sender<Message>,
key: Option<String>,
}
impl<S: 'static> Websocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
pub fn new(stream: S) -> Self {
Self::_create(stream, None)
}
pub fn new_with_key(stream: S, key: String) -> Self {
Self::_create(stream, Some(key))
}
fn _create(stream: S, key: Option<String>) -> Self {
let (reader, mut writer) = split(stream);
let reader = FramedRead::new(reader, WebsocketCodec::default());
let (tx, mut rx) = mpsc::channel::<Message>(100);
tokio::spawn(async move {
Self::_send_handshake(&mut writer, key).await;
let mut writer = FramedWrite::new(writer, WebsocketCodec::default());
while let Some(message) = rx.recv().await {
if message.opcode == Opcode::Close { break }
writer.send(message).await.unwrap();
}
trace!("Client disconnected");
});
Self { reader, tx, key: None }
}
async fn _send_handshake(writer: &mut WriteHalf<S>, key: Option<String>) {
let mut handshake = vec![
"HTTP/1.1 101 Switching Protocols".to_string(),
"Upgrade: websocket".to_string(),
"Connection: Upgrade".to_string(),
];
if let Some(key) = key {
let guid = [key.as_bytes(), b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"].concat();
let sha1 = Sha1::from(guid).digest().bytes();
let key = format!("Sec-Websocket-Accept: {}", encode(sha1));
handshake.push(key);
}
handshake.push("\r\n".to_string());
writer.write_all(handshake.join("\r\n").as_bytes()).await.unwrap();
}
pub async fn next(&mut self) -> Option<Message> {
while let Some(Ok(msg)) = self.reader.next().await {
if msg.opcode == Opcode::Text {
return Some(msg)
}
if msg.opcode == Opcode::Close {
self.tx.send(Message::close()).await.unwrap();
return None
}
else if msg.opcode == Opcode::Ping {
self.tx.send(Message::pong()).await.unwrap();
}
}
None
}
pub async fn send_text<T: AsRef<str>>(&mut self, t: T) -> WebsocketResult<()> {
Ok(self.tx.send(Message::text(t.as_ref())).await?)
}
}