wsio_server/connection/
mod.rs

1use std::sync::{
2    Arc,
3    LazyLock,
4    atomic::{
5        AtomicU64,
6        Ordering,
7    },
8};
9
10use anyhow::{
11    Result,
12    bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16    HeaderMap,
17    Uri,
18};
19use num_enum::{
20    IntoPrimitive,
21    TryFromPrimitive,
22};
23use serde::{
24    Serialize,
25    de::DeserializeOwned,
26};
27use tokio::{
28    spawn,
29    sync::{
30        Mutex,
31        mpsc::{
32            Receiver,
33            Sender,
34            channel,
35        },
36    },
37    task::JoinHandle,
38    time::{
39        sleep,
40        timeout,
41    },
42};
43use tokio_tungstenite::tungstenite::Message;
44use tokio_util::sync::CancellationToken;
45
46#[cfg(feature = "connection-extensions")]
47mod extensions;
48
49#[cfg(feature = "connection-extensions")]
50use self::extensions::ConnectionExtensions;
51use crate::{
52    WsIoServer,
53    core::{
54        atomic::status::AtomicStatus,
55        channel_capacity_from_websocket_config,
56        event::registry::WsIoEventRegistry,
57        packet::{
58            WsIoPacket,
59            WsIoPacketType,
60        },
61        traits::task::spawner::TaskSpawner,
62        types::{
63            BoxAsyncUnaryResultHandler,
64            hashers::FxDashSet,
65        },
66        utils::task::abort_locked_task,
67    },
68    namespace::WsIoServerNamespace,
69};
70
71// Enums
72#[repr(u8)]
73#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
74enum ConnectionStatus {
75    Activating,
76    AwaitingInit,
77    Closed,
78    Closing,
79    Created,
80    Initiating,
81    Ready,
82}
83
84// Structs
85pub struct WsIoServerConnection {
86    cancel_token: ArcSwap<CancellationToken>,
87    event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
88    #[cfg(feature = "connection-extensions")]
89    extensions: ConnectionExtensions,
90    headers: HeaderMap,
91    id: u64,
92    init_timeout_task: Mutex<Option<JoinHandle<()>>>,
93    joined_rooms: FxDashSet<String>,
94    message_tx: Sender<Arc<Message>>,
95    namespace: Arc<WsIoServerNamespace>,
96    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
97    request_uri: Uri,
98    status: AtomicStatus<ConnectionStatus>,
99}
100
101impl TaskSpawner for WsIoServerConnection {
102    #[inline]
103    fn cancel_token(&self) -> Arc<CancellationToken> {
104        self.cancel_token.load_full()
105    }
106}
107
108impl WsIoServerConnection {
109    #[inline]
110    pub(crate) fn new(
111        headers: HeaderMap,
112        namespace: Arc<WsIoServerNamespace>,
113        request_uri: Uri,
114    ) -> (Arc<Self>, Receiver<Arc<Message>>) {
115        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
116        let (message_tx, message_rx) = channel(channel_capacity);
117        (
118            Arc::new(Self {
119                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
120                event_registry: WsIoEventRegistry::new(),
121                #[cfg(feature = "connection-extensions")]
122                extensions: ConnectionExtensions::new(),
123                headers,
124                id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
125                init_timeout_task: Mutex::new(None),
126                joined_rooms: FxDashSet::default(),
127                message_tx,
128                namespace,
129                on_close_handler: Mutex::new(None),
130                request_uri,
131                status: AtomicStatus::new(ConnectionStatus::Created),
132            }),
133            message_rx,
134        )
135    }
136
137    // Private methods
138    #[inline]
139    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
140        self.event_registry.dispatch_event_packet(
141            self.clone(),
142            event,
143            &self.namespace.config.packet_codec,
144            packet_data,
145            self,
146        );
147
148        Ok(())
149    }
150
151    async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
152        // Verify current state; only valid from AwaitingInit → Initiating
153        let status = self.status.get();
154        match status {
155            ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
156            _ => bail!("Received init packet in invalid status: {status:?}"),
157        }
158
159        // Abort init-timeout task if still active
160        abort_locked_task(&self.init_timeout_task).await;
161
162        // Invoke init_response handler with timeout protection if configured
163        if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
164            timeout(
165                self.namespace.config.init_response_handler_timeout,
166                init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
167            )
168            .await??
169        }
170
171        // Activate connection
172        self.status
173            .try_transition(ConnectionStatus::Initiating, ConnectionStatus::Activating)?;
174
175        // Invoke middleware with timeout protection if configured
176        if let Some(middleware) = &self.namespace.config.middleware {
177            timeout(
178                self.namespace.config.middleware_execution_timeout,
179                middleware(self.clone()),
180            )
181            .await??;
182
183            // Ensure connection is still in Activating state
184            self.status.ensure(ConnectionStatus::Activating, |status| {
185                format!("Cannot activate connection in invalid status: {status:?}")
186            })?;
187        }
188
189        // Invoke on_connect handler with timeout protection if configured
190        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
191            timeout(
192                self.namespace.config.on_connect_handler_timeout,
193                on_connect_handler(self.clone()),
194            )
195            .await??;
196        }
197
198        // Transition state to Ready
199        self.status
200            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
201
202        // Insert connection into namespace
203        self.namespace.insert_connection(self.clone());
204
205        // Send ready packet
206        self.send_packet(&WsIoPacket::new_ready()).await?;
207
208        // Invoke on_ready handler if configured
209        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
210            // Run handler asynchronously in a detached task
211            self.spawn_task(on_ready_handler(self.clone()));
212        }
213
214        Ok(())
215    }
216
217    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
218        self.send_message(self.namespace.encode_packet_to_message(packet)?)
219            .await
220    }
221
222    // Protected methods
223    pub(crate) async fn cleanup(self: &Arc<Self>) {
224        // Set connection state to Closing
225        self.status.store(ConnectionStatus::Closing);
226
227        // Remove connection from namespace
228        self.namespace.remove_connection(self.id);
229
230        // Leave all joined rooms
231        let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
232        for room_name in &joined_rooms {
233            self.namespace.remove_connection_id_from_room(room_name, self.id);
234        }
235
236        self.joined_rooms.clear();
237
238        // Abort init-timeout task if still active
239        abort_locked_task(&self.init_timeout_task).await;
240
241        // Cancel all ongoing operations via cancel token
242        self.cancel_token.load().cancel();
243
244        // Invoke on_close handler with timeout protection if configured
245        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
246            let _ = timeout(
247                self.namespace.config.on_close_handler_timeout,
248                on_close_handler(self.clone()),
249            )
250            .await;
251        }
252
253        // Set connection state to Closed
254        self.status.store(ConnectionStatus::Closed);
255    }
256
257    #[inline]
258    pub(crate) fn close(&self) {
259        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
260        match self.status.get() {
261            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
262            _ => self.status.store(ConnectionStatus::Closing),
263        }
264
265        // Send websocket close frame to initiate graceful shutdown
266        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
267    }
268
269    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
270        self.status.ensure(ConnectionStatus::Ready, |status| {
271            format!("Cannot emit in invalid status: {status:?}")
272        })?;
273
274        self.send_message(message).await
275    }
276
277    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
278        // TODO: lazy load
279        let packet = self.namespace.config.packet_codec.decode(bytes)?;
280        match packet.r#type {
281            WsIoPacketType::Event => {
282                if let Some(event) = packet.key.as_deref() {
283                    self.handle_event_packet(event, packet.data)
284                } else {
285                    bail!("Event packet missing key");
286                }
287            }
288            WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
289            _ => Ok(()),
290        }
291    }
292
293    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
294        // Verify current state; only valid Created
295        self.status.ensure(ConnectionStatus::Created, |status| {
296            format!("Cannot init connection in invalid status: {status:?}")
297        })?;
298
299        // Generate init request data if init request handler is configured
300        let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
301            timeout(
302                self.namespace.config.init_request_handler_timeout,
303                init_request_handler(self.clone(), &self.namespace.config.packet_codec),
304            )
305            .await??
306        } else {
307            None
308        };
309
310        // Transition state to AwaitingInit
311        self.status
312            .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingInit)?;
313
314        // Spawn init-response-timeout watchdog to close connection if init not received in time
315        let connection = self.clone();
316        *self.init_timeout_task.lock().await = Some(spawn(async move {
317            sleep(connection.namespace.config.init_response_timeout).await;
318            if connection.status.is(ConnectionStatus::AwaitingInit) {
319                connection.close();
320            }
321        }));
322
323        // Send init packet
324        self.send_packet(&WsIoPacket::new_init(init_request_data)).await
325    }
326
327    pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
328        Ok(self.message_tx.send(message).await?)
329    }
330
331    // Public methods
332    pub async fn disconnect(&self) {
333        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
334        self.close()
335    }
336
337    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
338        self.emit_event_message(
339            self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
340                event.as_ref(),
341                data.map(|data| self.namespace.config.packet_codec.encode_data(data))
342                    .transpose()?,
343            ))?,
344        )
345        .await
346    }
347
348    #[cfg(feature = "connection-extensions")]
349    #[inline]
350    pub fn extensions(&self) -> &ConnectionExtensions {
351        &self.extensions
352    }
353
354    #[inline]
355    pub fn headers(&self) -> &HeaderMap {
356        &self.headers
357    }
358
359    #[inline]
360    pub fn id(&self) -> u64 {
361        self.id
362    }
363
364    #[inline]
365    pub fn join<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
366        for room_name in room_names {
367            let room_name = room_name.as_ref();
368            self.namespace.add_connection_id_to_room(room_name, self.id);
369            self.joined_rooms.insert(room_name.to_string());
370        }
371    }
372
373    #[inline]
374    pub fn leave<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
375        for room_name in room_names {
376            self.namespace
377                .remove_connection_id_from_room(room_name.as_ref(), self.id);
378
379            self.joined_rooms.remove(room_name.as_ref());
380        }
381    }
382
383    #[inline]
384    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
385        self.namespace.clone()
386    }
387
388    #[inline]
389    pub fn off(&self, event: impl AsRef<str>) {
390        self.event_registry.off(event.as_ref());
391    }
392
393    #[inline]
394    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
395        self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
396    }
397
398    #[inline]
399    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
400    where
401        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
402        Fut: Future<Output = Result<()>> + Send + 'static,
403        D: DeserializeOwned + Send + Sync + 'static,
404    {
405        self.event_registry.on(event.as_ref(), handler)
406    }
407
408    pub async fn on_close<H, Fut>(&self, handler: H)
409    where
410        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
411        Fut: Future<Output = Result<()>> + Send + 'static,
412    {
413        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
414    }
415
416    #[inline]
417    pub fn request_uri(&self) -> &Uri {
418        &self.request_uri
419    }
420
421    #[inline]
422    pub fn server(&self) -> WsIoServer {
423        self.namespace.server()
424    }
425}
426
427// Constants/Statics
428static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));