1use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22
23use crossbeam_queue::ArrayQueue;
24use parking_lot::Mutex;
25
26#[derive(Clone, Debug)]
28pub struct PoolConfig {
29 pub max_instances: usize,
34 pub warm_count: usize,
36}
37
38impl Default for PoolConfig {
39 fn default() -> Self {
40 Self {
41 max_instances: 4,
42 warm_count: 1,
43 }
44 }
45}
46
47#[derive(Debug, Default)]
49pub struct PoolMetrics {
50 pub hits: AtomicU64,
52 pub misses: AtomicU64,
54 pub exhausted: AtomicU64,
56 pub live: AtomicU64,
58}
59
60pub trait PoolResourceLimit {
71 #[must_use]
74 fn resource_limit(msg: String) -> Self;
75}
76
77pub struct InstancePool<T, E>
83where
84 T: Send + 'static,
85 E: PoolResourceLimit + Send + Sync + 'static,
86{
87 cfg: PoolConfig,
88 idle: ArrayQueue<T>,
89 factory: Mutex<Box<dyn Fn() -> Result<T, E> + Send + Sync>>,
90 metrics: Arc<PoolMetrics>,
91}
92
93impl<T, E> std::fmt::Debug for InstancePool<T, E>
94where
95 T: Send + 'static,
96 E: PoolResourceLimit + Send + Sync + 'static,
97{
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("InstancePool")
100 .field("cfg", &self.cfg)
101 .field("idle.len", &self.idle.len())
102 .field("metrics.hits", &self.metrics.hits.load(Ordering::Relaxed))
103 .field(
104 "metrics.misses",
105 &self.metrics.misses.load(Ordering::Relaxed),
106 )
107 .finish_non_exhaustive()
108 }
109}
110
111impl<T, E> InstancePool<T, E>
112where
113 T: Send + 'static,
114 E: PoolResourceLimit + Send + Sync + 'static,
115{
116 pub fn new(
126 cfg: PoolConfig,
127 factory: impl Fn() -> Result<T, E> + Send + Sync + 'static,
128 ) -> Result<Self, E> {
129 let idle = ArrayQueue::new(cfg.max_instances.max(1));
130 let factory = Mutex::new(Box::new(factory) as Box<dyn Fn() -> Result<T, E> + Send + Sync>);
131 let metrics = Arc::new(PoolMetrics::default());
132
133 let pool = Self {
134 cfg: cfg.clone(),
135 idle,
136 factory,
137 metrics: Arc::clone(&metrics),
138 };
139
140 for _ in 0..cfg.warm_count.min(cfg.max_instances) {
141 let inst = (pool.factory.lock())()?;
142 let _ = pool.idle.push(inst);
143 metrics.live.fetch_add(1, Ordering::SeqCst);
144 }
145 Ok(pool)
146 }
147
148 pub fn acquire(&self) -> Result<T, E> {
159 if let Some(inst) = self.idle.pop() {
160 self.metrics.hits.fetch_add(1, Ordering::SeqCst);
161 return Ok(inst);
162 }
163 let max = self.cfg.max_instances as u64;
168 loop {
169 let live = self.metrics.live.load(Ordering::SeqCst);
170 if live >= max {
171 self.metrics.exhausted.fetch_add(1, Ordering::SeqCst);
172 return Err(E::resource_limit(format!(
173 "instance pool at capacity ({} live)",
174 self.cfg.max_instances
175 )));
176 }
177 if self
178 .metrics
179 .live
180 .compare_exchange(live, live + 1, Ordering::SeqCst, Ordering::SeqCst)
181 .is_ok()
182 {
183 break;
184 }
185 }
186 let inst = match (self.factory.lock())() {
189 Ok(v) => v,
190 Err(err) => {
191 self.metrics.live.fetch_sub(1, Ordering::SeqCst);
192 return Err(err);
193 }
194 };
195 self.metrics.misses.fetch_add(1, Ordering::SeqCst);
196 Ok(inst)
197 }
198
199 pub fn release(&self, inst: T) {
204 if self.idle.push(inst).is_err() {
205 self.metrics.live.fetch_sub(1, Ordering::SeqCst);
206 }
207 }
208
209 #[must_use]
211 pub fn metrics(&self) -> Arc<PoolMetrics> {
212 Arc::clone(&self.metrics)
213 }
214
215 #[must_use]
217 pub fn config(&self) -> &PoolConfig {
218 &self.cfg
219 }
220
221 #[doc(hidden)]
222 pub fn idle_len(&self) -> usize {
223 self.idle.len()
224 }
225}
226
227pub struct PooledInstance<T, E>
233where
234 T: Send + 'static,
235 E: PoolResourceLimit + Send + Sync + 'static,
236{
237 pool: Arc<InstancePool<T, E>>,
238 inst: Option<T>,
239}
240
241impl<T, E> std::fmt::Debug for PooledInstance<T, E>
242where
243 T: Send + 'static,
244 E: PoolResourceLimit + Send + Sync + 'static,
245{
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.debug_struct("PooledInstance")
248 .field("has_inst", &self.inst.is_some())
249 .finish_non_exhaustive()
250 }
251}
252
253impl<T, E> PooledInstance<T, E>
254where
255 T: Send + 'static,
256 E: PoolResourceLimit + Send + Sync + 'static,
257{
258 pub fn acquire(pool: Arc<InstancePool<T, E>>) -> Result<Self, E> {
264 let inst = pool.acquire()?;
265 Ok(Self {
266 pool,
267 inst: Some(inst),
268 })
269 }
270
271 pub fn get_mut(&mut self) -> &mut T {
277 self.inst
278 .as_mut()
279 .expect("PooledInstance accessed after take/drop")
280 }
281
282 pub fn take(mut self) -> T {
286 let inst = self.inst.take().expect("PooledInstance already taken");
287 self.pool.metrics.live.fetch_sub(1, Ordering::SeqCst);
288 inst
289 }
290}
291
292impl<T, E> Drop for PooledInstance<T, E>
293where
294 T: Send + 'static,
295 E: PoolResourceLimit + Send + Sync + 'static,
296{
297 fn drop(&mut self) {
298 if let Some(inst) = self.inst.take() {
299 self.pool.release(inst);
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[derive(Debug, thiserror::Error)]
309 enum TestErr {
310 #[error("resource limit: {0}")]
311 ResourceLimit(String),
312 }
313
314 impl PoolResourceLimit for TestErr {
315 fn resource_limit(msg: String) -> Self {
316 Self::ResourceLimit(msg)
317 }
318 }
319
320 #[derive(Debug)]
321 #[allow(dead_code)]
322 struct Dummy(u32);
323
324 type TestPool = InstancePool<Dummy, TestErr>;
325
326 #[test]
327 fn warmup_populates_idle_queue() {
328 let n = Arc::new(AtomicU64::new(0));
329 let nc = Arc::clone(&n);
330 let pool = TestPool::new(
331 PoolConfig {
332 max_instances: 4,
333 warm_count: 2,
334 },
335 move || Ok(Dummy(nc.fetch_add(1, Ordering::SeqCst) as u32)),
336 )
337 .unwrap();
338 assert_eq!(pool.metrics.live.load(Ordering::SeqCst), 2);
339 }
340
341 #[test]
342 fn acquire_release_round_trip_counts_hits_and_misses() {
343 let n = Arc::new(AtomicU64::new(0));
344 let nc = Arc::clone(&n);
345 let pool = TestPool::new(
346 PoolConfig {
347 max_instances: 2,
348 warm_count: 1,
349 },
350 move || Ok(Dummy(nc.fetch_add(1, Ordering::SeqCst) as u32)),
351 )
352 .unwrap();
353
354 let a = pool.acquire().unwrap();
355 assert_eq!(pool.metrics.hits.load(Ordering::SeqCst), 1);
356
357 let b = pool.acquire().unwrap();
358 assert_eq!(pool.metrics.misses.load(Ordering::SeqCst), 1);
359
360 pool.release(a);
361 pool.release(b);
362
363 let _ = pool.acquire().unwrap();
364 assert_eq!(pool.metrics.hits.load(Ordering::SeqCst), 2);
365 }
366
367 #[test]
368 fn exhaustion_returns_resource_limit() {
369 let pool = TestPool::new(
370 PoolConfig {
371 max_instances: 1,
372 warm_count: 1,
373 },
374 || Ok(Dummy(0)),
375 )
376 .unwrap();
377 let _held = pool.acquire().unwrap();
378 let err = pool.acquire().unwrap_err();
379 assert!(matches!(err, TestErr::ResourceLimit(_)));
380 assert_eq!(pool.metrics.exhausted.load(Ordering::SeqCst), 1);
381 }
382
383 #[test]
384 fn pooled_instance_releases_on_drop() {
385 let n = Arc::new(AtomicU64::new(0));
386 let nc = Arc::clone(&n);
387 let pool = Arc::new(
388 TestPool::new(
389 PoolConfig {
390 max_instances: 2,
391 warm_count: 1,
392 },
393 move || Ok(Dummy(nc.fetch_add(1, Ordering::SeqCst) as u32)),
394 )
395 .unwrap(),
396 );
397 assert_eq!(pool.idle_len(), 1);
398 {
399 let _h = PooledInstance::acquire(Arc::clone(&pool)).unwrap();
400 assert_eq!(pool.idle_len(), 0);
401 }
402 assert_eq!(pool.idle_len(), 1);
403 }
404
405 #[test]
406 fn pooled_instance_take_does_not_release() {
407 let n = Arc::new(AtomicU64::new(0));
408 let nc = Arc::clone(&n);
409 let pool = Arc::new(
410 TestPool::new(
411 PoolConfig {
412 max_instances: 2,
413 warm_count: 1,
414 },
415 move || Ok(Dummy(nc.fetch_add(1, Ordering::SeqCst) as u32)),
416 )
417 .unwrap(),
418 );
419 let h = PooledInstance::acquire(Arc::clone(&pool)).unwrap();
420 let _taken = h.take();
421 assert_eq!(pool.idle_len(), 0);
422 assert_eq!(pool.metrics.live.load(Ordering::SeqCst), 0);
423 }
424
425 #[test]
426 fn config_default_matches_proposal() {
427 let c = PoolConfig::default();
428 assert_eq!(c.max_instances, 4);
429 assert_eq!(c.warm_count, 1);
430 }
431
432 #[test]
438 fn concurrent_acquire_never_exceeds_max() {
439 use std::sync::Barrier;
440 use std::thread;
441
442 const MAX: usize = 4;
443 const THREADS: usize = 32;
444
445 let pool = Arc::new(
446 TestPool::new(
447 PoolConfig {
448 max_instances: MAX,
449 warm_count: 0,
450 },
451 || Ok(Dummy(0)),
452 )
453 .unwrap(),
454 );
455
456 let barrier = Arc::new(Barrier::new(THREADS));
457 let mut handles = Vec::with_capacity(THREADS);
458 for _ in 0..THREADS {
459 let p = Arc::clone(&pool);
460 let b = Arc::clone(&barrier);
461 handles.push(thread::spawn(move || {
462 b.wait();
463 p.acquire().ok()
464 }));
465 }
466
467 let mut held = Vec::with_capacity(THREADS);
468 for h in handles {
469 if let Some(inst) = h.join().unwrap() {
470 held.push(inst);
471 }
472 }
473
474 assert_eq!(held.len(), MAX, "exactly max_instances must be live");
478 assert_eq!(pool.metrics.live.load(Ordering::SeqCst), MAX as u64);
479 assert_eq!(
480 pool.metrics.exhausted.load(Ordering::SeqCst),
481 (THREADS - MAX) as u64
482 );
483 }
484}