sortlock/
rwlock.rs

1use core::fmt::{self, Debug, Display, Formatter};
2
3#[cfg(feature = "std")]
4use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
5#[cfg(not(feature = "std"))]
6use spin::{RwLock, RwLockWriteGuard, RwLockReadGuard};
7
8
9use crate::{LockGroup, SortKey, SortableLock};
10
11/// A sortable lock that allows either exclusive write access or shared read access. 
12/// This is a sortable version of rust's `RwLock` type.
13///
14/// Locking looks a little different to `RwLock`, as this lock allows sorting with other locks
15/// through the use of `lock_all`. Locking for reading can be performed with `read` while locking
16/// for writing can be performed with `write`.
17/// ```
18/// use sortlock::{SortRwLock, LockGroup};
19///
20/// let lock = SortRwLock::new("some value");
21///
22/// let guard = lock.read().lock_all();
23/// println!("{}", *guard);
24/// ```
25/// ```
26/// use sortlock::{SortRwLock, LockGroup};
27///
28/// let lock = SortRwLock::new(1);
29///
30/// let mut guard = lock.write().lock_all();
31/// *guard += 1;
32/// assert_eq!(2, *guard);
33/// ```
34///
35/// With multiple locks this ensures that locks are always locked in the same order.
36/// This occurs regardless of whether the lock was locked for reading or writing.
37/// ```
38/// use sortlock::{SortRwLock, LockGroup};
39///
40/// let lock1 = SortRwLock::new(100);
41/// let lock2 = SortRwLock::new(200);
42///
43/// // Here lock1 is locked then lock2.
44/// let (guard1, mut guard2) = (lock1.read(), lock2.write()).lock_all();
45/// println!("{}", *guard1);
46/// *guard2 += 1;
47///
48/// // Unlock so we can lock again.
49/// drop(guard1);
50/// drop(guard2);
51///
52/// // Despite the order change the same is true here.
53/// let (guard2, mut guard1) = (lock2.read(), lock1.write()).lock_all();
54/// *guard1 += 1;
55/// println!("{}", *guard2);
56/// ```
57pub struct SortRwLock<T> {
58    /// The internal lock.
59    mutex: RwLock<T>,
60    /// The sort key for this lock.
61    key: SortKey,
62}
63
64impl <T> SortRwLock<T> {
65    /// Creates a new `SortRwLock`.
66    ///
67    /// - `value` - The value of the lock.
68    pub fn new(value: T) -> Self {
69        Self {
70            mutex: RwLock::new(value),
71            key: SortKey::new()
72        }
73    }
74
75    /// Requests to lock this lock for reading.
76    /// This method returns a guard which can be used with `lock_all` to perform a sorted lock.
77    ///
78    /// # Panicking
79    /// The guard will panic when locked if this lock becomes poisoned.
80    pub fn read(&self) -> SortReadGuard<T> {
81        SortReadGuard {
82            lock: self
83        }
84    }
85    
86    /// Requests to lock this lock for writing.
87    /// This method returns a guard which can be used with `lock_all` to perform a sorted lock.
88    ///
89    /// # Panicking
90    /// The guard will panic when locked if this lock becomes poisoned.
91    pub fn write(&self) -> SortWriteGuard<T> {
92        SortWriteGuard {
93            lock: self
94        }
95    }
96}
97
98impl <T: Debug> Debug for SortRwLock<T> {
99    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
100        self.read().lock_all().fmt(f)
101    }
102}
103
104impl <T: Display> Display for SortRwLock<T> {
105    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
106        self.read().lock_all().fmt(f)
107    }
108}
109
110impl <T: Default> Default for SortRwLock<T> {
111    fn default() -> Self {
112        Self::new(T::default())
113    }
114}
115
116/// A read guard for a `SortRwLock`.
117pub struct SortReadGuard<'l, T> {
118    /// The lock this request references.
119    lock: &'l SortRwLock<T>,
120}
121
122impl <'l, T> SortableLock for SortReadGuard<'l, T> {
123    type Guard = RwLockReadGuard<'l, T>;
124
125    fn sort_key(&self) -> SortKey {
126        self.lock.key
127    }
128
129    #[cfg(feature = "std")]
130    fn lock_presorted(&self) -> Self::Guard {
131        self.lock.mutex.read()
132            .expect("Failed to lock mutex.")
133    }
134    
135    #[cfg(not(feature = "std"))]
136    fn lock_presorted(&self) -> Self::Guard {
137        self.lock.mutex.read()
138    }
139}
140
141/// A write guard for a `SortRwLock`.
142pub struct SortWriteGuard<'l, T> {
143    /// The lock this request references.
144    lock: &'l SortRwLock<T>,
145}
146
147impl <'l, T> SortableLock for SortWriteGuard<'l, T> {
148    type Guard = RwLockWriteGuard<'l, T>;
149
150    fn sort_key(&self) -> SortKey {
151        self.lock.key
152    }
153
154    #[cfg(feature = "std")]
155    fn lock_presorted(&self) -> Self::Guard {
156        self.lock.mutex.write()
157            .expect("Failed to lock mutex.")
158    }
159    
160    #[cfg(not(feature = "std"))]
161    fn lock_presorted(&self) -> Self::Guard {
162        self.lock.mutex.write()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use std::{any::Any, sync::Arc, thread};
169
170    use crate::{SortRwLock, LockGroup};
171
172    #[test]
173    fn test_lock2() {
174        let lock1 = SortRwLock::new(1);
175        let lock2 = SortRwLock::new(2);
176
177        let (guard1, guard2) = (lock1.read(), lock2.write()).lock_all();
178
179        println!("{} {}", guard1, guard2);
180    }
181    
182    #[test]
183    fn test_deadlock() -> Result<(), Box<dyn Any + Send + 'static>> {
184        let lock1 = Arc::new(SortRwLock::new(0));
185        let lock2 = Arc::new(SortRwLock::new(0));
186        
187        let lock1b = lock1.clone();
188        let lock2b = lock2.clone();
189        
190        let lock1c = lock1.clone();
191        let lock2c = lock2.clone();
192
193        let count = 1000000;
194
195        let thread1 = thread::spawn(move || {
196            for _ in 0..count {
197                let (mut guard1, guard2) = (lock1.write(), lock2.read()).lock_all();
198               
199                *guard1 += 1;
200
201                drop(guard2);
202            }
203        });
204        let thread2 = thread::spawn(move || {
205            for _ in 0..count {
206                let (mut guard2, guard1) = (lock2b.write(), lock1b.read()).lock_all();
207               
208                *guard2 += 1;
209                
210                drop(guard1);
211            }
212        });
213        thread1.join()?;
214        thread2.join()?;
215
216        assert_eq!(count, *lock1c.read().lock_all());
217        assert_eq!(count, *lock2c.read().lock_all());
218
219        Ok(())
220    }
221}
222