rust_socketio/client/
client.rs

1use std::{
2    sync::{Arc, Mutex, RwLock},
3    time::Duration,
4};
5
6use super::{ClientBuilder, RawClient};
7use crate::{
8    error::Result,
9    packet::{Packet, PacketId},
10    Error,
11};
12pub(crate) use crate::{event::Event, payload::Payload};
13use backoff::ExponentialBackoff;
14use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
15
16#[derive(Clone)]
17pub struct Client {
18    builder: Arc<Mutex<ClientBuilder>>,
19    client: Arc<RwLock<RawClient>>,
20    backoff: ExponentialBackoff,
21}
22
23impl Client {
24    pub(crate) fn new(builder: ClientBuilder) -> Result<Self> {
25        let builder_clone = builder.clone();
26        let client = builder_clone.connect_raw()?;
27        let backoff = ExponentialBackoffBuilder::new()
28            .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
29            .with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
30            .build();
31
32        let s = Self {
33            builder: Arc::new(Mutex::new(builder)),
34            client: Arc::new(RwLock::new(client)),
35            backoff,
36        };
37        s.poll_callback();
38
39        Ok(s)
40    }
41
42    /// Updates the URL the client will connect to when reconnecting.
43    /// This is especially useful for updating query parameters.
44    pub fn set_reconnect_url<T: Into<String>>(&self, address: T) -> Result<()> {
45        self.builder.lock()?.address = address.into();
46        Ok(())
47    }
48
49    /// Sends a message to the server using the underlying `engine.io` protocol.
50    /// This message takes an event, which could either be one of the common
51    /// events like "message" or "error" or a custom event like "foo". But be
52    /// careful, the data string needs to be valid JSON. It's recommended to use
53    /// a library like `serde_json` to serialize the data properly.
54    ///
55    /// # Example
56    /// ```
57    /// use rust_socketio::{ClientBuilder, RawClient, Payload};
58    /// use serde_json::json;
59    ///
60    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
61    ///     .on("test", |payload: Payload, socket: RawClient| {
62    ///         println!("Received: {:#?}", payload);
63    ///         socket.emit("test", json!({"hello": true})).expect("Server unreachable");
64    ///      })
65    ///     .connect()
66    ///     .expect("connection failed");
67    ///
68    /// let json_payload = json!({"token": 123});
69    ///
70    /// let result = socket.emit("foo", json_payload);
71    ///
72    /// assert!(result.is_ok());
73    /// ```
74    pub fn emit<E, D>(&self, event: E, data: D) -> Result<()>
75    where
76        E: Into<Event>,
77        D: Into<Payload>,
78    {
79        let client = self.client.read()?;
80        // TODO(#230): like js client, buffer emit, resend after reconnect
81        client.emit(event, data)
82    }
83
84    /// Sends a message to the server but `alloc`s an `ack` to check whether the
85    /// server responded in a given time span. This message takes an event, which
86    /// could either be one of the common events like "message" or "error" or a
87    /// custom event like "foo", as well as a data parameter. But be careful,
88    /// in case you send a [`Payload::String`], the string needs to be valid JSON.
89    /// It's even recommended to use a library like serde_json to serialize the data properly.
90    /// It also requires a timeout `Duration` in which the client needs to answer.
91    /// If the ack is acked in the correct time span, the specified callback is
92    /// called. The callback consumes a [`Payload`] which represents the data send
93    /// by the server.
94    ///
95    /// # Example
96    /// ```
97    /// use rust_socketio::{ClientBuilder, Payload, RawClient};
98    /// use serde_json::json;
99    /// use std::time::Duration;
100    /// use std::thread::sleep;
101    ///
102    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
103    ///     .on("foo", |payload: Payload, _| println!("Received: {:#?}", payload))
104    ///     .connect()
105    ///     .expect("connection failed");
106    ///
107    /// let ack_callback = |message: Payload, socket: RawClient| {
108    ///     match message {
109    ///         Payload::Text(values) => println!("{:#?}", values),
110    ///         Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes),
111    ///         // This is deprecated, use Payload::Text instead.
112    ///         Payload::String(str) => println!("{}", str),
113    ///    }
114    /// };
115    ///
116    /// let payload = json!({"token": 123});
117    /// socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).unwrap();
118    ///
119    /// sleep(Duration::from_secs(2));
120    /// ```
121    pub fn emit_with_ack<F, E, D>(
122        &self,
123        event: E,
124        data: D,
125        timeout: Duration,
126        callback: F,
127    ) -> Result<()>
128    where
129        F: FnMut(Payload, RawClient) + 'static + Send,
130        E: Into<Event>,
131        D: Into<Payload>,
132    {
133        let client = self.client.read()?;
134        // TODO(#230): like js client, buffer emit, resend after reconnect
135        client.emit_with_ack(event, data, timeout, callback)
136    }
137
138    /// Disconnects this client from the server by sending a `socket.io` closing
139    /// packet.
140    /// # Example
141    /// ```rust
142    /// use rust_socketio::{ClientBuilder, Payload, RawClient};
143    /// use serde_json::json;
144    ///
145    /// fn handle_test(payload: Payload, socket: RawClient) {
146    ///     println!("Received: {:#?}", payload);
147    ///     socket.emit("test", json!({"hello": true})).expect("Server unreachable");
148    /// }
149    ///
150    /// let mut socket = ClientBuilder::new("http://localhost:4200/")
151    ///     .on("test", handle_test)
152    ///     .connect()
153    ///     .expect("connection failed");
154    ///
155    /// let json_payload = json!({"token": 123});
156    ///
157    /// socket.emit("foo", json_payload);
158    ///
159    /// // disconnect from the server
160    /// socket.disconnect();
161    ///
162    /// ```
163    pub fn disconnect(&self) -> Result<()> {
164        let client = self.client.read()?;
165        client.disconnect()
166    }
167
168    fn reconnect(&mut self) -> Result<()> {
169        let mut reconnect_attempts = 0;
170        let (reconnect, max_reconnect_attempts) = {
171            let builder = self.builder.lock()?;
172            (builder.reconnect, builder.max_reconnect_attempts)
173        };
174
175        if reconnect {
176            loop {
177                if let Some(max_reconnect_attempts) = max_reconnect_attempts {
178                    reconnect_attempts += 1;
179                    if reconnect_attempts > max_reconnect_attempts {
180                        break;
181                    }
182                }
183
184                if let Some(backoff) = self.backoff.next_backoff() {
185                    std::thread::sleep(backoff);
186                }
187
188                if self.do_reconnect().is_ok() {
189                    break;
190                }
191            }
192        }
193
194        Ok(())
195    }
196
197    fn do_reconnect(&self) -> Result<()> {
198        let builder = self.builder.lock()?;
199        let new_client = builder.clone().connect_raw()?;
200        let mut client = self.client.write()?;
201        *client = new_client;
202
203        Ok(())
204    }
205
206    pub(crate) fn iter(&self) -> Iter {
207        Iter {
208            socket: self.client.clone(),
209        }
210    }
211
212    fn poll_callback(&self) {
213        let mut self_clone = self.clone();
214        // Use thread to consume items in iterator in order to call callbacks
215        std::thread::spawn(move || {
216            // tries to restart a poll cycle whenever a 'normal' error occurs,
217            // it just panics on network errors, in case the poll cycle returned
218            // `Result::Ok`, the server receives a close frame so it's safe to
219            // terminate
220            for packet in self_clone.iter() {
221                let should_reconnect = match packet {
222                    Err(Error::IncompleteResponseFromEngineIo(_)) => {
223                        //TODO: 0.3.X handle errors
224                        //TODO: logging error
225                        true
226                    }
227                    Ok(Packet {
228                        packet_type: PacketId::Disconnect,
229                        ..
230                    }) => match self_clone.builder.lock() {
231                        Ok(builder) => builder.reconnect_on_disconnect,
232                        Err(_) => false,
233                    },
234                    _ => false,
235                };
236                if should_reconnect {
237                    let _ = self_clone.disconnect();
238                    let _ = self_clone.reconnect();
239                }
240            }
241        });
242    }
243}
244
245pub(crate) struct Iter {
246    socket: Arc<RwLock<RawClient>>,
247}
248
249impl Iterator for Iter {
250    type Item = Result<Packet>;
251
252    fn next(&mut self) -> Option<Self::Item> {
253        let socket = self.socket.read();
254        match socket {
255            Ok(socket) => match socket.poll() {
256                Err(err) => Some(Err(err)),
257                Ok(Some(packet)) => Some(Ok(packet)),
258                // If the underlying engineIO connection is closed,
259                // throw an error so we know to reconnect
260                Ok(None) => Some(Err(Error::StoppedEngineIoSocket)),
261            },
262            Err(_) => {
263                // Lock is poisoned, our iterator is useless.
264                None
265            }
266        }
267    }
268}
269
270#[cfg(test)]
271mod test {
272    use std::{
273        sync::atomic::{AtomicUsize, Ordering},
274        time::UNIX_EPOCH,
275    };
276
277    use super::*;
278    use crate::error::Result;
279    use crate::ClientBuilder;
280    use serde_json::json;
281    use serial_test::serial;
282    use std::time::{Duration, SystemTime};
283    use url::Url;
284
285    #[test]
286    #[serial(reconnect)]
287    fn socket_io_reconnect_integration() -> Result<()> {
288        static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
289        static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
290        static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
291
292        let url = crate::test::socket_io_restart_server();
293
294        let socket = ClientBuilder::new(url)
295            .reconnect(true)
296            .max_reconnect_attempts(100)
297            .reconnect_delay(100, 100)
298            .on(Event::Connect, move |_, socket| {
299                CONNECT_NUM.fetch_add(1, Ordering::Release);
300                let r = socket.emit_with_ack(
301                    "message",
302                    json!(""),
303                    Duration::from_millis(100),
304                    |_, _| {},
305                );
306                assert!(r.is_ok(), "should emit message success");
307            })
308            .on(Event::Close, move |_, _| {
309                CLOSE_NUM.fetch_add(1, Ordering::Release);
310            })
311            .on("message", move |_, _socket| {
312                // test the iterator implementation and make sure there is a constant
313                // stream of packets, even when reconnecting
314                MESSAGE_NUM.fetch_add(1, Ordering::Release);
315            })
316            .connect();
317
318        assert!(socket.is_ok(), "should connect success");
319        let socket = socket.unwrap();
320
321        // waiting for server to emit message
322        std::thread::sleep(std::time::Duration::from_millis(500));
323
324        assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
325        assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
326        assert_eq!(load(&CLOSE_NUM), 0, "should not close");
327
328        let r = socket.emit("restart_server", json!(""));
329        assert!(r.is_ok(), "should emit restart success");
330
331        // waiting for server to restart
332        for _ in 0..10 {
333            std::thread::sleep(std::time::Duration::from_millis(400));
334            if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
335                break;
336            }
337        }
338
339        assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
340        assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
341        assert_eq!(load(&CLOSE_NUM), 1, "should close once");
342
343        socket.disconnect()?;
344        Ok(())
345    }
346
347    #[test]
348    fn socket_io_reconnect_url_auth_integration() -> Result<()> {
349        static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
350        static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
351        static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
352
353        fn get_url() -> Url {
354            let timestamp = SystemTime::now()
355                .duration_since(UNIX_EPOCH)
356                .unwrap()
357                .as_millis();
358            let mut url = crate::test::socket_io_restart_url_auth_server();
359            url.set_query(Some(&format!("timestamp={timestamp}")));
360            url
361        }
362
363        let socket = ClientBuilder::new(get_url())
364            .reconnect(true)
365            .max_reconnect_attempts(100)
366            .reconnect_delay(100, 100)
367            .on(Event::Connect, move |_, socket| {
368                CONNECT_NUM.fetch_add(1, Ordering::Release);
369                let result = socket.emit_with_ack(
370                    "message",
371                    json!(""),
372                    Duration::from_millis(100),
373                    |_, _| {},
374                );
375                assert!(result.is_ok(), "should emit message success");
376            })
377            .on(Event::Close, move |_, _| {
378                CLOSE_NUM.fetch_add(1, Ordering::Release);
379            })
380            .on("message", move |_, _| {
381                // test the iterator implementation and make sure there is a constant
382                // stream of packets, even when reconnecting
383                MESSAGE_NUM.fetch_add(1, Ordering::Release);
384            })
385            .connect();
386
387        assert!(socket.is_ok(), "should connect success");
388        let socket = socket.unwrap();
389
390        // waiting for server to emit message
391        std::thread::sleep(std::time::Duration::from_millis(500));
392
393        assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
394        assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
395        assert_eq!(load(&CLOSE_NUM), 0, "should not close");
396
397        // waiting for timestamp in url to expire
398        std::thread::sleep(std::time::Duration::from_secs(1));
399
400        socket.set_reconnect_url(get_url())?;
401
402        let result = socket.emit("restart_server", json!(""));
403        assert!(result.is_ok(), "should emit restart success");
404
405        // waiting for server to restart
406        for _ in 0..10 {
407            std::thread::sleep(std::time::Duration::from_millis(400));
408            if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
409                break;
410            }
411        }
412
413        assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
414        assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
415        assert_eq!(load(&CLOSE_NUM), 1, "should close once");
416
417        socket.disconnect()?;
418        Ok(())
419    }
420
421    #[test]
422    fn socket_io_iterator_integration() -> Result<()> {
423        let url = crate::test::socket_io_server();
424        let builder = ClientBuilder::new(url);
425        let builder_clone = builder.clone();
426
427        let client = Arc::new(RwLock::new(builder_clone.connect_raw()?));
428        let mut socket = Client {
429            builder: Arc::new(Mutex::new(builder)),
430            client,
431            backoff: Default::default(),
432        };
433        let socket_clone = socket.clone();
434
435        let packets: Arc<RwLock<Vec<Packet>>> = Default::default();
436        let packets_clone = packets.clone();
437
438        std::thread::spawn(move || {
439            for packet in socket_clone.iter() {
440                {
441                    let mut packets = packets_clone.write().unwrap();
442                    if let Ok(packet) = packet {
443                        (*packets).push(packet);
444                    }
445                }
446            }
447        });
448
449        // waiting for client to emit messages
450        std::thread::sleep(Duration::from_millis(100));
451        let lock = packets.read().unwrap();
452        let pre_num = lock.len();
453        drop(lock);
454
455        let _ = socket.disconnect();
456        socket.reconnect()?;
457
458        // waiting for client to emit messages
459        std::thread::sleep(Duration::from_millis(100));
460
461        let lock = packets.read().unwrap();
462        let post_num = lock.len();
463        drop(lock);
464
465        assert!(
466            pre_num < post_num,
467            "pre_num {} should less than post_num {}",
468            pre_num,
469            post_num
470        );
471
472        Ok(())
473    }
474
475    fn load(num: &AtomicUsize) -> usize {
476        num.load(Ordering::Acquire)
477    }
478}