1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12use tracing;
13
14use crate::error::Result;
15use crate::transport::connector::MaybeHttpsStream;
16use crate::version::HttpVersion;
17
18#[derive(Debug, Clone, Hash, Eq, PartialEq)]
20pub struct PoolKey {
21 pub host: String,
22 pub port: u16,
23 pub is_https: bool,
24}
25
26impl PoolKey {
27 pub fn new(host: String, port: u16, is_https: bool) -> Self {
29 Self {
30 host,
31 port,
32 is_https,
33 }
34 }
35}
36
37#[derive(Debug)]
39pub struct H1PoolEntry {
40 pub stream: MaybeHttpsStream,
41 pub last_used: Instant,
42}
43
44impl H1PoolEntry {
45 pub fn new(stream: MaybeHttpsStream) -> Self {
46 Self {
47 stream,
48 last_used: Instant::now(),
49 }
50 }
51
52 pub fn is_expired(&self, max_idle: Duration) -> bool {
53 self.last_used.elapsed() >= max_idle
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct PoolEntry {
60 pub version: HttpVersion,
61 pub established_at: Instant,
62 pub last_used: Instant,
63 pub active_streams: u32,
65 pub max_streams: u32,
67 pub is_valid: bool,
69}
70
71impl PoolEntry {
72 pub fn new(version: HttpVersion, max_streams: u32) -> Self {
74 let now = Instant::now();
75 Self {
76 version,
77 established_at: now,
78 last_used: now,
79 active_streams: 0,
80 max_streams,
81 is_valid: true,
82 }
83 }
84
85 pub fn can_multiplex(&self) -> bool {
87 matches!(
88 self.version,
89 HttpVersion::Http2 | HttpVersion::Http3 | HttpVersion::Http3Only
90 ) && self.active_streams < self.max_streams
91 && self.is_valid
92 }
93
94 pub fn acquire_stream(&mut self) -> bool {
96 if self.can_multiplex() {
97 self.active_streams += 1;
98 self.last_used = Instant::now();
99 true
100 } else {
101 false
102 }
103 }
104
105 pub fn release_stream(&mut self) {
107 if self.active_streams > 0 {
108 self.active_streams -= 1;
109 self.last_used = Instant::now();
110 }
111 }
112
113 pub fn invalidate(&mut self) {
115 self.is_valid = false;
116 }
117
118 pub fn is_expired(&self, max_idle: Duration) -> bool {
120 let age = Instant::now().duration_since(self.last_used);
121 age >= max_idle
122 }
123}
124
125pub struct ConnectionPool {
127 entries: Arc<RwLock<HashMap<PoolKey, PoolEntry>>>,
128 h1_idle: Arc<RwLock<HashMap<PoolKey, Vec<H1PoolEntry>>>>,
129 max_idle_duration: Duration,
130 #[allow(dead_code)] max_connections_per_host: usize,
132 default_max_streams: u32,
133}
134
135impl ConnectionPool {
136 const DEFAULT_MAX_IDLE: Duration = Duration::from_secs(30);
138
139 const DEFAULT_MAX_PER_HOST: usize = 6;
141
142 const DEFAULT_MAX_STREAMS: u32 = 100;
144
145 pub fn new() -> Self {
147 Self {
148 entries: Arc::new(RwLock::new(HashMap::new())),
149 h1_idle: Arc::new(RwLock::new(HashMap::new())),
150 max_idle_duration: Self::DEFAULT_MAX_IDLE,
151 max_connections_per_host: Self::DEFAULT_MAX_PER_HOST,
152 default_max_streams: Self::DEFAULT_MAX_STREAMS,
153 }
154 }
155
156 pub fn with_config(max_idle: Duration, max_per_host: usize, max_streams: u32) -> Self {
158 Self {
159 entries: Arc::new(RwLock::new(HashMap::new())),
160 h1_idle: Arc::new(RwLock::new(HashMap::new())),
161 max_idle_duration: max_idle,
162 max_connections_per_host: max_per_host,
163 default_max_streams: max_streams,
164 }
165 }
166
167 pub async fn get_h1(&self, key: &PoolKey) -> Option<MaybeHttpsStream> {
169 let start = Instant::now();
170 let mut pool = self.h1_idle.write().await;
171 if let Some(entries) = pool.get_mut(key) {
172 tracing::debug!("H1 Pool: {} entries for key {:?}", entries.len(), key);
173 let initial_count = entries.len();
174 while let Some(entry) = entries.pop() {
175 if !entry.is_expired(self.max_idle_duration) {
176 tracing::debug!(
177 "H1 Pool: Reusing connection for {:?} (checked {} entries, took {:?})",
178 key,
179 initial_count - entries.len(),
180 start.elapsed()
181 );
182 return Some(entry.stream);
183 }
184 tracing::debug!(
185 "H1 Pool: Connection expired for {:?} (age: {:?})",
186 key,
187 entry.last_used.elapsed()
188 );
189 }
190 } else {
191 tracing::debug!("H1 Pool: No entries for key {:?}", key);
192 }
193 tracing::debug!(
194 "H1 Pool: No reusable connection found for {:?} (took {:?})",
195 key,
196 start.elapsed()
197 );
198 None
199 }
200
201 pub async fn put_h1(&self, key: PoolKey, stream: MaybeHttpsStream) {
203 let start = Instant::now();
204 tracing::debug!("H1 Pool: Returning connection for {:?}", key);
205 let mut pool = self.h1_idle.write().await;
206 let entries = pool.entry(key.clone()).or_default();
207 let count_before = entries.len();
208 entries.push(H1PoolEntry::new(stream));
209 tracing::debug!(
210 "H1 Pool: Returned connection for {:?} (pool size: {} -> {}, took {:?})",
211 key,
212 count_before,
213 entries.len(),
214 start.elapsed()
215 );
216 }
217
218 pub async fn get_or_create(
224 &self,
225 key: &PoolKey,
226 version: HttpVersion,
227 ) -> Result<Option<PoolEntry>> {
228 let start = Instant::now();
229 let mut entries = self.entries.write().await;
230
231 if version == HttpVersion::Http1_1 {
233 return Ok(None);
234 }
235
236 if let Some(entry) = entries.get_mut(key) {
238 let active_before = entry.active_streams;
239 if entry.acquire_stream() {
240 tracing::debug!(
241 "H2/H3 Pool: Reusing connection for {:?} (active streams: {} -> {}, took {:?})",
242 key,
243 active_before,
244 entry.active_streams,
245 start.elapsed()
246 );
247 return Ok(Some(entry.clone()));
248 } else {
249 tracing::debug!(
250 "H2/H3 Pool: Connection exists for {:?} but cannot multiplex (active: {}/{}, valid: {}, took {:?})",
251 key,
252 active_before,
253 entry.max_streams,
254 entry.is_valid,
255 start.elapsed()
256 );
257 }
258 } else {
259 tracing::debug!("H2/H3 Pool: No existing connection for {:?}", key);
260 }
261
262 tracing::debug!(
264 "H2/H3 Pool: Creating new connection entry for {:?} (took {:?})",
265 key,
266 start.elapsed()
267 );
268 let entry = PoolEntry::new(version, self.default_max_streams);
269 entries.insert(key.clone(), entry.clone());
270
271 Ok(Some(entry))
272 }
273
274 pub async fn release(&self, key: &PoolKey) {
276 let mut entries = self.entries.write().await;
277 if let Some(entry) = entries.get_mut(key) {
278 let active_before = entry.active_streams;
279 entry.release_stream();
280 tracing::debug!(
281 "H2/H3 Pool: Released stream for {:?} (active streams: {} -> {})",
282 key,
283 active_before,
284 entry.active_streams
285 );
286 } else {
287 tracing::warn!(
288 "H2/H3 Pool: Attempted to release stream for non-existent connection {:?}",
289 key
290 );
291 }
292 }
293
294 pub async fn invalidate(&self, key: &PoolKey) {
296 let mut entries = self.entries.write().await;
297 if let Some(entry) = entries.get_mut(key) {
298 entry.invalidate();
299 }
300 }
301
302 pub async fn cleanup(&self) {
304 {
306 let mut entries = self.entries.write().await;
307 entries
308 .retain(|_key, entry| entry.is_valid && !entry.is_expired(self.max_idle_duration));
309 }
310
311 {
313 let mut h1_pool = self.h1_idle.write().await;
314 for entries in h1_pool.values_mut() {
315 entries.retain(|e| !e.is_expired(self.max_idle_duration));
316 }
317 h1_pool.retain(|_, entries| !entries.is_empty());
318 }
319 }
320
321 pub fn spawn_cleanup_task(self: Arc<Self>, interval: Duration) -> tokio::task::JoinHandle<()> {
325 tokio::spawn(async move {
326 let mut interval_timer = tokio::time::interval(interval);
327 loop {
328 interval_timer.tick().await;
329 self.cleanup().await;
330 }
331 })
332 }
333
334 pub async fn stats(&self) -> PoolStats {
336 let entries = self.entries.read().await;
337 let h1_pool = self.h1_idle.read().await;
338
339 let h1_idle_count = h1_pool.values().map(|v| v.len()).sum();
340
341 PoolStats {
342 total_connections: entries.len() + h1_idle_count,
343 active_streams: entries.values().map(|e| e.active_streams).sum(),
344 http2_connections: entries
345 .values()
346 .filter(|e| matches!(e.version, HttpVersion::Http2))
347 .count(),
348 http3_connections: entries
349 .values()
350 .filter(|e| matches!(e.version, HttpVersion::Http3 | HttpVersion::Http3Only))
351 .count(),
352 http1_idle_connections: h1_idle_count,
353 }
354 }
355}
356
357impl Default for ConnectionPool {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct PoolStats {
366 pub total_connections: usize,
367 pub active_streams: u32,
368 pub http2_connections: usize,
369 pub http3_connections: usize,
370 pub http1_idle_connections: usize,
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_pool_key_equality() {
379 let key1 = PoolKey::new("example.com".to_string(), 443, true);
380 let key2 = PoolKey::new("example.com".to_string(), 443, true);
381 let key3 = PoolKey::new("example.com".to_string(), 80, false);
382
383 assert_eq!(key1, key2);
384 assert_ne!(key1, key3);
385 }
386
387 #[test]
388 fn test_pool_entry_multiplexing() {
389 let mut entry = PoolEntry::new(HttpVersion::Http2, 100);
390
391 assert!(entry.can_multiplex());
393 assert!(entry.acquire_stream());
394 assert_eq!(entry.active_streams, 1);
395
396 entry.release_stream();
398 assert_eq!(entry.active_streams, 0);
399 }
400
401 #[test]
402 fn test_pool_entry_max_streams() {
403 let mut entry = PoolEntry::new(HttpVersion::Http2, 2);
404
405 assert!(entry.acquire_stream());
406 assert!(entry.acquire_stream());
407 assert!(!entry.acquire_stream()); assert_eq!(entry.active_streams, 2);
409 }
410
411 #[test]
412 fn test_pool_entry_invalidation() {
413 let mut entry = PoolEntry::new(HttpVersion::Http2, 100);
414
415 assert!(entry.can_multiplex());
416 entry.invalidate();
417 assert!(!entry.can_multiplex());
418 }
419
420 #[test]
421 fn test_pool_entry_expiration() {
422 let entry = PoolEntry::new(HttpVersion::Http2, 100);
423
424 assert!(!entry.is_expired(Duration::from_secs(30)));
426
427 assert!(entry.is_expired(Duration::from_secs(0)));
429 }
430
431 #[tokio::test]
432 async fn test_connection_pool_http11() {
433 let pool = ConnectionPool::new();
434 let key = PoolKey::new("example.com".to_string(), 443, true);
435
436 let result = pool
438 .get_or_create(&key, HttpVersion::Http1_1)
439 .await
440 .unwrap();
441 assert!(result.is_none());
442 }
443
444 #[tokio::test]
445 async fn test_connection_pool_http2_multiplexing() {
446 let pool = ConnectionPool::new();
447 let key = PoolKey::new("example.com".to_string(), 443, true);
448
449 let entry1 = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
451 assert!(entry1.is_some());
452
453 let entry2 = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
455 assert!(entry2.is_some());
456
457 let stats = pool.stats().await;
459 assert_eq!(stats.total_connections, 1);
460 assert_eq!(stats.http2_connections, 1);
461 }
462
463 #[tokio::test]
464 async fn test_connection_pool_release() {
465 let pool = ConnectionPool::new();
466 let key = PoolKey::new("example.com".to_string(), 443, true);
467
468 let _entry = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
469
470 pool.release(&key).await;
472
473 let stats = pool.stats().await;
474 assert_eq!(stats.total_connections, 1);
475 }
476
477 #[tokio::test]
478 async fn test_connection_pool_invalidation() {
479 let pool = ConnectionPool::new();
480 let key = PoolKey::new("example.com".to_string(), 443, true);
481
482 let _entry = pool.get_or_create(&key, HttpVersion::Http2).await.unwrap();
483
484 pool.invalidate(&key).await;
486
487 pool.cleanup().await;
489
490 let stats = pool.stats().await;
491 assert_eq!(stats.total_connections, 0);
492 }
493}