sh_layer2/session_manager/
lock.rs1use parking_lot::{Condvar, Mutex};
11use std::time::Duration;
12
13#[derive(Debug, Default)]
15struct LockState {
16 readers: i32,
17 writers: i32,
18 waiting_writers: i32,
19 write_preferred: bool,
20}
21
22pub struct ReadWriteLock {
26 state: Mutex<LockState>,
27 read_cond: Condvar,
28 write_cond: Condvar,
29}
30
31impl ReadWriteLock {
32 pub fn new() -> Self {
34 Self {
35 state: Mutex::new(LockState::default()),
36 read_cond: Condvar::new(),
37 write_cond: Condvar::new(),
38 }
39 }
40
41 pub fn read<F, T>(&self, f: F) -> T
46 where
47 F: FnOnce() -> T,
48 {
49 let mut state = self.state.lock();
50
51 while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
53 self.read_cond.wait(&mut state);
54 }
55
56 state.readers += 1;
57 drop(state);
59
60 let result = f();
61
62 let mut state = self.state.lock();
63 state.readers -= 1;
64
65 if state.readers == 0 {
67 self.write_cond.notify_all();
68 self.read_cond.notify_all();
69 }
70
71 result
72 }
73
74 pub fn write<F, T>(&self, f: F) -> T
79 where
80 F: FnOnce() -> T,
81 {
82 let mut state = self.state.lock();
83 state.waiting_writers += 1;
84 state.write_preferred = true;
85
86 while state.readers > 0 || state.writers > 0 {
88 self.write_cond.wait(&mut state);
89 }
90
91 state.waiting_writers -= 1;
92 state.writers += 1;
93
94 drop(state);
96
97 let result = f();
98
99 let mut state = self.state.lock();
100 state.writers -= 1;
101
102 if state.waiting_writers == 0 {
104 state.write_preferred = false;
105 }
106
107 self.write_cond.notify_all();
109 self.read_cond.notify_all();
110
111 result
112 }
113
114 pub fn try_read_timeout<F, T>(&self, f: F, timeout: Duration) -> Option<T>
119 where
120 F: FnOnce() -> T,
121 {
122 let mut state = self.state.lock();
123
124 let deadline = std::time::Instant::now() + timeout;
125 while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
126 if self.read_cond.wait_until(&mut state, deadline).timed_out() {
127 return None;
128 }
129 }
130
131 state.readers += 1;
132 drop(state);
133
134 let result = f();
135
136 let mut state = self.state.lock();
137 state.readers -= 1;
138
139 if state.readers == 0 {
140 self.write_cond.notify_all();
141 self.read_cond.notify_all();
142 }
143
144 Some(result)
145 }
146
147 pub fn try_write_timeout<F, T>(&self, f: F, timeout: Duration) -> Option<T>
152 where
153 F: FnOnce() -> T,
154 {
155 let mut state = self.state.lock();
156 state.waiting_writers += 1;
157 state.write_preferred = true;
158
159 let deadline = std::time::Instant::now() + timeout;
160 while state.readers > 0 || state.writers > 0 {
161 if self.write_cond.wait_until(&mut state, deadline).timed_out() {
162 state.waiting_writers -= 1;
163 if state.waiting_writers == 0 {
164 state.write_preferred = false;
165 }
166 return None;
167 }
168 }
169
170 state.waiting_writers -= 1;
171 state.writers += 1;
172 drop(state);
173
174 let result = f();
175
176 let mut state = self.state.lock();
177 state.writers -= 1;
178
179 if state.waiting_writers == 0 {
180 state.write_preferred = false;
181 }
182
183 self.write_cond.notify_all();
184 self.read_cond.notify_all();
185
186 Some(result)
187 }
188
189 pub fn state(&self) -> LockStateInfo {
191 let state = self.state.lock();
192 LockStateInfo {
193 readers: state.readers,
194 writers: state.writers,
195 waiting_writers: state.waiting_writers,
196 write_preferred: state.write_preferred,
197 }
198 }
199}
200
201impl Default for ReadWriteLock {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct LockStateInfo {
210 pub readers: i32,
211 pub writers: i32,
212 pub waiting_writers: i32,
213 pub write_preferred: bool,
214}
215
216#[allow(dead_code)]
218pub struct ReadGuard<'a> {
219 lock: &'a ReadWriteLock,
220}
221
222#[allow(dead_code)]
223impl<'a> ReadGuard<'a> {
224 pub fn new(lock: &'a ReadWriteLock) -> Self {
225 let mut state = lock.state.lock();
226 while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
227 lock.read_cond.wait(&mut state);
228 }
229 state.readers += 1;
230 Self { lock }
231 }
232}
233
234impl<'a> Drop for ReadGuard<'a> {
235 fn drop(&mut self) {
236 let mut state = self.lock.state.lock();
237 state.readers -= 1;
238 if state.readers == 0 {
239 self.lock.write_cond.notify_all();
240 self.lock.read_cond.notify_all();
241 }
242 }
243}
244
245#[allow(dead_code)]
247pub struct WriteGuard<'a> {
248 lock: &'a ReadWriteLock,
249}
250
251#[allow(dead_code)]
252impl<'a> WriteGuard<'a> {
253 pub fn new(lock: &'a ReadWriteLock) -> Self {
254 let mut state = lock.state.lock();
255 state.waiting_writers += 1;
256 state.write_preferred = true;
257
258 while state.readers > 0 || state.writers > 0 {
259 lock.write_cond.wait(&mut state);
260 }
261
262 state.waiting_writers -= 1;
263 state.writers += 1;
264 Self { lock }
265 }
266}
267
268impl<'a> Drop for WriteGuard<'a> {
269 fn drop(&mut self) {
270 let mut state = self.lock.state.lock();
271 state.writers -= 1;
272
273 if state.waiting_writers == 0 {
274 state.write_preferred = false;
275 }
276
277 self.lock.write_cond.notify_all();
278 self.lock.read_cond.notify_all();
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use std::sync::atomic::{AtomicU32, Ordering};
286 use std::sync::Arc;
287 use std::thread;
288
289 #[test]
290 fn test_read_write_lock_basic() {
291 let lock = ReadWriteLock::new();
292
293 let result = lock.read(|| 42);
295 assert_eq!(result, 42);
296
297 let result = lock.write(|| 100);
299 assert_eq!(result, 100);
300 }
301
302 #[test]
303 fn test_concurrent_reads() {
304 let lock = Arc::new(ReadWriteLock::new());
305 let counter = Arc::new(AtomicU32::new(0));
306
307 let mut handles = vec![];
308
309 for _ in 0..10 {
310 let lock = Arc::clone(&lock);
311 let counter = Arc::clone(&counter);
312 handles.push(thread::spawn(move || {
313 lock.read(|| {
314 counter.fetch_add(1, Ordering::SeqCst);
315 thread::sleep(Duration::from_millis(10));
316 });
317 }));
318 }
319
320 for handle in handles {
321 handle.join().unwrap();
322 }
323
324 assert_eq!(counter.load(Ordering::SeqCst), 10);
326 }
327
328 #[test]
329 fn test_state_info() {
330 let lock = ReadWriteLock::new();
331 let state = lock.state();
332 assert_eq!(state.readers, 0);
333 assert_eq!(state.writers, 0);
334 }
335}