Skip to main content

rust_web_server/tcp_proxy/
mod.rs

1//! Layer-4 TCP proxy.
2//!
3//! [`TcpProxy`] accepts raw TCP connections and forwards them to backend servers,
4//! bidirectionally tunneling bytes with one thread per direction.
5//!
6//! Unlike [`crate::proxy::ReverseProxy`] (which operates at the HTTP layer),
7//! `TcpProxy` is protocol-agnostic: any TCP-based protocol (database wire formats,
8//! custom binary protocols, raw TLS passthrough) is forwarded unchanged.
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use rust_web_server::tcp_proxy::TcpProxy;
14//!
15//! // Proxy raw TCP on port 5432 across two PostgreSQL backends.
16//! TcpProxy::new(["backend-1:5432", "backend-2:5432"])
17//!     .connect_timeout_ms(3000)
18//!     .bind("0.0.0.0:5432")
19//!     .unwrap();
20//! ```
21
22use std::io;
23use std::net::{TcpListener, TcpStream, ToSocketAddrs};
24use std::sync::{
25    Arc,
26    atomic::{AtomicUsize, Ordering},
27};
28use std::time::Duration;
29
30/// Layer-4 (raw TCP) reverse proxy with round-robin load balancing.
31///
32/// Call [`TcpProxy::bind`] to start accepting connections. Each connection is
33/// handled in its own thread pair (one thread per direction), so `bind` blocks
34/// the calling thread indefinitely.
35pub struct TcpProxy {
36    backends: Vec<String>,
37    counter: Arc<AtomicUsize>,
38    connect_timeout: Duration,
39}
40
41impl TcpProxy {
42    /// Create a proxy that distributes connections across `backends` in
43    /// round-robin order. Each entry must be `"host:port"`.
44    pub fn new<I, S>(backends: I) -> Self
45    where
46        I: IntoIterator<Item = S>,
47        S: Into<String>,
48    {
49        TcpProxy {
50            backends: backends.into_iter().map(|b| b.into()).collect(),
51            counter: Arc::new(AtomicUsize::new(0)),
52            connect_timeout: Duration::from_secs(5),
53        }
54    }
55
56    /// Override the TCP connect timeout to each backend (default: 5 s).
57    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
58        self.connect_timeout = Duration::from_millis(ms);
59        self
60    }
61
62    /// Bind on `addr` and start proxying. Blocks until the listener is closed.
63    pub fn bind(self, addr: &str) -> Result<(), String> {
64        if self.backends.is_empty() {
65            return Err("TcpProxy: no backends configured".to_string());
66        }
67        let listener = TcpListener::bind(addr)
68            .map_err(|e| format!("TcpProxy: bind on {} failed: {}", addr, e))?;
69        println!("TcpProxy: listening on {}", addr);
70        let proxy = Arc::new(self);
71        for incoming in listener.incoming() {
72            let client = match incoming {
73                Ok(s) => s,
74                Err(e) => {
75                    eprintln!("TcpProxy: accept error: {}", e);
76                    continue;
77                }
78            };
79            let p = Arc::clone(&proxy);
80            std::thread::spawn(move || {
81                if let Err(e) = p.relay(client) {
82                    eprintln!("TcpProxy: relay error: {}", e);
83                }
84            });
85        }
86        Ok(())
87    }
88
89    fn pick_backend(&self) -> &str {
90        let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
91        &self.backends[i]
92    }
93
94    fn relay(&self, client: TcpStream) -> Result<(), String> {
95        let addr_str = self.pick_backend().to_string();
96        let sock_addr = addr_str
97            .to_socket_addrs()
98            .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
99            .next()
100            .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
101
102        let backend = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
103            .map_err(|e| format!("TcpProxy: connect to {} failed: {}", addr_str, e))?;
104
105        let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
106        let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
107        let mut client_w = client;
108        let mut backend_w = backend;
109
110        let t1 = std::thread::spawn(move || {
111            io::copy(&mut client_r, &mut backend_w).ok();
112            let _ = backend_w.shutdown(std::net::Shutdown::Write);
113        });
114        let t2 = std::thread::spawn(move || {
115            io::copy(&mut backend_r, &mut client_w).ok();
116            let _ = client_w.shutdown(std::net::Shutdown::Write);
117        });
118
119        let _ = t1.join();
120        let _ = t2.join();
121        Ok(())
122    }
123}