Skip to main content

session_rs/
session.rs

1use std::{collections::HashMap, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use tokio::sync::Mutex;
5use tokio::sync::broadcast;
6
7use crate::{GenericMethod, Method, MethodHandler, ws::WebSocket};
8
9#[derive(Debug, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase", tag = "type")]
11pub enum Message<M: Method> {
12    Request {
13        id: u32,
14        method: String,
15        data: M::Request,
16    },
17    Response {
18        id: u32,
19        result: M::Response,
20    },
21    ErrorResponse {
22        id: u32,
23        error: M::Error,
24    },
25    Notification {
26        method: String,
27        data: M::Request,
28    },
29}
30
31pub struct Session {
32    pub ws: WebSocket,
33    id: Arc<Mutex<u32>>,
34    methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
35    tx: broadcast::Sender<(u32, bool, serde_json::Value)>,
36}
37
38impl Session {
39    pub fn clone(&self) -> Self {
40        Self {
41            ws: self.ws.clone(),
42            id: self.id.clone(),
43            methods: self.methods.clone(),
44            tx: self.tx.clone(),
45        }
46    }
47}
48
49impl Session {
50    pub fn from_ws(ws: WebSocket) -> Self {
51        Self {
52            ws,
53            id: Arc::new(Mutex::new(0)),
54            methods: Arc::new(Mutex::new(HashMap::new())),
55            tx: broadcast::channel(8192).0,
56        }
57    }
58
59    pub async fn connect(addr: &str, path: &str) -> crate::Result<Self> {
60        Ok(Self::from_ws(WebSocket::connect(addr, path).await?))
61    }
62}
63
64impl Session {
65    pub fn start_receiver(&self) {
66        let s = self.clone();
67        tokio::spawn(async move {
68            loop {
69                match s.ws.read().await {
70                    Ok(crate::ws::Frame::Text(text)) => {
71                        let Ok(msg) = serde_json::from_str::<Message<GenericMethod>>(&text) else {
72                            continue;
73                        };
74
75                        match msg {
76                            Message::Request { id, method, data } => {
77                                if let Some(m) = s.methods.lock().await.get(&method) {
78                                    if let Some((err, res)) = (m)(id, data).await {
79                                        if err {
80                                            s.respond_error(id, res)
81                                                .await
82                                                .expect("Failed to respond");
83                                        } else {
84                                            s.respond(id, res).await.expect("Failed to respond");
85                                        }
86                                    }
87                                }
88                            }
89                            Message::Response { id, result } => {
90                                s.tx.send((id, false, result)).unwrap();
91                            }
92                            Message::ErrorResponse { id, error } => {
93                                s.tx.send((id, true, error)).unwrap();
94                            }
95                            _ => {}
96                        }
97                    }
98                    Ok(_) => {}
99                    Err(_) => break,
100                }
101            }
102        });
103    }
104
105    pub async fn on<
106        M: Method,
107        Fut: Future<Output = Result<M::Response, M::Error>> + Send + 'static,
108    >(
109        &self,
110        handler: impl Fn(u32, M::Request) -> Fut + Send + Sync + 'static,
111    ) {
112        let handler = Arc::new(handler);
113
114        self.methods.lock().await.insert(
115            M::NAME.to_string(),
116            Box::new(move |id, value| {
117                let handler = Arc::clone(&handler);
118
119                Box::pin(async move {
120                    Some(
121                        match handler(id, serde_json::from_value(value).ok()?).await {
122                            Ok(v) => (false, serde_json::to_value(v).ok()?),
123                            Err(v) => (true, serde_json::to_value(v).ok()?),
124                        },
125                    )
126                })
127            }),
128        );
129    }
130}
131
132impl Session {
133    pub async fn send<M: Method>(&self, data: &Message<M>) -> crate::Result<()> {
134        self.ws
135            .send_text_payload(&serde_json::to_vec(&data)?)
136            .await?;
137        Ok(())
138    }
139
140    pub async fn use_id(&self) -> u32 {
141        let mut id = self.id.lock().await;
142        *id += 1;
143        *id
144    }
145
146    pub async fn request<M: Method>(
147        &self,
148        req: M::Request,
149    ) -> crate::Result<std::result::Result<M::Response, M::Error>> {
150        let id = self.use_id().await;
151
152        self.send::<M>(&Message::Request {
153            id,
154            method: M::NAME.to_string(),
155            data: req,
156        })
157        .await?;
158
159        let mut rx = self.tx.subscribe();
160
161        loop {
162            let r = rx.recv().await?;
163
164            if r.0 == id {
165                break Ok(if r.1 {
166                    Err(serde_json::from_value(r.2)?)
167                } else {
168                    Ok(serde_json::from_value(r.2)?)
169                });
170            }
171        }
172    }
173
174    pub async fn respond(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
175        self.send::<GenericMethod>(&Message::Response {
176            id: to,
177            result: val,
178        })
179        .await
180    }
181
182    pub async fn respond_error(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
183        self.send::<GenericMethod>(&Message::ErrorResponse { id: to, error: val })
184            .await
185    }
186
187    pub async fn notify<M: Method>(&self, data: M::Request) -> crate::Result<()> {
188        self.send::<M>(&Message::Notification {
189            method: M::NAME.to_string(),
190            data,
191        })
192        .await
193    }
194
195    pub async fn close(&self) -> crate::Result<()> {
196        Ok(self.ws.close().await?)
197    }
198}