tower_http/timeout/
body.rs

1use crate::BoxError;
2use http_body::Body;
3use pin_project_lite::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8    time::Duration,
9};
10use tokio::time::{sleep, Sleep};
11
12pin_project! {
13    /// Middleware that applies a timeout to request and response bodies.
14    ///
15    /// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration.
16    ///
17    /// Bodies must produce data at most within the specified timeout.
18    /// If the body does not produce a requested data frame within the timeout period, it will return an error.
19    ///
20    /// # Differences from [`Timeout`][crate::timeout::Timeout]
21    ///
22    /// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body.
23    /// That timeout is not reset when bytes are handled, whether the request is active or not.
24    /// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout.
25    ///
26    /// This middleware will return a [`TimeoutError`].
27    ///
28    /// # Example
29    ///
30    /// ```
31    /// use http::{Request, Response};
32    /// use bytes::Bytes;
33    /// use http_body_util::Full;
34    /// use std::time::Duration;
35    /// use tower::ServiceBuilder;
36    /// use tower_http::timeout::RequestBodyTimeoutLayer;
37    ///
38    /// async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
39    ///     // ...
40    ///     # todo!()
41    /// }
42    ///
43    /// # #[tokio::main]
44    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
45    /// let svc = ServiceBuilder::new()
46    ///     // Timeout bodies after 30 seconds of inactivity
47    ///     .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30)))
48    ///     .service_fn(handle);
49    /// # Ok(())
50    /// # }
51    /// ```
52    pub struct TimeoutBody<B> {
53        timeout: Duration,
54        #[pin]
55        sleep: Option<Sleep>,
56        #[pin]
57        body: B,
58    }
59}
60
61impl<B> TimeoutBody<B> {
62    /// Creates a new [`TimeoutBody`].
63    pub fn new(timeout: Duration, body: B) -> Self {
64        TimeoutBody {
65            timeout,
66            sleep: None,
67            body,
68        }
69    }
70}
71
72impl<B> Body for TimeoutBody<B>
73where
74    B: Body,
75    B::Error: Into<BoxError>,
76{
77    type Data = B::Data;
78    type Error = Box<dyn std::error::Error + Send + Sync>;
79
80    fn poll_frame(
81        self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
84        let mut this = self.project();
85
86        // Start the `Sleep` if not active.
87        let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
88            some
89        } else {
90            this.sleep.set(Some(sleep(*this.timeout)));
91            this.sleep.as_mut().as_pin_mut().unwrap()
92        };
93
94        // Error if the timeout has expired.
95        if let Poll::Ready(()) = sleep_pinned.poll(cx) {
96            return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
97        }
98
99        // Check for body data.
100        let frame = ready!(this.body.poll_frame(cx));
101        // A frame is ready. Reset the `Sleep`...
102        this.sleep.set(None);
103
104        Poll::Ready(frame.transpose().map_err(Into::into).transpose())
105    }
106}
107
108/// Error for [`TimeoutBody`].
109#[derive(Debug)]
110pub struct TimeoutError(());
111
112impl std::error::Error for TimeoutError {}
113
114impl std::fmt::Display for TimeoutError {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        write!(f, "data was not received within the designated timeout")
117    }
118}
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    use bytes::Bytes;
124    use http_body::Frame;
125    use http_body_util::BodyExt;
126    use pin_project_lite::pin_project;
127    use std::{error::Error, fmt::Display};
128
129    #[derive(Debug)]
130    struct MockError;
131
132    impl Error for MockError {}
133
134    impl Display for MockError {
135        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136            write!(f, "mock error")
137        }
138    }
139
140    pin_project! {
141        struct MockBody {
142            #[pin]
143            sleep: Sleep
144        }
145    }
146
147    impl Body for MockBody {
148        type Data = Bytes;
149        type Error = MockError;
150
151        fn poll_frame(
152            self: Pin<&mut Self>,
153            cx: &mut Context<'_>,
154        ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
155            let this = self.project();
156            this.sleep
157                .poll(cx)
158                .map(|_| Some(Ok(Frame::data(vec![].into()))))
159        }
160    }
161
162    #[tokio::test]
163    async fn test_body_available_within_timeout() {
164        let mock_sleep = Duration::from_secs(1);
165        let timeout_sleep = Duration::from_secs(2);
166
167        let mock_body = MockBody {
168            sleep: sleep(mock_sleep),
169        };
170        let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
171
172        assert!(timeout_body
173            .boxed()
174            .frame()
175            .await
176            .expect("no frame")
177            .is_ok());
178    }
179
180    #[tokio::test]
181    async fn test_body_unavailable_within_timeout_error() {
182        let mock_sleep = Duration::from_secs(2);
183        let timeout_sleep = Duration::from_secs(1);
184
185        let mock_body = MockBody {
186            sleep: sleep(mock_sleep),
187        };
188        let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
189
190        assert!(timeout_body.boxed().frame().await.unwrap().is_err());
191    }
192}