Skip to main content

rust_web_server/ws_proxy/
mod.rs

1//! WebSocket reverse proxy.
2//!
3//! [`WsProxy`] listens for incoming TCP connections, reads the initial HTTP
4//! request, verifies it is a WebSocket upgrade, connects to a backend, performs
5//! the WebSocket handshake end-to-end, and then bidirectionally tunnels raw
6//! WebSocket bytes between the client and the backend.
7//!
8//! Two threads handle each live connection (one per direction), so neither side
9//! is blocked waiting for the other.
10//!
11//! # Example
12//!
13//! ```rust,no_run
14//! use rust_web_server::ws_proxy::WsProxy;
15//!
16//! // All WebSocket connections on port 8080 are forwarded to a chat backend.
17//! WsProxy::new(["chat-backend:9000", "chat-backend:9001"])
18//!     .connect_timeout_ms(3000)
19//!     .bind("0.0.0.0:8080")
20//!     .unwrap();
21//! ```
22
23use std::io::{Read, Write};
24use std::net::{TcpListener, TcpStream, ToSocketAddrs};
25use std::sync::{
26    Arc,
27    atomic::{AtomicUsize, Ordering},
28};
29use std::time::Duration;
30
31use crate::request::Request;
32use crate::websocket::WebSocket;
33
34/// WebSocket reverse proxy with round-robin load balancing.
35///
36/// Accepts plain HTTP/1.1 WebSocket upgrade requests, tunnels the handshake to
37/// a backend, and relays all subsequent frames bidirectionally.
38///
39/// For TLS-terminated WebSocket proxying, place a TLS terminator in front (e.g.
40/// another rws instance with TLS configured) and point it at this proxy.
41///
42/// Call [`WsProxy::bind`] to start. It blocks the calling thread indefinitely.
43pub struct WsProxy {
44    backends: Vec<String>,
45    counter: Arc<AtomicUsize>,
46    connect_timeout: Duration,
47    read_timeout: Duration,
48}
49
50impl WsProxy {
51    /// Create a proxy that distributes connections across `backends` in
52    /// round-robin order. Each entry must be `"host:port"`.
53    pub fn new<I, S>(backends: I) -> Self
54    where
55        I: IntoIterator<Item = S>,
56        S: Into<String>,
57    {
58        WsProxy {
59            backends: backends.into_iter().map(|b| b.into()).collect(),
60            counter: Arc::new(AtomicUsize::new(0)),
61            connect_timeout: Duration::from_secs(5),
62            read_timeout: Duration::from_secs(30),
63        }
64    }
65
66    /// Override the TCP connect timeout to each backend (default: 5 s).
67    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
68        self.connect_timeout = Duration::from_millis(ms);
69        self
70    }
71
72    /// Override the idle read timeout on client connections (default: 30 s).
73    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
74        self.read_timeout = Duration::from_millis(ms);
75        self
76    }
77
78    /// Bind on `addr` and start proxying WebSocket connections. Blocks indefinitely.
79    pub fn bind(self, addr: &str) -> Result<(), String> {
80        if self.backends.is_empty() {
81            return Err("WsProxy: no backends configured".to_string());
82        }
83        let listener = TcpListener::bind(addr)
84            .map_err(|e| format!("WsProxy: bind on {} failed: {}", addr, e))?;
85        println!("WsProxy: listening on {}", addr);
86        let proxy = Arc::new(self);
87        for incoming in listener.incoming() {
88            let client = match incoming {
89                Ok(s) => s,
90                Err(e) => {
91                    eprintln!("WsProxy: accept error: {}", e);
92                    continue;
93                }
94            };
95            let p = Arc::clone(&proxy);
96            std::thread::spawn(move || {
97                if let Err(e) = p.handle(client) {
98                    eprintln!("WsProxy: {}", e);
99                }
100            });
101        }
102        Ok(())
103    }
104
105    fn pick_backend(&self) -> &str {
106        let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
107        &self.backends[i]
108    }
109
110    fn handle(&self, mut client: TcpStream) -> Result<(), String> {
111        client.set_read_timeout(Some(self.read_timeout)).ok();
112
113        // Read the initial HTTP request
114        let mut buf = vec![0u8; 8192];
115        let n = client.read(&mut buf).map_err(|e| e.to_string())?;
116        if n == 0 {
117            return Ok(());
118        }
119
120        let request = Request::parse(&buf[..n])
121            .map_err(|e| format!("WsProxy: invalid HTTP request: {}", e))?;
122
123        if !WebSocket::is_upgrade_request(&request) {
124            let _ = client.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n");
125            return Err(format!(
126                "WsProxy: not a WebSocket upgrade — method={}, uri={}",
127                request.method, request.request_uri
128            ));
129        }
130
131        // Connect to backend
132        let backend_str = self.pick_backend().to_string();
133        let backend_sock = backend_str
134            .to_socket_addrs()
135            .map_err(|e| format!("WsProxy: DNS lookup for {} failed: {}", backend_str, e))?
136            .next()
137            .ok_or_else(|| format!("WsProxy: no address for {}", backend_str))?;
138
139        let mut backend = TcpStream::connect_timeout(&backend_sock, self.connect_timeout)
140            .map_err(|e| format!("WsProxy: connect to {} failed: {}", backend_str, e))?;
141
142        // Forward the HTTP upgrade request to the backend
143        let upgrade_req = build_upgrade_request(&request, &backend_str);
144        backend
145            .write_all(&upgrade_req)
146            .map_err(|e| format!("WsProxy: write upgrade to backend failed: {}", e))?;
147
148        // Read backend's 101 response
149        let mut resp_buf = vec![0u8; 4096];
150        let m = backend
151            .read(&mut resp_buf)
152            .map_err(|e| format!("WsProxy: read 101 from backend failed: {}", e))?;
153        let resp_preview = &resp_buf[..m.min(20)];
154        if !resp_preview.starts_with(b"HTTP/1.1 101") && !resp_preview.starts_with(b"HTTP/1.0 101") {
155            return Err(format!(
156                "WsProxy: backend {} did not send 101 (got {:?})",
157                backend_str,
158                std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
159            ));
160        }
161
162        // Send 101 Switching Protocols to the client
163        let response_101 = WebSocket::handshake_response(&request)?;
164        let raw_101 = format_response_head(&response_101);
165        client
166            .write_all(&raw_101)
167            .map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
168
169        // Bidirectional byte tunnel — two threads, one per direction
170        let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
171        let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
172        let mut client_w = client;
173        let mut backend_w = backend;
174
175        let t1 = std::thread::spawn(move || {
176            std::io::copy(&mut client_r, &mut backend_w).ok();
177            let _ = backend_w.shutdown(std::net::Shutdown::Write);
178        });
179        let t2 = std::thread::spawn(move || {
180            std::io::copy(&mut backend_r, &mut client_w).ok();
181            let _ = client_w.shutdown(std::net::Shutdown::Write);
182        });
183
184        let _ = t1.join();
185        let _ = t2.join();
186        Ok(())
187    }
188}
189
190fn build_upgrade_request(request: &Request, backend_host: &str) -> Vec<u8> {
191    let mut req = format!(
192        "{} {} HTTP/1.1\r\nHost: {}\r\n",
193        request.method, request.request_uri, backend_host
194    );
195    for header in &request.headers {
196        if header.name.to_lowercase() == "host" {
197            continue;
198        }
199        req.push_str(&format!("{}: {}\r\n", header.name, header.value));
200    }
201    req.push_str("\r\n");
202    req.into_bytes()
203}
204
205fn format_response_head(response: &crate::response::Response) -> Vec<u8> {
206    let mut out = format!(
207        "HTTP/1.1 {} {}\r\n",
208        response.status_code, response.reason_phrase
209    )
210    .into_bytes();
211    for h in &response.headers {
212        out.extend_from_slice(h.name.as_bytes());
213        out.extend_from_slice(b": ");
214        out.extend_from_slice(h.value.as_bytes());
215        out.extend_from_slice(b"\r\n");
216    }
217    out.extend_from_slice(b"\r\n");
218    out
219}