1#![warn(missing_docs)]
4#![cfg_attr(docrs, feature(doc_auto_cfg))]
5
6use std::collections::HashMap;
7
8pub mod errors;
10pub mod frame;
12pub mod protocol;
14
15pub mod codec;
17pub mod connector;
19
20mod message;
22pub use message::*;
23#[cfg(feature = "simple")]
24pub mod simple;
26#[cfg(feature = "simple")]
27pub use simple::ClientConfig;
28
29pub mod stream;
31
32pub mod extension;
34
35#[derive(Debug, Clone)]
37pub struct ClientBuilder {
38 protocols: Vec<String>,
39 extensions: Vec<String>,
40 #[cfg_attr(not(any(feature = "sync", feature = "async")), allow(dead_code))]
41 version: u8,
42 headers: HashMap<String, String>,
43}
44
45impl Default for ClientBuilder {
46 fn default() -> Self {
47 Self {
48 protocols: vec![],
49 extensions: vec![],
50 headers: HashMap::new(),
51 version: 13,
52 }
53 }
54}
55
56impl ClientBuilder {
57 pub fn new() -> Self {
59 Default::default()
60 }
61
62 pub fn protocol(mut self, protocol: String) -> Self {
64 self.protocols.push(protocol);
65 self
66 }
67
68 pub fn protocols(self, protocols: Vec<String>) -> Self {
72 Self { protocols, ..self }
73 }
74
75 pub fn extension(mut self, extension: String) -> Self {
77 self.extensions.push(extension);
78 self
79 }
80
81 pub fn extensions(self, extensions: Vec<String>) -> Self {
85 Self { extensions, ..self }
86 }
87
88 pub fn version(self, version: u8) -> Self {
90 Self { version, ..self }
91 }
92
93 pub fn header<K: ToString, V: ToString>(mut self, name: K, value: V) -> Self {
95 self.headers.insert(name.to_string(), value.to_string());
96 self
97 }
98
99 pub fn headers(self, headers: HashMap<String, String>) -> Self {
103 Self { headers, ..self }
104 }
105}
106
107#[cfg(feature = "sync")]
108mod blocking {
109 use std::{
110 io::{Read, Write},
111 net::TcpStream,
112 };
113
114 use crate::{
115 connector::{get_scheme, tcp_connect},
116 errors::WsError,
117 protocol::{handle_handshake, req_handshake},
118 ClientBuilder, ServerBuilder,
119 };
120
121 impl ClientBuilder {
122 pub fn connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
124 where
125 F: FnMut(String, http::Response<()>, TcpStream) -> Result<C, WsError>,
126 {
127 let mode = get_scheme(&uri)?;
128 if matches!(mode, crate::protocol::Mode::WSS) {
129 panic!("can not perform ssl connection, use `rustls_connect` or `native_tls_connect` instead");
130 }
131 let stream = tcp_connect(&uri)?;
132 self.with_stream(uri, stream, check_fn)
133 }
134
135 #[cfg(feature = "sync_tls_rustls")]
136 pub fn rustls_connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
138 where
139 F: FnMut(
140 String,
141 http::Response<()>,
142 rustls_connector::rustls::StreamOwned<
143 rustls_connector::rustls::ClientConnection,
144 TcpStream,
145 >,
146 ) -> Result<C, WsError>,
147 {
148 use crate::connector::{get_host, wrap_rustls};
149 let mode = get_scheme(&uri)?;
150 if matches!(mode, crate::protocol::Mode::WSS) {
151 panic!("can not perform not ssl connection, use `connect` instead");
152 }
153 let stream = tcp_connect(&uri)?;
154 let stream = wrap_rustls(stream, get_host(&uri)?, vec![])?;
155 self.with_stream(uri, stream, check_fn)
156 }
157
158 #[cfg(feature = "sync_tls_native")]
159 pub fn native_tls_connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
161 where
162 F: FnMut(
163 String,
164 http::Response<()>,
165 native_tls::TlsStream<TcpStream>,
166 ) -> Result<C, WsError>,
167 {
168 use crate::connector::{get_host, wrap_native_tls};
169 let mode = get_scheme(&uri)?;
170 if matches!(mode, crate::protocol::Mode::WSS) {
171 panic!("can not perform not ssl connection, use `connect` instead");
172 }
173 let stream = tcp_connect(&uri)?;
174 let stream = wrap_native_tls(stream, get_host(&uri)?, vec![])?;
175 self.with_stream(uri, stream, check_fn)
176 }
177
178 pub fn with_stream<C, F, S>(
181 &self,
182 uri: http::Uri,
183 mut stream: S,
184 mut check_fn: F,
185 ) -> Result<C, WsError>
186 where
187 S: Read + Write,
188 F: FnMut(String, http::Response<()>, S) -> Result<C, WsError>,
189 {
190 get_scheme(&uri)?;
191 let (key, resp) = req_handshake(
192 &mut stream,
193 &uri,
194 &self.protocols,
195 &self.extensions,
196 self.version,
197 self.headers.clone(),
198 )?;
199 check_fn(key, resp, stream)
200 }
201 }
202
203 impl ServerBuilder {
204 pub fn accept<F1, F2, T, C, S>(
207 mut stream: S,
208 mut handshake_handler: F1,
209 mut codec_factory: F2,
210 ) -> Result<C, WsError>
211 where
212 S: Read + Write,
213 F1: FnMut(
214 http::Request<()>,
215 ) -> Result<
216 (http::Request<()>, http::Response<T>),
217 (http::Response<T>, WsError),
218 >,
219 F2: FnMut(http::Request<()>, S) -> Result<C, WsError>,
220 T: ToString + std::fmt::Debug,
221 {
222 let req = handle_handshake(&mut stream)?;
223 match handshake_handler(req) {
224 Err((resp, e)) => {
225 write_resp(resp, &mut stream)?;
226 return Err(e);
227 }
228 Ok((req, resp)) => {
229 write_resp(resp, &mut stream)?;
230 codec_factory(req, stream)
231 }
232 }
233 }
234 }
235
236 fn write_resp<S, T>(resp: http::Response<T>, stream: &mut S) -> Result<(), WsError>
237 where
238 S: Read + Write,
239 T: ToString + std::fmt::Debug,
240 {
241 let mut resp_lines = vec![format!("{:?} {}", resp.version(), resp.status())];
242 resp.headers().iter().for_each(|(k, v)| {
243 resp_lines.push(format!("{}: {}", k, v.to_str().unwrap_or_default()))
244 });
245 resp_lines.push("\r\n".to_string());
246 stream.write_all(resp_lines.join("\r\n").as_bytes())?;
247 tracing::debug!("{:?}", &resp);
248 Ok(if resp.status() != http::StatusCode::SWITCHING_PROTOCOLS {
249 return Err(WsError::HandShakeFailed(resp.body().to_string()));
250 })
251 }
252}
253
254#[cfg(feature = "async")]
255mod non_blocking {
256 use http;
257 use std::fmt::Debug;
258
259 use tokio::{
260 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
261 net::TcpStream,
262 };
263
264 use crate::{
265 connector::async_tcp_connect,
266 errors::WsError,
267 protocol::{async_handle_handshake, async_req_handshake},
268 ServerBuilder,
269 };
270
271 use super::ClientBuilder;
272
273 impl ClientBuilder {
274 pub async fn async_connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
276 where
277 F: FnMut(String, http::Response<()>, TcpStream) -> Result<C, WsError>,
278 {
279 let stream = async_tcp_connect(&uri).await?;
280 self.async_with_stream(uri, stream, check_fn).await
281 }
282
283 #[cfg(feature = "async_tls_rustls")]
284 pub async fn async_rustls_connect<C, F>(
286 &self,
287 uri: http::Uri,
288 check_fn: F,
289 ) -> Result<C, WsError>
290 where
291 F: FnMut(
292 String,
293 http::Response<()>,
294 tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
295 ) -> Result<C, WsError>,
296 {
297 use crate::connector::{async_wrap_rustls, get_host};
298 let mode = crate::connector::get_scheme(&uri)?;
299 if matches!(mode, crate::protocol::Mode::WSS) {
300 panic!("can not perform not ssl connection, use `connect` instead");
301 }
302 let stream = async_tcp_connect(&uri).await?;
303 let stream = async_wrap_rustls(stream, get_host(&uri)?, vec![]).await?;
304 self.async_with_stream(uri, stream, check_fn).await
305 }
306
307 #[cfg(feature = "async_tls_native")]
308 pub async fn async_native_tls_connect<C, F>(
310 &self,
311 uri: http::Uri,
312 check_fn: F,
313 ) -> Result<C, WsError>
314 where
315 F: FnMut(
316 String,
317 http::Response<()>,
318 tokio_native_tls::TlsStream<TcpStream>,
319 ) -> Result<C, WsError>,
320 {
321 use crate::connector::{async_wrap_native_tls, get_host};
322 let mode = crate::connector::get_scheme(&uri)?;
323 if matches!(mode, crate::protocol::Mode::WSS) {
324 panic!("can not perform not ssl connection, use `connect` instead");
325 }
326 let stream = async_tcp_connect(&uri).await?;
327 let stream = async_wrap_native_tls(stream, get_host(&uri)?, vec![]).await?;
328 self.async_with_stream(uri, stream, check_fn).await
329 }
330
331 pub async fn async_with_stream<C, F, S>(
335 &self,
336 uri: http::Uri,
337 mut stream: S,
338 mut check_fn: F,
339 ) -> Result<C, WsError>
340 where
341 S: AsyncRead + AsyncWrite + Unpin,
342 F: FnMut(String, http::Response<()>, S) -> Result<C, WsError>,
343 {
344 let (key, resp) = async_req_handshake(
345 &mut stream,
346 &uri,
347 &self.protocols,
348 &self.extensions,
349 self.version,
350 self.headers.clone(),
351 )
352 .await?;
353 check_fn(key, resp, stream)
354 }
355 }
356
357 impl ServerBuilder {
358 pub async fn async_accept<F1, F2, T, C, S>(
363 mut stream: S,
364 mut handshake_handler: F1,
365 mut codec_factory: F2,
366 ) -> Result<C, WsError>
367 where
368 S: AsyncRead + AsyncWrite + Unpin,
369 F1: FnMut(
370 http::Request<()>,
371 ) -> Result<
372 (http::Request<()>, http::Response<T>),
373 (http::Response<T>, WsError),
374 >,
375 F2: FnMut(http::Request<()>, S) -> Result<C, WsError>,
376 T: ToString + Debug,
377 {
378 let req = async_handle_handshake(&mut stream).await?;
379 match handshake_handler(req) {
380 Ok((req, resp)) => {
381 async_write_resp(resp, &mut stream).await?;
382 codec_factory(req, stream)
383 }
384 Err((resp, e)) => {
385 async_write_resp(resp, &mut stream).await?;
386 return Err(e);
387 }
388 }
389 }
390 }
391
392 async fn async_write_resp<S, T>(resp: http::Response<T>, stream: &mut S) -> Result<(), WsError>
393 where
394 S: AsyncRead + AsyncWrite + Unpin,
395 T: ToString + Debug,
396 {
397 let mut resp_lines = vec![format!("{:?} {}", resp.version(), resp.status())];
398 resp.headers().iter().for_each(|(k, v)| {
399 resp_lines.push(format!("{}: {}", k, v.to_str().unwrap_or_default()))
400 });
401 resp_lines.push("\r\n".to_string());
402 stream.write_all(resp_lines.join("\r\n").as_bytes()).await?;
403 tracing::debug!("{:?}", &resp);
404 Ok(if resp.status() != http::StatusCode::SWITCHING_PROTOCOLS {
405 return Err(WsError::HandShakeFailed(resp.body().to_string()));
406 })
407 }
408}
409
410pub struct ServerBuilder {}