Skip to main content

sh_layer2/session_manager/
lock.rs

1//! # ReadWriteLock
2//!
3//! 读写分离锁实现,支持并发读取和互斥写入。
4//!
5//! 特性:
6//! - 读操作可并发执行(共享锁)
7//! - 写操作需互斥执行(排他锁)
8//! - 写优先:当有写者等待时,新的读者会被阻塞
9
10use parking_lot::{Condvar, Mutex};
11use std::time::Duration;
12
13/// 读写锁状态
14#[derive(Debug, Default)]
15struct LockState {
16    readers: i32,
17    writers: i32,
18    waiting_writers: i32,
19    write_preferred: bool,
20}
21
22/// 读写分离锁
23///
24/// 使用 parking_lot 实现的读写分离锁,比标准库 RwLock 提供更细粒度的控制。
25pub struct ReadWriteLock {
26    state: Mutex<LockState>,
27    read_cond: Condvar,
28    write_cond: Condvar,
29}
30
31impl ReadWriteLock {
32    /// 创建新的读写锁
33    pub fn new() -> Self {
34        Self {
35            state: Mutex::new(LockState::default()),
36            read_cond: Condvar::new(),
37            write_cond: Condvar::new(),
38        }
39    }
40
41    /// 获取读锁
42    ///
43    /// 多个读者可以同时持有读锁。
44    /// 当有写者活跃或等待时,读者会被阻塞。
45    pub fn read<F, T>(&self, f: F) -> T
46    where
47        F: FnOnce() -> T,
48    {
49        let mut state = self.state.lock();
50
51        // 等待条件:没有活跃写者,且没有写者在等待(写优先)
52        while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
53            self.read_cond.wait(&mut state);
54        }
55
56        state.readers += 1;
57        // 释放锁后执行读操作
58        drop(state);
59
60        let result = f();
61
62        let mut state = self.state.lock();
63        state.readers -= 1;
64
65        // 如果没有读者了,通知等待的写者
66        if state.readers == 0 {
67            self.write_cond.notify_all();
68            self.read_cond.notify_all();
69        }
70
71        result
72    }
73
74    /// 获取写锁
75    ///
76    /// 写锁是排他的,同一时间只能有一个写者。
77    /// 当有读者或写者活跃时,新的写者会被阻塞。
78    pub fn write<F, T>(&self, f: F) -> T
79    where
80        F: FnOnce() -> T,
81    {
82        let mut state = self.state.lock();
83        state.waiting_writers += 1;
84        state.write_preferred = true;
85
86        // 等待条件:没有活跃读者和写者
87        while state.readers > 0 || state.writers > 0 {
88            self.write_cond.wait(&mut state);
89        }
90
91        state.waiting_writers -= 1;
92        state.writers += 1;
93
94        // 释放锁后执行写操作
95        drop(state);
96
97        let result = f();
98
99        let mut state = self.state.lock();
100        state.writers -= 1;
101
102        // 如果没有等待的写者了,清除写优先标志
103        if state.waiting_writers == 0 {
104            state.write_preferred = false;
105        }
106
107        // 通知所有等待的线程
108        self.write_cond.notify_all();
109        self.read_cond.notify_all();
110
111        result
112    }
113
114    /// 尝试获取读锁(带超时)
115    ///
116    /// # Returns
117    /// 成功返回 Some(result),超时返回 None
118    pub fn try_read_timeout<F, T>(&self, f: F, timeout: Duration) -> Option<T>
119    where
120        F: FnOnce() -> T,
121    {
122        let mut state = self.state.lock();
123
124        let deadline = std::time::Instant::now() + timeout;
125        while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
126            if self.read_cond.wait_until(&mut state, deadline).timed_out() {
127                return None;
128            }
129        }
130
131        state.readers += 1;
132        drop(state);
133
134        let result = f();
135
136        let mut state = self.state.lock();
137        state.readers -= 1;
138
139        if state.readers == 0 {
140            self.write_cond.notify_all();
141            self.read_cond.notify_all();
142        }
143
144        Some(result)
145    }
146
147    /// 尝试获取写锁(带超时)
148    ///
149    /// # Returns
150    /// 成功返回 Some(result),超时返回 None
151    pub fn try_write_timeout<F, T>(&self, f: F, timeout: Duration) -> Option<T>
152    where
153        F: FnOnce() -> T,
154    {
155        let mut state = self.state.lock();
156        state.waiting_writers += 1;
157        state.write_preferred = true;
158
159        let deadline = std::time::Instant::now() + timeout;
160        while state.readers > 0 || state.writers > 0 {
161            if self.write_cond.wait_until(&mut state, deadline).timed_out() {
162                state.waiting_writers -= 1;
163                if state.waiting_writers == 0 {
164                    state.write_preferred = false;
165                }
166                return None;
167            }
168        }
169
170        state.waiting_writers -= 1;
171        state.writers += 1;
172        drop(state);
173
174        let result = f();
175
176        let mut state = self.state.lock();
177        state.writers -= 1;
178
179        if state.waiting_writers == 0 {
180            state.write_preferred = false;
181        }
182
183        self.write_cond.notify_all();
184        self.read_cond.notify_all();
185
186        Some(result)
187    }
188
189    /// 获取锁状态(用于调试)
190    pub fn state(&self) -> LockStateInfo {
191        let state = self.state.lock();
192        LockStateInfo {
193            readers: state.readers,
194            writers: state.writers,
195            waiting_writers: state.waiting_writers,
196            write_preferred: state.write_preferred,
197        }
198    }
199}
200
201impl Default for ReadWriteLock {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207/// 锁状态信息(用于调试)
208#[derive(Debug, Clone)]
209pub struct LockStateInfo {
210    pub readers: i32,
211    pub writers: i32,
212    pub waiting_writers: i32,
213    pub write_preferred: bool,
214}
215
216/// RAII 读锁守卫
217#[allow(dead_code)]
218pub struct ReadGuard<'a> {
219    lock: &'a ReadWriteLock,
220}
221
222#[allow(dead_code)]
223impl<'a> ReadGuard<'a> {
224    pub fn new(lock: &'a ReadWriteLock) -> Self {
225        let mut state = lock.state.lock();
226        while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
227            lock.read_cond.wait(&mut state);
228        }
229        state.readers += 1;
230        Self { lock }
231    }
232}
233
234impl<'a> Drop for ReadGuard<'a> {
235    fn drop(&mut self) {
236        let mut state = self.lock.state.lock();
237        state.readers -= 1;
238        if state.readers == 0 {
239            self.lock.write_cond.notify_all();
240            self.lock.read_cond.notify_all();
241        }
242    }
243}
244
245/// RAII 写锁守卫
246#[allow(dead_code)]
247pub struct WriteGuard<'a> {
248    lock: &'a ReadWriteLock,
249}
250
251#[allow(dead_code)]
252impl<'a> WriteGuard<'a> {
253    pub fn new(lock: &'a ReadWriteLock) -> Self {
254        let mut state = lock.state.lock();
255        state.waiting_writers += 1;
256        state.write_preferred = true;
257
258        while state.readers > 0 || state.writers > 0 {
259            lock.write_cond.wait(&mut state);
260        }
261
262        state.waiting_writers -= 1;
263        state.writers += 1;
264        Self { lock }
265    }
266}
267
268impl<'a> Drop for WriteGuard<'a> {
269    fn drop(&mut self) {
270        let mut state = self.lock.state.lock();
271        state.writers -= 1;
272
273        if state.waiting_writers == 0 {
274            state.write_preferred = false;
275        }
276
277        self.lock.write_cond.notify_all();
278        self.lock.read_cond.notify_all();
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use std::sync::atomic::{AtomicU32, Ordering};
286    use std::sync::Arc;
287    use std::thread;
288
289    #[test]
290    fn test_read_write_lock_basic() {
291        let lock = ReadWriteLock::new();
292
293        // 测试读操作
294        let result = lock.read(|| 42);
295        assert_eq!(result, 42);
296
297        // 测试写操作
298        let result = lock.write(|| 100);
299        assert_eq!(result, 100);
300    }
301
302    #[test]
303    fn test_concurrent_reads() {
304        let lock = Arc::new(ReadWriteLock::new());
305        let counter = Arc::new(AtomicU32::new(0));
306
307        let mut handles = vec![];
308
309        for _ in 0..10 {
310            let lock = Arc::clone(&lock);
311            let counter = Arc::clone(&counter);
312            handles.push(thread::spawn(move || {
313                lock.read(|| {
314                    counter.fetch_add(1, Ordering::SeqCst);
315                    thread::sleep(Duration::from_millis(10));
316                });
317            }));
318        }
319
320        for handle in handles {
321            handle.join().unwrap();
322        }
323
324        // 所有读操作应该并发执行
325        assert_eq!(counter.load(Ordering::SeqCst), 10);
326    }
327
328    #[test]
329    fn test_state_info() {
330        let lock = ReadWriteLock::new();
331        let state = lock.state();
332        assert_eq!(state.readers, 0);
333        assert_eq!(state.writers, 0);
334    }
335}