Skip to main content

rust_pipe/transport/
websocket.rs

1use async_trait::async_trait;
2use futures_util::{SinkExt, StreamExt};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::net::TcpListener;
6use tokio::sync::{mpsc, RwLock};
7use tokio_tungstenite::accept_async;
8
9use super::{Message, Transport, TransportConfig, TransportError};
10
11type WorkerSender = mpsc::UnboundedSender<Message>;
12
13pub struct WebSocketTransport {
14    config: TransportConfig,
15    workers: Arc<RwLock<HashMap<String, WorkerSender>>>,
16    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
17}
18
19impl WebSocketTransport {
20    pub fn new(
21        config: TransportConfig,
22        on_message: impl Fn(String, Message) + Send + Sync + 'static,
23    ) -> Self {
24        Self {
25            config,
26            workers: Arc::new(RwLock::new(HashMap::new())),
27            on_message: Arc::new(on_message),
28        }
29    }
30}
31
32#[async_trait]
33impl Transport for WebSocketTransport {
34    async fn start(&self) -> Result<(), TransportError> {
35        let addr = format!("{}:{}", self.config.host, self.config.port);
36        let listener = TcpListener::bind(&addr)
37            .await
38            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
39
40        tracing::info!("rust-pipe transport listening on {}", addr);
41
42        let workers = self.workers.clone();
43        let on_message = self.on_message.clone();
44
45        tokio::spawn(async move {
46            while let Ok((stream, peer_addr)) = listener.accept().await {
47                let workers = workers.clone();
48                let on_message = on_message.clone();
49
50                tokio::spawn(async move {
51                    let ws_stream = match accept_async(stream).await {
52                        Ok(ws) => ws,
53                        Err(e) => {
54                            tracing::error!("WebSocket handshake failed from {}: {}", peer_addr, e);
55                            return;
56                        }
57                    };
58
59                    let (mut ws_sender, mut ws_receiver) = ws_stream.split();
60                    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
61
62                    let mut worker_id = String::new();
63
64                    // Outbound message forwarder
65                    let send_task = tokio::spawn(async move {
66                        while let Some(msg) = rx.recv().await {
67                            let json = match serde_json::to_string(&msg) {
68                                Ok(j) => j,
69                                Err(e) => {
70                                    tracing::error!(error = %e, "Failed to serialize outbound message");
71                                    continue;
72                                }
73                            };
74                            if ws_sender
75                                .send(tokio_tungstenite::tungstenite::Message::Text(json))
76                                .await
77                                .is_err()
78                            {
79                                break;
80                            }
81                        }
82                    });
83
84                    // Inbound message handler
85                    while let Some(Ok(msg)) = ws_receiver.next().await {
86                        if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {
87                            let message = match serde_json::from_str::<Message>(&text) {
88                                Ok(m) => m,
89                                Err(e) => {
90                                    tracing::warn!(
91                                        error = %e,
92                                        raw = %&text[..text.len().min(200)],
93                                        "Failed to parse inbound message"
94                                    );
95                                    continue;
96                                }
97                            };
98                            if let Message::WorkerRegister {
99                                registration: ref reg,
100                            } = message
101                            {
102                                worker_id = reg.worker_id.clone();
103                                workers.write().await.insert(worker_id.clone(), tx.clone());
104                                tracing::info!(
105                                    worker_id = %worker_id,
106                                    language = ?reg.language,
107                                    tasks = ?reg.supported_tasks,
108                                    "Worker registered"
109                                );
110                            }
111                            on_message(worker_id.clone(), message);
112                        }
113                    }
114
115                    // Cleanup on disconnect
116                    if !worker_id.is_empty() {
117                        workers.write().await.remove(&worker_id);
118                        tracing::info!(worker_id = %worker_id, "Worker disconnected");
119                    }
120                    send_task.abort();
121                });
122            }
123        });
124
125        Ok(())
126    }
127
128    async fn stop(&self) -> Result<(), TransportError> {
129        let workers = self.workers.read().await;
130        for (_, sender) in workers.iter() {
131            let _ = sender.send(Message::Shutdown { graceful: true });
132        }
133        Ok(())
134    }
135
136    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
137        let workers = self.workers.read().await;
138        let sender = workers
139            .get(worker_id)
140            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
141
142        sender
143            .send(message)
144            .map_err(|e| TransportError::SendFailed(e.to_string()))
145    }
146
147    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
148        let workers = self.workers.read().await;
149        for (_, sender) in workers.iter() {
150            let _ = sender.send(message.clone());
151        }
152        Ok(())
153    }
154}