wsio_server/connection/
mod.rs

1use std::sync::{
2    Arc,
3    LazyLock,
4    atomic::{
5        AtomicU64,
6        Ordering,
7    },
8};
9
10use anyhow::{
11    Result,
12    bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16    HeaderMap,
17    Uri,
18};
19use num_enum::{
20    IntoPrimitive,
21    TryFromPrimitive,
22};
23use serde::{
24    Serialize,
25    de::DeserializeOwned,
26};
27use tokio::{
28    spawn,
29    sync::{
30        Mutex,
31        mpsc::{
32            Receiver,
33            Sender,
34            channel,
35        },
36    },
37    task::JoinHandle,
38    time::{
39        sleep,
40        timeout,
41    },
42};
43use tokio_tungstenite::tungstenite::Message;
44use tokio_util::sync::CancellationToken;
45
46#[cfg(feature = "connection-extensions")]
47mod extensions;
48
49#[cfg(feature = "connection-extensions")]
50use self::extensions::ConnectionExtensions;
51use crate::{
52    WsIoServer,
53    core::{
54        atomic::status::AtomicStatus,
55        channel_capacity_from_websocket_config,
56        event::registry::WsIoEventRegistry,
57        packet::{
58            WsIoPacket,
59            WsIoPacketType,
60        },
61        traits::task::spawner::TaskSpawner,
62        types::{
63            BoxAsyncUnaryResultHandler,
64            hashers::FxDashSet,
65        },
66        utils::task::abort_locked_task,
67    },
68    namespace::{
69        WsIoServerNamespace,
70        operators::broadcast::WsIoServerNamespaceBroadcastOperator,
71    },
72};
73
74// Enums
75#[repr(u8)]
76#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
77enum ConnectionStatus {
78    Activating,
79    AwaitingInit,
80    Closed,
81    Closing,
82    Created,
83    Initiating,
84    Ready,
85}
86
87// Structs
88pub struct WsIoServerConnection {
89    cancel_token: ArcSwap<CancellationToken>,
90    event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
91    #[cfg(feature = "connection-extensions")]
92    extensions: ConnectionExtensions,
93    headers: HeaderMap,
94    id: u64,
95    init_timeout_task: Mutex<Option<JoinHandle<()>>>,
96    joined_rooms: FxDashSet<String>,
97    message_tx: Sender<Arc<Message>>,
98    namespace: Arc<WsIoServerNamespace>,
99    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
100    request_uri: Uri,
101    status: AtomicStatus<ConnectionStatus>,
102}
103
104impl TaskSpawner for WsIoServerConnection {
105    #[inline]
106    fn cancel_token(&self) -> Arc<CancellationToken> {
107        self.cancel_token.load_full()
108    }
109}
110
111impl WsIoServerConnection {
112    #[inline]
113    pub(crate) fn new(
114        headers: HeaderMap,
115        namespace: Arc<WsIoServerNamespace>,
116        request_uri: Uri,
117    ) -> (Arc<Self>, Receiver<Arc<Message>>) {
118        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
119        let (message_tx, message_rx) = channel(channel_capacity);
120        (
121            Arc::new(Self {
122                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
123                event_registry: WsIoEventRegistry::new(),
124                #[cfg(feature = "connection-extensions")]
125                extensions: ConnectionExtensions::new(),
126                headers,
127                id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
128                init_timeout_task: Mutex::new(None),
129                joined_rooms: FxDashSet::default(),
130                message_tx,
131                namespace,
132                on_close_handler: Mutex::new(None),
133                request_uri,
134                status: AtomicStatus::new(ConnectionStatus::Created),
135            }),
136            message_rx,
137        )
138    }
139
140    // Private methods
141    #[inline]
142    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
143        self.event_registry.dispatch_event_packet(
144            self.clone(),
145            event,
146            &self.namespace.config.packet_codec,
147            packet_data,
148            self,
149        );
150
151        Ok(())
152    }
153
154    async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
155        // Verify current state; only valid from AwaitingInit → Initiating
156        let status = self.status.get();
157        match status {
158            ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
159            _ => bail!("Received init packet in invalid status: {status:?}"),
160        }
161
162        // Abort init-timeout task if still active
163        abort_locked_task(&self.init_timeout_task).await;
164
165        // Invoke init_response_handler with timeout protection if configured
166        if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
167            timeout(
168                self.namespace.config.init_response_handler_timeout,
169                init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
170            )
171            .await??
172        }
173
174        // Activate connection
175        self.status
176            .try_transition(ConnectionStatus::Initiating, ConnectionStatus::Activating)?;
177
178        // Invoke middleware with timeout protection if configured
179        if let Some(middleware) = &self.namespace.config.middleware {
180            timeout(
181                self.namespace.config.middleware_execution_timeout,
182                middleware(self.clone()),
183            )
184            .await??;
185
186            // Ensure connection is still in Activating state
187            self.status.ensure(ConnectionStatus::Activating, |status| {
188                format!("Cannot activate connection in invalid status: {status:?}")
189            })?;
190        }
191
192        // Invoke on_connect_handler with timeout protection if configured
193        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
194            timeout(
195                self.namespace.config.on_connect_handler_timeout,
196                on_connect_handler(self.clone()),
197            )
198            .await??;
199        }
200
201        // Transition state to Ready
202        self.status
203            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
204
205        // Insert connection into namespace
206        self.namespace.insert_connection(self.clone());
207
208        // Send ready packet
209        self.send_packet(&WsIoPacket::new_ready()).await?;
210
211        // Invoke on_ready_handler if configured
212        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
213            // Run handler asynchronously in a detached task
214            self.spawn_task(on_ready_handler(self.clone()));
215        }
216
217        Ok(())
218    }
219
220    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
221        self.send_message(self.namespace.encode_packet_to_message(packet)?)
222            .await
223    }
224
225    // Protected methods
226    pub(crate) async fn cleanup(self: &Arc<Self>) {
227        // Set connection state to Closing
228        self.status.store(ConnectionStatus::Closing);
229
230        // Remove connection from namespace
231        self.namespace.remove_connection(self.id);
232
233        // Leave all joined rooms
234        let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
235        for room_name in &joined_rooms {
236            self.namespace.remove_connection_id_from_room(room_name, self.id);
237        }
238
239        self.joined_rooms.clear();
240
241        // Abort init-timeout task if still active
242        abort_locked_task(&self.init_timeout_task).await;
243
244        // Cancel all ongoing operations via cancel token
245        self.cancel_token.load().cancel();
246
247        // Invoke on_close_handler with timeout protection if configured
248        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
249            let _ = timeout(
250                self.namespace.config.on_close_handler_timeout,
251                on_close_handler(self.clone()),
252            )
253            .await;
254        }
255
256        // Set connection state to Closed
257        self.status.store(ConnectionStatus::Closed);
258    }
259
260    #[inline]
261    pub(crate) fn close(&self) {
262        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
263        match self.status.get() {
264            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
265            _ => self.status.store(ConnectionStatus::Closing),
266        }
267
268        // Send websocket close frame to initiate graceful shutdown
269        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
270    }
271
272    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
273        self.status.ensure(ConnectionStatus::Ready, |status| {
274            format!("Cannot emit in invalid status: {status:?}")
275        })?;
276
277        self.send_message(message).await
278    }
279
280    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
281        // TODO: lazy load
282        let packet = self.namespace.config.packet_codec.decode(encoded_packet)?;
283        match packet.r#type {
284            WsIoPacketType::Event => {
285                if let Some(event) = packet.key.as_deref() {
286                    self.handle_event_packet(event, packet.data)
287                } else {
288                    bail!("Event packet missing key");
289                }
290            }
291            WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
292            _ => Ok(()),
293        }
294    }
295
296    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
297        // Verify current state; only valid Created
298        self.status.ensure(ConnectionStatus::Created, |status| {
299            format!("Cannot init connection in invalid status: {status:?}")
300        })?;
301
302        // Generate init request data if init request handler is configured
303        let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
304            timeout(
305                self.namespace.config.init_request_handler_timeout,
306                init_request_handler(self.clone(), &self.namespace.config.packet_codec),
307            )
308            .await??
309        } else {
310            None
311        };
312
313        // Transition state to AwaitingInit
314        self.status
315            .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingInit)?;
316
317        // Spawn init-response-timeout watchdog to close connection if init not received in time
318        let connection = self.clone();
319        *self.init_timeout_task.lock().await = Some(spawn(async move {
320            sleep(connection.namespace.config.init_response_timeout).await;
321            if connection.status.is(ConnectionStatus::AwaitingInit) {
322                connection.close();
323            }
324        }));
325
326        // Send init packet
327        self.send_packet(&WsIoPacket::new_init(init_request_data)).await
328    }
329
330    pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
331        Ok(self.message_tx.send(message).await?)
332    }
333
334    // Public methods
335    pub async fn disconnect(&self) {
336        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
337        self.close()
338    }
339
340    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
341        self.emit_event_message(
342            self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
343                event.as_ref(),
344                data.map(|data| self.namespace.config.packet_codec.encode_data(data))
345                    .transpose()?,
346            ))?,
347        )
348        .await
349    }
350
351    #[inline]
352    pub fn except(
353        self: &Arc<Self>,
354        room_names: impl IntoIterator<Item = impl AsRef<str>>,
355    ) -> WsIoServerNamespaceBroadcastOperator {
356        self.namespace.except(room_names).except_connection_ids(vec![self.id])
357    }
358
359    #[cfg(feature = "connection-extensions")]
360    #[inline]
361    pub fn extensions(&self) -> &ConnectionExtensions {
362        &self.extensions
363    }
364
365    #[inline]
366    pub fn headers(&self) -> &HeaderMap {
367        &self.headers
368    }
369
370    #[inline]
371    pub fn id(&self) -> u64 {
372        self.id
373    }
374
375    #[inline]
376    pub fn join(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
377        for room_name in room_names {
378            let room_name = room_name.as_ref();
379            self.namespace.add_connection_id_to_room(room_name, self.id);
380            self.joined_rooms.insert(room_name.into());
381        }
382    }
383
384    #[inline]
385    pub fn leave(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
386        for room_name in room_names {
387            self.namespace
388                .remove_connection_id_from_room(room_name.as_ref(), self.id);
389
390            self.joined_rooms.remove(room_name.as_ref());
391        }
392    }
393
394    #[inline]
395    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
396        self.namespace.clone()
397    }
398
399    #[inline]
400    pub fn off(&self, event: impl AsRef<str>) {
401        self.event_registry.off(event.as_ref());
402    }
403
404    #[inline]
405    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
406        self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
407    }
408
409    #[inline]
410    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
411    where
412        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
413        Fut: Future<Output = Result<()>> + Send + 'static,
414        D: DeserializeOwned + Send + Sync + 'static,
415    {
416        self.event_registry.on(event.as_ref(), handler)
417    }
418
419    pub async fn on_close<H, Fut>(&self, handler: H)
420    where
421        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
422        Fut: Future<Output = Result<()>> + Send + 'static,
423    {
424        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
425    }
426
427    #[inline]
428    pub fn request_uri(&self) -> &Uri {
429        &self.request_uri
430    }
431
432    #[inline]
433    pub fn server(&self) -> WsIoServer {
434        self.namespace.server()
435    }
436
437    #[inline]
438    pub fn to(
439        self: &Arc<Self>,
440        room_names: impl IntoIterator<Item = impl AsRef<str>>,
441    ) -> WsIoServerNamespaceBroadcastOperator {
442        self.namespace.to(room_names).except_connection_ids(vec![self.id])
443    }
444}
445
446// Constants/Statics
447static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));