tf_rust_socketio/asynchronous/client/builder.rs
1use futures_util::future::BoxFuture;
2use log::trace;
3use native_tls::TlsConnector;
4use std::collections::HashMap;
5use tf_rust_engineio::{
6 asynchronous::ClientBuilder as EngineIoClientBuilder,
7 header::{HeaderMap, HeaderValue},
8};
9use url::Url;
10
11use crate::{error::Result, Event, Payload, TransportType};
12
13use super::{
14 callback::{
15 Callback, DynAsyncAnyCallback, DynAsyncCallback, DynAsyncReconnectSettingsCallback,
16 },
17 client::{Client, ReconnectSettings},
18};
19use crate::asynchronous::socket::Socket as InnerSocket;
20
21/// A builder class for a `socket.io` socket. This handles setting up the client and
22/// configuring the callback, the namespace and metadata of the socket. If no
23/// namespace is specified, the default namespace `/` is taken. The `connect` method
24/// acts the `build` method and returns a connected [`Client`].
25pub struct ClientBuilder {
26 pub(crate) address: String,
27 pub(crate) on: HashMap<Event, Callback<DynAsyncCallback>>,
28 pub(crate) on_any: Option<Callback<DynAsyncAnyCallback>>,
29 pub(crate) on_reconnect: Option<Callback<DynAsyncReconnectSettingsCallback>>,
30 pub(crate) namespace: String,
31 tls_config: Option<TlsConnector>,
32 pub(crate) opening_headers: Option<HeaderMap>,
33 transport_type: TransportType,
34 pub(crate) auth: Option<serde_json::Value>,
35 pub(crate) reconnect: bool,
36 pub(crate) reconnect_on_disconnect: bool,
37 // None implies infinite attempts
38 pub(crate) max_reconnect_attempts: Option<u8>,
39 pub(crate) reconnect_delay_min: u64,
40 pub(crate) reconnect_delay_max: u64,
41}
42
43impl ClientBuilder {
44 /// Create as client builder from a URL. URLs must be in the form
45 /// `[ws or wss or http or https]://[domain]:[port]/[path]`. The
46 /// path of the URL is optional and if no port is given, port 80
47 /// will be used.
48 /// # Example
49 /// ```rust
50 /// use tf_rust_socketio::{Payload, asynchronous::{ClientBuilder, Client}};
51 /// use serde_json::json;
52 /// use futures_util::future::FutureExt;
53 ///
54 ///
55 /// #[tokio::main]
56 /// async fn main() {
57 /// let callback = |payload: Payload, socket: Client| {
58 /// async move {
59 /// match payload {
60 /// Payload::Text(values, _) => println!("Received: {:#?}", values),
61 /// Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
62 /// // This is deprecated, use Payload::Text instead
63 /// #[allow(deprecated)]
64 /// Payload::String(str, _) => println!("Received: {}", str),
65 /// }
66 /// }.boxed()
67 /// };
68 ///
69 /// let mut socket = ClientBuilder::new("http://localhost:4200")
70 /// .namespace("/admin")
71 /// .on("test", callback)
72 /// .connect()
73 /// .await
74 /// .expect("error while connecting");
75 ///
76 /// // use the socket
77 /// let json_payload = json!({"token": 123});
78 ///
79 /// let result = socket.emit("foo", json_payload).await;
80 ///
81 /// assert!(result.is_ok());
82 /// }
83 /// ```
84 pub fn new<T: Into<String>>(address: T) -> Self {
85 Self {
86 address: address.into(),
87 on: HashMap::new(),
88 on_any: None,
89 on_reconnect: None,
90 namespace: "/".to_owned(),
91 tls_config: None,
92 opening_headers: None,
93 transport_type: TransportType::Any,
94 auth: None,
95 reconnect: true,
96 reconnect_on_disconnect: false,
97 // None implies infinite attempts
98 max_reconnect_attempts: None,
99 reconnect_delay_min: 1000,
100 reconnect_delay_max: 5000,
101 }
102 }
103
104 /// Sets the target namespace of the client. The namespace should start
105 /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`.
106 /// If the String provided doesn't start with a leading `/`, it is
107 /// added manually.
108 pub fn namespace<T: Into<String>>(mut self, namespace: T) -> Self {
109 let mut nsp = namespace.into();
110 if !nsp.starts_with('/') {
111 nsp = "/".to_owned() + &nsp;
112 trace!("Added `/` to the given namespace: {}", nsp);
113 }
114 self.namespace = nsp;
115 self
116 }
117
118 /// Registers a new callback for a certain [`crate::event::Event`]. The event could either be
119 /// one of the common events like `message`, `error`, `open`, `close` or a custom
120 /// event defined by a string, e.g. `onPayment` or `foo`.
121 ///
122 /// # Example
123 /// ```rust
124 /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
125 /// use futures_util::FutureExt;
126 ///
127 /// #[tokio::main]
128 /// async fn main() {
129 /// let socket = ClientBuilder::new("http://localhost:4200/")
130 /// .namespace("/admin")
131 /// .on("test", |payload: Payload, _| {
132 /// async move {
133 /// match payload {
134 /// Payload::Text(values, _) => println!("Received: {:#?}", values),
135 /// Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
136 /// // This is deprecated, use Payload::Text instead
137 /// #[allow(deprecated)]
138 /// Payload::String(str, _) => println!("Received: {}", str),
139 /// }
140 /// }
141 /// .boxed()
142 /// })
143 /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
144 /// .connect()
145 /// .await;
146 /// }
147 /// ```
148 ///
149 /// # Issues with type inference for the callback method
150 ///
151 /// Currently stable Rust does not contain types like `AsyncFnMut`.
152 /// That is why this library uses the type `FnMut(..) -> BoxFuture<_>`,
153 /// which basically represents a closure or function that returns a
154 /// boxed future that can be executed in an async executor.
155 /// The complicated constraints for the callback function
156 /// bring the Rust compiler to it's limits, resulting in confusing error
157 /// messages when passing in a variable that holds a closure (to the `on` method).
158 /// In order to make sure type inference goes well, the [`futures_util::FutureExt::boxed`]
159 /// method can be used on an async block (the future) to make sure the return type
160 /// is conform with the generic requirements. An example can be found here:
161 ///
162 /// ```rust
163 /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
164 /// use futures_util::FutureExt;
165 ///
166 /// #[tokio::main]
167 /// async fn main() {
168 /// let callback = |payload: Payload, _| {
169 /// async move {
170 /// match payload {
171 /// Payload::Text(values, _) => println!("Received: {:#?}", values),
172 /// Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
173 /// // This is deprecated use Payload::Text instead
174 /// #[allow(deprecated)]
175 /// Payload::String(str, _) => println!("Received: {}", str),
176 /// }
177 /// }
178 /// .boxed() // <-- this makes sure we end up with a `BoxFuture<_>`
179 /// };
180 ///
181 /// let socket = ClientBuilder::new("http://localhost:4200/")
182 /// .namespace("/admin")
183 /// .on("test", callback)
184 /// .connect()
185 /// .await;
186 /// }
187 /// ```
188 ///
189 #[cfg(feature = "async-callbacks")]
190 pub fn on<T: Into<Event>, F>(mut self, event: T, callback: F) -> Self
191 where
192 F: for<'a> std::ops::FnMut(Payload, Client) -> BoxFuture<'static, ()>
193 + 'static
194 + Send
195 + Sync,
196 {
197 self.on
198 .insert(event.into(), Callback::<DynAsyncCallback>::new(callback));
199 self
200 }
201
202 /// Registers a callback for reconnect events. The event handler must return
203 /// a [ReconnectSettings] struct with the settings that should be updated.
204 ///
205 /// # Example
206 /// ```rust
207 /// use tf_rust_socketio::{asynchronous::{ClientBuilder, ReconnectSettings}};
208 /// use futures_util::future::FutureExt;
209 /// use serde_json::json;
210 ///
211 /// #[tokio::main]
212 /// async fn main() {
213 /// let client = ClientBuilder::new("http://localhost:4200/")
214 /// .namespace("/admin")
215 /// .on_reconnect(|| {
216 /// async {
217 /// let mut settings = ReconnectSettings::new();
218 /// settings.address("http://server?test=123");
219 /// settings.auth(json!({ "token": "abc" }));
220 /// settings.opening_header("TRAIL", "abc-123");
221 /// settings
222 /// }.boxed()
223 /// })
224 /// .connect()
225 /// .await;
226 /// }
227 /// ```
228 pub fn on_reconnect<F>(mut self, callback: F) -> Self
229 where
230 F: for<'a> std::ops::FnMut() -> BoxFuture<'static, ReconnectSettings>
231 + 'static
232 + Send
233 + Sync,
234 {
235 self.on_reconnect = Some(Callback::<DynAsyncReconnectSettingsCallback>::new(callback));
236 self
237 }
238
239 /// Registers a Callback for all [`crate::event::Event::Custom`] and [`crate::event::Event::Message`].
240 ///
241 /// # Example
242 /// ```rust
243 /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
244 /// use futures_util::future::FutureExt;
245 ///
246 /// #[tokio::main]
247 /// async fn main() {
248 /// let client = ClientBuilder::new("http://localhost:4200/")
249 /// .namespace("/admin")
250 /// .on_any(|event, payload, _client| {
251 /// async {
252 /// #[allow(deprecated)]
253 /// if let Payload::String(str, _) = payload {
254 /// println!("{}: {}", String::from(event), str);
255 /// }
256 /// }.boxed()
257 /// })
258 /// .connect()
259 /// .await;
260 /// }
261 /// ```
262 pub fn on_any<F>(mut self, callback: F) -> Self
263 where
264 F: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
265 {
266 self.on_any = Some(Callback::<DynAsyncAnyCallback>::new(callback));
267 self
268 }
269
270 /// Uses a preconfigured TLS connector for secure communication. This configures
271 /// both the `polling` as well as the `websocket` transport type.
272 /// # Example
273 /// ```rust
274 /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
275 /// use native_tls::TlsConnector;
276 /// use futures_util::future::FutureExt;
277 ///
278 /// #[tokio::main]
279 /// async fn main() {
280 /// let tls_connector = TlsConnector::builder()
281 /// .use_sni(true)
282 /// .build()
283 /// .expect("Found illegal configuration");
284 ///
285 /// let socket = ClientBuilder::new("http://localhost:4200/")
286 /// .namespace("/admin")
287 /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
288 /// .tls_config(tls_connector)
289 /// .connect()
290 /// .await;
291 /// }
292 /// ```
293 pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
294 self.tls_config = Some(tls_config);
295 self
296 }
297
298 /// Sets custom http headers for the opening request. The headers will be passed to the underlying
299 /// transport type (either websockets or polling) and then get passed with every request thats made.
300 /// via the transport layer.
301 /// # Example
302 /// ```rust
303 /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
304 /// use futures_util::future::FutureExt;
305 ///
306 /// #[tokio::main]
307 /// async fn main() {
308 /// let socket = ClientBuilder::new("http://localhost:4200/")
309 /// .namespace("/admin")
310 /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
311 /// .opening_header("accept-encoding", "application/json")
312 /// .connect()
313 /// .await;
314 /// }
315 /// ```
316 pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(mut self, key: K, val: T) -> Self {
317 match self.opening_headers {
318 Some(ref mut map) => {
319 map.insert(key.into(), val.into());
320 }
321 None => {
322 let mut map = HeaderMap::default();
323 map.insert(key.into(), val.into());
324 self.opening_headers = Some(map);
325 }
326 }
327 self
328 }
329
330 /// Sets authentification data sent in the opening request.
331 /// # Example
332 /// ```rust
333 /// use tf_rust_socketio::{asynchronous::ClientBuilder};
334 /// use serde_json::json;
335 /// use futures_util::future::FutureExt;
336 ///
337 /// #[tokio::main]
338 /// async fn main() {
339 /// let socket = ClientBuilder::new("http://localhost:4204/")
340 /// .namespace("/admin")
341 /// .auth(json!({ "password": "1337" }))
342 /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
343 /// .connect()
344 /// .await;
345 /// }
346 /// ```
347 pub fn auth<T: Into<serde_json::Value>>(mut self, auth: T) -> Self {
348 self.auth = Some(auth.into());
349
350 self
351 }
352
353 /// Specifies which EngineIO [`TransportType`] to use.
354 ///
355 /// # Example
356 /// ```rust
357 /// use tf_rust_socketio::{asynchronous::ClientBuilder, TransportType};
358 ///
359 /// #[tokio::main]
360 /// async fn main() {
361 /// let socket = ClientBuilder::new("http://localhost:4200/")
362 /// // Use websockets to handshake and connect.
363 /// .transport_type(TransportType::Websocket)
364 /// .connect()
365 /// .await
366 /// .expect("connection failed");
367 /// }
368 /// ```
369 pub fn transport_type(mut self, transport_type: TransportType) -> Self {
370 self.transport_type = transport_type;
371
372 self
373 }
374
375 /// If set to `false` do not try to reconnect on network errors. Defaults to
376 /// `true`
377 pub fn reconnect(mut self, reconnect: bool) -> Self {
378 self.reconnect = reconnect;
379 self
380 }
381
382 /// If set to `true` try to reconnect when the server disconnects the
383 /// client. Defaults to `false`
384 pub fn reconnect_on_disconnect(mut self, reconnect_on_disconnect: bool) -> Self {
385 self.reconnect_on_disconnect = reconnect_on_disconnect;
386 self
387 }
388
389 /// Sets the minimum and maximum delay between reconnection attempts
390 pub fn reconnect_delay(mut self, min: u64, max: u64) -> Self {
391 self.reconnect_delay_min = min;
392 self.reconnect_delay_max = max;
393 self
394 }
395
396 /// Sets the maximum number of times to attempt reconnections. Defaults to
397 /// an infinite number of attempts
398 pub fn max_reconnect_attempts(mut self, reconnect_attempts: u8) -> Self {
399 self.max_reconnect_attempts = Some(reconnect_attempts);
400 self
401 }
402
403 /// Connects the socket to a certain endpoint. This returns a connected
404 /// [`Client`] instance. This method returns an [`std::result::Result::Err`]
405 /// value if something goes wrong during connection. Also starts a separate
406 /// thread to start polling for packets. Used with callbacks.
407 /// # Example
408 /// ```rust
409 /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
410 /// use serde_json::json;
411 /// use futures_util::future::FutureExt;
412 ///
413 /// #[tokio::main]
414 /// async fn main() {
415 /// let mut socket = ClientBuilder::new("http://localhost:4200/")
416 /// .namespace("/admin")
417 /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
418 /// .connect()
419 /// .await
420 /// .expect("connection failed");
421 ///
422 /// // use the socket
423 /// let json_payload = json!({"token": 123});
424 ///
425 /// let result = socket.emit("foo", json_payload).await;
426 ///
427 /// assert!(result.is_ok());
428 /// }
429 /// ```
430 pub async fn connect(self) -> Result<Client> {
431 let mut socket = self.connect_manual().await?;
432 socket.poll_stream().await?;
433
434 Ok(socket)
435 }
436
437 /// Creates a new Socket that can be used for reconnections
438 pub(crate) async fn inner_create(&self) -> Result<InnerSocket> {
439 let mut url = Url::parse(&self.address)?;
440
441 if url.path() == "/" {
442 url.set_path("/socket.io/");
443 }
444
445 let mut builder = EngineIoClientBuilder::new(url);
446
447 if let Some(tls_config) = &self.tls_config {
448 builder = builder.tls_config(tls_config.to_owned());
449 }
450 if let Some(headers) = &self.opening_headers {
451 builder = builder.headers(headers.to_owned());
452 }
453
454 let engine_client = match self.transport_type {
455 TransportType::Any => builder.build_with_fallback().await?,
456 TransportType::Polling => builder.build_polling().await?,
457 TransportType::Websocket => builder.build_websocket().await?,
458 TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade().await?,
459 };
460
461 let inner_socket = InnerSocket::new(engine_client)?;
462 Ok(inner_socket)
463 }
464
465 //TODO: 0.3.X stabilize
466 pub(crate) async fn connect_manual(self) -> Result<Client> {
467 let inner_socket = self.inner_create().await?;
468
469 let socket = Client::new(inner_socket, self)?;
470 socket.connect().await?;
471
472 Ok(socket)
473 }
474}
475
476#[cfg(test)]
477mod test {
478 use super::*;
479 use crate::error::Error;
480 use std::io::{Read, Write};
481 use std::net::TcpListener;
482 use tf_rust_engineio::Error as EngineError;
483
484 /// Spawns a one-shot HTTP server that always replies with the given status
485 /// and body. Mirrors the engineio test helper, kept local because the
486 /// engineio version is `pub(crate)` to that crate.
487 fn spawn_http_error_mock(status: u16, body: &'static str) -> String {
488 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
489 let port = listener.local_addr().unwrap().port();
490
491 std::thread::spawn(move || {
492 for _ in 0..4 {
493 let Ok((mut stream, _)) = listener.accept() else {
494 break;
495 };
496 let mut buf = [0u8; 2048];
497 let _ = stream.read(&mut buf);
498 let response = format!(
499 "HTTP/1.1 {status} ERR\r\nContent-Type: application/json\r\nContent-Length: {len}\r\nConnection: close\r\n\r\n{body}",
500 status = status,
501 len = body.len(),
502 body = body
503 );
504 let _ = stream.write_all(response.as_bytes());
505 }
506 });
507
508 format!("http://127.0.0.1:{}/", port)
509 }
510
511 /// Verifies that when the server rejects the Engine.IO handshake with an
512 /// HTTP error + JSON body (e.g. an A2C-SMCP protocol-version mismatch
513 /// returning 400 + `{"code":4008,...}`), the body propagates all the way
514 /// out of `ClientBuilder::connect()` so downstream callers can match on
515 /// the structured error.
516 #[tokio::test]
517 async fn connect_surfaces_http_error_body_through_socketio_error() {
518 let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
519 let url = spawn_http_error_mock(400, body);
520
521 let result = ClientBuilder::new(url)
522 .transport_type(TransportType::Polling)
523 .connect()
524 .await;
525
526 let err = match result {
527 Ok(_) => panic!("connect should fail when handshake server returns 400"),
528 Err(e) => e,
529 };
530
531 match err {
532 Error::IncompleteResponseFromEngineIo(EngineError::HttpErrorWithBody {
533 status,
534 body: got,
535 }) => {
536 assert_eq!(status, 400);
537 assert_eq!(got, body);
538 }
539 other => {
540 panic!("expected IncompleteResponseFromEngineIo(HttpErrorWithBody), got: {other:?}")
541 }
542 }
543 }
544}