ws_tool/
simple.rs

1use http::Uri;
2use crate::{
3    codec::{PMDConfig, WindowBit},
4    connector::{get_host, get_scheme},
5    errors::WsError,
6    protocol::Mode,
7    ClientBuilder,
8};
9use std::{collections::HashMap, path::PathBuf};
10
11/// client connection config
12pub struct ClientConfig {
13    /// read buffer size
14    pub read_buf: usize,
15    /// write buffer size
16    pub write_buf: usize,
17    /// custom certification path
18    pub certs: Vec<PathBuf>,
19    /// deflate window size, if none, deflate will be disabled
20    pub window: Option<WindowBit>,
21    /// enable/disable deflate context taker over parameter
22    pub context_take_over: bool,
23    /// extra header when perform websocket protocol handshake
24    pub extra_headers: HashMap<String, String>,
25    /// modified socket option after create tcp socket, this function will be applied
26    /// before start tls session
27    pub set_socket_fn: Box<dyn FnMut(&std::net::TcpStream) -> Result<(), WsError> + Send + 'static>,
28}
29
30impl Default for ClientConfig {
31    fn default() -> Self {
32        Self {
33            read_buf: Default::default(),
34            write_buf: Default::default(),
35            certs: Default::default(),
36            window: Default::default(),
37            context_take_over: Default::default(),
38            extra_headers: Default::default(),
39            set_socket_fn: Box::new(|_| Ok(())),
40        }
41    }
42}
43
44impl ClientConfig {
45    /// use default buffer size 8192
46    pub fn buffered() -> Self {
47        Self {
48            read_buf: 8192,
49            write_buf: 8192,
50            ..Default::default()
51        }
52    }
53
54    /// perform websocket handshake, use custom codec
55    pub fn connect_with<C, F>(
56        &mut self,
57        uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
58        mut check_fn: F,
59    ) -> Result<C, WsError>
60    where
61        F: FnMut(
62            String,
63            http::Response<()>,
64            crate::stream::BufStream<crate::stream::SyncStream>,
65        ) -> Result<C, WsError>,
66    {
67        let (uri, mode, builder) = self.prepare(uri)?;
68        let stream = crate::connector::tcp_connect(&uri)?;
69        (self.set_socket_fn)(&stream)?;
70        let check_fn = |key, resp, stream| {
71            let stream =
72                crate::stream::BufStream::with_capacity(self.read_buf, self.write_buf, stream);
73            check_fn(key, resp, stream)
74        };
75        match mode {
76            Mode::WS => builder.with_stream(uri, crate::stream::SyncStream::Raw(stream), check_fn),
77            Mode::WSS => {
78                let host = get_host(&uri)?;
79                if cfg!(feature = "sync_tls_rustls") {
80                    #[cfg(feature = "sync_tls_rustls")]
81                    {
82                        let stream =
83                            crate::connector::wrap_rustls(stream, host, self.certs.clone())?;
84                        builder.with_stream(
85                            uri,
86                            crate::stream::SyncStream::Rustls(stream),
87                            check_fn,
88                        )
89                    }
90                    #[cfg(not(feature = "sync_tls_rustls"))]
91                    {
92                        panic!("")
93                    }
94                } else if cfg!(feature = "sync_tls_native") {
95                    #[cfg(feature = "sync_tls_native")]
96                    {
97                        let stream =
98                            crate::connector::wrap_native_tls(stream, host, self.certs.clone())?;
99                        builder.with_stream(
100                            uri,
101                            crate::stream::SyncStream::NativeTls(stream),
102                            check_fn,
103                        )
104                    }
105                    #[cfg(not(feature = "sync_tls_native"))]
106                    {
107                        panic!("")
108                    }
109                } else {
110                    panic!("for ssl connection, sync_tls_native or sync_tls_rustls feature is required")
111                }
112            }
113        }
114    }
115
116    /// perform websocket handshake
117    #[cfg(feature = "sync")]
118    pub fn connect(
119        &mut self,
120        uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
121    ) -> Result<
122        crate::codec::DeflateCodec<crate::stream::BufStream<crate::stream::SyncStream>>,
123        WsError,
124    > {
125        self.connect_with(uri, crate::codec::DeflateCodec::check_fn)
126    }
127
128    /// perform websocket handshake
129    #[cfg(feature = "async")]
130    #[allow(unused)]
131    pub async fn async_connect_with<C, F>(
132        &mut self,
133        uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
134        mut check_fn: F,
135    ) -> Result<C, WsError>
136    where
137        F: FnMut(
138            String,
139            http::Response<()>,
140            tokio::io::BufStream<crate::stream::AsyncStream>,
141        ) -> Result<C, WsError>,
142    {
143        let (uri, mode, builder) = self.prepare(uri)?;
144        let stream = crate::connector::async_tcp_connect(&uri).await?;
145        let stream = stream.into_std()?;
146        (self.set_socket_fn)(&stream)?;
147        let stream = tokio::net::TcpStream::from_std(stream)?;
148        let check_fn = |key, resp, stream: crate::stream::AsyncStream| {
149            let stream = tokio::io::BufStream::with_capacity(self.read_buf, self.write_buf, stream);
150            check_fn(key, resp, stream)
151        };
152        match mode {
153            Mode::WS => {
154                builder
155                    .async_with_stream(uri, crate::stream::AsyncStream::Raw(stream), check_fn)
156                    .await
157            }
158            Mode::WSS => {
159                let host = get_host(&uri)?;
160                if cfg!(feature = "async_tls_rustls") {
161                    #[cfg(feature = "async_tls_rustls")]
162                    {
163                        let stream =
164                            crate::connector::async_wrap_rustls(stream, host, self.certs.clone())
165                                .await?;
166                        builder
167                            .async_with_stream(
168                                uri,
169                                crate::stream::AsyncStream::Rustls(
170                                    tokio_rustls::TlsStream::Client(stream),
171                                ),
172                                check_fn,
173                            )
174                            .await
175                    }
176                    #[cfg(not(feature = "async_tls_rustls"))]
177                    {
178                        panic!("")
179                    }
180                } else if cfg!(feature = "async_tls_native") {
181                    #[cfg(feature = "async_tls_native")]
182                    {
183                        let stream = crate::connector::async_wrap_native_tls(
184                            stream,
185                            host,
186                            self.certs.clone(),
187                        )
188                        .await?;
189                        builder
190                            .async_with_stream(
191                                uri,
192                                crate::stream::AsyncStream::NativeTls(stream),
193                                check_fn,
194                            )
195                            .await
196                    }
197                    #[cfg(not(feature = "async_tls_native"))]
198                    {
199                        panic!("")
200                    }
201                } else {
202                    panic!("for ssl connection, async_tls_native or async_tls_rustls feature is required")
203                }
204            }
205        }
206    }
207
208    /// perform websocket handshake
209    #[cfg(feature = "async")]
210    pub async fn async_connect(
211        &mut self,
212        uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
213    ) -> Result<
214        crate::codec::AsyncDeflateCodec<tokio::io::BufStream<crate::stream::AsyncStream>>,
215        WsError,
216    > {
217        self.async_connect_with(uri, crate::codec::AsyncDeflateCodec::check_fn)
218            .await
219    }
220
221    fn prepare(
222        &mut self,
223        uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
224    ) -> Result<(Uri, Mode, ClientBuilder), WsError> {
225        let uri = uri
226            .try_into()
227            .map_err(|e| WsError::InvalidUri(e.to_string()))?;
228        let mode = get_scheme(&uri)?;
229        let mut builder = ClientBuilder::new();
230        let pmd_conf = self.window.map(|w| PMDConfig {
231            server_no_context_takeover: self.context_take_over,
232            client_no_context_takeover: self.context_take_over,
233            server_max_window_bits: w,
234            client_max_window_bits: w,
235        });
236        if let Some(conf) = pmd_conf {
237            builder = builder.extension(conf.ext_string())
238        }
239        for (k, v) in &self.extra_headers {
240            builder = builder.header(k, v);
241        }
242        Ok((uri, mode, builder))
243    }
244}