1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use arc_swap::ArcSwap;
8use num_enum::{
9 IntoPrimitive,
10 TryFromPrimitive,
11};
12use tokio::{
13 spawn,
14 sync::{
15 Mutex,
16 mpsc::{
17 Receiver,
18 Sender,
19 channel,
20 },
21 },
22 task::JoinHandle,
23 time::{
24 sleep,
25 timeout,
26 },
27};
28use tokio_tungstenite::tungstenite::Message;
29use tokio_util::sync::CancellationToken;
30
31use crate::{
32 WsIoClient,
33 core::{
34 atomic::status::AtomicStatus,
35 channel_capacity_from_websocket_config,
36 packet::{
37 WsIoPacket,
38 WsIoPacketType,
39 },
40 traits::task::spawner::TaskSpawner,
41 utils::task::abort_locked_task,
42 },
43 runtime::WsIoClientRuntime,
44};
45
46#[repr(u8)]
48#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
49enum SessionStatus {
50 AwaitingInit,
51 AwaitingReady,
52 Closed,
53 Closing,
54 Created,
55 Initiating,
56 Ready,
57 Readying,
58}
59
60pub struct WsIoClientSession {
62 cancel_token: ArcSwap<CancellationToken>,
63 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
64 message_tx: Sender<Arc<Message>>,
65 ready_timeout_task: Mutex<Option<JoinHandle<()>>>,
66 runtime: Arc<WsIoClientRuntime>,
67 status: AtomicStatus<SessionStatus>,
68}
69
70impl TaskSpawner for WsIoClientSession {
71 #[inline]
72 fn cancel_token(&self) -> Arc<CancellationToken> {
73 self.cancel_token.load_full()
74 }
75}
76
77impl WsIoClientSession {
78 #[inline]
79 pub(crate) fn new(runtime: Arc<WsIoClientRuntime>) -> (Arc<Self>, Receiver<Arc<Message>>) {
80 let channel_capacity = channel_capacity_from_websocket_config(&runtime.config.websocket_config);
81 let (message_tx, message_rx) = channel(channel_capacity);
82 (
83 Arc::new(Self {
84 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
85 init_timeout_task: Mutex::new(None),
86 message_tx,
87 ready_timeout_task: Mutex::new(None),
88 runtime,
89 status: AtomicStatus::new(SessionStatus::Created),
90 }),
91 message_rx,
92 )
93 }
94
95 #[inline]
97 fn handle_disconnect_packet(&self) -> Result<()> {
98 let runtime = self.runtime.clone();
99 spawn(async move { runtime.disconnect().await });
100 Ok(())
101 }
102
103 #[inline]
104 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
105 self.runtime.event_registry.dispatch_event_packet(
106 self.clone(),
107 event,
108 &self.runtime.config.packet_codec,
109 packet_data,
110 &self.runtime,
111 );
112
113 Ok(())
114 }
115
116 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
117 let status = self.status.get();
119 match status {
120 SessionStatus::AwaitingInit => self.status.try_transition(status, SessionStatus::Initiating)?,
121 _ => bail!("Received init packet in invalid status: {status:?}"),
122 }
123
124 abort_locked_task(&self.init_timeout_task).await;
126
127 let response_data = if let Some(init_handler) = &self.runtime.config.init_handler {
129 timeout(
130 self.runtime.config.init_handler_timeout,
131 init_handler(self.clone(), packet_data, &self.runtime.config.packet_codec),
132 )
133 .await??
134 } else {
135 None
136 };
137
138 self.status
140 .try_transition(SessionStatus::Initiating, SessionStatus::AwaitingReady)?;
141
142 let session = self.clone();
144 *self.ready_timeout_task.lock().await = Some(spawn(async move {
145 sleep(session.runtime.config.ready_packet_timeout).await;
146 if session.status.is(SessionStatus::AwaitingReady) {
147 session.close();
148 }
149 }));
150
151 self.send_packet(&WsIoPacket::new_init(response_data)).await
153 }
154
155 async fn handle_ready_packet(self: &Arc<Self>) -> Result<()> {
156 let status = self.status.get();
158 match status {
159 SessionStatus::AwaitingReady => self.status.try_transition(status, SessionStatus::Ready)?,
160 _ => bail!("Received ready packet in invalid status: {status:?}"),
161 }
162
163 abort_locked_task(&self.ready_timeout_task).await;
165
166 self.runtime.event_message_flush_notify.notify_waiters();
168
169 if let Some(on_session_ready_handler) = self.runtime.config.on_session_ready_handler.clone() {
171 self.spawn_task(on_session_ready_handler(self.clone()));
173 }
174
175 Ok(())
176 }
177
178 async fn send_message(&self, message: Arc<Message>) -> Result<()> {
179 Ok(self.message_tx.send(message).await?)
180 }
181
182 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
183 self.send_message(self.runtime.encode_packet_to_message(packet)?).await
184 }
185
186 pub(crate) async fn cleanup(self: &Arc<Self>) {
188 self.status.store(SessionStatus::Closing);
190
191 abort_locked_task(&self.init_timeout_task).await;
193 abort_locked_task(&self.ready_timeout_task).await;
194
195 self.cancel_token.load().cancel();
197
198 if let Some(on_session_close_handler) = &self.runtime.config.on_session_close_handler {
200 let _ = timeout(
201 self.runtime.config.on_session_close_handler_timeout,
202 on_session_close_handler(self.clone()),
203 )
204 .await;
205 }
206
207 self.status.store(SessionStatus::Closed);
209 }
210
211 #[inline]
212 pub(crate) fn close(&self) {
213 match self.status.get() {
215 SessionStatus::Closed | SessionStatus::Closing => return,
216 _ => self.status.store(SessionStatus::Closing),
217 }
218
219 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
221 }
222
223 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
224 self.status.ensure(SessionStatus::Ready, |status| {
225 format!("Cannot emit event message in invalid status: {status:?}")
226 })?;
227
228 self.send_message(message).await
229 }
230
231 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
232 let packet = self.runtime.config.packet_codec.decode(encoded_packet)?;
234 match packet.r#type {
235 WsIoPacketType::Disconnect => self.handle_disconnect_packet(),
236 WsIoPacketType::Event => {
237 if let Some(event) = packet.key.as_deref() {
238 self.handle_event_packet(event, packet.data)
239 } else {
240 bail!("Event packet missing key");
241 }
242 }
243 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
244 WsIoPacketType::Ready => self.handle_ready_packet().await,
245 }
246 }
247
248 pub(crate) async fn init(self: &Arc<Self>) {
249 self.status.store(SessionStatus::AwaitingInit);
250 let session = self.clone();
251 *self.init_timeout_task.lock().await = Some(spawn(async move {
252 sleep(session.runtime.config.init_packet_timeout).await;
253 if session.status.is(SessionStatus::AwaitingInit) {
254 session.close();
255 }
256 }));
257 }
258
259 #[inline]
261 pub fn client(&self) -> WsIoClient {
262 WsIoClient(self.runtime.clone())
263 }
264
265 #[inline]
266 pub fn is_ready(&self) -> bool {
267 self.status.is(SessionStatus::Ready)
268 }
269}