Skip to main content

veilid_tools/
tag_lock.rs

1use super::*;
2
3use core::fmt::Debug;
4use core::hash::Hash;
5
6#[derive(Clone, Debug)]
7pub struct TagLockGuard<T>
8where
9    T: Hash + Eq + Clone + Debug,
10{
11    inner: Arc<TagLockGuardInner<T>>,
12}
13
14impl<T> TagLockGuard<T>
15where
16    T: Hash + Eq + Clone + Debug,
17{
18    #[must_use]
19    pub fn tag(&self) -> T {
20        self.inner.tag()
21    }
22}
23
24struct TagLockGuardInner<T>
25where
26    T: Hash + Eq + Clone + Debug,
27{
28    table: TagLockTable<T>,
29    tag: T,
30    guard: Option<ArcMutexGuard<RawMutex, ()>>,
31}
32
33impl<T> fmt::Debug for TagLockGuardInner<T>
34where
35    T: Hash + Eq + Clone + Debug,
36{
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        f.debug_struct("TagLockGuardInner")
39            .field("tag", &self.tag)
40            .finish()
41    }
42}
43
44impl<T> TagLockGuardInner<T>
45where
46    T: Hash + Eq + Clone + Debug,
47{
48    fn new(table: TagLockTable<T>, tag: T, guard: ArcMutexGuard<RawMutex, ()>) -> Self {
49        Self {
50            table,
51            tag,
52            guard: Some(guard),
53        }
54    }
55
56    fn tag(&self) -> T {
57        self.tag.clone()
58    }
59}
60
61impl<T> Drop for TagLockGuardInner<T>
62where
63    T: Hash + Eq + Clone + Debug,
64{
65    fn drop(&mut self) {
66        let mut inner = self.table.inner.lock();
67        // Inform the table we're dropping this guard
68        let guards = {
69            // Get the table entry, it must exist since we have a guard locked
70            let entry = inner.table.get_mut(&self.tag).unwrap_or_log();
71            // Decrement the number of guards
72            entry.guards -= 1;
73            // Return the number of guards left
74            entry.guards
75        };
76        // If there are no guards left, we remove the tag from the table
77        if guards == 0 {
78            inner.table.remove(&self.tag).unwrap_or_log();
79        }
80        // Proceed with releasing guard, which may cause some concurrent tag lock to acquire
81        drop(self.guard.take());
82    }
83}
84
85#[derive(Clone, Debug)]
86struct TagLockTableEntry {
87    mutex: Arc<Mutex<()>>,
88    guards: usize,
89}
90
91struct TagLockTableInner<T>
92where
93    T: Hash + Eq + Clone + Debug,
94{
95    table: HashMap<T, TagLockTableEntry>,
96}
97
98#[derive(Clone)]
99pub struct TagLockTable<T>
100where
101    T: Hash + Eq + Clone + Debug,
102{
103    inner: Arc<Mutex<TagLockTableInner<T>>>,
104}
105
106impl<T> fmt::Debug for TagLockTable<T>
107where
108    T: Hash + Eq + Clone + Debug,
109{
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        f.debug_struct("TagLockTable").finish()
112    }
113}
114
115impl<T> TagLockTable<T>
116where
117    T: Hash + Eq + Clone + Debug,
118{
119    #[must_use]
120    pub fn new() -> Self {
121        Self {
122            inner: Arc::new(Mutex::new(TagLockTableInner {
123                table: HashMap::new(),
124            })),
125        }
126    }
127
128    #[must_use]
129    pub fn is_empty(&self) -> bool {
130        let inner = self.inner.lock();
131        inner.table.is_empty()
132    }
133
134    #[must_use]
135    pub fn len(&self) -> usize {
136        let inner = self.inner.lock();
137        inner.table.len()
138    }
139
140    pub fn lock_tag(&self, tag: T) -> TagLockGuard<T> {
141        // Get or create a tag lock entry
142        let mutex = {
143            let mut inner = self.inner.lock();
144
145            // See if this tag is in the table
146            // and if not, add a new mutex for this tag
147            let entry = inner
148                .table
149                .entry(tag.clone())
150                .or_insert_with(|| TagLockTableEntry {
151                    mutex: Arc::new(Mutex::new(())),
152                    guards: 0,
153                });
154
155            // Increment the number of guards
156            entry.guards += 1;
157
158            // Return the mutex associated with the tag
159            entry.mutex.clone()
160
161            // Drop the table guard
162        };
163
164        // Lock the tag lock
165        let guard = mutex.lock_arc();
166
167        // Return the locked guard
168        TagLockGuard {
169            inner: Arc::new(TagLockGuardInner::new(self.clone(), tag, guard)),
170        }
171    }
172
173    pub fn try_lock_tag(&self, tag: T) -> Option<TagLockGuard<T>> {
174        // Get or create a tag lock entry
175        let mut inner = self.inner.lock();
176
177        // See if this tag is in the table
178        // and if not, add a new mutex for this tag
179        let entry = inner.table.entry(tag.clone());
180
181        // Lock the tag lock
182        let guard = match entry {
183            std::collections::hash_map::Entry::Occupied(mut o) => {
184                let e = o.get_mut();
185                let guard = e.mutex.try_lock_arc()?;
186                e.guards += 1;
187                guard
188            }
189            std::collections::hash_map::Entry::Vacant(v) => {
190                let mutex = Arc::new(Mutex::new(()));
191                let guard = mutex.try_lock_arc().unwrap_or_log();
192                v.insert(TagLockTableEntry { mutex, guards: 1 });
193                guard
194            }
195        };
196        // Return guard
197        Some(TagLockGuard {
198            inner: Arc::new(TagLockGuardInner::new(self.clone(), tag, guard)),
199        })
200    }
201}
202
203impl<T> Default for TagLockTable<T>
204where
205    T: Hash + Eq + Clone + Debug,
206{
207    fn default() -> Self {
208        Self::new()
209    }
210}