1use std::{
8 borrow::Cow,
9 collections::{hash_map, hash_set, HashMap, HashSet},
10 error::Error as StdError,
11 future::{self, Future},
12 hash::Hash,
13 slice,
14 sync::{Arc, RwLock},
15 time::Duration,
16};
17
18use engineioxide::{sid::Sid, Str};
19use futures_core::{FusedStream, Stream};
20use serde::{de::DeserializeOwned, Deserialize, Serialize};
21use smallvec::SmallVec;
22
23use crate::{
24 errors::{AdapterError, BroadcastError, SocketError},
25 packet::Packet,
26 parser::Parse,
27 Uid, Value,
28};
29
30pub type Room = Cow<'static, str>;
32
33#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
35pub enum BroadcastFlags {
36 Local = 0x01,
38 Broadcast = 0x02,
40}
41
42#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
44pub struct BroadcastOptions {
45 flags: u8,
47 pub rooms: SmallVec<[Room; 4]>,
49 pub except: SmallVec<[Room; 4]>,
51 pub sid: Option<Sid>,
53 pub server_id: Option<Uid>,
57}
58impl BroadcastOptions {
59 pub fn add_flag(&mut self, flag: BroadcastFlags) {
61 self.flags |= flag as u8;
62 }
63 pub fn has_flag(&self, flag: BroadcastFlags) -> bool {
65 self.flags & flag as u8 == flag as u8
66 }
67
68 pub fn flags(&self) -> u8 {
70 self.flags
71 }
72
73 pub fn new(sid: Sid) -> Self {
75 Self {
76 sid: Some(sid),
77 ..Default::default()
78 }
79 }
80 pub fn new_remote(data: &RemoteSocketData) -> Self {
82 Self {
83 sid: Some(data.id),
84 server_id: Some(data.server_id),
85 ..Default::default()
86 }
87 }
88}
89
90pub trait RoomParam: Send + 'static {
94 type IntoIter: Iterator<Item = Room>;
96
97 fn into_room_iter(self) -> Self::IntoIter;
99}
100
101impl RoomParam for Room {
102 type IntoIter = std::iter::Once<Room>;
103 #[inline(always)]
104 fn into_room_iter(self) -> Self::IntoIter {
105 std::iter::once(self)
106 }
107}
108impl RoomParam for String {
109 type IntoIter = std::iter::Once<Room>;
110 #[inline(always)]
111 fn into_room_iter(self) -> Self::IntoIter {
112 std::iter::once(Cow::Owned(self))
113 }
114}
115impl RoomParam for Vec<String> {
116 type IntoIter = std::iter::Map<std::vec::IntoIter<String>, fn(String) -> Room>;
117 #[inline(always)]
118 fn into_room_iter(self) -> Self::IntoIter {
119 self.into_iter().map(Cow::Owned)
120 }
121}
122impl RoomParam for Vec<&'static str> {
123 type IntoIter = std::iter::Map<std::vec::IntoIter<&'static str>, fn(&'static str) -> Room>;
124 #[inline(always)]
125 fn into_room_iter(self) -> Self::IntoIter {
126 self.into_iter().map(Cow::Borrowed)
127 }
128}
129
130impl RoomParam for Vec<Room> {
131 type IntoIter = std::vec::IntoIter<Room>;
132 #[inline(always)]
133 fn into_room_iter(self) -> Self::IntoIter {
134 self.into_iter()
135 }
136}
137impl RoomParam for &'static str {
138 type IntoIter = std::iter::Once<Room>;
139 #[inline(always)]
140 fn into_room_iter(self) -> Self::IntoIter {
141 std::iter::once(Cow::Borrowed(self))
142 }
143}
144impl<const COUNT: usize> RoomParam for [&'static str; COUNT] {
145 type IntoIter =
146 std::iter::Map<std::array::IntoIter<&'static str, COUNT>, fn(&'static str) -> Room>;
147
148 #[inline(always)]
149 fn into_room_iter(self) -> Self::IntoIter {
150 self.into_iter().map(Cow::Borrowed)
151 }
152}
153impl RoomParam for &'static [&'static str] {
154 type IntoIter =
155 std::iter::Map<std::slice::Iter<'static, &'static str>, fn(&'static &'static str) -> Room>;
156
157 #[inline(always)]
158 fn into_room_iter(self) -> Self::IntoIter {
159 self.iter().map(|i| Cow::Borrowed(*i))
160 }
161}
162impl<const COUNT: usize> RoomParam for [String; COUNT] {
163 type IntoIter = std::iter::Map<std::array::IntoIter<String, COUNT>, fn(String) -> Room>;
164 #[inline(always)]
165 fn into_room_iter(self) -> Self::IntoIter {
166 self.into_iter().map(Cow::Owned)
167 }
168}
169impl RoomParam for Sid {
170 type IntoIter = std::iter::Once<Room>;
171 #[inline(always)]
172 fn into_room_iter(self) -> Self::IntoIter {
173 std::iter::once(Cow::Owned(self.to_string()))
174 }
175}
176
177pub type AckStreamItem<E> = (Sid, Result<Value, E>);
179pub trait SocketEmitter: Send + Sync + 'static {
183 type AckError: StdError + Send + Serialize + DeserializeOwned + 'static;
185 type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;
187
188 fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
190 fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData>;
192 fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
194 fn send_many_with_ack(
196 &self,
197 sids: BroadcastIter<'_>,
198 packet: Packet,
199 timeout: Option<Duration>,
200 ) -> (Self::AckStream, u32);
201 fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
205 fn path(&self) -> &Str;
207 fn parser(&self) -> impl Parse;
209 fn server_id(&self) -> Uid;
211}
212
213pub trait Spawnable {
218 fn spawn(self);
221}
222impl Spawnable for () {
223 fn spawn(self) {}
224}
225
226pub trait DefinedAdapter {}
232
233pub trait CoreAdapter<E: SocketEmitter>: Sized + Send + Sync + 'static {
239 type Error: StdError + Into<AdapterError> + Send + 'static;
241 type State: Send + Sync + 'static;
244 type AckStream: Stream<Item = AckStreamItem<E::AckError>> + FusedStream + Send + 'static;
246 type InitRes: Spawnable + Send;
248
249 fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self;
254
255 fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes;
257
258 fn close(&self) -> impl Future<Output = Result<(), Self::Error>> + Send {
260 future::ready(Ok(()))
261 }
262
263 fn server_count(&self) -> impl Future<Output = Result<u16, Self::Error>> + Send {
265 future::ready(Ok(1))
266 }
267
268 fn broadcast(
270 &self,
271 packet: Packet,
272 opts: BroadcastOptions,
273 ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
274 future::ready(
275 self.get_local()
276 .broadcast(packet, opts)
277 .map_err(BroadcastError::from),
278 )
279 }
280
281 fn broadcast_with_ack(
287 &self,
288 packet: Packet,
289 opts: BroadcastOptions,
290 timeout: Option<Duration>,
291 ) -> impl Future<Output = Result<Self::AckStream, Self::Error>> + Send;
292
293 fn add_sockets(
295 &self,
296 opts: BroadcastOptions,
297 rooms: impl RoomParam,
298 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299 self.get_local().add_sockets(opts, rooms);
300 future::ready(Ok(()))
301 }
302
303 fn del_sockets(
305 &self,
306 opts: BroadcastOptions,
307 rooms: impl RoomParam,
308 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
309 self.get_local().del_sockets(opts, rooms);
310 future::ready(Ok(()))
311 }
312
313 fn disconnect_socket(
315 &self,
316 opts: BroadcastOptions,
317 ) -> impl Future<Output = Result<(), BroadcastError>> + Send {
318 future::ready(
319 self.get_local()
320 .disconnect_socket(opts)
321 .map_err(BroadcastError::Socket),
322 )
323 }
324
325 fn rooms(
327 &self,
328 opts: BroadcastOptions,
329 ) -> impl Future<Output = Result<Vec<Room>, Self::Error>> + Send {
330 future::ready(Ok(self.get_local().rooms(opts).into_iter().collect()))
331 }
332
333 fn fetch_sockets(
335 &self,
336 opts: BroadcastOptions,
337 ) -> impl Future<Output = Result<Vec<RemoteSocketData>, Self::Error>> + Send {
338 future::ready(Ok(self.get_local().fetch_sockets(opts)))
339 }
340
341 fn get_local(&self) -> &CoreLocalAdapter<E>;
343
344 }
349
350pub struct CoreLocalAdapter<E> {
352 rooms: RwLock<HashMap<Room, HashSet<Sid>>>,
353 sockets: RwLock<HashMap<Sid, HashSet<Room>>>,
354 emitter: E,
355}
356
357impl<E: SocketEmitter> CoreLocalAdapter<E> {
358 pub fn new(emitter: E) -> Self {
360 Self {
361 rooms: RwLock::new(HashMap::new()),
362 sockets: RwLock::new(HashMap::new()),
363 emitter,
364 }
365 }
366
367 pub fn close(&self) {
369 let mut rooms = self.rooms.write().unwrap();
370 rooms.clear();
371 rooms.shrink_to_fit();
372 }
373
374 pub fn add_all(&self, sid: Sid, rooms: impl RoomParam) {
376 let mut rooms_map = self.rooms.write().unwrap();
377 let mut socket_map = self.sockets.write().unwrap();
378 for room in rooms.into_room_iter() {
379 rooms_map.entry(room.clone()).or_default().insert(sid);
380 socket_map.entry(sid).or_default().insert(room);
381 }
382 }
383
384 pub fn del(&self, sid: Sid, rooms: impl RoomParam) {
386 let mut rooms_map = self.rooms.write().unwrap();
387 let mut socket_map = self.sockets.write().unwrap();
388 for room in rooms.into_room_iter() {
389 remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || {
390 socket_map.entry(sid).and_modify(|r| {
391 r.remove(&room);
392 });
393 });
394 }
395 }
396
397 pub fn del_all(&self, sid: Sid) {
399 let mut rooms_map = self.rooms.write().unwrap();
400 if let Some(rooms) = self.sockets.write().unwrap().remove(&sid) {
401 for room in rooms {
402 remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || ());
403 }
404 }
405 }
406
407 pub fn broadcast(
409 &self,
410 packet: Packet,
411 opts: BroadcastOptions,
412 ) -> Result<(), Vec<SocketError>> {
413 let room_map = self.rooms.read().unwrap();
414 let sids = self.apply_opts(&opts, &room_map);
415
416 if sids.is_empty() {
417 return Ok(());
418 }
419
420 let data = self.emitter.parser().encode(packet);
421 self.emitter.send_many(sids, data)
422 }
423
424 pub fn broadcast_with_ack(
427 &self,
428 packet: Packet,
429 opts: BroadcastOptions,
430 timeout: Option<Duration>,
431 ) -> (E::AckStream, u32) {
432 let room_map = self.rooms.read().unwrap();
433 let sids = self.apply_opts(&opts, &room_map);
434 self.emitter.send_many_with_ack(sids, packet, timeout)
436 }
437
438 pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
440 self.apply_opts(&opts, &self.rooms.read().unwrap())
441 .collect()
442 }
443
444 pub fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec<RemoteSocketData> {
446 let rooms = self.rooms.read().unwrap();
447 let sids = self.apply_opts(&opts, &rooms);
448 self.emitter.get_remote_sockets(sids)
449 }
450
451 pub fn socket_rooms(&self, sid: Sid) -> HashSet<Room> {
453 self.sockets
454 .read()
455 .unwrap()
456 .get(&sid)
457 .cloned()
458 .unwrap_or_default()
459 }
460
461 pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
463 let rooms: Vec<Room> = rooms.into_room_iter().collect();
464 let mut room_map = self.rooms.write().unwrap();
465 let mut socket_map = self.sockets.write().unwrap();
466 let sids = self.apply_opts(&opts, &room_map).collect::<Vec<_>>();
468 for sid in &sids {
469 let entry = socket_map.entry(*sid).or_default();
470 for room in &rooms {
471 entry.insert(room.clone());
472 }
473 }
474 for room in rooms {
475 let entry = room_map.entry(room).or_default();
476 for sid in &sids {
477 entry.insert(*sid);
478 }
479 }
480 }
481
482 pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
484 let rooms: Vec<Room> = rooms.into_room_iter().collect();
485 let mut rooms_map = self.rooms.write().unwrap();
486 let mut socket_map = self.sockets.write().unwrap();
487 let sids = self.apply_opts(&opts, &rooms_map).collect::<Vec<_>>();
488 for room in rooms {
489 for sid in &sids {
490 remove_and_clean_entry(socket_map.entry(*sid), &room, || ());
491 remove_and_clean_entry(rooms_map.entry(room.clone()), sid, || ());
492 }
493 }
494 }
495
496 pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
498 let sids = self
499 .apply_opts(&opts, &self.rooms.read().unwrap())
500 .collect();
501 self.emitter.disconnect_many(sids)
502 }
503
504 pub fn rooms(&self, opts: BroadcastOptions) -> HashSet<Room> {
506 let rooms = self.rooms.read().unwrap();
507 let sockets = self.sockets.read().unwrap();
508 let sids = self.apply_opts(&opts, &rooms);
509 sids.filter_map(|id| sockets.get(&id))
510 .flatten()
511 .cloned()
512 .collect()
513 }
514
515 pub fn path(&self) -> &Str {
517 self.emitter.path()
518 }
519
520 pub fn parser(&self) -> impl Parse + '_ {
522 self.emitter.parser()
523 }
524 pub fn server_id(&self) -> Uid {
526 self.emitter.server_id()
527 }
528}
529
530struct BroadcastRooms<'a> {
533 rooms: slice::Iter<'a, Room>,
534 rooms_map: &'a HashMap<Room, HashSet<Sid>>,
535 except: HashSet<Sid>,
536 flatten_iter: Option<hash_set::Iter<'a, Sid>>,
537}
538impl<'a> BroadcastRooms<'a> {
539 fn new(
540 rooms: &'a [Room],
541 rooms_map: &'a HashMap<Room, HashSet<Sid>>,
542 except: HashSet<Sid>,
543 ) -> Self {
544 BroadcastRooms {
545 rooms: rooms.iter(),
546 rooms_map,
547 except,
548 flatten_iter: None,
549 }
550 }
551}
552impl Iterator for BroadcastRooms<'_> {
553 type Item = Sid;
554 fn next(&mut self) -> Option<Self::Item> {
555 loop {
556 match self.flatten_iter.as_mut().and_then(Iterator::next) {
557 Some(sid) if !self.except.contains(sid) => return Some(*sid),
558 Some(_) => continue,
559 None => self.flatten_iter = None,
560 }
561
562 let room = self.rooms.next()?;
563 self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter);
564 }
565 }
566}
567
568impl<E: SocketEmitter> CoreLocalAdapter<E> {
569 fn apply_opts<'a>(
571 &self,
572 opts: &'a BroadcastOptions,
573 rooms: &'a HashMap<Room, HashSet<Sid>>,
574 ) -> BroadcastIter<'a> {
575 let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
576
577 let mut except = get_except_sids(&opts.except, rooms);
578 if is_broadcast && opts.sid.is_some() {
581 except.insert(opts.sid.unwrap());
582 }
583
584 if !opts.rooms.is_empty() {
585 let iter = BroadcastRooms::new(&opts.rooms, rooms, except);
586 InnerBroadcastIter::BroadcastRooms(iter).into()
587 } else if is_broadcast {
588 let sids = self.emitter.get_all_sids(|id| !except.contains(id));
589 InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into()
590 } else if let Some(id) = opts.sid {
591 InnerBroadcastIter::Single(id).into()
592 } else {
593 InnerBroadcastIter::None.into()
594 }
595 }
596}
597
598#[inline]
599fn get_except_sids(except: &[Room], rooms: &HashMap<Room, HashSet<Sid>>) -> HashSet<Sid> {
600 let mut except_sids = HashSet::new();
601 for room in except {
602 if let Some(sockets) = rooms.get(room) {
603 except_sids.extend(sockets);
604 }
605 }
606 except_sids
607}
608
609#[inline]
612fn remove_and_clean_entry<K, T: Hash + Eq>(
613 entry: hash_map::Entry<'_, K, HashSet<T>>,
614 el: &T,
615 cleanup: impl FnOnce(),
616) {
617 match entry {
620 hash_map::Entry::Occupied(mut entry) => {
621 entry.get_mut().remove(el);
622 if entry.get().is_empty() {
623 entry.remove_entry();
624 }
625 cleanup();
626 }
627 hash_map::Entry::Vacant(_) => (),
628 }
629}
630
631pub struct BroadcastIter<'a> {
634 inner: InnerBroadcastIter<'a>,
635}
636enum InnerBroadcastIter<'a> {
637 BroadcastRooms(BroadcastRooms<'a>),
638 GlobalBroadcast(<Vec<Sid> as IntoIterator>::IntoIter),
639 Single(Sid),
640 None,
641}
642impl BroadcastIter<'_> {
643 fn is_empty(&self) -> bool {
644 matches!(self.inner, InnerBroadcastIter::None)
645 }
646}
647impl<'a> From<InnerBroadcastIter<'a>> for BroadcastIter<'a> {
648 fn from(inner: InnerBroadcastIter<'a>) -> Self {
649 BroadcastIter { inner }
650 }
651}
652
653impl Iterator for BroadcastIter<'_> {
654 type Item = Sid;
655
656 #[inline(always)]
657 fn next(&mut self) -> Option<Self::Item> {
658 self.inner.next()
659 }
660}
661impl Iterator for InnerBroadcastIter<'_> {
662 type Item = Sid;
663
664 fn next(&mut self) -> Option<Self::Item> {
665 match self {
666 InnerBroadcastIter::BroadcastRooms(inner) => inner.next(),
667 InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(),
668 InnerBroadcastIter::Single(sid) => {
669 let sid = *sid;
670 *self = InnerBroadcastIter::None;
671 Some(sid)
672 }
673 InnerBroadcastIter::None => None,
674 }
675 }
676}
677
678#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Default, Clone)]
680pub struct RemoteSocketData {
681 pub id: Sid,
683 pub server_id: Uid,
685 pub ns: Str,
687}
688
689#[cfg(test)]
690mod test {
691
692 use smallvec::smallvec;
693 use std::{
694 array,
695 pin::Pin,
696 task::{Context, Poll},
697 };
698
699 use super::*;
700
701 struct StubSockets {
702 sockets: HashSet<Sid>,
703 path: Str,
704 }
705 impl StubSockets {
706 fn new(sockets: &[Sid]) -> Self {
707 let sockets = HashSet::from_iter(sockets.iter().copied());
708 Self {
709 sockets,
710 path: Str::from("/"),
711 }
712 }
713 }
714
715 struct StubAckStream;
716 impl Stream for StubAckStream {
717 type Item = (Sid, Result<Value, StubError>);
718 fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
719 Poll::Ready(None)
720 }
721 }
722 impl FusedStream for StubAckStream {
723 fn is_terminated(&self) -> bool {
724 true
725 }
726 }
727 #[derive(Debug, Serialize, Deserialize)]
728 struct StubError;
729 impl std::fmt::Display for StubError {
730 fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
731 Ok(())
732 }
733 }
734 impl std::error::Error for StubError {}
735
736 impl SocketEmitter for StubSockets {
737 type AckError = StubError;
738 type AckStream = StubAckStream;
739 fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
740 self.sockets
741 .iter()
742 .copied()
743 .filter(|id| filter(id))
744 .collect()
745 }
746
747 fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
748 sids.map(|id| RemoteSocketData {
749 id,
750 server_id: Uid::ZERO,
751 ns: self.path.clone(),
752 })
753 .collect()
754 }
755
756 fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec<SocketError>> {
757 Ok(())
758 }
759
760 fn send_many_with_ack(
761 &self,
762 _: BroadcastIter<'_>,
763 _: Packet,
764 _: Option<Duration>,
765 ) -> (Self::AckStream, u32) {
766 (StubAckStream, 0)
767 }
768
769 fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
770 Ok(())
771 }
772
773 fn path(&self) -> &Str {
774 &self.path
775 }
776 fn parser(&self) -> impl Parse {
777 crate::parser::test::StubParser
778 }
779 fn server_id(&self) -> Uid {
780 Uid::ZERO
781 }
782 }
783
784 fn create_adapter<const S: usize>(sockets: [Sid; S]) -> CoreLocalAdapter<StubSockets> {
785 CoreLocalAdapter::new(StubSockets::new(&sockets))
786 }
787
788 #[test]
789 fn add_all() {
790 let socket = Sid::new();
791 let adapter = create_adapter([socket]);
792 adapter.add_all(socket, ["room1", "room2"]);
793 let rooms_map = adapter.rooms.read().unwrap();
794 let socket_map = adapter.sockets.read().unwrap();
795 assert_eq!(rooms_map.len(), 2);
796 assert_eq!(socket_map.len(), 1);
797 assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
798 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
799
800 let rooms = socket_map.get(&socket).unwrap();
801 assert!(rooms.contains("room1"));
802 assert!(rooms.contains("room2"));
803 }
804
805 #[test]
806 fn del() {
807 let socket = Sid::new();
808 let adapter = create_adapter([socket]);
809 adapter.add_all(socket, ["room1", "room2"]);
810 {
811 let rooms_map = adapter.rooms.read().unwrap();
812 assert_eq!(rooms_map.len(), 2);
813 assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
814 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
815 let socket_map = adapter.sockets.read().unwrap();
816 let rooms = socket_map.get(&socket).unwrap();
817 assert!(rooms.contains("room1"));
818 assert!(rooms.contains("room2"));
819 }
820 adapter.del(socket, "room1");
821 let rooms_map = adapter.rooms.read().unwrap();
822 let socket_map = adapter.sockets.read().unwrap();
823 assert_eq!(rooms_map.len(), 1);
824 assert!(rooms_map.get("room1").is_none());
825 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
826 assert_eq!(socket_map.get(&socket).unwrap().len(), 1);
827 }
828 #[test]
829 fn del_all() {
830 let socket = Sid::new();
831 let adapter = create_adapter([socket]);
832 adapter.add_all(socket, ["room1", "room2"]);
833
834 {
835 let rooms_map = adapter.rooms.read().unwrap();
836 assert_eq!(rooms_map.len(), 2);
837 assert_eq!(rooms_map.get("room1").unwrap().len(), 1);
838 assert_eq!(rooms_map.get("room2").unwrap().len(), 1);
839
840 let socket_map = adapter.sockets.read().unwrap();
841 let rooms = socket_map.get(&socket).unwrap();
842 assert!(rooms.contains("room1"));
843 assert!(rooms.contains("room2"));
844 }
845
846 adapter.del_all(socket);
847
848 {
849 let rooms_map = adapter.rooms.read().unwrap();
850 assert_eq!(rooms_map.len(), 0);
851
852 let socket_map = adapter.sockets.read().unwrap();
853 assert!(socket_map.get(&socket).is_none());
854 }
855 }
856
857 #[test]
858 fn socket_room() {
859 let sid1 = Sid::new();
860 let sid2 = Sid::new();
861 let sid3 = Sid::new();
862 let adapter = create_adapter([sid1, sid2, sid3]);
863 adapter.add_all(sid1, ["room1", "room2"]);
864 adapter.add_all(sid2, ["room1"]);
865 adapter.add_all(sid3, ["room2"]);
866 assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room1")));
867 assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room2")));
868 assert_eq!(
869 adapter.socket_rooms(sid2).into_iter().collect::<Vec<_>>(),
870 ["room1"]
871 );
872 assert_eq!(
873 adapter.socket_rooms(sid3).into_iter().collect::<Vec<_>>(),
874 ["room2"]
875 );
876 }
877
878 #[test]
879 fn add_socket() {
880 let socket = Sid::new();
881 let adapter = create_adapter([socket]);
882 adapter.add_all(socket, ["room1"]);
883
884 let mut opts = BroadcastOptions::new(socket);
885 opts.rooms = smallvec!["room1".into()];
886 adapter.add_sockets(opts, "room2");
887 let rooms_map = adapter.rooms.read().unwrap();
888
889 assert_eq!(rooms_map.len(), 2);
890 assert!(rooms_map.get("room1").unwrap().contains(&socket));
891 assert!(rooms_map.get("room2").unwrap().contains(&socket));
892 }
893
894 #[test]
895 fn del_socket() {
896 let socket = Sid::new();
897 let adapter = create_adapter([socket]);
898 adapter.add_all(socket, ["room1"]);
899
900 let mut opts = BroadcastOptions::new(socket);
901 opts.rooms = smallvec!["room1".into()];
902 adapter.add_sockets(opts, "room2");
903
904 {
905 let rooms_map = adapter.rooms.read().unwrap();
906
907 assert_eq!(rooms_map.len(), 2);
908 assert!(rooms_map.get("room1").unwrap().contains(&socket));
909 assert!(rooms_map.get("room2").unwrap().contains(&socket));
910 }
911
912 let mut opts = BroadcastOptions::new(socket);
913 opts.rooms = smallvec!["room1".into()];
914 adapter.del_sockets(opts, "room2");
915
916 {
917 let rooms_map = adapter.rooms.read().unwrap();
918
919 assert_eq!(rooms_map.len(), 1);
920 assert!(rooms_map.get("room1").unwrap().contains(&socket));
921 assert!(rooms_map.get("room2").is_none());
922 }
923 }
924
925 #[test]
926 fn sockets() {
927 let socket0 = Sid::new();
928 let socket1 = Sid::new();
929 let socket2 = Sid::new();
930 let adapter = create_adapter([socket0, socket1, socket2]);
931 adapter.add_all(socket0, ["room1", "room2"]);
932 adapter.add_all(socket1, ["room1", "room3"]);
933 adapter.add_all(socket2, ["room2", "room3"]);
934
935 let mut opts = BroadcastOptions {
936 rooms: smallvec!["room1".into()],
937 ..Default::default()
938 };
939 let sockets = adapter.sockets(opts.clone());
940 assert_eq!(sockets.len(), 2);
941 assert!(sockets.contains(&socket0));
942 assert!(sockets.contains(&socket1));
943
944 opts.rooms = smallvec!["room2".into()];
945 let sockets = adapter.sockets(opts.clone());
946 assert_eq!(sockets.len(), 2);
947 assert!(sockets.contains(&socket0));
948 assert!(sockets.contains(&socket2));
949
950 opts.rooms = smallvec!["room3".into()];
951 let sockets = adapter.sockets(opts.clone());
952 assert_eq!(sockets.len(), 2);
953 assert!(sockets.contains(&socket1));
954 assert!(sockets.contains(&socket2));
955 }
956
957 #[test]
958 fn disconnect_socket() {
959 let socket0 = Sid::new();
960 let socket1 = Sid::new();
961 let socket2 = Sid::new();
962 let adapter = create_adapter([socket0, socket1, socket2]);
963 adapter.add_all(socket0, ["room1", "room2", "room4"]);
964 adapter.add_all(socket1, ["room1", "room3", "room5"]);
965 adapter.add_all(socket2, ["room2", "room3", "room6"]);
966
967 let mut opts = BroadcastOptions::new(socket0);
968 opts.rooms = smallvec!["room5".into()];
969 adapter.disconnect_socket(opts).unwrap();
970
971 let mut opts = BroadcastOptions::default();
972 opts.rooms.push("room2".into());
973 let sockets = adapter.sockets(opts.clone());
974 assert_eq!(sockets.len(), 2);
975 assert!(sockets.contains(&socket2));
976 assert!(sockets.contains(&socket0));
977 }
978 #[test]
979 fn disconnect_empty_opts() {
980 let adapter = create_adapter([]);
981 let opts = BroadcastOptions::default();
982 adapter.disconnect_socket(opts).unwrap();
983 }
984 #[test]
985 fn rooms() {
986 let socket0 = Sid::new();
987 let socket1 = Sid::new();
988 let socket2 = Sid::new();
989 let adapter = create_adapter([socket0, socket1, socket2]);
990 adapter.add_all(socket0, ["room1", "room2", "room4"]);
991 adapter.add_all(socket1, ["room1", "room3", "room5"]);
992 adapter.add_all(socket2, ["room2", "room3", "room6"]);
993
994 let mut opts = BroadcastOptions::new(socket0);
995 opts.rooms = smallvec!["room5".into()];
996 opts.add_flag(BroadcastFlags::Broadcast);
997 let rooms = adapter.rooms(opts);
998 assert_eq!(rooms.len(), 3);
999 assert!(rooms.contains(&Cow::Borrowed("room1")));
1000 assert!(rooms.contains(&Cow::Borrowed("room3")));
1001 assert!(rooms.contains(&Cow::Borrowed("room5")));
1002
1003 let mut opts = BroadcastOptions::default();
1004 opts.rooms.push("room2".into());
1005 let rooms = adapter.rooms(opts.clone());
1006 assert_eq!(rooms.len(), 5);
1007 assert!(rooms.contains(&Cow::Borrowed("room1")));
1008 assert!(rooms.contains(&Cow::Borrowed("room2")));
1009 assert!(rooms.contains(&Cow::Borrowed("room3")));
1010 assert!(rooms.contains(&Cow::Borrowed("room4")));
1011 assert!(rooms.contains(&Cow::Borrowed("room6")));
1012 }
1013
1014 #[test]
1015 fn apply_opts() {
1016 let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new());
1017 sockets.sort();
1018 let adapter = create_adapter(sockets);
1019
1020 adapter.add_all(sockets[0], ["room1", "room2"]);
1021 adapter.add_all(sockets[1], ["room1", "room3"]);
1022 adapter.add_all(sockets[2], ["room1", "room2", "room3"]);
1023
1024 let mut opts = BroadcastOptions::new(sockets[2]);
1026 opts.rooms = smallvec!["room1".into()];
1027 opts.except = smallvec!["room2".into()];
1028 let sids = adapter
1029 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1030 .collect::<Vec<_>>();
1031 assert_eq!(sids, [sockets[1]]);
1032
1033 let mut opts = BroadcastOptions::new(sockets[2]);
1034 opts.add_flag(BroadcastFlags::Broadcast);
1035 let mut sids = adapter
1036 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1037 .collect::<Vec<_>>();
1038 sids.sort();
1039 assert_eq!(sids, [sockets[0], sockets[1]]);
1040
1041 let mut opts = BroadcastOptions::new(sockets[2]);
1042 opts.add_flag(BroadcastFlags::Broadcast);
1043 opts.except = smallvec!["room2".into()];
1044 let sids = adapter
1045 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1046 .collect::<Vec<_>>();
1047 assert_eq!(sids.len(), 1);
1048
1049 let opts = BroadcastOptions::new(sockets[2]);
1050 let sids = adapter
1051 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1052 .collect::<Vec<_>>();
1053 assert_eq!(sids.len(), 1);
1054 assert_eq!(sids[0], sockets[2]);
1055
1056 let opts = BroadcastOptions::new(Sid::new());
1057 let sids = adapter
1058 .apply_opts(&opts, &adapter.rooms.read().unwrap())
1059 .collect::<Vec<_>>();
1060 assert_eq!(sids.len(), 1);
1061 }
1062}