1use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use serde::Serialize;
11use tokio::sync::RwLock;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use crate::circuit_breaker::CircuitBreaker;
17use crate::error::{ProxyError, ProxyResult};
18use crate::health::{HealthChecker, HealthMap};
19use crate::session::{SessionMap, StickyPolicy};
20use crate::storage::ProxyStoragePort;
21use crate::strategy::{
22 BoxedRotationStrategy, LeastUsedStrategy, ProxyCandidate, RandomStrategy, RoundRobinStrategy,
23 WeightedStrategy,
24};
25use crate::types::{Proxy, ProxyConfig};
26
27#[derive(Debug, Serialize)]
33pub struct PoolStats {
34 pub total: usize,
36 pub healthy: usize,
38 pub open: usize,
40 pub active_sessions: usize,
42}
43
44pub struct ProxyHandle {
54 pub proxy_url: String,
56 circuit_breaker: Arc<CircuitBreaker>,
57 succeeded: AtomicBool,
58 session_key: Option<String>,
60 sessions: Option<SessionMap>,
61}
62
63impl ProxyHandle {
64 const fn new(proxy_url: String, circuit_breaker: Arc<CircuitBreaker>) -> Self {
65 Self {
66 proxy_url,
67 circuit_breaker,
68 succeeded: AtomicBool::new(false),
69 session_key: None,
70 sessions: None,
71 }
72 }
73
74 const fn new_sticky(
75 proxy_url: String,
76 circuit_breaker: Arc<CircuitBreaker>,
77 session_key: String,
78 sessions: SessionMap,
79 ) -> Self {
80 Self {
81 proxy_url,
82 circuit_breaker,
83 succeeded: AtomicBool::new(false),
84 session_key: Some(session_key),
85 sessions: Some(sessions),
86 }
87 }
88
89 pub fn direct() -> Self {
94 let noop_cb = Arc::new(CircuitBreaker::new(u32::MAX, u64::MAX));
95 Self {
96 proxy_url: String::new(),
97 circuit_breaker: noop_cb,
98 succeeded: AtomicBool::new(true),
99 session_key: None,
100 sessions: None,
101 }
102 }
103
104 pub fn mark_success(&self) {
106 self.succeeded.store(true, Ordering::Release);
107 }
108}
109
110impl std::fmt::Debug for ProxyHandle {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("ProxyHandle")
113 .field("proxy_url", &self.proxy_url)
114 .finish_non_exhaustive()
115 }
116}
117
118impl Drop for ProxyHandle {
119 fn drop(&mut self) {
120 if self.succeeded.load(Ordering::Acquire) {
121 self.circuit_breaker.record_success();
122 } else {
123 self.circuit_breaker.record_failure();
124 if let (Some(key), Some(sessions)) = (&self.session_key, &self.sessions) {
126 sessions.unbind(key);
127 }
128 }
129 }
130}
131
132pub struct ProxyManager {
169 storage: Arc<dyn ProxyStoragePort>,
170 strategy: BoxedRotationStrategy,
171 health_checker: HealthChecker,
172 circuit_breakers: Arc<RwLock<HashMap<Uuid, Arc<CircuitBreaker>>>>,
173 config: ProxyConfig,
174 sessions: SessionMap,
176}
177
178impl ProxyManager {
179 pub fn builder() -> ProxyManagerBuilder {
181 ProxyManagerBuilder::default()
182 }
183
184 pub fn with_round_robin(
186 storage: Arc<dyn ProxyStoragePort>,
187 config: ProxyConfig,
188 ) -> ProxyResult<Self> {
189 Self::builder()
190 .storage(storage)
191 .strategy(Arc::new(RoundRobinStrategy::default()))
192 .config(config)
193 .build()
194 }
195
196 pub fn with_random(
198 storage: Arc<dyn ProxyStoragePort>,
199 config: ProxyConfig,
200 ) -> ProxyResult<Self> {
201 Self::builder()
202 .storage(storage)
203 .strategy(Arc::new(RandomStrategy))
204 .config(config)
205 .build()
206 }
207
208 pub fn with_weighted(
210 storage: Arc<dyn ProxyStoragePort>,
211 config: ProxyConfig,
212 ) -> ProxyResult<Self> {
213 Self::builder()
214 .storage(storage)
215 .strategy(Arc::new(WeightedStrategy))
216 .config(config)
217 .build()
218 }
219
220 pub fn with_least_used(
222 storage: Arc<dyn ProxyStoragePort>,
223 config: ProxyConfig,
224 ) -> ProxyResult<Self> {
225 Self::builder()
226 .storage(storage)
227 .strategy(Arc::new(LeastUsedStrategy))
228 .config(config)
229 .build()
230 }
231
232 #[allow(clippy::significant_drop_tightening)]
243 pub async fn add_proxy(&self, proxy: Proxy) -> ProxyResult<Uuid> {
244 let mut cb_map = self.circuit_breakers.write().await;
245 let record = self.storage.add(proxy).await?;
246 cb_map.insert(
247 record.id,
248 Arc::new(CircuitBreaker::new(
249 self.config.circuit_open_threshold,
250 u64::try_from(self.config.circuit_half_open_after.as_millis()).unwrap_or(u64::MAX),
251 )),
252 );
253 Ok(record.id)
254 }
255
256 pub async fn remove_proxy(&self, id: Uuid) -> ProxyResult<()> {
258 self.storage.remove(id).await?;
259 self.circuit_breakers.write().await.remove(&id);
260 Ok(())
261 }
262
263 pub fn start(&self) -> (CancellationToken, JoinHandle<()>) {
270 let token = CancellationToken::new();
271 let health_handle = self.health_checker.clone().spawn(token.clone());
272
273 let sessions = self.sessions.clone();
274 let purge_token = token.clone();
275 let purge_handle = tokio::spawn(async move {
276 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
277 loop {
278 tokio::select! {
279 _ = interval.tick() => { sessions.purge_expired(); }
280 () = purge_token.cancelled() => break,
281 }
282 }
283 });
284
285 let combined = tokio::spawn(async move {
286 let _ = tokio::join!(health_handle, purge_handle);
287 });
288
289 (token, combined)
290 }
291
292 #[allow(clippy::significant_drop_tightening)]
298 async fn select_proxy_inner(&self) -> ProxyResult<(String, Arc<CircuitBreaker>, Uuid)> {
299 let with_metrics = self.storage.list_with_metrics().await?;
300 if with_metrics.is_empty() {
301 return Err(ProxyError::PoolExhausted);
302 }
303
304 let candidates = {
307 let health_map_ref = Arc::clone(self.health_checker.health_map());
308 let health_map = health_map_ref.read().await;
309 let cb_map_ref = Arc::clone(&self.circuit_breakers);
310 let cb_map = cb_map_ref.read().await;
311 let candidates: Vec<ProxyCandidate> = with_metrics
312 .iter()
313 .map(|(record, metrics)| {
314 let healthy = health_map.get(&record.id).copied().unwrap_or(true);
315 let available = cb_map.get(&record.id).is_none_or(|cb| cb.is_available());
316 ProxyCandidate {
317 id: record.id,
318 weight: record.proxy.weight,
319 metrics: Arc::clone(metrics),
320 healthy: healthy && available,
321 }
322 })
323 .collect();
324 candidates
325 };
327
328 let selected = self.strategy.select(&candidates).await?;
329 let id = selected.id;
330
331 let cb = self
333 .circuit_breakers
334 .read()
335 .await
336 .get(&id)
337 .cloned()
338 .ok_or(ProxyError::PoolExhausted)?;
339 let url = with_metrics
340 .iter()
341 .find(|(r, _)| r.id == id)
342 .map(|(r, _)| r.proxy.url.clone())
343 .unwrap_or_default();
344
345 Ok((url, cb, id))
346 }
347
348 pub async fn acquire_proxy(&self) -> ProxyResult<ProxyHandle> {
354 let (url, cb, _id) = self.select_proxy_inner().await?;
355 Ok(ProxyHandle::new(url, cb))
356 }
357
358 pub async fn acquire_for_domain(&self, domain: &str) -> ProxyResult<ProxyHandle> {
372 let ttl = match &self.config.sticky_policy {
373 StickyPolicy::Disabled => return self.acquire_proxy().await,
374 StickyPolicy::Domain { ttl } => *ttl,
375 };
376
377 if let Some(proxy_id) = self.sessions.lookup(domain) {
379 let cb_map = self.circuit_breakers.read().await;
380 if let Some(cb) = cb_map.get(&proxy_id).cloned()
381 && cb.is_available()
382 {
383 let with_metrics = self.storage.list_with_metrics().await?;
385 if let Some((record, _)) = with_metrics.iter().find(|(r, _)| r.id == proxy_id) {
386 let url = record.proxy.url.clone();
387 drop(cb_map);
388 return Ok(ProxyHandle::new_sticky(
389 url,
390 cb,
391 domain.to_string(),
392 self.sessions.clone(),
393 ));
394 }
395 }
396 drop(cb_map);
398 self.sessions.unbind(domain);
399 }
400
401 let (url, cb, proxy_id) = self.select_proxy_inner().await?;
403 self.sessions.bind(domain, proxy_id, ttl);
404 Ok(ProxyHandle::new_sticky(
405 url,
406 cb,
407 domain.to_string(),
408 self.sessions.clone(),
409 ))
410 }
411
412 pub async fn pool_stats(&self) -> ProxyResult<PoolStats> {
416 let records = self.storage.list().await?;
417 let total = records.len();
418 let health_map = self.health_checker.health_map().read().await;
419 let cb_map = self.circuit_breakers.read().await;
420
421 let mut healthy = 0usize;
422 let mut open = 0usize;
423 for r in &records {
424 if health_map.get(&r.id).copied().unwrap_or(true) {
425 healthy += 1;
426 }
427 if cb_map.get(&r.id).is_some_and(|cb| !cb.is_available()) {
428 open += 1;
429 }
430 }
431 drop(health_map);
432 drop(cb_map);
433 Ok(PoolStats {
434 total,
435 healthy,
436 open,
437 active_sessions: self.sessions.active_count(),
438 })
439 }
440}
441
442#[derive(Default)]
448pub struct ProxyManagerBuilder {
449 storage: Option<Arc<dyn ProxyStoragePort>>,
450 strategy: Option<BoxedRotationStrategy>,
451 config: Option<ProxyConfig>,
452}
453
454impl ProxyManagerBuilder {
455 #[must_use]
456 pub fn storage(mut self, s: Arc<dyn ProxyStoragePort>) -> Self {
457 self.storage = Some(s);
458 self
459 }
460
461 #[must_use]
462 pub fn strategy(mut self, s: BoxedRotationStrategy) -> Self {
463 self.strategy = Some(s);
464 self
465 }
466
467 #[must_use]
468 pub fn config(mut self, c: ProxyConfig) -> Self {
469 self.config = Some(c);
470 self
471 }
472
473 pub fn build(self) -> ProxyResult<ProxyManager> {
479 let storage = self.storage.ok_or_else(|| {
480 ProxyError::ConfigError("ProxyManagerBuilder: storage is required".into())
481 })?;
482 let strategy = self
483 .strategy
484 .unwrap_or_else(|| Arc::new(RoundRobinStrategy::default()));
485 let config = self.config.unwrap_or_default();
486 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
487 let checker = HealthChecker::new(
488 config.clone(),
489 Arc::clone(&storage),
490 Arc::clone(&health_map),
491 );
492
493 #[cfg(feature = "tls-profiled")]
494 let health_checker = if let Some(mode) = config.profiled_request_mode {
495 checker.with_profiled_mode(mode)?
496 } else {
497 checker
498 };
499
500 #[cfg(not(feature = "tls-profiled"))]
501 let health_checker = checker;
502
503 Ok(ProxyManager {
504 storage,
505 strategy,
506 health_checker,
507 circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
508 config,
509 sessions: SessionMap::new(),
510 })
511 }
512}
513
514#[cfg(test)]
519#[allow(
520 clippy::unwrap_used,
521 clippy::significant_drop_tightening,
522 clippy::manual_let_else,
523 clippy::panic
524)]
525mod tests {
526 use std::collections::HashSet;
527 use std::time::Duration;
528
529 use super::*;
530 use crate::circuit_breaker::{STATE_CLOSED, STATE_OPEN};
531 use crate::storage::MemoryProxyStore;
532 use crate::types::ProxyType;
533
534 fn make_proxy(url: &str) -> Proxy {
535 Proxy {
536 url: url.into(),
537 proxy_type: ProxyType::Http,
538 username: None,
539 password: None,
540 weight: 1,
541 tags: vec![],
542 }
543 }
544
545 fn storage() -> Arc<MemoryProxyStore> {
546 Arc::new(MemoryProxyStore::default())
547 }
548
549 #[tokio::test]
551 async fn round_robin_distribution() {
552 let store = storage();
553 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
554 mgr.add_proxy(make_proxy("http://a.test:8080"))
555 .await
556 .unwrap();
557 mgr.add_proxy(make_proxy("http://b.test:8080"))
558 .await
559 .unwrap();
560 mgr.add_proxy(make_proxy("http://c.test:8080"))
561 .await
562 .unwrap();
563
564 let mut seen = HashSet::new();
565 for _ in 0..10 {
566 let h = mgr.acquire_proxy().await.unwrap();
567 h.mark_success();
568 seen.insert(h.proxy_url.clone());
569 }
570 assert_eq!(seen.len(), 3, "all three proxies should have been selected");
571 }
572
573 #[tokio::test]
575 async fn all_open_returns_error() {
576 let store = storage();
577 let mgr = ProxyManager::with_round_robin(
578 store.clone(),
579 ProxyConfig {
580 circuit_open_threshold: 1,
581 ..ProxyConfig::default()
582 },
583 )
584 .unwrap();
585 let id = mgr
586 .add_proxy(make_proxy("http://x.test:8080"))
587 .await
588 .unwrap();
589
590 {
592 let map = mgr.circuit_breakers.read().await;
593 let cb = map.get(&id).unwrap();
594 cb.record_failure();
595 }
596
597 let err = mgr.acquire_proxy().await.unwrap_err();
598 assert!(
599 matches!(err, ProxyError::AllProxiesUnhealthy),
600 "expected AllProxiesUnhealthy, got {err:?}"
601 );
602 }
603
604 #[tokio::test]
606 async fn handle_drop_records_failure() {
607 let store = storage();
608 let mgr = ProxyManager::with_round_robin(
609 store.clone(),
610 ProxyConfig {
611 circuit_open_threshold: 1,
612 ..ProxyConfig::default()
613 },
614 )
615 .unwrap();
616 let id = mgr
617 .add_proxy(make_proxy("http://y.test:8080"))
618 .await
619 .unwrap();
620
621 {
622 let _h = mgr.acquire_proxy().await.unwrap();
623 }
625
626 let cb_map = mgr.circuit_breakers.read().await;
627 let cb = cb_map.get(&id).unwrap();
628 assert_eq!(cb.state(), STATE_OPEN);
629 }
630
631 #[tokio::test]
633 async fn handle_success_keeps_closed() {
634 let store = storage();
635 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
636 let id = mgr
637 .add_proxy(make_proxy("http://z.test:8080"))
638 .await
639 .unwrap();
640
641 let h = mgr.acquire_proxy().await.unwrap();
642 h.mark_success();
643 drop(h);
644
645 let cb_map = mgr.circuit_breakers.read().await;
646 let cb = cb_map.get(&id).unwrap();
647 assert_eq!(cb.state(), STATE_CLOSED);
648 }
649
650 #[tokio::test]
652 async fn start_and_graceful_shutdown() {
653 let store = storage();
654 let mgr = ProxyManager::with_round_robin(
655 store,
656 ProxyConfig {
657 health_check_interval: Duration::from_secs(3600),
658 ..ProxyConfig::default()
659 },
660 )
661 .unwrap();
662 let (token, handle) = mgr.start();
663 token.cancel();
664 let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
665 assert!(result.is_ok(), "health checker task should exit within 1s");
666 }
667
668 #[cfg(feature = "tls-profiled")]
669 #[tokio::test]
670 async fn builder_accepts_profiled_request_mode_preset() {
671 let store = storage();
672 let cfg = ProxyConfig {
673 profiled_request_mode: Some(crate::types::ProfiledRequestMode::Preset),
674 ..ProxyConfig::default()
675 };
676
677 let result = ProxyManager::builder()
678 .storage(store)
679 .strategy(Arc::new(RoundRobinStrategy::default()))
680 .config(cfg)
681 .build();
682
683 assert!(
684 result.is_ok(),
685 "builder should accept profiled preset mode: {:?}",
686 result.err()
687 );
688 }
689
690 #[cfg(feature = "tls-profiled")]
691 #[tokio::test]
692 async fn builder_rejects_profiled_request_mode_strict_all_for_chrome() {
693 let store = storage();
694 let cfg = ProxyConfig {
695 profiled_request_mode: Some(crate::types::ProfiledRequestMode::StrictAll),
696 ..ProxyConfig::default()
697 };
698
699 let result = ProxyManager::builder()
700 .storage(store)
701 .strategy(Arc::new(RoundRobinStrategy::default()))
702 .config(cfg)
703 .build();
704
705 let Err(err) = result else {
706 panic!("strict_all should fail for default Chrome baseline profile")
707 };
708
709 assert!(
710 matches!(err, ProxyError::ConfigError(_)),
711 "expected ConfigError, got {err:?}"
712 );
713 }
714
715 fn sticky_config() -> ProxyConfig {
718 use crate::session::StickyPolicy;
719 ProxyConfig {
720 sticky_policy: StickyPolicy::domain_default(),
721 ..ProxyConfig::default()
722 }
723 }
724
725 #[tokio::test]
727 async fn sticky_same_domain_returns_same_proxy() {
728 let store = storage();
729 let mgr = ProxyManager::with_round_robin(store, sticky_config()).unwrap();
730 mgr.add_proxy(make_proxy("http://p1.test:8080"))
731 .await
732 .unwrap();
733 mgr.add_proxy(make_proxy("http://p2.test:8080"))
734 .await
735 .unwrap();
736
737 let h1 = mgr.acquire_for_domain("example.com").await.unwrap();
738 let url1 = h1.proxy_url.clone();
739 h1.mark_success();
740
741 let h2 = mgr.acquire_for_domain("example.com").await.unwrap();
742 let url2 = h2.proxy_url.clone();
743 h2.mark_success();
744
745 assert_eq!(url1, url2, "same domain should return the same proxy");
746 }
747
748 #[tokio::test]
750 async fn sticky_different_domains_may_differ() {
751 let store = storage();
752 let mgr = ProxyManager::with_round_robin(store, sticky_config()).unwrap();
753 mgr.add_proxy(make_proxy("http://pa.test:8080"))
754 .await
755 .unwrap();
756 mgr.add_proxy(make_proxy("http://pb.test:8080"))
757 .await
758 .unwrap();
759
760 let ha = mgr.acquire_for_domain("a.com").await.unwrap();
761 let url_a = ha.proxy_url.clone();
762 ha.mark_success();
763
764 let hb = mgr.acquire_for_domain("b.com").await.unwrap();
765 let url_b = hb.proxy_url.clone();
766 hb.mark_success();
767
768 assert_ne!(
770 url_a, url_b,
771 "different domains should differ in this scenario"
772 );
773 }
774
775 #[tokio::test]
778 async fn sticky_expired_session_re_acquires() {
779 use crate::session::StickyPolicy;
780 let store = storage();
781 let mgr = ProxyManager::with_round_robin(
782 store,
783 ProxyConfig {
784 sticky_policy: StickyPolicy::domain(Duration::from_millis(1)),
785 ..ProxyConfig::default()
786 },
787 )
788 .unwrap();
789 mgr.add_proxy(make_proxy("http://x.test:8080"))
790 .await
791 .unwrap();
792
793 let h1 = mgr.acquire_for_domain("expired.com").await.unwrap();
794 h1.mark_success();
795
796 tokio::time::sleep(Duration::from_millis(5)).await;
798
799 let h2 = mgr.acquire_for_domain("expired.com").await.unwrap();
801 h2.mark_success();
802 }
803
804 #[tokio::test]
807 async fn sticky_cb_trip_invalidates_session() {
808 let store = storage();
809 let mgr = ProxyManager::with_round_robin(
810 store,
811 ProxyConfig {
812 circuit_open_threshold: 1,
813 sticky_policy: sticky_config().sticky_policy,
814 ..ProxyConfig::default()
815 },
816 )
817 .unwrap();
818 mgr.add_proxy(make_proxy("http://q1.test:8080"))
819 .await
820 .unwrap();
821 mgr.add_proxy(make_proxy("http://q2.test:8080"))
822 .await
823 .unwrap();
824
825 let h1 = mgr.acquire_for_domain("cb.com").await.unwrap();
827 let url1 = h1.proxy_url.clone();
828 drop(h1);
830
831 tokio::task::yield_now().await;
833
834 let _h2 = mgr.acquire_for_domain("cb.com").await;
838 let _ = url1;
840 }
841
842 #[tokio::test]
844 async fn sticky_purge_expired() {
845 use crate::session::StickyPolicy;
846 let store = storage();
847 let mgr = ProxyManager::with_round_robin(
848 store,
849 ProxyConfig {
850 sticky_policy: StickyPolicy::domain(Duration::from_millis(1)),
851 ..ProxyConfig::default()
852 },
853 )
854 .unwrap();
855 mgr.add_proxy(make_proxy("http://r.test:8080"))
856 .await
857 .unwrap();
858
859 let h = mgr.acquire_for_domain("purge.com").await.unwrap();
860 h.mark_success();
861
862 assert_eq!(mgr.sessions.active_count(), 1);
863
864 tokio::time::sleep(Duration::from_millis(5)).await;
866 mgr.sessions.purge_expired();
867
868 assert_eq!(mgr.sessions.active_count(), 0);
869 }
870
871 #[tokio::test]
873 async fn pool_stats_includes_sessions() {
874 let store = storage();
875 let mgr = ProxyManager::with_round_robin(store, sticky_config()).unwrap();
876 mgr.add_proxy(make_proxy("http://s.test:8080"))
877 .await
878 .unwrap();
879
880 let stats = mgr.pool_stats().await.unwrap();
881 assert_eq!(stats.active_sessions, 0);
882
883 let h = mgr.acquire_for_domain("stats.com").await.unwrap();
884 h.mark_success();
885
886 let stats = mgr.pool_stats().await.unwrap();
887 assert_eq!(stats.active_sessions, 1);
888 }
889}