socketio_rs/server/
server.rs

1use crate::{
2    ack::AckId, callback::Callback, packet::PacketType, server::Client as ServerSocket,
3    socket::RawSocket, Error, Event, NameSpace, Payload,
4};
5use dashmap::DashMap;
6use engineio_rs::{Event as EngineEvent, Server as EngineServer, Sid as EngineSid};
7use futures_util::future::BoxFuture;
8use serde_json::json;
9use std::{
10    collections::{HashMap, HashSet},
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc,
14    },
15    time::Duration,
16};
17use tracing::{error, trace, warn};
18
19// TODO: read from config
20const CONNECT_TIMEOUT: u64 = 5;
21
22type Sid = Arc<String>;
23type Room = String;
24type Rooms = DashMap<NameSpace, HashMap<Room, HashSet<Sid>>>;
25type On = DashMap<Event, Callback<ServerSocket>>;
26
27pub struct Server {
28    pub(crate) on: DashMap<NameSpace, Arc<On>>,
29    pub(crate) rooms: Rooms,
30    pub(crate) clients: DashMap<EngineSid, DashMap<Sid, HashMap<NameSpace, ServerSocket>>>,
31    pub(crate) engine_server: EngineServer,
32    pub(crate) sid_generator: SidGenerator,
33}
34
35impl Server {
36    pub fn client_count(self: &Arc<Self>) -> usize {
37        self.clients.iter().map(|i| i.iter().count()).sum()
38    }
39
40    #[allow(dead_code)]
41    pub async fn serve(self: Arc<Self>) {
42        self.recv_event();
43        self.engine_server.serve().await
44    }
45
46    pub async fn emit_to<E, D>(self: &Arc<Self>, nsp: &str, rooms: Vec<&str>, event: E, data: D)
47    where
48        E: Into<Event>,
49        D: Into<Payload>,
50    {
51        let event = event.into();
52        let payload = data.into();
53
54        let sids_to_emit = self.sids_to_emit(nsp, rooms).await;
55
56        for sid in sids_to_emit {
57            if let Some(client) = self.client(&sid, nsp).await {
58                let event = event.clone();
59                let payload = payload.clone();
60
61                tokio::spawn(async move {
62                    let r = client.emit(event, payload).await;
63                    trace!("server emit_to: {}, status: {:?}", sid, r);
64                    if r.is_err() {
65                        error!("emit_to {} failed {:?}", sid, r);
66                    }
67                });
68            }
69        }
70    }
71
72    pub async fn emit_to_with_ack<F, E, D>(
73        &self,
74        nsp: &str,
75        rooms: Vec<&str>,
76        event: E,
77        data: D,
78        timeout: Duration,
79        callback: F,
80    ) where
81        F: for<'a> std::ops::FnMut(
82                Option<Payload>,
83                ServerSocket,
84                Option<AckId>,
85            ) -> BoxFuture<'static, ()>
86            + 'static
87            + Send
88            + Sync
89            + Clone,
90        E: Into<Event>,
91        D: Into<Payload>,
92    {
93        let event = event.into();
94        let payload = data.into();
95
96        for sid in self.sids_to_emit(nsp, rooms).await {
97            if let Some(client) = self.client(&sid, nsp).await {
98                let event = event.clone();
99                let payload = payload.clone();
100
101                let callback_clone = callback.clone();
102
103                tokio::spawn(async move {
104                    let r = client
105                        .emit_with_ack(
106                            event.clone(),
107                            payload.clone(),
108                            timeout,
109                            callback_clone.clone(),
110                        )
111                        .await;
112                    if r.is_err() {
113                        error!("emit_with_ack to {} {:?}", sid, r);
114                    }
115                });
116            }
117        }
118    }
119
120    async fn sids_to_emit(&self, nsp: &str, rooms: Vec<&str>) -> HashSet<Sid> {
121        let clients = &self.rooms;
122        let mut sids_to_emit = HashSet::new();
123        if let Some(room_clients) = clients.get(nsp) {
124            for room_name in rooms {
125                match room_clients.get(room_name) {
126                    Some(room) => {
127                        for sid in room {
128                            sids_to_emit.insert(sid.clone());
129                        }
130                    }
131                    // room may be sid
132                    None => {
133                        let _ = sids_to_emit.insert(Arc::new(room_name.to_owned()));
134                    }
135                };
136            }
137        }
138        sids_to_emit
139    }
140
141    pub(crate) fn recv_event(self: &Arc<Self>) {
142        let event_rx = self.engine_server.event_rx();
143        let server = self.to_owned();
144        tokio::spawn(async move {
145            let mut event_rx = event_rx.lock().await;
146
147            while let Some(event) = event_rx.recv().await {
148                trace!("server recv_event: {:?}", event);
149                match event {
150                    EngineEvent::OnOpen(esid) => server.create_client(esid).await,
151                    EngineEvent::OnClose(esid) => server.drop_client(&esid).await,
152                    EngineEvent::OnPacket(_esid, _packet) => {
153                        // TODO: watch new namespace packet
154                    }
155                    _ => {}
156                };
157            }
158        });
159    }
160
161    pub(crate) async fn client(&self, sid: &Sid, nsp: &str) -> Option<ServerSocket> {
162        let esid = &SidGenerator::decode(sid)?;
163        self.clients.get(esid)?.get(sid)?.get(nsp).cloned()
164    }
165
166    pub(crate) async fn join<T: Into<String>>(
167        self: &Arc<Self>,
168        nsp: &str,
169        rooms: Vec<T>,
170        sid: Sid,
171    ) {
172        for room_name in rooms {
173            let room_name = room_name.into();
174            match self.rooms.get_mut(nsp) {
175                None => {
176                    let mut room_sids = HashSet::new();
177                    room_sids.insert(sid.clone());
178                    let mut rooms = HashMap::new();
179                    rooms.insert(room_name, room_sids);
180                    self.rooms.insert(nsp.to_owned(), rooms);
181                }
182                Some(mut rooms) => {
183                    if let Some(room_sids) = rooms.get_mut(&room_name) {
184                        let _ = room_sids.insert(sid.clone());
185                    } else {
186                        let mut room_sids = HashSet::new();
187                        room_sids.insert(sid.clone());
188                        rooms.insert(room_name, room_sids);
189                    }
190                }
191            };
192        }
193    }
194
195    pub(crate) async fn leave(self: &Arc<Self>, nsp: &str, rooms: Vec<&str>, sid: &Sid) {
196        for room_name in rooms {
197            if let Some(mut nsp_rooms) = self.rooms.get_mut(nsp) {
198                if let Some(room_sids) = nsp_rooms.get_mut(room_name) {
199                    room_sids.remove(sid);
200                }
201            };
202        }
203    }
204
205    async fn create_client(self: &Arc<Self>, esid: EngineSid) {
206        if let Some(engine_socket) = self.engine_server.socket(&esid).await {
207            let socket = RawSocket::server_end(engine_socket);
208
209            // TODO: support multiple namespace
210            match self.client_info(&esid).await {
211                Some((sid, nsp)) => self.insert_clients(socket, nsp, esid, sid, false).await,
212                None => self.handle_connect(socket, esid).await,
213            };
214        }
215    }
216
217    // TODO: support multiple nsp
218    // currently one esid mapping to one sid,
219    // one sid mapping one nsp
220    async fn client_info(&self, esid: &EngineSid) -> Option<(Sid, String)> {
221        let sid_map = self.clients.get(esid)?;
222        let entry = sid_map.iter().next()?;
223        let (sid, nsp_map) = entry.pair();
224        let (nsp, _) = nsp_map.iter().next()?;
225
226        Some((sid.to_owned(), nsp.to_owned()))
227    }
228
229    async fn handle_connect(self: &Arc<Self>, socket: RawSocket, esid: EngineSid) {
230        trace!("handle_connect: {:?}", esid);
231        let slf = self.clone();
232        tokio::spawn(async move {
233            if tokio::time::timeout(
234                Duration::from_secs(CONNECT_TIMEOUT),
235                slf.do_handle_connect(socket, esid.clone()),
236            )
237            .await
238            .is_err()
239            {
240                warn!("handle_connect timeout, {:?} dropped", esid);
241                slf.drop_client(&esid).await;
242            }
243        });
244    }
245
246    async fn do_handle_connect(self: &Arc<Self>, socket: RawSocket, esid: EngineSid) {
247        let sid = self.sid_generator.generate(&esid);
248        while let Some(Ok(packet)) = socket.poll_packet().await {
249            if packet.ptype == PacketType::Connect {
250                let nsp = packet.nsp.clone();
251                self.insert_clients(socket, nsp, esid, sid, true).await;
252                break;
253            } else {
254                continue;
255            }
256        }
257    }
258
259    async fn insert_clients(
260        self: &Arc<Self>,
261        socket: RawSocket,
262        nsp: String,
263        esid: EngineSid,
264        sid: Sid,
265        handshake: bool,
266    ) {
267        if let Some(on) = self.on.get(&nsp) {
268            let client = ServerSocket::new(
269                socket,
270                nsp.clone(),
271                sid.clone(),
272                on.to_owned(),
273                self.clone(),
274            );
275
276            client.connect_callback().await;
277
278            poll(client.clone());
279
280            if handshake {
281                let _ = client.handshake(json!({ "sid": sid.clone() })).await;
282            }
283
284            let sid_map = self.clients.entry(esid).or_default();
285
286            let mut nsp_map = sid_map.entry(sid).or_default();
287            nsp_map.insert(nsp, client);
288        } else {
289            warn!("unkown nsp {} from client", nsp);
290        }
291    }
292
293    async fn drop_client(self: &Arc<Self>, esid: &EngineSid) {
294        self.engine_server.close_socket(esid).await;
295
296        if self.clients.remove(esid).is_some() {
297            //TODO: disconnect
298        }
299
300        // FIXME: performance will be low if too many nsp and rooms
301        self.rooms.iter_mut().for_each(|mut nsp_clients| {
302            for room_clients in nsp_clients.values_mut() {
303                room_clients.retain(|sid| SidGenerator::decode(sid).as_ref() != Some(esid))
304            }
305        });
306    }
307}
308
309#[derive(Default)]
310pub(crate) struct SidGenerator {
311    seq: AtomicUsize,
312}
313
314impl SidGenerator {
315    pub fn generate(&self, engine_sid: &EngineSid) -> Sid {
316        let seq = self.seq.fetch_add(1, Ordering::SeqCst);
317        Arc::new(base64::encode(format!("{}-{}", engine_sid, seq)))
318    }
319
320    pub fn decode(sid: &Sid) -> Option<EngineSid> {
321        let sid_vec = base64::decode(sid.as_bytes()).ok()?;
322        let esid_sid = std::str::from_utf8(&sid_vec).ok()?;
323        let tokens: Vec<&str> = esid_sid.split('-').collect();
324        Some(Arc::new(tokens[0].to_owned()))
325    }
326}
327
328fn poll(socket: ServerSocket) {
329    tokio::runtime::Handle::current().spawn(async move {
330        loop {
331            // tries to restart a poll cycle whenever a 'normal' error occurs,
332            // it just logs on network errors, in case the poll cycle returned
333            // `Result::Ok`, the server receives a close frame so it's safe to
334            // terminate
335            let next = socket.poll_packet().await;
336            match next {
337                Some(e @ Err(Error::IncompleteResponseFromEngineIo(_))) => {
338                    trace!("Network error occurred: {:?}", e.err());
339                }
340                None => break,
341                _ => {}
342            }
343        }
344    });
345}
346
347#[cfg(test)]
348mod test {
349    use std::{
350        sync::{
351            atomic::{AtomicBool, Ordering},
352            Arc,
353        },
354        time::Duration,
355    };
356
357    use crate::{
358        client::ClientBuilder, client::Socket, server::client::Client as ServerClient,
359        test::rust_socket_io_server, AckId, Event, Payload, Server, ServerBuilder,
360    };
361
362    use super::SidGenerator;
363    use futures_util::FutureExt;
364    use serde_json::json;
365    use tracing::info;
366
367    #[test]
368    fn test_sid_generator() {
369        let generator = SidGenerator::default();
370        let engine_sid = Arc::new("engine_sid".to_owned());
371        let sid = generator.generate(&engine_sid);
372
373        assert_eq!(SidGenerator::decode(&sid), Some(engine_sid));
374    }
375
376    #[tokio::test]
377    async fn test_server() {
378        // tracing_subscriber::fmt()
379        //     .with_env_filter("engineio=trace,socketio=trace")
380        //     .init();
381        let server = setup();
382        test_emit().await;
383        test_client_count(&server).await;
384        test_client_ask_ack().await;
385        test_server_ask_ack().await;
386    }
387
388    async fn test_emit() {
389        let is_recv = Arc::new(AtomicBool::default());
390        let is_recv_clone = Arc::clone(&is_recv);
391
392        let callback = move |_: Option<Payload>, _: Socket, _: Option<AckId>| {
393            let is_recv = is_recv_clone.clone();
394            async move {
395                tracing::info!("1");
396                is_recv.store(true, Ordering::SeqCst);
397                tracing::info!("2");
398            }
399            .boxed()
400        };
401
402        let url = rust_socket_io_server();
403        let socket = ClientBuilder::new(url)
404            .namespace("/admin")
405            .on("echo", callback)
406            .on(Event::Connect, move |_payload, socket, _| {
407                async move {
408                    socket.emit("echo", json!("data")).await.expect("success");
409                }
410                .boxed()
411            })
412            .connect()
413            .await;
414
415        assert!(socket.is_ok());
416
417        // wait recv data
418        tokio::time::sleep(Duration::from_millis(100)).await;
419
420        assert!(is_recv.load(Ordering::SeqCst));
421    }
422
423    async fn test_client_count(server: &Arc<Server>) {
424        let url = rust_socket_io_server();
425
426        let socket = ClientBuilder::new(url.clone())
427            .namespace("/admin")
428            .on(Event::Connect, move |_payload, socket, _| {
429                async move {
430                    socket.emit("echo", json!("data")).await.expect("success");
431                }
432                .boxed()
433            })
434            .connect()
435            .await;
436
437        let socket2 = ClientBuilder::new(url.clone())
438            .namespace("/admin")
439            .on(Event::Connect, move |_payload, socket, _| {
440                async move {
441                    socket.emit("echo", json!("data")).await.expect("success");
442                }
443                .boxed()
444            })
445            .connect()
446            .await;
447
448        let socket3 = ClientBuilder::new(url)
449            .namespace("/admin")
450            .on(Event::Connect, move |_payload, socket, _| {
451                async move {
452                    socket.emit("echo", json!("data")).await.expect("success");
453                }
454                .boxed()
455            })
456            .connect()
457            .await;
458
459        assert!(socket.is_ok());
460        assert!(socket2.is_ok());
461        assert!(socket3.is_ok());
462
463        assert_eq!(server.client_count(), 3);
464    }
465
466    async fn test_client_ask_ack() {
467        let is_client_ack = Arc::new(AtomicBool::default());
468        let is_client_ack_clone = Arc::clone(&is_client_ack);
469
470        let client_ack_callback =
471            move |_payload: Option<Payload>, _socket: Socket, _need_ack: Option<AckId>| {
472                let is_client_ack = is_client_ack_clone.clone();
473                async move {
474                    is_client_ack.store(true, Ordering::SeqCst);
475                }
476                .boxed()
477            };
478
479        let url = rust_socket_io_server();
480        let socket = ClientBuilder::new(url)
481            .namespace("/admin")
482            .on(Event::Connect, move |_payload, socket, _| {
483                let client_ack_callback = client_ack_callback.clone();
484                async move {
485                    socket
486                        .emit_with_ack(
487                            "client_ack",
488                            json!("data"),
489                            Duration::from_millis(200),
490                            client_ack_callback,
491                        )
492                        .await
493                        .expect("success");
494                }
495                .boxed()
496            })
497            .connect()
498            .await;
499
500        assert!(socket.is_ok());
501
502        // wait recv data
503        tokio::time::sleep(Duration::from_millis(100)).await;
504
505        assert!(is_client_ack.load(Ordering::SeqCst));
506    }
507
508    async fn test_server_ask_ack() {
509        let is_server_ask_ack = Arc::new(AtomicBool::default());
510        let is_server_recv_ack = Arc::new(AtomicBool::default());
511        let is_server_ask_ack_clone = Arc::clone(&is_server_ask_ack);
512        let is_server_recv_ack_clone = Arc::clone(&is_server_recv_ack);
513
514        let server_ask_ack =
515            move |_payload: Option<Payload>, socket: Socket, need_ack: Option<AckId>| {
516                let is_server_ask_ack = is_server_ask_ack_clone.clone();
517                async move {
518                    assert!(need_ack.is_some());
519                    if let Some(ack_id) = need_ack {
520                        socket.ack(ack_id, json!("")).await.expect("success");
521                        is_server_ask_ack.store(true, Ordering::SeqCst);
522                    }
523                }
524                .boxed()
525            };
526
527        let server_recv_ack =
528            move |_payload: Option<Payload>, _socket: Socket, _need_ack: Option<AckId>| {
529                let is_server_recv_ack = is_server_recv_ack_clone.clone();
530                async move {
531                    is_server_recv_ack.store(true, Ordering::SeqCst);
532                }
533                .boxed()
534            };
535
536        let url = rust_socket_io_server();
537        let socket = ClientBuilder::new(url)
538            .namespace("/admin")
539            .on("server_ask_ack", server_ask_ack)
540            .on("server_recv_ack", server_recv_ack)
541            .on(Event::Connect, move |_payload, socket, _| {
542                async move {
543                    socket
544                        .emit("trigger_server_ack", json!("data"))
545                        .await
546                        .expect("success");
547                }
548                .boxed()
549            })
550            .connect()
551            .await;
552
553        assert!(socket.is_ok());
554
555        // wait recv data
556        tokio::time::sleep(Duration::from_millis(100)).await;
557
558        assert!(is_server_ask_ack.load(Ordering::SeqCst));
559        assert!(is_server_recv_ack.load(Ordering::SeqCst));
560    }
561
562    fn setup() -> Arc<crate::Server> {
563        let echo_callback =
564            move |_payload: Option<Payload>, socket: ServerClient, _need_ack: Option<AckId>| {
565                async move {
566                    info!("server echo callback");
567                    socket.join(vec!["room 1"]).await;
568                    socket.emit_to(vec!["room 1"], "echo", json!("")).await;
569                    socket.leave(vec!["room 1"]).await;
570                    info!("server echo callback done");
571                }
572                .boxed()
573            };
574
575        let client_ack =
576            move |_payload: Option<Payload>, socket: ServerClient, need_ack: Option<AckId>| {
577                async move {
578                    if let Some(ack_id) = need_ack {
579                        socket
580                            .ack(ack_id, json!("ack to client"))
581                            .await
582                            .expect("success");
583                    }
584                }
585                .boxed()
586            };
587
588        let server_recv_ack =
589            move |_payload: Option<Payload>, socket: ServerClient, _need_ack: Option<AckId>| {
590                async move {
591                    socket
592                        .emit("server_recv_ack", json!(""))
593                        .await
594                        .expect("success");
595                }
596                .boxed()
597            };
598
599        let trigger_ack = move |_message: Option<Payload>, socket: ServerClient, _| {
600            async move {
601                socket.join(vec!["room 2"]).await;
602                socket
603                    .emit_to_with_ack(
604                        vec!["room 2"],
605                        "server_ask_ack",
606                        json!(true),
607                        Duration::from_millis(400),
608                        server_recv_ack,
609                    )
610                    .await;
611                socket.leave(vec!["room 2"]).await;
612            }
613            .boxed()
614        };
615
616        let url = rust_socket_io_server();
617        let server = ServerBuilder::new(url.port().unwrap())
618            .on("/admin", "echo", echo_callback)
619            .on("/admin", "client_ack", client_ack)
620            .on("/admin", "trigger_server_ack", trigger_ack)
621            .build();
622
623        let server_clone = server.clone();
624
625        tokio::spawn(async move { server.serve().await });
626
627        server_clone
628    }
629}