tower_http/metrics/
in_flight_requests.rs1use http::{Request, Response};
55use http_body::Body;
56use pin_project_lite::pin_project;
57use std::{
58    future::Future,
59    pin::Pin,
60    sync::{
61        atomic::{AtomicUsize, Ordering},
62        Arc,
63    },
64    task::{ready, Context, Poll},
65    time::Duration,
66};
67use tower_layer::Layer;
68use tower_service::Service;
69
70#[derive(Clone, Debug)]
74pub struct InFlightRequestsLayer {
75    counter: InFlightRequestsCounter,
76}
77
78impl InFlightRequestsLayer {
79    pub fn pair() -> (Self, InFlightRequestsCounter) {
81        let counter = InFlightRequestsCounter::new();
82        let layer = Self::new(counter.clone());
83        (layer, counter)
84    }
85
86    pub fn new(counter: InFlightRequestsCounter) -> Self {
88        Self { counter }
89    }
90}
91
92impl<S> Layer<S> for InFlightRequestsLayer {
93    type Service = InFlightRequests<S>;
94
95    fn layer(&self, inner: S) -> Self::Service {
96        InFlightRequests {
97            inner,
98            counter: self.counter.clone(),
99        }
100    }
101}
102
103#[derive(Clone, Debug)]
107pub struct InFlightRequests<S> {
108    inner: S,
109    counter: InFlightRequestsCounter,
110}
111
112impl<S> InFlightRequests<S> {
113    pub fn pair(inner: S) -> (Self, InFlightRequestsCounter) {
115        let counter = InFlightRequestsCounter::new();
116        let service = Self::new(inner, counter.clone());
117        (service, counter)
118    }
119
120    pub fn new(inner: S, counter: InFlightRequestsCounter) -> Self {
122        Self { inner, counter }
123    }
124
125    define_inner_service_accessors!();
126}
127
128#[derive(Debug, Clone, Default)]
133pub struct InFlightRequestsCounter {
134    count: Arc<AtomicUsize>,
135}
136
137impl InFlightRequestsCounter {
138    pub fn new() -> Self {
140        Self::default()
141    }
142
143    pub fn get(&self) -> usize {
145        self.count.load(Ordering::Relaxed)
146    }
147
148    fn increment(&self) -> IncrementGuard {
149        self.count.fetch_add(1, Ordering::Relaxed);
150        IncrementGuard {
151            count: self.count.clone(),
152        }
153    }
154
155    pub async fn run_emitter<F, Fut>(mut self, interval: Duration, mut emit: F)
174    where
175        F: FnMut(usize) -> Fut + Send + 'static,
176        Fut: Future<Output = ()> + Send,
177    {
178        let mut interval = tokio::time::interval(interval);
179
180        loop {
181            match Arc::try_unwrap(self.count) {
183                Ok(_) => return,
184                Err(shared_count) => {
185                    self = Self {
186                        count: shared_count,
187                    }
188                }
189            }
190
191            interval.tick().await;
192            emit(self.get()).await;
193        }
194    }
195}
196
197struct IncrementGuard {
198    count: Arc<AtomicUsize>,
199}
200
201impl Drop for IncrementGuard {
202    fn drop(&mut self) {
203        self.count.fetch_sub(1, Ordering::Relaxed);
204    }
205}
206
207impl<S, R, ResBody> Service<Request<R>> for InFlightRequests<S>
208where
209    S: Service<Request<R>, Response = Response<ResBody>>,
210{
211    type Response = Response<ResponseBody<ResBody>>;
212    type Error = S::Error;
213    type Future = ResponseFuture<S::Future>;
214
215    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        self.inner.poll_ready(cx)
217    }
218
219    fn call(&mut self, req: Request<R>) -> Self::Future {
220        let guard = self.counter.increment();
221        ResponseFuture {
222            inner: self.inner.call(req),
223            guard: Some(guard),
224        }
225    }
226}
227
228pin_project! {
229    pub struct ResponseFuture<F> {
231        #[pin]
232        inner: F,
233        guard: Option<IncrementGuard>,
234    }
235}
236
237impl<F, B, E> Future for ResponseFuture<F>
238where
239    F: Future<Output = Result<Response<B>, E>>,
240{
241    type Output = Result<Response<ResponseBody<B>>, E>;
242
243    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244        let this = self.project();
245        let response = ready!(this.inner.poll(cx))?;
246        let guard = this.guard.take().unwrap();
247        let response = response.map(move |body| ResponseBody { inner: body, guard });
248
249        Poll::Ready(Ok(response))
250    }
251}
252
253pin_project! {
254    pub struct ResponseBody<B> {
256        #[pin]
257        inner: B,
258        guard: IncrementGuard,
259    }
260}
261
262impl<B> Body for ResponseBody<B>
263where
264    B: Body,
265{
266    type Data = B::Data;
267    type Error = B::Error;
268
269    #[inline]
270    fn poll_frame(
271        self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
274        self.project().inner.poll_frame(cx)
275    }
276
277    #[inline]
278    fn is_end_stream(&self) -> bool {
279        self.inner.is_end_stream()
280    }
281
282    #[inline]
283    fn size_hint(&self) -> http_body::SizeHint {
284        self.inner.size_hint()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    #[allow(unused_imports)]
291    use super::*;
292    use crate::test_helpers::Body;
293    use http::Request;
294    use tower::{BoxError, ServiceBuilder};
295
296    #[tokio::test]
297    async fn basic() {
298        let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair();
299
300        let mut service = ServiceBuilder::new()
301            .layer(in_flight_requests_layer)
302            .service_fn(echo);
303        assert_eq!(counter.get(), 0);
304
305        std::future::poll_fn(|cx| service.poll_ready(cx))
307            .await
308            .unwrap();
309        assert_eq!(counter.get(), 0);
310
311        let response_future = service.call(Request::new(Body::empty()));
313        assert_eq!(counter.get(), 1);
314
315        let response = response_future.await.unwrap();
317        assert_eq!(counter.get(), 1);
318
319        let body = response.into_body();
320        crate::test_helpers::to_bytes(body).await.unwrap();
321        assert_eq!(counter.get(), 0);
322    }
323
324    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
325        Ok(Response::new(req.into_body()))
326    }
327}