simple_websocket/
handshake.rs

1use crate::connection::WSConnection;
2use crate::error::{HandshakeError, StreamError};
3use crate::frame::Frame;
4use crate::read::{ReadStream, StreamKind};
5use crate::write::WriteStream;
6use base64::prelude::BASE64_STANDARD;
7use base64::prelude::*;
8use bytes::BytesMut;
9use rand::random;
10use sha1::{Digest, Sha1};
11use std::sync::Arc;
12use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::channel;
15use tokio::sync::Mutex;
16use tokio::time::{timeout, Duration};
17
18const SEC_WEBSOCKETS_KEY: &str = "Sec-WebSocket-Key:";
19const UUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
20const SWITCHING_PROTOCOLS: &str = "101 Switching Protocols";
21
22const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
23        Connection: Upgrade\r\n\
24        Upgrade: websocket\r\n\
25        Sec-WebSocket-Accept: {}\r\n\
26        \r\n";
27
28const HTTP_HANDSHAKE_REQUEST: &str = "GET / HTTP/1.1\r\n\
29        Host: {host}\r\n\
30        Connection: Upgrade\r\n\
31        Upgrade: websocket\r\n\
32        Sec-WebSocket-Key: {key}\r\n\
33        Sec-WebSocket-Version: 13\r\n\
34        Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\
35        \r\n";
36pub type Result = std::result::Result<WSConnection, HandshakeError>;
37
38// Using Send trait because we are going to run the process to read frames from the socket concurrently
39// TCPStream from tokio implements Send
40// Using static, because tokio::spawn returns a JoinHandle, because the spawned task could outilive the
41// lifetime of the function call to tokio::spawn.
42pub async fn perform_handshake<T: AsyncRead + AsyncWrite + Send + 'static>(stream: T) -> Result {
43    let (reader, mut writer) = split(stream);
44    let mut buf_reader = BufReader::new(reader);
45
46    let sec_websockets_accept = header_read(&mut buf_reader).await;
47
48    match sec_websockets_accept {
49        Some(accept_value) => {
50            let response = HTTP_ACCEPT_RESPONSE.replace("{}", &accept_value);
51            writer
52                .write_all(response.as_bytes())
53                .await
54                .map_err(|source| HandshakeError::IOError { source })?
55        }
56        None => Err(HandshakeError::NoSecWebsocketKey)?,
57    }
58
59    second_stage_handshake(StreamKind::Server, buf_reader, writer).await
60}
61
62async fn second_stage_handshake<
63    R: AsyncReadExt + Send + Unpin + 'static,
64    W: AsyncWriteExt + Send + Unpin + 'static,
65>(
66    kind: StreamKind,
67    buf_reader: R,
68    writer: W,
69) -> Result {
70    // We are using tokio async channels to communicate the frames received from the client
71    // and another channel to send messages from server to client
72    // TODO - Check if 20 is a good number for Buffer size, remembering that channel is async, so if it's full
73    // all the callers that are trying to add new data, will be blocked until we have free space (off course, using await in the method)
74    let (write_tx, write_rx) = channel::<Frame>(20);
75
76    let (read_tx, read_rx) = channel::<std::result::Result<Vec<u8>, StreamError>>(20);
77    let read_tx = Arc::new(Mutex::new(read_tx));
78
79    // These internal channels are used to communicate between write and read stream
80    let (internal_tx, internal_rx) = channel::<Frame>(20);
81    let (close_tx, close_rx) = channel::<bool>(1);
82
83    let read_tx_stream = read_tx.clone();
84    // We are separating the stream in read and write, because handling them in the same struct, would need us to
85    // wrap some references with Arc<mutex>, and for the sake of a clean syntax, we selected to split it
86    let mut read_stream = ReadStream::new(kind, buf_reader, read_tx_stream, internal_tx, close_tx);
87    let mut write_stream = WriteStream::new(writer, write_rx, internal_rx);
88
89    let ws_connection = WSConnection::new(read_rx, write_tx, close_rx);
90
91    let read_tx_r = read_tx.clone();
92    // We are spawning poll_messages which is the method for reading the frames from the socket
93    // we need to do it concurrently, because we need this method running, while the end-user can have
94    // a channel returned, for receiving and sending messages
95    // Since ReadHalf and WriteHalf implements Send and Sync, it's ok to send them over spawn
96    // Additionally, since our BufReader doesn't change, we only call read methods from it, there is no
97    // need to wrap it in an Arc<Mutex>, also because poll_messages read frames sequentially.
98    // Also, since this is the only task that holds the ownership of BufReader, if some IO error happens,
99    // poll_messages will return, and since BufReader is only inside the scope of the function, it will be dropped
100    // dropping the WriteHalf, hence, the TCP connection
101
102    tokio::spawn(async move {
103        if let Err(err) = read_stream.poll_messages().await {
104            read_tx_r.lock().await.send(Err(err)).await.unwrap();
105        }
106    });
107
108    let read_tx_w = read_tx.clone();
109    tokio::spawn(async move {
110        if let Err(err) = write_stream.run().await {
111            read_tx_w.lock().await.send(Err(err)).await.unwrap()
112        }
113        drop(read_tx_w);
114    });
115
116    Ok(ws_connection)
117}
118
119pub async fn perform_client_handshake(stream: TcpStream) -> Result {
120    let client_websocket_key = generate_websocket_key();
121    let request = HTTP_HANDSHAKE_REQUEST
122        .replace("{key}", &client_websocket_key)
123        .replace("{host}", &stream.local_addr().unwrap().to_string());
124
125    let (reader, mut writer) = split(stream);
126    let mut buf_reader = BufReader::new(reader);
127
128    writer.write_all(request.as_bytes()).await?;
129
130    // Create a buffer for the server's response, since most of the Websocket won't send a big payload
131    // for the handshake response, defining this size of Vector would be enough, and also will put a limit
132    // to bigger payloads
133    let mut buffer: Vec<u8> = vec![0; 1024];
134
135    // Read the server's response
136    let number_read = buf_reader.read(&mut buffer).await?;
137
138    // Keep only the section of the buffer that was filled.
139    buffer.truncate(number_read);
140
141    // Convert the server's response to a string
142    let response = String::from_utf8(buffer)?;
143
144    // Verify that the server agreed to upgrade the connection
145    if !response.contains(SWITCHING_PROTOCOLS) {
146        return Err(HandshakeError::NoUpgrade);
147    }
148
149    // Generate the server expected accept key using UUID, and checking if it's present in the response
150    let expected_accept_value = generate_websocket_accept_value(client_websocket_key);
151    if !response.contains(&expected_accept_value) {
152        return Err(HandshakeError::InvalidAcceptKey);
153    }
154
155    second_stage_handshake(StreamKind::Client, buf_reader, writer).await
156}
157
158// Here we are using the generic T, and expressing its two tokio traits, to avoiding adding the
159// entire type of the argument in the function signature (BufReader<ReadHalf<TcpStream>>)
160// The Unpin trait in Rust is used when the exact location of an object in memory needs to remain
161// constant after being pinned. In simple terms, it means that the object doesn't move around in memory
162// Here, we need to use Unpin, because the timeout function puts the passed Future into a Pin<Box<dyn Future>>
163async fn header_read<T: AsyncReadExt + Unpin>(buf_reader: &mut T) -> Option<String> {
164    let mut websocket_header: Option<String> = None;
165    let mut websocket_accept: Option<String> = None;
166    let mut header_buf = BytesMut::with_capacity(1024 * 16); // 16 kilobytes
167
168    // Limit the maximum amount of data read to prevent a denial of service attack.
169    while header_buf.len() <= 1024 * 16 {
170        let mut tmp_buf = vec![0; 1024];
171        match timeout(Duration::from_secs(10), buf_reader.read(&mut tmp_buf)).await {
172            Ok(Ok(0)) | Err(_) => break, // EOF reached or Timeout, we stop. In the case of EOF
173            // there is no need to log or return EOF or timeout errors
174            Ok(Ok(n)) => {
175                header_buf.extend_from_slice(&tmp_buf[..n]);
176                let s = String::from_utf8_lossy(&header_buf);
177                if let Some(start) = s.find(SEC_WEBSOCKETS_KEY) {
178                    websocket_header = Some(s[start..].lines().next().unwrap().to_string());
179                    break;
180                }
181            }
182            _ => {}
183        }
184    }
185
186    if let Some(header) = websocket_header {
187        if let Some(key) = parse_websocket_key(header) {
188            websocket_accept = Some(generate_websocket_accept_value(key));
189        }
190    }
191
192    websocket_accept
193}
194
195fn parse_websocket_key(header: String) -> Option<String> {
196    for line in header.lines() {
197        if line.starts_with(SEC_WEBSOCKETS_KEY) {
198            if let Some(stripped) = line.strip_prefix(SEC_WEBSOCKETS_KEY) {
199                return stripped.split_whitespace().next().map(ToOwned::to_owned);
200            }
201        }
202    }
203    None
204}
205
206fn generate_websocket_accept_value(key: String) -> String {
207    let mut sha1 = Sha1::new();
208    sha1.update(key.as_bytes());
209    sha1.update(UUID.as_bytes());
210    BASE64_STANDARD.encode(sha1.finalize())
211}
212
213fn generate_websocket_key() -> String {
214    let random_bytes: [u8; 16] = random();
215    BASE64_STANDARD.encode(random_bytes)
216}