1use std::sync::{
2 Arc,
3 LazyLock,
4 atomic::{
5 AtomicU64,
6 Ordering,
7 },
8};
9
10use anyhow::{
11 Result,
12 bail,
13};
14use arc_swap::ArcSwap;
15use http::HeaderMap;
16use num_enum::{
17 IntoPrimitive,
18 TryFromPrimitive,
19};
20use serde::{
21 Serialize,
22 de::DeserializeOwned,
23};
24use tokio::{
25 spawn,
26 sync::{
27 Mutex,
28 mpsc::{
29 Receiver,
30 Sender,
31 channel,
32 },
33 },
34 task::JoinHandle,
35 time::{
36 sleep,
37 timeout,
38 },
39};
40use tokio_tungstenite::tungstenite::Message;
41use tokio_util::sync::CancellationToken;
42
43#[cfg(feature = "connection-extensions")]
44mod extensions;
45
46#[cfg(feature = "connection-extensions")]
47use self::extensions::ConnectionExtensions;
48use crate::{
49 WsIoServer,
50 core::{
51 atomic::status::AtomicStatus,
52 channel_capacity_from_websocket_config,
53 event::registry::WsIoEventRegistry,
54 packet::{
55 WsIoPacket,
56 WsIoPacketType,
57 },
58 traits::task::spawner::TaskSpawner,
59 types::BoxAsyncUnaryResultHandler,
60 utils::task::abort_locked_task,
61 },
62 namespace::WsIoServerNamespace,
63};
64
65#[repr(u8)]
67#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
68enum ConnectionStatus {
69 Activating,
70 Authenticating,
71 AwaitingAuth,
72 Closed,
73 Closing,
74 Created,
75 Ready,
76}
77
78pub struct WsIoServerConnection {
80 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
81 cancel_token: ArcSwap<CancellationToken>,
82 event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
83 #[cfg(feature = "connection-extensions")]
84 extensions: ConnectionExtensions,
85 headers: HeaderMap,
86 id: u64,
87 message_tx: Sender<Message>,
88 namespace: Arc<WsIoServerNamespace>,
89 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
90 status: AtomicStatus<ConnectionStatus>,
91}
92
93impl TaskSpawner for WsIoServerConnection {
94 #[inline]
95 fn cancel_token(&self) -> Arc<CancellationToken> {
96 self.cancel_token.load_full()
97 }
98}
99
100impl WsIoServerConnection {
101 #[inline]
102 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Arc<Self>, Receiver<Message>) {
103 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
104 let (message_tx, message_rx) = channel(channel_capacity);
105 (
106 Arc::new(Self {
107 auth_timeout_task: Mutex::new(None),
108 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
109 event_registry: WsIoEventRegistry::new(),
110 #[cfg(feature = "connection-extensions")]
111 extensions: ConnectionExtensions::new(),
112 headers,
113 id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
114 message_tx,
115 namespace,
116 on_close_handler: Mutex::new(None),
117 status: AtomicStatus::new(ConnectionStatus::Created),
118 }),
119 message_rx,
120 )
121 }
122
123 async fn activate(self: &Arc<Self>) -> Result<()> {
125 let status = self.status.get();
127 match status {
128 ConnectionStatus::Authenticating | ConnectionStatus::Created => {
129 self.status.try_transition(status, ConnectionStatus::Activating)?
130 }
131 _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
132 }
133
134 if let Some(middleware) = &self.namespace.config.middleware {
136 timeout(
137 self.namespace.config.middleware_execution_timeout,
138 middleware(self.clone()),
139 )
140 .await??;
141 }
142
143 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
145 timeout(
146 self.namespace.config.on_connect_handler_timeout,
147 on_connect_handler(self.clone()),
148 )
149 .await??;
150 }
151
152 self.namespace.insert_connection(self.clone());
154
155 self.status
157 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
158
159 self.send_packet(&WsIoPacket::new_ready()).await?;
161
162 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
164 self.spawn_task(on_ready_handler(self.clone()));
166 }
167
168 Ok(())
169 }
170
171 #[inline]
172 fn ensure_status_ready(&self) -> Result<()> {
173 self.status.ensure(ConnectionStatus::Ready, |status| {
174 format!("Cannot emit event in invalid status: {:#?}", status)
175 })?;
176
177 Ok(())
178 }
179
180 async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
181 let status = self.status.get();
183 match status {
184 ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
185 _ => bail!("Received auth packet in invalid status: {:#?}", status),
186 }
187
188 abort_locked_task(&self.auth_timeout_task).await;
190
191 if let Some(auth_handler) = &self.namespace.config.auth_handler {
193 timeout(
194 self.namespace.config.auth_handler_timeout,
195 auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
196 )
197 .await??;
198
199 self.activate().await
201 } else {
202 bail!("Auth packet received but no auth handler is configured");
203 }
204 }
205
206 #[inline]
207 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
208 self.event_registry.dispatch_event_packet(
209 self.clone(),
210 event,
211 &self.namespace.config.packet_codec,
212 packet_data,
213 self,
214 );
215
216 Ok(())
217 }
218
219 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
220 Ok(self
221 .message_tx
222 .send(self.namespace.encode_packet_to_message(packet)?)
223 .await?)
224 }
225
226 pub(crate) async fn cleanup(self: &Arc<Self>) {
228 self.status.store(ConnectionStatus::Closing);
230
231 abort_locked_task(&self.auth_timeout_task).await;
233
234 self.namespace.remove_connection(self.id);
236
237 self.cancel_token.load().cancel();
239
240 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
242 let _ = timeout(
243 self.namespace.config.on_close_handler_timeout,
244 on_close_handler(self.clone()),
245 )
246 .await;
247 }
248
249 self.status.store(ConnectionStatus::Closed);
251 }
252
253 #[inline]
254 pub(crate) fn close(&self) {
255 match self.status.get() {
257 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
258 _ => self.status.store(ConnectionStatus::Closing),
259 }
260
261 let _ = self.message_tx.try_send(Message::Close(None));
263 }
264
265 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
266 let packet = self.namespace.config.packet_codec.decode(bytes)?;
268 match packet.r#type {
269 WsIoPacketType::Auth => {
270 if let Some(packet_data) = packet.data.as_deref() {
271 self.handle_auth_packet(packet_data).await
272 } else {
273 bail!("Auth packet missing data");
274 }
275 }
276 WsIoPacketType::Event => {
277 if let Some(event) = packet.key.as_deref() {
278 self.handle_event_packet(event, packet.data)
279 } else {
280 bail!("Event packet missing key");
281 }
282 }
283 _ => Ok(()),
284 }
285 }
286
287 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
288 self.status.ensure(ConnectionStatus::Created, |status| {
290 format!("Cannot init connection in invalid status: {:#?}", status)
291 })?;
292
293 let requires_auth = self.namespace.config.auth_handler.is_some();
295
296 let packet = &WsIoPacket::new_init(self.namespace.config.packet_codec.encode_data(&requires_auth)?);
298
299 if requires_auth {
301 self.status
303 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
304
305 let connection = self.clone();
307 *self.auth_timeout_task.lock().await = Some(spawn(async move {
308 sleep(connection.namespace.config.auth_packet_timeout).await;
309 if connection.status.is(ConnectionStatus::AwaitingAuth) {
310 connection.close();
311 }
312 }));
313
314 self.send_packet(packet).await
316 } else {
317 self.send_packet(packet).await?;
319
320 self.activate().await
322 }
323 }
324
325 pub async fn disconnect(&self) {
327 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
328 self.close()
329 }
330
331 pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
332 self.ensure_status_ready()?;
333 self.send_packet(&WsIoPacket::new_event(
334 event,
335 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
336 .transpose()?,
337 ))
338 .await
339 }
340
341 pub async fn emit_message(&self, message: Message) -> Result<()> {
342 self.ensure_status_ready()?;
343 Ok(self.message_tx.send(message).await?)
344 }
345
346 #[cfg(feature = "connection-extensions")]
347 #[inline]
348 pub fn extensions(&self) -> &ConnectionExtensions {
349 &self.extensions
350 }
351
352 #[inline]
353 pub fn headers(&self) -> &HeaderMap {
354 &self.headers
355 }
356
357 #[inline]
358 pub fn id(&self) -> u64 {
359 self.id
360 }
361
362 #[inline]
363 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
364 self.namespace.clone()
365 }
366
367 #[inline]
368 pub fn off(&self, event: impl AsRef<str>) {
369 self.event_registry.off(event);
370 }
371
372 #[inline]
373 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
374 self.event_registry.off_by_handler_id(event, handler_id);
375 }
376
377 #[inline]
378 pub fn on<H, Fut, D>(&self, event: impl Into<String>, handler: H) -> u32
379 where
380 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
381 Fut: Future<Output = Result<()>> + Send + 'static,
382 D: DeserializeOwned + Send + Sync + 'static,
383 {
384 self.event_registry.on(event, handler)
385 }
386
387 pub async fn on_close<H, Fut>(&self, handler: H)
388 where
389 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
390 Fut: Future<Output = Result<()>> + Send + 'static,
391 {
392 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
393 }
394
395 #[inline]
396 pub fn server(&self) -> WsIoServer {
397 self.namespace.server()
398 }
399}
400
401static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));