1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::{Arc, Mutex};
4use std::time::Duration;
5
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7
8use crate::client::S7Client;
9use crate::error::Error;
10use crate::transport::TcpTransport;
11use crate::types::ConnectParams;
12
13#[derive(Debug, Clone)]
15pub struct PoolConfig {
16 pub max_size: usize,
18 pub connect_timeout: Duration,
20}
21
22impl Default for PoolConfig {
23 fn default() -> Self {
24 PoolConfig {
25 max_size: 4,
26 connect_timeout: Duration::from_secs(5),
27 }
28 }
29}
30
31struct PoolInner {
32 idle: VecDeque<S7Client<TcpTransport>>,
33 addr: SocketAddr,
34 connect_params: ConnectParams,
35 connect_timeout: Duration,
36}
37
38pub struct S7Pool {
40 inner: Arc<Mutex<PoolInner>>,
41 semaphore: Arc<Semaphore>,
42}
43
44pub struct PooledClient {
46 client: Option<S7Client<TcpTransport>>,
47 pool: Arc<Mutex<PoolInner>>,
48 _permit: OwnedSemaphorePermit,
49}
50
51impl PooledClient {
52 pub fn client(&self) -> &S7Client<TcpTransport> {
54 self.client
55 .as_ref()
56 .expect("client always present until drop")
57 }
58}
59
60impl Drop for PooledClient {
61 fn drop(&mut self) {
62 if let Some(client) = self.client.take() {
63 if let Ok(mut inner) = self.pool.lock() {
64 inner.idle.push_back(client);
65 }
66 }
68 }
69}
70
71impl S7Pool {
72 pub fn new(addr: SocketAddr, connect_params: ConnectParams, cfg: PoolConfig) -> Self {
74 let max = cfg.max_size;
75 S7Pool {
76 inner: Arc::new(Mutex::new(PoolInner {
77 idle: VecDeque::new(),
78 addr,
79 connect_params,
80 connect_timeout: cfg.connect_timeout,
81 })),
82 semaphore: Arc::new(Semaphore::new(max)),
83 }
84 }
85
86 pub async fn acquire(&self) -> Result<PooledClient, Error> {
89 let permit = self
90 .semaphore
91 .clone()
92 .acquire_owned()
93 .await
94 .expect("semaphore never closed");
95
96 let idle_client = {
98 let mut inner = self.inner.lock().expect("pool mutex not poisoned");
99 inner.idle.pop_front()
100 };
101
102 if let Some(client) = idle_client {
103 return Ok(PooledClient {
104 client: Some(client),
105 pool: self.inner.clone(),
106 _permit: permit,
107 });
108 }
109
110 let (addr, params, connect_timeout) = {
112 let inner = self.inner.lock().expect("pool mutex not poisoned");
113 (
114 inner.addr,
115 inner.connect_params.clone(),
116 inner.connect_timeout,
117 )
118 };
119
120 let client = tokio::time::timeout(
121 connect_timeout,
122 S7Client::<TcpTransport>::connect(addr, params),
123 )
124 .await
125 .map_err(|_| {
126 Error::Io(std::io::Error::new(
127 std::io::ErrorKind::TimedOut,
128 "pool connect timeout",
129 ))
130 })??;
131
132 Ok(PooledClient {
133 client: Some(client),
134 pool: self.inner.clone(),
135 _permit: permit,
136 })
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use std::time::Duration;
144
145 fn cfg(max: usize) -> PoolConfig {
146 PoolConfig {
147 max_size: max,
148 connect_timeout: Duration::from_millis(100),
149 }
150 }
151
152 #[test]
153 fn pool_config_defaults_are_sane() {
154 let c = PoolConfig::default();
155 assert!(c.max_size >= 1);
156 assert!(c.connect_timeout.as_millis() > 0);
157 }
158
159 #[test]
160 fn pool_config_max_size() {
161 let c = cfg(4);
162 assert_eq!(c.max_size, 4);
163 }
164
165 #[tokio::test]
166 async fn pool_acquire_returns_err_on_unreachable_host() {
167 let addr = "127.0.0.1:1".parse().unwrap();
168 let pool = S7Pool::new(addr, Default::default(), cfg(2));
169 let result = pool.acquire().await;
170 assert!(result.is_err(), "expected connection error on port 1");
171 }
172
173 #[tokio::test]
174 async fn pool_acquire_releases_permit_on_error() {
175 let addr = "127.0.0.1:1".parse().unwrap();
176 let pool = S7Pool::new(
177 addr,
178 Default::default(),
179 PoolConfig {
180 max_size: 1,
181 connect_timeout: Duration::from_millis(100),
182 },
183 );
184 assert!(pool.acquire().await.is_err());
186 let result = tokio::time::timeout(Duration::from_secs(2), pool.acquire()).await;
188 assert!(
189 result.is_ok(),
190 "second acquire timed out — permit was leaked"
191 );
192 }
193
194 #[tokio::test]
195 async fn pool_semaphore_limits_concurrent_borrows() {
196 let addr = "127.0.0.1:1".parse().unwrap();
199 let pool = Arc::new(S7Pool::new(
200 addr,
201 Default::default(),
202 PoolConfig {
203 max_size: 1,
204 connect_timeout: Duration::from_millis(100),
205 },
206 ));
207
208 let pool1 = pool.clone();
209 let t1 = tokio::spawn(async move { pool1.acquire().await });
210
211 let t2 = tokio::spawn(async move { pool.acquire().await });
212
213 let (r1, r2) = tokio::join!(t1, t2);
214 assert!(r1.unwrap().is_err());
216 assert!(r2.unwrap().is_err());
217 }
218}