Skip to main content

tf_rust_socketio/client/
builder.rs

1use super::super::{event::Event, payload::Payload};
2use super::callback::Callback;
3use super::client::Client;
4use crate::RawClient;
5use native_tls::TlsConnector;
6use tf_rust_engineio::client::ClientBuilder as EngineIoClientBuilder;
7use tf_rust_engineio::header::{HeaderMap, HeaderValue};
8use url::Url;
9
10use crate::client::callback::{SocketAnyCallback, SocketCallback};
11use crate::error::Result;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15use crate::socket::Socket as InnerSocket;
16
17/// Flavor of Engine.IO transport.
18#[derive(Clone, Eq, PartialEq)]
19pub enum TransportType {
20    /// Handshakes with polling, upgrades if possible
21    Any,
22    /// Handshakes with websocket. Does not use polling.
23    Websocket,
24    /// Handshakes with polling, errors if upgrade fails
25    WebsocketUpgrade,
26    /// Handshakes with polling
27    Polling,
28}
29
30/// A builder class for a `socket.io` socket. This handles setting up the client and
31/// configuring the callback, the namespace and metadata of the socket. If no
32/// namespace is specified, the default namespace `/` is taken. The `connect` method
33/// acts the `build` method and returns a connected [`Client`].
34#[derive(Clone)]
35pub struct ClientBuilder {
36    pub(crate) address: String,
37    on: Arc<Mutex<HashMap<Event, Callback<SocketCallback>>>>,
38    on_any: Arc<Mutex<Option<Callback<SocketAnyCallback>>>>,
39    namespace: String,
40    tls_config: Option<TlsConnector>,
41    opening_headers: Option<HeaderMap>,
42    transport_type: TransportType,
43    auth: Option<serde_json::Value>,
44    pub(crate) reconnect: bool,
45    pub(crate) reconnect_on_disconnect: bool,
46    // None reconnect attempts represent infinity.
47    pub(crate) max_reconnect_attempts: Option<u8>,
48    pub(crate) reconnect_delay_min: u64,
49    pub(crate) reconnect_delay_max: u64,
50}
51
52impl ClientBuilder {
53    /// Create as client builder from a URL. URLs must be in the form
54    /// `[ws or wss or http or https]://[domain]:[port]/[path]`. The
55    /// path of the URL is optional and if no port is given, port 80
56    /// will be used.
57    /// # Example
58    /// ```rust
59    /// use tf_rust_socketio::{ClientBuilder, Payload, RawClient};
60    /// use serde_json::json;
61    ///
62    ///
63    /// let callback = |payload: Payload, socket: RawClient| {
64    ///            match payload {
65    ///                Payload::Text(values, _) => println!("Received: {:#?}", values),
66    ///                Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
67    ///                // This payload type is deprecated, use Payload::Text instead
68    ///                #[allow(deprecated)]
69    ///                Payload::String(str, _) => println!("Received: {}", str),
70    ///            }
71    /// };
72    ///
73    /// let mut socket = ClientBuilder::new("http://localhost:4200")
74    ///     .namespace("/admin")
75    ///     .on("test", callback)
76    ///     .connect()
77    ///     .expect("error while connecting");
78    ///
79    /// // use the socket
80    /// let json_payload = json!({"token": 123});
81    ///
82    /// let result = socket.emit("foo", json_payload);
83    ///
84    /// assert!(result.is_ok());
85    /// ```
86    pub fn new<T: Into<String>>(address: T) -> Self {
87        Self {
88            address: address.into(),
89            on: Arc::new(Mutex::new(HashMap::new())),
90            on_any: Arc::new(Mutex::new(None)),
91            namespace: "/".to_owned(),
92            tls_config: None,
93            opening_headers: None,
94            transport_type: TransportType::Any,
95            auth: None,
96            reconnect: true,
97            reconnect_on_disconnect: false,
98            // None means infinity
99            max_reconnect_attempts: None,
100            reconnect_delay_min: 1000,
101            reconnect_delay_max: 5000,
102        }
103    }
104
105    /// Sets the target namespace of the client. The namespace should start
106    /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`.
107    pub fn namespace<T: Into<String>>(mut self, namespace: T) -> Self {
108        let mut nsp = namespace.into();
109        if !nsp.starts_with('/') {
110            nsp = "/".to_owned() + &nsp;
111        }
112        self.namespace = nsp;
113        self
114    }
115
116    pub fn reconnect(mut self, reconnect: bool) -> Self {
117        self.reconnect = reconnect;
118        self
119    }
120
121    /// If set to `true` automatically set try to reconnect when the server
122    /// disconnects the client.
123    /// Defaults to `false`.
124    ///
125    /// # Example
126    /// ```rust
127    /// use tf_rust_socketio::ClientBuilder;
128    ///
129    /// let socket = ClientBuilder::new("http://localhost:4200/")
130    ///     .reconnect_on_disconnect(true)
131    ///     .connect();
132    /// ```
133    pub fn reconnect_on_disconnect(mut self, reconnect_on_disconnect: bool) -> Self {
134        self.reconnect_on_disconnect = reconnect_on_disconnect;
135        self
136    }
137
138    pub fn reconnect_delay(mut self, min: u64, max: u64) -> Self {
139        self.reconnect_delay_min = min;
140        self.reconnect_delay_max = max;
141
142        self
143    }
144
145    pub fn max_reconnect_attempts(mut self, reconnect_attempts: u8) -> Self {
146        self.max_reconnect_attempts = Some(reconnect_attempts);
147        self
148    }
149
150    /// Registers a new callback for a certain [`crate::event::Event`]. The event could either be
151    /// one of the common events like `message`, `error`, `open`, `close` or a custom
152    /// event defined by a string, e.g. `onPayment` or `foo`.
153    ///
154    /// # Example
155    /// ```rust
156    /// use tf_rust_socketio::{ClientBuilder, Payload};
157    ///
158    /// let socket = ClientBuilder::new("http://localhost:4200/")
159    ///     .namespace("/admin")
160    ///     .on("test", |payload: Payload, _| {
161    ///            match payload {
162    ///                Payload::Text(values, _) => println!("Received: {:#?}", values),
163    ///                Payload::Binary(bin_data, _) => println!("Received bytes: {:#?}", bin_data),
164    ///                // This payload type is deprecated, use Payload::Text instead
165    ///                #[allow(deprecated)]
166    ///                Payload::String(str, _) => println!("Received: {}", str),
167    ///            }
168    ///     })
169    ///     .on("error", |err, _| eprintln!("Error: {:#?}", err))
170    ///     .connect();
171    ///
172    /// ```
173    // While present implementation doesn't require mut, it's reasonable to require mutability.
174    #[allow(unused_mut)]
175    pub fn on<T: Into<Event>, F>(mut self, event: T, callback: F) -> Self
176    where
177        F: FnMut(Payload, RawClient) + 'static + Send,
178    {
179        let callback = Callback::<SocketCallback>::new(callback);
180        // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held
181        self.on.lock().unwrap().insert(event.into(), callback);
182        self
183    }
184
185    /// Registers a Callback for all [`crate::event::Event::Custom`] and [`crate::event::Event::Message`].
186    ///
187    /// # Example
188    /// ```rust
189    /// use tf_rust_socketio::{ClientBuilder, Payload};
190    ///
191    /// let client = ClientBuilder::new("http://localhost:4200/")
192    ///     .namespace("/admin")
193    ///     .on_any(|event, payload, _client| {
194    ///         #[allow(deprecated)]
195    ///         if let Payload::String(str, _) = payload {
196    ///           println!("{} {}", String::from(event), str);
197    ///         }
198    ///     })
199    ///     .connect();
200    ///
201    /// ```
202    // While present implementation doesn't require mut, it's reasonable to require mutability.
203    #[allow(unused_mut)]
204    pub fn on_any<F>(mut self, callback: F) -> Self
205    where
206        F: FnMut(Event, Payload, RawClient) + 'static + Send,
207    {
208        let callback = Some(Callback::<SocketAnyCallback>::new(callback));
209        // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held
210        *self.on_any.lock().unwrap() = callback;
211        self
212    }
213
214    /// Uses a preconfigured TLS connector for secure communication. This configures
215    /// both the `polling` as well as the `websocket` transport type.
216    /// # Example
217    /// ```rust
218    /// use tf_rust_socketio::{ClientBuilder, Payload};
219    /// use native_tls::TlsConnector;
220    ///
221    /// let tls_connector =  TlsConnector::builder()
222    ///            .use_sni(true)
223    ///            .build()
224    ///            .expect("Found illegal configuration");
225    ///
226    /// let socket = ClientBuilder::new("http://localhost:4200/")
227    ///     .namespace("/admin")
228    ///     .on("error", |err, _| eprintln!("Error: {:#?}", err))
229    ///     .tls_config(tls_connector)
230    ///     .connect();
231    ///
232    /// ```
233    pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
234        self.tls_config = Some(tls_config);
235        self
236    }
237
238    /// Sets custom http headers for the opening request. The headers will be passed to the underlying
239    /// transport type (either websockets or polling) and then get passed with every request thats made.
240    /// via the transport layer.
241    /// # Example
242    /// ```rust
243    /// use tf_rust_socketio::{ClientBuilder, Payload};
244    ///
245    ///
246    /// let socket = ClientBuilder::new("http://localhost:4200/")
247    ///     .namespace("/admin")
248    ///     .on("error", |err, _| eprintln!("Error: {:#?}", err))
249    ///     .opening_header("accept-encoding", "application/json")
250    ///     .connect();
251    ///
252    /// ```
253    pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(mut self, key: K, val: T) -> Self {
254        match self.opening_headers {
255            Some(ref mut map) => {
256                map.insert(key.into(), val.into());
257            }
258            None => {
259                let mut map = HeaderMap::default();
260                map.insert(key.into(), val.into());
261                self.opening_headers = Some(map);
262            }
263        }
264        self
265    }
266
267    /// Sets data sent in the opening request.
268    /// # Example
269    /// ```rust
270    /// use tf_rust_socketio::{ClientBuilder};
271    /// use serde_json::json;
272    ///
273    /// let socket = ClientBuilder::new("http://localhost:4204/")
274    ///     .namespace("/admin")
275    ///     .auth(json!({ "password": "1337" }))
276    ///     .on("error", |err, _| eprintln!("Error: {:#?}", err))
277    ///     .connect()
278    ///     .expect("Connection error");
279    ///
280    /// ```
281    pub fn auth(mut self, auth: serde_json::Value) -> Self {
282        self.auth = Some(auth);
283
284        self
285    }
286
287    /// Specifies which EngineIO [`TransportType`] to use.
288    /// # Example
289    /// ```rust
290    /// use tf_rust_socketio::{ClientBuilder, TransportType};
291    /// use serde_json::json;
292    ///
293    /// let socket = ClientBuilder::new("http://localhost:4200/")
294    ///     // Use websockets to handshake and connect.
295    ///     .transport_type(TransportType::Websocket)
296    ///     .connect()
297    ///     .expect("connection failed");
298    ///
299    /// // use the socket
300    /// let json_payload = json!({"token": 123});
301    ///
302    /// let result = socket.emit("foo", json_payload);
303    ///
304    /// assert!(result.is_ok());
305    /// ```
306    pub fn transport_type(mut self, transport_type: TransportType) -> Self {
307        self.transport_type = transport_type;
308
309        self
310    }
311
312    /// Connects the socket to a certain endpoint. This returns a connected
313    /// [`Client`] instance. This method returns an [`std::result::Result::Err`]
314    /// value if something goes wrong during connection. Also starts a separate
315    /// thread to start polling for packets. Used with callbacks.
316    /// # Example
317    /// ```rust
318    /// use tf_rust_socketio::{ClientBuilder, Payload};
319    /// use serde_json::json;
320    ///
321    ///
322    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
323    ///     .namespace("/admin")
324    ///     .on("error", |err, _| eprintln!("Client error!: {:#?}", err))
325    ///     .connect()
326    ///     .expect("connection failed");
327    ///
328    /// // use the socket
329    /// let json_payload = json!({"token": 123});
330    ///
331    /// let result = socket.emit("foo", json_payload);
332    ///
333    /// assert!(result.is_ok());
334    /// ```
335    pub fn connect(self) -> Result<Client> {
336        Client::new(self)
337    }
338
339    pub fn connect_raw(self) -> Result<RawClient> {
340        // Parse url here rather than in new to keep new returning Self.
341        let mut url = Url::parse(&self.address)?;
342
343        if url.path() == "/" {
344            url.set_path("/socket.io/");
345        }
346
347        let mut builder = EngineIoClientBuilder::new(url);
348
349        if let Some(tls_config) = self.tls_config {
350            builder = builder.tls_config(tls_config);
351        }
352        if let Some(headers) = self.opening_headers {
353            builder = builder.headers(headers);
354        }
355
356        let engine_client = match self.transport_type {
357            TransportType::Any => builder.build_with_fallback()?,
358            TransportType::Polling => builder.build_polling()?,
359            TransportType::Websocket => builder.build_websocket()?,
360            TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade()?,
361        };
362
363        let inner_socket = InnerSocket::new(engine_client)?;
364
365        let socket = RawClient::new(
366            inner_socket,
367            &self.namespace,
368            self.on,
369            self.on_any,
370            self.auth,
371        )?;
372        socket.connect()?;
373
374        Ok(socket)
375    }
376}