1use crate::sync::{futex_wait_fast, NotSend};
31use core::cell::UnsafeCell;
32use core::fmt;
33use core::ops::{Deref, DerefMut};
34use core::ptr::NonNull;
35use core::sync::atomic::AtomicU32;
36use core::sync::atomic::Ordering::{Acquire, Relaxed, Release};
37use rusl::futex::futex_wake;
38
39pub struct RwLock<T: ?Sized> {
40 inner: InnerLock,
41 data: UnsafeCell<T>,
42}
43
44unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
45unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
46#[must_use = "if unused the RwLock will immediately unlock"]
55#[clippy::has_significant_drop]
56pub struct RwLockReadGuard<'a, T: ?Sized + 'a> {
57 data: NonNull<T>,
62 inner_lock: &'a InnerLock,
63 _not_send: NotSend,
64}
65
66unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
67
68impl<'rwlock, T: ?Sized> RwLockReadGuard<'rwlock, T> {
69 unsafe fn new(lock: &'rwlock RwLock<T>) -> RwLockReadGuard<'rwlock, T> {
73 RwLockReadGuard {
74 data: NonNull::new_unchecked(lock.data.get()),
75 inner_lock: &lock.inner,
76 _not_send: NotSend::new(),
77 }
78 }
79}
80#[must_use = "if unused the RwLock will immediately unlock"]
81#[clippy::has_significant_drop]
82pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> {
83 lock: &'a RwLock<T>,
84 _not_send: NotSend,
85}
86
87unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
88
89impl<'rwlock, T: ?Sized> RwLockWriteGuard<'rwlock, T> {
90 unsafe fn new(lock: &'rwlock RwLock<T>) -> RwLockWriteGuard<'rwlock, T> {
94 RwLockWriteGuard {
95 lock,
96 _not_send: NotSend::new(),
97 }
98 }
99}
100impl<T> RwLock<T> {
101 #[inline]
102 pub const fn new(t: T) -> RwLock<T> {
103 RwLock {
104 inner: InnerLock::new(),
105 data: UnsafeCell::new(t),
106 }
107 }
108}
109
110impl<T: ?Sized> RwLock<T> {
111 #[inline]
112 pub fn read(&self) -> RwLockReadGuard<'_, T> {
113 unsafe {
114 self.inner.read();
115 RwLockReadGuard::new(self)
116 }
117 }
118
119 #[inline]
120 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
121 unsafe { self.inner.try_read().then(|| RwLockReadGuard::new(self)) }
122 }
123
124 #[inline]
125 pub fn write(&self) -> RwLockWriteGuard<'_, T> {
126 unsafe {
127 self.inner.write();
128 RwLockWriteGuard::new(self)
129 }
130 }
131
132 #[inline]
133 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
134 unsafe { self.inner.try_write().then(|| RwLockWriteGuard::new(self)) }
135 }
136
137 pub fn into_inner(self) -> T
138 where
139 T: Sized,
140 {
141 self.data.into_inner()
142 }
143
144 pub fn get_mut(&mut self) -> &mut T {
145 self.data.get_mut()
146 }
147}
148
149struct InnerLock {
150 state: AtomicU32,
151 writer_notify: AtomicU32,
152}
153
154const READ_LOCKED: u32 = 1;
155const MASK: u32 = (1 << 30) - 1;
156const WRITE_LOCKED: u32 = MASK;
157const MAX_READERS: u32 = MASK - 1;
158const READERS_WAITING: u32 = 1 << 30;
159const WRITERS_WAITING: u32 = 1 << 31;
160
161#[inline]
162fn is_unlocked(state: u32) -> bool {
163 state & MASK == 0
164}
165
166#[inline]
167fn is_write_locked(state: u32) -> bool {
168 state & MASK == WRITE_LOCKED
169}
170
171#[inline]
172fn has_readers_waiting(state: u32) -> bool {
173 state & READERS_WAITING != 0
174}
175
176#[inline]
177fn has_writers_waiting(state: u32) -> bool {
178 state & WRITERS_WAITING != 0
179}
180
181#[inline]
182fn is_read_lockable(state: u32) -> bool {
183 state & MASK < MAX_READERS && !has_readers_waiting(state) && !has_writers_waiting(state)
190}
191
192#[inline]
193fn has_reached_max_readers(state: u32) -> bool {
194 state & MASK == MAX_READERS
195}
196
197impl InnerLock {
198 #[inline]
199 pub const fn new() -> Self {
200 Self {
201 state: AtomicU32::new(0),
202 writer_notify: AtomicU32::new(0),
203 }
204 }
205
206 #[inline]
207 pub fn try_read(&self) -> bool {
208 self.state
209 .fetch_update(Acquire, Relaxed, |s| {
210 is_read_lockable(s).then_some(s + READ_LOCKED)
211 })
212 .is_ok()
213 }
214
215 #[inline]
216 pub fn read(&self) {
217 let state = self.state.load(Relaxed);
218 if !is_read_lockable(state)
219 || self
220 .state
221 .compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
222 .is_err()
223 {
224 self.read_contended();
225 }
226 }
227
228 #[inline]
229 pub unsafe fn read_unlock(&self) {
230 let state = self.state.fetch_sub(READ_LOCKED, Release) - READ_LOCKED;
231
232 debug_assert!(!has_readers_waiting(state) || has_writers_waiting(state));
235
236 if is_unlocked(state) && has_writers_waiting(state) {
238 self.wake_writer_or_readers(state);
239 }
240 }
241
242 #[cold]
243 fn read_contended(&self) {
244 let mut state = self.spin_read();
245
246 loop {
247 if is_read_lockable(state) {
249 match self
250 .state
251 .compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
252 {
253 Ok(_) => return, Err(s) => {
255 state = s;
256 continue;
257 }
258 }
259 }
260
261 assert!(
263 !has_reached_max_readers(state),
264 "too many active read locks on RwLock"
265 );
266
267 if !has_readers_waiting(state) {
269 if let Err(s) =
270 self.state
271 .compare_exchange(state, state | READERS_WAITING, Relaxed, Relaxed)
272 {
273 state = s;
274 continue;
275 }
276 }
277
278 futex_wait_fast(&self.state, state | READERS_WAITING);
280
281 state = self.spin_read();
283 }
284 }
285
286 #[inline]
287 pub fn try_write(&self) -> bool {
288 self.state
289 .fetch_update(Acquire, Relaxed, |s| {
290 is_unlocked(s).then_some(s + WRITE_LOCKED)
291 })
292 .is_ok()
293 }
294
295 #[inline]
296 pub fn write(&self) {
297 if self
298 .state
299 .compare_exchange_weak(0, WRITE_LOCKED, Acquire, Relaxed)
300 .is_err()
301 {
302 self.write_contended();
303 }
304 }
305
306 #[inline]
307 pub unsafe fn write_unlock(&self) {
308 let state = self.state.fetch_sub(WRITE_LOCKED, Release) - WRITE_LOCKED;
309
310 debug_assert!(is_unlocked(state));
311
312 if has_writers_waiting(state) || has_readers_waiting(state) {
313 self.wake_writer_or_readers(state);
314 }
315 }
316
317 #[cold]
318 fn write_contended(&self) {
319 let mut state = self.spin_write();
320
321 let mut other_writers_waiting = 0;
322
323 loop {
324 if is_unlocked(state) {
326 match self.state.compare_exchange_weak(
327 state,
328 state | WRITE_LOCKED | other_writers_waiting,
329 Acquire,
330 Relaxed,
331 ) {
332 Ok(_) => return, Err(s) => {
334 state = s;
335 continue;
336 }
337 }
338 }
339
340 if !has_writers_waiting(state) {
342 if let Err(s) =
343 self.state
344 .compare_exchange(state, state | WRITERS_WAITING, Relaxed, Relaxed)
345 {
346 state = s;
347 continue;
348 }
349 }
350
351 other_writers_waiting = WRITERS_WAITING;
354
355 let seq = self.writer_notify.load(Acquire);
358
359 state = self.state.load(Relaxed);
362 if is_unlocked(state) || !has_writers_waiting(state) {
363 continue;
364 }
365
366 futex_wait_fast(&self.writer_notify, seq);
368
369 state = self.spin_write();
371 }
372 }
373
374 #[cold]
379 fn wake_writer_or_readers(&self, mut state: u32) {
380 assert!(is_unlocked(state));
381
382 if state == WRITERS_WAITING {
393 match self.state.compare_exchange(state, 0, Relaxed, Relaxed) {
394 Ok(_) => {
395 self.wake_writer();
396 return;
397 }
398 Err(s) => {
399 state = s;
401 }
402 }
403 }
404
405 if state == READERS_WAITING + WRITERS_WAITING {
408 if self
409 .state
410 .compare_exchange(state, READERS_WAITING, Relaxed, Relaxed)
411 .is_err()
412 {
413 return;
415 }
416 if self.wake_writer() {
417 return;
418 }
419 state = READERS_WAITING;
422 }
423
424 if state == READERS_WAITING
426 && self
427 .state
428 .compare_exchange(state, 0, Relaxed, Relaxed)
429 .is_ok()
430 {
431 let _ = futex_wake(&self.state, i32::MAX);
432 }
433 }
434
435 fn wake_writer(&self) -> bool {
436 self.writer_notify.fetch_add(1, Release);
437 futex_wake(&self.writer_notify, 1).unwrap() != 0
438 }
443
444 #[inline]
445 fn spin_until(&self, f: impl Fn(u32) -> bool) -> u32 {
446 let mut spin = 100; loop {
448 let state = self.state.load(Relaxed);
449 if f(state) || spin == 0 {
450 return state;
451 }
452 core::hint::spin_loop();
453 spin -= 1;
454 }
455 }
456
457 #[inline]
458 fn spin_write(&self) -> u32 {
459 self.spin_until(|state| is_unlocked(state) || has_writers_waiting(state))
461 }
462
463 #[inline]
464 fn spin_read(&self) -> u32 {
465 self.spin_until(|state| {
467 !is_write_locked(state) || has_readers_waiting(state) || has_writers_waiting(state)
468 })
469 }
470}
471
472impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
473 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474 (**self).fmt(f)
475 }
476}
477
478impl<T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'_, T> {
479 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480 (**self).fmt(f)
481 }
482}
483
484impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
485 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
486 (**self).fmt(f)
487 }
488}
489
490impl<T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'_, T> {
491 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492 (**self).fmt(f)
493 }
494}
495
496impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
497 type Target = T;
498
499 fn deref(&self) -> &T {
500 unsafe { self.data.as_ref() }
502 }
503}
504
505impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
506 type Target = T;
507
508 fn deref(&self) -> &T {
509 unsafe { &*self.lock.data.get() }
511 }
512}
513
514impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
515 fn deref_mut(&mut self) -> &mut T {
516 unsafe { &mut *self.lock.data.get() }
518 }
519}
520
521impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
522 fn drop(&mut self) {
523 unsafe {
525 self.inner_lock.read_unlock();
526 }
527 }
528}
529
530impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
531 fn drop(&mut self) {
532 unsafe {
534 self.lock.inner.write_unlock();
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use crate::sync::RwLock;
542 use core::time::Duration;
543
544 #[test]
545 fn can_lock() {
546 let rw = std::sync::Arc::new(super::RwLock::new(0));
547 let rw_c = rw.clone();
548 let mut guard = rw.write();
549 let res = std::thread::spawn(move || *rw_c.read());
550 *guard = 15;
551 drop(guard);
552 let thread_res = res.join().unwrap();
553 assert_eq!(15, thread_res);
554 }
555
556 #[test]
557 fn can_mutex_contended() {
558 const NUM_THREADS: usize = 32;
559 let count = std::sync::Arc::new(RwLock::new(0));
560 let mut handles = std::vec::Vec::new();
561 for _i in 0..NUM_THREADS {
562 let count_c = count.clone();
563 let handle = std::thread::spawn(move || {
564 let mut w_guard = count_c.write();
566 let orig = *w_guard;
567 std::thread::sleep(Duration::from_millis(1));
568 *w_guard += 1;
569 drop(w_guard);
570 std::thread::sleep(Duration::from_millis(1));
571 let r_guard = count_c.read();
572 std::thread::sleep(Duration::from_millis(1));
573 assert!(*r_guard > orig);
575 });
576 handles.push(handle);
577 }
578 for h in handles {
579 h.join().unwrap();
580 }
581 assert_eq!(NUM_THREADS, *count.read());
582 }
583
584 #[test]
585 fn can_try_rw_single_thread_contended() {
586 let rw = std::sync::Arc::new(super::RwLock::new(0));
587 let rw_c = rw.clone();
588 assert_eq!(0, *rw_c.try_read().unwrap());
589 let r_guard = rw.read();
590 assert_eq!(0, *rw_c.try_read().unwrap());
591 assert!(rw_c.try_write().is_none());
592 drop(r_guard);
593 assert_eq!(0, *rw_c.try_write().unwrap());
594 let _w_guard = rw.write();
595 assert!(rw_c.try_read().is_none());
596 }
597}