tower_http/metrics/
in_flight_requests.rs

1//! Measure the number of in-flight requests.
2//!
3//! In-flight requests is the number of requests a service is currently processing. The processing
4//! of a request starts when it is received by the service (`tower::Service::call` is called) and
5//! is considered complete when the response body is consumed, dropped, or an error happens.
6//!
7//! # Example
8//!
9//! ```
10//! use tower::{Service, ServiceExt, ServiceBuilder};
11//! use tower_http::metrics::InFlightRequestsLayer;
12//! use http::{Request, Response};
13//! use bytes::Bytes;
14//! use http_body_util::Full;
15//! use std::{time::Duration, convert::Infallible};
16//!
17//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
18//!     // ...
19//!     # Ok(Response::new(Full::default()))
20//! }
21//!
22//! async fn update_in_flight_requests_metric(count: usize) {
23//!     // ...
24//! }
25//!
26//! # #[tokio::main]
27//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
28//! // Create a `Layer` with an associated counter.
29//! let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair();
30//!
31//! // Spawn a task that will receive the number of in-flight requests every 10 seconds.
32//! tokio::spawn(
33//!     counter.run_emitter(Duration::from_secs(10), |count| async move {
34//!         update_in_flight_requests_metric(count).await;
35//!     }),
36//! );
37//!
38//! let mut service = ServiceBuilder::new()
39//!     // Keep track of the number of in-flight requests. This will increment and decrement
40//!     // `counter` automatically.
41//!     .layer(in_flight_requests_layer)
42//!     .service_fn(handle);
43//!
44//! // Call the service.
45//! let response = service
46//!     .ready()
47//!     .await?
48//!     .call(Request::new(Full::default()))
49//!     .await?;
50//! # Ok(())
51//! # }
52//! ```
53
54use http::{Request, Response};
55use http_body::Body;
56use pin_project_lite::pin_project;
57use std::{
58    future::Future,
59    pin::Pin,
60    sync::{
61        atomic::{AtomicUsize, Ordering},
62        Arc,
63    },
64    task::{ready, Context, Poll},
65    time::Duration,
66};
67use tower_layer::Layer;
68use tower_service::Service;
69
70/// Layer for applying [`InFlightRequests`] which counts the number of in-flight requests.
71///
72/// See the [module docs](crate::metrics::in_flight_requests) for more details.
73#[derive(Clone, Debug)]
74pub struct InFlightRequestsLayer {
75    counter: InFlightRequestsCounter,
76}
77
78impl InFlightRequestsLayer {
79    /// Create a new `InFlightRequestsLayer` and its associated counter.
80    pub fn pair() -> (Self, InFlightRequestsCounter) {
81        let counter = InFlightRequestsCounter::new();
82        let layer = Self::new(counter.clone());
83        (layer, counter)
84    }
85
86    /// Create a new `InFlightRequestsLayer` that will update the given counter.
87    pub fn new(counter: InFlightRequestsCounter) -> Self {
88        Self { counter }
89    }
90}
91
92impl<S> Layer<S> for InFlightRequestsLayer {
93    type Service = InFlightRequests<S>;
94
95    fn layer(&self, inner: S) -> Self::Service {
96        InFlightRequests {
97            inner,
98            counter: self.counter.clone(),
99        }
100    }
101}
102
103/// Middleware that counts the number of in-flight requests.
104///
105/// See the [module docs](crate::metrics::in_flight_requests) for more details.
106#[derive(Clone, Debug)]
107pub struct InFlightRequests<S> {
108    inner: S,
109    counter: InFlightRequestsCounter,
110}
111
112impl<S> InFlightRequests<S> {
113    /// Create a new `InFlightRequests` and its associated counter.
114    pub fn pair(inner: S) -> (Self, InFlightRequestsCounter) {
115        let counter = InFlightRequestsCounter::new();
116        let service = Self::new(inner, counter.clone());
117        (service, counter)
118    }
119
120    /// Create a new `InFlightRequests` that will update the given counter.
121    pub fn new(inner: S, counter: InFlightRequestsCounter) -> Self {
122        Self { inner, counter }
123    }
124
125    define_inner_service_accessors!();
126}
127
128/// An atomic counter that keeps track of the number of in-flight requests.
129///
130/// This will normally combined with [`InFlightRequestsLayer`] or [`InFlightRequests`] which will
131/// update the counter as requests arrive.
132#[derive(Debug, Clone, Default)]
133pub struct InFlightRequestsCounter {
134    count: Arc<AtomicUsize>,
135}
136
137impl InFlightRequestsCounter {
138    /// Create a new `InFlightRequestsCounter`.
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    /// Get the current number of in-flight requests.
144    pub fn get(&self) -> usize {
145        self.count.load(Ordering::Relaxed)
146    }
147
148    fn increment(&self) -> IncrementGuard {
149        self.count.fetch_add(1, Ordering::Relaxed);
150        IncrementGuard {
151            count: self.count.clone(),
152        }
153    }
154
155    /// Run a future every `interval` which receives the current number of in-flight requests.
156    ///
157    /// This can be used to send the current count to your metrics system.
158    ///
159    /// This function will loop forever so normally it is called with [`tokio::spawn`]:
160    ///
161    /// ```rust,no_run
162    /// use tower_http::metrics::in_flight_requests::InFlightRequestsCounter;
163    /// use std::time::Duration;
164    ///
165    /// let counter = InFlightRequestsCounter::new();
166    ///
167    /// tokio::spawn(
168    ///     counter.run_emitter(Duration::from_secs(10), |count: usize| async move {
169    ///         // Send `count` to metrics system.
170    ///     }),
171    /// );
172    /// ```
173    pub async fn run_emitter<F, Fut>(mut self, interval: Duration, mut emit: F)
174    where
175        F: FnMut(usize) -> Fut + Send + 'static,
176        Fut: Future<Output = ()> + Send,
177    {
178        let mut interval = tokio::time::interval(interval);
179
180        loop {
181            // if all producers have gone away we don't need to emit anymore
182            match Arc::try_unwrap(self.count) {
183                Ok(_) => return,
184                Err(shared_count) => {
185                    self = Self {
186                        count: shared_count,
187                    }
188                }
189            }
190
191            interval.tick().await;
192            emit(self.get()).await;
193        }
194    }
195}
196
197struct IncrementGuard {
198    count: Arc<AtomicUsize>,
199}
200
201impl Drop for IncrementGuard {
202    fn drop(&mut self) {
203        self.count.fetch_sub(1, Ordering::Relaxed);
204    }
205}
206
207impl<S, R, ResBody> Service<Request<R>> for InFlightRequests<S>
208where
209    S: Service<Request<R>, Response = Response<ResBody>>,
210{
211    type Response = Response<ResponseBody<ResBody>>;
212    type Error = S::Error;
213    type Future = ResponseFuture<S::Future>;
214
215    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        self.inner.poll_ready(cx)
217    }
218
219    fn call(&mut self, req: Request<R>) -> Self::Future {
220        let guard = self.counter.increment();
221        ResponseFuture {
222            inner: self.inner.call(req),
223            guard: Some(guard),
224        }
225    }
226}
227
228pin_project! {
229    /// Response future for [`InFlightRequests`].
230    pub struct ResponseFuture<F> {
231        #[pin]
232        inner: F,
233        guard: Option<IncrementGuard>,
234    }
235}
236
237impl<F, B, E> Future for ResponseFuture<F>
238where
239    F: Future<Output = Result<Response<B>, E>>,
240{
241    type Output = Result<Response<ResponseBody<B>>, E>;
242
243    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244        let this = self.project();
245        let response = ready!(this.inner.poll(cx))?;
246        let guard = this.guard.take().unwrap();
247        let response = response.map(move |body| ResponseBody { inner: body, guard });
248
249        Poll::Ready(Ok(response))
250    }
251}
252
253pin_project! {
254    /// Response body for [`InFlightRequests`].
255    pub struct ResponseBody<B> {
256        #[pin]
257        inner: B,
258        guard: IncrementGuard,
259    }
260}
261
262impl<B> Body for ResponseBody<B>
263where
264    B: Body,
265{
266    type Data = B::Data;
267    type Error = B::Error;
268
269    #[inline]
270    fn poll_frame(
271        self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
274        self.project().inner.poll_frame(cx)
275    }
276
277    #[inline]
278    fn is_end_stream(&self) -> bool {
279        self.inner.is_end_stream()
280    }
281
282    #[inline]
283    fn size_hint(&self) -> http_body::SizeHint {
284        self.inner.size_hint()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    #[allow(unused_imports)]
291    use super::*;
292    use crate::test_helpers::Body;
293    use http::Request;
294    use tower::{BoxError, ServiceBuilder};
295
296    #[tokio::test]
297    async fn basic() {
298        let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair();
299
300        let mut service = ServiceBuilder::new()
301            .layer(in_flight_requests_layer)
302            .service_fn(echo);
303        assert_eq!(counter.get(), 0);
304
305        // driving service to ready shouldn't increment the counter
306        std::future::poll_fn(|cx| service.poll_ready(cx))
307            .await
308            .unwrap();
309        assert_eq!(counter.get(), 0);
310
311        // creating the response future should increment the count
312        let response_future = service.call(Request::new(Body::empty()));
313        assert_eq!(counter.get(), 1);
314
315        // count shouldn't decrement until the full body has been comsumed
316        let response = response_future.await.unwrap();
317        assert_eq!(counter.get(), 1);
318
319        let body = response.into_body();
320        crate::test_helpers::to_bytes(body).await.unwrap();
321        assert_eq!(counter.get(), 0);
322    }
323
324    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
325        Ok(Response::new(req.into_body()))
326    }
327}