ws_rs/client.rs
1use std::fs::File;
2use std::io::BufReader;
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures_util::{SinkExt, StreamExt};
8use log::{error, info, warn};
9use rustls::RootCertStore;
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11use std::collections::HashMap;
12use tokio::sync::{Mutex, mpsc};
13use tokio::task::JoinHandle;
14use tokio::time::timeout;
15use tokio_tungstenite::tungstenite::Message;
16use tokio_tungstenite::{Connector, connect_async_tls_with_config};
17use url::Url;
18
19/// WebSocket client structure for handling secure WebSocket connections.
20///
21/// This client supports TLS/SSL secure connections and provides a simple interface
22/// for sending and receiving messages. It is optimized for performance with features like:
23/// - Binary message support
24/// - Connection timeout handling
25/// - Certificate caching
26/// - Auto-reconnection capabilities
27/// - Optimized memory usage
28///
29/// # Example
30///
31/// ```ignore
32/// use ws_rs::client::WebSocketClient;
33/// use ws_rs::client::MessageType;
34/// use std::time::Duration;
35///
36/// #[tokio::main]
37/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
38/// // Create a client with custom configuration
39/// let mut client = WebSocketClient::builder()
40/// .with_channel_capacity(200)
41/// .with_connection_timeout(Duration::from_secs(10))
42/// .with_auto_reconnect(true)
43/// .build();
44///
45/// // Connect to a WebSocket server
46/// client.connect(
47/// "wss://127.0.0.1:9000",
48/// "./certs",
49/// "client_cert.pem",
50/// "client_key.pem",
51/// "ca_cert.pem"
52/// ).await?;
53///
54/// // Send a text message
55/// client.send_message(MessageType::Text("Hello, server!".to_string())).await?;
56///
57/// // Send a binary message
58/// client.send_message(MessageType::Binary(vec![1, 2, 3, 4])).await?;
59///
60/// // Receive a message
61/// if let Some(response) = client.receive_message().await {
62/// match response {
63/// MessageType::Text(text) => println!("Received text: {}", text),
64/// MessageType::Binary(data) => println!("Received binary data: {} bytes", data.len()),
65/// }
66/// }
67///
68/// // Close the connection
69/// client.close().await;
70///
71/// Ok(())
72/// }
73/// ```
74/// Message type enum for WebSocket communication
75#[derive(Debug, Clone)]
76pub enum MessageType {
77 /// Text message
78 Text(String),
79 /// Binary message
80 Binary(Vec<u8>),
81}
82
83/// Configuration for WebSocketClient
84#[derive(Debug, Clone)]
85pub struct WSClientConfig {
86 /// Channel capacity for message queues
87 pub channel_capacity: usize,
88 /// Connection timeout in seconds
89 pub connection_timeout: Duration,
90 /// Whether to automatically reconnect on connection failure
91 pub auto_reconnect: bool,
92 /// Maximum reconnection attempts
93 pub max_reconnect_attempts: u32,
94 /// Delay between reconnection attempts
95 pub reconnect_delay: Duration,
96}
97
98impl Default for WSClientConfig {
99 fn default() -> Self {
100 Self {
101 channel_capacity: 100,
102 connection_timeout: Duration::from_secs(30),
103 auto_reconnect: false,
104 max_reconnect_attempts: 5,
105 reconnect_delay: Duration::from_secs(2),
106 }
107 }
108}
109
110/// Builder for WebSocketClient
111pub struct WebSocketClientBuilder {
112 config: WSClientConfig,
113}
114
115impl WebSocketClientBuilder {
116 /// Create a new builder with default configuration
117 pub fn new() -> Self {
118 Self {
119 config: WSClientConfig::default(),
120 }
121 }
122
123 /// Set channel capacity
124 pub fn with_channel_capacity(mut self, capacity: usize) -> Self {
125 self.config.channel_capacity = capacity;
126 self
127 }
128
129 /// Set connection timeout
130 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
131 self.config.connection_timeout = timeout;
132 self
133 }
134
135 /// Enable or disable auto-reconnect
136 pub fn with_auto_reconnect(mut self, auto_reconnect: bool) -> Self {
137 self.config.auto_reconnect = auto_reconnect;
138 self
139 }
140
141 /// Set maximum reconnection attempts
142 pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
143 self.config.max_reconnect_attempts = attempts;
144 self
145 }
146
147 /// Set delay between reconnection attempts
148 pub fn with_reconnect_delay(mut self, delay: Duration) -> Self {
149 self.config.reconnect_delay = delay;
150 self
151 }
152
153 /// Build the WebSocketClient with the configured options
154 pub fn build(self) -> WebSocketClient {
155 WebSocketClient {
156 sender: None,
157 receiver: None,
158 ws_handle: None,
159 is_connected: false,
160 server_url: None,
161 cert_paths: None,
162 config: self.config,
163 cert_cache: Arc::new(Mutex::new(HashMap::new())),
164 }
165 }
166}
167
168pub struct WebSocketClient {
169 sender: Option<mpsc::Sender<MessageType>>,
170 receiver: Option<mpsc::Receiver<MessageType>>,
171 ws_handle: Option<JoinHandle<()>>,
172 is_connected: bool,
173 server_url: Option<Url>,
174 cert_paths: Option<(String, String, String, String, String)>,
175 config: WSClientConfig,
176 cert_cache: Arc<Mutex<HashMap<String, Arc<rustls::ClientConfig>>>>,
177}
178
179impl WebSocketClient {
180 /// Creates a new WebSocketClient instance with default configuration.
181 ///
182 /// The new client is initially disconnected. Use the `connect` method
183 /// to establish a connection to a WebSocket server.
184 ///
185 /// # Returns
186 ///
187 /// A new `WebSocketClient` instance.
188 pub fn new() -> Self {
189 Self::builder().build()
190 }
191
192 /// Creates a builder for configuring a WebSocketClient.
193 ///
194 /// # Returns
195 ///
196 /// A WebSocketClientBuilder instance.
197 pub fn builder() -> WebSocketClientBuilder {
198 WebSocketClientBuilder::new()
199 }
200
201 /// Loads certificates from a PEM file.
202 ///
203 /// # Parameters
204 ///
205 /// * `path` - Path to the certificate file
206 ///
207 /// # Returns
208 ///
209 /// A vector of certificates in DER format.
210 ///
211 /// # Panics
212 ///
213 /// Panics if the certificate file cannot be opened or parsed.
214 fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>> {
215 let file = File::open(path)?;
216 let mut reader = BufReader::new(file);
217 let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
218
219 if certs.is_empty() {
220 return Err("No certificates found in file".into());
221 }
222
223 Ok(certs)
224 }
225
226 /// Loads a private key from a PEM file.
227 ///
228 /// # Parameters
229 ///
230 /// * `path` - Path to the private key file
231 ///
232 /// # Returns
233 ///
234 /// The private key in DER format.
235 ///
236 /// # Errors
237 ///
238 /// Returns an error if the private key file cannot be opened, parsed, or if no keys are found.
239 fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>> {
240 let file = File::open(path)?;
241 let mut reader = BufReader::new(file);
242 let keys =
243 rustls_pemfile::pkcs8_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?;
244
245 if keys.is_empty() {
246 return Err("No private key found in file".into());
247 }
248
249 // Use the first private key
250 Ok(PrivateKeyDer::Pkcs8(keys.into_iter().next().unwrap()))
251 }
252
253 /// Creates a TLS client configuration from certificates and keys.
254 ///
255 /// This method attempts to reuse cached configurations when possible.
256 ///
257 /// # Parameters
258 ///
259 /// * `cache_key` - A unique key for caching the configuration
260 /// * `client_cert_path` - Path to client certificate
261 /// * `client_key_path` - Path to client private key
262 /// * `ca_cert_path` - Path to CA certificate
263 ///
264 /// # Returns
265 ///
266 /// A TLS client configuration or an error.
267 async fn create_tls_config(
268 &self,
269 cache_key: &str,
270 client_cert_path: &Path,
271 client_key_path: &Path,
272 ca_cert_path: &Path,
273 ) -> Result<Arc<rustls::ClientConfig>, Box<dyn std::error::Error>> {
274 // Check if we have a cached configuration
275 {
276 let cache = self.cert_cache.lock().await;
277 if let Some(config) = cache.get(cache_key) {
278 info!("Using cached TLS configuration");
279 return Ok(config.clone());
280 }
281 }
282
283 // Load certificates and keys
284 let client_certs = Self::load_certs(client_cert_path)?;
285 let client_key = Self::load_private_key(client_key_path)?;
286 let ca_certs = Self::load_certs(ca_cert_path)?;
287
288 // Create TLS configuration
289 let mut root_store = RootCertStore::empty();
290 for cert in ca_certs {
291 root_store.add(cert)?;
292 }
293
294 let client_config = rustls::ClientConfig::builder()
295 .with_root_certificates(root_store)
296 .with_client_auth_cert(client_certs, client_key)?;
297
298 let config = Arc::new(client_config);
299
300 // Cache the configuration
301 {
302 let mut cache = self.cert_cache.lock().await;
303 cache.insert(cache_key.to_string(), config.clone());
304 }
305
306 Ok(config)
307 }
308
309 /// Connects to a WebSocket server using TLS.
310 ///
311 /// This method establishes a secure WebSocket connection to the specified server URL
312 /// using the provided certificates and keys.
313 ///
314 /// # Parameters
315 ///
316 /// * `server_url` - The WebSocket server URL (e.g., "wss://example.com:9000")
317 /// * `cert_dir` - Directory containing the certificate files
318 /// * `client_cert_file` - Client certificate filename
319 /// * `client_key_file` - Client private key filename
320 /// * `ca_cert_file` - CA certificate filename
321 ///
322 /// # Returns
323 ///
324 /// `Ok(())` on successful connection, or an error if the connection fails.
325 ///
326 /// # Errors
327 ///
328 /// Returns an error if URL parsing fails, certificate loading fails, or connection fails.
329 pub async fn connect(
330 &mut self,
331 server_url: &str,
332 cert_dir: &str,
333 client_cert_file: &str,
334 client_key_file: &str,
335 ca_cert_file: &str,
336 ) -> Result<(), Box<dyn std::error::Error>> {
337 // Parse server URL
338 let server_url = Url::parse(server_url)?;
339
340 // Save connection parameters for potential reconnection
341 self.server_url = Some(server_url.clone());
342 self.cert_paths = Some((
343 server_url.to_string(),
344 cert_dir.to_string(),
345 client_cert_file.to_string(),
346 client_key_file.to_string(),
347 ca_cert_file.to_string(),
348 ));
349
350 // Handle connection with retries
351 let mut current_attempt = 0;
352 loop {
353 // Perform the connection attempt
354 let result = self
355 .connect_internal(
356 &server_url,
357 cert_dir,
358 client_cert_file,
359 client_key_file,
360 ca_cert_file,
361 current_attempt,
362 )
363 .await;
364
365 // Check if we need to retry based on the special error
366 match result {
367 Err(e) => {
368 let err_str = e.to_string();
369 if err_str.starts_with("__RETRY_CONNECTION_") {
370 // Parse the attempt number
371 if let Ok(next_attempt) = err_str
372 .trim_start_matches("__RETRY_CONNECTION_")
373 .parse::<u32>()
374 {
375 current_attempt = next_attempt;
376 // Wait before retrying
377 tokio::time::sleep(self.config.reconnect_delay).await;
378 continue;
379 }
380 }
381 return Err(e);
382 }
383 Ok(_) => return Ok(()),
384 }
385 }
386 }
387
388 /// Internal connect method that handles reconnection attempts
389 ///
390 /// This function uses manual reconnection logic instead of recursive calls
391 /// to avoid boxing issues with async functions.
392 async fn connect_internal(
393 &mut self,
394 server_url: &Url,
395 cert_dir: &str,
396 client_cert_file: &str,
397 client_key_file: &str,
398 ca_cert_file: &str,
399 attempt: u32,
400 ) -> Result<(), Box<dyn std::error::Error>> {
401 // Certificate paths
402 let cert_dir = Path::new(cert_dir);
403 let client_cert = cert_dir.join(client_cert_file);
404 let client_key = cert_dir.join(client_key_file);
405 let ca_cert = cert_dir.join(ca_cert_file);
406
407 info!("Client certificate: {:?}", client_cert);
408 info!("Client private key: {:?}", client_key);
409 info!("CA certificate: {:?}", ca_cert);
410
411 // Create a cache key for the TLS configuration
412 let cache_key = format!(
413 "{}:{}:{}:{}",
414 server_url,
415 client_cert.display(),
416 client_key.display(),
417 ca_cert.display()
418 );
419
420 info!("Loading certificates and keys...");
421 let tls_config = match self
422 .create_tls_config(&cache_key, &client_cert, &client_key, &ca_cert)
423 .await
424 {
425 Ok(config) => config,
426 Err(e) => {
427 error!("Failed to create TLS configuration: {}", e);
428 return Err(e);
429 }
430 };
431
432 // Create TLS connector
433 let connector = Connector::Rustls(tls_config);
434
435 // Connect to WebSocket server with timeout
436 info!("Connecting to WebSocket server: {}", server_url);
437 // Use timeout for connection attempt
438 let connection_attempt =
439 connect_async_tls_with_config(server_url.clone(), None, false, Some(connector));
440 let ws_stream = match timeout(self.config.connection_timeout, connection_attempt).await {
441 Ok(result) => {
442 match result {
443 Ok((stream, _)) => stream,
444 Err(e) => {
445 error!("Connection error: {}", e);
446
447 // Handle reconnection if enabled
448 if self.config.auto_reconnect
449 && attempt < self.config.max_reconnect_attempts
450 {
451 warn!(
452 "Reconnection attempt {}/{} in {}s",
453 attempt + 1,
454 self.config.max_reconnect_attempts,
455 self.config.reconnect_delay.as_secs()
456 );
457
458 // Wait before attempting to reconnect
459 tokio::time::sleep(self.config.reconnect_delay).await;
460
461 // Rather than making a recursive call, we'll return a special error
462 // that indicates we should retry the connection
463 return Err(format!("__RETRY_CONNECTION_{}", attempt + 1).into());
464 }
465
466 return Err(e.into());
467 }
468 }
469 }
470 Err(_) => {
471 let err = format!(
472 "Connection timeout after {:?}",
473 self.config.connection_timeout
474 );
475 error!("{}", err);
476
477 // Handle reconnection if enabled
478 if self.config.auto_reconnect && attempt < self.config.max_reconnect_attempts {
479 warn!(
480 "Reconnection attempt {}/{} in {}s",
481 attempt + 1,
482 self.config.max_reconnect_attempts,
483 self.config.reconnect_delay.as_secs()
484 );
485
486 // Wait before attempting to reconnect
487 tokio::time::sleep(self.config.reconnect_delay).await;
488
489 // Rather than making a recursive call, we'll return a special error
490 // that indicates we should retry the connection
491 return Err(format!("__RETRY_CONNECTION_{}", attempt + 1).into());
492 }
493
494 return Err(err.into());
495 }
496 };
497
498 info!("Connected to WebSocket server");
499
500 // Create channels for message passing with configured capacity
501 let (tx_sender, mut rx_sender) = mpsc::channel::<MessageType>(self.config.channel_capacity);
502 let (tx_receiver, rx_receiver) = mpsc::channel::<MessageType>(self.config.channel_capacity);
503
504 // Split connection into sender and receiver
505 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
506
507 // Task for handling outgoing messages
508 let send_task = tokio::spawn(async move {
509 while let Some(message) = rx_sender.recv().await {
510 let ws_message = match message {
511 MessageType::Text(text) => Message::Text(text),
512 MessageType::Binary(data) => Message::Binary(data),
513 };
514
515 match ws_sender.send(ws_message).await {
516 Ok(_) => info!("Message sent"),
517 Err(e) => {
518 error!("Error sending message: {}", e);
519 break;
520 }
521 }
522 }
523 // Close WebSocket connection
524 let _ = ws_sender.close().await;
525 });
526
527 // Task for handling incoming messages
528 let receive_task = tokio::spawn(async move {
529 while let Some(msg) = ws_receiver.next().await {
530 match msg {
531 Ok(msg) => {
532 let message = match msg {
533 Message::Text(text) => {
534 info!("Received text message: {} bytes", text.len());
535 MessageType::Text(text)
536 }
537 Message::Binary(data) => {
538 info!("Received binary message: {} bytes", data.len());
539 MessageType::Binary(data)
540 }
541 Message::Ping(_) | Message::Pong(_) => {
542 // Handle ping/pong internally
543 continue;
544 }
545 Message::Close(_) => {
546 info!("Received close frame");
547 break;
548 }
549 // Handle other message types if needed
550 _ => continue,
551 };
552
553 if let Err(e) = tx_receiver.send(message).await {
554 error!("Error forwarding to receiver channel: {}", e);
555 break;
556 }
557 }
558 Err(e) => {
559 error!("Error receiving message: {}", e);
560 break;
561 }
562 }
563 }
564 });
565
566 // Combine tasks with select to handle termination
567 let handle = tokio::spawn(async move {
568 tokio::select! {
569 _ = send_task => info!("Send task completed"),
570 _ = receive_task => info!("Receive task completed"),
571 }
572 });
573
574 // Update client state
575 self.sender = Some(tx_sender);
576 self.receiver = Some(rx_receiver);
577 self.ws_handle = Some(handle);
578 self.is_connected = true;
579
580 Ok(())
581 }
582
583 /// Reconnects to the WebSocket server using the last connection parameters.
584 ///
585 /// # Returns
586 ///
587 /// `Ok(())` on successful reconnection, or an error if reconnection fails.
588 ///
589 /// # Errors
590 ///
591 /// Returns an error if no previous connection exists or if reconnection fails.
592 pub async fn reconnect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
593 if let Some((url, cert_dir, client_cert, client_key, ca_cert)) = self.cert_paths.clone() {
594 // Close existing connection if any
595 if self.is_connected {
596 self.close().await;
597 }
598
599 // Connect using saved parameters
600 self.connect(&url, &cert_dir, &client_cert, &client_key, &ca_cert)
601 .await
602 } else {
603 Err("No previous connection parameters available for reconnection".into())
604 }
605 }
606
607 /// Sends a message to the connected WebSocket server.
608 ///
609 /// # Parameters
610 ///
611 /// * `message` - The message to send (text or binary)
612 ///
613 /// # Returns
614 ///
615 /// `Ok(())` if the message was queued for sending, or an error if not connected.
616 ///
617 /// # Errors
618 ///
619 /// Returns an error if the client is not connected or if the message cannot be sent.
620 pub async fn send_message(
621 &self,
622 message: MessageType,
623 ) -> Result<(), Box<dyn std::error::Error>> {
624 if let Some(sender) = &self.sender {
625 sender.send(message).await?;
626 Ok(())
627 } else {
628 Err("Not connected to WebSocket server".into())
629 }
630 }
631
632 /// Sends a text message to the connected WebSocket server.
633 ///
634 /// This is a convenience method that wraps send_message.
635 ///
636 /// # Parameters
637 ///
638 /// * `text` - The text message to send
639 ///
640 /// # Returns
641 ///
642 /// `Ok(())` if the message was queued for sending, or an error if not connected.
643 pub async fn send_text(&self, text: String) -> Result<(), Box<dyn std::error::Error>> {
644 self.send_message(MessageType::Text(text)).await
645 }
646
647 /// Sends a binary message to the connected WebSocket server.
648 ///
649 /// This is a convenience method that wraps send_message.
650 ///
651 /// # Parameters
652 ///
653 /// * `data` - The binary data to send
654 ///
655 /// # Returns
656 ///
657 /// `Ok(())` if the message was queued for sending, or an error if not connected.
658 pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), Box<dyn std::error::Error>> {
659 self.send_message(MessageType::Binary(data)).await
660 }
661
662 /// Receives a message from the WebSocket server.
663 ///
664 /// This method waits for the next message from the server. If no message
665 /// is available or the connection is closed, it returns `None`.
666 ///
667 /// # Returns
668 ///
669 /// * `Some(MessageType)` - The received message (text or binary)
670 /// * `None` - If not connected or the connection was closed
671 pub async fn receive_message(&mut self) -> Option<MessageType> {
672 if let Some(receiver) = &mut self.receiver {
673 receiver.recv().await
674 } else {
675 None
676 }
677 }
678
679 /// Receives a message with timeout.
680 ///
681 /// This method waits for the next message from the server with a timeout.
682 ///
683 /// # Parameters
684 ///
685 /// * `timeout_duration` - Maximum time to wait for a message
686 ///
687 /// # Returns
688 ///
689 /// * `Ok(Some(MessageType))` - A message was received
690 /// * `Ok(None)` - No message received (not connected)
691 /// * `Err(_)` - Timeout occurred
692 pub async fn receive_message_timeout(
693 &mut self,
694 timeout_duration: Duration,
695 ) -> Result<Option<MessageType>, tokio::time::error::Elapsed> {
696 if let Some(receiver) = &mut self.receiver {
697 timeout(timeout_duration, receiver.recv()).await
698 } else {
699 Ok(None)
700 }
701 }
702
703 /// Checks if the client is connected to a WebSocket server.
704 ///
705 /// # Returns
706 ///
707 /// `true` if connected, `false` otherwise.
708 pub fn is_connected(&self) -> bool {
709 self.is_connected
710 }
711
712 /// Closes the WebSocket connection.
713 ///
714 /// This method gracefully shuts down the connection by:
715 /// 1. Dropping the sender channel to trigger closing the WebSocket
716 /// 2. Waiting for the worker task to complete
717 /// 3. Cleaning up resources
718 ///
719 /// The client can be reconnected after closing by calling `connect()` again.
720 pub async fn close(&mut self) {
721 // Drop the sender channel to trigger close operation
722 self.sender = None;
723
724 // Wait for the main task to complete
725 if let Some(handle) = self.ws_handle.take() {
726 let _ = handle.await;
727 }
728
729 self.receiver = None;
730 self.is_connected = false;
731
732 info!("WebSocket connection closed");
733 }
734
735 /// Sends a ping message to check connection health.
736 ///
737 /// This method can be used to keep the connection alive or
738 /// check if the server is still responsive.
739 ///
740 /// # Returns
741 ///
742 /// `Ok(())` if the ping was sent, or an error if not connected.
743 pub async fn ping(&self) -> Result<(), Box<dyn std::error::Error>> {
744 if let Some(sender) = &self.sender {
745 // Use an empty binary message as a ping
746 sender.send(MessageType::Binary(Vec::new())).await?;
747 Ok(())
748 } else {
749 Err("Not connected to WebSocket server".into())
750 }
751 }
752
753 /// Clears the certificate cache.
754 ///
755 /// This method can be useful to force reloading of certificates
756 /// if they have been updated on disk.
757 pub async fn clear_cert_cache(&self) {
758 let mut cache = self.cert_cache.lock().await;
759 cache.clear();
760 info!("Certificate cache cleared");
761 }
762
763 /// Checks if a connection is active and sends a ping to verify connectivity.
764 ///
765 /// Returns true if the connection is active and responsive.
766 pub async fn check_connection(&self) -> bool {
767 if !self.is_connected {
768 return false;
769 }
770
771 match self.ping().await {
772 Ok(_) => true,
773 Err(_) => false,
774 }
775 }
776
777 /// Gets the current configuration.
778 ///
779 /// # Returns
780 ///
781 /// A reference to the current client configuration.
782 pub fn get_config(&self) -> &WSClientConfig {
783 &self.config
784 }
785}
786
787impl Drop for WebSocketClient {
788 fn drop(&mut self) {
789 // If the client is still connected when going out of scope,
790 // drop all channels to allow resources to be cleaned up
791 self.sender = None;
792 self.receiver = None;
793
794 // Drop the task handle, allowing it to complete on its own
795 self.ws_handle = None;
796 }
797}
798
799#[cfg(test)]
800mod tests {
801 use super::*;
802
803 #[test]
804 fn test_message_type() {
805 let text = MessageType::Text("hello".to_string());
806 let binary = MessageType::Binary(vec![1, 2, 3]);
807
808 match text {
809 MessageType::Text(s) => assert_eq!(s, "hello"),
810 _ => panic!("Expected Text variant"),
811 }
812
813 match binary {
814 MessageType::Binary(b) => assert_eq!(b, vec![1, 2, 3]),
815 _ => panic!("Expected Binary variant"),
816 }
817 }
818
819 #[test]
820 fn test_client_config_default() {
821 let config = WSClientConfig::default();
822 assert_eq!(config.channel_capacity, 100);
823 assert_eq!(config.connection_timeout, Duration::from_secs(30));
824 assert_eq!(config.auto_reconnect, false);
825 }
826
827 #[test]
828 fn test_client_builder() {
829 let client = WebSocketClient::builder()
830 .with_channel_capacity(200)
831 .with_connection_timeout(Duration::from_secs(10))
832 .with_auto_reconnect(true)
833 .build();
834
835 assert_eq!(client.config.channel_capacity, 200);
836 assert_eq!(client.config.connection_timeout, Duration::from_secs(10));
837 assert_eq!(client.config.auto_reconnect, true);
838 }
839}