wsio_client/
connection.rs1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use arc_swap::ArcSwap;
8use num_enum::{
9 IntoPrimitive,
10 TryFromPrimitive,
11};
12use tokio::{
13 spawn,
14 sync::{
15 Mutex,
16 mpsc::{
17 Receiver,
18 Sender,
19 channel,
20 },
21 },
22 task::JoinHandle,
23 time::{
24 sleep,
25 timeout,
26 },
27};
28use tokio_tungstenite::tungstenite::Message;
29use tokio_util::sync::CancellationToken;
30
31use crate::{
32 WsIoClient,
33 core::{
34 atomic::status::AtomicStatus,
35 channel_capacity_from_websocket_config,
36 packet::{
37 WsIoPacket,
38 WsIoPacketType,
39 },
40 traits::task::spawner::TaskSpawner,
41 utils::task::abort_locked_task,
42 },
43 runtime::WsIoClientRuntime,
44};
45
46#[repr(u8)]
48#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
49enum ConnectionStatus {
50 AwaitingInit,
51 AwaitingReady,
52 Closed,
53 Closing,
54 Created,
55 Initiating,
56 Ready,
57 Readying,
58}
59
60pub struct WsIoClientConnection {
62 cancel_token: ArcSwap<CancellationToken>,
63 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
64 message_tx: Sender<Arc<Message>>,
65 ready_timeout_task: Mutex<Option<JoinHandle<()>>>,
66 runtime: Arc<WsIoClientRuntime>,
67 status: AtomicStatus<ConnectionStatus>,
68}
69
70impl TaskSpawner for WsIoClientConnection {
71 #[inline]
72 fn cancel_token(&self) -> Arc<CancellationToken> {
73 self.cancel_token.load_full()
74 }
75}
76
77impl WsIoClientConnection {
78 #[inline]
79 pub(crate) fn new(runtime: Arc<WsIoClientRuntime>) -> (Arc<Self>, Receiver<Arc<Message>>) {
80 let channel_capacity = channel_capacity_from_websocket_config(&runtime.config.websocket_config);
81 let (message_tx, message_rx) = channel(channel_capacity);
82 (
83 Arc::new(Self {
84 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
85 init_timeout_task: Mutex::new(None),
86 message_tx,
87 ready_timeout_task: Mutex::new(None),
88 runtime,
89 status: AtomicStatus::new(ConnectionStatus::Created),
90 }),
91 message_rx,
92 )
93 }
94
95 #[inline]
97 fn handle_disconnect_packet(&self) -> Result<()> {
98 let runtime = self.runtime.clone();
99 spawn(async move { runtime.disconnect().await });
100 Ok(())
101 }
102
103 #[inline]
104 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
105 self.runtime.event_registry.dispatch_event_packet(
106 self.clone(),
107 event,
108 &self.runtime.config.packet_codec,
109 packet_data,
110 &self.runtime,
111 );
112
113 Ok(())
114 }
115
116 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
117 let status = self.status.get();
119 match status {
120 ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
121 _ => bail!("Received init packet in invalid status: {status:?}"),
122 }
123
124 abort_locked_task(&self.init_timeout_task).await;
126
127 let response_data = if let Some(init_handler) = &self.runtime.config.init_handler {
129 timeout(
130 self.runtime.config.init_handler_timeout,
131 init_handler(self.clone(), packet_data, &self.runtime.config.packet_codec),
132 )
133 .await??
134 } else {
135 None
136 };
137
138 self.status
140 .try_transition(ConnectionStatus::Initiating, ConnectionStatus::AwaitingReady)?;
141
142 let connection = self.clone();
144 *self.ready_timeout_task.lock().await = Some(spawn(async move {
145 sleep(connection.runtime.config.ready_packet_timeout).await;
146 if connection.status.is(ConnectionStatus::AwaitingReady) {
147 connection.close();
148 }
149 }));
150
151 self.send_packet(&WsIoPacket::new_init(response_data)).await
153 }
154
155 async fn handle_ready_packet(self: &Arc<Self>) -> Result<()> {
156 let status = self.status.get();
158 match status {
159 ConnectionStatus::AwaitingReady => self.status.try_transition(status, ConnectionStatus::Ready)?,
160 _ => bail!("Received ready packet in invalid status: {status:?}"),
161 }
162
163 abort_locked_task(&self.ready_timeout_task).await;
165
166 self.runtime.event_message_flush_notify.notify_waiters();
168
169 if let Some(on_connection_ready_handler) = self.runtime.config.on_connection_ready_handler.clone() {
171 self.spawn_task(on_connection_ready_handler(self.clone()));
173 }
174
175 Ok(())
176 }
177
178 async fn send_message(&self, message: Arc<Message>) -> Result<()> {
179 Ok(self.message_tx.send(message).await?)
180 }
181
182 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
183 self.send_message(self.runtime.encode_packet_to_message(packet)?).await
184 }
185
186 pub(crate) async fn cleanup(self: &Arc<Self>) {
188 self.status.store(ConnectionStatus::Closing);
190
191 abort_locked_task(&self.init_timeout_task).await;
193 abort_locked_task(&self.ready_timeout_task).await;
194
195 self.cancel_token.load().cancel();
197
198 if let Some(on_connection_close_handler) = &self.runtime.config.on_connection_close_handler {
200 let _ = timeout(
201 self.runtime.config.on_connection_close_handler_timeout,
202 on_connection_close_handler(self.clone()),
203 )
204 .await;
205 }
206
207 self.status.store(ConnectionStatus::Closed);
209 }
210
211 #[inline]
212 pub(crate) fn close(&self) {
213 match self.status.get() {
215 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
216 _ => self.status.store(ConnectionStatus::Closing),
217 }
218
219 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
221 }
222
223 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
224 self.status.ensure(ConnectionStatus::Ready, |status| {
225 format!("Cannot emit event message in invalid status: {status:?}")
226 })?;
227
228 self.send_message(message).await
229 }
230
231 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
232 let packet = self.runtime.config.packet_codec.decode(bytes)?;
234 match packet.r#type {
235 WsIoPacketType::Disconnect => self.handle_disconnect_packet(),
236 WsIoPacketType::Event => {
237 if let Some(event) = packet.key.as_deref() {
238 self.handle_event_packet(event, packet.data)
239 } else {
240 bail!("Event packet missing key");
241 }
242 }
243 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
244 WsIoPacketType::Ready => self.handle_ready_packet().await,
245 }
246 }
247
248 pub(crate) async fn init(self: &Arc<Self>) {
249 self.status.store(ConnectionStatus::AwaitingInit);
250 let connection = self.clone();
251 *self.init_timeout_task.lock().await = Some(spawn(async move {
252 sleep(connection.runtime.config.init_packet_timeout).await;
253 if connection.status.is(ConnectionStatus::AwaitingInit) {
254 connection.close();
255 }
256 }));
257 }
258
259 pub fn client(&self) -> WsIoClient {
261 WsIoClient(self.runtime.clone())
262 }
263}