ws_rs/
client.rs

1use std::fs::File;
2use std::io::BufReader;
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures_util::{SinkExt, StreamExt};
8use log::{error, info, warn};
9use rustls::RootCertStore;
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11use std::collections::HashMap;
12use tokio::sync::{Mutex, mpsc};
13use tokio::task::JoinHandle;
14use tokio::time::timeout;
15use tokio_tungstenite::tungstenite::Message;
16use tokio_tungstenite::{Connector, connect_async_tls_with_config};
17use url::Url;
18
19/// WebSocket client structure for handling secure WebSocket connections.
20///
21/// This client supports TLS/SSL secure connections and provides a simple interface
22/// for sending and receiving messages. It is optimized for performance with features like:
23/// - Binary message support
24/// - Connection timeout handling
25/// - Certificate caching
26/// - Auto-reconnection capabilities
27/// - Optimized memory usage
28///
29/// # Example
30///
31/// ```ignore
32/// use ws_rs::client::WebSocketClient;
33/// use ws_rs::client::MessageType;
34/// use std::time::Duration;
35///
36/// #[tokio::main]
37/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
38///     // Create a client with custom configuration
39///     let mut client = WebSocketClient::builder()
40///         .with_channel_capacity(200)
41///         .with_connection_timeout(Duration::from_secs(10))
42///         .with_auto_reconnect(true)
43///         .build();
44///
45///     // Connect to a WebSocket server
46///     client.connect(
47///         "wss://127.0.0.1:9000",
48///         "./certs",
49///         "client_cert.pem",
50///         "client_key.pem",
51///         "ca_cert.pem"
52///     ).await?;
53///
54///     // Send a text message
55///     client.send_message(MessageType::Text("Hello, server!".to_string())).await?;
56///
57///     // Send a binary message
58///     client.send_message(MessageType::Binary(vec![1, 2, 3, 4])).await?;
59///
60///     // Receive a message
61///     if let Some(response) = client.receive_message().await {
62///         match response {
63///             MessageType::Text(text) => println!("Received text: {}", text),
64///             MessageType::Binary(data) => println!("Received binary data: {} bytes", data.len()),
65///         }
66///     }
67///
68///     // Close the connection
69///     client.close().await;
70///
71///     Ok(())
72/// }
73/// ```
74/// Message type enum for WebSocket communication
75#[derive(Debug, Clone)]
76pub enum MessageType {
77    /// Text message
78    Text(String),
79    /// Binary message
80    Binary(Vec<u8>),
81}
82
83/// Configuration for WebSocketClient
84#[derive(Debug, Clone)]
85pub struct WSClientConfig {
86    /// Channel capacity for message queues
87    pub channel_capacity: usize,
88    /// Connection timeout in seconds
89    pub connection_timeout: Duration,
90    /// Whether to automatically reconnect on connection failure
91    pub auto_reconnect: bool,
92    /// Maximum reconnection attempts
93    pub max_reconnect_attempts: u32,
94    /// Delay between reconnection attempts
95    pub reconnect_delay: Duration,
96}
97
98impl Default for WSClientConfig {
99    fn default() -> Self {
100        Self {
101            channel_capacity: 100,
102            connection_timeout: Duration::from_secs(30),
103            auto_reconnect: false,
104            max_reconnect_attempts: 5,
105            reconnect_delay: Duration::from_secs(2),
106        }
107    }
108}
109
110/// Builder for WebSocketClient
111pub struct WebSocketClientBuilder {
112    config: WSClientConfig,
113}
114
115impl WebSocketClientBuilder {
116    /// Create a new builder with default configuration
117    pub fn new() -> Self {
118        Self {
119            config: WSClientConfig::default(),
120        }
121    }
122
123    /// Set channel capacity
124    pub fn with_channel_capacity(mut self, capacity: usize) -> Self {
125        self.config.channel_capacity = capacity;
126        self
127    }
128
129    /// Set connection timeout
130    pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
131        self.config.connection_timeout = timeout;
132        self
133    }
134
135    /// Enable or disable auto-reconnect
136    pub fn with_auto_reconnect(mut self, auto_reconnect: bool) -> Self {
137        self.config.auto_reconnect = auto_reconnect;
138        self
139    }
140
141    /// Set maximum reconnection attempts
142    pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
143        self.config.max_reconnect_attempts = attempts;
144        self
145    }
146
147    /// Set delay between reconnection attempts
148    pub fn with_reconnect_delay(mut self, delay: Duration) -> Self {
149        self.config.reconnect_delay = delay;
150        self
151    }
152
153    /// Build the WebSocketClient with the configured options
154    pub fn build(self) -> WebSocketClient {
155        WebSocketClient {
156            sender: None,
157            receiver: None,
158            ws_handle: None,
159            is_connected: false,
160            server_url: None,
161            cert_paths: None,
162            config: self.config,
163            cert_cache: Arc::new(Mutex::new(HashMap::new())),
164        }
165    }
166}
167
168pub struct WebSocketClient {
169    sender: Option<mpsc::Sender<MessageType>>,
170    receiver: Option<mpsc::Receiver<MessageType>>,
171    ws_handle: Option<JoinHandle<()>>,
172    is_connected: bool,
173    server_url: Option<Url>,
174    cert_paths: Option<(String, String, String, String, String)>,
175    config: WSClientConfig,
176    cert_cache: Arc<Mutex<HashMap<String, Arc<rustls::ClientConfig>>>>,
177}
178
179impl WebSocketClient {
180    /// Creates a new WebSocketClient instance with default configuration.
181    ///
182    /// The new client is initially disconnected. Use the `connect` method
183    /// to establish a connection to a WebSocket server.
184    ///
185    /// # Returns
186    ///
187    /// A new `WebSocketClient` instance.
188    pub fn new() -> Self {
189        Self::builder().build()
190    }
191
192    /// Creates a builder for configuring a WebSocketClient.
193    ///
194    /// # Returns
195    ///
196    /// A WebSocketClientBuilder instance.
197    pub fn builder() -> WebSocketClientBuilder {
198        WebSocketClientBuilder::new()
199    }
200
201    /// Loads certificates from a PEM file.
202    ///
203    /// # Parameters
204    ///
205    /// * `path` - Path to the certificate file
206    ///
207    /// # Returns
208    ///
209    /// A vector of certificates in DER format.
210    ///
211    /// # Panics
212    ///
213    /// Panics if the certificate file cannot be opened or parsed.
214    fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>> {
215        let file = File::open(path)?;
216        let mut reader = BufReader::new(file);
217        let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
218
219        if certs.is_empty() {
220            return Err("No certificates found in file".into());
221        }
222
223        Ok(certs)
224    }
225
226    /// Loads a private key from a PEM file.
227    ///
228    /// # Parameters
229    ///
230    /// * `path` - Path to the private key file
231    ///
232    /// # Returns
233    ///
234    /// The private key in DER format.
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if the private key file cannot be opened, parsed, or if no keys are found.
239    fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>> {
240        let file = File::open(path)?;
241        let mut reader = BufReader::new(file);
242        let keys =
243            rustls_pemfile::pkcs8_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?;
244
245        if keys.is_empty() {
246            return Err("No private key found in file".into());
247        }
248
249        // Use the first private key
250        Ok(PrivateKeyDer::Pkcs8(keys.into_iter().next().unwrap()))
251    }
252
253    /// Creates a TLS client configuration from certificates and keys.
254    ///
255    /// This method attempts to reuse cached configurations when possible.
256    ///
257    /// # Parameters
258    ///
259    /// * `cache_key` - A unique key for caching the configuration
260    /// * `client_cert_path` - Path to client certificate
261    /// * `client_key_path` - Path to client private key
262    /// * `ca_cert_path` - Path to CA certificate
263    ///
264    /// # Returns
265    ///
266    /// A TLS client configuration or an error.
267    async fn create_tls_config(
268        &self,
269        cache_key: &str,
270        client_cert_path: &Path,
271        client_key_path: &Path,
272        ca_cert_path: &Path,
273    ) -> Result<Arc<rustls::ClientConfig>, Box<dyn std::error::Error>> {
274        // Check if we have a cached configuration
275        {
276            let cache = self.cert_cache.lock().await;
277            if let Some(config) = cache.get(cache_key) {
278                info!("Using cached TLS configuration");
279                return Ok(config.clone());
280            }
281        }
282
283        // Load certificates and keys
284        let client_certs = Self::load_certs(client_cert_path)?;
285        let client_key = Self::load_private_key(client_key_path)?;
286        let ca_certs = Self::load_certs(ca_cert_path)?;
287
288        // Create TLS configuration
289        let mut root_store = RootCertStore::empty();
290        for cert in ca_certs {
291            root_store.add(cert)?;
292        }
293
294        let client_config = rustls::ClientConfig::builder()
295            .with_root_certificates(root_store)
296            .with_client_auth_cert(client_certs, client_key)?;
297
298        let config = Arc::new(client_config);
299
300        // Cache the configuration
301        {
302            let mut cache = self.cert_cache.lock().await;
303            cache.insert(cache_key.to_string(), config.clone());
304        }
305
306        Ok(config)
307    }
308
309    /// Connects to a WebSocket server using TLS.
310    ///
311    /// This method establishes a secure WebSocket connection to the specified server URL
312    /// using the provided certificates and keys.
313    ///
314    /// # Parameters
315    ///
316    /// * `server_url` - The WebSocket server URL (e.g., "wss://example.com:9000")
317    /// * `cert_dir` - Directory containing the certificate files
318    /// * `client_cert_file` - Client certificate filename
319    /// * `client_key_file` - Client private key filename
320    /// * `ca_cert_file` - CA certificate filename
321    ///
322    /// # Returns
323    ///
324    /// `Ok(())` on successful connection, or an error if the connection fails.
325    ///
326    /// # Errors
327    ///
328    /// Returns an error if URL parsing fails, certificate loading fails, or connection fails.
329    pub async fn connect(
330        &mut self,
331        server_url: &str,
332        cert_dir: &str,
333        client_cert_file: &str,
334        client_key_file: &str,
335        ca_cert_file: &str,
336    ) -> Result<(), Box<dyn std::error::Error>> {
337        // Parse server URL
338        let server_url = Url::parse(server_url)?;
339
340        // Save connection parameters for potential reconnection
341        self.server_url = Some(server_url.clone());
342        self.cert_paths = Some((
343            server_url.to_string(),
344            cert_dir.to_string(),
345            client_cert_file.to_string(),
346            client_key_file.to_string(),
347            ca_cert_file.to_string(),
348        ));
349
350        // Handle connection with retries
351        let mut current_attempt = 0;
352        loop {
353            // Perform the connection attempt
354            let result = self
355                .connect_internal(
356                    &server_url,
357                    cert_dir,
358                    client_cert_file,
359                    client_key_file,
360                    ca_cert_file,
361                    current_attempt,
362                )
363                .await;
364
365            // Check if we need to retry based on the special error
366            match result {
367                Err(e) => {
368                    let err_str = e.to_string();
369                    if err_str.starts_with("__RETRY_CONNECTION_") {
370                        // Parse the attempt number
371                        if let Ok(next_attempt) = err_str
372                            .trim_start_matches("__RETRY_CONNECTION_")
373                            .parse::<u32>()
374                        {
375                            current_attempt = next_attempt;
376                            // Wait before retrying
377                            tokio::time::sleep(self.config.reconnect_delay).await;
378                            continue;
379                        }
380                    }
381                    return Err(e);
382                }
383                Ok(_) => return Ok(()),
384            }
385        }
386    }
387
388    /// Internal connect method that handles reconnection attempts
389    ///
390    /// This function uses manual reconnection logic instead of recursive calls
391    /// to avoid boxing issues with async functions.
392    async fn connect_internal(
393        &mut self,
394        server_url: &Url,
395        cert_dir: &str,
396        client_cert_file: &str,
397        client_key_file: &str,
398        ca_cert_file: &str,
399        attempt: u32,
400    ) -> Result<(), Box<dyn std::error::Error>> {
401        // Certificate paths
402        let cert_dir = Path::new(cert_dir);
403        let client_cert = cert_dir.join(client_cert_file);
404        let client_key = cert_dir.join(client_key_file);
405        let ca_cert = cert_dir.join(ca_cert_file);
406
407        info!("Client certificate: {:?}", client_cert);
408        info!("Client private key: {:?}", client_key);
409        info!("CA certificate: {:?}", ca_cert);
410
411        // Create a cache key for the TLS configuration
412        let cache_key = format!(
413            "{}:{}:{}:{}",
414            server_url,
415            client_cert.display(),
416            client_key.display(),
417            ca_cert.display()
418        );
419
420        info!("Loading certificates and keys...");
421        let tls_config = match self
422            .create_tls_config(&cache_key, &client_cert, &client_key, &ca_cert)
423            .await
424        {
425            Ok(config) => config,
426            Err(e) => {
427                error!("Failed to create TLS configuration: {}", e);
428                return Err(e);
429            }
430        };
431
432        // Create TLS connector
433        let connector = Connector::Rustls(tls_config);
434
435        // Connect to WebSocket server with timeout
436        info!("Connecting to WebSocket server: {}", server_url);
437        // Use timeout for connection attempt
438        let connection_attempt =
439            connect_async_tls_with_config(server_url.clone(), None, false, Some(connector));
440        let ws_stream = match timeout(self.config.connection_timeout, connection_attempt).await {
441            Ok(result) => {
442                match result {
443                    Ok((stream, _)) => stream,
444                    Err(e) => {
445                        error!("Connection error: {}", e);
446
447                        // Handle reconnection if enabled
448                        if self.config.auto_reconnect
449                            && attempt < self.config.max_reconnect_attempts
450                        {
451                            warn!(
452                                "Reconnection attempt {}/{} in {}s",
453                                attempt + 1,
454                                self.config.max_reconnect_attempts,
455                                self.config.reconnect_delay.as_secs()
456                            );
457
458                            // Wait before attempting to reconnect
459                            tokio::time::sleep(self.config.reconnect_delay).await;
460
461                            // Rather than making a recursive call, we'll return a special error
462                            // that indicates we should retry the connection
463                            return Err(format!("__RETRY_CONNECTION_{}", attempt + 1).into());
464                        }
465
466                        return Err(e.into());
467                    }
468                }
469            }
470            Err(_) => {
471                let err = format!(
472                    "Connection timeout after {:?}",
473                    self.config.connection_timeout
474                );
475                error!("{}", err);
476
477                // Handle reconnection if enabled
478                if self.config.auto_reconnect && attempt < self.config.max_reconnect_attempts {
479                    warn!(
480                        "Reconnection attempt {}/{} in {}s",
481                        attempt + 1,
482                        self.config.max_reconnect_attempts,
483                        self.config.reconnect_delay.as_secs()
484                    );
485
486                    // Wait before attempting to reconnect
487                    tokio::time::sleep(self.config.reconnect_delay).await;
488
489                    // Rather than making a recursive call, we'll return a special error
490                    // that indicates we should retry the connection
491                    return Err(format!("__RETRY_CONNECTION_{}", attempt + 1).into());
492                }
493
494                return Err(err.into());
495            }
496        };
497
498        info!("Connected to WebSocket server");
499
500        // Create channels for message passing with configured capacity
501        let (tx_sender, mut rx_sender) = mpsc::channel::<MessageType>(self.config.channel_capacity);
502        let (tx_receiver, rx_receiver) = mpsc::channel::<MessageType>(self.config.channel_capacity);
503
504        // Split connection into sender and receiver
505        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
506
507        // Task for handling outgoing messages
508        let send_task = tokio::spawn(async move {
509            while let Some(message) = rx_sender.recv().await {
510                let ws_message = match message {
511                    MessageType::Text(text) => Message::Text(text),
512                    MessageType::Binary(data) => Message::Binary(data),
513                };
514
515                match ws_sender.send(ws_message).await {
516                    Ok(_) => info!("Message sent"),
517                    Err(e) => {
518                        error!("Error sending message: {}", e);
519                        break;
520                    }
521                }
522            }
523            // Close WebSocket connection
524            let _ = ws_sender.close().await;
525        });
526
527        // Task for handling incoming messages
528        let receive_task = tokio::spawn(async move {
529            while let Some(msg) = ws_receiver.next().await {
530                match msg {
531                    Ok(msg) => {
532                        let message = match msg {
533                            Message::Text(text) => {
534                                info!("Received text message: {} bytes", text.len());
535                                MessageType::Text(text)
536                            }
537                            Message::Binary(data) => {
538                                info!("Received binary message: {} bytes", data.len());
539                                MessageType::Binary(data)
540                            }
541                            Message::Ping(_) | Message::Pong(_) => {
542                                // Handle ping/pong internally
543                                continue;
544                            }
545                            Message::Close(_) => {
546                                info!("Received close frame");
547                                break;
548                            }
549                            // Handle other message types if needed
550                            _ => continue,
551                        };
552
553                        if let Err(e) = tx_receiver.send(message).await {
554                            error!("Error forwarding to receiver channel: {}", e);
555                            break;
556                        }
557                    }
558                    Err(e) => {
559                        error!("Error receiving message: {}", e);
560                        break;
561                    }
562                }
563            }
564        });
565
566        // Combine tasks with select to handle termination
567        let handle = tokio::spawn(async move {
568            tokio::select! {
569                _ = send_task => info!("Send task completed"),
570                _ = receive_task => info!("Receive task completed"),
571            }
572        });
573
574        // Update client state
575        self.sender = Some(tx_sender);
576        self.receiver = Some(rx_receiver);
577        self.ws_handle = Some(handle);
578        self.is_connected = true;
579
580        Ok(())
581    }
582
583    /// Reconnects to the WebSocket server using the last connection parameters.
584    ///
585    /// # Returns
586    ///
587    /// `Ok(())` on successful reconnection, or an error if reconnection fails.
588    ///
589    /// # Errors
590    ///
591    /// Returns an error if no previous connection exists or if reconnection fails.
592    pub async fn reconnect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
593        if let Some((url, cert_dir, client_cert, client_key, ca_cert)) = self.cert_paths.clone() {
594            // Close existing connection if any
595            if self.is_connected {
596                self.close().await;
597            }
598
599            // Connect using saved parameters
600            self.connect(&url, &cert_dir, &client_cert, &client_key, &ca_cert)
601                .await
602        } else {
603            Err("No previous connection parameters available for reconnection".into())
604        }
605    }
606
607    /// Sends a message to the connected WebSocket server.
608    ///
609    /// # Parameters
610    ///
611    /// * `message` - The message to send (text or binary)
612    ///
613    /// # Returns
614    ///
615    /// `Ok(())` if the message was queued for sending, or an error if not connected.
616    ///
617    /// # Errors
618    ///
619    /// Returns an error if the client is not connected or if the message cannot be sent.
620    pub async fn send_message(
621        &self,
622        message: MessageType,
623    ) -> Result<(), Box<dyn std::error::Error>> {
624        if let Some(sender) = &self.sender {
625            sender.send(message).await?;
626            Ok(())
627        } else {
628            Err("Not connected to WebSocket server".into())
629        }
630    }
631
632    /// Sends a text message to the connected WebSocket server.
633    ///
634    /// This is a convenience method that wraps send_message.
635    ///
636    /// # Parameters
637    ///
638    /// * `text` - The text message to send
639    ///
640    /// # Returns
641    ///
642    /// `Ok(())` if the message was queued for sending, or an error if not connected.
643    pub async fn send_text(&self, text: String) -> Result<(), Box<dyn std::error::Error>> {
644        self.send_message(MessageType::Text(text)).await
645    }
646
647    /// Sends a binary message to the connected WebSocket server.
648    ///
649    /// This is a convenience method that wraps send_message.
650    ///
651    /// # Parameters
652    ///
653    /// * `data` - The binary data to send
654    ///
655    /// # Returns
656    ///
657    /// `Ok(())` if the message was queued for sending, or an error if not connected.
658    pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), Box<dyn std::error::Error>> {
659        self.send_message(MessageType::Binary(data)).await
660    }
661
662    /// Receives a message from the WebSocket server.
663    ///
664    /// This method waits for the next message from the server. If no message
665    /// is available or the connection is closed, it returns `None`.
666    ///
667    /// # Returns
668    ///
669    /// * `Some(MessageType)` - The received message (text or binary)
670    /// * `None` - If not connected or the connection was closed
671    pub async fn receive_message(&mut self) -> Option<MessageType> {
672        if let Some(receiver) = &mut self.receiver {
673            receiver.recv().await
674        } else {
675            None
676        }
677    }
678
679    /// Receives a message with timeout.
680    ///
681    /// This method waits for the next message from the server with a timeout.
682    ///
683    /// # Parameters
684    ///
685    /// * `timeout_duration` - Maximum time to wait for a message
686    ///
687    /// # Returns
688    ///
689    /// * `Ok(Some(MessageType))` - A message was received
690    /// * `Ok(None)` - No message received (not connected)
691    /// * `Err(_)` - Timeout occurred
692    pub async fn receive_message_timeout(
693        &mut self,
694        timeout_duration: Duration,
695    ) -> Result<Option<MessageType>, tokio::time::error::Elapsed> {
696        if let Some(receiver) = &mut self.receiver {
697            timeout(timeout_duration, receiver.recv()).await
698        } else {
699            Ok(None)
700        }
701    }
702
703    /// Checks if the client is connected to a WebSocket server.
704    ///
705    /// # Returns
706    ///
707    /// `true` if connected, `false` otherwise.
708    pub fn is_connected(&self) -> bool {
709        self.is_connected
710    }
711
712    /// Closes the WebSocket connection.
713    ///
714    /// This method gracefully shuts down the connection by:
715    /// 1. Dropping the sender channel to trigger closing the WebSocket
716    /// 2. Waiting for the worker task to complete
717    /// 3. Cleaning up resources
718    ///
719    /// The client can be reconnected after closing by calling `connect()` again.
720    pub async fn close(&mut self) {
721        // Drop the sender channel to trigger close operation
722        self.sender = None;
723
724        // Wait for the main task to complete
725        if let Some(handle) = self.ws_handle.take() {
726            let _ = handle.await;
727        }
728
729        self.receiver = None;
730        self.is_connected = false;
731
732        info!("WebSocket connection closed");
733    }
734
735    /// Sends a ping message to check connection health.
736    ///
737    /// This method can be used to keep the connection alive or
738    /// check if the server is still responsive.
739    ///
740    /// # Returns
741    ///
742    /// `Ok(())` if the ping was sent, or an error if not connected.
743    pub async fn ping(&self) -> Result<(), Box<dyn std::error::Error>> {
744        if let Some(sender) = &self.sender {
745            // Use an empty binary message as a ping
746            sender.send(MessageType::Binary(Vec::new())).await?;
747            Ok(())
748        } else {
749            Err("Not connected to WebSocket server".into())
750        }
751    }
752
753    /// Clears the certificate cache.
754    ///
755    /// This method can be useful to force reloading of certificates
756    /// if they have been updated on disk.
757    pub async fn clear_cert_cache(&self) {
758        let mut cache = self.cert_cache.lock().await;
759        cache.clear();
760        info!("Certificate cache cleared");
761    }
762
763    /// Checks if a connection is active and sends a ping to verify connectivity.
764    ///
765    /// Returns true if the connection is active and responsive.
766    pub async fn check_connection(&self) -> bool {
767        if !self.is_connected {
768            return false;
769        }
770
771        match self.ping().await {
772            Ok(_) => true,
773            Err(_) => false,
774        }
775    }
776
777    /// Gets the current configuration.
778    ///
779    /// # Returns
780    ///
781    /// A reference to the current client configuration.
782    pub fn get_config(&self) -> &WSClientConfig {
783        &self.config
784    }
785}
786
787impl Drop for WebSocketClient {
788    fn drop(&mut self) {
789        // If the client is still connected when going out of scope,
790        // drop all channels to allow resources to be cleaned up
791        self.sender = None;
792        self.receiver = None;
793
794        // Drop the task handle, allowing it to complete on its own
795        self.ws_handle = None;
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802
803    #[test]
804    fn test_message_type() {
805        let text = MessageType::Text("hello".to_string());
806        let binary = MessageType::Binary(vec![1, 2, 3]);
807
808        match text {
809            MessageType::Text(s) => assert_eq!(s, "hello"),
810            _ => panic!("Expected Text variant"),
811        }
812
813        match binary {
814            MessageType::Binary(b) => assert_eq!(b, vec![1, 2, 3]),
815            _ => panic!("Expected Binary variant"),
816        }
817    }
818
819    #[test]
820    fn test_client_config_default() {
821        let config = WSClientConfig::default();
822        assert_eq!(config.channel_capacity, 100);
823        assert_eq!(config.connection_timeout, Duration::from_secs(30));
824        assert_eq!(config.auto_reconnect, false);
825    }
826
827    #[test]
828    fn test_client_builder() {
829        let client = WebSocketClient::builder()
830            .with_channel_capacity(200)
831            .with_connection_timeout(Duration::from_secs(10))
832            .with_auto_reconnect(true)
833            .build();
834
835        assert_eq!(client.config.channel_capacity, 200);
836        assert_eq!(client.config.connection_timeout, Duration::from_secs(10));
837        assert_eq!(client.config.auto_reconnect, true);
838    }
839}