1use std::sync::Arc;
2
3use anyhow::Result;
4use futures_util::{
5 SinkExt,
6 StreamExt,
7};
8use http::HeaderMap;
9use hyper::upgrade::{
10 OnUpgrade,
11 Upgraded,
12};
13use hyper_util::rt::TokioIo;
14use num_enum::{
15 IntoPrimitive,
16 TryFromPrimitive,
17};
18use serde::Serialize;
19use tokio::{
20 join,
21 select,
22 spawn,
23 sync::Mutex,
24 task::JoinSet,
25};
26use tokio_tungstenite::{
27 WebSocketStream,
28 tungstenite::{
29 Message,
30 protocol::Role,
31 },
32};
33
34pub(crate) mod builder;
35mod config;
36pub mod operators;
37
38use self::{
39 config::WsIoServerNamespaceConfig,
40 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
41};
42use crate::{
43 WsIoServer,
44 connection::WsIoServerConnection,
45 core::{
46 atomic::status::AtomicStatus,
47 packet::WsIoPacket,
48 types::hashers::{
49 FxDashMap,
50 FxDashSet,
51 },
52 },
53 runtime::{
54 WsIoServerRuntime,
55 WsIoServerRuntimeStatus,
56 },
57};
58
59#[repr(u8)]
61#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
62enum NamespaceStatus {
63 Running,
64 Stopped,
65 Stopping,
66}
67
68pub struct WsIoServerNamespace {
70 pub(crate) config: WsIoServerNamespaceConfig,
71 connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
72 connection_task_set: Mutex<JoinSet<()>>,
73 rooms: FxDashMap<String, Arc<FxDashSet<u64>>>,
74 runtime: Arc<WsIoServerRuntime>,
75 status: AtomicStatus<NamespaceStatus>,
76}
77
78impl WsIoServerNamespace {
79 fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
80 Arc::new(Self {
81 config,
82 connections: FxDashMap::default(),
83 connection_task_set: Mutex::new(JoinSet::new()),
84 rooms: FxDashMap::default(),
85 runtime,
86 status: AtomicStatus::new(NamespaceStatus::Running),
87 })
88 }
89
90 async fn handle_upgraded_request(self: &Arc<Self>, headers: HeaderMap, upgraded: Upgraded) -> Result<()> {
92 let mut ws_stream =
94 WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
95 .await;
96
97 if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
99 ws_stream
100 .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
101 .await?;
102
103 let _ = ws_stream.close(None).await;
104 return Ok(());
105 }
106
107 let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone());
109
110 let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
112 let connection_clone = connection.clone();
113 let mut read_ws_stream_task = spawn(async move {
114 while let Some(message) = ws_stream_reader.next().await {
115 if match message {
116 Ok(Message::Binary(bytes)) => connection_clone.handle_incoming_packet(&bytes).await,
117 Ok(Message::Close(_)) => break,
118 Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
119 Err(_) => break,
120 _ => Ok(()),
121 }
122 .is_err()
123 {
124 break;
125 }
126 }
127 });
128
129 let mut write_ws_stream_task = spawn(async move {
130 while let Some(message) = message_rx.recv().await {
131 let message = (*message).clone();
132 let is_close = matches!(message, Message::Close(_));
133 if ws_stream_writer.send(message).await.is_err() {
134 break;
135 }
136
137 if is_close {
138 let _ = ws_stream_writer.close().await;
139 break;
140 }
141 }
142 });
143
144 match connection.init().await {
146 Ok(_) => {
147 select! {
149 _ = &mut read_ws_stream_task => {
150 write_ws_stream_task.abort();
151 },
152 _ = &mut write_ws_stream_task => {
153 read_ws_stream_task.abort();
154 },
155 }
156 }
157 Err(_) => {
158 read_ws_stream_task.abort();
160 connection.close();
161 let _ = join!(read_ws_stream_task, write_ws_stream_task);
162 }
163 }
164
165 connection.cleanup().await;
167 Ok(())
168 }
169
170 #[inline]
172 pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
173 self.rooms
174 .entry(room_name.to_string())
175 .or_default()
176 .clone()
177 .insert(connection_id);
178 }
179
180 #[inline]
181 pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
182 let bytes = self.config.packet_codec.encode(packet)?;
183 Ok(Arc::new(match self.config.packet_codec.is_text() {
184 true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
185 false => Message::Binary(bytes.into()),
186 }))
187 }
188
189 pub(crate) async fn handle_on_upgrade_request(self: &Arc<Self>, headers: HeaderMap, on_upgrade: OnUpgrade) {
190 let namespace = self.clone();
191 self.connection_task_set.lock().await.spawn(async move {
192 if let Ok(upgraded) = on_upgrade.await {
193 let _ = namespace.handle_upgraded_request(headers, upgraded).await;
194 }
195 });
196 }
197
198 #[inline]
199 pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
200 self.connections.insert(connection.id(), connection.clone());
201 self.runtime.insert_connection_id(connection.id());
202 }
203
204 #[inline]
205 pub(crate) fn remove_connection(&self, id: u64) {
206 self.connections.remove(&id);
207 self.runtime.remove_connection_id(id);
208 }
209
210 #[inline]
211 pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
212 if let Some(room) = self.rooms.get(room_name).map(|entry| entry.clone()) {
213 room.remove(&connection_id);
214 if room.is_empty() {
215 self.rooms.remove(room_name);
216 }
217 }
218 }
219
220 #[inline]
222 pub fn connection_count(&self) -> usize {
223 self.connections.len()
224 }
225
226 pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
227 WsIoServerNamespaceBroadcastOperator::new(self.clone())
228 .emit(event, data)
229 .await
230 }
231
232 #[inline]
233 pub fn except<I: IntoIterator<Item = S>, S: AsRef<str>>(
234 self: &Arc<Self>,
235 room_names: I,
236 ) -> WsIoServerNamespaceBroadcastOperator {
237 WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
238 }
239
240 #[inline]
241 pub fn path(&self) -> &str {
242 &self.config.path
243 }
244
245 #[inline]
246 pub fn server(&self) -> WsIoServer {
247 WsIoServer(self.runtime.clone())
248 }
249
250 pub async fn shutdown(self: &Arc<Self>) {
251 match self.status.get() {
252 NamespaceStatus::Stopped => return,
253 NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
254 _ => unreachable!(),
255 }
256
257 let _ = WsIoServerNamespaceBroadcastOperator::new(self.clone())
258 .disconnect()
259 .await;
260
261 let mut connection_task_set = self.connection_task_set.lock().await;
262 while connection_task_set.join_next().await.is_some() {}
263
264 self.status.store(NamespaceStatus::Stopped);
265 }
266
267 #[inline]
268 pub fn to<I: IntoIterator<Item = S>, S: AsRef<str>>(
269 self: &Arc<Self>,
270 room_names: I,
271 ) -> WsIoServerNamespaceBroadcastOperator {
272 WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
273 }
274}