1use std::sync::{
2 Arc,
3 LazyLock,
4 atomic::{
5 AtomicU64,
6 Ordering,
7 },
8};
9
10use anyhow::{
11 Result,
12 bail,
13};
14use arc_swap::ArcSwap;
15use http::HeaderMap;
16use num_enum::{
17 IntoPrimitive,
18 TryFromPrimitive,
19};
20use serde::{
21 Serialize,
22 de::DeserializeOwned,
23};
24use tokio::{
25 spawn,
26 sync::{
27 Mutex,
28 mpsc::{
29 Receiver,
30 Sender,
31 channel,
32 },
33 },
34 task::JoinHandle,
35 time::{
36 sleep,
37 timeout,
38 },
39};
40use tokio_tungstenite::tungstenite::Message;
41use tokio_util::sync::CancellationToken;
42
43#[cfg(feature = "connection-extensions")]
44mod extensions;
45
46#[cfg(feature = "connection-extensions")]
47use self::extensions::ConnectionExtensions;
48use crate::{
49 WsIoServer,
50 core::{
51 atomic::status::AtomicStatus,
52 channel_capacity_from_websocket_config,
53 event::registry::WsIoEventRegistry,
54 packet::{
55 WsIoPacket,
56 WsIoPacketType,
57 },
58 traits::task::spawner::TaskSpawner,
59 types::{
60 BoxAsyncUnaryResultHandler,
61 hashers::FxDashSet,
62 },
63 utils::task::abort_locked_task,
64 },
65 namespace::WsIoServerNamespace,
66};
67
68#[repr(u8)]
70#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
71enum ConnectionStatus {
72 Activating,
73 Authenticating,
74 AwaitingAuth,
75 Closed,
76 Closing,
77 Created,
78 Ready,
79}
80
81pub struct WsIoServerConnection {
83 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
84 cancel_token: ArcSwap<CancellationToken>,
85 event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
86 #[cfg(feature = "connection-extensions")]
87 extensions: ConnectionExtensions,
88 headers: HeaderMap,
89 id: u64,
90 joined_rooms: FxDashSet<String>,
91 message_tx: Sender<Arc<Message>>,
92 namespace: Arc<WsIoServerNamespace>,
93 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
94 status: AtomicStatus<ConnectionStatus>,
95}
96
97impl TaskSpawner for WsIoServerConnection {
98 #[inline]
99 fn cancel_token(&self) -> Arc<CancellationToken> {
100 self.cancel_token.load_full()
101 }
102}
103
104impl WsIoServerConnection {
105 #[inline]
106 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Arc<Self>, Receiver<Arc<Message>>) {
107 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
108 let (message_tx, message_rx) = channel(channel_capacity);
109 (
110 Arc::new(Self {
111 auth_timeout_task: Mutex::new(None),
112 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
113 event_registry: WsIoEventRegistry::new(),
114 #[cfg(feature = "connection-extensions")]
115 extensions: ConnectionExtensions::new(),
116 headers,
117 id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
118 joined_rooms: FxDashSet::default(),
119 message_tx,
120 namespace,
121 on_close_handler: Mutex::new(None),
122 status: AtomicStatus::new(ConnectionStatus::Created),
123 }),
124 message_rx,
125 )
126 }
127
128 async fn activate(self: &Arc<Self>) -> Result<()> {
130 let status = self.status.get();
132 match status {
133 ConnectionStatus::Authenticating | ConnectionStatus::Created => {
134 self.status.try_transition(status, ConnectionStatus::Activating)?
135 }
136 _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
137 }
138
139 if let Some(middleware) = &self.namespace.config.middleware {
141 timeout(
142 self.namespace.config.middleware_execution_timeout,
143 middleware(self.clone()),
144 )
145 .await??;
146
147 self.status.ensure(ConnectionStatus::Activating, |status| {
149 format!("Cannot activate connection in invalid status: {:#?}", status)
150 })?;
151 }
152
153 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
155 timeout(
156 self.namespace.config.on_connect_handler_timeout,
157 on_connect_handler(self.clone()),
158 )
159 .await??;
160 }
161
162 self.status
164 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
165
166 self.namespace.insert_connection(self.clone());
168
169 self.send_packet(&WsIoPacket::new_ready()).await?;
171
172 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
174 self.spawn_task(on_ready_handler(self.clone()));
176 }
177
178 Ok(())
179 }
180
181 async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
182 let status = self.status.get();
184 match status {
185 ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
186 _ => bail!("Received auth packet in invalid status: {:#?}", status),
187 }
188
189 abort_locked_task(&self.auth_timeout_task).await;
191
192 if let Some(auth_handler) = &self.namespace.config.auth_handler {
194 timeout(
195 self.namespace.config.auth_handler_timeout,
196 auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
197 )
198 .await??;
199
200 self.activate().await
202 } else {
203 bail!("Auth packet received but no auth handler is configured");
204 }
205 }
206
207 #[inline]
208 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
209 self.event_registry.dispatch_event_packet(
210 self.clone(),
211 event,
212 &self.namespace.config.packet_codec,
213 packet_data,
214 self,
215 );
216
217 Ok(())
218 }
219
220 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
221 self.send_message(self.namespace.encode_packet_to_message(packet)?)
222 .await
223 }
224
225 pub(crate) async fn cleanup(self: &Arc<Self>) {
227 self.status.store(ConnectionStatus::Closing);
229
230 self.namespace.remove_connection(self.id);
232
233 let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
235 for room_name in &joined_rooms {
236 self.namespace.remove_connection_id_from_room(room_name, self.id);
237 }
238
239 self.joined_rooms.clear();
240
241 abort_locked_task(&self.auth_timeout_task).await;
243
244 self.cancel_token.load().cancel();
246
247 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
249 let _ = timeout(
250 self.namespace.config.on_close_handler_timeout,
251 on_close_handler(self.clone()),
252 )
253 .await;
254 }
255
256 self.status.store(ConnectionStatus::Closed);
258 }
259
260 #[inline]
261 pub(crate) fn close(&self) {
262 match self.status.get() {
264 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
265 _ => self.status.store(ConnectionStatus::Closing),
266 }
267
268 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
270 }
271
272 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
273 self.status.ensure(ConnectionStatus::Ready, |status| {
274 format!("Cannot emit in invalid status: {:#?}", status)
275 })?;
276
277 self.send_message(message).await
278 }
279
280 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
281 let packet = self.namespace.config.packet_codec.decode(bytes)?;
283 match packet.r#type {
284 WsIoPacketType::Auth => {
285 if let Some(packet_data) = packet.data.as_deref() {
286 self.handle_auth_packet(packet_data).await
287 } else {
288 bail!("Auth packet missing data");
289 }
290 }
291 WsIoPacketType::Event => {
292 if let Some(event) = packet.key.as_deref() {
293 self.handle_event_packet(event, packet.data)
294 } else {
295 bail!("Event packet missing key");
296 }
297 }
298 _ => Ok(()),
299 }
300 }
301
302 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
303 self.status.ensure(ConnectionStatus::Created, |status| {
305 format!("Cannot init connection in invalid status: {:#?}", status)
306 })?;
307
308 let requires_auth = self.namespace.config.auth_handler.is_some();
310
311 let packet = &WsIoPacket::new_init(self.namespace.config.packet_codec.encode_data(&requires_auth)?);
313
314 if requires_auth {
316 self.status
318 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
319
320 let connection = self.clone();
322 *self.auth_timeout_task.lock().await = Some(spawn(async move {
323 sleep(connection.namespace.config.auth_packet_timeout).await;
324 if connection.status.is(ConnectionStatus::AwaitingAuth) {
325 connection.close();
326 }
327 }));
328
329 self.send_packet(packet).await
331 } else {
332 self.send_packet(packet).await?;
334
335 self.activate().await
337 }
338 }
339
340 pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
341 Ok(self.message_tx.send(message).await?)
342 }
343
344 pub async fn disconnect(&self) {
346 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
347 self.close()
348 }
349
350 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
351 self.emit_event_message(
352 self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
353 event.as_ref(),
354 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
355 .transpose()?,
356 ))?,
357 )
358 .await
359 }
360
361 #[cfg(feature = "connection-extensions")]
362 #[inline]
363 pub fn extensions(&self) -> &ConnectionExtensions {
364 &self.extensions
365 }
366
367 #[inline]
368 pub fn headers(&self) -> &HeaderMap {
369 &self.headers
370 }
371
372 #[inline]
373 pub fn id(&self) -> u64 {
374 self.id
375 }
376
377 #[inline]
378 pub fn join<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
379 for room_name in room_names {
380 let room_name = room_name.as_ref();
381 self.namespace.add_connection_id_to_room(room_name, self.id);
382 self.joined_rooms.insert(room_name.to_string());
383 }
384 }
385
386 #[inline]
387 pub fn leave<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
388 for room_name in room_names {
389 self.namespace
390 .remove_connection_id_from_room(room_name.as_ref(), self.id);
391
392 self.joined_rooms.remove(room_name.as_ref());
393 }
394 }
395
396 #[inline]
397 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
398 self.namespace.clone()
399 }
400
401 #[inline]
402 pub fn off(&self, event: impl AsRef<str>) {
403 self.event_registry.off(event.as_ref());
404 }
405
406 #[inline]
407 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
408 self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
409 }
410
411 #[inline]
412 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
413 where
414 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
415 Fut: Future<Output = Result<()>> + Send + 'static,
416 D: DeserializeOwned + Send + Sync + 'static,
417 {
418 self.event_registry.on(event.as_ref(), handler)
419 }
420
421 pub async fn on_close<H, Fut>(&self, handler: H)
422 where
423 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
424 Fut: Future<Output = Result<()>> + Send + 'static,
425 {
426 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
427 }
428
429 #[inline]
430 pub fn server(&self) -> WsIoServer {
431 self.namespace.server()
432 }
433}
434
435static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));