wsio_server/namespace/
mod.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use dashmap::DashMap;
5use futures_util::{
6 SinkExt,
7 StreamExt,
8 future::join_all,
9};
10use http::HeaderMap;
11use hyper::upgrade::{
12 OnUpgrade,
13 Upgraded,
14};
15use hyper_util::rt::TokioIo;
16use num_enum::{
17 IntoPrimitive,
18 TryFromPrimitive,
19};
20use serde::Serialize;
21use tokio::{
22 join,
23 select,
24 spawn,
25 sync::Mutex,
26 task::JoinSet,
27};
28use tokio_tungstenite::{
29 WebSocketStream,
30 tungstenite::{
31 Message,
32 protocol::Role,
33 },
34};
35
36pub(crate) mod builder;
37mod config;
38
39use self::config::WsIoServerNamespaceConfig;
40use crate::{
41 WsIoServer,
42 connection::WsIoServerConnection,
43 core::{
44 atomic::status::AtomicStatus,
45 packet::WsIoPacket,
46 },
47 runtime::{
48 WsIoServerRuntime,
49 WsIoServerRuntimeStatus,
50 },
51};
52
53#[repr(u8)]
55#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
56enum NamespaceStatus {
57 Running,
58 Stopped,
59 Stopping,
60}
61
62pub struct WsIoServerNamespace {
64 pub(crate) config: WsIoServerNamespaceConfig,
65 connections: DashMap<u64, Arc<WsIoServerConnection>>,
66 connection_task_set: Mutex<JoinSet<()>>,
67 runtime: Arc<WsIoServerRuntime>,
68 status: AtomicStatus<NamespaceStatus>,
69}
70
71impl WsIoServerNamespace {
72 fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
73 Arc::new(Self {
74 config,
75 connections: DashMap::new(),
76 connection_task_set: Mutex::new(JoinSet::new()),
77 runtime,
78 status: AtomicStatus::new(NamespaceStatus::Running),
79 })
80 }
81
82 async fn handle_upgraded_request(self: &Arc<Self>, headers: HeaderMap, upgraded: Upgraded) -> Result<()> {
84 let mut ws_stream =
86 WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
87 .await;
88
89 if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
91 ws_stream
92 .send(self.encode_packet_to_message(&WsIoPacket::new_disconnect())?)
93 .await?;
94
95 let _ = ws_stream.close(None).await;
96 return Ok(());
97 }
98
99 let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone());
101
102 let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
104 let connection_clone = connection.clone();
105 let mut read_ws_stream_task = spawn(async move {
106 while let Some(message) = ws_stream_reader.next().await {
107 if match message {
108 Ok(Message::Binary(bytes)) => connection_clone.handle_incoming_packet(&bytes).await,
109 Ok(Message::Close(_)) => break,
110 Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
111 Err(_) => break,
112 _ => Ok(()),
113 }
114 .is_err()
115 {
116 break;
117 }
118 }
119 });
120
121 let mut write_ws_stream_task = spawn(async move {
122 while let Some(message) = message_rx.recv().await {
123 let is_close = matches!(message, Message::Close(_));
124 if ws_stream_writer.send(message).await.is_err() {
125 break;
126 }
127
128 if is_close {
129 let _ = ws_stream_writer.close().await;
130 break;
131 }
132 }
133 });
134
135 match connection.init().await {
137 Ok(_) => {
138 select! {
140 _ = &mut read_ws_stream_task => {
141 write_ws_stream_task.abort();
142 },
143 _ = &mut write_ws_stream_task => {
144 read_ws_stream_task.abort();
145 },
146 }
147 }
148 Err(_) => {
149 read_ws_stream_task.abort();
151 connection.close();
152 let _ = join!(read_ws_stream_task, write_ws_stream_task);
153 }
154 }
155
156 connection.cleanup().await;
158 Ok(())
159 }
160
161 #[inline]
163 pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Message> {
164 let bytes = self.config.packet_codec.encode(packet)?;
165 Ok(match self.config.packet_codec.is_text() {
166 true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
167 false => Message::Binary(bytes.into()),
168 })
169 }
170
171 pub(crate) async fn handle_on_upgrade_request(self: &Arc<Self>, headers: HeaderMap, on_upgrade: OnUpgrade) {
172 let namespace = self.clone();
173 self.connection_task_set.lock().await.spawn(async move {
174 if let Ok(upgraded) = on_upgrade.await {
175 let _ = namespace.handle_upgraded_request(headers, upgraded).await;
176 }
177 });
178 }
179
180 #[inline]
181 pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
182 self.connections.insert(connection.id(), connection.clone());
183 self.runtime.insert_connection(&connection);
184 }
185
186 #[inline]
187 pub(crate) fn remove_connection(&self, id: u64) {
188 self.connections.remove(&id);
189 self.runtime.remove_connection(id);
190 }
191
192 #[inline]
194 pub fn connection_count(&self) -> usize {
195 self.connections.len()
196 }
197
198 pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
199 self.status.ensure(NamespaceStatus::Running, |status| {
200 format!("Cannot emit event in invalid status: {:#?}", status)
201 })?;
202
203 let message = self.encode_packet_to_message(&WsIoPacket::new_event(
204 event,
205 data.map(|data| self.config.packet_codec.encode_data(data))
206 .transpose()?,
207 ))?;
208
209 join_all(self.connections.iter().map(|entry| {
210 let connection = entry.value().clone();
211 let message = message.clone();
212 async move { connection.emit_message(message).await }
213 }))
214 .await;
215
216 Ok(())
217 }
218
219 #[inline]
220 pub fn path(&self) -> &str {
221 &self.config.path
222 }
223
224 #[inline]
225 pub fn server(&self) -> WsIoServer {
226 WsIoServer(self.runtime.clone())
227 }
228
229 pub async fn shutdown(&self) {
230 match self.status.get() {
231 NamespaceStatus::Stopped => return,
232 NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
233 _ => unreachable!(),
234 }
235
236 join_all(self.connections.iter().map(|entry| {
237 let connection = entry.value().clone();
238 async move { connection.disconnect().await }
239 }))
240 .await;
241
242 let mut connection_task_set = self.connection_task_set.lock().await;
243 while connection_task_set.join_next().await.is_some() {}
244
245 self.status.store(NamespaceStatus::Stopped);
246 }
247}