Skip to main content

rust_web_server/proxy_config/
health.rs

1//! Background health-checker for upstream backends.
2//!
3//! Each `[[upstream]]` with a `[upstream.health_check]` section gets a
4//! dedicated background thread that periodically sends `GET {path}` to every
5//! backend and updates the shared `Arc<RwLock<Vec<String>>>` live-backend list.
6
7use std::io::{Read, Write};
8use std::net::{TcpStream, ToSocketAddrs};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11
12use crate::proxy_config::HealthCheckConfig;
13
14/// Start a background health-checker thread.
15///
16/// The thread runs until the process exits. It periodically checks every
17/// backend in `backends` by sending `GET {config.path} HTTP/1.1` and tracking
18/// consecutive successes/failures. The `live` list is updated accordingly.
19pub(crate) fn start_health_checker(
20    upstream_name: String,
21    backends: Vec<String>,
22    live: Arc<RwLock<Vec<String>>>,
23    config: HealthCheckConfig,
24) {
25    std::thread::Builder::new()
26        .name(format!("health-{}", upstream_name))
27        .spawn(move || {
28            let interval = Duration::from_secs(config.interval_secs);
29            let timeout = Duration::from_millis(config.timeout_ms);
30            // Per-backend consecutive success/failure counters
31            let mut successes: Vec<u32> = vec![0; backends.len()];
32            let mut failures: Vec<u32> = vec![0; backends.len()];
33            // Initial state: all backends considered alive
34            let mut is_live: Vec<bool> = vec![true; backends.len()];
35
36            loop {
37                std::thread::sleep(interval);
38
39                for (i, backend) in backends.iter().enumerate() {
40                    let ok = check_backend(backend, &config.path, timeout);
41                    if ok {
42                        successes[i] += 1;
43                        failures[i] = 0;
44                        // Restore if we have enough consecutive successes
45                        if !is_live[i] && successes[i] >= config.healthy_threshold {
46                            is_live[i] = true;
47                            eprintln!(
48                                "[health] upstream={} backend={} restored ({}x ok)",
49                                upstream_name, backend, successes[i]
50                            );
51                        }
52                    } else {
53                        failures[i] += 1;
54                        successes[i] = 0;
55                        // Remove if we have enough consecutive failures
56                        if is_live[i] && failures[i] >= config.unhealthy_threshold {
57                            is_live[i] = false;
58                            eprintln!(
59                                "[health] upstream={} backend={} removed ({}x fail)",
60                                upstream_name, backend, failures[i]
61                            );
62                        }
63                    }
64                }
65
66                // Update the live list
67                let live_list: Vec<String> = backends
68                    .iter()
69                    .enumerate()
70                    .filter(|(i, _)| is_live[*i])
71                    .map(|(_, b)| b.clone())
72                    .collect();
73                if let Ok(mut guard) = live.write() {
74                    *guard = live_list;
75                }
76            }
77        })
78        .ok();
79}
80
81/// Send a minimal HTTP(S)/1.1 GET request to `backend` at `path` with the
82/// given `timeout`. Returns `true` on a 2xx response.
83fn check_backend(backend: &str, path: &str, timeout: Duration) -> bool {
84    let (host, port, tls) = match parse_backend_url(backend) {
85        Some(t) => t,
86        None => return false,
87    };
88
89    let addr_str = format!("{}:{}", host, port);
90    let sock_addr = match addr_str.to_socket_addrs().ok().and_then(|mut a| a.next()) {
91        Some(a) => a,
92        None => return false,
93    };
94
95    let stream = match TcpStream::connect_timeout(&sock_addr, timeout) {
96        Ok(s) => s,
97        Err(_) => return false,
98    };
99    let _ = stream.set_read_timeout(Some(timeout));
100    let _ = stream.set_write_timeout(Some(timeout));
101
102    let req = format!(
103        "GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
104        path, host
105    );
106
107    if tls {
108        check_via_tls(stream, &host, req.as_bytes())
109    } else {
110        let mut stream = stream;
111        if stream.write_all(req.as_bytes()).is_err() {
112            return false;
113        }
114        let mut buf = [0u8; 16];
115        if stream.read(&mut buf).is_err() {
116            return false;
117        }
118        buf.starts_with(b"HTTP/1.1 2") || buf.starts_with(b"HTTP/1.0 2")
119    }
120}
121
122#[cfg(any(feature = "http-client", feature = "http2"))]
123fn check_via_tls(stream: TcpStream, host: &str, req: &[u8]) -> bool {
124    use rustls::pki_types::ServerName;
125    use rustls::ClientConfig;
126    use std::sync::Arc;
127
128    let root_store =
129        rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
130    let config = Arc::new(
131        ClientConfig::builder()
132            .with_root_certificates(root_store)
133            .with_no_client_auth(),
134    );
135    let server_name = match ServerName::try_from(host.to_string()) {
136        Ok(n) => n,
137        Err(_) => return false,
138    };
139    let conn = match rustls::ClientConnection::new(config, server_name) {
140        Ok(c) => c,
141        Err(_) => return false,
142    };
143    let mut tls = rustls::StreamOwned::new(conn, stream);
144    if tls.write_all(req).is_err() {
145        return false;
146    }
147    let mut buf = [0u8; 16];
148    if tls.read(&mut buf).is_err() {
149        return false;
150    }
151    buf.starts_with(b"HTTP/1.1 2") || buf.starts_with(b"HTTP/1.0 2")
152}
153
154// When rustls is not compiled in, silently skip TLS health checks.
155#[cfg(not(any(feature = "http-client", feature = "http2")))]
156fn check_via_tls(_stream: TcpStream, _host: &str, _req: &[u8]) -> bool {
157    false
158}
159
160/// Parse a backend address that may include a scheme prefix.
161///
162/// Returns `(host, port, tls)`:
163/// - `https://` → TLS=true, default port 443
164/// - `http://`, `h2://`, or no scheme → TLS=false, default port 80
165pub(crate) fn parse_backend_url(backend: &str) -> Option<(String, u16, bool)> {
166    let (rest, tls, default_port) = if let Some(r) = backend.strip_prefix("https://") {
167        (r, true, 443u16)
168    } else if let Some(r) = backend.strip_prefix("http://") {
169        (r, false, 80u16)
170    } else if let Some(r) = backend.strip_prefix("h2://") {
171        (r, false, 80u16)
172    } else {
173        (backend, false, 80u16)
174    };
175
176    // Drop any path component
177    let host_port = rest.split('/').next().unwrap_or(rest);
178    if host_port.is_empty() {
179        return None;
180    }
181
182    // Handle IPv6 addresses like [::1]:8080
183    let (host, port) = if host_port.starts_with('[') {
184        // IPv6 literal: [host]:port or [host]
185        let close = host_port.find(']')?;
186        let host = host_port[1..close].to_string();
187        let port = if host_port.len() > close + 1 && host_port.as_bytes()[close + 1] == b':' {
188            host_port[close + 2..].parse::<u16>().unwrap_or(default_port)
189        } else {
190            default_port
191        };
192        (host, port)
193    } else if let Some(colon) = host_port.rfind(':') {
194        let port_str = &host_port[colon + 1..];
195        if let Ok(p) = port_str.parse::<u16>() {
196            (host_port[..colon].to_string(), p)
197        } else {
198            (host_port.to_string(), default_port)
199        }
200    } else {
201        (host_port.to_string(), default_port)
202    };
203
204    if host.is_empty() {
205        return None;
206    }
207    Some((host, port, tls))
208}
209