socketioxide/
client.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::{Arc, Mutex, OnceLock, RwLock};
5
6use bytes::Bytes;
7use engineioxide::Str;
8use engineioxide::handler::EngineIoHandler;
9use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket};
10use futures_util::{FutureExt, TryFutureExt};
11
12use matchit::{Match, Router};
13use socketioxide_core::packet::{Packet, PacketData};
14use socketioxide_core::parser::{Parse, ParserState};
15use socketioxide_core::{Sid, Value};
16use tokio::sync::oneshot;
17
18use crate::{
19    ProtocolVersion, SocketIo, SocketIoConfig,
20    adapter::Adapter,
21    errors::Error,
22    handler::ConnectHandler,
23    ns::{Namespace, NamespaceCtr},
24    parser::{ParseError, Parser},
25    socket::DisconnectReason,
26};
27
28pub struct Client<A: Adapter> {
29    pub(crate) config: SocketIoConfig,
30    nsps: RwLock<HashMap<Str, Arc<Namespace<A>>>>,
31    router: RwLock<Router<NamespaceCtr<A>>>,
32    adapter_state: A::State,
33
34    #[cfg(feature = "state")]
35    pub(crate) state: state::TypeMap![Send + Sync],
36}
37
38// ==== impl Client ====
39
40impl<A: Adapter> Client<A> {
41    pub fn new(
42        config: SocketIoConfig,
43        adapter_state: A::State,
44        #[cfg(feature = "state")] mut state: state::TypeMap![Send + Sync],
45    ) -> Self {
46        #[cfg(feature = "state")]
47        state.freeze();
48
49        Self {
50            config,
51            nsps: RwLock::new(HashMap::new()),
52            router: RwLock::new(Router::new()),
53            adapter_state,
54            #[cfg(feature = "state")]
55            state,
56        }
57    }
58
59    /// Called when a socket connects to a new namespace
60    fn sock_connect(
61        self: &Arc<Self>,
62        auth: Option<Value>,
63        ns_path: &str,
64        esocket: &Arc<engineioxide::Socket<SocketData<A>>>,
65    ) {
66        #[cfg(feature = "tracing")]
67        tracing::debug!("auth: {:?}", auth);
68        let protocol: ProtocolVersion = esocket.protocol.into();
69        let connect = async move |ns: Arc<Namespace<A>>, esocket: Arc<EIoSocket<SocketData<A>>>| {
70            if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() {
71                // cancel the connect timeout task for v5
72                if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
73                    tx.send(()).ok();
74                }
75            }
76        };
77
78        if let Some(ns) = self.get_ns(ns_path) {
79            tokio::spawn(connect(ns, esocket.clone()));
80        } else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(ns_path) {
81            let path = Str::copy_from_slice(ns_path);
82            let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config);
83            let this = self.clone();
84            let esocket = esocket.clone();
85            let adapter = ns.adapter.clone();
86            let on_success = move || {
87                this.nsps.write().unwrap().insert(path, ns.clone());
88                tokio::spawn(connect(ns, esocket));
89            };
90            // We "ask" the adapter implementation to manage the init response itself
91            socketioxide_core::adapter::Spawnable::spawn(adapter.init(on_success));
92        } else if protocol == ProtocolVersion::V4 && ns_path == "/" {
93            #[cfg(feature = "tracing")]
94            tracing::error!(
95                "the root namespace \"/\" must be defined before any connection for protocol V4 (legacy)!"
96            );
97            esocket.close(EIoDisconnectReason::TransportClose);
98        } else {
99            let path = Str::copy_from_slice(ns_path);
100            let packet = self
101                .parser()
102                .encode(Packet::connect_error(path, "Invalid namespace"));
103            let _ = match packet {
104                Value::Str(p, _) => esocket.emit(p).map_err(|_e| {
105                    #[cfg(feature = "tracing")]
106                    tracing::error!("error while sending invalid namespace packet: {}", _e);
107                }),
108                Value::Bytes(p) => esocket.emit_binary(p).map_err(|_e| {
109                    #[cfg(feature = "tracing")]
110                    tracing::error!("error while sending invalid namespace packet: {}", _e);
111                }),
112            };
113        }
114    }
115
116    /// Propagate a packet to its target namespace
117    fn sock_propagate_packet(&self, packet: Packet, sid: Sid) -> Result<(), Error> {
118        if let Some(ns) = self.get_ns(&packet.ns) {
119            ns.recv(sid, packet.inner)
120        } else {
121            #[cfg(feature = "tracing")]
122            tracing::debug!(?sid, "invalid namespace requested: {}", packet.ns);
123            Ok(())
124        }
125    }
126
127    /// Spawn a task that will close the socket if it is not connected to a namespace
128    /// after the [`SocketIoConfig::connect_timeout`] duration
129    fn spawn_connect_timeout_task(&self, socket: Arc<EIoSocket<SocketData<A>>>) {
130        #[cfg(feature = "tracing")]
131        tracing::debug!("spawning connect timeout task");
132        let (tx, rx) = oneshot::channel();
133        socket.data.connect_recv_tx.lock().unwrap().replace(tx);
134
135        tokio::spawn(
136            tokio::time::timeout(self.config.connect_timeout, rx).map_err(move |_| {
137                #[cfg(feature = "tracing")]
138                tracing::debug!("connect timeout for socket {}", socket.id);
139                socket.close(EIoDisconnectReason::TransportClose);
140            }),
141        );
142    }
143
144    /// Adds a new namespace handler
145    pub fn add_ns<C, T>(self: Arc<Self>, path: Cow<'static, str>, callback: C) -> A::InitRes
146    where
147        C: ConnectHandler<A, T>,
148        T: Send + Sync + 'static,
149    {
150        #[cfg(feature = "tracing")]
151        tracing::debug!("adding namespace {}", path);
152
153        let ns_path = Str::from(&path);
154        let ns = Namespace::new(ns_path.clone(), callback, &self.adapter_state, &self.config);
155        let adapter = ns.adapter.clone();
156        let on_success = move || {
157            self.nsps.write().unwrap().insert(ns_path, ns);
158        };
159        adapter.init(on_success)
160    }
161
162    pub fn add_dyn_ns<C, T>(&self, path: String, callback: C) -> Result<(), matchit::InsertError>
163    where
164        C: ConnectHandler<A, T>,
165        T: Send + Sync + 'static,
166    {
167        #[cfg(feature = "tracing")]
168        tracing::debug!("adding dynamic namespace {}", &path);
169
170        let ns = NamespaceCtr::new(callback);
171        self.router.write().unwrap().insert(path, ns)
172    }
173
174    /// Deletes a namespace handler and closes all the connections to it
175    pub fn delete_ns(&self, path: &str) {
176        #[cfg(feature = "v4")]
177        if path == "/" {
178            panic!(
179                "the root namespace \"/\" cannot be deleted for the socket.io v4 protocol. See https://socket.io/docs/v3/namespaces/#main-namespace for more info"
180            );
181        }
182
183        #[cfg(feature = "tracing")]
184        tracing::debug!("deleting namespace {}", path);
185        if let Some(ns) = self.nsps.write().unwrap().remove(path) {
186            ns.close(DisconnectReason::ServerNSDisconnect)
187                .now_or_never();
188        }
189    }
190
191    pub fn get_ns(&self, path: &str) -> Option<Arc<Namespace<A>>> {
192        self.nsps.read().unwrap().get(path).cloned()
193    }
194
195    /// Closes all engine.io connections and all clients
196    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
197    pub(crate) async fn close(&self) {
198        #[cfg(feature = "tracing")]
199        tracing::debug!("closing all namespaces");
200        let ns = { std::mem::take(&mut *self.nsps.write().unwrap()) };
201        futures_util::future::join_all(
202            ns.values()
203                .map(|ns| ns.close(DisconnectReason::ClosingServer)),
204        )
205        .await;
206        #[cfg(feature = "tracing")]
207        tracing::debug!("all namespaces closed");
208    }
209
210    pub(crate) fn parser(&self) -> Parser {
211        self.config.parser
212    }
213}
214
215pub struct SocketData<A: Adapter> {
216    pub parser_state: ParserState,
217    /// Channel used to notify the socket that it has been connected to a namespace for v5
218    pub connect_recv_tx: Mutex<Option<oneshot::Sender<()>>>,
219
220    /// Used to store the [`SocketIo`] instance so it can be accessed by any sockets
221    pub io: OnceLock<SocketIo<A>>,
222}
223impl<A: Adapter> Default for SocketData<A> {
224    fn default() -> Self {
225        Self {
226            parser_state: ParserState::default(),
227            connect_recv_tx: Mutex::new(None),
228            io: OnceLock::new(),
229        }
230    }
231}
232impl<A: Adapter> fmt::Debug for SocketData<A> {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        f.debug_struct("SocketData")
235            .field("parser_state", &self.parser_state)
236            .field("connect_recv_tx", &self.connect_recv_tx)
237            .finish()
238    }
239}
240
241impl<A: Adapter> EngineIoHandler for Client<A> {
242    type Data = SocketData<A>;
243
244    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))]
245    fn on_connect(self: Arc<Self>, socket: Arc<EIoSocket<SocketData<A>>>) {
246        socket.data.io.set(SocketIo::from(self.clone())).ok();
247
248        #[cfg(feature = "tracing")]
249        tracing::debug!("eio socket connect");
250
251        let protocol: ProtocolVersion = socket.protocol.into();
252
253        // Connecting the client to the default namespace is mandatory if the SocketIO protocol is v4.
254        // Because we connect by default to the root namespace, we should ensure before that the root namespace is defined
255        match protocol {
256            ProtocolVersion::V4 => {
257                #[cfg(feature = "tracing")]
258                tracing::debug!("connecting to default namespace for v4");
259                self.sock_connect(None, "/", &socket);
260            }
261            ProtocolVersion::V5 => self.spawn_connect_timeout_task(socket),
262        }
263    }
264
265    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))]
266    fn on_disconnect(&self, socket: Arc<EIoSocket<SocketData<A>>>, reason: EIoDisconnectReason) {
267        #[cfg(feature = "tracing")]
268        tracing::debug!("eio socket disconnected");
269        let socks: Vec<_> = self
270            .nsps
271            .read()
272            .unwrap()
273            .values()
274            .filter_map(|ns| ns.get_socket(socket.id).ok())
275            .collect();
276
277        let _cnt = socks
278            .into_iter()
279            .map(|s| s.close(reason.clone().into()))
280            .count();
281
282        #[cfg(feature = "tracing")]
283        tracing::debug!("disconnect handle spawned for {_cnt} namespaces");
284    }
285
286    fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<EIoSocket<SocketData<A>>>) {
287        #[cfg(feature = "tracing")]
288        tracing::debug!("received message: {:?}", msg);
289        let packet = match self.parser().decode_str(&socket.data.parser_state, msg) {
290            Ok(packet) => packet,
291            Err(ParseError::NeedsMoreBinaryData) => return,
292            Err(_e) => {
293                #[cfg(feature = "tracing")]
294                tracing::debug!("socket deserialization error: {}", _e);
295                socket.close(EIoDisconnectReason::PacketParsingError);
296                return;
297            }
298        };
299        #[cfg(feature = "tracing")]
300        tracing::debug!("Packet: {:?}", packet);
301
302        let res: Result<(), Error> = match packet.inner {
303            PacketData::Connect(auth) => {
304                self.sock_connect(auth, &packet.ns, &socket);
305                Ok(())
306            }
307            _ => self.sock_propagate_packet(packet, socket.id),
308        };
309        if let Err(ref err) = res {
310            #[cfg(feature = "tracing")]
311            tracing::debug!(
312                "error while processing packet to socket {}: {}",
313                socket.id,
314                err
315            );
316            if let Some(reason) = err.into() {
317                socket.close(reason);
318            }
319        }
320    }
321
322    /// When a binary payload is received from a socket, it is applied to the partial binary packet
323    ///
324    /// If the packet is complete, it is propagated to the namespace
325    fn on_binary(self: &Arc<Self>, data: Bytes, socket: Arc<EIoSocket<SocketData<A>>>) {
326        #[cfg(feature = "tracing")]
327        tracing::debug!("received binary: {:?}", &data);
328        let packet = match self.parser().decode_bin(&socket.data.parser_state, data) {
329            Ok(packet) => packet,
330            Err(ParseError::NeedsMoreBinaryData) => return,
331            Err(_e) => {
332                #[cfg(feature = "tracing")]
333                tracing::debug!("socket deserialization error: {}", _e);
334                socket.close(EIoDisconnectReason::PacketParsingError);
335                return;
336            }
337        };
338
339        let res: Result<(), Error> = match packet.inner {
340            PacketData::Connect(auth) => {
341                self.sock_connect(auth, &packet.ns, &socket);
342                Ok(())
343            }
344            _ => self.sock_propagate_packet(packet, socket.id),
345        };
346        if let Err(ref err) = res {
347            #[cfg(feature = "tracing")]
348            tracing::debug!(
349                "error while propagating packet to socket {}: {}",
350                socket.id,
351                err
352            );
353            if let Some(reason) = err.into() {
354                socket.close(reason);
355            }
356        }
357    }
358}
359impl<A: Adapter> std::fmt::Debug for Client<A> {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        let mut f = f.debug_struct("Client");
362        f.field("config", &self.config).field("nsps", &self.nsps);
363        #[cfg(feature = "state")]
364        let f = f.field("state", &self.state);
365        f.finish()
366    }
367}
368
369#[doc(hidden)]
370#[cfg(feature = "__test_harness")]
371impl<A: Adapter> Client<A> {
372    pub async fn new_dummy_sock(
373        self: Arc<Self>,
374        ns: &'static str,
375        auth: impl serde::Serialize,
376    ) -> (
377        tokio::sync::mpsc::Sender<engineioxide::Packet>,
378        tokio::sync::mpsc::Receiver<engineioxide::Packet>,
379    ) {
380        let buffer_size = self.config.engine_config.max_buffer_size;
381        let sid = Sid::new();
382        let (esock, rx) =
383            EIoSocket::<SocketData<A>>::new_dummy_piped(sid, Box::new(|_, _| {}), buffer_size);
384        esock.data.io.set(SocketIo::from(self.clone())).ok();
385        let (tx1, mut rx1) = tokio::sync::mpsc::channel(buffer_size);
386        tokio::spawn({
387            let esock = esock.clone();
388            let client = self.clone();
389            async move {
390                while let Some(packet) = rx1.recv().await {
391                    match packet {
392                        engineioxide::Packet::Message(msg) => {
393                            client.on_message(msg, esock.clone());
394                        }
395                        engineioxide::Packet::Close => {
396                            client
397                                .on_disconnect(esock.clone(), EIoDisconnectReason::TransportClose);
398                        }
399                        engineioxide::Packet::Binary(bin) => {
400                            client.on_binary(bin, esock.clone());
401                        }
402                        _ => {}
403                    }
404                }
405            }
406        });
407        let parser = crate::parser::Parser::default();
408        let val = parser.encode(Packet {
409            ns: ns.into(),
410            inner: PacketData::Connect(Some(parser.encode_default(&auth).unwrap())),
411        });
412        if let Value::Str(s, _) = val {
413            self.on_message(s, esock.clone());
414        }
415
416        // wait for the socket to be connected to the namespace
417        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
418
419        (tx1, rx)
420    }
421}
422
423#[cfg(test)]
424mod test {
425    use super::*;
426    use tokio::sync::mpsc;
427
428    use crate::adapter::LocalAdapter;
429    const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(50);
430
431    fn create_client() -> Arc<super::Client<LocalAdapter>> {
432        let config = crate::SocketIoConfig {
433            connect_timeout: CONNECT_TIMEOUT,
434            ..Default::default()
435        };
436        let client = Client::new(
437            config,
438            (),
439            #[cfg(feature = "state")]
440            Default::default(),
441        );
442        let client = Arc::new(client);
443        client.clone().add_ns("/".into(), || {});
444        client
445    }
446
447    #[tokio::test]
448    async fn get_ns() {
449        let client = create_client();
450        let ns = Namespace::new(Str::from("/"), || {}, &client.adapter_state, &client.config);
451        client.nsps.write().unwrap().insert(Str::from("/"), ns);
452        assert!(client.get_ns("/").is_some());
453    }
454
455    #[tokio::test]
456    async fn io_should_always_be_set() {
457        let client = create_client();
458        let close_fn = Box::new(move |_, _| {});
459        let sock = EIoSocket::new_dummy(Sid::new(), close_fn);
460        client.on_connect(sock.clone());
461        assert!(sock.data.io.get().is_some());
462    }
463
464    #[tokio::test]
465    async fn connect_timeout_fail() {
466        let client = create_client();
467        let (close_tx, mut close_rx) = mpsc::channel(1);
468        let close_fn = Box::new(move |_, reason| close_tx.try_send(reason).unwrap());
469        let sock = EIoSocket::new_dummy(Sid::new(), close_fn);
470        client.on_connect(sock.clone());
471        // The socket is closed
472        let res = tokio::time::timeout(CONNECT_TIMEOUT * 2, close_rx.recv())
473            .await
474            .unwrap();
475        // applied in case of ns timeout
476        assert_eq!(res, Some(EIoDisconnectReason::TransportClose));
477    }
478
479    #[tokio::test]
480    async fn connect_timeout() {
481        let client = create_client();
482        let (close_tx, mut close_rx) = mpsc::channel(1);
483        let close_fn = Box::new(move |_, reason| close_tx.try_send(reason).unwrap());
484        let sock = EIoSocket::new_dummy(Sid::new(), close_fn);
485        client.clone().on_connect(sock.clone());
486        client.on_message("0".into(), sock.clone());
487        // The socket is not closed.
488        tokio::time::timeout(CONNECT_TIMEOUT * 2, close_rx.recv())
489            .await
490            .unwrap_err();
491    }
492}