tower_http/timeout/
deadline_body.rs1use crate::BoxError;
2use http_body::Body;
3use pin_project_lite::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10use tokio::time::{sleep, Sleep};
11
12pin_project! {
13 pub struct DeadlineBody<B> {
61 #[pin]
62 sleep: Sleep,
63 #[pin]
64 body: B,
65 }
66}
67
68impl<B> DeadlineBody<B> {
69 pub fn new(timeout: Duration, body: B) -> Self {
74 DeadlineBody {
75 sleep: sleep(timeout),
76 body,
77 }
78 }
79}
80
81impl<B> Body for DeadlineBody<B>
82where
83 B: Body,
84 B::Error: Into<BoxError>,
85{
86 type Data = B::Data;
87 type Error = Box<dyn std::error::Error + Send + Sync>;
88
89 fn poll_frame(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
93 let this = self.project();
94
95 if let Poll::Ready(()) = this.sleep.poll(cx) {
97 return Poll::Ready(Some(Err(Box::new(super::TimeoutError(())))));
98 }
99
100 let frame = ready!(this.body.poll_frame(cx));
102
103 Poll::Ready(frame.transpose().map_err(Into::into).transpose())
104 }
105
106 fn is_end_stream(&self) -> bool {
107 self.body.is_end_stream()
108 }
109
110 fn size_hint(&self) -> http_body::SizeHint {
111 self.body.size_hint()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 use bytes::Bytes;
120 use http_body::Frame;
121 use http_body_util::BodyExt;
122 use pin_project_lite::pin_project;
123 use std::{error::Error, fmt::Display};
124 use tokio::time::sleep;
125
126 #[derive(Debug)]
127 struct MockError;
128
129 impl Error for MockError {}
130
131 impl Display for MockError {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "mock error")
134 }
135 }
136
137 pin_project! {
138 struct MockBody {
140 #[pin]
141 sleep: Sleep,
142 }
143 }
144
145 impl Body for MockBody {
146 type Data = Bytes;
147 type Error = MockError;
148
149 fn poll_frame(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
153 let this = self.project();
154 this.sleep
155 .poll(cx)
156 .map(|_| Some(Ok(Frame::data(vec![].into()))))
157 }
158 }
159
160 pin_project! {
161 struct MultiFrameBody {
163 frames_remaining: usize,
164 frame_interval: Duration,
165 #[pin]
166 sleep: Option<Sleep>,
167 }
168 }
169
170 impl Body for MultiFrameBody {
171 type Data = Bytes;
172 type Error = MockError;
173
174 fn poll_frame(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
178 let mut this = self.project();
179
180 if *this.frames_remaining == 0 {
181 return Poll::Ready(None);
182 }
183
184 let sleep_pinned = if let Some(s) = this.sleep.as_mut().as_pin_mut() {
186 s
187 } else {
188 this.sleep.set(Some(sleep(*this.frame_interval)));
189 this.sleep.as_mut().as_pin_mut().unwrap()
190 };
191
192 ready!(sleep_pinned.poll(cx));
193 this.sleep.set(None);
194 *this.frames_remaining -= 1;
195
196 Poll::Ready(Some(Ok(Frame::data(Bytes::from("chunk")))))
197 }
198 }
199
200 #[tokio::test]
201 async fn body_completes_within_timeout() {
202 let mock_body = MockBody {
203 sleep: sleep(Duration::from_millis(50)),
204 };
205 let timeout_body = DeadlineBody::new(Duration::from_millis(200), mock_body);
206
207 assert!(timeout_body
208 .boxed()
209 .frame()
210 .await
211 .expect("no frame")
212 .is_ok());
213 }
214
215 #[tokio::test]
216 async fn body_exceeds_timeout() {
217 let mock_body = MockBody {
218 sleep: sleep(Duration::from_millis(200)),
219 };
220 let timeout_body = DeadlineBody::new(Duration::from_millis(50), mock_body);
221
222 let result = timeout_body.boxed().frame().await.unwrap();
223 assert!(result.is_err());
224 assert!(result
225 .unwrap_err()
226 .downcast_ref::<super::super::TimeoutError>()
227 .is_some());
228 }
229
230 #[tokio::test]
231 async fn deadline_fires_despite_steady_frames() {
232 let body = MultiFrameBody {
235 frames_remaining: 5,
236 frame_interval: Duration::from_millis(30),
237 sleep: None,
238 };
239 let timeout_body = DeadlineBody::new(Duration::from_millis(100), body);
240
241 let mut boxed = timeout_body.boxed();
242 let mut got_error = false;
243
244 loop {
245 match boxed.frame().await {
246 Some(Ok(_)) => {}
247 Some(Err(_)) => {
248 got_error = true;
249 break;
250 }
251 None => break,
252 }
253 }
254
255 assert!(
256 got_error,
257 "expected timeout error before all frames arrived"
258 );
259 }
260
261 #[tokio::test]
262 async fn all_frames_arrive_within_deadline() {
263 let body = MultiFrameBody {
265 frames_remaining: 3,
266 frame_interval: Duration::from_millis(20),
267 sleep: None,
268 };
269 let timeout_body = DeadlineBody::new(Duration::from_millis(200), body);
270
271 let mut boxed = timeout_body.boxed();
272 let mut frame_count = 0;
273
274 loop {
275 match boxed.frame().await {
276 Some(Ok(_)) => frame_count += 1,
277 Some(Err(e)) => panic!("unexpected error: {}", e),
278 None => break,
279 }
280 }
281
282 assert_eq!(frame_count, 3);
283 }
284}