use super::*;
use core::fmt::Debug;
use core::hash::Hash;
#[derive(Debug)]
pub struct AsyncTagLockGuard<T>
where
T: Hash + Eq + Clone + Debug,
{
table: AsyncTagLockTable<T>,
tag: T,
_guard: AsyncMutexGuardArc<()>,
}
impl<T> AsyncTagLockGuard<T>
where
T: Hash + Eq + Clone + Debug,
{
fn new(table: AsyncTagLockTable<T>, tag: T, guard: AsyncMutexGuardArc<()>) -> Self {
Self {
table,
tag,
_guard: guard,
}
}
}
impl<T> Drop for AsyncTagLockGuard<T>
where
T: Hash + Eq + Clone + Debug,
{
fn drop(&mut self) {
let mut inner = self.table.inner.lock();
let waiters = {
let entry = inner.table.get_mut(&self.tag).unwrap();
entry.waiters -= 1;
entry.waiters
};
if waiters == 0 {
inner.table.remove(&self.tag).unwrap();
}
}
}
#[derive(Clone, Debug)]
struct AsyncTagLockTableEntry {
mutex: Arc<AsyncMutex<()>>,
waiters: usize,
}
struct AsyncTagLockTableInner<T>
where
T: Hash + Eq + Clone + Debug,
{
table: HashMap<T, AsyncTagLockTableEntry>,
}
#[derive(Clone)]
pub struct AsyncTagLockTable<T>
where
T: Hash + Eq + Clone + Debug,
{
inner: Arc<Mutex<AsyncTagLockTableInner<T>>>,
}
impl<T> fmt::Debug for AsyncTagLockTable<T>
where
T: Hash + Eq + Clone + Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncTagLockTable").finish()
}
}
impl<T> AsyncTagLockTable<T>
where
T: Hash + Eq + Clone + Debug,
{
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(AsyncTagLockTableInner {
table: HashMap::new(),
})),
}
}
pub fn len(&self) -> usize {
let inner = self.inner.lock();
inner.table.len()
}
pub async fn lock_tag(&self, tag: T) -> AsyncTagLockGuard<T> {
let mutex = {
let mut inner = self.inner.lock();
let entry = inner
.table
.entry(tag.clone())
.or_insert_with(|| AsyncTagLockTableEntry {
mutex: Arc::new(AsyncMutex::new(())),
waiters: 0,
});
entry.waiters += 1;
entry.mutex.clone()
};
let guard;
cfg_if! {
if #[cfg(feature="rt-tokio")] {
guard = mutex.lock_owned().await;
} else {
guard = mutex.lock_arc().await;
}
}
AsyncTagLockGuard::new(self.clone(), tag, guard)
}
}