1use anyhow::Result;
4use futures_util::{SinkExt, StreamExt};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::net::{TcpListener, TcpStream};
10use tokio::sync::{broadcast, RwLock};
11use tokio_tungstenite::{accept_async, tungstenite::Message};
12use uuid::Uuid;
13
14use crate::progress::{ProgressManager, ProgressUpdate};
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "type")]
19pub enum WebSocketMessage {
20 Subscribe { operation_id: Option<Uuid> },
22 Unsubscribe { operation_id: Option<Uuid> },
24 ProgressUpdate(ProgressUpdate),
26 Error { message: String },
28 Ping,
30 Pong,
32}
33
34#[derive(Debug)]
36pub struct WebSocketClient {
37 id: Uuid,
38 #[allow(dead_code)]
39 sender: crossbeam_channel::Sender<ProgressUpdate>,
40 subscriptions: Arc<RwLock<Vec<Uuid>>>,
41}
42
43impl WebSocketClient {
44 #[must_use]
46 pub fn new(sender: crossbeam_channel::Sender<ProgressUpdate>) -> Self {
47 Self {
48 id: Uuid::new_v4(),
49 sender,
50 subscriptions: Arc::new(RwLock::new(Vec::new())),
51 }
52 }
53
54 pub async fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) -> Result<()> {
59 let ws_stream = accept_async(stream).await?;
60 let (ws_sender, mut ws_receiver) = ws_stream.split();
61
62 let subscriptions = self.subscriptions.clone();
63 let client_id = self.id;
64
65 log::info!("New WebSocket connection from {addr}");
66
67 let subscriptions_clone = subscriptions.clone();
69 let ws_sender = Arc::new(tokio::sync::Mutex::new(ws_sender));
70
71 tokio::spawn(async move {
72 while let Some(msg) = ws_receiver.next().await {
73 match msg {
74 Ok(Message::Text(text)) => {
75 if let Ok(ws_msg) = serde_json::from_str::<WebSocketMessage>(&text) {
76 match ws_msg {
77 WebSocketMessage::Subscribe { operation_id } => {
78 let mut subs = subscriptions_clone.write().await;
79 if let Some(op_id) = operation_id {
80 if !subs.contains(&op_id) {
81 subs.push(op_id);
82 }
83 }
84 log::debug!("Client {client_id} subscribed to operation {operation_id:?}");
85 }
86 WebSocketMessage::Unsubscribe { operation_id } => {
87 let mut subs = subscriptions_clone.write().await;
88 if let Some(op_id) = operation_id {
89 subs.retain(|&id| id != op_id);
90 } else {
91 subs.clear();
92 }
93 log::debug!("Client {client_id} unsubscribed from operation {operation_id:?}");
94 }
95 WebSocketMessage::Ping => {
96 let pong = WebSocketMessage::Pong;
98 if let Ok(pong_text) = serde_json::to_string(&pong) {
99 let mut sender = ws_sender.lock().await;
100 let _ = sender.send(Message::Text(pong_text)).await;
101 }
102 }
103 _ => {
104 log::warn!(
105 "Client {client_id} sent unexpected message: {ws_msg:?}"
106 );
107 }
108 }
109 } else {
110 log::warn!("Client {client_id} sent invalid JSON: {text}");
111 }
112 }
113 Ok(Message::Close(_)) => {
114 log::info!("Client {client_id} disconnected");
115 break;
116 }
117 Ok(Message::Ping(data)) => {
118 let mut sender = ws_sender.lock().await;
119 if let Err(e) = sender.send(Message::Pong(data)).await {
120 log::error!("Failed to send pong to client {client_id}: {e}");
121 break;
122 }
123 }
124 Err(e) => {
125 log::error!("WebSocket error for client {client_id}: {e}");
126 break;
127 }
128 _ => {}
129 }
130 }
131 });
132
133 Ok(())
134 }
135}
136
137#[derive(Debug)]
139pub struct WebSocketServer {
140 progress_manager: Arc<ProgressManager>,
141 clients: Arc<RwLock<HashMap<Uuid, WebSocketClient>>>,
142 port: u16,
143}
144
145impl WebSocketServer {
146 #[must_use]
148 pub fn new(port: u16) -> Self {
149 Self {
150 progress_manager: Arc::new(ProgressManager::new()),
151 clients: Arc::new(RwLock::new(HashMap::new())),
152 port,
153 }
154 }
155
156 #[must_use]
158 pub fn progress_manager(&self) -> Arc<ProgressManager> {
159 self.progress_manager.clone()
160 }
161
162 pub async fn start(&self) -> Result<()> {
167 let addr = format!("127.0.0.1:{}", self.port);
168 let listener = TcpListener::bind(&addr).await?;
169
170 log::info!("WebSocket server listening on {addr}");
171
172 let progress_manager = self.progress_manager.clone();
174 tokio::spawn(async move {
175 let _ = progress_manager.run();
176 });
177
178 let clients = self.clients.clone();
179 let progress_sender = self.progress_manager.sender();
180
181 while let Ok((stream, addr)) = listener.accept().await {
182 let client = WebSocketClient::new(progress_sender.clone());
183 let client_id = client.id;
184
185 {
187 let mut clients = clients.write().await;
188 clients.insert(client_id, client);
189 }
190
191 let clients_clone = clients.clone();
193 tokio::spawn(async move {
194 if let Some(client) = clients_clone.read().await.get(&client_id) {
195 if let Err(e) = client.handle_connection(stream, addr).await {
196 log::error!("Error handling WebSocket connection from {addr}: {e}");
197 }
198 }
199
200 clients_clone.write().await.remove(&client_id);
202 });
203 }
204
205 Ok(())
206 }
207
208 pub async fn client_count(&self) -> usize {
210 self.clients.read().await.len()
211 }
212
213 pub async fn broadcast(&self, message: WebSocketMessage) -> Result<()> {
218 let clients = self.clients.read().await;
219 let _message_text = serde_json::to_string(&message)?;
220
221 for client in clients.values() {
222 log::debug!("Broadcasting message to client {}", client.id);
225 }
226
227 Ok(())
228 }
229}
230
231#[derive(Debug)]
233pub struct WebSocketClientConnection {
234 sender: broadcast::Sender<ProgressUpdate>,
235 #[allow(dead_code)]
236 receiver: broadcast::Receiver<ProgressUpdate>,
237}
238
239impl Default for WebSocketClientConnection {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245impl WebSocketClientConnection {
246 #[must_use]
248 pub fn new() -> Self {
249 let (sender, receiver) = broadcast::channel(1000);
250 Self { sender, receiver }
251 }
252
253 #[must_use]
255 pub fn subscribe(&self) -> broadcast::Receiver<ProgressUpdate> {
256 self.sender.subscribe()
257 }
258
259 pub fn send_update(&self, update: ProgressUpdate) -> Result<()> {
264 self.sender.send(update)?;
265 Ok(())
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use std::time::Duration as StdDuration;
273
274 #[test]
275 fn test_websocket_message_serialization() {
276 let msg = WebSocketMessage::Subscribe {
277 operation_id: Some(Uuid::new_v4()),
278 };
279 let json = serde_json::to_string(&msg).unwrap();
280 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
281
282 match deserialized {
283 WebSocketMessage::Subscribe { operation_id } => {
284 assert!(operation_id.is_some());
285 }
286 _ => panic!("Expected Subscribe message"),
287 }
288 }
289
290 #[test]
291 fn test_websocket_client_creation() {
292 let (sender, _) = crossbeam_channel::unbounded();
293 let client = WebSocketClient::new(sender);
294 assert!(!client.id.is_nil());
295 }
296
297 #[test]
298 fn test_websocket_server_creation() {
299 let server = WebSocketServer::new(8080);
300 assert_eq!(server.port, 8080);
301 }
302
303 #[tokio::test]
304 async fn test_websocket_client_connection() {
305 let connection = WebSocketClientConnection::new();
306 let mut receiver = connection.subscribe();
307
308 let update = ProgressUpdate {
310 operation_id: Uuid::new_v4(),
311 operation_name: "test".to_string(),
312 current: 10,
313 total: Some(100),
314 message: Some("test message".to_string()),
315 timestamp: chrono::Utc::now(),
316 status: crate::progress::ProgressStatus::InProgress,
317 };
318
319 connection.send_update(update.clone()).unwrap();
320
321 let received_msg = tokio::time::timeout(StdDuration::from_millis(100), receiver.recv())
323 .await
324 .unwrap()
325 .unwrap();
326 assert_eq!(received_msg.operation_name, update.operation_name);
327 }
328
329 #[tokio::test]
330 async fn test_websocket_server_creation_with_port() {
331 let server = WebSocketServer::new(8080);
332 assert_eq!(server.port, 8080);
333 }
334
335 #[tokio::test]
336 async fn test_websocket_server_progress_manager() {
337 let server = WebSocketServer::new(8080);
338 let _progress_manager = server.progress_manager();
339 }
341
342 #[tokio::test]
343 async fn test_websocket_client_creation_async() {
344 let (sender, _receiver) = crossbeam_channel::unbounded();
345 let client = WebSocketClient::new(sender);
346 assert!(!client.id.is_nil());
348 }
349
350 #[tokio::test]
351 async fn test_websocket_client_connection_default() {
352 let _connection = WebSocketClientConnection::default();
353 }
355
356 #[tokio::test]
357 async fn test_websocket_client_connection_subscribe() {
358 let connection = WebSocketClientConnection::new();
359 let _receiver = connection.subscribe();
360 }
362
363 #[tokio::test]
364 async fn test_websocket_client_connection_send_update() {
365 let connection = WebSocketClientConnection::new();
366 let update = ProgressUpdate {
367 operation_id: Uuid::new_v4(),
368 operation_name: "test".to_string(),
369 current: 50,
370 total: Some(100),
371 message: Some("test message".to_string()),
372 timestamp: chrono::Utc::now(),
373 status: crate::progress::ProgressStatus::InProgress,
374 };
375
376 let result = connection.send_update(update);
377 assert!(result.is_ok());
378 }
379
380 #[tokio::test]
381 async fn test_websocket_message_serialization_async() {
382 let message = WebSocketMessage::Subscribe {
383 operation_id: Some(Uuid::new_v4()),
384 };
385
386 let json = serde_json::to_string(&message).unwrap();
387 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
388
389 match (message, deserialized) {
390 (
391 WebSocketMessage::Subscribe { operation_id: id1 },
392 WebSocketMessage::Subscribe { operation_id: id2 },
393 ) => {
394 assert_eq!(id1, id2);
395 }
396 _ => panic!("Message types don't match"),
397 }
398 }
399
400 #[tokio::test]
401 #[allow(clippy::similar_names)]
402 async fn test_websocket_message_ping_pong() {
403 let ping_message = WebSocketMessage::Ping;
404 let pong_message = WebSocketMessage::Pong;
405
406 let ping_json = serde_json::to_string(&ping_message).unwrap();
407 let pong_json = serde_json::to_string(&pong_message).unwrap();
408
409 let ping_deserialized: WebSocketMessage = serde_json::from_str(&ping_json).unwrap();
410 let pong_deserialized: WebSocketMessage = serde_json::from_str(&pong_json).unwrap();
411
412 assert!(matches!(ping_deserialized, WebSocketMessage::Ping));
413 assert!(matches!(pong_deserialized, WebSocketMessage::Pong));
414 }
415
416 #[tokio::test]
417 async fn test_websocket_message_unsubscribe() {
418 let message = WebSocketMessage::Unsubscribe {
419 operation_id: Some(Uuid::new_v4()),
420 };
421
422 let json = serde_json::to_string(&message).unwrap();
423 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
424
425 match (message, deserialized) {
426 (
427 WebSocketMessage::Unsubscribe { operation_id: id1 },
428 WebSocketMessage::Unsubscribe { operation_id: id2 },
429 ) => {
430 assert_eq!(id1, id2);
431 }
432 _ => panic!("Message types don't match"),
433 }
434 }
435
436 #[tokio::test]
437 async fn test_websocket_message_progress_update() {
438 let update = ProgressUpdate {
439 operation_id: Uuid::new_v4(),
440 operation_name: "test_operation".to_string(),
441 current: 75,
442 total: Some(100),
443 message: Some("Almost done".to_string()),
444 timestamp: chrono::Utc::now(),
445 status: crate::progress::ProgressStatus::InProgress,
446 };
447
448 let message = WebSocketMessage::ProgressUpdate(update.clone());
449
450 let json = serde_json::to_string(&message).unwrap();
451 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
452
453 match deserialized {
454 WebSocketMessage::ProgressUpdate(deserialized_update) => {
455 assert_eq!(update.operation_id, deserialized_update.operation_id);
456 assert_eq!(update.operation_name, deserialized_update.operation_name);
457 assert_eq!(update.current, deserialized_update.current);
458 }
459 _ => panic!("Expected ProgressUpdate message"),
460 }
461 }
462
463 #[tokio::test]
464 async fn test_websocket_message_error() {
465 let message = WebSocketMessage::Error {
466 message: "Test error".to_string(),
467 };
468
469 let json = serde_json::to_string(&message).unwrap();
470 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
471
472 match deserialized {
473 WebSocketMessage::Error { message: msg } => {
474 assert_eq!(msg, "Test error");
475 }
476 _ => panic!("Expected Error message"),
477 }
478 }
479
480 #[tokio::test]
481 async fn test_websocket_client_connection_multiple_updates() {
482 let connection = WebSocketClientConnection::new();
483 let mut receiver = connection.subscribe();
484
485 for i in 0..5 {
487 let update = ProgressUpdate {
488 operation_id: Uuid::new_v4(),
489 operation_name: format!("test_{i}"),
490 current: i * 20,
491 total: Some(100),
492 message: Some(format!("Update {i}")),
493 timestamp: chrono::Utc::now(),
494 status: crate::progress::ProgressStatus::InProgress,
495 };
496
497 connection.send_update(update).unwrap();
498 }
499
500 for i in 0..5 {
502 let received_msg = tokio::time::timeout(StdDuration::from_millis(100), receiver.recv())
503 .await
504 .unwrap()
505 .unwrap();
506 assert_eq!(received_msg.operation_name, format!("test_{i}"));
507 }
508 }
509
510 #[tokio::test]
511 async fn test_websocket_client_connection_timeout() {
512 let connection = WebSocketClientConnection::new();
513 let mut receiver = connection.subscribe();
514
515 let result = tokio::time::timeout(StdDuration::from_millis(50), receiver.recv()).await;
517 assert!(result.is_err()); }
519
520 #[tokio::test]
521 async fn test_websocket_server_start() {
522 let server = WebSocketServer::new(8080);
523
524 assert_eq!(server.port, 8080);
527
528 let _server_ref = &server;
531 }
534
535 #[tokio::test]
536 async fn test_websocket_server_broadcast() {
537 let server = WebSocketServer::new(8080);
538
539 let update = ProgressUpdate {
540 operation_id: Uuid::new_v4(),
541 operation_name: "test_operation".to_string(),
542 current: 50,
543 total: Some(100),
544 message: Some("Test message".to_string()),
545 timestamp: chrono::Utc::now(),
546 status: crate::progress::ProgressStatus::InProgress,
547 };
548
549 let result = server
551 .broadcast(WebSocketMessage::ProgressUpdate(update))
552 .await;
553 assert!(result.is_ok());
554 }
555
556 #[test]
557 fn test_websocket_message_debug() {
558 let message = WebSocketMessage::Ping;
559 let debug_str = format!("{message:?}");
560 assert!(debug_str.contains("Ping"));
561 }
562
563 #[test]
564 fn test_websocket_message_clone() {
565 let message = WebSocketMessage::Ping;
566 let cloned = message.clone();
567 assert_eq!(message, cloned);
568 }
569
570 #[test]
571 fn test_websocket_message_partial_eq() {
572 let message1 = WebSocketMessage::Ping;
573 let message2 = WebSocketMessage::Ping;
574 let message3 = WebSocketMessage::Pong;
575
576 assert_eq!(message1, message2);
577 assert_ne!(message1, message3);
578 }
579
580 #[test]
581 fn test_websocket_client_debug() {
582 let (sender, _receiver) = crossbeam_channel::unbounded();
583 let client = WebSocketClient::new(sender);
584 let debug_str = format!("{client:?}");
585 assert!(debug_str.contains("WebSocketClient"));
586 }
587
588 #[test]
589 fn test_websocket_client_connection_debug() {
590 let connection = WebSocketClientConnection::new();
591 let debug_str = format!("{connection:?}");
592 assert!(debug_str.contains("WebSocketClientConnection"));
593 }
594
595 #[test]
596 fn test_websocket_server_debug() {
597 let server = WebSocketServer::new(8080);
598 let debug_str = format!("{server:?}");
599 assert!(debug_str.contains("WebSocketServer"));
600 }
601
602 #[test]
603 fn test_websocket_message_subscribe_serialization() {
604 let message = WebSocketMessage::Subscribe {
605 operation_id: Some(Uuid::new_v4()),
606 };
607 let json = serde_json::to_string(&message).unwrap();
608 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
609 assert_eq!(message, deserialized);
610 }
611
612 #[test]
613 fn test_websocket_message_unsubscribe_serialization() {
614 let message = WebSocketMessage::Unsubscribe {
615 operation_id: Some(Uuid::new_v4()),
616 };
617 let json = serde_json::to_string(&message).unwrap();
618 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
619 assert_eq!(message, deserialized);
620 }
621
622 #[test]
623 fn test_websocket_message_progress_update_serialization() {
624 let update = ProgressUpdate {
625 operation_id: Uuid::new_v4(),
626 operation_name: "test_operation".to_string(),
627 current: 50,
628 total: Some(100),
629 message: Some("Test message".to_string()),
630 timestamp: chrono::Utc::now(),
631 status: crate::progress::ProgressStatus::InProgress,
632 };
633 let message = WebSocketMessage::ProgressUpdate(update);
634 let json = serde_json::to_string(&message).unwrap();
635 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
636 assert_eq!(message, deserialized);
637 }
638
639 #[test]
640 fn test_websocket_message_error_serialization() {
641 let message = WebSocketMessage::Error {
642 message: "Test error".to_string(),
643 };
644 let json = serde_json::to_string(&message).unwrap();
645 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
646 assert_eq!(message, deserialized);
647 }
648
649 #[tokio::test]
650 async fn test_websocket_server_multiple_broadcasts() {
651 let server = WebSocketServer::new(8080);
652
653 let update1 = ProgressUpdate {
654 operation_id: Uuid::new_v4(),
655 operation_name: "operation1".to_string(),
656 current: 25,
657 total: Some(100),
658 message: Some("First update".to_string()),
659 timestamp: chrono::Utc::now(),
660 status: crate::progress::ProgressStatus::InProgress,
661 };
662
663 let update2 = ProgressUpdate {
664 operation_id: Uuid::new_v4(),
665 operation_name: "operation2".to_string(),
666 current: 50,
667 total: Some(100),
668 message: Some("Second update".to_string()),
669 timestamp: chrono::Utc::now(),
670 status: crate::progress::ProgressStatus::InProgress,
671 };
672
673 let result1 = server
675 .broadcast(WebSocketMessage::ProgressUpdate(update1))
676 .await;
677 let result2 = server
678 .broadcast(WebSocketMessage::ProgressUpdate(update2))
679 .await;
680
681 assert!(result1.is_ok());
682 assert!(result2.is_ok());
683 }
684
685 #[test]
686 fn test_websocket_server_port_access() {
687 let server = WebSocketServer::new(8080);
688 assert_eq!(server.port, 8080);
689 }
690
691 #[test]
692 fn test_websocket_client_id_generation() {
693 let (sender1, _receiver1) = crossbeam_channel::unbounded();
694 let (sender2, _receiver2) = crossbeam_channel::unbounded();
695
696 let client1 = WebSocketClient::new(sender1);
697 let client2 = WebSocketClient::new(sender2);
698
699 assert_ne!(client1.id, client2.id);
701 assert!(!client1.id.is_nil());
702 assert!(!client2.id.is_nil());
703 }
704
705 #[tokio::test]
706 async fn test_websocket_message_roundtrip_all_types() {
707 let messages = vec![
708 WebSocketMessage::Subscribe {
709 operation_id: Some(Uuid::new_v4()),
710 },
711 WebSocketMessage::Unsubscribe {
712 operation_id: Some(Uuid::new_v4()),
713 },
714 WebSocketMessage::Ping,
715 WebSocketMessage::Pong,
716 WebSocketMessage::ProgressUpdate(ProgressUpdate {
717 operation_id: Uuid::new_v4(),
718 operation_name: "test".to_string(),
719 current: 0,
720 total: Some(100),
721 message: None,
722 timestamp: chrono::Utc::now(),
723 status: crate::progress::ProgressStatus::InProgress,
724 }),
725 WebSocketMessage::Error {
726 message: "test error".to_string(),
727 },
728 ];
729
730 for message in messages {
731 let json = serde_json::to_string(&message).unwrap();
732 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
733 assert_eq!(message, deserialized);
734 }
735 }
736
737 #[tokio::test]
738 async fn test_websocket_server_client_count() {
739 let server = WebSocketServer::new(8080);
740 let _count = server.client_count().await;
741 }
743
744 #[tokio::test]
745 async fn test_websocket_server_broadcast_error_handling() {
746 let server = WebSocketServer::new(8080);
747 let message = WebSocketMessage::Ping;
748
749 let result = server.broadcast(message).await;
751 assert!(result.is_ok());
752 }
753
754 #[tokio::test]
755 async fn test_websocket_server_creation_with_different_ports() {
756 let server1 = WebSocketServer::new(8080);
757 let server2 = WebSocketServer::new(8081);
758
759 assert_eq!(server1.port, 8080);
760 assert_eq!(server2.port, 8081);
761 }
762
763 #[tokio::test]
764 async fn test_websocket_server_progress_manager_access() {
765 let server = WebSocketServer::new(8080);
766 let _progress_manager = server.progress_manager();
767
768 }
772
773 #[tokio::test]
774 async fn test_websocket_client_creation_with_sender() {
775 let (sender, _receiver) = crossbeam_channel::unbounded();
776 let client = WebSocketClient::new(sender);
777
778 assert!(!client.id.is_nil());
779 }
780
781 #[tokio::test]
782 async fn test_websocket_client_connection_creation() {
783 let (_sender, _receiver) = broadcast::channel::<ProgressUpdate>(100);
784 let _connection = WebSocketClientConnection::new();
785
786 }
788
789 #[tokio::test]
790 async fn test_websocket_message_error_creation() {
791 let error_msg = WebSocketMessage::Error {
792 message: "Test error".to_string(),
793 };
794
795 match error_msg {
796 WebSocketMessage::Error { message: msg } => assert_eq!(msg, "Test error"),
797 _ => panic!("Expected Error variant"),
798 }
799 }
800
801 #[tokio::test]
802 async fn test_websocket_message_progress_update_creation() {
803 let update = ProgressUpdate {
804 operation_id: Uuid::new_v4(),
805 operation_name: "test".to_string(),
806 current: 5,
807 total: Some(10),
808 status: crate::progress::ProgressStatus::InProgress,
809 message: Some("Test".to_string()),
810 timestamp: chrono::Utc::now(),
811 };
812
813 let message = WebSocketMessage::ProgressUpdate(update);
814
815 match message {
816 WebSocketMessage::ProgressUpdate(update) => {
817 assert_eq!(update.operation_name, "test");
818 assert_eq!(update.current, 5);
819 assert_eq!(update.total, Some(10));
820 }
821 _ => panic!("Expected ProgressUpdate variant"),
822 }
823 }
824
825 #[tokio::test]
826 async fn test_websocket_message_subscribe_creation() {
827 let operation_id = Some(Uuid::new_v4());
828 let message = WebSocketMessage::Subscribe { operation_id };
829
830 match message {
831 WebSocketMessage::Subscribe { operation_id: id } => {
832 assert_eq!(id, operation_id);
833 }
834 _ => panic!("Expected Subscribe variant"),
835 }
836 }
837
838 #[tokio::test]
839 async fn test_websocket_message_unsubscribe_creation() {
840 let operation_id = Some(Uuid::new_v4());
841 let message = WebSocketMessage::Unsubscribe { operation_id };
842
843 match message {
844 WebSocketMessage::Unsubscribe { operation_id: id } => {
845 assert_eq!(id, operation_id);
846 }
847 _ => panic!("Expected Unsubscribe variant"),
848 }
849 }
850
851 #[tokio::test]
852 async fn test_websocket_message_serialization_all_variants() {
853 let operation_id = Some(Uuid::new_v4());
854 let update = ProgressUpdate {
855 operation_id: Uuid::new_v4(),
856 operation_name: "test".to_string(),
857 current: 5,
858 total: Some(10),
859 status: crate::progress::ProgressStatus::InProgress,
860 message: Some("Test".to_string()),
861 timestamp: chrono::Utc::now(),
862 };
863
864 let messages = vec![
865 WebSocketMessage::Subscribe { operation_id },
866 WebSocketMessage::Unsubscribe { operation_id },
867 WebSocketMessage::Ping,
868 WebSocketMessage::Pong,
869 WebSocketMessage::ProgressUpdate(update),
870 WebSocketMessage::Error {
871 message: "Test error".to_string(),
872 },
873 ];
874
875 for message in messages {
876 let json = serde_json::to_string(&message).unwrap();
877 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
878 assert_eq!(message, deserialized);
879 }
880 }
881
882 #[tokio::test]
883 async fn test_websocket_server_client_count_multiple_clients() {
884 let server = WebSocketServer::new(8080);
885
886 assert_eq!(server.client_count().await, 0);
888
889 let _count = server.client_count().await;
892 }
894
895 #[tokio::test]
896 async fn test_websocket_server_broadcast_different_message_types() {
897 let server = WebSocketServer::new(8080);
898
899 let messages = vec![
900 WebSocketMessage::Ping,
901 WebSocketMessage::Pong,
902 WebSocketMessage::Error {
903 message: "Test error".to_string(),
904 },
905 WebSocketMessage::Subscribe {
906 operation_id: Some(Uuid::new_v4()),
907 },
908 WebSocketMessage::Unsubscribe {
909 operation_id: Some(Uuid::new_v4()),
910 },
911 ];
912
913 for message in messages {
914 let result = server.broadcast(message).await;
915 assert!(result.is_ok());
916 }
917 }
918
919 #[tokio::test]
920 async fn test_websocket_client_connection_receive_update() {
921 let (_sender, _receiver) = broadcast::channel::<ProgressUpdate>(100);
922 let connection = WebSocketClientConnection::new();
923
924 let update = ProgressUpdate {
925 operation_id: Uuid::new_v4(),
926 operation_name: "test".to_string(),
927 current: 5,
928 total: Some(10),
929 status: crate::progress::ProgressStatus::InProgress,
930 message: Some("Test".to_string()),
931 timestamp: chrono::Utc::now(),
932 };
933
934 connection.send_update(update.clone()).unwrap();
936
937 let received_msg = tokio::time::timeout(
939 std::time::Duration::from_millis(100),
940 connection.subscribe().recv(),
941 )
942 .await;
943
944 if let Ok(Ok(received_update)) = received_msg {
945 assert_eq!(received_update.operation_name, update.operation_name);
946 assert_eq!(received_update.current, update.current);
947 assert_eq!(received_update.total, update.total);
948 } else {
949 }
952 }
953
954 #[tokio::test]
955 async fn test_websocket_server_handle_connection_error_handling() {
956 let server = WebSocketServer::new(8080);
957
958 let _server_ref = &server;
962 }
965
966 #[tokio::test]
967 async fn test_websocket_server_start_error_handling() {
968 let server = WebSocketServer::new(8080);
969
970 let _server_ref = &server;
973 }
976
977 #[tokio::test]
978 async fn test_websocket_message_debug_formatting() {
979 let message = WebSocketMessage::Ping;
980 let debug_str = format!("{message:?}");
981 assert!(debug_str.contains("Ping"));
982 }
983
984 #[tokio::test]
985 async fn test_websocket_server_debug_formatting() {
986 let server = WebSocketServer::new(8080);
987 let debug_str = format!("{server:?}");
988 assert!(debug_str.contains("8080"));
989 }
990
991 #[tokio::test]
992 async fn test_websocket_client_debug_formatting() {
993 let (sender, _receiver) = crossbeam_channel::unbounded();
994 let client = WebSocketClient::new(sender);
995 let debug_str = format!("{client:?}");
996 assert!(debug_str.contains("WebSocketClient"));
997 }
998
999 #[tokio::test]
1000 async fn test_websocket_client_connection_debug_formatting() {
1001 let (_sender, _receiver) = broadcast::channel::<ProgressUpdate>(100);
1002 let connection = WebSocketClientConnection::new();
1003 let debug_str = format!("{connection:?}");
1004 assert!(debug_str.contains("WebSocketClientConnection"));
1005 }
1006
1007 #[tokio::test]
1008 async fn test_websocket_server_multiple_ports() {
1009 let server1 = WebSocketServer::new(8080);
1010 let server2 = WebSocketServer::new(8081);
1011 let server3 = WebSocketServer::new(8082);
1012
1013 assert_eq!(server1.port, 8080);
1014 assert_eq!(server2.port, 8081);
1015 assert_eq!(server3.port, 8082);
1016 }
1017
1018 #[tokio::test]
1019 async fn test_websocket_server_port_edge_cases() {
1020 let server_min = WebSocketServer::new(1);
1021 let server_max = WebSocketServer::new(65535);
1022
1023 assert_eq!(server_min.port, 1);
1024 assert_eq!(server_max.port, 65535);
1025 }
1026
1027 #[tokio::test]
1028 async fn test_websocket_message_all_variants() {
1029 let _task_id = Uuid::new_v4();
1030 let operation_id = Uuid::new_v4();
1031
1032 let messages = vec![
1034 WebSocketMessage::Subscribe {
1035 operation_id: Some(operation_id),
1036 },
1037 WebSocketMessage::Unsubscribe {
1038 operation_id: Some(operation_id),
1039 },
1040 WebSocketMessage::Ping,
1041 WebSocketMessage::Pong,
1042 WebSocketMessage::ProgressUpdate(ProgressUpdate {
1043 operation_id,
1044 operation_name: "test".to_string(),
1045 current: 50,
1046 total: Some(100),
1047 message: Some("Testing".to_string()),
1048 timestamp: chrono::Utc::now(),
1049 status: crate::progress::ProgressStatus::InProgress,
1050 }),
1051 WebSocketMessage::Error {
1052 message: "Test error".to_string(),
1053 },
1054 ];
1055
1056 for message in messages {
1057 let json = serde_json::to_string(&message).unwrap();
1058 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1059 assert_eq!(message, deserialized);
1060 }
1061 }
1062
1063 #[tokio::test]
1064 async fn test_websocket_message_serialization_edge_cases() {
1065 let subscribe_none = WebSocketMessage::Subscribe { operation_id: None };
1067 let json = serde_json::to_string(&subscribe_none).unwrap();
1068 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1069 assert_eq!(subscribe_none, deserialized);
1070
1071 let error_empty = WebSocketMessage::Error {
1073 message: String::new(),
1074 };
1075 let json = serde_json::to_string(&error_empty).unwrap();
1076 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1077 assert_eq!(error_empty, deserialized);
1078 }
1079
1080 #[tokio::test]
1081 async fn test_websocket_client_id_uniqueness() {
1082 let (sender1, _receiver1) = crossbeam_channel::unbounded();
1083 let (sender2, _receiver2) = crossbeam_channel::unbounded();
1084
1085 let client1 = WebSocketClient::new(sender1);
1086 let client2 = WebSocketClient::new(sender2);
1087
1088 assert_ne!(client1.id, client2.id);
1089 }
1090
1091 #[tokio::test]
1092 async fn test_websocket_client_connection_subscription() {
1093 let connection = WebSocketClientConnection::new();
1094 let mut subscriber = connection.subscribe();
1095
1096 let result = subscriber.try_recv();
1099 assert!(result.is_err());
1100 }
1101
1102 #[tokio::test]
1103 async fn test_websocket_server_error_handling() {
1104 let _server = WebSocketServer::new(8080);
1105
1106 let error_msg = WebSocketMessage::Error {
1108 message: "Test error".to_string(),
1109 };
1110
1111 match error_msg {
1112 WebSocketMessage::Error { message } => {
1113 assert_eq!(message, "Test error");
1114 }
1115 _ => panic!("Expected Error variant"),
1116 }
1117 }
1118
1119 #[tokio::test]
1120 async fn test_websocket_server_connection_handling() {
1121 let _server = WebSocketServer::new(8080);
1122
1123 let connection = WebSocketClientConnection::new();
1125 let _subscriber = connection.subscribe();
1129 }
1131
1132 #[tokio::test]
1133 async fn test_websocket_server_message_serialization_edge_cases() {
1134 let subscribe_msg = WebSocketMessage::Subscribe { operation_id: None };
1136 let json = serde_json::to_string(&subscribe_msg).unwrap();
1137 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1138 assert_eq!(subscribe_msg, deserialized);
1139
1140 let ping_msg = WebSocketMessage::Ping;
1142 let json = serde_json::to_string(&ping_msg).unwrap();
1143 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1144 assert_eq!(ping_msg, deserialized);
1145 }
1146
1147 #[tokio::test]
1148 async fn test_websocket_server_large_messages() {
1149 let _server = WebSocketServer::new(8080);
1150
1151 let ping_msg = WebSocketMessage::Ping;
1153
1154 let json = serde_json::to_string(&ping_msg).unwrap();
1155 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1156 assert_eq!(ping_msg, deserialized);
1157 }
1158
1159 #[tokio::test]
1160 async fn test_websocket_server_concurrent_operations() {
1161 let _server = Arc::new(WebSocketServer::new(8080));
1162 let mut handles = vec![];
1163
1164 for _i in 0..10 {
1166 let handle = tokio::spawn(async move {
1167 let message = WebSocketMessage::Ping;
1168 let json = serde_json::to_string(&message).unwrap();
1169 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1170 assert_eq!(message, deserialized);
1171 });
1172 handles.push(handle);
1173 }
1174
1175 for handle in handles {
1177 handle.await.unwrap();
1178 }
1179 }
1180
1181 #[tokio::test]
1182 async fn test_websocket_server_message_roundtrip_all_variants() {
1183 let variants = vec![
1184 WebSocketMessage::Subscribe {
1185 operation_id: Some(Uuid::new_v4()),
1186 },
1187 WebSocketMessage::Unsubscribe { operation_id: None },
1188 WebSocketMessage::Ping,
1189 WebSocketMessage::Pong,
1190 WebSocketMessage::Error {
1191 message: "error".to_string(),
1192 },
1193 WebSocketMessage::ProgressUpdate(ProgressUpdate {
1194 operation_id: Uuid::new_v4(),
1195 operation_name: "test".to_string(),
1196 current: 1,
1197 total: Some(10),
1198 status: crate::progress::ProgressStatus::InProgress,
1199 message: Some("test".to_string()),
1200 timestamp: chrono::Utc::now(),
1201 }),
1202 ];
1203
1204 for variant in variants {
1205 let json = serde_json::to_string(&variant).unwrap();
1206 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1207 assert_eq!(variant, deserialized);
1208 }
1209 }
1210
1211 #[tokio::test]
1212 async fn test_websocket_server_edge_cases() {
1213 let minimal_ping = WebSocketMessage::Ping;
1215 let json = serde_json::to_string(&minimal_ping).unwrap();
1216 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1217 assert_eq!(minimal_ping, deserialized);
1218
1219 let special_ping = WebSocketMessage::Ping;
1221 let json = serde_json::to_string(&special_ping).unwrap();
1222 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1223 assert_eq!(special_ping, deserialized);
1224 }
1225
1226 #[tokio::test]
1227 async fn test_websocket_server_performance() {
1228 let _server = WebSocketServer::new(8080);
1229
1230 let start = std::time::Instant::now();
1232
1233 for _i in 0..1000 {
1234 let message = WebSocketMessage::Ping;
1235 let _json = serde_json::to_string(&message).unwrap();
1236 }
1237
1238 let elapsed = start.elapsed();
1239 assert!(elapsed.as_millis() < 1000); }
1241
1242 #[tokio::test]
1243 async fn test_websocket_server_memory_usage() {
1244 let _server = WebSocketServer::new(8080);
1245
1246 let mut messages = Vec::new();
1248
1249 for _i in 0..100 {
1250 let message = WebSocketMessage::Ping;
1251 messages.push(message);
1252 }
1253
1254 assert_eq!(messages.len(), 100);
1256
1257 for message in messages {
1259 let _json = serde_json::to_string(&message).unwrap();
1260 }
1261 }
1262
1263 #[tokio::test]
1264 async fn test_websocket_server_error_recovery() {
1265 let _server = WebSocketServer::new(8080);
1266
1267 let malformed_json = r#"{"invalid": "json"}"#;
1269 let result: Result<WebSocketMessage, _> = serde_json::from_str(malformed_json);
1270 assert!(result.is_err());
1271
1272 let empty_json = r"{}";
1274 let result: Result<WebSocketMessage, _> = serde_json::from_str(empty_json);
1275 assert!(result.is_err());
1276 }
1277
1278 #[tokio::test]
1279 async fn test_websocket_server_unicode_handling() {
1280 let _server = WebSocketServer::new(8080);
1281
1282 let unicode_ping = WebSocketMessage::Ping;
1284
1285 let json = serde_json::to_string(&unicode_ping).unwrap();
1286 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1287 assert_eq!(unicode_ping, deserialized);
1288 }
1289
1290 #[tokio::test]
1291 async fn test_websocket_server_nested_data() {
1292 let _server = WebSocketServer::new(8080);
1293
1294 let ping_msg = WebSocketMessage::Ping;
1296
1297 let json = serde_json::to_string(&ping_msg).unwrap();
1298 let deserialized: WebSocketMessage = serde_json::from_str(&json).unwrap();
1299 assert_eq!(ping_msg, deserialized);
1300 }
1301}