1use http::Uri;
2use crate::{
3 codec::{PMDConfig, WindowBit},
4 connector::{get_host, get_scheme},
5 errors::WsError,
6 protocol::Mode,
7 ClientBuilder,
8};
9use std::{collections::HashMap, path::PathBuf};
10
11pub struct ClientConfig {
13 pub read_buf: usize,
15 pub write_buf: usize,
17 pub certs: Vec<PathBuf>,
19 pub window: Option<WindowBit>,
21 pub context_take_over: bool,
23 pub extra_headers: HashMap<String, String>,
25 pub set_socket_fn: Box<dyn FnMut(&std::net::TcpStream) -> Result<(), WsError> + Send + 'static>,
28}
29
30impl Default for ClientConfig {
31 fn default() -> Self {
32 Self {
33 read_buf: Default::default(),
34 write_buf: Default::default(),
35 certs: Default::default(),
36 window: Default::default(),
37 context_take_over: Default::default(),
38 extra_headers: Default::default(),
39 set_socket_fn: Box::new(|_| Ok(())),
40 }
41 }
42}
43
44impl ClientConfig {
45 pub fn buffered() -> Self {
47 Self {
48 read_buf: 8192,
49 write_buf: 8192,
50 ..Default::default()
51 }
52 }
53
54 pub fn connect_with<C, F>(
56 &mut self,
57 uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
58 mut check_fn: F,
59 ) -> Result<C, WsError>
60 where
61 F: FnMut(
62 String,
63 http::Response<()>,
64 crate::stream::BufStream<crate::stream::SyncStream>,
65 ) -> Result<C, WsError>,
66 {
67 let (uri, mode, builder) = self.prepare(uri)?;
68 let stream = crate::connector::tcp_connect(&uri)?;
69 (self.set_socket_fn)(&stream)?;
70 let check_fn = |key, resp, stream| {
71 let stream =
72 crate::stream::BufStream::with_capacity(self.read_buf, self.write_buf, stream);
73 check_fn(key, resp, stream)
74 };
75 match mode {
76 Mode::WS => builder.with_stream(uri, crate::stream::SyncStream::Raw(stream), check_fn),
77 Mode::WSS => {
78 let host = get_host(&uri)?;
79 if cfg!(feature = "sync_tls_rustls") {
80 #[cfg(feature = "sync_tls_rustls")]
81 {
82 let stream =
83 crate::connector::wrap_rustls(stream, host, self.certs.clone())?;
84 builder.with_stream(
85 uri,
86 crate::stream::SyncStream::Rustls(stream),
87 check_fn,
88 )
89 }
90 #[cfg(not(feature = "sync_tls_rustls"))]
91 {
92 panic!("")
93 }
94 } else if cfg!(feature = "sync_tls_native") {
95 #[cfg(feature = "sync_tls_native")]
96 {
97 let stream =
98 crate::connector::wrap_native_tls(stream, host, self.certs.clone())?;
99 builder.with_stream(
100 uri,
101 crate::stream::SyncStream::NativeTls(stream),
102 check_fn,
103 )
104 }
105 #[cfg(not(feature = "sync_tls_native"))]
106 {
107 panic!("")
108 }
109 } else {
110 panic!("for ssl connection, sync_tls_native or sync_tls_rustls feature is required")
111 }
112 }
113 }
114 }
115
116 #[cfg(feature = "sync")]
118 pub fn connect(
119 &mut self,
120 uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
121 ) -> Result<
122 crate::codec::DeflateCodec<crate::stream::BufStream<crate::stream::SyncStream>>,
123 WsError,
124 > {
125 self.connect_with(uri, crate::codec::DeflateCodec::check_fn)
126 }
127
128 #[cfg(feature = "async")]
130 #[allow(unused)]
131 pub async fn async_connect_with<C, F>(
132 &mut self,
133 uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
134 mut check_fn: F,
135 ) -> Result<C, WsError>
136 where
137 F: FnMut(
138 String,
139 http::Response<()>,
140 tokio::io::BufStream<crate::stream::AsyncStream>,
141 ) -> Result<C, WsError>,
142 {
143 let (uri, mode, builder) = self.prepare(uri)?;
144 let stream = crate::connector::async_tcp_connect(&uri).await?;
145 let stream = stream.into_std()?;
146 (self.set_socket_fn)(&stream)?;
147 let stream = tokio::net::TcpStream::from_std(stream)?;
148 let check_fn = |key, resp, stream: crate::stream::AsyncStream| {
149 let stream = tokio::io::BufStream::with_capacity(self.read_buf, self.write_buf, stream);
150 check_fn(key, resp, stream)
151 };
152 match mode {
153 Mode::WS => {
154 builder
155 .async_with_stream(uri, crate::stream::AsyncStream::Raw(stream), check_fn)
156 .await
157 }
158 Mode::WSS => {
159 let host = get_host(&uri)?;
160 if cfg!(feature = "async_tls_rustls") {
161 #[cfg(feature = "async_tls_rustls")]
162 {
163 let stream =
164 crate::connector::async_wrap_rustls(stream, host, self.certs.clone())
165 .await?;
166 builder
167 .async_with_stream(
168 uri,
169 crate::stream::AsyncStream::Rustls(
170 tokio_rustls::TlsStream::Client(stream),
171 ),
172 check_fn,
173 )
174 .await
175 }
176 #[cfg(not(feature = "async_tls_rustls"))]
177 {
178 panic!("")
179 }
180 } else if cfg!(feature = "async_tls_native") {
181 #[cfg(feature = "async_tls_native")]
182 {
183 let stream = crate::connector::async_wrap_native_tls(
184 stream,
185 host,
186 self.certs.clone(),
187 )
188 .await?;
189 builder
190 .async_with_stream(
191 uri,
192 crate::stream::AsyncStream::NativeTls(stream),
193 check_fn,
194 )
195 .await
196 }
197 #[cfg(not(feature = "async_tls_native"))]
198 {
199 panic!("")
200 }
201 } else {
202 panic!("for ssl connection, async_tls_native or async_tls_rustls feature is required")
203 }
204 }
205 }
206 }
207
208 #[cfg(feature = "async")]
210 pub async fn async_connect(
211 &mut self,
212 uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
213 ) -> Result<
214 crate::codec::AsyncDeflateCodec<tokio::io::BufStream<crate::stream::AsyncStream>>,
215 WsError,
216 > {
217 self.async_connect_with(uri, crate::codec::AsyncDeflateCodec::check_fn)
218 .await
219 }
220
221 fn prepare(
222 &mut self,
223 uri: impl TryInto<Uri, Error = http::uri::InvalidUri>,
224 ) -> Result<(Uri, Mode, ClientBuilder), WsError> {
225 let uri = uri
226 .try_into()
227 .map_err(|e| WsError::InvalidUri(e.to_string()))?;
228 let mode = get_scheme(&uri)?;
229 let mut builder = ClientBuilder::new();
230 let pmd_conf = self.window.map(|w| PMDConfig {
231 server_no_context_takeover: self.context_take_over,
232 client_no_context_takeover: self.context_take_over,
233 server_max_window_bits: w,
234 client_max_window_bits: w,
235 });
236 if let Some(conf) = pmd_conf {
237 builder = builder.extension(conf.ext_string())
238 }
239 for (k, v) in &self.extra_headers {
240 builder = builder.header(k, v);
241 }
242 Ok((uri, mode, builder))
243 }
244}