ws_tool/
lib.rs

1//! rust websocket toolkit
2
3#![warn(missing_docs)]
4#![cfg_attr(docrs, feature(doc_auto_cfg))]
5
6use std::collections::HashMap;
7
8/// websocket error definitions
9pub mod errors;
10/// websocket transport unit
11pub mod frame;
12/// build connection & read/write frame utils
13pub mod protocol;
14
15/// frame codec impl
16pub mod codec;
17/// connection helper function
18pub mod connector;
19
20/// helper message definition
21mod message;
22pub use message::*;
23#[cfg(feature = "simple")]
24/// simple api to create websocket connection
25pub mod simple;
26#[cfg(feature = "simple")]
27pub use simple::ClientConfig;
28
29/// helper stream definition
30pub mod stream;
31
32/// some helper extension
33pub mod extension;
34
35/// helper builder to construct websocket client
36#[derive(Debug, Clone)]
37pub struct ClientBuilder {
38    protocols: Vec<String>,
39    extensions: Vec<String>,
40    #[cfg_attr(not(any(feature = "sync", feature = "async")), allow(dead_code))]
41    version: u8,
42    headers: HashMap<String, String>,
43}
44
45impl Default for ClientBuilder {
46    fn default() -> Self {
47        Self {
48            protocols: vec![],
49            extensions: vec![],
50            headers: HashMap::new(),
51            version: 13,
52        }
53    }
54}
55
56impl ClientBuilder {
57    /// create builder with websocket url
58    pub fn new() -> Self {
59        Default::default()
60    }
61
62    /// add protocols
63    pub fn protocol(mut self, protocol: String) -> Self {
64        self.protocols.push(protocol);
65        self
66    }
67
68    /// set extension in handshake http header
69    ///
70    /// **NOTE** it will clear protocols set by `protocol` method
71    pub fn protocols(self, protocols: Vec<String>) -> Self {
72        Self { protocols, ..self }
73    }
74
75    /// add protocols
76    pub fn extension(mut self, extension: String) -> Self {
77        self.extensions.push(extension);
78        self
79    }
80
81    /// set extension in handshake http header
82    ///
83    /// **NOTE** it will clear protocols set by `protocol` method
84    pub fn extensions(self, extensions: Vec<String>) -> Self {
85        Self { extensions, ..self }
86    }
87
88    /// set websocket version, default 13
89    pub fn version(self, version: u8) -> Self {
90        Self { version, ..self }
91    }
92
93    /// add initial request header
94    pub fn header<K: ToString, V: ToString>(mut self, name: K, value: V) -> Self {
95        self.headers.insert(name.to_string(), value.to_string());
96        self
97    }
98
99    /// set initial request headers
100    ///
101    /// **NOTE** it will clear header set by previous `header` method
102    pub fn headers(self, headers: HashMap<String, String>) -> Self {
103        Self { headers, ..self }
104    }
105}
106
107#[cfg(feature = "sync")]
108mod blocking {
109    use std::{
110        io::{Read, Write},
111        net::TcpStream,
112    };
113
114    use crate::{
115        connector::{get_scheme, tcp_connect},
116        errors::WsError,
117        protocol::{handle_handshake, req_handshake},
118        ClientBuilder, ServerBuilder,
119    };
120
121    impl ClientBuilder {
122        /// perform protocol handshake & check server response
123        pub fn connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
124        where
125            F: FnMut(String, http::Response<()>, TcpStream) -> Result<C, WsError>,
126        {
127            let mode = get_scheme(&uri)?;
128            if matches!(mode, crate::protocol::Mode::WSS) {
129                panic!("can not perform ssl connection, use `rustls_connect` or `native_tls_connect` instead");
130            }
131            let stream = tcp_connect(&uri)?;
132            self.with_stream(uri, stream, check_fn)
133        }
134
135        #[cfg(feature = "sync_tls_rustls")]
136        /// perform protocol handshake via ssl with default certs & check server response
137        pub fn rustls_connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
138        where
139            F: FnMut(
140                String,
141                http::Response<()>,
142                rustls_connector::rustls::StreamOwned<
143                    rustls_connector::rustls::ClientConnection,
144                    TcpStream,
145                >,
146            ) -> Result<C, WsError>,
147        {
148            use crate::connector::{get_host, wrap_rustls};
149            let mode = get_scheme(&uri)?;
150            if matches!(mode, crate::protocol::Mode::WSS) {
151                panic!("can not perform not ssl connection, use `connect` instead");
152            }
153            let stream = tcp_connect(&uri)?;
154            let stream = wrap_rustls(stream, get_host(&uri)?, vec![])?;
155            self.with_stream(uri, stream, check_fn)
156        }
157
158        #[cfg(feature = "sync_tls_native")]
159        /// perform protocol handshake via ssl with default certs & check server response
160        pub fn native_tls_connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
161        where
162            F: FnMut(
163                String,
164                http::Response<()>,
165                native_tls::TlsStream<TcpStream>,
166            ) -> Result<C, WsError>,
167        {
168            use crate::connector::{get_host, wrap_native_tls};
169            let mode = get_scheme(&uri)?;
170            if matches!(mode, crate::protocol::Mode::WSS) {
171                panic!("can not perform not ssl connection, use `connect` instead");
172            }
173            let stream = tcp_connect(&uri)?;
174            let stream = wrap_native_tls(stream, get_host(&uri)?, vec![])?;
175            self.with_stream(uri, stream, check_fn)
176        }
177
178        /// ## Low level api
179        /// perform protocol handshake & check server response
180        pub fn with_stream<C, F, S>(
181            &self,
182            uri: http::Uri,
183            mut stream: S,
184            mut check_fn: F,
185        ) -> Result<C, WsError>
186        where
187            S: Read + Write,
188            F: FnMut(String, http::Response<()>, S) -> Result<C, WsError>,
189        {
190            get_scheme(&uri)?;
191            let (key, resp) = req_handshake(
192                &mut stream,
193                &uri,
194                &self.protocols,
195                &self.extensions,
196                self.version,
197                self.headers.clone(),
198            )?;
199            check_fn(key, resp, stream)
200        }
201    }
202
203    impl ServerBuilder {
204        /// wait for protocol handshake from client
205        /// checking handshake & construct server
206        pub fn accept<F1, F2, T, C, S>(
207            mut stream: S,
208            mut handshake_handler: F1,
209            mut codec_factory: F2,
210        ) -> Result<C, WsError>
211        where
212            S: Read + Write,
213            F1: FnMut(
214                http::Request<()>,
215            ) -> Result<
216                (http::Request<()>, http::Response<T>),
217                (http::Response<T>, WsError),
218            >,
219            F2: FnMut(http::Request<()>, S) -> Result<C, WsError>,
220            T: ToString + std::fmt::Debug,
221        {
222            let req = handle_handshake(&mut stream)?;
223            match handshake_handler(req) {
224                Err((resp, e)) => {
225                    write_resp(resp, &mut stream)?;
226                    return Err(e);
227                }
228                Ok((req, resp)) => {
229                    write_resp(resp, &mut stream)?;
230                    codec_factory(req, stream)
231                }
232            }
233        }
234    }
235
236    fn write_resp<S, T>(resp: http::Response<T>, stream: &mut S) -> Result<(), WsError>
237    where
238        S: Read + Write,
239        T: ToString + std::fmt::Debug,
240    {
241        let mut resp_lines = vec![format!("{:?} {}", resp.version(), resp.status())];
242        resp.headers().iter().for_each(|(k, v)| {
243            resp_lines.push(format!("{}: {}", k, v.to_str().unwrap_or_default()))
244        });
245        resp_lines.push("\r\n".to_string());
246        stream.write_all(resp_lines.join("\r\n").as_bytes())?;
247        tracing::debug!("{:?}", &resp);
248        Ok(if resp.status() != http::StatusCode::SWITCHING_PROTOCOLS {
249            return Err(WsError::HandShakeFailed(resp.body().to_string()));
250        })
251    }
252}
253
254#[cfg(feature = "async")]
255mod non_blocking {
256    use http;
257    use std::fmt::Debug;
258
259    use tokio::{
260        io::{AsyncRead, AsyncWrite, AsyncWriteExt},
261        net::TcpStream,
262    };
263
264    use crate::{
265        connector::async_tcp_connect,
266        errors::WsError,
267        protocol::{async_handle_handshake, async_req_handshake},
268        ServerBuilder,
269    };
270
271    use super::ClientBuilder;
272
273    impl ClientBuilder {
274        /// perform protocol handshake & check server response
275        pub async fn async_connect<C, F>(&self, uri: http::Uri, check_fn: F) -> Result<C, WsError>
276        where
277            F: FnMut(String, http::Response<()>, TcpStream) -> Result<C, WsError>,
278        {
279            let stream = async_tcp_connect(&uri).await?;
280            self.async_with_stream(uri, stream, check_fn).await
281        }
282
283        #[cfg(feature = "async_tls_rustls")]
284        /// perform protocol handshake via ssl with default certs & check server response
285        pub async fn async_rustls_connect<C, F>(
286            &self,
287            uri: http::Uri,
288            check_fn: F,
289        ) -> Result<C, WsError>
290        where
291            F: FnMut(
292                String,
293                http::Response<()>,
294                tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
295            ) -> Result<C, WsError>,
296        {
297            use crate::connector::{async_wrap_rustls, get_host};
298            let mode = crate::connector::get_scheme(&uri)?;
299            if matches!(mode, crate::protocol::Mode::WSS) {
300                panic!("can not perform not ssl connection, use `connect` instead");
301            }
302            let stream = async_tcp_connect(&uri).await?;
303            let stream = async_wrap_rustls(stream, get_host(&uri)?, vec![]).await?;
304            self.async_with_stream(uri, stream, check_fn).await
305        }
306
307        #[cfg(feature = "async_tls_native")]
308        /// perform protocol handshake via ssl with default certs & check server response
309        pub async fn async_native_tls_connect<C, F>(
310            &self,
311            uri: http::Uri,
312            check_fn: F,
313        ) -> Result<C, WsError>
314        where
315            F: FnMut(
316                String,
317                http::Response<()>,
318                tokio_native_tls::TlsStream<TcpStream>,
319            ) -> Result<C, WsError>,
320        {
321            use crate::connector::{async_wrap_native_tls, get_host};
322            let mode = crate::connector::get_scheme(&uri)?;
323            if matches!(mode, crate::protocol::Mode::WSS) {
324                panic!("can not perform not ssl connection, use `connect` instead");
325            }
326            let stream = async_tcp_connect(&uri).await?;
327            let stream = async_wrap_native_tls(stream, get_host(&uri)?, vec![]).await?;
328            self.async_with_stream(uri, stream, check_fn).await
329        }
330
331        /// async version of connect
332        ///
333        /// perform protocol handshake & check server response
334        pub async fn async_with_stream<C, F, S>(
335            &self,
336            uri: http::Uri,
337            mut stream: S,
338            mut check_fn: F,
339        ) -> Result<C, WsError>
340        where
341            S: AsyncRead + AsyncWrite + Unpin,
342            F: FnMut(String, http::Response<()>, S) -> Result<C, WsError>,
343        {
344            let (key, resp) = async_req_handshake(
345                &mut stream,
346                &uri,
347                &self.protocols,
348                &self.extensions,
349                self.version,
350                self.headers.clone(),
351            )
352            .await?;
353            check_fn(key, resp, stream)
354        }
355    }
356
357    impl ServerBuilder {
358        /// async version
359        ///
360        /// wait for protocol handshake from client
361        /// checking handshake & construct server
362        pub async fn async_accept<F1, F2, T, C, S>(
363            mut stream: S,
364            mut handshake_handler: F1,
365            mut codec_factory: F2,
366        ) -> Result<C, WsError>
367        where
368            S: AsyncRead + AsyncWrite + Unpin,
369            F1: FnMut(
370                http::Request<()>,
371            ) -> Result<
372                (http::Request<()>, http::Response<T>),
373                (http::Response<T>, WsError),
374            >,
375            F2: FnMut(http::Request<()>, S) -> Result<C, WsError>,
376            T: ToString + Debug,
377        {
378            let req = async_handle_handshake(&mut stream).await?;
379            match handshake_handler(req) {
380                Ok((req, resp)) => {
381                    async_write_resp(resp, &mut stream).await?;
382                    codec_factory(req, stream)
383                }
384                Err((resp, e)) => {
385                    async_write_resp(resp, &mut stream).await?;
386                    return Err(e);
387                }
388            }
389        }
390    }
391
392    async fn async_write_resp<S, T>(resp: http::Response<T>, stream: &mut S) -> Result<(), WsError>
393    where
394        S: AsyncRead + AsyncWrite + Unpin,
395        T: ToString + Debug,
396    {
397        let mut resp_lines = vec![format!("{:?} {}", resp.version(), resp.status())];
398        resp.headers().iter().for_each(|(k, v)| {
399            resp_lines.push(format!("{}: {}", k, v.to_str().unwrap_or_default()))
400        });
401        resp_lines.push("\r\n".to_string());
402        stream.write_all(resp_lines.join("\r\n").as_bytes()).await?;
403        tracing::debug!("{:?}", &resp);
404        Ok(if resp.status() != http::StatusCode::SWITCHING_PROTOCOLS {
405            return Err(WsError::HandShakeFailed(resp.body().to_string()));
406        })
407    }
408}
409
410/// helper struct to config & construct websocket server
411pub struct ServerBuilder {}