wsio_server/namespace/
mod.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use dashmap::DashMap;
5use futures_util::{
6    SinkExt,
7    StreamExt,
8    future::join_all,
9};
10use http::HeaderMap;
11use hyper::upgrade::{
12    OnUpgrade,
13    Upgraded,
14};
15use hyper_util::rt::TokioIo;
16use num_enum::{
17    IntoPrimitive,
18    TryFromPrimitive,
19};
20use serde::Serialize;
21use tokio::{
22    join,
23    select,
24    spawn,
25    sync::Mutex,
26    task::JoinSet,
27};
28use tokio_tungstenite::{
29    WebSocketStream,
30    tungstenite::{
31        Message,
32        protocol::Role,
33    },
34};
35
36pub(crate) mod builder;
37mod config;
38
39use self::config::WsIoServerNamespaceConfig;
40use crate::{
41    WsIoServer,
42    connection::WsIoServerConnection,
43    core::{
44        atomic::status::AtomicStatus,
45        packet::WsIoPacket,
46    },
47    runtime::{
48        WsIoServerRuntime,
49        WsIoServerRuntimeStatus,
50    },
51};
52
53// Enums
54#[repr(u8)]
55#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
56enum NamespaceStatus {
57    Running,
58    Stopped,
59    Stopping,
60}
61
62// Structs
63pub struct WsIoServerNamespace {
64    pub(crate) config: WsIoServerNamespaceConfig,
65    connections: DashMap<u64, Arc<WsIoServerConnection>>,
66    connection_task_set: Mutex<JoinSet<()>>,
67    runtime: Arc<WsIoServerRuntime>,
68    status: AtomicStatus<NamespaceStatus>,
69}
70
71impl WsIoServerNamespace {
72    fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
73        Arc::new(Self {
74            config,
75            connections: DashMap::new(),
76            connection_task_set: Mutex::new(JoinSet::new()),
77            runtime,
78            status: AtomicStatus::new(NamespaceStatus::Running),
79        })
80    }
81
82    // Private methods
83    async fn handle_upgraded_request(self: &Arc<Self>, headers: HeaderMap, upgraded: Upgraded) -> Result<()> {
84        // Create ws stream
85        let mut ws_stream =
86            WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
87                .await;
88
89        // Check runtime and namespace status
90        if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
91            ws_stream
92                .send(self.encode_packet_to_message(&WsIoPacket::new_disconnect())?)
93                .await?;
94
95            let _ = ws_stream.close(None).await;
96            return Ok(());
97        }
98
99        // Create connection
100        let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone());
101
102        // Split ws stream and spawn read and write tasks
103        let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
104        let connection_clone = connection.clone();
105        let mut read_ws_stream_task = spawn(async move {
106            while let Some(message) = ws_stream_reader.next().await {
107                if match message {
108                    Ok(Message::Binary(bytes)) => connection_clone.handle_incoming_packet(&bytes).await,
109                    Ok(Message::Close(_)) => break,
110                    Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
111                    Err(_) => break,
112                    _ => Ok(()),
113                }
114                .is_err()
115                {
116                    break;
117                }
118            }
119        });
120
121        let mut write_ws_stream_task = spawn(async move {
122            while let Some(message) = message_rx.recv().await {
123                let is_close = matches!(message, Message::Close(_));
124                if ws_stream_writer.send(message).await.is_err() {
125                    break;
126                }
127
128                if is_close {
129                    let _ = ws_stream_writer.close().await;
130                    break;
131                }
132            }
133        });
134
135        // Try to init connection
136        match connection.init().await {
137            Ok(_) => {
138                // Wait for either read or write task to finish
139                select! {
140                    _ = &mut read_ws_stream_task => {
141                        write_ws_stream_task.abort();
142                    },
143                    _ = &mut write_ws_stream_task => {
144                        read_ws_stream_task.abort();
145                    },
146                }
147            }
148            Err(_) => {
149                // Close connection
150                read_ws_stream_task.abort();
151                connection.close();
152                let _ = join!(read_ws_stream_task, write_ws_stream_task);
153            }
154        }
155
156        // Cleanup connection
157        connection.cleanup().await;
158        Ok(())
159    }
160
161    // Protected methods
162    #[inline]
163    pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Message> {
164        let bytes = self.config.packet_codec.encode(packet)?;
165        Ok(match self.config.packet_codec.is_text() {
166            true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
167            false => Message::Binary(bytes.into()),
168        })
169    }
170
171    pub(crate) async fn handle_on_upgrade_request(self: &Arc<Self>, headers: HeaderMap, on_upgrade: OnUpgrade) {
172        let namespace = self.clone();
173        self.connection_task_set.lock().await.spawn(async move {
174            if let Ok(upgraded) = on_upgrade.await {
175                let _ = namespace.handle_upgraded_request(headers, upgraded).await;
176            }
177        });
178    }
179
180    #[inline]
181    pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
182        self.connections.insert(connection.id(), connection.clone());
183        self.runtime.insert_connection(&connection);
184    }
185
186    #[inline]
187    pub(crate) fn remove_connection(&self, id: u64) {
188        self.connections.remove(&id);
189        self.runtime.remove_connection(id);
190    }
191
192    // Public methods
193    #[inline]
194    pub fn connection_count(&self) -> usize {
195        self.connections.len()
196    }
197
198    pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
199        self.status.ensure(NamespaceStatus::Running, |status| {
200            format!("Cannot emit event in invalid status: {:#?}", status)
201        })?;
202
203        let message = self.encode_packet_to_message(&WsIoPacket::new_event(
204            event,
205            data.map(|data| self.config.packet_codec.encode_data(data))
206                .transpose()?,
207        ))?;
208
209        join_all(self.connections.iter().map(|entry| {
210            let connection = entry.value().clone();
211            let message = message.clone();
212            async move { connection.emit_message(message).await }
213        }))
214        .await;
215
216        Ok(())
217    }
218
219    #[inline]
220    pub fn path(&self) -> &str {
221        &self.config.path
222    }
223
224    #[inline]
225    pub fn server(&self) -> WsIoServer {
226        WsIoServer(self.runtime.clone())
227    }
228
229    pub async fn shutdown(&self) {
230        match self.status.get() {
231            NamespaceStatus::Stopped => return,
232            NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
233            _ => unreachable!(),
234        }
235
236        join_all(self.connections.iter().map(|entry| {
237            let connection = entry.value().clone();
238            async move { connection.disconnect().await }
239        }))
240        .await;
241
242        let mut connection_task_set = self.connection_task_set.lock().await;
243        while connection_task_set.join_next().await.is_some() {}
244
245        self.status.store(NamespaceStatus::Stopped);
246    }
247}