veilid_tools/
async_tag_lock.rs

1use super::*;
2
3use core::fmt::Debug;
4use core::hash::Hash;
5
6#[derive(Debug)]
7pub struct AsyncTagLockGuard<T>
8where
9    T: Hash + Eq + Clone + Debug,
10{
11    table: AsyncTagLockTable<T>,
12    tag: T,
13    _guard: AsyncMutexGuardArc<()>,
14}
15
16impl<T> AsyncTagLockGuard<T>
17where
18    T: Hash + Eq + Clone + Debug,
19{
20    fn new(table: AsyncTagLockTable<T>, tag: T, guard: AsyncMutexGuardArc<()>) -> Self {
21        Self {
22            table,
23            tag,
24            _guard: guard,
25        }
26    }
27}
28
29impl<T> Drop for AsyncTagLockGuard<T>
30where
31    T: Hash + Eq + Clone + Debug,
32{
33    fn drop(&mut self) {
34        let mut inner = self.table.inner.lock();
35        // Inform the table we're dropping this guard
36        let guards = {
37            // Get the table entry, it must exist since we have a guard locked
38            let entry = inner.table.get_mut(&self.tag).unwrap();
39            // Decrement the number of guards
40            entry.guards -= 1;
41            // Return the number of guards left
42            entry.guards
43        };
44        // If there are no guards left, we remove the tag from the table
45        if guards == 0 {
46            inner.table.remove(&self.tag).unwrap();
47        }
48        // Proceed with releasing _guard, which may cause some concurrent tag lock to acquire
49    }
50}
51
52#[derive(Clone, Debug)]
53struct AsyncTagLockTableEntry {
54    mutex: Arc<AsyncMutex<()>>,
55    guards: usize,
56}
57
58struct AsyncTagLockTableInner<T>
59where
60    T: Hash + Eq + Clone + Debug,
61{
62    table: HashMap<T, AsyncTagLockTableEntry>,
63}
64
65#[derive(Clone)]
66pub struct AsyncTagLockTable<T>
67where
68    T: Hash + Eq + Clone + Debug,
69{
70    inner: Arc<Mutex<AsyncTagLockTableInner<T>>>,
71}
72
73impl<T> fmt::Debug for AsyncTagLockTable<T>
74where
75    T: Hash + Eq + Clone + Debug,
76{
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        f.debug_struct("AsyncTagLockTable").finish()
79    }
80}
81
82impl<T> AsyncTagLockTable<T>
83where
84    T: Hash + Eq + Clone + Debug,
85{
86    pub fn new() -> Self {
87        Self {
88            inner: Arc::new(Mutex::new(AsyncTagLockTableInner {
89                table: HashMap::new(),
90            })),
91        }
92    }
93
94    pub fn is_empty(&self) -> bool {
95        let inner = self.inner.lock();
96        inner.table.is_empty()
97    }
98
99    pub fn len(&self) -> usize {
100        let inner = self.inner.lock();
101        inner.table.len()
102    }
103
104    pub async fn lock_tag(&self, tag: T) -> AsyncTagLockGuard<T> {
105        // Get or create a tag lock entry
106        let mutex = {
107            let mut inner = self.inner.lock();
108
109            // See if this tag is in the table
110            // and if not, add a new mutex for this tag
111            let entry = inner
112                .table
113                .entry(tag.clone())
114                .or_insert_with(|| AsyncTagLockTableEntry {
115                    mutex: Arc::new(AsyncMutex::new(())),
116                    guards: 0,
117                });
118
119            // Increment the number of guards
120            entry.guards += 1;
121
122            // Return the mutex associated with the tag
123            entry.mutex.clone()
124
125            // Drop the table guard
126        };
127
128        // Lock the tag lock
129        let guard = asyncmutex_lock_arc!(mutex);
130
131        // Return the locked guard
132        AsyncTagLockGuard::new(self.clone(), tag, guard)
133    }
134
135    pub fn try_lock_tag(&self, tag: T) -> Option<AsyncTagLockGuard<T>> {
136        // Get or create a tag lock entry
137        let mut inner = self.inner.lock();
138
139        // See if this tag is in the table
140        // and if not, add a new mutex for this tag
141        let entry = inner.table.entry(tag.clone());
142
143        // Lock the tag lock
144        let guard = match entry {
145            std::collections::hash_map::Entry::Occupied(mut o) => {
146                let e = o.get_mut();
147                let guard = asyncmutex_try_lock_arc!(e.mutex)?;
148                e.guards += 1;
149                guard
150            }
151            std::collections::hash_map::Entry::Vacant(v) => {
152                let mutex = Arc::new(AsyncMutex::new(()));
153                let guard = asyncmutex_try_lock_arc!(mutex)?;
154                v.insert(AsyncTagLockTableEntry { mutex, guards: 1 });
155                guard
156            }
157        };
158        // Return guard
159        Some(AsyncTagLockGuard::new(self.clone(), tag, guard))
160    }
161}
162
163impl<T> Default for AsyncTagLockTable<T>
164where
165    T: Hash + Eq + Clone + Debug,
166{
167    fn default() -> Self {
168        Self::new()
169    }
170}