1use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use serde::Serialize;
11use tokio::sync::RwLock;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use crate::circuit_breaker::CircuitBreaker;
17use crate::error::{ProxyError, ProxyResult};
18use crate::health::{HealthChecker, HealthMap};
19use crate::storage::ProxyStoragePort;
20use crate::strategy::{
21 BoxedRotationStrategy, LeastUsedStrategy, ProxyCandidate, RandomStrategy, RoundRobinStrategy,
22 WeightedStrategy,
23};
24use crate::types::{Proxy, ProxyConfig};
25
26#[derive(Debug, Serialize)]
32pub struct PoolStats {
33 pub total: usize,
35 pub healthy: usize,
37 pub open: usize,
39}
40
41pub struct ProxyHandle {
51 pub proxy_url: String,
53 circuit_breaker: Arc<CircuitBreaker>,
54 succeeded: AtomicBool,
55}
56
57impl ProxyHandle {
58 fn new(proxy_url: String, circuit_breaker: Arc<CircuitBreaker>) -> Self {
59 Self {
60 proxy_url,
61 circuit_breaker,
62 succeeded: AtomicBool::new(false),
63 }
64 }
65
66 pub fn direct() -> Self {
71 let noop_cb = Arc::new(CircuitBreaker::new(u32::MAX, u64::MAX));
72 Self {
73 proxy_url: String::new(),
74 circuit_breaker: noop_cb,
75 succeeded: AtomicBool::new(true),
76 }
77 }
78
79 pub fn mark_success(&self) {
81 self.succeeded.store(true, Ordering::Release);
82 }
83}
84
85impl std::fmt::Debug for ProxyHandle {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("ProxyHandle")
88 .field("proxy_url", &self.proxy_url)
89 .finish_non_exhaustive()
90 }
91}
92
93impl Drop for ProxyHandle {
94 fn drop(&mut self) {
95 if self.succeeded.load(Ordering::Acquire) {
96 self.circuit_breaker.record_success();
97 } else {
98 self.circuit_breaker.record_failure();
99 }
100 }
101}
102
103pub struct ProxyManager {
140 storage: Arc<dyn ProxyStoragePort>,
141 strategy: BoxedRotationStrategy,
142 health_checker: HealthChecker,
143 circuit_breakers: Arc<RwLock<HashMap<Uuid, Arc<CircuitBreaker>>>>,
144 config: ProxyConfig,
145}
146
147impl ProxyManager {
148 pub fn builder() -> ProxyManagerBuilder {
150 ProxyManagerBuilder::default()
151 }
152
153 pub fn with_round_robin(
155 storage: Arc<dyn ProxyStoragePort>,
156 config: ProxyConfig,
157 ) -> ProxyResult<Self> {
158 Self::builder()
159 .storage(storage)
160 .strategy(Arc::new(RoundRobinStrategy::default()))
161 .config(config)
162 .build()
163 }
164
165 pub fn with_random(
167 storage: Arc<dyn ProxyStoragePort>,
168 config: ProxyConfig,
169 ) -> ProxyResult<Self> {
170 Self::builder()
171 .storage(storage)
172 .strategy(Arc::new(RandomStrategy))
173 .config(config)
174 .build()
175 }
176
177 pub fn with_weighted(
179 storage: Arc<dyn ProxyStoragePort>,
180 config: ProxyConfig,
181 ) -> ProxyResult<Self> {
182 Self::builder()
183 .storage(storage)
184 .strategy(Arc::new(WeightedStrategy))
185 .config(config)
186 .build()
187 }
188
189 pub fn with_least_used(
191 storage: Arc<dyn ProxyStoragePort>,
192 config: ProxyConfig,
193 ) -> ProxyResult<Self> {
194 Self::builder()
195 .storage(storage)
196 .strategy(Arc::new(LeastUsedStrategy))
197 .config(config)
198 .build()
199 }
200
201 pub async fn add_proxy(&self, proxy: Proxy) -> ProxyResult<Uuid> {
212 let mut cb_map = self.circuit_breakers.write().await;
213 let record = self.storage.add(proxy).await?;
214 cb_map.insert(
215 record.id,
216 Arc::new(CircuitBreaker::new(
217 self.config.circuit_open_threshold,
218 self.config.circuit_half_open_after.as_millis() as u64,
219 )),
220 );
221 Ok(record.id)
222 }
223
224 pub async fn remove_proxy(&self, id: Uuid) -> ProxyResult<()> {
226 self.storage.remove(id).await?;
227 self.circuit_breakers.write().await.remove(&id);
228 Ok(())
229 }
230
231 pub fn start(&self) -> (CancellationToken, JoinHandle<()>) {
238 let token = CancellationToken::new();
239 let handle = self.health_checker.clone().spawn(token.clone());
240 (token, handle)
241 }
242
243 pub async fn acquire_proxy(&self) -> ProxyResult<ProxyHandle> {
251 let with_metrics = self.storage.list_with_metrics().await?;
252 if with_metrics.is_empty() {
253 return Err(ProxyError::PoolExhausted);
254 }
255
256 let health_map: tokio::sync::RwLockReadGuard<'_, _> =
257 self.health_checker.health_map().read().await;
258 let cb_map = self.circuit_breakers.read().await;
259
260 let candidates: Vec<ProxyCandidate> = with_metrics
261 .iter()
262 .map(|(record, metrics)| {
263 let healthy = health_map.get(&record.id).copied().unwrap_or(true);
265 let available = cb_map
266 .get(&record.id)
267 .map(|cb| cb.is_available())
268 .unwrap_or(true);
269 ProxyCandidate {
270 id: record.id,
271 weight: record.proxy.weight,
272 metrics: Arc::clone(metrics),
273 healthy: healthy && available,
274 }
275 })
276 .collect();
277
278 drop(health_map);
279 let selected = self.strategy.select(&candidates).await?;
280 let id = selected.id;
281
282 let cb = cb_map.get(&id).cloned().ok_or(ProxyError::PoolExhausted)?;
286
287 let url = with_metrics
288 .iter()
289 .find(|(r, _)| r.id == id)
290 .map(|(r, _)| r.proxy.url.clone())
291 .unwrap_or_default();
292
293 Ok(ProxyHandle::new(url, cb))
294 }
295
296 pub async fn pool_stats(&self) -> ProxyResult<PoolStats> {
300 let records = self.storage.list().await?;
301 let total = records.len();
302 let health_map = self.health_checker.health_map().read().await;
303 let cb_map = self.circuit_breakers.read().await;
304
305 let mut healthy = 0usize;
306 let mut open = 0usize;
307 for r in &records {
308 if health_map.get(&r.id).copied().unwrap_or(true) {
309 healthy += 1;
310 }
311 if cb_map
312 .get(&r.id)
313 .map(|cb| !cb.is_available())
314 .unwrap_or(false)
315 {
316 open += 1;
317 }
318 }
319 Ok(PoolStats {
320 total,
321 healthy,
322 open,
323 })
324 }
325}
326
327#[derive(Default)]
333pub struct ProxyManagerBuilder {
334 storage: Option<Arc<dyn ProxyStoragePort>>,
335 strategy: Option<BoxedRotationStrategy>,
336 config: Option<ProxyConfig>,
337}
338
339impl ProxyManagerBuilder {
340 pub fn storage(mut self, s: Arc<dyn ProxyStoragePort>) -> Self {
341 self.storage = Some(s);
342 self
343 }
344
345 pub fn strategy(mut self, s: BoxedRotationStrategy) -> Self {
346 self.strategy = Some(s);
347 self
348 }
349
350 pub fn config(mut self, c: ProxyConfig) -> Self {
351 self.config = Some(c);
352 self
353 }
354
355 pub fn build(self) -> ProxyResult<ProxyManager> {
361 let storage = self.storage.ok_or_else(|| {
362 ProxyError::ConfigError("ProxyManagerBuilder: storage is required".into())
363 })?;
364 let strategy = self
365 .strategy
366 .unwrap_or_else(|| Arc::new(RoundRobinStrategy::default()));
367 let config = self.config.unwrap_or_default();
368 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
369 let health_checker = HealthChecker::new(
370 config.clone(),
371 Arc::clone(&storage),
372 Arc::clone(&health_map),
373 );
374 Ok(ProxyManager {
375 storage,
376 strategy,
377 health_checker,
378 circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
379 config,
380 })
381 }
382}
383
384#[cfg(test)]
389mod tests {
390 use std::collections::HashSet;
391 use std::time::Duration;
392
393 use super::*;
394 use crate::circuit_breaker::{STATE_CLOSED, STATE_OPEN};
395 use crate::storage::MemoryProxyStore;
396 use crate::types::ProxyType;
397
398 fn make_proxy(url: &str) -> Proxy {
399 Proxy {
400 url: url.into(),
401 proxy_type: ProxyType::Http,
402 username: None,
403 password: None,
404 weight: 1,
405 tags: vec![],
406 }
407 }
408
409 fn storage() -> Arc<MemoryProxyStore> {
410 Arc::new(MemoryProxyStore::default())
411 }
412
413 #[tokio::test]
415 async fn round_robin_distribution() {
416 let store = storage();
417 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
418 mgr.add_proxy(make_proxy("http://a.test:8080"))
419 .await
420 .unwrap();
421 mgr.add_proxy(make_proxy("http://b.test:8080"))
422 .await
423 .unwrap();
424 mgr.add_proxy(make_proxy("http://c.test:8080"))
425 .await
426 .unwrap();
427
428 let mut seen = HashSet::new();
429 for _ in 0..10 {
430 let h = mgr.acquire_proxy().await.unwrap();
431 h.mark_success();
432 seen.insert(h.proxy_url.clone());
433 }
434 assert_eq!(seen.len(), 3, "all three proxies should have been selected");
435 }
436
437 #[tokio::test]
439 async fn all_open_returns_error() {
440 let store = storage();
441 let mgr = ProxyManager::with_round_robin(
442 store.clone(),
443 ProxyConfig {
444 circuit_open_threshold: 1,
445 ..ProxyConfig::default()
446 },
447 )
448 .unwrap();
449 let id = mgr
450 .add_proxy(make_proxy("http://x.test:8080"))
451 .await
452 .unwrap();
453
454 {
456 let map = mgr.circuit_breakers.read().await;
457 let cb = map.get(&id).unwrap();
458 cb.record_failure();
459 }
460
461 let err = mgr.acquire_proxy().await.unwrap_err();
462 assert!(
463 matches!(err, ProxyError::AllProxiesUnhealthy),
464 "expected AllProxiesUnhealthy, got {err:?}"
465 );
466 }
467
468 #[tokio::test]
470 async fn handle_drop_records_failure() {
471 let store = storage();
472 let mgr = ProxyManager::with_round_robin(
473 store.clone(),
474 ProxyConfig {
475 circuit_open_threshold: 1,
476 ..ProxyConfig::default()
477 },
478 )
479 .unwrap();
480 let id = mgr
481 .add_proxy(make_proxy("http://y.test:8080"))
482 .await
483 .unwrap();
484
485 {
486 let _h = mgr.acquire_proxy().await.unwrap();
487 }
489
490 let cb_map = mgr.circuit_breakers.read().await;
491 let cb = cb_map.get(&id).unwrap();
492 assert_eq!(cb.state(), STATE_OPEN);
493 }
494
495 #[tokio::test]
497 async fn handle_success_keeps_closed() {
498 let store = storage();
499 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
500 let id = mgr
501 .add_proxy(make_proxy("http://z.test:8080"))
502 .await
503 .unwrap();
504
505 let h = mgr.acquire_proxy().await.unwrap();
506 h.mark_success();
507 drop(h);
508
509 let cb_map = mgr.circuit_breakers.read().await;
510 let cb = cb_map.get(&id).unwrap();
511 assert_eq!(cb.state(), STATE_CLOSED);
512 }
513
514 #[tokio::test]
516 async fn start_and_graceful_shutdown() {
517 let store = storage();
518 let mgr = ProxyManager::with_round_robin(
519 store,
520 ProxyConfig {
521 health_check_interval: Duration::from_secs(3600),
522 ..ProxyConfig::default()
523 },
524 )
525 .unwrap();
526 let (token, handle) = mgr.start();
527 token.cancel();
528 let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
529 assert!(result.is_ok(), "health checker task should exit within 1s");
530 }
531}