Skip to main content

socketioxide_core/adapter/
mod.rs

1//! The adapter module contains the [`CoreAdapter`] trait and other related types.
2//!
3//! It is used to implement communication between socket.io servers to share messages and state.
4//!
5//! The [`CoreLocalAdapter`] provide a local implementation that will allow any implementors to apply local
6//! operations (`broadcast_with_ack`, `broadcast`, `rooms`, etc...).
7//!
8use std::{
9    borrow::Cow,
10    collections::{HashMap, HashSet, hash_map, hash_set},
11    error::Error as StdError,
12    future::{self, Future},
13    hash::Hash,
14    slice,
15    sync::{Arc, RwLock},
16    time::Duration,
17};
18
19use engineioxide_core::{Sid, Str};
20use futures_core::{FusedStream, Stream};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use smallvec::SmallVec;
23
24use crate::{Uid, Value, packet::Packet, parser::Parse};
25use errors::{AdapterError, BroadcastError, SocketError};
26
27pub mod errors;
28#[cfg(feature = "remote-adapter")]
29pub mod remote_packet;
30
31/// A room identifier
32pub type Room = Cow<'static, str>;
33
34/// Flags that can be used to modify the behavior of the broadcast methods.
35#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
36pub enum BroadcastFlags {
37    /// Broadcast only to the current server
38    Local = 0x01,
39    /// Broadcast to all clients except the sender
40    Broadcast = 0x02,
41}
42
43/// Options that can be used to modify the behavior of the broadcast methods.
44#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
45pub struct BroadcastOptions {
46    /// The flags to apply to the broadcast represented as a bitflag.
47    flags: u8,
48    /// The rooms to broadcast to.
49    pub rooms: SmallVec<[Room; 4]>,
50    /// The rooms to exclude from the broadcast.
51    pub except: SmallVec<[Room; 4]>,
52    /// The socket id of the sender.
53    pub sid: Option<Sid>,
54    /// The target server id can be used to optimize the broadcast.
55    /// More specifically when we use broadcasting to apply a single action on a remote socket.
56    /// We now the server_id of the remote socket, so we can send the action directly to the server.
57    pub server_id: Option<Uid>,
58}
59impl BroadcastOptions {
60    /// Add any flags to the options.
61    pub fn add_flag(&mut self, flag: BroadcastFlags) {
62        self.flags |= flag as u8;
63    }
64    /// Check if the options have a flag.
65    pub fn has_flag(&self, flag: BroadcastFlags) -> bool {
66        self.flags & flag as u8 == flag as u8
67    }
68
69    /// get the flags of the options.
70    pub fn flags(&self) -> u8 {
71        self.flags
72    }
73
74    /// Set the socket id of the sender.
75    pub fn new(sid: Sid) -> Self {
76        Self {
77            sid: Some(sid),
78            ..Default::default()
79        }
80    }
81    /// Create a new broadcast options from a remote socket data.
82    pub fn new_remote(data: &RemoteSocketData) -> Self {
83        Self {
84            sid: Some(data.id),
85            server_id: Some(data.server_id),
86            ..Default::default()
87        }
88    }
89
90    /// Check if the selected options are local to the current server.
91    #[inline]
92    pub fn is_local(&self, uid: Uid) -> bool {
93        let target_sock_is_local = !self.has_flag(BroadcastFlags::Broadcast)
94            && self.server_id == Some(uid)
95            && self.rooms.is_empty()
96            && self.sid.is_some();
97        self.has_flag(BroadcastFlags::Local) || target_sock_is_local
98    }
99}
100
101/// A trait for types that can be used as a room parameter.
102///
103/// [`String`], [`Vec<String>`], [`Vec<&str>`], [`&'static str`](str) and const arrays are implemented by default.
104pub trait RoomParam: Send + 'static {
105    /// The type of the iterator returned by `into_room_iter`.
106    type IntoIter: Iterator<Item = Room>;
107
108    /// Convert `self` into an iterator of rooms.
109    fn into_room_iter(self) -> Self::IntoIter;
110}
111
112impl RoomParam for Room {
113    type IntoIter = std::iter::Once<Room>;
114    #[inline(always)]
115    fn into_room_iter(self) -> Self::IntoIter {
116        std::iter::once(self)
117    }
118}
119impl RoomParam for String {
120    type IntoIter = std::iter::Once<Room>;
121    #[inline(always)]
122    fn into_room_iter(self) -> Self::IntoIter {
123        std::iter::once(Cow::Owned(self))
124    }
125}
126impl RoomParam for Vec<String> {
127    type IntoIter = std::iter::Map<std::vec::IntoIter<String>, fn(String) -> Room>;
128    #[inline(always)]
129    fn into_room_iter(self) -> Self::IntoIter {
130        self.into_iter().map(Cow::Owned)
131    }
132}
133impl RoomParam for Vec<&'static str> {
134    type IntoIter = std::iter::Map<std::vec::IntoIter<&'static str>, fn(&'static str) -> Room>;
135    #[inline(always)]
136    fn into_room_iter(self) -> Self::IntoIter {
137        self.into_iter().map(Cow::Borrowed)
138    }
139}
140
141impl RoomParam for Vec<Room> {
142    type IntoIter = std::vec::IntoIter<Room>;
143    #[inline(always)]
144    fn into_room_iter(self) -> Self::IntoIter {
145        self.into_iter()
146    }
147}
148impl RoomParam for &'static str {
149    type IntoIter = std::iter::Once<Room>;
150    #[inline(always)]
151    fn into_room_iter(self) -> Self::IntoIter {
152        std::iter::once(Cow::Borrowed(self))
153    }
154}
155impl<const COUNT: usize> RoomParam for [&'static str; COUNT] {
156    type IntoIter =
157        std::iter::Map<std::array::IntoIter<&'static str, COUNT>, fn(&'static str) -> Room>;
158
159    #[inline(always)]
160    fn into_room_iter(self) -> Self::IntoIter {
161        self.into_iter().map(Cow::Borrowed)
162    }
163}
164impl RoomParam for &'static [&'static str] {
165    type IntoIter =
166        std::iter::Map<std::slice::Iter<'static, &'static str>, fn(&'static &'static str) -> Room>;
167
168    #[inline(always)]
169    fn into_room_iter(self) -> Self::IntoIter {
170        self.iter().map(|i| Cow::Borrowed(*i))
171    }
172}
173impl<const COUNT: usize> RoomParam for [String; COUNT] {
174    type IntoIter = std::iter::Map<std::array::IntoIter<String, COUNT>, fn(String) -> Room>;
175    #[inline(always)]
176    fn into_room_iter(self) -> Self::IntoIter {
177        self.into_iter().map(Cow::Owned)
178    }
179}
180impl RoomParam for Sid {
181    type IntoIter = std::iter::Once<Room>;
182    #[inline(always)]
183    fn into_room_iter(self) -> Self::IntoIter {
184        std::iter::once(Cow::Owned(self.to_string()))
185    }
186}
187
188/// A item yield by the ack stream.
189pub type AckStreamItem<E> = (Sid, Result<Value, E>);
190/// The [`SocketEmitter`] will be implemented by the socketioxide library.
191/// It is simply used as an abstraction to allow the adapter to communicate
192/// with the socket server without the need to depend on the socketioxide lib.
193pub trait SocketEmitter: Send + Sync + 'static {
194    /// An error that can occur when sending data an acknowledgment.
195    type AckError: StdError + Send + Serialize + DeserializeOwned + 'static;
196    /// A stream that emits the acknowledgments of multiple sockets.
197    type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;
198
199    /// Get all the socket ids in the namespace.
200    fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
201    /// Get the socket data that match the list of socket ids.
202    fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData>;
203    /// Send data to the list of socket ids.
204    fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
205    /// Send data to the list of socket ids and get a stream of acks and the number of expected acks.
206    fn send_many_with_ack(
207        &self,
208        sids: BroadcastIter<'_>,
209        packet: Packet,
210        timeout: Option<Duration>,
211    ) -> (Self::AckStream, u32);
212    /// Disconnect all the sockets in the list.
213    /// TODO: take a [`BroadcastIter`]. Currently it is impossible because it may create deadlocks
214    /// with Adapter::del_all call.
215    fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
216    /// Get the path of the namespace.
217    fn path(&self) -> &Str;
218    /// Get the parser of the namespace.
219    fn parser(&self) -> impl Parse;
220    /// Get the unique server id.
221    fn server_id(&self) -> Uid;
222    /// Get the default configured ack timeout.
223    fn ack_timeout(&self) -> Duration;
224}
225
226/// For static namespaces, the init response will be managed by the user.
227/// However, for dynamic namespaces, the socket.io client will manage the response.
228/// As it does not know the type of the response, the spawnable trait is used to spawn the response.
229/// Without the client having to know the type of the response.
230pub trait Spawnable {
231    /// Spawn the response. Implementors should spawn the future with `tokio::spawn` if it is an async function.
232    /// They should also print a `tracing::error` log in case of an error.
233    fn spawn(self);
234}
235impl Spawnable for () {
236    fn spawn(self) {}
237}
238
239/// A trait to add a "defined" bound to adapter types.
240/// This allow the socket io library to implement function given a *defined* adapter
241/// and not a generic `A: Adapter`.
242///
243/// This is useful to force the user to handle potential init response type [`CoreAdapter::InitRes`].
244pub trait DefinedAdapter {}
245
246/// An adapter is responsible for managing the state of the namespace.
247/// This adapter can be implemented to share the state between multiple servers.
248///
249/// A [`CoreLocalAdapter`] instance will be given when constructing this type, it will allow
250/// you to manipulate local sockets (emitting, fetching data, broadcasting).
251pub trait CoreAdapter<E: SocketEmitter>: Sized + Send + Sync + 'static {
252    /// An error that can occur when using the adapter.
253    type Error: StdError + Into<AdapterError> + Send + 'static;
254    /// A shared state between all the namespace [`CoreAdapter`].
255    /// This can be used to share a connection for example.
256    type State: Send + Sync + 'static;
257    /// A stream that emits the acknowledgments of multiple sockets.
258    type AckStream: Stream<Item = AckStreamItem<E::AckError>> + FusedStream + Send + 'static;
259    /// A named result type for the initialization of the adapter.
260    type InitRes: Spawnable + Send;
261
262    /// Creates a new adapter with the given state and local adapter.
263    ///
264    /// The state is used to share a common state between all your adapters. E.G. a connection to a remote system.
265    /// The local adapter is used to manipulate the local sockets.
266    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self;
267
268    /// Initializes the adapter. The on_success callback should be called when the adapter ready.
269    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes;
270
271    /// Closes the adapter.
272    fn close(&self) -> impl Future<Output = Result<(), Self::Error>> + Send {
273        future::ready(Ok(()))
274    }
275
276    /// Returns the number of servers.
277    fn server_count(&self) -> impl Future<Output = Result<u16, Self::Error>> + Send {
278        future::ready(Ok(1))
279    }
280
281    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`].
282    fn broadcast(
283        &self,
284        packet: Packet,
285        opts: BroadcastOptions,
286    ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
287        future::ready(
288            self.get_local()
289                .broadcast(packet, opts)
290                .map_err(BroadcastError::from),
291        )
292    }
293
294    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]
295    /// and return a stream of ack responses.
296    ///
297    /// This method does not have default implementation because GAT cannot have default impls.
298    /// <https://github.com/rust-lang/rust/issues/29661>
299    fn broadcast_with_ack(
300        &self,
301        packet: Packet,
302        opts: BroadcastOptions,
303        timeout: Option<Duration>,
304    ) -> impl Future<Output = Result<Self::AckStream, Self::Error>> + Send;
305
306    /// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
307    fn add_sockets(
308        &self,
309        opts: BroadcastOptions,
310        rooms: impl RoomParam,
311    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
312        self.get_local().add_sockets(opts, rooms);
313        future::ready(Ok(()))
314    }
315
316    /// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
317    fn del_sockets(
318        &self,
319        opts: BroadcastOptions,
320        rooms: impl RoomParam,
321    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
322        self.get_local().del_sockets(opts, rooms);
323        future::ready(Ok(()))
324    }
325
326    /// Disconnects the sockets that match the [`BroadcastOptions`].
327    fn disconnect_socket(
328        &self,
329        opts: BroadcastOptions,
330    ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
331        future::ready(
332            self.get_local()
333                .disconnect_socket(opts)
334                .map_err(BroadcastError::Socket),
335        )
336    }
337
338    /// Fetches rooms that match the [`BroadcastOptions`]
339    fn rooms(
340        &self,
341        opts: BroadcastOptions,
342    ) -> impl Future<Output = Result<Vec<Room>, Self::Error>> + Send {
343        future::ready(Ok(self.get_local().rooms(opts).into_iter().collect()))
344    }
345
346    /// Fetches remote sockets that match the [`BroadcastOptions`].
347    fn fetch_sockets(
348        &self,
349        opts: BroadcastOptions,
350    ) -> impl Future<Output = Result<Vec<RemoteSocketData>, Self::Error>> + Send {
351        future::ready(Ok(self.get_local().fetch_sockets(opts)))
352    }
353
354    /// Returns the local adapter. Used to enable default behaviors.
355    fn get_local(&self) -> &CoreLocalAdapter<E>;
356
357    //TODO: implement
358    // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result<u64, Error>;
359    // fn persist_session(&self, sid: i64);
360    // fn restore_session(&self, sid: i64) -> Session;
361}
362
363/// The default adapter. Store the state in memory.
364pub struct CoreLocalAdapter<E> {
365    rooms: RwLock<HashMap<Room, HashSet<Sid>>>,
366    sockets: RwLock<HashMap<Sid, HashSet<Room>>>,
367    emitter: E,
368}
369
370impl<E: SocketEmitter> CoreLocalAdapter<E> {
371    /// Create a new local adapter with the given sockets interface.
372    pub fn new(emitter: E) -> Self {
373        Self {
374            rooms: RwLock::new(HashMap::new()),
375            sockets: RwLock::new(HashMap::new()),
376            emitter,
377        }
378    }
379
380    /// Clears all the rooms and sockets.
381    pub fn close(&self) {
382        let mut rooms = self.rooms.write().unwrap();
383        rooms.clear();
384        rooms.shrink_to_fit();
385    }
386
387    /// Adds the socket to all the rooms.
388    pub fn add_all(&self, sid: Sid, rooms: impl RoomParam) {
389        let mut rooms_map = self.rooms.write().unwrap();
390        let mut socket_map = self.sockets.write().unwrap();
391        for room in rooms.into_room_iter() {
392            rooms_map.entry(room.clone()).or_default().insert(sid);
393            socket_map.entry(sid).or_default().insert(room);
394        }
395    }
396
397    /// Removes the socket from the rooms.
398    pub fn del(&self, sid: Sid, rooms: impl RoomParam) {
399        let mut rooms_map = self.rooms.write().unwrap();
400        let mut socket_map = self.sockets.write().unwrap();
401        for room in rooms.into_room_iter() {
402            remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || {
403                socket_map.entry(sid).and_modify(|r| {
404                    r.remove(&room);
405                });
406            });
407        }
408    }
409
410    /// Removes the socket from all the rooms.
411    pub fn del_all(&self, sid: Sid) {
412        let mut rooms_map = self.rooms.write().unwrap();
413        if let Some(rooms) = self.sockets.write().unwrap().remove(&sid) {
414            for room in rooms {
415                remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || ());
416            }
417        }
418    }
419
420    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`].
421    pub fn broadcast(
422        &self,
423        packet: Packet,
424        opts: BroadcastOptions,
425    ) -> Result<(), Vec<SocketError>> {
426        let room_map = self.rooms.read().unwrap();
427        let sids = self.apply_opts(&opts, &room_map);
428
429        if sids.is_empty() {
430            return Ok(());
431        }
432
433        let data = self.emitter.parser().encode(packet);
434        self.emitter.send_many(sids, data)
435    }
436
437    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
438    /// Also returns the number of local expected aknowledgements to know when to stop waiting.
439    pub fn broadcast_with_ack(
440        &self,
441        packet: Packet,
442        opts: BroadcastOptions,
443        timeout: Option<Duration>,
444    ) -> (E::AckStream, u32) {
445        let room_map = self.rooms.read().unwrap();
446        let sids = self.apply_opts(&opts, &room_map);
447        // We cannot pre-serialize the packet because we need to change the ack id.
448        self.emitter.send_many_with_ack(sids, packet, timeout)
449    }
450
451    /// Returns the sockets ids that match the [`BroadcastOptions`].
452    pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
453        self.apply_opts(&opts, &self.rooms.read().unwrap())
454            .collect()
455    }
456
457    /// Returns the sockets ids that match the [`BroadcastOptions`].
458    pub fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec<RemoteSocketData> {
459        let rooms = self.rooms.read().unwrap();
460        let sids = self.apply_opts(&opts, &rooms);
461        self.emitter.get_remote_sockets(sids)
462    }
463
464    /// Returns the rooms of the socket.
465    pub fn socket_rooms(&self, sid: Sid) -> HashSet<Room> {
466        self.sockets
467            .read()
468            .unwrap()
469            .get(&sid)
470            .cloned()
471            .unwrap_or_default()
472    }
473
474    /// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
475    pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
476        let rooms: Vec<Room> = rooms.into_room_iter().collect();
477        let mut room_map = self.rooms.write().unwrap();
478        let mut socket_map = self.sockets.write().unwrap();
479        // Here we have to collect sids, because we are going to modify the rooms map.
480        let sids = self.apply_opts(&opts, &room_map).collect::<Vec<_>>();
481        for sid in &sids {
482            let entry = socket_map.entry(*sid).or_default();
483            for room in &rooms {
484                entry.insert(room.clone());
485            }
486        }
487        for room in rooms {
488            let entry = room_map.entry(room).or_default();
489            for sid in &sids {
490                entry.insert(*sid);
491            }
492        }
493    }
494
495    /// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
496    pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
497        let rooms: Vec<Room> = rooms.into_room_iter().collect();
498        let mut rooms_map = self.rooms.write().unwrap();
499        let mut socket_map = self.sockets.write().unwrap();
500        let sids = self.apply_opts(&opts, &rooms_map).collect::<Vec<_>>();
501        for room in rooms {
502            for sid in &sids {
503                remove_and_clean_entry(socket_map.entry(*sid), &room, || ());
504                remove_and_clean_entry(rooms_map.entry(room.clone()), sid, || ());
505            }
506        }
507    }
508
509    /// Disconnects the sockets that match the [`BroadcastOptions`].
510    pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
511        let sids = self
512            .apply_opts(&opts, &self.rooms.read().unwrap())
513            .collect();
514        self.emitter.disconnect_many(sids)
515    }
516
517    /// Returns all the matching rooms
518    pub fn rooms(&self, opts: BroadcastOptions) -> HashSet<Room> {
519        let rooms = self.rooms.read().unwrap();
520        let sockets = self.sockets.read().unwrap();
521        let sids = self.apply_opts(&opts, &rooms);
522        sids.filter_map(|id| sockets.get(&id))
523            .flatten()
524            .cloned()
525            .collect()
526    }
527
528    /// Get the namespace path.
529    pub fn path(&self) -> &Str {
530        self.emitter.path()
531    }
532
533    /// Get the parser of the namespace.
534    pub fn parser(&self) -> impl Parse + '_ {
535        self.emitter.parser()
536    }
537    /// Get the unique server identifier
538    pub fn server_id(&self) -> Uid {
539        self.emitter.server_id()
540    }
541    /// Get the default configured ack timeout.
542    pub fn ack_timeout(&self) -> Duration {
543        self.emitter.ack_timeout()
544    }
545}
546
547/// The default broadcast iterator.
548/// Extract, flatten and filter a list of sid from a room list
549struct BroadcastRooms<'a> {
550    rooms: slice::Iter<'a, Room>,
551    rooms_map: &'a HashMap<Room, HashSet<Sid>>,
552    except: HashSet<Sid>,
553    flatten_iter: Option<hash_set::Iter<'a, Sid>>,
554}
555impl<'a> BroadcastRooms<'a> {
556    fn new(
557        rooms: &'a [Room],
558        rooms_map: &'a HashMap<Room, HashSet<Sid>>,
559        except: HashSet<Sid>,
560    ) -> Self {
561        BroadcastRooms {
562            rooms: rooms.iter(),
563            rooms_map,
564            except,
565            flatten_iter: None,
566        }
567    }
568}
569impl Iterator for BroadcastRooms<'_> {
570    type Item = Sid;
571    fn next(&mut self) -> Option<Self::Item> {
572        loop {
573            match self.flatten_iter.as_mut().and_then(Iterator::next) {
574                Some(sid) if !self.except.contains(sid) => return Some(*sid),
575                Some(_) => continue,
576                None => self.flatten_iter = None,
577            }
578
579            let room = self.rooms.next()?;
580            self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter);
581        }
582    }
583}
584
585impl<E: SocketEmitter> CoreLocalAdapter<E> {
586    /// Applies the given `opts` and return the sockets that match.
587    fn apply_opts<'a>(
588        &self,
589        opts: &'a BroadcastOptions,
590        rooms: &'a HashMap<Room, HashSet<Sid>>,
591    ) -> BroadcastIter<'a> {
592        let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
593
594        let mut except = get_except_sids(&opts.except, rooms);
595        // In case of broadcast flag + if the sender is set,
596        // we should not broadcast to it.
597        if is_broadcast {
598            //FIXME(1.88): switch to if let chains when available
599            if let Some(sid) = opts.sid {
600                except.insert(sid);
601            }
602        }
603
604        if !opts.rooms.is_empty() {
605            let iter = BroadcastRooms::new(&opts.rooms, rooms, except);
606            InnerBroadcastIter::BroadcastRooms(iter).into()
607        } else if is_broadcast {
608            let sids = self.emitter.get_all_sids(|id| !except.contains(id));
609            InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into()
610        } else if let Some(id) = opts.sid {
611            InnerBroadcastIter::Single(id).into()
612        } else {
613            InnerBroadcastIter::None.into()
614        }
615    }
616}
617
618#[inline]
619fn get_except_sids(except: &[Room], rooms: &HashMap<Room, HashSet<Sid>>) -> HashSet<Sid> {
620    let mut except_sids = HashSet::new();
621    for room in except {
622        if let Some(sockets) = rooms.get(room) {
623            except_sids.extend(sockets);
624        }
625    }
626    except_sids
627}
628
629/// Remove a field from a HashSet value and remove it if empty.
630/// Call `cleanup` fn if the entry exists
631#[inline]
632fn remove_and_clean_entry<K, T: Hash + Eq>(
633    entry: hash_map::Entry<'_, K, HashSet<T>>,
634    el: &T,
635    cleanup: impl FnOnce(),
636) {
637    //TODO: use hashmap raw entry when stabilized to avoid entry clone.
638    // https://github.com/rust-lang/rust/issues/56167
639    match entry {
640        hash_map::Entry::Occupied(mut entry) => {
641            entry.get_mut().remove(el);
642            if entry.get().is_empty() {
643                entry.remove_entry();
644            }
645            cleanup();
646        }
647        hash_map::Entry::Vacant(_) => (),
648    }
649}
650
651/// An iterator that yields the socket ids that match the broadcast options.
652/// Used with the [`SocketEmitter`] interface.
653pub struct BroadcastIter<'a> {
654    inner: InnerBroadcastIter<'a>,
655}
656enum InnerBroadcastIter<'a> {
657    BroadcastRooms(BroadcastRooms<'a>),
658    GlobalBroadcast(<Vec<Sid> as IntoIterator>::IntoIter),
659    Single(Sid),
660    None,
661}
662impl BroadcastIter<'_> {
663    fn is_empty(&self) -> bool {
664        matches!(self.inner, InnerBroadcastIter::None)
665    }
666}
667impl<'a> From<InnerBroadcastIter<'a>> for BroadcastIter<'a> {
668    fn from(inner: InnerBroadcastIter<'a>) -> Self {
669        BroadcastIter { inner }
670    }
671}
672
673impl Iterator for BroadcastIter<'_> {
674    type Item = Sid;
675
676    #[inline(always)]
677    fn next(&mut self) -> Option<Self::Item> {
678        self.inner.next()
679    }
680}
681impl Iterator for InnerBroadcastIter<'_> {
682    type Item = Sid;
683
684    fn next(&mut self) -> Option<Self::Item> {
685        match self {
686            InnerBroadcastIter::BroadcastRooms(inner) => inner.next(),
687            InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(),
688            InnerBroadcastIter::Single(sid) => {
689                let sid = *sid;
690                *self = InnerBroadcastIter::None;
691                Some(sid)
692            }
693            InnerBroadcastIter::None => None,
694        }
695    }
696}
697
698/// Represent the data of a remote socket.
699#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Default, Clone)]
700pub struct RemoteSocketData {
701    /// The id of the remote socket.
702    pub id: Sid,
703    /// The server id this socket is connected to.
704    pub server_id: Uid,
705    /// The namespace this socket is connected to.
706    pub ns: Str,
707}
708
709#[cfg(test)]
710mod test {
711
712    use smallvec::smallvec;
713    use std::{
714        array,
715        pin::Pin,
716        task::{Context, Poll},
717    };
718
719    use super::*;
720
721    struct StubSockets {
722        sockets: HashSet<Sid>,
723        path: Str,
724    }
725    impl StubSockets {
726        fn new(sockets: &[Sid]) -> Self {
727            let sockets = HashSet::from_iter(sockets.iter().copied());
728            Self {
729                sockets,
730                path: Str::from("/"),
731            }
732        }
733    }
734
735    struct StubAckStream;
736    impl Stream for StubAckStream {
737        type Item = (Sid, Result<Value, StubError>);
738        fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
739            Poll::Ready(None)
740        }
741    }
742    impl FusedStream for StubAckStream {
743        fn is_terminated(&self) -> bool {
744            true
745        }
746    }
747    #[derive(Debug, Serialize, Deserialize)]
748    struct StubError;
749    impl std::fmt::Display for StubError {
750        fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
751            Ok(())
752        }
753    }
754    impl std::error::Error for StubError {}
755
756    impl SocketEmitter for StubSockets {
757        type AckError = StubError;
758        type AckStream = StubAckStream;
759        fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
760            self.sockets
761                .iter()
762                .copied()
763                .filter(|id| filter(id))
764                .collect()
765        }
766
767        fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
768            sids.map(|id| RemoteSocketData {
769                id,
770                server_id: Uid::ZERO,
771                ns: self.path.clone(),
772            })
773            .collect()
774        }
775
776        fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec<SocketError>> {
777            Ok(())
778        }
779
780        fn send_many_with_ack(
781            &self,
782            _: BroadcastIter<'_>,
783            _: Packet,
784            _: Option<Duration>,
785        ) -> (Self::AckStream, u32) {
786            (StubAckStream, 0)
787        }
788
789        fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
790            Ok(())
791        }
792
793        fn path(&self) -> &Str {
794            &self.path
795        }
796        fn parser(&self) -> impl Parse {
797            crate::parser::test::StubParser
798        }
799        fn server_id(&self) -> Uid {
800            Uid::ZERO
801        }
802        fn ack_timeout(&self) -> Duration {
803            Duration::ZERO
804        }
805    }
806
807    fn create_adapter<const S: usize>(sockets: [Sid; S]) -> CoreLocalAdapter<StubSockets> {
808        CoreLocalAdapter::new(StubSockets::new(&sockets))
809    }
810
811    #[test]
812    fn add_all() {
813        let socket = Sid::new();
814        let adapter = create_adapter([socket]);
815        adapter.add_all(socket, ["room1", "room2"]);
816        let rooms_map = adapter.rooms.read().unwrap();
817        let socket_map = adapter.sockets.read().unwrap();
818        assert_eq!(rooms_map.len(), 2);
819        assert_eq!(socket_map.len(), 1);
820        assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
821        assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
822
823        let rooms = socket_map.get(&socket).unwrap();
824        assert!(rooms.contains("room1"));
825        assert!(rooms.contains("room2"));
826    }
827
828    #[test]
829    fn del() {
830        let socket = Sid::new();
831        let adapter = create_adapter([socket]);
832        adapter.add_all(socket, ["room1", "room2"]);
833        {
834            let rooms_map = adapter.rooms.read().unwrap();
835            assert_eq!(rooms_map.len(), 2);
836            assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
837            assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
838            let socket_map = adapter.sockets.read().unwrap();
839            let rooms = socket_map.get(&socket).unwrap();
840            assert!(rooms.contains("room1"));
841            assert!(rooms.contains("room2"));
842        }
843        adapter.del(socket, "room1");
844        let rooms_map = adapter.rooms.read().unwrap();
845        let socket_map = adapter.sockets.read().unwrap();
846        assert_eq!(rooms_map.len(), 1);
847        assert!(rooms_map.get("room1").is_none());
848        assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
849        assert_eq!(socket_map.get(&socket).unwrap().len(), 1);
850    }
851    #[test]
852    fn del_all() {
853        let socket = Sid::new();
854        let adapter = create_adapter([socket]);
855        adapter.add_all(socket, ["room1", "room2"]);
856
857        {
858            let rooms_map = adapter.rooms.read().unwrap();
859            assert_eq!(rooms_map.len(), 2);
860            assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
861            assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
862
863            let socket_map = adapter.sockets.read().unwrap();
864            let rooms = socket_map.get(&socket).unwrap();
865            assert!(rooms.contains("room1"));
866            assert!(rooms.contains("room2"));
867        }
868
869        adapter.del_all(socket);
870
871        {
872            let rooms_map = adapter.rooms.read().unwrap();
873            assert_eq!(rooms_map.len(), 0);
874
875            let socket_map = adapter.sockets.read().unwrap();
876            assert!(socket_map.get(&socket).is_none());
877        }
878    }
879
880    #[test]
881    fn socket_room() {
882        let sid1 = Sid::new();
883        let sid2 = Sid::new();
884        let sid3 = Sid::new();
885        let adapter = create_adapter([sid1, sid2, sid3]);
886        adapter.add_all(sid1, ["room1", "room2"]);
887        adapter.add_all(sid2, ["room1"]);
888        adapter.add_all(sid3, ["room2"]);
889        assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room1")));
890        assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room2")));
891        assert_eq!(
892            adapter.socket_rooms(sid2).into_iter().collect::<Vec<_>>(),
893            ["room1"]
894        );
895        assert_eq!(
896            adapter.socket_rooms(sid3).into_iter().collect::<Vec<_>>(),
897            ["room2"]
898        );
899    }
900
901    #[test]
902    fn add_socket() {
903        let socket = Sid::new();
904        let adapter = create_adapter([socket]);
905        adapter.add_all(socket, ["room1"]);
906
907        let mut opts = BroadcastOptions::new(socket);
908        opts.rooms = smallvec!["room1".into()];
909        adapter.add_sockets(opts, "room2");
910        let rooms_map = adapter.rooms.read().unwrap();
911
912        assert_eq!(rooms_map.len(), 2);
913        assert!(rooms_map.get("room1").unwrap().contains(&socket));
914        assert!(rooms_map.get("room2").unwrap().contains(&socket));
915    }
916
917    #[test]
918    fn del_socket() {
919        let socket = Sid::new();
920        let adapter = create_adapter([socket]);
921        adapter.add_all(socket, ["room1"]);
922
923        let mut opts = BroadcastOptions::new(socket);
924        opts.rooms = smallvec!["room1".into()];
925        adapter.add_sockets(opts, "room2");
926
927        {
928            let rooms_map = adapter.rooms.read().unwrap();
929
930            assert_eq!(rooms_map.len(), 2);
931            assert!(rooms_map.get("room1").unwrap().contains(&socket));
932            assert!(rooms_map.get("room2").unwrap().contains(&socket));
933        }
934
935        let mut opts = BroadcastOptions::new(socket);
936        opts.rooms = smallvec!["room1".into()];
937        adapter.del_sockets(opts, "room2");
938
939        {
940            let rooms_map = adapter.rooms.read().unwrap();
941
942            assert_eq!(rooms_map.len(), 1);
943            assert!(rooms_map.get("room1").unwrap().contains(&socket));
944            assert!(rooms_map.get("room2").is_none());
945        }
946    }
947
948    #[test]
949    fn sockets() {
950        let socket0 = Sid::new();
951        let socket1 = Sid::new();
952        let socket2 = Sid::new();
953        let adapter = create_adapter([socket0, socket1, socket2]);
954        adapter.add_all(socket0, ["room1", "room2"]);
955        adapter.add_all(socket1, ["room1", "room3"]);
956        adapter.add_all(socket2, ["room2", "room3"]);
957
958        let mut opts = BroadcastOptions {
959            rooms: smallvec!["room1".into()],
960            ..Default::default()
961        };
962        let sockets = adapter.sockets(opts.clone());
963        assert_eq!(sockets.len(), 2);
964        assert!(sockets.contains(&socket0));
965        assert!(sockets.contains(&socket1));
966
967        opts.rooms = smallvec!["room2".into()];
968        let sockets = adapter.sockets(opts.clone());
969        assert_eq!(sockets.len(), 2);
970        assert!(sockets.contains(&socket0));
971        assert!(sockets.contains(&socket2));
972
973        opts.rooms = smallvec!["room3".into()];
974        let sockets = adapter.sockets(opts.clone());
975        assert_eq!(sockets.len(), 2);
976        assert!(sockets.contains(&socket1));
977        assert!(sockets.contains(&socket2));
978    }
979
980    #[test]
981    fn disconnect_socket() {
982        let socket0 = Sid::new();
983        let socket1 = Sid::new();
984        let socket2 = Sid::new();
985        let adapter = create_adapter([socket0, socket1, socket2]);
986        adapter.add_all(socket0, ["room1", "room2", "room4"]);
987        adapter.add_all(socket1, ["room1", "room3", "room5"]);
988        adapter.add_all(socket2, ["room2", "room3", "room6"]);
989
990        let mut opts = BroadcastOptions::new(socket0);
991        opts.rooms = smallvec!["room5".into()];
992        adapter.disconnect_socket(opts).unwrap();
993
994        let mut opts = BroadcastOptions::default();
995        opts.rooms.push("room2".into());
996        let sockets = adapter.sockets(opts.clone());
997        assert_eq!(sockets.len(), 2);
998        assert!(sockets.contains(&socket2));
999        assert!(sockets.contains(&socket0));
1000    }
1001    #[test]
1002    fn disconnect_empty_opts() {
1003        let adapter = create_adapter([]);
1004        let opts = BroadcastOptions::default();
1005        adapter.disconnect_socket(opts).unwrap();
1006    }
1007    #[test]
1008    fn rooms() {
1009        let socket0 = Sid::new();
1010        let socket1 = Sid::new();
1011        let socket2 = Sid::new();
1012        let adapter = create_adapter([socket0, socket1, socket2]);
1013        adapter.add_all(socket0, ["room1", "room2", "room4"]);
1014        adapter.add_all(socket1, ["room1", "room3", "room5"]);
1015        adapter.add_all(socket2, ["room2", "room3", "room6"]);
1016
1017        let mut opts = BroadcastOptions::new(socket0);
1018        opts.rooms = smallvec!["room5".into()];
1019        opts.add_flag(BroadcastFlags::Broadcast);
1020        let rooms = adapter.rooms(opts);
1021        assert_eq!(rooms.len(), 3);
1022        assert!(rooms.contains(&Cow::Borrowed("room1")));
1023        assert!(rooms.contains(&Cow::Borrowed("room3")));
1024        assert!(rooms.contains(&Cow::Borrowed("room5")));
1025
1026        let mut opts = BroadcastOptions::default();
1027        opts.rooms.push("room2".into());
1028        let rooms = adapter.rooms(opts.clone());
1029        assert_eq!(rooms.len(), 5);
1030        assert!(rooms.contains(&Cow::Borrowed("room1")));
1031        assert!(rooms.contains(&Cow::Borrowed("room2")));
1032        assert!(rooms.contains(&Cow::Borrowed("room3")));
1033        assert!(rooms.contains(&Cow::Borrowed("room4")));
1034        assert!(rooms.contains(&Cow::Borrowed("room6")));
1035    }
1036
1037    #[test]
1038    fn apply_opts() {
1039        let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new());
1040        sockets.sort();
1041        let adapter = create_adapter(sockets);
1042
1043        adapter.add_all(sockets[0], ["room1", "room2"]);
1044        adapter.add_all(sockets[1], ["room1", "room3"]);
1045        adapter.add_all(sockets[2], ["room1", "room2", "room3"]);
1046
1047        // socket 2 is the sender
1048        let mut opts = BroadcastOptions::new(sockets[2]);
1049        opts.rooms = smallvec!["room1".into()];
1050        opts.except = smallvec!["room2".into()];
1051        let sids = adapter
1052            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1053            .collect::<Vec<_>>();
1054        assert_eq!(sids, [sockets[1]]);
1055
1056        let mut opts = BroadcastOptions::new(sockets[2]);
1057        opts.add_flag(BroadcastFlags::Broadcast);
1058        let mut sids = adapter
1059            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1060            .collect::<Vec<_>>();
1061        sids.sort();
1062        assert_eq!(sids, [sockets[0], sockets[1]]);
1063
1064        let mut opts = BroadcastOptions::new(sockets[2]);
1065        opts.add_flag(BroadcastFlags::Broadcast);
1066        opts.except = smallvec!["room2".into()];
1067        let sids = adapter
1068            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1069            .collect::<Vec<_>>();
1070        assert_eq!(sids.len(), 1);
1071
1072        let opts = BroadcastOptions::new(sockets[2]);
1073        let sids = adapter
1074            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1075            .collect::<Vec<_>>();
1076        assert_eq!(sids.len(), 1);
1077        assert_eq!(sids[0], sockets[2]);
1078
1079        let opts = BroadcastOptions::new(Sid::new());
1080        let sids = adapter
1081            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1082            .collect::<Vec<_>>();
1083        assert_eq!(sids.len(), 1);
1084    }
1085
1086    #[test]
1087    fn test_is_local_opts() {
1088        let server_id = Uid::new();
1089        let remote = RemoteSocketData {
1090            id: Sid::new(),
1091            server_id,
1092            ns: "/".into(),
1093        };
1094        let opts = BroadcastOptions::new_remote(&remote);
1095        assert!(opts.is_local(server_id));
1096        assert!(!opts.is_local(Uid::new()));
1097        let opts = BroadcastOptions::new(Sid::new());
1098        assert!(!opts.is_local(Uid::new()));
1099    }
1100}