1use std::io::{Read, Write};
35use std::net::{TcpListener, TcpStream, ToSocketAddrs};
36use std::sync::{
37 Arc,
38 atomic::{AtomicUsize, Ordering},
39};
40use std::time::Duration;
41
42use crate::request::Request;
43use crate::websocket::WebSocket;
44
45pub struct WsProxy {
58 backends: Vec<WsBackend>,
59 counter: Arc<AtomicUsize>,
60 connect_timeout: Duration,
61 read_timeout: Duration,
62}
63
64impl WsProxy {
65 pub fn new<I, S>(backends: I) -> Self
72 where
73 I: IntoIterator<Item = S>,
74 S: Into<String>,
75 {
76 WsProxy {
77 backends: backends
78 .into_iter()
79 .filter_map(|b| WsBackend::parse(&b.into()))
80 .collect(),
81 counter: Arc::new(AtomicUsize::new(0)),
82 connect_timeout: Duration::from_secs(5),
83 read_timeout: Duration::from_secs(30),
84 }
85 }
86
87 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
89 self.connect_timeout = Duration::from_millis(ms);
90 self
91 }
92
93 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
98 self.read_timeout = Duration::from_millis(ms);
99 self
100 }
101
102 pub fn bind(self, addr: &str) -> Result<(), String> {
104 if self.backends.is_empty() {
105 return Err("WsProxy: no backends configured".to_string());
106 }
107 let listener = TcpListener::bind(addr)
108 .map_err(|e| format!("WsProxy: bind on {} failed: {}", addr, e))?;
109 println!("WsProxy: listening on {}", addr);
110 let proxy = Arc::new(self);
111 for incoming in listener.incoming() {
112 let client = match incoming {
113 Ok(s) => s,
114 Err(e) => {
115 eprintln!("WsProxy: accept error: {}", e);
116 continue;
117 }
118 };
119 let p = Arc::clone(&proxy);
120 std::thread::spawn(move || {
121 if let Err(e) = p.handle(client) {
122 eprintln!("WsProxy: {}", e);
123 }
124 });
125 }
126 Ok(())
127 }
128
129 fn pick_backend(&self) -> &WsBackend {
130 let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
131 &self.backends[i]
132 }
133
134 fn handle(&self, mut client: TcpStream) -> Result<(), String> {
135 client.set_read_timeout(Some(self.read_timeout)).ok();
136
137 let mut buf = vec![0u8; 8192];
139 let n = client.read(&mut buf).map_err(|e| e.to_string())?;
140 if n == 0 {
141 return Ok(());
142 }
143
144 let request = Request::parse(&buf[..n])
145 .map_err(|e| format!("WsProxy: invalid HTTP request: {}", e))?;
146
147 if !WebSocket::is_upgrade_request(&request) {
148 let _ = client.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n");
149 return Err(format!(
150 "WsProxy: not a WebSocket upgrade — method={}, uri={}",
151 request.method, request.request_uri
152 ));
153 }
154
155 let backend = self.pick_backend();
156 let addr_str = &backend.addr;
157 let sock_addr = addr_str
158 .to_socket_addrs()
159 .map_err(|e| format!("WsProxy: DNS lookup for {} failed: {}", addr_str, e))?
160 .next()
161 .ok_or_else(|| format!("WsProxy: no address for {}", addr_str))?;
162
163 let tcp = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
164 .map_err(|e| format!("WsProxy: connect to {} failed: {}", addr_str, e))?;
165
166 let upgrade_req = build_upgrade_request(&request, &backend.host);
168
169 if backend.tls {
170 self.handle_tls(client, tcp, &request, &backend.host, upgrade_req, addr_str)
171 } else {
172 handle_plain(client, tcp, &request, upgrade_req, addr_str)
173 }
174 }
175
176 fn handle_tls(
177 &self,
178 mut client: TcpStream,
179 tcp: TcpStream,
180 request: &Request,
181 host: &str,
182 upgrade_req: Vec<u8>,
183 addr_str: &str,
184 ) -> Result<(), String> {
185 #[cfg(any(feature = "http-client", feature = "http2"))]
186 {
187 use rustls::pki_types::ServerName;
188 use rustls::ClientConfig;
189 use std::sync::Arc;
190
191 let root_store = rustls::RootCertStore::from_iter(
192 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
193 );
194 let config = Arc::new(
195 ClientConfig::builder()
196 .with_root_certificates(root_store)
197 .with_no_client_auth(),
198 );
199 let server_name = ServerName::try_from(host)
200 .map_err(|e| format!("WsProxy: invalid hostname '{}': {}", host, e))?
201 .to_owned();
202 let conn = rustls::ClientConnection::new(config, server_name)
203 .map_err(|e| format!("WsProxy: TLS init failed: {}", e))?;
204 let mut tls = rustls::StreamOwned::new(conn, tcp);
205
206 tls.write_all(&upgrade_req)
208 .map_err(|e| format!("WsProxy: write upgrade to {} failed: {}", addr_str, e))?;
209
210 let mut resp_buf = vec![0u8; 4096];
212 let m = tls
213 .read(&mut resp_buf)
214 .map_err(|e| format!("WsProxy: read 101 from {} failed: {}", addr_str, e))?;
215 let preview = &resp_buf[..m.min(20)];
216 if !preview.starts_with(b"HTTP/1.1 101") && !preview.starts_with(b"HTTP/1.0 101") {
217 return Err(format!(
218 "WsProxy: backend {} did not send 101 (got {:?})",
219 addr_str,
220 std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
221 ));
222 }
223
224 let response_101 = WebSocket::handshake_response(request)?;
226 let raw_101 = format_response_head(&response_101);
227 client
228 .write_all(&raw_101)
229 .map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
230
231 tls.sock.set_read_timeout(Some(Duration::from_millis(5))).ok();
234 client.set_read_timeout(Some(Duration::from_millis(5))).ok();
235 relay_tls(client, tls);
236 Ok(())
237 }
238
239 #[cfg(not(any(feature = "http-client", feature = "http2")))]
240 {
241 let _ = (tcp, request, host, upgrade_req, addr_str);
242 let _ = client.write_all(
243 b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n",
244 );
245 Err("WsProxy: wss:// upstreams require the http-client or http2 Cargo feature".to_string())
246 }
247 }
248}
249
250fn handle_plain(
252 mut client: TcpStream,
253 mut backend: TcpStream,
254 request: &Request,
255 upgrade_req: Vec<u8>,
256 addr_str: &str,
257) -> Result<(), String> {
258 backend
259 .write_all(&upgrade_req)
260 .map_err(|e| format!("WsProxy: write upgrade to {} failed: {}", addr_str, e))?;
261
262 let mut resp_buf = vec![0u8; 4096];
263 let m = backend
264 .read(&mut resp_buf)
265 .map_err(|e| format!("WsProxy: read 101 from {} failed: {}", addr_str, e))?;
266 let preview = &resp_buf[..m.min(20)];
267 if !preview.starts_with(b"HTTP/1.1 101") && !preview.starts_with(b"HTTP/1.0 101") {
268 return Err(format!(
269 "WsProxy: backend {} did not send 101 (got {:?})",
270 addr_str,
271 std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
272 ));
273 }
274
275 let response_101 = WebSocket::handshake_response(request)?;
276 let raw_101 = format_response_head(&response_101);
277 client
278 .write_all(&raw_101)
279 .map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
280
281 let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
283 let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
284 let mut client_w = client;
285 let mut backend_w = backend;
286
287 let t1 = std::thread::spawn(move || {
288 std::io::copy(&mut client_r, &mut backend_w).ok();
289 let _ = backend_w.shutdown(std::net::Shutdown::Write);
290 });
291 let t2 = std::thread::spawn(move || {
292 std::io::copy(&mut backend_r, &mut client_w).ok();
293 let _ = client_w.shutdown(std::net::Shutdown::Write);
294 });
295
296 let _ = t1.join();
297 let _ = t2.join();
298 Ok(())
299}
300
301#[cfg(any(feature = "http-client", feature = "http2"))]
310fn relay_tls(
311 mut client: TcpStream,
312 mut backend: rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
313) {
314 use std::io::ErrorKind::{TimedOut, WouldBlock};
315 let mut buf = [0u8; 8192];
316
317 loop {
318 let mut active = false;
319
320 let cn = match client.read(&mut buf) {
322 Ok(0) => break, Ok(n) => n,
324 Err(ref e) if e.kind() == TimedOut || e.kind() == WouldBlock => 0,
325 Err(_) => break,
326 };
327 if cn > 0 {
328 if backend.write_all(&buf[..cn]).is_err() {
329 break;
330 }
331 active = true;
332 }
333
334 let bn = match backend.read(&mut buf) {
336 Ok(0) => break, Ok(n) => n,
338 Err(ref e) if e.kind() == TimedOut || e.kind() == WouldBlock => 0,
339 Err(_) => break,
340 };
341 if bn > 0 {
342 if client.write_all(&buf[..bn]).is_err() {
343 break;
344 }
345 active = true;
346 }
347
348 if !active {
349 std::thread::sleep(Duration::from_millis(1));
350 }
351 }
352}
353
354fn build_upgrade_request(request: &Request, backend_host: &str) -> Vec<u8> {
357 let mut req = format!(
358 "{} {} HTTP/1.1\r\nHost: {}\r\n",
359 request.method, request.request_uri, backend_host
360 );
361 for header in &request.headers {
362 if header.name.to_lowercase() == "host" {
363 continue;
364 }
365 req.push_str(&format!("{}: {}\r\n", header.name, header.value));
366 }
367 req.push_str("\r\n");
368 req.into_bytes()
369}
370
371fn format_response_head(response: &crate::response::Response) -> Vec<u8> {
372 let mut out = format!(
373 "HTTP/1.1 {} {}\r\n",
374 response.status_code, response.reason_phrase
375 )
376 .into_bytes();
377 for h in &response.headers {
378 out.extend_from_slice(h.name.as_bytes());
379 out.extend_from_slice(b": ");
380 out.extend_from_slice(h.value.as_bytes());
381 out.extend_from_slice(b"\r\n");
382 }
383 out.extend_from_slice(b"\r\n");
384 out
385}
386
387struct WsBackend {
390 addr: String,
392 host: String,
394 tls: bool,
396}
397
398impl WsBackend {
399 fn parse(url: &str) -> Option<Self> {
400 let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("wss://") {
401 (r, true, 443u16)
402 } else if let Some(r) = url.strip_prefix("ws://") {
403 (r, false, 80u16)
404 } else {
405 (url, false, 80u16)
406 };
407
408 let host_port = rest.split('/').next().unwrap_or(rest);
410
411 let (host, port) = if let Some(colon) = host_port.rfind(':') {
412 let port_str = &host_port[colon + 1..];
413 if let Ok(p) = port_str.parse::<u16>() {
414 (host_port[..colon].to_string(), p)
415 } else {
416 (host_port.to_string(), default_port)
417 }
418 } else {
419 (host_port.to_string(), default_port)
420 };
421
422 if host.is_empty() {
423 return None;
424 }
425
426 Some(WsBackend {
427 addr: format!("{}:{}", host, port),
428 host,
429 tls,
430 })
431 }
432}
433
434#[cfg(test)]
437mod backend_parse_tests {
438 use super::WsBackend;
439
440 fn parse(url: &str) -> Option<(String, String, bool)> {
441 WsBackend::parse(url).map(|b| (b.addr, b.host, b.tls))
442 }
443
444 #[test]
445 fn bare_host_port() {
446 assert_eq!(
447 Some(("chat:9000".into(), "chat".into(), false)),
448 parse("chat:9000")
449 );
450 }
451
452 #[test]
453 fn ws_scheme_plain() {
454 assert_eq!(
455 Some(("backend:3000".into(), "backend".into(), false)),
456 parse("ws://backend:3000")
457 );
458 }
459
460 #[test]
461 fn ws_scheme_default_port() {
462 assert_eq!(
463 Some(("api.example.com:80".into(), "api.example.com".into(), false)),
464 parse("ws://api.example.com")
465 );
466 }
467
468 #[test]
469 fn wss_scheme_sets_tls() {
470 assert_eq!(
471 Some(("secure.example.com:443".into(), "secure.example.com".into(), true)),
472 parse("wss://secure.example.com")
473 );
474 }
475
476 #[test]
477 fn wss_scheme_explicit_port() {
478 assert_eq!(
479 Some(("secure.example.com:8443".into(), "secure.example.com".into(), true)),
480 parse("wss://secure.example.com:8443")
481 );
482 }
483
484 #[test]
485 fn wss_default_port_is_443() {
486 let b = WsBackend::parse("wss://api.example.com").unwrap();
487 assert_eq!("api.example.com:443", b.addr);
488 assert_eq!("api.example.com", b.host);
489 assert!(b.tls);
490 }
491
492 #[test]
493 fn ws_default_port_is_80() {
494 let b = WsBackend::parse("ws://api.example.com").unwrap();
495 assert_eq!("api.example.com:80", b.addr);
496 assert!(!b.tls);
497 }
498
499 #[test]
500 fn empty_host_returns_none() {
501 assert_eq!(None, parse("wss://"));
502 }
503
504 #[test]
505 fn bare_host_no_port_defaults_to_80() {
506 assert_eq!(
507 Some(("myhost:80".into(), "myhost".into(), false)),
508 parse("myhost")
509 );
510 }
511
512 #[test]
513 fn path_component_is_ignored() {
514 let b = WsBackend::parse("ws://backend:9000/ws").unwrap();
516 assert_eq!("backend:9000", b.addr);
517 assert_eq!("backend", b.host);
518 }
519}