socket_flow/
handshake.rs

1use crate::config::{ClientConfig, WebSocketConfig};
2use crate::connection::WSConnection;
3use crate::decoder::Decoder;
4use crate::encoder::Encoder;
5use crate::error::Error;
6use crate::extensions::{add_extension_headers, merge_extensions, parse_extensions, Extensions};
7use crate::message::Message;
8use crate::read::ReadStream;
9use crate::request::{construct_http_request, HttpRequest};
10use crate::split::{WSReader, WSWriter};
11use crate::stream::SocketFlowStream;
12use crate::utils::{generate_websocket_accept_value, generate_websocket_key};
13use crate::write::{Writer, WriterKind};
14use std::fs::File;
15use std::io::BufReader as SyncBufReader;
16use std::path::Path;
17use std::sync::Arc;
18use tokio::io::{split, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
19use tokio::net::TcpStream;
20use tokio::sync::mpsc::channel;
21use tokio::sync::Mutex;
22use tokio_rustls::{TlsConnector, TlsStream};
23use tokio_stream::wrappers::ReceiverStream;
24
25pub(crate) const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
26        Connection: Upgrade\r\n\
27        Upgrade: websocket\r\n\
28        Sec-WebSocket-Accept: {}\r\n\
29        ";
30
31const HTTP_METHOD: &str = "GET";
32pub(crate) const SEC_WEBSOCKET_KEY: &str = "sec-websocket-key";
33pub(crate) const SEC_WEBSOCKET_EXTENSIONS: &str = "sec-websocket-extensions";
34pub(crate) const SEC_WEBSOCKET_ACCEPT: &str = "sec-websocket-accept";
35const HOST: &str = "host";
36
37pub type Result = std::result::Result<WSConnection, Error>;
38
39/// Used for accepting websocket connections as a server.
40///
41/// It basically does the first step of verifying the client key in the request
42/// going to the second step, which is sending the acceptance response,
43/// finally creating the connection, and returning a `WSConnection`.
44pub async fn accept_async(stream: SocketFlowStream) -> Result {
45    accept_async_with_config(stream, None).await
46}
47
48/// Same as accept_async, with an additional argument for custom websocket connection configurations.
49pub async fn accept_async_with_config(
50    stream: SocketFlowStream,
51    config: Option<WebSocketConfig>,
52) -> Result {
53    let (reader, mut write_half) = split(stream);
54    let mut buf_reader = BufReader::new(reader);
55
56    let mut config = config.unwrap_or_default();
57    let parsed_extensions =
58        parse_handshake_server(&mut buf_reader, &mut write_half, config.extensions).await?;
59    config.extensions = parsed_extensions;
60
61    let decoder_extensions = config.extensions.clone().unwrap_or_default();
62    // The decoder will be reading and decompressing all client messages,
63    // so we need to pass all the client extensions to it
64    let decoder = Decoder::new(
65        decoder_extensions
66            .client_no_context_takeover
67            .unwrap_or_default(),
68        decoder_extensions.client_max_window_bits,
69    );
70
71    let encoder_extensions = config.extensions.clone().unwrap_or_default();
72    let encoder = Encoder::new(
73        encoder_extensions
74            .server_no_context_takeover
75            .unwrap_or_default(),
76        encoder_extensions.server_max_window_bits,
77    );
78
79    // Identify permessage-deflate for enabling compression
80    second_stage_handshake(
81        buf_reader,
82        write_half,
83        WriterKind::Server,
84        config,
85        decoder,
86        encoder,
87    )
88        .await
89}
90
91async fn second_stage_handshake(
92    buf_reader: BufReader<ReadHalf<SocketFlowStream>>,
93    write_half: WriteHalf<SocketFlowStream>,
94    kind: WriterKind,
95    config: WebSocketConfig,
96    decoder: Decoder,
97    encoder: Encoder,
98) -> Result {
99    // This writer instance would be used for writing frames into the socket.
100    // Since it's going to be used by two different instances, we need to wrap it through an Arc
101    let writer = Arc::new(Mutex::new(Writer::new(write_half, kind)));
102
103    let stream_writer = writer.clone();
104
105    // ReadStream will be running on a separate task, capturing all the incoming frames from the connection, and broadcasting them through this
106    // tokio mpsc channel. Therefore, it can be consumed by the end-user of this library
107    let (read_tx, read_rx) = channel::<std::result::Result<Message, Error>>(20);
108    let mut read_stream =
109        ReadStream::new(buf_reader, read_tx, stream_writer, config.clone(), decoder);
110
111    let connection_writer = writer.clone();
112    // Transforming the receiver of the channel into a Stream, so we could leverage using
113    // next() method, for processing the values from this channel
114    let receiver_stream = ReceiverStream::new(read_rx);
115
116    // The WSConnection is the structure that will be delivered to the end-user, which contains
117    // a stream of frames, for consuming the incoming frames, and methods for writing frames into
118    // the socket
119    let ws_connection = WSConnection::new(
120        WSWriter::new(connection_writer, config, encoder),
121        WSReader::new(receiver_stream),
122    );
123
124    // Spawning poll_messages which is the method for reading the frames from the socket concurrently,
125    // because we need this method running, while the end-user can have
126    // a connection returned, for receiving and sending messages.
127    // Since this is the only task that holds the ownership of BufReader, if some IO error happens,
128    // poll_messages will return.
129    // BufReader will be dropped, hence, the writeHalf and TCP connection
130    tokio::spawn(async move {
131        if let Err(err) = read_stream.poll_messages().await {
132            let _ = read_stream.read_tx.send(Err(err)).await;
133        }
134    });
135
136    Ok(ws_connection)
137}
138
139/// Used for connecting as a client to a websocket endpoint.
140///
141/// It basically does the first step of generating the client key
142/// going to the second step, which is parsing the server response,
143/// finally creating the connection, and returning a `WSConnection`.
144pub async fn connect_async(addr: &str) -> Result {
145    connect_async_with_config(addr, None).await
146}
147
148/// Same as connect_async, with an additional argument for custom websocket connection configurations.
149pub async fn connect_async_with_config(addr: &str, client_config: Option<ClientConfig>) -> Result {
150    let client_websocket_key = generate_websocket_key();
151
152    let client_extensions = client_config.clone().unwrap_or_default().web_socket_config.extensions;
153
154    let (request, hostname, host, use_tls) = construct_http_request(addr, &client_websocket_key, client_extensions)?;
155
156    let stream = TcpStream::connect(hostname).await?;
157
158    let maybe_ca_file = client_config.clone().unwrap_or_default().ca_file;
159    let maybe_tls = if use_tls {
160        // Creating a cert store, to inject the TLS certificates
161        let mut root_cert_store = rustls::RootCertStore::empty();
162
163        // In the case you are using self-signed certificates on the server
164        // you are trying to connect, you must indicate a CA certificate of this server
165        // when connecting to it.
166        // Since the server has a self-signed cert, the only way of this library validating
167        // the cert is adding as an argument of the connect_async function
168        if let Some(file) = maybe_ca_file {
169            let mut pem = SyncBufReader::new(File::open(Path::new(file.as_str()))?);
170            for cert in rustls_pemfile::certs(&mut pem) {
171                root_cert_store.add(cert?).unwrap();
172            }
173        } else {
174            // Here we are adding TLS_SERVER_ROOTS to the certificate store,
175            // which is basically a reference to a list of trusted root certificates
176            // issue by a CA.
177            // In the case, you are establishing a connection with a server
178            // that has a valid trusted certificate.
179            // You won't need a CA file
180            root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
181        }
182
183        let config = rustls::ClientConfig::builder()
184            .with_root_certificates(root_cert_store)
185            .with_no_client_auth();
186        let connector = TlsConnector::from(Arc::new(config));
187
188        let domain = pki_types::ServerName::try_from(host)?;
189        let tls_stream = connector.connect(domain, stream).await?;
190        SocketFlowStream::Secure(TlsStream::from(tls_stream))
191    } else {
192        SocketFlowStream::Plain(stream)
193    };
194
195    let (reader, mut write_half) = split(maybe_tls);
196    let mut buf_reader = BufReader::new(reader);
197
198    write_half.write_all(request.as_bytes()).await?;
199
200    let mut config = client_config.unwrap_or_default().web_socket_config;
201    let extensions = parse_handshake_client(&mut buf_reader, client_websocket_key).await?;
202    config.extensions = extensions;
203
204    let decoder_extensions = config.extensions.clone().unwrap_or_default();
205    // The decoder will be reading and decompressing all client messages,
206    // so we need to pass all the client extensions to it
207    let decoder = Decoder::new(
208        decoder_extensions
209            .client_no_context_takeover
210            .unwrap_or_default(),
211        decoder_extensions.client_max_window_bits,
212    );
213
214    let encoder_extensions = config.extensions.clone().unwrap_or_default();
215    let encoder = Encoder::new(
216        encoder_extensions
217            .server_no_context_takeover
218            .unwrap_or_default(),
219        encoder_extensions.server_max_window_bits,
220    );
221
222    second_stage_handshake(
223        buf_reader,
224        write_half,
225        WriterKind::Client,
226        config,
227        decoder,
228        encoder,
229    )
230        .await
231}
232
233async fn parse_handshake_server(
234    buf_reader: &mut BufReader<ReadHalf<SocketFlowStream>>,
235    write_half: &mut WriteHalf<SocketFlowStream>,
236    server_extensions: Option<Extensions>,
237) -> std::result::Result<Option<Extensions>, Error> {
238    let mut req = HttpRequest::parse_http_request(buf_reader).await?;
239
240    // Validate the WebSocket handshake
241    if !req.method.eq(HTTP_METHOD) {
242        return Err(Error::InvalidHTTPHandshake);
243    }
244
245    if req.get_header_value(HOST).is_none() {
246        return Err(Error::NoHostHeaderPresent);
247    }
248
249    let sec_websocket_key = match req.get_header_value(SEC_WEBSOCKET_KEY) {
250        Some(key) => key.to_string(),
251        None => Err(Error::NoSecWebsocketKey)?,
252    };
253
254    let client_extensions = parse_extensions(
255        req.get_header_value(SEC_WEBSOCKET_EXTENSIONS)
256            .unwrap_or_default(),
257    );
258    let agreed_extensions = merge_extensions(server_extensions, client_extensions);
259
260    let accept_key = generate_websocket_accept_value(sec_websocket_key);
261
262    let mut response = HTTP_ACCEPT_RESPONSE.replace("{}", &accept_key);
263    add_extension_headers(&mut response, agreed_extensions.clone());
264
265    write_half
266        .write_all(response.as_bytes())
267        .await
268        .map_err(|source| Error::IOError { source })?;
269    write_half.flush().await?;
270
271    Ok(agreed_extensions)
272}
273
274async fn parse_handshake_client(
275    buf_reader: &mut BufReader<ReadHalf<SocketFlowStream>>,
276    client_websocket_key: String,
277) -> std::result::Result<Option<Extensions>, Error> {
278    let mut req = HttpRequest::parse_http_request(buf_reader).await?;
279
280    let expected_accept_value = generate_websocket_accept_value(client_websocket_key);
281
282    // Some websockets server returns the SEC_WEBSOCKET_ACCEPT header, as lowercase.
283    // Therefore, we need to cover both cases, for the sake of having support, even though it's
284    // out of RFC standards
285    let sec_websocket_accept =  req.get_header_value(SEC_WEBSOCKET_ACCEPT).unwrap_or_default();
286    if !sec_websocket_accept.contains(&expected_accept_value) {
287        return Err(Error::InvalidAcceptKey);
288    }
289
290    let extensions = parse_extensions(
291        req.get_header_value(SEC_WEBSOCKET_EXTENSIONS)
292            .unwrap_or_default(),
293    );
294
295    Ok(extensions)
296}