1#![allow(dead_code)]
2
3use super::cell::CausalCell;
14use super::waker::AtomicWaker;
15use std::{
16 sync::atomic::{AtomicPtr, AtomicUsize},
17 thread,
18};
19
20use std::fmt;
21use std::ptr::{self, NonNull};
22use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
23use std::sync::Arc;
24use std::task::Poll::{Pending, Ready};
25use std::task::{Context, Poll};
26use std::usize;
27
28pub(crate) struct Semaphore {
30 state: AtomicUsize,
33
34 head: CausalCell<NonNull<WaiterNode>>,
36
37 rx_lock: AtomicUsize,
39
40 stub: Box<WaiterNode>,
42}
43
44#[derive(Debug)]
56pub(crate) struct Permit {
57 waiter: Option<Arc<WaiterNode>>,
58 state: PermitState,
59}
60
61#[derive(Debug)]
63pub(crate) struct AcquireError(());
64
65#[derive(Debug)]
67pub(crate) struct TryAcquireError {
68 kind: ErrorKind,
69}
70
71#[derive(Debug)]
72enum ErrorKind {
73 Closed,
74 NoPermits,
75}
76
77#[derive(Debug)]
79struct WaiterNode {
80 state: AtomicUsize,
84
85 waker: AtomicWaker,
87
88 next: AtomicPtr<WaiterNode>,
90}
91
92#[derive(Copy, Clone)]
103struct SemState(usize);
104
105#[derive(Debug, Copy, Clone, Eq, PartialEq)]
107enum PermitState {
108 Idle,
110
111 Waiting,
114
115 Acquired,
117}
118
119#[derive(Debug, Copy, Clone, Eq, PartialEq)]
121#[repr(usize)]
122enum NodeState {
123 Idle = 0,
127
128 Queued = 1,
135
136 QueuedWaiting = 2,
138
139 Assigned = 3,
142
143 Closed = 4,
145}
146
147impl Semaphore {
150 pub(crate) fn new(permits: usize) -> Semaphore {
156 let stub = Box::new(WaiterNode::new());
157 let ptr = NonNull::new(&*stub as *const _ as *mut _).unwrap();
158
159 debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
161
162 let state = SemState::new(permits, &stub);
163
164 Semaphore {
165 state: AtomicUsize::new(state.to_usize()),
166 head: CausalCell::new(ptr),
167 rx_lock: AtomicUsize::new(0),
168 stub,
169 }
170 }
171
172 pub(crate) fn available_permits(&self) -> usize {
174 let curr = SemState::load(&self.state, Acquire);
175 curr.available_permits()
176 }
177
178 fn poll_permit(
180 &self,
181 mut permit: Option<(&mut Context<'_>, &mut Permit)>,
182 ) -> Poll<Result<(), AcquireError>> {
183 let mut curr = SemState::load(&self.state, Acquire);
185
186 let mut maybe_strong: Option<NonNull<WaiterNode>> = None;
190
191 macro_rules! undo_strong {
192 () => {
193 if let Some(waiter) = maybe_strong {
194 let waiter = unsafe { Arc::from_raw(waiter.as_ptr()) };
199 waiter.revert_to_idle();
200 }
201 };
202 }
203
204 loop {
205 let mut next = curr;
206
207 if curr.is_closed() {
208 undo_strong!();
209 return Ready(Err(AcquireError::closed()));
210 }
211
212 if !next.acquire_permit(&self.stub) {
213 debug_assert!(curr.waiter().is_some());
214
215 if maybe_strong.is_none() {
216 if let Some((ref mut cx, ref mut permit)) = permit {
217 let waiter = permit
219 .waiter
220 .get_or_insert_with(|| Arc::new(WaiterNode::new()));
221
222 waiter.register(cx);
223
224 if !waiter.to_queued_waiting() {
225 return Pending;
228 }
229
230 maybe_strong = Some(WaiterNode::into_non_null(waiter.clone()));
231 } else {
232 return Pending;
235 }
236 }
237
238 next.set_waiter(maybe_strong.unwrap());
239 }
240
241 debug_assert_ne!(curr.0, 0);
242 debug_assert_ne!(next.0, 0);
243
244 match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
245 Ok(_) => {
246 match curr.waiter() {
247 Some(prev_waiter) => {
248 let waiter = maybe_strong.unwrap();
249
250 unsafe {
252 prev_waiter.as_ref().next.store(waiter.as_ptr(), Release);
253 }
254
255 return Pending;
256 }
257 None => {
258 undo_strong!();
259
260 return Ready(Ok(()));
261 }
262 }
263 }
264 Err(actual) => {
265 curr = actual;
266 }
267 }
268 }
269 }
270
271 pub(crate) fn close(&self) {
274 let prev = self.rx_lock.fetch_or(1, AcqRel);
276
277 if prev != 0 {
278 return;
281 }
282
283 self.add_permits_locked(0, true);
284 }
285
286 pub(crate) fn add_permits(&self, n: usize) {
288 if n == 0 {
289 return;
290 }
291
292 let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
295
296 if prev != 0 {
297 return;
300 }
301
302 self.add_permits_locked(n, false);
303 }
304
305 fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
306 while rem > 0 || closed {
307 if closed {
308 SemState::fetch_set_closed(&self.state, AcqRel);
309 }
310
311 self.add_permits_locked2(rem, closed);
313
314 let n = rem << 1;
315
316 let actual = if closed {
317 let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
318 closed = false;
319 actual
320 } else {
321 let actual = self.rx_lock.fetch_sub(n, AcqRel);
322 closed = actual & 1 == 1;
323 actual
324 };
325
326 rem = (actual >> 1) - rem;
327 }
328 }
329
330 fn add_permits_locked2(&self, mut n: usize, closed: bool) {
335 while n > 0 || closed {
336 let waiter = match self.pop(n, closed) {
337 Some(waiter) => waiter,
338 None => {
339 return;
340 }
341 };
342
343 if waiter.notify(closed) {
344 n = n.saturating_sub(1);
345 }
346 }
347 }
348
349 fn pop(&self, rem: usize, closed: bool) -> Option<Arc<WaiterNode>> {
355 'outer: loop {
356 unsafe {
357 let mut head = self.head.with(|head| *head);
358 let mut next_ptr = head.as_ref().next.load(Acquire);
359
360 let stub = self.stub();
361
362 if head == stub {
363 let next = match NonNull::new(next_ptr) {
364 Some(next) => next,
365 None => {
366 let mut curr = SemState::load(&self.state, Acquire);
384
385 loop {
386 if curr.has_waiter(&self.stub) {
387 thread::yield_now();
389 continue 'outer;
390 }
391
392 if rem == 0 {
397 debug_assert!(curr.is_closed(), "state = {:?}", curr);
398 return None;
399 }
400
401 let mut next = curr;
402 next.release_permits(rem, &self.stub);
403
404 match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
405 Ok(_) => return None,
406 Err(actual) => {
407 curr = actual;
408 }
409 }
410 }
411 }
412 };
413
414 self.head.with_mut(|head| *head = next);
415 head = next;
416 next_ptr = next.as_ref().next.load(Acquire);
417 }
418
419 if let Some(next) = NonNull::new(next_ptr) {
420 self.head.with_mut(|head| *head = next);
421
422 return Some(Arc::from_raw(head.as_ptr()));
423 }
424
425 let state = SemState::load(&self.state, Acquire);
426
427 let tail = state.waiter().unwrap();
429
430 if tail != head {
431 thread::yield_now();
433 continue 'outer;
434 }
435
436 self.push_stub(closed);
437
438 next_ptr = head.as_ref().next.load(Acquire);
439
440 if let Some(next) = NonNull::new(next_ptr) {
441 self.head.with_mut(|head| *head = next);
442
443 return Some(Arc::from_raw(head.as_ptr()));
444 }
445
446 thread::yield_now();
448 }
449 }
450 }
451
452 unsafe fn push_stub(&self, closed: bool) {
453 let stub = self.stub();
454
455 stub.as_ref().next.store(ptr::null_mut(), Relaxed);
459
460 let prev = SemState::new_ptr(stub, closed).swap(&self.state, AcqRel);
464
465 debug_assert_eq!(closed, prev.is_closed());
466
467 let prev = prev.waiter().unwrap();
470
471 debug_assert_ne!(prev, stub);
473
474 prev.as_ref().next.store(stub.as_ptr(), Release);
476 }
477
478 fn stub(&self) -> NonNull<WaiterNode> {
479 unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
480 }
481}
482
483impl fmt::Debug for Semaphore {
484 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
485 fmt.debug_struct("Semaphore")
486 .field("state", &SemState::load(&self.state, Relaxed))
487 .field("head", &self.head.with(|ptr| ptr))
488 .field("rx_lock", &self.rx_lock.load(Relaxed))
489 .field("stub", &self.stub)
490 .finish()
491 }
492}
493
494unsafe impl Send for Semaphore {}
495unsafe impl Sync for Semaphore {}
496
497impl Permit {
500 pub(crate) fn new() -> Permit {
504 Permit {
505 waiter: None,
506 state: PermitState::Idle,
507 }
508 }
509
510 pub(crate) fn is_acquired(&self) -> bool {
512 self.state == PermitState::Acquired
513 }
514
515 pub(crate) fn poll_acquire(
518 &mut self,
519 cx: &mut Context<'_>,
520 semaphore: &Semaphore,
521 ) -> Poll<Result<(), AcquireError>> {
522 match self.state {
523 PermitState::Idle => {}
524 PermitState::Waiting => {
525 let waiter = self.waiter.as_ref().unwrap();
526
527 if waiter.acquire(cx)? {
528 self.state = PermitState::Acquired;
529 return Ready(Ok(()));
530 } else {
531 return Pending;
532 }
533 }
534 PermitState::Acquired => {
535 return Ready(Ok(()));
536 }
537 }
538
539 match semaphore.poll_permit(Some((cx, self)))? {
540 Ready(()) => {
541 self.state = PermitState::Acquired;
542 Ready(Ok(()))
543 }
544 Pending => {
545 self.state = PermitState::Waiting;
546 Pending
547 }
548 }
549 }
550
551 pub(crate) fn try_acquire(&mut self, semaphore: &Semaphore) -> Result<(), TryAcquireError> {
553 match self.state {
554 PermitState::Idle => {}
555 PermitState::Waiting => {
556 let waiter = self.waiter.as_ref().unwrap();
557
558 if waiter.acquire2().map_err(to_try_acquire)? {
559 self.state = PermitState::Acquired;
560 return Ok(());
561 } else {
562 return Err(TryAcquireError::no_permits());
563 }
564 }
565 PermitState::Acquired => {
566 return Ok(());
567 }
568 }
569
570 match semaphore.poll_permit(None).map_err(to_try_acquire)? {
571 Ready(()) => {
572 self.state = PermitState::Acquired;
573 Ok(())
574 }
575 Pending => Err(TryAcquireError::no_permits()),
576 }
577 }
578
579 pub(crate) fn release(&mut self, semaphore: &Semaphore) {
581 if self.forget2() {
582 semaphore.add_permits(1);
583 }
584 }
585
586 pub(crate) fn forget(&mut self) {
594 self.forget2();
595 }
596
597 fn forget2(&mut self) -> bool {
599 match self.state {
600 PermitState::Idle => false,
601 PermitState::Waiting => {
602 let ret = self.waiter.as_ref().unwrap().cancel_interest();
603 self.state = PermitState::Idle;
604 ret
605 }
606 PermitState::Acquired => {
607 self.state = PermitState::Idle;
608 true
609 }
610 }
611 }
612}
613
614impl Default for Permit {
615 fn default() -> Self {
616 Self::new()
617 }
618}
619
620impl AcquireError {
623 fn closed() -> AcquireError {
624 AcquireError(())
625 }
626}
627
628fn to_try_acquire(_: AcquireError) -> TryAcquireError {
629 TryAcquireError::closed()
630}
631
632impl fmt::Display for AcquireError {
633 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
634 write!(fmt, "semaphore closed")
635 }
636}
637
638impl ::std::error::Error for AcquireError {}
639
640impl TryAcquireError {
643 fn closed() -> TryAcquireError {
644 TryAcquireError {
645 kind: ErrorKind::Closed,
646 }
647 }
648
649 fn no_permits() -> TryAcquireError {
650 TryAcquireError {
651 kind: ErrorKind::NoPermits,
652 }
653 }
654
655 pub(crate) fn is_closed(&self) -> bool {
657 match self.kind {
658 ErrorKind::Closed => true,
659 _ => false,
660 }
661 }
662
663 pub(crate) fn is_no_permits(&self) -> bool {
666 match self.kind {
667 ErrorKind::NoPermits => true,
668 _ => false,
669 }
670 }
671}
672
673impl fmt::Display for TryAcquireError {
674 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
675 let descr = match self.kind {
676 ErrorKind::Closed => "semaphore closed",
677 ErrorKind::NoPermits => "no permits available",
678 };
679 write!(fmt, "{}", descr)
680 }
681}
682
683impl ::std::error::Error for TryAcquireError {}
684
685impl WaiterNode {
688 fn new() -> WaiterNode {
689 WaiterNode {
690 state: AtomicUsize::new(NodeState::new().to_usize()),
691 waker: AtomicWaker::new(),
692 next: AtomicPtr::new(ptr::null_mut()),
693 }
694 }
695
696 fn acquire(&self, cx: &mut Context<'_>) -> Result<bool, AcquireError> {
697 if self.acquire2()? {
698 return Ok(true);
699 }
700
701 self.waker.register_by_ref(cx.waker());
702
703 self.acquire2()
704 }
705
706 fn acquire2(&self) -> Result<bool, AcquireError> {
707 use self::NodeState::*;
708
709 match Idle.compare_exchange(&self.state, Assigned, AcqRel, Acquire) {
710 Ok(_) => Ok(true),
711 Err(Closed) => Err(AcquireError::closed()),
712 Err(_) => Ok(false),
713 }
714 }
715
716 fn register(&self, cx: &mut Context<'_>) {
717 self.waker.register_by_ref(cx.waker())
718 }
719
720 fn cancel_interest(&self) -> bool {
722 use self::NodeState::*;
723
724 match Queued.compare_exchange(&self.state, QueuedWaiting, AcqRel, Acquire) {
725 Ok(_) => false,
728 Err(Closed) => false,
731 Err(Assigned) => {
734 match self.acquire2() {
735 Ok(true) => true,
736 Ok(false) => panic!(),
738 Err(_) => false,
740 }
741 }
742 Err(state) => panic!("unexpected state = {:?}", state),
743 }
744 }
745
746 fn to_queued_waiting(&self) -> bool {
752 use self::NodeState::*;
753
754 let mut curr = NodeState::load(&self.state, Acquire);
755
756 loop {
757 debug_assert!(curr == Idle || curr == Queued, "actual = {:?}", curr);
758 let next = QueuedWaiting;
759
760 match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
761 Ok(_) => {
762 if curr.is_queued() {
763 return false;
764 } else {
765 self.next.store(ptr::null_mut(), Relaxed);
767 return true;
768 }
769 }
770 Err(actual) => {
771 curr = actual;
772 }
773 }
774 }
775 }
776
777 fn notify(&self, closed: bool) -> bool {
781 use self::NodeState::*;
782
783 let mut curr = QueuedWaiting;
785
786 loop {
787 let next = match curr {
788 Queued => Idle,
789 QueuedWaiting => {
790 if closed {
791 Closed
792 } else {
793 Assigned
794 }
795 }
796 actual => panic!("actual = {:?}", actual),
797 };
798
799 match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
800 Ok(_) => match curr {
801 QueuedWaiting => {
802 self.waker.wake();
803 return true;
804 }
805 _ => return false,
806 },
807 Err(actual) => curr = actual,
808 }
809 }
810 }
811
812 fn revert_to_idle(&self) {
813 use self::NodeState::Idle;
814
815 NodeState::store(&self.state, Idle, Relaxed);
817 }
818
819 #[allow(clippy::wrong_self_convention)] fn into_non_null(self: Arc<WaiterNode>) -> NonNull<WaiterNode> {
821 let ptr = Arc::into_raw(self);
822 unsafe { NonNull::new_unchecked(ptr as *mut _) }
823 }
824}
825
826const NUM_FLAG: usize = 0b01;
834
835const CLOSED_FLAG: usize = 0b10;
836
837const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
838
839const NUM_SHIFT: usize = 2;
842
843impl SemState {
844 fn new(permits: usize, stub: &WaiterNode) -> SemState {
846 assert!(permits <= MAX_PERMITS);
847
848 if permits > 0 {
849 SemState((permits << NUM_SHIFT) | NUM_FLAG)
850 } else {
851 SemState(stub as *const _ as usize)
852 }
853 }
854
855 fn new_ptr(tail: NonNull<WaiterNode>, closed: bool) -> SemState {
857 let mut val = tail.as_ptr() as usize;
858
859 if closed {
860 val |= CLOSED_FLAG;
861 }
862
863 SemState(val)
864 }
865
866 fn available_permits(self) -> usize {
868 if !self.has_available_permits() {
869 return 0;
870 }
871
872 self.0 >> NUM_SHIFT
873 }
874
875 fn has_available_permits(self) -> bool {
877 self.0 & NUM_FLAG == NUM_FLAG
878 }
879
880 fn has_waiter(self, stub: &WaiterNode) -> bool {
881 !self.has_available_permits() && !self.is_stub(stub)
882 }
883
884 fn acquire_permit(&mut self, stub: &WaiterNode) -> bool {
892 if !self.has_available_permits() {
893 return false;
894 }
895
896 debug_assert!(self.waiter().is_none());
897
898 self.0 -= 1 << NUM_SHIFT;
899
900 if self.0 == NUM_FLAG {
901 self.0 = stub as *const _ as usize;
903 }
904
905 true
906 }
907
908 fn release_permits(&mut self, permits: usize, stub: &WaiterNode) {
912 debug_assert!(permits > 0);
913
914 if self.is_stub(stub) {
915 self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
916 return;
917 }
918
919 debug_assert!(self.has_available_permits());
920
921 self.0 += permits << NUM_SHIFT;
922 }
923
924 fn is_waiter(self) -> bool {
925 self.0 & NUM_FLAG == 0
926 }
927
928 fn waiter(self) -> Option<NonNull<WaiterNode>> {
930 if self.is_waiter() {
931 let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
932
933 Some(waiter)
934 } else {
935 None
936 }
937 }
938
939 fn as_ptr(self) -> *mut WaiterNode {
941 (self.0 & !CLOSED_FLAG) as *mut WaiterNode
942 }
943
944 fn set_waiter(&mut self, waiter: NonNull<WaiterNode>) {
948 let waiter = waiter.as_ptr() as usize;
949 debug_assert!(waiter & NUM_FLAG == 0);
950 debug_assert!(!self.is_closed());
951
952 self.0 = waiter;
953 }
954
955 fn is_stub(self, stub: &WaiterNode) -> bool {
956 self.as_ptr() as usize == stub as *const _ as usize
957 }
958
959 fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
961 let value = cell.load(ordering);
962 SemState(value)
963 }
964
965 fn swap(self, cell: &AtomicUsize, ordering: Ordering) -> SemState {
967 let prev = SemState(cell.swap(self.to_usize(), ordering));
968 debug_assert_eq!(prev.is_closed(), self.is_closed());
969 prev
970 }
971
972 fn compare_exchange(
974 self,
975 cell: &AtomicUsize,
976 prev: SemState,
977 success: Ordering,
978 failure: Ordering,
979 ) -> Result<SemState, SemState> {
980 debug_assert_eq!(prev.is_closed(), self.is_closed());
981
982 let res = cell.compare_exchange(prev.to_usize(), self.to_usize(), success, failure);
983
984 res.map(SemState).map_err(SemState)
985 }
986
987 fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
988 let value = cell.fetch_or(CLOSED_FLAG, ordering);
989 SemState(value)
990 }
991
992 fn is_closed(self) -> bool {
993 self.0 & CLOSED_FLAG == CLOSED_FLAG
994 }
995
996 fn to_usize(self) -> usize {
998 self.0
999 }
1000}
1001
1002impl fmt::Debug for SemState {
1003 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1004 let mut fmt = fmt.debug_struct("SemState");
1005
1006 if self.is_waiter() {
1007 fmt.field("state", &"<waiter>");
1008 } else {
1009 fmt.field("permits", &self.available_permits());
1010 }
1011
1012 fmt.finish()
1013 }
1014}
1015
1016impl NodeState {
1019 fn new() -> NodeState {
1020 NodeState::Idle
1021 }
1022
1023 fn from_usize(value: usize) -> NodeState {
1024 use self::NodeState::*;
1025
1026 match value {
1027 0 => Idle,
1028 1 => Queued,
1029 2 => QueuedWaiting,
1030 3 => Assigned,
1031 4 => Closed,
1032 _ => panic!(),
1033 }
1034 }
1035
1036 fn load(cell: &AtomicUsize, ordering: Ordering) -> NodeState {
1037 NodeState::from_usize(cell.load(ordering))
1038 }
1039
1040 fn store(cell: &AtomicUsize, value: NodeState, ordering: Ordering) {
1042 cell.store(value.to_usize(), ordering);
1043 }
1044
1045 fn compare_exchange(
1046 self,
1047 cell: &AtomicUsize,
1048 prev: NodeState,
1049 success: Ordering,
1050 failure: Ordering,
1051 ) -> Result<NodeState, NodeState> {
1052 cell.compare_exchange(prev.to_usize(), self.to_usize(), success, failure)
1053 .map(NodeState::from_usize)
1054 .map_err(NodeState::from_usize)
1055 }
1056
1057 fn is_queued(self) -> bool {
1059 use self::NodeState::*;
1060
1061 match self {
1062 Queued | QueuedWaiting => true,
1063 _ => false,
1064 }
1065 }
1066
1067 fn to_usize(self) -> usize {
1068 self as usize
1069 }
1070}