socketio_rs/client/
client.rs

1use std::{
2    ops::{Deref, DerefMut},
3    sync::Arc,
4    time::Duration,
5};
6
7use crate::{
8    socket::Socket as InnerSocket, AckId, ClientBuilder, Error, Event, Packet, Payload, Result,
9};
10
11use backoff::{backoff::Backoff, ExponentialBackoff, ExponentialBackoffBuilder};
12use futures_util::future::BoxFuture;
13use tokio::sync::RwLock;
14use tracing::{trace, warn};
15
16#[derive(Clone)]
17pub struct Client {
18    builder: ClientBuilder,
19    socket: Arc<RwLock<InnerSocket<Socket>>>,
20    backoff: ExponentialBackoff,
21    connected: Arc<RwLock<bool>>,
22}
23
24#[derive(Clone)]
25pub struct Socket {
26    pub(crate) socket: InnerSocket<Self>,
27}
28
29impl From<InnerSocket<Socket>> for Socket {
30    fn from(socket: InnerSocket<Socket>) -> Self {
31        Self { socket }
32    }
33}
34
35impl Client {
36    /// Sends a message to the server using the underlying `engine.io` protocol.
37    /// This message takes an event, which could either be one of the common
38    /// events like "message" or "error" or a custom event like "foo". But be
39    /// careful, the data string needs to be valid JSON. It's recommended to use
40    /// a library like `serde_json` to serialize the data properly.
41    ///
42    /// # Example
43    /// ```no_run
44    /// use socketio_rs::{ClientBuilder, Socket, AckId, Payload};
45    /// use serde_json::json;
46    /// use futures_util::FutureExt;
47    ///
48    /// #[tokio::main]
49    /// async fn main() {
50    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
51    ///         .on("test", |payload: Option<Payload>, socket: Socket, need_ack: Option<AckId>| {
52    ///             async move {
53    ///                 println!("Received: {:?}", payload);
54    ///                 socket.emit("test", json!({"hello": true})).await.expect("Server unreachable");
55    ///             }.boxed()
56    ///         })
57    ///         .connect()
58    ///         .await
59    ///         .expect("connection failed");
60    ///
61    ///     let json_payload = json!({"token": 123});
62    ///
63    ///     let result = socket.emit("foo", json_payload).await;
64    ///
65    ///     assert!(result.is_ok());
66    /// }
67    /// ```
68    #[inline]
69    pub async fn emit<E, D>(&self, event: E, data: D) -> Result<()>
70    where
71        E: Into<Event>,
72        D: Into<Payload>,
73    {
74        let socket = self.socket.read().await;
75        socket.emit(event, data).await
76    }
77
78    /// Sends a message to the server but `alloc`s an `ack` to check whether the
79    /// server responded in a given time span. This message takes an event, which
80    /// could either be one of the common events like "message" or "error" or a
81    /// custom event like "foo", as well as a data parameter.
82    /// It also requires a timeout `Duration` in which the client needs to answer.
83    /// If the ack is acked in the correct time span, the specified callback is
84    /// called. The callback consumes a [`Payload`] which represents the data send
85    /// by the server.
86    ///
87    /// Please note that the requirements on the provided callbacks are similar to the ones
88    /// for [`crate::asynchronous::ClientBuilder::on`].
89    /// # Example
90    /// ```no_run
91    /// use socketio_rs::{ClientBuilder, Socket, Payload};
92    /// use serde_json::json;
93    /// use std::time::Duration;
94    /// use std::thread::sleep;
95    /// use futures_util::FutureExt;
96    ///
97    /// #[tokio::main]
98    /// async fn main() {
99    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
100    ///         .on("foo", |payload: Option<Payload>, _, _| async move { println!("Received: {:#?}", payload) }.boxed())
101    ///         .connect()
102    ///         .await
103    ///         .expect("connection failed");
104    ///
105    ///     let ack_callback = |message: Option<Payload>, socket: Socket, _| {
106    ///         async move {
107    ///             match message {
108    ///                 Some(Payload::Json(data)) => println!("{:?}", data),
109    ///                 Some(Payload::Binary(bytes)) => println!("Received bytes: {:#?}", bytes),
110    ///                 Some(Payload::Multi(multi)) => println!("Received multi: {:?}", multi),
111    ///                 _ => {}
112    ///             }
113    ///         }.boxed()
114    ///     };    
115    ///
116    ///
117    ///     let payload = json!({"token": 123});
118    ///     socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).await.unwrap();
119    ///
120    ///     sleep(Duration::from_secs(2));
121    /// }
122    /// ```
123    #[inline]
124    pub async fn emit_with_ack<F, E, D>(
125        &self,
126        event: E,
127        data: D,
128        timeout: Duration,
129        callback: F,
130    ) -> Result<()>
131    where
132        F: for<'a> std::ops::FnMut(
133                Option<Payload>,
134                Socket,
135                Option<AckId>,
136            ) -> BoxFuture<'static, ()>
137            + 'static
138            + Send
139            + Sync,
140        E: Into<Event>,
141        D: Into<Payload>,
142    {
143        let socket = self.socket.read().await;
144        socket.emit_with_ack(event, data, timeout, callback).await
145    }
146
147    pub async fn ack(&self, id: usize, data: Payload) -> Result<()> {
148        let socket = self.socket.read().await;
149        socket.ack(id, data).await
150    }
151
152    /// Disconnects from the server by sending a socket.io `Disconnect` packet. This results
153    /// in the underlying engine.io transport to get closed as well.
154    pub async fn disconnect(&self) -> Result<()> {
155        trace!("client disconnect");
156        let mut connected = self.connected.write().await;
157        if !*connected {
158            return Ok(());
159        }
160        *connected = false;
161        self.disconnect_socket().await
162    }
163
164    async fn disconnect_socket(&self) -> Result<()> {
165        let socket = self.socket.read().await;
166        socket.disconnect().await
167    }
168
169    pub(crate) async fn new(builder: ClientBuilder) -> Result<Self> {
170        let b = builder.clone();
171        let socket = b.connect_socket().await?;
172        let connected = Arc::new(RwLock::new(true));
173        let backoff = ExponentialBackoffBuilder::new()
174            .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
175            .with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
176            .build();
177
178        let s = Self {
179            builder,
180            socket: Arc::new(RwLock::new(socket)),
181            backoff,
182            connected,
183        };
184
185        Ok(s)
186    }
187
188    async fn reconnect(&mut self) {
189        let mut reconnect_attempts = 0;
190        if self.builder.reconnect {
191            loop {
192                if let Some(max_reconnect_attempts) = self.builder.max_reconnect_attempts {
193                    if reconnect_attempts > max_reconnect_attempts {
194                        break;
195                    }
196                }
197                reconnect_attempts += 1;
198
199                if let Some(backoff) = self.backoff.next_backoff() {
200                    trace!("reconnect backoff {:?}", backoff);
201                    tokio::time::sleep(backoff).await;
202                }
203
204                trace!("client reconnect {}", reconnect_attempts);
205                if self.do_reconnect().await.is_ok() {
206                    break;
207                }
208            }
209        }
210    }
211
212    async fn do_reconnect(&self) -> Result<()> {
213        let new_socket = self.builder.clone().connect_socket().await?;
214        let mut socket = self.socket.write().await;
215        *socket = new_socket;
216        Ok(())
217    }
218
219    pub(crate) fn poll_callback(&self) {
220        let mut self_clone = self.clone();
221        // Use thread to consume items in iterator in order to call callbacks
222        tokio::spawn(async move {
223            trace!("start poll_callback ");
224            // tries to restart a poll cycle whenever a 'normal' error occurs,
225            // it just panics on network errors, in case the poll cycle returned
226            // `Result::Ok`, the server receives a close frame so it's safe to
227            // terminate
228            #[allow(clippy::for_loops_over_fallibles)]
229            loop {
230                let packet = self_clone.poll_packet().await;
231                trace!("poll_callback packet {:?}", packet);
232                if let Some(Err(Error::IncompleteResponseFromEngineIo(_))) = packet {
233                    //TODO: logging error
234                    let _ = self_clone.disconnect_socket().await;
235                    self_clone.reconnect().await;
236                }
237                if !*self_clone.connected.read().await {
238                    break;
239                }
240            }
241            warn!("poll_callback exist");
242        });
243    }
244
245    pub(crate) async fn poll_packet(&self) -> Option<Result<Packet>> {
246        let socket = self.socket.read().await;
247        socket.poll_packet().await
248    }
249}
250
251impl Deref for Socket {
252    type Target = InnerSocket<Self>;
253
254    fn deref(&self) -> &Self::Target {
255        &self.socket
256    }
257}
258
259impl DerefMut for Socket {
260    fn deref_mut(&mut self) -> &mut Self::Target {
261        &mut self.socket
262    }
263}
264
265#[cfg(test)]
266mod test {
267    use std::time::Duration;
268
269    use super::*;
270    use crate::{
271        test::socket_io_server, AckId, Client, ClientBuilder, Event, Packet, PacketType, Payload,
272        Result, ServerBuilder, ServerSocket,
273    };
274
275    use bytes::Bytes;
276    use futures_util::FutureExt;
277    use serde_json::json;
278    use tokio::{sync::mpsc::unbounded_channel, time::sleep};
279    use tracing::info;
280
281    #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
282    async fn test_client() -> Result<()> {
283        // tracing_subscriber::fmt()
284        //     .with_env_filter("engineio=trace,socketio=trace")
285        //     .init();
286        setup_server();
287
288        socket_io_integration().await?;
289        socket_io_builder_integration().await?;
290        socket_io_builder_integration_iterator().await?;
291        Ok(())
292    }
293
294    async fn socket_io_integration() -> Result<()> {
295        let url = socket_io_server();
296
297        let socket = ClientBuilder::new(url)
298            .on("test", |msg, _, _| {
299                async {
300                    match msg {
301                        Some(Payload::Json(data)) => info!("Received string: {:?}", data),
302                        Some(Payload::Binary(bin)) => info!("Received binary data: {:#?}", bin),
303                        Some(Payload::Multi(multi)) => info!("Received multi {:?}", multi),
304                        _ => {}
305                    }
306                }
307                .boxed()
308            })
309            .connect()
310            .await?;
311
312        let payload = json!({"token": 123_i32});
313        let result = socket.emit("test", Payload::Json(payload.clone())).await;
314
315        assert!(result.is_ok());
316
317        let ack = socket
318            .emit_with_ack(
319                "test",
320                Payload::Json(payload),
321                Duration::from_secs(1),
322                |message: Option<Payload>, socket: Socket, _| {
323                    async move {
324                        let result = socket
325                            .emit("test", Payload::Json(json!({"got ack": true})))
326                            .await;
327                        assert!(result.is_ok());
328
329                        info!("Yehaa! My ack got acked?");
330                        if let Some(Payload::Json(data)) = message {
331                            info!("Received string Ack");
332                            info!("Ack data: {:?}", data);
333                        }
334                    }
335                    .boxed()
336                },
337            )
338            .await;
339        assert!(ack.is_ok());
340
341        sleep(Duration::from_secs(2)).await;
342
343        assert!(socket.disconnect().await.is_ok());
344
345        Ok(())
346    }
347
348    async fn socket_io_builder_integration() -> Result<()> {
349        let url = socket_io_server();
350
351        // test socket build logic
352        let socket_builder = ClientBuilder::new(url);
353
354        let socket = socket_builder
355            .namespace("/admin")
356            .opening_header("accept-encoding", "application/json")
357            .on("test", |str, _, _| {
358                async move { info!("Received: {:#?}", str) }.boxed()
359            })
360            .on("message", |payload, _, _| {
361                async move { info!("{:#?}", payload) }.boxed()
362            })
363            .connect()
364            .await?;
365
366        assert!(socket.emit("message", json!("Hello World")).await.is_ok());
367
368        assert!(socket
369            .emit("binary", Bytes::from_static(&[46, 88]))
370            .await
371            .is_ok());
372
373        assert!(socket
374            .emit_with_ack(
375                "binary",
376                json!("pls ack"),
377                Duration::from_secs(1),
378                |payload, _, _| async move {
379                    info!("Yehaa the ack got acked");
380                    info!("With data: {:#?}", payload);
381                }
382                .boxed()
383            )
384            .await
385            .is_ok());
386
387        sleep(Duration::from_secs(2)).await;
388
389        Ok(())
390    }
391
392    async fn socket_io_builder_integration_iterator() -> Result<()> {
393        let url = socket_io_server();
394
395        // test socket build logic
396        let socket_builder = ClientBuilder::new(url);
397
398        let socket = socket_builder
399            .namespace("/admin")
400            .opening_header("accept-encoding", "application/json")
401            .on("test", |str, _, _| {
402                async move { info!("Received: {:#?}", str) }.boxed()
403            })
404            .on("message", |payload, _, _| {
405                async move { info!("Received binary {:#?}", payload) }.boxed()
406            })
407            .connect_client()
408            .await?;
409
410        test_socketio_socket(socket, "/admin".to_owned()).await
411    }
412
413    async fn test_socketio_socket(socket: Client, nsp: String) -> Result<()> {
414        // ignore connect packet
415        let _: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
416
417        let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
418        assert!(packet.is_some());
419
420        let packet = packet.unwrap();
421        assert_eq!(
422            packet,
423            Packet::new(
424                PacketType::Event,
425                nsp.clone(),
426                Some(json!(["test", "Hello from the test event!"])),
427                None,
428                0,
429                None
430            )
431        );
432
433        let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
434        assert!(packet.is_some());
435
436        let packet = packet.unwrap();
437        assert_eq!(
438            packet,
439            Packet::new(
440                PacketType::BinaryEvent,
441                nsp.clone(),
442                Some(json!(["test", {"_placeholder": true, "num": 0}])),
443                None,
444                1,
445                Some(vec![Bytes::from_static(&[1, 2, 3])]),
446            )
447        );
448
449        let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
450        assert!(packet.is_some());
451
452        let packet = packet.unwrap();
453        match packet.data {
454            Some(serde_json::Value::Array(array)) => assert_eq!(array.len(), 5),
455            _ => panic!("invlaid emit multi payload"),
456        }
457
458        let socket_clone = socket.clone();
459        // continue poll cycle
460        tokio::spawn(async move {
461            loop {
462                let _ = socket_clone.poll_packet().await;
463            }
464        });
465
466        let (tx, mut rx) = unbounded_channel();
467        let tx = Arc::new(tx);
468
469        let cb = move |message: Option<Payload>, _, _| {
470            let tx = tx.clone();
471            async move {
472                match message {
473                    Some(Payload::Multi(vec)) => {
474                        let _ = tx.send(vec.len() == 2);
475                    }
476                    _ => {
477                        let _ = tx.send(false);
478                    }
479                };
480            }
481            .boxed()
482        };
483
484        assert!(socket
485            .emit_with_ack(
486                "client_ack",
487                Payload::Multi(vec![json!(1).into(), json!(2).into()]),
488                Duration::from_secs(10),
489                cb
490            )
491            .await
492            .is_ok());
493
494        match rx.recv().await {
495            Some(true) => {}
496            _ => panic!("ACK callback invlaid"),
497        };
498
499        let (tx, mut rx) = unbounded_channel();
500        let cb = move |message: Option<Payload>, _, _| {
501            let tx = tx.clone();
502            async move {
503                match message {
504                    Some(Payload::Multi(vec)) => {
505                        let _ = tx.send(vec.len() == 2);
506                    }
507                    _ => {
508                        let _ = tx.send(false);
509                    }
510                };
511            }
512            .boxed()
513        };
514
515        assert!(socket
516            .emit_with_ack(
517                "client_ack",
518                Payload::Multi(vec![Bytes::from_static(b"1").into(), json!(2).into()]),
519                Duration::from_secs(10),
520                cb
521            )
522            .await
523            .is_ok());
524
525        match rx.recv().await {
526            Some(true) => {}
527            _ => panic!("BINARY_ACK callback invlaid"),
528        };
529
530        Ok(())
531    }
532
533    fn setup_server() {
534        let echo_callback =
535            move |_payload: Option<Payload>, socket: ServerSocket, _need_ack: Option<AckId>| {
536                async move {
537                    let _ = socket.emit("echo", json!("")).await;
538                }
539                .boxed()
540            };
541
542        let client_ack =
543            move |payload: Option<Payload>, socket: ServerSocket, need_ack: Option<AckId>| {
544                async move {
545                    if let Some(ack_id) = need_ack {
546                        socket
547                            .ack(ack_id, payload.unwrap_or_else(|| json!("ackback").into()))
548                            .await
549                            .expect("success");
550                    }
551                }
552                .boxed()
553            };
554
555        let server_recv_ack =
556            move |_payload: Option<Payload>, socket: ServerSocket, _need_ack: Option<AckId>| {
557                async move {
558                    socket
559                        .emit("server_recv_ack", json!(""))
560                        .await
561                        .expect("success");
562                }
563                .boxed()
564            };
565
566        let trigger_ack = move |message: Option<Payload>, socket: ServerSocket, _| {
567            async move {
568                let payload = message.unwrap_or_else(|| json!({"ack_back": true}).into());
569                socket
570                    .emit_with_ack(
571                        "server_ask_ack",
572                        payload,
573                        Duration::from_millis(400),
574                        server_recv_ack,
575                    )
576                    .await
577                    .expect("success");
578            }
579            .boxed()
580        };
581
582        let connect_cb = move |_payload: Option<Payload>, socket: ServerSocket, _| {
583            async move {
584                socket
585                    .emit("test", json!("Hello from the test event!"))
586                    .await
587                    .expect("success");
588
589                socket
590                    .emit("test", Payload::Binary(Bytes::from_static(&[1, 2, 3])))
591                    .await
592                    .expect("success");
593
594                socket
595                    .emit(
596                        "test",
597                        Payload::Multi(vec![
598                            json!(1).into(),
599                            json!("2").into(),
600                            Bytes::from_static(&[3]).into(),
601                            Bytes::from_static(b"4").into(),
602                        ]),
603                    )
604                    .await
605                    .expect("success");
606            }
607            .boxed()
608        };
609
610        let url = socket_io_server();
611        let server = ServerBuilder::new(url.port().unwrap())
612            .on("/admin", "echo", echo_callback)
613            .on("/admin", "client_ack", client_ack)
614            .on("/admin", "server_ack", trigger_ack)
615            .on("/admin", Event::Connect, connect_cb)
616            .build();
617
618        tokio::spawn(async move { server.serve().await });
619    }
620}