Skip to main content

tf_rust_socketio/asynchronous/client/
builder.rs

1use futures_util::future::BoxFuture;
2use log::trace;
3use native_tls::TlsConnector;
4use std::collections::HashMap;
5use tf_rust_engineio::{
6    asynchronous::ClientBuilder as EngineIoClientBuilder,
7    header::{HeaderMap, HeaderValue},
8};
9use url::Url;
10
11use crate::{error::Result, Event, Payload, TransportType};
12
13use super::{
14    callback::{
15        Callback, DynAsyncAnyCallback, DynAsyncCallback, DynAsyncReconnectSettingsCallback,
16    },
17    client::{Client, ReconnectSettings},
18};
19use crate::asynchronous::socket::Socket as InnerSocket;
20
21/// A builder class for a `socket.io` socket. This handles setting up the client and
22/// configuring the callback, the namespace and metadata of the socket. If no
23/// namespace is specified, the default namespace `/` is taken. The `connect` method
24/// acts the `build` method and returns a connected [`Client`].
25pub struct ClientBuilder {
26    pub(crate) address: String,
27    pub(crate) on: HashMap<Event, Callback<DynAsyncCallback>>,
28    pub(crate) on_any: Option<Callback<DynAsyncAnyCallback>>,
29    pub(crate) on_reconnect: Option<Callback<DynAsyncReconnectSettingsCallback>>,
30    pub(crate) namespace: String,
31    tls_config: Option<TlsConnector>,
32    pub(crate) opening_headers: Option<HeaderMap>,
33    transport_type: TransportType,
34    pub(crate) auth: Option<serde_json::Value>,
35    pub(crate) reconnect: bool,
36    pub(crate) reconnect_on_disconnect: bool,
37    // None implies infinite attempts
38    pub(crate) max_reconnect_attempts: Option<u8>,
39    pub(crate) reconnect_delay_min: u64,
40    pub(crate) reconnect_delay_max: u64,
41}
42
43impl ClientBuilder {
44    /// Create as client builder from a URL. URLs must be in the form
45    /// `[ws or wss or http or https]://[domain]:[port]/[path]`. The
46    /// path of the URL is optional and if no port is given, port 80
47    /// will be used.
48    /// # Example
49    /// ```rust
50    /// use tf_rust_socketio::{Payload, asynchronous::{ClientBuilder, Client}};
51    /// use serde_json::json;
52    /// use futures_util::future::FutureExt;
53    ///
54    ///
55    /// #[tokio::main]
56    /// async fn main() {
57    ///     let callback = |payload: Payload, socket: Client| {
58    ///         async move {
59    ///             match payload {
60    ///                 Payload::Text(values, _) => println!("Received: {:#?}", values),
61    ///                 Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
62    ///                 // This is deprecated, use Payload::Text instead
63    ///                 #[allow(deprecated)]
64    ///                 Payload::String(str, _) => println!("Received: {}", str),
65    ///             }
66    ///         }.boxed()
67    ///     };
68    ///
69    ///     let mut socket = ClientBuilder::new("http://localhost:4200")
70    ///         .namespace("/admin")
71    ///         .on("test", callback)
72    ///         .connect()
73    ///         .await
74    ///         .expect("error while connecting");
75    ///
76    ///     // use the socket
77    ///     let json_payload = json!({"token": 123});
78    ///
79    ///     let result = socket.emit("foo", json_payload).await;
80    ///
81    ///     assert!(result.is_ok());
82    /// }
83    /// ```
84    pub fn new<T: Into<String>>(address: T) -> Self {
85        Self {
86            address: address.into(),
87            on: HashMap::new(),
88            on_any: None,
89            on_reconnect: None,
90            namespace: "/".to_owned(),
91            tls_config: None,
92            opening_headers: None,
93            transport_type: TransportType::Any,
94            auth: None,
95            reconnect: true,
96            reconnect_on_disconnect: false,
97            // None implies infinite attempts
98            max_reconnect_attempts: None,
99            reconnect_delay_min: 1000,
100            reconnect_delay_max: 5000,
101        }
102    }
103
104    /// Sets the target namespace of the client. The namespace should start
105    /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`.
106    /// If the String provided doesn't start with a leading `/`, it is
107    /// added manually.
108    pub fn namespace<T: Into<String>>(mut self, namespace: T) -> Self {
109        let mut nsp = namespace.into();
110        if !nsp.starts_with('/') {
111            nsp = "/".to_owned() + &nsp;
112            trace!("Added `/` to the given namespace: {}", nsp);
113        }
114        self.namespace = nsp;
115        self
116    }
117
118    /// Registers a new callback for a certain [`crate::event::Event`]. The event could either be
119    /// one of the common events like `message`, `error`, `open`, `close` or a custom
120    /// event defined by a string, e.g. `onPayment` or `foo`.
121    ///
122    /// # Example
123    /// ```rust
124    /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
125    /// use futures_util::FutureExt;
126    ///
127    ///  #[tokio::main]
128    /// async fn main() {
129    ///     let socket = ClientBuilder::new("http://localhost:4200/")
130    ///         .namespace("/admin")
131    ///         .on("test", |payload: Payload, _| {
132    ///             async move {
133    ///                 match payload {
134    ///                     Payload::Text(values, _) => println!("Received: {:#?}", values),
135    ///                     Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
136    ///                     // This is deprecated, use Payload::Text instead
137    ///                     #[allow(deprecated)]
138    ///                     Payload::String(str, _) => println!("Received: {}", str),
139    ///                 }
140    ///             }
141    ///             .boxed()
142    ///         })
143    ///         .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
144    ///         .connect()
145    ///         .await;
146    /// }
147    /// ```
148    ///
149    /// # Issues with type inference for the callback method
150    ///
151    /// Currently stable Rust does not contain types like `AsyncFnMut`.
152    /// That is why this library uses the type `FnMut(..) -> BoxFuture<_>`,
153    /// which basically represents a closure or function that returns a
154    /// boxed future that can be executed in an async executor.
155    /// The complicated constraints for the callback function
156    /// bring the Rust compiler to it's limits, resulting in confusing error
157    /// messages when passing in a variable that holds a closure (to the `on` method).
158    /// In order to make sure type inference goes well, the [`futures_util::FutureExt::boxed`]
159    /// method can be used on an async block (the future) to make sure the return type
160    /// is conform with the generic requirements. An example can be found here:
161    ///
162    /// ```rust
163    /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
164    /// use futures_util::FutureExt;
165    ///
166    /// #[tokio::main]
167    /// async fn main() {
168    ///     let callback = |payload: Payload, _| {
169    ///             async move {
170    ///                 match payload {
171    ///                     Payload::Text(values, _) => println!("Received: {:#?}", values),
172    ///                     Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
173    ///                     // This is deprecated use Payload::Text instead
174    ///                     #[allow(deprecated)]
175    ///                     Payload::String(str, _) => println!("Received: {}", str),
176    ///                 }
177    ///             }
178    ///             .boxed() // <-- this makes sure we end up with a `BoxFuture<_>`
179    ///         };
180    ///
181    ///     let socket = ClientBuilder::new("http://localhost:4200/")
182    ///         .namespace("/admin")
183    ///         .on("test", callback)
184    ///         .connect()
185    ///         .await;
186    /// }
187    /// ```
188    ///
189    #[cfg(feature = "async-callbacks")]
190    pub fn on<T: Into<Event>, F>(mut self, event: T, callback: F) -> Self
191    where
192        F: for<'a> std::ops::FnMut(Payload, Client) -> BoxFuture<'static, ()>
193            + 'static
194            + Send
195            + Sync,
196    {
197        self.on
198            .insert(event.into(), Callback::<DynAsyncCallback>::new(callback));
199        self
200    }
201
202    /// Registers a callback for reconnect events. The event handler must return
203    /// a [ReconnectSettings] struct with the settings that should be updated.
204    ///
205    /// # Example
206    /// ```rust
207    /// use tf_rust_socketio::{asynchronous::{ClientBuilder, ReconnectSettings}};
208    /// use futures_util::future::FutureExt;
209    /// use serde_json::json;
210    ///
211    /// #[tokio::main]
212    /// async fn main() {
213    ///     let client = ClientBuilder::new("http://localhost:4200/")
214    ///         .namespace("/admin")
215    ///         .on_reconnect(|| {
216    ///             async {
217    ///                 let mut settings = ReconnectSettings::new();
218    ///                 settings.address("http://server?test=123");
219    ///                 settings.auth(json!({ "token": "abc" }));
220    ///                 settings.opening_header("TRAIL", "abc-123");
221    ///                 settings
222    ///             }.boxed()
223    ///         })
224    ///         .connect()
225    ///         .await;
226    /// }
227    /// ```
228    pub fn on_reconnect<F>(mut self, callback: F) -> Self
229    where
230        F: for<'a> std::ops::FnMut() -> BoxFuture<'static, ReconnectSettings>
231            + 'static
232            + Send
233            + Sync,
234    {
235        self.on_reconnect = Some(Callback::<DynAsyncReconnectSettingsCallback>::new(callback));
236        self
237    }
238
239    /// Registers a Callback for all [`crate::event::Event::Custom`] and [`crate::event::Event::Message`].
240    ///
241    /// # Example
242    /// ```rust
243    /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
244    /// use futures_util::future::FutureExt;
245    ///
246    /// #[tokio::main]
247    /// async fn main() {
248    ///     let client = ClientBuilder::new("http://localhost:4200/")
249    ///         .namespace("/admin")
250    ///         .on_any(|event, payload, _client| {
251    ///             async {
252    ///                 #[allow(deprecated)]
253    ///                 if let Payload::String(str, _) = payload {
254    ///                     println!("{}: {}", String::from(event), str);
255    ///                 }
256    ///             }.boxed()
257    ///         })
258    ///         .connect()
259    ///         .await;
260    /// }
261    /// ```
262    pub fn on_any<F>(mut self, callback: F) -> Self
263    where
264        F: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
265    {
266        self.on_any = Some(Callback::<DynAsyncAnyCallback>::new(callback));
267        self
268    }
269
270    /// Uses a preconfigured TLS connector for secure communication. This configures
271    /// both the `polling` as well as the `websocket` transport type.
272    /// # Example
273    /// ```rust
274    /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
275    /// use native_tls::TlsConnector;
276    /// use futures_util::future::FutureExt;
277    ///
278    /// #[tokio::main]
279    /// async fn main() {
280    ///     let tls_connector =  TlsConnector::builder()
281    ///                .use_sni(true)
282    ///                .build()
283    ///             .expect("Found illegal configuration");
284    ///
285    ///     let socket = ClientBuilder::new("http://localhost:4200/")
286    ///         .namespace("/admin")
287    ///         .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
288    ///         .tls_config(tls_connector)
289    ///         .connect()
290    ///         .await;
291    /// }
292    /// ```
293    pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
294        self.tls_config = Some(tls_config);
295        self
296    }
297
298    /// Sets custom http headers for the opening request. The headers will be passed to the underlying
299    /// transport type (either websockets or polling) and then get passed with every request thats made.
300    /// via the transport layer.
301    /// # Example
302    /// ```rust
303    /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
304    /// use futures_util::future::FutureExt;
305    ///
306    /// #[tokio::main]
307    /// async fn main() {
308    ///     let socket = ClientBuilder::new("http://localhost:4200/")
309    ///         .namespace("/admin")
310    ///         .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
311    ///         .opening_header("accept-encoding", "application/json")
312    ///         .connect()
313    ///         .await;
314    /// }
315    /// ```
316    pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(mut self, key: K, val: T) -> Self {
317        match self.opening_headers {
318            Some(ref mut map) => {
319                map.insert(key.into(), val.into());
320            }
321            None => {
322                let mut map = HeaderMap::default();
323                map.insert(key.into(), val.into());
324                self.opening_headers = Some(map);
325            }
326        }
327        self
328    }
329
330    /// Sets authentification data sent in the opening request.
331    /// # Example
332    /// ```rust
333    /// use tf_rust_socketio::{asynchronous::ClientBuilder};
334    /// use serde_json::json;
335    /// use futures_util::future::FutureExt;
336    ///
337    /// #[tokio::main]
338    /// async fn main() {
339    ///     let socket = ClientBuilder::new("http://localhost:4204/")
340    ///         .namespace("/admin")
341    ///         .auth(json!({ "password": "1337" }))
342    ///         .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
343    ///         .connect()
344    ///         .await;
345    /// }
346    /// ```
347    pub fn auth<T: Into<serde_json::Value>>(mut self, auth: T) -> Self {
348        self.auth = Some(auth.into());
349
350        self
351    }
352
353    /// Specifies which EngineIO [`TransportType`] to use.
354    ///
355    /// # Example
356    /// ```rust
357    /// use tf_rust_socketio::{asynchronous::ClientBuilder, TransportType};
358    ///
359    /// #[tokio::main]
360    /// async fn main() {
361    ///     let socket = ClientBuilder::new("http://localhost:4200/")
362    ///         // Use websockets to handshake and connect.
363    ///         .transport_type(TransportType::Websocket)
364    ///         .connect()
365    ///         .await
366    ///         .expect("connection failed");
367    /// }
368    /// ```
369    pub fn transport_type(mut self, transport_type: TransportType) -> Self {
370        self.transport_type = transport_type;
371
372        self
373    }
374
375    /// If set to `false` do not try to reconnect on network errors. Defaults to
376    /// `true`
377    pub fn reconnect(mut self, reconnect: bool) -> Self {
378        self.reconnect = reconnect;
379        self
380    }
381
382    /// If set to `true` try to reconnect when the server disconnects the
383    /// client. Defaults to `false`
384    pub fn reconnect_on_disconnect(mut self, reconnect_on_disconnect: bool) -> Self {
385        self.reconnect_on_disconnect = reconnect_on_disconnect;
386        self
387    }
388
389    /// Sets the minimum and maximum delay between reconnection attempts
390    pub fn reconnect_delay(mut self, min: u64, max: u64) -> Self {
391        self.reconnect_delay_min = min;
392        self.reconnect_delay_max = max;
393        self
394    }
395
396    /// Sets the maximum number of times to attempt reconnections. Defaults to
397    /// an infinite number of attempts
398    pub fn max_reconnect_attempts(mut self, reconnect_attempts: u8) -> Self {
399        self.max_reconnect_attempts = Some(reconnect_attempts);
400        self
401    }
402
403    /// Connects the socket to a certain endpoint. This returns a connected
404    /// [`Client`] instance. This method returns an [`std::result::Result::Err`]
405    /// value if something goes wrong during connection. Also starts a separate
406    /// thread to start polling for packets. Used with callbacks.
407    /// # Example
408    /// ```rust
409    /// use tf_rust_socketio::{asynchronous::ClientBuilder, Payload};
410    /// use serde_json::json;
411    /// use futures_util::future::FutureExt;
412    ///
413    /// #[tokio::main]
414    /// async fn main() {
415    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
416    ///         .namespace("/admin")
417    ///         .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
418    ///         .connect()
419    ///         .await
420    ///         .expect("connection failed");
421    ///
422    ///     // use the socket
423    ///     let json_payload = json!({"token": 123});
424    ///
425    ///     let result = socket.emit("foo", json_payload).await;
426    ///
427    ///     assert!(result.is_ok());
428    /// }
429    /// ```
430    pub async fn connect(self) -> Result<Client> {
431        let mut socket = self.connect_manual().await?;
432        socket.poll_stream().await?;
433
434        Ok(socket)
435    }
436
437    /// Creates a new Socket that can be used for reconnections
438    pub(crate) async fn inner_create(&self) -> Result<InnerSocket> {
439        let mut url = Url::parse(&self.address)?;
440
441        if url.path() == "/" {
442            url.set_path("/socket.io/");
443        }
444
445        let mut builder = EngineIoClientBuilder::new(url);
446
447        if let Some(tls_config) = &self.tls_config {
448            builder = builder.tls_config(tls_config.to_owned());
449        }
450        if let Some(headers) = &self.opening_headers {
451            builder = builder.headers(headers.to_owned());
452        }
453
454        let engine_client = match self.transport_type {
455            TransportType::Any => builder.build_with_fallback().await?,
456            TransportType::Polling => builder.build_polling().await?,
457            TransportType::Websocket => builder.build_websocket().await?,
458            TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade().await?,
459        };
460
461        let inner_socket = InnerSocket::new(engine_client)?;
462        Ok(inner_socket)
463    }
464
465    //TODO: 0.3.X stabilize
466    pub(crate) async fn connect_manual(self) -> Result<Client> {
467        let inner_socket = self.inner_create().await?;
468
469        let socket = Client::new(inner_socket, self)?;
470        socket.connect().await?;
471
472        Ok(socket)
473    }
474}
475
476#[cfg(test)]
477mod test {
478    use super::*;
479    use crate::error::Error;
480    use std::io::{Read, Write};
481    use std::net::TcpListener;
482    use tf_rust_engineio::Error as EngineError;
483
484    /// Spawns a one-shot HTTP server that always replies with the given status
485    /// and body. Mirrors the engineio test helper, kept local because the
486    /// engineio version is `pub(crate)` to that crate.
487    fn spawn_http_error_mock(status: u16, body: &'static str) -> String {
488        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
489        let port = listener.local_addr().unwrap().port();
490
491        std::thread::spawn(move || {
492            for _ in 0..4 {
493                let Ok((mut stream, _)) = listener.accept() else {
494                    break;
495                };
496                let mut buf = [0u8; 2048];
497                let _ = stream.read(&mut buf);
498                let response = format!(
499                    "HTTP/1.1 {status} ERR\r\nContent-Type: application/json\r\nContent-Length: {len}\r\nConnection: close\r\n\r\n{body}",
500                    status = status,
501                    len = body.len(),
502                    body = body
503                );
504                let _ = stream.write_all(response.as_bytes());
505            }
506        });
507
508        format!("http://127.0.0.1:{}/", port)
509    }
510
511    /// Verifies that when the server rejects the Engine.IO handshake with an
512    /// HTTP error + JSON body (e.g. an A2C-SMCP protocol-version mismatch
513    /// returning 400 + `{"code":4008,...}`), the body propagates all the way
514    /// out of `ClientBuilder::connect()` so downstream callers can match on
515    /// the structured error.
516    #[tokio::test]
517    async fn connect_surfaces_http_error_body_through_socketio_error() {
518        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
519        let url = spawn_http_error_mock(400, body);
520
521        let result = ClientBuilder::new(url)
522            .transport_type(TransportType::Polling)
523            .connect()
524            .await;
525
526        let err = match result {
527            Ok(_) => panic!("connect should fail when handshake server returns 400"),
528            Err(e) => e,
529        };
530
531        match err {
532            Error::IncompleteResponseFromEngineIo(EngineError::HttpErrorWithBody {
533                status,
534                body: got,
535            }) => {
536                assert_eq!(status, 400);
537                assert_eq!(got, body);
538            }
539            other => {
540                panic!("expected IncompleteResponseFromEngineIo(HttpErrorWithBody), got: {other:?}")
541            }
542        }
543    }
544}