1use std::{
2 ops::{Deref, DerefMut},
3 sync::Arc,
4 time::Duration,
5};
6
7use crate::{
8 socket::Socket as InnerSocket, AckId, ClientBuilder, Error, Event, Packet, Payload, Result,
9};
10
11use backoff::{backoff::Backoff, ExponentialBackoff, ExponentialBackoffBuilder};
12use futures_util::future::BoxFuture;
13use tokio::sync::RwLock;
14use tracing::{trace, warn};
15
16#[derive(Clone)]
17pub struct Client {
18 builder: ClientBuilder,
19 socket: Arc<RwLock<InnerSocket<Socket>>>,
20 backoff: ExponentialBackoff,
21 connected: Arc<RwLock<bool>>,
22}
23
24#[derive(Clone)]
25pub struct Socket {
26 pub(crate) socket: InnerSocket<Self>,
27}
28
29impl From<InnerSocket<Socket>> for Socket {
30 fn from(socket: InnerSocket<Socket>) -> Self {
31 Self { socket }
32 }
33}
34
35impl Client {
36 #[inline]
69 pub async fn emit<E, D>(&self, event: E, data: D) -> Result<()>
70 where
71 E: Into<Event>,
72 D: Into<Payload>,
73 {
74 let socket = self.socket.read().await;
75 socket.emit(event, data).await
76 }
77
78 #[inline]
124 pub async fn emit_with_ack<F, E, D>(
125 &self,
126 event: E,
127 data: D,
128 timeout: Duration,
129 callback: F,
130 ) -> Result<()>
131 where
132 F: for<'a> std::ops::FnMut(
133 Option<Payload>,
134 Socket,
135 Option<AckId>,
136 ) -> BoxFuture<'static, ()>
137 + 'static
138 + Send
139 + Sync,
140 E: Into<Event>,
141 D: Into<Payload>,
142 {
143 let socket = self.socket.read().await;
144 socket.emit_with_ack(event, data, timeout, callback).await
145 }
146
147 pub async fn ack(&self, id: usize, data: Payload) -> Result<()> {
148 let socket = self.socket.read().await;
149 socket.ack(id, data).await
150 }
151
152 pub async fn disconnect(&self) -> Result<()> {
155 trace!("client disconnect");
156 let mut connected = self.connected.write().await;
157 if !*connected {
158 return Ok(());
159 }
160 *connected = false;
161 self.disconnect_socket().await
162 }
163
164 async fn disconnect_socket(&self) -> Result<()> {
165 let socket = self.socket.read().await;
166 socket.disconnect().await
167 }
168
169 pub(crate) async fn new(builder: ClientBuilder) -> Result<Self> {
170 let b = builder.clone();
171 let socket = b.connect_socket().await?;
172 let connected = Arc::new(RwLock::new(true));
173 let backoff = ExponentialBackoffBuilder::new()
174 .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
175 .with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
176 .build();
177
178 let s = Self {
179 builder,
180 socket: Arc::new(RwLock::new(socket)),
181 backoff,
182 connected,
183 };
184
185 Ok(s)
186 }
187
188 async fn reconnect(&mut self) {
189 let mut reconnect_attempts = 0;
190 if self.builder.reconnect {
191 loop {
192 if let Some(max_reconnect_attempts) = self.builder.max_reconnect_attempts {
193 if reconnect_attempts > max_reconnect_attempts {
194 break;
195 }
196 }
197 reconnect_attempts += 1;
198
199 if let Some(backoff) = self.backoff.next_backoff() {
200 trace!("reconnect backoff {:?}", backoff);
201 tokio::time::sleep(backoff).await;
202 }
203
204 trace!("client reconnect {}", reconnect_attempts);
205 if self.do_reconnect().await.is_ok() {
206 break;
207 }
208 }
209 }
210 }
211
212 async fn do_reconnect(&self) -> Result<()> {
213 let new_socket = self.builder.clone().connect_socket().await?;
214 let mut socket = self.socket.write().await;
215 *socket = new_socket;
216 Ok(())
217 }
218
219 pub(crate) fn poll_callback(&self) {
220 let mut self_clone = self.clone();
221 tokio::spawn(async move {
223 trace!("start poll_callback ");
224 #[allow(clippy::for_loops_over_fallibles)]
229 loop {
230 let packet = self_clone.poll_packet().await;
231 trace!("poll_callback packet {:?}", packet);
232 if let Some(Err(Error::IncompleteResponseFromEngineIo(_))) = packet {
233 let _ = self_clone.disconnect_socket().await;
235 self_clone.reconnect().await;
236 }
237 if !*self_clone.connected.read().await {
238 break;
239 }
240 }
241 warn!("poll_callback exist");
242 });
243 }
244
245 pub(crate) async fn poll_packet(&self) -> Option<Result<Packet>> {
246 let socket = self.socket.read().await;
247 socket.poll_packet().await
248 }
249}
250
251impl Deref for Socket {
252 type Target = InnerSocket<Self>;
253
254 fn deref(&self) -> &Self::Target {
255 &self.socket
256 }
257}
258
259impl DerefMut for Socket {
260 fn deref_mut(&mut self) -> &mut Self::Target {
261 &mut self.socket
262 }
263}
264
265#[cfg(test)]
266mod test {
267 use std::time::Duration;
268
269 use super::*;
270 use crate::{
271 test::socket_io_server, AckId, Client, ClientBuilder, Event, Packet, PacketType, Payload,
272 Result, ServerBuilder, ServerSocket,
273 };
274
275 use bytes::Bytes;
276 use futures_util::FutureExt;
277 use serde_json::json;
278 use tokio::{sync::mpsc::unbounded_channel, time::sleep};
279 use tracing::info;
280
281 #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
282 async fn test_client() -> Result<()> {
283 setup_server();
287
288 socket_io_integration().await?;
289 socket_io_builder_integration().await?;
290 socket_io_builder_integration_iterator().await?;
291 Ok(())
292 }
293
294 async fn socket_io_integration() -> Result<()> {
295 let url = socket_io_server();
296
297 let socket = ClientBuilder::new(url)
298 .on("test", |msg, _, _| {
299 async {
300 match msg {
301 Some(Payload::Json(data)) => info!("Received string: {:?}", data),
302 Some(Payload::Binary(bin)) => info!("Received binary data: {:#?}", bin),
303 Some(Payload::Multi(multi)) => info!("Received multi {:?}", multi),
304 _ => {}
305 }
306 }
307 .boxed()
308 })
309 .connect()
310 .await?;
311
312 let payload = json!({"token": 123_i32});
313 let result = socket.emit("test", Payload::Json(payload.clone())).await;
314
315 assert!(result.is_ok());
316
317 let ack = socket
318 .emit_with_ack(
319 "test",
320 Payload::Json(payload),
321 Duration::from_secs(1),
322 |message: Option<Payload>, socket: Socket, _| {
323 async move {
324 let result = socket
325 .emit("test", Payload::Json(json!({"got ack": true})))
326 .await;
327 assert!(result.is_ok());
328
329 info!("Yehaa! My ack got acked?");
330 if let Some(Payload::Json(data)) = message {
331 info!("Received string Ack");
332 info!("Ack data: {:?}", data);
333 }
334 }
335 .boxed()
336 },
337 )
338 .await;
339 assert!(ack.is_ok());
340
341 sleep(Duration::from_secs(2)).await;
342
343 assert!(socket.disconnect().await.is_ok());
344
345 Ok(())
346 }
347
348 async fn socket_io_builder_integration() -> Result<()> {
349 let url = socket_io_server();
350
351 let socket_builder = ClientBuilder::new(url);
353
354 let socket = socket_builder
355 .namespace("/admin")
356 .opening_header("accept-encoding", "application/json")
357 .on("test", |str, _, _| {
358 async move { info!("Received: {:#?}", str) }.boxed()
359 })
360 .on("message", |payload, _, _| {
361 async move { info!("{:#?}", payload) }.boxed()
362 })
363 .connect()
364 .await?;
365
366 assert!(socket.emit("message", json!("Hello World")).await.is_ok());
367
368 assert!(socket
369 .emit("binary", Bytes::from_static(&[46, 88]))
370 .await
371 .is_ok());
372
373 assert!(socket
374 .emit_with_ack(
375 "binary",
376 json!("pls ack"),
377 Duration::from_secs(1),
378 |payload, _, _| async move {
379 info!("Yehaa the ack got acked");
380 info!("With data: {:#?}", payload);
381 }
382 .boxed()
383 )
384 .await
385 .is_ok());
386
387 sleep(Duration::from_secs(2)).await;
388
389 Ok(())
390 }
391
392 async fn socket_io_builder_integration_iterator() -> Result<()> {
393 let url = socket_io_server();
394
395 let socket_builder = ClientBuilder::new(url);
397
398 let socket = socket_builder
399 .namespace("/admin")
400 .opening_header("accept-encoding", "application/json")
401 .on("test", |str, _, _| {
402 async move { info!("Received: {:#?}", str) }.boxed()
403 })
404 .on("message", |payload, _, _| {
405 async move { info!("Received binary {:#?}", payload) }.boxed()
406 })
407 .connect_client()
408 .await?;
409
410 test_socketio_socket(socket, "/admin".to_owned()).await
411 }
412
413 async fn test_socketio_socket(socket: Client, nsp: String) -> Result<()> {
414 let _: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
416
417 let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
418 assert!(packet.is_some());
419
420 let packet = packet.unwrap();
421 assert_eq!(
422 packet,
423 Packet::new(
424 PacketType::Event,
425 nsp.clone(),
426 Some(json!(["test", "Hello from the test event!"])),
427 None,
428 0,
429 None
430 )
431 );
432
433 let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
434 assert!(packet.is_some());
435
436 let packet = packet.unwrap();
437 assert_eq!(
438 packet,
439 Packet::new(
440 PacketType::BinaryEvent,
441 nsp.clone(),
442 Some(json!(["test", {"_placeholder": true, "num": 0}])),
443 None,
444 1,
445 Some(vec![Bytes::from_static(&[1, 2, 3])]),
446 )
447 );
448
449 let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
450 assert!(packet.is_some());
451
452 let packet = packet.unwrap();
453 match packet.data {
454 Some(serde_json::Value::Array(array)) => assert_eq!(array.len(), 5),
455 _ => panic!("invlaid emit multi payload"),
456 }
457
458 let socket_clone = socket.clone();
459 tokio::spawn(async move {
461 loop {
462 let _ = socket_clone.poll_packet().await;
463 }
464 });
465
466 let (tx, mut rx) = unbounded_channel();
467 let tx = Arc::new(tx);
468
469 let cb = move |message: Option<Payload>, _, _| {
470 let tx = tx.clone();
471 async move {
472 match message {
473 Some(Payload::Multi(vec)) => {
474 let _ = tx.send(vec.len() == 2);
475 }
476 _ => {
477 let _ = tx.send(false);
478 }
479 };
480 }
481 .boxed()
482 };
483
484 assert!(socket
485 .emit_with_ack(
486 "client_ack",
487 Payload::Multi(vec![json!(1).into(), json!(2).into()]),
488 Duration::from_secs(10),
489 cb
490 )
491 .await
492 .is_ok());
493
494 match rx.recv().await {
495 Some(true) => {}
496 _ => panic!("ACK callback invlaid"),
497 };
498
499 let (tx, mut rx) = unbounded_channel();
500 let cb = move |message: Option<Payload>, _, _| {
501 let tx = tx.clone();
502 async move {
503 match message {
504 Some(Payload::Multi(vec)) => {
505 let _ = tx.send(vec.len() == 2);
506 }
507 _ => {
508 let _ = tx.send(false);
509 }
510 };
511 }
512 .boxed()
513 };
514
515 assert!(socket
516 .emit_with_ack(
517 "client_ack",
518 Payload::Multi(vec![Bytes::from_static(b"1").into(), json!(2).into()]),
519 Duration::from_secs(10),
520 cb
521 )
522 .await
523 .is_ok());
524
525 match rx.recv().await {
526 Some(true) => {}
527 _ => panic!("BINARY_ACK callback invlaid"),
528 };
529
530 Ok(())
531 }
532
533 fn setup_server() {
534 let echo_callback =
535 move |_payload: Option<Payload>, socket: ServerSocket, _need_ack: Option<AckId>| {
536 async move {
537 let _ = socket.emit("echo", json!("")).await;
538 }
539 .boxed()
540 };
541
542 let client_ack =
543 move |payload: Option<Payload>, socket: ServerSocket, need_ack: Option<AckId>| {
544 async move {
545 if let Some(ack_id) = need_ack {
546 socket
547 .ack(ack_id, payload.unwrap_or_else(|| json!("ackback").into()))
548 .await
549 .expect("success");
550 }
551 }
552 .boxed()
553 };
554
555 let server_recv_ack =
556 move |_payload: Option<Payload>, socket: ServerSocket, _need_ack: Option<AckId>| {
557 async move {
558 socket
559 .emit("server_recv_ack", json!(""))
560 .await
561 .expect("success");
562 }
563 .boxed()
564 };
565
566 let trigger_ack = move |message: Option<Payload>, socket: ServerSocket, _| {
567 async move {
568 let payload = message.unwrap_or_else(|| json!({"ack_back": true}).into());
569 socket
570 .emit_with_ack(
571 "server_ask_ack",
572 payload,
573 Duration::from_millis(400),
574 server_recv_ack,
575 )
576 .await
577 .expect("success");
578 }
579 .boxed()
580 };
581
582 let connect_cb = move |_payload: Option<Payload>, socket: ServerSocket, _| {
583 async move {
584 socket
585 .emit("test", json!("Hello from the test event!"))
586 .await
587 .expect("success");
588
589 socket
590 .emit("test", Payload::Binary(Bytes::from_static(&[1, 2, 3])))
591 .await
592 .expect("success");
593
594 socket
595 .emit(
596 "test",
597 Payload::Multi(vec![
598 json!(1).into(),
599 json!("2").into(),
600 Bytes::from_static(&[3]).into(),
601 Bytes::from_static(b"4").into(),
602 ]),
603 )
604 .await
605 .expect("success");
606 }
607 .boxed()
608 };
609
610 let url = socket_io_server();
611 let server = ServerBuilder::new(url.port().unwrap())
612 .on("/admin", "echo", echo_callback)
613 .on("/admin", "client_ack", client_ack)
614 .on("/admin", "server_ack", trigger_ack)
615 .on("/admin", Event::Connect, connect_cb)
616 .build();
617
618 tokio::spawn(async move { server.serve().await });
619 }
620}