1use tower_layer::Layer;
2
3use crate::Algorithm;
4use crate::classifier::DefaultClassifier;
5use crate::service::ConcurrencyLimit;
6
7#[derive(Debug, Clone)]
24pub struct ConcurrencyLimitLayer<A, C = DefaultClassifier> {
25 algorithm: A,
26 classifier: C,
27}
28
29impl<A> ConcurrencyLimitLayer<A> {
30 pub fn new(algorithm: A) -> Self {
32 Self {
33 algorithm,
34 classifier: DefaultClassifier,
35 }
36 }
37}
38
39impl<A, C> ConcurrencyLimitLayer<A, C> {
40 pub fn with_classifier(algorithm: A, classifier: C) -> Self {
43 Self {
44 algorithm,
45 classifier,
46 }
47 }
48}
49
50impl<S, A, C> Layer<S> for ConcurrencyLimitLayer<A, C>
51where
52 A: Algorithm + Clone,
53 C: Clone,
54{
55 type Service = ConcurrencyLimit<S, A, C>;
56
57 fn layer(&self, service: S) -> Self::Service {
58 ConcurrencyLimit::with_classifier(service, self.algorithm.clone(), self.classifier.clone())
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use std::convert::Infallible;
66 use std::future::{Ready, ready};
67 use std::task::{Context, Poll};
68 use std::time::Duration;
69 use tower_service::Service;
70
71 #[derive(Clone, Debug)]
73 struct FixedAlgorithm(usize);
74
75 impl Algorithm for FixedAlgorithm {
76 fn max_concurrency(&self) -> usize {
77 self.0
78 }
79
80 fn update(
81 &mut self,
82 _rtt: Duration,
83 _num_inflight: usize,
84 _is_error: bool,
85 _is_canceled: bool,
86 ) {
87 }
88 }
89
90 #[derive(Clone, Debug)]
92 struct EchoService;
93
94 impl Service<&'static str> for EchoService {
95 type Response = &'static str;
96 type Error = Infallible;
97 type Future = Ready<Result<&'static str, Infallible>>;
98
99 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 Poll::Ready(Ok(()))
101 }
102
103 fn call(&mut self, req: &'static str) -> Self::Future {
104 ready(Ok(req))
105 }
106 }
107
108 #[test]
109 fn layer_produces_concurrency_limit_service() {
110 let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(10));
111 let svc = layer.layer(EchoService);
112 let inner: &EchoService = svc.get_ref();
114 assert!(format!("{:?}", inner).contains("EchoService"));
115 }
116
117 #[tokio::test]
118 async fn layered_service_forwards_requests() {
119 let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(10));
120 let mut svc = layer.layer(EchoService);
121
122 std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
124 let resp = svc.call("hello").await.unwrap();
125 assert_eq!(resp, "hello");
126 }
127
128 #[test]
129 fn layer_is_clone() {
130 let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(5));
131 let layer2 = layer.clone();
132 let _ = layer.layer(EchoService);
134 let _ = layer2.layer(EchoService);
135 }
136
137 #[test]
138 fn layer_is_debug() {
139 let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(5));
140 let debug = format!("{:?}", layer);
141 assert!(debug.contains("ConcurrencyLimitLayer"));
142 }
143
144 #[tokio::test]
145 async fn layer_with_custom_classifier() {
146 let classifier = |_result: &Result<&str, Infallible>| false;
148
149 let layer = ConcurrencyLimitLayer::with_classifier(FixedAlgorithm(10), classifier);
150 let mut svc = layer.layer(EchoService);
151
152 std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
153 let resp = svc.call("hello").await.unwrap();
154 assert_eq!(resp, "hello");
155 }
156}