rust_web_server/proxy_config/
health.rs1use std::io::{Read, Write};
8use std::net::{TcpStream, ToSocketAddrs};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11
12use crate::proxy_config::HealthCheckConfig;
13
14pub(crate) fn start_health_checker(
20 upstream_name: String,
21 backends: Vec<String>,
22 live: Arc<RwLock<Vec<String>>>,
23 config: HealthCheckConfig,
24) {
25 std::thread::Builder::new()
26 .name(format!("health-{}", upstream_name))
27 .spawn(move || {
28 let interval = Duration::from_secs(config.interval_secs);
29 let timeout = Duration::from_millis(config.timeout_ms);
30 let mut successes: Vec<u32> = vec![0; backends.len()];
32 let mut failures: Vec<u32> = vec![0; backends.len()];
33 let mut is_live: Vec<bool> = vec![true; backends.len()];
35
36 loop {
37 std::thread::sleep(interval);
38
39 for (i, backend) in backends.iter().enumerate() {
40 let ok = check_backend(backend, &config.path, timeout);
41 if ok {
42 successes[i] += 1;
43 failures[i] = 0;
44 if !is_live[i] && successes[i] >= config.healthy_threshold {
46 is_live[i] = true;
47 eprintln!(
48 "[health] upstream={} backend={} restored ({}x ok)",
49 upstream_name, backend, successes[i]
50 );
51 }
52 } else {
53 failures[i] += 1;
54 successes[i] = 0;
55 if is_live[i] && failures[i] >= config.unhealthy_threshold {
57 is_live[i] = false;
58 eprintln!(
59 "[health] upstream={} backend={} removed ({}x fail)",
60 upstream_name, backend, failures[i]
61 );
62 }
63 }
64 }
65
66 let live_list: Vec<String> = backends
68 .iter()
69 .enumerate()
70 .filter(|(i, _)| is_live[*i])
71 .map(|(_, b)| b.clone())
72 .collect();
73 if let Ok(mut guard) = live.write() {
74 *guard = live_list;
75 }
76 }
77 })
78 .ok();
79}
80
81fn check_backend(backend: &str, path: &str, timeout: Duration) -> bool {
84 let (host, port, tls) = match parse_backend_url(backend) {
85 Some(t) => t,
86 None => return false,
87 };
88
89 let addr_str = format!("{}:{}", host, port);
90 let sock_addr = match addr_str.to_socket_addrs().ok().and_then(|mut a| a.next()) {
91 Some(a) => a,
92 None => return false,
93 };
94
95 let stream = match TcpStream::connect_timeout(&sock_addr, timeout) {
96 Ok(s) => s,
97 Err(_) => return false,
98 };
99 let _ = stream.set_read_timeout(Some(timeout));
100 let _ = stream.set_write_timeout(Some(timeout));
101
102 let req = format!(
103 "GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
104 path, host
105 );
106
107 if tls {
108 check_via_tls(stream, &host, req.as_bytes())
109 } else {
110 let mut stream = stream;
111 if stream.write_all(req.as_bytes()).is_err() {
112 return false;
113 }
114 let mut buf = [0u8; 16];
115 if stream.read(&mut buf).is_err() {
116 return false;
117 }
118 buf.starts_with(b"HTTP/1.1 2") || buf.starts_with(b"HTTP/1.0 2")
119 }
120}
121
122#[cfg(any(feature = "http-client", feature = "http2"))]
123fn check_via_tls(stream: TcpStream, host: &str, req: &[u8]) -> bool {
124 use rustls::pki_types::ServerName;
125 use rustls::ClientConfig;
126 use std::sync::Arc;
127
128 let root_store =
129 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
130 let config = Arc::new(
131 ClientConfig::builder()
132 .with_root_certificates(root_store)
133 .with_no_client_auth(),
134 );
135 let server_name = match ServerName::try_from(host.to_string()) {
136 Ok(n) => n,
137 Err(_) => return false,
138 };
139 let conn = match rustls::ClientConnection::new(config, server_name) {
140 Ok(c) => c,
141 Err(_) => return false,
142 };
143 let mut tls = rustls::StreamOwned::new(conn, stream);
144 if tls.write_all(req).is_err() {
145 return false;
146 }
147 let mut buf = [0u8; 16];
148 if tls.read(&mut buf).is_err() {
149 return false;
150 }
151 buf.starts_with(b"HTTP/1.1 2") || buf.starts_with(b"HTTP/1.0 2")
152}
153
154#[cfg(not(any(feature = "http-client", feature = "http2")))]
156fn check_via_tls(_stream: TcpStream, _host: &str, _req: &[u8]) -> bool {
157 false
158}
159
160pub(crate) fn parse_backend_url(backend: &str) -> Option<(String, u16, bool)> {
166 let (rest, tls, default_port) = if let Some(r) = backend.strip_prefix("https://") {
167 (r, true, 443u16)
168 } else if let Some(r) = backend.strip_prefix("http://") {
169 (r, false, 80u16)
170 } else if let Some(r) = backend.strip_prefix("h2://") {
171 (r, false, 80u16)
172 } else {
173 (backend, false, 80u16)
174 };
175
176 let host_port = rest.split('/').next().unwrap_or(rest);
178 if host_port.is_empty() {
179 return None;
180 }
181
182 let (host, port) = if host_port.starts_with('[') {
184 let close = host_port.find(']')?;
186 let host = host_port[1..close].to_string();
187 let port = if host_port.len() > close + 1 && host_port.as_bytes()[close + 1] == b':' {
188 host_port[close + 2..].parse::<u16>().unwrap_or(default_port)
189 } else {
190 default_port
191 };
192 (host, port)
193 } else if let Some(colon) = host_port.rfind(':') {
194 let port_str = &host_port[colon + 1..];
195 if let Ok(p) = port_str.parse::<u16>() {
196 (host_port[..colon].to_string(), p)
197 } else {
198 (host_port.to_string(), default_port)
199 }
200 } else {
201 (host_port.to_string(), default_port)
202 };
203
204 if host.is_empty() {
205 return None;
206 }
207 Some((host, port, tls))
208}
209