wsio_server/namespace/operators/
broadcast.rs

1use std::{
2    collections::HashSet,
3    sync::Arc,
4};
5
6use anyhow::Result;
7use futures_util::{
8    StreamExt,
9    future::ready,
10    stream::iter,
11};
12use serde::Serialize;
13
14use super::super::{
15    NamespaceStatus,
16    WsIoServerNamespace,
17};
18use crate::{
19    connection::WsIoServerConnection,
20    core::{
21        packet::WsIoPacket,
22        types::hashers::FxHashSet,
23    },
24};
25
26// Structs
27pub struct WsIoServerNamespaceBroadcastOperator {
28    exclude_connection_ids: HashSet<u64>,
29    exclude_rooms: HashSet<String>,
30    include_rooms: HashSet<String>,
31    namespace: Arc<WsIoServerNamespace>,
32}
33
34impl WsIoServerNamespaceBroadcastOperator {
35    #[inline]
36    pub(in super::super) fn new(namespace: Arc<WsIoServerNamespace>) -> Self {
37        Self {
38            exclude_connection_ids: HashSet::new(),
39            exclude_rooms: HashSet::new(),
40            include_rooms: HashSet::new(),
41            namespace,
42        }
43    }
44
45    // Private methods
46    async fn for_each_target_connections<F, Fut>(&self, f: F)
47    where
48        F: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
49        Fut: Future<Output = Result<()>> + Send + 'static,
50    {
51        let mut target_connection_ids = FxHashSet::default();
52        if self.include_rooms.is_empty() {
53            target_connection_ids.extend(self.namespace.connections.iter().map(|entry| *entry.key()));
54        } else {
55            for room_name in &self.include_rooms {
56                if let Some(room) = self.namespace.rooms.get(room_name) {
57                    target_connection_ids.extend(room.iter().map(|entry| *entry.key()));
58                }
59            }
60        };
61
62        for room_name in &self.exclude_rooms {
63            if let Some(room) = self.namespace.rooms.get(room_name) {
64                for entry in room.iter() {
65                    target_connection_ids.remove(&entry);
66                }
67            }
68        }
69
70        for exclude_connection_id in &self.exclude_connection_ids {
71            target_connection_ids.remove(&exclude_connection_id);
72        }
73
74        iter(target_connection_ids)
75            .filter_map(|target_connection_id| {
76                ready(
77                    self.namespace
78                        .connections
79                        .get(&target_connection_id)
80                        .map(|entry| entry.value().clone()),
81                )
82            })
83            .for_each_concurrent(self.namespace.config.broadcast_concurrency_limit, |connection| async {
84                let _ = f(connection).await;
85            })
86            .await;
87    }
88
89    // Public methods
90    pub async fn disconnect(&self) -> Result<()> {
91        let message = self.namespace.encode_packet_to_message(&WsIoPacket::new_disconnect())?;
92        self.for_each_target_connections(move |connection| {
93            let message = message.clone();
94            async move { connection.send_message(message).await }
95        })
96        .await;
97
98        Ok(())
99    }
100
101    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
102        self.namespace.status.ensure(NamespaceStatus::Running, |status| {
103            format!("Cannot emit in invalid status: {status:?}")
104        })?;
105
106        let message = self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
107            event.as_ref(),
108            data.map(|data| self.namespace.config.packet_codec.encode_data(data))
109                .transpose()?,
110        ))?;
111
112        self.for_each_target_connections(move |connection| {
113            let message = message.clone();
114            async move { connection.emit_event_message(message).await }
115        })
116        .await;
117
118        Ok(())
119    }
120
121    #[inline]
122    pub fn except(mut self, room_names: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
123        self.exclude_rooms
124            .extend(room_names.into_iter().map(|room_name| room_name.as_ref().into()));
125
126        self
127    }
128
129    pub fn except_connection_ids(mut self, connection_ids: impl IntoIterator<Item = u64>) -> Self {
130        self.exclude_connection_ids.extend(connection_ids);
131        self
132    }
133
134    #[inline]
135    pub fn to(mut self, room_names: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
136        self.include_rooms
137            .extend(room_names.into_iter().map(|room_name| room_name.as_ref().into()));
138
139        self
140    }
141}