1use std::cell::UnsafeCell;
4use std::marker::PhantomPinned;
5use std::pin::Pin;
6use std::ptr::NonNull;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, Mutex, MutexGuard};
9use std::task::{Context, Poll, Waker};
10
11use futures_util::Future;
12
13use crate::util::linked_list::{Link, LinkedList, Pointers};
14use crate::util::wake_list::WakeList;
15
16pub struct PrioritySemaphore {
17 waiters: Mutex<Waitlist>,
18 permits: AtomicUsize,
19}
20
21impl PrioritySemaphore {
22 const MAX_PERMITS: usize = usize::MAX >> 3;
23 const CLOSED: usize = 1;
24 const PERMIT_SHIFT: usize = 1;
25
26 pub fn new(permits: usize) -> Self {
27 assert!(
28 permits <= Self::MAX_PERMITS,
29 "a semaphore may not have more than MAX_PERMITS permits ({})",
30 Self::MAX_PERMITS
31 );
32
33 Self {
34 permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
35 waiters: Mutex::new(Waitlist {
36 ordinary_queue: LinkedList::new(),
37 priority_queue: LinkedList::new(),
38 closed: false,
39 }),
40 }
41 }
42
43 pub const fn const_new(permits: usize) -> Self {
44 assert!(permits <= Self::MAX_PERMITS);
45
46 Self {
47 permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
48 waiters: Mutex::new(Waitlist {
49 ordinary_queue: LinkedList::new(),
50 priority_queue: LinkedList::new(),
51 closed: false,
52 }),
53 }
54 }
55
56 pub fn available_permits(&self) -> usize {
57 self.permits.load(Ordering::Acquire) >> Self::PERMIT_SHIFT
58 }
59
60 pub fn close(&self) {
61 fn clear_queue(queue: &mut LinkedList<Waiter, <Waiter as Link>::Target>) {
62 while let Some(mut waiter) = queue.pop_back() {
63 let waker = unsafe { (*waiter.as_mut().waker.get()).take() };
64 if let Some(waker) = waker {
65 waker.wake();
66 }
67 }
68 }
69
70 let mut waiters = self.waiters.lock().unwrap();
71
72 self.permits.fetch_or(Self::CLOSED, Ordering::Release);
73 waiters.closed = true;
74
75 clear_queue(&mut waiters.ordinary_queue);
76 clear_queue(&mut waiters.priority_queue);
77 }
78
79 pub fn is_closed(&self) -> bool {
80 self.permits.load(Ordering::Acquire) & Self::CLOSED == Self::CLOSED
81 }
82
83 pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
84 self.try_acquire_impl(1).map(|()| SemaphorePermit {
85 semaphore: self,
86 permits: 1,
87 })
88 }
89
90 pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> {
91 self.try_acquire_impl(1).map(|()| OwnedSemaphorePermit {
92 semaphore: self,
93 permits: 1,
94 })
95 }
96
97 pub async fn acquire(&self, priority: bool) -> Result<SemaphorePermit<'_>, AcquireError> {
98 match self.acquire_impl(1, priority).await {
99 Ok(()) => Ok(SemaphorePermit {
100 semaphore: self,
101 permits: 1,
102 }),
103 Err(e) => Err(e),
104 }
105 }
106
107 pub async fn acquire_owned(
108 self: Arc<Self>,
109 priority: bool,
110 ) -> Result<OwnedSemaphorePermit, AcquireError> {
111 match self.acquire_impl(1, priority).await {
112 Ok(()) => Ok(OwnedSemaphorePermit {
113 semaphore: self,
114 permits: 1,
115 }),
116 Err(e) => Err(e),
117 }
118 }
119
120 pub fn add_permits(&self, n: usize) {
121 if n == 0 {
122 return;
123 }
124
125 self.add_permits_locked(n, self.waiters.lock().unwrap());
127 }
128
129 fn try_acquire_impl(&self, num_permits: usize) -> Result<(), TryAcquireError> {
130 assert!(
131 num_permits <= Self::MAX_PERMITS,
132 "a semaphore may not have more than MAX_PERMITS permits ({})",
133 Self::MAX_PERMITS
134 );
135
136 let num_permits = num_permits << Self::PERMIT_SHIFT;
137 let mut curr = self.permits.load(Ordering::Acquire);
138 loop {
139 if curr & Self::CLOSED == Self::CLOSED {
141 return Err(TryAcquireError::Closed);
142 }
143
144 if curr < num_permits {
146 return Err(TryAcquireError::NoPermits);
147 }
148
149 let next = curr - num_permits;
150
151 match self
152 .permits
153 .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire)
154 {
155 Ok(_) => return Ok(()),
156 Err(actual) => curr = actual,
157 }
158 }
159 }
160
161 fn acquire_impl(&self, num_permits: usize, priority: bool) -> Acquire<'_> {
162 Acquire::new(self, num_permits, priority)
163 }
164
165 fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) {
166 let mut wakers = WakeList::new();
167 let mut lock = Some(waiters);
168 let mut is_empty = false;
169 while rem > 0 {
170 let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
171
172 {
173 let waiters = &mut *waiters;
174 'inner: while wakers.can_push() {
175 let queue = 'queue: {
177 for queue in [&mut waiters.priority_queue, &mut waiters.ordinary_queue] {
178 if let Some(waiter) = queue.last() {
179 if !waiter.assign_permits(&mut rem) {
180 continue;
181 }
182 break 'queue queue;
183 }
184 }
185
186 is_empty = true;
187 break 'inner;
190 };
191
192 let mut waiter = queue.pop_back().unwrap();
193 if let Some(waker) = unsafe { (*waiter.as_mut().waker.get()).take() } {
194 wakers.push(waker);
195 }
196 }
197 }
198
199 if rem > 0 && is_empty {
200 let permits = rem;
201 assert!(
202 permits <= Self::MAX_PERMITS,
203 "cannot add more than MAX_PERMITS permits ({})",
204 Self::MAX_PERMITS
205 );
206 let prev = self
207 .permits
208 .fetch_add(rem << Self::PERMIT_SHIFT, Ordering::Release);
209 let prev = prev >> Self::PERMIT_SHIFT;
210 assert!(
211 prev + permits <= Self::MAX_PERMITS,
212 "number of added permits ({}) would overflow MAX_PERMITS ({})",
213 rem,
214 Self::MAX_PERMITS
215 );
216
217 rem = 0;
218 }
219
220 drop(waiters); wakers.wake_all();
223 }
224
225 assert_eq!(rem, 0);
226 }
227
228 fn poll_acquire(
229 &self,
230 cx: &mut Context<'_>,
231 num_permits: usize,
232 node: Pin<&mut Waiter>,
233 queued: bool,
234 priority: bool,
235 ) -> Poll<Result<(), AcquireError>> {
236 let mut acquired = 0;
237
238 let needed = if queued {
239 node.state.load(Ordering::Acquire) << Self::PERMIT_SHIFT
240 } else {
241 num_permits << Self::PERMIT_SHIFT
242 };
243
244 let mut lock = None;
245 let mut curr = self.permits.load(Ordering::Acquire);
248 let mut waiters = loop {
249 if curr & Self::CLOSED > 0 {
251 return Poll::Ready(Err(AcquireError(())));
252 }
253
254 let mut remaining = 0;
255 let total = curr
256 .checked_add(acquired)
257 .expect("number of permits must not overflow");
258 let (next, acq) = if total >= needed {
259 let next = curr - (needed - acquired);
260 (next, needed >> Self::PERMIT_SHIFT)
261 } else {
262 remaining = (needed - acquired) - curr;
263 (0, curr >> Self::PERMIT_SHIFT)
264 };
265
266 if remaining > 0 && lock.is_none() {
267 lock = Some(self.waiters.lock().unwrap());
275 }
276
277 match self
278 .permits
279 .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire)
280 {
281 Ok(_) => {
282 acquired += acq;
283 if remaining == 0 {
284 if !queued {
285 return Poll::Ready(Ok(()));
286 } else if lock.is_none() {
287 break self.waiters.lock().unwrap();
288 }
289 }
290 break lock.expect("lock must be acquired before waiting");
291 }
292 Err(actual) => curr = actual,
293 }
294 };
295
296 if waiters.closed {
297 return Poll::Ready(Err(AcquireError(())));
298 }
299
300 if node.assign_permits(&mut acquired) {
301 self.add_permits_locked(acquired, waiters);
302 return Poll::Ready(Ok(()));
303 }
304
305 assert_eq!(acquired, 0);
306 let mut old_waker = None;
307
308 {
310 let waker = unsafe { &mut *node.waker.get() };
312
313 if waker
315 .as_ref()
316 .is_none_or(|waker| !waker.will_wake(cx.waker()))
317 {
318 old_waker = waker.replace(cx.waker().clone());
319 }
320 }
321
322 if !queued {
324 let node = unsafe {
325 let node = Pin::into_inner_unchecked(node) as *mut _;
326 NonNull::new_unchecked(node)
327 };
328
329 waiters.queue_mut(priority).push_front(node);
330 }
331 drop(waiters);
332 drop(old_waker);
333
334 Poll::Pending
335 }
336}
337
338#[must_use]
339#[clippy::has_significant_drop]
340pub struct SemaphorePermit<'a> {
341 semaphore: &'a PrioritySemaphore,
342 permits: u32,
343}
344
345impl Drop for SemaphorePermit<'_> {
346 fn drop(&mut self) {
347 self.semaphore.add_permits(self.permits as usize);
348 }
349}
350
351#[must_use]
352#[clippy::has_significant_drop]
353pub struct OwnedSemaphorePermit {
354 semaphore: Arc<PrioritySemaphore>,
355 permits: u32,
356}
357
358impl Drop for OwnedSemaphorePermit {
359 fn drop(&mut self) {
360 self.semaphore.add_permits(self.permits as usize);
361 }
362}
363
364struct Acquire<'a> {
365 node: Waiter,
366 semaphore: &'a PrioritySemaphore,
367 num_permits: usize,
368 queued: bool,
369 priority: bool,
370}
371
372impl<'a> Acquire<'a> {
373 fn new(semaphore: &'a PrioritySemaphore, num_permits: usize, priority: bool) -> Self {
374 Self {
375 node: Waiter::new(num_permits),
376 semaphore,
377 num_permits,
378 queued: false,
379 priority,
380 }
381 }
382
383 fn project(
384 self: Pin<&mut Self>,
385 ) -> (Pin<&mut Waiter>, &PrioritySemaphore, usize, &mut bool, bool) {
386 fn is_unpin<T: Unpin>() {}
387 unsafe {
388 is_unpin::<&PrioritySemaphore>();
391 is_unpin::<&mut bool>();
392 is_unpin::<usize>();
393
394 let this = self.get_unchecked_mut();
395 (
396 Pin::new_unchecked(&mut this.node),
397 this.semaphore,
398 this.num_permits,
399 &mut this.queued,
400 this.priority,
401 )
402 }
403 }
404}
405
406impl Drop for Acquire<'_> {
407 fn drop(&mut self) {
408 if !self.queued {
409 return;
410 }
411
412 let mut waiters = self.semaphore.waiters.lock().unwrap();
413
414 let node = NonNull::from(&mut self.node);
415 unsafe { waiters.queue_mut(self.priority).remove(node) };
417
418 let acquired_permits = self.num_permits - self.node.state.load(Ordering::Acquire);
419 if acquired_permits > 0 {
420 self.semaphore.add_permits_locked(acquired_permits, waiters);
421 }
422 }
423}
424
425unsafe impl Sync for Acquire<'_> {}
431
432impl Future for Acquire<'_> {
433 type Output = Result<(), AcquireError>;
434
435 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436 let (node, semaphore, needed, queued, priority) = self.project();
437
438 match semaphore.poll_acquire(cx, needed, node, *queued, priority) {
439 Poll::Pending => {
440 *queued = true;
441 Poll::Pending
442 }
443 Poll::Ready(r) => {
444 r?;
445 *queued = false;
446 Poll::Ready(Ok(()))
447 }
448 }
449 }
450}
451
452#[derive(Debug, thiserror::Error)]
453#[error("semaphore closed")]
454pub struct AcquireError(());
455
456#[derive(Debug, PartialEq, Eq, thiserror::Error)]
457pub enum TryAcquireError {
458 #[error("semaphore closed")]
462 Closed,
463
464 #[error("no permits available")]
466 NoPermits,
467}
468
469struct Waitlist {
470 ordinary_queue: LinkedList<Waiter, <Waiter as Link>::Target>,
471 priority_queue: LinkedList<Waiter, <Waiter as Link>::Target>,
472 closed: bool,
473}
474
475impl Waitlist {
476 fn queue_mut(&mut self, priority: bool) -> &mut LinkedList<Waiter, <Waiter as Link>::Target> {
477 if priority {
478 &mut self.priority_queue
479 } else {
480 &mut self.ordinary_queue
481 }
482 }
483}
484
485struct Waiter {
486 state: AtomicUsize,
487 waker: UnsafeCell<Option<Waker>>,
488 pointers: Pointers<Waiter>,
489 _pin: PhantomPinned,
490}
491
492impl Waiter {
493 fn new(num_permits: usize) -> Self {
494 Waiter {
495 state: AtomicUsize::new(num_permits),
496 waker: UnsafeCell::new(None),
497 pointers: Pointers::new(),
498 _pin: PhantomPinned,
499 }
500 }
501
502 fn assign_permits(&self, n: &mut usize) -> bool {
506 let mut curr = self.state.load(Ordering::Acquire);
507 loop {
508 let assign = std::cmp::min(curr, *n);
509 let next = curr - assign;
510 match self
511 .state
512 .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire)
513 {
514 Ok(_) => {
515 *n -= assign;
516 return next == 0;
517 }
518 Err(actual) => curr = actual,
519 }
520 }
521 }
522
523 unsafe fn addr_of_pointers(target: NonNull<Waiter>) -> NonNull<Pointers<Self>> {
524 let target = target.as_ptr();
525 let field = unsafe { std::ptr::addr_of_mut!((*target).pointers) };
526 unsafe { NonNull::new_unchecked(field) }
527 }
528}
529
530unsafe impl Link for Waiter {
531 type Handle = NonNull<Self>;
532 type Target = Self;
533
534 #[inline]
535 fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> {
536 *handle
537 }
538
539 #[inline]
540 unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle {
541 ptr
542 }
543
544 #[inline]
545 unsafe fn pointers(target: NonNull<Self::Target>) -> NonNull<Pointers<Self::Target>> {
546 unsafe { Self::addr_of_pointers(target) }
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use std::sync::Arc;
553 use std::sync::atomic::{AtomicBool, Ordering};
554 use std::time::Duration;
555
556 use super::*;
557
558 #[tokio::test(flavor = "multi_thread")]
559 async fn priority_semaphore_works() {
560 let permits = Arc::new(PrioritySemaphore::new(1));
561
562 let flag = Arc::new(AtomicBool::new(false));
563
564 tokio::spawn({
565 let permits = permits.clone();
566 async move {
567 println!("BACKGROUND BEFORE");
568 let _guard = permits.acquire(false).await.unwrap();
569 println!("BACKGROUND AFTER");
570 tokio::time::sleep(Duration::from_millis(100)).await;
571 println!("BACKGROUND FINISH");
572 }
573 });
574
575 tokio::time::sleep(Duration::from_micros(10)).await;
576
577 let ordinary_task = tokio::spawn({
579 let permits = permits.clone();
580 let flag = flag.clone();
581 async move {
582 println!("ORDINARY BEFORE");
583 let _guard = permits.acquire(false).await.unwrap();
584 println!("ORDINARY AFTER");
585 assert!(flag.load(Ordering::Acquire));
587 }
588 });
589
590 tokio::time::sleep(Duration::from_micros(10)).await;
591
592 let priority_task = tokio::spawn({
593 let flag = flag.clone();
594 async move {
595 println!("PRIORITY BEFORE");
596 let _guard = permits.acquire(true).await.unwrap();
597 println!("PRIORITY");
598 flag.store(true, Ordering::Release);
599 }
600 });
601
602 ordinary_task.await.unwrap();
603 priority_task.await.unwrap();
604 }
605
606 #[tokio::test(flavor = "multi_thread")]
607 async fn priority_semaphore_is_fair() {
608 let permits = Arc::new(PrioritySemaphore::new(10));
609
610 let flag = AtomicBool::new(false);
611 tokio::join!(
612 non_cooperative_task(permits, &flag),
613 poor_little_task(&flag),
614 );
615 }
616
617 async fn non_cooperative_task(permits: Arc<PrioritySemaphore>, flag: &AtomicBool) {
618 while !flag.load(Ordering::Acquire) {
619 let _permit = permits.acquire(false).await.unwrap();
620
621 tokio::task::yield_now().await;
623 }
624 }
625
626 async fn poor_little_task(flag: &AtomicBool) {
627 tokio::time::sleep(Duration::from_secs(1)).await;
628 flag.store(true, Ordering::Release);
629 }
630}