Skip to main content

vcl_protocol/
pool.rs

1//! # VCL Connection Pool
2//!
3//! [`VCLPool`] manages multiple [`VCLConnection`]s under a single manager.
4//!
5//! Useful when you need to handle many peers simultaneously —
6//! for example a server accepting connections from multiple clients,
7//! or a client maintaining connections to multiple servers.
8//!
9//! ## Example
10//!
11//! ```no_run
12//! use vcl_protocol::pool::VCLPool;
13//!
14//! #[tokio::main]
15//! async fn main() {
16//!     let mut pool = VCLPool::new(10);
17//!
18//!     let id = pool.bind("127.0.0.1:0").await.unwrap();
19//!     pool.connect(id, "127.0.0.1:8080").await.unwrap();
20//!     pool.send(id, b"Hello from pool!").await.unwrap();
21//!
22//!     let packet = pool.recv(id).await.unwrap();
23//!     println!("{}", String::from_utf8_lossy(&packet.payload));
24//!
25//!     pool.close(id).unwrap();
26//! }
27//! ```
28
29use crate::connection::VCLConnection;
30use crate::error::VCLError;
31use crate::packet::VCLPacket;
32use std::collections::HashMap;
33use tracing::{debug, info, warn};
34
35/// A unique identifier for a connection inside a [`VCLPool`].
36pub type ConnectionId = u64;
37
38/// Manages multiple [`VCLConnection`]s under a single pool.
39///
40/// Each connection gets a unique [`ConnectionId`] assigned at `bind()`.
41/// The pool enforces a maximum connection limit set at construction.
42pub struct VCLPool {
43    connections: HashMap<ConnectionId, VCLConnection>,
44    next_id: ConnectionId,
45    max_connections: usize,
46}
47
48impl VCLPool {
49    /// Create a new pool with a maximum number of concurrent connections.
50    ///
51    /// # Example
52    /// ```
53    /// use vcl_protocol::pool::VCLPool;
54    /// let pool = VCLPool::new(10);
55    /// ```
56    pub fn new(max_connections: usize) -> Self {
57        info!(max_connections, "VCLPool created");
58        VCLPool {
59            connections: HashMap::new(),
60            next_id: 0,
61            max_connections,
62        }
63    }
64
65    /// Bind a new connection to a local UDP address and add it to the pool.
66    ///
67    /// Returns the [`ConnectionId`] assigned to this connection.
68    ///
69    /// # Errors
70    /// - [`VCLError::InvalidPacket`] — pool is at maximum capacity
71    /// - [`VCLError::IoError`] — socket bind failed
72    pub async fn bind(&mut self, addr: &str) -> Result<ConnectionId, VCLError> {
73        if self.connections.len() >= self.max_connections {
74            warn!(
75                current = self.connections.len(),
76                max = self.max_connections,
77                "Pool is at maximum capacity"
78            );
79            return Err(VCLError::InvalidPacket(format!(
80                "Pool is full: max {} connections",
81                self.max_connections
82            )));
83        }
84
85        let conn = VCLConnection::bind(addr).await?;
86        let id = self.next_id;
87        self.next_id += 1;
88        self.connections.insert(id, conn);
89        info!(id, addr, "Connection added to pool");
90        Ok(id)
91    }
92
93    /// Connect a pooled connection to a remote peer (client side handshake).
94    ///
95    /// # Errors
96    /// - [`VCLError::InvalidPacket`] — connection ID not found
97    /// - [`VCLError::HandshakeFailed`] — handshake failed
98    pub async fn connect(&mut self, id: ConnectionId, addr: &str) -> Result<(), VCLError> {
99        let conn = self.get_mut(id)?;
100        debug!(id, peer = %addr, "Pool: connecting");
101        conn.connect(addr).await
102    }
103
104    /// Accept an incoming handshake on a pooled connection (server side).
105    ///
106    /// # Errors
107    /// - [`VCLError::InvalidPacket`] — connection ID not found
108    /// - [`VCLError::HandshakeFailed`] — handshake failed
109    pub async fn accept_handshake(&mut self, id: ConnectionId) -> Result<(), VCLError> {
110        let conn = self.get_mut(id)?;
111        debug!(id, "Pool: accepting handshake");
112        conn.accept_handshake().await
113    }
114
115    /// Send data on a pooled connection.
116    ///
117    /// # Errors
118    /// - [`VCLError::InvalidPacket`] — connection ID not found
119    /// - Any error from [`VCLConnection::send`]
120    pub async fn send(&mut self, id: ConnectionId, data: &[u8]) -> Result<(), VCLError> {
121        let conn = self.get_mut(id)?;
122        debug!(id, size = data.len(), "Pool: sending");
123        conn.send(data).await
124    }
125
126    /// Receive the next data packet on a pooled connection.
127    ///
128    /// # Errors
129    /// - [`VCLError::InvalidPacket`] — connection ID not found
130    /// - Any error from [`VCLConnection::recv`]
131    pub async fn recv(&mut self, id: ConnectionId) -> Result<VCLPacket, VCLError> {
132        let conn = self.get_mut(id)?;
133        debug!(id, "Pool: waiting for packet");
134        conn.recv().await
135    }
136
137    /// Send a ping on a pooled connection.
138    ///
139    /// # Errors
140    /// - [`VCLError::InvalidPacket`] — connection ID not found
141    /// - Any error from [`VCLConnection::ping`]
142    pub async fn ping(&mut self, id: ConnectionId) -> Result<(), VCLError> {
143        let conn = self.get_mut(id)?;
144        debug!(id, "Pool: sending ping");
145        conn.ping().await
146    }
147
148    /// Rotate keys on a pooled connection.
149    ///
150    /// # Errors
151    /// - [`VCLError::InvalidPacket`] — connection ID not found
152    /// - Any error from [`VCLConnection::rotate_keys`]
153    pub async fn rotate_keys(&mut self, id: ConnectionId) -> Result<(), VCLError> {
154        let conn = self.get_mut(id)?;
155        debug!(id, "Pool: rotating keys");
156        conn.rotate_keys().await
157    }
158
159    /// Close a specific connection and remove it from the pool.
160    ///
161    /// # Errors
162    /// - [`VCLError::InvalidPacket`] — connection ID not found
163    /// - [`VCLError::ConnectionClosed`] — already closed
164    pub fn close(&mut self, id: ConnectionId) -> Result<(), VCLError> {
165        match self.connections.get_mut(&id) {
166            Some(conn) => {
167                conn.close()?;
168                self.connections.remove(&id);
169                info!(id, "Connection removed from pool");
170                Ok(())
171            }
172            None => {
173                warn!(id, "close() called with unknown connection ID");
174                Err(VCLError::InvalidPacket(format!(
175                    "Connection ID {} not found in pool",
176                    id
177                )))
178            }
179        }
180    }
181
182    /// Close all connections and clear the pool.
183    pub fn close_all(&mut self) {
184        info!(count = self.connections.len(), "Closing all pool connections");
185        for (id, conn) in self.connections.iter_mut() {
186            if let Err(e) = conn.close() {
187                warn!(id, error = %e, "Error closing connection during close_all");
188            }
189        }
190        self.connections.clear();
191    }
192
193    /// Returns the number of active connections in the pool.
194    pub fn len(&self) -> usize {
195        self.connections.len()
196    }
197
198    /// Returns `true` if the pool has no active connections.
199    pub fn is_empty(&self) -> bool {
200        self.connections.is_empty()
201    }
202
203    /// Returns `true` if the pool has reached its maximum capacity.
204    pub fn is_full(&self) -> bool {
205        self.connections.len() >= self.max_connections
206    }
207
208    /// Returns a list of all active [`ConnectionId`]s in the pool.
209    pub fn connection_ids(&self) -> Vec<ConnectionId> {
210        self.connections.keys().copied().collect()
211    }
212
213    /// Returns `true` if a connection with the given ID exists in the pool.
214    pub fn contains(&self, id: ConnectionId) -> bool {
215        self.connections.contains_key(&id)
216    }
217
218    /// Get a reference to a connection by ID.
219    ///
220    /// # Errors
221    /// Returns [`VCLError::InvalidPacket`] if the ID is not found.
222    pub fn get(&self, id: ConnectionId) -> Result<&VCLConnection, VCLError> {
223        self.connections.get(&id).ok_or_else(|| {
224            VCLError::InvalidPacket(format!("Connection ID {} not found in pool", id))
225        })
226    }
227
228    fn get_mut(&mut self, id: ConnectionId) -> Result<&mut VCLConnection, VCLError> {
229        self.connections.get_mut(&id).ok_or_else(|| {
230            VCLError::InvalidPacket(format!("Connection ID {} not found in pool", id))
231        })
232    }
233}
234
235impl Drop for VCLPool {
236    fn drop(&mut self) {
237        if !self.connections.is_empty() {
238            self.close_all();
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_pool_new() {
249        let pool = VCLPool::new(5);
250        assert_eq!(pool.len(), 0);
251        assert!(pool.is_empty());
252        assert!(!pool.is_full());
253    }
254
255    #[tokio::test]
256    async fn test_pool_bind() {
257        let mut pool = VCLPool::new(5);
258        let id = pool.bind("127.0.0.1:0").await.unwrap();
259        assert_eq!(pool.len(), 1);
260        assert!(pool.contains(id));
261        assert!(!pool.is_empty());
262    }
263
264    #[tokio::test]
265    async fn test_pool_max_capacity() {
266        let mut pool = VCLPool::new(2);
267        pool.bind("127.0.0.1:0").await.unwrap();
268        pool.bind("127.0.0.1:0").await.unwrap();
269        assert!(pool.is_full());
270        let result = pool.bind("127.0.0.1:0").await;
271        assert!(result.is_err());
272    }
273
274    #[tokio::test]
275    async fn test_pool_close() {
276        let mut pool = VCLPool::new(5);
277        let id = pool.bind("127.0.0.1:0").await.unwrap();
278        assert_eq!(pool.len(), 1);
279        pool.close(id).unwrap();
280        assert_eq!(pool.len(), 0);
281        assert!(!pool.contains(id));
282    }
283
284    #[tokio::test]
285    async fn test_pool_close_all() {
286        let mut pool = VCLPool::new(5);
287        pool.bind("127.0.0.1:0").await.unwrap();
288        pool.bind("127.0.0.1:0").await.unwrap();
289        pool.bind("127.0.0.1:0").await.unwrap();
290        assert_eq!(pool.len(), 3);
291        pool.close_all();
292        assert!(pool.is_empty());
293    }
294
295    #[tokio::test]
296    async fn test_pool_unknown_id() {
297        let mut pool = VCLPool::new(5);
298        let result = pool.close(999);
299        assert!(result.is_err());
300    }
301
302    #[tokio::test]
303    async fn test_pool_connection_ids() {
304        let mut pool = VCLPool::new(5);
305        let id1 = pool.bind("127.0.0.1:0").await.unwrap();
306        let id2 = pool.bind("127.0.0.1:0").await.unwrap();
307        let mut ids = pool.connection_ids();
308        ids.sort();
309        assert_eq!(ids, vec![id1, id2]);
310    }
311}