1use std::sync::{
2 Arc,
3 LazyLock,
4 atomic::{
5 AtomicU64,
6 Ordering,
7 },
8};
9
10use anyhow::{
11 Result,
12 bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16 HeaderMap,
17 Uri,
18};
19use num_enum::{
20 IntoPrimitive,
21 TryFromPrimitive,
22};
23use serde::{
24 Serialize,
25 de::DeserializeOwned,
26};
27use tokio::{
28 spawn,
29 sync::{
30 Mutex,
31 mpsc::{
32 Receiver,
33 Sender,
34 channel,
35 },
36 },
37 task::JoinHandle,
38 time::{
39 sleep,
40 timeout,
41 },
42};
43use tokio_tungstenite::tungstenite::Message;
44use tokio_util::sync::CancellationToken;
45
46#[cfg(feature = "connection-extensions")]
47mod extensions;
48
49#[cfg(feature = "connection-extensions")]
50use self::extensions::ConnectionExtensions;
51use crate::{
52 WsIoServer,
53 core::{
54 atomic::status::AtomicStatus,
55 channel_capacity_from_websocket_config,
56 event::registry::WsIoEventRegistry,
57 packet::{
58 WsIoPacket,
59 WsIoPacketType,
60 },
61 traits::task::spawner::TaskSpawner,
62 types::{
63 BoxAsyncUnaryResultHandler,
64 hashers::FxDashSet,
65 },
66 utils::task::abort_locked_task,
67 },
68 namespace::WsIoServerNamespace,
69};
70
71#[repr(u8)]
73#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
74enum ConnectionStatus {
75 Activating,
76 AwaitingInit,
77 Closed,
78 Closing,
79 Created,
80 Initiating,
81 Ready,
82}
83
84pub struct WsIoServerConnection {
86 cancel_token: ArcSwap<CancellationToken>,
87 event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
88 #[cfg(feature = "connection-extensions")]
89 extensions: ConnectionExtensions,
90 headers: HeaderMap,
91 id: u64,
92 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
93 joined_rooms: FxDashSet<String>,
94 message_tx: Sender<Arc<Message>>,
95 namespace: Arc<WsIoServerNamespace>,
96 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
97 request_uri: Uri,
98 status: AtomicStatus<ConnectionStatus>,
99}
100
101impl TaskSpawner for WsIoServerConnection {
102 #[inline]
103 fn cancel_token(&self) -> Arc<CancellationToken> {
104 self.cancel_token.load_full()
105 }
106}
107
108impl WsIoServerConnection {
109 #[inline]
110 pub(crate) fn new(
111 headers: HeaderMap,
112 namespace: Arc<WsIoServerNamespace>,
113 request_uri: Uri,
114 ) -> (Arc<Self>, Receiver<Arc<Message>>) {
115 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
116 let (message_tx, message_rx) = channel(channel_capacity);
117 (
118 Arc::new(Self {
119 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
120 event_registry: WsIoEventRegistry::new(),
121 #[cfg(feature = "connection-extensions")]
122 extensions: ConnectionExtensions::new(),
123 headers,
124 id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
125 init_timeout_task: Mutex::new(None),
126 joined_rooms: FxDashSet::default(),
127 message_tx,
128 namespace,
129 on_close_handler: Mutex::new(None),
130 request_uri,
131 status: AtomicStatus::new(ConnectionStatus::Created),
132 }),
133 message_rx,
134 )
135 }
136
137 #[inline]
139 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
140 self.event_registry.dispatch_event_packet(
141 self.clone(),
142 event,
143 &self.namespace.config.packet_codec,
144 packet_data,
145 self,
146 );
147
148 Ok(())
149 }
150
151 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
152 let status = self.status.get();
154 match status {
155 ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
156 _ => bail!("Received init packet in invalid status: {status:?}"),
157 }
158
159 abort_locked_task(&self.init_timeout_task).await;
161
162 if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
164 timeout(
165 self.namespace.config.init_response_handler_timeout,
166 init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
167 )
168 .await??
169 }
170
171 self.status
173 .try_transition(ConnectionStatus::Initiating, ConnectionStatus::Activating)?;
174
175 if let Some(middleware) = &self.namespace.config.middleware {
177 timeout(
178 self.namespace.config.middleware_execution_timeout,
179 middleware(self.clone()),
180 )
181 .await??;
182
183 self.status.ensure(ConnectionStatus::Activating, |status| {
185 format!("Cannot activate connection in invalid status: {status:?}")
186 })?;
187 }
188
189 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
191 timeout(
192 self.namespace.config.on_connect_handler_timeout,
193 on_connect_handler(self.clone()),
194 )
195 .await??;
196 }
197
198 self.status
200 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
201
202 self.namespace.insert_connection(self.clone());
204
205 self.send_packet(&WsIoPacket::new_ready()).await?;
207
208 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
210 self.spawn_task(on_ready_handler(self.clone()));
212 }
213
214 Ok(())
215 }
216
217 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
218 self.send_message(self.namespace.encode_packet_to_message(packet)?)
219 .await
220 }
221
222 pub(crate) async fn cleanup(self: &Arc<Self>) {
224 self.status.store(ConnectionStatus::Closing);
226
227 self.namespace.remove_connection(self.id);
229
230 let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
232 for room_name in &joined_rooms {
233 self.namespace.remove_connection_id_from_room(room_name, self.id);
234 }
235
236 self.joined_rooms.clear();
237
238 abort_locked_task(&self.init_timeout_task).await;
240
241 self.cancel_token.load().cancel();
243
244 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
246 let _ = timeout(
247 self.namespace.config.on_close_handler_timeout,
248 on_close_handler(self.clone()),
249 )
250 .await;
251 }
252
253 self.status.store(ConnectionStatus::Closed);
255 }
256
257 #[inline]
258 pub(crate) fn close(&self) {
259 match self.status.get() {
261 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
262 _ => self.status.store(ConnectionStatus::Closing),
263 }
264
265 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
267 }
268
269 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
270 self.status.ensure(ConnectionStatus::Ready, |status| {
271 format!("Cannot emit in invalid status: {status:?}")
272 })?;
273
274 self.send_message(message).await
275 }
276
277 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
278 let packet = self.namespace.config.packet_codec.decode(bytes)?;
280 match packet.r#type {
281 WsIoPacketType::Event => {
282 if let Some(event) = packet.key.as_deref() {
283 self.handle_event_packet(event, packet.data)
284 } else {
285 bail!("Event packet missing key");
286 }
287 }
288 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
289 _ => Ok(()),
290 }
291 }
292
293 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
294 self.status.ensure(ConnectionStatus::Created, |status| {
296 format!("Cannot init connection in invalid status: {status:?}")
297 })?;
298
299 let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
301 timeout(
302 self.namespace.config.init_request_handler_timeout,
303 init_request_handler(self.clone(), &self.namespace.config.packet_codec),
304 )
305 .await??
306 } else {
307 None
308 };
309
310 self.status
312 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingInit)?;
313
314 let connection = self.clone();
316 *self.init_timeout_task.lock().await = Some(spawn(async move {
317 sleep(connection.namespace.config.init_response_timeout).await;
318 if connection.status.is(ConnectionStatus::AwaitingInit) {
319 connection.close();
320 }
321 }));
322
323 self.send_packet(&WsIoPacket::new_init(init_request_data)).await
325 }
326
327 pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
328 Ok(self.message_tx.send(message).await?)
329 }
330
331 pub async fn disconnect(&self) {
333 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
334 self.close()
335 }
336
337 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
338 self.emit_event_message(
339 self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
340 event.as_ref(),
341 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
342 .transpose()?,
343 ))?,
344 )
345 .await
346 }
347
348 #[cfg(feature = "connection-extensions")]
349 #[inline]
350 pub fn extensions(&self) -> &ConnectionExtensions {
351 &self.extensions
352 }
353
354 #[inline]
355 pub fn headers(&self) -> &HeaderMap {
356 &self.headers
357 }
358
359 #[inline]
360 pub fn id(&self) -> u64 {
361 self.id
362 }
363
364 #[inline]
365 pub fn join<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
366 for room_name in room_names {
367 let room_name = room_name.as_ref();
368 self.namespace.add_connection_id_to_room(room_name, self.id);
369 self.joined_rooms.insert(room_name.to_string());
370 }
371 }
372
373 #[inline]
374 pub fn leave<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
375 for room_name in room_names {
376 self.namespace
377 .remove_connection_id_from_room(room_name.as_ref(), self.id);
378
379 self.joined_rooms.remove(room_name.as_ref());
380 }
381 }
382
383 #[inline]
384 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
385 self.namespace.clone()
386 }
387
388 #[inline]
389 pub fn off(&self, event: impl AsRef<str>) {
390 self.event_registry.off(event.as_ref());
391 }
392
393 #[inline]
394 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
395 self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
396 }
397
398 #[inline]
399 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
400 where
401 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
402 Fut: Future<Output = Result<()>> + Send + 'static,
403 D: DeserializeOwned + Send + Sync + 'static,
404 {
405 self.event_registry.on(event.as_ref(), handler)
406 }
407
408 pub async fn on_close<H, Fut>(&self, handler: H)
409 where
410 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
411 Fut: Future<Output = Result<()>> + Send + 'static,
412 {
413 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
414 }
415
416 #[inline]
417 pub fn request_uri(&self) -> &Uri {
418 &self.request_uri
419 }
420
421 #[inline]
422 pub fn server(&self) -> WsIoServer {
423 self.namespace.server()
424 }
425}
426
427static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));