Skip to main content

tower_acc/
service.rs

1use crate::Algorithm;
2use crate::classifier::{Classifier, DefaultClassifier};
3use crate::controller::Controller;
4use crate::future::ResponseFuture;
5
6use tokio::sync::OwnedSemaphorePermit;
7use tokio_util::sync::PollSemaphore;
8use tower_service::Service;
9
10use std::{
11    sync::{Arc, Mutex},
12    task::{Context, Poll, ready},
13    time::Instant,
14};
15
16/// Enforces an adaptive limit on the concurrent number of requests the
17/// underlying service can handle.
18///
19/// Unlike a static concurrency limit, `ConcurrencyLimit` continuously observes
20/// request latency and adjusts the number of allowed in-flight requests using
21/// the configured [`Algorithm`].
22///
23/// Use [`ConcurrencyLimitLayer`](crate::ConcurrencyLimitLayer) to integrate
24/// with `tower::ServiceBuilder`.
25pub struct ConcurrencyLimit<S, A, C = DefaultClassifier> {
26    inner: S,
27    classifier: C,
28    controller: Arc<Mutex<Controller<A>>>,
29    semaphore: PollSemaphore,
30    /// The currently acquired semaphore permit, if there is sufficient
31    /// concurrency to send a new request.
32    ///
33    /// The permit is acquired in `poll_ready`, and taken in `call` when sending
34    /// a new request.
35    permit: Option<OwnedSemaphorePermit>,
36}
37
38impl<S, A> ConcurrencyLimit<S, A>
39where
40    A: Algorithm,
41{
42    /// Creates a new concurrency limiter.
43    pub fn new(inner: S, algorithm: A) -> Self {
44        Self::with_classifier(inner, algorithm, DefaultClassifier)
45    }
46}
47
48impl<S, A, C> ConcurrencyLimit<S, A, C>
49where
50    A: Algorithm,
51{
52    /// Creates a new concurrency limiter with a custom [`Classifier`].
53    pub fn with_classifier(inner: S, algorithm: A, classifier: C) -> Self {
54        let controller = Controller::new(algorithm);
55        let semaphore = controller.semaphore();
56
57        Self {
58            inner,
59            classifier,
60            controller: Arc::new(Mutex::new(controller)),
61            semaphore: PollSemaphore::new(semaphore),
62            permit: None,
63        }
64    }
65
66    /// Gets a reference to the inner service.
67    pub fn get_ref(&self) -> &S {
68        &self.inner
69    }
70
71    /// Gets a mutable reference to the inner service.
72    pub fn get_mut(&mut self) -> &mut S {
73        &mut self.inner
74    }
75
76    /// Consumes `self`, returning the inner service.
77    pub fn into_inner(self) -> S {
78        self.inner
79    }
80}
81
82impl<S, A, C, Request> Service<Request> for ConcurrencyLimit<S, A, C>
83where
84    S: Service<Request>,
85    A: Algorithm,
86    C: Classifier<S::Response, S::Error> + Clone,
87{
88    type Response = S::Response;
89    type Error = S::Error;
90    type Future = ResponseFuture<S::Future, A, C>;
91
92    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93        if self.permit.is_none() {
94            self.permit = ready!(self.semaphore.poll_acquire(cx));
95            debug_assert!(self.permit.is_some(), "semaphore should never be closed");
96        }
97        // Once we've acquired a permit (or if we already had one), poll the
98        // inner service.
99        self.inner.poll_ready(cx)
100    }
101
102    fn call(&mut self, request: Request) -> Self::Future {
103        let start = Instant::now();
104        // Take the permit
105        let permit = self
106            .permit
107            .take()
108            .expect("`poll_ready` should be called first");
109
110        // Call the inner service
111        let future = self.inner.call(request);
112        ResponseFuture::new(
113            future,
114            self.controller.clone(),
115            permit,
116            start,
117            self.classifier.clone(),
118        )
119    }
120}
121
122impl<S: Clone, A, C: Clone> Clone for ConcurrencyLimit<S, A, C> {
123    fn clone(&self) -> Self {
124        // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`.
125        // Instead, when cloning the service, create a new service with the
126        // same semaphore, but with the permit in the un-acquired state.
127        Self {
128            inner: self.inner.clone(),
129            classifier: self.classifier.clone(),
130            controller: self.controller.clone(),
131            semaphore: self.semaphore.clone(),
132            permit: None,
133        }
134    }
135}
136
137impl<S: std::fmt::Debug, A, C> std::fmt::Debug for ConcurrencyLimit<S, A, C> {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        f.debug_struct("ConcurrencyLimit")
140            .field("inner", &self.inner)
141            .field("permit", &self.permit)
142            .finish_non_exhaustive()
143    }
144}