wsio_client/
session.rs

1use std::sync::Arc;
2
3use anyhow::{
4    Result,
5    bail,
6};
7use arc_swap::ArcSwap;
8use num_enum::{
9    IntoPrimitive,
10    TryFromPrimitive,
11};
12use tokio::{
13    spawn,
14    sync::{
15        Mutex,
16        mpsc::{
17            Receiver,
18            Sender,
19            channel,
20        },
21    },
22    task::JoinHandle,
23    time::{
24        sleep,
25        timeout,
26    },
27};
28use tokio_tungstenite::tungstenite::Message;
29use tokio_util::sync::CancellationToken;
30
31use crate::{
32    WsIoClient,
33    core::{
34        atomic::status::AtomicStatus,
35        channel_capacity_from_websocket_config,
36        packet::{
37            WsIoPacket,
38            WsIoPacketType,
39        },
40        traits::task::spawner::TaskSpawner,
41        utils::task::abort_locked_task,
42    },
43    runtime::WsIoClientRuntime,
44};
45
46// Enums
47#[repr(u8)]
48#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
49enum SessionStatus {
50    AwaitingInit,
51    AwaitingReady,
52    Closed,
53    Closing,
54    Created,
55    Initiating,
56    Ready,
57    Readying,
58}
59
60// Structs
61pub struct WsIoClientSession {
62    cancel_token: ArcSwap<CancellationToken>,
63    init_timeout_task: Mutex<Option<JoinHandle<()>>>,
64    message_tx: Sender<Arc<Message>>,
65    ready_timeout_task: Mutex<Option<JoinHandle<()>>>,
66    runtime: Arc<WsIoClientRuntime>,
67    status: AtomicStatus<SessionStatus>,
68}
69
70impl TaskSpawner for WsIoClientSession {
71    #[inline]
72    fn cancel_token(&self) -> Arc<CancellationToken> {
73        self.cancel_token.load_full()
74    }
75}
76
77impl WsIoClientSession {
78    #[inline]
79    pub(crate) fn new(runtime: Arc<WsIoClientRuntime>) -> (Arc<Self>, Receiver<Arc<Message>>) {
80        let channel_capacity = channel_capacity_from_websocket_config(&runtime.config.websocket_config);
81        let (message_tx, message_rx) = channel(channel_capacity);
82        (
83            Arc::new(Self {
84                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
85                init_timeout_task: Mutex::new(None),
86                message_tx,
87                ready_timeout_task: Mutex::new(None),
88                runtime,
89                status: AtomicStatus::new(SessionStatus::Created),
90            }),
91            message_rx,
92        )
93    }
94
95    // Private methods
96    #[inline]
97    fn handle_disconnect_packet(&self) -> Result<()> {
98        let runtime = self.runtime.clone();
99        spawn(async move { runtime.disconnect().await });
100        Ok(())
101    }
102
103    #[inline]
104    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
105        if self.is_ready() {
106            self.runtime.event_registry.dispatch_event_packet(
107                self.clone(),
108                event,
109                &self.runtime.config.packet_codec,
110                packet_data,
111                &self.runtime,
112            );
113        }
114
115        Ok(())
116    }
117
118    async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
119        // Verify current state; only valid from AwaitingInit → Initiating
120        let status = self.status.get();
121        match status {
122            SessionStatus::AwaitingInit => self.status.try_transition(status, SessionStatus::Initiating)?,
123            _ => bail!("Received init packet in invalid status: {status:?}"),
124        }
125
126        // Abort init-timeout task
127        abort_locked_task(&self.init_timeout_task).await;
128
129        // Invoke init_handler with timeout protection if configured
130        let response_data = if let Some(init_handler) = &self.runtime.config.init_handler {
131            timeout(
132                self.runtime.config.init_handler_timeout,
133                init_handler(self.clone(), packet_data, &self.runtime.config.packet_codec),
134            )
135            .await??
136        } else {
137            None
138        };
139
140        // Transition state to AwaitingReady
141        self.status
142            .try_transition(SessionStatus::Initiating, SessionStatus::AwaitingReady)?;
143
144        // Spawn ready-timeout watchdog to close session if Ready is not received in time
145        let session = self.clone();
146        *self.ready_timeout_task.lock().await = Some(spawn(async move {
147            sleep(session.runtime.config.ready_packet_timeout).await;
148            if session.status.is(SessionStatus::AwaitingReady) {
149                session.close();
150            }
151        }));
152
153        // Send init packet
154        self.send_packet(&WsIoPacket::new_init(response_data)).await
155    }
156
157    async fn handle_ready_packet(self: &Arc<Self>) -> Result<()> {
158        // Verify current state; only valid from AwaitingReady → Ready
159        let status = self.status.get();
160        match status {
161            SessionStatus::AwaitingReady => self.status.try_transition(status, SessionStatus::Ready)?,
162            _ => bail!("Received ready packet in invalid status: {status:?}"),
163        }
164
165        // Abort ready-timeout task
166        abort_locked_task(&self.ready_timeout_task).await;
167
168        // Wake send event message task
169        self.runtime.wake_send_event_message_task_notify.notify_waiters();
170
171        // Invoke on_session_ready_handler if configured
172        if let Some(on_session_ready_handler) = self.runtime.config.on_session_ready_handler.clone() {
173            // Run handler asynchronously in a detached task
174            self.spawn_task(on_session_ready_handler(self.clone()));
175        }
176
177        Ok(())
178    }
179
180    async fn send_message(&self, message: Arc<Message>) -> Result<()> {
181        Ok(self.message_tx.send(message).await?)
182    }
183
184    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
185        self.send_message(self.runtime.encode_packet_to_message(packet)?).await
186    }
187
188    // Protected methods
189    pub(crate) async fn cleanup(self: &Arc<Self>) {
190        // Set state to Closing
191        self.status.store(SessionStatus::Closing);
192
193        // Abort timeout tasks
194        abort_locked_task(&self.init_timeout_task).await;
195        abort_locked_task(&self.ready_timeout_task).await;
196
197        // Cancel all ongoing operations via cancel token
198        self.cancel_token.load().cancel();
199
200        // Invoke on_session_close_handler with timeout protection if configured
201        if let Some(on_session_close_handler) = &self.runtime.config.on_session_close_handler {
202            let _ = timeout(
203                self.runtime.config.on_session_close_handler_timeout,
204                on_session_close_handler(self.clone()),
205            )
206            .await;
207        }
208
209        // Set state to Closed
210        self.status.store(SessionStatus::Closed);
211    }
212
213    #[inline]
214    pub(crate) fn close(&self) {
215        // Skip if session is already Closing or Closed, otherwise set state to Closing
216        match self.status.get() {
217            SessionStatus::Closed | SessionStatus::Closing => return,
218            _ => self.status.store(SessionStatus::Closing),
219        }
220
221        // Send websocket close frame to initiate graceful shutdown
222        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
223    }
224
225    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
226        self.status.ensure(SessionStatus::Ready, |status| {
227            format!("Cannot emit event message in invalid status: {status:?}")
228        })?;
229
230        self.send_message(message).await
231    }
232
233    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
234        // TODO: lazy load
235        let packet = self.runtime.config.packet_codec.decode(encoded_packet)?;
236        match packet.r#type {
237            WsIoPacketType::Disconnect => self.handle_disconnect_packet(),
238            WsIoPacketType::Event => {
239                if let Some(event) = packet.key.as_deref() {
240                    self.handle_event_packet(event, packet.data)
241                } else {
242                    bail!("Event packet missing key");
243                }
244            }
245            WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
246            WsIoPacketType::Ready => self.handle_ready_packet().await,
247        }
248    }
249
250    pub(crate) async fn init(self: &Arc<Self>) {
251        self.status.store(SessionStatus::AwaitingInit);
252        let session = self.clone();
253        *self.init_timeout_task.lock().await = Some(spawn(async move {
254            sleep(session.runtime.config.init_packet_timeout).await;
255            if session.status.is(SessionStatus::AwaitingInit) {
256                session.close();
257            }
258        }));
259    }
260
261    // Public methods
262    #[inline]
263    pub fn client(&self) -> WsIoClient {
264        WsIoClient(self.runtime.clone())
265    }
266
267    #[inline]
268    pub fn is_ready(&self) -> bool {
269        self.status.is(SessionStatus::Ready)
270    }
271}