1use crate::Algorithm;
2use crate::controller::Controller;
3use crate::future::ResponseFuture;
4
5use tokio::sync::OwnedSemaphorePermit;
6use tokio_util::sync::PollSemaphore;
7use tower_service::Service;
8
9use std::{
10 sync::{Arc, Mutex},
11 task::{Context, Poll, ready},
12 time::Instant,
13};
14
15pub struct ConcurrencyLimit<S, A> {
25 inner: S,
26 controller: Arc<Mutex<Controller<A>>>,
27 semaphore: PollSemaphore,
28 permit: Option<OwnedSemaphorePermit>,
34}
35
36impl<S, A> ConcurrencyLimit<S, A>
37where
38 A: Algorithm,
39{
40 pub fn new(inner: S, algorithm: A) -> Self {
42 let controller = Controller::new(algorithm);
43 let semaphore = controller.semaphore();
44
45 Self {
46 inner,
47 controller: Arc::new(Mutex::new(controller)),
48 semaphore: PollSemaphore::new(semaphore),
49 permit: None,
50 }
51 }
52
53 pub fn get_ref(&self) -> &S {
55 &self.inner
56 }
57
58 pub fn get_mut(&mut self) -> &mut S {
60 &mut self.inner
61 }
62
63 pub fn into_inner(self) -> S {
65 self.inner
66 }
67}
68
69impl<S, A, Request> Service<Request> for ConcurrencyLimit<S, A>
70where
71 S: Service<Request>,
72 A: Algorithm,
73{
74 type Response = S::Response;
75 type Error = S::Error;
76 type Future = ResponseFuture<S::Future, A>;
77
78 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79 if self.permit.is_none() {
80 self.permit = ready!(self.semaphore.poll_acquire(cx));
81 debug_assert!(self.permit.is_some(), "semaphore should never be closed");
82 }
83 self.inner.poll_ready(cx)
86 }
87
88 fn call(&mut self, request: Request) -> Self::Future {
89 let start = Instant::now();
90 let permit = self
92 .permit
93 .take()
94 .expect("`poll_ready` should be called first");
95
96 let future = self.inner.call(request);
98 ResponseFuture::new(future, self.controller.clone(), permit, start)
99 }
100}
101
102impl<S: Clone, A> Clone for ConcurrencyLimit<S, A> {
103 fn clone(&self) -> Self {
104 Self {
108 inner: self.inner.clone(),
109 controller: self.controller.clone(),
110 semaphore: self.semaphore.clone(),
111 permit: None,
112 }
113 }
114}
115
116impl<S: std::fmt::Debug, A> std::fmt::Debug for ConcurrencyLimit<S, A> {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("ConcurrencyLimit")
119 .field("inner", &self.inner)
120 .field("permit", &self.permit)
121 .finish_non_exhaustive()
122 }
123}