simple_websocket/
handshake.rs1use crate::connection::WSConnection;
2use crate::error::{HandshakeError, StreamError};
3use crate::frame::Frame;
4use crate::read::{ReadStream, StreamKind};
5use crate::write::WriteStream;
6use base64::prelude::BASE64_STANDARD;
7use base64::prelude::*;
8use bytes::BytesMut;
9use rand::random;
10use sha1::{Digest, Sha1};
11use std::sync::Arc;
12use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::channel;
15use tokio::sync::Mutex;
16use tokio::time::{timeout, Duration};
17
18const SEC_WEBSOCKETS_KEY: &str = "Sec-WebSocket-Key:";
19const UUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
20const SWITCHING_PROTOCOLS: &str = "101 Switching Protocols";
21
22const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
23 Connection: Upgrade\r\n\
24 Upgrade: websocket\r\n\
25 Sec-WebSocket-Accept: {}\r\n\
26 \r\n";
27
28const HTTP_HANDSHAKE_REQUEST: &str = "GET / HTTP/1.1\r\n\
29 Host: {host}\r\n\
30 Connection: Upgrade\r\n\
31 Upgrade: websocket\r\n\
32 Sec-WebSocket-Key: {key}\r\n\
33 Sec-WebSocket-Version: 13\r\n\
34 Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\
35 \r\n";
36pub type Result = std::result::Result<WSConnection, HandshakeError>;
37
38pub async fn perform_handshake<T: AsyncRead + AsyncWrite + Send + 'static>(stream: T) -> Result {
43 let (reader, mut writer) = split(stream);
44 let mut buf_reader = BufReader::new(reader);
45
46 let sec_websockets_accept = header_read(&mut buf_reader).await;
47
48 match sec_websockets_accept {
49 Some(accept_value) => {
50 let response = HTTP_ACCEPT_RESPONSE.replace("{}", &accept_value);
51 writer
52 .write_all(response.as_bytes())
53 .await
54 .map_err(|source| HandshakeError::IOError { source })?
55 }
56 None => Err(HandshakeError::NoSecWebsocketKey)?,
57 }
58
59 second_stage_handshake(StreamKind::Server, buf_reader, writer).await
60}
61
62async fn second_stage_handshake<
63 R: AsyncReadExt + Send + Unpin + 'static,
64 W: AsyncWriteExt + Send + Unpin + 'static,
65>(
66 kind: StreamKind,
67 buf_reader: R,
68 writer: W,
69) -> Result {
70 let (write_tx, write_rx) = channel::<Frame>(20);
75
76 let (read_tx, read_rx) = channel::<std::result::Result<Vec<u8>, StreamError>>(20);
77 let read_tx = Arc::new(Mutex::new(read_tx));
78
79 let (internal_tx, internal_rx) = channel::<Frame>(20);
81 let (close_tx, close_rx) = channel::<bool>(1);
82
83 let read_tx_stream = read_tx.clone();
84 let mut read_stream = ReadStream::new(kind, buf_reader, read_tx_stream, internal_tx, close_tx);
87 let mut write_stream = WriteStream::new(writer, write_rx, internal_rx);
88
89 let ws_connection = WSConnection::new(read_rx, write_tx, close_rx);
90
91 let read_tx_r = read_tx.clone();
92 tokio::spawn(async move {
103 if let Err(err) = read_stream.poll_messages().await {
104 read_tx_r.lock().await.send(Err(err)).await.unwrap();
105 }
106 });
107
108 let read_tx_w = read_tx.clone();
109 tokio::spawn(async move {
110 if let Err(err) = write_stream.run().await {
111 read_tx_w.lock().await.send(Err(err)).await.unwrap()
112 }
113 drop(read_tx_w);
114 });
115
116 Ok(ws_connection)
117}
118
119pub async fn perform_client_handshake(stream: TcpStream) -> Result {
120 let client_websocket_key = generate_websocket_key();
121 let request = HTTP_HANDSHAKE_REQUEST
122 .replace("{key}", &client_websocket_key)
123 .replace("{host}", &stream.local_addr().unwrap().to_string());
124
125 let (reader, mut writer) = split(stream);
126 let mut buf_reader = BufReader::new(reader);
127
128 writer.write_all(request.as_bytes()).await?;
129
130 let mut buffer: Vec<u8> = vec![0; 1024];
134
135 let number_read = buf_reader.read(&mut buffer).await?;
137
138 buffer.truncate(number_read);
140
141 let response = String::from_utf8(buffer)?;
143
144 if !response.contains(SWITCHING_PROTOCOLS) {
146 return Err(HandshakeError::NoUpgrade);
147 }
148
149 let expected_accept_value = generate_websocket_accept_value(client_websocket_key);
151 if !response.contains(&expected_accept_value) {
152 return Err(HandshakeError::InvalidAcceptKey);
153 }
154
155 second_stage_handshake(StreamKind::Client, buf_reader, writer).await
156}
157
158async fn header_read<T: AsyncReadExt + Unpin>(buf_reader: &mut T) -> Option<String> {
164 let mut websocket_header: Option<String> = None;
165 let mut websocket_accept: Option<String> = None;
166 let mut header_buf = BytesMut::with_capacity(1024 * 16); while header_buf.len() <= 1024 * 16 {
170 let mut tmp_buf = vec![0; 1024];
171 match timeout(Duration::from_secs(10), buf_reader.read(&mut tmp_buf)).await {
172 Ok(Ok(0)) | Err(_) => break, Ok(Ok(n)) => {
175 header_buf.extend_from_slice(&tmp_buf[..n]);
176 let s = String::from_utf8_lossy(&header_buf);
177 if let Some(start) = s.find(SEC_WEBSOCKETS_KEY) {
178 websocket_header = Some(s[start..].lines().next().unwrap().to_string());
179 break;
180 }
181 }
182 _ => {}
183 }
184 }
185
186 if let Some(header) = websocket_header {
187 if let Some(key) = parse_websocket_key(header) {
188 websocket_accept = Some(generate_websocket_accept_value(key));
189 }
190 }
191
192 websocket_accept
193}
194
195fn parse_websocket_key(header: String) -> Option<String> {
196 for line in header.lines() {
197 if line.starts_with(SEC_WEBSOCKETS_KEY) {
198 if let Some(stripped) = line.strip_prefix(SEC_WEBSOCKETS_KEY) {
199 return stripped.split_whitespace().next().map(ToOwned::to_owned);
200 }
201 }
202 }
203 None
204}
205
206fn generate_websocket_accept_value(key: String) -> String {
207 let mut sha1 = Sha1::new();
208 sha1.update(key.as_bytes());
209 sha1.update(UUID.as_bytes());
210 BASE64_STANDARD.encode(sha1.finalize())
211}
212
213fn generate_websocket_key() -> String {
214 let random_bytes: [u8; 16] = random();
215 BASE64_STANDARD.encode(random_bytes)
216}