Skip to main content

trojan_server/
pool.rs

1//! Connection pool for fallback backend.
2//!
3//! Warm pool strategy: pre-connect N fresh connections and hand them out once.
4//! Connections are not returned to the pool after use.
5
6use std::{
7    collections::VecDeque,
8    net::SocketAddr,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use parking_lot::Mutex;
14use tokio::net::TcpStream;
15use tracing::debug;
16use trojan_metrics::{record_fallback_pool_warm_fail, set_fallback_pool_size};
17
18/// A pooled connection with metadata.
19struct PooledConnection {
20    stream: TcpStream,
21    created_at: Instant,
22}
23
24/// Connection pool for a single backend address.
25pub struct ConnectionPool {
26    addr: SocketAddr,
27    connections: Arc<Mutex<VecDeque<PooledConnection>>>,
28    max_idle: usize,
29    max_age: Duration,
30    fill_batch: usize,
31    fill_delay: Duration,
32}
33
34impl ConnectionPool {
35    /// Create a new connection pool.
36    pub fn new(
37        addr: SocketAddr,
38        max_idle: usize,
39        max_age_secs: u64,
40        fill_batch: usize,
41        fill_delay_ms: u64,
42    ) -> Self {
43        let pool = Self {
44            addr,
45            connections: Arc::new(Mutex::new(VecDeque::new())),
46            max_idle,
47            max_age: Duration::from_secs(max_age_secs),
48            fill_batch,
49            fill_delay: Duration::from_millis(fill_delay_ms),
50        };
51        set_fallback_pool_size(0);
52        pool
53    }
54
55    /// Get a fresh connection from the pool or create a new one.
56    pub async fn get(&self) -> std::io::Result<TcpStream> {
57        // Pop one fresh connection if available
58        let pooled = {
59            let mut pool = self.connections.lock();
60            let pooled = pool.pop_front();
61            set_fallback_pool_size(pool.len());
62            pooled
63        };
64        if let Some(pooled) = pooled {
65            if pooled.created_at.elapsed() < self.max_age {
66                debug!(addr = %self.addr, "using pooled connection");
67                return Ok(pooled.stream);
68            }
69            debug!(addr = %self.addr, "discarding expired pooled connection");
70        }
71
72        // No valid pooled connection, create new one
73        debug!(addr = %self.addr, "creating new connection");
74        TcpStream::connect(self.addr).await
75    }
76
77    /// Warm pool maintains fresh connections; used connections are not returned.
78    /// Clean up expired connections.
79    pub fn cleanup(&self) {
80        let mut pool = self.connections.lock();
81        let before = pool.len();
82        pool.retain(|conn| conn.created_at.elapsed() < self.max_age);
83        let removed = before - pool.len();
84        set_fallback_pool_size(pool.len());
85        if removed > 0 {
86            debug!(addr = %self.addr, removed, remaining = pool.len(), "cleaned up expired connections");
87        }
88    }
89
90    /// Start a background warm-fill task.
91    pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
92        let pool = self.clone();
93        tokio::spawn(async move {
94            loop {
95                tokio::time::sleep(interval).await;
96                pool.cleanup();
97                pool.warm_fill().await;
98            }
99        });
100    }
101
102    /// Get current pool size.
103    pub fn size(&self) -> usize {
104        self.connections.lock().len()
105    }
106
107    /// Fill the pool with fresh connections up to max_idle.
108    async fn warm_fill(&self) {
109        let need = {
110            let pool = self.connections.lock();
111            if pool.len() >= self.max_idle {
112                return;
113            }
114            self.max_idle - pool.len()
115        };
116        if need == 0 {
117            return;
118        }
119        let batch = self.fill_batch.min(need);
120        for idx in 0..batch {
121            match TcpStream::connect(self.addr).await {
122                Ok(stream) => {
123                    let mut pool = self.connections.lock();
124                    if pool.len() < self.max_idle {
125                        pool.push_back(PooledConnection {
126                            stream,
127                            created_at: Instant::now(),
128                        });
129                        set_fallback_pool_size(pool.len());
130                        debug!(addr = %self.addr, size = pool.len(), "warm connection added");
131                    }
132                }
133                Err(err) => {
134                    record_fallback_pool_warm_fail();
135                    debug!(addr = %self.addr, error = %err, "warm connection failed");
136                    break;
137                }
138            }
139            if self.fill_delay > Duration::from_millis(0) && idx + 1 < batch {
140                tokio::time::sleep(self.fill_delay).await;
141            }
142        }
143    }
144}
145
146#[cfg(test)]
147impl ConnectionPool {
148    async fn warm_fill_once(&self) {
149        self.warm_fill().await;
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use std::net::TcpListener;
157
158    #[tokio::test]
159    async fn test_pool_basic() {
160        // Start a simple TCP listener
161        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
162        let addr = listener.local_addr().unwrap();
163
164        // Accept connections in background
165        std::thread::spawn(move || {
166            while let Ok((_, _)) = listener.accept() {
167                // Just accept, don't do anything
168            }
169        });
170
171        let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
172
173        // Warm-fill the pool (fills up to max_idle=2 connections)
174        pool.warm_fill_once().await;
175        let initial_size = pool.size();
176        assert!(initial_size <= 2);
177
178        // Get a connection (takes one from pool)
179        let conn1 = pool.get().await.unwrap();
180        // Pool should have one less connection (or 0 if only 1 was added)
181        assert_eq!(pool.size(), initial_size.saturating_sub(1));
182
183        drop(conn1);
184    }
185
186    #[tokio::test]
187    async fn test_pool_max_idle() {
188        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
189        let addr = listener.local_addr().unwrap();
190
191        std::thread::spawn(move || while let Ok((_, _)) = listener.accept() {});
192
193        let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
194
195        // Warm-fill should not exceed max_idle
196        pool.warm_fill_once().await;
197        assert!(pool.size() <= 2);
198    }
199}