Skip to main content

stygian_proxy/
health.rs

1//! Async background health checker for proxy liveness verification.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tokio::task::{JoinHandle, JoinSet};
9use tokio_util::sync::CancellationToken;
10use uuid::Uuid;
11
12use crate::storage::ProxyStoragePort;
13use crate::types::ProxyConfig;
14
15/// Shared health map type.  
16/// `true` = proxy is currently considered healthy.
17pub type HealthMap = Arc<RwLock<HashMap<Uuid, bool>>>;
18
19/// Continuously verifies proxy liveness and updates the shared [`HealthMap`].
20///
21/// Run one check cycle with [`check_once`](HealthChecker::check_once) or launch
22/// a background task with [`spawn`](HealthChecker::spawn).
23#[derive(Clone)]
24pub struct HealthChecker {
25    config: ProxyConfig,
26    storage: Arc<dyn ProxyStoragePort>,
27    health_map: HealthMap,
28}
29
30impl HealthChecker {
31    /// Access the shared health map (read it to filter candidates).
32    pub fn health_map(&self) -> &HealthMap {
33        &self.health_map
34    }
35
36    /// Create a new checker.
37    ///
38    /// `health_map` should be the **same** `Arc` held by the `ProxyManager` so
39    /// that selection decisions always see up-to-date health information.
40    pub fn new(
41        config: ProxyConfig,
42        storage: Arc<dyn ProxyStoragePort>,
43        health_map: HealthMap,
44    ) -> Self {
45        Self {
46            config,
47            storage,
48            health_map,
49        }
50    }
51
52    /// Spawn an infinite background task that checks proxies on every
53    /// `config.health_check_interval` tick.
54    ///
55    /// Cancel `token` to stop the task gracefully.  Missed ticks are skipped.
56    pub fn spawn(self, token: CancellationToken) -> JoinHandle<()> {
57        tokio::spawn(async move {
58            let mut interval = tokio::time::interval(self.config.health_check_interval);
59            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
60            loop {
61                tokio::select! {
62                    _ = token.cancelled() => {
63                        tracing::info!("health checker: shutdown requested");
64                        break;
65                    }
66                    _ = interval.tick() => {
67                        self.check_all().await;
68                    }
69                }
70            }
71            tracing::info!("health checker: stopped");
72        })
73    }
74
75    /// Run one full check cycle synchronously (useful for tests).
76    pub async fn check_once(&self) {
77        self.check_all().await;
78    }
79
80    async fn check_all(&self) {
81        let records = match self.storage.list().await {
82            Ok(r) => r,
83            Err(e) => {
84                tracing::error!("health checker: storage list failed: {e}");
85                return;
86            }
87        };
88
89        let health_url = self.config.health_check_url.clone();
90        let timeout = self.config.health_check_timeout;
91
92        let mut set: JoinSet<(Uuid, Result<u64, String>)> = JoinSet::new();
93        for record in records {
94            let proxy_url = record.proxy.url.clone();
95            let username = record.proxy.username.clone();
96            let password = record.proxy.password.clone();
97            let id = record.id;
98            let check_url = health_url.clone();
99            set.spawn(async move {
100                let result = do_check(
101                    &proxy_url,
102                    username.as_deref(),
103                    password.as_deref(),
104                    &check_url,
105                    timeout,
106                )
107                .await;
108                (id, result)
109            });
110        }
111
112        let mut updates: Vec<(Uuid, bool, u64)> = Vec::new();
113        while let Some(task_result) = set.join_next().await {
114            match task_result {
115                Ok((id, Ok(latency_ms))) => updates.push((id, true, latency_ms)),
116                Ok((id, Err(e))) => {
117                    tracing::warn!(proxy = %id, error = %e, "health check failed");
118                    updates.push((id, false, 0));
119                }
120                Err(join_err) => {
121                    tracing::error!("health check task panicked: {join_err}");
122                }
123            }
124        }
125
126        let total = updates.len() as u32;
127        let healthy_count = updates.iter().filter(|(_, h, _)| *h).count() as u32;
128
129        {
130            let mut map = self.health_map.write().await;
131            for (id, healthy, _) in &updates {
132                map.insert(*id, *healthy);
133            }
134        }
135
136        for (id, success, latency) in updates {
137            if let Err(e) = self.storage.update_metrics(id, success, latency).await {
138                tracing::warn!("health checker: metrics update failed for {id}: {e}");
139            }
140        }
141
142        tracing::info!(
143            total,
144            healthy = healthy_count,
145            unhealthy = total - healthy_count,
146            "health check cycle complete"
147        );
148    }
149}
150
151/// Route a GET request through `proxy_url` to `health_url` and return the
152/// elapsed time in milliseconds on success.
153async fn do_check(
154    proxy_url: &str,
155    username: Option<&str>,
156    password: Option<&str>,
157    health_url: &str,
158    timeout: std::time::Duration,
159) -> Result<u64, String> {
160    let mut proxy = reqwest::Proxy::all(proxy_url).map_err(|e| e.to_string())?;
161    if let (Some(user), Some(pass)) = (username, password) {
162        proxy = proxy.basic_auth(user, pass);
163    }
164    let client = reqwest::Client::builder()
165        .proxy(proxy)
166        .timeout(timeout)
167        .build()
168        .map_err(|e| e.to_string())?;
169
170    let start = Instant::now();
171    client
172        .get(health_url)
173        .send()
174        .await
175        .map_err(|e| e.to_string())?
176        .error_for_status()
177        .map_err(|e| e.to_string())?;
178    Ok(start.elapsed().as_millis() as u64)
179}
180
181// ─────────────────────────────────────────────────────────────────────────────
182// Tests
183// ─────────────────────────────────────────────────────────────────────────────
184
185#[cfg(test)]
186mod tests {
187    use std::time::Duration;
188
189    use wiremock::matchers::method;
190    use wiremock::{Mock, MockServer, ResponseTemplate};
191
192    use super::*;
193    use crate::storage::MemoryProxyStore;
194    use crate::types::{Proxy, ProxyType};
195
196    fn make_proxy(url: &str) -> Proxy {
197        Proxy {
198            url: url.into(),
199            proxy_type: ProxyType::Http,
200            username: None,
201            password: None,
202            weight: 1,
203            tags: vec![],
204        }
205    }
206
207    #[tokio::test]
208    async fn healthy_and_unhealthy_proxies() {
209        // Mock server acts as both the HTTP proxy and the health-check target.
210        // reqwest sends the GET in absolute-form; wiremock responds 200.
211        let server = MockServer::start().await;
212        Mock::given(method("GET"))
213            .respond_with(ResponseTemplate::new(200))
214            .mount(&server)
215            .await;
216
217        let storage = Arc::new(MemoryProxyStore::default());
218        // Proxy 1: URL points to the mock server → health check will succeed.
219        storage.add(make_proxy(&server.uri())).await.unwrap();
220        // Proxy 2: invalid address → health check will fail.
221        storage
222            .add(make_proxy("http://192.0.2.1:9999"))
223            .await
224            .unwrap();
225
226        let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
227        let config = ProxyConfig {
228            health_check_url: format!("{}/", server.uri()),
229            health_check_interval: Duration::from_secs(3600),
230            health_check_timeout: Duration::from_secs(2),
231            ..ProxyConfig::default()
232        };
233        let checker = HealthChecker::new(config, storage.clone(), health_map.clone());
234        checker.check_once().await;
235
236        let map = health_map.read().await;
237        let healthy = map.values().filter(|&&v| v).count();
238        let unhealthy = map.values().filter(|&&v| !v).count();
239        assert_eq!(healthy, 1, "expected 1 healthy proxy");
240        assert_eq!(unhealthy, 1, "expected 1 unhealthy proxy");
241    }
242
243    #[tokio::test]
244    async fn graceful_shutdown() {
245        let storage = Arc::new(MemoryProxyStore::default());
246        let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
247        let config = ProxyConfig {
248            health_check_interval: Duration::from_secs(3600),
249            ..ProxyConfig::default()
250        };
251        let token = CancellationToken::new();
252        let checker = HealthChecker::new(config, storage, health_map);
253        let handle = checker.spawn(token.clone());
254
255        token.cancel();
256        let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
257        assert!(
258            result.is_ok(),
259            "task should exit within 1s after cancellation"
260        );
261    }
262}