1use std::hash::Hash;
2use std::{collections::HashMap, sync::Arc};
3
4use serde::{Deserialize, Serialize};
5use tokio::sync::Mutex;
6use tokio::sync::broadcast;
7use tokio::time::timeout;
8
9use crate::BoxFuture;
10use crate::{GenericMethod, Method, MethodHandler, ws::WebSocket};
11
12#[derive(Debug, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase", tag = "type")]
14pub enum Message<M: Method> {
15 Request {
16 id: u32,
17 method: String,
18 data: M::Request,
19 },
20 Response {
21 id: u32,
22 result: M::Response,
23 },
24 ErrorResponse {
25 id: u32,
26 error: M::Error,
27 },
28 Notification {
29 method: String,
30 data: M::Request,
31 },
32}
33
34pub struct Session {
35 pub ws: WebSocket,
36 id: Arc<Mutex<u32>>,
37 methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
38 on_close_fn:
39 Arc<Mutex<Option<Box<dyn Fn() -> BoxFuture<'static, Result<(), String>> + Send + Sync>>>>,
40 tx: broadcast::Sender<(u32, bool, serde_json::Value)>,
41 pong_tx: broadcast::Sender<()>,
42}
43
44impl Session {
45 pub fn clone(&self) -> Self {
46 Self {
47 ws: self.ws.clone(),
48 id: self.id.clone(),
49 methods: self.methods.clone(),
50 on_close_fn: self.on_close_fn.clone(),
51 tx: self.tx.clone(),
52 pong_tx: self.pong_tx.clone(),
53 }
54 }
55}
56
57impl Session {
58 pub fn from_ws(ws: WebSocket) -> Self {
59 let (tx, _) = broadcast::channel(8192);
60 let (pong_tx, _) = broadcast::channel(16);
61
62 Self {
63 ws,
64 id: Arc::new(Mutex::new(0)),
65 methods: Arc::new(Mutex::new(HashMap::new())),
66 on_close_fn: Arc::new(Mutex::new(None)),
67 tx,
68 pong_tx,
69 }
70 }
71
72 pub async fn connect(addr: &str, path: &str) -> crate::Result<Self> {
73 Ok(Self::from_ws(WebSocket::connect(addr, path).await?))
74 }
75}
76
77impl Session {
78 pub fn start_receiver(&self) {
79 let s = self.clone();
80 tokio::spawn(async move {
81 loop {
82 match s.ws.read().await {
83 Ok(crate::ws::Frame::Text(text)) => {
84 let Ok(msg) = serde_json::from_str::<Message<GenericMethod>>(&text) else {
85 continue;
86 };
87
88 match msg {
89 Message::Request { id, method, data } => {
90 if let Some(m) = s.methods.lock().await.get(&method) {
91 if let Some((err, res)) = (m)(id, data).await {
92 if err {
93 s.respond_error(id, res)
94 .await
95 .expect("Failed to respond");
96 } else {
97 s.respond(id, res).await.expect("Failed to respond");
98 }
99 }
100 }
101 }
102 Message::Response { id, result } => {
103 s.tx.send((id, false, result)).unwrap();
104 }
105 Message::ErrorResponse { id, error } => {
106 s.tx.send((id, true, error)).unwrap();
107 }
108 _ => {}
109 }
110 }
111 Ok(crate::ws::Frame::Pong) => {
112 let _ = s.pong_tx.send(());
113 }
114 Ok(_) => {}
115 Err(_) => {
116 s.trigger_close().await;
117 break;
118 }
119 }
120 }
121 });
122 }
123 pub fn start_ping(&self, interval: tokio::time::Duration, timeout_dur: tokio::time::Duration) {
124 let s = self.clone();
125
126 tokio::spawn(async move {
127 let mut pong_rx = s.pong_tx.subscribe();
128
129 loop {
130 tokio::time::sleep(interval).await;
131
132 if s.ws.send_ping().await.is_err() {
133 s.trigger_close().await;
134 break;
135 }
136
137 let result = timeout(timeout_dur, pong_rx.recv()).await;
138
139 if result.is_err() {
140 let _ = s.close().await;
142 s.trigger_close().await;
143 break;
144 }
145 }
146 });
147 }
148
149 pub async fn on_request<
150 M: Method,
151 Fut: Future<Output = Result<M::Response, M::Error>> + Send + 'static,
152 >(
153 &self,
154 handler: impl Fn(u32, M::Request) -> Fut + Send + Sync + 'static,
155 ) {
156 let handler = Arc::new(handler);
157
158 self.methods.lock().await.insert(
159 M::NAME.to_string(),
160 Box::new(move |id, value| {
161 let handler = Arc::clone(&handler);
162
163 Box::pin(async move {
164 Some(
165 match handler(id, serde_json::from_value(value).ok()?).await {
166 Ok(v) => (false, serde_json::to_value(v).ok()?),
167 Err(v) => (true, serde_json::to_value(v).ok()?),
168 },
169 )
170 })
171 }),
172 );
173 }
174
175 pub async fn on_close<Fut>(&self, handler: impl Fn() -> Fut + Send + Sync + 'static)
176 where
177 Fut: Future<Output = Result<(), String>> + Send + 'static,
178 {
179 let handler = Arc::new(handler);
180
181 *self.on_close_fn.lock().await = Some(Box::new(move || {
182 let handler = handler.clone();
183 Box::pin(async move { handler().await })
184 }));
185 }
186}
187
188impl Session {
189 pub async fn send<M: Method>(&self, data: &Message<M>) -> crate::Result<()> {
190 self.ws
191 .send_text_payload(&serde_json::to_vec(&data)?)
192 .await?;
193 Ok(())
194 }
195
196 pub async fn use_id(&self) -> u32 {
197 let mut id = self.id.lock().await;
198 *id += 1;
199 *id
200 }
201
202 pub async fn request<M: Method>(
203 &self,
204 req: M::Request,
205 ) -> crate::Result<std::result::Result<M::Response, M::Error>> {
206 let id = self.use_id().await;
207
208 self.send::<M>(&Message::Request {
209 id,
210 method: M::NAME.to_string(),
211 data: req,
212 })
213 .await?;
214
215 let mut rx = self.tx.subscribe();
216
217 loop {
218 let r = rx.recv().await?;
219
220 if r.0 == id {
221 break Ok(if r.1 {
222 Err(serde_json::from_value(r.2)?)
223 } else {
224 Ok(serde_json::from_value(r.2)?)
225 });
226 }
227 }
228 }
229
230 pub async fn respond(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
231 self.send::<GenericMethod>(&Message::Response {
232 id: to,
233 result: val,
234 })
235 .await
236 }
237
238 pub async fn respond_error(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
239 self.send::<GenericMethod>(&Message::ErrorResponse { id: to, error: val })
240 .await
241 }
242
243 pub async fn notify<M: Method>(&self, data: M::Request) -> crate::Result<()> {
244 self.send::<M>(&Message::Notification {
245 method: M::NAME.to_string(),
246 data,
247 })
248 .await
249 }
250
251 async fn trigger_close(&self) {
252 if let Some(handler) = self.on_close_fn.lock().await.as_ref() {
253 let _ = handler().await;
254 }
255 }
256
257 pub async fn close(&self) -> crate::Result<()> {
258 let res = self.ws.close().await;
259 self.trigger_close().await;
260 Ok(res?)
261 }
262}
263
264impl Hash for Session {
265 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
266 self.ws.id.hash(state);
267 }
268}
269
270impl PartialEq for Session {
271 fn eq(&self, other: &Self) -> bool {
272 self.ws.id == other.ws.id
273 }
274}
275
276impl Eq for Session {}