socketio_rs/client/builder.rs
1use std::sync::Arc;
2
3use super::client::{Client, Socket as ClientSocket};
4use crate::socket::RawSocket;
5use crate::{ack::AckId, socket::Socket};
6use crate::{callback::Callback, error::Result, Event, Payload};
7
8use dashmap::DashMap;
9use engineio_rs::{HeaderMap, HeaderValue, SocketBuilder as EngineSocketBuilder};
10use futures_util::future::BoxFuture;
11use tracing::trace;
12use url::Url;
13
14/// Flavor of Engine.IO transport.
15#[derive(Clone, Eq, PartialEq)]
16pub enum TransportType {
17 /// Handshakes with polling, upgrades if possible
18 Any,
19 /// Handshakes with websocket. Does not use polling.
20 Websocket,
21 /// Handshakes with polling, errors if upgrade fails
22 WebsocketUpgrade,
23 /// Handshakes with polling
24 Polling,
25}
26
27/// A builder class for a `socket.io` socket. This handles setting up the client and
28/// configuring the callback, the namespace and metadata of the socket. If no
29/// namespace is specified, the default namespace `/` is taken. The `connect` method
30/// acts the `build` method and returns a connected [`Client`].
31#[derive(Clone)]
32pub struct ClientBuilder {
33 address: String,
34 on: Arc<DashMap<Event, Callback<ClientSocket>>>,
35 namespace: String,
36 opening_headers: Option<HeaderMap>,
37 transport_type: TransportType,
38 pub(crate) reconnect: bool,
39 // None reconnect attempts represent infinity.
40 pub(crate) max_reconnect_attempts: Option<usize>,
41 pub(crate) reconnect_delay_min: u64,
42 pub(crate) reconnect_delay_max: u64,
43}
44
45impl ClientBuilder {
46 /// Create as client builder from a URL. URLs must be in the form
47 /// `[ws or wss or http or https]://[domain]:[port]/[path]`. The
48 /// path of the URL is optional and if no port is given, port 80
49 /// will be used.
50 /// # Example
51 /// ```no_run
52 /// use socketio_rs::{Payload, ClientBuilder, Socket, AckId};
53 /// use serde_json::json;
54 /// use futures_util::future::FutureExt;
55 ///
56 ///
57 /// #[tokio::main]
58 /// async fn main() {
59 /// let callback = |payload: Option<Payload>, socket: Socket, need_ack: Option<AckId>| {
60 /// async move {
61 /// match payload {
62 /// Some(Payload::Json(data)) => println!("Received: {:?}", data),
63 /// Some(Payload::Binary(bin)) => println!("Received bytes: {:#?}", bin),
64 /// Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
65 /// _ => {},
66 /// }
67 /// }.boxed()
68 /// };
69 ///
70 /// let mut socket = ClientBuilder::new("http://localhost:4200")
71 /// .namespace("/admin")
72 /// .on("test", callback)
73 /// .connect()
74 /// .await
75 /// .expect("error while connecting");
76 ///
77 /// // use the socket
78 /// let json_payload = json!({"token": 123});
79 ///
80 /// let result = socket.emit("foo", json_payload).await;
81 ///
82 /// assert!(result.is_ok());
83 /// }
84 /// ```
85 pub fn new<T: Into<String>>(address: T) -> Self {
86 Self {
87 address: address.into(),
88 on: Default::default(),
89 namespace: "/".to_owned(),
90 opening_headers: None,
91 transport_type: TransportType::Any,
92 reconnect: true,
93 // None means infinity
94 max_reconnect_attempts: None,
95 reconnect_delay_min: 1000,
96 reconnect_delay_max: 5000,
97 }
98 }
99
100 /// Sets the target namespace of the client. The namespace should start
101 /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`.
102 /// If the String provided doesn't start with a leading `/`, it is
103 /// added manually.
104 pub fn namespace<T: Into<String>>(mut self, namespace: T) -> Self {
105 let mut nsp = namespace.into();
106 if !nsp.starts_with('/') {
107 nsp = "/".to_owned() + &nsp;
108 trace!("Added `/` to the given namespace: {}", nsp);
109 }
110 self.namespace = nsp;
111 self
112 }
113
114 /// Registers a new callback for a certain [`crate::event::Event`]. The event could either be
115 /// one of the common events like `message`, `error`, `connect`, `close` or a custom
116 /// event defined by a string, e.g. `onPayment` or `foo`.
117 ///
118 /// # Example
119 /// ```rust
120 /// use socketio_rs::{ClientBuilder, Payload};
121 /// use futures_util::FutureExt;
122 ///
123 /// #[tokio::main]
124 /// async fn main() {
125 /// let socket = ClientBuilder::new("http://localhost:4200/")
126 /// .namespace("/admin")
127 /// .on("test", |payload: Option<Payload>, _, _| {
128 /// async move {
129 /// match payload {
130 /// Some(Payload::Json(data)) => println!("Received: {:?}", data),
131 /// Some(Payload::Binary(bin)) => println!("Received bytes: {:#?}", bin),
132 /// Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
133 /// _ => {},
134 /// }
135 /// }
136 /// .boxed()
137 /// })
138 /// .on("error", |err, _, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
139 /// .connect()
140 /// .await;
141 /// }
142 /// ```
143 ///
144 /// # Issues with type inference for the callback method
145 ///
146 /// Currently stable Rust does not contain types like `AsyncFnMut`.
147 /// That is why this library uses the type `FnMut(..) -> BoxFuture<_>`,
148 /// which basically represents a closure or function that returns a
149 /// boxed future that can be executed in an async executor.
150 /// The complicated constraints for the callback function
151 /// bring the Rust compiler to it's limits, resulting in confusing error
152 /// messages when passing in a variable that holds a closure (to the `on` method).
153 /// In order to make sure type inference goes well, the [`futures_util::FutureExt::boxed`]
154 /// method can be used on an async block (the future) to make sure the return type
155 /// is conform with the generic requirements. An example can be found here:
156 ///
157 /// ```rust
158 /// use socketio_rs::{ClientBuilder, Payload};
159 /// use futures_util::FutureExt;
160 ///
161 /// #[tokio::main]
162 /// async fn main() {
163 /// let callback = |payload: Option<Payload>, _, _| {
164 /// async move {
165 /// match payload {
166 /// Some(Payload::Json(data)) => println!("Received: {:?}", data),
167 /// Some(Payload::Binary(bin)) => println!("Received bytes: {:#?}", bin),
168 /// Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
169 /// _ => {},
170 /// }
171 /// }
172 /// .boxed() // <-- this makes sure we end up with a `BoxFuture<_>`
173 /// };
174 ///
175 /// let socket = ClientBuilder::new("http://localhost:4200/")
176 /// .namespace("/admin")
177 /// .on("test", callback)
178 /// .connect()
179 /// .await;
180 /// }
181 /// ```
182 ///
183 pub fn on<T: Into<Event>, F>(self, event: T, callback: F) -> Self
184 where
185 F: for<'a> std::ops::FnMut(
186 Option<Payload>,
187 ClientSocket,
188 Option<AckId>,
189 ) -> BoxFuture<'static, ()>
190 + 'static
191 + Send
192 + Sync,
193 {
194 let callback = Callback::new(callback);
195 let event = event.into();
196 // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held
197 let on = self.on.clone();
198 tokio::spawn(async move {
199 on.insert(event, callback);
200 });
201 self
202 }
203
204 /// Sets custom http headers for the opening request. The headers will be passed to the underlying
205 /// transport type (either websockets or polling) and then get passed with every request thats made.
206 /// via the transport layer.
207 /// # Example
208 /// ```rust
209 /// use socketio_rs::{ClientBuilder, Payload};
210 /// use futures_util::future::FutureExt;
211 ///
212 /// #[tokio::main]
213 /// async fn main() {
214 /// let socket = ClientBuilder::new("http://localhost:4200/")
215 /// .namespace("/admin")
216 /// .on("error", |err, _, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
217 /// .opening_header("accept-encoding", "application/json")
218 /// .connect()
219 /// .await;
220 /// }
221 /// ```
222 pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(mut self, key: K, val: T) -> Self {
223 match self.opening_headers {
224 Some(ref mut map) => {
225 map.insert(key.into(), val.into());
226 }
227 None => {
228 let mut map = HeaderMap::default();
229 map.insert(key.into(), val.into());
230 self.opening_headers = Some(map);
231 }
232 }
233 self
234 }
235
236 /// Specifies which EngineIO [`TransportType`] to use.
237 ///
238 /// # Example
239 /// ```no_run
240 /// use socketio_rs::{ClientBuilder, TransportType};
241 ///
242 /// #[tokio::main]
243 /// async fn main() {
244 /// let socket = ClientBuilder::new("http://localhost:4200/")
245 /// // Use websockets to handshake and connect.
246 /// .transport_type(TransportType::Websocket)
247 /// .connect()
248 /// .await
249 /// .expect("connection failed");
250 /// }
251 /// ```
252 pub fn transport_type(mut self, transport_type: TransportType) -> Self {
253 self.transport_type = transport_type;
254
255 self
256 }
257
258 /// Connects the socket to a certain endpoint. This returns a connected
259 /// [`Client`] instance. This method returns an [`std::result::Result::Err`]
260 /// value if something goes wrong during connection. Also starts a separate
261 /// thread to start polling for packets. Used with callbacks.
262 /// # Example
263 /// ```no_run
264 /// use socketio_rs::{ClientBuilder, Payload};
265 /// use serde_json::json;
266 /// use futures_util::future::FutureExt;
267 ///
268 /// #[tokio::main]
269 /// async fn main() {
270 /// let mut socket = ClientBuilder::new("http://localhost:4200/")
271 /// .namespace("/admin")
272 /// .on("error", |err, _, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
273 /// .connect()
274 /// .await
275 /// .expect("connection failed");
276 ///
277 /// // use the socket
278 /// let json_payload = json!({"token": 123});
279 ///
280 /// let result = socket.emit("foo", json_payload).await;
281 ///
282 /// assert!(result.is_ok());
283 /// }
284 /// ```
285 pub async fn connect(self) -> Result<Client> {
286 let client = Client::new(self).await;
287 if let Ok(c) = &client {
288 c.poll_callback();
289 }
290 client
291 }
292
293 pub fn reconnect(mut self, reconnect: bool) -> Self {
294 self.reconnect = reconnect;
295 self
296 }
297
298 pub fn reconnect_delay(mut self, min: u64, max: u64) -> Self {
299 self.reconnect_delay_min = min;
300 self.reconnect_delay_max = max;
301
302 self
303 }
304
305 pub fn max_reconnect_attempts(mut self, reconnect_attempts: usize) -> Self {
306 self.max_reconnect_attempts = Some(reconnect_attempts);
307 self
308 }
309
310 #[cfg(test)]
311 pub(crate) async fn connect_client(self) -> Result<Client> {
312 Client::new(self.clone()).await
313 }
314
315 pub(crate) async fn connect_socket(&self) -> Result<Socket<ClientSocket>> {
316 // Parse url here rather than in new to keep new returning Self.
317 let mut url = Url::parse(&self.address)?;
318
319 if url.path() == "/" {
320 url.set_path("/socket.io/");
321 }
322
323 let mut builder = EngineSocketBuilder::new(url);
324
325 if let Some(headers) = &self.opening_headers {
326 builder = builder.headers(headers.clone());
327 }
328
329 let engine_client = match self.transport_type {
330 TransportType::Any => builder.build_with_fallback().await?,
331 TransportType::Polling => builder.build_polling().await?,
332 TransportType::Websocket => builder.build_websocket().await?,
333 TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade().await?,
334 };
335
336 let inner_socket = RawSocket::client_end(engine_client);
337 let socket = Socket::<ClientSocket>::new(
338 inner_socket,
339 self.namespace.clone(),
340 self.on.clone(),
341 Arc::new(|s| s.into()),
342 );
343
344 socket.connect().await?;
345 Ok(socket)
346 }
347}