#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
use parking_lot::Mutex;
use std::collections::BTreeSet;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::ptr::addr_of;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
type ClientId = usize;
struct ResourcePoolGet<'a, T> {
pool: &'a ResourcePool<T>,
queued: bool,
}
impl<'a, T> Future for ResourcePoolGet<'a, T> {
type Output = ResourcePoolGuard<T>;
fn poll(
mut self: Pin<&mut ResourcePoolGet<'a, T>>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
let mut holder = self.pool.holder.lock();
if holder.wakers.is_empty() || self.queued {
if let Some(res) = holder.resources.pop() {
self.queued = false;
holder.confirm_get(self.id());
return Poll::Ready(ResourcePoolGuard {
resource: Some(res),
holder: self.pool.holder.clone(),
});
}
}
self.queued = true;
holder.append_callback(cx.waker().clone(), self.id());
Poll::Pending
}
}
impl<'a, T> ResourcePoolGet<'a, T> {
#[inline]
fn id(&self) -> ClientId {
addr_of!(self.queued).cast::<bool>() as ClientId
}
}
impl<'a, T> Drop for ResourcePoolGet<'a, T> {
#[inline]
fn drop(&mut self) {
self.pool.holder.lock().notify_drop(self.id());
}
}
pub struct ResourceHolder<T> {
pub resources: Vec<T>,
wakers: Vec<(Waker, ClientId)>,
pending: BTreeSet<ClientId>,
}
impl<T> ResourceHolder<T> {
fn new(size: usize) -> Self {
Self {
resources: Vec::with_capacity(size),
wakers: <_>::default(),
pending: <_>::default(),
}
}
#[inline]
fn append_resource(&mut self, res: T) {
self.resources.push(res);
self.wake_next();
}
#[inline]
fn wake_next(&mut self) {
if !self.wakers.is_empty() {
let (waker, id) = self.wakers.remove(0);
self.pending.insert(id);
waker.wake();
}
}
#[inline]
fn notify_drop(&mut self, id: ClientId) {
self.wakers.retain(|(_, i)| *i != id);
if self.pending.remove(&id) {
self.wake_next();
}
}
#[inline]
fn confirm_get(&mut self, id: ClientId) {
self.pending.remove(&id);
}
#[inline]
fn append_callback(&mut self, waker: Waker, id: ClientId) {
self.wakers.push((waker, id));
}
}
pub struct ResourcePool<T> {
pub holder: Arc<Mutex<ResourceHolder<T>>>,
}
impl<T> Default for ResourcePool<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> ResourcePool<T> {
pub fn new() -> Self {
Self {
holder: Arc::new(Mutex::new(ResourceHolder::new(0))),
}
}
pub fn with_capacity(size: usize) -> Self {
Self {
holder: Arc::new(Mutex::new(ResourceHolder::new(size))),
}
}
#[inline]
pub fn append(&self, res: T) {
let mut resources = self.holder.lock();
resources.append_resource(res);
}
#[inline]
pub fn get(&self) -> impl Future<Output = ResourcePoolGuard<T>> + '_ {
ResourcePoolGet {
pool: self,
queued: false,
}
}
}
pub struct ResourcePoolGuard<T> {
resource: Option<T>,
holder: Arc<Mutex<ResourceHolder<T>>>,
}
impl<T> ResourcePoolGuard<T> {
#[inline]
pub fn forget_resource(&mut self) {
self.resource.take();
}
#[inline]
pub fn replace_resource(&mut self, resource: T) {
self.resource.replace(resource);
}
}
impl<T> Deref for ResourcePoolGuard<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
self.resource.as_ref().unwrap()
}
}
impl<T> DerefMut for ResourcePoolGuard<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.resource.as_mut().unwrap()
}
}
impl<T> Drop for ResourcePoolGuard<T> {
fn drop(&mut self) {
if let Some(res) = self.resource.take() {
self.holder.lock().append_resource(res);
}
}
}
#[cfg(test)]
mod test {
use super::ResourcePool;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::sleep;
#[tokio::test(flavor = "multi_thread")]
async fn test_ordering() {
for _ in 0..5 {
let pool = Arc::new(ResourcePool::new());
let op = Instant::now();
pool.append(());
let n = 1_000;
let mut futs = Vec::new();
let (tx, mut rx) = mpsc::channel(n);
for i in 1..=n {
let p = pool.clone();
let tx = tx.clone();
let fut = tokio::spawn(async move {
sleep(Duration::from_millis(1)).await;
let _lock = p.get().await;
tx.send(i).await.unwrap();
println!("future {} locked {}", i, op.elapsed().as_millis());
sleep(Duration::from_millis(10)).await;
});
sleep(Duration::from_millis(2)).await;
if i > 1 && (i - 2) % 10 == 0 {
println!("future {} canceled", i);
fut.abort();
} else {
futs.push(fut);
}
}
for fut in futs {
tokio::time::timeout(Duration::from_secs(10), fut)
.await
.unwrap()
.unwrap();
}
let mut i = 0;
loop {
i += 1;
if i > 1 && (i - 2) % 10 == 0 {
i += 1;
}
if i > n {
break;
}
let fut_n = rx.recv().await.unwrap();
assert_eq!(i, fut_n);
}
assert!(
pool.holder.lock().pending.is_empty(),
"pool is poisoned (pendings)",
);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_no_poisoning() {
let n = 1_000;
for _ in 1..=n {
let pool: Arc<ResourcePool<()>> = Arc::new(ResourcePool::new());
let pool_c = pool.clone();
let fut1 = tokio::spawn(async move {
sleep(Duration::from_millis(1)).await;
let _resource = pool_c.get().await;
});
let pool_c = pool.clone();
let _fut2 = tokio::spawn(async move {
sleep(Duration::from_millis(2)).await;
let _resource = pool_c.get().await;
});
let pool_c = pool.clone();
let _fut3 = tokio::spawn(async move {
sleep(Duration::from_millis(3)).await;
let _resource = pool_c.get().await;
});
sleep(Duration::from_millis(2)).await;
pool.append(());
fut1.abort();
sleep(Duration::from_millis(10)).await;
let holder = pool.holder.lock();
assert!(
holder.wakers.is_empty(),
"pool is poisoned {}/{}",
holder.wakers.len(),
holder.resources.len()
);
assert!(holder.pending.is_empty(), "pool is poisoned (pendings)",);
}
}
}