1use crate::config::{ClientConfig, WebSocketConfig};
2use crate::connection::WSConnection;
3use crate::decoder::Decoder;
4use crate::encoder::Encoder;
5use crate::error::Error;
6use crate::extensions::{add_extension_headers, merge_extensions, parse_extensions, Extensions};
7use crate::message::Message;
8use crate::read::ReadStream;
9use crate::request::{construct_http_request, HttpRequest};
10use crate::split::{WSReader, WSWriter};
11use crate::stream::SocketFlowStream;
12use crate::utils::{generate_websocket_accept_value, generate_websocket_key};
13use crate::write::{Writer, WriterKind};
14use std::fs::File;
15use std::io::BufReader as SyncBufReader;
16use std::path::Path;
17use std::sync::Arc;
18use tokio::io::{split, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
19use tokio::net::TcpStream;
20use tokio::sync::mpsc::channel;
21use tokio::sync::Mutex;
22use tokio_rustls::{TlsConnector, TlsStream};
23use tokio_stream::wrappers::ReceiverStream;
24
25pub(crate) const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
26 Connection: Upgrade\r\n\
27 Upgrade: websocket\r\n\
28 Sec-WebSocket-Accept: {}\r\n\
29 ";
30
31const HTTP_METHOD: &str = "GET";
32pub(crate) const SEC_WEBSOCKET_KEY: &str = "sec-websocket-key";
33pub(crate) const SEC_WEBSOCKET_EXTENSIONS: &str = "sec-websocket-extensions";
34pub(crate) const SEC_WEBSOCKET_ACCEPT: &str = "sec-websocket-accept";
35const HOST: &str = "host";
36
37pub type Result = std::result::Result<WSConnection, Error>;
38
39pub async fn accept_async(stream: SocketFlowStream) -> Result {
45 accept_async_with_config(stream, None).await
46}
47
48pub async fn accept_async_with_config(
50 stream: SocketFlowStream,
51 config: Option<WebSocketConfig>,
52) -> Result {
53 let (reader, mut write_half) = split(stream);
54 let mut buf_reader = BufReader::new(reader);
55
56 let mut config = config.unwrap_or_default();
57 let parsed_extensions =
58 parse_handshake_server(&mut buf_reader, &mut write_half, config.extensions).await?;
59 config.extensions = parsed_extensions;
60
61 let decoder_extensions = config.extensions.clone().unwrap_or_default();
62 let decoder = Decoder::new(
65 decoder_extensions
66 .client_no_context_takeover
67 .unwrap_or_default(),
68 decoder_extensions.client_max_window_bits,
69 );
70
71 let encoder_extensions = config.extensions.clone().unwrap_or_default();
72 let encoder = Encoder::new(
73 encoder_extensions
74 .server_no_context_takeover
75 .unwrap_or_default(),
76 encoder_extensions.server_max_window_bits,
77 );
78
79 second_stage_handshake(
81 buf_reader,
82 write_half,
83 WriterKind::Server,
84 config,
85 decoder,
86 encoder,
87 )
88 .await
89}
90
91async fn second_stage_handshake(
92 buf_reader: BufReader<ReadHalf<SocketFlowStream>>,
93 write_half: WriteHalf<SocketFlowStream>,
94 kind: WriterKind,
95 config: WebSocketConfig,
96 decoder: Decoder,
97 encoder: Encoder,
98) -> Result {
99 let writer = Arc::new(Mutex::new(Writer::new(write_half, kind)));
102
103 let stream_writer = writer.clone();
104
105 let (read_tx, read_rx) = channel::<std::result::Result<Message, Error>>(20);
108 let mut read_stream =
109 ReadStream::new(buf_reader, read_tx, stream_writer, config.clone(), decoder);
110
111 let connection_writer = writer.clone();
112 let receiver_stream = ReceiverStream::new(read_rx);
115
116 let ws_connection = WSConnection::new(
120 WSWriter::new(connection_writer, config, encoder),
121 WSReader::new(receiver_stream),
122 );
123
124 tokio::spawn(async move {
131 if let Err(err) = read_stream.poll_messages().await {
132 let _ = read_stream.read_tx.send(Err(err)).await;
133 }
134 });
135
136 Ok(ws_connection)
137}
138
139pub async fn connect_async(addr: &str) -> Result {
145 connect_async_with_config(addr, None).await
146}
147
148pub async fn connect_async_with_config(addr: &str, client_config: Option<ClientConfig>) -> Result {
150 let client_websocket_key = generate_websocket_key();
151
152 let client_extensions = client_config.clone().unwrap_or_default().web_socket_config.extensions;
153
154 let (request, hostname, host, use_tls) = construct_http_request(addr, &client_websocket_key, client_extensions)?;
155
156 let stream = TcpStream::connect(hostname).await?;
157
158 let maybe_ca_file = client_config.clone().unwrap_or_default().ca_file;
159 let maybe_tls = if use_tls {
160 let mut root_cert_store = rustls::RootCertStore::empty();
162
163 if let Some(file) = maybe_ca_file {
169 let mut pem = SyncBufReader::new(File::open(Path::new(file.as_str()))?);
170 for cert in rustls_pemfile::certs(&mut pem) {
171 root_cert_store.add(cert?).unwrap();
172 }
173 } else {
174 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
181 }
182
183 let config = rustls::ClientConfig::builder()
184 .with_root_certificates(root_cert_store)
185 .with_no_client_auth();
186 let connector = TlsConnector::from(Arc::new(config));
187
188 let domain = pki_types::ServerName::try_from(host)?;
189 let tls_stream = connector.connect(domain, stream).await?;
190 SocketFlowStream::Secure(TlsStream::from(tls_stream))
191 } else {
192 SocketFlowStream::Plain(stream)
193 };
194
195 let (reader, mut write_half) = split(maybe_tls);
196 let mut buf_reader = BufReader::new(reader);
197
198 write_half.write_all(request.as_bytes()).await?;
199
200 let mut config = client_config.unwrap_or_default().web_socket_config;
201 let extensions = parse_handshake_client(&mut buf_reader, client_websocket_key).await?;
202 config.extensions = extensions;
203
204 let decoder_extensions = config.extensions.clone().unwrap_or_default();
205 let decoder = Decoder::new(
208 decoder_extensions
209 .client_no_context_takeover
210 .unwrap_or_default(),
211 decoder_extensions.client_max_window_bits,
212 );
213
214 let encoder_extensions = config.extensions.clone().unwrap_or_default();
215 let encoder = Encoder::new(
216 encoder_extensions
217 .server_no_context_takeover
218 .unwrap_or_default(),
219 encoder_extensions.server_max_window_bits,
220 );
221
222 second_stage_handshake(
223 buf_reader,
224 write_half,
225 WriterKind::Client,
226 config,
227 decoder,
228 encoder,
229 )
230 .await
231}
232
233async fn parse_handshake_server(
234 buf_reader: &mut BufReader<ReadHalf<SocketFlowStream>>,
235 write_half: &mut WriteHalf<SocketFlowStream>,
236 server_extensions: Option<Extensions>,
237) -> std::result::Result<Option<Extensions>, Error> {
238 let mut req = HttpRequest::parse_http_request(buf_reader).await?;
239
240 if !req.method.eq(HTTP_METHOD) {
242 return Err(Error::InvalidHTTPHandshake);
243 }
244
245 if req.get_header_value(HOST).is_none() {
246 return Err(Error::NoHostHeaderPresent);
247 }
248
249 let sec_websocket_key = match req.get_header_value(SEC_WEBSOCKET_KEY) {
250 Some(key) => key.to_string(),
251 None => Err(Error::NoSecWebsocketKey)?,
252 };
253
254 let client_extensions = parse_extensions(
255 req.get_header_value(SEC_WEBSOCKET_EXTENSIONS)
256 .unwrap_or_default(),
257 );
258 let agreed_extensions = merge_extensions(server_extensions, client_extensions);
259
260 let accept_key = generate_websocket_accept_value(sec_websocket_key);
261
262 let mut response = HTTP_ACCEPT_RESPONSE.replace("{}", &accept_key);
263 add_extension_headers(&mut response, agreed_extensions.clone());
264
265 write_half
266 .write_all(response.as_bytes())
267 .await
268 .map_err(|source| Error::IOError { source })?;
269 write_half.flush().await?;
270
271 Ok(agreed_extensions)
272}
273
274async fn parse_handshake_client(
275 buf_reader: &mut BufReader<ReadHalf<SocketFlowStream>>,
276 client_websocket_key: String,
277) -> std::result::Result<Option<Extensions>, Error> {
278 let mut req = HttpRequest::parse_http_request(buf_reader).await?;
279
280 let expected_accept_value = generate_websocket_accept_value(client_websocket_key);
281
282 let sec_websocket_accept = req.get_header_value(SEC_WEBSOCKET_ACCEPT).unwrap_or_default();
286 if !sec_websocket_accept.contains(&expected_accept_value) {
287 return Err(Error::InvalidAcceptKey);
288 }
289
290 let extensions = parse_extensions(
291 req.get_header_value(SEC_WEBSOCKET_EXTENSIONS)
292 .unwrap_or_default(),
293 );
294
295 Ok(extensions)
296}