1use std::sync::{
2 Arc,
3 LazyLock,
4 atomic::{
5 AtomicU64,
6 Ordering,
7 },
8};
9
10use anyhow::{
11 Result,
12 bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16 HeaderMap,
17 Uri,
18};
19use num_enum::{
20 IntoPrimitive,
21 TryFromPrimitive,
22};
23use serde::{
24 Serialize,
25 de::DeserializeOwned,
26};
27use tokio::{
28 spawn,
29 sync::{
30 Mutex,
31 mpsc::{
32 Receiver,
33 Sender,
34 channel,
35 },
36 },
37 task::JoinHandle,
38 time::{
39 sleep,
40 timeout,
41 },
42};
43use tokio_tungstenite::tungstenite::Message;
44use tokio_util::sync::CancellationToken;
45
46#[cfg(feature = "connection-extensions")]
47mod extensions;
48
49#[cfg(feature = "connection-extensions")]
50use self::extensions::ConnectionExtensions;
51use crate::{
52 WsIoServer,
53 core::{
54 atomic::status::AtomicStatus,
55 channel_capacity_from_websocket_config,
56 event::registry::WsIoEventRegistry,
57 packet::{
58 WsIoPacket,
59 WsIoPacketType,
60 },
61 traits::task::spawner::TaskSpawner,
62 types::{
63 BoxAsyncUnaryResultHandler,
64 hashers::FxDashSet,
65 },
66 utils::task::abort_locked_task,
67 },
68 namespace::{
69 WsIoServerNamespace,
70 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
71 },
72};
73
74#[repr(u8)]
76#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
77enum ConnectionStatus {
78 Activating,
79 AwaitingInit,
80 Closed,
81 Closing,
82 Created,
83 Initiating,
84 Ready,
85}
86
87pub struct WsIoServerConnection {
89 cancel_token: ArcSwap<CancellationToken>,
90 event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
91 #[cfg(feature = "connection-extensions")]
92 extensions: ConnectionExtensions,
93 headers: HeaderMap,
94 id: u64,
95 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
96 joined_rooms: FxDashSet<String>,
97 message_tx: Sender<Arc<Message>>,
98 namespace: Arc<WsIoServerNamespace>,
99 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
100 request_uri: Uri,
101 status: AtomicStatus<ConnectionStatus>,
102}
103
104impl TaskSpawner for WsIoServerConnection {
105 #[inline]
106 fn cancel_token(&self) -> Arc<CancellationToken> {
107 self.cancel_token.load_full()
108 }
109}
110
111impl WsIoServerConnection {
112 #[inline]
113 pub(crate) fn new(
114 headers: HeaderMap,
115 namespace: Arc<WsIoServerNamespace>,
116 request_uri: Uri,
117 ) -> (Arc<Self>, Receiver<Arc<Message>>) {
118 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
119 let (message_tx, message_rx) = channel(channel_capacity);
120 (
121 Arc::new(Self {
122 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
123 event_registry: WsIoEventRegistry::new(),
124 #[cfg(feature = "connection-extensions")]
125 extensions: ConnectionExtensions::new(),
126 headers,
127 id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
128 init_timeout_task: Mutex::new(None),
129 joined_rooms: FxDashSet::default(),
130 message_tx,
131 namespace,
132 on_close_handler: Mutex::new(None),
133 request_uri,
134 status: AtomicStatus::new(ConnectionStatus::Created),
135 }),
136 message_rx,
137 )
138 }
139
140 #[inline]
142 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
143 self.event_registry.dispatch_event_packet(
144 self.clone(),
145 event,
146 &self.namespace.config.packet_codec,
147 packet_data,
148 self,
149 );
150
151 Ok(())
152 }
153
154 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
155 let status = self.status.get();
157 match status {
158 ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
159 _ => bail!("Received init packet in invalid status: {status:?}"),
160 }
161
162 abort_locked_task(&self.init_timeout_task).await;
164
165 if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
167 timeout(
168 self.namespace.config.init_response_handler_timeout,
169 init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
170 )
171 .await??
172 }
173
174 self.status
176 .try_transition(ConnectionStatus::Initiating, ConnectionStatus::Activating)?;
177
178 if let Some(middleware) = &self.namespace.config.middleware {
180 timeout(
181 self.namespace.config.middleware_execution_timeout,
182 middleware(self.clone()),
183 )
184 .await??;
185
186 self.status.ensure(ConnectionStatus::Activating, |status| {
188 format!("Cannot activate connection in invalid status: {status:?}")
189 })?;
190 }
191
192 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
194 timeout(
195 self.namespace.config.on_connect_handler_timeout,
196 on_connect_handler(self.clone()),
197 )
198 .await??;
199 }
200
201 self.status
203 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
204
205 self.namespace.insert_connection(self.clone());
207
208 self.send_packet(&WsIoPacket::new_ready()).await?;
210
211 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
213 self.spawn_task(on_ready_handler(self.clone()));
215 }
216
217 Ok(())
218 }
219
220 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
221 self.send_message(self.namespace.encode_packet_to_message(packet)?)
222 .await
223 }
224
225 pub(crate) async fn cleanup(self: &Arc<Self>) {
227 self.status.store(ConnectionStatus::Closing);
229
230 self.namespace.remove_connection(self.id);
232
233 let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
235 for room_name in &joined_rooms {
236 self.namespace.remove_connection_id_from_room(room_name, self.id);
237 }
238
239 self.joined_rooms.clear();
240
241 abort_locked_task(&self.init_timeout_task).await;
243
244 self.cancel_token.load().cancel();
246
247 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
249 let _ = timeout(
250 self.namespace.config.on_close_handler_timeout,
251 on_close_handler(self.clone()),
252 )
253 .await;
254 }
255
256 self.status.store(ConnectionStatus::Closed);
258 }
259
260 #[inline]
261 pub(crate) fn close(&self) {
262 match self.status.get() {
264 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
265 _ => self.status.store(ConnectionStatus::Closing),
266 }
267
268 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
270 }
271
272 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
273 self.status.ensure(ConnectionStatus::Ready, |status| {
274 format!("Cannot emit in invalid status: {status:?}")
275 })?;
276
277 self.send_message(message).await
278 }
279
280 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
281 let packet = self.namespace.config.packet_codec.decode(encoded_packet)?;
283 match packet.r#type {
284 WsIoPacketType::Event => {
285 if let Some(event) = packet.key.as_deref() {
286 self.handle_event_packet(event, packet.data)
287 } else {
288 bail!("Event packet missing key");
289 }
290 }
291 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
292 _ => Ok(()),
293 }
294 }
295
296 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
297 self.status.ensure(ConnectionStatus::Created, |status| {
299 format!("Cannot init connection in invalid status: {status:?}")
300 })?;
301
302 let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
304 timeout(
305 self.namespace.config.init_request_handler_timeout,
306 init_request_handler(self.clone(), &self.namespace.config.packet_codec),
307 )
308 .await??
309 } else {
310 None
311 };
312
313 self.status
315 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingInit)?;
316
317 let connection = self.clone();
319 *self.init_timeout_task.lock().await = Some(spawn(async move {
320 sleep(connection.namespace.config.init_response_timeout).await;
321 if connection.status.is(ConnectionStatus::AwaitingInit) {
322 connection.close();
323 }
324 }));
325
326 self.send_packet(&WsIoPacket::new_init(init_request_data)).await
328 }
329
330 pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
331 Ok(self.message_tx.send(message).await?)
332 }
333
334 pub async fn disconnect(&self) {
336 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
337 self.close()
338 }
339
340 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
341 self.emit_event_message(
342 self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
343 event.as_ref(),
344 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
345 .transpose()?,
346 ))?,
347 )
348 .await
349 }
350
351 #[inline]
352 pub fn except(
353 self: &Arc<Self>,
354 room_names: impl IntoIterator<Item = impl AsRef<str>>,
355 ) -> WsIoServerNamespaceBroadcastOperator {
356 self.namespace.except(room_names).except_connection_ids(vec![self.id])
357 }
358
359 #[cfg(feature = "connection-extensions")]
360 #[inline]
361 pub fn extensions(&self) -> &ConnectionExtensions {
362 &self.extensions
363 }
364
365 #[inline]
366 pub fn headers(&self) -> &HeaderMap {
367 &self.headers
368 }
369
370 #[inline]
371 pub fn id(&self) -> u64 {
372 self.id
373 }
374
375 #[inline]
376 pub fn join(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
377 for room_name in room_names {
378 let room_name = room_name.as_ref();
379 self.namespace.add_connection_id_to_room(room_name, self.id);
380 self.joined_rooms.insert(room_name.into());
381 }
382 }
383
384 #[inline]
385 pub fn leave(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
386 for room_name in room_names {
387 self.namespace
388 .remove_connection_id_from_room(room_name.as_ref(), self.id);
389
390 self.joined_rooms.remove(room_name.as_ref());
391 }
392 }
393
394 #[inline]
395 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
396 self.namespace.clone()
397 }
398
399 #[inline]
400 pub fn off(&self, event: impl AsRef<str>) {
401 self.event_registry.off(event.as_ref());
402 }
403
404 #[inline]
405 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
406 self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
407 }
408
409 #[inline]
410 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
411 where
412 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
413 Fut: Future<Output = Result<()>> + Send + 'static,
414 D: DeserializeOwned + Send + Sync + 'static,
415 {
416 self.event_registry.on(event.as_ref(), handler)
417 }
418
419 pub async fn on_close<H, Fut>(&self, handler: H)
420 where
421 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
422 Fut: Future<Output = Result<()>> + Send + 'static,
423 {
424 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
425 }
426
427 #[inline]
428 pub fn request_uri(&self) -> &Uri {
429 &self.request_uri
430 }
431
432 #[inline]
433 pub fn server(&self) -> WsIoServer {
434 self.namespace.server()
435 }
436
437 #[inline]
438 pub fn to(
439 self: &Arc<Self>,
440 room_names: impl IntoIterator<Item = impl AsRef<str>>,
441 ) -> WsIoServerNamespaceBroadcastOperator {
442 self.namespace.to(room_names).except_connection_ids(vec![self.id])
443 }
444}
445
446static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));