1use core::fmt::{self, Debug, Display, Formatter};
2
3#[cfg(feature = "std")]
4use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
5#[cfg(not(feature = "std"))]
6use spin::{RwLock, RwLockWriteGuard, RwLockReadGuard};
7
8
9use crate::{LockGroup, SortKey, SortableLock};
10
11pub struct SortRwLock<T> {
58 mutex: RwLock<T>,
60 key: SortKey,
62}
63
64impl <T> SortRwLock<T> {
65 pub fn new(value: T) -> Self {
69 Self {
70 mutex: RwLock::new(value),
71 key: SortKey::new()
72 }
73 }
74
75 pub fn read(&self) -> SortReadGuard<T> {
81 SortReadGuard {
82 lock: self
83 }
84 }
85
86 pub fn write(&self) -> SortWriteGuard<T> {
92 SortWriteGuard {
93 lock: self
94 }
95 }
96}
97
98impl <T: Debug> Debug for SortRwLock<T> {
99 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
100 self.read().lock_all().fmt(f)
101 }
102}
103
104impl <T: Display> Display for SortRwLock<T> {
105 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
106 self.read().lock_all().fmt(f)
107 }
108}
109
110impl <T: Default> Default for SortRwLock<T> {
111 fn default() -> Self {
112 Self::new(T::default())
113 }
114}
115
116pub struct SortReadGuard<'l, T> {
118 lock: &'l SortRwLock<T>,
120}
121
122impl <'l, T> SortableLock for SortReadGuard<'l, T> {
123 type Guard = RwLockReadGuard<'l, T>;
124
125 fn sort_key(&self) -> SortKey {
126 self.lock.key
127 }
128
129 #[cfg(feature = "std")]
130 fn lock_presorted(&self) -> Self::Guard {
131 self.lock.mutex.read()
132 .expect("Failed to lock mutex.")
133 }
134
135 #[cfg(not(feature = "std"))]
136 fn lock_presorted(&self) -> Self::Guard {
137 self.lock.mutex.read()
138 }
139}
140
141pub struct SortWriteGuard<'l, T> {
143 lock: &'l SortRwLock<T>,
145}
146
147impl <'l, T> SortableLock for SortWriteGuard<'l, T> {
148 type Guard = RwLockWriteGuard<'l, T>;
149
150 fn sort_key(&self) -> SortKey {
151 self.lock.key
152 }
153
154 #[cfg(feature = "std")]
155 fn lock_presorted(&self) -> Self::Guard {
156 self.lock.mutex.write()
157 .expect("Failed to lock mutex.")
158 }
159
160 #[cfg(not(feature = "std"))]
161 fn lock_presorted(&self) -> Self::Guard {
162 self.lock.mutex.write()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use std::{any::Any, sync::Arc, thread};
169
170 use crate::{SortRwLock, LockGroup};
171
172 #[test]
173 fn test_lock2() {
174 let lock1 = SortRwLock::new(1);
175 let lock2 = SortRwLock::new(2);
176
177 let (guard1, guard2) = (lock1.read(), lock2.write()).lock_all();
178
179 println!("{} {}", guard1, guard2);
180 }
181
182 #[test]
183 fn test_deadlock() -> Result<(), Box<dyn Any + Send + 'static>> {
184 let lock1 = Arc::new(SortRwLock::new(0));
185 let lock2 = Arc::new(SortRwLock::new(0));
186
187 let lock1b = lock1.clone();
188 let lock2b = lock2.clone();
189
190 let lock1c = lock1.clone();
191 let lock2c = lock2.clone();
192
193 let count = 1000000;
194
195 let thread1 = thread::spawn(move || {
196 for _ in 0..count {
197 let (mut guard1, guard2) = (lock1.write(), lock2.read()).lock_all();
198
199 *guard1 += 1;
200
201 drop(guard2);
202 }
203 });
204 let thread2 = thread::spawn(move || {
205 for _ in 0..count {
206 let (mut guard2, guard1) = (lock2b.write(), lock1b.read()).lock_all();
207
208 *guard2 += 1;
209
210 drop(guard1);
211 }
212 });
213 thread1.join()?;
214 thread2.join()?;
215
216 assert_eq!(count, *lock1c.read().lock_all());
217 assert_eq!(count, *lock2c.read().lock_all());
218
219 Ok(())
220 }
221}
222