socketioxide_core/
adapter.rs

1//! The adapter module contains the [`CoreAdapter`] trait and other related types.
2//!
3//! It is used to implement communication between socket.io servers to share messages and state.
4//!
5//! The [`CoreLocalAdapter`] provide a local implementation that will allow any implementors to apply local
6//! operations (`broadcast_with_ack`, `broadcast`, `rooms`, etc...).
7use std::{
8    borrow::Cow,
9    collections::{hash_map, hash_set, HashMap, HashSet},
10    error::Error as StdError,
11    future::{self, Future},
12    hash::Hash,
13    slice,
14    sync::{Arc, RwLock},
15    time::Duration,
16};
17
18use engineioxide::{sid::Sid, Str};
19use futures_core::{FusedStream, Stream};
20use serde::{de::DeserializeOwned, Deserialize, Serialize};
21use smallvec::SmallVec;
22
23use crate::{
24    errors::{AdapterError, BroadcastError, SocketError},
25    packet::Packet,
26    parser::Parse,
27    Uid, Value,
28};
29
30/// A room identifier
31pub type Room = Cow<'static, str>;
32
33/// Flags that can be used to modify the behavior of the broadcast methods.
34#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
35pub enum BroadcastFlags {
36    /// Broadcast only to the current server
37    Local = 0x01,
38    /// Broadcast to all clients except the sender
39    Broadcast = 0x02,
40}
41
42/// Options that can be used to modify the behavior of the broadcast methods.
43#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
44pub struct BroadcastOptions {
45    /// The flags to apply to the broadcast represented as a bitflag.
46    flags: u8,
47    /// The rooms to broadcast to.
48    pub rooms: SmallVec<[Room; 4]>,
49    /// The rooms to exclude from the broadcast.
50    pub except: SmallVec<[Room; 4]>,
51    /// The socket id of the sender.
52    pub sid: Option<Sid>,
53    /// The target server id can be used to optimize the broadcast.
54    /// More specifically when we use broadcasting to apply a single action on a remote socket.
55    /// We now the server_id of the remote socket, so we can send the action directly to the server.
56    pub server_id: Option<Uid>,
57}
58impl BroadcastOptions {
59    /// Add any flags to the options.
60    pub fn add_flag(&mut self, flag: BroadcastFlags) {
61        self.flags |= flag as u8;
62    }
63    /// Check if the options have a flag.
64    pub fn has_flag(&self, flag: BroadcastFlags) -> bool {
65        self.flags & flag as u8 == flag as u8
66    }
67
68    /// get the flags of the options.
69    pub fn flags(&self) -> u8 {
70        self.flags
71    }
72
73    /// Set the socket id of the sender.
74    pub fn new(sid: Sid) -> Self {
75        Self {
76            sid: Some(sid),
77            ..Default::default()
78        }
79    }
80    /// Create a new broadcast options from a remote socket data.
81    pub fn new_remote(data: &RemoteSocketData) -> Self {
82        Self {
83            sid: Some(data.id),
84            server_id: Some(data.server_id),
85            ..Default::default()
86        }
87    }
88}
89
90/// A trait for types that can be used as a room parameter.
91///
92/// [`String`], [`Vec<String>`], [`Vec<&str>`], [`&'static str`](str) and const arrays are implemented by default.
93pub trait RoomParam: Send + 'static {
94    /// The type of the iterator returned by `into_room_iter`.
95    type IntoIter: Iterator<Item = Room>;
96
97    /// Convert `self` into an iterator of rooms.
98    fn into_room_iter(self) -> Self::IntoIter;
99}
100
101impl RoomParam for Room {
102    type IntoIter = std::iter::Once<Room>;
103    #[inline(always)]
104    fn into_room_iter(self) -> Self::IntoIter {
105        std::iter::once(self)
106    }
107}
108impl RoomParam for String {
109    type IntoIter = std::iter::Once<Room>;
110    #[inline(always)]
111    fn into_room_iter(self) -> Self::IntoIter {
112        std::iter::once(Cow::Owned(self))
113    }
114}
115impl RoomParam for Vec<String> {
116    type IntoIter = std::iter::Map<std::vec::IntoIter<String>, fn(String) -> Room>;
117    #[inline(always)]
118    fn into_room_iter(self) -> Self::IntoIter {
119        self.into_iter().map(Cow::Owned)
120    }
121}
122impl RoomParam for Vec<&'static str> {
123    type IntoIter = std::iter::Map<std::vec::IntoIter<&'static str>, fn(&'static str) -> Room>;
124    #[inline(always)]
125    fn into_room_iter(self) -> Self::IntoIter {
126        self.into_iter().map(Cow::Borrowed)
127    }
128}
129
130impl RoomParam for Vec<Room> {
131    type IntoIter = std::vec::IntoIter<Room>;
132    #[inline(always)]
133    fn into_room_iter(self) -> Self::IntoIter {
134        self.into_iter()
135    }
136}
137impl RoomParam for &'static str {
138    type IntoIter = std::iter::Once<Room>;
139    #[inline(always)]
140    fn into_room_iter(self) -> Self::IntoIter {
141        std::iter::once(Cow::Borrowed(self))
142    }
143}
144impl<const COUNT: usize> RoomParam for [&'static str; COUNT] {
145    type IntoIter =
146        std::iter::Map<std::array::IntoIter<&'static str, COUNT>, fn(&'static str) -> Room>;
147
148    #[inline(always)]
149    fn into_room_iter(self) -> Self::IntoIter {
150        self.into_iter().map(Cow::Borrowed)
151    }
152}
153impl RoomParam for &'static [&'static str] {
154    type IntoIter =
155        std::iter::Map<std::slice::Iter<'static, &'static str>, fn(&'static &'static str) -> Room>;
156
157    #[inline(always)]
158    fn into_room_iter(self) -> Self::IntoIter {
159        self.iter().map(|i| Cow::Borrowed(*i))
160    }
161}
162impl<const COUNT: usize> RoomParam for [String; COUNT] {
163    type IntoIter = std::iter::Map<std::array::IntoIter<String, COUNT>, fn(String) -> Room>;
164    #[inline(always)]
165    fn into_room_iter(self) -> Self::IntoIter {
166        self.into_iter().map(Cow::Owned)
167    }
168}
169impl RoomParam for Sid {
170    type IntoIter = std::iter::Once<Room>;
171    #[inline(always)]
172    fn into_room_iter(self) -> Self::IntoIter {
173        std::iter::once(Cow::Owned(self.to_string()))
174    }
175}
176
177/// A item yield by the ack stream.
178pub type AckStreamItem<E> = (Sid, Result<Value, E>);
179/// The [`SocketEmitter`] will be implemented by the socketioxide library.
180/// It is simply used as an abstraction to allow the adapter to communicate
181/// with the socket server without the need to depend on the socketioxide lib.
182pub trait SocketEmitter: Send + Sync + 'static {
183    /// An error that can occur when sending data an acknowledgment.
184    type AckError: StdError + Send + Serialize + DeserializeOwned + 'static;
185    /// A stream that emits the acknowledgments of multiple sockets.
186    type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;
187
188    /// Get all the socket ids in the namespace.
189    fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
190    /// Get the socket data that match the list of socket ids.
191    fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData>;
192    /// Send data to the list of socket ids.
193    fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
194    /// Send data to the list of socket ids and get a stream of acks and the number of expected acks.
195    fn send_many_with_ack(
196        &self,
197        sids: BroadcastIter<'_>,
198        packet: Packet,
199        timeout: Option<Duration>,
200    ) -> (Self::AckStream, u32);
201    /// Disconnect all the sockets in the list.
202    /// TODO: take a [`BroadcastIter`]. Currently it is impossible because it may create deadlocks
203    /// with Adapter::del_all call.
204    fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
205    /// Get the path of the namespace.
206    fn path(&self) -> &Str;
207    /// Get the parser of the namespace.
208    fn parser(&self) -> impl Parse;
209    /// Get the unique server id.
210    fn server_id(&self) -> Uid;
211}
212
213/// For static namespaces, the init response will be managed by the user.
214/// However, for dynamic namespaces, the socket.io client will manage the response.
215/// As it does not know the type of the response, the spawnable trait is used to spawn the response.
216/// Without the client having to know the type of the response.
217pub trait Spawnable {
218    /// Spawn the response. Implementors should spawn the future with `tokio::spawn` if it is an async function.
219    /// They should also print a `tracing::error` log in case of an error.
220    fn spawn(self);
221}
222impl Spawnable for () {
223    fn spawn(self) {}
224}
225
226/// A trait to add a "defined" bound to adapter types.
227/// This allow the socket io library to implement function given a *defined* adapter
228/// and not a generic `A: Adapter`.
229///
230/// This is useful to force the user to handle potential init response type [`CoreAdapter::InitRes`].
231pub trait DefinedAdapter {}
232
233/// An adapter is responsible for managing the state of the namespace.
234/// This adapter can be implemented to share the state between multiple servers.
235///
236/// A [`CoreLocalAdapter`] instance will be given when constructing this type, it will allow
237/// you to manipulate local sockets (emitting, fetching data, broadcasting).
238pub trait CoreAdapter<E: SocketEmitter>: Sized + Send + Sync + 'static {
239    /// An error that can occur when using the adapter.
240    type Error: StdError + Into<AdapterError> + Send + 'static;
241    /// A shared state between all the namespace [`CoreAdapter`].
242    /// This can be used to share a connection for example.
243    type State: Send + Sync + 'static;
244    /// A stream that emits the acknowledgments of multiple sockets.
245    type AckStream: Stream<Item = AckStreamItem<E::AckError>> + FusedStream + Send + 'static;
246    /// A named result type for the initialization of the adapter.
247    type InitRes: Spawnable + Send;
248
249    /// Creates a new adapter with the given state and local adapter.
250    ///
251    /// The state is used to share a common state between all your adapters. E.G. a connection to a remote system.
252    /// The local adapter is used to manipulate the local sockets.
253    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self;
254
255    /// Initializes the adapter. The on_success callback should be called when the adapter ready.
256    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes;
257
258    /// Closes the adapter.
259    fn close(&self) -> impl Future<Output = Result<(), Self::Error>> + Send {
260        future::ready(Ok(()))
261    }
262
263    /// Returns the number of servers.
264    fn server_count(&self) -> impl Future<Output = Result<u16, Self::Error>> + Send {
265        future::ready(Ok(1))
266    }
267
268    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`].
269    fn broadcast(
270        &self,
271        packet: Packet,
272        opts: BroadcastOptions,
273    ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
274        future::ready(
275            self.get_local()
276                .broadcast(packet, opts)
277                .map_err(BroadcastError::from),
278        )
279    }
280
281    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]
282    /// and return a stream of ack responses.
283    ///
284    /// This method does not have default implementation because GAT cannot have default impls.
285    /// <https://github.com/rust-lang/rust/issues/29661>
286    fn broadcast_with_ack(
287        &self,
288        packet: Packet,
289        opts: BroadcastOptions,
290        timeout: Option<Duration>,
291    ) -> impl Future<Output = Result<Self::AckStream, Self::Error>> + Send;
292
293    /// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
294    fn add_sockets(
295        &self,
296        opts: BroadcastOptions,
297        rooms: impl RoomParam,
298    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299        self.get_local().add_sockets(opts, rooms);
300        future::ready(Ok(()))
301    }
302
303    /// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
304    fn del_sockets(
305        &self,
306        opts: BroadcastOptions,
307        rooms: impl RoomParam,
308    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
309        self.get_local().del_sockets(opts, rooms);
310        future::ready(Ok(()))
311    }
312
313    /// Disconnects the sockets that match the [`BroadcastOptions`].
314    fn disconnect_socket(
315        &self,
316        opts: BroadcastOptions,
317    ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
318        future::ready(
319            self.get_local()
320                .disconnect_socket(opts)
321                .map_err(BroadcastError::Socket),
322        )
323    }
324
325    /// Fetches rooms that match the [`BroadcastOptions`]
326    fn rooms(
327        &self,
328        opts: BroadcastOptions,
329    ) -> impl Future<Output = Result<Vec<Room>, Self::Error>> + Send {
330        future::ready(Ok(self.get_local().rooms(opts).into_iter().collect()))
331    }
332
333    /// Fetches remote sockets that match the [`BroadcastOptions`].
334    fn fetch_sockets(
335        &self,
336        opts: BroadcastOptions,
337    ) -> impl Future<Output = Result<Vec<RemoteSocketData>, Self::Error>> + Send {
338        future::ready(Ok(self.get_local().fetch_sockets(opts)))
339    }
340
341    /// Returns the local adapter. Used to enable default behaviors.
342    fn get_local(&self) -> &CoreLocalAdapter<E>;
343
344    //TODO: implement
345    // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result<u64, Error>;
346    // fn persist_session(&self, sid: i64);
347    // fn restore_session(&self, sid: i64) -> Session;
348}
349
350/// The default adapter. Store the state in memory.
351pub struct CoreLocalAdapter<E> {
352    rooms: RwLock<HashMap<Room, HashSet<Sid>>>,
353    sockets: RwLock<HashMap<Sid, HashSet<Room>>>,
354    emitter: E,
355}
356
357impl<E: SocketEmitter> CoreLocalAdapter<E> {
358    /// Create a new local adapter with the given sockets interface.
359    pub fn new(emitter: E) -> Self {
360        Self {
361            rooms: RwLock::new(HashMap::new()),
362            sockets: RwLock::new(HashMap::new()),
363            emitter,
364        }
365    }
366
367    /// Clears all the rooms and sockets.
368    pub fn close(&self) {
369        let mut rooms = self.rooms.write().unwrap();
370        rooms.clear();
371        rooms.shrink_to_fit();
372    }
373
374    /// Adds the socket to all the rooms.
375    pub fn add_all(&self, sid: Sid, rooms: impl RoomParam) {
376        let mut rooms_map = self.rooms.write().unwrap();
377        let mut socket_map = self.sockets.write().unwrap();
378        for room in rooms.into_room_iter() {
379            rooms_map.entry(room.clone()).or_default().insert(sid);
380            socket_map.entry(sid).or_default().insert(room);
381        }
382    }
383
384    /// Removes the socket from the rooms.
385    pub fn del(&self, sid: Sid, rooms: impl RoomParam) {
386        let mut rooms_map = self.rooms.write().unwrap();
387        let mut socket_map = self.sockets.write().unwrap();
388        for room in rooms.into_room_iter() {
389            remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || {
390                socket_map.entry(sid).and_modify(|r| {
391                    r.remove(&room);
392                });
393            });
394        }
395    }
396
397    /// Removes the socket from all the rooms.
398    pub fn del_all(&self, sid: Sid) {
399        let mut rooms_map = self.rooms.write().unwrap();
400        if let Some(rooms) = self.sockets.write().unwrap().remove(&sid) {
401            for room in rooms {
402                remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || ());
403            }
404        }
405    }
406
407    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`].
408    pub fn broadcast(
409        &self,
410        packet: Packet,
411        opts: BroadcastOptions,
412    ) -> Result<(), Vec<SocketError>> {
413        let room_map = self.rooms.read().unwrap();
414        let sids = self.apply_opts(&opts, &room_map);
415
416        if sids.is_empty() {
417            return Ok(());
418        }
419
420        let data = self.emitter.parser().encode(packet);
421        self.emitter.send_many(sids, data)
422    }
423
424    /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
425    /// Also returns the number of local expected aknowledgements to know when to stop waiting.
426    pub fn broadcast_with_ack(
427        &self,
428        packet: Packet,
429        opts: BroadcastOptions,
430        timeout: Option<Duration>,
431    ) -> (E::AckStream, u32) {
432        let room_map = self.rooms.read().unwrap();
433        let sids = self.apply_opts(&opts, &room_map);
434        // We cannot pre-serialize the packet because we need to change the ack id.
435        self.emitter.send_many_with_ack(sids, packet, timeout)
436    }
437
438    /// Returns the sockets ids that match the [`BroadcastOptions`].
439    pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
440        self.apply_opts(&opts, &self.rooms.read().unwrap())
441            .collect()
442    }
443
444    /// Returns the sockets ids that match the [`BroadcastOptions`].
445    pub fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec<RemoteSocketData> {
446        let rooms = self.rooms.read().unwrap();
447        let sids = self.apply_opts(&opts, &rooms);
448        self.emitter.get_remote_sockets(sids)
449    }
450
451    /// Returns the rooms of the socket.
452    pub fn socket_rooms(&self, sid: Sid) -> HashSet<Room> {
453        self.sockets
454            .read()
455            .unwrap()
456            .get(&sid)
457            .cloned()
458            .unwrap_or_default()
459    }
460
461    /// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
462    pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
463        let rooms: Vec<Room> = rooms.into_room_iter().collect();
464        let mut room_map = self.rooms.write().unwrap();
465        let mut socket_map = self.sockets.write().unwrap();
466        // Here we have to collect sids, because we are going to modify the rooms map.
467        let sids = self.apply_opts(&opts, &room_map).collect::<Vec<_>>();
468        for sid in &sids {
469            let entry = socket_map.entry(*sid).or_default();
470            for room in &rooms {
471                entry.insert(room.clone());
472            }
473        }
474        for room in rooms {
475            let entry = room_map.entry(room).or_default();
476            for sid in &sids {
477                entry.insert(*sid);
478            }
479        }
480    }
481
482    /// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
483    pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
484        let rooms: Vec<Room> = rooms.into_room_iter().collect();
485        let mut rooms_map = self.rooms.write().unwrap();
486        let mut socket_map = self.sockets.write().unwrap();
487        let sids = self.apply_opts(&opts, &rooms_map).collect::<Vec<_>>();
488        for room in rooms {
489            for sid in &sids {
490                remove_and_clean_entry(socket_map.entry(*sid), &room, || ());
491                remove_and_clean_entry(rooms_map.entry(room.clone()), sid, || ());
492            }
493        }
494    }
495
496    /// Disconnects the sockets that match the [`BroadcastOptions`].
497    pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
498        let sids = self
499            .apply_opts(&opts, &self.rooms.read().unwrap())
500            .collect();
501        self.emitter.disconnect_many(sids)
502    }
503
504    /// Returns all the matching rooms
505    pub fn rooms(&self, opts: BroadcastOptions) -> HashSet<Room> {
506        let rooms = self.rooms.read().unwrap();
507        let sockets = self.sockets.read().unwrap();
508        let sids = self.apply_opts(&opts, &rooms);
509        sids.filter_map(|id| sockets.get(&id))
510            .flatten()
511            .cloned()
512            .collect()
513    }
514
515    /// Get the namespace path.
516    pub fn path(&self) -> &Str {
517        self.emitter.path()
518    }
519
520    /// Get the parser of the namespace.
521    pub fn parser(&self) -> impl Parse + '_ {
522        self.emitter.parser()
523    }
524    /// Get the unique server identifier
525    pub fn server_id(&self) -> Uid {
526        self.emitter.server_id()
527    }
528}
529
530/// The default broadcast iterator.
531/// Extract, flatten and filter a list of sid from a room list
532struct BroadcastRooms<'a> {
533    rooms: slice::Iter<'a, Room>,
534    rooms_map: &'a HashMap<Room, HashSet<Sid>>,
535    except: HashSet<Sid>,
536    flatten_iter: Option<hash_set::Iter<'a, Sid>>,
537}
538impl<'a> BroadcastRooms<'a> {
539    fn new(
540        rooms: &'a [Room],
541        rooms_map: &'a HashMap<Room, HashSet<Sid>>,
542        except: HashSet<Sid>,
543    ) -> Self {
544        BroadcastRooms {
545            rooms: rooms.iter(),
546            rooms_map,
547            except,
548            flatten_iter: None,
549        }
550    }
551}
552impl Iterator for BroadcastRooms<'_> {
553    type Item = Sid;
554    fn next(&mut self) -> Option<Self::Item> {
555        loop {
556            match self.flatten_iter.as_mut().and_then(Iterator::next) {
557                Some(sid) if !self.except.contains(sid) => return Some(*sid),
558                Some(_) => continue,
559                None => self.flatten_iter = None,
560            }
561
562            let room = self.rooms.next()?;
563            self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter);
564        }
565    }
566}
567
568impl<E: SocketEmitter> CoreLocalAdapter<E> {
569    /// Applies the given `opts` and return the sockets that match.
570    fn apply_opts<'a>(
571        &self,
572        opts: &'a BroadcastOptions,
573        rooms: &'a HashMap<Room, HashSet<Sid>>,
574    ) -> BroadcastIter<'a> {
575        let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
576
577        let mut except = get_except_sids(&opts.except, rooms);
578        // In case of broadcast flag + if the sender is set,
579        // we should not broadcast to it.
580        if is_broadcast && opts.sid.is_some() {
581            except.insert(opts.sid.unwrap());
582        }
583
584        if !opts.rooms.is_empty() {
585            let iter = BroadcastRooms::new(&opts.rooms, rooms, except);
586            InnerBroadcastIter::BroadcastRooms(iter).into()
587        } else if is_broadcast {
588            let sids = self.emitter.get_all_sids(|id| !except.contains(id));
589            InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into()
590        } else if let Some(id) = opts.sid {
591            InnerBroadcastIter::Single(id).into()
592        } else {
593            InnerBroadcastIter::None.into()
594        }
595    }
596}
597
598#[inline]
599fn get_except_sids(except: &[Room], rooms: &HashMap<Room, HashSet<Sid>>) -> HashSet<Sid> {
600    let mut except_sids = HashSet::new();
601    for room in except {
602        if let Some(sockets) = rooms.get(room) {
603            except_sids.extend(sockets);
604        }
605    }
606    except_sids
607}
608
609/// Remove a field from a HashSet value and remove it if empty.
610/// Call `cleanup` fn if the entry exists
611#[inline]
612fn remove_and_clean_entry<K, T: Hash + Eq>(
613    entry: hash_map::Entry<'_, K, HashSet<T>>,
614    el: &T,
615    cleanup: impl FnOnce(),
616) {
617    //TODO: use hashmap raw entry when stabilized to avoid entry clone.
618    // https://github.com/rust-lang/rust/issues/56167
619    match entry {
620        hash_map::Entry::Occupied(mut entry) => {
621            entry.get_mut().remove(el);
622            if entry.get().is_empty() {
623                entry.remove_entry();
624            }
625            cleanup();
626        }
627        hash_map::Entry::Vacant(_) => (),
628    }
629}
630
631/// An iterator that yields the socket ids that match the broadcast options.
632/// Used with the [`SocketEmitter`] interface.
633pub struct BroadcastIter<'a> {
634    inner: InnerBroadcastIter<'a>,
635}
636enum InnerBroadcastIter<'a> {
637    BroadcastRooms(BroadcastRooms<'a>),
638    GlobalBroadcast(<Vec<Sid> as IntoIterator>::IntoIter),
639    Single(Sid),
640    None,
641}
642impl BroadcastIter<'_> {
643    fn is_empty(&self) -> bool {
644        matches!(self.inner, InnerBroadcastIter::None)
645    }
646}
647impl<'a> From<InnerBroadcastIter<'a>> for BroadcastIter<'a> {
648    fn from(inner: InnerBroadcastIter<'a>) -> Self {
649        BroadcastIter { inner }
650    }
651}
652
653impl Iterator for BroadcastIter<'_> {
654    type Item = Sid;
655
656    #[inline(always)]
657    fn next(&mut self) -> Option<Self::Item> {
658        self.inner.next()
659    }
660}
661impl Iterator for InnerBroadcastIter<'_> {
662    type Item = Sid;
663
664    fn next(&mut self) -> Option<Self::Item> {
665        match self {
666            InnerBroadcastIter::BroadcastRooms(inner) => inner.next(),
667            InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(),
668            InnerBroadcastIter::Single(sid) => {
669                let sid = *sid;
670                *self = InnerBroadcastIter::None;
671                Some(sid)
672            }
673            InnerBroadcastIter::None => None,
674        }
675    }
676}
677
678/// Represent the data of a remote socket.
679#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Default, Clone)]
680pub struct RemoteSocketData {
681    /// The id of the remote socket.
682    pub id: Sid,
683    /// The server id this socket is connected to.
684    pub server_id: Uid,
685    /// The namespace this socket is connected to.
686    pub ns: Str,
687}
688
689#[cfg(test)]
690mod test {
691
692    use smallvec::smallvec;
693    use std::{
694        array,
695        pin::Pin,
696        task::{Context, Poll},
697    };
698
699    use super::*;
700
701    struct StubSockets {
702        sockets: HashSet<Sid>,
703        path: Str,
704    }
705    impl StubSockets {
706        fn new(sockets: &[Sid]) -> Self {
707            let sockets = HashSet::from_iter(sockets.iter().copied());
708            Self {
709                sockets,
710                path: Str::from("/"),
711            }
712        }
713    }
714
715    struct StubAckStream;
716    impl Stream for StubAckStream {
717        type Item = (Sid, Result<Value, StubError>);
718        fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
719            Poll::Ready(None)
720        }
721    }
722    impl FusedStream for StubAckStream {
723        fn is_terminated(&self) -> bool {
724            true
725        }
726    }
727    #[derive(Debug, Serialize, Deserialize)]
728    struct StubError;
729    impl std::fmt::Display for StubError {
730        fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
731            Ok(())
732        }
733    }
734    impl std::error::Error for StubError {}
735
736    impl SocketEmitter for StubSockets {
737        type AckError = StubError;
738        type AckStream = StubAckStream;
739        fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
740            self.sockets
741                .iter()
742                .copied()
743                .filter(|id| filter(id))
744                .collect()
745        }
746
747        fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
748            sids.map(|id| RemoteSocketData {
749                id,
750                server_id: Uid::ZERO,
751                ns: self.path.clone(),
752            })
753            .collect()
754        }
755
756        fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec<SocketError>> {
757            Ok(())
758        }
759
760        fn send_many_with_ack(
761            &self,
762            _: BroadcastIter<'_>,
763            _: Packet,
764            _: Option<Duration>,
765        ) -> (Self::AckStream, u32) {
766            (StubAckStream, 0)
767        }
768
769        fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
770            Ok(())
771        }
772
773        fn path(&self) -> &Str {
774            &self.path
775        }
776        fn parser(&self) -> impl Parse {
777            crate::parser::test::StubParser
778        }
779        fn server_id(&self) -> Uid {
780            Uid::ZERO
781        }
782    }
783
784    fn create_adapter<const S: usize>(sockets: [Sid; S]) -> CoreLocalAdapter<StubSockets> {
785        CoreLocalAdapter::new(StubSockets::new(&sockets))
786    }
787
788    #[test]
789    fn add_all() {
790        let socket = Sid::new();
791        let adapter = create_adapter([socket]);
792        adapter.add_all(socket, ["room1", "room2"]);
793        let rooms_map = adapter.rooms.read().unwrap();
794        let socket_map = adapter.sockets.read().unwrap();
795        assert_eq!(rooms_map.len(), 2);
796        assert_eq!(socket_map.len(), 1);
797        assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
798        assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
799
800        let rooms = socket_map.get(&socket).unwrap();
801        assert!(rooms.contains("room1"));
802        assert!(rooms.contains("room2"));
803    }
804
805    #[test]
806    fn del() {
807        let socket = Sid::new();
808        let adapter = create_adapter([socket]);
809        adapter.add_all(socket, ["room1", "room2"]);
810        {
811            let rooms_map = adapter.rooms.read().unwrap();
812            assert_eq!(rooms_map.len(), 2);
813            assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
814            assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
815            let socket_map = adapter.sockets.read().unwrap();
816            let rooms = socket_map.get(&socket).unwrap();
817            assert!(rooms.contains("room1"));
818            assert!(rooms.contains("room2"));
819        }
820        adapter.del(socket, "room1");
821        let rooms_map = adapter.rooms.read().unwrap();
822        let socket_map = adapter.sockets.read().unwrap();
823        assert_eq!(rooms_map.len(), 1);
824        assert!(rooms_map.get("room1").is_none());
825        assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
826        assert_eq!(socket_map.get(&socket).unwrap().len(), 1);
827    }
828    #[test]
829    fn del_all() {
830        let socket = Sid::new();
831        let adapter = create_adapter([socket]);
832        adapter.add_all(socket, ["room1", "room2"]);
833
834        {
835            let rooms_map = adapter.rooms.read().unwrap();
836            assert_eq!(rooms_map.len(), 2);
837            assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
838            assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
839
840            let socket_map = adapter.sockets.read().unwrap();
841            let rooms = socket_map.get(&socket).unwrap();
842            assert!(rooms.contains("room1"));
843            assert!(rooms.contains("room2"));
844        }
845
846        adapter.del_all(socket);
847
848        {
849            let rooms_map = adapter.rooms.read().unwrap();
850            assert_eq!(rooms_map.len(), 0);
851
852            let socket_map = adapter.sockets.read().unwrap();
853            assert!(socket_map.get(&socket).is_none());
854        }
855    }
856
857    #[test]
858    fn socket_room() {
859        let sid1 = Sid::new();
860        let sid2 = Sid::new();
861        let sid3 = Sid::new();
862        let adapter = create_adapter([sid1, sid2, sid3]);
863        adapter.add_all(sid1, ["room1", "room2"]);
864        adapter.add_all(sid2, ["room1"]);
865        adapter.add_all(sid3, ["room2"]);
866        assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room1")));
867        assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room2")));
868        assert_eq!(
869            adapter.socket_rooms(sid2).into_iter().collect::<Vec<_>>(),
870            ["room1"]
871        );
872        assert_eq!(
873            adapter.socket_rooms(sid3).into_iter().collect::<Vec<_>>(),
874            ["room2"]
875        );
876    }
877
878    #[test]
879    fn add_socket() {
880        let socket = Sid::new();
881        let adapter = create_adapter([socket]);
882        adapter.add_all(socket, ["room1"]);
883
884        let mut opts = BroadcastOptions::new(socket);
885        opts.rooms = smallvec!["room1".into()];
886        adapter.add_sockets(opts, "room2");
887        let rooms_map = adapter.rooms.read().unwrap();
888
889        assert_eq!(rooms_map.len(), 2);
890        assert!(rooms_map.get("room1").unwrap().contains(&socket));
891        assert!(rooms_map.get("room2").unwrap().contains(&socket));
892    }
893
894    #[test]
895    fn del_socket() {
896        let socket = Sid::new();
897        let adapter = create_adapter([socket]);
898        adapter.add_all(socket, ["room1"]);
899
900        let mut opts = BroadcastOptions::new(socket);
901        opts.rooms = smallvec!["room1".into()];
902        adapter.add_sockets(opts, "room2");
903
904        {
905            let rooms_map = adapter.rooms.read().unwrap();
906
907            assert_eq!(rooms_map.len(), 2);
908            assert!(rooms_map.get("room1").unwrap().contains(&socket));
909            assert!(rooms_map.get("room2").unwrap().contains(&socket));
910        }
911
912        let mut opts = BroadcastOptions::new(socket);
913        opts.rooms = smallvec!["room1".into()];
914        adapter.del_sockets(opts, "room2");
915
916        {
917            let rooms_map = adapter.rooms.read().unwrap();
918
919            assert_eq!(rooms_map.len(), 1);
920            assert!(rooms_map.get("room1").unwrap().contains(&socket));
921            assert!(rooms_map.get("room2").is_none());
922        }
923    }
924
925    #[test]
926    fn sockets() {
927        let socket0 = Sid::new();
928        let socket1 = Sid::new();
929        let socket2 = Sid::new();
930        let adapter = create_adapter([socket0, socket1, socket2]);
931        adapter.add_all(socket0, ["room1", "room2"]);
932        adapter.add_all(socket1, ["room1", "room3"]);
933        adapter.add_all(socket2, ["room2", "room3"]);
934
935        let mut opts = BroadcastOptions {
936            rooms: smallvec!["room1".into()],
937            ..Default::default()
938        };
939        let sockets = adapter.sockets(opts.clone());
940        assert_eq!(sockets.len(), 2);
941        assert!(sockets.contains(&socket0));
942        assert!(sockets.contains(&socket1));
943
944        opts.rooms = smallvec!["room2".into()];
945        let sockets = adapter.sockets(opts.clone());
946        assert_eq!(sockets.len(), 2);
947        assert!(sockets.contains(&socket0));
948        assert!(sockets.contains(&socket2));
949
950        opts.rooms = smallvec!["room3".into()];
951        let sockets = adapter.sockets(opts.clone());
952        assert_eq!(sockets.len(), 2);
953        assert!(sockets.contains(&socket1));
954        assert!(sockets.contains(&socket2));
955    }
956
957    #[test]
958    fn disconnect_socket() {
959        let socket0 = Sid::new();
960        let socket1 = Sid::new();
961        let socket2 = Sid::new();
962        let adapter = create_adapter([socket0, socket1, socket2]);
963        adapter.add_all(socket0, ["room1", "room2", "room4"]);
964        adapter.add_all(socket1, ["room1", "room3", "room5"]);
965        adapter.add_all(socket2, ["room2", "room3", "room6"]);
966
967        let mut opts = BroadcastOptions::new(socket0);
968        opts.rooms = smallvec!["room5".into()];
969        adapter.disconnect_socket(opts).unwrap();
970
971        let mut opts = BroadcastOptions::default();
972        opts.rooms.push("room2".into());
973        let sockets = adapter.sockets(opts.clone());
974        assert_eq!(sockets.len(), 2);
975        assert!(sockets.contains(&socket2));
976        assert!(sockets.contains(&socket0));
977    }
978    #[test]
979    fn disconnect_empty_opts() {
980        let adapter = create_adapter([]);
981        let opts = BroadcastOptions::default();
982        adapter.disconnect_socket(opts).unwrap();
983    }
984    #[test]
985    fn rooms() {
986        let socket0 = Sid::new();
987        let socket1 = Sid::new();
988        let socket2 = Sid::new();
989        let adapter = create_adapter([socket0, socket1, socket2]);
990        adapter.add_all(socket0, ["room1", "room2", "room4"]);
991        adapter.add_all(socket1, ["room1", "room3", "room5"]);
992        adapter.add_all(socket2, ["room2", "room3", "room6"]);
993
994        let mut opts = BroadcastOptions::new(socket0);
995        opts.rooms = smallvec!["room5".into()];
996        opts.add_flag(BroadcastFlags::Broadcast);
997        let rooms = adapter.rooms(opts);
998        assert_eq!(rooms.len(), 3);
999        assert!(rooms.contains(&Cow::Borrowed("room1")));
1000        assert!(rooms.contains(&Cow::Borrowed("room3")));
1001        assert!(rooms.contains(&Cow::Borrowed("room5")));
1002
1003        let mut opts = BroadcastOptions::default();
1004        opts.rooms.push("room2".into());
1005        let rooms = adapter.rooms(opts.clone());
1006        assert_eq!(rooms.len(), 5);
1007        assert!(rooms.contains(&Cow::Borrowed("room1")));
1008        assert!(rooms.contains(&Cow::Borrowed("room2")));
1009        assert!(rooms.contains(&Cow::Borrowed("room3")));
1010        assert!(rooms.contains(&Cow::Borrowed("room4")));
1011        assert!(rooms.contains(&Cow::Borrowed("room6")));
1012    }
1013
1014    #[test]
1015    fn apply_opts() {
1016        let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new());
1017        sockets.sort();
1018        let adapter = create_adapter(sockets);
1019
1020        adapter.add_all(sockets[0], ["room1", "room2"]);
1021        adapter.add_all(sockets[1], ["room1", "room3"]);
1022        adapter.add_all(sockets[2], ["room1", "room2", "room3"]);
1023
1024        // socket 2 is the sender
1025        let mut opts = BroadcastOptions::new(sockets[2]);
1026        opts.rooms = smallvec!["room1".into()];
1027        opts.except = smallvec!["room2".into()];
1028        let sids = adapter
1029            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1030            .collect::<Vec<_>>();
1031        assert_eq!(sids, [sockets[1]]);
1032
1033        let mut opts = BroadcastOptions::new(sockets[2]);
1034        opts.add_flag(BroadcastFlags::Broadcast);
1035        let mut sids = adapter
1036            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1037            .collect::<Vec<_>>();
1038        sids.sort();
1039        assert_eq!(sids, [sockets[0], sockets[1]]);
1040
1041        let mut opts = BroadcastOptions::new(sockets[2]);
1042        opts.add_flag(BroadcastFlags::Broadcast);
1043        opts.except = smallvec!["room2".into()];
1044        let sids = adapter
1045            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1046            .collect::<Vec<_>>();
1047        assert_eq!(sids.len(), 1);
1048
1049        let opts = BroadcastOptions::new(sockets[2]);
1050        let sids = adapter
1051            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1052            .collect::<Vec<_>>();
1053        assert_eq!(sids.len(), 1);
1054        assert_eq!(sids[0], sockets[2]);
1055
1056        let opts = BroadcastOptions::new(Sid::new());
1057        let sids = adapter
1058            .apply_opts(&opts, &adapter.rooms.read().unwrap())
1059            .collect::<Vec<_>>();
1060        assert_eq!(sids.len(), 1);
1061    }
1062}