rust_pipe/transport/
websocket.rs1use async_trait::async_trait;
2use futures_util::{SinkExt, StreamExt};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::net::TcpListener;
6use tokio::sync::{mpsc, RwLock};
7use tokio_tungstenite::accept_async;
8
9use super::{Message, Transport, TransportConfig, TransportError};
10
11type WorkerSender = mpsc::UnboundedSender<Message>;
12
13pub struct WebSocketTransport {
14 config: TransportConfig,
15 workers: Arc<RwLock<HashMap<String, WorkerSender>>>,
16 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
17}
18
19impl WebSocketTransport {
20 pub fn new(
21 config: TransportConfig,
22 on_message: impl Fn(String, Message) + Send + Sync + 'static,
23 ) -> Self {
24 Self {
25 config,
26 workers: Arc::new(RwLock::new(HashMap::new())),
27 on_message: Arc::new(on_message),
28 }
29 }
30}
31
32#[async_trait]
33impl Transport for WebSocketTransport {
34 async fn start(&self) -> Result<(), TransportError> {
35 let addr = format!("{}:{}", self.config.host, self.config.port);
36 let listener = TcpListener::bind(&addr)
37 .await
38 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
39
40 tracing::info!("rust-pipe transport listening on {}", addr);
41
42 let workers = self.workers.clone();
43 let on_message = self.on_message.clone();
44
45 tokio::spawn(async move {
46 while let Ok((stream, peer_addr)) = listener.accept().await {
47 let workers = workers.clone();
48 let on_message = on_message.clone();
49
50 tokio::spawn(async move {
51 let ws_stream = match accept_async(stream).await {
52 Ok(ws) => ws,
53 Err(e) => {
54 tracing::error!("WebSocket handshake failed from {}: {}", peer_addr, e);
55 return;
56 }
57 };
58
59 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
60 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
61
62 let mut worker_id = String::new();
63
64 let send_task = tokio::spawn(async move {
66 while let Some(msg) = rx.recv().await {
67 let json = match serde_json::to_string(&msg) {
68 Ok(j) => j,
69 Err(e) => {
70 tracing::error!(error = %e, "Failed to serialize outbound message");
71 continue;
72 }
73 };
74 if ws_sender
75 .send(tokio_tungstenite::tungstenite::Message::Text(json))
76 .await
77 .is_err()
78 {
79 break;
80 }
81 }
82 });
83
84 while let Some(Ok(msg)) = ws_receiver.next().await {
86 if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {
87 let message = match serde_json::from_str::<Message>(&text) {
88 Ok(m) => m,
89 Err(e) => {
90 tracing::warn!(
91 error = %e,
92 raw = %&text[..text.len().min(200)],
93 "Failed to parse inbound message"
94 );
95 continue;
96 }
97 };
98 if let Message::WorkerRegister {
99 registration: ref reg,
100 } = message
101 {
102 worker_id = reg.worker_id.clone();
103 workers.write().await.insert(worker_id.clone(), tx.clone());
104 tracing::info!(
105 worker_id = %worker_id,
106 language = ?reg.language,
107 tasks = ?reg.supported_tasks,
108 "Worker registered"
109 );
110 }
111 on_message(worker_id.clone(), message);
112 }
113 }
114
115 if !worker_id.is_empty() {
117 workers.write().await.remove(&worker_id);
118 tracing::info!(worker_id = %worker_id, "Worker disconnected");
119 }
120 send_task.abort();
121 });
122 }
123 });
124
125 Ok(())
126 }
127
128 async fn stop(&self) -> Result<(), TransportError> {
129 let workers = self.workers.read().await;
130 for (_, sender) in workers.iter() {
131 let _ = sender.send(Message::Shutdown { graceful: true });
132 }
133 Ok(())
134 }
135
136 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
137 let workers = self.workers.read().await;
138 let sender = workers
139 .get(worker_id)
140 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
141
142 sender
143 .send(message)
144 .map_err(|e| TransportError::SendFailed(e.to_string()))
145 }
146
147 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
148 let workers = self.workers.read().await;
149 for (_, sender) in workers.iter() {
150 let _ = sender.send(message.clone());
151 }
152 Ok(())
153 }
154}