1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use super::*;

use core::fmt::Debug;
use core::hash::Hash;

#[derive(Debug)]
pub struct AsyncTagLockGuard<T>
where
    T: Hash + Eq + Clone + Debug,
{
    table: AsyncTagLockTable<T>,
    tag: T,
    _guard: AsyncMutexGuardArc<()>,
}

impl<T> AsyncTagLockGuard<T>
where
    T: Hash + Eq + Clone + Debug,
{
    fn new(table: AsyncTagLockTable<T>, tag: T, guard: AsyncMutexGuardArc<()>) -> Self {
        Self {
            table,
            tag,
            _guard: guard,
        }
    }
}

impl<T> Drop for AsyncTagLockGuard<T>
where
    T: Hash + Eq + Clone + Debug,
{
    fn drop(&mut self) {
        let mut inner = self.table.inner.lock();
        // Inform the table we're dropping this guard
        let guards = {
            // Get the table entry, it must exist since we have a guard locked
            let entry = inner.table.get_mut(&self.tag).unwrap();
            // Decrement the number of guards
            entry.guards -= 1;
            // Return the number of guards left
            entry.guards
        };
        // If there are no guards left, we remove the tag from the table
        if guards == 0 {
            inner.table.remove(&self.tag).unwrap();
        }
        // Proceed with releasing _guard, which may cause some concurrent tag lock to acquire
    }
}

#[derive(Clone, Debug)]
struct AsyncTagLockTableEntry {
    mutex: Arc<AsyncMutex<()>>,
    guards: usize,
}

struct AsyncTagLockTableInner<T>
where
    T: Hash + Eq + Clone + Debug,
{
    table: HashMap<T, AsyncTagLockTableEntry>,
}

#[derive(Clone)]
pub struct AsyncTagLockTable<T>
where
    T: Hash + Eq + Clone + Debug,
{
    inner: Arc<Mutex<AsyncTagLockTableInner<T>>>,
}

impl<T> fmt::Debug for AsyncTagLockTable<T>
where
    T: Hash + Eq + Clone + Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("AsyncTagLockTable").finish()
    }
}

impl<T> AsyncTagLockTable<T>
where
    T: Hash + Eq + Clone + Debug,
{
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Mutex::new(AsyncTagLockTableInner {
                table: HashMap::new(),
            })),
        }
    }

    pub fn len(&self) -> usize {
        let inner = self.inner.lock();
        inner.table.len()
    }

    pub async fn lock_tag(&self, tag: T) -> AsyncTagLockGuard<T> {
        // Get or create a tag lock entry
        let mutex = {
            let mut inner = self.inner.lock();

            // See if this tag is in the table
            // and if not, add a new mutex for this tag
            let entry = inner
                .table
                .entry(tag.clone())
                .or_insert_with(|| AsyncTagLockTableEntry {
                    mutex: Arc::new(AsyncMutex::new(())),
                    guards: 0,
                });

            // Increment the number of guards
            entry.guards += 1;

            // Return the mutex associated with the tag
            entry.mutex.clone()

            // Drop the table guard
        };

        // Lock the tag lock
        let guard = asyncmutex_lock_arc!(mutex);

        // Return the locked guard
        AsyncTagLockGuard::new(self.clone(), tag, guard)
    }

    pub fn try_lock_tag(&self, tag: T) -> Option<AsyncTagLockGuard<T>> {
        // Get or create a tag lock entry
        let mut inner = self.inner.lock();

        // See if this tag is in the table
        // and if not, add a new mutex for this tag
        let entry = inner.table.entry(tag.clone());

        // Lock the tag lock
        let guard = match entry {
            std::collections::hash_map::Entry::Occupied(mut o) => {
                let e = o.get_mut();
                let guard = asyncmutex_try_lock_arc!(e.mutex)?;
                e.guards += 1;
                guard
            }
            std::collections::hash_map::Entry::Vacant(v) => {
                let mutex = Arc::new(AsyncMutex::new(()));
                let guard = asyncmutex_try_lock_arc!(mutex)?;
                v.insert(AsyncTagLockTableEntry { mutex, guards: 1 });
                guard
            }
        };
        // Return guard
        Some(AsyncTagLockGuard::new(self.clone(), tag, guard))
    }
}