Skip to main content

tf_rust_socketio/client/
client.rs

1use std::{
2    sync::{Arc, Mutex, RwLock},
3    time::Duration,
4};
5
6use super::{ClientBuilder, RawClient};
7use crate::{
8    error::Result,
9    packet::{Packet, PacketId},
10    Error,
11};
12pub(crate) use crate::{event::Event, payload::Payload};
13use backoff::ExponentialBackoff;
14use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
15
16#[derive(Clone)]
17pub struct Client {
18    builder: Arc<Mutex<ClientBuilder>>,
19    client: Arc<RwLock<RawClient>>,
20    backoff: ExponentialBackoff,
21}
22
23impl Client {
24    pub(crate) fn new(builder: ClientBuilder) -> Result<Self> {
25        let builder_clone = builder.clone();
26        let client = builder_clone.connect_raw()?;
27        let backoff = ExponentialBackoffBuilder::new()
28            .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
29            .with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
30            .build();
31
32        let s = Self {
33            builder: Arc::new(Mutex::new(builder)),
34            client: Arc::new(RwLock::new(client)),
35            backoff,
36        };
37        s.poll_callback();
38
39        Ok(s)
40    }
41
42    /// Updates the URL the client will connect to when reconnecting.
43    /// This is especially useful for updating query parameters.
44    pub fn set_reconnect_url<T: Into<String>>(&self, address: T) -> Result<()> {
45        self.builder.lock()?.address = address.into();
46        Ok(())
47    }
48
49    /// Sends a message to the server using the underlying `engine.io` protocol.
50    /// This message takes an event, which could either be one of the common
51    /// events like "message" or "error" or a custom event like "foo". But be
52    /// careful, the data string needs to be valid JSON. It's recommended to use
53    /// a library like `serde_json` to serialize the data properly.
54    ///
55    /// # Example
56    /// ```
57    /// use tf_rust_socketio::{ClientBuilder, RawClient, Payload};
58    /// use serde_json::json;
59    ///
60    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
61    ///     .on("test", |payload: Payload, socket: RawClient| {
62    ///         println!("Received: {:#?}", payload);
63    ///         socket.emit("test", json!({"hello": true})).expect("Server unreachable");
64    ///      })
65    ///     .connect()
66    ///     .expect("connection failed");
67    ///
68    /// let json_payload = json!({"token": 123});
69    ///
70    /// let result = socket.emit("foo", json_payload);
71    ///
72    /// assert!(result.is_ok());
73    /// ```
74    pub fn emit<E, D>(&self, event: E, data: D) -> Result<()>
75    where
76        E: Into<Event>,
77        D: Into<Payload>,
78    {
79        let client = self.client.read()?;
80        // TODO(#230): like js client, buffer emit, resend after reconnect
81        client.emit(event, data)
82    }
83
84    /// Sends a message to the server but `alloc`s an `ack` to check whether the
85    /// server responded in a given time span. This message takes an event, which
86    /// could either be one of the common events like "message" or "error" or a
87    /// custom event like "foo", as well as a data parameter. But be careful,
88    /// in case you send a [`Payload::String`], the string needs to be valid JSON.
89    /// It's even recommended to use a library like serde_json to serialize the data properly.
90    /// It also requires a timeout `Duration` in which the client needs to answer.
91    /// If the ack is acked in the correct time span, the specified callback is
92    /// called. The callback consumes a [`Payload`] which represents the data send
93    /// by the server.
94    ///
95    /// # Example
96    /// ```
97    /// use tf_rust_socketio::{ClientBuilder, Payload, RawClient};
98    /// use serde_json::json;
99    /// use std::time::Duration;
100    /// use std::thread::sleep;
101    ///
102    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
103    ///     .on("foo", |payload: Payload, _| println!("Received: {:#?}", payload))
104    ///     .connect()
105    ///     .expect("connection failed");
106    ///
107    /// let ack_callback = |message: Payload, socket: RawClient| {
108    ///     match message {
109    ///         Payload::Text(values, _) => println!("{:#?}", values),
110    ///         Payload::Binary(bytes, _) => println!("Received bytes: {:#?}", bytes),
111    ///         // This is deprecated, use Payload::Text instead.
112    ///         #[allow(deprecated)]
113    ///         Payload::String(str, _) => println!("{}", str),
114    ///    }
115    /// };
116    ///
117    /// let payload = json!({"token": 123});
118    /// socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).unwrap();
119    ///
120    /// sleep(Duration::from_secs(2));
121    /// ```
122    pub fn emit_with_ack<F, E, D>(
123        &self,
124        event: E,
125        data: D,
126        timeout: Duration,
127        callback: F,
128    ) -> Result<()>
129    where
130        F: FnMut(Payload, RawClient) + 'static + Send,
131        E: Into<Event>,
132        D: Into<Payload>,
133    {
134        let client = self.client.read()?;
135        // TODO(#230): like js client, buffer emit, resend after reconnect
136        client.emit_with_ack(event, data, timeout, callback)
137    }
138
139    /// Disconnects this client from the server by sending a `socket.io` closing
140    /// packet.
141    /// # Example
142    /// ```rust
143    /// use tf_rust_socketio::{ClientBuilder, Payload, RawClient};
144    /// use serde_json::json;
145    ///
146    /// fn handle_test(payload: Payload, socket: RawClient) {
147    ///     println!("Received: {:#?}", payload);
148    ///     socket.emit("test", json!({"hello": true})).expect("Server unreachable");
149    /// }
150    ///
151    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
152    ///     .on("test", handle_test)
153    ///     .connect()
154    ///     .expect("connection failed");
155    ///
156    /// let json_payload = json!({"token": 123});
157    ///
158    /// socket.emit("foo", json_payload);
159    ///
160    /// // disconnect from the server
161    /// socket.disconnect();
162    ///
163    /// ```
164    pub fn disconnect(&self) -> Result<()> {
165        let client = self.client.read()?;
166        client.disconnect()
167    }
168
169    fn reconnect(&mut self) -> Result<()> {
170        let mut reconnect_attempts = 0;
171        let (reconnect, max_reconnect_attempts) = {
172            let builder = self.builder.lock()?;
173            (builder.reconnect, builder.max_reconnect_attempts)
174        };
175
176        if reconnect {
177            loop {
178                if let Some(max_reconnect_attempts) = max_reconnect_attempts {
179                    reconnect_attempts += 1;
180                    if reconnect_attempts > max_reconnect_attempts {
181                        break;
182                    }
183                }
184
185                if let Some(backoff) = self.backoff.next_backoff() {
186                    std::thread::sleep(backoff);
187                }
188
189                if self.do_reconnect().is_ok() {
190                    break;
191                }
192            }
193        }
194
195        Ok(())
196    }
197
198    fn do_reconnect(&self) -> Result<()> {
199        let builder = self.builder.lock()?;
200        let new_client = builder.clone().connect_raw()?;
201        let mut client = self.client.write()?;
202        *client = new_client;
203
204        Ok(())
205    }
206
207    pub(crate) fn iter(&self) -> Iter {
208        Iter {
209            socket: self.client.clone(),
210        }
211    }
212
213    fn poll_callback(&self) {
214        let mut self_clone = self.clone();
215        // Use thread to consume items in iterator in order to call callbacks
216        std::thread::spawn(move || {
217            // tries to restart a poll cycle whenever a 'normal' error occurs,
218            // it just panics on network errors, in case the poll cycle returned
219            // `Result::Ok`, the server receives a close frame so it's safe to
220            // terminate
221            for packet in self_clone.iter() {
222                let should_reconnect = match packet {
223                    Err(Error::IncompleteResponseFromEngineIo(_)) => {
224                        //TODO: 0.3.X handle errors
225                        //TODO: logging error
226                        true
227                    }
228                    Ok(Packet {
229                        packet_type: PacketId::Disconnect,
230                        ..
231                    }) => match self_clone.builder.lock() {
232                        Ok(builder) => builder.reconnect_on_disconnect,
233                        Err(_) => false,
234                    },
235                    _ => false,
236                };
237                if should_reconnect {
238                    let _ = self_clone.disconnect();
239                    let _ = self_clone.reconnect();
240                }
241            }
242        });
243    }
244}
245
246pub(crate) struct Iter {
247    socket: Arc<RwLock<RawClient>>,
248}
249
250impl Iterator for Iter {
251    type Item = Result<Packet>;
252
253    fn next(&mut self) -> Option<Self::Item> {
254        let socket = self.socket.read();
255        match socket {
256            Ok(socket) => match socket.poll() {
257                Err(err) => Some(Err(err)),
258                Ok(Some(packet)) => Some(Ok(packet)),
259                // If the underlying engineIO connection is closed,
260                // throw an error so we know to reconnect
261                Ok(None) => Some(Err(Error::StoppedEngineIoSocket)),
262            },
263            Err(_) => {
264                // Lock is poisoned, our iterator is useless.
265                None
266            }
267        }
268    }
269}
270
271#[cfg(test)]
272mod test {
273    use std::{
274        sync::atomic::{AtomicUsize, Ordering},
275        time::UNIX_EPOCH,
276    };
277
278    use super::*;
279    use crate::error::Result;
280    use crate::ClientBuilder;
281    use serde_json::json;
282    use serial_test::serial;
283    use std::time::{Duration, SystemTime};
284    use url::Url;
285
286    #[test]
287    #[serial(reconnect)]
288    fn socket_io_reconnect_integration() -> Result<()> {
289        static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
290        static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
291        static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
292
293        let url = crate::test::socket_io_restart_server();
294
295        let socket = ClientBuilder::new(url)
296            .reconnect(true)
297            .max_reconnect_attempts(100)
298            .reconnect_delay(100, 100)
299            .on(Event::Connect, move |_, socket| {
300                CONNECT_NUM.fetch_add(1, Ordering::Release);
301                let r = socket.emit_with_ack(
302                    "message",
303                    json!(""),
304                    Duration::from_millis(100),
305                    |_, _| {},
306                );
307                assert!(r.is_ok(), "should emit message success");
308            })
309            .on(Event::Close, move |_, _| {
310                CLOSE_NUM.fetch_add(1, Ordering::Release);
311            })
312            .on("message", move |_, _socket| {
313                // test the iterator implementation and make sure there is a constant
314                // stream of packets, even when reconnecting
315                MESSAGE_NUM.fetch_add(1, Ordering::Release);
316            })
317            .connect();
318
319        assert!(socket.is_ok(), "should connect success");
320        let socket = socket.unwrap();
321
322        // waiting for server to emit message
323        std::thread::sleep(std::time::Duration::from_millis(500));
324
325        assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
326        assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
327        assert_eq!(load(&CLOSE_NUM), 0, "should not close");
328
329        let r = socket.emit("restart_server", json!(""));
330        assert!(r.is_ok(), "should emit restart success");
331
332        // waiting for server to restart
333        for _ in 0..10 {
334            std::thread::sleep(std::time::Duration::from_millis(400));
335            if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
336                break;
337            }
338        }
339
340        assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
341        assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
342        assert_eq!(load(&CLOSE_NUM), 1, "should close once");
343
344        socket.disconnect()?;
345        Ok(())
346    }
347
348    #[test]
349    fn socket_io_reconnect_url_auth_integration() -> Result<()> {
350        static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
351        static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
352        static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
353
354        fn get_url() -> Url {
355            let timestamp = SystemTime::now()
356                .duration_since(UNIX_EPOCH)
357                .unwrap()
358                .as_millis();
359            let mut url = crate::test::socket_io_restart_url_auth_server();
360            url.set_query(Some(&format!("timestamp={timestamp}")));
361            url
362        }
363
364        let socket = ClientBuilder::new(get_url())
365            .reconnect(true)
366            .max_reconnect_attempts(100)
367            .reconnect_delay(100, 100)
368            .on(Event::Connect, move |_, socket| {
369                CONNECT_NUM.fetch_add(1, Ordering::Release);
370                let result = socket.emit_with_ack(
371                    "message",
372                    json!(""),
373                    Duration::from_millis(100),
374                    |_, _| {},
375                );
376                assert!(result.is_ok(), "should emit message success");
377            })
378            .on(Event::Close, move |_, _| {
379                CLOSE_NUM.fetch_add(1, Ordering::Release);
380            })
381            .on("message", move |_, _| {
382                // test the iterator implementation and make sure there is a constant
383                // stream of packets, even when reconnecting
384                MESSAGE_NUM.fetch_add(1, Ordering::Release);
385            })
386            .connect();
387
388        assert!(socket.is_ok(), "should connect success");
389        let socket = socket.unwrap();
390
391        // waiting for server to emit message
392        std::thread::sleep(std::time::Duration::from_millis(500));
393
394        assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
395        assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
396        assert_eq!(load(&CLOSE_NUM), 0, "should not close");
397
398        // waiting for timestamp in url to expire
399        std::thread::sleep(std::time::Duration::from_secs(1));
400
401        socket.set_reconnect_url(get_url())?;
402
403        let result = socket.emit("restart_server", json!(""));
404        assert!(result.is_ok(), "should emit restart success");
405
406        // waiting for server to restart
407        for _ in 0..10 {
408            std::thread::sleep(std::time::Duration::from_millis(400));
409            if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
410                break;
411            }
412        }
413
414        assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
415        assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
416        assert_eq!(load(&CLOSE_NUM), 1, "should close once");
417
418        socket.disconnect()?;
419        Ok(())
420    }
421
422    #[test]
423    fn socket_io_iterator_integration() -> Result<()> {
424        let url = crate::test::socket_io_server();
425        let builder = ClientBuilder::new(url);
426        let builder_clone = builder.clone();
427
428        let client = Arc::new(RwLock::new(builder_clone.connect_raw()?));
429        let mut socket = Client {
430            builder: Arc::new(Mutex::new(builder)),
431            client,
432            backoff: Default::default(),
433        };
434        let socket_clone = socket.clone();
435
436        let packets: Arc<RwLock<Vec<Packet>>> = Default::default();
437        let packets_clone = packets.clone();
438
439        std::thread::spawn(move || {
440            for packet in socket_clone.iter() {
441                {
442                    let mut packets = packets_clone.write().unwrap();
443                    if let Ok(packet) = packet {
444                        (*packets).push(packet);
445                    }
446                }
447            }
448        });
449
450        // waiting for client to emit messages
451        std::thread::sleep(Duration::from_millis(100));
452        let lock = packets.read().unwrap();
453        let pre_num = lock.len();
454        drop(lock);
455
456        let _ = socket.disconnect();
457        socket.reconnect()?;
458
459        // waiting for client to emit messages
460        std::thread::sleep(Duration::from_millis(100));
461
462        let lock = packets.read().unwrap();
463        let post_num = lock.len();
464        drop(lock);
465
466        assert!(
467            pre_num < post_num,
468            "pre_num {} should less than post_num {}",
469            pre_num,
470            post_num
471        );
472
473        Ok(())
474    }
475
476    fn load(num: &AtomicUsize) -> usize {
477        num.load(Ordering::Acquire)
478    }
479}