1use std::{
2 sync::{Arc, Mutex, RwLock},
3 time::Duration,
4};
5
6use super::{ClientBuilder, RawClient};
7use crate::{
8 error::Result,
9 packet::{Packet, PacketId},
10 Error,
11};
12pub(crate) use crate::{event::Event, payload::Payload};
13use backoff::ExponentialBackoff;
14use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
15
16#[derive(Clone)]
17pub struct Client {
18 builder: Arc<Mutex<ClientBuilder>>,
19 client: Arc<RwLock<RawClient>>,
20 backoff: ExponentialBackoff,
21}
22
23impl Client {
24 pub(crate) fn new(builder: ClientBuilder) -> Result<Self> {
25 let builder_clone = builder.clone();
26 let client = builder_clone.connect_raw()?;
27 let backoff = ExponentialBackoffBuilder::new()
28 .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
29 .with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
30 .build();
31
32 let s = Self {
33 builder: Arc::new(Mutex::new(builder)),
34 client: Arc::new(RwLock::new(client)),
35 backoff,
36 };
37 s.poll_callback();
38
39 Ok(s)
40 }
41
42 pub fn set_reconnect_url<T: Into<String>>(&self, address: T) -> Result<()> {
45 self.builder.lock()?.address = address.into();
46 Ok(())
47 }
48
49 pub fn emit<E, D>(&self, event: E, data: D) -> Result<()>
75 where
76 E: Into<Event>,
77 D: Into<Payload>,
78 {
79 let client = self.client.read()?;
80 client.emit(event, data)
82 }
83
84 pub fn emit_with_ack<F, E, D>(
123 &self,
124 event: E,
125 data: D,
126 timeout: Duration,
127 callback: F,
128 ) -> Result<()>
129 where
130 F: FnMut(Payload, RawClient) + 'static + Send,
131 E: Into<Event>,
132 D: Into<Payload>,
133 {
134 let client = self.client.read()?;
135 client.emit_with_ack(event, data, timeout, callback)
137 }
138
139 pub fn disconnect(&self) -> Result<()> {
165 let client = self.client.read()?;
166 client.disconnect()
167 }
168
169 fn reconnect(&mut self) -> Result<()> {
170 let mut reconnect_attempts = 0;
171 let (reconnect, max_reconnect_attempts) = {
172 let builder = self.builder.lock()?;
173 (builder.reconnect, builder.max_reconnect_attempts)
174 };
175
176 if reconnect {
177 loop {
178 if let Some(max_reconnect_attempts) = max_reconnect_attempts {
179 reconnect_attempts += 1;
180 if reconnect_attempts > max_reconnect_attempts {
181 break;
182 }
183 }
184
185 if let Some(backoff) = self.backoff.next_backoff() {
186 std::thread::sleep(backoff);
187 }
188
189 if self.do_reconnect().is_ok() {
190 break;
191 }
192 }
193 }
194
195 Ok(())
196 }
197
198 fn do_reconnect(&self) -> Result<()> {
199 let builder = self.builder.lock()?;
200 let new_client = builder.clone().connect_raw()?;
201 let mut client = self.client.write()?;
202 *client = new_client;
203
204 Ok(())
205 }
206
207 pub(crate) fn iter(&self) -> Iter {
208 Iter {
209 socket: self.client.clone(),
210 }
211 }
212
213 fn poll_callback(&self) {
214 let mut self_clone = self.clone();
215 std::thread::spawn(move || {
217 for packet in self_clone.iter() {
222 let should_reconnect = match packet {
223 Err(Error::IncompleteResponseFromEngineIo(_)) => {
224 true
227 }
228 Ok(Packet {
229 packet_type: PacketId::Disconnect,
230 ..
231 }) => match self_clone.builder.lock() {
232 Ok(builder) => builder.reconnect_on_disconnect,
233 Err(_) => false,
234 },
235 _ => false,
236 };
237 if should_reconnect {
238 let _ = self_clone.disconnect();
239 let _ = self_clone.reconnect();
240 }
241 }
242 });
243 }
244}
245
246pub(crate) struct Iter {
247 socket: Arc<RwLock<RawClient>>,
248}
249
250impl Iterator for Iter {
251 type Item = Result<Packet>;
252
253 fn next(&mut self) -> Option<Self::Item> {
254 let socket = self.socket.read();
255 match socket {
256 Ok(socket) => match socket.poll() {
257 Err(err) => Some(Err(err)),
258 Ok(Some(packet)) => Some(Ok(packet)),
259 Ok(None) => Some(Err(Error::StoppedEngineIoSocket)),
262 },
263 Err(_) => {
264 None
266 }
267 }
268 }
269}
270
271#[cfg(test)]
272mod test {
273 use std::{
274 sync::atomic::{AtomicUsize, Ordering},
275 time::UNIX_EPOCH,
276 };
277
278 use super::*;
279 use crate::error::Result;
280 use crate::ClientBuilder;
281 use serde_json::json;
282 use serial_test::serial;
283 use std::time::{Duration, SystemTime};
284 use url::Url;
285
286 #[test]
287 #[serial(reconnect)]
288 fn socket_io_reconnect_integration() -> Result<()> {
289 static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
290 static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
291 static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
292
293 let url = crate::test::socket_io_restart_server();
294
295 let socket = ClientBuilder::new(url)
296 .reconnect(true)
297 .max_reconnect_attempts(100)
298 .reconnect_delay(100, 100)
299 .on(Event::Connect, move |_, socket| {
300 CONNECT_NUM.fetch_add(1, Ordering::Release);
301 let r = socket.emit_with_ack(
302 "message",
303 json!(""),
304 Duration::from_millis(100),
305 |_, _| {},
306 );
307 assert!(r.is_ok(), "should emit message success");
308 })
309 .on(Event::Close, move |_, _| {
310 CLOSE_NUM.fetch_add(1, Ordering::Release);
311 })
312 .on("message", move |_, _socket| {
313 MESSAGE_NUM.fetch_add(1, Ordering::Release);
316 })
317 .connect();
318
319 assert!(socket.is_ok(), "should connect success");
320 let socket = socket.unwrap();
321
322 std::thread::sleep(std::time::Duration::from_millis(500));
324
325 assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
326 assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
327 assert_eq!(load(&CLOSE_NUM), 0, "should not close");
328
329 let r = socket.emit("restart_server", json!(""));
330 assert!(r.is_ok(), "should emit restart success");
331
332 for _ in 0..10 {
334 std::thread::sleep(std::time::Duration::from_millis(400));
335 if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
336 break;
337 }
338 }
339
340 assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
341 assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
342 assert_eq!(load(&CLOSE_NUM), 1, "should close once");
343
344 socket.disconnect()?;
345 Ok(())
346 }
347
348 #[test]
349 fn socket_io_reconnect_url_auth_integration() -> Result<()> {
350 static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
351 static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
352 static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
353
354 fn get_url() -> Url {
355 let timestamp = SystemTime::now()
356 .duration_since(UNIX_EPOCH)
357 .unwrap()
358 .as_millis();
359 let mut url = crate::test::socket_io_restart_url_auth_server();
360 url.set_query(Some(&format!("timestamp={timestamp}")));
361 url
362 }
363
364 let socket = ClientBuilder::new(get_url())
365 .reconnect(true)
366 .max_reconnect_attempts(100)
367 .reconnect_delay(100, 100)
368 .on(Event::Connect, move |_, socket| {
369 CONNECT_NUM.fetch_add(1, Ordering::Release);
370 let result = socket.emit_with_ack(
371 "message",
372 json!(""),
373 Duration::from_millis(100),
374 |_, _| {},
375 );
376 assert!(result.is_ok(), "should emit message success");
377 })
378 .on(Event::Close, move |_, _| {
379 CLOSE_NUM.fetch_add(1, Ordering::Release);
380 })
381 .on("message", move |_, _| {
382 MESSAGE_NUM.fetch_add(1, Ordering::Release);
385 })
386 .connect();
387
388 assert!(socket.is_ok(), "should connect success");
389 let socket = socket.unwrap();
390
391 std::thread::sleep(std::time::Duration::from_millis(500));
393
394 assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
395 assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
396 assert_eq!(load(&CLOSE_NUM), 0, "should not close");
397
398 std::thread::sleep(std::time::Duration::from_secs(1));
400
401 socket.set_reconnect_url(get_url())?;
402
403 let result = socket.emit("restart_server", json!(""));
404 assert!(result.is_ok(), "should emit restart success");
405
406 for _ in 0..10 {
408 std::thread::sleep(std::time::Duration::from_millis(400));
409 if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
410 break;
411 }
412 }
413
414 assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
415 assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
416 assert_eq!(load(&CLOSE_NUM), 1, "should close once");
417
418 socket.disconnect()?;
419 Ok(())
420 }
421
422 #[test]
423 fn socket_io_iterator_integration() -> Result<()> {
424 let url = crate::test::socket_io_server();
425 let builder = ClientBuilder::new(url);
426 let builder_clone = builder.clone();
427
428 let client = Arc::new(RwLock::new(builder_clone.connect_raw()?));
429 let mut socket = Client {
430 builder: Arc::new(Mutex::new(builder)),
431 client,
432 backoff: Default::default(),
433 };
434 let socket_clone = socket.clone();
435
436 let packets: Arc<RwLock<Vec<Packet>>> = Default::default();
437 let packets_clone = packets.clone();
438
439 std::thread::spawn(move || {
440 for packet in socket_clone.iter() {
441 {
442 let mut packets = packets_clone.write().unwrap();
443 if let Ok(packet) = packet {
444 (*packets).push(packet);
445 }
446 }
447 }
448 });
449
450 std::thread::sleep(Duration::from_millis(100));
452 let lock = packets.read().unwrap();
453 let pre_num = lock.len();
454 drop(lock);
455
456 let _ = socket.disconnect();
457 socket.reconnect()?;
458
459 std::thread::sleep(Duration::from_millis(100));
461
462 let lock = packets.read().unwrap();
463 let post_num = lock.len();
464 drop(lock);
465
466 assert!(
467 pre_num < post_num,
468 "pre_num {} should less than post_num {}",
469 pre_num,
470 post_num
471 );
472
473 Ok(())
474 }
475
476 fn load(num: &AtomicUsize) -> usize {
477 num.load(Ordering::Acquire)
478 }
479}