1use std::sync::Arc;
2
3use anyhow::Result;
4use futures_util::{
5 SinkExt,
6 StreamExt,
7};
8use http::{
9 HeaderMap,
10 Uri,
11};
12use hyper::upgrade::{
13 OnUpgrade,
14 Upgraded,
15};
16use hyper_util::rt::TokioIo;
17use num_enum::{
18 IntoPrimitive,
19 TryFromPrimitive,
20};
21use serde::Serialize;
22use tokio::{
23 join,
24 select,
25 spawn,
26 sync::Mutex,
27 task::JoinSet,
28};
29use tokio_tungstenite::{
30 WebSocketStream,
31 tungstenite::{
32 Message,
33 protocol::Role,
34 },
35};
36
37pub(crate) mod builder;
38mod config;
39pub mod operators;
40
41use self::{
42 config::WsIoServerNamespaceConfig,
43 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
44};
45use crate::{
46 WsIoServer,
47 connection::WsIoServerConnection,
48 core::{
49 atomic::status::AtomicStatus,
50 packet::WsIoPacket,
51 types::hashers::{
52 FxDashMap,
53 FxDashSet,
54 },
55 },
56 runtime::{
57 WsIoServerRuntime,
58 WsIoServerRuntimeStatus,
59 },
60};
61
62#[repr(u8)]
64#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
65enum NamespaceStatus {
66 Running,
67 Stopped,
68 Stopping,
69}
70
71pub struct WsIoServerNamespace {
73 pub(crate) config: WsIoServerNamespaceConfig,
74 connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
75 connection_task_set: Mutex<JoinSet<()>>,
76 rooms: FxDashMap<String, Arc<FxDashSet<u64>>>,
77 runtime: Arc<WsIoServerRuntime>,
78 status: AtomicStatus<NamespaceStatus>,
79}
80
81impl WsIoServerNamespace {
82 fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
83 Arc::new(Self {
84 config,
85 connections: FxDashMap::default(),
86 connection_task_set: Mutex::new(JoinSet::new()),
87 rooms: FxDashMap::default(),
88 runtime,
89 status: AtomicStatus::new(NamespaceStatus::Running),
90 })
91 }
92
93 async fn handle_upgraded_request(
95 self: &Arc<Self>,
96 headers: HeaderMap,
97 request_uri: Uri,
98 upgraded: Upgraded,
99 ) -> Result<()> {
100 let mut ws_stream =
102 WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
103 .await;
104
105 if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
107 ws_stream
108 .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
109 .await?;
110
111 let _ = ws_stream.close(None).await;
112 return Ok(());
113 }
114
115 let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
117
118 let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
120 let connection_clone = connection.clone();
121 let mut read_ws_stream_task = spawn(async move {
122 while let Some(message) = ws_stream_reader.next().await {
123 if match message {
124 Ok(Message::Binary(bytes)) => connection_clone.handle_incoming_packet(&bytes).await,
125 Ok(Message::Close(_)) => break,
126 Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
127 Err(_) => break,
128 _ => Ok(()),
129 }
130 .is_err()
131 {
132 break;
133 }
134 }
135 });
136
137 let mut write_ws_stream_task = spawn(async move {
138 while let Some(message) = message_rx.recv().await {
139 let message = (*message).clone();
140 let is_close = matches!(message, Message::Close(_));
141 if ws_stream_writer.send(message).await.is_err() {
142 break;
143 }
144
145 if is_close {
146 let _ = ws_stream_writer.close().await;
147 break;
148 }
149 }
150 });
151
152 match connection.init().await {
154 Ok(_) => {
155 select! {
157 _ = &mut read_ws_stream_task => {
158 write_ws_stream_task.abort();
159 },
160 _ = &mut write_ws_stream_task => {
161 read_ws_stream_task.abort();
162 },
163 }
164 }
165 Err(_) => {
166 read_ws_stream_task.abort();
168 connection.close();
169 let _ = join!(read_ws_stream_task, write_ws_stream_task);
170 }
171 }
172
173 connection.cleanup().await;
175 Ok(())
176 }
177
178 #[inline]
180 pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
181 self.rooms
182 .entry(room_name.into())
183 .or_default()
184 .clone()
185 .insert(connection_id);
186 }
187
188 #[inline]
189 pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
190 let bytes = self.config.packet_codec.encode(packet)?;
191 Ok(Arc::new(match self.config.packet_codec.is_text() {
192 true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
193 false => Message::Binary(bytes.into()),
194 }))
195 }
196
197 pub(crate) async fn handle_on_upgrade_request(
198 self: &Arc<Self>,
199 headers: HeaderMap,
200 on_upgrade: OnUpgrade,
201 request_uri: Uri,
202 ) {
203 let namespace = self.clone();
204 self.connection_task_set.lock().await.spawn(async move {
205 if let Ok(upgraded) = on_upgrade.await {
206 let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
207 }
208 });
209 }
210
211 #[inline]
212 pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
213 self.connections.insert(connection.id(), connection.clone());
214 self.runtime.insert_connection_id(connection.id());
215 }
216
217 #[inline]
218 pub(crate) fn remove_connection(&self, id: u64) {
219 self.connections.remove(&id);
220 self.runtime.remove_connection_id(id);
221 }
222
223 #[inline]
224 pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
225 if let Some(room) = self.rooms.get(room_name).map(|entry| entry.clone()) {
226 room.remove(&connection_id);
227 if room.is_empty() {
228 self.rooms.remove(room_name);
229 }
230 }
231 }
232
233 #[inline]
235 pub fn connection_count(&self) -> usize {
236 self.connections.len()
237 }
238
239 pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
240 WsIoServerNamespaceBroadcastOperator::new(self.clone())
241 .emit(event, data)
242 .await
243 }
244
245 #[inline]
246 pub fn except(
247 self: &Arc<Self>,
248 room_names: impl IntoIterator<Item = impl AsRef<str>>,
249 ) -> WsIoServerNamespaceBroadcastOperator {
250 WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
251 }
252
253 #[inline]
254 pub fn path(&self) -> &str {
255 &self.config.path
256 }
257
258 #[inline]
259 pub fn server(&self) -> WsIoServer {
260 WsIoServer(self.runtime.clone())
261 }
262
263 pub async fn shutdown(self: &Arc<Self>) {
264 match self.status.get() {
265 NamespaceStatus::Stopped => return,
266 NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
267 _ => unreachable!(),
268 }
269
270 let _ = WsIoServerNamespaceBroadcastOperator::new(self.clone())
271 .disconnect()
272 .await;
273
274 let mut connection_task_set = self.connection_task_set.lock().await;
275 while connection_task_set.join_next().await.is_some() {}
276
277 self.status.store(NamespaceStatus::Stopped);
278 }
279
280 #[inline]
281 pub fn to(
282 self: &Arc<Self>,
283 room_names: impl IntoIterator<Item = impl AsRef<str>>,
284 ) -> WsIoServerNamespaceBroadcastOperator {
285 WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
286 }
287}