1use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31use std::time::Duration;
32
33use futures_util::StreamExt;
34use tokio::sync::{broadcast, mpsc, watch, RwLock};
35use tokio_util::sync::CancellationToken;
36use tungstenite::protocol::WebSocketConfig;
37use tungstenite::Message;
38
39use crate::connection::ConnectionEvent;
40use crate::connection::{
41 ConnectionSnapshot, ConnectionSupervisor, DefaultConnector, ExponentialBackoff, NoRetry,
42 RetryStrategy,
43};
44use crate::error::{ClientError, DisconnectReason, HandshakeError, SendError, SupervisorError};
45use crate::extension::{Extension, ExtensionHost};
46use crate::handshake::{BoxHandshaker, Handshaker, NoOpHandshaker};
47use crate::message::{DispatcherConfig, MessageDispatcher, ProcessorErrorPolicy, SharedMessage};
48
49#[derive(Clone)]
54pub struct ClientConfig {
55 pub receive_timeout: Duration,
57 pub exit_on_first_failure: bool,
59 pub connect_timeout: Duration,
61 pub handshake_retry_delay: Duration,
63 pub ws_config: Option<WebSocketConfig>,
65 pub disable_nagle: bool,
67 pub channel_buffer_size: usize,
69 pub send_queue_capacity: usize,
71 pub processor_error_policy: ProcessorErrorPolicy,
73}
74
75impl Default for ClientConfig {
76 fn default() -> Self {
77 Self {
78 receive_timeout: Duration::from_secs(20),
79 exit_on_first_failure: false,
80 connect_timeout: Duration::from_secs(30),
81 handshake_retry_delay: Duration::from_secs(5),
82 ws_config: None,
83 disable_nagle: false,
84 channel_buffer_size: 256,
85 send_queue_capacity: 256,
86 processor_error_policy: ProcessorErrorPolicy::Ignore,
87 }
88 }
89}
90
91impl ClientConfig {
92 #[must_use]
94 pub fn new() -> Self {
95 Self::default()
96 }
97
98 #[must_use]
100 pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
101 self.receive_timeout = timeout;
102 self
103 }
104
105 #[must_use]
107 pub const fn with_exit_on_first_failure(mut self, exit: bool) -> Self {
108 self.exit_on_first_failure = exit;
109 self
110 }
111
112 #[must_use]
114 pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
115 self.connect_timeout = timeout;
116 self
117 }
118
119 #[must_use]
121 pub const fn with_handshake_retry_delay(mut self, delay: Duration) -> Self {
122 self.handshake_retry_delay = delay;
123 self
124 }
125
126 #[must_use]
128 #[allow(clippy::missing_const_for_fn)] pub fn with_ws_config(mut self, config: WebSocketConfig) -> Self {
130 self.ws_config = Some(config);
131 self
132 }
133
134 #[must_use]
136 pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
137 self.disable_nagle = nodelay;
138 self
139 }
140
141 #[must_use]
143 pub const fn with_channel_buffer(mut self, size: usize) -> Self {
144 self.channel_buffer_size = size;
145 self
146 }
147
148 #[must_use]
150 pub const fn with_send_queue_capacity(mut self, cap: usize) -> Self {
151 self.send_queue_capacity = cap;
152 self
153 }
154
155 #[must_use]
157 pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
158 self.processor_error_policy = policy;
159 self
160 }
161
162 #[must_use]
164 pub const fn fast_reconnect() -> Self {
165 Self {
166 receive_timeout: Duration::from_secs(10),
167 exit_on_first_failure: false,
168 connect_timeout: Duration::from_secs(10),
169 handshake_retry_delay: Duration::from_millis(500),
170 ws_config: None,
171 disable_nagle: true,
172 channel_buffer_size: 512,
173 send_queue_capacity: 512,
174 processor_error_policy: ProcessorErrorPolicy::Ignore,
175 }
176 }
177
178 #[must_use]
180 pub const fn stable_connection() -> Self {
181 Self {
182 receive_timeout: Duration::from_secs(60),
183 exit_on_first_failure: false,
184 connect_timeout: Duration::from_secs(60),
185 handshake_retry_delay: Duration::from_secs(2),
186 ws_config: None,
187 disable_nagle: false,
188 channel_buffer_size: 128,
189 send_queue_capacity: 128,
190 processor_error_policy: ProcessorErrorPolicy::Ignore,
191 }
192 }
193}
194
195impl From<&ClientConfig> for DispatcherConfig {
196 fn from(config: &ClientConfig) -> Self {
197 Self::new()
198 .with_receive_timeout(config.receive_timeout)
199 .with_broadcast_capacity(config.channel_buffer_size)
200 .with_send_buffer_capacity(config.send_queue_capacity)
201 .with_processor_error_policy(config.processor_error_policy)
202 }
203}
204
205#[derive(Clone)]
207pub struct Sender {
208 tx: mpsc::Sender<Message>,
209}
210
211impl Sender {
212 pub fn send(&self, message: Message) -> Result<(), SendError> {
214 match self.tx.try_send(message) {
215 Ok(()) => Ok(()),
216 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
217 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(SendError::ChannelClosed),
218 }
219 }
220
221 pub fn send_text(&self, text: impl Into<String>) -> Result<(), SendError> {
223 self.send(Message::Text(text.into().into()))
224 }
225
226 pub fn send_binary(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
228 self.send(Message::Binary(data.into().into()))
229 }
230
231 pub fn ping(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
233 self.send(Message::Ping(data.into().into()))
234 }
235
236 pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
238 self.tx
239 .send(message)
240 .await
241 .map_err(|_| SendError::ChannelClosed)
242 }
243
244 pub async fn send_text_async(&self, text: impl Into<String>) -> Result<(), SendError> {
246 self.send_async(Message::Text(text.into().into())).await
247 }
248
249 pub async fn send_binary_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
251 self.send_async(Message::Binary(data.into().into())).await
252 }
253
254 pub async fn ping_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
256 self.send_async(Message::Ping(data.into().into())).await
257 }
258
259 pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
261 match tokio::time::timeout(timeout, self.tx.send(message)).await {
262 Ok(Ok(())) => Ok(()),
263 Ok(Err(_)) => Err(SendError::ChannelClosed),
264 Err(_) => Err(SendError::Timeout(timeout)),
265 }
266 }
267}
268
269struct ClientRuntime {
271 is_running: AtomicBool,
272 cancel: CancellationToken,
273 message_tx: broadcast::Sender<SharedMessage>,
274 send_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
275 dispatcher: Arc<MessageDispatcher<crate::connection::DefaultWsStream>>,
276 run_state: watch::Sender<bool>,
277}
278
279impl ClientRuntime {
280 fn new(config: &ClientConfig) -> Self {
281 let (message_tx, _) = broadcast::channel(config.channel_buffer_size);
282 let dispatcher_config = DispatcherConfig::from(config);
283 let (run_state, _rx) = watch::channel(false);
284
285 Self {
286 is_running: AtomicBool::new(false),
287 cancel: CancellationToken::new(),
288 message_tx,
289 send_tx: Arc::new(RwLock::new(None)),
290 dispatcher: Arc::new(MessageDispatcher::new(dispatcher_config)),
291 run_state,
292 }
293 }
294
295 fn begin_run(&self) -> Result<(), ClientError> {
296 if self.is_running.swap(true, Ordering::SeqCst) {
297 Err(ClientError::AlreadyRunning)
298 } else {
299 let _ = self.run_state.send(true);
300 Ok(())
301 }
302 }
303
304 fn finish_run(&self) {
305 self.is_running.store(false, Ordering::SeqCst);
306 let _ = self.run_state.send(false);
307 }
308
309 fn cancel(&self) {
310 self.cancel.cancel();
311 }
312
313 fn cancel_token(&self) -> CancellationToken {
314 self.cancel.clone()
315 }
316
317 fn is_cancelled(&self) -> bool {
318 self.cancel.is_cancelled()
319 }
320
321 fn is_running(&self) -> bool {
322 self.is_running.load(Ordering::SeqCst)
323 }
324
325 fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
326 self.message_tx.subscribe()
327 }
328
329 fn message_channel(&self) -> broadcast::Sender<SharedMessage> {
330 self.message_tx.clone()
331 }
332
333 fn dispatcher(&self) -> Arc<MessageDispatcher<crate::connection::DefaultWsStream>> {
334 self.dispatcher.clone()
335 }
336
337 async fn sender(&self) -> Option<Sender> {
338 let guard = self.send_tx.read().await;
339 guard.as_ref().map(|tx| Sender { tx: tx.clone() })
340 }
341
342 async fn send(&self, message: Message) -> Result<(), SendError> {
343 let guard = self.send_tx.read().await;
344 guard.as_ref().map_or(Err(SendError::NotConnected), |tx| {
345 match tx.try_send(message) {
346 Ok(()) => Ok(()),
347 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
348 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
349 Err(SendError::ChannelClosed)
350 }
351 }
352 })
353 }
354
355 async fn send_async(&self, message: Message) -> Result<(), SendError> {
356 let tx = self
357 .send_tx
358 .read()
359 .await
360 .as_ref()
361 .ok_or(SendError::NotConnected)?
362 .clone();
363 tx.send(message).await.map_err(|_| SendError::ChannelClosed)
364 }
365
366 async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
367 let tx = self
368 .send_tx
369 .read()
370 .await
371 .as_ref()
372 .ok_or(SendError::NotConnected)?
373 .clone();
374 match tokio::time::timeout(timeout, tx.send(message)).await {
375 Ok(Ok(())) => Ok(()),
376 Ok(Err(_)) => Err(SendError::ChannelClosed),
377 Err(_) => Err(SendError::Timeout(timeout)),
378 }
379 }
380
381 async fn set_send_channel(&self, tx: mpsc::Sender<Message>) {
382 let mut guard = self.send_tx.write().await;
383 *guard = Some(tx);
384 }
385
386 async fn clear_send_channel(&self) {
387 let mut guard = self.send_tx.write().await;
388 *guard = None;
389 }
390
391 fn run_state_receiver(&self) -> watch::Receiver<bool> {
392 self.run_state.subscribe()
393 }
394}
395
396pub struct WebSocketClient {
398 uri: String,
400 config: ClientConfig,
401 handshaker: BoxHandshaker,
402 extension_host: Arc<ExtensionHost>,
403 supervisor: ConnectionSupervisor<DefaultConnector>,
404 runtime: Arc<ClientRuntime>,
405}
406
407impl WebSocketClient {
408 pub fn builder(uri: impl Into<String>) -> WebSocketClientBuilder {
410 WebSocketClientBuilder::new(uri)
411 }
412
413 pub fn new(uri: impl Into<String>) -> Self {
415 Self::builder(uri).build()
416 }
417
418 #[must_use]
426 pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
427 self.runtime.subscribe()
428 }
429
430 #[must_use]
432 pub fn uri(&self) -> &str {
433 &self.uri
434 }
435
436 #[must_use]
438 pub fn subscribe_events(&self) -> broadcast::Receiver<ConnectionEvent> {
439 self.supervisor.subscribe()
440 }
441
442 pub async fn sender(&self) -> Option<Sender> {
444 self.runtime.sender().await
445 }
446
447 pub async fn send(&self, message: Message) -> Result<(), SendError> {
449 self.runtime.send(message).await
450 }
451 pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
453 self.runtime.send_async(message).await
454 }
455
456 pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
458 self.runtime.send_timeout(message, timeout).await
459 }
460
461 pub async fn state(&self) -> ConnectionSnapshot {
463 self.supervisor.snapshot().await
464 }
465
466 #[must_use]
468 pub fn is_connected(&self) -> bool {
469 self.supervisor.is_connected()
470 }
471
472 pub async fn register_extension<E: Extension + 'static>(
474 &self,
475 extension: E,
476 ) -> Result<(), ClientError> {
477 self.extension_host
478 .register(extension)
479 .await
480 .map_err(ClientError::Extension)
481 }
482
483 pub async fn run(&self) -> Result<(), ClientError> {
485 self.runtime.begin_run()?;
486 let result = self.run_loop().await;
487 self.runtime.finish_run();
488 result
489 }
490
491 pub fn shutdown(&self) {
493 self.runtime.cancel();
497 self.supervisor.shutdown();
499 }
500
501 pub async fn shutdown_graceful(&self, timeout: Duration) -> Result<(), ClientError> {
515 let mut run_state = self.runtime.run_state_receiver();
516 self.shutdown();
517 if !self.runtime.is_running() || !*run_state.borrow() {
518 return Ok(());
519 }
520
521 let wait_for_shutdown = async {
522 while run_state.changed().await.is_ok() {
523 if !*run_state.borrow() {
524 break;
525 }
526 }
527 };
528
529 match tokio::time::timeout(timeout, wait_for_shutdown).await {
530 Ok(()) => Ok(()),
531 Err(_) => Err(ClientError::ShutdownTimeout(timeout)),
532 }
533 }
534
535 async fn run_loop(&self) -> Result<(), ClientError> {
536 loop {
537 if self.runtime.is_cancelled() {
538 tracing::info!("Shutdown requested");
539 self.extension_host.shutdown().await?;
540 return Ok(());
541 }
542 let (stream, mut send_rx, connection_id) = match self.establish_session().await {
544 Ok(t) => t,
545 Err(ClientError::Supervisor(SupervisorError::Shutdown)) => {
546 tracing::info!("Supervisor shutdown requested");
547 self.extension_host.shutdown().await?;
548 return Ok(());
549 }
550 Err(ClientError::Handshake(_)) => {
551 continue;
553 }
554 Err(e) => {
555 self.extension_host.shutdown().await?;
556 return Err(e);
557 }
558 };
559
560 let (mut recv_task, forward_task) = self.spawn_receiver_and_bridge(stream);
562
563 let disconnect_reason = self.drive_session(&mut send_rx, &mut recv_task).await;
565
566 self.cleanup_session(forward_task, disconnect_reason, connection_id)
568 .await?;
569 }
570 }
571
572 async fn connect_via_supervisor(
573 &self,
574 ) -> Result<crate::connection::DefaultWsStream, ClientError> {
575 match self.supervisor.connect().await {
576 Ok(stream) => Ok(stream),
577 Err(e) => Err(ClientError::Supervisor(e)),
578 }
579 }
580
581 async fn establish_session(
582 &self,
583 ) -> Result<
584 (
585 futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
586 mpsc::Receiver<Message>,
587 u64,
588 ),
589 ClientError,
590 > {
591 let ws_stream = self.connect_via_supervisor().await?;
592
593 let (mut sink, mut stream) = ws_stream.split();
595
596 let (send_tx, send_rx) = mpsc::channel::<Message>(self.config.send_queue_capacity);
598 self.runtime.set_send_channel(send_tx).await;
599
600 if let Err(e) = self.perform_handshake(&mut sink, &mut stream).await {
602 tracing::error!(error = ?e, "Handshake failed");
603 self.supervisor
604 .mark_disconnected(DisconnectReason::Error(e.to_string()))
605 .await;
606 if self.handshaker.is_retryable(&e) {
608 tokio::time::sleep(self.config.handshake_retry_delay).await;
609 return Err(ClientError::Handshake(e));
610 }
611 self.supervisor
613 .fatal(crate::error::ConnectError::HandshakeFailed(e.to_string()));
614 self.supervisor.shutdown();
615 return Err(ClientError::Supervisor(SupervisorError::Shutdown));
616 }
617
618 self.runtime.dispatcher().attach(sink).await;
620
621 let connection_id = self.supervisor.connection_id();
623 let snapshot = self.supervisor.snapshot().await;
624 self.extension_host
625 .update_context(connection_id, snapshot.reconnect_count)
626 .await;
627 let _ = self.extension_host.notify_connect().await;
628 tracing::info!(connection_id = connection_id, "Connected");
629
630 Ok((stream, send_rx, connection_id))
631 }
632
633 fn spawn_receiver_and_bridge(
634 &self,
635 stream: futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
636 ) -> (
637 tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
638 tokio::task::JoinHandle<()>,
639 ) {
640 let dispatcher = self.runtime.dispatcher();
642 let ext_host = self.extension_host.clone();
643 let mut disp_rx = dispatcher.subscribe();
644 let client_broadcast = self.runtime.message_channel();
645 let cancel_token = self.runtime.cancel_token();
646 let forward_task = tokio::spawn(async move {
647 loop {
648 tokio::select! {
649 () = cancel_token.cancelled() => break,
650 msg = disp_rx.recv() => {
651 match msg {
652 Ok(m) => { let _ = client_broadcast.send(m); }
653 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { }
654 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
655 }
656 }
657 }
658 }
659 });
660 let activity = self.supervisor.activity_handle();
662 let recv_task = tokio::spawn(async move {
663 dispatcher
664 .receive_loop_with_processor(
665 stream,
666 move || {
667 let activity = activity.clone();
668 async move { activity.update().await }
669 },
670 move |msg| {
671 let ext_host = ext_host.clone();
672 async move { ext_host.process_message(&msg).await }
673 },
674 )
675 .await
676 });
677
678 (recv_task, forward_task)
679 }
680
681 async fn drive_session(
682 &self,
683 send_rx: &mut mpsc::Receiver<Message>,
684 recv_task: &mut tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
685 ) -> Option<DisconnectReason> {
686 let cancel_token = self.runtime.cancel_token();
687 let dispatcher = self.runtime.dispatcher();
688 loop {
689 tokio::select! {
690 () = cancel_token.cancelled() => {
691 recv_task.abort();
692 return Some(DisconnectReason::Shutdown);
693 }
694 res = &mut *recv_task => {
695 return match res {
696 Ok(Ok(())) => Some(DisconnectReason::Normal),
697 Ok(Err(e)) => Some(match e {
698 crate::error::ReceiveError::Timeout(_) => DisconnectReason::Timeout,
699 crate::error::ReceiveError::StreamClosed => DisconnectReason::Normal,
700 crate::error::ReceiveError::WebSocket(err) => DisconnectReason::Error(err),
701 }),
702 Err(_) => Some(DisconnectReason::Error("receiver task aborted".to_string())),
703 }
704 }
705 msg = send_rx.recv() => {
706 if let Some(message) = msg {
707 if let Err(e) = dispatcher.send(message).await {
708 return Some(DisconnectReason::Error(format!("send error: {e:?}")));
709 }
710 } else {
711 return Some(DisconnectReason::Error("send channel closed".to_string()));
712 }
713 }
714 }
715 }
716 }
717
718 async fn cleanup_session(
719 &self,
720 forward_task: tokio::task::JoinHandle<()>,
721 disconnect_reason: Option<DisconnectReason>,
722 connection_id: u64,
723 ) -> Result<(), ClientError> {
724 self.runtime.clear_send_channel().await;
725 forward_task.abort();
726 self.runtime.dispatcher().detach().await;
727
728 let reason = disconnect_reason.unwrap_or(DisconnectReason::Normal);
729 self.supervisor.mark_disconnected(reason.clone()).await;
730 let _ = self.extension_host.notify_disconnect().await;
731
732 tracing::info!(
733 connection_id = connection_id,
734 reason = ?Some(reason.clone()),
735 "Disconnected"
736 );
737
738 if self.runtime.is_cancelled() {
739 tracing::info!("Shutdown requested after disconnect");
740 self.extension_host.shutdown().await?;
741 return Ok(());
742 }
743
744 Ok(())
745 }
746
747 async fn perform_handshake<S, St>(
748 &self,
749 sink: &mut S,
750 stream: &mut St,
751 ) -> Result<(), HandshakeError>
752 where
753 S: futures_util::Sink<Message, Error = tungstenite::Error> + Unpin + Send,
754 St: futures_util::Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send,
755 {
756 use crate::context::ConnectionContext;
757
758 let snapshot = self.supervisor.snapshot().await;
759 let context =
760 ConnectionContext::new(snapshot.id).with_reconnect_count(snapshot.reconnect_count);
761
762 self.handshaker
763 .handshake_with_timeout(sink, stream, &context)
764 .await
765 }
766
767 }
769
770pub struct WebSocketClientBuilder {
772 uri: String,
773 config: ClientConfig,
774 retry_strategy: Box<dyn RetryStrategy>,
775 handshaker: BoxHandshaker,
776}
777
778impl WebSocketClientBuilder {
779 pub fn new(uri: impl Into<String>) -> Self {
781 Self {
782 uri: uri.into(),
783 config: ClientConfig::default(),
784 retry_strategy: Box::new(ExponentialBackoff::default()),
785 handshaker: Box::new(NoOpHandshaker),
786 }
787 }
788
789 #[must_use]
791 #[allow(clippy::missing_const_for_fn)] pub fn config(mut self, config: ClientConfig) -> Self {
793 self.config = config;
794 self
795 }
796
797 #[must_use]
799 pub const fn receive_timeout(mut self, timeout: Duration) -> Self {
800 self.config.receive_timeout = timeout;
801 self
802 }
803
804 #[must_use]
806 pub fn retry_strategy<S: RetryStrategy + 'static>(mut self, strategy: S) -> Self {
807 self.retry_strategy = Box::new(strategy);
808 self
809 }
810
811 #[must_use]
813 pub fn handshaker<H: Handshaker + 'static>(mut self, handshaker: H) -> Self {
814 self.handshaker = Box::new(handshaker);
815 self
816 }
817
818 #[must_use]
820 pub fn no_retry(mut self) -> Self {
821 self.retry_strategy = Box::new(NoRetry);
822 self
823 }
824
825 #[must_use]
827 pub fn exponential_backoff(
828 mut self,
829 initial: Duration,
830 max: Duration,
831 multiplier: f64,
832 ) -> Self {
833 self.retry_strategy = Box::new(
834 ExponentialBackoff::new(initial, max)
835 .with_factor(multiplier)
836 .with_jitter(0.1),
837 );
838 self
839 }
840
841 #[must_use]
843 pub fn build(self) -> WebSocketClient {
844 let runtime = Arc::new(ClientRuntime::new(&self.config));
845
846 let connector = DefaultConnector::new()
848 .with_nodelay(self.config.disable_nagle)
849 ;
852 let connector = if let Some(ws_cfg) = self.config.ws_config {
853 connector.with_ws_config(ws_cfg)
854 } else {
855 connector
856 };
857
858 let mut sup_cfg = crate::connection::SupervisorConfig::new();
860 sup_cfg.retry_strategy = self.retry_strategy;
861 sup_cfg.exit_on_first_failure = self.config.exit_on_first_failure;
862 sup_cfg.connect_timeout = self.config.connect_timeout;
863
864 let supervisor =
865 ConnectionSupervisor::with_connector(self.uri.clone(), connector).with_config(sup_cfg);
866
867 WebSocketClient {
868 uri: self.uri,
869 config: self.config,
870 handshaker: self.handshaker,
871 extension_host: Arc::new(ExtensionHost::new()),
872 supervisor,
873 runtime,
874 }
875 }
876}
877
878pub trait WebSocketClientExt {
880 fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>>;
882}
883
884impl WebSocketClientExt for WebSocketClient {
885 fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>> {
886 tokio::spawn(async move { self.run().await })
887 }
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893
894 #[test]
895 fn test_client_config_defaults() {
896 let config = ClientConfig::default();
897 assert_eq!(config.receive_timeout, Duration::from_secs(20));
898 assert!(!config.exit_on_first_failure);
899 assert!(!config.disable_nagle);
900 }
901
902 #[test]
903 fn test_client_config_presets() {
904 let fast = ClientConfig::fast_reconnect();
905 assert_eq!(fast.receive_timeout, Duration::from_secs(10));
906 assert!(fast.disable_nagle);
907
908 let stable = ClientConfig::stable_connection();
909 assert_eq!(stable.receive_timeout, Duration::from_secs(60));
910 }
911
912 #[test]
913 fn test_builder() {
914 let client = WebSocketClient::builder("ws://localhost:8080")
915 .receive_timeout(Duration::from_secs(30))
916 .no_retry()
917 .build();
918
919 assert_eq!(client.config.receive_timeout, Duration::from_secs(30));
920 }
921
922 #[tokio::test]
923 async fn test_sender_backpressure_full() {
924 use tokio::sync::mpsc;
925 let (tx, mut _rx) = mpsc::channel::<Message>(1);
927 let sender = Sender { tx };
928
929 assert!(sender.send(Message::Text("a".into())).is_ok());
931 let res = sender.send(Message::Text("b".into()));
933 assert!(matches!(res, Err(crate::error::SendError::ChannelFull)));
934 }
935
936 #[tokio::test]
937 async fn test_client_shutdown_exits_quickly() {
938 let client = WebSocketClient::builder("wss://example.test/ws")
940 .receive_timeout(std::time::Duration::from_millis(100))
941 .no_retry()
942 .build();
943 let client = std::sync::Arc::new(client);
944
945 client.shutdown();
947
948 let h = {
949 let c = client.clone();
950 tokio::spawn(async move { c.run().await })
951 };
952
953 let res = tokio::time::timeout(std::time::Duration::from_secs(1), h).await;
954 assert!(res.is_ok(), "run() did not exit in time");
955 let run_res = res.unwrap().unwrap();
956 assert!(run_res.is_ok(), "run() returned error: {run_res:?}");
957 }
958}