wsio_server/namespace/
mod.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use futures_util::{
5    SinkExt,
6    StreamExt,
7};
8use http::{
9    HeaderMap,
10    Uri,
11};
12use hyper::upgrade::{
13    OnUpgrade,
14    Upgraded,
15};
16use hyper_util::rt::TokioIo;
17use kikiutils::{
18    atomic::enum_cell::AtomicEnumCell,
19    types::fx_collections::{
20        FxDashMap,
21        FxDashSet,
22    },
23};
24use num_enum::{
25    IntoPrimitive,
26    TryFromPrimitive,
27};
28use serde::Serialize;
29use tokio::{
30    join,
31    select,
32    spawn,
33    sync::Mutex,
34    task::JoinSet,
35};
36use tokio_tungstenite::{
37    WebSocketStream,
38    tungstenite::{
39        Message,
40        protocol::Role,
41    },
42};
43
44pub(crate) mod builder;
45mod config;
46pub mod operators;
47
48use self::{
49    config::WsIoServerNamespaceConfig,
50    operators::broadcast::WsIoServerNamespaceBroadcastOperator,
51};
52use crate::{
53    WsIoServer,
54    connection::WsIoServerConnection,
55    core::packet::WsIoPacket,
56    runtime::{
57        WsIoServerRuntime,
58        WsIoServerRuntimeStatus,
59    },
60};
61
62// Enums
63#[repr(u8)]
64#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
65enum NamespaceStatus {
66    Running,
67    Stopped,
68    Stopping,
69}
70
71// Structs
72pub struct WsIoServerNamespace {
73    pub(crate) config: WsIoServerNamespaceConfig,
74    connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
75    connection_task_set: Mutex<JoinSet<()>>,
76    rooms: FxDashMap<String, Arc<FxDashSet<u64>>>,
77    runtime: Arc<WsIoServerRuntime>,
78    status: AtomicEnumCell<NamespaceStatus>,
79}
80
81impl WsIoServerNamespace {
82    fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
83        Arc::new(Self {
84            config,
85            connections: FxDashMap::default(),
86            connection_task_set: Mutex::new(JoinSet::new()),
87            rooms: FxDashMap::default(),
88            runtime,
89            status: AtomicEnumCell::new(NamespaceStatus::Running),
90        })
91    }
92
93    // Private methods
94    async fn handle_upgraded_request(
95        self: &Arc<Self>,
96        headers: HeaderMap,
97        request_uri: Uri,
98        upgraded: Upgraded,
99    ) -> Result<()> {
100        // Create ws stream
101        let mut ws_stream =
102            WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
103                .await;
104
105        // Check runtime and namespace status
106        if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
107            ws_stream
108                .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
109                .await?;
110
111            let _ = ws_stream.close(None).await;
112            return Ok(());
113        }
114
115        // Create connection
116        let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
117
118        // Split ws stream and spawn read and write tasks
119        let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
120        let connection_clone = connection.clone();
121        let mut read_ws_stream_task = spawn(async move {
122            while let Some(message) = ws_stream_reader.next().await {
123                if match message {
124                    Ok(Message::Binary(bytes)) => {
125                        // Treat any single-byte binary frame as a client heartbeat and ignore it
126                        if bytes.len() == 1 {
127                            continue;
128                        }
129
130                        connection_clone.handle_incoming_packet(&bytes).await
131                    }
132                    Ok(Message::Close(_)) => break,
133                    Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
134                    Err(_) => break,
135                    _ => Ok(()),
136                }
137                .is_err()
138                {
139                    break;
140                }
141            }
142        });
143
144        let mut write_ws_stream_task = spawn(async move {
145            while let Some(message) = message_rx.recv().await {
146                let message = (*message).clone();
147                let is_close = matches!(message, Message::Close(_));
148                if ws_stream_writer.send(message).await.is_err() {
149                    break;
150                }
151
152                if is_close {
153                    let _ = ws_stream_writer.close().await;
154                    break;
155                }
156            }
157        });
158
159        // Try to init connection
160        match connection.init().await {
161            Ok(_) => {
162                // Wait for either read or write task to finish
163                select! {
164                    _ = &mut read_ws_stream_task => {
165                        write_ws_stream_task.abort();
166                    },
167                    _ = &mut write_ws_stream_task => {
168                        read_ws_stream_task.abort();
169                    },
170                }
171            }
172            Err(_) => {
173                // Close connection
174                read_ws_stream_task.abort();
175                connection.close();
176                let _ = join!(read_ws_stream_task, write_ws_stream_task);
177            }
178        }
179
180        // Cleanup connection
181        connection.cleanup().await;
182        Ok(())
183    }
184
185    // Protected methods
186    #[inline]
187    pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
188        self.rooms
189            .entry(room_name.into())
190            .or_default()
191            .clone()
192            .insert(connection_id);
193    }
194
195    #[inline]
196    pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
197        let bytes = self.config.packet_codec.encode(packet)?;
198        Ok(Arc::new(match self.config.packet_codec.is_text() {
199            true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
200            false => Message::Binary(bytes.into()),
201        }))
202    }
203
204    pub(crate) async fn handle_on_upgrade_request(
205        self: &Arc<Self>,
206        headers: HeaderMap,
207        on_upgrade: OnUpgrade,
208        request_uri: Uri,
209    ) {
210        let namespace = self.clone();
211        self.connection_task_set.lock().await.spawn(async move {
212            if let Ok(upgraded) = on_upgrade.await {
213                let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
214            }
215        });
216    }
217
218    #[inline]
219    pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
220        self.connections.insert(connection.id(), connection.clone());
221        self.runtime.insert_connection_id(connection.id());
222    }
223
224    #[inline]
225    pub(crate) fn remove_connection(&self, id: u64) {
226        self.connections.remove(&id);
227        self.runtime.remove_connection_id(id);
228    }
229
230    #[inline]
231    pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
232        if let Some(room) = self.rooms.get(room_name).map(|entry| entry.clone()) {
233            room.remove(&connection_id);
234            if room.is_empty() {
235                self.rooms.remove(room_name);
236            }
237        }
238    }
239
240    // Public methods
241    pub async fn close_all(self: &Arc<Self>) {
242        WsIoServerNamespaceBroadcastOperator::new(self.clone()).close().await;
243    }
244
245    #[inline]
246    pub fn connection_count(&self) -> usize {
247        self.connections.len()
248    }
249
250    pub async fn disconnect_all(self: &Arc<Self>) -> Result<()> {
251        WsIoServerNamespaceBroadcastOperator::new(self.clone())
252            .disconnect()
253            .await
254    }
255
256    pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
257        WsIoServerNamespaceBroadcastOperator::new(self.clone())
258            .emit(event, data)
259            .await
260    }
261
262    #[inline]
263    pub fn except(
264        self: &Arc<Self>,
265        room_names: impl IntoIterator<Item = impl Into<String>>,
266    ) -> WsIoServerNamespaceBroadcastOperator {
267        WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
268    }
269
270    #[inline]
271    pub fn path(&self) -> &str {
272        &self.config.path
273    }
274
275    #[inline]
276    pub fn server(&self) -> WsIoServer {
277        WsIoServer(self.runtime.clone())
278    }
279
280    pub async fn shutdown(self: &Arc<Self>) {
281        match self.status.get() {
282            NamespaceStatus::Stopped => return,
283            NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
284            _ => unreachable!(),
285        }
286
287        self.close_all().await;
288        let mut connection_task_set = self.connection_task_set.lock().await;
289        while connection_task_set.join_next().await.is_some() {}
290
291        self.status.store(NamespaceStatus::Stopped);
292    }
293
294    #[inline]
295    pub fn to(
296        self: &Arc<Self>,
297        room_names: impl IntoIterator<Item = impl Into<String>>,
298    ) -> WsIoServerNamespaceBroadcastOperator {
299        WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
300    }
301}