wsio_server/namespace/
mod.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use futures_util::{
5    SinkExt,
6    StreamExt,
7};
8use http::{
9    HeaderMap,
10    Uri,
11};
12use hyper::upgrade::{
13    OnUpgrade,
14    Upgraded,
15};
16use hyper_util::rt::TokioIo;
17use num_enum::{
18    IntoPrimitive,
19    TryFromPrimitive,
20};
21use serde::Serialize;
22use tokio::{
23    join,
24    select,
25    spawn,
26    sync::Mutex,
27    task::JoinSet,
28};
29use tokio_tungstenite::{
30    WebSocketStream,
31    tungstenite::{
32        Message,
33        protocol::Role,
34    },
35};
36
37pub(crate) mod builder;
38mod config;
39pub mod operators;
40
41use self::{
42    config::WsIoServerNamespaceConfig,
43    operators::broadcast::WsIoServerNamespaceBroadcastOperator,
44};
45use crate::{
46    WsIoServer,
47    connection::WsIoServerConnection,
48    core::{
49        atomic::status::AtomicStatus,
50        packet::WsIoPacket,
51        types::hashers::{
52            FxDashMap,
53            FxDashSet,
54        },
55    },
56    runtime::{
57        WsIoServerRuntime,
58        WsIoServerRuntimeStatus,
59    },
60};
61
62// Enums
63#[repr(u8)]
64#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
65enum NamespaceStatus {
66    Running,
67    Stopped,
68    Stopping,
69}
70
71// Structs
72pub struct WsIoServerNamespace {
73    pub(crate) config: WsIoServerNamespaceConfig,
74    connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
75    connection_task_set: Mutex<JoinSet<()>>,
76    rooms: FxDashMap<String, Arc<FxDashSet<u64>>>,
77    runtime: Arc<WsIoServerRuntime>,
78    status: AtomicStatus<NamespaceStatus>,
79}
80
81impl WsIoServerNamespace {
82    fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
83        Arc::new(Self {
84            config,
85            connections: FxDashMap::default(),
86            connection_task_set: Mutex::new(JoinSet::new()),
87            rooms: FxDashMap::default(),
88            runtime,
89            status: AtomicStatus::new(NamespaceStatus::Running),
90        })
91    }
92
93    // Private methods
94    async fn handle_upgraded_request(
95        self: &Arc<Self>,
96        headers: HeaderMap,
97        request_uri: Uri,
98        upgraded: Upgraded,
99    ) -> Result<()> {
100        // Create ws stream
101        let mut ws_stream =
102            WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
103                .await;
104
105        // Check runtime and namespace status
106        if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
107            ws_stream
108                .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
109                .await?;
110
111            let _ = ws_stream.close(None).await;
112            return Ok(());
113        }
114
115        // Create connection
116        let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
117
118        // Split ws stream and spawn read and write tasks
119        let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
120        let connection_clone = connection.clone();
121        let mut read_ws_stream_task = spawn(async move {
122            while let Some(message) = ws_stream_reader.next().await {
123                if match message {
124                    Ok(Message::Binary(bytes)) => connection_clone.handle_incoming_packet(&bytes).await,
125                    Ok(Message::Close(_)) => break,
126                    Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
127                    Err(_) => break,
128                    _ => Ok(()),
129                }
130                .is_err()
131                {
132                    break;
133                }
134            }
135        });
136
137        let mut write_ws_stream_task = spawn(async move {
138            while let Some(message) = message_rx.recv().await {
139                let message = (*message).clone();
140                let is_close = matches!(message, Message::Close(_));
141                if ws_stream_writer.send(message).await.is_err() {
142                    break;
143                }
144
145                if is_close {
146                    let _ = ws_stream_writer.close().await;
147                    break;
148                }
149            }
150        });
151
152        // Try to init connection
153        match connection.init().await {
154            Ok(_) => {
155                // Wait for either read or write task to finish
156                select! {
157                    _ = &mut read_ws_stream_task => {
158                        write_ws_stream_task.abort();
159                    },
160                    _ = &mut write_ws_stream_task => {
161                        read_ws_stream_task.abort();
162                    },
163                }
164            }
165            Err(_) => {
166                // Close connection
167                read_ws_stream_task.abort();
168                connection.close();
169                let _ = join!(read_ws_stream_task, write_ws_stream_task);
170            }
171        }
172
173        // Cleanup connection
174        connection.cleanup().await;
175        Ok(())
176    }
177
178    // Protected methods
179    #[inline]
180    pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
181        self.rooms
182            .entry(room_name.into())
183            .or_default()
184            .clone()
185            .insert(connection_id);
186    }
187
188    #[inline]
189    pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
190        let bytes = self.config.packet_codec.encode(packet)?;
191        Ok(Arc::new(match self.config.packet_codec.is_text() {
192            true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
193            false => Message::Binary(bytes.into()),
194        }))
195    }
196
197    pub(crate) async fn handle_on_upgrade_request(
198        self: &Arc<Self>,
199        headers: HeaderMap,
200        on_upgrade: OnUpgrade,
201        request_uri: Uri,
202    ) {
203        let namespace = self.clone();
204        self.connection_task_set.lock().await.spawn(async move {
205            if let Ok(upgraded) = on_upgrade.await {
206                let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
207            }
208        });
209    }
210
211    #[inline]
212    pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
213        self.connections.insert(connection.id(), connection.clone());
214        self.runtime.insert_connection_id(connection.id());
215    }
216
217    #[inline]
218    pub(crate) fn remove_connection(&self, id: u64) {
219        self.connections.remove(&id);
220        self.runtime.remove_connection_id(id);
221    }
222
223    #[inline]
224    pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
225        if let Some(room) = self.rooms.get(room_name).map(|entry| entry.clone()) {
226            room.remove(&connection_id);
227            if room.is_empty() {
228                self.rooms.remove(room_name);
229            }
230        }
231    }
232
233    // Public methods
234    #[inline]
235    pub fn connection_count(&self) -> usize {
236        self.connections.len()
237    }
238
239    pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
240        WsIoServerNamespaceBroadcastOperator::new(self.clone())
241            .emit(event, data)
242            .await
243    }
244
245    #[inline]
246    pub fn except<I: IntoIterator<Item = S>, S: AsRef<str>>(
247        self: &Arc<Self>,
248        room_names: I,
249    ) -> WsIoServerNamespaceBroadcastOperator {
250        WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
251    }
252
253    #[inline]
254    pub fn path(&self) -> &str {
255        &self.config.path
256    }
257
258    #[inline]
259    pub fn server(&self) -> WsIoServer {
260        WsIoServer(self.runtime.clone())
261    }
262
263    pub async fn shutdown(self: &Arc<Self>) {
264        match self.status.get() {
265            NamespaceStatus::Stopped => return,
266            NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
267            _ => unreachable!(),
268        }
269
270        let _ = WsIoServerNamespaceBroadcastOperator::new(self.clone())
271            .disconnect()
272            .await;
273
274        let mut connection_task_set = self.connection_task_set.lock().await;
275        while connection_task_set.join_next().await.is_some() {}
276
277        self.status.store(NamespaceStatus::Stopped);
278    }
279
280    #[inline]
281    pub fn to<I: IntoIterator<Item = S>, S: AsRef<str>>(
282        self: &Arc<Self>,
283        room_names: I,
284    ) -> WsIoServerNamespaceBroadcastOperator {
285        WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
286    }
287}