Skip to main content

rust_serv/server/
websocket.rs

1//! WebSocket Server implementation
2//!
3//! This module implements WebSocket functionality:
4//! - WebSocket protocol handshake
5//! - Frame handling and parsing
6//! - Connection management
7//! - Message broadcasting
8//! - Heartbeat and ping/pong
9
10use crate::config::Config;
11use crate::error::Error;
12use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
13use hyper::header::{CONNECTION, UPGRADE, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION};
14use hyper::{Request, Response, StatusCode, Method};
15use http_body_util::Full;
16use hyper::body::Bytes;
17use sha1::{Digest, Sha1};
18use std::collections::HashMap;
19use std::sync::Arc;
20use tokio::sync::{mpsc, RwLock};
21use tokio_tungstenite::tungstenite::protocol::Message;
22use tokio_tungstenite::WebSocketStream;
23use futures::{StreamExt, SinkExt};
24
25/// WebSocket message types
26#[derive(Debug, Clone, PartialEq)]
27pub enum WebSocketMessage {
28    /// Text message
29    Text(String),
30    /// Binary message
31    Binary(tungstenite::Bytes),
32    /// Ping message
33    Ping(tungstenite::Bytes),
34    /// Pong message
35    Pong(tungstenite::Bytes),
36    /// Close message
37    Close { code: u16, reason: String },
38}
39
40/// WebSocket connection info
41#[derive(Debug, Clone)]
42pub struct WebSocketConnection {
43    /// Connection ID
44    pub id: String,
45    /// Remote address
46    pub remote_addr: String,
47    /// Connected timestamp
48    pub connected_at: std::time::Instant,
49}
50
51/// WebSocket server with connection management
52pub struct WebSocketServer {
53    connections: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<WebSocketMessage>>>>,
54}
55
56impl WebSocketServer {
57    /// Create a new WebSocket server
58    pub fn new(_config: Config) -> Self {
59        Self {
60            connections: Arc::new(RwLock::new(HashMap::new())),
61        }
62    }
63
64    /// Check if request is a WebSocket upgrade request
65    pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
66        req.method() == Method::GET
67            && req.headers().get(UPGRADE)
68                .and_then(|v| v.to_str().ok())
69                .map(|v| v.to_lowercase() == "websocket")
70                .unwrap_or(false)
71            && req.headers().get(CONNECTION)
72                .and_then(|v| v.to_str().ok())
73                .map(|v| v.to_lowercase().contains("upgrade"))
74                .unwrap_or(false)
75    }
76
77    /// Perform WebSocket handshake
78    pub fn handshake<B>(req: &Request<B>) -> Result<Response<Full<Bytes>>, Error> {
79        // Validate WebSocket version
80        let ws_version = req.headers()
81            .get(SEC_WEBSOCKET_VERSION)
82            .and_then(|v| v.to_str().ok())
83            .ok_or_else(|| Error::Http("Missing WebSocket-Version header".to_string()))?;
84
85        if ws_version != "13" {
86            return Err(Error::Http("Unsupported WebSocket version".to_string()));
87        }
88
89        // Get WebSocket key
90        let ws_key = req.headers()
91            .get(SEC_WEBSOCKET_KEY)
92            .and_then(|v| v.to_str().ok())
93            .ok_or_else(|| Error::Http("Missing WebSocket-Key header".to_string()))?;
94
95        // Generate accept key
96        let accept_key = Self::generate_accept_key(ws_key)?;
97
98        // Build handshake response
99        let response = Response::builder()
100            .status(StatusCode::SWITCHING_PROTOCOLS)
101            .header(UPGRADE, "websocket")
102            .header(CONNECTION, "Upgrade")
103            .header(SEC_WEBSOCKET_ACCEPT, accept_key)
104            .body(Full::new(Bytes::new()))
105            .map_err(|e| Error::Internal(format!("Failed to build response: {}", e)))?;
106
107        Ok(response)
108    }
109
110    /// Generate WebSocket accept key
111    fn generate_accept_key(ws_key: &str) -> Result<String, Error> {
112        let magic_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
113        let combined = format!("{}{}", ws_key, magic_guid);
114
115        let mut hasher = Sha1::new();
116        hasher.update(combined.as_bytes());
117        let hash = hasher.finalize();
118
119        let accept_key = BASE64.encode(&hash);
120        Ok(accept_key)
121    }
122
123    /// Add a new WebSocket connection
124    pub async fn add_connection(&self, id: String, sender: mpsc::UnboundedSender<WebSocketMessage>) {
125        let mut connections = self.connections.write().await;
126        connections.insert(id, sender);
127    }
128
129    /// Remove a WebSocket connection
130    pub async fn remove_connection(&self, id: &str) {
131        let mut connections = self.connections.write().await;
132        connections.remove(id);
133    }
134
135    /// Broadcast a message to all connected clients
136    pub async fn broadcast(&self, message: WebSocketMessage) {
137        let connections = self.connections.read().await;
138        let mut failed_connections = Vec::new();
139
140        for (id, sender) in connections.iter() {
141            if sender.send(message.clone()).is_err() {
142                failed_connections.push(id.clone());
143            }
144        }
145
146        // Remove failed connections
147        if !failed_connections.is_empty() {
148            drop(connections);
149            let mut connections = self.connections.write().await;
150            for id in failed_connections {
151                connections.remove(&id);
152            }
153        }
154    }
155
156    /// Send a message to a specific connection
157    pub async fn send_to(&self, id: &str, message: WebSocketMessage) -> Result<(), Error> {
158        let connections = self.connections.read().await;
159        if let Some(sender) = connections.get(id) {
160            sender.send(message)
161                .map_err(|_| Error::Http("Connection closed".to_string()))?;
162        } else {
163            return Err(Error::Http("Connection not found".to_string()));
164        }
165        Ok(())
166    }
167
168    /// Get the number of active connections
169    pub async fn connection_count(&self) -> usize {
170        let connections = self.connections.read().await;
171        connections.len()
172    }
173
174    /// Get a list of all active connections
175    pub async fn list_connections(&self) -> Vec<String> {
176        let connections = self.connections.read().await;
177        connections.keys().cloned().collect()
178    }
179
180    /// Handle incoming WebSocket message
181    pub fn handle_message(&self, message: Message) -> Result<Option<WebSocketMessage>, Error> {
182        match message {
183            Message::Text(text) => Ok(Some(WebSocketMessage::Text(text.to_string()))),
184            Message::Binary(data) => Ok(Some(WebSocketMessage::Binary(data))),
185            Message::Ping(data) => {
186                // Auto-respond to ping with pong
187                Ok(Some(WebSocketMessage::Pong(data)))
188            }
189            Message::Pong(data) => Ok(Some(WebSocketMessage::Pong(data))),
190            Message::Close(frame) => {
191                if let Some(frame) = frame {
192                    Ok(Some(WebSocketMessage::Close {
193                        code: frame.code.into(),
194                        reason: frame.reason.to_string(),
195                    }))
196                } else {
197                    Ok(Some(WebSocketMessage::Close {
198                        code: 1000,
199                        reason: String::new(),
200                    }))
201                }
202            }
203            Message::Frame(_) => {
204                // Raw frames - ignore for now
205                Ok(None)
206            }
207        }
208    }
209
210    /// Convert WebSocket message to tungstenite message
211    pub fn to_tungstenite_message(&self, message: &WebSocketMessage) -> Message {
212        match message {
213            WebSocketMessage::Text(text) => Message::Text(text.clone().into()),
214            WebSocketMessage::Binary(data) => Message::Binary(data.clone()),
215            WebSocketMessage::Ping(data) => Message::Ping(data.clone()),
216            WebSocketMessage::Pong(data) => Message::Pong(data.clone()),
217            WebSocketMessage::Close { code, reason } => {
218                Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
219                    code: tungstenite::protocol::frame::coding::CloseCode::from(*code),
220                    reason: reason.clone().into(),
221                }))
222            }
223        }
224    }
225}
226
227/// WebSocket connection handler
228pub struct WebSocketHandler {
229    server: Arc<WebSocketServer>,
230    connection_id: String,
231    receiver: Option<mpsc::UnboundedReceiver<WebSocketMessage>>,
232}
233
234impl WebSocketHandler {
235    /// Create a new WebSocket handler
236    pub fn new(server: Arc<WebSocketServer>, connection_id: String) -> (Self, mpsc::UnboundedSender<WebSocketMessage>) {
237        let (sender, receiver) = mpsc::unbounded_channel();
238
239        let handler = Self {
240            server,
241            connection_id,
242            receiver: Some(receiver),
243        };
244
245        (handler, sender)
246    }
247
248    /// Handle WebSocket connection
249    pub async fn handle_connection<T>(&mut self, mut ws_stream: WebSocketStream<T>) -> Result<(), Error>
250    where
251        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static,
252    {
253        let mut receiver = self.receiver.take().ok_or_else(|| Error::Internal("Receiver already taken".to_string()))?;
254        let server = self.server.clone();
255
256        // Handle incoming messages
257        loop {
258            tokio::select! {
259                // Receive from server broadcast
260                msg = receiver.recv() => {
261                    match msg {
262                        Some(message) => {
263                            let tungstenite_msg = match message {
264                                WebSocketMessage::Text(text) => Message::Text(text.into()),
265                                WebSocketMessage::Binary(data) => Message::Binary(data),
266                                WebSocketMessage::Ping(data) => Message::Ping(data),
267                                WebSocketMessage::Pong(data) => Message::Pong(data),
268                                WebSocketMessage::Close { code, reason } => Message::Close(Some(
269                                    tungstenite::protocol::frame::CloseFrame {
270                                        code: tungstenite::protocol::frame::coding::CloseCode::from(code),
271                                        reason: reason.into(),
272                                    }
273                                )),
274                            };
275
276                            if tungstenite_msg.is_close() || SinkExt::send(&mut ws_stream, tungstenite_msg).await.is_err() {
277                                break;
278                            }
279                        }
280                        None => break,
281                    }
282                }
283                // Receive from WebSocket
284                result = ws_stream.next() => {
285                    match result {
286                        Some(Ok(message)) => {
287                            if let Some(ws_message) = server.handle_message(message)? {
288                                // Handle special messages
289                                if matches!(ws_message, WebSocketMessage::Close { .. }) {
290                                    break;
291                                }
292                            }
293                        }
294                        Some(Err(e)) => {
295                            tracing::error!("WebSocket error: {}", e);
296                            break;
297                        }
298                        None => break,
299                    }
300                }
301            }
302        }
303
304        // Clean up
305        server.remove_connection(&self.connection_id).await;
306        tracing::info!("WebSocket connection {} closed", self.connection_id);
307
308        Ok(())
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use hyper::Request;
316
317    #[test]
318    fn test_websocket_upgrade_detection() {
319        let config = Config::default();
320        let server = WebSocketServer::new(config);
321
322        let req = Request::builder()
323            .method("GET")
324            .uri("/ws")
325            .header("Upgrade", "websocket")
326            .header("Connection", "Upgrade")
327            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
328            .header("Sec-WebSocket-Version", "13")
329            .body(Bytes::new())
330            .unwrap();
331
332        assert!(WebSocketServer::is_websocket_upgrade(&req));
333    }
334
335    #[test]
336    fn test_non_websocket_request() {
337        let config = Config::default();
338        let server = WebSocketServer::new(config);
339
340        let req = Request::builder()
341            .method("GET")
342            .uri("/")
343            .body(Bytes::new())
344            .unwrap();
345
346        assert!(!WebSocketServer::is_websocket_upgrade(&req));
347    }
348
349    #[test]
350    fn test_websocket_version_validation() {
351        let config = Config::default();
352        let server = WebSocketServer::new(config);
353
354        // Missing WebSocket-Version header
355        let req = Request::builder()
356            .method("GET")
357            .uri("/ws")
358            .header("Upgrade", "websocket")
359            .header("Connection", "Upgrade")
360            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
361            .body(Bytes::new())
362            .unwrap();
363
364        assert!(WebSocketServer::handshake(&req).is_err());
365
366        // Invalid WebSocket-Version
367        let req = Request::builder()
368            .method("GET")
369            .uri("/ws")
370            .header("Upgrade", "websocket")
371            .header("Connection", "Upgrade")
372            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
373            .header("Sec-WebSocket-Version", "12")
374            .body(Bytes::new())
375            .unwrap();
376
377        assert!(WebSocketServer::handshake(&req).is_err());
378    }
379
380    #[test]
381    fn test_generate_accept_key() {
382        let ws_key = "dGhlIHNhbXBsZSBub25jZQ==";
383        let accept_key = WebSocketServer::generate_accept_key(ws_key).unwrap();
384
385        // Known test value from RFC 6455
386        assert_eq!(accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
387    }
388
389    #[test]
390    fn test_message_conversion() {
391        let config = Config::default();
392        let server = WebSocketServer::new(config);
393
394        // Test text message
395        let text_msg = WebSocketMessage::Text("Hello".to_string());
396        let converted = server.to_tungstenite_message(&text_msg);
397        assert!(matches!(converted, Message::Text(_)));
398
399        // Test binary message
400        let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
401        let converted = server.to_tungstenite_message(&binary_msg);
402        assert!(matches!(converted, Message::Binary(_)));
403
404        // Test close message
405        let close_msg = WebSocketMessage::Close {
406            code: 1000,
407            reason: "Normal closure".to_string(),
408        };
409        let converted = server.to_tungstenite_message(&close_msg);
410        assert!(matches!(converted, Message::Close(_)));
411    }
412
413    #[test]
414    fn test_connection_management() {
415        let config = Config::default();
416        let server = WebSocketServer::new(config);
417
418        // Test connection count
419        futures::executor::block_on(async {
420            assert_eq!(server.connection_count().await, 0);
421
422            // Add a connection
423            let (sender, _) = mpsc::unbounded_channel();
424            server.add_connection("test_conn".to_string(), sender).await;
425            assert_eq!(server.connection_count().await, 1);
426
427            // Remove connection
428            server.remove_connection("test_conn").await;
429            assert_eq!(server.connection_count().await, 0);
430
431            // List connections
432            let (sender1, _) = mpsc::unbounded_channel();
433            let (sender2, _) = mpsc::unbounded_channel();
434            server.add_connection("conn1".to_string(), sender1).await;
435            server.add_connection("conn2".to_string(), sender2).await;
436
437            let connections = server.list_connections().await;
438            assert_eq!(connections.len(), 2);
439            assert!(connections.contains(&"conn1".to_string()));
440            assert!(connections.contains(&"conn2".to_string()));
441        });
442    }
443
444    #[test]
445    fn test_send_to_connection() {
446        let config = Config::default();
447        let server = WebSocketServer::new(config);
448
449        futures::executor::block_on(async {
450            // Create a connection with a receiver
451            let (sender, mut receiver) = mpsc::unbounded_channel();
452            server.add_connection("test_conn".to_string(), sender).await;
453
454            // Send a message to the connection
455            let message = WebSocketMessage::Text("Hello".to_string());
456            let result = server.send_to("test_conn", message.clone()).await;
457            assert!(result.is_ok());
458
459            // Verify the message was received
460            let received = receiver.recv().await;
461            assert!(received.is_some());
462            assert_eq!(received.unwrap(), WebSocketMessage::Text("Hello".to_string()));
463
464            // Try to send to non-existent connection
465            let result = server.send_to("nonexistent", message).await;
466            assert!(result.is_err());
467        });
468    }
469
470    #[test]
471    fn test_broadcast_message() {
472        let config = Config::default();
473        let server = WebSocketServer::new(config);
474
475        futures::executor::block_on(async {
476            // Create multiple connections
477            let (sender1, mut receiver1) = mpsc::unbounded_channel();
478            let (sender2, mut receiver2) = mpsc::unbounded_channel();
479            let (sender3, mut receiver3) = mpsc::unbounded_channel();
480
481            server.add_connection("conn1".to_string(), sender1).await;
482            server.add_connection("conn2".to_string(), sender2).await;
483            server.add_connection("conn3".to_string(), sender3).await;
484
485            // Broadcast a message
486            let message = WebSocketMessage::Text("Broadcast message".to_string());
487            server.broadcast(message).await;
488
489            // Verify all connections received the message
490            assert_eq!(
491                receiver1.recv().await.unwrap(),
492                WebSocketMessage::Text("Broadcast message".to_string())
493            );
494            assert_eq!(
495                receiver2.recv().await.unwrap(),
496                WebSocketMessage::Text("Broadcast message".to_string())
497            );
498            assert_eq!(
499                receiver3.recv().await.unwrap(),
500                WebSocketMessage::Text("Broadcast message".to_string())
501            );
502        });
503    }
504
505    #[test]
506    fn test_handle_text_message() {
507        let config = Config::default();
508        let server = WebSocketServer::new(config);
509
510        let message = Message::Text("Hello".into());
511        let result = server.handle_message(message).unwrap();
512
513        assert!(result.is_some());
514        assert_eq!(result.unwrap(), WebSocketMessage::Text("Hello".to_string()));
515    }
516
517    #[test]
518    fn test_handle_binary_message() {
519        let config = Config::default();
520        let server = WebSocketServer::new(config);
521
522        let message = Message::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
523        let result = server.handle_message(message).unwrap();
524
525        assert!(result.is_some());
526        let ws_message = result.unwrap();
527        assert!(matches!(ws_message, WebSocketMessage::Binary(_)));
528    }
529
530    #[test]
531    fn test_handle_ping_message() {
532        let config = Config::default();
533        let server = WebSocketServer::new(config);
534
535        let message = Message::Ping(tungstenite::Bytes::from(vec![1, 2, 3]));
536        let result = server.handle_message(message).unwrap();
537
538        assert!(result.is_some());
539        let ws_message = result.unwrap();
540        assert!(matches!(ws_message, WebSocketMessage::Pong(_)));
541    }
542
543    #[test]
544    fn test_handle_pong_message() {
545        let config = Config::default();
546        let server = WebSocketServer::new(config);
547
548        let message = Message::Pong(tungstenite::Bytes::from(vec![1, 2, 3]));
549        let result = server.handle_message(message).unwrap();
550
551        assert!(result.is_some());
552        let ws_message = result.unwrap();
553        assert!(matches!(ws_message, WebSocketMessage::Pong(_)));
554    }
555
556    #[test]
557    fn test_handle_close_message_with_frame() {
558        let config = Config::default();
559        let server = WebSocketServer::new(config);
560
561        // Create a close message with a frame using Message::Close
562        let message = Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
563            code: tungstenite::protocol::frame::coding::CloseCode::Normal,
564            reason: "Normal closure".into(),
565        }));
566        let result = server.handle_message(message).unwrap();
567
568        assert!(result.is_some());
569        let ws_message = result.unwrap();
570        match ws_message {
571            WebSocketMessage::Close { code, reason } => {
572                assert_eq!(code, 1000);
573                assert_eq!(reason, "Normal closure");
574            }
575            _ => panic!("Expected Close message"),
576        }
577    }
578
579    #[test]
580    fn test_handle_close_message_without_frame() {
581        let config = Config::default();
582        let server = WebSocketServer::new(config);
583
584        let message = Message::Close(None);
585        let result = server.handle_message(message).unwrap();
586
587        assert!(result.is_some());
588        let ws_message = result.unwrap();
589        match ws_message {
590            WebSocketMessage::Close { code, reason } => {
591                assert_eq!(code, 1000);
592                assert_eq!(reason, "");
593            }
594            _ => panic!("Expected Close message"),
595        }
596    }
597
598    #[test]
599    fn test_to_tungstenite_text() {
600        let config = Config::default();
601        let server = WebSocketServer::new(config);
602
603        let message = WebSocketMessage::Text("Hello".to_string());
604        let converted = server.to_tungstenite_message(&message);
605
606        assert!(matches!(converted, Message::Text(_)));
607        if let Message::Text(text) = converted {
608            assert_eq!(text, "Hello");
609        }
610    }
611
612    #[test]
613    fn test_to_tungstenite_binary() {
614        let config = Config::default();
615        let server = WebSocketServer::new(config);
616
617        let message = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
618        let converted = server.to_tungstenite_message(&message);
619
620        assert!(matches!(converted, Message::Binary(_)));
621    }
622
623    #[test]
624    fn test_to_tungstenite_ping() {
625        let config = Config::default();
626        let server = WebSocketServer::new(config);
627
628        let message = WebSocketMessage::Ping(tungstenite::Bytes::from(vec![1, 2, 3]));
629        let converted = server.to_tungstenite_message(&message);
630
631        assert!(matches!(converted, Message::Ping(_)));
632    }
633
634    #[test]
635    fn test_to_tungstenite_pong() {
636        let config = Config::default();
637        let server = WebSocketServer::new(config);
638
639        let message = WebSocketMessage::Pong(tungstenite::Bytes::from(vec![1, 2, 3]));
640        let converted = server.to_tungstenite_message(&message);
641
642        assert!(matches!(converted, Message::Pong(_)));
643    }
644
645    #[test]
646    fn test_to_tungstenite_close() {
647        let config = Config::default();
648        let server = WebSocketServer::new(config);
649
650        let message = WebSocketMessage::Close {
651            code: 1000,
652            reason: "Normal closure".to_string(),
653        };
654        let converted = server.to_tungstenite_message(&message);
655
656        assert!(matches!(converted, Message::Close(_)));
657        if let Message::Close(Some(frame)) = converted {
658            assert_eq!(frame.code, tungstenite::protocol::frame::coding::CloseCode::Normal);
659            assert_eq!(frame.reason, std::borrow::Cow::from("Normal closure"));
660        } else {
661            panic!("Expected Close message with frame");
662        }
663    }
664
665    #[test]
666    fn test_broadcast_with_closed_connection() {
667        let config = Config::default();
668        let server = WebSocketServer::new(config);
669
670        futures::executor::block_on(async {
671            // Create a connection with a closed sender
672            let (sender, receiver) = mpsc::unbounded_channel();
673            drop(receiver); // Close the receiver
674
675            server.add_connection("closed_conn".to_string(), sender).await;
676
677            // Create a normal connection
678            let (sender2, mut receiver2) = mpsc::unbounded_channel();
679            server.add_connection("normal_conn".to_string(), sender2).await;
680
681            // Broadcast a message
682            let message = WebSocketMessage::Text("Test".to_string());
683            server.broadcast(message).await;
684
685            // The closed connection should be removed
686            let connections = server.list_connections().await;
687            assert!(!connections.contains(&"closed_conn".to_string()));
688            assert!(connections.contains(&"normal_conn".to_string()));
689
690            // The normal connection should still receive the message
691            assert!(receiver2.recv().await.is_some());
692        });
693    }
694
695    #[test]
696    fn test_send_to_connection_success() {
697        let config = Config::default();
698        let server = WebSocketServer::new(config);
699
700        futures::executor::block_on(async {
701            let (sender, mut receiver) = mpsc::unbounded_channel();
702            server.add_connection("test_conn".to_string(), sender).await;
703
704            let message = WebSocketMessage::Text("Hello".to_string());
705            let result = server.send_to("test_conn", message).await;
706            assert!(result.is_ok());
707
708            let received = receiver.recv().await;
709            assert!(received.is_some());
710        });
711    }
712
713    #[test]
714    fn test_send_to_connection_not_found() {
715        let config = Config::default();
716        let server = WebSocketServer::new(config);
717
718        futures::executor::block_on(async {
719            let message = WebSocketMessage::Text("Hello".to_string());
720            let result = server.send_to("nonexistent", message).await;
721            assert!(result.is_err());
722            assert!(result.unwrap_err().to_string().contains("not found"));
723        });
724    }
725
726    #[test]
727    fn test_send_to_connection_closed() {
728        let config = Config::default();
729        let server = WebSocketServer::new(config);
730
731        futures::executor::block_on(async {
732            let (sender, receiver) = mpsc::unbounded_channel();
733            drop(receiver); // Close the receiver immediately
734
735            server.add_connection("test_conn".to_string(), sender).await;
736
737            let message = WebSocketMessage::Text("Hello".to_string());
738            let result = server.send_to("test_conn", message).await;
739            assert!(result.is_err());
740            assert!(result.unwrap_err().to_string().contains("closed"));
741        });
742    }
743
744    #[test]
745    fn test_websocket_upgrade_missing_headers() {
746        let req = Request::builder()
747            .method("GET")
748            .uri("/ws")
749            .body(Full::new(Bytes::new()))
750            .unwrap();
751
752        // Should return false for non-WebSocket request
753        assert!(!WebSocketServer::is_websocket_upgrade(&req));
754    }
755
756    #[test]
757    fn test_websocket_upgrade_only_upgrade_header() {
758        let req = Request::builder()
759            .method("GET")
760            .uri("/ws")
761            .header("Upgrade", "websocket")
762            .body(Full::new(Bytes::new()))
763            .unwrap();
764
765        // Should return false because Connection header is missing
766        assert!(!WebSocketServer::is_websocket_upgrade(&req));
767    }
768
769    #[test]
770    fn test_websocket_upgrade_wrong_method() {
771        let req = Request::builder()
772            .method("POST")
773            .uri("/ws")
774            .header("Upgrade", "websocket")
775            .header("Connection", "Upgrade")
776            .body(Full::new(Bytes::new()))
777            .unwrap();
778
779        // Should return false because method is not GET
780        assert!(!WebSocketServer::is_websocket_upgrade(&req));
781    }
782
783    #[test]
784    fn test_handshake_missing_key() {
785        let req = Request::builder()
786            .method("GET")
787            .uri("/ws")
788            .header("Upgrade", "websocket")
789            .header("Connection", "Upgrade")
790            .header("Sec-WebSocket-Version", "13")
791            .body(Full::new(Bytes::new()))
792            .unwrap();
793
794        let result = WebSocketServer::handshake(&req);
795        assert!(result.is_err());
796    }
797
798    #[test]
799    fn test_handshake_missing_version() {
800        let req = Request::builder()
801            .method("GET")
802            .uri("/ws")
803            .header("Upgrade", "websocket")
804            .header("Connection", "Upgrade")
805            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
806            .body(Full::new(Bytes::new()))
807            .unwrap();
808
809        let result = WebSocketServer::handshake(&req);
810        assert!(result.is_err());
811    }
812
813    #[test]
814    fn test_handshake_wrong_version() {
815        let req = Request::builder()
816            .method("GET")
817            .uri("/ws")
818            .header("Upgrade", "websocket")
819            .header("Connection", "Upgrade")
820            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
821            .header("Sec-WebSocket-Version", "12")
822            .body(Full::new(Bytes::new()))
823            .unwrap();
824
825        let result = WebSocketServer::handshake(&req);
826        assert!(result.is_err());
827    }
828
829    #[tokio::test]
830    async fn test_add_and_remove_connection() {
831        let config = Config::default();
832        let server = WebSocketServer::new(config);
833
834        let (sender1, _) = mpsc::unbounded_channel();
835        let (sender2, _) = mpsc::unbounded_channel();
836
837        server.add_connection("conn1".to_string(), sender1).await;
838        server.add_connection("conn2".to_string(), sender2).await;
839
840        let count = server.connection_count().await;
841        assert_eq!(count, 2);
842
843        server.remove_connection("conn1").await;
844        let count = server.connection_count().await;
845        assert_eq!(count, 1);
846
847        server.remove_connection("conn2").await;
848        let count = server.connection_count().await;
849        assert_eq!(count, 0);
850    }
851
852    #[tokio::test]
853    async fn test_broadcast_empty_connections() {
854        let config = Config::default();
855        let server = WebSocketServer::new(config);
856
857        // Broadcast to empty connections should not fail
858        let message = WebSocketMessage::Text("Hello".to_string());
859        server.broadcast(message).await;
860        // No assertion needed, just ensure no panic
861    }
862
863    #[test]
864    fn test_websocket_upgrade_case_insensitive() {
865        // Test case-insensitive upgrade header
866        let req = Request::builder()
867            .method("GET")
868            .uri("/ws")
869            .header("Upgrade", "WebSocket") // Uppercase
870            .header("Connection", "upgrade") // lowercase
871            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
872            .header("Sec-WebSocket-Version", "13")
873            .body(Full::new(Bytes::new()))
874            .unwrap();
875
876        assert!(WebSocketServer::is_websocket_upgrade(&req));
877
878        // Test mixed case
879        let req2 = Request::builder()
880            .method("GET")
881            .uri("/ws")
882            .header("Upgrade", "WEBSOCKET")
883            .header("Connection", "UPGRADE")
884            .body(Full::new(Bytes::new()))
885            .unwrap();
886
887        assert!(WebSocketServer::is_websocket_upgrade(&req2));
888    }
889
890    #[test]
891    fn test_websocket_upgrade_non_utf8_upgrade_header() {
892        // Test with invalid UTF-8 in Upgrade header using HeaderValue
893        use hyper::header::HeaderValue;
894        
895        let req = Request::builder()
896            .method("GET")
897            .uri("/ws")
898            .header("Upgrade", HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap()) // Invalid UTF-8
899            .header("Connection", "Upgrade")
900            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
901            .header("Sec-WebSocket-Version", "13")
902            .body(Full::new(Bytes::new()))
903            .unwrap();
904
905        // Should return false due to invalid UTF-8
906        assert!(!WebSocketServer::is_websocket_upgrade(&req));
907    }
908
909    #[test]
910    fn test_websocket_upgrade_non_utf8_connection_header() {
911        // Test with invalid UTF-8 in Connection header
912        use hyper::header::HeaderValue;
913        
914        let req = Request::builder()
915            .method("GET")
916            .uri("/ws")
917            .header("Upgrade", "websocket")
918            .header("Connection", HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap()) // Invalid UTF-8
919            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
920            .header("Sec-WebSocket-Version", "13")
921            .body(Full::new(Bytes::new()))
922            .unwrap();
923
924        // Should return false due to invalid UTF-8
925        assert!(!WebSocketServer::is_websocket_upgrade(&req));
926    }
927
928    #[test]
929    fn test_generate_accept_key_empty() {
930        // Test with empty key
931        let accept_key = WebSocketServer::generate_accept_key("").unwrap();
932        // Known value: SHA1 of empty + magic GUID, base64 encoded
933        assert_eq!(accept_key.len(), 28); // Base64 encoded SHA1 is always 28 chars
934        assert!(!accept_key.is_empty());
935    }
936
937    #[test]
938    fn test_generate_accept_key_long_key() {
939        // Test with a long key
940        let long_key = "a".repeat(100);
941        let accept_key = WebSocketServer::generate_accept_key(&long_key).unwrap();
942        assert_eq!(accept_key.len(), 28); // Base64 encoded SHA1 is always 28 chars
943    }
944
945    #[test]
946    fn test_generate_accept_key_special_chars() {
947        // Test with special characters
948        let special_key = "+/=special&chars%$#@!";
949        let accept_key = WebSocketServer::generate_accept_key(special_key).unwrap();
950        assert_eq!(accept_key.len(), 28);
951    }
952
953    #[test]
954    fn test_handle_frame_message() {
955        let config = Config::default();
956        let server = WebSocketServer::new(config);
957
958        // Test handling of raw Frame message - should return Ok(None)
959        // Create a frame using Frame::from_payload
960        let header = tungstenite::protocol::frame::FrameHeader::default();
961        let payload = tungstenite::Bytes::from_static(&[0x01, 0x02, 0x03]);
962        let frame = tungstenite::protocol::frame::Frame::from_payload(header, payload);
963        let frame_msg = Message::Frame(frame);
964        let result = server.handle_message(frame_msg).unwrap();
965
966        // Frame messages should be ignored (return None)
967        assert!(result.is_none());
968    }
969
970    #[test]
971    fn test_handshake_success_response_headers() {
972        let req = Request::builder()
973            .method("GET")
974            .uri("/ws")
975            .header("Upgrade", "websocket")
976            .header("Connection", "Upgrade")
977            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
978            .header("Sec-WebSocket-Version", "13")
979            .body(Full::new(Bytes::new()))
980            .unwrap();
981
982        let response = WebSocketServer::handshake(&req).unwrap();
983
984        // Verify status code
985        assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
986
987        // Verify Upgrade header
988        assert_eq!(
989            response.headers().get(UPGRADE).unwrap().to_str().unwrap(),
990            "websocket"
991        );
992
993        // Verify Connection header
994        assert_eq!(
995            response.headers().get(CONNECTION).unwrap().to_str().unwrap(),
996            "Upgrade"
997        );
998
999        // Verify Sec-WebSocket-Accept header
1000        assert_eq!(
1001            response.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap().to_str().unwrap(),
1002            "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
1003        );
1004    }
1005
1006    #[test]
1007    fn test_websocket_message_debug() {
1008        // Test Debug implementation for WebSocketMessage
1009        let text_msg = WebSocketMessage::Text("Hello".to_string());
1010        let debug_str = format!("{:?}", text_msg);
1011        assert!(debug_str.contains("Text"));
1012        assert!(debug_str.contains("Hello"));
1013
1014        let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
1015        let debug_str = format!("{:?}", binary_msg);
1016        assert!(debug_str.contains("Binary"));
1017
1018        let close_msg = WebSocketMessage::Close {
1019            code: 1000,
1020            reason: "Normal".to_string(),
1021        };
1022        let debug_str = format!("{:?}", close_msg);
1023        assert!(debug_str.contains("Close"));
1024        assert!(debug_str.contains("1000"));
1025    }
1026
1027    #[test]
1028    fn test_websocket_message_clone() {
1029        // Test Clone implementation
1030        let msg = WebSocketMessage::Text("Hello".to_string());
1031        let cloned = msg.clone();
1032        assert_eq!(msg, cloned);
1033
1034        let close_msg = WebSocketMessage::Close {
1035            code: 1001,
1036            reason: "Going away".to_string(),
1037        };
1038        let cloned_close = close_msg.clone();
1039        assert_eq!(close_msg, cloned_close);
1040    }
1041
1042    #[tokio::test]
1043    async fn test_websocket_handler_new() {
1044        let config = Config::default();
1045        let server = Arc::new(WebSocketServer::new(config));
1046
1047        let (handler, sender) = WebSocketHandler::new(server, "test-conn-123".to_string());
1048
1049        assert_eq!(handler.connection_id, "test-conn-123");
1050        assert!(handler.receiver.is_some());
1051
1052        // Test that sender works
1053        let msg = WebSocketMessage::Text("Test".to_string());
1054        assert!(sender.send(msg).is_ok());
1055    }
1056
1057    #[tokio::test]
1058    async fn test_broadcast_multiple_messages() {
1059        let config = Config::default();
1060        let server = WebSocketServer::new(config);
1061
1062        let (sender, mut receiver) = mpsc::unbounded_channel();
1063        server.add_connection("conn1".to_string(), sender).await;
1064
1065        // Broadcast multiple messages
1066        for i in 0..10 {
1067            let msg = WebSocketMessage::Text(format!("Message {}", i));
1068            server.broadcast(msg).await;
1069        }
1070
1071        // Verify all messages received
1072        for i in 0..10 {
1073            let received = receiver.recv().await.unwrap();
1074            assert_eq!(received, WebSocketMessage::Text(format!("Message {}", i)));
1075        }
1076    }
1077
1078    #[tokio::test]
1079    async fn test_send_to_various_message_types() {
1080        let config = Config::default();
1081        let server = WebSocketServer::new(config);
1082
1083        let (sender, mut receiver) = mpsc::unbounded_channel();
1084        server.add_connection("test_conn".to_string(), sender).await;
1085
1086        // Send text message
1087        let text_msg = WebSocketMessage::Text("Hello".to_string());
1088        server.send_to("test_conn", text_msg.clone()).await.unwrap();
1089        assert_eq!(receiver.recv().await.unwrap(), text_msg);
1090
1091        // Send binary message
1092        let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
1093        server.send_to("test_conn", binary_msg.clone()).await.unwrap();
1094        assert_eq!(receiver.recv().await.unwrap(), binary_msg);
1095
1096        // Send ping message
1097        let ping_msg = WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6]));
1098        server.send_to("test_conn", ping_msg.clone()).await.unwrap();
1099        assert_eq!(receiver.recv().await.unwrap(), ping_msg);
1100
1101        // Send pong message
1102        let pong_msg = WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9]));
1103        server.send_to("test_conn", pong_msg.clone()).await.unwrap();
1104        assert_eq!(receiver.recv().await.unwrap(), pong_msg);
1105
1106        // Send close message
1107        let close_msg = WebSocketMessage::Close {
1108            code: 1000,
1109            reason: "Normal".to_string(),
1110        };
1111        server.send_to("test_conn", close_msg.clone()).await.unwrap();
1112        assert_eq!(receiver.recv().await.unwrap(), close_msg);
1113    }
1114
1115    #[test]
1116    fn test_handle_close_with_different_codes() {
1117        let config = Config::default();
1118        let server = WebSocketServer::new(config);
1119
1120        // Test different close codes
1121        let codes = vec![
1122            (1000u16, "Normal closure"),
1123            (1001, "Going away"),
1124            (1002, "Protocol error"),
1125            (1003, "Unsupported data"),
1126            (1006, "Abnormal closure"),
1127            (1008, "Policy violation"),
1128            (1009, "Message too big"),
1129            (1010, "Mandatory extension"),
1130            (1011, "Internal error"),
1131            (1015, "TLS handshake"),
1132        ];
1133
1134        for (code, reason) in codes {
1135            let close_frame = tungstenite::protocol::frame::CloseFrame {
1136                code: tungstenite::protocol::frame::coding::CloseCode::from(code),
1137                reason: reason.into(),
1138            };
1139            let message = Message::Close(Some(close_frame));
1140            let result = server.handle_message(message).unwrap();
1141
1142            match result {
1143                Some(WebSocketMessage::Close { code: c, reason: r }) => {
1144                    assert_eq!(c, code);
1145                    assert_eq!(r, reason);
1146                }
1147                _ => panic!("Expected Close message with code {}", code),
1148            }
1149        }
1150    }
1151
1152    #[test]
1153    fn test_handle_empty_binary_message() {
1154        let config = Config::default();
1155        let server = WebSocketServer::new(config);
1156
1157        let message = Message::Binary(tungstenite::Bytes::from(vec![]));
1158        let result = server.handle_message(message).unwrap();
1159
1160        assert!(result.is_some());
1161        match result.unwrap() {
1162            WebSocketMessage::Binary(data) => {
1163                assert!(data.is_empty());
1164            }
1165            _ => panic!("Expected Binary message"),
1166        }
1167    }
1168
1169    #[test]
1170    fn test_handle_empty_text_message() {
1171        let config = Config::default();
1172        let server = WebSocketServer::new(config);
1173
1174        let message = Message::Text("".into());
1175        let result = server.handle_message(message).unwrap();
1176
1177        assert!(result.is_some());
1178        match result.unwrap() {
1179            WebSocketMessage::Text(text) => {
1180                assert!(text.is_empty());
1181            }
1182            _ => panic!("Expected Text message"),
1183        }
1184    }
1185
1186    #[test]
1187    fn test_handle_empty_ping_pong() {
1188        let config = Config::default();
1189        let server = WebSocketServer::new(config);
1190
1191        // Empty ping
1192        let ping_msg = Message::Ping(tungstenite::Bytes::from(vec![]));
1193        let result = server.handle_message(ping_msg).unwrap();
1194        match result {
1195            Some(WebSocketMessage::Pong(data)) => {
1196                assert!(data.is_empty());
1197            }
1198            _ => panic!("Expected Pong with empty data"),
1199        }
1200
1201        // Empty pong
1202        let pong_msg = Message::Pong(tungstenite::Bytes::from(vec![]));
1203        let result = server.handle_message(pong_msg).unwrap();
1204        match result {
1205            Some(WebSocketMessage::Pong(data)) => {
1206                assert!(data.is_empty());
1207            }
1208            _ => panic!("Expected Pong with empty data"),
1209        }
1210    }
1211
1212    #[tokio::test]
1213    async fn test_list_connections_empty() {
1214        let config = Config::default();
1215        let server = WebSocketServer::new(config);
1216
1217        let connections = server.list_connections().await;
1218        assert!(connections.is_empty());
1219    }
1220
1221    #[tokio::test]
1222    async fn test_list_connections_order() {
1223        let config = Config::default();
1224        let server = WebSocketServer::new(config);
1225
1226        // Add connections in specific order
1227        for i in 0..5 {
1228            let (sender, _) = mpsc::unbounded_channel();
1229            server.add_connection(format!("conn_{}", i), sender).await;
1230        }
1231
1232        let connections = server.list_connections().await;
1233        assert_eq!(connections.len(), 5);
1234
1235        // Verify all connections are present
1236        for i in 0..5 {
1237            assert!(connections.contains(&format!("conn_{}", i)));
1238        }
1239    }
1240
1241    #[tokio::test]
1242    async fn test_remove_nonexistent_connection() {
1243        let config = Config::default();
1244        let server = WebSocketServer::new(config);
1245
1246        // Remove a connection that doesn't exist - should not panic
1247        server.remove_connection("nonexistent").await;
1248
1249        // Count should still be 0
1250        assert_eq!(server.connection_count().await, 0);
1251    }
1252
1253    #[tokio::test]
1254    async fn test_add_duplicate_connection() {
1255        let config = Config::default();
1256        let server = WebSocketServer::new(config);
1257
1258        let (sender1, _receiver1) = mpsc::unbounded_channel();
1259        let (sender2, mut receiver2) = mpsc::unbounded_channel();
1260
1261        // Add first connection
1262        server.add_connection("duplicate_conn".to_string(), sender1).await;
1263        assert_eq!(server.connection_count().await, 1);
1264
1265        // Add with same ID - should replace
1266        server.add_connection("duplicate_conn".to_string(), sender2).await;
1267        assert_eq!(server.connection_count().await, 1);
1268
1269        // New sender should work
1270        let msg = WebSocketMessage::Text("Test".to_string());
1271        server.send_to("duplicate_conn", msg.clone()).await.unwrap();
1272        assert_eq!(receiver2.recv().await.unwrap(), msg);
1273    }
1274
1275    #[test]
1276    fn test_is_websocket_upgrade_variations() {
1277        // Test with "keep-alive, Upgrade" connection header
1278        let req = Request::builder()
1279            .method("GET")
1280            .uri("/ws")
1281            .header("Upgrade", "websocket")
1282            .header("Connection", "keep-alive, Upgrade")
1283            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
1284            .header("Sec-WebSocket-Version", "13")
1285            .body(Full::new(Bytes::new()))
1286            .unwrap();
1287
1288        assert!(WebSocketServer::is_websocket_upgrade(&req));
1289
1290        // Test with "Upgrade, keep-alive" connection header
1291        let req = Request::builder()
1292            .method("GET")
1293            .uri("/ws")
1294            .header("Upgrade", "websocket")
1295            .header("Connection", "Upgrade, keep-alive")
1296            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
1297            .header("Sec-WebSocket-Version", "13")
1298            .body(Full::new(Bytes::new()))
1299            .unwrap();
1300
1301        assert!(WebSocketServer::is_websocket_upgrade(&req));
1302    }
1303
1304    #[test]
1305    fn test_websocket_server_new() {
1306        let config = Config::default();
1307        let server = WebSocketServer::new(config);
1308
1309        // Server should be created successfully
1310        // We can't directly access private fields, but we can test functionality
1311        futures::executor::block_on(async {
1312            assert_eq!(server.connection_count().await, 0);
1313        });
1314    }
1315
1316    #[test]
1317    fn test_to_tungstenite_close_various_codes() {
1318        let config = Config::default();
1319        let server = WebSocketServer::new(config);
1320
1321        // Test various close codes
1322        let test_cases = vec![
1323            (1000u16, "Normal"),
1324            (1001, "Going away"),
1325            (1006, "Abnormal"),
1326            (1011, "Error"),
1327        ];
1328
1329        for (code, reason) in test_cases {
1330            let msg = WebSocketMessage::Close {
1331                code,
1332                reason: reason.to_string(),
1333            };
1334            let converted = server.to_tungstenite_message(&msg);
1335
1336            match converted {
1337                Message::Close(Some(frame)) => {
1338                    // Compare the close codes by converting to u16
1339                    let frame_code: u16 = frame.code.into();
1340                    assert_eq!(frame_code, code);
1341                    assert_eq!(frame.reason, std::borrow::Cow::from(reason));
1342                }
1343                _ => panic!("Expected Close message with frame for code {}", code),
1344            }
1345        }
1346    }
1347
1348    #[tokio::test]
1349    async fn test_broadcast_binary_message() {
1350        let config = Config::default();
1351        let server = WebSocketServer::new(config);
1352
1353        let (sender, mut receiver) = mpsc::unbounded_channel();
1354        server.add_connection("conn1".to_string(), sender).await;
1355
1356        // Broadcast binary message
1357        let binary_data = vec![0u8, 1, 2, 3, 255, 254, 253];
1358        let msg = WebSocketMessage::Binary(tungstenite::Bytes::from(binary_data.clone()));
1359        server.broadcast(msg).await;
1360
1361        let received = receiver.recv().await.unwrap();
1362        match received {
1363            WebSocketMessage::Binary(data) => {
1364                assert_eq!(data.to_vec(), binary_data);
1365            }
1366            _ => panic!("Expected Binary message"),
1367        }
1368    }
1369
1370    #[tokio::test]
1371    async fn test_broadcast_ping_pong() {
1372        let config = Config::default();
1373        let server = WebSocketServer::new(config);
1374
1375        let (sender, mut receiver) = mpsc::unbounded_channel();
1376        server.add_connection("conn1".to_string(), sender).await;
1377
1378        // Broadcast ping
1379        let ping_data = vec![1, 2, 3];
1380        let ping_msg = WebSocketMessage::Ping(tungstenite::Bytes::from(ping_data.clone()));
1381        server.broadcast(ping_msg).await;
1382
1383        let received = receiver.recv().await.unwrap();
1384        match received {
1385            WebSocketMessage::Ping(data) => {
1386                assert_eq!(data.to_vec(), ping_data);
1387            }
1388            _ => panic!("Expected Ping message"),
1389        }
1390
1391        // Broadcast pong
1392        let pong_data = vec![4, 5, 6];
1393        let pong_msg = WebSocketMessage::Pong(tungstenite::Bytes::from(pong_data.clone()));
1394        server.broadcast(pong_msg).await;
1395
1396        let received = receiver.recv().await.unwrap();
1397        match received {
1398            WebSocketMessage::Pong(data) => {
1399                assert_eq!(data.to_vec(), pong_data);
1400            }
1401            _ => panic!("Expected Pong message"),
1402        }
1403    }
1404
1405    #[tokio::test]
1406    async fn test_broadcast_close_message() {
1407        let config = Config::default();
1408        let server = WebSocketServer::new(config);
1409
1410        let (sender, mut receiver) = mpsc::unbounded_channel();
1411        server.add_connection("conn1".to_string(), sender).await;
1412
1413        // Broadcast close message
1414        let close_msg = WebSocketMessage::Close {
1415            code: 1000,
1416            reason: "Server shutting down".to_string(),
1417        };
1418        server.broadcast(close_msg.clone()).await;
1419
1420        let received = receiver.recv().await.unwrap();
1421        assert_eq!(received, close_msg);
1422    }
1423
1424    #[test]
1425    fn test_websocket_connection_struct() {
1426        // Test WebSocketConnection struct creation and field access
1427        let conn = WebSocketConnection {
1428            id: "test-conn-123".to_string(),
1429            remote_addr: "192.168.1.1:12345".to_string(),
1430            connected_at: std::time::Instant::now(),
1431        };
1432
1433        // Test field access
1434        assert_eq!(conn.id, "test-conn-123");
1435        assert_eq!(conn.remote_addr, "192.168.1.1:12345");
1436        
1437        // Test Clone
1438        let cloned = conn.clone();
1439        assert_eq!(cloned.id, conn.id);
1440        assert_eq!(cloned.remote_addr, conn.remote_addr);
1441        
1442        // Test Debug
1443        let debug_str = format!("{:?}", conn);
1444        assert!(debug_str.contains("test-conn-123"));
1445        assert!(debug_str.contains("192.168.1.1:12345"));
1446    }
1447
1448    #[test]
1449    fn test_websocket_handler_receiver_already_taken() {
1450        let config = Config::default();
1451        let server = Arc::new(WebSocketServer::new(config));
1452
1453        let (mut handler, _sender) = WebSocketHandler::new(server, "test_conn".to_string());
1454
1455        // Take the receiver once
1456        let _receiver = handler.receiver.take().unwrap();
1457
1458        // Verify receiver is now None
1459        assert!(handler.receiver.is_none());
1460    }
1461
1462    #[tokio::test]
1463    async fn test_websocket_handler_drop_receiver() {
1464        let config = Config::default();
1465        let server = Arc::new(WebSocketServer::new(config));
1466
1467        let (handler, sender) = WebSocketHandler::new(server, "test_conn".to_string());
1468
1469        // Drop the handler, sender should still work
1470        drop(handler);
1471
1472        // Sender should still be able to send (though receiver is gone)
1473        let msg = WebSocketMessage::Text("Test".to_string());
1474        // This will fail because receiver is dropped
1475        assert!(sender.send(msg).is_err());
1476    }
1477
1478    #[test]
1479    fn test_websocket_message_variants() {
1480        // Test all WebSocketMessage variants
1481        let text = WebSocketMessage::Text("hello".to_string());
1482        let binary = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
1483        let ping = WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6]));
1484        let pong = WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9]));
1485        let close = WebSocketMessage::Close {
1486            code: 1000,
1487            reason: "Normal".to_string(),
1488        };
1489
1490        // Test equality
1491        assert_eq!(text, WebSocketMessage::Text("hello".to_string()));
1492        assert_eq!(binary, WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3])));
1493        assert_eq!(ping, WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6])));
1494        assert_eq!(pong, WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9])));
1495        assert_eq!(close.clone(), close);
1496
1497        // Test inequality
1498        assert_ne!(text, WebSocketMessage::Text("world".to_string()));
1499        assert_ne!(binary, WebSocketMessage::Binary(tungstenite::Bytes::from(vec![4, 5, 6])));
1500    }
1501
1502    #[test]
1503    fn test_is_websocket_upgrade_no_headers() {
1504        // Test with no headers at all
1505        let req = Request::builder()
1506            .method("GET")
1507            .uri("/ws")
1508            .body(Full::new(Bytes::new()))
1509            .unwrap();
1510
1511        assert!(!WebSocketServer::is_websocket_upgrade(&req));
1512    }
1513
1514    #[test]
1515    fn test_is_websocket_upgrade_no_upgrade_header() {
1516        // Test with Connection header but no Upgrade header
1517        let req = Request::builder()
1518            .method("GET")
1519            .uri("/ws")
1520            .header("Connection", "Upgrade")
1521            .body(Full::new(Bytes::new()))
1522            .unwrap();
1523
1524        assert!(!WebSocketServer::is_websocket_upgrade(&req));
1525    }
1526
1527    #[test]
1528    fn test_is_websocket_upgrade_no_connection_header() {
1529        // Test with Upgrade header but no Connection header
1530        let req = Request::builder()
1531            .method("GET")
1532            .uri("/ws")
1533            .header("Upgrade", "websocket")
1534            .body(Full::new(Bytes::new()))
1535            .unwrap();
1536
1537        assert!(!WebSocketServer::is_websocket_upgrade(&req));
1538    }
1539
1540    #[test]
1541    fn test_is_websocket_upgrade_wrong_upgrade_value() {
1542        // Test with wrong Upgrade value
1543        let req = Request::builder()
1544            .method("GET")
1545            .uri("/ws")
1546            .header("Upgrade", "http2")
1547            .header("Connection", "Upgrade")
1548            .body(Full::new(Bytes::new()))
1549            .unwrap();
1550
1551        assert!(!WebSocketServer::is_websocket_upgrade(&req));
1552    }
1553
1554    #[test]
1555    fn test_is_websocket_upgrade_connection_without_upgrade() {
1556        // Test with Connection header that doesn't contain "upgrade"
1557        let req = Request::builder()
1558            .method("GET")
1559            .uri("/ws")
1560            .header("Upgrade", "websocket")
1561            .header("Connection", "keep-alive")
1562            .body(Full::new(Bytes::new()))
1563            .unwrap();
1564
1565        assert!(!WebSocketServer::is_websocket_upgrade(&req));
1566    }
1567
1568    #[test]
1569    fn test_handshake_response_body_empty() {
1570        let req = Request::builder()
1571            .method("GET")
1572            .uri("/ws")
1573            .header("Upgrade", "websocket")
1574            .header("Connection", "Upgrade")
1575            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
1576            .header("Sec-WebSocket-Version", "13")
1577            .body(Full::new(Bytes::new()))
1578            .unwrap();
1579
1580        let response = WebSocketServer::handshake(&req).unwrap();
1581
1582        // Verify body is empty - Full<Bytes> doesn't have is_empty
1583        // Just verify the response status is correct (Switching Protocols)
1584        assert_eq!(response.status(), 101);
1585    }
1586
1587    #[tokio::test]
1588    async fn test_connection_count_after_operations() {
1589        let config = Config::default();
1590        let server = WebSocketServer::new(config);
1591
1592        // Initially 0
1593        assert_eq!(server.connection_count().await, 0);
1594
1595        // Add multiple connections
1596        for i in 0..10 {
1597            let (sender, _) = mpsc::unbounded_channel();
1598            server.add_connection(format!("conn{}", i), sender).await;
1599            assert_eq!(server.connection_count().await, i + 1);
1600        }
1601
1602        // Remove all connections one by one
1603        for i in (0..10).rev() {
1604            server.remove_connection(&format!("conn{}", i)).await;
1605            assert_eq!(server.connection_count().await, i);
1606        }
1607
1608        // Should be 0 again
1609        assert_eq!(server.connection_count().await, 0);
1610    }
1611
1612    #[tokio::test]
1613    async fn test_broadcast_after_remove_all() {
1614        let config = Config::default();
1615        let server = WebSocketServer::new(config);
1616
1617        // Add connections
1618        let (sender1, _receiver1) = mpsc::unbounded_channel();
1619        let (sender2, _receiver2) = mpsc::unbounded_channel();
1620        server.add_connection("conn1".to_string(), sender1).await;
1621        server.add_connection("conn2".to_string(), sender2).await;
1622
1623        // Remove all
1624        server.remove_connection("conn1").await;
1625        server.remove_connection("conn2").await;
1626
1627        // Broadcast should work without panicking
1628        let msg = WebSocketMessage::Text("Test".to_string());
1629        server.broadcast(msg).await;
1630
1631        // Count should be 0
1632        assert_eq!(server.connection_count().await, 0);
1633    }
1634
1635    #[test]
1636    fn test_generate_accept_key_unicode() {
1637        // Test with unicode characters
1638        let unicode_key = "πŸ”πŸ”‘πŸ—οΈε―†ι’₯";
1639        let accept_key = WebSocketServer::generate_accept_key(unicode_key).unwrap();
1640        assert_eq!(accept_key.len(), 28);
1641
1642        // Test with emoji only
1643        let emoji_key = "πŸš€πŸš€πŸš€";
1644        let accept_key = WebSocketServer::generate_accept_key(emoji_key).unwrap();
1645        assert_eq!(accept_key.len(), 28);
1646    }
1647
1648    #[test]
1649    fn test_to_tungstenite_message_preserves_data() {
1650        let config = Config::default();
1651        let server = WebSocketServer::new(config);
1652
1653        // Test text preservation
1654        let text_data = "Hello, δΈ–η•Œ! 🌍";
1655        let text_msg = WebSocketMessage::Text(text_data.to_string());
1656        let converted = server.to_tungstenite_message(&text_msg);
1657        if let Message::Text(text) = converted {
1658            assert_eq!(text, text_data);
1659        } else {
1660            panic!("Expected Text message");
1661        }
1662
1663        // Test binary preservation
1664        let binary_data: Vec<u8> = (0..256).map(|i| i as u8).collect();
1665        let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(binary_data.clone()));
1666        let converted = server.to_tungstenite_message(&binary_msg);
1667        if let Message::Binary(data) = converted {
1668            assert_eq!(data.to_vec(), binary_data);
1669        } else {
1670            panic!("Expected Binary message");
1671        }
1672    }
1673
1674    #[test]
1675    fn test_websocket_message_close_edge_codes() {
1676        let config = Config::default();
1677        let server = WebSocketServer::new(config);
1678
1679        // Test edge close codes
1680        let edge_codes = vec![
1681            0u16,     // Minimum
1682            999,      // Just below normal
1683            1000,     // Normal closure
1684            1001,     // Going away
1685            2999,     // Application specific range
1686            3000,     // Application specific
1687            3999,     // Max application specific
1688            4000,     // Private use
1689            4999,     // Max valid code
1690        ];
1691
1692        for code in edge_codes {
1693            let close_frame = tungstenite::protocol::frame::CloseFrame {
1694                code: tungstenite::protocol::frame::coding::CloseCode::from(code),
1695                reason: "Test".into(),
1696            };
1697            let message = Message::Close(Some(close_frame));
1698            let result = server.handle_message(message).unwrap();
1699
1700            match result {
1701                Some(WebSocketMessage::Close { code: c, .. }) => {
1702                    assert_eq!(c, code);
1703                }
1704                _ => panic!("Expected Close message with code {}", code),
1705            }
1706        }
1707    }
1708
1709    #[tokio::test]
1710    async fn test_send_to_all_message_types() {
1711        let config = Config::default();
1712        let server = WebSocketServer::new(config);
1713
1714        let (sender, mut receiver) = mpsc::unbounded_channel();
1715        server.add_connection("test_conn".to_string(), sender).await;
1716
1717        // Test all message types via send_to
1718        let test_cases = vec![
1719            WebSocketMessage::Text("Hello".to_string()),
1720            WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3])),
1721            WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6])),
1722            WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9])),
1723            WebSocketMessage::Close {
1724                code: 1000,
1725                reason: "Normal".to_string(),
1726            },
1727        ];
1728
1729        for msg in test_cases {
1730            let msg_clone = msg.clone();
1731            server.send_to("test_conn", msg).await.unwrap();
1732            let received = receiver.recv().await.unwrap();
1733            assert_eq!(received, msg_clone);
1734        }
1735    }
1736}