1use std::collections::HashMap;
2use std::future::Future;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::time::Duration;
7
8use futures_util::{FutureExt, SinkExt, StreamExt, future::BoxFuture};
9use serde_json::{Value, json};
10use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
11use tokio::time::{MissedTickBehavior, interval, sleep, timeout};
12use tokio_tungstenite::{connect_async, tungstenite::Message};
13use uuid::Uuid;
14use wscall_protocol::{
15 EncryptionKind, ErrorPayload, FileAttachment, FrameCodec, PacketBody, PacketEnvelope,
16};
17
18use crate::client_types::{
19 ClientConnectionEvent, ClientDisconnectEvent, ClientError, ClientOutbound, EventMessage,
20};
21
22type EventHandler = Arc<dyn Fn(EventMessage) -> BoxFuture<'static, Value> + Send + Sync>;
23type ConnectionHandler = Arc<dyn Fn(ClientConnectionEvent) -> BoxFuture<'static, ()> + Send + Sync>;
24type DisconnectHandler =
25 Arc<dyn Fn(ClientDisconnectEvent) -> BoxFuture<'static, ()> + Send + Sync>;
26type PendingSender = oneshot::Sender<Result<Value, ClientError>>;
27type PendingMap = Arc<Mutex<HashMap<String, PendingSender>>>;
28
29const CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(45);
30const CLIENT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
31const CLIENT_OUTBOUND_QUEUE_CAPACITY: usize = 256;
32const CLIENT_RECONNECT_BASE_DELAY_SECS: u64 = 3;
33const CLIENT_RECONNECT_MAX_DELAY_SECS: u64 = 30;
34
35#[derive(Clone)]
36pub struct WscallClient {
37 url: Arc<str>,
38 codec: FrameCodec,
39 writer: Arc<RwLock<Option<mpsc::Sender<ClientOutbound>>>>,
40 pending_api: PendingMap,
41 pending_event: PendingMap,
42 event_handlers: Arc<RwLock<HashMap<String, Vec<EventHandler>>>>,
43 connected_handlers: Arc<RwLock<Vec<ConnectionHandler>>>,
44 disconnected_handlers: Arc<RwLock<Vec<DisconnectHandler>>>,
45 default_timeout: Duration,
46 default_encryption: EncryptionKind,
47 is_connected: Arc<AtomicBool>,
48 shutdown: Arc<AtomicBool>,
49 connection_generation: Arc<AtomicU64>,
50}
51
52impl WscallClient {
53 pub async fn connect(url: &str) -> Result<Self, ClientError> {
54 Self::connect_with_settings(url, FrameCodec::plaintext(), EncryptionKind::None).await
55 }
56
57 pub async fn connect_with_chacha20(url: &str, key: [u8; 32]) -> Result<Self, ClientError> {
58 Self::connect_with_settings(
59 url,
60 FrameCodec::plaintext().with_chacha20_key(key),
61 EncryptionKind::ChaCha20,
62 )
63 .await
64 }
65
66 pub async fn connect_with_aes256(url: &str, key: [u8; 32]) -> Result<Self, ClientError> {
67 Self::connect_with_settings(
68 url,
69 FrameCodec::plaintext().with_aes256_key(key),
70 EncryptionKind::Aes256,
71 )
72 .await
73 }
74
75 async fn connect_with_settings(
76 url: &str,
77 codec: FrameCodec,
78 default_encryption: EncryptionKind,
79 ) -> Result<Self, ClientError> {
80 let client = Self {
81 url: Arc::<str>::from(url),
82 codec,
83 writer: Arc::new(RwLock::new(None)),
84 pending_api: Arc::new(Mutex::new(HashMap::new())),
85 pending_event: Arc::new(Mutex::new(HashMap::new())),
86 event_handlers: Arc::new(RwLock::new(HashMap::new())),
87 connected_handlers: Arc::new(RwLock::new(Vec::new())),
88 disconnected_handlers: Arc::new(RwLock::new(Vec::new())),
89 default_timeout: Duration::from_secs(10),
90 default_encryption,
91 is_connected: Arc::new(AtomicBool::new(false)),
92 shutdown: Arc::new(AtomicBool::new(false)),
93 connection_generation: Arc::new(AtomicU64::new(0)),
94 };
95
96 let (ready_tx, ready_rx) = oneshot::channel();
97 let supervisor_client = client.clone();
98 tokio::spawn(async move {
99 supervisor_client.run_connection_supervisor(ready_tx).await;
100 });
101
102 ready_rx.await.map_err(|_| {
103 ClientError::ConnectionClosed("connection setup task stopped unexpectedly".to_string())
104 })??;
105 Ok(client)
106 }
107
108 pub fn is_connected(&self) -> bool {
109 self.is_connected.load(Ordering::SeqCst)
110 }
111
112 pub async fn on_event<F, Fut>(&self, name: impl Into<String>, handler: F)
113 where
114 F: Fn(EventMessage) -> Fut + Send + Sync + 'static,
115 Fut: Future<Output = Value> + Send + 'static,
116 {
117 let handler = Arc::new(move |event: EventMessage| {
118 Box::pin(handler(event)) as BoxFuture<'static, Value>
119 });
120 self.event_handlers
121 .write()
122 .await
123 .entry(name.into())
124 .or_default()
125 .push(handler);
126 }
127
128 pub async fn on_connected<F, Fut>(&self, handler: F)
129 where
130 F: Fn(ClientConnectionEvent) -> Fut + Send + Sync + 'static,
131 Fut: Future<Output = ()> + Send + 'static,
132 {
133 let handler: ConnectionHandler = Arc::new(move |event: ClientConnectionEvent| {
134 Box::pin(handler(event)) as BoxFuture<'static, ()>
135 });
136
137 self.connected_handlers.write().await.push(Arc::clone(&handler));
138
139 if self.is_connected() {
140 self.invoke_connection_handler(
141 handler,
142 ClientConnectionEvent {
143 url: self.url.to_string(),
144 },
145 )
146 .await;
147 }
148 }
149
150 pub async fn on_disconnected<F, Fut>(&self, handler: F)
151 where
152 F: Fn(ClientDisconnectEvent) -> Fut + Send + Sync + 'static,
153 Fut: Future<Output = ()> + Send + 'static,
154 {
155 let handler: DisconnectHandler = Arc::new(move |event: ClientDisconnectEvent| {
156 Box::pin(handler(event)) as BoxFuture<'static, ()>
157 });
158 self.disconnected_handlers.write().await.push(handler);
159 }
160
161 pub async fn call(
162 &self,
163 route: impl Into<String>,
164 params: Value,
165 attachments: Vec<FileAttachment>,
166 ) -> Result<Value, ClientError> {
167 if !self.is_connected.load(Ordering::SeqCst) {
168 return Err(ClientError::Disconnected);
169 }
170
171 let request_id = Uuid::new_v4().to_string();
172 let route = route.into();
173 let (tx, rx) = oneshot::channel();
174 self.pending_api.lock().await.insert(request_id.clone(), tx);
175 if self
176 .send_outbound(ClientOutbound::Packet(PacketEnvelope::with_encryption(
177 PacketBody::ApiRequest {
178 request_id: request_id.clone(),
179 route,
180 params,
181 attachments,
182 metadata: json!({ "client_name": "rust-demo" }),
183 },
184 self.default_encryption,
185 )))
186 .await
187 .is_err()
188 {
189 self.pending_api.lock().await.remove(&request_id);
190 return Err(ClientError::Disconnected);
191 }
192
193 match timeout(self.default_timeout, rx).await {
194 Ok(result) => result.map_err(|_| ClientError::Disconnected)?,
195 Err(_) => {
196 self.pending_api.lock().await.remove(&request_id);
197 Err(ClientError::Timeout)
198 }
199 }
200 }
201
202 pub async fn send_event(
203 &self,
204 name: impl Into<String>,
205 data: Value,
206 attachments: Vec<FileAttachment>,
207 ) -> Result<Value, ClientError> {
208 if !self.is_connected.load(Ordering::SeqCst) {
209 return Err(ClientError::Disconnected);
210 }
211
212 let event_id = Uuid::new_v4().to_string();
213 let (tx, rx) = oneshot::channel();
214 self.pending_event.lock().await.insert(event_id.clone(), tx);
215 if self
216 .send_outbound(ClientOutbound::Packet(PacketEnvelope::with_encryption(
217 PacketBody::EventEmit {
218 event_id: event_id.clone(),
219 name: name.into(),
220 data,
221 attachments,
222 metadata: json!({ "client_name": "rust-demo" }),
223 expect_ack: true,
224 },
225 self.default_encryption,
226 )))
227 .await
228 .is_err()
229 {
230 self.pending_event.lock().await.remove(&event_id);
231 return Err(ClientError::Disconnected);
232 }
233
234 match timeout(self.default_timeout, rx).await {
235 Ok(result) => result.map_err(|_| ClientError::Disconnected)?,
236 Err(_) => {
237 self.pending_event.lock().await.remove(&event_id);
238 Err(ClientError::Timeout)
239 }
240 }
241 }
242
243 pub async fn close(&self) -> Result<(), ClientError> {
244 self.shutdown.store(true, Ordering::SeqCst);
245
246 if let Some(writer) = self.writer.read().await.clone() {
247 let _ = writer.send(ClientOutbound::Close).await;
248 }
249
250 let generation = self.connection_generation.load(Ordering::SeqCst);
251 let (disconnect_tx, _disconnect_rx) = oneshot::channel();
252 self.handle_disconnect(
253 generation,
254 ClientError::Disconnected,
255 Arc::new(Mutex::new(Some(disconnect_tx))),
256 )
257 .await;
258 Ok(())
259 }
260
261 async fn handle_packet(&self, packet: PacketEnvelope) {
262 match packet.body {
263 PacketBody::ApiResponse {
264 request_id,
265 ok,
266 data,
267 error,
268 ..
269 } => {
270 if let Some(tx) = self.pending_api.lock().await.remove(&request_id) {
271 let result = if ok {
272 Ok(data)
273 } else {
274 Err(ClientError::Remote(error.unwrap_or_else(|| ErrorPayload {
275 code: "remote_error".to_string(),
276 message: "missing remote error".to_string(),
277 status: 500,
278 details: None,
279 })))
280 };
281 let _ = tx.send(result);
282 }
283 }
284 PacketBody::EventAck {
285 event_id,
286 ok,
287 receipt,
288 error,
289 } => {
290 if let Some(tx) = self.pending_event.lock().await.remove(&event_id) {
291 let result = if ok {
292 Ok(receipt)
293 } else {
294 Err(ClientError::Remote(error.unwrap_or_else(|| ErrorPayload {
295 code: "remote_error".to_string(),
296 message: "missing remote error".to_string(),
297 status: 500,
298 details: None,
299 })))
300 };
301 let _ = tx.send(result);
302 }
303 }
304 PacketBody::EventEmit {
305 event_id,
306 name,
307 data,
308 attachments,
309 metadata,
310 expect_ack,
311 } => {
312 let event = EventMessage {
313 event_id: event_id.clone(),
314 name: name.clone(),
315 data,
316 attachments,
317 metadata,
318 };
319 let handlers = self
320 .event_handlers
321 .read()
322 .await
323 .get(&name)
324 .cloned()
325 .unwrap_or_default();
326
327 let mut receipt = json!({ "handled": false });
328 for handler in handlers {
329 receipt = handler(event.clone()).await;
330 }
331
332 if expect_ack {
333 let _ = self
334 .send_outbound(ClientOutbound::Packet(PacketEnvelope::with_encryption(
335 PacketBody::EventAck {
336 event_id,
337 ok: true,
338 receipt,
339 error: None,
340 },
341 self.default_encryption,
342 )))
343 .await;
344 }
345 }
346 PacketBody::ApiRequest { .. } => {}
347 }
348 }
349
350 async fn run_connection_supervisor(
351 self,
352 ready_tx: oneshot::Sender<Result<(), ClientError>>,
353 ) {
354 let mut ready_tx = Some(ready_tx);
355 let mut reconnect_attempt = 0_u32;
356
357 loop {
358 if self.shutdown.load(Ordering::SeqCst) {
359 return;
360 }
361
362 let generation = self.connection_generation.fetch_add(1, Ordering::SeqCst) + 1;
363 match self.establish_connection(generation).await {
364 Ok(disconnect_rx) => {
365 if let Some(ready_tx) = ready_tx.take() {
366 let _ = ready_tx.send(Ok(()));
367 }
368 reconnect_attempt = 0;
369 let _ = disconnect_rx.await;
370 }
371 Err(error) => {
372 if let Some(ready_tx) = ready_tx.take() {
373 let _ = ready_tx.send(Err(error));
374 return;
375 }
376 }
377 }
378
379 if self.shutdown.load(Ordering::SeqCst) {
380 return;
381 }
382
383 reconnect_attempt = reconnect_attempt.saturating_add(1);
384 sleep(Self::reconnect_delay(reconnect_attempt)).await;
385 }
386 }
387
388 async fn establish_connection(
389 &self,
390 generation: u64,
391 ) -> Result<oneshot::Receiver<ClientError>, ClientError> {
392 let (socket, _) = connect_async(self.url.as_ref()).await?;
393 let (mut sink, mut stream) = socket.split();
394 let (tx, mut rx) = mpsc::channel::<ClientOutbound>(CLIENT_OUTBOUND_QUEUE_CAPACITY);
395 let (disconnect_tx, disconnect_rx) = oneshot::channel();
396 let disconnect_signal = Arc::new(Mutex::new(Some(disconnect_tx)));
397
398 *self.writer.write().await = Some(tx.clone());
399 self.is_connected.store(true, Ordering::SeqCst);
400 self.emit_connected().await;
401
402 let writer_codec = self.codec.clone();
403 let writer_client = self.clone();
404 let writer_signal = Arc::clone(&disconnect_signal);
405 tokio::spawn(async move {
406 let error = loop {
407 let Some(outbound) = rx.recv().await else {
408 break ClientError::ConnectionClosed("writer loop stopped".to_string());
409 };
410
411 match outbound {
412 ClientOutbound::Packet(packet) => {
413 let encoded = match writer_codec.encode(&packet) {
414 Ok(encoded) => encoded,
415 Err(error) => {
416 eprintln!("failed to encode outbound frame: {error}");
417 continue;
418 }
419 };
420
421 if let Err(error) = sink.send(Message::Binary(encoded)).await {
422 break ClientError::ConnectionClosed(error.to_string());
423 }
424 }
425 ClientOutbound::Ping(payload) => {
426 if let Err(error) = sink.send(Message::Ping(payload)).await {
427 break ClientError::ConnectionClosed(error.to_string());
428 }
429 }
430 ClientOutbound::Pong(payload) => {
431 if let Err(error) = sink.send(Message::Pong(payload)).await {
432 break ClientError::ConnectionClosed(error.to_string());
433 }
434 }
435 ClientOutbound::Close => {
436 let _ = sink.send(Message::Close(None)).await;
437 break ClientError::ConnectionClosed("client closed".to_string());
438 }
439 }
440 };
441
442 writer_client
443 .handle_disconnect(generation, error, writer_signal)
444 .await;
445 });
446
447 let heartbeat_client = self.clone();
448 let heartbeat_tx = tx.clone();
449 let heartbeat_signal = Arc::clone(&disconnect_signal);
450 tokio::spawn(async move {
451 let mut ticker = interval(CLIENT_HEARTBEAT_INTERVAL);
452 ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
453 loop {
454 ticker.tick().await;
455 if !heartbeat_client.is_connection_generation_active(generation) {
456 break;
457 }
458
459 if heartbeat_tx
460 .send(ClientOutbound::Ping(Vec::new()))
461 .await
462 .is_err()
463 {
464 heartbeat_client
465 .handle_disconnect(
466 generation,
467 ClientError::ConnectionClosed("heartbeat stopped".to_string()),
468 heartbeat_signal,
469 )
470 .await;
471 break;
472 }
473 }
474 });
475
476 let reader_client = self.clone();
477 let reader_tx = tx;
478 let reader_codec = self.codec.clone();
479 let reader_signal = Arc::clone(&disconnect_signal);
480 tokio::spawn(async move {
481 let error = loop {
482 let next_message = timeout(CLIENT_IDLE_TIMEOUT, stream.next()).await;
483 let message = match next_message {
484 Ok(Some(message)) => message,
485 Ok(None) => {
486 break ClientError::ConnectionClosed("reader loop stopped".to_string())
487 }
488 Err(_) => break ClientError::IdleTimeout,
489 };
490
491 match message {
492 Ok(Message::Binary(bytes)) => match reader_codec.decode(&bytes) {
493 Ok(packet) => reader_client.handle_packet(packet).await,
494 Err(error) => eprintln!("failed to decode inbound frame: {error}"),
495 },
496 Ok(Message::Close(_)) => {
497 break ClientError::ConnectionClosed("server closed connection".to_string())
498 }
499 Ok(Message::Ping(payload)) => {
500 if reader_tx
501 .send(ClientOutbound::Pong(payload.to_vec()))
502 .await
503 .is_err()
504 {
505 break ClientError::ConnectionClosed(
506 "failed to queue pong response".to_string(),
507 );
508 }
509 }
510 Ok(Message::Pong(_)) | Ok(Message::Text(_)) | Ok(Message::Frame(_)) => {}
511 Err(error) => {
512 eprintln!("client reader stopped: {error}");
513 break ClientError::ConnectionClosed(error.to_string());
514 }
515 }
516 };
517
518 reader_client
519 .handle_disconnect(generation, error, reader_signal)
520 .await;
521 });
522
523 Ok(disconnect_rx)
524 }
525
526 async fn send_outbound(&self, outbound: ClientOutbound) -> Result<(), ClientError> {
527 let Some(writer) = self.writer.read().await.clone() else {
528 return Err(ClientError::Disconnected);
529 };
530
531 writer.send(outbound).await.map_err(|_| ClientError::Disconnected)
532 }
533
534 async fn handle_disconnect(
535 &self,
536 generation: u64,
537 error: ClientError,
538 disconnect_signal: Arc<Mutex<Option<oneshot::Sender<ClientError>>>>,
539 ) {
540 if !self.is_connection_generation_active(generation) {
541 return;
542 }
543
544 let reason = Self::disconnect_reason(&error);
545
546 if !self.is_connected.swap(false, Ordering::SeqCst) {
547 return;
548 }
549
550 *self.writer.write().await = None;
551
552 let pending_api = std::mem::take(&mut *self.pending_api.lock().await);
553 for sender in pending_api.into_values() {
554 let _ = sender.send(Err(ClientError::ConnectionClosed(reason.clone())));
555 }
556
557 let pending_event = std::mem::take(&mut *self.pending_event.lock().await);
558 for sender in pending_event.into_values() {
559 let _ = sender.send(Err(ClientError::ConnectionClosed(reason.clone())));
560 }
561
562 self.emit_disconnected(ClientDisconnectEvent {
563 url: self.url.to_string(),
564 reason,
565 will_reconnect: !self.shutdown.load(Ordering::SeqCst),
566 retry_after: (!self.shutdown.load(Ordering::SeqCst))
567 .then_some(Self::reconnect_delay(1)),
568 })
569 .await;
570
571 if let Some(sender) = disconnect_signal.lock().await.take() {
572 let _ = sender.send(error);
573 }
574 }
575
576 fn disconnect_reason(error: &ClientError) -> String {
577 match error {
578 ClientError::ConnectionClosed(reason) => reason.clone(),
579 ClientError::IdleTimeout => "idle timeout".to_string(),
580 ClientError::Disconnected => "disconnected".to_string(),
581 other => other.to_string(),
582 }
583 }
584
585 fn is_connection_generation_active(&self, generation: u64) -> bool {
586 self.connection_generation.load(Ordering::SeqCst) == generation
587 }
588
589 fn reconnect_delay(attempt: u32) -> Duration {
590 let seconds = CLIENT_RECONNECT_BASE_DELAY_SECS
591 .saturating_add(u64::from(attempt.saturating_sub(1)))
592 .min(CLIENT_RECONNECT_MAX_DELAY_SECS);
593 Duration::from_secs(seconds)
594 }
595
596 async fn emit_connected(&self) {
597 let event = ClientConnectionEvent {
598 url: self.url.to_string(),
599 };
600 let handlers = self.connected_handlers.read().await.clone();
601 for handler in handlers {
602 self.invoke_connection_handler(handler, event.clone()).await;
603 }
604 }
605
606 async fn emit_disconnected(&self, event: ClientDisconnectEvent) {
607 let handlers = self.disconnected_handlers.read().await.clone();
608 for handler in handlers {
609 self.invoke_disconnect_handler(handler, event.clone()).await;
610 }
611 }
612
613 async fn invoke_connection_handler(
614 &self,
615 handler: ConnectionHandler,
616 event: ClientConnectionEvent,
617 ) {
618 if AssertUnwindSafe(handler(event)).catch_unwind().await.is_err() {
619 eprintln!("client connected handler panicked");
620 }
621 }
622
623 async fn invoke_disconnect_handler(
624 &self,
625 handler: DisconnectHandler,
626 event: ClientDisconnectEvent,
627 ) {
628 if AssertUnwindSafe(handler(event)).catch_unwind().await.is_err() {
629 eprintln!("client disconnected handler panicked");
630 }
631 }
632}