veilid_tools/
async_tag_lock.rs

1use super::*;
2
3use core::fmt::Debug;
4use core::hash::Hash;
5
6#[derive(Clone, Debug)]
7pub struct AsyncTagLockGuard<T>
8where
9    T: Hash + Eq + Clone + Debug,
10{
11    inner: Arc<AsyncTagLockGuardInner<T>>,
12}
13
14impl<T> AsyncTagLockGuard<T>
15where
16    T: Hash + Eq + Clone + Debug,
17{
18    #[must_use]
19    pub fn tag(&self) -> T {
20        self.inner.tag()
21    }
22}
23
24#[derive(Debug)]
25struct AsyncTagLockGuardInner<T>
26where
27    T: Hash + Eq + Clone + Debug,
28{
29    table: AsyncTagLockTable<T>,
30    tag: T,
31    guard: Option<AsyncMutexGuardArc<()>>,
32}
33
34impl<T> AsyncTagLockGuardInner<T>
35where
36    T: Hash + Eq + Clone + Debug,
37{
38    fn new(table: AsyncTagLockTable<T>, tag: T, guard: AsyncMutexGuardArc<()>) -> Self {
39        Self {
40            table,
41            tag,
42            guard: Some(guard),
43        }
44    }
45
46    fn tag(&self) -> T {
47        self.tag.clone()
48    }
49}
50
51impl<T> Drop for AsyncTagLockGuardInner<T>
52where
53    T: Hash + Eq + Clone + Debug,
54{
55    fn drop(&mut self) {
56        let mut inner = self.table.inner.lock();
57        // Inform the table we're dropping this guard
58        let guards = {
59            // Get the table entry, it must exist since we have a guard locked
60            let entry = inner.table.get_mut(&self.tag).unwrap();
61            // Decrement the number of guards
62            entry.guards -= 1;
63            // Return the number of guards left
64            entry.guards
65        };
66        // If there are no guards left, we remove the tag from the table
67        if guards == 0 {
68            inner.table.remove(&self.tag).unwrap();
69        }
70        // Proceed with releasing guard, which may cause some concurrent tag lock to acquire
71        drop(self.guard.take());
72    }
73}
74
75#[derive(Clone, Debug)]
76struct AsyncTagLockTableEntry {
77    mutex: Arc<AsyncMutex<()>>,
78    guards: usize,
79}
80
81struct AsyncTagLockTableInner<T>
82where
83    T: Hash + Eq + Clone + Debug,
84{
85    table: HashMap<T, AsyncTagLockTableEntry>,
86}
87
88#[derive(Clone)]
89pub struct AsyncTagLockTable<T>
90where
91    T: Hash + Eq + Clone + Debug,
92{
93    inner: Arc<Mutex<AsyncTagLockTableInner<T>>>,
94}
95
96impl<T> fmt::Debug for AsyncTagLockTable<T>
97where
98    T: Hash + Eq + Clone + Debug,
99{
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        f.debug_struct("AsyncTagLockTable").finish()
102    }
103}
104
105impl<T> AsyncTagLockTable<T>
106where
107    T: Hash + Eq + Clone + Debug,
108{
109    #[must_use]
110    pub fn new() -> Self {
111        Self {
112            inner: Arc::new(Mutex::new(AsyncTagLockTableInner {
113                table: HashMap::new(),
114            })),
115        }
116    }
117
118    #[must_use]
119    pub fn is_empty(&self) -> bool {
120        let inner = self.inner.lock();
121        inner.table.is_empty()
122    }
123
124    #[must_use]
125    pub fn len(&self) -> usize {
126        let inner = self.inner.lock();
127        inner.table.len()
128    }
129
130    pub async fn lock_tag(&self, tag: T) -> AsyncTagLockGuard<T> {
131        // Get or create a tag lock entry
132        let mutex = {
133            let mut inner = self.inner.lock();
134
135            // See if this tag is in the table
136            // and if not, add a new mutex for this tag
137            let entry = inner
138                .table
139                .entry(tag.clone())
140                .or_insert_with(|| AsyncTagLockTableEntry {
141                    mutex: Arc::new(AsyncMutex::new(())),
142                    guards: 0,
143                });
144
145            // Increment the number of guards
146            entry.guards += 1;
147
148            // Return the mutex associated with the tag
149            entry.mutex.clone()
150
151            // Drop the table guard
152        };
153
154        // Lock the tag lock
155        let guard = asyncmutex_lock_arc!(mutex);
156
157        // Return the locked guard
158        AsyncTagLockGuard {
159            inner: Arc::new(AsyncTagLockGuardInner::new(self.clone(), tag, guard)),
160        }
161    }
162
163    pub fn try_lock_tag(&self, tag: T) -> Option<AsyncTagLockGuard<T>> {
164        // Get or create a tag lock entry
165        let mut inner = self.inner.lock();
166
167        // See if this tag is in the table
168        // and if not, add a new mutex for this tag
169        let entry = inner.table.entry(tag.clone());
170
171        // Lock the tag lock
172        let guard = match entry {
173            std::collections::hash_map::Entry::Occupied(mut o) => {
174                let e = o.get_mut();
175                let guard = asyncmutex_try_lock_arc!(e.mutex)?;
176                e.guards += 1;
177                guard
178            }
179            std::collections::hash_map::Entry::Vacant(v) => {
180                let mutex = Arc::new(AsyncMutex::new(()));
181                let guard = asyncmutex_try_lock_arc!(mutex).unwrap();
182                v.insert(AsyncTagLockTableEntry { mutex, guards: 1 });
183                guard
184            }
185        };
186        // Return guard
187        Some(AsyncTagLockGuard {
188            inner: Arc::new(AsyncTagLockGuardInner::new(self.clone(), tag, guard)),
189        })
190    }
191}
192
193impl<T> Default for AsyncTagLockTable<T>
194where
195    T: Hash + Eq + Clone + Debug,
196{
197    fn default() -> Self {
198        Self::new()
199    }
200}