tycho_util/sync/
priority_semaphore.rs

1//! See <https://github.com/tokio-rs/tokio/blob/c9273f1aee9927b16ee3a789a382c99ad600c8b6/tokio/src/sync/batch_semaphore.rs>.
2
3use std::cell::UnsafeCell;
4use std::marker::PhantomPinned;
5use std::pin::Pin;
6use std::ptr::NonNull;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, Mutex, MutexGuard};
9use std::task::{Context, Poll, Waker};
10
11use futures_util::Future;
12
13use crate::util::linked_list::{Link, LinkedList, Pointers};
14use crate::util::wake_list::WakeList;
15
16pub struct PrioritySemaphore {
17    waiters: Mutex<Waitlist>,
18    permits: AtomicUsize,
19}
20
21impl PrioritySemaphore {
22    const MAX_PERMITS: usize = usize::MAX >> 3;
23    const CLOSED: usize = 1;
24    const PERMIT_SHIFT: usize = 1;
25
26    pub fn new(permits: usize) -> Self {
27        assert!(
28            permits <= Self::MAX_PERMITS,
29            "a semaphore may not have more than MAX_PERMITS permits ({})",
30            Self::MAX_PERMITS
31        );
32
33        Self {
34            permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
35            waiters: Mutex::new(Waitlist {
36                ordinary_queue: LinkedList::new(),
37                priority_queue: LinkedList::new(),
38                closed: false,
39            }),
40        }
41    }
42
43    pub const fn const_new(permits: usize) -> Self {
44        assert!(permits <= Self::MAX_PERMITS);
45
46        Self {
47            permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
48            waiters: Mutex::new(Waitlist {
49                ordinary_queue: LinkedList::new(),
50                priority_queue: LinkedList::new(),
51                closed: false,
52            }),
53        }
54    }
55
56    pub fn available_permits(&self) -> usize {
57        self.permits.load(Ordering::Acquire) >> Self::PERMIT_SHIFT
58    }
59
60    pub fn close(&self) {
61        fn clear_queue(queue: &mut LinkedList<Waiter, <Waiter as Link>::Target>) {
62            while let Some(mut waiter) = queue.pop_back() {
63                let waker = unsafe { (*waiter.as_mut().waker.get()).take() };
64                if let Some(waker) = waker {
65                    waker.wake();
66                }
67            }
68        }
69
70        let mut waiters = self.waiters.lock().unwrap();
71
72        self.permits.fetch_or(Self::CLOSED, Ordering::Release);
73        waiters.closed = true;
74
75        clear_queue(&mut waiters.ordinary_queue);
76        clear_queue(&mut waiters.priority_queue);
77    }
78
79    pub fn is_closed(&self) -> bool {
80        self.permits.load(Ordering::Acquire) & Self::CLOSED == Self::CLOSED
81    }
82
83    pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
84        self.try_acquire_impl(1).map(|()| SemaphorePermit {
85            semaphore: self,
86            permits: 1,
87        })
88    }
89
90    pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> {
91        self.try_acquire_impl(1).map(|()| OwnedSemaphorePermit {
92            semaphore: self,
93            permits: 1,
94        })
95    }
96
97    pub async fn acquire(&self, priority: bool) -> Result<SemaphorePermit<'_>, AcquireError> {
98        match self.acquire_impl(1, priority).await {
99            Ok(()) => Ok(SemaphorePermit {
100                semaphore: self,
101                permits: 1,
102            }),
103            Err(e) => Err(e),
104        }
105    }
106
107    pub async fn acquire_owned(
108        self: Arc<Self>,
109        priority: bool,
110    ) -> Result<OwnedSemaphorePermit, AcquireError> {
111        match self.acquire_impl(1, priority).await {
112            Ok(()) => Ok(OwnedSemaphorePermit {
113                semaphore: self,
114                permits: 1,
115            }),
116            Err(e) => Err(e),
117        }
118    }
119
120    pub fn add_permits(&self, n: usize) {
121        if n == 0 {
122            return;
123        }
124
125        // Assign permits to the wait queue
126        self.add_permits_locked(n, self.waiters.lock().unwrap());
127    }
128
129    fn try_acquire_impl(&self, num_permits: usize) -> Result<(), TryAcquireError> {
130        assert!(
131            num_permits <= Self::MAX_PERMITS,
132            "a semaphore may not have more than MAX_PERMITS permits ({})",
133            Self::MAX_PERMITS
134        );
135
136        let num_permits = num_permits << Self::PERMIT_SHIFT;
137        let mut curr = self.permits.load(Ordering::Acquire);
138        loop {
139            // Has the semaphore closed?
140            if curr & Self::CLOSED == Self::CLOSED {
141                return Err(TryAcquireError::Closed);
142            }
143
144            // Are there enough permits remaining?
145            if curr < num_permits {
146                return Err(TryAcquireError::NoPermits);
147            }
148
149            let next = curr - num_permits;
150
151            match self
152                .permits
153                .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire)
154            {
155                Ok(_) => return Ok(()),
156                Err(actual) => curr = actual,
157            }
158        }
159    }
160
161    fn acquire_impl(&self, num_permits: usize, priority: bool) -> Acquire<'_> {
162        Acquire::new(self, num_permits, priority)
163    }
164
165    fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) {
166        let mut wakers = WakeList::new();
167        let mut lock = Some(waiters);
168        let mut is_empty = false;
169        while rem > 0 {
170            let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
171
172            {
173                let waiters = &mut *waiters;
174                'inner: while wakers.can_push() {
175                    // Was the waiter assigned enough permits to wake it?
176                    let queue = 'queue: {
177                        for queue in [&mut waiters.priority_queue, &mut waiters.ordinary_queue] {
178                            if let Some(waiter) = queue.last() {
179                                if !waiter.assign_permits(&mut rem) {
180                                    continue;
181                                }
182                                break 'queue queue;
183                            }
184                        }
185
186                        is_empty = true;
187                        // If we assigned permits to all the waiters in the queue, and there are
188                        // still permits left over, assign them back to the semaphore.
189                        break 'inner;
190                    };
191
192                    let mut waiter = queue.pop_back().unwrap();
193                    if let Some(waker) = unsafe { (*waiter.as_mut().waker.get()).take() } {
194                        wakers.push(waker);
195                    }
196                }
197            }
198
199            if rem > 0 && is_empty {
200                let permits = rem;
201                assert!(
202                    permits <= Self::MAX_PERMITS,
203                    "cannot add more than MAX_PERMITS permits ({})",
204                    Self::MAX_PERMITS
205                );
206                let prev = self
207                    .permits
208                    .fetch_add(rem << Self::PERMIT_SHIFT, Ordering::Release);
209                let prev = prev >> Self::PERMIT_SHIFT;
210                assert!(
211                    prev + permits <= Self::MAX_PERMITS,
212                    "number of added permits ({}) would overflow MAX_PERMITS ({})",
213                    rem,
214                    Self::MAX_PERMITS
215                );
216
217                rem = 0;
218            }
219
220            drop(waiters); // release the lock
221
222            wakers.wake_all();
223        }
224
225        assert_eq!(rem, 0);
226    }
227
228    fn poll_acquire(
229        &self,
230        cx: &mut Context<'_>,
231        num_permits: usize,
232        node: Pin<&mut Waiter>,
233        queued: bool,
234        priority: bool,
235    ) -> Poll<Result<(), AcquireError>> {
236        let mut acquired = 0;
237
238        let needed = if queued {
239            node.state.load(Ordering::Acquire) << Self::PERMIT_SHIFT
240        } else {
241            num_permits << Self::PERMIT_SHIFT
242        };
243
244        let mut lock = None;
245        // First, try to take the requested number of permits from the
246        // semaphore.
247        let mut curr = self.permits.load(Ordering::Acquire);
248        let mut waiters = loop {
249            // Has the semaphore closed?
250            if curr & Self::CLOSED > 0 {
251                return Poll::Ready(Err(AcquireError(())));
252            }
253
254            let mut remaining = 0;
255            let total = curr
256                .checked_add(acquired)
257                .expect("number of permits must not overflow");
258            let (next, acq) = if total >= needed {
259                let next = curr - (needed - acquired);
260                (next, needed >> Self::PERMIT_SHIFT)
261            } else {
262                remaining = (needed - acquired) - curr;
263                (0, curr >> Self::PERMIT_SHIFT)
264            };
265
266            if remaining > 0 && lock.is_none() {
267                // No permits were immediately available, so this permit will
268                // (probably) need to wait. We'll need to acquire a lock on the
269                // wait queue before continuing. We need to do this _before_ the
270                // CAS that sets the new value of the semaphore's `permits`
271                // counter. Otherwise, if we subtract the permits and then
272                // acquire the lock, we might miss additional permits being
273                // added while waiting for the lock.
274                lock = Some(self.waiters.lock().unwrap());
275            }
276
277            match self
278                .permits
279                .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire)
280            {
281                Ok(_) => {
282                    acquired += acq;
283                    if remaining == 0 {
284                        if !queued {
285                            return Poll::Ready(Ok(()));
286                        } else if lock.is_none() {
287                            break self.waiters.lock().unwrap();
288                        }
289                    }
290                    break lock.expect("lock must be acquired before waiting");
291                }
292                Err(actual) => curr = actual,
293            }
294        };
295
296        if waiters.closed {
297            return Poll::Ready(Err(AcquireError(())));
298        }
299
300        if node.assign_permits(&mut acquired) {
301            self.add_permits_locked(acquired, waiters);
302            return Poll::Ready(Ok(()));
303        }
304
305        assert_eq!(acquired, 0);
306        let mut old_waker = None;
307
308        // Otherwise, register the waker & enqueue the node.
309        {
310            // SAFETY: the wait list is locked, so we may modify the waker.
311            let waker = unsafe { &mut *node.waker.get() };
312
313            // Do we need to register the new waker?
314            if waker
315                .as_ref()
316                .is_none_or(|waker| !waker.will_wake(cx.waker()))
317            {
318                old_waker = waker.replace(cx.waker().clone());
319            }
320        }
321
322        // If the waiter is not already in the wait queue, enqueue it.
323        if !queued {
324            let node = unsafe {
325                let node = Pin::into_inner_unchecked(node) as *mut _;
326                NonNull::new_unchecked(node)
327            };
328
329            waiters.queue_mut(priority).push_front(node);
330        }
331        drop(waiters);
332        drop(old_waker);
333
334        Poll::Pending
335    }
336}
337
338#[must_use]
339#[clippy::has_significant_drop]
340pub struct SemaphorePermit<'a> {
341    semaphore: &'a PrioritySemaphore,
342    permits: u32,
343}
344
345impl Drop for SemaphorePermit<'_> {
346    fn drop(&mut self) {
347        self.semaphore.add_permits(self.permits as usize);
348    }
349}
350
351#[must_use]
352#[clippy::has_significant_drop]
353pub struct OwnedSemaphorePermit {
354    semaphore: Arc<PrioritySemaphore>,
355    permits: u32,
356}
357
358impl Drop for OwnedSemaphorePermit {
359    fn drop(&mut self) {
360        self.semaphore.add_permits(self.permits as usize);
361    }
362}
363
364struct Acquire<'a> {
365    node: Waiter,
366    semaphore: &'a PrioritySemaphore,
367    num_permits: usize,
368    queued: bool,
369    priority: bool,
370}
371
372impl<'a> Acquire<'a> {
373    fn new(semaphore: &'a PrioritySemaphore, num_permits: usize, priority: bool) -> Self {
374        Self {
375            node: Waiter::new(num_permits),
376            semaphore,
377            num_permits,
378            queued: false,
379            priority,
380        }
381    }
382
383    fn project(
384        self: Pin<&mut Self>,
385    ) -> (Pin<&mut Waiter>, &PrioritySemaphore, usize, &mut bool, bool) {
386        fn is_unpin<T: Unpin>() {}
387        unsafe {
388            // SAFETY: all fields other than `node` are `Unpin`
389
390            is_unpin::<&PrioritySemaphore>();
391            is_unpin::<&mut bool>();
392            is_unpin::<usize>();
393
394            let this = self.get_unchecked_mut();
395            (
396                Pin::new_unchecked(&mut this.node),
397                this.semaphore,
398                this.num_permits,
399                &mut this.queued,
400                this.priority,
401            )
402        }
403    }
404}
405
406impl Drop for Acquire<'_> {
407    fn drop(&mut self) {
408        if !self.queued {
409            return;
410        }
411
412        let mut waiters = self.semaphore.waiters.lock().unwrap();
413
414        let node = NonNull::from(&mut self.node);
415        // SAFETY: we have locked the wait list.
416        unsafe { waiters.queue_mut(self.priority).remove(node) };
417
418        let acquired_permits = self.num_permits - self.node.state.load(Ordering::Acquire);
419        if acquired_permits > 0 {
420            self.semaphore.add_permits_locked(acquired_permits, waiters);
421        }
422    }
423}
424
425// SAFETY: the `Acquire` future is not `Sync` automatically because it contains
426// a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the
427// `UnsafeCell` is only accessed when the future is borrowed mutably (either in
428// `poll` or in `drop`). Therefore, it is safe (although not particularly
429// _useful_) for the future to be borrowed immutably across threads.
430unsafe impl Sync for Acquire<'_> {}
431
432impl Future for Acquire<'_> {
433    type Output = Result<(), AcquireError>;
434
435    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436        let (node, semaphore, needed, queued, priority) = self.project();
437
438        match semaphore.poll_acquire(cx, needed, node, *queued, priority) {
439            Poll::Pending => {
440                *queued = true;
441                Poll::Pending
442            }
443            Poll::Ready(r) => {
444                r?;
445                *queued = false;
446                Poll::Ready(Ok(()))
447            }
448        }
449    }
450}
451
452#[derive(Debug, thiserror::Error)]
453#[error("semaphore closed")]
454pub struct AcquireError(());
455
456#[derive(Debug, PartialEq, Eq, thiserror::Error)]
457pub enum TryAcquireError {
458    /// The semaphore has been [closed] and cannot issue new permits.
459    ///
460    /// [closed]: crate::sync::PrioritySemaphore::close
461    #[error("semaphore closed")]
462    Closed,
463
464    /// The semaphore has no available permits.
465    #[error("no permits available")]
466    NoPermits,
467}
468
469struct Waitlist {
470    ordinary_queue: LinkedList<Waiter, <Waiter as Link>::Target>,
471    priority_queue: LinkedList<Waiter, <Waiter as Link>::Target>,
472    closed: bool,
473}
474
475impl Waitlist {
476    fn queue_mut(&mut self, priority: bool) -> &mut LinkedList<Waiter, <Waiter as Link>::Target> {
477        if priority {
478            &mut self.priority_queue
479        } else {
480            &mut self.ordinary_queue
481        }
482    }
483}
484
485struct Waiter {
486    state: AtomicUsize,
487    waker: UnsafeCell<Option<Waker>>,
488    pointers: Pointers<Waiter>,
489    _pin: PhantomPinned,
490}
491
492impl Waiter {
493    fn new(num_permits: usize) -> Self {
494        Waiter {
495            state: AtomicUsize::new(num_permits),
496            waker: UnsafeCell::new(None),
497            pointers: Pointers::new(),
498            _pin: PhantomPinned,
499        }
500    }
501
502    /// Assign permits to the waiter.
503    ///
504    /// Returns `true` if the waiter should be removed from the queue
505    fn assign_permits(&self, n: &mut usize) -> bool {
506        let mut curr = self.state.load(Ordering::Acquire);
507        loop {
508            let assign = std::cmp::min(curr, *n);
509            let next = curr - assign;
510            match self
511                .state
512                .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire)
513            {
514                Ok(_) => {
515                    *n -= assign;
516                    return next == 0;
517                }
518                Err(actual) => curr = actual,
519            }
520        }
521    }
522
523    unsafe fn addr_of_pointers(target: NonNull<Waiter>) -> NonNull<Pointers<Self>> {
524        let target = target.as_ptr();
525        let field = unsafe { std::ptr::addr_of_mut!((*target).pointers) };
526        unsafe { NonNull::new_unchecked(field) }
527    }
528}
529
530unsafe impl Link for Waiter {
531    type Handle = NonNull<Self>;
532    type Target = Self;
533
534    #[inline]
535    fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> {
536        *handle
537    }
538
539    #[inline]
540    unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle {
541        ptr
542    }
543
544    #[inline]
545    unsafe fn pointers(target: NonNull<Self::Target>) -> NonNull<Pointers<Self::Target>> {
546        unsafe { Self::addr_of_pointers(target) }
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use std::sync::Arc;
553    use std::sync::atomic::{AtomicBool, Ordering};
554    use std::time::Duration;
555
556    use super::*;
557
558    #[tokio::test(flavor = "multi_thread")]
559    async fn priority_semaphore_works() {
560        let permits = Arc::new(PrioritySemaphore::new(1));
561
562        let flag = Arc::new(AtomicBool::new(false));
563
564        tokio::spawn({
565            let permits = permits.clone();
566            async move {
567                println!("BACKGROUND BEFORE");
568                let _guard = permits.acquire(false).await.unwrap();
569                println!("BACKGROUND AFTER");
570                tokio::time::sleep(Duration::from_millis(100)).await;
571                println!("BACKGROUND FINISH");
572            }
573        });
574
575        tokio::time::sleep(Duration::from_micros(10)).await;
576
577        // Spawn an ordinary task that acquires a permit.
578        let ordinary_task = tokio::spawn({
579            let permits = permits.clone();
580            let flag = flag.clone();
581            async move {
582                println!("ORDINARY BEFORE");
583                let _guard = permits.acquire(false).await.unwrap();
584                println!("ORDINARY AFTER");
585                // Flag must be fired by the priority task after the permit is acquired.
586                assert!(flag.load(Ordering::Acquire));
587            }
588        });
589
590        tokio::time::sleep(Duration::from_micros(10)).await;
591
592        let priority_task = tokio::spawn({
593            let flag = flag.clone();
594            async move {
595                println!("PRIORITY BEFORE");
596                let _guard = permits.acquire(true).await.unwrap();
597                println!("PRIORITY");
598                flag.store(true, Ordering::Release);
599            }
600        });
601
602        ordinary_task.await.unwrap();
603        priority_task.await.unwrap();
604    }
605
606    #[tokio::test(flavor = "multi_thread")]
607    async fn priority_semaphore_is_fair() {
608        let permits = Arc::new(PrioritySemaphore::new(10));
609
610        let flag = AtomicBool::new(false);
611        tokio::join!(
612            non_cooperative_task(permits, &flag),
613            poor_little_task(&flag),
614        );
615    }
616
617    async fn non_cooperative_task(permits: Arc<PrioritySemaphore>, flag: &AtomicBool) {
618        while !flag.load(Ordering::Acquire) {
619            let _permit = permits.acquire(false).await.unwrap();
620
621            // NOTE: This yield is necessary to allow the other task to run.
622            tokio::task::yield_now().await;
623        }
624    }
625
626    async fn poor_little_task(flag: &AtomicBool) {
627        tokio::time::sleep(Duration::from_secs(1)).await;
628        flag.store(true, Ordering::Release);
629    }
630}