tf_rust_engineio/asynchronous/client/
builder.rs1use crate::{
2 asynchronous::{
3 async_socket::Socket as InnerSocket,
4 async_transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport},
5 callback::OptionalCallback,
6 transport::AsyncTransport,
7 },
8 error::Result,
9 header::HeaderMap,
10 packet::HandshakePacket,
11 Error, Packet, ENGINE_IO_VERSION,
12};
13use bytes::Bytes;
14use futures_util::{future::BoxFuture, StreamExt};
15use native_tls::TlsConnector;
16use url::Url;
17
18use super::Client;
19
20#[derive(Clone, Debug)]
21pub struct ClientBuilder {
22 url: Url,
23 tls_config: Option<TlsConnector>,
24 headers: Option<HeaderMap>,
25 handshake: Option<HandshakePacket>,
26 on_error: OptionalCallback<String>,
27 on_open: OptionalCallback<()>,
28 on_close: OptionalCallback<()>,
29 on_data: OptionalCallback<Bytes>,
30 on_packet: OptionalCallback<Packet>,
31}
32
33impl ClientBuilder {
34 pub fn new(url: Url) -> Self {
35 let mut url = url;
36 url.query_pairs_mut()
37 .append_pair("EIO", &ENGINE_IO_VERSION.to_string());
38
39 if url.path() == "/" {
41 url.set_path("/engine.io/");
42 }
43 ClientBuilder {
44 url,
45 headers: None,
46 tls_config: None,
47 handshake: None,
48 on_close: OptionalCallback::default(),
49 on_data: OptionalCallback::default(),
50 on_error: OptionalCallback::default(),
51 on_open: OptionalCallback::default(),
52 on_packet: OptionalCallback::default(),
53 }
54 }
55
56 pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
58 self.tls_config = Some(tls_config);
59 self
60 }
61
62 pub fn headers(mut self, headers: HeaderMap) -> Self {
64 self.headers = Some(headers);
65 self
66 }
67
68 #[cfg(feature = "async-callbacks")]
70 pub fn on_close<T>(mut self, callback: T) -> Self
71 where
72 T: 'static + Send + Sync + Fn(()) -> BoxFuture<'static, ()>,
73 {
74 self.on_close = OptionalCallback::new(callback);
75 self
76 }
77
78 #[cfg(feature = "async-callbacks")]
80 pub fn on_data<T>(mut self, callback: T) -> Self
81 where
82 T: 'static + Send + Sync + Fn(Bytes) -> BoxFuture<'static, ()>,
83 {
84 self.on_data = OptionalCallback::new(callback);
85 self
86 }
87
88 #[cfg(feature = "async-callbacks")]
90 pub fn on_error<T>(mut self, callback: T) -> Self
91 where
92 T: 'static + Send + Sync + Fn(String) -> BoxFuture<'static, ()>,
93 {
94 self.on_error = OptionalCallback::new(callback);
95 self
96 }
97
98 #[cfg(feature = "async-callbacks")]
100 pub fn on_open<T>(mut self, callback: T) -> Self
101 where
102 T: 'static + Send + Sync + Fn(()) -> BoxFuture<'static, ()>,
103 {
104 self.on_open = OptionalCallback::new(callback);
105 self
106 }
107
108 #[cfg(feature = "async-callbacks")]
110 pub fn on_packet<T>(mut self, callback: T) -> Self
111 where
112 T: 'static + Send + Sync + Fn(Packet) -> BoxFuture<'static, ()>,
113 {
114 self.on_packet = OptionalCallback::new(callback);
115 self
116 }
117
118 async fn handshake_with_transport<T: AsyncTransport + Unpin>(
120 &mut self,
121 transport: &mut T,
122 ) -> Result<()> {
123 if self.handshake.is_some() {
125 return Ok(());
126 }
127
128 let mut url = self.url.clone();
129
130 let handshake: HandshakePacket =
131 Packet::try_from(transport.next().await.ok_or(Error::IncompletePacket())??)?
132 .try_into()?;
133
134 url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
136
137 self.handshake = Some(handshake);
138
139 self.url = url;
140
141 Ok(())
142 }
143
144 async fn handshake(&mut self) -> Result<()> {
145 if self.handshake.is_some() {
146 return Ok(());
147 }
148
149 let headers = if let Some(map) = self.headers.clone() {
150 Some(map.try_into()?)
151 } else {
152 None
153 };
154
155 let mut transport =
157 PollingTransport::new(self.url.clone(), self.tls_config.clone(), headers);
158
159 self.handshake_with_transport(&mut transport).await
160 }
161
162 pub async fn build(mut self) -> Result<Client> {
164 self.handshake().await?;
165
166 if self.websocket_upgrade()? {
167 self.build_websocket_with_upgrade().await
168 } else {
169 self.build_polling().await
170 }
171 }
172
173 pub async fn build_polling(mut self) -> Result<Client> {
175 self.handshake().await?;
176
177 let transport = PollingTransport::new(
179 self.url,
180 self.tls_config,
181 self.headers.map(|v| v.try_into().unwrap()),
182 );
183
184 Ok(Client::new(InnerSocket::new(
186 transport.into(),
187 self.handshake.unwrap(),
188 self.on_close,
189 self.on_data,
190 self.on_error,
191 self.on_open,
192 self.on_packet,
193 )))
194 }
195
196 pub async fn build_websocket_with_upgrade(mut self) -> Result<Client> {
198 self.handshake().await?;
199
200 if self.websocket_upgrade()? {
201 self.build_websocket().await
202 } else {
203 Err(Error::IllegalWebsocketUpgrade())
204 }
205 }
206
207 pub async fn build_websocket(mut self) -> Result<Client> {
209 let headers = if let Some(map) = self.headers.clone() {
210 Some(map.try_into()?)
211 } else {
212 None
213 };
214
215 match self.url.scheme() {
216 "http" | "ws" => {
217 let mut transport = WebsocketTransport::new(self.url.clone(), headers).await?;
218
219 if self.handshake.is_some() {
220 transport.upgrade().await?;
221 } else {
222 self.handshake_with_transport(&mut transport).await?;
223 }
224 Ok(Client::new(InnerSocket::new(
227 transport.into(),
228 self.handshake.unwrap(),
229 self.on_close,
230 self.on_data,
231 self.on_error,
232 self.on_open,
233 self.on_packet,
234 )))
235 }
236 "https" | "wss" => {
237 let mut transport = WebsocketSecureTransport::new(
238 self.url.clone(),
239 self.tls_config.clone(),
240 headers,
241 )
242 .await?;
243
244 if self.handshake.is_some() {
245 transport.upgrade().await?;
246 } else {
247 self.handshake_with_transport(&mut transport).await?;
248 }
249 Ok(Client::new(InnerSocket::new(
252 transport.into(),
253 self.handshake.unwrap(),
254 self.on_close,
255 self.on_data,
256 self.on_error,
257 self.on_open,
258 self.on_packet,
259 )))
260 }
261 _ => Err(Error::InvalidUrlScheme(self.url.scheme().to_string())),
262 }
263 }
264
265 pub async fn build_with_fallback(self) -> Result<Client> {
268 let result = self.clone().build().await;
269 if result.is_err() {
270 self.build_polling().await
271 } else {
272 result
273 }
274 }
275
276 fn websocket_upgrade(&mut self) -> Result<bool> {
278 if self.handshake.is_none() {
279 return Ok(false);
280 }
281
282 Ok(self
283 .handshake
284 .as_ref()
285 .unwrap()
286 .upgrades
287 .iter()
288 .any(|upgrade| upgrade.to_lowercase() == *"websocket"))
289 }
290}