Skip to main content

snap7_client/
pool.rs

1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::{Arc, Mutex};
4use std::time::Duration;
5
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7
8use crate::client::S7Client;
9use crate::error::Error;
10use crate::transport::TcpTransport;
11use crate::types::ConnectParams;
12
13/// Configuration for `S7Pool`.
14#[derive(Debug, Clone)]
15pub struct PoolConfig {
16    /// Maximum number of concurrent connections.
17    pub max_size: usize,
18    /// TCP connect timeout per connection attempt.
19    pub connect_timeout: Duration,
20}
21
22impl Default for PoolConfig {
23    fn default() -> Self {
24        PoolConfig {
25            max_size: 4,
26            connect_timeout: Duration::from_secs(5),
27        }
28    }
29}
30
31struct PoolInner {
32    idle: VecDeque<S7Client<TcpTransport>>,
33    addr: SocketAddr,
34    connect_params: ConnectParams,
35    connect_timeout: Duration,
36}
37
38/// A bounded pool of `S7Client<TcpTransport>` connections.
39pub struct S7Pool {
40    inner: Arc<Mutex<PoolInner>>,
41    semaphore: Arc<Semaphore>,
42}
43
44/// RAII guard — returns the connection to the pool on drop.
45pub struct PooledClient {
46    client: Option<S7Client<TcpTransport>>,
47    pool: Arc<Mutex<PoolInner>>,
48    _permit: OwnedSemaphorePermit,
49}
50
51impl PooledClient {
52    /// Access the underlying `S7Client`.
53    pub fn client(&self) -> &S7Client<TcpTransport> {
54        self.client
55            .as_ref()
56            .expect("client always present until drop")
57    }
58}
59
60impl Drop for PooledClient {
61    fn drop(&mut self) {
62        if let Some(client) = self.client.take() {
63            if let Ok(mut inner) = self.pool.lock() {
64                inner.idle.push_back(client);
65            }
66            // If the mutex is poisoned, the connection is dropped — acceptable.
67        }
68    }
69}
70
71impl S7Pool {
72    /// Create a new pool targeting `addr` with `connect_params` and `cfg`.
73    pub fn new(addr: SocketAddr, connect_params: ConnectParams, cfg: PoolConfig) -> Self {
74        let max = cfg.max_size;
75        S7Pool {
76            inner: Arc::new(Mutex::new(PoolInner {
77                idle: VecDeque::new(),
78                addr,
79                connect_params,
80                connect_timeout: cfg.connect_timeout,
81            })),
82            semaphore: Arc::new(Semaphore::new(max)),
83        }
84    }
85
86    /// Borrow a connection from the pool, opening a new one if none are idle.
87    /// Blocks until a semaphore permit is available (bounded by `max_size`).
88    pub async fn acquire(&self) -> Result<PooledClient, Error> {
89        let permit = self
90            .semaphore
91            .clone()
92            .acquire_owned()
93            .await
94            .expect("semaphore never closed");
95
96        // Check for an idle connection — hold the lock only briefly.
97        let idle_client = {
98            let mut inner = self.inner.lock().expect("pool mutex not poisoned");
99            inner.idle.pop_front()
100        };
101
102        if let Some(client) = idle_client {
103            return Ok(PooledClient {
104                client: Some(client),
105                pool: self.inner.clone(),
106                _permit: permit,
107            });
108        }
109
110        // No idle connection — extract params (brief lock scope), then connect.
111        let (addr, params, connect_timeout) = {
112            let inner = self.inner.lock().expect("pool mutex not poisoned");
113            (
114                inner.addr,
115                inner.connect_params.clone(),
116                inner.connect_timeout,
117            )
118        };
119
120        let client = tokio::time::timeout(
121            connect_timeout,
122            S7Client::<TcpTransport>::connect(addr, params),
123        )
124        .await
125        .map_err(|_| {
126            Error::Io(std::io::Error::new(
127                std::io::ErrorKind::TimedOut,
128                "pool connect timeout",
129            ))
130        })??;
131
132        Ok(PooledClient {
133            client: Some(client),
134            pool: self.inner.clone(),
135            _permit: permit,
136        })
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use std::time::Duration;
144
145    fn cfg(max: usize) -> PoolConfig {
146        PoolConfig {
147            max_size: max,
148            connect_timeout: Duration::from_millis(100),
149        }
150    }
151
152    #[test]
153    fn pool_config_defaults_are_sane() {
154        let c = PoolConfig::default();
155        assert!(c.max_size >= 1);
156        assert!(c.connect_timeout.as_millis() > 0);
157    }
158
159    #[test]
160    fn pool_config_max_size() {
161        let c = cfg(4);
162        assert_eq!(c.max_size, 4);
163    }
164
165    #[tokio::test]
166    async fn pool_acquire_returns_err_on_unreachable_host() {
167        let addr = "127.0.0.1:1".parse().unwrap();
168        let pool = S7Pool::new(addr, Default::default(), cfg(2));
169        let result = pool.acquire().await;
170        assert!(result.is_err(), "expected connection error on port 1");
171    }
172
173    #[tokio::test]
174    async fn pool_acquire_releases_permit_on_error() {
175        let addr = "127.0.0.1:1".parse().unwrap();
176        let pool = S7Pool::new(
177            addr,
178            Default::default(),
179            PoolConfig {
180                max_size: 1,
181                connect_timeout: Duration::from_millis(100),
182            },
183        );
184        // First acquire fails (unreachable host).
185        assert!(pool.acquire().await.is_err());
186        // If the permit was leaked, this second acquire would deadlock.
187        let result = tokio::time::timeout(Duration::from_secs(2), pool.acquire()).await;
188        assert!(
189            result.is_ok(),
190            "second acquire timed out — permit was leaked"
191        );
192    }
193
194    #[tokio::test]
195    async fn pool_semaphore_limits_concurrent_borrows() {
196        // Pool of size 1 — both acquires fail (port 1 not listening) but neither panics,
197        // proving the semaphore releases correctly on error and allows a second attempt.
198        let addr = "127.0.0.1:1".parse().unwrap();
199        let pool = Arc::new(S7Pool::new(
200            addr,
201            Default::default(),
202            PoolConfig {
203                max_size: 1,
204                connect_timeout: Duration::from_millis(100),
205            },
206        ));
207
208        let pool1 = pool.clone();
209        let t1 = tokio::spawn(async move { pool1.acquire().await });
210
211        let t2 = tokio::spawn(async move { pool.acquire().await });
212
213        let (r1, r2) = tokio::join!(t1, t2);
214        // Both fail with connection error — what matters is neither panicked
215        assert!(r1.unwrap().is_err());
216        assert!(r2.unwrap().is_err());
217    }
218}