socketioxide_core/adapter/
mod.rs

1//! The adapter module contains the [`CoreAdapter`] trait and other related types.
2//!
3//! It is used to implement communication between socket.io servers to share messages and state.
4//!
5//! The [`CoreLocalAdapter`] provide a local implementation that will allow any implementors to apply local
6//! operations (`broadcast_with_ack`, `broadcast`, `rooms`, etc...).
7//!
8use std::{
9    borrow::Cow,
10    collections::{HashMap, HashSet, hash_map, hash_set},
11    error::Error as StdError,
12    future::{self, Future},
13    hash::Hash,
14    slice,
15    sync::{Arc, RwLock},
16    time::Duration,
17};
18
19use engineioxide_core::{Sid, Str};
20use futures_core::{FusedStream, Stream};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use smallvec::SmallVec;
23
24use crate::{Uid, Value, packet::Packet, parser::Parse};
25use errors::{AdapterError, BroadcastError, SocketError};
26
27pub mod errors;
28#[cfg(feature = "remote-adapter")]
29pub mod remote_packet;
30
31/// A room identifier
32pub type Room = Cow<'static, str>;
33
34/// Flags that can be used to modify the behavior of the broadcast methods.
35#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
36pub enum BroadcastFlags {
37    /// Broadcast only to the current server
38    Local = 0x01,
39    /// Broadcast to all clients except the sender
40    Broadcast = 0x02,
41}
42
43/// Options that can be used to modify the behavior of the broadcast methods.
44#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
45pub struct BroadcastOptions {
46    /// The flags to apply to the broadcast represented as a bitflag.
47    flags: u8,
48    /// The rooms to broadcast to.
49    pub rooms: SmallVec<[Room; 4]>,
50    /// The rooms to exclude from the broadcast.
51    pub except: SmallVec<[Room; 4]>,
52    /// The socket id of the sender.
53    pub sid: Option<Sid>,
54    /// The target server id can be used to optimize the broadcast.
55    /// More specifically when we use broadcasting to apply a single action on a remote socket.
56    /// We now the server_id of the remote socket, so we can send the action directly to the server.
57    pub server_id: Option<Uid>,
58}
59impl BroadcastOptions {
60    /// Add any flags to the options.
61    pub fn add_flag(&mut self, flag: BroadcastFlags) {
62        self.flags |= flag as u8;
63    }
64    /// Check if the options have a flag.
65    pub fn has_flag(&self, flag: BroadcastFlags) -> bool {
66        self.flags & flag as u8 == flag as u8
67    }
68
69    /// get the flags of the options.
70    pub fn flags(&self) -> u8 {
71        self.flags
72    }
73
74    /// Set the socket id of the sender.
75    pub fn new(sid: Sid) -> Self {
76        Self {
77            sid: Some(sid),
78            ..Default::default()
79        }
80    }
81    /// Create a new broadcast options from a remote socket data.
82    pub fn new_remote(data: &RemoteSocketData) -> Self {
83        Self {
84            sid: Some(data.id),
85            server_id: Some(data.server_id),
86            ..Default::default()
87        }
88    }
89
90    /// Check if the selected options are local to the current server.
91    #[inline]
92    pub fn is_local(&self, uid: Uid) -> bool {
93        let target_sock_is_local = !self.has_flag(BroadcastFlags::Broadcast)
94            && self.server_id == Some(uid)
95            && self.rooms.is_empty()
96            && self.sid.is_some();
97        self.has_flag(BroadcastFlags::Local) || target_sock_is_local
98    }
99}
100
101/// A trait for types that can be used as a room parameter.
102///
103/// [`String`], [`Vec<String>`], [`Vec<&str>`], [`&'static str`](str) and const arrays are implemented by default.
104pub trait RoomParam: Send + 'static {
105    /// The type of the iterator returned by `into_room_iter`.
106    type IntoIter: Iterator<Item = Room>;
107
108    /// Convert `self` into an iterator of rooms.
109    fn into_room_iter(self) -> Self::IntoIter;
110}
111
112impl RoomParam for Room {
113    type IntoIter = std::iter::Once<Room>;
114    #[inline(always)]
115    fn into_room_iter(self) -> Self::IntoIter {
116        std::iter::once(self)
117    }
118}
119impl RoomParam for String {
120    type IntoIter = std::iter::Once<Room>;
121    #[inline(always)]
122    fn into_room_iter(self) -> Self::IntoIter {
123        std::iter::once(Cow::Owned(self))
124    }
125}
126impl RoomParam for Vec<String> {
127    type IntoIter = std::iter::Map<std::vec::IntoIter<String>, fn(String) -> Room>;
128    #[inline(always)]
129    fn into_room_iter(self) -> Self::IntoIter {
130        self.into_iter().map(Cow::Owned)
131    }
132}
133impl RoomParam for Vec<&'static str> {
134    type IntoIter = std::iter::Map<std::vec::IntoIter<&'static str>, fn(&'static str) -> Room>;
135    #[inline(always)]
136    fn into_room_iter(self) -> Self::IntoIter {
137        self.into_iter().map(Cow::Borrowed)
138    }
139}
140
141impl RoomParam for Vec<Room> {
142    type IntoIter = std::vec::IntoIter<Room>;
143    #[inline(always)]
144    fn into_room_iter(self) -> Self::IntoIter {
145        self.into_iter()
146    }
147}
148impl RoomParam for &'static str {
149    type IntoIter = std::iter::Once<Room>;
150    #[inline(always)]
151    fn into_room_iter(self) -> Self::IntoIter {
152        std::iter::once(Cow::Borrowed(self))
153    }
154}
155impl<const COUNT: usize> RoomParam for [&'static str; COUNT] {
156    type IntoIter =
157        std::iter::Map<std::array::IntoIter<&'static str, COUNT>, fn(&'static str) -> Room>;
158
159    #[inline(always)]
160    fn into_room_iter(self) -> Self::IntoIter {
161        self.into_iter().map(Cow::Borrowed)
162    }
163}
164impl RoomParam for &'static [&'static str] {
165    type IntoIter =
166        std::iter::Map<std::slice::Iter<'static, &'static str>, fn(&'static &'static str) -> Room>;
167
168    #[inline(always)]
169    fn into_room_iter(self) -> Self::IntoIter {
170        self.iter().map(|i| Cow::Borrowed(*i))
171    }
172}
173impl<const COUNT: usize> RoomParam for [String; COUNT] {
174    type IntoIter = std::iter::Map<std::array::IntoIter<String, COUNT>, fn(String) -> Room>;
175    #[inline(always)]
176    fn into_room_iter(self) -> Self::IntoIter {
177        self.into_iter().map(Cow::Owned)
178    }
179}
180impl RoomParam for Sid {
181    type IntoIter = std::iter::Once<Room>;
182    #[inline(always)]
183    fn into_room_iter(self) -> Self::IntoIter {
184        std::iter::once(Cow::Owned(self.to_string()))
185    }
186}
187
188/// A item yield by the ack stream.
189pub type AckStreamItem<E> = (Sid, Result<Value, E>);
190/// The [`SocketEmitter`] will be implemented by the socketioxide library.
191/// It is simply used as an abstraction to allow the adapter to communicate
192/// with the socket server without the need to depend on the socketioxide lib.
193pub trait SocketEmitter: Send + Sync + 'static {
194    /// An error that can occur when sending data an acknowledgment.
195    type AckError: StdError + Send + Serialize + DeserializeOwned + 'static;
196    /// A stream that emits the acknowledgments of multiple sockets.
197    type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;
198
199    /// Get all the socket ids in the namespace.
200    fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
201    /// Get the socket data that match the list of socket ids.
202    fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData>;
203    /// Send data to the list of socket ids.
204    fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
205    /// Send data to the list of socket ids and get a stream of acks and the number of expected acks.
206    fn send_many_with_ack(
207        &self,
208        sids: BroadcastIter<'_>,
209        packet: Packet,
210        timeout: Option<Duration>,
211    ) -> (Self::AckStream, u32);
212    /// Disconnect all the sockets in the list.
213    /// TODO: take a [`BroadcastIter`]. Currently it is impossible because it may create deadlocks
214    /// with Adapter::del_all call.
215    fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
216    /// Get the path of the namespace.
217    fn path(&self) -> &Str;
218    /// Get the parser of the namespace.
219    fn parser(&self) -> impl Parse;
220    /// Get the unique server id.
221    fn server_id(&self) -> Uid;
222}
223
224/// For static namespaces, the init response will be managed by the user.
225/// However, for dynamic namespaces, the socket.io client will manage the response.
226/// As it does not know the type of the response, the spawnable trait is used to spawn the response.
227/// Without the client having to know the type of the response.
228pub trait Spawnable {
229    /// Spawn the response. Implementors should spawn the future with `tokio::spawn` if it is an async function.
230    /// They should also print a `tracing::error` log in case of an error.
231    fn spawn(self);
232}
233impl Spawnable for () {
234    fn spawn(self) {}
235}
236
237/// A trait to add a "defined" bound to adapter types.
238/// This allow the socket io library to implement function given a *defined* adapter
239/// and not a generic `A: Adapter`.
240///
241/// This is useful to force the user to handle potential init response type [`CoreAdapter::InitRes`].
242pub trait DefinedAdapter {}
243
244/// An adapter is responsible for managing the state of the namespace.
245/// This adapter can be implemented to share the state between multiple servers.
246///
247/// A [`CoreLocalAdapter`] instance will be given when constructing this type, it will allow
248/// you to manipulate local sockets (emitting, fetching data, broadcasting).
249pub trait CoreAdapter<E: SocketEmitter>: Sized + Send + Sync + 'static {
250    /// An error that can occur when using the adapter.
251    type Error: StdError + Into<AdapterError> + Send + 'static;
252    /// A shared state between all the namespace [`CoreAdapter`].
253    /// This can be used to share a connection for example.
254    type State: Send + Sync + 'static;
255    /// A stream that emits the acknowledgments of multiple sockets.
256    type AckStream: Stream<Item = AckStreamItem<E::AckError>> + FusedStream + Send + 'static;
257    /// A named result type for the initialization of the adapter.
258    type InitRes: Spawnable + Send;
259
260    /// Creates a new adapter with the given state and local adapter.
261    ///
262    /// The state is used to share a common state between all your adapters. E.G. a connection to a remote system.
263    /// The local adapter is used to manipulate the local sockets.
264    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self;
265
266    /// Initializes the adapter. The on_success callback should be called when the adapter ready.
267    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes;
268
269    /// Closes the adapter.
270    fn close(&self) -> impl Future<Output = Result<(), Self::Error>> + Send {
271        future::ready(Ok(()))
272    }
273
274    /// Returns the number of servers.
275    fn server_count(&self) -> impl Future<Output = Result<u16, Self::Error>> + Send {
276        future::ready(Ok(1))
277    }
278
279    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`].
280    fn broadcast(
281        &self,
282        packet: Packet,
283        opts: BroadcastOptions,
284    ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
285        future::ready(
286            self.get_local()
287                .broadcast(packet, opts)
288                .map_err(BroadcastError::from),
289        )
290    }
291
292    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]
293    /// and return a stream of ack responses.
294    ///
295    /// This method does not have default implementation because GAT cannot have default impls.
296    /// <https://github.com/rust-lang/rust/issues/29661>
297    fn broadcast_with_ack(
298        &self,
299        packet: Packet,
300        opts: BroadcastOptions,
301        timeout: Option<Duration>,
302    ) -> impl Future<Output = Result<Self::AckStream, Self::Error>> + Send;
303
304    /// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
305    fn add_sockets(
306        &self,
307        opts: BroadcastOptions,
308        rooms: impl RoomParam,
309    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
310        self.get_local().add_sockets(opts, rooms);
311        future::ready(Ok(()))
312    }
313
314    /// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
315    fn del_sockets(
316        &self,
317        opts: BroadcastOptions,
318        rooms: impl RoomParam,
319    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
320        self.get_local().del_sockets(opts, rooms);
321        future::ready(Ok(()))
322    }
323
324    /// Disconnects the sockets that match the [`BroadcastOptions`].
325    fn disconnect_socket(
326        &self,
327        opts: BroadcastOptions,
328    ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
329        future::ready(
330            self.get_local()
331                .disconnect_socket(opts)
332                .map_err(BroadcastError::Socket),
333        )
334    }
335
336    /// Fetches rooms that match the [`BroadcastOptions`]
337    fn rooms(
338        &self,
339        opts: BroadcastOptions,
340    ) -> impl Future<Output = Result<Vec<Room>, Self::Error>> + Send {
341        future::ready(Ok(self.get_local().rooms(opts).into_iter().collect()))
342    }
343
344    /// Fetches remote sockets that match the [`BroadcastOptions`].
345    fn fetch_sockets(
346        &self,
347        opts: BroadcastOptions,
348    ) -> impl Future<Output = Result<Vec<RemoteSocketData>, Self::Error>> + Send {
349        future::ready(Ok(self.get_local().fetch_sockets(opts)))
350    }
351
352    /// Returns the local adapter. Used to enable default behaviors.
353    fn get_local(&self) -> &CoreLocalAdapter<E>;
354
355    //TODO: implement
356    // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result<u64, Error>;
357    // fn persist_session(&self, sid: i64);
358    // fn restore_session(&self, sid: i64) -> Session;
359}
360
361/// The default adapter. Store the state in memory.
362pub struct CoreLocalAdapter<E> {
363    rooms: RwLock<HashMap<Room, HashSet<Sid>>>,
364    sockets: RwLock<HashMap<Sid, HashSet<Room>>>,
365    emitter: E,
366}
367
368impl<E: SocketEmitter> CoreLocalAdapter<E> {
369    /// Create a new local adapter with the given sockets interface.
370    pub fn new(emitter: E) -> Self {
371        Self {
372            rooms: RwLock::new(HashMap::new()),
373            sockets: RwLock::new(HashMap::new()),
374            emitter,
375        }
376    }
377
378    /// Clears all the rooms and sockets.
379    pub fn close(&self) {
380        let mut rooms = self.rooms.write().unwrap();
381        rooms.clear();
382        rooms.shrink_to_fit();
383    }
384
385    /// Adds the socket to all the rooms.
386    pub fn add_all(&self, sid: Sid, rooms: impl RoomParam) {
387        let mut rooms_map = self.rooms.write().unwrap();
388        let mut socket_map = self.sockets.write().unwrap();
389        for room in rooms.into_room_iter() {
390            rooms_map.entry(room.clone()).or_default().insert(sid);
391            socket_map.entry(sid).or_default().insert(room);
392        }
393    }
394
395    /// Removes the socket from the rooms.
396    pub fn del(&self, sid: Sid, rooms: impl RoomParam) {
397        let mut rooms_map = self.rooms.write().unwrap();
398        let mut socket_map = self.sockets.write().unwrap();
399        for room in rooms.into_room_iter() {
400            remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || {
401                socket_map.entry(sid).and_modify(|r| {
402                    r.remove(&room);
403                });
404            });
405        }
406    }
407
408    /// Removes the socket from all the rooms.
409    pub fn del_all(&self, sid: Sid) {
410        let mut rooms_map = self.rooms.write().unwrap();
411        if let Some(rooms) = self.sockets.write().unwrap().remove(&sid) {
412            for room in rooms {
413                remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || ());
414            }
415        }
416    }
417
418    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`].
419    pub fn broadcast(
420        &self,
421        packet: Packet,
422        opts: BroadcastOptions,
423    ) -> Result<(), Vec<SocketError>> {
424        let room_map = self.rooms.read().unwrap();
425        let sids = self.apply_opts(&opts, &room_map);
426
427        if sids.is_empty() {
428            return Ok(());
429        }
430
431        let data = self.emitter.parser().encode(packet);
432        self.emitter.send_many(sids, data)
433    }
434
435    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
436    /// Also returns the number of local expected aknowledgements to know when to stop waiting.
437    pub fn broadcast_with_ack(
438        &self,
439        packet: Packet,
440        opts: BroadcastOptions,
441        timeout: Option<Duration>,
442    ) -> (E::AckStream, u32) {
443        let room_map = self.rooms.read().unwrap();
444        let sids = self.apply_opts(&opts, &room_map);
445        // We cannot pre-serialize the packet because we need to change the ack id.
446        self.emitter.send_many_with_ack(sids, packet, timeout)
447    }
448
449    /// Returns the sockets ids that match the [`BroadcastOptions`].
450    pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
451        self.apply_opts(&opts, &self.rooms.read().unwrap())
452            .collect()
453    }
454
455    /// Returns the sockets ids that match the [`BroadcastOptions`].
456    pub fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec<RemoteSocketData> {
457        let rooms = self.rooms.read().unwrap();
458        let sids = self.apply_opts(&opts, &rooms);
459        self.emitter.get_remote_sockets(sids)
460    }
461
462    /// Returns the rooms of the socket.
463    pub fn socket_rooms(&self, sid: Sid) -> HashSet<Room> {
464        self.sockets
465            .read()
466            .unwrap()
467            .get(&sid)
468            .cloned()
469            .unwrap_or_default()
470    }
471
472    /// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
473    pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
474        let rooms: Vec<Room> = rooms.into_room_iter().collect();
475        let mut room_map = self.rooms.write().unwrap();
476        let mut socket_map = self.sockets.write().unwrap();
477        // Here we have to collect sids, because we are going to modify the rooms map.
478        let sids = self.apply_opts(&opts, &room_map).collect::<Vec<_>>();
479        for sid in &sids {
480            let entry = socket_map.entry(*sid).or_default();
481            for room in &rooms {
482                entry.insert(room.clone());
483            }
484        }
485        for room in rooms {
486            let entry = room_map.entry(room).or_default();
487            for sid in &sids {
488                entry.insert(*sid);
489            }
490        }
491    }
492
493    /// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
494    pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
495        let rooms: Vec<Room> = rooms.into_room_iter().collect();
496        let mut rooms_map = self.rooms.write().unwrap();
497        let mut socket_map = self.sockets.write().unwrap();
498        let sids = self.apply_opts(&opts, &rooms_map).collect::<Vec<_>>();
499        for room in rooms {
500            for sid in &sids {
501                remove_and_clean_entry(socket_map.entry(*sid), &room, || ());
502                remove_and_clean_entry(rooms_map.entry(room.clone()), sid, || ());
503            }
504        }
505    }
506
507    /// Disconnects the sockets that match the [`BroadcastOptions`].
508    pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
509        let sids = self
510            .apply_opts(&opts, &self.rooms.read().unwrap())
511            .collect();
512        self.emitter.disconnect_many(sids)
513    }
514
515    /// Returns all the matching rooms
516    pub fn rooms(&self, opts: BroadcastOptions) -> HashSet<Room> {
517        let rooms = self.rooms.read().unwrap();
518        let sockets = self.sockets.read().unwrap();
519        let sids = self.apply_opts(&opts, &rooms);
520        sids.filter_map(|id| sockets.get(&id))
521            .flatten()
522            .cloned()
523            .collect()
524    }
525
526    /// Get the namespace path.
527    pub fn path(&self) -> &Str {
528        self.emitter.path()
529    }
530
531    /// Get the parser of the namespace.
532    pub fn parser(&self) -> impl Parse + '_ {
533        self.emitter.parser()
534    }
535    /// Get the unique server identifier
536    pub fn server_id(&self) -> Uid {
537        self.emitter.server_id()
538    }
539}
540
541/// The default broadcast iterator.
542/// Extract, flatten and filter a list of sid from a room list
543struct BroadcastRooms<'a> {
544    rooms: slice::Iter<'a, Room>,
545    rooms_map: &'a HashMap<Room, HashSet<Sid>>,
546    except: HashSet<Sid>,
547    flatten_iter: Option<hash_set::Iter<'a, Sid>>,
548}
549impl<'a> BroadcastRooms<'a> {
550    fn new(
551        rooms: &'a [Room],
552        rooms_map: &'a HashMap<Room, HashSet<Sid>>,
553        except: HashSet<Sid>,
554    ) -> Self {
555        BroadcastRooms {
556            rooms: rooms.iter(),
557            rooms_map,
558            except,
559            flatten_iter: None,
560        }
561    }
562}
563impl Iterator for BroadcastRooms<'_> {
564    type Item = Sid;
565    fn next(&mut self) -> Option<Self::Item> {
566        loop {
567            match self.flatten_iter.as_mut().and_then(Iterator::next) {
568                Some(sid) if !self.except.contains(sid) => return Some(*sid),
569                Some(_) => continue,
570                None => self.flatten_iter = None,
571            }
572
573            let room = self.rooms.next()?;
574            self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter);
575        }
576    }
577}
578
579impl<E: SocketEmitter> CoreLocalAdapter<E> {
580    /// Applies the given `opts` and return the sockets that match.
581    fn apply_opts<'a>(
582        &self,
583        opts: &'a BroadcastOptions,
584        rooms: &'a HashMap<Room, HashSet<Sid>>,
585    ) -> BroadcastIter<'a> {
586        let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
587
588        let mut except = get_except_sids(&opts.except, rooms);
589        // In case of broadcast flag + if the sender is set,
590        // we should not broadcast to it.
591        if is_broadcast && opts.sid.is_some() {
592            except.insert(opts.sid.unwrap());
593        }
594
595        if !opts.rooms.is_empty() {
596            let iter = BroadcastRooms::new(&opts.rooms, rooms, except);
597            InnerBroadcastIter::BroadcastRooms(iter).into()
598        } else if is_broadcast {
599            let sids = self.emitter.get_all_sids(|id| !except.contains(id));
600            InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into()
601        } else if let Some(id) = opts.sid {
602            InnerBroadcastIter::Single(id).into()
603        } else {
604            InnerBroadcastIter::None.into()
605        }
606    }
607}
608
609#[inline]
610fn get_except_sids(except: &[Room], rooms: &HashMap<Room, HashSet<Sid>>) -> HashSet<Sid> {
611    let mut except_sids = HashSet::new();
612    for room in except {
613        if let Some(sockets) = rooms.get(room) {
614            except_sids.extend(sockets);
615        }
616    }
617    except_sids
618}
619
620/// Remove a field from a HashSet value and remove it if empty.
621/// Call `cleanup` fn if the entry exists
622#[inline]
623fn remove_and_clean_entry<K, T: Hash + Eq>(
624    entry: hash_map::Entry<'_, K, HashSet<T>>,
625    el: &T,
626    cleanup: impl FnOnce(),
627) {
628    //TODO: use hashmap raw entry when stabilized to avoid entry clone.
629    // https://github.com/rust-lang/rust/issues/56167
630    match entry {
631        hash_map::Entry::Occupied(mut entry) => {
632            entry.get_mut().remove(el);
633            if entry.get().is_empty() {
634                entry.remove_entry();
635            }
636            cleanup();
637        }
638        hash_map::Entry::Vacant(_) => (),
639    }
640}
641
642/// An iterator that yields the socket ids that match the broadcast options.
643/// Used with the [`SocketEmitter`] interface.
644pub struct BroadcastIter<'a> {
645    inner: InnerBroadcastIter<'a>,
646}
647enum InnerBroadcastIter<'a> {
648    BroadcastRooms(BroadcastRooms<'a>),
649    GlobalBroadcast(<Vec<Sid> as IntoIterator>::IntoIter),
650    Single(Sid),
651    None,
652}
653impl BroadcastIter<'_> {
654    fn is_empty(&self) -> bool {
655        matches!(self.inner, InnerBroadcastIter::None)
656    }
657}
658impl<'a> From<InnerBroadcastIter<'a>> for BroadcastIter<'a> {
659    fn from(inner: InnerBroadcastIter<'a>) -> Self {
660        BroadcastIter { inner }
661    }
662}
663
664impl Iterator for BroadcastIter<'_> {
665    type Item = Sid;
666
667    #[inline(always)]
668    fn next(&mut self) -> Option<Self::Item> {
669        self.inner.next()
670    }
671}
672impl Iterator for InnerBroadcastIter<'_> {
673    type Item = Sid;
674
675    fn next(&mut self) -> Option<Self::Item> {
676        match self {
677            InnerBroadcastIter::BroadcastRooms(inner) => inner.next(),
678            InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(),
679            InnerBroadcastIter::Single(sid) => {
680                let sid = *sid;
681                *self = InnerBroadcastIter::None;
682                Some(sid)
683            }
684            InnerBroadcastIter::None => None,
685        }
686    }
687}
688
689/// Represent the data of a remote socket.
690#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Default, Clone)]
691pub struct RemoteSocketData {
692    /// The id of the remote socket.
693    pub id: Sid,
694    /// The server id this socket is connected to.
695    pub server_id: Uid,
696    /// The namespace this socket is connected to.
697    pub ns: Str,
698}
699
700#[cfg(test)]
701mod test {
702
703    use smallvec::smallvec;
704    use std::{
705        array,
706        pin::Pin,
707        task::{Context, Poll},
708    };
709
710    use super::*;
711
712    struct StubSockets {
713        sockets: HashSet<Sid>,
714        path: Str,
715    }
716    impl StubSockets {
717        fn new(sockets: &[Sid]) -> Self {
718            let sockets = HashSet::from_iter(sockets.iter().copied());
719            Self {
720                sockets,
721                path: Str::from("/"),
722            }
723        }
724    }
725
726    struct StubAckStream;
727    impl Stream for StubAckStream {
728        type Item = (Sid, Result<Value, StubError>);
729        fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
730            Poll::Ready(None)
731        }
732    }
733    impl FusedStream for StubAckStream {
734        fn is_terminated(&self) -> bool {
735            true
736        }
737    }
738    #[derive(Debug, Serialize, Deserialize)]
739    struct StubError;
740    impl std::fmt::Display for StubError {
741        fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
742            Ok(())
743        }
744    }
745    impl std::error::Error for StubError {}
746
747    impl SocketEmitter for StubSockets {
748        type AckError = StubError;
749        type AckStream = StubAckStream;
750        fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
751            self.sockets
752                .iter()
753                .copied()
754                .filter(|id| filter(id))
755                .collect()
756        }
757
758        fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
759            sids.map(|id| RemoteSocketData {
760                id,
761                server_id: Uid::ZERO,
762                ns: self.path.clone(),
763            })
764            .collect()
765        }
766
767        fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec<SocketError>> {
768            Ok(())
769        }
770
771        fn send_many_with_ack(
772            &self,
773            _: BroadcastIter<'_>,
774            _: Packet,
775            _: Option<Duration>,
776        ) -> (Self::AckStream, u32) {
777            (StubAckStream, 0)
778        }
779
780        fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
781            Ok(())
782        }
783
784        fn path(&self) -> &Str {
785            &self.path
786        }
787        fn parser(&self) -> impl Parse {
788            crate::parser::test::StubParser
789        }
790        fn server_id(&self) -> Uid {
791            Uid::ZERO
792        }
793    }
794
795    fn create_adapter<const S: usize>(sockets: [Sid; S]) -> CoreLocalAdapter<StubSockets> {
796        CoreLocalAdapter::new(StubSockets::new(&sockets))
797    }
798
799    #[test]
800    fn add_all() {
801        let socket = Sid::new();
802        let adapter = create_adapter([socket]);
803        adapter.add_all(socket, ["room1", "room2"]);
804        let rooms_map = adapter.rooms.read().unwrap();
805        let socket_map = adapter.sockets.read().unwrap();
806        assert_eq!(rooms_map.len(), 2);
807        assert_eq!(socket_map.len(), 1);
808        assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
809        assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
810
811        let rooms = socket_map.get(&socket).unwrap();
812        assert!(rooms.contains("room1"));
813        assert!(rooms.contains("room2"));
814    }
815
816    #[test]
817    fn del() {
818        let socket = Sid::new();
819        let adapter = create_adapter([socket]);
820        adapter.add_all(socket, ["room1", "room2"]);
821        {
822            let rooms_map = adapter.rooms.read().unwrap();
823            assert_eq!(rooms_map.len(), 2);
824            assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
825            assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
826            let socket_map = adapter.sockets.read().unwrap();
827            let rooms = socket_map.get(&socket).unwrap();
828            assert!(rooms.contains("room1"));
829            assert!(rooms.contains("room2"));
830        }
831        adapter.del(socket, "room1");
832        let rooms_map = adapter.rooms.read().unwrap();
833        let socket_map = adapter.sockets.read().unwrap();
834        assert_eq!(rooms_map.len(), 1);
835        assert!(rooms_map.get("room1").is_none());
836        assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
837        assert_eq!(socket_map.get(&socket).unwrap().len(), 1);
838    }
839    #[test]
840    fn del_all() {
841        let socket = Sid::new();
842        let adapter = create_adapter([socket]);
843        adapter.add_all(socket, ["room1", "room2"]);
844
845        {
846            let rooms_map = adapter.rooms.read().unwrap();
847            assert_eq!(rooms_map.len(), 2);
848            assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
849            assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
850
851            let socket_map = adapter.sockets.read().unwrap();
852            let rooms = socket_map.get(&socket).unwrap();
853            assert!(rooms.contains("room1"));
854            assert!(rooms.contains("room2"));
855        }
856
857        adapter.del_all(socket);
858
859        {
860            let rooms_map = adapter.rooms.read().unwrap();
861            assert_eq!(rooms_map.len(), 0);
862
863            let socket_map = adapter.sockets.read().unwrap();
864            assert!(socket_map.get(&socket).is_none());
865        }
866    }
867
868    #[test]
869    fn socket_room() {
870        let sid1 = Sid::new();
871        let sid2 = Sid::new();
872        let sid3 = Sid::new();
873        let adapter = create_adapter([sid1, sid2, sid3]);
874        adapter.add_all(sid1, ["room1", "room2"]);
875        adapter.add_all(sid2, ["room1"]);
876        adapter.add_all(sid3, ["room2"]);
877        assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room1")));
878        assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room2")));
879        assert_eq!(
880            adapter.socket_rooms(sid2).into_iter().collect::<Vec<_>>(),
881            ["room1"]
882        );
883        assert_eq!(
884            adapter.socket_rooms(sid3).into_iter().collect::<Vec<_>>(),
885            ["room2"]
886        );
887    }
888
889    #[test]
890    fn add_socket() {
891        let socket = Sid::new();
892        let adapter = create_adapter([socket]);
893        adapter.add_all(socket, ["room1"]);
894
895        let mut opts = BroadcastOptions::new(socket);
896        opts.rooms = smallvec!["room1".into()];
897        adapter.add_sockets(opts, "room2");
898        let rooms_map = adapter.rooms.read().unwrap();
899
900        assert_eq!(rooms_map.len(), 2);
901        assert!(rooms_map.get("room1").unwrap().contains(&socket));
902        assert!(rooms_map.get("room2").unwrap().contains(&socket));
903    }
904
905    #[test]
906    fn del_socket() {
907        let socket = Sid::new();
908        let adapter = create_adapter([socket]);
909        adapter.add_all(socket, ["room1"]);
910
911        let mut opts = BroadcastOptions::new(socket);
912        opts.rooms = smallvec!["room1".into()];
913        adapter.add_sockets(opts, "room2");
914
915        {
916            let rooms_map = adapter.rooms.read().unwrap();
917
918            assert_eq!(rooms_map.len(), 2);
919            assert!(rooms_map.get("room1").unwrap().contains(&socket));
920            assert!(rooms_map.get("room2").unwrap().contains(&socket));
921        }
922
923        let mut opts = BroadcastOptions::new(socket);
924        opts.rooms = smallvec!["room1".into()];
925        adapter.del_sockets(opts, "room2");
926
927        {
928            let rooms_map = adapter.rooms.read().unwrap();
929
930            assert_eq!(rooms_map.len(), 1);
931            assert!(rooms_map.get("room1").unwrap().contains(&socket));
932            assert!(rooms_map.get("room2").is_none());
933        }
934    }
935
936    #[test]
937    fn sockets() {
938        let socket0 = Sid::new();
939        let socket1 = Sid::new();
940        let socket2 = Sid::new();
941        let adapter = create_adapter([socket0, socket1, socket2]);
942        adapter.add_all(socket0, ["room1", "room2"]);
943        adapter.add_all(socket1, ["room1", "room3"]);
944        adapter.add_all(socket2, ["room2", "room3"]);
945
946        let mut opts = BroadcastOptions {
947            rooms: smallvec!["room1".into()],
948            ..Default::default()
949        };
950        let sockets = adapter.sockets(opts.clone());
951        assert_eq!(sockets.len(), 2);
952        assert!(sockets.contains(&socket0));
953        assert!(sockets.contains(&socket1));
954
955        opts.rooms = smallvec!["room2".into()];
956        let sockets = adapter.sockets(opts.clone());
957        assert_eq!(sockets.len(), 2);
958        assert!(sockets.contains(&socket0));
959        assert!(sockets.contains(&socket2));
960
961        opts.rooms = smallvec!["room3".into()];
962        let sockets = adapter.sockets(opts.clone());
963        assert_eq!(sockets.len(), 2);
964        assert!(sockets.contains(&socket1));
965        assert!(sockets.contains(&socket2));
966    }
967
968    #[test]
969    fn disconnect_socket() {
970        let socket0 = Sid::new();
971        let socket1 = Sid::new();
972        let socket2 = Sid::new();
973        let adapter = create_adapter([socket0, socket1, socket2]);
974        adapter.add_all(socket0, ["room1", "room2", "room4"]);
975        adapter.add_all(socket1, ["room1", "room3", "room5"]);
976        adapter.add_all(socket2, ["room2", "room3", "room6"]);
977
978        let mut opts = BroadcastOptions::new(socket0);
979        opts.rooms = smallvec!["room5".into()];
980        adapter.disconnect_socket(opts).unwrap();
981
982        let mut opts = BroadcastOptions::default();
983        opts.rooms.push("room2".into());
984        let sockets = adapter.sockets(opts.clone());
985        assert_eq!(sockets.len(), 2);
986        assert!(sockets.contains(&socket2));
987        assert!(sockets.contains(&socket0));
988    }
989    #[test]
990    fn disconnect_empty_opts() {
991        let adapter = create_adapter([]);
992        let opts = BroadcastOptions::default();
993        adapter.disconnect_socket(opts).unwrap();
994    }
995    #[test]
996    fn rooms() {
997        let socket0 = Sid::new();
998        let socket1 = Sid::new();
999        let socket2 = Sid::new();
1000        let adapter = create_adapter([socket0, socket1, socket2]);
1001        adapter.add_all(socket0, ["room1", "room2", "room4"]);
1002        adapter.add_all(socket1, ["room1", "room3", "room5"]);
1003        adapter.add_all(socket2, ["room2", "room3", "room6"]);
1004
1005        let mut opts = BroadcastOptions::new(socket0);
1006        opts.rooms = smallvec!["room5".into()];
1007        opts.add_flag(BroadcastFlags::Broadcast);
1008        let rooms = adapter.rooms(opts);
1009        assert_eq!(rooms.len(), 3);
1010        assert!(rooms.contains(&Cow::Borrowed("room1")));
1011        assert!(rooms.contains(&Cow::Borrowed("room3")));
1012        assert!(rooms.contains(&Cow::Borrowed("room5")));
1013
1014        let mut opts = BroadcastOptions::default();
1015        opts.rooms.push("room2".into());
1016        let rooms = adapter.rooms(opts.clone());
1017        assert_eq!(rooms.len(), 5);
1018        assert!(rooms.contains(&Cow::Borrowed("room1")));
1019        assert!(rooms.contains(&Cow::Borrowed("room2")));
1020        assert!(rooms.contains(&Cow::Borrowed("room3")));
1021        assert!(rooms.contains(&Cow::Borrowed("room4")));
1022        assert!(rooms.contains(&Cow::Borrowed("room6")));
1023    }
1024
1025    #[test]
1026    fn apply_opts() {
1027        let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new());
1028        sockets.sort();
1029        let adapter = create_adapter(sockets);
1030
1031        adapter.add_all(sockets[0], ["room1", "room2"]);
1032        adapter.add_all(sockets[1], ["room1", "room3"]);
1033        adapter.add_all(sockets[2], ["room1", "room2", "room3"]);
1034
1035        // socket 2 is the sender
1036        let mut opts = BroadcastOptions::new(sockets[2]);
1037        opts.rooms = smallvec!["room1".into()];
1038        opts.except = smallvec!["room2".into()];
1039        let sids = adapter
1040            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1041            .collect::<Vec<_>>();
1042        assert_eq!(sids, [sockets[1]]);
1043
1044        let mut opts = BroadcastOptions::new(sockets[2]);
1045        opts.add_flag(BroadcastFlags::Broadcast);
1046        let mut sids = adapter
1047            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1048            .collect::<Vec<_>>();
1049        sids.sort();
1050        assert_eq!(sids, [sockets[0], sockets[1]]);
1051
1052        let mut opts = BroadcastOptions::new(sockets[2]);
1053        opts.add_flag(BroadcastFlags::Broadcast);
1054        opts.except = smallvec!["room2".into()];
1055        let sids = adapter
1056            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1057            .collect::<Vec<_>>();
1058        assert_eq!(sids.len(), 1);
1059
1060        let opts = BroadcastOptions::new(sockets[2]);
1061        let sids = adapter
1062            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1063            .collect::<Vec<_>>();
1064        assert_eq!(sids.len(), 1);
1065        assert_eq!(sids[0], sockets[2]);
1066
1067        let opts = BroadcastOptions::new(Sid::new());
1068        let sids = adapter
1069            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1070            .collect::<Vec<_>>();
1071        assert_eq!(sids.len(), 1);
1072    }
1073
1074    #[test]
1075    fn test_is_local_opts() {
1076        let server_id = Uid::new();
1077        let remote = RemoteSocketData {
1078            id: Sid::new(),
1079            server_id,
1080            ns: "/".into(),
1081        };
1082        let opts = BroadcastOptions::new_remote(&remote);
1083        assert!(opts.is_local(server_id));
1084        assert!(!opts.is_local(Uid::new()));
1085        let opts = BroadcastOptions::new(Sid::new());
1086        assert!(!opts.is_local(Uid::new()));
1087    }
1088}