1use crate::{
2 ack::AckId, callback::Callback, packet::PacketType, server::Client as ServerSocket,
3 socket::RawSocket, Error, Event, NameSpace, Payload,
4};
5use dashmap::DashMap;
6use engineio_rs::{Event as EngineEvent, Server as EngineServer, Sid as EngineSid};
7use futures_util::future::BoxFuture;
8use serde_json::json;
9use std::{
10 collections::{HashMap, HashSet},
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc,
14 },
15 time::Duration,
16};
17use tracing::{error, trace, warn};
18
19const CONNECT_TIMEOUT: u64 = 5;
21
22type Sid = Arc<String>;
23type Room = String;
24type Rooms = DashMap<NameSpace, HashMap<Room, HashSet<Sid>>>;
25type On = DashMap<Event, Callback<ServerSocket>>;
26
27pub struct Server {
28 pub(crate) on: DashMap<NameSpace, Arc<On>>,
29 pub(crate) rooms: Rooms,
30 pub(crate) clients: DashMap<EngineSid, DashMap<Sid, HashMap<NameSpace, ServerSocket>>>,
31 pub(crate) engine_server: EngineServer,
32 pub(crate) sid_generator: SidGenerator,
33}
34
35impl Server {
36 pub fn client_count(self: &Arc<Self>) -> usize {
37 self.clients.iter().map(|i| i.iter().count()).sum()
38 }
39
40 #[allow(dead_code)]
41 pub async fn serve(self: Arc<Self>) {
42 self.recv_event();
43 self.engine_server.serve().await
44 }
45
46 pub async fn emit_to<E, D>(self: &Arc<Self>, nsp: &str, rooms: Vec<&str>, event: E, data: D)
47 where
48 E: Into<Event>,
49 D: Into<Payload>,
50 {
51 let event = event.into();
52 let payload = data.into();
53
54 let sids_to_emit = self.sids_to_emit(nsp, rooms).await;
55
56 for sid in sids_to_emit {
57 if let Some(client) = self.client(&sid, nsp).await {
58 let event = event.clone();
59 let payload = payload.clone();
60
61 tokio::spawn(async move {
62 let r = client.emit(event, payload).await;
63 trace!("server emit_to: {}, status: {:?}", sid, r);
64 if r.is_err() {
65 error!("emit_to {} failed {:?}", sid, r);
66 }
67 });
68 }
69 }
70 }
71
72 pub async fn emit_to_with_ack<F, E, D>(
73 &self,
74 nsp: &str,
75 rooms: Vec<&str>,
76 event: E,
77 data: D,
78 timeout: Duration,
79 callback: F,
80 ) where
81 F: for<'a> std::ops::FnMut(
82 Option<Payload>,
83 ServerSocket,
84 Option<AckId>,
85 ) -> BoxFuture<'static, ()>
86 + 'static
87 + Send
88 + Sync
89 + Clone,
90 E: Into<Event>,
91 D: Into<Payload>,
92 {
93 let event = event.into();
94 let payload = data.into();
95
96 for sid in self.sids_to_emit(nsp, rooms).await {
97 if let Some(client) = self.client(&sid, nsp).await {
98 let event = event.clone();
99 let payload = payload.clone();
100
101 let callback_clone = callback.clone();
102
103 tokio::spawn(async move {
104 let r = client
105 .emit_with_ack(
106 event.clone(),
107 payload.clone(),
108 timeout,
109 callback_clone.clone(),
110 )
111 .await;
112 if r.is_err() {
113 error!("emit_with_ack to {} {:?}", sid, r);
114 }
115 });
116 }
117 }
118 }
119
120 async fn sids_to_emit(&self, nsp: &str, rooms: Vec<&str>) -> HashSet<Sid> {
121 let clients = &self.rooms;
122 let mut sids_to_emit = HashSet::new();
123 if let Some(room_clients) = clients.get(nsp) {
124 for room_name in rooms {
125 match room_clients.get(room_name) {
126 Some(room) => {
127 for sid in room {
128 sids_to_emit.insert(sid.clone());
129 }
130 }
131 None => {
133 let _ = sids_to_emit.insert(Arc::new(room_name.to_owned()));
134 }
135 };
136 }
137 }
138 sids_to_emit
139 }
140
141 pub(crate) fn recv_event(self: &Arc<Self>) {
142 let event_rx = self.engine_server.event_rx();
143 let server = self.to_owned();
144 tokio::spawn(async move {
145 let mut event_rx = event_rx.lock().await;
146
147 while let Some(event) = event_rx.recv().await {
148 trace!("server recv_event: {:?}", event);
149 match event {
150 EngineEvent::OnOpen(esid) => server.create_client(esid).await,
151 EngineEvent::OnClose(esid) => server.drop_client(&esid).await,
152 EngineEvent::OnPacket(_esid, _packet) => {
153 }
155 _ => {}
156 };
157 }
158 });
159 }
160
161 pub(crate) async fn client(&self, sid: &Sid, nsp: &str) -> Option<ServerSocket> {
162 let esid = &SidGenerator::decode(sid)?;
163 self.clients.get(esid)?.get(sid)?.get(nsp).cloned()
164 }
165
166 pub(crate) async fn join<T: Into<String>>(
167 self: &Arc<Self>,
168 nsp: &str,
169 rooms: Vec<T>,
170 sid: Sid,
171 ) {
172 for room_name in rooms {
173 let room_name = room_name.into();
174 match self.rooms.get_mut(nsp) {
175 None => {
176 let mut room_sids = HashSet::new();
177 room_sids.insert(sid.clone());
178 let mut rooms = HashMap::new();
179 rooms.insert(room_name, room_sids);
180 self.rooms.insert(nsp.to_owned(), rooms);
181 }
182 Some(mut rooms) => {
183 if let Some(room_sids) = rooms.get_mut(&room_name) {
184 let _ = room_sids.insert(sid.clone());
185 } else {
186 let mut room_sids = HashSet::new();
187 room_sids.insert(sid.clone());
188 rooms.insert(room_name, room_sids);
189 }
190 }
191 };
192 }
193 }
194
195 pub(crate) async fn leave(self: &Arc<Self>, nsp: &str, rooms: Vec<&str>, sid: &Sid) {
196 for room_name in rooms {
197 if let Some(mut nsp_rooms) = self.rooms.get_mut(nsp) {
198 if let Some(room_sids) = nsp_rooms.get_mut(room_name) {
199 room_sids.remove(sid);
200 }
201 };
202 }
203 }
204
205 async fn create_client(self: &Arc<Self>, esid: EngineSid) {
206 if let Some(engine_socket) = self.engine_server.socket(&esid).await {
207 let socket = RawSocket::server_end(engine_socket);
208
209 match self.client_info(&esid).await {
211 Some((sid, nsp)) => self.insert_clients(socket, nsp, esid, sid, false).await,
212 None => self.handle_connect(socket, esid).await,
213 };
214 }
215 }
216
217 async fn client_info(&self, esid: &EngineSid) -> Option<(Sid, String)> {
221 let sid_map = self.clients.get(esid)?;
222 let entry = sid_map.iter().next()?;
223 let (sid, nsp_map) = entry.pair();
224 let (nsp, _) = nsp_map.iter().next()?;
225
226 Some((sid.to_owned(), nsp.to_owned()))
227 }
228
229 async fn handle_connect(self: &Arc<Self>, socket: RawSocket, esid: EngineSid) {
230 trace!("handle_connect: {:?}", esid);
231 let slf = self.clone();
232 tokio::spawn(async move {
233 if tokio::time::timeout(
234 Duration::from_secs(CONNECT_TIMEOUT),
235 slf.do_handle_connect(socket, esid.clone()),
236 )
237 .await
238 .is_err()
239 {
240 warn!("handle_connect timeout, {:?} dropped", esid);
241 slf.drop_client(&esid).await;
242 }
243 });
244 }
245
246 async fn do_handle_connect(self: &Arc<Self>, socket: RawSocket, esid: EngineSid) {
247 let sid = self.sid_generator.generate(&esid);
248 while let Some(Ok(packet)) = socket.poll_packet().await {
249 if packet.ptype == PacketType::Connect {
250 let nsp = packet.nsp.clone();
251 self.insert_clients(socket, nsp, esid, sid, true).await;
252 break;
253 } else {
254 continue;
255 }
256 }
257 }
258
259 async fn insert_clients(
260 self: &Arc<Self>,
261 socket: RawSocket,
262 nsp: String,
263 esid: EngineSid,
264 sid: Sid,
265 handshake: bool,
266 ) {
267 if let Some(on) = self.on.get(&nsp) {
268 let client = ServerSocket::new(
269 socket,
270 nsp.clone(),
271 sid.clone(),
272 on.to_owned(),
273 self.clone(),
274 );
275
276 client.connect_callback().await;
277
278 poll(client.clone());
279
280 if handshake {
281 let _ = client.handshake(json!({ "sid": sid.clone() })).await;
282 }
283
284 let sid_map = self.clients.entry(esid).or_default();
285
286 let mut nsp_map = sid_map.entry(sid).or_default();
287 nsp_map.insert(nsp, client);
288 } else {
289 warn!("unkown nsp {} from client", nsp);
290 }
291 }
292
293 async fn drop_client(self: &Arc<Self>, esid: &EngineSid) {
294 self.engine_server.close_socket(esid).await;
295
296 if self.clients.remove(esid).is_some() {
297 }
299
300 self.rooms.iter_mut().for_each(|mut nsp_clients| {
302 for room_clients in nsp_clients.values_mut() {
303 room_clients.retain(|sid| SidGenerator::decode(sid).as_ref() != Some(esid))
304 }
305 });
306 }
307}
308
309#[derive(Default)]
310pub(crate) struct SidGenerator {
311 seq: AtomicUsize,
312}
313
314impl SidGenerator {
315 pub fn generate(&self, engine_sid: &EngineSid) -> Sid {
316 let seq = self.seq.fetch_add(1, Ordering::SeqCst);
317 Arc::new(base64::encode(format!("{}-{}", engine_sid, seq)))
318 }
319
320 pub fn decode(sid: &Sid) -> Option<EngineSid> {
321 let sid_vec = base64::decode(sid.as_bytes()).ok()?;
322 let esid_sid = std::str::from_utf8(&sid_vec).ok()?;
323 let tokens: Vec<&str> = esid_sid.split('-').collect();
324 Some(Arc::new(tokens[0].to_owned()))
325 }
326}
327
328fn poll(socket: ServerSocket) {
329 tokio::runtime::Handle::current().spawn(async move {
330 loop {
331 let next = socket.poll_packet().await;
336 match next {
337 Some(e @ Err(Error::IncompleteResponseFromEngineIo(_))) => {
338 trace!("Network error occurred: {:?}", e.err());
339 }
340 None => break,
341 _ => {}
342 }
343 }
344 });
345}
346
347#[cfg(test)]
348mod test {
349 use std::{
350 sync::{
351 atomic::{AtomicBool, Ordering},
352 Arc,
353 },
354 time::Duration,
355 };
356
357 use crate::{
358 client::ClientBuilder, client::Socket, server::client::Client as ServerClient,
359 test::rust_socket_io_server, AckId, Event, Payload, Server, ServerBuilder,
360 };
361
362 use super::SidGenerator;
363 use futures_util::FutureExt;
364 use serde_json::json;
365 use tracing::info;
366
367 #[test]
368 fn test_sid_generator() {
369 let generator = SidGenerator::default();
370 let engine_sid = Arc::new("engine_sid".to_owned());
371 let sid = generator.generate(&engine_sid);
372
373 assert_eq!(SidGenerator::decode(&sid), Some(engine_sid));
374 }
375
376 #[tokio::test]
377 async fn test_server() {
378 let server = setup();
382 test_emit().await;
383 test_client_count(&server).await;
384 test_client_ask_ack().await;
385 test_server_ask_ack().await;
386 }
387
388 async fn test_emit() {
389 let is_recv = Arc::new(AtomicBool::default());
390 let is_recv_clone = Arc::clone(&is_recv);
391
392 let callback = move |_: Option<Payload>, _: Socket, _: Option<AckId>| {
393 let is_recv = is_recv_clone.clone();
394 async move {
395 tracing::info!("1");
396 is_recv.store(true, Ordering::SeqCst);
397 tracing::info!("2");
398 }
399 .boxed()
400 };
401
402 let url = rust_socket_io_server();
403 let socket = ClientBuilder::new(url)
404 .namespace("/admin")
405 .on("echo", callback)
406 .on(Event::Connect, move |_payload, socket, _| {
407 async move {
408 socket.emit("echo", json!("data")).await.expect("success");
409 }
410 .boxed()
411 })
412 .connect()
413 .await;
414
415 assert!(socket.is_ok());
416
417 tokio::time::sleep(Duration::from_millis(100)).await;
419
420 assert!(is_recv.load(Ordering::SeqCst));
421 }
422
423 async fn test_client_count(server: &Arc<Server>) {
424 let url = rust_socket_io_server();
425
426 let socket = ClientBuilder::new(url.clone())
427 .namespace("/admin")
428 .on(Event::Connect, move |_payload, socket, _| {
429 async move {
430 socket.emit("echo", json!("data")).await.expect("success");
431 }
432 .boxed()
433 })
434 .connect()
435 .await;
436
437 let socket2 = ClientBuilder::new(url.clone())
438 .namespace("/admin")
439 .on(Event::Connect, move |_payload, socket, _| {
440 async move {
441 socket.emit("echo", json!("data")).await.expect("success");
442 }
443 .boxed()
444 })
445 .connect()
446 .await;
447
448 let socket3 = ClientBuilder::new(url)
449 .namespace("/admin")
450 .on(Event::Connect, move |_payload, socket, _| {
451 async move {
452 socket.emit("echo", json!("data")).await.expect("success");
453 }
454 .boxed()
455 })
456 .connect()
457 .await;
458
459 assert!(socket.is_ok());
460 assert!(socket2.is_ok());
461 assert!(socket3.is_ok());
462
463 assert_eq!(server.client_count(), 3);
464 }
465
466 async fn test_client_ask_ack() {
467 let is_client_ack = Arc::new(AtomicBool::default());
468 let is_client_ack_clone = Arc::clone(&is_client_ack);
469
470 let client_ack_callback =
471 move |_payload: Option<Payload>, _socket: Socket, _need_ack: Option<AckId>| {
472 let is_client_ack = is_client_ack_clone.clone();
473 async move {
474 is_client_ack.store(true, Ordering::SeqCst);
475 }
476 .boxed()
477 };
478
479 let url = rust_socket_io_server();
480 let socket = ClientBuilder::new(url)
481 .namespace("/admin")
482 .on(Event::Connect, move |_payload, socket, _| {
483 let client_ack_callback = client_ack_callback.clone();
484 async move {
485 socket
486 .emit_with_ack(
487 "client_ack",
488 json!("data"),
489 Duration::from_millis(200),
490 client_ack_callback,
491 )
492 .await
493 .expect("success");
494 }
495 .boxed()
496 })
497 .connect()
498 .await;
499
500 assert!(socket.is_ok());
501
502 tokio::time::sleep(Duration::from_millis(100)).await;
504
505 assert!(is_client_ack.load(Ordering::SeqCst));
506 }
507
508 async fn test_server_ask_ack() {
509 let is_server_ask_ack = Arc::new(AtomicBool::default());
510 let is_server_recv_ack = Arc::new(AtomicBool::default());
511 let is_server_ask_ack_clone = Arc::clone(&is_server_ask_ack);
512 let is_server_recv_ack_clone = Arc::clone(&is_server_recv_ack);
513
514 let server_ask_ack =
515 move |_payload: Option<Payload>, socket: Socket, need_ack: Option<AckId>| {
516 let is_server_ask_ack = is_server_ask_ack_clone.clone();
517 async move {
518 assert!(need_ack.is_some());
519 if let Some(ack_id) = need_ack {
520 socket.ack(ack_id, json!("")).await.expect("success");
521 is_server_ask_ack.store(true, Ordering::SeqCst);
522 }
523 }
524 .boxed()
525 };
526
527 let server_recv_ack =
528 move |_payload: Option<Payload>, _socket: Socket, _need_ack: Option<AckId>| {
529 let is_server_recv_ack = is_server_recv_ack_clone.clone();
530 async move {
531 is_server_recv_ack.store(true, Ordering::SeqCst);
532 }
533 .boxed()
534 };
535
536 let url = rust_socket_io_server();
537 let socket = ClientBuilder::new(url)
538 .namespace("/admin")
539 .on("server_ask_ack", server_ask_ack)
540 .on("server_recv_ack", server_recv_ack)
541 .on(Event::Connect, move |_payload, socket, _| {
542 async move {
543 socket
544 .emit("trigger_server_ack", json!("data"))
545 .await
546 .expect("success");
547 }
548 .boxed()
549 })
550 .connect()
551 .await;
552
553 assert!(socket.is_ok());
554
555 tokio::time::sleep(Duration::from_millis(100)).await;
557
558 assert!(is_server_ask_ack.load(Ordering::SeqCst));
559 assert!(is_server_recv_ack.load(Ordering::SeqCst));
560 }
561
562 fn setup() -> Arc<crate::Server> {
563 let echo_callback =
564 move |_payload: Option<Payload>, socket: ServerClient, _need_ack: Option<AckId>| {
565 async move {
566 info!("server echo callback");
567 socket.join(vec!["room 1"]).await;
568 socket.emit_to(vec!["room 1"], "echo", json!("")).await;
569 socket.leave(vec!["room 1"]).await;
570 info!("server echo callback done");
571 }
572 .boxed()
573 };
574
575 let client_ack =
576 move |_payload: Option<Payload>, socket: ServerClient, need_ack: Option<AckId>| {
577 async move {
578 if let Some(ack_id) = need_ack {
579 socket
580 .ack(ack_id, json!("ack to client"))
581 .await
582 .expect("success");
583 }
584 }
585 .boxed()
586 };
587
588 let server_recv_ack =
589 move |_payload: Option<Payload>, socket: ServerClient, _need_ack: Option<AckId>| {
590 async move {
591 socket
592 .emit("server_recv_ack", json!(""))
593 .await
594 .expect("success");
595 }
596 .boxed()
597 };
598
599 let trigger_ack = move |_message: Option<Payload>, socket: ServerClient, _| {
600 async move {
601 socket.join(vec!["room 2"]).await;
602 socket
603 .emit_to_with_ack(
604 vec!["room 2"],
605 "server_ask_ack",
606 json!(true),
607 Duration::from_millis(400),
608 server_recv_ack,
609 )
610 .await;
611 socket.leave(vec!["room 2"]).await;
612 }
613 .boxed()
614 };
615
616 let url = rust_socket_io_server();
617 let server = ServerBuilder::new(url.port().unwrap())
618 .on("/admin", "echo", echo_callback)
619 .on("/admin", "client_ack", client_ack)
620 .on("/admin", "trigger_server_ack", trigger_ack)
621 .build();
622
623 let server_clone = server.clone();
624
625 tokio::spawn(async move { server.serve().await });
626
627 server_clone
628 }
629}