tower_http/metrics/
in_flight_requests.rs1use http::{Request, Response};
55use http_body::Body;
56use pin_project_lite::pin_project;
57use std::{
58 future::Future,
59 pin::Pin,
60 sync::{
61 atomic::{AtomicUsize, Ordering},
62 Arc,
63 },
64 task::{ready, Context, Poll},
65 time::Duration,
66};
67use tower_layer::Layer;
68use tower_service::Service;
69
70#[derive(Clone, Debug)]
74pub struct InFlightRequestsLayer {
75 counter: InFlightRequestsCounter,
76}
77
78impl InFlightRequestsLayer {
79 pub fn pair() -> (Self, InFlightRequestsCounter) {
81 let counter = InFlightRequestsCounter::new();
82 let layer = Self::new(counter.clone());
83 (layer, counter)
84 }
85
86 pub fn new(counter: InFlightRequestsCounter) -> Self {
88 Self { counter }
89 }
90}
91
92impl<S> Layer<S> for InFlightRequestsLayer {
93 type Service = InFlightRequests<S>;
94
95 fn layer(&self, inner: S) -> Self::Service {
96 InFlightRequests {
97 inner,
98 counter: self.counter.clone(),
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
107pub struct InFlightRequests<S> {
108 inner: S,
109 counter: InFlightRequestsCounter,
110}
111
112impl<S> InFlightRequests<S> {
113 pub fn pair(inner: S) -> (Self, InFlightRequestsCounter) {
115 let counter = InFlightRequestsCounter::new();
116 let service = Self::new(inner, counter.clone());
117 (service, counter)
118 }
119
120 pub fn new(inner: S, counter: InFlightRequestsCounter) -> Self {
122 Self { inner, counter }
123 }
124
125 define_inner_service_accessors!();
126}
127
128#[derive(Debug, Clone, Default)]
133pub struct InFlightRequestsCounter {
134 count: Arc<AtomicUsize>,
135}
136
137impl InFlightRequestsCounter {
138 pub fn new() -> Self {
140 Self::default()
141 }
142
143 pub fn get(&self) -> usize {
145 self.count.load(Ordering::Relaxed)
146 }
147
148 fn increment(&self) -> IncrementGuard {
149 self.count.fetch_add(1, Ordering::Relaxed);
150 IncrementGuard {
151 count: self.count.clone(),
152 }
153 }
154
155 pub async fn run_emitter<F, Fut>(mut self, interval: Duration, mut emit: F)
174 where
175 F: FnMut(usize) -> Fut + Send + 'static,
176 Fut: Future<Output = ()> + Send,
177 {
178 let mut interval = tokio::time::interval(interval);
179
180 loop {
181 match Arc::try_unwrap(self.count) {
183 Ok(_) => return,
184 Err(shared_count) => {
185 self = Self {
186 count: shared_count,
187 }
188 }
189 }
190
191 interval.tick().await;
192 emit(self.get()).await;
193 }
194 }
195}
196
197struct IncrementGuard {
198 count: Arc<AtomicUsize>,
199}
200
201impl Drop for IncrementGuard {
202 fn drop(&mut self) {
203 self.count.fetch_sub(1, Ordering::Relaxed);
204 }
205}
206
207impl<S, R, ResBody> Service<Request<R>> for InFlightRequests<S>
208where
209 S: Service<Request<R>, Response = Response<ResBody>>,
210{
211 type Response = Response<ResponseBody<ResBody>>;
212 type Error = S::Error;
213 type Future = ResponseFuture<S::Future>;
214
215 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 self.inner.poll_ready(cx)
217 }
218
219 fn call(&mut self, req: Request<R>) -> Self::Future {
220 let guard = self.counter.increment();
221 ResponseFuture {
222 inner: self.inner.call(req),
223 guard: Some(guard),
224 }
225 }
226}
227
228pin_project! {
229 pub struct ResponseFuture<F> {
231 #[pin]
232 inner: F,
233 guard: Option<IncrementGuard>,
234 }
235}
236
237impl<F, B, E> Future for ResponseFuture<F>
238where
239 F: Future<Output = Result<Response<B>, E>>,
240{
241 type Output = Result<Response<ResponseBody<B>>, E>;
242
243 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244 let this = self.project();
245 let response = ready!(this.inner.poll(cx))?;
246 let guard = this.guard.take().unwrap();
247 let response = response.map(move |body| ResponseBody { inner: body, guard });
248
249 Poll::Ready(Ok(response))
250 }
251}
252
253pin_project! {
254 pub struct ResponseBody<B> {
256 #[pin]
257 inner: B,
258 guard: IncrementGuard,
259 }
260}
261
262impl<B> Body for ResponseBody<B>
263where
264 B: Body,
265{
266 type Data = B::Data;
267 type Error = B::Error;
268
269 #[inline]
270 fn poll_frame(
271 self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
274 self.project().inner.poll_frame(cx)
275 }
276
277 #[inline]
278 fn is_end_stream(&self) -> bool {
279 self.inner.is_end_stream()
280 }
281
282 #[inline]
283 fn size_hint(&self) -> http_body::SizeHint {
284 self.inner.size_hint()
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 #[allow(unused_imports)]
291 use super::*;
292 use crate::test_helpers::Body;
293 use http::Request;
294 use tower::{BoxError, ServiceBuilder};
295
296 #[tokio::test]
297 async fn basic() {
298 let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair();
299
300 let mut service = ServiceBuilder::new()
301 .layer(in_flight_requests_layer)
302 .service_fn(echo);
303 assert_eq!(counter.get(), 0);
304
305 std::future::poll_fn(|cx| service.poll_ready(cx))
307 .await
308 .unwrap();
309 assert_eq!(counter.get(), 0);
310
311 let response_future = service.call(Request::new(Body::empty()));
313 assert_eq!(counter.get(), 1);
314
315 let response = response_future.await.unwrap();
317 assert_eq!(counter.get(), 1);
318
319 let body = response.into_body();
320 crate::test_helpers::to_bytes(body).await.unwrap();
321 assert_eq!(counter.get(), 0);
322 }
323
324 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
325 Ok(Response::new(req.into_body()))
326 }
327}