Skip to main content

tower_http/timeout/
service.rs

1use crate::timeout::body::TimeoutBody;
2use crate::timeout::deadline_body::DeadlineBody;
3use http::{Request, Response, StatusCode};
4use pin_project_lite::pin_project;
5use std::{
6    future::Future,
7    pin::Pin,
8    task::{ready, Context, Poll},
9    time::Duration,
10};
11use tokio::time::Sleep;
12use tower_layer::Layer;
13use tower_service::Service;
14
15/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests.
16///
17/// See the [module docs](super) for an example.
18#[derive(Debug, Clone, Copy)]
19pub struct TimeoutLayer {
20    timeout: Duration,
21    status_code: StatusCode,
22}
23
24impl TimeoutLayer {
25    /// Creates a new [`TimeoutLayer`].
26    ///
27    /// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
28    /// To customize the response status code, use the `with_status_code` method.
29    #[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")]
30    pub fn new(timeout: Duration) -> Self {
31        Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
32    }
33
34    /// Creates a new [`TimeoutLayer`] with the specified status code for the timeout response.
35    pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self {
36        Self {
37            timeout,
38            status_code,
39        }
40    }
41}
42
43impl<S> Layer<S> for TimeoutLayer {
44    type Service = Timeout<S>;
45
46    fn layer(&self, inner: S) -> Self::Service {
47        Timeout::with_status_code(inner, self.status_code, self.timeout)
48    }
49}
50
51/// Middleware which apply a timeout to requests.
52///
53/// See the [module docs](super) for an example.
54#[derive(Debug, Clone, Copy)]
55pub struct Timeout<S> {
56    inner: S,
57    timeout: Duration,
58    status_code: StatusCode,
59}
60
61impl<S> Timeout<S> {
62    /// Creates a new [`Timeout`].
63    ///
64    /// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
65    /// To customize the response status code, use the `with_status_code` method.
66    #[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")]
67    pub fn new(inner: S, timeout: Duration) -> Self {
68        Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout)
69    }
70
71    /// Creates a new [`Timeout`] with the specified status code for the timeout response.
72    pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self {
73        Self {
74            inner,
75            timeout,
76            status_code,
77        }
78    }
79
80    define_inner_service_accessors!();
81
82    /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware.
83    ///
84    /// [`Layer`]: tower_layer::Layer
85    #[deprecated(
86        since = "0.6.7",
87        note = "Use `Timeout::layer_with_status_code` instead"
88    )]
89    pub fn layer(timeout: Duration) -> TimeoutLayer {
90        TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
91    }
92
93    /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware with the specified status code.
94    pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer {
95        TimeoutLayer::with_status_code(status_code, timeout)
96    }
97}
98
99impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Timeout<S>
100where
101    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
102    ResBody: Default,
103{
104    type Response = S::Response;
105    type Error = S::Error;
106    type Future = ResponseFuture<S::Future>;
107
108    #[inline]
109    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        self.inner.poll_ready(cx)
111    }
112
113    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
114        let sleep = tokio::time::sleep(self.timeout);
115        ResponseFuture {
116            inner: self.inner.call(req),
117            sleep,
118            status_code: self.status_code,
119        }
120    }
121}
122
123pin_project! {
124    /// Response future for [`Timeout`].
125    pub struct ResponseFuture<F> {
126        #[pin]
127        inner: F,
128        #[pin]
129        sleep: Sleep,
130        status_code: StatusCode,
131    }
132}
133
134impl<F, B, E> Future for ResponseFuture<F>
135where
136    F: Future<Output = Result<Response<B>, E>>,
137    B: Default,
138{
139    type Output = Result<Response<B>, E>;
140
141    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142        let this = self.project();
143
144        if this.sleep.poll(cx).is_ready() {
145            let mut res = Response::new(B::default());
146            *res.status_mut() = *this.status_code;
147            return Poll::Ready(Ok(res));
148        }
149
150        this.inner.poll(cx)
151    }
152}
153
154/// Applies a [`TimeoutBody`] to the request body.
155#[derive(Clone, Debug)]
156pub struct RequestBodyTimeoutLayer {
157    timeout: Duration,
158}
159
160impl RequestBodyTimeoutLayer {
161    /// Creates a new [`RequestBodyTimeoutLayer`].
162    pub fn new(timeout: Duration) -> Self {
163        Self { timeout }
164    }
165}
166
167impl<S> Layer<S> for RequestBodyTimeoutLayer {
168    type Service = RequestBodyTimeout<S>;
169
170    fn layer(&self, inner: S) -> Self::Service {
171        RequestBodyTimeout::new(inner, self.timeout)
172    }
173}
174
175/// Applies a [`TimeoutBody`] to the request body.
176#[derive(Clone, Debug)]
177pub struct RequestBodyTimeout<S> {
178    inner: S,
179    timeout: Duration,
180}
181
182impl<S> RequestBodyTimeout<S> {
183    /// Creates a new [`RequestBodyTimeout`].
184    pub fn new(service: S, timeout: Duration) -> Self {
185        Self {
186            inner: service,
187            timeout,
188        }
189    }
190
191    /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware.
192    ///
193    /// [`Layer`]: tower_layer::Layer
194    pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer {
195        RequestBodyTimeoutLayer::new(timeout)
196    }
197
198    define_inner_service_accessors!();
199}
200
201impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyTimeout<S>
202where
203    S: Service<Request<TimeoutBody<ReqBody>>>,
204{
205    type Response = S::Response;
206    type Error = S::Error;
207    type Future = S::Future;
208
209    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
210        self.inner.poll_ready(cx)
211    }
212
213    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
214        let req = req.map(|body| TimeoutBody::new(self.timeout, body));
215        self.inner.call(req)
216    }
217}
218
219/// Applies a [`TimeoutBody`] to the response body.
220#[derive(Clone)]
221pub struct ResponseBodyTimeoutLayer {
222    timeout: Duration,
223}
224
225impl ResponseBodyTimeoutLayer {
226    /// Creates a new [`ResponseBodyTimeoutLayer`].
227    pub fn new(timeout: Duration) -> Self {
228        Self { timeout }
229    }
230}
231
232impl<S> Layer<S> for ResponseBodyTimeoutLayer {
233    type Service = ResponseBodyTimeout<S>;
234
235    fn layer(&self, inner: S) -> Self::Service {
236        ResponseBodyTimeout::new(inner, self.timeout)
237    }
238}
239
240/// Applies a [`TimeoutBody`] to the response body.
241#[derive(Clone)]
242pub struct ResponseBodyTimeout<S> {
243    inner: S,
244    timeout: Duration,
245}
246
247impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyTimeout<S>
248where
249    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
250{
251    type Response = Response<TimeoutBody<ResBody>>;
252    type Error = S::Error;
253    type Future = ResponseBodyTimeoutFuture<S::Future>;
254
255    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
256        self.inner.poll_ready(cx)
257    }
258
259    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
260        ResponseBodyTimeoutFuture {
261            inner: self.inner.call(req),
262            timeout: self.timeout,
263        }
264    }
265}
266
267impl<S> ResponseBodyTimeout<S> {
268    /// Creates a new [`ResponseBodyTimeout`].
269    pub fn new(service: S, timeout: Duration) -> Self {
270        Self {
271            inner: service,
272            timeout,
273        }
274    }
275
276    /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware.
277    ///
278    /// [`Layer`]: tower_layer::Layer
279    pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer {
280        ResponseBodyTimeoutLayer::new(timeout)
281    }
282
283    define_inner_service_accessors!();
284}
285
286pin_project! {
287    /// Response future for [`ResponseBodyTimeout`].
288    pub struct ResponseBodyTimeoutFuture<Fut> {
289        #[pin]
290        inner: Fut,
291        timeout: Duration,
292    }
293}
294
295impl<Fut, ResBody, E> Future for ResponseBodyTimeoutFuture<Fut>
296where
297    Fut: Future<Output = Result<Response<ResBody>, E>>,
298{
299    type Output = Result<Response<TimeoutBody<ResBody>>, E>;
300
301    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
302        let timeout = self.timeout;
303        let this = self.project();
304        let res = ready!(this.inner.poll(cx))?;
305        Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body))))
306    }
307}
308
309/// Applies a [`DeadlineBody`] to the request body.
310///
311/// Unlike [`RequestBodyTimeoutLayer`], which resets on each frame, this enforces a hard
312/// deadline on the entire body transfer.
313#[derive(Clone, Debug)]
314pub struct RequestBodyDeadlineLayer {
315    timeout: Duration,
316}
317
318impl RequestBodyDeadlineLayer {
319    /// Creates a new [`RequestBodyDeadlineLayer`].
320    pub fn new(timeout: Duration) -> Self {
321        Self { timeout }
322    }
323}
324
325impl<S> Layer<S> for RequestBodyDeadlineLayer {
326    type Service = RequestBodyDeadline<S>;
327
328    fn layer(&self, inner: S) -> Self::Service {
329        RequestBodyDeadline::new(inner, self.timeout)
330    }
331}
332
333/// Applies a [`DeadlineBody`] to the request body.
334#[derive(Clone, Debug)]
335pub struct RequestBodyDeadline<S> {
336    inner: S,
337    timeout: Duration,
338}
339
340impl<S> RequestBodyDeadline<S> {
341    /// Creates a new [`RequestBodyDeadline`].
342    pub fn new(service: S, timeout: Duration) -> Self {
343        Self {
344            inner: service,
345            timeout,
346        }
347    }
348
349    /// Returns a new [`Layer`] that wraps services with a [`RequestBodyDeadlineLayer`] middleware.
350    ///
351    /// [`Layer`]: tower_layer::Layer
352    pub fn layer(timeout: Duration) -> RequestBodyDeadlineLayer {
353        RequestBodyDeadlineLayer::new(timeout)
354    }
355
356    define_inner_service_accessors!();
357}
358
359impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyDeadline<S>
360where
361    S: Service<Request<DeadlineBody<ReqBody>>>,
362{
363    type Response = S::Response;
364    type Error = S::Error;
365    type Future = S::Future;
366
367    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
368        self.inner.poll_ready(cx)
369    }
370
371    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
372        let req = req.map(|body| DeadlineBody::new(self.timeout, body));
373        self.inner.call(req)
374    }
375}
376
377/// Applies a [`DeadlineBody`] to the response body.
378///
379/// Unlike [`ResponseBodyTimeoutLayer`], which resets on each frame, this enforces a hard
380/// deadline on the entire body transfer.
381#[derive(Clone)]
382pub struct ResponseBodyDeadlineLayer {
383    timeout: Duration,
384}
385
386impl ResponseBodyDeadlineLayer {
387    /// Creates a new [`ResponseBodyDeadlineLayer`].
388    pub fn new(timeout: Duration) -> Self {
389        Self { timeout }
390    }
391}
392
393impl<S> Layer<S> for ResponseBodyDeadlineLayer {
394    type Service = ResponseBodyDeadline<S>;
395
396    fn layer(&self, inner: S) -> Self::Service {
397        ResponseBodyDeadline::new(inner, self.timeout)
398    }
399}
400
401/// Applies a [`DeadlineBody`] to the response body.
402#[derive(Clone)]
403pub struct ResponseBodyDeadline<S> {
404    inner: S,
405    timeout: Duration,
406}
407
408impl<S> ResponseBodyDeadline<S> {
409    /// Creates a new [`ResponseBodyDeadline`].
410    pub fn new(service: S, timeout: Duration) -> Self {
411        Self {
412            inner: service,
413            timeout,
414        }
415    }
416
417    /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyDeadlineLayer`] middleware.
418    ///
419    /// [`Layer`]: tower_layer::Layer
420    pub fn layer(timeout: Duration) -> ResponseBodyDeadlineLayer {
421        ResponseBodyDeadlineLayer::new(timeout)
422    }
423
424    define_inner_service_accessors!();
425}
426
427impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyDeadline<S>
428where
429    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
430{
431    type Response = Response<DeadlineBody<ResBody>>;
432    type Error = S::Error;
433    type Future = ResponseBodyDeadlineFuture<S::Future>;
434
435    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436        self.inner.poll_ready(cx)
437    }
438
439    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
440        ResponseBodyDeadlineFuture {
441            inner: self.inner.call(req),
442            timeout: self.timeout,
443        }
444    }
445}
446
447pin_project! {
448    /// Response future for [`ResponseBodyDeadline`].
449    pub struct ResponseBodyDeadlineFuture<Fut> {
450        #[pin]
451        inner: Fut,
452        timeout: Duration,
453    }
454}
455
456impl<Fut, ResBody, E> Future for ResponseBodyDeadlineFuture<Fut>
457where
458    Fut: Future<Output = Result<Response<ResBody>, E>>,
459{
460    type Output = Result<Response<DeadlineBody<ResBody>>, E>;
461
462    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
463        let timeout = self.timeout;
464        let this = self.project();
465        let res = ready!(this.inner.poll(cx))?;
466        Poll::Ready(Ok(res.map(|body| DeadlineBody::new(timeout, body))))
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::test_helpers::Body;
474    use http::{Request, Response, StatusCode};
475    use std::time::Duration;
476    use tower::{BoxError, ServiceBuilder, ServiceExt};
477
478    #[tokio::test]
479    async fn request_completes_within_timeout() {
480        let mut service = ServiceBuilder::new()
481            .layer(TimeoutLayer::with_status_code(
482                StatusCode::GATEWAY_TIMEOUT,
483                Duration::from_secs(1),
484            ))
485            .service_fn(fast_handler);
486
487        let request = Request::get("/").body(Body::empty()).unwrap();
488        let res = service.ready().await.unwrap().call(request).await.unwrap();
489
490        assert_eq!(res.status(), StatusCode::OK);
491    }
492
493    #[tokio::test]
494    async fn timeout_middleware_with_custom_status_code() {
495        let timeout_service = Timeout::with_status_code(
496            tower::service_fn(slow_handler),
497            StatusCode::REQUEST_TIMEOUT,
498            Duration::from_millis(10),
499        );
500
501        let mut service = ServiceBuilder::new().service(timeout_service);
502
503        let request = Request::get("/").body(Body::empty()).unwrap();
504        let res = service.ready().await.unwrap().call(request).await.unwrap();
505
506        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
507    }
508
509    #[tokio::test]
510    async fn timeout_response_has_empty_body() {
511        let mut service = ServiceBuilder::new()
512            .layer(TimeoutLayer::with_status_code(
513                StatusCode::GATEWAY_TIMEOUT,
514                Duration::from_millis(10),
515            ))
516            .service_fn(slow_handler);
517
518        let request = Request::get("/").body(Body::empty()).unwrap();
519        let res = service.ready().await.unwrap().call(request).await.unwrap();
520
521        assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT);
522
523        // Verify the body is empty (default)
524        use http_body_util::BodyExt;
525        let body = res.into_body();
526        let bytes = body.collect().await.unwrap().to_bytes();
527        assert!(bytes.is_empty());
528    }
529
530    #[tokio::test]
531    async fn deprecated_new_method_compatibility() {
532        #[allow(deprecated)]
533        let layer = TimeoutLayer::new(Duration::from_millis(10));
534
535        let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler);
536
537        let request = Request::get("/").body(Body::empty()).unwrap();
538        let res = service.ready().await.unwrap().call(request).await.unwrap();
539
540        // Should use default 408 status code
541        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
542    }
543
544    async fn slow_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
545        tokio::time::sleep(Duration::from_secs(10)).await;
546        Ok(Response::builder()
547            .status(StatusCode::OK)
548            .body(Body::empty())
549            .unwrap())
550    }
551
552    async fn fast_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
553        Ok(Response::builder()
554            .status(StatusCode::OK)
555            .body(Body::empty())
556            .unwrap())
557    }
558}