Skip to main content

rust_web_server/ws_proxy/
mod.rs

1//! WebSocket reverse proxy.
2//!
3//! [`WsProxy`] listens for incoming TCP connections, reads the initial HTTP
4//! request, verifies it is a WebSocket upgrade, connects to a backend, performs
5//! the WebSocket handshake end-to-end, and then bidirectionally tunnels raw
6//! WebSocket bytes between the client and the backend.
7//!
8//! Plain (`ws://`) backends use two threads (one per direction) via
9//! `std::io::copy`, identical to the original implementation.
10//!
11//! TLS (`wss://`) backends use a single-thread polling loop: both streams are
12//! set to a 5 ms read timeout and the loop alternates between the two
13//! directions, sleeping 1 ms when neither side has data.  This avoids the
14//! deadlock that arises when trying to share a `rustls::StreamOwned` between
15//! two blocking threads.
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! use rust_web_server::ws_proxy::WsProxy;
21//!
22//! // Plain WebSocket — two backends, round-robin.
23//! WsProxy::new(["ws://chat-backend:9000", "ws://chat-backend:9001"])
24//!     .connect_timeout_ms(3000)
25//!     .bind("0.0.0.0:8080")
26//!     .unwrap();
27//!
28//! // TLS WebSocket (requires http-client or http2 feature).
29//! WsProxy::new(["wss://chat-backend.internal:443"])
30//!     .bind("0.0.0.0:8080")
31//!     .unwrap();
32//! ```
33
34use std::io::{Read, Write};
35use std::net::{TcpListener, TcpStream, ToSocketAddrs};
36use std::sync::{
37    Arc,
38    atomic::{AtomicUsize, Ordering},
39};
40use std::time::Duration;
41
42use crate::request::Request;
43use crate::websocket::WebSocket;
44
45/// WebSocket reverse proxy with round-robin load balancing.
46///
47/// Accepts HTTP/1.1 WebSocket upgrade requests and tunnels traffic to one of the
48/// configured backends.
49///
50/// Backend URL schemes:
51/// - `"host:port"` — plain TCP (no scheme)
52/// - `"ws://host:port"` — plain TCP (port defaults to 80)
53/// - `"wss://host:port"` — TLS (port defaults to 443); requires the
54///   `http-client` or `http2` Cargo feature
55///
56/// Call [`WsProxy::bind`] to start. It blocks the calling thread indefinitely.
57pub struct WsProxy {
58    backends: Vec<WsBackend>,
59    counter: Arc<AtomicUsize>,
60    connect_timeout: Duration,
61    read_timeout: Duration,
62}
63
64impl WsProxy {
65    /// Create a proxy that distributes connections across `backends` in
66    /// round-robin order.
67    ///
68    /// Each entry may be `"host:port"`, `"ws://host:port"`, or
69    /// `"wss://host:port"`.  `wss://` requires the `http-client` or `http2`
70    /// Cargo feature.
71    pub fn new<I, S>(backends: I) -> Self
72    where
73        I: IntoIterator<Item = S>,
74        S: Into<String>,
75    {
76        WsProxy {
77            backends: backends
78                .into_iter()
79                .filter_map(|b| WsBackend::parse(&b.into()))
80                .collect(),
81            counter: Arc::new(AtomicUsize::new(0)),
82            connect_timeout: Duration::from_secs(5),
83            read_timeout: Duration::from_secs(30),
84        }
85    }
86
87    /// Override the TCP connect timeout to each backend (default: 5 s).
88    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
89        self.connect_timeout = Duration::from_millis(ms);
90        self
91    }
92
93    /// Override the idle read timeout on client connections (default: 30 s).
94    ///
95    /// For `wss://` backends this controls the outer idle timeout on the
96    /// client side; the internal polling interval is fixed at 5 ms.
97    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
98        self.read_timeout = Duration::from_millis(ms);
99        self
100    }
101
102    /// Bind on `addr` and start proxying WebSocket connections. Blocks indefinitely.
103    pub fn bind(self, addr: &str) -> Result<(), String> {
104        if self.backends.is_empty() {
105            return Err("WsProxy: no backends configured".to_string());
106        }
107        let listener = TcpListener::bind(addr)
108            .map_err(|e| format!("WsProxy: bind on {} failed: {}", addr, e))?;
109        println!("WsProxy: listening on {}", addr);
110        let proxy = Arc::new(self);
111        for incoming in listener.incoming() {
112            let client = match incoming {
113                Ok(s) => s,
114                Err(e) => {
115                    eprintln!("WsProxy: accept error: {}", e);
116                    continue;
117                }
118            };
119            let p = Arc::clone(&proxy);
120            std::thread::spawn(move || {
121                if let Err(e) = p.handle(client) {
122                    eprintln!("WsProxy: {}", e);
123                }
124            });
125        }
126        Ok(())
127    }
128
129    fn pick_backend(&self) -> &WsBackend {
130        let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
131        &self.backends[i]
132    }
133
134    fn handle(&self, mut client: TcpStream) -> Result<(), String> {
135        client.set_read_timeout(Some(self.read_timeout)).ok();
136
137        // Read the initial HTTP request.
138        let mut buf = vec![0u8; 8192];
139        let n = client.read(&mut buf).map_err(|e| e.to_string())?;
140        if n == 0 {
141            return Ok(());
142        }
143
144        let request = Request::parse(&buf[..n])
145            .map_err(|e| format!("WsProxy: invalid HTTP request: {}", e))?;
146
147        if !WebSocket::is_upgrade_request(&request) {
148            let _ = client.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n");
149            return Err(format!(
150                "WsProxy: not a WebSocket upgrade — method={}, uri={}",
151                request.method, request.request_uri
152            ));
153        }
154
155        let backend = self.pick_backend();
156        let addr_str = &backend.addr;
157        let sock_addr = addr_str
158            .to_socket_addrs()
159            .map_err(|e| format!("WsProxy: DNS lookup for {} failed: {}", addr_str, e))?
160            .next()
161            .ok_or_else(|| format!("WsProxy: no address for {}", addr_str))?;
162
163        let tcp = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
164            .map_err(|e| format!("WsProxy: connect to {} failed: {}", addr_str, e))?;
165
166        // Use backend.host (no port) for the Host header and TLS SNI.
167        let upgrade_req = build_upgrade_request(&request, &backend.host);
168
169        if backend.tls {
170            self.handle_tls(client, tcp, &request, &backend.host, upgrade_req, addr_str)
171        } else {
172            handle_plain(client, tcp, &request, upgrade_req, addr_str)
173        }
174    }
175
176    fn handle_tls(
177        &self,
178        mut client: TcpStream,
179        tcp: TcpStream,
180        request: &Request,
181        host: &str,
182        upgrade_req: Vec<u8>,
183        addr_str: &str,
184    ) -> Result<(), String> {
185        #[cfg(any(feature = "http-client", feature = "http2"))]
186        {
187            use rustls::pki_types::ServerName;
188            use rustls::ClientConfig;
189            use std::sync::Arc;
190
191            let root_store = rustls::RootCertStore::from_iter(
192                webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
193            );
194            let config = Arc::new(
195                ClientConfig::builder()
196                    .with_root_certificates(root_store)
197                    .with_no_client_auth(),
198            );
199            let server_name = ServerName::try_from(host)
200                .map_err(|e| format!("WsProxy: invalid hostname '{}': {}", host, e))?
201                .to_owned();
202            let conn = rustls::ClientConnection::new(config, server_name)
203                .map_err(|e| format!("WsProxy: TLS init failed: {}", e))?;
204            let mut tls = rustls::StreamOwned::new(conn, tcp);
205
206            // Send WebSocket upgrade request over TLS.
207            tls.write_all(&upgrade_req)
208                .map_err(|e| format!("WsProxy: write upgrade to {} failed: {}", addr_str, e))?;
209
210            // Read backend's 101 Switching Protocols.
211            let mut resp_buf = vec![0u8; 4096];
212            let m = tls
213                .read(&mut resp_buf)
214                .map_err(|e| format!("WsProxy: read 101 from {} failed: {}", addr_str, e))?;
215            let preview = &resp_buf[..m.min(20)];
216            if !preview.starts_with(b"HTTP/1.1 101") && !preview.starts_with(b"HTTP/1.0 101") {
217                return Err(format!(
218                    "WsProxy: backend {} did not send 101 (got {:?})",
219                    addr_str,
220                    std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
221                ));
222            }
223
224            // Forward 101 to client.
225            let response_101 = WebSocket::handshake_response(request)?;
226            let raw_101 = format_response_head(&response_101);
227            client
228                .write_all(&raw_101)
229                .map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
230
231            // Bidirectional relay via single-thread poll loop.
232            // Set both sides to 5 ms polling timeout.
233            tls.sock.set_read_timeout(Some(Duration::from_millis(5))).ok();
234            client.set_read_timeout(Some(Duration::from_millis(5))).ok();
235            relay_tls(client, tls);
236            Ok(())
237        }
238
239        #[cfg(not(any(feature = "http-client", feature = "http2")))]
240        {
241            let _ = (tcp, request, host, upgrade_req, addr_str);
242            let _ = client.write_all(
243                b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n",
244            );
245            Err("WsProxy: wss:// upstreams require the http-client or http2 Cargo feature".to_string())
246        }
247    }
248}
249
250/// Relay over a plain TCP backend using two blocking threads.
251fn handle_plain(
252    mut client: TcpStream,
253    mut backend: TcpStream,
254    request: &Request,
255    upgrade_req: Vec<u8>,
256    addr_str: &str,
257) -> Result<(), String> {
258    backend
259        .write_all(&upgrade_req)
260        .map_err(|e| format!("WsProxy: write upgrade to {} failed: {}", addr_str, e))?;
261
262    let mut resp_buf = vec![0u8; 4096];
263    let m = backend
264        .read(&mut resp_buf)
265        .map_err(|e| format!("WsProxy: read 101 from {} failed: {}", addr_str, e))?;
266    let preview = &resp_buf[..m.min(20)];
267    if !preview.starts_with(b"HTTP/1.1 101") && !preview.starts_with(b"HTTP/1.0 101") {
268        return Err(format!(
269            "WsProxy: backend {} did not send 101 (got {:?})",
270            addr_str,
271            std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
272        ));
273    }
274
275    let response_101 = WebSocket::handshake_response(request)?;
276    let raw_101 = format_response_head(&response_101);
277    client
278        .write_all(&raw_101)
279        .map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
280
281    // Bidirectional tunnel — one thread per direction.
282    let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
283    let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
284    let mut client_w = client;
285    let mut backend_w = backend;
286
287    let t1 = std::thread::spawn(move || {
288        std::io::copy(&mut client_r, &mut backend_w).ok();
289        let _ = backend_w.shutdown(std::net::Shutdown::Write);
290    });
291    let t2 = std::thread::spawn(move || {
292        std::io::copy(&mut backend_r, &mut client_w).ok();
293        let _ = client_w.shutdown(std::net::Shutdown::Write);
294    });
295
296    let _ = t1.join();
297    let _ = t2.join();
298    Ok(())
299}
300
301/// Bidirectional relay between `client` (plain TCP) and a TLS backend.
302///
303/// Uses a single-thread polling loop to avoid the deadlock that arises when
304/// sharing a `rustls::StreamOwned` between two blocking threads (the reader
305/// thread would hold the TLS lock while waiting for data, blocking the writer).
306///
307/// Both streams are set to a 5 ms read timeout before this function is called.
308/// The loop reads from each side in turn; when neither has data it sleeps 1 ms.
309#[cfg(any(feature = "http-client", feature = "http2"))]
310fn relay_tls(
311    mut client: TcpStream,
312    mut backend: rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
313) {
314    use std::io::ErrorKind::{TimedOut, WouldBlock};
315    let mut buf = [0u8; 8192];
316
317    loop {
318        let mut active = false;
319
320        // client → TLS backend
321        let cn = match client.read(&mut buf) {
322            Ok(0) => break, // client closed
323            Ok(n) => n,
324            Err(ref e) if e.kind() == TimedOut || e.kind() == WouldBlock => 0,
325            Err(_) => break,
326        };
327        if cn > 0 {
328            if backend.write_all(&buf[..cn]).is_err() {
329                break;
330            }
331            active = true;
332        }
333
334        // TLS backend → client
335        let bn = match backend.read(&mut buf) {
336            Ok(0) => break, // backend closed
337            Ok(n) => n,
338            Err(ref e) if e.kind() == TimedOut || e.kind() == WouldBlock => 0,
339            Err(_) => break,
340        };
341        if bn > 0 {
342            if client.write_all(&buf[..bn]).is_err() {
343                break;
344            }
345            active = true;
346        }
347
348        if !active {
349            std::thread::sleep(Duration::from_millis(1));
350        }
351    }
352}
353
354// ── helpers ───────────────────────────────────────────────────────────────────
355
356fn build_upgrade_request(request: &Request, backend_host: &str) -> Vec<u8> {
357    let mut req = format!(
358        "{} {} HTTP/1.1\r\nHost: {}\r\n",
359        request.method, request.request_uri, backend_host
360    );
361    for header in &request.headers {
362        if header.name.to_lowercase() == "host" {
363            continue;
364        }
365        req.push_str(&format!("{}: {}\r\n", header.name, header.value));
366    }
367    req.push_str("\r\n");
368    req.into_bytes()
369}
370
371fn format_response_head(response: &crate::response::Response) -> Vec<u8> {
372    let mut out = format!(
373        "HTTP/1.1 {} {}\r\n",
374        response.status_code, response.reason_phrase
375    )
376    .into_bytes();
377    for h in &response.headers {
378        out.extend_from_slice(h.name.as_bytes());
379        out.extend_from_slice(b": ");
380        out.extend_from_slice(h.value.as_bytes());
381        out.extend_from_slice(b"\r\n");
382    }
383    out.extend_from_slice(b"\r\n");
384    out
385}
386
387// ── Backend URL parsing ───────────────────────────────────────────────────────
388
389struct WsBackend {
390    /// `"host:port"` — passed to `to_socket_addrs()` for TCP connect.
391    addr: String,
392    /// Bare hostname (no port) — used for the `Host` header and TLS SNI.
393    host: String,
394    /// `true` when the URL scheme was `wss://`.
395    tls: bool,
396}
397
398impl WsBackend {
399    fn parse(url: &str) -> Option<Self> {
400        let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("wss://") {
401            (r, true, 443u16)
402        } else if let Some(r) = url.strip_prefix("ws://") {
403            (r, false, 80u16)
404        } else {
405            (url, false, 80u16)
406        };
407
408        // Drop any path component.
409        let host_port = rest.split('/').next().unwrap_or(rest);
410
411        let (host, port) = if let Some(colon) = host_port.rfind(':') {
412            let port_str = &host_port[colon + 1..];
413            if let Ok(p) = port_str.parse::<u16>() {
414                (host_port[..colon].to_string(), p)
415            } else {
416                (host_port.to_string(), default_port)
417            }
418        } else {
419            (host_port.to_string(), default_port)
420        };
421
422        if host.is_empty() {
423            return None;
424        }
425
426        Some(WsBackend {
427            addr: format!("{}:{}", host, port),
428            host,
429            tls,
430        })
431    }
432}
433
434// ── WsBackend::parse unit tests ───────────────────────────────────────────────
435
436#[cfg(test)]
437mod backend_parse_tests {
438    use super::WsBackend;
439
440    fn parse(url: &str) -> Option<(String, String, bool)> {
441        WsBackend::parse(url).map(|b| (b.addr, b.host, b.tls))
442    }
443
444    #[test]
445    fn bare_host_port() {
446        assert_eq!(
447            Some(("chat:9000".into(), "chat".into(), false)),
448            parse("chat:9000")
449        );
450    }
451
452    #[test]
453    fn ws_scheme_plain() {
454        assert_eq!(
455            Some(("backend:3000".into(), "backend".into(), false)),
456            parse("ws://backend:3000")
457        );
458    }
459
460    #[test]
461    fn ws_scheme_default_port() {
462        assert_eq!(
463            Some(("api.example.com:80".into(), "api.example.com".into(), false)),
464            parse("ws://api.example.com")
465        );
466    }
467
468    #[test]
469    fn wss_scheme_sets_tls() {
470        assert_eq!(
471            Some(("secure.example.com:443".into(), "secure.example.com".into(), true)),
472            parse("wss://secure.example.com")
473        );
474    }
475
476    #[test]
477    fn wss_scheme_explicit_port() {
478        assert_eq!(
479            Some(("secure.example.com:8443".into(), "secure.example.com".into(), true)),
480            parse("wss://secure.example.com:8443")
481        );
482    }
483
484    #[test]
485    fn wss_default_port_is_443() {
486        let b = WsBackend::parse("wss://api.example.com").unwrap();
487        assert_eq!("api.example.com:443", b.addr);
488        assert_eq!("api.example.com", b.host);
489        assert!(b.tls);
490    }
491
492    #[test]
493    fn ws_default_port_is_80() {
494        let b = WsBackend::parse("ws://api.example.com").unwrap();
495        assert_eq!("api.example.com:80", b.addr);
496        assert!(!b.tls);
497    }
498
499    #[test]
500    fn empty_host_returns_none() {
501        assert_eq!(None, parse("wss://"));
502    }
503
504    #[test]
505    fn bare_host_no_port_defaults_to_80() {
506        assert_eq!(
507            Some(("myhost:80".into(), "myhost".into(), false)),
508            parse("myhost")
509        );
510    }
511
512    #[test]
513    fn path_component_is_ignored() {
514        // URL paths after host:port are stripped — only host:port matters.
515        let b = WsBackend::parse("ws://backend:9000/ws").unwrap();
516        assert_eq!("backend:9000", b.addr);
517        assert_eq!("backend", b.host);
518    }
519}