Skip to main content

wscall_client/
client_runtime.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::time::Duration;
7
8use futures_util::{FutureExt, SinkExt, StreamExt, future::BoxFuture};
9use serde_json::{Value, json};
10use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
11use tokio::time::{MissedTickBehavior, interval, sleep, timeout};
12use tokio_tungstenite::{connect_async, tungstenite::Message};
13use uuid::Uuid;
14use wscall_protocol::{
15    EncryptionKind, ErrorPayload, FileAttachment, FrameCodec, PacketBody, PacketEnvelope,
16};
17
18use crate::client_types::{
19    ClientConnectionEvent, ClientDisconnectEvent, ClientError, ClientOutbound, EventMessage,
20};
21
22type EventHandler = Arc<dyn Fn(EventMessage) -> BoxFuture<'static, Value> + Send + Sync>;
23type ConnectionHandler = Arc<dyn Fn(ClientConnectionEvent) -> BoxFuture<'static, ()> + Send + Sync>;
24type DisconnectHandler =
25    Arc<dyn Fn(ClientDisconnectEvent) -> BoxFuture<'static, ()> + Send + Sync>;
26type PendingSender = oneshot::Sender<Result<Value, ClientError>>;
27type PendingMap = Arc<Mutex<HashMap<String, PendingSender>>>;
28
29const CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(45);
30const CLIENT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
31const CLIENT_OUTBOUND_QUEUE_CAPACITY: usize = 256;
32const CLIENT_RECONNECT_BASE_DELAY_SECS: u64 = 3;
33const CLIENT_RECONNECT_MAX_DELAY_SECS: u64 = 30;
34
35#[derive(Clone)]
36pub struct WscallClient {
37    url: Arc<str>,
38    codec: FrameCodec,
39    writer: Arc<RwLock<Option<mpsc::Sender<ClientOutbound>>>>,
40    pending_api: PendingMap,
41    pending_event: PendingMap,
42    event_handlers: Arc<RwLock<HashMap<String, Vec<EventHandler>>>>,
43    connected_handlers: Arc<RwLock<Vec<ConnectionHandler>>>,
44    disconnected_handlers: Arc<RwLock<Vec<DisconnectHandler>>>,
45    default_timeout: Duration,
46    default_encryption: EncryptionKind,
47    is_connected: Arc<AtomicBool>,
48    shutdown: Arc<AtomicBool>,
49    connection_generation: Arc<AtomicU64>,
50}
51
52impl WscallClient {
53    pub async fn connect(url: &str) -> Result<Self, ClientError> {
54        Self::connect_with_settings(url, FrameCodec::plaintext(), EncryptionKind::None).await
55    }
56
57    pub async fn connect_with_chacha20(url: &str, key: [u8; 32]) -> Result<Self, ClientError> {
58        Self::connect_with_settings(
59            url,
60            FrameCodec::plaintext().with_chacha20_key(key),
61            EncryptionKind::ChaCha20,
62        )
63        .await
64    }
65
66    pub async fn connect_with_aes256(url: &str, key: [u8; 32]) -> Result<Self, ClientError> {
67        Self::connect_with_settings(
68            url,
69            FrameCodec::plaintext().with_aes256_key(key),
70            EncryptionKind::Aes256,
71        )
72        .await
73    }
74
75    async fn connect_with_settings(
76        url: &str,
77        codec: FrameCodec,
78        default_encryption: EncryptionKind,
79    ) -> Result<Self, ClientError> {
80        let client = Self {
81            url: Arc::<str>::from(url),
82            codec,
83            writer: Arc::new(RwLock::new(None)),
84            pending_api: Arc::new(Mutex::new(HashMap::new())),
85            pending_event: Arc::new(Mutex::new(HashMap::new())),
86            event_handlers: Arc::new(RwLock::new(HashMap::new())),
87            connected_handlers: Arc::new(RwLock::new(Vec::new())),
88            disconnected_handlers: Arc::new(RwLock::new(Vec::new())),
89            default_timeout: Duration::from_secs(10),
90            default_encryption,
91            is_connected: Arc::new(AtomicBool::new(false)),
92            shutdown: Arc::new(AtomicBool::new(false)),
93            connection_generation: Arc::new(AtomicU64::new(0)),
94        };
95
96        let (ready_tx, ready_rx) = oneshot::channel();
97        let supervisor_client = client.clone();
98        tokio::spawn(async move {
99            supervisor_client.run_connection_supervisor(ready_tx).await;
100        });
101
102        ready_rx.await.map_err(|_| {
103            ClientError::ConnectionClosed("connection setup task stopped unexpectedly".to_string())
104        })??;
105        Ok(client)
106    }
107
108    pub fn is_connected(&self) -> bool {
109        self.is_connected.load(Ordering::SeqCst)
110    }
111
112    pub async fn on_event<F, Fut>(&self, name: impl Into<String>, handler: F)
113    where
114        F: Fn(EventMessage) -> Fut + Send + Sync + 'static,
115        Fut: Future<Output = Value> + Send + 'static,
116    {
117        let handler = Arc::new(move |event: EventMessage| {
118            Box::pin(handler(event)) as BoxFuture<'static, Value>
119        });
120        self.event_handlers
121            .write()
122            .await
123            .entry(name.into())
124            .or_default()
125            .push(handler);
126    }
127
128    pub async fn on_connected<F, Fut>(&self, handler: F)
129    where
130        F: Fn(ClientConnectionEvent) -> Fut + Send + Sync + 'static,
131        Fut: Future<Output = ()> + Send + 'static,
132    {
133        let handler: ConnectionHandler = Arc::new(move |event: ClientConnectionEvent| {
134            Box::pin(handler(event)) as BoxFuture<'static, ()>
135        });
136
137        self.connected_handlers.write().await.push(Arc::clone(&handler));
138
139        if self.is_connected() {
140            self.invoke_connection_handler(
141                handler,
142                ClientConnectionEvent {
143                    url: self.url.to_string(),
144                },
145            )
146            .await;
147        }
148    }
149
150    pub async fn on_disconnected<F, Fut>(&self, handler: F)
151    where
152        F: Fn(ClientDisconnectEvent) -> Fut + Send + Sync + 'static,
153        Fut: Future<Output = ()> + Send + 'static,
154    {
155        let handler: DisconnectHandler = Arc::new(move |event: ClientDisconnectEvent| {
156            Box::pin(handler(event)) as BoxFuture<'static, ()>
157        });
158        self.disconnected_handlers.write().await.push(handler);
159    }
160
161    pub async fn call(
162        &self,
163        route: impl Into<String>,
164        params: Value,
165        attachments: Vec<FileAttachment>,
166    ) -> Result<Value, ClientError> {
167        if !self.is_connected.load(Ordering::SeqCst) {
168            return Err(ClientError::Disconnected);
169        }
170
171        let request_id = Uuid::new_v4().to_string();
172        let route = route.into();
173        let (tx, rx) = oneshot::channel();
174        self.pending_api.lock().await.insert(request_id.clone(), tx);
175        if self
176            .send_outbound(ClientOutbound::Packet(PacketEnvelope::with_encryption(
177                PacketBody::ApiRequest {
178                    request_id: request_id.clone(),
179                    route,
180                    params,
181                    attachments,
182                    metadata: json!({ "client_name": "rust-demo" }),
183                },
184                self.default_encryption,
185            )))
186            .await
187            .is_err()
188        {
189            self.pending_api.lock().await.remove(&request_id);
190            return Err(ClientError::Disconnected);
191        }
192
193        match timeout(self.default_timeout, rx).await {
194            Ok(result) => result.map_err(|_| ClientError::Disconnected)?,
195            Err(_) => {
196                self.pending_api.lock().await.remove(&request_id);
197                Err(ClientError::Timeout)
198            }
199        }
200    }
201
202    pub async fn send_event(
203        &self,
204        name: impl Into<String>,
205        data: Value,
206        attachments: Vec<FileAttachment>,
207    ) -> Result<Value, ClientError> {
208        if !self.is_connected.load(Ordering::SeqCst) {
209            return Err(ClientError::Disconnected);
210        }
211
212        let event_id = Uuid::new_v4().to_string();
213        let (tx, rx) = oneshot::channel();
214        self.pending_event.lock().await.insert(event_id.clone(), tx);
215        if self
216            .send_outbound(ClientOutbound::Packet(PacketEnvelope::with_encryption(
217                PacketBody::EventEmit {
218                    event_id: event_id.clone(),
219                    name: name.into(),
220                    data,
221                    attachments,
222                    metadata: json!({ "client_name": "rust-demo" }),
223                    expect_ack: true,
224                },
225                self.default_encryption,
226            )))
227            .await
228            .is_err()
229        {
230            self.pending_event.lock().await.remove(&event_id);
231            return Err(ClientError::Disconnected);
232        }
233
234        match timeout(self.default_timeout, rx).await {
235            Ok(result) => result.map_err(|_| ClientError::Disconnected)?,
236            Err(_) => {
237                self.pending_event.lock().await.remove(&event_id);
238                Err(ClientError::Timeout)
239            }
240        }
241    }
242
243    pub async fn close(&self) -> Result<(), ClientError> {
244        self.shutdown.store(true, Ordering::SeqCst);
245
246        if let Some(writer) = self.writer.read().await.clone() {
247            let _ = writer.send(ClientOutbound::Close).await;
248        }
249
250        let generation = self.connection_generation.load(Ordering::SeqCst);
251        let (disconnect_tx, _disconnect_rx) = oneshot::channel();
252        self.handle_disconnect(
253            generation,
254            ClientError::Disconnected,
255            Arc::new(Mutex::new(Some(disconnect_tx))),
256        )
257        .await;
258        Ok(())
259    }
260
261    async fn handle_packet(&self, packet: PacketEnvelope) {
262        match packet.body {
263            PacketBody::ApiResponse {
264                request_id,
265                ok,
266                data,
267                error,
268                ..
269            } => {
270                if let Some(tx) = self.pending_api.lock().await.remove(&request_id) {
271                    let result = if ok {
272                        Ok(data)
273                    } else {
274                        Err(ClientError::Remote(error.unwrap_or_else(|| ErrorPayload {
275                            code: "remote_error".to_string(),
276                            message: "missing remote error".to_string(),
277                            status: 500,
278                            details: None,
279                        })))
280                    };
281                    let _ = tx.send(result);
282                }
283            }
284            PacketBody::EventAck {
285                event_id,
286                ok,
287                receipt,
288                error,
289            } => {
290                if let Some(tx) = self.pending_event.lock().await.remove(&event_id) {
291                    let result = if ok {
292                        Ok(receipt)
293                    } else {
294                        Err(ClientError::Remote(error.unwrap_or_else(|| ErrorPayload {
295                            code: "remote_error".to_string(),
296                            message: "missing remote error".to_string(),
297                            status: 500,
298                            details: None,
299                        })))
300                    };
301                    let _ = tx.send(result);
302                }
303            }
304            PacketBody::EventEmit {
305                event_id,
306                name,
307                data,
308                attachments,
309                metadata,
310                expect_ack,
311            } => {
312                let event = EventMessage {
313                    event_id: event_id.clone(),
314                    name: name.clone(),
315                    data,
316                    attachments,
317                    metadata,
318                };
319                let handlers = self
320                    .event_handlers
321                    .read()
322                    .await
323                    .get(&name)
324                    .cloned()
325                    .unwrap_or_default();
326
327                let mut receipt = json!({ "handled": false });
328                for handler in handlers {
329                    receipt = handler(event.clone()).await;
330                }
331
332                if expect_ack {
333                    let _ = self
334                        .send_outbound(ClientOutbound::Packet(PacketEnvelope::with_encryption(
335                            PacketBody::EventAck {
336                                event_id,
337                                ok: true,
338                                receipt,
339                                error: None,
340                            },
341                            self.default_encryption,
342                        )))
343                        .await;
344                }
345            }
346            PacketBody::ApiRequest { .. } => {}
347        }
348    }
349
350    async fn run_connection_supervisor(
351        self,
352        ready_tx: oneshot::Sender<Result<(), ClientError>>,
353    ) {
354        let mut ready_tx = Some(ready_tx);
355        let mut reconnect_attempt = 0_u32;
356
357        loop {
358            if self.shutdown.load(Ordering::SeqCst) {
359                return;
360            }
361
362            let generation = self.connection_generation.fetch_add(1, Ordering::SeqCst) + 1;
363            match self.establish_connection(generation).await {
364                Ok(disconnect_rx) => {
365                    if let Some(ready_tx) = ready_tx.take() {
366                        let _ = ready_tx.send(Ok(()));
367                    }
368                    reconnect_attempt = 0;
369                    let _ = disconnect_rx.await;
370                }
371                Err(error) => {
372                    if let Some(ready_tx) = ready_tx.take() {
373                        let _ = ready_tx.send(Err(error));
374                        return;
375                    }
376                }
377            }
378
379            if self.shutdown.load(Ordering::SeqCst) {
380                return;
381            }
382
383            reconnect_attempt = reconnect_attempt.saturating_add(1);
384            sleep(Self::reconnect_delay(reconnect_attempt)).await;
385        }
386    }
387
388    async fn establish_connection(
389        &self,
390        generation: u64,
391    ) -> Result<oneshot::Receiver<ClientError>, ClientError> {
392        let (socket, _) = connect_async(self.url.as_ref()).await?;
393        let (mut sink, mut stream) = socket.split();
394        let (tx, mut rx) = mpsc::channel::<ClientOutbound>(CLIENT_OUTBOUND_QUEUE_CAPACITY);
395        let (disconnect_tx, disconnect_rx) = oneshot::channel();
396        let disconnect_signal = Arc::new(Mutex::new(Some(disconnect_tx)));
397
398        *self.writer.write().await = Some(tx.clone());
399        self.is_connected.store(true, Ordering::SeqCst);
400        self.emit_connected().await;
401
402        let writer_codec = self.codec.clone();
403        let writer_client = self.clone();
404        let writer_signal = Arc::clone(&disconnect_signal);
405        tokio::spawn(async move {
406            let error = loop {
407                let Some(outbound) = rx.recv().await else {
408                    break ClientError::ConnectionClosed("writer loop stopped".to_string());
409                };
410
411                match outbound {
412                    ClientOutbound::Packet(packet) => {
413                        let encoded = match writer_codec.encode(&packet) {
414                            Ok(encoded) => encoded,
415                            Err(error) => {
416                                eprintln!("failed to encode outbound frame: {error}");
417                                continue;
418                            }
419                        };
420
421                        if let Err(error) = sink.send(Message::Binary(encoded)).await {
422                            break ClientError::ConnectionClosed(error.to_string());
423                        }
424                    }
425                    ClientOutbound::Ping(payload) => {
426                        if let Err(error) = sink.send(Message::Ping(payload)).await {
427                            break ClientError::ConnectionClosed(error.to_string());
428                        }
429                    }
430                    ClientOutbound::Pong(payload) => {
431                        if let Err(error) = sink.send(Message::Pong(payload)).await {
432                            break ClientError::ConnectionClosed(error.to_string());
433                        }
434                    }
435                    ClientOutbound::Close => {
436                        let _ = sink.send(Message::Close(None)).await;
437                        break ClientError::ConnectionClosed("client closed".to_string());
438                    }
439                }
440            };
441
442            writer_client
443                .handle_disconnect(generation, error, writer_signal)
444                .await;
445        });
446
447        let heartbeat_client = self.clone();
448        let heartbeat_tx = tx.clone();
449        let heartbeat_signal = Arc::clone(&disconnect_signal);
450        tokio::spawn(async move {
451            let mut ticker = interval(CLIENT_HEARTBEAT_INTERVAL);
452            ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
453            loop {
454                ticker.tick().await;
455                if !heartbeat_client.is_connection_generation_active(generation) {
456                    break;
457                }
458
459                if heartbeat_tx
460                    .send(ClientOutbound::Ping(Vec::new()))
461                    .await
462                    .is_err()
463                {
464                    heartbeat_client
465                        .handle_disconnect(
466                            generation,
467                            ClientError::ConnectionClosed("heartbeat stopped".to_string()),
468                            heartbeat_signal,
469                        )
470                        .await;
471                    break;
472                }
473            }
474        });
475
476        let reader_client = self.clone();
477        let reader_tx = tx;
478        let reader_codec = self.codec.clone();
479        let reader_signal = Arc::clone(&disconnect_signal);
480        tokio::spawn(async move {
481            let error = loop {
482                let next_message = timeout(CLIENT_IDLE_TIMEOUT, stream.next()).await;
483                let message = match next_message {
484                    Ok(Some(message)) => message,
485                    Ok(None) => {
486                        break ClientError::ConnectionClosed("reader loop stopped".to_string())
487                    }
488                    Err(_) => break ClientError::IdleTimeout,
489                };
490
491                match message {
492                    Ok(Message::Binary(bytes)) => match reader_codec.decode(&bytes) {
493                        Ok(packet) => reader_client.handle_packet(packet).await,
494                        Err(error) => eprintln!("failed to decode inbound frame: {error}"),
495                    },
496                    Ok(Message::Close(_)) => {
497                        break ClientError::ConnectionClosed("server closed connection".to_string())
498                    }
499                    Ok(Message::Ping(payload)) => {
500                        if reader_tx
501                            .send(ClientOutbound::Pong(payload.to_vec()))
502                            .await
503                            .is_err()
504                        {
505                            break ClientError::ConnectionClosed(
506                                "failed to queue pong response".to_string(),
507                            );
508                        }
509                    }
510                    Ok(Message::Pong(_)) | Ok(Message::Text(_)) | Ok(Message::Frame(_)) => {}
511                    Err(error) => {
512                        eprintln!("client reader stopped: {error}");
513                        break ClientError::ConnectionClosed(error.to_string());
514                    }
515                }
516            };
517
518            reader_client
519                .handle_disconnect(generation, error, reader_signal)
520                .await;
521        });
522
523        Ok(disconnect_rx)
524    }
525
526    async fn send_outbound(&self, outbound: ClientOutbound) -> Result<(), ClientError> {
527        let Some(writer) = self.writer.read().await.clone() else {
528            return Err(ClientError::Disconnected);
529        };
530
531        writer.send(outbound).await.map_err(|_| ClientError::Disconnected)
532    }
533
534    async fn handle_disconnect(
535        &self,
536        generation: u64,
537        error: ClientError,
538        disconnect_signal: Arc<Mutex<Option<oneshot::Sender<ClientError>>>>,
539    ) {
540        if !self.is_connection_generation_active(generation) {
541            return;
542        }
543
544        let reason = Self::disconnect_reason(&error);
545
546        if !self.is_connected.swap(false, Ordering::SeqCst) {
547            return;
548        }
549
550        *self.writer.write().await = None;
551
552        let pending_api = std::mem::take(&mut *self.pending_api.lock().await);
553        for sender in pending_api.into_values() {
554            let _ = sender.send(Err(ClientError::ConnectionClosed(reason.clone())));
555        }
556
557        let pending_event = std::mem::take(&mut *self.pending_event.lock().await);
558        for sender in pending_event.into_values() {
559            let _ = sender.send(Err(ClientError::ConnectionClosed(reason.clone())));
560        }
561
562        self.emit_disconnected(ClientDisconnectEvent {
563            url: self.url.to_string(),
564            reason,
565            will_reconnect: !self.shutdown.load(Ordering::SeqCst),
566            retry_after: (!self.shutdown.load(Ordering::SeqCst))
567                .then_some(Self::reconnect_delay(1)),
568        })
569        .await;
570
571        if let Some(sender) = disconnect_signal.lock().await.take() {
572            let _ = sender.send(error);
573        }
574    }
575
576    fn disconnect_reason(error: &ClientError) -> String {
577        match error {
578            ClientError::ConnectionClosed(reason) => reason.clone(),
579            ClientError::IdleTimeout => "idle timeout".to_string(),
580            ClientError::Disconnected => "disconnected".to_string(),
581            other => other.to_string(),
582        }
583    }
584
585    fn is_connection_generation_active(&self, generation: u64) -> bool {
586        self.connection_generation.load(Ordering::SeqCst) == generation
587    }
588
589    fn reconnect_delay(attempt: u32) -> Duration {
590        let seconds = CLIENT_RECONNECT_BASE_DELAY_SECS
591            .saturating_add(u64::from(attempt.saturating_sub(1)))
592            .min(CLIENT_RECONNECT_MAX_DELAY_SECS);
593        Duration::from_secs(seconds)
594    }
595
596    async fn emit_connected(&self) {
597        let event = ClientConnectionEvent {
598            url: self.url.to_string(),
599        };
600        let handlers = self.connected_handlers.read().await.clone();
601        for handler in handlers {
602            self.invoke_connection_handler(handler, event.clone()).await;
603        }
604    }
605
606    async fn emit_disconnected(&self, event: ClientDisconnectEvent) {
607        let handlers = self.disconnected_handlers.read().await.clone();
608        for handler in handlers {
609            self.invoke_disconnect_handler(handler, event.clone()).await;
610        }
611    }
612
613    async fn invoke_connection_handler(
614        &self,
615        handler: ConnectionHandler,
616        event: ClientConnectionEvent,
617    ) {
618        if AssertUnwindSafe(handler(event)).catch_unwind().await.is_err() {
619            eprintln!("client connected handler panicked");
620        }
621    }
622
623    async fn invoke_disconnect_handler(
624        &self,
625        handler: DisconnectHandler,
626        event: ClientDisconnectEvent,
627    ) {
628        if AssertUnwindSafe(handler(event)).catch_unwind().await.is_err() {
629            eprintln!("client disconnected handler panicked");
630        }
631    }
632}