1use std::{
2 collections::HashMap,
3 sync::{
4 Arc,
5 atomic::{AtomicU64, Ordering},
6 },
7 time::Duration,
8};
9
10use futures_util::{
11 SinkExt, StreamExt,
12 stream::{SplitSink, SplitStream},
13};
14use serde_json::Value;
15use tokio::net::TcpStream;
16use tokio::{
17 sync::{Mutex, mpsc, oneshot},
18 time::{interval, timeout},
19};
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
21use uuid::Uuid;
22
23use crate::protocol::{ClientFrame, DEFAULT_EVENT, ErrorPayload, ServerFrame};
24
25use super::ClientConfig;
26
27pub type ClientResult<T> = std::result::Result<T, String>;
28pub type SubscriptionId = u64;
29type ChannelHandler = Arc<dyn Fn(Value) + Send + Sync>;
30type GlobalHandler = Arc<dyn Fn(String, Value) + Send + Sync>;
31type ChannelEventHandler = Arc<dyn Fn(String, Value) + Send + Sync>;
32type GlobalEventHandler = Arc<dyn Fn(String, String, Value) + Send + Sync>;
33type PendingAcks = Arc<Mutex<HashMap<String, oneshot::Sender<ClientResult<()>>>>>;
34type ChannelHandlers =
35 Arc<std::sync::Mutex<HashMap<String, HashMap<SubscriptionId, ChannelHandler>>>>;
36type GlobalHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalHandler>>>;
37type ChannelEventHandlers =
38 Arc<std::sync::Mutex<HashMap<String, HashMap<SubscriptionId, ChannelEventHandler>>>>;
39type GlobalEventHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalEventHandler>>>;
40type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
41type WsWriter = SplitSink<WsStream, Message>;
42type WsReader = SplitStream<WsStream>;
43
44#[derive(Clone)]
45pub struct RealtimeClient {
46 outbound_tx: mpsc::Sender<ClientFrame>,
47 pending_acks: PendingAcks,
48 channel_handlers: ChannelHandlers,
49 global_handlers: GlobalHandlers,
50 channel_event_handlers: ChannelEventHandlers,
51 global_event_handlers: GlobalEventHandlers,
52 next_subscription_id: Arc<AtomicU64>,
53 cfg: ClientConfig,
54}
55
56impl RealtimeClient {
57 pub async fn connect(base_url: &str, token: &str) -> ClientResult<Self> {
58 Self::connect_with_config(base_url, token, ClientConfig::default()).await
59 }
60
61 pub async fn connect_with_config(
62 base_url: &str,
63 token: &str,
64 cfg: ClientConfig,
65 ) -> ClientResult<Self> {
66 let ws = Self::open_socket(base_url, token).await?;
67 let (write, read) = ws.split();
68 let (outbound_tx, outbound_rx) = mpsc::channel::<ClientFrame>(cfg.outbound_buffer);
69
70 let pending_acks: PendingAcks = Arc::new(Mutex::new(HashMap::new()));
71 let channel_handlers: ChannelHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
72 let global_handlers: GlobalHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
73 let channel_event_handlers: ChannelEventHandlers =
74 Arc::new(std::sync::Mutex::new(HashMap::new()));
75 let global_event_handlers: GlobalEventHandlers =
76 Arc::new(std::sync::Mutex::new(HashMap::new()));
77
78 Self::spawn_writer_task(write, outbound_rx);
79 Self::spawn_reader_task(
80 read,
81 Arc::clone(&pending_acks),
82 Arc::clone(&channel_handlers),
83 Arc::clone(&global_handlers),
84 Arc::clone(&channel_event_handlers),
85 Arc::clone(&global_event_handlers),
86 );
87 Self::spawn_ping_task(outbound_tx.clone(), cfg.ping_interval);
88
89 Ok(Self {
90 outbound_tx,
91 pending_acks,
92 channel_handlers,
93 global_handlers,
94 channel_event_handlers,
95 global_event_handlers,
96 next_subscription_id: Arc::new(AtomicU64::new(1)),
97 cfg,
98 })
99 }
100
101 pub async fn join(&self, channel: &str) -> ClientResult<()> {
102 self.request_ack(
103 ClientFrame::ChannelJoin {
104 id: Uuid::new_v4().to_string(),
105 channel: channel.to_string(),
106 ts: None,
107 },
108 self.cfg.request_timeout,
109 )
110 .await
111 }
112
113 pub async fn leave(&self, channel: &str) -> ClientResult<()> {
114 self.request_ack(
115 ClientFrame::ChannelLeave {
116 id: Uuid::new_v4().to_string(),
117 channel: channel.to_string(),
118 ts: None,
119 },
120 self.cfg.request_timeout,
121 )
122 .await
123 }
124
125 pub async fn send(&self, channel: &str, message: Value) -> ClientResult<()> {
126 self.send_event(channel, DEFAULT_EVENT, message).await
127 }
128
129 pub async fn send_event(&self, channel: &str, event: &str, message: Value) -> ClientResult<()> {
130 self.request_ack(
131 ClientFrame::ChannelEmit {
132 id: Uuid::new_v4().to_string(),
133 channel: channel.to_string(),
134 event: event.to_string(),
135 data: message,
136 ts: None,
137 },
138 self.cfg.request_timeout,
139 )
140 .await
141 }
142
143 pub fn on_message<F>(&self, channel: &str, handler: F) -> SubscriptionId
144 where
145 F: Fn(Value) + Send + Sync + 'static,
146 {
147 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
148 let mut guard = self
149 .channel_handlers
150 .lock()
151 .expect("channel handler mutex poisoned");
152 guard
153 .entry(channel.to_string())
154 .or_default()
155 .insert(id, Arc::new(handler));
156 id
157 }
158
159 pub fn on_messages<F>(&self, handler: F) -> SubscriptionId
160 where
161 F: Fn(String, Value) + Send + Sync + 'static,
162 {
163 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
164 self.global_handlers
165 .lock()
166 .expect("global handler mutex poisoned")
167 .insert(id, Arc::new(handler));
168 id
169 }
170
171 pub fn on_channel_event<F>(&self, channel: &str, handler: F) -> SubscriptionId
172 where
173 F: Fn(String, Value) + Send + Sync + 'static,
174 {
175 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
176 let mut guard = self
177 .channel_event_handlers
178 .lock()
179 .expect("channel event handler mutex poisoned");
180 guard
181 .entry(channel.to_string())
182 .or_default()
183 .insert(id, Arc::new(handler));
184 id
185 }
186
187 pub fn on_events<F>(&self, handler: F) -> SubscriptionId
188 where
189 F: Fn(String, String, Value) + Send + Sync + 'static,
190 {
191 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
192 self.global_event_handlers
193 .lock()
194 .expect("global event handler mutex poisoned")
195 .insert(id, Arc::new(handler));
196 id
197 }
198
199 pub fn off(&self, id: SubscriptionId) -> bool {
200 let mut removed = false;
201
202 let mut global = self
203 .global_handlers
204 .lock()
205 .expect("global handler mutex poisoned");
206 if global.remove(&id).is_some() {
207 removed = true;
208 }
209 drop(global);
210
211 let mut channels = self
212 .channel_handlers
213 .lock()
214 .expect("channel handler mutex poisoned");
215 for handlers in channels.values_mut() {
216 if handlers.remove(&id).is_some() {
217 removed = true;
218 }
219 }
220
221 let mut global_events = self
222 .global_event_handlers
223 .lock()
224 .expect("global event handler mutex poisoned");
225 if global_events.remove(&id).is_some() {
226 removed = true;
227 }
228 drop(global_events);
229
230 let mut channel_events = self
231 .channel_event_handlers
232 .lock()
233 .expect("channel event handler mutex poisoned");
234 for handlers in channel_events.values_mut() {
235 if handlers.remove(&id).is_some() {
236 removed = true;
237 }
238 }
239
240 removed
241 }
242
243 async fn open_socket(base_url: &str, token: &str) -> ClientResult<WsStream> {
244 let url = with_query_token(base_url, token);
245 let (ws, _) = connect_async(&url)
246 .await
247 .map_err(|err| format!("failed to connect to {url}: {err}"))?;
248 Ok(ws)
249 }
250
251 fn spawn_writer_task(mut write: WsWriter, mut outbound_rx: mpsc::Receiver<ClientFrame>) {
252 tokio::spawn(async move {
253 while let Some(frame) = outbound_rx.recv().await {
254 let text = match serde_json::to_string(&frame) {
255 Ok(text) => text,
256 Err(err) => {
257 eprintln!("failed to serialize outbound frame: {err}");
258 continue;
259 }
260 };
261
262 if write.send(Message::Text(text.into())).await.is_err() {
263 break;
264 }
265 }
266 });
267 }
268
269 fn spawn_reader_task(
270 mut read: WsReader,
271 pending_acks: PendingAcks,
272 channel_handlers: ChannelHandlers,
273 global_handlers: GlobalHandlers,
274 channel_event_handlers: ChannelEventHandlers,
275 global_event_handlers: GlobalEventHandlers,
276 ) {
277 tokio::spawn(async move {
278 while let Some(next) = read.next().await {
279 let msg = match next {
280 Ok(msg) => msg,
281 Err(err) => {
282 eprintln!("websocket read error: {err}");
283 break;
284 }
285 };
286
287 let keep_reading = Self::handle_incoming_message(
288 msg,
289 &pending_acks,
290 &channel_handlers,
291 &global_handlers,
292 &channel_event_handlers,
293 &global_event_handlers,
294 )
295 .await;
296 if !keep_reading {
297 break;
298 }
299 }
300
301 Self::fail_pending_acks(&pending_acks).await;
302 });
303 }
304
305 async fn handle_incoming_message(
306 msg: Message,
307 pending_acks: &PendingAcks,
308 channel_handlers: &ChannelHandlers,
309 global_handlers: &GlobalHandlers,
310 channel_event_handlers: &ChannelEventHandlers,
311 global_event_handlers: &GlobalEventHandlers,
312 ) -> bool {
313 let text = match msg {
314 Message::Text(text) => text,
315 Message::Close(_) => return false,
316 _ => return true,
317 };
318
319 let frame = match serde_json::from_str::<ServerFrame>(&text) {
320 Ok(frame) => frame,
321 Err(err) => {
322 eprintln!("invalid server frame: {err}");
323 return true;
324 }
325 };
326
327 Self::handle_server_frame(
328 frame,
329 pending_acks,
330 channel_handlers,
331 global_handlers,
332 channel_event_handlers,
333 global_event_handlers,
334 )
335 .await;
336 true
337 }
338
339 async fn handle_server_frame(
340 frame: ServerFrame,
341 pending_acks: &PendingAcks,
342 channel_handlers: &ChannelHandlers,
343 global_handlers: &GlobalHandlers,
344 channel_event_handlers: &ChannelEventHandlers,
345 global_event_handlers: &GlobalEventHandlers,
346 ) {
347 match frame {
348 ServerFrame::Connected {
349 conn_id, user_id, ..
350 } => {
351 println!("connected: conn_id={conn_id} user_id={user_id}");
352 }
353 ServerFrame::Joined { channel, .. } => {
354 println!("joined channel={channel}");
355 }
356 ServerFrame::Left { channel, .. } => {
357 println!("left channel={channel}");
358 }
359 ServerFrame::Event {
360 channel,
361 event,
362 data,
363 ..
364 } => {
365 dispatch_channel_handlers(channel_handlers, &channel, &data);
366 dispatch_global_handlers(global_handlers, &channel, &data);
367 dispatch_channel_event_handlers(channel_event_handlers, &channel, &event, &data);
368 dispatch_global_event_handlers(global_event_handlers, &channel, &event, &data);
369 }
370 ServerFrame::Ack {
371 for_id, ok, error, ..
372 } => {
373 Self::resolve_ack(pending_acks, for_id, ok, error).await;
374 }
375 ServerFrame::Pong { id, .. } => {
376 println!("pong id={id}");
377 }
378 ServerFrame::Error { error, .. } => {
379 eprintln!("server error {}: {}", error.code, error.message);
380 }
381 }
382 }
383
384 async fn resolve_ack(
385 pending_acks: &PendingAcks,
386 for_id: String,
387 ok: bool,
388 error: Option<ErrorPayload>,
389 ) {
390 let Some(tx) = pending_acks.lock().await.remove(&for_id) else {
391 return;
392 };
393
394 let result = if ok {
395 Ok(())
396 } else {
397 let message = error
398 .map(|e| format!("{}: {}", e.code, e.message))
399 .unwrap_or_else(|| "request rejected".to_string());
400 Err(message)
401 };
402 let _ = tx.send(result);
403 }
404
405 async fn fail_pending_acks(pending_acks: &PendingAcks) {
406 let mut pending = pending_acks.lock().await;
407 for (_, tx) in pending.drain() {
408 let _ = tx.send(Err("websocket connection closed".to_string()));
409 }
410 }
411
412 fn spawn_ping_task(outbound_tx: mpsc::Sender<ClientFrame>, ping_interval: Duration) {
413 tokio::spawn(async move {
414 let mut ticker = interval(ping_interval);
415 loop {
416 ticker.tick().await;
417 if outbound_tx
418 .send(ClientFrame::Ping {
419 id: Uuid::new_v4().to_string(),
420 ts: None,
421 })
422 .await
423 .is_err()
424 {
425 break;
426 }
427 }
428 });
429 }
430
431 async fn request_ack(&self, frame: ClientFrame, timeout_dur: Duration) -> ClientResult<()> {
432 let req_id = frame_id(&frame).to_string();
433 let (tx, rx) = oneshot::channel();
434 self.pending_acks.lock().await.insert(req_id.clone(), tx);
435
436 if let Err(err) = self.outbound_tx.send(frame).await {
437 self.pending_acks.lock().await.remove(&req_id);
438 return Err(format!("failed to send request: {err}"));
439 }
440
441 match timeout(timeout_dur, rx).await {
442 Ok(Ok(result)) => result,
443 Ok(Err(_)) => Err("ack wait channel dropped".to_string()),
444 Err(_) => {
445 self.pending_acks.lock().await.remove(&req_id);
446 Err(format!("ack timeout for request {req_id}"))
447 }
448 }
449 }
450}
451
452fn dispatch_channel_handlers(handlers: &ChannelHandlers, channel: &str, message: &Value) {
453 let callbacks: Vec<ChannelHandler> = {
454 let guard = handlers.lock().expect("channel handler mutex poisoned");
455 guard
456 .get(channel)
457 .map(|entries| entries.values().cloned().collect())
458 .unwrap_or_default()
459 };
460
461 for callback in callbacks {
462 callback(message.clone());
463 }
464}
465
466fn dispatch_global_handlers(handlers: &GlobalHandlers, channel: &str, message: &Value) {
467 let callbacks: Vec<GlobalHandler> = {
468 let guard = handlers.lock().expect("global handler mutex poisoned");
469 guard.values().cloned().collect()
470 };
471
472 for callback in callbacks {
473 callback(channel.to_string(), message.clone());
474 }
475}
476
477fn dispatch_channel_event_handlers(
478 handlers: &ChannelEventHandlers,
479 channel: &str,
480 event: &str,
481 message: &Value,
482) {
483 let callbacks: Vec<ChannelEventHandler> = {
484 let guard = handlers
485 .lock()
486 .expect("channel event handler mutex poisoned");
487 guard
488 .get(channel)
489 .map(|entries| entries.values().cloned().collect())
490 .unwrap_or_default()
491 };
492
493 for callback in callbacks {
494 callback(event.to_string(), message.clone());
495 }
496}
497
498fn dispatch_global_event_handlers(
499 handlers: &GlobalEventHandlers,
500 channel: &str,
501 event: &str,
502 message: &Value,
503) {
504 let callbacks: Vec<GlobalEventHandler> = {
505 let guard = handlers
506 .lock()
507 .expect("global event handler mutex poisoned");
508 guard.values().cloned().collect()
509 };
510
511 for callback in callbacks {
512 callback(channel.to_string(), event.to_string(), message.clone());
513 }
514}
515
516fn frame_id(frame: &ClientFrame) -> &str {
517 match frame {
518 ClientFrame::ChannelJoin { id, .. } => id,
519 ClientFrame::ChannelLeave { id, .. } => id,
520 ClientFrame::ChannelEmit { id, .. } => id,
521 ClientFrame::Ping { id, .. } => id,
522 }
523}
524
525fn with_query_token(base_url: &str, token: &str) -> String {
526 if base_url.contains('?') {
527 format!("{base_url}&token={token}")
528 } else {
529 format!("{base_url}?token={token}")
530 }
531}