rust_web_server/ws_proxy/
mod.rs1use std::io::{Read, Write};
24use std::net::{TcpListener, TcpStream, ToSocketAddrs};
25use std::sync::{
26 Arc,
27 atomic::{AtomicUsize, Ordering},
28};
29use std::time::Duration;
30
31use crate::request::Request;
32use crate::websocket::WebSocket;
33
34pub struct WsProxy {
44 backends: Vec<String>,
45 counter: Arc<AtomicUsize>,
46 connect_timeout: Duration,
47 read_timeout: Duration,
48}
49
50impl WsProxy {
51 pub fn new<I, S>(backends: I) -> Self
54 where
55 I: IntoIterator<Item = S>,
56 S: Into<String>,
57 {
58 WsProxy {
59 backends: backends.into_iter().map(|b| b.into()).collect(),
60 counter: Arc::new(AtomicUsize::new(0)),
61 connect_timeout: Duration::from_secs(5),
62 read_timeout: Duration::from_secs(30),
63 }
64 }
65
66 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
68 self.connect_timeout = Duration::from_millis(ms);
69 self
70 }
71
72 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
74 self.read_timeout = Duration::from_millis(ms);
75 self
76 }
77
78 pub fn bind(self, addr: &str) -> Result<(), String> {
80 if self.backends.is_empty() {
81 return Err("WsProxy: no backends configured".to_string());
82 }
83 let listener = TcpListener::bind(addr)
84 .map_err(|e| format!("WsProxy: bind on {} failed: {}", addr, e))?;
85 println!("WsProxy: listening on {}", addr);
86 let proxy = Arc::new(self);
87 for incoming in listener.incoming() {
88 let client = match incoming {
89 Ok(s) => s,
90 Err(e) => {
91 eprintln!("WsProxy: accept error: {}", e);
92 continue;
93 }
94 };
95 let p = Arc::clone(&proxy);
96 std::thread::spawn(move || {
97 if let Err(e) = p.handle(client) {
98 eprintln!("WsProxy: {}", e);
99 }
100 });
101 }
102 Ok(())
103 }
104
105 fn pick_backend(&self) -> &str {
106 let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
107 &self.backends[i]
108 }
109
110 fn handle(&self, mut client: TcpStream) -> Result<(), String> {
111 client.set_read_timeout(Some(self.read_timeout)).ok();
112
113 let mut buf = vec![0u8; 8192];
115 let n = client.read(&mut buf).map_err(|e| e.to_string())?;
116 if n == 0 {
117 return Ok(());
118 }
119
120 let request = Request::parse(&buf[..n])
121 .map_err(|e| format!("WsProxy: invalid HTTP request: {}", e))?;
122
123 if !WebSocket::is_upgrade_request(&request) {
124 let _ = client.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n");
125 return Err(format!(
126 "WsProxy: not a WebSocket upgrade — method={}, uri={}",
127 request.method, request.request_uri
128 ));
129 }
130
131 let backend_str = self.pick_backend().to_string();
133 let backend_sock = backend_str
134 .to_socket_addrs()
135 .map_err(|e| format!("WsProxy: DNS lookup for {} failed: {}", backend_str, e))?
136 .next()
137 .ok_or_else(|| format!("WsProxy: no address for {}", backend_str))?;
138
139 let mut backend = TcpStream::connect_timeout(&backend_sock, self.connect_timeout)
140 .map_err(|e| format!("WsProxy: connect to {} failed: {}", backend_str, e))?;
141
142 let upgrade_req = build_upgrade_request(&request, &backend_str);
144 backend
145 .write_all(&upgrade_req)
146 .map_err(|e| format!("WsProxy: write upgrade to backend failed: {}", e))?;
147
148 let mut resp_buf = vec![0u8; 4096];
150 let m = backend
151 .read(&mut resp_buf)
152 .map_err(|e| format!("WsProxy: read 101 from backend failed: {}", e))?;
153 let resp_preview = &resp_buf[..m.min(20)];
154 if !resp_preview.starts_with(b"HTTP/1.1 101") && !resp_preview.starts_with(b"HTTP/1.0 101") {
155 return Err(format!(
156 "WsProxy: backend {} did not send 101 (got {:?})",
157 backend_str,
158 std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
159 ));
160 }
161
162 let response_101 = WebSocket::handshake_response(&request)?;
164 let raw_101 = format_response_head(&response_101);
165 client
166 .write_all(&raw_101)
167 .map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
168
169 let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
171 let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
172 let mut client_w = client;
173 let mut backend_w = backend;
174
175 let t1 = std::thread::spawn(move || {
176 std::io::copy(&mut client_r, &mut backend_w).ok();
177 let _ = backend_w.shutdown(std::net::Shutdown::Write);
178 });
179 let t2 = std::thread::spawn(move || {
180 std::io::copy(&mut backend_r, &mut client_w).ok();
181 let _ = client_w.shutdown(std::net::Shutdown::Write);
182 });
183
184 let _ = t1.join();
185 let _ = t2.join();
186 Ok(())
187 }
188}
189
190fn build_upgrade_request(request: &Request, backend_host: &str) -> Vec<u8> {
191 let mut req = format!(
192 "{} {} HTTP/1.1\r\nHost: {}\r\n",
193 request.method, request.request_uri, backend_host
194 );
195 for header in &request.headers {
196 if header.name.to_lowercase() == "host" {
197 continue;
198 }
199 req.push_str(&format!("{}: {}\r\n", header.name, header.value));
200 }
201 req.push_str("\r\n");
202 req.into_bytes()
203}
204
205fn format_response_head(response: &crate::response::Response) -> Vec<u8> {
206 let mut out = format!(
207 "HTTP/1.1 {} {}\r\n",
208 response.status_code, response.reason_phrase
209 )
210 .into_bytes();
211 for h in &response.headers {
212 out.extend_from_slice(h.name.as_bytes());
213 out.extend_from_slice(b": ");
214 out.extend_from_slice(h.value.as_bytes());
215 out.extend_from_slice(b"\r\n");
216 }
217 out.extend_from_slice(b"\r\n");
218 out
219}