wsio_server/
connection.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use bson::oid::ObjectId;
5use http::HeaderMap;
6use tokio::{
7 spawn,
8 sync::{
9 Mutex,
10 RwLock,
11 mpsc::{
12 UnboundedReceiver,
13 UnboundedSender,
14 unbounded_channel,
15 },
16 },
17 task::JoinHandle,
18 time::sleep,
19};
20use tokio_tungstenite::tungstenite::Message;
21
22use crate::{
23 core::packet::{
24 WsIoPacket,
25 WsIoPacketType,
26 },
27 namespace::WsIoServerNamespace,
28 types::handler::WsIoServerConnectionOnDisconnectHandler,
29};
30
31enum WsIoServerConnectionStatus {
32 Activating,
33 AwaitingAuth,
34 Closed,
35 Closing,
36 Created,
37 Ready,
38}
39
40pub struct WsIoServerConnection {
41 headers: HeaderMap,
42 namespace: Arc<WsIoServerNamespace>,
43 on_disconnect_handler: Mutex<Option<WsIoServerConnectionOnDisconnectHandler>>,
44 sid: String,
45 status: RwLock<WsIoServerConnectionStatus>,
46 tx: UnboundedSender<Message>,
47 wait_auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
48}
49
50impl WsIoServerConnection {
51 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
52 let (tx, rx) = unbounded_channel();
53 (
54 Self {
55 headers,
56 namespace,
57 on_disconnect_handler: Mutex::new(None),
58 sid: ObjectId::new().to_string(),
59 status: RwLock::new(WsIoServerConnectionStatus::Created),
60 tx,
61 wait_auth_timeout_task: Mutex::new(None),
62 },
63 rx,
64 )
65 }
66
67 pub(crate) async fn activate(self: &Arc<Self>) -> Result<()> {
69 *self.status.write().await = WsIoServerConnectionStatus::Activating;
70 self.namespace.insert_connection(self.clone());
72 *self.status.write().await = WsIoServerConnectionStatus::Ready;
73 let packet = WsIoPacket {
74 data: None,
75 key: None,
76 r#type: WsIoPacketType::Ready,
77 };
78
79 self.send_packet(&packet)?;
80 (self.namespace.config.on_connect_handler)(self.clone()).await
81 }
82
83 pub(crate) async fn cleanup(self: &Arc<Self>) {
84 *self.status.write().await = WsIoServerConnectionStatus::Closing;
85 if let Some(wait_auth_timeout_task) = self.wait_auth_timeout_task.lock().await.take() {
86 wait_auth_timeout_task.abort();
87 }
88
89 self.namespace.cleanup_connection(&self.sid);
90 if let Some(on_disconnect_handler) = self.on_disconnect_handler.lock().await.take() {
91 let _ = on_disconnect_handler(self.clone()).await;
92 }
93
94 *self.status.write().await = WsIoServerConnectionStatus::Closed;
95 }
96
97 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
98 self.send(Message::Text(format!("c{}", self.namespace.config.packet_codec).into()))?;
99 let require_auth = self.namespace.config.auth_handler.is_some();
100 let packet = WsIoPacket {
101 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
102 key: Some(self.sid.clone()),
103 r#type: WsIoPacketType::Init,
104 };
105
106 if require_auth {
107 *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
108 let connection = self.clone();
109 self.wait_auth_timeout_task.lock().await.replace(spawn(async move {
110 sleep(connection.namespace.config.auth_timeout).await;
111 if matches!(
112 *connection.status.read().await,
113 WsIoServerConnectionStatus::AwaitingAuth
114 ) {
115 connection.close();
116 }
117 }));
118
119 self.send_packet(&packet)?;
120 } else {
121 self.send_packet(&packet)?;
122 self.activate().await?;
123 }
124
125 Ok(())
126 }
127
128 pub(crate) async fn on_message(&self, _message: Message) {}
129
130 #[inline]
131 pub(crate) fn send(&self, message: Message) -> Result<()> {
132 Ok(self.tx.send(message)?)
133 }
134
135 #[inline]
136 pub(crate) fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
137 self.send(self.namespace.encode_packet_to_message(packet)?)
138 }
139
140 #[inline]
143 pub fn close(&self) {
144 let _ = self.send(Message::Close(None));
145 }
146
147 #[inline]
148 pub fn headers(&self) -> &HeaderMap {
149 &self.headers
150 }
151
152 #[inline]
153 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
154 self.namespace.clone()
155 }
156
157 pub async fn on_disconnect<H, Fut>(&self, handler: H)
158 where
159 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
160 Fut: Future<Output = Result<()>> + Send + 'static,
161 {
162 self.on_disconnect_handler
163 .lock()
164 .await
165 .replace(Box::new(move |connection| Box::pin(handler(connection))));
166 }
167
168 #[inline]
169 pub fn sid(&self) -> &str {
170 &self.sid
171 }
172}