wsio_server/namespace/operators/
broadcast.rs1use std::{
2 collections::HashSet,
3 sync::Arc,
4};
5
6use anyhow::Result;
7use futures_util::{
8 StreamExt,
9 future::ready,
10 stream::iter,
11};
12use serde::Serialize;
13
14use super::super::{
15 NamespaceStatus,
16 WsIoServerNamespace,
17};
18use crate::{
19 connection::WsIoServerConnection,
20 core::{
21 packet::WsIoPacket,
22 types::hashers::FxHashSet,
23 },
24};
25
26pub struct WsIoServerNamespaceBroadcastOperator {
28 exclude_connection_ids: HashSet<u64>,
29 exclude_rooms: HashSet<String>,
30 include_rooms: HashSet<String>,
31 namespace: Arc<WsIoServerNamespace>,
32}
33
34impl WsIoServerNamespaceBroadcastOperator {
35 #[inline]
36 pub(in super::super) fn new(namespace: Arc<WsIoServerNamespace>) -> Self {
37 Self {
38 exclude_connection_ids: HashSet::new(),
39 exclude_rooms: HashSet::new(),
40 include_rooms: HashSet::new(),
41 namespace,
42 }
43 }
44
45 async fn for_each_target_connections<F, Fut>(&self, f: F)
47 where
48 F: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
49 Fut: Future<Output = Result<()>> + Send + 'static,
50 {
51 let mut target_connection_ids = FxHashSet::default();
52 if self.include_rooms.is_empty() {
53 target_connection_ids.extend(self.namespace.connections.iter().map(|entry| *entry.key()));
54 } else {
55 for room_name in &self.include_rooms {
56 if let Some(room) = self.namespace.rooms.get(room_name) {
57 target_connection_ids.extend(room.iter().map(|entry| *entry.key()));
58 }
59 }
60 };
61
62 for room_name in &self.exclude_rooms {
63 if let Some(room) = self.namespace.rooms.get(room_name) {
64 for entry in room.iter() {
65 target_connection_ids.remove(&entry);
66 }
67 }
68 }
69
70 for exclude_connection_id in &self.exclude_connection_ids {
71 target_connection_ids.remove(&exclude_connection_id);
72 }
73
74 iter(target_connection_ids)
75 .filter_map(|target_connection_id| {
76 ready(
77 self.namespace
78 .connections
79 .get(&target_connection_id)
80 .map(|entry| entry.value().clone()),
81 )
82 })
83 .for_each_concurrent(self.namespace.config.broadcast_concurrency_limit, |connection| async {
84 let _ = f(connection).await;
85 })
86 .await;
87 }
88
89 pub async fn disconnect(&self) -> Result<()> {
91 let message = self.namespace.encode_packet_to_message(&WsIoPacket::new_disconnect())?;
92 self.for_each_target_connections(move |connection| {
93 let message = message.clone();
94 async move { connection.send_message(message).await }
95 })
96 .await;
97
98 Ok(())
99 }
100
101 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
102 self.namespace.status.ensure(NamespaceStatus::Running, |status| {
103 format!("Cannot emit in invalid status: {status:?}")
104 })?;
105
106 let message = self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
107 event.as_ref(),
108 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
109 .transpose()?,
110 ))?;
111
112 self.for_each_target_connections(move |connection| {
113 let message = message.clone();
114 async move { connection.emit_event_message(message).await }
115 })
116 .await;
117
118 Ok(())
119 }
120
121 #[inline]
122 pub fn except<I: IntoIterator<Item = S>, S: AsRef<str>>(mut self, room_names: I) -> Self {
123 self.exclude_rooms
124 .extend(room_names.into_iter().map(|room_name| room_name.as_ref().into()));
125
126 self
127 }
128
129 pub fn except_connection_ids<I: IntoIterator<Item = u64>>(mut self, connection_ids: I) -> Self {
130 self.exclude_connection_ids.extend(connection_ids);
131 self
132 }
133
134 #[inline]
135 pub fn to<I: IntoIterator<Item = S>, S: AsRef<str>>(mut self, room_names: I) -> Self {
136 self.include_rooms
137 .extend(room_names.into_iter().map(|room_name| room_name.as_ref().into()));
138
139 self
140 }
141}