Skip to main content

tf_rust_socketio/asynchronous/client/
client.rs

1use std::{ops::DerefMut, pin::Pin, sync::Arc};
2
3use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
4use futures_util::{future::BoxFuture, stream, Stream, StreamExt};
5use log::{error, trace};
6use rand::{thread_rng, Rng};
7use serde_json::Value;
8use tf_rust_engineio::header::{HeaderMap, HeaderValue};
9use tokio::{
10    sync::RwLock,
11    time::{sleep, Duration, Instant},
12};
13
14use super::{
15    ack::Ack,
16    builder::ClientBuilder,
17    callback::{Callback, DynAsyncCallback},
18};
19use crate::{
20    asynchronous::socket::Socket as InnerSocket,
21    error::{Error, Result},
22    packet::{Packet, PacketId},
23    CloseReason, Event, Payload,
24};
25
26#[derive(Default)]
27enum DisconnectReason {
28    /// There is no known reason for the disconnect; likely a network error
29    #[default]
30    Unknown,
31    /// The user disconnected manually
32    Manual,
33    /// The server disconnected
34    Server,
35}
36
37/// Settings that can be updated before reconnecting to a server
38#[derive(Default)]
39pub struct ReconnectSettings {
40    address: Option<String>,
41    auth: Option<serde_json::Value>,
42    headers: Option<HeaderMap>,
43}
44
45impl ReconnectSettings {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Sets the URL that will be used when reconnecting to the server
51    pub fn address<T>(&mut self, address: T) -> &mut Self
52    where
53        T: Into<String>,
54    {
55        self.address = Some(address.into());
56        self
57    }
58
59    /// Sets the authentication data that will be send in the opening request
60    pub fn auth(&mut self, auth: serde_json::Value) {
61        self.auth = Some(auth);
62    }
63
64    /// Adds an http header to a container that is going to completely replace opening headers on reconnect.
65    /// If there are no headers set in `ReconnectSettings`, client will use headers initially set via the builder.
66    pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(
67        &mut self,
68        key: K,
69        val: T,
70    ) -> &mut Self {
71        self.headers
72            .get_or_insert_with(HeaderMap::default)
73            .insert(key.into(), val.into());
74        self
75    }
76}
77
78/// A socket which handles communication with the server. It's initialized with
79/// a specific address as well as an optional namespace to connect to. If `None`
80/// is given the client will connect to the default namespace `"/"`.
81#[derive(Clone)]
82pub struct Client {
83    /// The inner socket client to delegate the methods to.
84    socket: Arc<RwLock<InnerSocket>>,
85    outstanding_acks: Arc<RwLock<Vec<Ack>>>,
86    // namespace, for multiplexing messages
87    nsp: String,
88    // Data send in the opening packet (commonly used as for auth)
89    auth: Option<serde_json::Value>,
90    builder: Arc<RwLock<ClientBuilder>>,
91    disconnect_reason: Arc<RwLock<DisconnectReason>>,
92}
93
94impl Client {
95    /// Creates a socket with a certain address to connect to as well as a
96    /// namespace. If `None` is passed in as namespace, the default namespace
97    /// `"/"` is taken.
98    /// ```
99    pub(crate) fn new(socket: InnerSocket, builder: ClientBuilder) -> Result<Self> {
100        Ok(Client {
101            socket: Arc::new(RwLock::new(socket)),
102            nsp: builder.namespace.to_owned(),
103            outstanding_acks: Arc::new(RwLock::new(Vec::new())),
104            auth: builder.auth.clone(),
105            builder: Arc::new(RwLock::new(builder)),
106            disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())),
107        })
108    }
109
110    /// Connects the client to a server. Afterwards the `emit_*` methods can be
111    /// called to interact with the server.
112    pub(crate) async fn connect(&self) -> Result<()> {
113        // Connect the underlying socket
114        self.socket.read().await.connect().await?;
115
116        // construct the opening packet
117        let auth = self.auth.as_ref().map(|data| data.to_string());
118        let open_packet = Packet::new(PacketId::Connect, self.nsp.clone(), auth, None, 0, None);
119
120        self.socket.read().await.send(open_packet).await?;
121
122        Ok(())
123    }
124
125    pub(crate) async fn reconnect(&mut self) -> Result<()> {
126        let mut builder = self.builder.write().await;
127
128        if let Some(config) = builder.on_reconnect.as_mut() {
129            let reconnect_settings = config().await;
130
131            if let Some(address) = reconnect_settings.address {
132                builder.address = address;
133            }
134
135            if let Some(auth) = reconnect_settings.auth {
136                self.auth = Some(auth);
137            }
138
139            if reconnect_settings.headers.is_some() {
140                builder.opening_headers = reconnect_settings.headers;
141            }
142        }
143
144        let socket = builder.inner_create().await?;
145
146        // New inner socket that can be connected
147        let mut client_socket = self.socket.write().await;
148        *client_socket = socket;
149
150        // Now that we have replaced `self.socket`, we drop the write lock
151        // because the `connect` method we call below will need to use it
152        drop(client_socket);
153
154        self.connect().await?;
155
156        Ok(())
157    }
158
159    /// Drives the stream using a thread so messages are processed
160    pub(crate) async fn poll_stream(&mut self) -> Result<()> {
161        let builder = self.builder.read().await;
162        let reconnect_delay_min = builder.reconnect_delay_min;
163        let reconnect_delay_max = builder.reconnect_delay_max;
164        let max_reconnect_attempts = builder.max_reconnect_attempts;
165        let reconnect = builder.reconnect;
166        let reconnect_on_disconnect = builder.reconnect_on_disconnect;
167        drop(builder);
168
169        let mut client_clone = self.clone();
170
171        tokio::runtime::Handle::current().spawn(async move {
172            loop {
173                let mut stream = client_clone.as_stream().await;
174                // Consume the stream until it returns None and the stream is closed.
175                while let Some(item) = stream.next().await {
176                    if let Err(e) = item {
177                        trace!("Network error occurred: {}", e);
178                    }
179                }
180
181                // Drop the stream so we can once again use `socket_clone` as mutable
182                drop(stream);
183
184                let should_reconnect = match *(client_clone.disconnect_reason.read().await) {
185                    DisconnectReason::Unknown => {
186                        // If we disconnected for an unknown reason, the client might not have noticed
187                        // the closure yet. Hence, fire a transport close event to notify it.
188                        // We don't need to do that in the other cases, since proper server close
189                        // and manual client close are handled explicitly.
190                        if let Some(err) = client_clone
191                            .callback(&Event::Close, CloseReason::TransportClose.as_str(), None)
192                            .await
193                            .err()
194                        {
195                            error!("Error while notifying client of transport close: {err}")
196                        }
197
198                        reconnect
199                    }
200                    DisconnectReason::Manual => false,
201                    DisconnectReason::Server => reconnect_on_disconnect,
202                };
203
204                if should_reconnect {
205                    let mut reconnect_attempts = 0;
206                    let mut backoff = ExponentialBackoffBuilder::new()
207                        .with_initial_interval(Duration::from_millis(reconnect_delay_min))
208                        .with_max_interval(Duration::from_millis(reconnect_delay_max))
209                        .build();
210
211                    loop {
212                        if let Some(max_reconnect_attempts) = max_reconnect_attempts {
213                            reconnect_attempts += 1;
214                            if reconnect_attempts > max_reconnect_attempts {
215                                trace!("Max reconnect attempts reached without success");
216                                break;
217                            }
218                        }
219                        match client_clone.reconnect().await {
220                            Ok(_) => {
221                                trace!("Reconnected after {reconnect_attempts} attempts");
222                                break;
223                            }
224                            Err(e) => {
225                                trace!("Failed to reconnect: {e:?}");
226                                if let Some(delay) = backoff.next_backoff() {
227                                    let delay_ms = delay.as_millis();
228                                    trace!("Waiting for {delay_ms}ms before reconnecting");
229                                    sleep(delay).await;
230                                }
231                            }
232                        }
233                    }
234                } else {
235                    break;
236                }
237            }
238        });
239
240        Ok(())
241    }
242
243    /// Sends a message to the server using the underlying `engine.io` protocol.
244    /// This message takes an event, which could either be one of the common
245    /// events like "message" or "error" or a custom event like "foo". But be
246    /// careful, the data string needs to be valid JSON. It's recommended to use
247    /// a library like `serde_json` to serialize the data properly.
248    ///
249    /// # Example
250    /// ```
251    /// use tf_rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload};
252    /// use serde_json::json;
253    /// use futures_util::FutureExt;
254    ///
255    /// #[tokio::main]
256    /// async fn main() {
257    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
258    ///         .on("test", |payload: Payload, socket: Client| {
259    ///             async move {
260    ///                 println!("Received: {:#?}", payload);
261    ///                 socket.emit("test", json!({"hello": true})).await.expect("Server unreachable");
262    ///             }.boxed()
263    ///         })
264    ///         .connect()
265    ///         .await
266    ///         .expect("connection failed");
267    ///
268    ///     let json_payload = json!({"token": 123});
269    ///
270    ///     let result = socket.emit("foo", json_payload).await;
271    ///
272    ///     assert!(result.is_ok());
273    /// }
274    /// ```
275    #[inline]
276    pub async fn emit<E, D>(&self, event: E, data: D) -> Result<()>
277    where
278        E: Into<Event>,
279        D: Into<Payload>,
280    {
281        self.socket
282            .read()
283            .await
284            .emit(&self.nsp, event.into(), data.into())
285            .await
286    }
287
288    /// When receive server's emitwithack callback event, invoke socket.ack(..) function can react to server with ack signal
289    /// use futures_util::FutureExt;
290    ///
291    /// # Example
292    /// ```
293    /// use futures_util::FutureExt;
294    /// use tf_rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload};
295    /// use serde_json::json;
296    /// use std::time::Duration;
297    /// use std::thread;
298    /// use bytes::Bytes;
299    ///
300    /// #[tokio::main]
301    /// async fn main() {
302    ///
303    ///     let callback = |payload: Payload, socket: Client| {
304    ///        async move {
305    ///           match payload {
306    ///               Payload::Text(values, ack_id) => {
307    ///                   println!("{:#?}", values);
308    ///                   if let Some(id) = ack_id {
309    ///                       let _ = socket.ack_with_id(id, json!({"status": "received"})).await;
310    ///                   }
311    ///               },
312    ///               Payload::Binary(bytes, ack_id) => {
313    ///                   println!("Received bytes: {:#?}", bytes);
314    ///                   if let Some(id) = ack_id {
315    ///                       let _ = socket.ack_with_id(id, vec![4, 5, 6]).await;
316    ///                   }
317    ///               },
318    ///               Payload::String(str, ack_id) => {
319    ///                   println!("{}", str);
320    ///                   if let Some(id) = ack_id {
321    ///                       let _ = socket.ack_with_id(id, "response").await;
322    ///                   }
323    ///               },
324    ///           }
325    ///        }.boxed()
326    ///     };
327    ///
328    ///     // get a socket that is connected to the admin namespace
329    ///     let socket = ClientBuilder::new("http://localhost:4200")
330    ///         .namespace("/")
331    ///         .on("foo", callback)
332    ///         .on("error", |err, _| {
333    ///             async move { eprintln!("Error: {:#?}", err) }.boxed()
334    ///         })
335    ///         .connect()
336    ///         .await
337    ///         .expect("Connection failed");
338    ///     
339    ///
340    ///     thread::sleep(Duration::from_millis(30000));
341    ///     socket.disconnect().await.expect("Disconnect failed");
342    /// }
343    /// ```
344    #[inline]
345    pub async fn ack<D>(&self, data: D) -> Result<()>
346    where
347        D: Into<Payload>,
348    {
349        // For backward compatibility, this method doesn't specify an ack_id
350        // It should only be used when there's only one pending ack
351        let socket = self.socket.read().await;
352        socket.ack(&self.nsp, data.into(), None).await
353    }
354
355    /// Acknowledge a message with a specific ack_id
356    pub async fn ack_with_id<D>(&self, ack_id: i32, data: D) -> Result<()>
357    where
358        D: Into<Payload>,
359    {
360        let socket = self.socket.read().await;
361        socket.ack(&self.nsp, data.into(), Some(ack_id)).await
362    }
363
364    /// Disconnects this client from the server by sending a `socket.io` closing
365    /// packet.
366    /// # Example
367    /// ```rust
368    /// use tf_rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload};
369    /// use serde_json::json;
370    /// use futures_util::{FutureExt, future::BoxFuture};
371    ///
372    /// #[tokio::main]
373    /// async fn main() {
374    ///     // apparently the syntax for functions is a bit verbose as rust currently doesn't
375    ///     // support an `AsyncFnMut` type that conform with async functions
376    ///     fn handle_test(payload: Payload, socket: Client) -> BoxFuture<'static, ()> {
377    ///         async move {
378    ///             println!("Received: {:#?}", payload);
379    ///             socket.emit("test", json!({"hello": true})).await.expect("Server unreachable");
380    ///         }.boxed()
381    ///     }
382    ///
383    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
384    ///         .on("test", handle_test)
385    ///         .connect()
386    ///         .await
387    ///         .expect("connection failed");
388    ///
389    ///     let json_payload = json!({"token": 123});
390    ///
391    ///     socket.emit("foo", json_payload).await;
392    ///
393    ///     // disconnect from the server
394    ///     socket.disconnect().await;
395    /// }
396    /// ```
397    pub async fn disconnect(&self) -> Result<()> {
398        *(self.disconnect_reason.write().await) = DisconnectReason::Manual;
399
400        let disconnect_packet =
401            Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None);
402
403        self.socket.read().await.send(disconnect_packet).await?;
404        self.socket.read().await.disconnect().await?;
405
406        Ok(())
407    }
408
409    /// Sends a message to the server but `alloc`s an `ack` to check whether the
410    /// server responded in a given time span. This message takes an event, which
411    /// could either be one of the common events like "message" or "error" or a
412    /// custom event like "foo", as well as a data parameter. But be careful,
413    /// in case you send a [`Payload::String`], the string needs to be valid JSON.
414    /// It's even recommended to use a library like serde_json to serialize the data properly.
415    /// It also requires a timeout `Duration` in which the client needs to answer.
416    /// If the ack is acked in the correct time span, the specified callback is
417    /// called. The callback consumes a [`Payload`] which represents the data send
418    /// by the server.
419    ///
420    /// Please note that the requirements on the provided callbacks are similar to the ones
421    /// for [`crate::asynchronous::ClientBuilder::on`].
422    /// # Example
423    /// ```
424    /// use tf_rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload};
425    /// use serde_json::json;
426    /// use std::time::Duration;
427    /// use std::thread::sleep;
428    /// use futures_util::FutureExt;
429    ///
430    /// #[tokio::main]
431    /// async fn main() {
432    ///     let mut socket = ClientBuilder::new("http://localhost:4200/")
433    ///         .on("foo", |payload: Payload, _| async move { println!("Received: {:#?}", payload) }.boxed())
434    ///         .connect()
435    ///         .await
436    ///         .expect("connection failed");
437    ///
438    ///     let ack_callback = |message: Payload, socket: Client| {
439    ///         async move {
440    ///             match message {
441    ///                 Payload::Text(values, _) => println!("{:#?}", values),
442    ///                 Payload::Binary(bytes, _) => println!("Received bytes: {:#?}", bytes),
443    ///                 // This is deprecated use Payload::Text instead
444    ///                 #[allow(deprecated)]
445    ///                 Payload::String(str, _) => println!("{}", str),
446    ///             }
447    ///         }.boxed()
448    ///     };
449    ///
450    ///
451    ///     let payload = json!({"token": 123});
452    ///     socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).await.unwrap();
453    ///
454    ///     sleep(Duration::from_secs(2));
455    /// }
456    /// ```
457    #[inline]
458    pub async fn emit_with_ack<F, E, D>(
459        &self,
460        event: E,
461        data: D,
462        timeout: Duration,
463        callback: F,
464    ) -> Result<()>
465    where
466        F: for<'a> std::ops::FnMut(Payload, Client) -> BoxFuture<'static, ()>
467            + 'static
468            + Send
469            + Sync,
470        E: Into<Event>,
471        D: Into<Payload>,
472    {
473        let id = thread_rng().gen_range(0..999);
474        let socket_packet =
475            Packet::new_from_payload(data.into(), event.into(), &self.nsp, Some(id))?;
476
477        let ack = Ack {
478            id,
479            time_started: Instant::now(),
480            timeout,
481            callback: Callback::<DynAsyncCallback>::new(callback),
482        };
483
484        // add the ack to the tuple of outstanding acks
485        self.outstanding_acks.write().await.push(ack);
486
487        self.socket.read().await.send(socket_packet).await
488    }
489
490    async fn callback<P: Into<Payload>>(
491        &self,
492        event: &Event,
493        payload: P,
494        ack_id: Option<i32>,
495    ) -> Result<()> {
496        let mut builder = self.builder.write().await;
497        let mut payload = payload.into();
498        payload.set_ack_id(ack_id);
499
500        if let Some(callback) = builder.on.get_mut(event) {
501            callback(payload.clone(), self.clone()).await;
502        }
503
504        // Call on_any for all common and custom events.
505        match event {
506            Event::Message | Event::Custom(_) => {
507                if let Some(callback) = builder.on_any.as_mut() {
508                    callback(event.clone(), payload, self.clone()).await;
509                }
510            }
511            _ => (),
512        }
513
514        Ok(())
515    }
516
517    /// Handles the incoming acks and classifies what callbacks to call and how.
518    #[inline]
519    async fn handle_ack(&self, socket_packet: &Packet) -> Result<()> {
520        let mut to_be_removed = Vec::new();
521        if let Some(id) = socket_packet.id {
522            for (index, ack) in self.outstanding_acks.write().await.iter_mut().enumerate() {
523                if ack.id == id {
524                    to_be_removed.push(index);
525
526                    if ack.time_started.elapsed() < ack.timeout {
527                        if let Some(ref payload) = socket_packet.data {
528                            let mut payload = Payload::from(payload.to_owned());
529                            payload.set_ack_id(socket_packet.id);
530                            ack.callback.deref_mut()(payload, self.clone()).await;
531                        }
532                        if let Some(ref attachments) = socket_packet.attachments {
533                            if let Some(payload) = attachments.first() {
534                                let payload = Payload::Binary(payload.to_owned(), socket_packet.id);
535                                ack.callback.deref_mut()(payload, self.clone()).await;
536                            }
537                        }
538                    } else {
539                        trace!("Received an Ack that is now timed out (elapsed time was longer than specified duration)");
540                    }
541                }
542            }
543            for index in to_be_removed {
544                self.outstanding_acks.write().await.remove(index);
545            }
546        }
547        Ok(())
548    }
549
550    /// Handles a binary event.
551    #[inline]
552    async fn handle_binary_event(&self, packet: &Packet) -> Result<()> {
553        let event = if let Some(string_data) = &packet.data {
554            string_data.replace('\"', "").into()
555        } else {
556            Event::Message
557        };
558
559        if let Some(attachments) = &packet.attachments {
560            if let Some(binary_payload) = attachments.first() {
561                self.callback(
562                    &event,
563                    Payload::Binary(binary_payload.to_owned(), packet.id),
564                    packet.id,
565                )
566                .await?;
567            }
568        }
569        Ok(())
570    }
571
572    /// A method that parses a packet and eventually calls the corresponding
573    /// callback with the supplied data.
574    async fn handle_event(&self, packet: &Packet) -> Result<()> {
575        let Some(ref data) = packet.data else {
576            return Ok(());
577        };
578
579        // a socketio message always comes in one of the following two flavors (both JSON):
580        // 1: `["event", "msg", ...]`
581        // 2: `["msg"]`
582        // in case 2, the message is ment for the default message event, in case 1 the event
583        // is specified
584        if let Ok(Value::Array(contents)) = serde_json::from_str::<Value>(data) {
585            let (event, payloads) = match contents.len() {
586                0 => return Err(Error::IncompletePacket()),
587                // Incorrect packet, ignore it
588                1 => (Event::Message, contents.as_slice()),
589                // it's a message event
590                _ => match contents.first() {
591                    Some(Value::String(ev)) => (Event::from(ev.as_str()), &contents[1..]),
592                    // get rest(1..) of them as data, not just take the 2nd element
593                    _ => (Event::Message, contents.as_slice()),
594                    // take them all as data
595                },
596            };
597
598            // call the correct callback
599            self.callback(&event, payloads.to_vec(), packet.id).await?;
600        }
601
602        Ok(())
603    }
604
605    /// Handles the incoming messages and classifies what callbacks to call and how.
606    /// This method is later registered as the callback for the `on_data` event of the
607    /// engineio client.
608    #[inline]
609    async fn handle_socketio_packet(&self, packet: &Packet) -> Result<()> {
610        if packet.nsp == self.nsp {
611            match packet.packet_type {
612                PacketId::Ack | PacketId::BinaryAck => {
613                    if let Err(err) = self.handle_ack(packet).await {
614                        self.callback(&Event::Error, err.to_string(), None).await?;
615                        return Err(err);
616                    }
617                }
618                PacketId::BinaryEvent => {
619                    if let Err(err) = self.handle_binary_event(packet).await {
620                        self.callback(&Event::Error, err.to_string(), None).await?;
621                    }
622                }
623                PacketId::Connect => {
624                    *(self.disconnect_reason.write().await) = DisconnectReason::default();
625                    self.callback(&Event::Connect, "", None).await?;
626                }
627                PacketId::Disconnect => {
628                    *(self.disconnect_reason.write().await) = DisconnectReason::Server;
629                    self.callback(
630                        &Event::Close,
631                        CloseReason::IOServerDisconnect.as_str(),
632                        None,
633                    )
634                    .await?;
635                }
636                PacketId::ConnectError => {
637                    self.callback(
638                        &Event::Error,
639                        String::from("Received an ConnectError frame: ")
640                            + packet
641                                .data
642                                .as_ref()
643                                .unwrap_or(&String::from("\"No error message provided\"")),
644                        None,
645                    )
646                    .await?;
647                }
648                PacketId::Event => {
649                    if let Err(err) = self.handle_event(packet).await {
650                        self.callback(&Event::Error, err.to_string(), None).await?;
651                    }
652                }
653            }
654        }
655        Ok(())
656    }
657
658    /// Returns the packet stream for the client.
659    pub(crate) async fn as_stream<'a>(
660        &'a self,
661    ) -> Pin<Box<dyn Stream<Item = Result<Packet>> + Send + 'a>> {
662        let socket_clone = (*self.socket.read().await).clone();
663
664        stream::unfold(socket_clone, |mut socket| async {
665            // wait for the next payload
666            let packet: Option<std::result::Result<Packet, Error>> = socket.next().await;
667            match packet {
668                // end the stream if the underlying one is closed
669                None => None,
670                Some(Err(err)) => {
671                    // call the error callback
672                    match self.callback(&Event::Error, err.to_string(), None).await {
673                        Err(callback_err) => Some((Err(callback_err), socket)),
674                        Ok(_) => Some((Err(err), socket)),
675                    }
676                }
677                Some(Ok(packet)) => match self.handle_socketio_packet(&packet).await {
678                    Err(callback_err) => Some((Err(callback_err), socket)),
679                    Ok(_) => Some((Ok(packet), socket)),
680                },
681            }
682        })
683        .boxed()
684    }
685}
686
687#[cfg(test)]
688mod test {
689
690    use std::{
691        sync::{
692            atomic::{AtomicUsize, Ordering},
693            Arc,
694        },
695        time::Duration,
696    };
697
698    use bytes::Bytes;
699    use futures_util::{FutureExt, StreamExt};
700    use native_tls::TlsConnector;
701    use serde_json::json;
702    use tokio::{
703        sync::{mpsc, Mutex},
704        time::{sleep, timeout},
705    };
706
707    use serial_test::serial;
708
709    use crate::{
710        asynchronous::{
711            client::{builder::ClientBuilder, client::Client},
712            ReconnectSettings,
713        },
714        error::Result,
715        packet::{Packet, PacketId},
716        CloseReason, Event, Payload, TransportType,
717    };
718
719    #[tokio::test]
720    async fn socket_io_integration() -> Result<()> {
721        let url = crate::test::socket_io_server();
722
723        let socket = ClientBuilder::new(url)
724            .on("test", |msg, _| {
725                async {
726                    match msg {
727                        Payload::Text(values, _) => println!("Received json: {:#?}", values),
728                        #[allow(deprecated)]
729                        Payload::String(str, _) => println!("Received string: {}", str),
730                        Payload::Binary(bin, _) => println!("Received binary data: {:#?}", bin),
731                    }
732                }
733                .boxed()
734            })
735            .connect()
736            .await?;
737
738        let payload = json!({"token": 123_i32});
739        let result = socket.emit("test", Payload::from(payload.clone())).await;
740
741        assert!(result.is_ok());
742
743        let ack = socket
744            .emit_with_ack(
745                "test",
746                Payload::from(payload),
747                Duration::from_secs(1),
748                |message: Payload, socket: Client| {
749                    async move {
750                        let result = socket
751                            .emit("test", Payload::from(json!({"got ack": true})))
752                            .await;
753                        assert!(result.is_ok());
754
755                        println!("Yehaa! My ack got acked?");
756                        if let Payload::Text(json, _) = message {
757                            println!("Received json Ack");
758                            println!("Ack data: {:#?}", json);
759                        }
760                    }
761                    .boxed()
762                },
763            )
764            .await;
765        assert!(ack.is_ok());
766
767        sleep(Duration::from_secs(2)).await;
768
769        assert!(socket.disconnect().await.is_ok());
770
771        Ok(())
772    }
773
774    #[tokio::test]
775    async fn socket_io_async_callback() -> Result<()> {
776        // Test whether asynchronous callbacks are fully executed.
777        let url = crate::test::socket_io_server();
778
779        // This synchronization mechanism is used to let the test know that the end of the
780        // async callback was reached.
781        let notify = Arc::new(tokio::sync::Notify::new());
782        let notify_clone = notify.clone();
783
784        let socket = ClientBuilder::new(url)
785            .on("test", move |_, _| {
786                let cl = notify_clone.clone();
787                async move {
788                    sleep(Duration::from_secs(1)).await;
789                    // The async callback should be awaited and not aborted.
790                    // Thus, the notification should be called.
791                    cl.notify_one();
792                }
793                .boxed()
794            })
795            .connect()
796            .await?;
797
798        let payload = json!({"token": 123_i32});
799        let result = socket.emit("test", Payload::from(payload)).await;
800
801        assert!(result.is_ok());
802        // If the timeout did not trigger, the async callback was fully executed.
803        let timeout = timeout(Duration::from_secs(5), notify.notified()).await;
804        assert!(timeout.is_ok());
805
806        Ok(())
807    }
808
809    #[tokio::test]
810    async fn socket_io_builder_integration() -> Result<()> {
811        let url = crate::test::socket_io_server();
812
813        // test socket build logic
814        let socket_builder = ClientBuilder::new(url);
815
816        let tls_connector = TlsConnector::builder()
817            .use_sni(true)
818            .build()
819            .expect("Found illegal configuration");
820
821        let socket = socket_builder
822            .namespace("/admin")
823            .tls_config(tls_connector)
824            .opening_header("accept-encoding", "application/json")
825            .on("test", |str, _| {
826                async move { println!("Received: {:#?}", str) }.boxed()
827            })
828            .on("message", |payload, _| {
829                async move { println!("{:#?}", payload) }.boxed()
830            })
831            .connect()
832            .await?;
833
834        assert!(socket.emit("message", json!("Hello World")).await.is_ok());
835
836        assert!(socket
837            .emit("binary", Bytes::from_static(&[46, 88]))
838            .await
839            .is_ok());
840
841        assert!(socket
842            .emit_with_ack(
843                "binary",
844                json!("pls ack"),
845                Duration::from_secs(1),
846                |payload, _| async move {
847                    println!("Yehaa the ack got acked");
848                    println!("With data: {:#?}", payload);
849                }
850                .boxed()
851            )
852            .await
853            .is_ok());
854
855        sleep(Duration::from_secs(2)).await;
856
857        Ok(())
858    }
859
860    #[tokio::test]
861    #[serial(reconnect)]
862    async fn socket_io_reconnect_integration() -> Result<()> {
863        static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
864        static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
865        static ON_RECONNECT_CALLED: AtomicUsize = AtomicUsize::new(0);
866        let latest_message = Arc::new(Mutex::new(String::new()));
867        let handler_latest_message = latest_message.clone();
868
869        let url = crate::test::socket_io_restart_server();
870
871        let socket = ClientBuilder::new(url.clone())
872            .reconnect(true)
873            .max_reconnect_attempts(100)
874            .reconnect_delay(100, 100)
875            .on_reconnect(move || {
876                let url = url.clone();
877                async move {
878                    ON_RECONNECT_CALLED.fetch_add(1, Ordering::Release);
879
880                    let mut settings = ReconnectSettings::new();
881
882                    // Try setting the address to what we already have, just
883                    // to test. This is not strictly necessary in real usage.
884                    settings.address(url.to_string());
885                    settings.opening_header("MESSAGE_BACK", "updated");
886                    settings
887                }
888                .boxed()
889            })
890            .on("open", |_, socket| {
891                async move {
892                    CONNECT_NUM.fetch_add(1, Ordering::Release);
893                    let r = socket.emit_with_ack(
894                        "message",
895                        json!(""),
896                        Duration::from_millis(100),
897                        |_, _| async move {}.boxed(),
898                    );
899                    assert!(r.await.is_ok(), "should emit message success");
900                }
901                .boxed()
902            })
903            .on("message", move |payload, _socket| {
904                let latest_message = handler_latest_message.clone();
905                async move {
906                    // test the iterator implementation and make sure there is a constant
907                    // stream of packets, even when reconnecting
908                    MESSAGE_NUM.fetch_add(1, Ordering::Release);
909
910                    let msg = match payload {
911                        Payload::Text(msg, _) => msg
912                            .into_iter()
913                            .next()
914                            .expect("there should be one text payload"),
915                        _ => panic!(),
916                    };
917
918                    let msg = serde_json::from_value(msg).expect("payload should be json string");
919
920                    *latest_message.lock().await = msg;
921                }
922                .boxed()
923            })
924            .connect()
925            .await;
926
927        assert!(socket.is_ok(), "should connect success");
928        let socket = socket.unwrap();
929
930        // waiting for server to emit message
931        sleep(Duration::from_millis(500)).await;
932
933        assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
934        assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
935        assert_eq!(
936            *latest_message.lock().await,
937            "test",
938            "should receive test message"
939        );
940
941        let r = socket.emit("restart_server", json!("")).await;
942        assert!(r.is_ok(), "should emit restart success");
943
944        // waiting for server to restart
945        for _ in 0..10 {
946            sleep(Duration::from_millis(400)).await;
947            if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
948                break;
949            }
950        }
951
952        assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
953        assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
954        assert!(
955            load(&ON_RECONNECT_CALLED) > 1,
956            "should call on_reconnect at least once"
957        );
958        assert_eq!(
959            *latest_message.lock().await,
960            "updated",
961            "should receive updated message"
962        );
963
964        socket.disconnect().await?;
965        Ok(())
966    }
967
968    #[tokio::test]
969    async fn socket_io_builder_integration_iterator() -> Result<()> {
970        let url = crate::test::socket_io_server();
971
972        // test socket build logic
973        let socket_builder = ClientBuilder::new(url);
974
975        let tls_connector = TlsConnector::builder()
976            .use_sni(true)
977            .build()
978            .expect("Found illegal configuration");
979
980        let socket = socket_builder
981            .namespace("/admin")
982            .tls_config(tls_connector)
983            .opening_header("accept-encoding", "application/json")
984            .on("test", |str, _| {
985                async move { println!("Received: {:#?}", str) }.boxed()
986            })
987            .on("message", |payload, _| {
988                async move { println!("{:#?}", payload) }.boxed()
989            })
990            .connect_manual()
991            .await?;
992
993        assert!(socket.emit("message", json!("Hello World")).await.is_ok());
994
995        assert!(socket
996            .emit("binary", Bytes::from_static(&[46, 88]))
997            .await
998            .is_ok());
999
1000        assert!(socket
1001            .emit_with_ack(
1002                "binary",
1003                json!("pls ack"),
1004                Duration::from_secs(1),
1005                |payload, _| async move {
1006                    println!("Yehaa the ack got acked");
1007                    println!("With data: {:#?}", payload);
1008                }
1009                .boxed()
1010            )
1011            .await
1012            .is_ok());
1013
1014        test_socketio_socket(socket, "/admin".to_owned()).await
1015    }
1016
1017    #[tokio::test]
1018    async fn socket_io_on_any_integration() -> Result<()> {
1019        let url = crate::test::socket_io_server();
1020
1021        let (tx, mut rx) = mpsc::channel(2);
1022
1023        let mut _socket = ClientBuilder::new(url)
1024            .namespace("/")
1025            .auth(json!({ "password": "123" }))
1026            .on_any(move |event, payload, _| {
1027                let clone_tx = tx.clone();
1028                async move {
1029                    if let Payload::Text(values, _) = payload {
1030                        println!("{event}: {values:#?}");
1031                    }
1032                    clone_tx.send(String::from(event)).await.unwrap();
1033                }
1034                .boxed()
1035            })
1036            .connect()
1037            .await?;
1038
1039        let event = rx.recv().await.unwrap();
1040        assert_eq!(event, "message");
1041
1042        let event = rx.recv().await.unwrap();
1043        assert_eq!(event, "test");
1044
1045        Ok(())
1046    }
1047
1048    #[tokio::test]
1049    async fn socket_io_auth_builder_integration() -> Result<()> {
1050        let url = crate::test::socket_io_auth_server();
1051        let nsp = String::from("/admin");
1052        let socket = ClientBuilder::new(url)
1053            .namespace(nsp.clone())
1054            .auth(json!({ "password": "123" }))
1055            .connect_manual()
1056            .await?;
1057
1058        // open packet
1059        let mut socket_stream = socket.as_stream().await;
1060        let _ = socket_stream.next().await.unwrap()?;
1061
1062        let packet = socket_stream.next().await.unwrap()?;
1063        assert_eq!(
1064            packet,
1065            Packet::new(
1066                PacketId::Event,
1067                nsp,
1068                Some("[\"auth\",\"success\"]".to_owned()),
1069                None,
1070                0,
1071                None
1072            )
1073        );
1074
1075        Ok(())
1076    }
1077
1078    #[tokio::test]
1079    async fn socket_io_transport_close() -> Result<()> {
1080        let url = crate::test::socket_io_server();
1081
1082        let (tx, mut rx) = mpsc::channel(1);
1083
1084        let notify = Arc::new(tokio::sync::Notify::new());
1085        let notify_clone = notify.clone();
1086
1087        let socket = ClientBuilder::new(url)
1088            .on(Event::Connect, move |_, _| {
1089                let cl = notify_clone.clone();
1090                async move {
1091                    cl.notify_one();
1092                }
1093                .boxed()
1094            })
1095            .on(Event::Close, move |payload, _| {
1096                let clone_tx = tx.clone();
1097                async move { clone_tx.send(payload).await.unwrap() }.boxed()
1098            })
1099            .connect()
1100            .await?;
1101
1102        // Wait until socket is connected
1103        let connect_timeout = timeout(Duration::from_secs(1), notify.notified()).await;
1104        assert!(connect_timeout.is_ok());
1105
1106        // Instruct server to close transport
1107        let result = socket.emit("close_transport", Payload::from("")).await;
1108        assert!(result.is_ok());
1109
1110        // Wait for Event::Close
1111        let rx_timeout = timeout(Duration::from_secs(1), rx.recv()).await;
1112        assert!(rx_timeout.is_ok());
1113
1114        assert_eq!(
1115            rx_timeout.unwrap(),
1116            Some(Payload::from(CloseReason::TransportClose.as_str()))
1117        );
1118
1119        Ok(())
1120    }
1121
1122    #[tokio::test]
1123    async fn socketio_polling_integration() -> Result<()> {
1124        let url = crate::test::socket_io_server();
1125        let socket = ClientBuilder::new(url.clone())
1126            .transport_type(TransportType::Polling)
1127            .connect_manual()
1128            .await?;
1129        test_socketio_socket(socket, "/".to_owned()).await
1130    }
1131
1132    #[tokio::test]
1133    async fn socket_io_websocket_integration() -> Result<()> {
1134        let url = crate::test::socket_io_server();
1135        let socket = ClientBuilder::new(url.clone())
1136            .transport_type(TransportType::Websocket)
1137            .connect_manual()
1138            .await?;
1139        test_socketio_socket(socket, "/".to_owned()).await
1140    }
1141
1142    #[tokio::test]
1143    async fn socket_io_websocket_upgrade_integration() -> Result<()> {
1144        let url = crate::test::socket_io_server();
1145        let socket = ClientBuilder::new(url)
1146            .transport_type(TransportType::WebsocketUpgrade)
1147            .connect_manual()
1148            .await?;
1149        test_socketio_socket(socket, "/".to_owned()).await
1150    }
1151
1152    #[tokio::test]
1153    async fn socket_io_any_integration() -> Result<()> {
1154        let url = crate::test::socket_io_server();
1155        let socket = ClientBuilder::new(url)
1156            .transport_type(TransportType::Any)
1157            .connect_manual()
1158            .await?;
1159        test_socketio_socket(socket, "/".to_owned()).await
1160    }
1161
1162    async fn test_socketio_socket(socket: Client, nsp: String) -> Result<()> {
1163        // open packet
1164        let mut socket_stream = socket.as_stream().await;
1165        let _: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1166
1167        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1168
1169        assert!(packet.is_some());
1170
1171        let packet = packet.unwrap();
1172
1173        assert_eq!(
1174            packet,
1175            Packet::new(
1176                PacketId::Event,
1177                nsp.clone(),
1178                Some("[\"Hello from the message event!\"]".to_owned()),
1179                None,
1180                0,
1181                None,
1182            )
1183        );
1184
1185        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1186
1187        assert!(packet.is_some());
1188
1189        let packet = packet.unwrap();
1190
1191        assert_eq!(
1192            packet,
1193            Packet::new(
1194                PacketId::Event,
1195                nsp.clone(),
1196                Some("[\"test\",\"Hello from the test event!\"]".to_owned()),
1197                None,
1198                0,
1199                None
1200            )
1201        );
1202        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1203
1204        assert!(packet.is_some());
1205
1206        let packet = packet.unwrap();
1207        assert_eq!(
1208            packet,
1209            Packet::new(
1210                PacketId::BinaryEvent,
1211                nsp.clone(),
1212                None,
1213                None,
1214                1,
1215                Some(vec![Bytes::from_static(&[4, 5, 6])]),
1216            )
1217        );
1218
1219        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1220
1221        assert!(packet.is_some());
1222
1223        let packet = packet.unwrap();
1224        assert_eq!(
1225            packet,
1226            Packet::new(
1227                PacketId::BinaryEvent,
1228                nsp.clone(),
1229                Some("\"test\"".to_owned()),
1230                None,
1231                1,
1232                Some(vec![Bytes::from_static(&[1, 2, 3])]),
1233            )
1234        );
1235
1236        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1237
1238        assert!(packet.is_some());
1239
1240        let packet = packet.unwrap();
1241        assert_eq!(
1242            packet,
1243            Packet::new(
1244                PacketId::Event,
1245                nsp.clone(),
1246                Some(
1247                    serde_json::Value::Array(vec![
1248                        serde_json::Value::from("This is the first argument"),
1249                        serde_json::Value::from("This is the second argument"),
1250                        serde_json::json!({"argCount":3})
1251                    ])
1252                    .to_string()
1253                ),
1254                None,
1255                0,
1256                None,
1257            )
1258        );
1259
1260        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1261
1262        assert!(packet.is_some());
1263
1264        let packet = packet.unwrap();
1265        assert_eq!(
1266            packet,
1267            Packet::new(
1268                PacketId::Event,
1269                nsp.clone(),
1270                Some(
1271                    serde_json::json!([
1272                        "on_abc_event",
1273                        "",
1274                        {
1275                        "abc": 0,
1276                        "some_other": "value",
1277                        }
1278                    ])
1279                    .to_string()
1280                ),
1281                None,
1282                0,
1283                None,
1284            )
1285        );
1286
1287        let cb = |message: Payload, _| {
1288            async {
1289                println!("Yehaa! My ack got acked?");
1290                if let Payload::Text(values, _) = message {
1291                    println!("Received json ack");
1292                    println!("Ack data: {:#?}", values);
1293                }
1294            }
1295            .boxed()
1296        };
1297
1298        assert!(socket
1299            .emit_with_ack(
1300                "test",
1301                Payload::from("123".to_owned()),
1302                Duration::from_secs(10),
1303                cb
1304            )
1305            .await
1306            .is_ok());
1307
1308        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1309
1310        assert!(packet.is_some());
1311        let packet = packet.unwrap();
1312        assert_eq!(
1313            packet,
1314            Packet::new(
1315                PacketId::Event,
1316                nsp.clone(),
1317                Some("[\"test-received\",123]".to_owned()),
1318                None,
1319                0,
1320                None,
1321            )
1322        );
1323
1324        let packet: Option<Packet> = Some(socket_stream.next().await.unwrap()?);
1325
1326        assert!(packet.is_some());
1327        let packet = packet.unwrap();
1328        assert!(matches!(
1329            packet,
1330            Packet {
1331                packet_type: PacketId::Ack,
1332                nsp: _,
1333                data: Some(_),
1334                id: Some(_),
1335                attachment_count: 0,
1336                attachments: None,
1337            }
1338        ));
1339
1340        Ok(())
1341    }
1342
1343    fn load(num: &AtomicUsize) -> usize {
1344        num.load(Ordering::Acquire)
1345    }
1346}