tower_async/limit/
mod.rs

1//! A middleware that limits the number of in-flight requests.
2//!
3//! See [`Limit`].
4
5use tower_async_service::Service;
6
7use crate::BoxError;
8
9pub mod policy;
10pub use policy::{Policy, PolicyOutput};
11
12mod layer;
13pub use layer::LimitLayer;
14
15/// Limit requests based on a policy
16#[derive(Debug)]
17pub struct Limit<T, P> {
18    inner: T,
19    policy: P,
20}
21
22impl<T, P> Limit<T, P> {
23    /// Creates a new [`Limit`] from a limit policy,
24    /// wrapping the given service.
25    pub fn new(inner: T, policy: P) -> Self {
26        Limit { inner, policy }
27    }
28}
29
30impl<T, P> Clone for Limit<T, P>
31where
32    T: Clone,
33    P: Clone,
34{
35    fn clone(&self) -> Self {
36        Limit {
37            inner: self.inner.clone(),
38            policy: self.policy.clone(),
39        }
40    }
41}
42
43impl<T, P, Request> Service<Request> for Limit<T, P>
44where
45    T: Service<Request>,
46    T::Error: Into<BoxError>,
47    P: policy::Policy<Request>,
48    P::Error: Into<BoxError>,
49{
50    type Response = T::Response;
51    type Error = BoxError;
52
53    async fn call(&self, request: Request) -> Result<Self::Response, Self::Error> {
54        let mut request = request;
55        loop {
56            match self.policy.check(&mut request).await {
57                policy::PolicyOutput::Ready(guard) => {
58                    let _ = guard;
59                    return self.inner.call(request).await.map_err(Into::into);
60                }
61                policy::PolicyOutput::Abort(err) => return Err(err.into()),
62                policy::PolicyOutput::Retry => (),
63            }
64        }
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use std::convert::Infallible;
71
72    use crate::limit::policy::ConcurrentPolicy;
73    use crate::service_fn;
74
75    use super::*;
76
77    use futures_util::future::join_all;
78    use tower_async_layer::Layer;
79    use tower_async_service::Service;
80
81    #[tokio::test]
82    async fn test_limit() {
83        async fn handle_request<Request>(req: Request) -> Result<Request, Infallible> {
84            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
85            Ok(req)
86        }
87
88        let layer: LimitLayer<ConcurrentPolicy<()>> = LimitLayer::new(ConcurrentPolicy::new(1));
89
90        let service_1 = layer.layer(service_fn(handle_request));
91        let service_2 = layer.layer(service_fn(handle_request));
92
93        let future_1 = service_1.call("Hello");
94        let future_2 = service_2.call("Hello");
95
96        let mut results = join_all(vec![future_1, future_2]).await;
97        let result_1 = results.pop().unwrap();
98        let result_2 = results.pop().unwrap();
99
100        // check that one request succeeded and the other failed
101        if result_1.is_err() {
102            assert_eq!(result_2.unwrap(), "Hello");
103        } else {
104            assert_eq!(result_1.unwrap(), "Hello");
105            assert!(result_2.is_err());
106        }
107    }
108}