1use std::sync::atomic::Ordering::SeqCst;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use more_asserts::debug_assert_le;
6use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
7
8const PERMIT_LIMIT: u64 = {
11 let sem_max = Semaphore::MAX_PERMITS as u64;
12 let u32_max = u32::MAX as u64;
13 if sem_max < u32_max { sem_max } else { u32_max }
14};
15
16#[derive(Debug)]
29pub struct AdjustableSemaphore {
30 semaphore: Arc<Semaphore>,
31 total_permits: AtomicU64,
32 enqueued_permit_decreases: AtomicU64,
33 min_physical_permits: u64,
34 max_physical_permits: u64,
35 basis: u64,
36
37 adjustment_lock: Mutex<()>,
40}
41
42pub struct AdjustableSemaphorePermit {
45 permit: Option<OwnedSemaphorePermit>,
46 num_physical_permits: u32,
47 parent: Arc<AdjustableSemaphore>,
48}
49
50impl AdjustableSemaphorePermit {
51 pub fn num_permits(&self) -> u64 {
53 self.num_physical_permits as u64 * self.parent.basis
54 }
55
56 pub fn num_physical_permits(&self) -> u32 {
58 self.num_physical_permits
59 }
60
61 pub fn split(&mut self, n: u64) -> Option<AdjustableSemaphorePermit> {
65 let physical_n = n.div_ceil(self.parent.basis);
66 if physical_n > self.num_physical_permits as u64 {
67 return None;
68 }
69 let physical_n = physical_n as u32;
70
71 self.num_physical_permits -= physical_n;
72
73 if physical_n > 0 {
74 let permit = self.permit.as_mut().and_then(|p| p.split(physical_n as usize));
75 Some(AdjustableSemaphorePermit {
76 permit,
77 num_physical_permits: physical_n,
78 parent: self.parent.clone(),
79 })
80 } else {
81 None
82 }
83 }
84}
85
86impl Drop for AdjustableSemaphorePermit {
87 fn drop(&mut self) {
88 let parent = &self.parent;
89 let num_permits = self.num_physical_permits as u64;
90
91 let decreases_resolved = attempt_sub(&parent.enqueued_permit_decreases, num_permits, 0);
92
93 if let Some(mut permit) = self.permit.take() {
94 if decreases_resolved > 0 {
95 if let Some(p) = permit.split(decreases_resolved as usize) {
98 p.forget();
99 } else {
100 debug_assert!(false, "Failed to split permit; mismatch in self.num_permits.");
101 }
102 }
103 } else {
104 debug_assert_le!(decreases_resolved, num_permits);
106
107 let to_return = (num_permits - decreases_resolved) as usize;
110 if to_return > 0 {
111 parent.semaphore.add_permits(to_return);
112 }
113 }
114 }
115}
116
117impl AdjustableSemaphore {
118 pub fn new(initial_permits: u64, permit_range: (u64, u64)) -> Arc<Self> {
119 debug_assert!(permit_range.0 <= permit_range.1);
120 debug_assert!(permit_range.0 <= initial_permits);
121 debug_assert!(initial_permits <= permit_range.1);
122
123 let basis = Self::compute_basis(permit_range.1);
124 let min_physical = permit_range.0.div_ceil(basis);
125 let max_physical = permit_range.1.div_ceil(basis);
126 let initial_physical = initial_permits.div_ceil(basis).clamp(min_physical, max_physical);
127
128 Arc::new(Self {
129 semaphore: Arc::new(Semaphore::new(initial_physical as usize)),
130 total_permits: AtomicU64::new(initial_physical),
131 enqueued_permit_decreases: AtomicU64::new(0),
132 min_physical_permits: min_physical,
133 max_physical_permits: max_physical,
134 basis,
135 adjustment_lock: Mutex::new(()),
136 })
137 }
138
139 pub fn total_permits(&self) -> u64 {
140 self.total_permits.load(Ordering::Relaxed) * self.basis
141 }
142
143 pub fn available_permits(&self) -> u64 {
144 self.semaphore.available_permits() as u64 * self.basis
145 }
146
147 pub fn active_permits(&self) -> u64 {
148 (self.total_permits.load(Ordering::Relaxed) + self.enqueued_permit_decreases.load(Ordering::Relaxed))
149 .saturating_sub(self.semaphore.available_permits() as u64)
150 * self.basis
151 }
152
153 pub fn basis(&self) -> u64 {
156 self.basis
157 }
158
159 pub async fn acquire(self: &Arc<Self>) -> Result<AdjustableSemaphorePermit, AcquireError> {
161 self.acquire_many(1).await
162 }
163
164 pub async fn acquire_many(self: &Arc<Self>, n: u64) -> Result<AdjustableSemaphorePermit, AcquireError> {
168 let physical = self.to_physical_acquire(n);
169 let permit = self.semaphore.clone().acquire_many_owned(physical).await?;
170 Ok(AdjustableSemaphorePermit {
171 permit: Some(permit),
172 num_physical_permits: physical,
173 parent: self.clone(),
174 })
175 }
176
177 pub fn decrement_total_permits(&self, n: u64) -> Option<u64> {
184 let lock = self.adjustment_lock.lock().unwrap();
185 self.decrement_total_permits_impl(lock, n)
186 }
187
188 pub fn decrement_permits_to_target(&self, target: u64) -> Option<u64> {
196 let lock = self.adjustment_lock.lock().unwrap();
197 let current = self.total_permits();
198 if target >= current {
199 return None;
200 }
201 let requested_decrease = current - target;
202 self.decrement_total_permits_impl(lock, requested_decrease)
203 }
204
205 fn decrement_total_permits_impl(&self, _lock: MutexGuard<'_, ()>, n: u64) -> Option<u64> {
206 let physical_n = n.div_ceil(self.basis);
207 if physical_n == 0 {
208 return None;
209 }
210
211 let removed = attempt_sub(&self.total_permits, physical_n, self.min_physical_permits);
212 if removed == 0 {
213 return None;
214 }
215
216 if let Ok(permit) = self.semaphore.clone().try_acquire_many_owned(removed as u32) {
217 permit.forget();
218 } else {
219 self.enqueued_permit_decreases.fetch_add(removed, Ordering::Relaxed);
220 }
221
222 Some(removed * self.basis)
223 }
224
225 pub fn increment_total_permits(self: &Arc<Self>, n: u64) -> Option<AdjustableSemaphorePermit> {
235 let lock = self.adjustment_lock.lock().unwrap();
236 self.increment_total_permits_impl(lock, n)
237 }
238
239 pub fn increment_permits_to_target(self: &Arc<Self>, target: u64) -> Option<AdjustableSemaphorePermit> {
247 let lock = self.adjustment_lock.lock().unwrap();
248 let current = self.total_permits();
249 if target <= current {
250 return None;
251 }
252 self.increment_total_permits_impl(lock, target - current)
253 }
254
255 fn increment_total_permits_impl(
256 self: &Arc<Self>,
257 _lock: MutexGuard<'_, ()>,
258 n: u64,
259 ) -> Option<AdjustableSemaphorePermit> {
260 let physical_n = n.div_ceil(self.basis);
261 if physical_n == 0 {
262 return None;
263 }
264
265 let added = attempt_add(&self.total_permits, physical_n, self.max_physical_permits);
266 if added == 0 {
267 return None;
268 }
269
270 let cancelled = attempt_sub(&self.enqueued_permit_decreases, added, 0);
271 let to_hold = (added - cancelled) as u32;
272
273 Some(AdjustableSemaphorePermit {
274 permit: None,
275 num_physical_permits: to_hold,
276 parent: self.clone(),
277 })
278 }
279
280 fn compute_basis(max_permits: u64) -> u64 {
283 let mut basis: u64 = 1;
284 while max_permits.div_ceil(basis) > PERMIT_LIMIT {
285 basis *= 2;
286 }
287 basis
288 }
289
290 fn to_physical_acquire(&self, n: u64) -> u32 {
293 let total = self.total_permits.load(Ordering::Relaxed).max(1);
294 n.div_ceil(self.basis).clamp(1, total) as u32
295 }
296
297 #[cfg(test)]
300 fn with_forced_basis(initial: u64, min: u64, max: u64, basis: u64) -> Arc<Self> {
301 assert!(basis > 0, "basis must be greater than zero");
302 let min_physical_permits = min.div_ceil(basis);
303 let max_physical_permits = max.div_ceil(basis).min(PERMIT_LIMIT);
304 let initial_physical = initial.div_ceil(basis).clamp(min_physical_permits, max_physical_permits);
305
306 Arc::new(Self {
307 semaphore: Arc::new(Semaphore::new(initial_physical as usize)),
308 total_permits: AtomicU64::new(initial_physical),
309 enqueued_permit_decreases: AtomicU64::new(0),
310 min_physical_permits,
311 max_physical_permits,
312 basis,
313 adjustment_lock: Mutex::new(()),
314 })
315 }
316}
317
318#[inline]
321fn attempt_add(v: &AtomicU64, n: u64, max_value: u64) -> u64 {
322 match v.fetch_update(SeqCst, SeqCst, |x| {
323 if x >= max_value {
324 None
325 } else {
326 Some(x.saturating_add(n).min(max_value))
327 }
328 }) {
329 Ok(old) => old.saturating_add(n).min(max_value) - old,
330 Err(_) => 0,
331 }
332}
333
334#[inline]
337fn attempt_sub(v: &AtomicU64, n: u64, min_value: u64) -> u64 {
338 match v.fetch_update(SeqCst, SeqCst, |x| {
339 if x <= min_value {
340 None
341 } else {
342 Some(x.saturating_sub(n).max(min_value))
343 }
344 }) {
345 Ok(old) => old - old.saturating_sub(n).max(min_value),
346 Err(_) => 0,
347 }
348}
349
350#[cfg(test)]
351mod tests {
352
353 use std::time::Duration;
354
355 use more_asserts::{assert_ge, assert_le};
356 use rand::prelude::*;
357 use tokio::sync::Barrier;
358 use tokio::task::JoinSet;
359
360 use super::*;
361
362 #[tokio::test]
365 async fn test_bounds_and_adjustment() {
366 for basis in [1u64, 2] {
367 let sem = AdjustableSemaphore::with_forced_basis(6, 2, 12, basis);
368 assert_eq!(sem.total_permits(), 6);
369
370 assert!(sem.increment_total_permits(4).is_some());
371 assert_eq!(sem.total_permits(), 10);
372
373 assert!(sem.increment_total_permits(100).is_some());
375 assert_eq!(sem.total_permits(), 12);
376 assert!(sem.increment_total_permits(2).is_none());
377
378 assert_eq!(sem.decrement_total_permits(4), Some(4));
380 assert_eq!(sem.total_permits(), 8);
381 assert_eq!(sem.decrement_total_permits(100), Some(6));
382 assert_eq!(sem.total_permits(), 2);
383 assert!(sem.decrement_total_permits(2).is_none());
384
385 assert!(sem.increment_total_permits(4).is_some());
387 assert_eq!(sem.total_permits(), 6);
388
389 assert!(sem.increment_permits_to_target(10).is_some());
391 assert_eq!(sem.total_permits(), 10);
392 assert!(sem.increment_permits_to_target(10).is_none());
393
394 assert_eq!(sem.decrement_permits_to_target(6), Some(4));
395 assert_eq!(sem.total_permits(), 6);
396 assert!(sem.decrement_permits_to_target(6).is_none());
397
398 assert_eq!(sem.decrement_permits_to_target(0), Some(4));
400 assert_eq!(sem.total_permits(), 2);
401
402 assert!(sem.increment_permits_to_target(12).is_some());
404 assert_eq!(sem.total_permits(), 12);
405 }
406 }
407
408 #[tokio::test]
411 async fn test_acquire_and_release() {
412 for basis in [1u64, 2] {
413 let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 1024, basis);
414 assert_eq!(sem.available_permits(), 1024);
415
416 let p1 = sem.acquire_many(256).await.unwrap();
417 assert_eq!(p1.num_permits(), 256);
418 assert_eq!(sem.available_permits(), 768);
419
420 let p2 = sem.acquire_many(512).await.unwrap();
421 assert_eq!(sem.available_permits(), 256);
422
423 drop(p1);
424 assert_eq!(sem.available_permits(), 512);
425 drop(p2);
426 assert_eq!(sem.available_permits(), 1024);
427
428 {
430 let _p = sem.acquire_many(1024).await.unwrap();
431 assert_eq!(sem.available_permits(), 0);
432 }
433 assert_eq!(sem.available_permits(), 1024);
434
435 let _p = sem.acquire_many(5000).await.unwrap();
437 assert_eq!(sem.available_permits(), 0);
438 }
439 }
440
441 #[tokio::test]
444 async fn test_enqueued_decrease_resolution() {
445 for basis in [1u64, 2] {
446 let sem = AdjustableSemaphore::with_forced_basis(4, 2, 6, basis);
448
449 let p1 = sem.acquire_many(2).await.unwrap();
450 let p2 = sem.acquire_many(2).await.unwrap();
451 assert_eq!(sem.available_permits(), 0);
452
453 assert!(sem.decrement_total_permits(2).is_some());
454 assert_eq!(sem.total_permits(), 2);
455
456 drop(p1); assert_eq!(sem.available_permits(), 0);
458
459 drop(p2);
460 assert_eq!(sem.available_permits(), 2);
461
462 let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 1024, basis);
464 let p = sem.acquire_many(1024).await.unwrap();
465 assert!(sem.decrement_total_permits(512).is_some());
466 assert_eq!(sem.total_permits(), 512);
467
468 drop(p);
469 assert_eq!(sem.available_permits(), 512);
470 }
471 }
472
473 #[tokio::test]
476 async fn test_increment_cancels_enqueued() {
477 for basis in [1u64, 2] {
478 let sem = AdjustableSemaphore::with_forced_basis(4, 0, 10, basis);
479
480 let p1 = sem.acquire_many(2).await.unwrap();
481 let p2 = sem.acquire_many(2).await.unwrap();
482
483 assert!(sem.decrement_total_permits(2).is_some());
484 assert_eq!(sem.total_permits(), 2);
485
486 let vp = sem.increment_total_permits(2).unwrap();
487 assert_eq!(vp.num_permits(), 0);
488 assert_eq!(sem.total_permits(), 4);
489 drop(vp);
490
491 drop(p1);
492 assert_eq!(sem.available_permits(), 2);
493 drop(p2);
494 assert_eq!(sem.available_permits(), 4);
495 }
496 }
497
498 #[tokio::test]
501 async fn test_virtual_permit() {
502 for basis in [1u64, 2] {
503 let sem = AdjustableSemaphore::with_forced_basis(4, 0, 20, basis);
504
505 let vp = sem.increment_total_permits(6).unwrap();
507 assert_eq!(sem.total_permits(), 10);
508 assert_eq!(vp.num_permits(), 6);
509 assert_eq!(sem.available_permits(), 4);
510
511 drop(vp);
512 assert_eq!(sem.available_permits(), 10);
513
514 let sem = AdjustableSemaphore::with_forced_basis(0, 0, 22, basis);
516 let mut permits = Vec::new();
517 for i in 0..10u64 {
518 assert_eq!(sem.available_permits(), 0);
519 assert_eq!(sem.total_permits(), i * 2);
520 sem.increment_total_permits(2);
521 permits.push(sem.acquire_many(2).await.unwrap());
522 }
523 for i in 0..10u64 {
524 assert_eq!(sem.available_permits(), i * 2);
525 permits.pop();
526 }
527 }
528 }
529
530 #[tokio::test]
533 async fn test_permit_split() {
534 for basis in [1u64, 2] {
535 let sem = AdjustableSemaphore::with_forced_basis(10, 0, 10, basis);
536
537 let mut p = sem.acquire_many(6).await.unwrap();
539 let p2 = p.split(2).unwrap();
540 assert_eq!(p.num_permits(), 4);
541 assert_eq!(p2.num_permits(), 2);
542 drop(p2);
543 assert_eq!(sem.available_permits(), 6);
544 drop(p);
545 assert_eq!(sem.available_permits(), 10);
546
547 let mut p = sem.acquire_many(6).await.unwrap();
549 let p2 = p.split(6).unwrap();
550 assert_eq!(p.num_permits(), 0);
551 assert_eq!(p2.num_permits(), 6);
552 drop(p);
553 assert_eq!(sem.available_permits(), 4);
554 drop(p2);
555 assert_eq!(sem.available_permits(), 10);
556
557 let mut p = sem.acquire_many(4).await.unwrap();
559 assert!(p.split(6).is_none());
560 assert_eq!(p.num_permits(), 4);
561 drop(p);
562 }
563 }
564
565 #[tokio::test]
568 async fn test_virtual_permit_split() {
569 for basis in [1u64, 2] {
570 let sem = AdjustableSemaphore::with_forced_basis(4, 0, 20, basis);
571
572 let mut vp = sem.increment_total_permits(8).unwrap();
573 assert_eq!(sem.total_permits(), 12);
574 assert_eq!(sem.available_permits(), 4);
575 assert_eq!(vp.num_permits(), 8);
576
577 let vp2 = vp.split(2).unwrap();
578 assert_eq!(vp.num_permits(), 6);
579 assert_eq!(vp2.num_permits(), 2);
580
581 drop(vp2);
582 assert_eq!(sem.available_permits(), 6);
583
584 drop(vp);
585 assert_eq!(sem.available_permits(), 12);
586 }
587 }
588
589 #[test]
592 fn test_basis_computation() {
593 assert_eq!(AdjustableSemaphore::new(1024, (0, 1024)).basis(), 1);
594 assert_eq!(AdjustableSemaphore::new(PERMIT_LIMIT, (0, PERMIT_LIMIT)).basis(), 1);
595 assert_eq!(AdjustableSemaphore::new(PERMIT_LIMIT + 1, (0, PERMIT_LIMIT + 1)).basis(), 2);
596 }
597
598 #[test]
601 fn test_forced_basis_rounding() {
602 let sem = AdjustableSemaphore::with_forced_basis(1000, 0, 1000, 300);
604 assert_eq!(sem.total_permits(), 1200);
605
606 let sem = AdjustableSemaphore::with_forced_basis(900, 0, 900, 300);
608 assert_eq!(sem.total_permits(), 900);
609 }
610
611 #[tokio::test]
614 async fn test_rounding_and_physical_permits() {
615 let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 1024, 256);
617 let p = sem.acquire_many(1).await.unwrap();
618 assert_eq!(p.num_permits(), 256);
619 assert_eq!(p.num_physical_permits(), 1);
620 assert_eq!(sem.available_permits(), 768);
621 drop(p);
622
623 let sem = AdjustableSemaphore::with_forced_basis(1000, 0, 1000, 100);
625 let p = sem.acquire_many(250).await.unwrap();
626 assert_eq!(p.num_permits(), 300);
627 assert_eq!(p.num_physical_permits(), 3);
628 drop(p);
629
630 let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 2048, 256);
632 let vp = sem.increment_total_permits(512).unwrap();
633 assert_eq!(vp.num_permits(), 512);
634 assert_eq!(vp.num_physical_permits(), 2);
635 drop(vp);
636
637 let sem = AdjustableSemaphore::with_forced_basis(500, 0, 500, 100);
639 let mut p = sem.acquire_many(500).await.unwrap();
640 let p2 = p.split(1).unwrap();
641 assert_eq!(p2.num_permits(), 100);
642 assert_eq!(p2.num_physical_permits(), 1);
643 assert_eq!(p.num_permits(), 400);
644 assert_eq!(p.num_physical_permits(), 4);
645 drop(p2);
646 drop(p);
647
648 let sem = AdjustableSemaphore::with_forced_basis(500, 300, 500, 100);
650 assert!(sem.decrement_total_permits(300).is_some());
651 assert_eq!(sem.total_permits(), 300);
652 assert!(sem.decrement_total_permits(1).is_none());
653 }
654
655 #[test]
658 fn test_zero_capacity() {
659 let sem = AdjustableSemaphore::new(0, (0, 0));
660 assert_eq!(sem.total_permits(), 0);
661 assert_eq!(sem.available_permits(), 0);
662 }
663
664 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
667 #[cfg_attr(feature = "smoke-test", ignore)]
668 async fn test_concurrent_stress() {
669 const TASKS: usize = 50;
670 const OPS_PER_TASK: usize = 1000;
671
672 const MIN_PERMITS: u64 = 10;
673 const MAX_PERMITS: u64 = 50;
674
675 let sem = AdjustableSemaphore::new(30, (MIN_PERMITS, MAX_PERMITS));
676
677 let mut js = JoinSet::new();
678 let barrier = Arc::new(Barrier::new(TASKS + 1));
679
680 for t in 0..TASKS {
681 let sem = sem.clone();
682 let mut rng = SmallRng::seed_from_u64(t as u64);
683 let barrier = barrier.clone();
684
685 js.spawn(async move {
686 barrier.wait().await;
687 for _ in 0..OPS_PER_TASK {
688 if rng.random_bool(0.1) {
689 sem.increment_total_permits(1);
690 }
691
692 if rng.random_bool(0.1) {
693 let _ = sem.decrement_total_permits(1);
694 }
695
696 let p = sem.acquire().await;
697 tokio::time::sleep(Duration::from_micros(100)).await;
698 drop(p);
699
700 assert!(sem.total_permits() >= MIN_PERMITS);
701 assert!(sem.total_permits() <= MAX_PERMITS);
702 assert!(sem.available_permits() <= MAX_PERMITS);
703 }
704 });
705 }
706
707 barrier.wait().await;
708
709 js.join_all().await;
710
711 let final_permits = sem.total_permits();
712 assert_le!(final_permits, MAX_PERMITS);
713 assert_ge!(final_permits, MIN_PERMITS);
714 let avail_permits = sem.available_permits();
715 assert_eq!(avail_permits, final_permits);
716 }
717
718 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
721 #[cfg_attr(feature = "smoke-test", ignore)]
722 async fn test_concurrent_stress_acquire_many() {
723 const TASKS: usize = 30;
724 const OPS_PER_TASK: usize = 500;
725
726 const MIN_PERMITS: u64 = 100;
727 const MAX_PERMITS: u64 = 500;
728
729 let sem = AdjustableSemaphore::new(300, (MIN_PERMITS, MAX_PERMITS));
730
731 let mut js = JoinSet::new();
732 let barrier = Arc::new(Barrier::new(TASKS + 1));
733
734 for t in 0..TASKS {
735 let sem = sem.clone();
736 let mut rng = SmallRng::seed_from_u64(t as u64);
737 let barrier = barrier.clone();
738
739 js.spawn(async move {
740 barrier.wait().await;
741 for _ in 0..OPS_PER_TASK {
742 if rng.random_bool(0.05) {
743 sem.increment_total_permits(rng.random_range(1..=10));
744 }
745
746 if rng.random_bool(0.05) {
747 let _ = sem.decrement_total_permits(rng.random_range(1..=10));
748 }
749
750 let amount = rng.random_range(1..=50);
751 let p = sem.acquire_many(amount).await;
752 tokio::time::sleep(Duration::from_micros(50)).await;
753 drop(p);
754
755 assert!(sem.total_permits() >= MIN_PERMITS);
756 assert!(sem.total_permits() <= MAX_PERMITS);
757 }
758 });
759 }
760
761 barrier.wait().await;
762
763 js.join_all().await;
764
765 let final_permits = sem.total_permits();
766 assert_le!(final_permits, MAX_PERMITS);
767 assert_ge!(final_permits, MIN_PERMITS);
768 }
769}