wsio_client/
session.rs

1use std::sync::Arc;
2
3use anyhow::{
4    Result,
5    bail,
6};
7use arc_swap::ArcSwap;
8use num_enum::{
9    IntoPrimitive,
10    TryFromPrimitive,
11};
12use tokio::{
13    spawn,
14    sync::{
15        Mutex,
16        mpsc::{
17            Receiver,
18            Sender,
19            channel,
20        },
21    },
22    task::JoinHandle,
23    time::{
24        sleep,
25        timeout,
26    },
27};
28use tokio_tungstenite::tungstenite::Message;
29use tokio_util::sync::CancellationToken;
30
31use crate::{
32    WsIoClient,
33    core::{
34        atomic::status::AtomicStatus,
35        channel_capacity_from_websocket_config,
36        packet::{
37            WsIoPacket,
38            WsIoPacketType,
39        },
40        traits::task::spawner::TaskSpawner,
41        utils::task::abort_locked_task,
42    },
43    runtime::WsIoClientRuntime,
44};
45
46// Enums
47#[repr(u8)]
48#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
49enum SessionStatus {
50    AwaitingInit,
51    AwaitingReady,
52    Closed,
53    Closing,
54    Created,
55    Initiating,
56    Ready,
57    Readying,
58}
59
60// Structs
61pub struct WsIoClientSession {
62    cancel_token: ArcSwap<CancellationToken>,
63    init_timeout_task: Mutex<Option<JoinHandle<()>>>,
64    message_tx: Sender<Arc<Message>>,
65    ready_timeout_task: Mutex<Option<JoinHandle<()>>>,
66    runtime: Arc<WsIoClientRuntime>,
67    status: AtomicStatus<SessionStatus>,
68}
69
70impl TaskSpawner for WsIoClientSession {
71    #[inline]
72    fn cancel_token(&self) -> Arc<CancellationToken> {
73        self.cancel_token.load_full()
74    }
75}
76
77impl WsIoClientSession {
78    #[inline]
79    pub(crate) fn new(runtime: Arc<WsIoClientRuntime>) -> (Arc<Self>, Receiver<Arc<Message>>) {
80        let channel_capacity = channel_capacity_from_websocket_config(&runtime.config.websocket_config);
81        let (message_tx, message_rx) = channel(channel_capacity);
82        (
83            Arc::new(Self {
84                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
85                init_timeout_task: Mutex::new(None),
86                message_tx,
87                ready_timeout_task: Mutex::new(None),
88                runtime,
89                status: AtomicStatus::new(SessionStatus::Created),
90            }),
91            message_rx,
92        )
93    }
94
95    // Private methods
96    #[inline]
97    fn handle_disconnect_packet(&self) -> Result<()> {
98        let runtime = self.runtime.clone();
99        spawn(async move { runtime.disconnect().await });
100        Ok(())
101    }
102
103    #[inline]
104    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
105        self.runtime.event_registry.dispatch_event_packet(
106            self.clone(),
107            event,
108            &self.runtime.config.packet_codec,
109            packet_data,
110            &self.runtime,
111        );
112
113        Ok(())
114    }
115
116    async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
117        // Verify current state; only valid from AwaitingInit → Initiating
118        let status = self.status.get();
119        match status {
120            SessionStatus::AwaitingInit => self.status.try_transition(status, SessionStatus::Initiating)?,
121            _ => bail!("Received init packet in invalid status: {status:?}"),
122        }
123
124        // Abort init-timeout task if still active
125        abort_locked_task(&self.init_timeout_task).await;
126
127        // Invoke init_handler with timeout protection if configured
128        let response_data = if let Some(init_handler) = &self.runtime.config.init_handler {
129            timeout(
130                self.runtime.config.init_handler_timeout,
131                init_handler(self.clone(), packet_data, &self.runtime.config.packet_codec),
132            )
133            .await??
134        } else {
135            None
136        };
137
138        // Transition state to AwaitingReady
139        self.status
140            .try_transition(SessionStatus::Initiating, SessionStatus::AwaitingReady)?;
141
142        // Spawn ready-timeout watchdog to close session if Ready is not received in time
143        let session = self.clone();
144        *self.ready_timeout_task.lock().await = Some(spawn(async move {
145            sleep(session.runtime.config.ready_packet_timeout).await;
146            if session.status.is(SessionStatus::AwaitingReady) {
147                session.close();
148            }
149        }));
150
151        // Send init packet
152        self.send_packet(&WsIoPacket::new_init(response_data)).await
153    }
154
155    async fn handle_ready_packet(self: &Arc<Self>) -> Result<()> {
156        // Verify current state; only valid from AwaitingReady → Ready
157        let status = self.status.get();
158        match status {
159            SessionStatus::AwaitingReady => self.status.try_transition(status, SessionStatus::Ready)?,
160            _ => bail!("Received ready packet in invalid status: {status:?}"),
161        }
162
163        // Abort ready-timeout task if still active
164        abort_locked_task(&self.ready_timeout_task).await;
165
166        // Wake event message flush task
167        self.runtime.event_message_flush_notify.notify_waiters();
168
169        // Invoke on_session_ready handler if configured
170        if let Some(on_session_ready_handler) = self.runtime.config.on_session_ready_handler.clone() {
171            // Run handler asynchronously in a detached task
172            self.spawn_task(on_session_ready_handler(self.clone()));
173        }
174
175        Ok(())
176    }
177
178    async fn send_message(&self, message: Arc<Message>) -> Result<()> {
179        Ok(self.message_tx.send(message).await?)
180    }
181
182    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
183        self.send_message(self.runtime.encode_packet_to_message(packet)?).await
184    }
185
186    // Protected methods
187    pub(crate) async fn cleanup(self: &Arc<Self>) {
188        // Set state to Closing
189        self.status.store(SessionStatus::Closing);
190
191        // Abort timeout tasks if still active
192        abort_locked_task(&self.init_timeout_task).await;
193        abort_locked_task(&self.ready_timeout_task).await;
194
195        // Cancel all ongoing operations via cancel token
196        self.cancel_token.load().cancel();
197
198        // Invoke on_session_close handler with timeout protection if configured
199        if let Some(on_session_close_handler) = &self.runtime.config.on_session_close_handler {
200            let _ = timeout(
201                self.runtime.config.on_session_close_handler_timeout,
202                on_session_close_handler(self.clone()),
203            )
204            .await;
205        }
206
207        // Set state to Closed
208        self.status.store(SessionStatus::Closed);
209    }
210
211    #[inline]
212    pub(crate) fn close(&self) {
213        // Skip if session is already Closing or Closed, otherwise set state to Closing
214        match self.status.get() {
215            SessionStatus::Closed | SessionStatus::Closing => return,
216            _ => self.status.store(SessionStatus::Closing),
217        }
218
219        // Send websocket close frame to initiate graceful shutdown
220        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
221    }
222
223    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
224        self.status.ensure(SessionStatus::Ready, |status| {
225            format!("Cannot emit event message in invalid status: {status:?}")
226        })?;
227
228        self.send_message(message).await
229    }
230
231    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
232        // TODO: lazy load
233        let packet = self.runtime.config.packet_codec.decode(encoded_packet)?;
234        match packet.r#type {
235            WsIoPacketType::Disconnect => self.handle_disconnect_packet(),
236            WsIoPacketType::Event => {
237                if let Some(event) = packet.key.as_deref() {
238                    self.handle_event_packet(event, packet.data)
239                } else {
240                    bail!("Event packet missing key");
241                }
242            }
243            WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
244            WsIoPacketType::Ready => self.handle_ready_packet().await,
245        }
246    }
247
248    pub(crate) async fn init(self: &Arc<Self>) {
249        self.status.store(SessionStatus::AwaitingInit);
250        let session = self.clone();
251        *self.init_timeout_task.lock().await = Some(spawn(async move {
252            sleep(session.runtime.config.init_packet_timeout).await;
253            if session.status.is(SessionStatus::AwaitingInit) {
254                session.close();
255            }
256        }));
257    }
258
259    // Public methods
260    #[inline]
261    pub fn client(&self) -> WsIoClient {
262        WsIoClient(self.runtime.clone())
263    }
264
265    #[inline]
266    pub fn is_ready(&self) -> bool {
267        self.status.is(SessionStatus::Ready)
268    }
269}