1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::{Arc, Mutex, OnceLock, RwLock};
5
6use bytes::Bytes;
7use engineioxide::Str;
8use engineioxide::handler::EngineIoHandler;
9use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket};
10use futures_util::{FutureExt, TryFutureExt};
11
12use matchit::{Match, Router};
13use socketioxide_core::packet::{Packet, PacketData};
14use socketioxide_core::parser::{Parse, ParserState};
15use socketioxide_core::{Sid, Value};
16use tokio::sync::oneshot;
17
18use crate::{
19 ProtocolVersion, SocketIo, SocketIoConfig,
20 adapter::Adapter,
21 errors::Error,
22 handler::ConnectHandler,
23 ns::{Namespace, NamespaceCtr},
24 parser::{ParseError, Parser},
25 socket::DisconnectReason,
26};
27
28pub struct Client<A: Adapter> {
29 pub(crate) config: SocketIoConfig,
30 nsps: RwLock<HashMap<Str, Arc<Namespace<A>>>>,
31 router: RwLock<Router<NamespaceCtr<A>>>,
32 adapter_state: A::State,
33
34 #[cfg(feature = "state")]
35 pub(crate) state: state::TypeMap![Send + Sync],
36}
37
38impl<A: Adapter> Client<A> {
41 pub fn new(
42 config: SocketIoConfig,
43 adapter_state: A::State,
44 #[cfg(feature = "state")] mut state: state::TypeMap![Send + Sync],
45 ) -> Self {
46 #[cfg(feature = "state")]
47 state.freeze();
48
49 Self {
50 config,
51 nsps: RwLock::new(HashMap::new()),
52 router: RwLock::new(Router::new()),
53 adapter_state,
54 #[cfg(feature = "state")]
55 state,
56 }
57 }
58
59 fn sock_connect(
61 self: &Arc<Self>,
62 auth: Option<Value>,
63 ns_path: &str,
64 esocket: &Arc<engineioxide::Socket<SocketData<A>>>,
65 ) {
66 #[cfg(feature = "tracing")]
67 tracing::debug!("auth: {:?}", auth);
68 let protocol: ProtocolVersion = esocket.protocol.into();
69 let connect = async move |ns: Arc<Namespace<A>>, esocket: Arc<EIoSocket<SocketData<A>>>| {
70 if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() {
71 if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
73 tx.send(()).ok();
74 }
75 }
76 };
77
78 if let Some(ns) = self.get_ns(ns_path) {
79 tokio::spawn(connect(ns, esocket.clone()));
80 } else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(ns_path) {
81 let path = Str::copy_from_slice(ns_path);
82 let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config);
83 let this = self.clone();
84 let esocket = esocket.clone();
85 let adapter = ns.adapter.clone();
86 let on_success = move || {
87 this.nsps.write().unwrap().insert(path, ns.clone());
88 tokio::spawn(connect(ns, esocket));
89 };
90 socketioxide_core::adapter::Spawnable::spawn(adapter.init(on_success));
92 } else if protocol == ProtocolVersion::V4 && ns_path == "/" {
93 #[cfg(feature = "tracing")]
94 tracing::error!(
95 "the root namespace \"/\" must be defined before any connection for protocol V4 (legacy)!"
96 );
97 esocket.close(EIoDisconnectReason::TransportClose);
98 } else {
99 let path = Str::copy_from_slice(ns_path);
100 let packet = self
101 .parser()
102 .encode(Packet::connect_error(path, "Invalid namespace"));
103 let _ = match packet {
104 Value::Str(p, _) => esocket.emit(p).map_err(|_e| {
105 #[cfg(feature = "tracing")]
106 tracing::error!("error while sending invalid namespace packet: {}", _e);
107 }),
108 Value::Bytes(p) => esocket.emit_binary(p).map_err(|_e| {
109 #[cfg(feature = "tracing")]
110 tracing::error!("error while sending invalid namespace packet: {}", _e);
111 }),
112 };
113 }
114 }
115
116 fn sock_propagate_packet(&self, packet: Packet, sid: Sid) -> Result<(), Error> {
118 if let Some(ns) = self.get_ns(&packet.ns) {
119 ns.recv(sid, packet.inner)
120 } else {
121 #[cfg(feature = "tracing")]
122 tracing::debug!(?sid, "invalid namespace requested: {}", packet.ns);
123 Ok(())
124 }
125 }
126
127 fn spawn_connect_timeout_task(&self, socket: Arc<EIoSocket<SocketData<A>>>) {
130 #[cfg(feature = "tracing")]
131 tracing::debug!("spawning connect timeout task");
132 let (tx, rx) = oneshot::channel();
133 socket.data.connect_recv_tx.lock().unwrap().replace(tx);
134
135 tokio::spawn(
136 tokio::time::timeout(self.config.connect_timeout, rx).map_err(move |_| {
137 #[cfg(feature = "tracing")]
138 tracing::debug!("connect timeout for socket {}", socket.id);
139 socket.close(EIoDisconnectReason::TransportClose);
140 }),
141 );
142 }
143
144 pub fn add_ns<C, T>(self: Arc<Self>, path: Cow<'static, str>, callback: C) -> A::InitRes
146 where
147 C: ConnectHandler<A, T>,
148 T: Send + Sync + 'static,
149 {
150 #[cfg(feature = "tracing")]
151 tracing::debug!("adding namespace {}", path);
152
153 let ns_path = Str::from(&path);
154 let ns = Namespace::new(ns_path.clone(), callback, &self.adapter_state, &self.config);
155 let adapter = ns.adapter.clone();
156 let on_success = move || {
157 self.nsps.write().unwrap().insert(ns_path, ns);
158 };
159 adapter.init(on_success)
160 }
161
162 pub fn add_dyn_ns<C, T>(&self, path: String, callback: C) -> Result<(), matchit::InsertError>
163 where
164 C: ConnectHandler<A, T>,
165 T: Send + Sync + 'static,
166 {
167 #[cfg(feature = "tracing")]
168 tracing::debug!("adding dynamic namespace {}", &path);
169
170 let ns = NamespaceCtr::new(callback);
171 self.router.write().unwrap().insert(path, ns)
172 }
173
174 pub fn delete_ns(&self, path: &str) {
176 #[cfg(feature = "v4")]
177 if path == "/" {
178 panic!(
179 "the root namespace \"/\" cannot be deleted for the socket.io v4 protocol. See https://socket.io/docs/v3/namespaces/#main-namespace for more info"
180 );
181 }
182
183 #[cfg(feature = "tracing")]
184 tracing::debug!("deleting namespace {}", path);
185 if let Some(ns) = self.nsps.write().unwrap().remove(path) {
186 ns.close(DisconnectReason::ServerNSDisconnect)
187 .now_or_never();
188 }
189 }
190
191 pub fn get_ns(&self, path: &str) -> Option<Arc<Namespace<A>>> {
192 self.nsps.read().unwrap().get(path).cloned()
193 }
194
195 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
197 pub(crate) async fn close(&self) {
198 #[cfg(feature = "tracing")]
199 tracing::debug!("closing all namespaces");
200 let ns = { std::mem::take(&mut *self.nsps.write().unwrap()) };
201 futures_util::future::join_all(
202 ns.values()
203 .map(|ns| ns.close(DisconnectReason::ClosingServer)),
204 )
205 .await;
206 #[cfg(feature = "tracing")]
207 tracing::debug!("all namespaces closed");
208 }
209
210 pub(crate) fn parser(&self) -> Parser {
211 self.config.parser
212 }
213}
214
215pub struct SocketData<A: Adapter> {
216 pub parser_state: ParserState,
217 pub connect_recv_tx: Mutex<Option<oneshot::Sender<()>>>,
219
220 pub io: OnceLock<SocketIo<A>>,
222}
223impl<A: Adapter> Default for SocketData<A> {
224 fn default() -> Self {
225 Self {
226 parser_state: ParserState::default(),
227 connect_recv_tx: Mutex::new(None),
228 io: OnceLock::new(),
229 }
230 }
231}
232impl<A: Adapter> fmt::Debug for SocketData<A> {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 f.debug_struct("SocketData")
235 .field("parser_state", &self.parser_state)
236 .field("connect_recv_tx", &self.connect_recv_tx)
237 .finish()
238 }
239}
240
241impl<A: Adapter> EngineIoHandler for Client<A> {
242 type Data = SocketData<A>;
243
244 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))]
245 fn on_connect(self: Arc<Self>, socket: Arc<EIoSocket<SocketData<A>>>) {
246 socket.data.io.set(SocketIo::from(self.clone())).ok();
247
248 #[cfg(feature = "tracing")]
249 tracing::debug!("eio socket connect");
250
251 let protocol: ProtocolVersion = socket.protocol.into();
252
253 match protocol {
256 ProtocolVersion::V4 => {
257 #[cfg(feature = "tracing")]
258 tracing::debug!("connecting to default namespace for v4");
259 self.sock_connect(None, "/", &socket);
260 }
261 ProtocolVersion::V5 => self.spawn_connect_timeout_task(socket),
262 }
263 }
264
265 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))]
266 fn on_disconnect(&self, socket: Arc<EIoSocket<SocketData<A>>>, reason: EIoDisconnectReason) {
267 #[cfg(feature = "tracing")]
268 tracing::debug!("eio socket disconnected");
269 let socks: Vec<_> = self
270 .nsps
271 .read()
272 .unwrap()
273 .values()
274 .filter_map(|ns| ns.get_socket(socket.id).ok())
275 .collect();
276
277 let _cnt = socks
278 .into_iter()
279 .map(|s| s.close(reason.clone().into()))
280 .count();
281
282 #[cfg(feature = "tracing")]
283 tracing::debug!("disconnect handle spawned for {_cnt} namespaces");
284 }
285
286 fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<EIoSocket<SocketData<A>>>) {
287 #[cfg(feature = "tracing")]
288 tracing::debug!("received message: {:?}", msg);
289 let packet = match self.parser().decode_str(&socket.data.parser_state, msg) {
290 Ok(packet) => packet,
291 Err(ParseError::NeedsMoreBinaryData) => return,
292 Err(_e) => {
293 #[cfg(feature = "tracing")]
294 tracing::debug!("socket deserialization error: {}", _e);
295 socket.close(EIoDisconnectReason::PacketParsingError);
296 return;
297 }
298 };
299 #[cfg(feature = "tracing")]
300 tracing::debug!("Packet: {:?}", packet);
301
302 let res: Result<(), Error> = match packet.inner {
303 PacketData::Connect(auth) => {
304 self.sock_connect(auth, &packet.ns, &socket);
305 Ok(())
306 }
307 _ => self.sock_propagate_packet(packet, socket.id),
308 };
309 if let Err(ref err) = res {
310 #[cfg(feature = "tracing")]
311 tracing::debug!(
312 "error while processing packet to socket {}: {}",
313 socket.id,
314 err
315 );
316 if let Some(reason) = err.into() {
317 socket.close(reason);
318 }
319 }
320 }
321
322 fn on_binary(self: &Arc<Self>, data: Bytes, socket: Arc<EIoSocket<SocketData<A>>>) {
326 #[cfg(feature = "tracing")]
327 tracing::debug!("received binary: {:?}", &data);
328 let packet = match self.parser().decode_bin(&socket.data.parser_state, data) {
329 Ok(packet) => packet,
330 Err(ParseError::NeedsMoreBinaryData) => return,
331 Err(_e) => {
332 #[cfg(feature = "tracing")]
333 tracing::debug!("socket deserialization error: {}", _e);
334 socket.close(EIoDisconnectReason::PacketParsingError);
335 return;
336 }
337 };
338
339 let res: Result<(), Error> = match packet.inner {
340 PacketData::Connect(auth) => {
341 self.sock_connect(auth, &packet.ns, &socket);
342 Ok(())
343 }
344 _ => self.sock_propagate_packet(packet, socket.id),
345 };
346 if let Err(ref err) = res {
347 #[cfg(feature = "tracing")]
348 tracing::debug!(
349 "error while propagating packet to socket {}: {}",
350 socket.id,
351 err
352 );
353 if let Some(reason) = err.into() {
354 socket.close(reason);
355 }
356 }
357 }
358}
359impl<A: Adapter> std::fmt::Debug for Client<A> {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 let mut f = f.debug_struct("Client");
362 f.field("config", &self.config).field("nsps", &self.nsps);
363 #[cfg(feature = "state")]
364 let f = f.field("state", &self.state);
365 f.finish()
366 }
367}
368
369#[doc(hidden)]
370#[cfg(feature = "__test_harness")]
371impl<A: Adapter> Client<A> {
372 pub async fn new_dummy_sock(
373 self: Arc<Self>,
374 ns: &'static str,
375 auth: impl serde::Serialize,
376 ) -> (
377 tokio::sync::mpsc::Sender<engineioxide::Packet>,
378 tokio::sync::mpsc::Receiver<engineioxide::Packet>,
379 ) {
380 let buffer_size = self.config.engine_config.max_buffer_size;
381 let sid = Sid::new();
382 let (esock, rx) =
383 EIoSocket::<SocketData<A>>::new_dummy_piped(sid, Box::new(|_, _| {}), buffer_size);
384 esock.data.io.set(SocketIo::from(self.clone())).ok();
385 let (tx1, mut rx1) = tokio::sync::mpsc::channel(buffer_size);
386 tokio::spawn({
387 let esock = esock.clone();
388 let client = self.clone();
389 async move {
390 while let Some(packet) = rx1.recv().await {
391 match packet {
392 engineioxide::Packet::Message(msg) => {
393 client.on_message(msg, esock.clone());
394 }
395 engineioxide::Packet::Close => {
396 client
397 .on_disconnect(esock.clone(), EIoDisconnectReason::TransportClose);
398 }
399 engineioxide::Packet::Binary(bin) => {
400 client.on_binary(bin, esock.clone());
401 }
402 _ => {}
403 }
404 }
405 }
406 });
407 let parser = crate::parser::Parser::default();
408 let val = parser.encode(Packet {
409 ns: ns.into(),
410 inner: PacketData::Connect(Some(parser.encode_default(&auth).unwrap())),
411 });
412 if let Value::Str(s, _) = val {
413 self.on_message(s, esock.clone());
414 }
415
416 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
418
419 (tx1, rx)
420 }
421}
422
423#[cfg(test)]
424mod test {
425 use super::*;
426 use tokio::sync::mpsc;
427
428 use crate::adapter::LocalAdapter;
429 const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(50);
430
431 fn create_client() -> Arc<super::Client<LocalAdapter>> {
432 let config = crate::SocketIoConfig {
433 connect_timeout: CONNECT_TIMEOUT,
434 ..Default::default()
435 };
436 let client = Client::new(
437 config,
438 (),
439 #[cfg(feature = "state")]
440 Default::default(),
441 );
442 let client = Arc::new(client);
443 client.clone().add_ns("/".into(), || {});
444 client
445 }
446
447 #[tokio::test]
448 async fn get_ns() {
449 let client = create_client();
450 let ns = Namespace::new(Str::from("/"), || {}, &client.adapter_state, &client.config);
451 client.nsps.write().unwrap().insert(Str::from("/"), ns);
452 assert!(client.get_ns("/").is_some());
453 }
454
455 #[tokio::test]
456 async fn io_should_always_be_set() {
457 let client = create_client();
458 let close_fn = Box::new(move |_, _| {});
459 let sock = EIoSocket::new_dummy(Sid::new(), close_fn);
460 client.on_connect(sock.clone());
461 assert!(sock.data.io.get().is_some());
462 }
463
464 #[tokio::test]
465 async fn connect_timeout_fail() {
466 let client = create_client();
467 let (close_tx, mut close_rx) = mpsc::channel(1);
468 let close_fn = Box::new(move |_, reason| close_tx.try_send(reason).unwrap());
469 let sock = EIoSocket::new_dummy(Sid::new(), close_fn);
470 client.on_connect(sock.clone());
471 let res = tokio::time::timeout(CONNECT_TIMEOUT * 2, close_rx.recv())
473 .await
474 .unwrap();
475 assert_eq!(res, Some(EIoDisconnectReason::TransportClose));
477 }
478
479 #[tokio::test]
480 async fn connect_timeout() {
481 let client = create_client();
482 let (close_tx, mut close_rx) = mpsc::channel(1);
483 let close_fn = Box::new(move |_, reason| close_tx.try_send(reason).unwrap());
484 let sock = EIoSocket::new_dummy(Sid::new(), close_fn);
485 client.clone().on_connect(sock.clone());
486 client.on_message("0".into(), sock.clone());
487 tokio::time::timeout(CONNECT_TIMEOUT * 2, close_rx.recv())
489 .await
490 .unwrap_err();
491 }
492}