1use crate::config::Config;
11use crate::error::Error;
12use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
13use hyper::header::{CONNECTION, UPGRADE, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION};
14use hyper::{Request, Response, StatusCode, Method};
15use http_body_util::Full;
16use hyper::body::Bytes;
17use sha1::{Digest, Sha1};
18use std::collections::HashMap;
19use std::sync::Arc;
20use tokio::sync::{mpsc, RwLock};
21use tokio_tungstenite::tungstenite::protocol::Message;
22use tokio_tungstenite::WebSocketStream;
23use futures::{StreamExt, SinkExt};
24
25#[derive(Debug, Clone, PartialEq)]
27pub enum WebSocketMessage {
28 Text(String),
30 Binary(tungstenite::Bytes),
32 Ping(tungstenite::Bytes),
34 Pong(tungstenite::Bytes),
36 Close { code: u16, reason: String },
38}
39
40#[derive(Debug, Clone)]
42pub struct WebSocketConnection {
43 pub id: String,
45 pub remote_addr: String,
47 pub connected_at: std::time::Instant,
49}
50
51pub struct WebSocketServer {
53 connections: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<WebSocketMessage>>>>,
54}
55
56impl WebSocketServer {
57 pub fn new(_config: Config) -> Self {
59 Self {
60 connections: Arc::new(RwLock::new(HashMap::new())),
61 }
62 }
63
64 pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
66 req.method() == Method::GET
67 && req.headers().get(UPGRADE)
68 .and_then(|v| v.to_str().ok())
69 .map(|v| v.to_lowercase() == "websocket")
70 .unwrap_or(false)
71 && req.headers().get(CONNECTION)
72 .and_then(|v| v.to_str().ok())
73 .map(|v| v.to_lowercase().contains("upgrade"))
74 .unwrap_or(false)
75 }
76
77 pub fn handshake<B>(req: &Request<B>) -> Result<Response<Full<Bytes>>, Error> {
79 let ws_version = req.headers()
81 .get(SEC_WEBSOCKET_VERSION)
82 .and_then(|v| v.to_str().ok())
83 .ok_or_else(|| Error::Http("Missing WebSocket-Version header".to_string()))?;
84
85 if ws_version != "13" {
86 return Err(Error::Http("Unsupported WebSocket version".to_string()));
87 }
88
89 let ws_key = req.headers()
91 .get(SEC_WEBSOCKET_KEY)
92 .and_then(|v| v.to_str().ok())
93 .ok_or_else(|| Error::Http("Missing WebSocket-Key header".to_string()))?;
94
95 let accept_key = Self::generate_accept_key(ws_key)?;
97
98 let response = Response::builder()
100 .status(StatusCode::SWITCHING_PROTOCOLS)
101 .header(UPGRADE, "websocket")
102 .header(CONNECTION, "Upgrade")
103 .header(SEC_WEBSOCKET_ACCEPT, accept_key)
104 .body(Full::new(Bytes::new()))
105 .map_err(|e| Error::Internal(format!("Failed to build response: {}", e)))?;
106
107 Ok(response)
108 }
109
110 fn generate_accept_key(ws_key: &str) -> Result<String, Error> {
112 let magic_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
113 let combined = format!("{}{}", ws_key, magic_guid);
114
115 let mut hasher = Sha1::new();
116 hasher.update(combined.as_bytes());
117 let hash = hasher.finalize();
118
119 let accept_key = BASE64.encode(&hash);
120 Ok(accept_key)
121 }
122
123 pub async fn add_connection(&self, id: String, sender: mpsc::UnboundedSender<WebSocketMessage>) {
125 let mut connections = self.connections.write().await;
126 connections.insert(id, sender);
127 }
128
129 pub async fn remove_connection(&self, id: &str) {
131 let mut connections = self.connections.write().await;
132 connections.remove(id);
133 }
134
135 pub async fn broadcast(&self, message: WebSocketMessage) {
137 let connections = self.connections.read().await;
138 let mut failed_connections = Vec::new();
139
140 for (id, sender) in connections.iter() {
141 if sender.send(message.clone()).is_err() {
142 failed_connections.push(id.clone());
143 }
144 }
145
146 if !failed_connections.is_empty() {
148 drop(connections);
149 let mut connections = self.connections.write().await;
150 for id in failed_connections {
151 connections.remove(&id);
152 }
153 }
154 }
155
156 pub async fn send_to(&self, id: &str, message: WebSocketMessage) -> Result<(), Error> {
158 let connections = self.connections.read().await;
159 if let Some(sender) = connections.get(id) {
160 sender.send(message)
161 .map_err(|_| Error::Http("Connection closed".to_string()))?;
162 } else {
163 return Err(Error::Http("Connection not found".to_string()));
164 }
165 Ok(())
166 }
167
168 pub async fn connection_count(&self) -> usize {
170 let connections = self.connections.read().await;
171 connections.len()
172 }
173
174 pub async fn list_connections(&self) -> Vec<String> {
176 let connections = self.connections.read().await;
177 connections.keys().cloned().collect()
178 }
179
180 pub fn handle_message(&self, message: Message) -> Result<Option<WebSocketMessage>, Error> {
182 match message {
183 Message::Text(text) => Ok(Some(WebSocketMessage::Text(text.to_string()))),
184 Message::Binary(data) => Ok(Some(WebSocketMessage::Binary(data))),
185 Message::Ping(data) => {
186 Ok(Some(WebSocketMessage::Pong(data)))
188 }
189 Message::Pong(data) => Ok(Some(WebSocketMessage::Pong(data))),
190 Message::Close(frame) => {
191 if let Some(frame) = frame {
192 Ok(Some(WebSocketMessage::Close {
193 code: frame.code.into(),
194 reason: frame.reason.to_string(),
195 }))
196 } else {
197 Ok(Some(WebSocketMessage::Close {
198 code: 1000,
199 reason: String::new(),
200 }))
201 }
202 }
203 Message::Frame(_) => {
204 Ok(None)
206 }
207 }
208 }
209
210 pub fn to_tungstenite_message(&self, message: &WebSocketMessage) -> Message {
212 match message {
213 WebSocketMessage::Text(text) => Message::Text(text.clone().into()),
214 WebSocketMessage::Binary(data) => Message::Binary(data.clone()),
215 WebSocketMessage::Ping(data) => Message::Ping(data.clone()),
216 WebSocketMessage::Pong(data) => Message::Pong(data.clone()),
217 WebSocketMessage::Close { code, reason } => {
218 Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
219 code: tungstenite::protocol::frame::coding::CloseCode::from(*code),
220 reason: reason.clone().into(),
221 }))
222 }
223 }
224 }
225}
226
227pub struct WebSocketHandler {
229 server: Arc<WebSocketServer>,
230 connection_id: String,
231 receiver: Option<mpsc::UnboundedReceiver<WebSocketMessage>>,
232}
233
234impl WebSocketHandler {
235 pub fn new(server: Arc<WebSocketServer>, connection_id: String) -> (Self, mpsc::UnboundedSender<WebSocketMessage>) {
237 let (sender, receiver) = mpsc::unbounded_channel();
238
239 let handler = Self {
240 server,
241 connection_id,
242 receiver: Some(receiver),
243 };
244
245 (handler, sender)
246 }
247
248 pub async fn handle_connection<T>(&mut self, mut ws_stream: WebSocketStream<T>) -> Result<(), Error>
250 where
251 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static,
252 {
253 let mut receiver = self.receiver.take().ok_or_else(|| Error::Internal("Receiver already taken".to_string()))?;
254 let server = self.server.clone();
255
256 loop {
258 tokio::select! {
259 msg = receiver.recv() => {
261 match msg {
262 Some(message) => {
263 let tungstenite_msg = match message {
264 WebSocketMessage::Text(text) => Message::Text(text.into()),
265 WebSocketMessage::Binary(data) => Message::Binary(data),
266 WebSocketMessage::Ping(data) => Message::Ping(data),
267 WebSocketMessage::Pong(data) => Message::Pong(data),
268 WebSocketMessage::Close { code, reason } => Message::Close(Some(
269 tungstenite::protocol::frame::CloseFrame {
270 code: tungstenite::protocol::frame::coding::CloseCode::from(code),
271 reason: reason.into(),
272 }
273 )),
274 };
275
276 if tungstenite_msg.is_close() || SinkExt::send(&mut ws_stream, tungstenite_msg).await.is_err() {
277 break;
278 }
279 }
280 None => break,
281 }
282 }
283 result = ws_stream.next() => {
285 match result {
286 Some(Ok(message)) => {
287 if let Some(ws_message) = server.handle_message(message)? {
288 if matches!(ws_message, WebSocketMessage::Close { .. }) {
290 break;
291 }
292 }
293 }
294 Some(Err(e)) => {
295 tracing::error!("WebSocket error: {}", e);
296 break;
297 }
298 None => break,
299 }
300 }
301 }
302 }
303
304 server.remove_connection(&self.connection_id).await;
306 tracing::info!("WebSocket connection {} closed", self.connection_id);
307
308 Ok(())
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use hyper::Request;
316
317 #[test]
318 fn test_websocket_upgrade_detection() {
319 let config = Config::default();
320 let server = WebSocketServer::new(config);
321
322 let req = Request::builder()
323 .method("GET")
324 .uri("/ws")
325 .header("Upgrade", "websocket")
326 .header("Connection", "Upgrade")
327 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
328 .header("Sec-WebSocket-Version", "13")
329 .body(Bytes::new())
330 .unwrap();
331
332 assert!(WebSocketServer::is_websocket_upgrade(&req));
333 }
334
335 #[test]
336 fn test_non_websocket_request() {
337 let config = Config::default();
338 let server = WebSocketServer::new(config);
339
340 let req = Request::builder()
341 .method("GET")
342 .uri("/")
343 .body(Bytes::new())
344 .unwrap();
345
346 assert!(!WebSocketServer::is_websocket_upgrade(&req));
347 }
348
349 #[test]
350 fn test_websocket_version_validation() {
351 let config = Config::default();
352 let server = WebSocketServer::new(config);
353
354 let req = Request::builder()
356 .method("GET")
357 .uri("/ws")
358 .header("Upgrade", "websocket")
359 .header("Connection", "Upgrade")
360 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
361 .body(Bytes::new())
362 .unwrap();
363
364 assert!(WebSocketServer::handshake(&req).is_err());
365
366 let req = Request::builder()
368 .method("GET")
369 .uri("/ws")
370 .header("Upgrade", "websocket")
371 .header("Connection", "Upgrade")
372 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
373 .header("Sec-WebSocket-Version", "12")
374 .body(Bytes::new())
375 .unwrap();
376
377 assert!(WebSocketServer::handshake(&req).is_err());
378 }
379
380 #[test]
381 fn test_generate_accept_key() {
382 let ws_key = "dGhlIHNhbXBsZSBub25jZQ==";
383 let accept_key = WebSocketServer::generate_accept_key(ws_key).unwrap();
384
385 assert_eq!(accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
387 }
388
389 #[test]
390 fn test_message_conversion() {
391 let config = Config::default();
392 let server = WebSocketServer::new(config);
393
394 let text_msg = WebSocketMessage::Text("Hello".to_string());
396 let converted = server.to_tungstenite_message(&text_msg);
397 assert!(matches!(converted, Message::Text(_)));
398
399 let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
401 let converted = server.to_tungstenite_message(&binary_msg);
402 assert!(matches!(converted, Message::Binary(_)));
403
404 let close_msg = WebSocketMessage::Close {
406 code: 1000,
407 reason: "Normal closure".to_string(),
408 };
409 let converted = server.to_tungstenite_message(&close_msg);
410 assert!(matches!(converted, Message::Close(_)));
411 }
412
413 #[test]
414 fn test_connection_management() {
415 let config = Config::default();
416 let server = WebSocketServer::new(config);
417
418 futures::executor::block_on(async {
420 assert_eq!(server.connection_count().await, 0);
421
422 let (sender, _) = mpsc::unbounded_channel();
424 server.add_connection("test_conn".to_string(), sender).await;
425 assert_eq!(server.connection_count().await, 1);
426
427 server.remove_connection("test_conn").await;
429 assert_eq!(server.connection_count().await, 0);
430
431 let (sender1, _) = mpsc::unbounded_channel();
433 let (sender2, _) = mpsc::unbounded_channel();
434 server.add_connection("conn1".to_string(), sender1).await;
435 server.add_connection("conn2".to_string(), sender2).await;
436
437 let connections = server.list_connections().await;
438 assert_eq!(connections.len(), 2);
439 assert!(connections.contains(&"conn1".to_string()));
440 assert!(connections.contains(&"conn2".to_string()));
441 });
442 }
443
444 #[test]
445 fn test_send_to_connection() {
446 let config = Config::default();
447 let server = WebSocketServer::new(config);
448
449 futures::executor::block_on(async {
450 let (sender, mut receiver) = mpsc::unbounded_channel();
452 server.add_connection("test_conn".to_string(), sender).await;
453
454 let message = WebSocketMessage::Text("Hello".to_string());
456 let result = server.send_to("test_conn", message.clone()).await;
457 assert!(result.is_ok());
458
459 let received = receiver.recv().await;
461 assert!(received.is_some());
462 assert_eq!(received.unwrap(), WebSocketMessage::Text("Hello".to_string()));
463
464 let result = server.send_to("nonexistent", message).await;
466 assert!(result.is_err());
467 });
468 }
469
470 #[test]
471 fn test_broadcast_message() {
472 let config = Config::default();
473 let server = WebSocketServer::new(config);
474
475 futures::executor::block_on(async {
476 let (sender1, mut receiver1) = mpsc::unbounded_channel();
478 let (sender2, mut receiver2) = mpsc::unbounded_channel();
479 let (sender3, mut receiver3) = mpsc::unbounded_channel();
480
481 server.add_connection("conn1".to_string(), sender1).await;
482 server.add_connection("conn2".to_string(), sender2).await;
483 server.add_connection("conn3".to_string(), sender3).await;
484
485 let message = WebSocketMessage::Text("Broadcast message".to_string());
487 server.broadcast(message).await;
488
489 assert_eq!(
491 receiver1.recv().await.unwrap(),
492 WebSocketMessage::Text("Broadcast message".to_string())
493 );
494 assert_eq!(
495 receiver2.recv().await.unwrap(),
496 WebSocketMessage::Text("Broadcast message".to_string())
497 );
498 assert_eq!(
499 receiver3.recv().await.unwrap(),
500 WebSocketMessage::Text("Broadcast message".to_string())
501 );
502 });
503 }
504
505 #[test]
506 fn test_handle_text_message() {
507 let config = Config::default();
508 let server = WebSocketServer::new(config);
509
510 let message = Message::Text("Hello".into());
511 let result = server.handle_message(message).unwrap();
512
513 assert!(result.is_some());
514 assert_eq!(result.unwrap(), WebSocketMessage::Text("Hello".to_string()));
515 }
516
517 #[test]
518 fn test_handle_binary_message() {
519 let config = Config::default();
520 let server = WebSocketServer::new(config);
521
522 let message = Message::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
523 let result = server.handle_message(message).unwrap();
524
525 assert!(result.is_some());
526 let ws_message = result.unwrap();
527 assert!(matches!(ws_message, WebSocketMessage::Binary(_)));
528 }
529
530 #[test]
531 fn test_handle_ping_message() {
532 let config = Config::default();
533 let server = WebSocketServer::new(config);
534
535 let message = Message::Ping(tungstenite::Bytes::from(vec![1, 2, 3]));
536 let result = server.handle_message(message).unwrap();
537
538 assert!(result.is_some());
539 let ws_message = result.unwrap();
540 assert!(matches!(ws_message, WebSocketMessage::Pong(_)));
541 }
542
543 #[test]
544 fn test_handle_pong_message() {
545 let config = Config::default();
546 let server = WebSocketServer::new(config);
547
548 let message = Message::Pong(tungstenite::Bytes::from(vec![1, 2, 3]));
549 let result = server.handle_message(message).unwrap();
550
551 assert!(result.is_some());
552 let ws_message = result.unwrap();
553 assert!(matches!(ws_message, WebSocketMessage::Pong(_)));
554 }
555
556 #[test]
557 fn test_handle_close_message_with_frame() {
558 let config = Config::default();
559 let server = WebSocketServer::new(config);
560
561 let message = Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
563 code: tungstenite::protocol::frame::coding::CloseCode::Normal,
564 reason: "Normal closure".into(),
565 }));
566 let result = server.handle_message(message).unwrap();
567
568 assert!(result.is_some());
569 let ws_message = result.unwrap();
570 match ws_message {
571 WebSocketMessage::Close { code, reason } => {
572 assert_eq!(code, 1000);
573 assert_eq!(reason, "Normal closure");
574 }
575 _ => panic!("Expected Close message"),
576 }
577 }
578
579 #[test]
580 fn test_handle_close_message_without_frame() {
581 let config = Config::default();
582 let server = WebSocketServer::new(config);
583
584 let message = Message::Close(None);
585 let result = server.handle_message(message).unwrap();
586
587 assert!(result.is_some());
588 let ws_message = result.unwrap();
589 match ws_message {
590 WebSocketMessage::Close { code, reason } => {
591 assert_eq!(code, 1000);
592 assert_eq!(reason, "");
593 }
594 _ => panic!("Expected Close message"),
595 }
596 }
597
598 #[test]
599 fn test_to_tungstenite_text() {
600 let config = Config::default();
601 let server = WebSocketServer::new(config);
602
603 let message = WebSocketMessage::Text("Hello".to_string());
604 let converted = server.to_tungstenite_message(&message);
605
606 assert!(matches!(converted, Message::Text(_)));
607 if let Message::Text(text) = converted {
608 assert_eq!(text, "Hello");
609 }
610 }
611
612 #[test]
613 fn test_to_tungstenite_binary() {
614 let config = Config::default();
615 let server = WebSocketServer::new(config);
616
617 let message = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
618 let converted = server.to_tungstenite_message(&message);
619
620 assert!(matches!(converted, Message::Binary(_)));
621 }
622
623 #[test]
624 fn test_to_tungstenite_ping() {
625 let config = Config::default();
626 let server = WebSocketServer::new(config);
627
628 let message = WebSocketMessage::Ping(tungstenite::Bytes::from(vec![1, 2, 3]));
629 let converted = server.to_tungstenite_message(&message);
630
631 assert!(matches!(converted, Message::Ping(_)));
632 }
633
634 #[test]
635 fn test_to_tungstenite_pong() {
636 let config = Config::default();
637 let server = WebSocketServer::new(config);
638
639 let message = WebSocketMessage::Pong(tungstenite::Bytes::from(vec![1, 2, 3]));
640 let converted = server.to_tungstenite_message(&message);
641
642 assert!(matches!(converted, Message::Pong(_)));
643 }
644
645 #[test]
646 fn test_to_tungstenite_close() {
647 let config = Config::default();
648 let server = WebSocketServer::new(config);
649
650 let message = WebSocketMessage::Close {
651 code: 1000,
652 reason: "Normal closure".to_string(),
653 };
654 let converted = server.to_tungstenite_message(&message);
655
656 assert!(matches!(converted, Message::Close(_)));
657 if let Message::Close(Some(frame)) = converted {
658 assert_eq!(frame.code, tungstenite::protocol::frame::coding::CloseCode::Normal);
659 assert_eq!(frame.reason, std::borrow::Cow::from("Normal closure"));
660 } else {
661 panic!("Expected Close message with frame");
662 }
663 }
664
665 #[test]
666 fn test_broadcast_with_closed_connection() {
667 let config = Config::default();
668 let server = WebSocketServer::new(config);
669
670 futures::executor::block_on(async {
671 let (sender, receiver) = mpsc::unbounded_channel();
673 drop(receiver); server.add_connection("closed_conn".to_string(), sender).await;
676
677 let (sender2, mut receiver2) = mpsc::unbounded_channel();
679 server.add_connection("normal_conn".to_string(), sender2).await;
680
681 let message = WebSocketMessage::Text("Test".to_string());
683 server.broadcast(message).await;
684
685 let connections = server.list_connections().await;
687 assert!(!connections.contains(&"closed_conn".to_string()));
688 assert!(connections.contains(&"normal_conn".to_string()));
689
690 assert!(receiver2.recv().await.is_some());
692 });
693 }
694
695 #[test]
696 fn test_send_to_connection_success() {
697 let config = Config::default();
698 let server = WebSocketServer::new(config);
699
700 futures::executor::block_on(async {
701 let (sender, mut receiver) = mpsc::unbounded_channel();
702 server.add_connection("test_conn".to_string(), sender).await;
703
704 let message = WebSocketMessage::Text("Hello".to_string());
705 let result = server.send_to("test_conn", message).await;
706 assert!(result.is_ok());
707
708 let received = receiver.recv().await;
709 assert!(received.is_some());
710 });
711 }
712
713 #[test]
714 fn test_send_to_connection_not_found() {
715 let config = Config::default();
716 let server = WebSocketServer::new(config);
717
718 futures::executor::block_on(async {
719 let message = WebSocketMessage::Text("Hello".to_string());
720 let result = server.send_to("nonexistent", message).await;
721 assert!(result.is_err());
722 assert!(result.unwrap_err().to_string().contains("not found"));
723 });
724 }
725
726 #[test]
727 fn test_send_to_connection_closed() {
728 let config = Config::default();
729 let server = WebSocketServer::new(config);
730
731 futures::executor::block_on(async {
732 let (sender, receiver) = mpsc::unbounded_channel();
733 drop(receiver); server.add_connection("test_conn".to_string(), sender).await;
736
737 let message = WebSocketMessage::Text("Hello".to_string());
738 let result = server.send_to("test_conn", message).await;
739 assert!(result.is_err());
740 assert!(result.unwrap_err().to_string().contains("closed"));
741 });
742 }
743
744 #[test]
745 fn test_websocket_upgrade_missing_headers() {
746 let req = Request::builder()
747 .method("GET")
748 .uri("/ws")
749 .body(Full::new(Bytes::new()))
750 .unwrap();
751
752 assert!(!WebSocketServer::is_websocket_upgrade(&req));
754 }
755
756 #[test]
757 fn test_websocket_upgrade_only_upgrade_header() {
758 let req = Request::builder()
759 .method("GET")
760 .uri("/ws")
761 .header("Upgrade", "websocket")
762 .body(Full::new(Bytes::new()))
763 .unwrap();
764
765 assert!(!WebSocketServer::is_websocket_upgrade(&req));
767 }
768
769 #[test]
770 fn test_websocket_upgrade_wrong_method() {
771 let req = Request::builder()
772 .method("POST")
773 .uri("/ws")
774 .header("Upgrade", "websocket")
775 .header("Connection", "Upgrade")
776 .body(Full::new(Bytes::new()))
777 .unwrap();
778
779 assert!(!WebSocketServer::is_websocket_upgrade(&req));
781 }
782
783 #[test]
784 fn test_handshake_missing_key() {
785 let req = Request::builder()
786 .method("GET")
787 .uri("/ws")
788 .header("Upgrade", "websocket")
789 .header("Connection", "Upgrade")
790 .header("Sec-WebSocket-Version", "13")
791 .body(Full::new(Bytes::new()))
792 .unwrap();
793
794 let result = WebSocketServer::handshake(&req);
795 assert!(result.is_err());
796 }
797
798 #[test]
799 fn test_handshake_missing_version() {
800 let req = Request::builder()
801 .method("GET")
802 .uri("/ws")
803 .header("Upgrade", "websocket")
804 .header("Connection", "Upgrade")
805 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
806 .body(Full::new(Bytes::new()))
807 .unwrap();
808
809 let result = WebSocketServer::handshake(&req);
810 assert!(result.is_err());
811 }
812
813 #[test]
814 fn test_handshake_wrong_version() {
815 let req = Request::builder()
816 .method("GET")
817 .uri("/ws")
818 .header("Upgrade", "websocket")
819 .header("Connection", "Upgrade")
820 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
821 .header("Sec-WebSocket-Version", "12")
822 .body(Full::new(Bytes::new()))
823 .unwrap();
824
825 let result = WebSocketServer::handshake(&req);
826 assert!(result.is_err());
827 }
828
829 #[tokio::test]
830 async fn test_add_and_remove_connection() {
831 let config = Config::default();
832 let server = WebSocketServer::new(config);
833
834 let (sender1, _) = mpsc::unbounded_channel();
835 let (sender2, _) = mpsc::unbounded_channel();
836
837 server.add_connection("conn1".to_string(), sender1).await;
838 server.add_connection("conn2".to_string(), sender2).await;
839
840 let count = server.connection_count().await;
841 assert_eq!(count, 2);
842
843 server.remove_connection("conn1").await;
844 let count = server.connection_count().await;
845 assert_eq!(count, 1);
846
847 server.remove_connection("conn2").await;
848 let count = server.connection_count().await;
849 assert_eq!(count, 0);
850 }
851
852 #[tokio::test]
853 async fn test_broadcast_empty_connections() {
854 let config = Config::default();
855 let server = WebSocketServer::new(config);
856
857 let message = WebSocketMessage::Text("Hello".to_string());
859 server.broadcast(message).await;
860 }
862
863 #[test]
864 fn test_websocket_upgrade_case_insensitive() {
865 let req = Request::builder()
867 .method("GET")
868 .uri("/ws")
869 .header("Upgrade", "WebSocket") .header("Connection", "upgrade") .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
872 .header("Sec-WebSocket-Version", "13")
873 .body(Full::new(Bytes::new()))
874 .unwrap();
875
876 assert!(WebSocketServer::is_websocket_upgrade(&req));
877
878 let req2 = Request::builder()
880 .method("GET")
881 .uri("/ws")
882 .header("Upgrade", "WEBSOCKET")
883 .header("Connection", "UPGRADE")
884 .body(Full::new(Bytes::new()))
885 .unwrap();
886
887 assert!(WebSocketServer::is_websocket_upgrade(&req2));
888 }
889
890 #[test]
891 fn test_websocket_upgrade_non_utf8_upgrade_header() {
892 use hyper::header::HeaderValue;
894
895 let req = Request::builder()
896 .method("GET")
897 .uri("/ws")
898 .header("Upgrade", HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap()) .header("Connection", "Upgrade")
900 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
901 .header("Sec-WebSocket-Version", "13")
902 .body(Full::new(Bytes::new()))
903 .unwrap();
904
905 assert!(!WebSocketServer::is_websocket_upgrade(&req));
907 }
908
909 #[test]
910 fn test_websocket_upgrade_non_utf8_connection_header() {
911 use hyper::header::HeaderValue;
913
914 let req = Request::builder()
915 .method("GET")
916 .uri("/ws")
917 .header("Upgrade", "websocket")
918 .header("Connection", HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap()) .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
920 .header("Sec-WebSocket-Version", "13")
921 .body(Full::new(Bytes::new()))
922 .unwrap();
923
924 assert!(!WebSocketServer::is_websocket_upgrade(&req));
926 }
927
928 #[test]
929 fn test_generate_accept_key_empty() {
930 let accept_key = WebSocketServer::generate_accept_key("").unwrap();
932 assert_eq!(accept_key.len(), 28); assert!(!accept_key.is_empty());
935 }
936
937 #[test]
938 fn test_generate_accept_key_long_key() {
939 let long_key = "a".repeat(100);
941 let accept_key = WebSocketServer::generate_accept_key(&long_key).unwrap();
942 assert_eq!(accept_key.len(), 28); }
944
945 #[test]
946 fn test_generate_accept_key_special_chars() {
947 let special_key = "+/=special&chars%$#@!";
949 let accept_key = WebSocketServer::generate_accept_key(special_key).unwrap();
950 assert_eq!(accept_key.len(), 28);
951 }
952
953 #[test]
954 fn test_handle_frame_message() {
955 let config = Config::default();
956 let server = WebSocketServer::new(config);
957
958 let header = tungstenite::protocol::frame::FrameHeader::default();
961 let payload = tungstenite::Bytes::from_static(&[0x01, 0x02, 0x03]);
962 let frame = tungstenite::protocol::frame::Frame::from_payload(header, payload);
963 let frame_msg = Message::Frame(frame);
964 let result = server.handle_message(frame_msg).unwrap();
965
966 assert!(result.is_none());
968 }
969
970 #[test]
971 fn test_handshake_success_response_headers() {
972 let req = Request::builder()
973 .method("GET")
974 .uri("/ws")
975 .header("Upgrade", "websocket")
976 .header("Connection", "Upgrade")
977 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
978 .header("Sec-WebSocket-Version", "13")
979 .body(Full::new(Bytes::new()))
980 .unwrap();
981
982 let response = WebSocketServer::handshake(&req).unwrap();
983
984 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
986
987 assert_eq!(
989 response.headers().get(UPGRADE).unwrap().to_str().unwrap(),
990 "websocket"
991 );
992
993 assert_eq!(
995 response.headers().get(CONNECTION).unwrap().to_str().unwrap(),
996 "Upgrade"
997 );
998
999 assert_eq!(
1001 response.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap().to_str().unwrap(),
1002 "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
1003 );
1004 }
1005
1006 #[test]
1007 fn test_websocket_message_debug() {
1008 let text_msg = WebSocketMessage::Text("Hello".to_string());
1010 let debug_str = format!("{:?}", text_msg);
1011 assert!(debug_str.contains("Text"));
1012 assert!(debug_str.contains("Hello"));
1013
1014 let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
1015 let debug_str = format!("{:?}", binary_msg);
1016 assert!(debug_str.contains("Binary"));
1017
1018 let close_msg = WebSocketMessage::Close {
1019 code: 1000,
1020 reason: "Normal".to_string(),
1021 };
1022 let debug_str = format!("{:?}", close_msg);
1023 assert!(debug_str.contains("Close"));
1024 assert!(debug_str.contains("1000"));
1025 }
1026
1027 #[test]
1028 fn test_websocket_message_clone() {
1029 let msg = WebSocketMessage::Text("Hello".to_string());
1031 let cloned = msg.clone();
1032 assert_eq!(msg, cloned);
1033
1034 let close_msg = WebSocketMessage::Close {
1035 code: 1001,
1036 reason: "Going away".to_string(),
1037 };
1038 let cloned_close = close_msg.clone();
1039 assert_eq!(close_msg, cloned_close);
1040 }
1041
1042 #[tokio::test]
1043 async fn test_websocket_handler_new() {
1044 let config = Config::default();
1045 let server = Arc::new(WebSocketServer::new(config));
1046
1047 let (handler, sender) = WebSocketHandler::new(server, "test-conn-123".to_string());
1048
1049 assert_eq!(handler.connection_id, "test-conn-123");
1050 assert!(handler.receiver.is_some());
1051
1052 let msg = WebSocketMessage::Text("Test".to_string());
1054 assert!(sender.send(msg).is_ok());
1055 }
1056
1057 #[tokio::test]
1058 async fn test_broadcast_multiple_messages() {
1059 let config = Config::default();
1060 let server = WebSocketServer::new(config);
1061
1062 let (sender, mut receiver) = mpsc::unbounded_channel();
1063 server.add_connection("conn1".to_string(), sender).await;
1064
1065 for i in 0..10 {
1067 let msg = WebSocketMessage::Text(format!("Message {}", i));
1068 server.broadcast(msg).await;
1069 }
1070
1071 for i in 0..10 {
1073 let received = receiver.recv().await.unwrap();
1074 assert_eq!(received, WebSocketMessage::Text(format!("Message {}", i)));
1075 }
1076 }
1077
1078 #[tokio::test]
1079 async fn test_send_to_various_message_types() {
1080 let config = Config::default();
1081 let server = WebSocketServer::new(config);
1082
1083 let (sender, mut receiver) = mpsc::unbounded_channel();
1084 server.add_connection("test_conn".to_string(), sender).await;
1085
1086 let text_msg = WebSocketMessage::Text("Hello".to_string());
1088 server.send_to("test_conn", text_msg.clone()).await.unwrap();
1089 assert_eq!(receiver.recv().await.unwrap(), text_msg);
1090
1091 let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
1093 server.send_to("test_conn", binary_msg.clone()).await.unwrap();
1094 assert_eq!(receiver.recv().await.unwrap(), binary_msg);
1095
1096 let ping_msg = WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6]));
1098 server.send_to("test_conn", ping_msg.clone()).await.unwrap();
1099 assert_eq!(receiver.recv().await.unwrap(), ping_msg);
1100
1101 let pong_msg = WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9]));
1103 server.send_to("test_conn", pong_msg.clone()).await.unwrap();
1104 assert_eq!(receiver.recv().await.unwrap(), pong_msg);
1105
1106 let close_msg = WebSocketMessage::Close {
1108 code: 1000,
1109 reason: "Normal".to_string(),
1110 };
1111 server.send_to("test_conn", close_msg.clone()).await.unwrap();
1112 assert_eq!(receiver.recv().await.unwrap(), close_msg);
1113 }
1114
1115 #[test]
1116 fn test_handle_close_with_different_codes() {
1117 let config = Config::default();
1118 let server = WebSocketServer::new(config);
1119
1120 let codes = vec![
1122 (1000u16, "Normal closure"),
1123 (1001, "Going away"),
1124 (1002, "Protocol error"),
1125 (1003, "Unsupported data"),
1126 (1006, "Abnormal closure"),
1127 (1008, "Policy violation"),
1128 (1009, "Message too big"),
1129 (1010, "Mandatory extension"),
1130 (1011, "Internal error"),
1131 (1015, "TLS handshake"),
1132 ];
1133
1134 for (code, reason) in codes {
1135 let close_frame = tungstenite::protocol::frame::CloseFrame {
1136 code: tungstenite::protocol::frame::coding::CloseCode::from(code),
1137 reason: reason.into(),
1138 };
1139 let message = Message::Close(Some(close_frame));
1140 let result = server.handle_message(message).unwrap();
1141
1142 match result {
1143 Some(WebSocketMessage::Close { code: c, reason: r }) => {
1144 assert_eq!(c, code);
1145 assert_eq!(r, reason);
1146 }
1147 _ => panic!("Expected Close message with code {}", code),
1148 }
1149 }
1150 }
1151
1152 #[test]
1153 fn test_handle_empty_binary_message() {
1154 let config = Config::default();
1155 let server = WebSocketServer::new(config);
1156
1157 let message = Message::Binary(tungstenite::Bytes::from(vec![]));
1158 let result = server.handle_message(message).unwrap();
1159
1160 assert!(result.is_some());
1161 match result.unwrap() {
1162 WebSocketMessage::Binary(data) => {
1163 assert!(data.is_empty());
1164 }
1165 _ => panic!("Expected Binary message"),
1166 }
1167 }
1168
1169 #[test]
1170 fn test_handle_empty_text_message() {
1171 let config = Config::default();
1172 let server = WebSocketServer::new(config);
1173
1174 let message = Message::Text("".into());
1175 let result = server.handle_message(message).unwrap();
1176
1177 assert!(result.is_some());
1178 match result.unwrap() {
1179 WebSocketMessage::Text(text) => {
1180 assert!(text.is_empty());
1181 }
1182 _ => panic!("Expected Text message"),
1183 }
1184 }
1185
1186 #[test]
1187 fn test_handle_empty_ping_pong() {
1188 let config = Config::default();
1189 let server = WebSocketServer::new(config);
1190
1191 let ping_msg = Message::Ping(tungstenite::Bytes::from(vec![]));
1193 let result = server.handle_message(ping_msg).unwrap();
1194 match result {
1195 Some(WebSocketMessage::Pong(data)) => {
1196 assert!(data.is_empty());
1197 }
1198 _ => panic!("Expected Pong with empty data"),
1199 }
1200
1201 let pong_msg = Message::Pong(tungstenite::Bytes::from(vec![]));
1203 let result = server.handle_message(pong_msg).unwrap();
1204 match result {
1205 Some(WebSocketMessage::Pong(data)) => {
1206 assert!(data.is_empty());
1207 }
1208 _ => panic!("Expected Pong with empty data"),
1209 }
1210 }
1211
1212 #[tokio::test]
1213 async fn test_list_connections_empty() {
1214 let config = Config::default();
1215 let server = WebSocketServer::new(config);
1216
1217 let connections = server.list_connections().await;
1218 assert!(connections.is_empty());
1219 }
1220
1221 #[tokio::test]
1222 async fn test_list_connections_order() {
1223 let config = Config::default();
1224 let server = WebSocketServer::new(config);
1225
1226 for i in 0..5 {
1228 let (sender, _) = mpsc::unbounded_channel();
1229 server.add_connection(format!("conn_{}", i), sender).await;
1230 }
1231
1232 let connections = server.list_connections().await;
1233 assert_eq!(connections.len(), 5);
1234
1235 for i in 0..5 {
1237 assert!(connections.contains(&format!("conn_{}", i)));
1238 }
1239 }
1240
1241 #[tokio::test]
1242 async fn test_remove_nonexistent_connection() {
1243 let config = Config::default();
1244 let server = WebSocketServer::new(config);
1245
1246 server.remove_connection("nonexistent").await;
1248
1249 assert_eq!(server.connection_count().await, 0);
1251 }
1252
1253 #[tokio::test]
1254 async fn test_add_duplicate_connection() {
1255 let config = Config::default();
1256 let server = WebSocketServer::new(config);
1257
1258 let (sender1, _receiver1) = mpsc::unbounded_channel();
1259 let (sender2, mut receiver2) = mpsc::unbounded_channel();
1260
1261 server.add_connection("duplicate_conn".to_string(), sender1).await;
1263 assert_eq!(server.connection_count().await, 1);
1264
1265 server.add_connection("duplicate_conn".to_string(), sender2).await;
1267 assert_eq!(server.connection_count().await, 1);
1268
1269 let msg = WebSocketMessage::Text("Test".to_string());
1271 server.send_to("duplicate_conn", msg.clone()).await.unwrap();
1272 assert_eq!(receiver2.recv().await.unwrap(), msg);
1273 }
1274
1275 #[test]
1276 fn test_is_websocket_upgrade_variations() {
1277 let req = Request::builder()
1279 .method("GET")
1280 .uri("/ws")
1281 .header("Upgrade", "websocket")
1282 .header("Connection", "keep-alive, Upgrade")
1283 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
1284 .header("Sec-WebSocket-Version", "13")
1285 .body(Full::new(Bytes::new()))
1286 .unwrap();
1287
1288 assert!(WebSocketServer::is_websocket_upgrade(&req));
1289
1290 let req = Request::builder()
1292 .method("GET")
1293 .uri("/ws")
1294 .header("Upgrade", "websocket")
1295 .header("Connection", "Upgrade, keep-alive")
1296 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
1297 .header("Sec-WebSocket-Version", "13")
1298 .body(Full::new(Bytes::new()))
1299 .unwrap();
1300
1301 assert!(WebSocketServer::is_websocket_upgrade(&req));
1302 }
1303
1304 #[test]
1305 fn test_websocket_server_new() {
1306 let config = Config::default();
1307 let server = WebSocketServer::new(config);
1308
1309 futures::executor::block_on(async {
1312 assert_eq!(server.connection_count().await, 0);
1313 });
1314 }
1315
1316 #[test]
1317 fn test_to_tungstenite_close_various_codes() {
1318 let config = Config::default();
1319 let server = WebSocketServer::new(config);
1320
1321 let test_cases = vec![
1323 (1000u16, "Normal"),
1324 (1001, "Going away"),
1325 (1006, "Abnormal"),
1326 (1011, "Error"),
1327 ];
1328
1329 for (code, reason) in test_cases {
1330 let msg = WebSocketMessage::Close {
1331 code,
1332 reason: reason.to_string(),
1333 };
1334 let converted = server.to_tungstenite_message(&msg);
1335
1336 match converted {
1337 Message::Close(Some(frame)) => {
1338 let frame_code: u16 = frame.code.into();
1340 assert_eq!(frame_code, code);
1341 assert_eq!(frame.reason, std::borrow::Cow::from(reason));
1342 }
1343 _ => panic!("Expected Close message with frame for code {}", code),
1344 }
1345 }
1346 }
1347
1348 #[tokio::test]
1349 async fn test_broadcast_binary_message() {
1350 let config = Config::default();
1351 let server = WebSocketServer::new(config);
1352
1353 let (sender, mut receiver) = mpsc::unbounded_channel();
1354 server.add_connection("conn1".to_string(), sender).await;
1355
1356 let binary_data = vec![0u8, 1, 2, 3, 255, 254, 253];
1358 let msg = WebSocketMessage::Binary(tungstenite::Bytes::from(binary_data.clone()));
1359 server.broadcast(msg).await;
1360
1361 let received = receiver.recv().await.unwrap();
1362 match received {
1363 WebSocketMessage::Binary(data) => {
1364 assert_eq!(data.to_vec(), binary_data);
1365 }
1366 _ => panic!("Expected Binary message"),
1367 }
1368 }
1369
1370 #[tokio::test]
1371 async fn test_broadcast_ping_pong() {
1372 let config = Config::default();
1373 let server = WebSocketServer::new(config);
1374
1375 let (sender, mut receiver) = mpsc::unbounded_channel();
1376 server.add_connection("conn1".to_string(), sender).await;
1377
1378 let ping_data = vec![1, 2, 3];
1380 let ping_msg = WebSocketMessage::Ping(tungstenite::Bytes::from(ping_data.clone()));
1381 server.broadcast(ping_msg).await;
1382
1383 let received = receiver.recv().await.unwrap();
1384 match received {
1385 WebSocketMessage::Ping(data) => {
1386 assert_eq!(data.to_vec(), ping_data);
1387 }
1388 _ => panic!("Expected Ping message"),
1389 }
1390
1391 let pong_data = vec![4, 5, 6];
1393 let pong_msg = WebSocketMessage::Pong(tungstenite::Bytes::from(pong_data.clone()));
1394 server.broadcast(pong_msg).await;
1395
1396 let received = receiver.recv().await.unwrap();
1397 match received {
1398 WebSocketMessage::Pong(data) => {
1399 assert_eq!(data.to_vec(), pong_data);
1400 }
1401 _ => panic!("Expected Pong message"),
1402 }
1403 }
1404
1405 #[tokio::test]
1406 async fn test_broadcast_close_message() {
1407 let config = Config::default();
1408 let server = WebSocketServer::new(config);
1409
1410 let (sender, mut receiver) = mpsc::unbounded_channel();
1411 server.add_connection("conn1".to_string(), sender).await;
1412
1413 let close_msg = WebSocketMessage::Close {
1415 code: 1000,
1416 reason: "Server shutting down".to_string(),
1417 };
1418 server.broadcast(close_msg.clone()).await;
1419
1420 let received = receiver.recv().await.unwrap();
1421 assert_eq!(received, close_msg);
1422 }
1423
1424 #[test]
1425 fn test_websocket_connection_struct() {
1426 let conn = WebSocketConnection {
1428 id: "test-conn-123".to_string(),
1429 remote_addr: "192.168.1.1:12345".to_string(),
1430 connected_at: std::time::Instant::now(),
1431 };
1432
1433 assert_eq!(conn.id, "test-conn-123");
1435 assert_eq!(conn.remote_addr, "192.168.1.1:12345");
1436
1437 let cloned = conn.clone();
1439 assert_eq!(cloned.id, conn.id);
1440 assert_eq!(cloned.remote_addr, conn.remote_addr);
1441
1442 let debug_str = format!("{:?}", conn);
1444 assert!(debug_str.contains("test-conn-123"));
1445 assert!(debug_str.contains("192.168.1.1:12345"));
1446 }
1447
1448 #[test]
1449 fn test_websocket_handler_receiver_already_taken() {
1450 let config = Config::default();
1451 let server = Arc::new(WebSocketServer::new(config));
1452
1453 let (mut handler, _sender) = WebSocketHandler::new(server, "test_conn".to_string());
1454
1455 let _receiver = handler.receiver.take().unwrap();
1457
1458 assert!(handler.receiver.is_none());
1460 }
1461
1462 #[tokio::test]
1463 async fn test_websocket_handler_drop_receiver() {
1464 let config = Config::default();
1465 let server = Arc::new(WebSocketServer::new(config));
1466
1467 let (handler, sender) = WebSocketHandler::new(server, "test_conn".to_string());
1468
1469 drop(handler);
1471
1472 let msg = WebSocketMessage::Text("Test".to_string());
1474 assert!(sender.send(msg).is_err());
1476 }
1477
1478 #[test]
1479 fn test_websocket_message_variants() {
1480 let text = WebSocketMessage::Text("hello".to_string());
1482 let binary = WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3]));
1483 let ping = WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6]));
1484 let pong = WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9]));
1485 let close = WebSocketMessage::Close {
1486 code: 1000,
1487 reason: "Normal".to_string(),
1488 };
1489
1490 assert_eq!(text, WebSocketMessage::Text("hello".to_string()));
1492 assert_eq!(binary, WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3])));
1493 assert_eq!(ping, WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6])));
1494 assert_eq!(pong, WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9])));
1495 assert_eq!(close.clone(), close);
1496
1497 assert_ne!(text, WebSocketMessage::Text("world".to_string()));
1499 assert_ne!(binary, WebSocketMessage::Binary(tungstenite::Bytes::from(vec![4, 5, 6])));
1500 }
1501
1502 #[test]
1503 fn test_is_websocket_upgrade_no_headers() {
1504 let req = Request::builder()
1506 .method("GET")
1507 .uri("/ws")
1508 .body(Full::new(Bytes::new()))
1509 .unwrap();
1510
1511 assert!(!WebSocketServer::is_websocket_upgrade(&req));
1512 }
1513
1514 #[test]
1515 fn test_is_websocket_upgrade_no_upgrade_header() {
1516 let req = Request::builder()
1518 .method("GET")
1519 .uri("/ws")
1520 .header("Connection", "Upgrade")
1521 .body(Full::new(Bytes::new()))
1522 .unwrap();
1523
1524 assert!(!WebSocketServer::is_websocket_upgrade(&req));
1525 }
1526
1527 #[test]
1528 fn test_is_websocket_upgrade_no_connection_header() {
1529 let req = Request::builder()
1531 .method("GET")
1532 .uri("/ws")
1533 .header("Upgrade", "websocket")
1534 .body(Full::new(Bytes::new()))
1535 .unwrap();
1536
1537 assert!(!WebSocketServer::is_websocket_upgrade(&req));
1538 }
1539
1540 #[test]
1541 fn test_is_websocket_upgrade_wrong_upgrade_value() {
1542 let req = Request::builder()
1544 .method("GET")
1545 .uri("/ws")
1546 .header("Upgrade", "http2")
1547 .header("Connection", "Upgrade")
1548 .body(Full::new(Bytes::new()))
1549 .unwrap();
1550
1551 assert!(!WebSocketServer::is_websocket_upgrade(&req));
1552 }
1553
1554 #[test]
1555 fn test_is_websocket_upgrade_connection_without_upgrade() {
1556 let req = Request::builder()
1558 .method("GET")
1559 .uri("/ws")
1560 .header("Upgrade", "websocket")
1561 .header("Connection", "keep-alive")
1562 .body(Full::new(Bytes::new()))
1563 .unwrap();
1564
1565 assert!(!WebSocketServer::is_websocket_upgrade(&req));
1566 }
1567
1568 #[test]
1569 fn test_handshake_response_body_empty() {
1570 let req = Request::builder()
1571 .method("GET")
1572 .uri("/ws")
1573 .header("Upgrade", "websocket")
1574 .header("Connection", "Upgrade")
1575 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
1576 .header("Sec-WebSocket-Version", "13")
1577 .body(Full::new(Bytes::new()))
1578 .unwrap();
1579
1580 let response = WebSocketServer::handshake(&req).unwrap();
1581
1582 assert_eq!(response.status(), 101);
1585 }
1586
1587 #[tokio::test]
1588 async fn test_connection_count_after_operations() {
1589 let config = Config::default();
1590 let server = WebSocketServer::new(config);
1591
1592 assert_eq!(server.connection_count().await, 0);
1594
1595 for i in 0..10 {
1597 let (sender, _) = mpsc::unbounded_channel();
1598 server.add_connection(format!("conn{}", i), sender).await;
1599 assert_eq!(server.connection_count().await, i + 1);
1600 }
1601
1602 for i in (0..10).rev() {
1604 server.remove_connection(&format!("conn{}", i)).await;
1605 assert_eq!(server.connection_count().await, i);
1606 }
1607
1608 assert_eq!(server.connection_count().await, 0);
1610 }
1611
1612 #[tokio::test]
1613 async fn test_broadcast_after_remove_all() {
1614 let config = Config::default();
1615 let server = WebSocketServer::new(config);
1616
1617 let (sender1, _receiver1) = mpsc::unbounded_channel();
1619 let (sender2, _receiver2) = mpsc::unbounded_channel();
1620 server.add_connection("conn1".to_string(), sender1).await;
1621 server.add_connection("conn2".to_string(), sender2).await;
1622
1623 server.remove_connection("conn1").await;
1625 server.remove_connection("conn2").await;
1626
1627 let msg = WebSocketMessage::Text("Test".to_string());
1629 server.broadcast(msg).await;
1630
1631 assert_eq!(server.connection_count().await, 0);
1633 }
1634
1635 #[test]
1636 fn test_generate_accept_key_unicode() {
1637 let unicode_key = "ππποΈε―ι₯";
1639 let accept_key = WebSocketServer::generate_accept_key(unicode_key).unwrap();
1640 assert_eq!(accept_key.len(), 28);
1641
1642 let emoji_key = "πππ";
1644 let accept_key = WebSocketServer::generate_accept_key(emoji_key).unwrap();
1645 assert_eq!(accept_key.len(), 28);
1646 }
1647
1648 #[test]
1649 fn test_to_tungstenite_message_preserves_data() {
1650 let config = Config::default();
1651 let server = WebSocketServer::new(config);
1652
1653 let text_data = "Hello, δΈη! π";
1655 let text_msg = WebSocketMessage::Text(text_data.to_string());
1656 let converted = server.to_tungstenite_message(&text_msg);
1657 if let Message::Text(text) = converted {
1658 assert_eq!(text, text_data);
1659 } else {
1660 panic!("Expected Text message");
1661 }
1662
1663 let binary_data: Vec<u8> = (0..256).map(|i| i as u8).collect();
1665 let binary_msg = WebSocketMessage::Binary(tungstenite::Bytes::from(binary_data.clone()));
1666 let converted = server.to_tungstenite_message(&binary_msg);
1667 if let Message::Binary(data) = converted {
1668 assert_eq!(data.to_vec(), binary_data);
1669 } else {
1670 panic!("Expected Binary message");
1671 }
1672 }
1673
1674 #[test]
1675 fn test_websocket_message_close_edge_codes() {
1676 let config = Config::default();
1677 let server = WebSocketServer::new(config);
1678
1679 let edge_codes = vec![
1681 0u16, 999, 1000, 1001, 2999, 3000, 3999, 4000, 4999, ];
1691
1692 for code in edge_codes {
1693 let close_frame = tungstenite::protocol::frame::CloseFrame {
1694 code: tungstenite::protocol::frame::coding::CloseCode::from(code),
1695 reason: "Test".into(),
1696 };
1697 let message = Message::Close(Some(close_frame));
1698 let result = server.handle_message(message).unwrap();
1699
1700 match result {
1701 Some(WebSocketMessage::Close { code: c, .. }) => {
1702 assert_eq!(c, code);
1703 }
1704 _ => panic!("Expected Close message with code {}", code),
1705 }
1706 }
1707 }
1708
1709 #[tokio::test]
1710 async fn test_send_to_all_message_types() {
1711 let config = Config::default();
1712 let server = WebSocketServer::new(config);
1713
1714 let (sender, mut receiver) = mpsc::unbounded_channel();
1715 server.add_connection("test_conn".to_string(), sender).await;
1716
1717 let test_cases = vec![
1719 WebSocketMessage::Text("Hello".to_string()),
1720 WebSocketMessage::Binary(tungstenite::Bytes::from(vec![1, 2, 3])),
1721 WebSocketMessage::Ping(tungstenite::Bytes::from(vec![4, 5, 6])),
1722 WebSocketMessage::Pong(tungstenite::Bytes::from(vec![7, 8, 9])),
1723 WebSocketMessage::Close {
1724 code: 1000,
1725 reason: "Normal".to_string(),
1726 },
1727 ];
1728
1729 for msg in test_cases {
1730 let msg_clone = msg.clone();
1731 server.send_to("test_conn", msg).await.unwrap();
1732 let received = receiver.recv().await.unwrap();
1733 assert_eq!(received, msg_clone);
1734 }
1735 }
1736}