1use std::{
2 sync::{Arc, Mutex, RwLock},
3 time::Duration,
4};
5
6use super::{ClientBuilder, RawClient};
7use crate::{
8 error::Result,
9 packet::{Packet, PacketId},
10 Error,
11};
12pub(crate) use crate::{event::Event, payload::Payload};
13use backoff::ExponentialBackoff;
14use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
15
16#[derive(Clone)]
17pub struct Client {
18 builder: Arc<Mutex<ClientBuilder>>,
19 client: Arc<RwLock<RawClient>>,
20 backoff: ExponentialBackoff,
21}
22
23impl Client {
24 pub(crate) fn new(builder: ClientBuilder) -> Result<Self> {
25 let builder_clone = builder.clone();
26 let client = builder_clone.connect_raw()?;
27 let backoff = ExponentialBackoffBuilder::new()
28 .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
29 .with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
30 .build();
31
32 let s = Self {
33 builder: Arc::new(Mutex::new(builder)),
34 client: Arc::new(RwLock::new(client)),
35 backoff,
36 };
37 s.poll_callback();
38
39 Ok(s)
40 }
41
42 pub fn set_reconnect_url<T: Into<String>>(&self, address: T) -> Result<()> {
45 self.builder.lock()?.address = address.into();
46 Ok(())
47 }
48
49 pub fn emit<E, D>(&self, event: E, data: D) -> Result<()>
75 where
76 E: Into<Event>,
77 D: Into<Payload>,
78 {
79 let client = self.client.read()?;
80 client.emit(event, data)
82 }
83
84 pub fn emit_with_ack<F, E, D>(
122 &self,
123 event: E,
124 data: D,
125 timeout: Duration,
126 callback: F,
127 ) -> Result<()>
128 where
129 F: FnMut(Payload, RawClient) + 'static + Send,
130 E: Into<Event>,
131 D: Into<Payload>,
132 {
133 let client = self.client.read()?;
134 client.emit_with_ack(event, data, timeout, callback)
136 }
137
138 pub fn disconnect(&self) -> Result<()> {
164 let client = self.client.read()?;
165 client.disconnect()
166 }
167
168 fn reconnect(&mut self) -> Result<()> {
169 let mut reconnect_attempts = 0;
170 let (reconnect, max_reconnect_attempts) = {
171 let builder = self.builder.lock()?;
172 (builder.reconnect, builder.max_reconnect_attempts)
173 };
174
175 if reconnect {
176 loop {
177 if let Some(max_reconnect_attempts) = max_reconnect_attempts {
178 reconnect_attempts += 1;
179 if reconnect_attempts > max_reconnect_attempts {
180 break;
181 }
182 }
183
184 if let Some(backoff) = self.backoff.next_backoff() {
185 std::thread::sleep(backoff);
186 }
187
188 if self.do_reconnect().is_ok() {
189 break;
190 }
191 }
192 }
193
194 Ok(())
195 }
196
197 fn do_reconnect(&self) -> Result<()> {
198 let builder = self.builder.lock()?;
199 let new_client = builder.clone().connect_raw()?;
200 let mut client = self.client.write()?;
201 *client = new_client;
202
203 Ok(())
204 }
205
206 pub(crate) fn iter(&self) -> Iter {
207 Iter {
208 socket: self.client.clone(),
209 }
210 }
211
212 fn poll_callback(&self) {
213 let mut self_clone = self.clone();
214 std::thread::spawn(move || {
216 for packet in self_clone.iter() {
221 let should_reconnect = match packet {
222 Err(Error::IncompleteResponseFromEngineIo(_)) => {
223 true
226 }
227 Ok(Packet {
228 packet_type: PacketId::Disconnect,
229 ..
230 }) => match self_clone.builder.lock() {
231 Ok(builder) => builder.reconnect_on_disconnect,
232 Err(_) => false,
233 },
234 _ => false,
235 };
236 if should_reconnect {
237 let _ = self_clone.disconnect();
238 let _ = self_clone.reconnect();
239 }
240 }
241 });
242 }
243}
244
245pub(crate) struct Iter {
246 socket: Arc<RwLock<RawClient>>,
247}
248
249impl Iterator for Iter {
250 type Item = Result<Packet>;
251
252 fn next(&mut self) -> Option<Self::Item> {
253 let socket = self.socket.read();
254 match socket {
255 Ok(socket) => match socket.poll() {
256 Err(err) => Some(Err(err)),
257 Ok(Some(packet)) => Some(Ok(packet)),
258 Ok(None) => Some(Err(Error::StoppedEngineIoSocket)),
261 },
262 Err(_) => {
263 None
265 }
266 }
267 }
268}
269
270#[cfg(test)]
271mod test {
272 use std::{
273 sync::atomic::{AtomicUsize, Ordering},
274 time::UNIX_EPOCH,
275 };
276
277 use super::*;
278 use crate::error::Result;
279 use crate::ClientBuilder;
280 use serde_json::json;
281 use serial_test::serial;
282 use std::time::{Duration, SystemTime};
283 use url::Url;
284
285 #[test]
286 #[serial(reconnect)]
287 fn socket_io_reconnect_integration() -> Result<()> {
288 static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
289 static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
290 static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
291
292 let url = crate::test::socket_io_restart_server();
293
294 let socket = ClientBuilder::new(url)
295 .reconnect(true)
296 .max_reconnect_attempts(100)
297 .reconnect_delay(100, 100)
298 .on(Event::Connect, move |_, socket| {
299 CONNECT_NUM.fetch_add(1, Ordering::Release);
300 let r = socket.emit_with_ack(
301 "message",
302 json!(""),
303 Duration::from_millis(100),
304 |_, _| {},
305 );
306 assert!(r.is_ok(), "should emit message success");
307 })
308 .on(Event::Close, move |_, _| {
309 CLOSE_NUM.fetch_add(1, Ordering::Release);
310 })
311 .on("message", move |_, _socket| {
312 MESSAGE_NUM.fetch_add(1, Ordering::Release);
315 })
316 .connect();
317
318 assert!(socket.is_ok(), "should connect success");
319 let socket = socket.unwrap();
320
321 std::thread::sleep(std::time::Duration::from_millis(500));
323
324 assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
325 assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
326 assert_eq!(load(&CLOSE_NUM), 0, "should not close");
327
328 let r = socket.emit("restart_server", json!(""));
329 assert!(r.is_ok(), "should emit restart success");
330
331 for _ in 0..10 {
333 std::thread::sleep(std::time::Duration::from_millis(400));
334 if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
335 break;
336 }
337 }
338
339 assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
340 assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
341 assert_eq!(load(&CLOSE_NUM), 1, "should close once");
342
343 socket.disconnect()?;
344 Ok(())
345 }
346
347 #[test]
348 fn socket_io_reconnect_url_auth_integration() -> Result<()> {
349 static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
350 static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0);
351 static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
352
353 fn get_url() -> Url {
354 let timestamp = SystemTime::now()
355 .duration_since(UNIX_EPOCH)
356 .unwrap()
357 .as_millis();
358 let mut url = crate::test::socket_io_restart_url_auth_server();
359 url.set_query(Some(&format!("timestamp={timestamp}")));
360 url
361 }
362
363 let socket = ClientBuilder::new(get_url())
364 .reconnect(true)
365 .max_reconnect_attempts(100)
366 .reconnect_delay(100, 100)
367 .on(Event::Connect, move |_, socket| {
368 CONNECT_NUM.fetch_add(1, Ordering::Release);
369 let result = socket.emit_with_ack(
370 "message",
371 json!(""),
372 Duration::from_millis(100),
373 |_, _| {},
374 );
375 assert!(result.is_ok(), "should emit message success");
376 })
377 .on(Event::Close, move |_, _| {
378 CLOSE_NUM.fetch_add(1, Ordering::Release);
379 })
380 .on("message", move |_, _| {
381 MESSAGE_NUM.fetch_add(1, Ordering::Release);
384 })
385 .connect();
386
387 assert!(socket.is_ok(), "should connect success");
388 let socket = socket.unwrap();
389
390 std::thread::sleep(std::time::Duration::from_millis(500));
392
393 assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
394 assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
395 assert_eq!(load(&CLOSE_NUM), 0, "should not close");
396
397 std::thread::sleep(std::time::Duration::from_secs(1));
399
400 socket.set_reconnect_url(get_url())?;
401
402 let result = socket.emit("restart_server", json!(""));
403 assert!(result.is_ok(), "should emit restart success");
404
405 for _ in 0..10 {
407 std::thread::sleep(std::time::Duration::from_millis(400));
408 if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 {
409 break;
410 }
411 }
412
413 assert_eq!(load(&CONNECT_NUM), 2, "should connect twice");
414 assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages");
415 assert_eq!(load(&CLOSE_NUM), 1, "should close once");
416
417 socket.disconnect()?;
418 Ok(())
419 }
420
421 #[test]
422 fn socket_io_iterator_integration() -> Result<()> {
423 let url = crate::test::socket_io_server();
424 let builder = ClientBuilder::new(url);
425 let builder_clone = builder.clone();
426
427 let client = Arc::new(RwLock::new(builder_clone.connect_raw()?));
428 let mut socket = Client {
429 builder: Arc::new(Mutex::new(builder)),
430 client,
431 backoff: Default::default(),
432 };
433 let socket_clone = socket.clone();
434
435 let packets: Arc<RwLock<Vec<Packet>>> = Default::default();
436 let packets_clone = packets.clone();
437
438 std::thread::spawn(move || {
439 for packet in socket_clone.iter() {
440 {
441 let mut packets = packets_clone.write().unwrap();
442 if let Ok(packet) = packet {
443 (*packets).push(packet);
444 }
445 }
446 }
447 });
448
449 std::thread::sleep(Duration::from_millis(100));
451 let lock = packets.read().unwrap();
452 let pre_num = lock.len();
453 drop(lock);
454
455 let _ = socket.disconnect();
456 socket.reconnect()?;
457
458 std::thread::sleep(Duration::from_millis(100));
460
461 let lock = packets.read().unwrap();
462 let post_num = lock.len();
463 drop(lock);
464
465 assert!(
466 pre_num < post_num,
467 "pre_num {} should less than post_num {}",
468 pre_num,
469 post_num
470 );
471
472 Ok(())
473 }
474
475 fn load(num: &AtomicUsize) -> usize {
476 num.load(Ordering::Acquire)
477 }
478}