1use std::{
9 borrow::Cow,
10 collections::{HashMap, HashSet, hash_map, hash_set},
11 error::Error as StdError,
12 future::{self, Future},
13 hash::Hash,
14 slice,
15 sync::{Arc, RwLock},
16 time::Duration,
17};
18
19use engineioxide_core::{Sid, Str};
20use futures_core::{FusedStream, Stream};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use smallvec::SmallVec;
23
24use crate::{Uid, Value, packet::Packet, parser::Parse};
25use errors::{AdapterError, BroadcastError, SocketError};
26
27pub mod errors;
28#[cfg(feature = "remote-adapter")]
29pub mod remote_packet;
30
31pub type Room = Cow<'static, str>;
33
34#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
36pub enum BroadcastFlags {
37 Local = 0x01,
39 Broadcast = 0x02,
41}
42
43#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
45pub struct BroadcastOptions {
46 flags: u8,
48 pub rooms: SmallVec<[Room; 4]>,
50 pub except: SmallVec<[Room; 4]>,
52 pub sid: Option<Sid>,
54 pub server_id: Option<Uid>,
58}
59impl BroadcastOptions {
60 pub fn add_flag(&mut self, flag: BroadcastFlags) {
62 self.flags |= flag as u8;
63 }
64 pub fn has_flag(&self, flag: BroadcastFlags) -> bool {
66 self.flags & flag as u8 == flag as u8
67 }
68
69 pub fn flags(&self) -> u8 {
71 self.flags
72 }
73
74 pub fn new(sid: Sid) -> Self {
76 Self {
77 sid: Some(sid),
78 ..Default::default()
79 }
80 }
81 pub fn new_remote(data: &RemoteSocketData) -> Self {
83 Self {
84 sid: Some(data.id),
85 server_id: Some(data.server_id),
86 ..Default::default()
87 }
88 }
89
90 #[inline]
92 pub fn is_local(&self, uid: Uid) -> bool {
93 let target_sock_is_local = !self.has_flag(BroadcastFlags::Broadcast)
94 && self.server_id == Some(uid)
95 && self.rooms.is_empty()
96 && self.sid.is_some();
97 self.has_flag(BroadcastFlags::Local) || target_sock_is_local
98 }
99}
100
101pub trait RoomParam: Send + 'static {
105 type IntoIter: Iterator<Item = Room>;
107
108 fn into_room_iter(self) -> Self::IntoIter;
110}
111
112impl RoomParam for Room {
113 type IntoIter = std::iter::Once<Room>;
114 #[inline(always)]
115 fn into_room_iter(self) -> Self::IntoIter {
116 std::iter::once(self)
117 }
118}
119impl RoomParam for String {
120 type IntoIter = std::iter::Once<Room>;
121 #[inline(always)]
122 fn into_room_iter(self) -> Self::IntoIter {
123 std::iter::once(Cow::Owned(self))
124 }
125}
126impl RoomParam for Vec<String> {
127 type IntoIter = std::iter::Map<std::vec::IntoIter<String>, fn(String) -> Room>;
128 #[inline(always)]
129 fn into_room_iter(self) -> Self::IntoIter {
130 self.into_iter().map(Cow::Owned)
131 }
132}
133impl RoomParam for Vec<&'static str> {
134 type IntoIter = std::iter::Map<std::vec::IntoIter<&'static str>, fn(&'static str) -> Room>;
135 #[inline(always)]
136 fn into_room_iter(self) -> Self::IntoIter {
137 self.into_iter().map(Cow::Borrowed)
138 }
139}
140
141impl RoomParam for Vec<Room> {
142 type IntoIter = std::vec::IntoIter<Room>;
143 #[inline(always)]
144 fn into_room_iter(self) -> Self::IntoIter {
145 self.into_iter()
146 }
147}
148impl RoomParam for &'static str {
149 type IntoIter = std::iter::Once<Room>;
150 #[inline(always)]
151 fn into_room_iter(self) -> Self::IntoIter {
152 std::iter::once(Cow::Borrowed(self))
153 }
154}
155impl<const COUNT: usize> RoomParam for [&'static str; COUNT] {
156 type IntoIter =
157 std::iter::Map<std::array::IntoIter<&'static str, COUNT>, fn(&'static str) -> Room>;
158
159 #[inline(always)]
160 fn into_room_iter(self) -> Self::IntoIter {
161 self.into_iter().map(Cow::Borrowed)
162 }
163}
164impl RoomParam for &'static [&'static str] {
165 type IntoIter =
166 std::iter::Map<std::slice::Iter<'static, &'static str>, fn(&'static &'static str) -> Room>;
167
168 #[inline(always)]
169 fn into_room_iter(self) -> Self::IntoIter {
170 self.iter().map(|i| Cow::Borrowed(*i))
171 }
172}
173impl<const COUNT: usize> RoomParam for [String; COUNT] {
174 type IntoIter = std::iter::Map<std::array::IntoIter<String, COUNT>, fn(String) -> Room>;
175 #[inline(always)]
176 fn into_room_iter(self) -> Self::IntoIter {
177 self.into_iter().map(Cow::Owned)
178 }
179}
180impl RoomParam for Sid {
181 type IntoIter = std::iter::Once<Room>;
182 #[inline(always)]
183 fn into_room_iter(self) -> Self::IntoIter {
184 std::iter::once(Cow::Owned(self.to_string()))
185 }
186}
187
188pub type AckStreamItem<E> = (Sid, Result<Value, E>);
190pub trait SocketEmitter: Send + Sync + 'static {
194 type AckError: StdError + Send + Serialize + DeserializeOwned + 'static;
196 type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;
198
199 fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
201 fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData>;
203 fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
205 fn send_many_with_ack(
207 &self,
208 sids: BroadcastIter<'_>,
209 packet: Packet,
210 timeout: Option<Duration>,
211 ) -> (Self::AckStream, u32);
212 fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
216 fn path(&self) -> &Str;
218 fn parser(&self) -> impl Parse;
220 fn server_id(&self) -> Uid;
222}
223
224pub trait Spawnable {
229 fn spawn(self);
232}
233impl Spawnable for () {
234 fn spawn(self) {}
235}
236
237pub trait DefinedAdapter {}
243
244pub trait CoreAdapter<E: SocketEmitter>: Sized + Send + Sync + 'static {
250 type Error: StdError + Into<AdapterError> + Send + 'static;
252 type State: Send + Sync + 'static;
255 type AckStream: Stream<Item = AckStreamItem<E::AckError>> + FusedStream + Send + 'static;
257 type InitRes: Spawnable + Send;
259
260 fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self;
265
266 fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes;
268
269 fn close(&self) -> impl Future<Output = Result<(), Self::Error>> + Send {
271 future::ready(Ok(()))
272 }
273
274 fn server_count(&self) -> impl Future<Output = Result<u16, Self::Error>> + Send {
276 future::ready(Ok(1))
277 }
278
279 fn broadcast(
281 &self,
282 packet: Packet,
283 opts: BroadcastOptions,
284 ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
285 future::ready(
286 self.get_local()
287 .broadcast(packet, opts)
288 .map_err(BroadcastError::from),
289 )
290 }
291
292 fn broadcast_with_ack(
298 &self,
299 packet: Packet,
300 opts: BroadcastOptions,
301 timeout: Option<Duration>,
302 ) -> impl Future<Output = Result<Self::AckStream, Self::Error>> + Send;
303
304 fn add_sockets(
306 &self,
307 opts: BroadcastOptions,
308 rooms: impl RoomParam,
309 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
310 self.get_local().add_sockets(opts, rooms);
311 future::ready(Ok(()))
312 }
313
314 fn del_sockets(
316 &self,
317 opts: BroadcastOptions,
318 rooms: impl RoomParam,
319 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
320 self.get_local().del_sockets(opts, rooms);
321 future::ready(Ok(()))
322 }
323
324 fn disconnect_socket(
326 &self,
327 opts: BroadcastOptions,
328 ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
329 future::ready(
330 self.get_local()
331 .disconnect_socket(opts)
332 .map_err(BroadcastError::Socket),
333 )
334 }
335
336 fn rooms(
338 &self,
339 opts: BroadcastOptions,
340 ) -> impl Future<Output = Result<Vec<Room>, Self::Error>> + Send {
341 future::ready(Ok(self.get_local().rooms(opts).into_iter().collect()))
342 }
343
344 fn fetch_sockets(
346 &self,
347 opts: BroadcastOptions,
348 ) -> impl Future<Output = Result<Vec<RemoteSocketData>, Self::Error>> + Send {
349 future::ready(Ok(self.get_local().fetch_sockets(opts)))
350 }
351
352 fn get_local(&self) -> &CoreLocalAdapter<E>;
354
355 }
360
361pub struct CoreLocalAdapter<E> {
363 rooms: RwLock<HashMap<Room, HashSet<Sid>>>,
364 sockets: RwLock<HashMap<Sid, HashSet<Room>>>,
365 emitter: E,
366}
367
368impl<E: SocketEmitter> CoreLocalAdapter<E> {
369 pub fn new(emitter: E) -> Self {
371 Self {
372 rooms: RwLock::new(HashMap::new()),
373 sockets: RwLock::new(HashMap::new()),
374 emitter,
375 }
376 }
377
378 pub fn close(&self) {
380 let mut rooms = self.rooms.write().unwrap();
381 rooms.clear();
382 rooms.shrink_to_fit();
383 }
384
385 pub fn add_all(&self, sid: Sid, rooms: impl RoomParam) {
387 let mut rooms_map = self.rooms.write().unwrap();
388 let mut socket_map = self.sockets.write().unwrap();
389 for room in rooms.into_room_iter() {
390 rooms_map.entry(room.clone()).or_default().insert(sid);
391 socket_map.entry(sid).or_default().insert(room);
392 }
393 }
394
395 pub fn del(&self, sid: Sid, rooms: impl RoomParam) {
397 let mut rooms_map = self.rooms.write().unwrap();
398 let mut socket_map = self.sockets.write().unwrap();
399 for room in rooms.into_room_iter() {
400 remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || {
401 socket_map.entry(sid).and_modify(|r| {
402 r.remove(&room);
403 });
404 });
405 }
406 }
407
408 pub fn del_all(&self, sid: Sid) {
410 let mut rooms_map = self.rooms.write().unwrap();
411 if let Some(rooms) = self.sockets.write().unwrap().remove(&sid) {
412 for room in rooms {
413 remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || ());
414 }
415 }
416 }
417
418 pub fn broadcast(
420 &self,
421 packet: Packet,
422 opts: BroadcastOptions,
423 ) -> Result<(), Vec<SocketError>> {
424 let room_map = self.rooms.read().unwrap();
425 let sids = self.apply_opts(&opts, &room_map);
426
427 if sids.is_empty() {
428 return Ok(());
429 }
430
431 let data = self.emitter.parser().encode(packet);
432 self.emitter.send_many(sids, data)
433 }
434
435 pub fn broadcast_with_ack(
438 &self,
439 packet: Packet,
440 opts: BroadcastOptions,
441 timeout: Option<Duration>,
442 ) -> (E::AckStream, u32) {
443 let room_map = self.rooms.read().unwrap();
444 let sids = self.apply_opts(&opts, &room_map);
445 self.emitter.send_many_with_ack(sids, packet, timeout)
447 }
448
449 pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
451 self.apply_opts(&opts, &self.rooms.read().unwrap())
452 .collect()
453 }
454
455 pub fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec<RemoteSocketData> {
457 let rooms = self.rooms.read().unwrap();
458 let sids = self.apply_opts(&opts, &rooms);
459 self.emitter.get_remote_sockets(sids)
460 }
461
462 pub fn socket_rooms(&self, sid: Sid) -> HashSet<Room> {
464 self.sockets
465 .read()
466 .unwrap()
467 .get(&sid)
468 .cloned()
469 .unwrap_or_default()
470 }
471
472 pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
474 let rooms: Vec<Room> = rooms.into_room_iter().collect();
475 let mut room_map = self.rooms.write().unwrap();
476 let mut socket_map = self.sockets.write().unwrap();
477 let sids = self.apply_opts(&opts, &room_map).collect::<Vec<_>>();
479 for sid in &sids {
480 let entry = socket_map.entry(*sid).or_default();
481 for room in &rooms {
482 entry.insert(room.clone());
483 }
484 }
485 for room in rooms {
486 let entry = room_map.entry(room).or_default();
487 for sid in &sids {
488 entry.insert(*sid);
489 }
490 }
491 }
492
493 pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
495 let rooms: Vec<Room> = rooms.into_room_iter().collect();
496 let mut rooms_map = self.rooms.write().unwrap();
497 let mut socket_map = self.sockets.write().unwrap();
498 let sids = self.apply_opts(&opts, &rooms_map).collect::<Vec<_>>();
499 for room in rooms {
500 for sid in &sids {
501 remove_and_clean_entry(socket_map.entry(*sid), &room, || ());
502 remove_and_clean_entry(rooms_map.entry(room.clone()), sid, || ());
503 }
504 }
505 }
506
507 pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
509 let sids = self
510 .apply_opts(&opts, &self.rooms.read().unwrap())
511 .collect();
512 self.emitter.disconnect_many(sids)
513 }
514
515 pub fn rooms(&self, opts: BroadcastOptions) -> HashSet<Room> {
517 let rooms = self.rooms.read().unwrap();
518 let sockets = self.sockets.read().unwrap();
519 let sids = self.apply_opts(&opts, &rooms);
520 sids.filter_map(|id| sockets.get(&id))
521 .flatten()
522 .cloned()
523 .collect()
524 }
525
526 pub fn path(&self) -> &Str {
528 self.emitter.path()
529 }
530
531 pub fn parser(&self) -> impl Parse + '_ {
533 self.emitter.parser()
534 }
535 pub fn server_id(&self) -> Uid {
537 self.emitter.server_id()
538 }
539}
540
541struct BroadcastRooms<'a> {
544 rooms: slice::Iter<'a, Room>,
545 rooms_map: &'a HashMap<Room, HashSet<Sid>>,
546 except: HashSet<Sid>,
547 flatten_iter: Option<hash_set::Iter<'a, Sid>>,
548}
549impl<'a> BroadcastRooms<'a> {
550 fn new(
551 rooms: &'a [Room],
552 rooms_map: &'a HashMap<Room, HashSet<Sid>>,
553 except: HashSet<Sid>,
554 ) -> Self {
555 BroadcastRooms {
556 rooms: rooms.iter(),
557 rooms_map,
558 except,
559 flatten_iter: None,
560 }
561 }
562}
563impl Iterator for BroadcastRooms<'_> {
564 type Item = Sid;
565 fn next(&mut self) -> Option<Self::Item> {
566 loop {
567 match self.flatten_iter.as_mut().and_then(Iterator::next) {
568 Some(sid) if !self.except.contains(sid) => return Some(*sid),
569 Some(_) => continue,
570 None => self.flatten_iter = None,
571 }
572
573 let room = self.rooms.next()?;
574 self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter);
575 }
576 }
577}
578
579impl<E: SocketEmitter> CoreLocalAdapter<E> {
580 fn apply_opts<'a>(
582 &self,
583 opts: &'a BroadcastOptions,
584 rooms: &'a HashMap<Room, HashSet<Sid>>,
585 ) -> BroadcastIter<'a> {
586 let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
587
588 let mut except = get_except_sids(&opts.except, rooms);
589 if is_broadcast && opts.sid.is_some() {
592 except.insert(opts.sid.unwrap());
593 }
594
595 if !opts.rooms.is_empty() {
596 let iter = BroadcastRooms::new(&opts.rooms, rooms, except);
597 InnerBroadcastIter::BroadcastRooms(iter).into()
598 } else if is_broadcast {
599 let sids = self.emitter.get_all_sids(|id| !except.contains(id));
600 InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into()
601 } else if let Some(id) = opts.sid {
602 InnerBroadcastIter::Single(id).into()
603 } else {
604 InnerBroadcastIter::None.into()
605 }
606 }
607}
608
609#[inline]
610fn get_except_sids(except: &[Room], rooms: &HashMap<Room, HashSet<Sid>>) -> HashSet<Sid> {
611 let mut except_sids = HashSet::new();
612 for room in except {
613 if let Some(sockets) = rooms.get(room) {
614 except_sids.extend(sockets);
615 }
616 }
617 except_sids
618}
619
620#[inline]
623fn remove_and_clean_entry<K, T: Hash + Eq>(
624 entry: hash_map::Entry<'_, K, HashSet<T>>,
625 el: &T,
626 cleanup: impl FnOnce(),
627) {
628 match entry {
631 hash_map::Entry::Occupied(mut entry) => {
632 entry.get_mut().remove(el);
633 if entry.get().is_empty() {
634 entry.remove_entry();
635 }
636 cleanup();
637 }
638 hash_map::Entry::Vacant(_) => (),
639 }
640}
641
642pub struct BroadcastIter<'a> {
645 inner: InnerBroadcastIter<'a>,
646}
647enum InnerBroadcastIter<'a> {
648 BroadcastRooms(BroadcastRooms<'a>),
649 GlobalBroadcast(<Vec<Sid> as IntoIterator>::IntoIter),
650 Single(Sid),
651 None,
652}
653impl BroadcastIter<'_> {
654 fn is_empty(&self) -> bool {
655 matches!(self.inner, InnerBroadcastIter::None)
656 }
657}
658impl<'a> From<InnerBroadcastIter<'a>> for BroadcastIter<'a> {
659 fn from(inner: InnerBroadcastIter<'a>) -> Self {
660 BroadcastIter { inner }
661 }
662}
663
664impl Iterator for BroadcastIter<'_> {
665 type Item = Sid;
666
667 #[inline(always)]
668 fn next(&mut self) -> Option<Self::Item> {
669 self.inner.next()
670 }
671}
672impl Iterator for InnerBroadcastIter<'_> {
673 type Item = Sid;
674
675 fn next(&mut self) -> Option<Self::Item> {
676 match self {
677 InnerBroadcastIter::BroadcastRooms(inner) => inner.next(),
678 InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(),
679 InnerBroadcastIter::Single(sid) => {
680 let sid = *sid;
681 *self = InnerBroadcastIter::None;
682 Some(sid)
683 }
684 InnerBroadcastIter::None => None,
685 }
686 }
687}
688
689#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Default, Clone)]
691pub struct RemoteSocketData {
692 pub id: Sid,
694 pub server_id: Uid,
696 pub ns: Str,
698}
699
700#[cfg(test)]
701mod test {
702
703 use smallvec::smallvec;
704 use std::{
705 array,
706 pin::Pin,
707 task::{Context, Poll},
708 };
709
710 use super::*;
711
712 struct StubSockets {
713 sockets: HashSet<Sid>,
714 path: Str,
715 }
716 impl StubSockets {
717 fn new(sockets: &[Sid]) -> Self {
718 let sockets = HashSet::from_iter(sockets.iter().copied());
719 Self {
720 sockets,
721 path: Str::from("/"),
722 }
723 }
724 }
725
726 struct StubAckStream;
727 impl Stream for StubAckStream {
728 type Item = (Sid, Result<Value, StubError>);
729 fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
730 Poll::Ready(None)
731 }
732 }
733 impl FusedStream for StubAckStream {
734 fn is_terminated(&self) -> bool {
735 true
736 }
737 }
738 #[derive(Debug, Serialize, Deserialize)]
739 struct StubError;
740 impl std::fmt::Display for StubError {
741 fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
742 Ok(())
743 }
744 }
745 impl std::error::Error for StubError {}
746
747 impl SocketEmitter for StubSockets {
748 type AckError = StubError;
749 type AckStream = StubAckStream;
750 fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
751 self.sockets
752 .iter()
753 .copied()
754 .filter(|id| filter(id))
755 .collect()
756 }
757
758 fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
759 sids.map(|id| RemoteSocketData {
760 id,
761 server_id: Uid::ZERO,
762 ns: self.path.clone(),
763 })
764 .collect()
765 }
766
767 fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec<SocketError>> {
768 Ok(())
769 }
770
771 fn send_many_with_ack(
772 &self,
773 _: BroadcastIter<'_>,
774 _: Packet,
775 _: Option<Duration>,
776 ) -> (Self::AckStream, u32) {
777 (StubAckStream, 0)
778 }
779
780 fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
781 Ok(())
782 }
783
784 fn path(&self) -> &Str {
785 &self.path
786 }
787 fn parser(&self) -> impl Parse {
788 crate::parser::test::StubParser
789 }
790 fn server_id(&self) -> Uid {
791 Uid::ZERO
792 }
793 }
794
795 fn create_adapter<const S: usize>(sockets: [Sid; S]) -> CoreLocalAdapter<StubSockets> {
796 CoreLocalAdapter::new(StubSockets::new(&sockets))
797 }
798
799 #[test]
800 fn add_all() {
801 let socket = Sid::new();
802 let adapter = create_adapter([socket]);
803 adapter.add_all(socket, ["room1", "room2"]);
804 let rooms_map = adapter.rooms.read().unwrap();
805 let socket_map = adapter.sockets.read().unwrap();
806 assert_eq!(rooms_map.len(), 2);
807 assert_eq!(socket_map.len(), 1);
808 assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
809 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
810
811 let rooms = socket_map.get(&socket).unwrap();
812 assert!(rooms.contains("room1"));
813 assert!(rooms.contains("room2"));
814 }
815
816 #[test]
817 fn del() {
818 let socket = Sid::new();
819 let adapter = create_adapter([socket]);
820 adapter.add_all(socket, ["room1", "room2"]);
821 {
822 let rooms_map = adapter.rooms.read().unwrap();
823 assert_eq!(rooms_map.len(), 2);
824 assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
825 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
826 let socket_map = adapter.sockets.read().unwrap();
827 let rooms = socket_map.get(&socket).unwrap();
828 assert!(rooms.contains("room1"));
829 assert!(rooms.contains("room2"));
830 }
831 adapter.del(socket, "room1");
832 let rooms_map = adapter.rooms.read().unwrap();
833 let socket_map = adapter.sockets.read().unwrap();
834 assert_eq!(rooms_map.len(), 1);
835 assert!(rooms_map.get("room1").is_none());
836 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
837 assert_eq!(socket_map.get(&socket).unwrap().len(), 1);
838 }
839 #[test]
840 fn del_all() {
841 let socket = Sid::new();
842 let adapter = create_adapter([socket]);
843 adapter.add_all(socket, ["room1", "room2"]);
844
845 {
846 let rooms_map = adapter.rooms.read().unwrap();
847 assert_eq!(rooms_map.len(), 2);
848 assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
849 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
850
851 let socket_map = adapter.sockets.read().unwrap();
852 let rooms = socket_map.get(&socket).unwrap();
853 assert!(rooms.contains("room1"));
854 assert!(rooms.contains("room2"));
855 }
856
857 adapter.del_all(socket);
858
859 {
860 let rooms_map = adapter.rooms.read().unwrap();
861 assert_eq!(rooms_map.len(), 0);
862
863 let socket_map = adapter.sockets.read().unwrap();
864 assert!(socket_map.get(&socket).is_none());
865 }
866 }
867
868 #[test]
869 fn socket_room() {
870 let sid1 = Sid::new();
871 let sid2 = Sid::new();
872 let sid3 = Sid::new();
873 let adapter = create_adapter([sid1, sid2, sid3]);
874 adapter.add_all(sid1, ["room1", "room2"]);
875 adapter.add_all(sid2, ["room1"]);
876 adapter.add_all(sid3, ["room2"]);
877 assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room1")));
878 assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room2")));
879 assert_eq!(
880 adapter.socket_rooms(sid2).into_iter().collect::<Vec<_>>(),
881 ["room1"]
882 );
883 assert_eq!(
884 adapter.socket_rooms(sid3).into_iter().collect::<Vec<_>>(),
885 ["room2"]
886 );
887 }
888
889 #[test]
890 fn add_socket() {
891 let socket = Sid::new();
892 let adapter = create_adapter([socket]);
893 adapter.add_all(socket, ["room1"]);
894
895 let mut opts = BroadcastOptions::new(socket);
896 opts.rooms = smallvec!["room1".into()];
897 adapter.add_sockets(opts, "room2");
898 let rooms_map = adapter.rooms.read().unwrap();
899
900 assert_eq!(rooms_map.len(), 2);
901 assert!(rooms_map.get("room1").unwrap().contains(&socket));
902 assert!(rooms_map.get("room2").unwrap().contains(&socket));
903 }
904
905 #[test]
906 fn del_socket() {
907 let socket = Sid::new();
908 let adapter = create_adapter([socket]);
909 adapter.add_all(socket, ["room1"]);
910
911 let mut opts = BroadcastOptions::new(socket);
912 opts.rooms = smallvec!["room1".into()];
913 adapter.add_sockets(opts, "room2");
914
915 {
916 let rooms_map = adapter.rooms.read().unwrap();
917
918 assert_eq!(rooms_map.len(), 2);
919 assert!(rooms_map.get("room1").unwrap().contains(&socket));
920 assert!(rooms_map.get("room2").unwrap().contains(&socket));
921 }
922
923 let mut opts = BroadcastOptions::new(socket);
924 opts.rooms = smallvec!["room1".into()];
925 adapter.del_sockets(opts, "room2");
926
927 {
928 let rooms_map = adapter.rooms.read().unwrap();
929
930 assert_eq!(rooms_map.len(), 1);
931 assert!(rooms_map.get("room1").unwrap().contains(&socket));
932 assert!(rooms_map.get("room2").is_none());
933 }
934 }
935
936 #[test]
937 fn sockets() {
938 let socket0 = Sid::new();
939 let socket1 = Sid::new();
940 let socket2 = Sid::new();
941 let adapter = create_adapter([socket0, socket1, socket2]);
942 adapter.add_all(socket0, ["room1", "room2"]);
943 adapter.add_all(socket1, ["room1", "room3"]);
944 adapter.add_all(socket2, ["room2", "room3"]);
945
946 let mut opts = BroadcastOptions {
947 rooms: smallvec!["room1".into()],
948 ..Default::default()
949 };
950 let sockets = adapter.sockets(opts.clone());
951 assert_eq!(sockets.len(), 2);
952 assert!(sockets.contains(&socket0));
953 assert!(sockets.contains(&socket1));
954
955 opts.rooms = smallvec!["room2".into()];
956 let sockets = adapter.sockets(opts.clone());
957 assert_eq!(sockets.len(), 2);
958 assert!(sockets.contains(&socket0));
959 assert!(sockets.contains(&socket2));
960
961 opts.rooms = smallvec!["room3".into()];
962 let sockets = adapter.sockets(opts.clone());
963 assert_eq!(sockets.len(), 2);
964 assert!(sockets.contains(&socket1));
965 assert!(sockets.contains(&socket2));
966 }
967
968 #[test]
969 fn disconnect_socket() {
970 let socket0 = Sid::new();
971 let socket1 = Sid::new();
972 let socket2 = Sid::new();
973 let adapter = create_adapter([socket0, socket1, socket2]);
974 adapter.add_all(socket0, ["room1", "room2", "room4"]);
975 adapter.add_all(socket1, ["room1", "room3", "room5"]);
976 adapter.add_all(socket2, ["room2", "room3", "room6"]);
977
978 let mut opts = BroadcastOptions::new(socket0);
979 opts.rooms = smallvec!["room5".into()];
980 adapter.disconnect_socket(opts).unwrap();
981
982 let mut opts = BroadcastOptions::default();
983 opts.rooms.push("room2".into());
984 let sockets = adapter.sockets(opts.clone());
985 assert_eq!(sockets.len(), 2);
986 assert!(sockets.contains(&socket2));
987 assert!(sockets.contains(&socket0));
988 }
989 #[test]
990 fn disconnect_empty_opts() {
991 let adapter = create_adapter([]);
992 let opts = BroadcastOptions::default();
993 adapter.disconnect_socket(opts).unwrap();
994 }
995 #[test]
996 fn rooms() {
997 let socket0 = Sid::new();
998 let socket1 = Sid::new();
999 let socket2 = Sid::new();
1000 let adapter = create_adapter([socket0, socket1, socket2]);
1001 adapter.add_all(socket0, ["room1", "room2", "room4"]);
1002 adapter.add_all(socket1, ["room1", "room3", "room5"]);
1003 adapter.add_all(socket2, ["room2", "room3", "room6"]);
1004
1005 let mut opts = BroadcastOptions::new(socket0);
1006 opts.rooms = smallvec!["room5".into()];
1007 opts.add_flag(BroadcastFlags::Broadcast);
1008 let rooms = adapter.rooms(opts);
1009 assert_eq!(rooms.len(), 3);
1010 assert!(rooms.contains(&Cow::Borrowed("room1")));
1011 assert!(rooms.contains(&Cow::Borrowed("room3")));
1012 assert!(rooms.contains(&Cow::Borrowed("room5")));
1013
1014 let mut opts = BroadcastOptions::default();
1015 opts.rooms.push("room2".into());
1016 let rooms = adapter.rooms(opts.clone());
1017 assert_eq!(rooms.len(), 5);
1018 assert!(rooms.contains(&Cow::Borrowed("room1")));
1019 assert!(rooms.contains(&Cow::Borrowed("room2")));
1020 assert!(rooms.contains(&Cow::Borrowed("room3")));
1021 assert!(rooms.contains(&Cow::Borrowed("room4")));
1022 assert!(rooms.contains(&Cow::Borrowed("room6")));
1023 }
1024
1025 #[test]
1026 fn apply_opts() {
1027 let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new());
1028 sockets.sort();
1029 let adapter = create_adapter(sockets);
1030
1031 adapter.add_all(sockets[0], ["room1", "room2"]);
1032 adapter.add_all(sockets[1], ["room1", "room3"]);
1033 adapter.add_all(sockets[2], ["room1", "room2", "room3"]);
1034
1035 let mut opts = BroadcastOptions::new(sockets[2]);
1037 opts.rooms = smallvec!["room1".into()];
1038 opts.except = smallvec!["room2".into()];
1039 let sids = adapter
1040 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1041 .collect::<Vec<_>>();
1042 assert_eq!(sids, [sockets[1]]);
1043
1044 let mut opts = BroadcastOptions::new(sockets[2]);
1045 opts.add_flag(BroadcastFlags::Broadcast);
1046 let mut sids = adapter
1047 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1048 .collect::<Vec<_>>();
1049 sids.sort();
1050 assert_eq!(sids, [sockets[0], sockets[1]]);
1051
1052 let mut opts = BroadcastOptions::new(sockets[2]);
1053 opts.add_flag(BroadcastFlags::Broadcast);
1054 opts.except = smallvec!["room2".into()];
1055 let sids = adapter
1056 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1057 .collect::<Vec<_>>();
1058 assert_eq!(sids.len(), 1);
1059
1060 let opts = BroadcastOptions::new(sockets[2]);
1061 let sids = adapter
1062 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1063 .collect::<Vec<_>>();
1064 assert_eq!(sids.len(), 1);
1065 assert_eq!(sids[0], sockets[2]);
1066
1067 let opts = BroadcastOptions::new(Sid::new());
1068 let sids = adapter
1069 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1070 .collect::<Vec<_>>();
1071 assert_eq!(sids.len(), 1);
1072 }
1073
1074 #[test]
1075 fn test_is_local_opts() {
1076 let server_id = Uid::new();
1077 let remote = RemoteSocketData {
1078 id: Sid::new(),
1079 server_id,
1080 ns: "/".into(),
1081 };
1082 let opts = BroadcastOptions::new_remote(&remote);
1083 assert!(opts.is_local(server_id));
1084 assert!(!opts.is_local(Uid::new()));
1085 let opts = BroadcastOptions::new(Sid::new());
1086 assert!(!opts.is_local(Uid::new()));
1087 }
1088}