Skip to main content

tower_http/timeout/
deadline_body.rs

1use crate::BoxError;
2use http_body::Body;
3use pin_project_lite::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8    time::Duration,
9};
10use tokio::time::{sleep, Sleep};
11
12pin_project! {
13    /// Wrapper around a [`Body`] that enforces a hard deadline on the entire body transfer.
14    ///
15    /// Unlike [`TimeoutBody`][super::TimeoutBody], which resets its deadline each time a frame is
16    /// received, `DeadlineBody` starts a single timer at construction and returns a
17    /// [`TimeoutError`][super::TimeoutError] if the body is not fully consumed before the deadline.
18    ///
19    /// The deadline is **wall-clock time from construction**, not cumulative poll time. The
20    /// timer continues to count even if the consumer is not actively polling the body. If you
21    /// poll some frames, pause to do other work, and then resume, the elapsed pause time counts
22    /// toward the deadline.
23    ///
24    /// # When to use this
25    ///
26    /// This is primarily useful as middleware on public-facing endpoints where you want to bound
27    /// the total wall-clock time a single request can hold resources (task slots, memory for
28    /// buffering, etc.), regardless of how frequently data trickles in. A slow client sending
29    /// one byte per second will never trip [`TimeoutBody`][super::TimeoutBody]'s idle timeout,
30    /// but will correctly trip `DeadlineBody`.
31    ///
32    /// If you only need to detect stalled connections where no data flows for a period, use
33    /// [`TimeoutBody`][super::TimeoutBody] instead. The two can be stacked if you want both
34    /// an idle timeout and a hard deadline.
35    ///
36    /// # Example
37    ///
38    /// ```
39    /// use http::{Request, Response};
40    /// use bytes::Bytes;
41    /// use http_body_util::Full;
42    /// use std::time::Duration;
43    /// use tower::ServiceBuilder;
44    /// use tower_http::timeout::RequestBodyDeadlineLayer;
45    ///
46    /// async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
47    ///     // ...
48    ///     # todo!()
49    /// }
50    ///
51    /// # #[tokio::main]
52    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
53    /// let svc = ServiceBuilder::new()
54    ///     // Timeout bodies after 30 seconds total
55    ///     .layer(RequestBodyDeadlineLayer::new(Duration::from_secs(30)))
56    ///     .service_fn(handle);
57    /// # Ok(())
58    /// # }
59    /// ```
60    pub struct DeadlineBody<B> {
61        #[pin]
62        sleep: Sleep,
63        #[pin]
64        body: B,
65    }
66}
67
68impl<B> DeadlineBody<B> {
69    /// Creates a new [`DeadlineBody`].
70    ///
71    /// The timeout starts immediately. If the body is not fully consumed within `timeout`,
72    /// subsequent `poll_frame` calls will return a [`TimeoutError`][super::TimeoutError].
73    pub fn new(timeout: Duration, body: B) -> Self {
74        DeadlineBody {
75            sleep: sleep(timeout),
76            body,
77        }
78    }
79}
80
81impl<B> Body for DeadlineBody<B>
82where
83    B: Body,
84    B::Error: Into<BoxError>,
85{
86    type Data = B::Data;
87    type Error = Box<dyn std::error::Error + Send + Sync>;
88
89    fn poll_frame(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
93        let this = self.project();
94
95        // Error if the absolute timeout has expired.
96        if let Poll::Ready(()) = this.sleep.poll(cx) {
97            return Poll::Ready(Some(Err(Box::new(super::TimeoutError(())))));
98        }
99
100        // Check for body data.
101        let frame = ready!(this.body.poll_frame(cx));
102
103        Poll::Ready(frame.transpose().map_err(Into::into).transpose())
104    }
105
106    fn is_end_stream(&self) -> bool {
107        self.body.is_end_stream()
108    }
109
110    fn size_hint(&self) -> http_body::SizeHint {
111        self.body.size_hint()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    use bytes::Bytes;
120    use http_body::Frame;
121    use http_body_util::BodyExt;
122    use pin_project_lite::pin_project;
123    use std::{error::Error, fmt::Display};
124    use tokio::time::sleep;
125
126    #[derive(Debug)]
127    struct MockError;
128
129    impl Error for MockError {}
130
131    impl Display for MockError {
132        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133            write!(f, "mock error")
134        }
135    }
136
137    pin_project! {
138        /// A body that yields a frame after a delay.
139        struct MockBody {
140            #[pin]
141            sleep: Sleep,
142        }
143    }
144
145    impl Body for MockBody {
146        type Data = Bytes;
147        type Error = MockError;
148
149        fn poll_frame(
150            self: Pin<&mut Self>,
151            cx: &mut Context<'_>,
152        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
153            let this = self.project();
154            this.sleep
155                .poll(cx)
156                .map(|_| Some(Ok(Frame::data(vec![].into()))))
157        }
158    }
159
160    pin_project! {
161        /// A body that yields multiple frames with a delay between each.
162        struct MultiFrameBody {
163            frames_remaining: usize,
164            frame_interval: Duration,
165            #[pin]
166            sleep: Option<Sleep>,
167        }
168    }
169
170    impl Body for MultiFrameBody {
171        type Data = Bytes;
172        type Error = MockError;
173
174        fn poll_frame(
175            self: Pin<&mut Self>,
176            cx: &mut Context<'_>,
177        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
178            let mut this = self.project();
179
180            if *this.frames_remaining == 0 {
181                return Poll::Ready(None);
182            }
183
184            // Start the sleep if not active.
185            let sleep_pinned = if let Some(s) = this.sleep.as_mut().as_pin_mut() {
186                s
187            } else {
188                this.sleep.set(Some(sleep(*this.frame_interval)));
189                this.sleep.as_mut().as_pin_mut().unwrap()
190            };
191
192            ready!(sleep_pinned.poll(cx));
193            this.sleep.set(None);
194            *this.frames_remaining -= 1;
195
196            Poll::Ready(Some(Ok(Frame::data(Bytes::from("chunk")))))
197        }
198    }
199
200    #[tokio::test]
201    async fn body_completes_within_timeout() {
202        let mock_body = MockBody {
203            sleep: sleep(Duration::from_millis(50)),
204        };
205        let timeout_body = DeadlineBody::new(Duration::from_millis(200), mock_body);
206
207        assert!(timeout_body
208            .boxed()
209            .frame()
210            .await
211            .expect("no frame")
212            .is_ok());
213    }
214
215    #[tokio::test]
216    async fn body_exceeds_timeout() {
217        let mock_body = MockBody {
218            sleep: sleep(Duration::from_millis(200)),
219        };
220        let timeout_body = DeadlineBody::new(Duration::from_millis(50), mock_body);
221
222        let result = timeout_body.boxed().frame().await.unwrap();
223        assert!(result.is_err());
224        assert!(result
225            .unwrap_err()
226            .downcast_ref::<super::super::TimeoutError>()
227            .is_some());
228    }
229
230    #[tokio::test]
231    async fn deadline_fires_despite_steady_frames() {
232        // Each frame arrives every 30ms (well within an idle timeout of 100ms),
233        // but total transfer takes 5 * 30ms = 150ms, exceeding the 100ms deadline.
234        let body = MultiFrameBody {
235            frames_remaining: 5,
236            frame_interval: Duration::from_millis(30),
237            sleep: None,
238        };
239        let timeout_body = DeadlineBody::new(Duration::from_millis(100), body);
240
241        let mut boxed = timeout_body.boxed();
242        let mut got_error = false;
243
244        loop {
245            match boxed.frame().await {
246                Some(Ok(_)) => {}
247                Some(Err(_)) => {
248                    got_error = true;
249                    break;
250                }
251                None => break,
252            }
253        }
254
255        assert!(
256            got_error,
257            "expected timeout error before all frames arrived"
258        );
259    }
260
261    #[tokio::test]
262    async fn all_frames_arrive_within_deadline() {
263        // Each frame arrives every 20ms, total = 3 * 20ms = 60ms, within 200ms deadline.
264        let body = MultiFrameBody {
265            frames_remaining: 3,
266            frame_interval: Duration::from_millis(20),
267            sleep: None,
268        };
269        let timeout_body = DeadlineBody::new(Duration::from_millis(200), body);
270
271        let mut boxed = timeout_body.boxed();
272        let mut frame_count = 0;
273
274        loop {
275            match boxed.frame().await {
276                Some(Ok(_)) => frame_count += 1,
277                Some(Err(e)) => panic!("unexpected error: {}", e),
278                None => break,
279            }
280        }
281
282        assert_eq!(frame_count, 3);
283    }
284}