rust_web_server/tcp_proxy/
mod.rs1use std::io;
23use std::net::{TcpListener, TcpStream, ToSocketAddrs};
24use std::sync::{
25 Arc,
26 atomic::{AtomicUsize, Ordering},
27};
28use std::time::Duration;
29
30pub struct TcpProxy {
36 backends: Vec<String>,
37 counter: Arc<AtomicUsize>,
38 connect_timeout: Duration,
39}
40
41impl TcpProxy {
42 pub fn new<I, S>(backends: I) -> Self
45 where
46 I: IntoIterator<Item = S>,
47 S: Into<String>,
48 {
49 TcpProxy {
50 backends: backends.into_iter().map(|b| b.into()).collect(),
51 counter: Arc::new(AtomicUsize::new(0)),
52 connect_timeout: Duration::from_secs(5),
53 }
54 }
55
56 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
58 self.connect_timeout = Duration::from_millis(ms);
59 self
60 }
61
62 pub fn bind(self, addr: &str) -> Result<(), String> {
64 if self.backends.is_empty() {
65 return Err("TcpProxy: no backends configured".to_string());
66 }
67 let listener = TcpListener::bind(addr)
68 .map_err(|e| format!("TcpProxy: bind on {} failed: {}", addr, e))?;
69 println!("TcpProxy: listening on {}", addr);
70 let proxy = Arc::new(self);
71 for incoming in listener.incoming() {
72 let client = match incoming {
73 Ok(s) => s,
74 Err(e) => {
75 eprintln!("TcpProxy: accept error: {}", e);
76 continue;
77 }
78 };
79 let p = Arc::clone(&proxy);
80 std::thread::spawn(move || {
81 if let Err(e) = p.relay(client) {
82 eprintln!("TcpProxy: relay error: {}", e);
83 }
84 });
85 }
86 Ok(())
87 }
88
89 fn pick_backend(&self) -> &str {
90 let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
91 &self.backends[i]
92 }
93
94 fn relay(&self, client: TcpStream) -> Result<(), String> {
95 let addr_str = self.pick_backend().to_string();
96 let sock_addr = addr_str
97 .to_socket_addrs()
98 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
99 .next()
100 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
101
102 let backend = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
103 .map_err(|e| format!("TcpProxy: connect to {} failed: {}", addr_str, e))?;
104
105 let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
106 let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
107 let mut client_w = client;
108 let mut backend_w = backend;
109
110 let t1 = std::thread::spawn(move || {
111 io::copy(&mut client_r, &mut backend_w).ok();
112 let _ = backend_w.shutdown(std::net::Shutdown::Write);
113 });
114 let t2 = std::thread::spawn(move || {
115 io::copy(&mut backend_r, &mut client_w).ok();
116 let _ = client_w.shutdown(std::net::Shutdown::Write);
117 });
118
119 let _ = t1.join();
120 let _ = t2.join();
121 Ok(())
122 }
123}