socketio_rs/client/
builder.rs

1use std::sync::Arc;
2
3use super::client::{Client, Socket as ClientSocket};
4use crate::socket::RawSocket;
5use crate::{ack::AckId, socket::Socket};
6use crate::{callback::Callback, error::Result, Event, Payload};
7
8use dashmap::DashMap;
9use engineio_rs::{HeaderMap, HeaderValue, SocketBuilder as EngineSocketBuilder};
10use futures_util::future::BoxFuture;
11use tracing::trace;
12use url::Url;
13
14/// Flavor of Engine.IO transport.
15#[derive(Clone, Eq, PartialEq)]
16pub enum TransportType {
17    /// Handshakes with polling, upgrades if possible
18    Any,
19    /// Handshakes with websocket. Does not use polling.
20    Websocket,
21    /// Handshakes with polling, errors if upgrade fails
22    WebsocketUpgrade,
23    /// Handshakes with polling
24    Polling,
25}
26
27/// A builder class for a `socket.io` socket. This handles setting up the client and
28/// configuring the callback, the namespace and metadata of the socket. If no
29/// namespace is specified, the default namespace `/` is taken. The `connect` method
30/// acts the `build` method and returns a connected [`Client`].
31#[derive(Clone)]
32pub struct ClientBuilder {
33    address: String,
34    on: Arc<DashMap<Event, Callback<ClientSocket>>>,
35    namespace: String,
36    opening_headers: Option<HeaderMap>,
37    transport_type: TransportType,
38    pub(crate) reconnect: bool,
39    // None reconnect attempts represent infinity.
40    pub(crate) max_reconnect_attempts: Option<usize>,
41    pub(crate) reconnect_delay_min: u64,
42    pub(crate) reconnect_delay_max: u64,
43}
44
45impl ClientBuilder {
46    /// Create as client builder from a URL. URLs must be in the form
47    /// `[ws or wss or http or https]://[domain]:[port]/[path]`. The
48    /// path of the URL is optional and if no port is given, port 80
49    /// will be used.
50    /// # Example
51    /// ```no_run
52    /// use socketio_rs::{Payload, ClientBuilder, Socket, AckId};
53    /// use serde_json::json;
54    /// use futures_util::future::FutureExt;
55    ///
56    ///
57    /// #[tokio::main]
58    /// async fn main() {
59    ///     let callback = |payload: Option<Payload>, socket: Socket, need_ack: Option<AckId>| {
60    ///         async move {
61    ///             match payload {
62    ///                 Some(Payload::Json(data)) => println!("Received: {:?}", data),
63    ///                 Some(Payload::Binary(bin)) => println!("Received bytes: {:#?}", bin),
64    ///                 Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
65    ///                 _ => {},
66    ///             }
67    ///         }.boxed()
68    ///     };
69    ///
70    ///     let mut socket = ClientBuilder::new("http://localhost:4200")
71    ///         .namespace("/admin")
72    ///         .on("test", callback)
73    ///         .connect()
74    ///         .await
75    ///         .expect("error while connecting");
76    ///
77    ///     // use the socket
78    ///     let json_payload = json!({"token": 123});
79    ///
80    ///     let result = socket.emit("foo", json_payload).await;
81    ///
82    ///     assert!(result.is_ok());
83    /// }
84    /// ```
85    pub fn new<T: Into<String>>(address: T) -> Self {
86        Self {
87            address: address.into(),
88            on: Default::default(),
89            namespace: "/".to_owned(),
90            opening_headers: None,
91            transport_type: TransportType::Any,
92            reconnect: true,
93            // None means infinity
94            max_reconnect_attempts: None,
95            reconnect_delay_min: 1000,
96            reconnect_delay_max: 5000,
97        }
98    }
99
100    /// Sets the target namespace of the client. The namespace should start
101    /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`.
102    /// If the String provided doesn't start with a leading `/`, it is
103    /// added manually.
104    pub fn namespace<T: Into<String>>(mut self, namespace: T) -> Self {
105        let mut nsp = namespace.into();
106        if !nsp.starts_with('/') {
107            nsp = "/".to_owned() + &nsp;
108            trace!("Added `/` to the given namespace: {}", nsp);
109        }
110        self.namespace = nsp;
111        self
112    }
113
114    /// Registers a new callback for a certain [`crate::event::Event`]. The event could either be
115    /// one of the common events like `message`, `error`, `connect`, `close` or a custom
116    /// event defined by a string, e.g. `onPayment` or `foo`.
117    ///
118    /// # Example
119    /// ```rust
120    /// use socketio_rs::{ClientBuilder, Payload};
121    /// use futures_util::FutureExt;
122    ///
123    ///  #[tokio::main]
124    /// async fn main() {
125    ///     let socket = ClientBuilder::new("http://localhost:4200/")
126    ///         .namespace("/admin")
127    ///         .on("test", |payload: Option<Payload>, _, _| {
128    ///             async move {
129    ///                 match payload {
130    ///                     Some(Payload::Json(data)) => println!("Received: {:?}", data),
131    ///                     Some(Payload::Binary(bin)) => println!("Received bytes: {:#?}", bin),
132    ///                     Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
133    ///                     _ => {},
134    ///                 }
135    ///             }
136    ///             .boxed()
137    ///         })
138    ///         .on("error", |err, _, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
139    ///         .connect()
140    ///         .await;
141    /// }
142    /// ```
143    ///
144    /// # Issues with type inference for the callback method
145    ///
146    /// Currently stable Rust does not contain types like `AsyncFnMut`.
147    /// That is why this library uses the type `FnMut(..) -> BoxFuture<_>`,
148    /// which basically represents a closure or function that returns a
149    /// boxed future that can be executed in an async executor.
150    /// The complicated constraints for the callback function
151    /// bring the Rust compiler to it's limits, resulting in confusing error
152    /// messages when passing in a variable that holds a closure (to the `on` method).
153    /// In order to make sure type inference goes well, the [`futures_util::FutureExt::boxed`]
154    /// method can be used on an async block (the future) to make sure the return type
155    /// is conform with the generic requirements. An example can be found here:
156    ///
157    /// ```rust
158    /// use socketio_rs::{ClientBuilder, Payload};
159    /// use futures_util::FutureExt;
160    ///
161    /// #[tokio::main]
162    /// async fn main() {
163    ///     let callback = |payload: Option<Payload>, _, _| {
164    ///             async move {
165    ///                 match payload {
166    ///                     Some(Payload::Json(data)) => println!("Received: {:?}", data),
167    ///                     Some(Payload::Binary(bin)) => println!("Received bytes: {:#?}", bin),
168    ///                     Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
169    ///                     _ => {},
170    ///                 }
171    ///             }
172    ///             .boxed() // <-- this makes sure we end up with a `BoxFuture<_>`
173    ///         };
174    ///
175    ///     let socket = ClientBuilder::new("http://localhost:4200/")
176    ///         .namespace("/admin")
177    ///         .on("test", callback)
178    ///         .connect()
179    ///         .await;
180    /// }
181    /// ```
182    ///
183    pub fn on<T: Into<Event>, F>(self, event: T, callback: F) -> Self
184    where
185        F: for<'a> std::ops::FnMut(
186                Option<Payload>,
187                ClientSocket,
188                Option<AckId>,
189            ) -> BoxFuture<'static, ()>
190            + 'static
191            + Send
192            + Sync,
193    {
194        let callback = Callback::new(callback);
195        let event = event.into();
196        // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held
197        let on = self.on.clone();
198        tokio::spawn(async move {
199            on.insert(event, callback);
200        });
201        self
202    }
203
204    /// Sets custom http headers for the opening request. The headers will be passed to the underlying
205    /// transport type (either websockets or polling) and then get passed with every request thats made.
206    /// via the transport layer.
207    /// # Example
208    /// ```rust
209    /// use socketio_rs::{ClientBuilder, Payload};
210    /// use futures_util::future::FutureExt;
211    ///
212    /// #[tokio::main]
213    /// async fn main() {
214    ///     let socket = ClientBuilder::new("http://localhost:4200/")
215    ///         .namespace("/admin")
216    ///         .on("error", |err, _, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
217    ///         .opening_header("accept-encoding", "application/json")
218    ///         .connect()
219    ///         .await;
220    /// }
221    /// ```
222    pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(mut self, key: K, val: T) -> Self {
223        match self.opening_headers {
224            Some(ref mut map) => {
225                map.insert(key.into(), val.into());
226            }
227            None => {
228                let mut map = HeaderMap::default();
229                map.insert(key.into(), val.into());
230                self.opening_headers = Some(map);
231            }
232        }
233        self
234    }
235
236    /// Specifies which EngineIO [`TransportType`] to use.
237    ///
238    /// # Example
239    /// ```no_run
240    /// use socketio_rs::{ClientBuilder, TransportType};
241    ///
242    /// #[tokio::main]
243    /// async fn main() {
244    ///     let socket = ClientBuilder::new("http://localhost:4200/")
245    ///         // Use websockets to handshake and connect.
246    ///         .transport_type(TransportType::Websocket)
247    ///         .connect()
248    ///         .await
249    ///         .expect("connection failed");
250    /// }
251    /// ```
252    pub fn transport_type(mut self, transport_type: TransportType) -> Self {
253        self.transport_type = transport_type;
254
255        self
256    }
257
258    /// Connects the socket to a certain endpoint. This returns a connected
259    /// [`Client`] instance. This method returns an [`std::result::Result::Err`]
260    /// value if something goes wrong during connection. Also starts a separate
261    /// thread to start polling for packets. Used with callbacks.
262    /// # Example
263    /// ```no_run
264    /// use socketio_rs::{ClientBuilder, Payload};
265    /// use serde_json::json;
266    /// use futures_util::future::FutureExt;
267    ///
268    /// #[tokio::main]
269    /// async fn main() {
270    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
271    ///         .namespace("/admin")
272    ///         .on("error", |err, _, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
273    ///         .connect()
274    ///         .await
275    ///         .expect("connection failed");
276    ///
277    ///     // use the socket
278    ///     let json_payload = json!({"token": 123});
279    ///
280    ///     let result = socket.emit("foo", json_payload).await;
281    ///
282    ///     assert!(result.is_ok());
283    /// }
284    /// ```
285    pub async fn connect(self) -> Result<Client> {
286        let client = Client::new(self).await;
287        if let Ok(c) = &client {
288            c.poll_callback();
289        }
290        client
291    }
292
293    pub fn reconnect(mut self, reconnect: bool) -> Self {
294        self.reconnect = reconnect;
295        self
296    }
297
298    pub fn reconnect_delay(mut self, min: u64, max: u64) -> Self {
299        self.reconnect_delay_min = min;
300        self.reconnect_delay_max = max;
301
302        self
303    }
304
305    pub fn max_reconnect_attempts(mut self, reconnect_attempts: usize) -> Self {
306        self.max_reconnect_attempts = Some(reconnect_attempts);
307        self
308    }
309
310    #[cfg(test)]
311    pub(crate) async fn connect_client(self) -> Result<Client> {
312        Client::new(self.clone()).await
313    }
314
315    pub(crate) async fn connect_socket(&self) -> Result<Socket<ClientSocket>> {
316        // Parse url here rather than in new to keep new returning Self.
317        let mut url = Url::parse(&self.address)?;
318
319        if url.path() == "/" {
320            url.set_path("/socket.io/");
321        }
322
323        let mut builder = EngineSocketBuilder::new(url);
324
325        if let Some(headers) = &self.opening_headers {
326            builder = builder.headers(headers.clone());
327        }
328
329        let engine_client = match self.transport_type {
330            TransportType::Any => builder.build_with_fallback().await?,
331            TransportType::Polling => builder.build_polling().await?,
332            TransportType::Websocket => builder.build_websocket().await?,
333            TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade().await?,
334        };
335
336        let inner_socket = RawSocket::client_end(engine_client);
337        let socket = Socket::<ClientSocket>::new(
338            inner_socket,
339            self.namespace.clone(),
340            self.on.clone(),
341            Arc::new(|s| s.into()),
342        );
343
344        socket.connect().await?;
345        Ok(socket)
346    }
347}