1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use arc_swap::ArcSwap;
8use num_enum::{
9 IntoPrimitive,
10 TryFromPrimitive,
11};
12use tokio::{
13 spawn,
14 sync::{
15 Mutex,
16 mpsc::{
17 Receiver,
18 Sender,
19 channel,
20 },
21 },
22 task::JoinHandle,
23 time::{
24 sleep,
25 timeout,
26 },
27};
28use tokio_tungstenite::tungstenite::Message;
29use tokio_util::sync::CancellationToken;
30
31use crate::{
32 WsIoClient,
33 core::{
34 atomic::status::AtomicStatus,
35 channel_capacity_from_websocket_config,
36 packet::{
37 WsIoPacket,
38 WsIoPacketType,
39 },
40 traits::task::spawner::TaskSpawner,
41 utils::task::abort_locked_task,
42 },
43 runtime::WsIoClientRuntime,
44};
45
46#[repr(u8)]
48#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
49enum SessionStatus {
50 AwaitingInit,
51 AwaitingReady,
52 Closed,
53 Closing,
54 Created,
55 Initiating,
56 Ready,
57 Readying,
58}
59
60pub struct WsIoClientSession {
62 cancel_token: ArcSwap<CancellationToken>,
63 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
64 message_tx: Sender<Arc<Message>>,
65 ready_timeout_task: Mutex<Option<JoinHandle<()>>>,
66 runtime: Arc<WsIoClientRuntime>,
67 status: AtomicStatus<SessionStatus>,
68}
69
70impl TaskSpawner for WsIoClientSession {
71 #[inline]
72 fn cancel_token(&self) -> Arc<CancellationToken> {
73 self.cancel_token.load_full()
74 }
75}
76
77impl WsIoClientSession {
78 #[inline]
79 pub(crate) fn new(runtime: Arc<WsIoClientRuntime>) -> (Arc<Self>, Receiver<Arc<Message>>) {
80 let channel_capacity = channel_capacity_from_websocket_config(&runtime.config.websocket_config);
81 let (message_tx, message_rx) = channel(channel_capacity);
82 (
83 Arc::new(Self {
84 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
85 init_timeout_task: Mutex::new(None),
86 message_tx,
87 ready_timeout_task: Mutex::new(None),
88 runtime,
89 status: AtomicStatus::new(SessionStatus::Created),
90 }),
91 message_rx,
92 )
93 }
94
95 #[inline]
97 fn handle_disconnect_packet(&self) -> Result<()> {
98 let runtime = self.runtime.clone();
99 spawn(async move { runtime.disconnect().await });
100 Ok(())
101 }
102
103 #[inline]
104 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
105 if self.is_ready() {
106 self.runtime.event_registry.dispatch_event_packet(
107 self.clone(),
108 event,
109 &self.runtime.config.packet_codec,
110 packet_data,
111 &self.runtime,
112 );
113 }
114
115 Ok(())
116 }
117
118 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
119 let status = self.status.get();
121 match status {
122 SessionStatus::AwaitingInit => self.status.try_transition(status, SessionStatus::Initiating)?,
123 _ => bail!("Received init packet in invalid status: {status:?}"),
124 }
125
126 abort_locked_task(&self.init_timeout_task).await;
128
129 let response_data = if let Some(init_handler) = &self.runtime.config.init_handler {
131 timeout(
132 self.runtime.config.init_handler_timeout,
133 init_handler(self.clone(), packet_data, &self.runtime.config.packet_codec),
134 )
135 .await??
136 } else {
137 None
138 };
139
140 self.status
142 .try_transition(SessionStatus::Initiating, SessionStatus::AwaitingReady)?;
143
144 let session = self.clone();
146 *self.ready_timeout_task.lock().await = Some(spawn(async move {
147 sleep(session.runtime.config.ready_packet_timeout).await;
148 if session.status.is(SessionStatus::AwaitingReady) {
149 session.close();
150 }
151 }));
152
153 self.send_packet(&WsIoPacket::new_init(response_data)).await
155 }
156
157 async fn handle_ready_packet(self: &Arc<Self>) -> Result<()> {
158 let status = self.status.get();
160 match status {
161 SessionStatus::AwaitingReady => self.status.try_transition(status, SessionStatus::Ready)?,
162 _ => bail!("Received ready packet in invalid status: {status:?}"),
163 }
164
165 abort_locked_task(&self.ready_timeout_task).await;
167
168 self.runtime.wake_send_event_message_task_notify.notify_waiters();
170
171 if let Some(on_session_ready_handler) = self.runtime.config.on_session_ready_handler.clone() {
173 self.spawn_task(on_session_ready_handler(self.clone()));
175 }
176
177 Ok(())
178 }
179
180 async fn send_message(&self, message: Arc<Message>) -> Result<()> {
181 Ok(self.message_tx.send(message).await?)
182 }
183
184 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
185 self.send_message(self.runtime.encode_packet_to_message(packet)?).await
186 }
187
188 pub(crate) async fn cleanup(self: &Arc<Self>) {
190 self.status.store(SessionStatus::Closing);
192
193 abort_locked_task(&self.init_timeout_task).await;
195 abort_locked_task(&self.ready_timeout_task).await;
196
197 self.cancel_token.load().cancel();
199
200 if let Some(on_session_close_handler) = &self.runtime.config.on_session_close_handler {
202 let _ = timeout(
203 self.runtime.config.on_session_close_handler_timeout,
204 on_session_close_handler(self.clone()),
205 )
206 .await;
207 }
208
209 self.status.store(SessionStatus::Closed);
211 }
212
213 #[inline]
214 pub(crate) fn close(&self) {
215 match self.status.get() {
217 SessionStatus::Closed | SessionStatus::Closing => return,
218 _ => self.status.store(SessionStatus::Closing),
219 }
220
221 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
223 }
224
225 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
226 self.status.ensure(SessionStatus::Ready, |status| {
227 format!("Cannot emit event message in invalid status: {status:?}")
228 })?;
229
230 self.send_message(message).await
231 }
232
233 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
234 let packet = self.runtime.config.packet_codec.decode(encoded_packet)?;
236 match packet.r#type {
237 WsIoPacketType::Disconnect => self.handle_disconnect_packet(),
238 WsIoPacketType::Event => {
239 if let Some(event) = packet.key.as_deref() {
240 self.handle_event_packet(event, packet.data)
241 } else {
242 bail!("Event packet missing key");
243 }
244 }
245 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
246 WsIoPacketType::Ready => self.handle_ready_packet().await,
247 }
248 }
249
250 pub(crate) async fn init(self: &Arc<Self>) {
251 self.status.store(SessionStatus::AwaitingInit);
252 let session = self.clone();
253 *self.init_timeout_task.lock().await = Some(spawn(async move {
254 sleep(session.runtime.config.init_packet_timeout).await;
255 if session.status.is(SessionStatus::AwaitingInit) {
256 session.close();
257 }
258 }));
259 }
260
261 #[inline]
263 pub fn client(&self) -> WsIoClient {
264 WsIoClient(self.runtime.clone())
265 }
266
267 #[inline]
268 pub fn is_ready(&self) -> bool {
269 self.status.is(SessionStatus::Ready)
270 }
271}