1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12use tracing;
13
14use crate::error::Result;
15use crate::fingerprint::FingerprintProfile;
16use crate::transport::connector::MaybeHttpsStream;
17use crate::transport::h2::PseudoHeaderOrder;
18use crate::version::HttpVersion;
19
20#[derive(Debug, Clone, Hash, Eq, PartialEq)]
22pub struct PoolKey {
23 pub host: String,
24 pub port: u16,
25 pub is_https: bool,
26 pub fingerprint: FingerprintProfile,
27 pub pseudo_order: PseudoHeaderOrder,
28}
29
30impl PoolKey {
31 pub fn new(
33 host: String,
34 port: u16,
35 is_https: bool,
36 fingerprint: FingerprintProfile,
37 pseudo_order: PseudoHeaderOrder,
38 ) -> Self {
39 Self {
40 host,
41 port,
42 is_https,
43 fingerprint,
44 pseudo_order,
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct H1PoolEntry {
52 pub stream: MaybeHttpsStream,
53 pub last_used: Instant,
54}
55
56impl H1PoolEntry {
57 pub fn new(stream: MaybeHttpsStream) -> Self {
58 Self {
59 stream,
60 last_used: Instant::now(),
61 }
62 }
63
64 pub fn is_expired(&self, max_idle: Duration) -> bool {
65 self.last_used.elapsed() >= max_idle
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct PoolEntry {
72 pub version: HttpVersion,
73 pub established_at: Instant,
74 pub last_used: Instant,
75 pub active_streams: u32,
77 pub max_streams: u32,
79 pub is_valid: bool,
81}
82
83impl PoolEntry {
84 pub fn new(version: HttpVersion, max_streams: u32) -> Self {
86 let now = Instant::now();
87 Self {
88 version,
89 established_at: now,
90 last_used: now,
91 active_streams: 0,
92 max_streams,
93 is_valid: true,
94 }
95 }
96
97 pub fn can_multiplex(&self) -> bool {
99 matches!(
100 self.version,
101 HttpVersion::Http2 | HttpVersion::Http3 | HttpVersion::Http3Only
102 ) && self.active_streams < self.max_streams
103 && self.is_valid
104 }
105
106 pub fn acquire_stream(&mut self) -> bool {
108 if self.can_multiplex() {
109 self.active_streams += 1;
110 self.last_used = Instant::now();
111 true
112 } else {
113 false
114 }
115 }
116
117 pub fn release_stream(&mut self) {
119 if self.active_streams > 0 {
120 self.active_streams -= 1;
121 self.last_used = Instant::now();
122 }
123 }
124
125 pub fn invalidate(&mut self) {
127 self.is_valid = false;
128 }
129
130 pub fn is_expired(&self, max_idle: Duration) -> bool {
132 let age = Instant::now().duration_since(self.last_used);
133 age >= max_idle
134 }
135}
136
137pub struct ConnectionPool {
139 entries: Arc<RwLock<HashMap<PoolKey, PoolEntry>>>,
140 h1_idle: Arc<RwLock<HashMap<PoolKey, Vec<H1PoolEntry>>>>,
141 max_idle_duration: Duration,
142 #[allow(dead_code)] max_connections_per_host: usize,
144 default_max_streams: u32,
145}
146
147impl ConnectionPool {
148 const DEFAULT_MAX_IDLE: Duration = Duration::from_secs(30);
150
151 const DEFAULT_MAX_PER_HOST: usize = 6;
153
154 const DEFAULT_MAX_STREAMS: u32 = 100;
156
157 pub fn new() -> Self {
159 Self {
160 entries: Arc::new(RwLock::new(HashMap::new())),
161 h1_idle: Arc::new(RwLock::new(HashMap::new())),
162 max_idle_duration: Self::DEFAULT_MAX_IDLE,
163 max_connections_per_host: Self::DEFAULT_MAX_PER_HOST,
164 default_max_streams: Self::DEFAULT_MAX_STREAMS,
165 }
166 }
167
168 pub fn with_config(max_idle: Duration, max_per_host: usize, max_streams: u32) -> Self {
170 Self {
171 entries: Arc::new(RwLock::new(HashMap::new())),
172 h1_idle: Arc::new(RwLock::new(HashMap::new())),
173 max_idle_duration: max_idle,
174 max_connections_per_host: max_per_host,
175 default_max_streams: max_streams,
176 }
177 }
178
179 pub async fn get_h1(&self, key: &PoolKey) -> Option<MaybeHttpsStream> {
181 let start = Instant::now();
182 let mut pool = self.h1_idle.write().await;
183 if let Some(entries) = pool.get_mut(key) {
184 tracing::debug!("H1 Pool: {} entries for key {:?}", entries.len(), key);
185 let initial_count = entries.len();
186 while let Some(entry) = entries.pop() {
187 if !entry.is_expired(self.max_idle_duration) {
188 tracing::debug!(
189 "H1 Pool: Reusing connection for {:?} (checked {} entries, took {:?})",
190 key,
191 initial_count - entries.len(),
192 start.elapsed()
193 );
194 return Some(entry.stream);
195 }
196 tracing::debug!(
197 "H1 Pool: Connection expired for {:?} (age: {:?})",
198 key,
199 entry.last_used.elapsed()
200 );
201 }
202 } else {
203 tracing::debug!("H1 Pool: No entries for key {:?}", key);
204 }
205 tracing::debug!(
206 "H1 Pool: No reusable connection found for {:?} (took {:?})",
207 key,
208 start.elapsed()
209 );
210 None
211 }
212
213 pub async fn put_h1(&self, key: PoolKey, stream: MaybeHttpsStream) {
215 if self.max_connections_per_host == 0 {
216 return;
217 }
218 let start = Instant::now();
219 tracing::debug!("H1 Pool: Returning connection for {:?}", key);
220 let mut pool = self.h1_idle.write().await;
221 let entries = pool.entry(key.clone()).or_default();
222 let count_before = entries.len();
223 while entries.len() >= self.max_connections_per_host {
224 entries.remove(0);
225 }
226 entries.push(H1PoolEntry::new(stream));
227 tracing::debug!(
228 "H1 Pool: Returned connection for {:?} (pool size: {} -> {}, took {:?})",
229 key,
230 count_before,
231 entries.len(),
232 start.elapsed()
233 );
234 }
235
236 pub async fn get_or_create(
242 &self,
243 key: &PoolKey,
244 version: HttpVersion,
245 ) -> Result<Option<PoolEntry>> {
246 let start = Instant::now();
247 let mut entries = self.entries.write().await;
248
249 if version == HttpVersion::Http1_1 {
251 return Ok(None);
252 }
253
254 if let Some(entry) = entries.get_mut(key) {
256 let active_before = entry.active_streams;
257 if entry.acquire_stream() {
258 tracing::debug!(
259 "H2/H3 Pool: Reusing connection for {:?} (active streams: {} -> {}, took {:?})",
260 key,
261 active_before,
262 entry.active_streams,
263 start.elapsed()
264 );
265 return Ok(Some(entry.clone()));
266 } else {
267 tracing::debug!(
268 "H2/H3 Pool: Connection exists for {:?} but cannot multiplex (active: {}/{}, valid: {}, took {:?})",
269 key,
270 active_before,
271 entry.max_streams,
272 entry.is_valid,
273 start.elapsed()
274 );
275 }
276 } else {
277 tracing::debug!("H2/H3 Pool: No existing connection for {:?}", key);
278 }
279
280 tracing::debug!(
282 "H2/H3 Pool: Creating new connection entry for {:?} (took {:?})",
283 key,
284 start.elapsed()
285 );
286 let entry = PoolEntry::new(version, self.default_max_streams);
287 entries.insert(key.clone(), entry.clone());
288
289 Ok(Some(entry))
290 }
291
292 pub async fn release(&self, key: &PoolKey) {
294 let mut entries = self.entries.write().await;
295 if let Some(entry) = entries.get_mut(key) {
296 let active_before = entry.active_streams;
297 entry.release_stream();
298 tracing::debug!(
299 "H2/H3 Pool: Released stream for {:?} (active streams: {} -> {})",
300 key,
301 active_before,
302 entry.active_streams
303 );
304 } else {
305 tracing::warn!(
306 "H2/H3 Pool: Attempted to release stream for non-existent connection {:?}",
307 key
308 );
309 }
310 }
311
312 pub async fn invalidate(&self, key: &PoolKey) {
314 let mut entries = self.entries.write().await;
315 if let Some(entry) = entries.get_mut(key) {
316 entry.invalidate();
317 }
318 }
319
320 pub async fn cleanup(&self) {
322 {
324 let mut entries = self.entries.write().await;
325 entries
326 .retain(|_key, entry| entry.is_valid && !entry.is_expired(self.max_idle_duration));
327 }
328
329 {
331 let mut h1_pool = self.h1_idle.write().await;
332 for entries in h1_pool.values_mut() {
333 entries.retain(|e| !e.is_expired(self.max_idle_duration));
334 }
335 h1_pool.retain(|_, entries| !entries.is_empty());
336 }
337 }
338
339 pub fn spawn_cleanup_task(self: Arc<Self>, interval: Duration) -> tokio::task::JoinHandle<()> {
343 tokio::spawn(async move {
344 let mut interval_timer = tokio::time::interval(interval);
345 loop {
346 interval_timer.tick().await;
347 self.cleanup().await;
348 }
349 })
350 }
351
352 pub async fn stats(&self) -> PoolStats {
354 let entries = self.entries.read().await;
355 let h1_pool = self.h1_idle.read().await;
356
357 let h1_idle_count = h1_pool.values().map(|v| v.len()).sum();
358
359 PoolStats {
360 total_connections: entries.len() + h1_idle_count,
361 active_streams: entries.values().map(|e| e.active_streams).sum(),
362 http2_connections: entries
363 .values()
364 .filter(|e| matches!(e.version, HttpVersion::Http2))
365 .count(),
366 http3_connections: entries
367 .values()
368 .filter(|e| matches!(e.version, HttpVersion::Http3 | HttpVersion::Http3Only))
369 .count(),
370 http1_idle_connections: h1_idle_count,
371 }
372 }
373}
374
375impl Default for ConnectionPool {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381#[derive(Debug, Clone)]
383pub struct PoolStats {
384 pub total_connections: usize,
385 pub active_streams: u32,
386 pub http2_connections: usize,
387 pub http3_connections: usize,
388 pub http1_idle_connections: usize,
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_pool_key_equality() {
397 let key1 = PoolKey::new(
398 "example.com".to_string(),
399 443,
400 true,
401 FingerprintProfile::Chrome142,
402 PseudoHeaderOrder::Chrome,
403 );
404 let key2 = PoolKey::new(
405 "example.com".to_string(),
406 443,
407 true,
408 FingerprintProfile::Chrome142,
409 PseudoHeaderOrder::Chrome,
410 );
411 let key3 = PoolKey::new(
412 "example.com".to_string(),
413 80,
414 false,
415 FingerprintProfile::Chrome142,
416 PseudoHeaderOrder::Chrome,
417 );
418
419 assert_eq!(key1, key2);
420 assert_ne!(key1, key3);
421 }
422
423 #[test]
424 fn test_pool_entry_multiplexing() {
425 let mut entry = PoolEntry::new(HttpVersion::Http2, 100);
426
427 assert!(entry.can_multiplex());
429 assert!(entry.acquire_stream());
430 assert_eq!(entry.active_streams, 1);
431
432 entry.release_stream();
434 assert_eq!(entry.active_streams, 0);
435 }
436
437 #[test]
438 fn test_pool_entry_max_streams() {
439 let mut entry = PoolEntry::new(HttpVersion::Http2, 2);
440
441 assert!(entry.acquire_stream());
442 assert!(entry.acquire_stream());
443 assert!(!entry.acquire_stream()); assert_eq!(entry.active_streams, 2);
445 }
446
447 #[test]
448 fn test_pool_entry_invalidation() {
449 let mut entry = PoolEntry::new(HttpVersion::Http2, 100);
450
451 assert!(entry.can_multiplex());
452 entry.invalidate();
453 assert!(!entry.can_multiplex());
454 }
455
456 #[test]
457 fn test_pool_entry_expiration() {
458 let entry = PoolEntry::new(HttpVersion::Http2, 100);
459
460 assert!(!entry.is_expired(Duration::from_secs(30)));
462
463 assert!(entry.is_expired(Duration::from_secs(0)));
465 }
466
467 #[tokio::test]
468 async fn test_connection_pool_http11() {
469 let pool = ConnectionPool::new();
470 let key = PoolKey::new(
471 "example.com".to_string(),
472 443,
473 true,
474 FingerprintProfile::Chrome142,
475 PseudoHeaderOrder::Chrome,
476 );
477
478 let result = pool
480 .get_or_create(&key, HttpVersion::Http1_1)
481 .await
482 .unwrap();
483 assert!(result.is_none());
484 }
485
486 #[tokio::test]
487 async fn test_connection_pool_http2_multiplexing() {
488 let pool = ConnectionPool::new();
489 let key = PoolKey::new(
490 "example.com".to_string(),
491 443,
492 true,
493 FingerprintProfile::Chrome142,
494 PseudoHeaderOrder::Chrome,
495 );
496
497 let entry1 = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
499 assert!(entry1.is_some());
500
501 let entry2 = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
503 assert!(entry2.is_some());
504
505 let stats = pool.stats().await;
507 assert_eq!(stats.total_connections, 1);
508 assert_eq!(stats.http2_connections, 1);
509 }
510
511 #[tokio::test]
512 async fn test_connection_pool_release() {
513 let pool = ConnectionPool::new();
514 let key = PoolKey::new(
515 "example.com".to_string(),
516 443,
517 true,
518 FingerprintProfile::Chrome142,
519 PseudoHeaderOrder::Chrome,
520 );
521
522 let _entry = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
523
524 pool.release(&key).await;
526
527 let stats = pool.stats().await;
528 assert_eq!(stats.total_connections, 1);
529 }
530
531 #[tokio::test]
532 async fn test_connection_pool_invalidation() {
533 let pool = ConnectionPool::new();
534 let key = PoolKey::new(
535 "example.com".to_string(),
536 443,
537 true,
538 FingerprintProfile::Chrome142,
539 PseudoHeaderOrder::Chrome,
540 );
541
542 let _entry = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
543
544 pool.invalidate(&key).await;
546
547 pool.cleanup().await;
549
550 let stats = pool.stats().await;
551 assert_eq!(stats.total_connections, 0);
552 }
553}