Skip to main content

tower_acc/
layer.rs

1use tower_layer::Layer;
2
3use crate::Algorithm;
4use crate::classifier::DefaultClassifier;
5use crate::service::ConcurrencyLimit;
6
7/// A [`Layer`] that wraps services with an adaptive [`ConcurrencyLimit`].
8///
9/// # Example
10///
11/// ```rust,no_run
12/// use tower::ServiceBuilder;
13/// use tower_acc::{ConcurrencyLimitLayer, Vegas};
14/// # fn wrap<S>(my_service: S) -> impl tower_service::Service<()>
15/// # where S: tower_service::Service<(), Error = std::convert::Infallible> {
16///
17/// let service = ServiceBuilder::new()
18///     .layer(ConcurrencyLimitLayer::new(Vegas::default()))
19///     .service(my_service);
20/// # service
21/// # }
22/// ```
23#[derive(Debug, Clone)]
24pub struct ConcurrencyLimitLayer<A, C = DefaultClassifier> {
25    algorithm: A,
26    classifier: C,
27}
28
29impl<A> ConcurrencyLimitLayer<A> {
30    /// Creates a new `ConcurrencyLimitLayer` with the given algorithm.
31    pub fn new(algorithm: A) -> Self {
32        Self {
33            algorithm,
34            classifier: DefaultClassifier,
35        }
36    }
37}
38
39impl<A, C> ConcurrencyLimitLayer<A, C> {
40    /// Creates a new `ConcurrencyLimitLayer` with the given algorithm and
41    /// [`Classifier`](crate::Classifier).
42    pub fn with_classifier(algorithm: A, classifier: C) -> Self {
43        Self {
44            algorithm,
45            classifier,
46        }
47    }
48}
49
50impl<S, A, C> Layer<S> for ConcurrencyLimitLayer<A, C>
51where
52    A: Algorithm + Clone,
53    C: Clone,
54{
55    type Service = ConcurrencyLimit<S, A, C>;
56
57    fn layer(&self, service: S) -> Self::Service {
58        ConcurrencyLimit::with_classifier(service, self.algorithm.clone(), self.classifier.clone())
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use std::convert::Infallible;
66    use std::future::{Ready, ready};
67    use std::task::{Context, Poll};
68    use std::time::Duration;
69    use tower_service::Service;
70
71    /// Minimal algorithm with a fixed limit.
72    #[derive(Clone, Debug)]
73    struct FixedAlgorithm(usize);
74
75    impl Algorithm for FixedAlgorithm {
76        fn max_concurrency(&self) -> usize {
77            self.0
78        }
79
80        fn update(
81            &mut self,
82            _rtt: Duration,
83            _num_inflight: usize,
84            _is_error: bool,
85            _is_canceled: bool,
86        ) {
87        }
88    }
89
90    /// Trivial service that returns the request unchanged.
91    #[derive(Clone, Debug)]
92    struct EchoService;
93
94    impl Service<&'static str> for EchoService {
95        type Response = &'static str;
96        type Error = Infallible;
97        type Future = Ready<Result<&'static str, Infallible>>;
98
99        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100            Poll::Ready(Ok(()))
101        }
102
103        fn call(&mut self, req: &'static str) -> Self::Future {
104            ready(Ok(req))
105        }
106    }
107
108    #[test]
109    fn layer_produces_concurrency_limit_service() {
110        let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(10));
111        let svc = layer.layer(EchoService);
112        // Verify we get a ConcurrencyLimit wrapping EchoService.
113        let inner: &EchoService = svc.get_ref();
114        assert!(format!("{:?}", inner).contains("EchoService"));
115    }
116
117    #[tokio::test]
118    async fn layered_service_forwards_requests() {
119        let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(10));
120        let mut svc = layer.layer(EchoService);
121
122        // poll_ready + call.
123        std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
124        let resp = svc.call("hello").await.unwrap();
125        assert_eq!(resp, "hello");
126    }
127
128    #[test]
129    fn layer_is_clone() {
130        let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(5));
131        let layer2 = layer.clone();
132        // Both produce working services.
133        let _ = layer.layer(EchoService);
134        let _ = layer2.layer(EchoService);
135    }
136
137    #[test]
138    fn layer_is_debug() {
139        let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(5));
140        let debug = format!("{:?}", layer);
141        assert!(debug.contains("ConcurrencyLimitLayer"));
142    }
143
144    #[tokio::test]
145    async fn layer_with_custom_classifier() {
146        // Classifier that treats all results as non-errors.
147        let classifier = |_result: &Result<&str, Infallible>| false;
148
149        let layer = ConcurrencyLimitLayer::with_classifier(FixedAlgorithm(10), classifier);
150        let mut svc = layer.layer(EchoService);
151
152        std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
153        let resp = svc.call("hello").await.unwrap();
154        assert_eq!(resp, "hello");
155    }
156}