1use crate::timeout::body::TimeoutBody;
2use http::{Request, Response, StatusCode};
3use pin_project_lite::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10use tokio::time::Sleep;
11use tower_layer::Layer;
12use tower_service::Service;
13
14#[derive(Debug, Clone, Copy)]
18pub struct TimeoutLayer {
19 timeout: Duration,
20 status_code: StatusCode,
21}
22
23impl TimeoutLayer {
24 #[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")]
29 pub fn new(timeout: Duration) -> Self {
30 Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
31 }
32
33 pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self {
35 Self {
36 timeout,
37 status_code,
38 }
39 }
40}
41
42impl<S> Layer<S> for TimeoutLayer {
43 type Service = Timeout<S>;
44
45 fn layer(&self, inner: S) -> Self::Service {
46 Timeout::with_status_code(inner, self.status_code, self.timeout)
47 }
48}
49
50#[derive(Debug, Clone, Copy)]
54pub struct Timeout<S> {
55 inner: S,
56 timeout: Duration,
57 status_code: StatusCode,
58}
59
60impl<S> Timeout<S> {
61 #[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")]
66 pub fn new(inner: S, timeout: Duration) -> Self {
67 Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout)
68 }
69
70 pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self {
72 Self {
73 inner,
74 timeout,
75 status_code,
76 }
77 }
78
79 define_inner_service_accessors!();
80
81 #[deprecated(
85 since = "0.6.7",
86 note = "Use `Timeout::layer_with_status_code` instead"
87 )]
88 pub fn layer(timeout: Duration) -> TimeoutLayer {
89 TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
90 }
91
92 pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer {
94 TimeoutLayer::with_status_code(status_code, timeout)
95 }
96}
97
98impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Timeout<S>
99where
100 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
101 ResBody: Default,
102{
103 type Response = S::Response;
104 type Error = S::Error;
105 type Future = ResponseFuture<S::Future>;
106
107 #[inline]
108 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 self.inner.poll_ready(cx)
110 }
111
112 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
113 let sleep = tokio::time::sleep(self.timeout);
114 ResponseFuture {
115 inner: self.inner.call(req),
116 sleep,
117 status_code: self.status_code,
118 }
119 }
120}
121
122pin_project! {
123 pub struct ResponseFuture<F> {
125 #[pin]
126 inner: F,
127 #[pin]
128 sleep: Sleep,
129 status_code: StatusCode,
130 }
131}
132
133impl<F, B, E> Future for ResponseFuture<F>
134where
135 F: Future<Output = Result<Response<B>, E>>,
136 B: Default,
137{
138 type Output = Result<Response<B>, E>;
139
140 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 let this = self.project();
142
143 if this.sleep.poll(cx).is_ready() {
144 let mut res = Response::new(B::default());
145 *res.status_mut() = *this.status_code;
146 return Poll::Ready(Ok(res));
147 }
148
149 this.inner.poll(cx)
150 }
151}
152
153#[derive(Clone, Debug)]
155pub struct RequestBodyTimeoutLayer {
156 timeout: Duration,
157}
158
159impl RequestBodyTimeoutLayer {
160 pub fn new(timeout: Duration) -> Self {
162 Self { timeout }
163 }
164}
165
166impl<S> Layer<S> for RequestBodyTimeoutLayer {
167 type Service = RequestBodyTimeout<S>;
168
169 fn layer(&self, inner: S) -> Self::Service {
170 RequestBodyTimeout::new(inner, self.timeout)
171 }
172}
173
174#[derive(Clone, Debug)]
176pub struct RequestBodyTimeout<S> {
177 inner: S,
178 timeout: Duration,
179}
180
181impl<S> RequestBodyTimeout<S> {
182 pub fn new(service: S, timeout: Duration) -> Self {
184 Self {
185 inner: service,
186 timeout,
187 }
188 }
189
190 pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer {
194 RequestBodyTimeoutLayer::new(timeout)
195 }
196
197 define_inner_service_accessors!();
198}
199
200impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyTimeout<S>
201where
202 S: Service<Request<TimeoutBody<ReqBody>>>,
203{
204 type Response = S::Response;
205 type Error = S::Error;
206 type Future = S::Future;
207
208 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209 self.inner.poll_ready(cx)
210 }
211
212 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
213 let req = req.map(|body| TimeoutBody::new(self.timeout, body));
214 self.inner.call(req)
215 }
216}
217
218#[derive(Clone)]
220pub struct ResponseBodyTimeoutLayer {
221 timeout: Duration,
222}
223
224impl ResponseBodyTimeoutLayer {
225 pub fn new(timeout: Duration) -> Self {
227 Self { timeout }
228 }
229}
230
231impl<S> Layer<S> for ResponseBodyTimeoutLayer {
232 type Service = ResponseBodyTimeout<S>;
233
234 fn layer(&self, inner: S) -> Self::Service {
235 ResponseBodyTimeout::new(inner, self.timeout)
236 }
237}
238
239#[derive(Clone)]
241pub struct ResponseBodyTimeout<S> {
242 inner: S,
243 timeout: Duration,
244}
245
246impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyTimeout<S>
247where
248 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
249{
250 type Response = Response<TimeoutBody<ResBody>>;
251 type Error = S::Error;
252 type Future = ResponseBodyTimeoutFuture<S::Future>;
253
254 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255 self.inner.poll_ready(cx)
256 }
257
258 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
259 ResponseBodyTimeoutFuture {
260 inner: self.inner.call(req),
261 timeout: self.timeout,
262 }
263 }
264}
265
266impl<S> ResponseBodyTimeout<S> {
267 pub fn new(service: S, timeout: Duration) -> Self {
269 Self {
270 inner: service,
271 timeout,
272 }
273 }
274
275 pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer {
279 ResponseBodyTimeoutLayer::new(timeout)
280 }
281
282 define_inner_service_accessors!();
283}
284
285pin_project! {
286 pub struct ResponseBodyTimeoutFuture<Fut> {
288 #[pin]
289 inner: Fut,
290 timeout: Duration,
291 }
292}
293
294impl<Fut, ResBody, E> Future for ResponseBodyTimeoutFuture<Fut>
295where
296 Fut: Future<Output = Result<Response<ResBody>, E>>,
297{
298 type Output = Result<Response<TimeoutBody<ResBody>>, E>;
299
300 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
301 let timeout = self.timeout;
302 let this = self.project();
303 let res = ready!(this.inner.poll(cx))?;
304 Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body))))
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::test_helpers::Body;
312 use http::{Request, Response, StatusCode};
313 use std::time::Duration;
314 use tower::{BoxError, ServiceBuilder, ServiceExt};
315
316 #[tokio::test]
317 async fn request_completes_within_timeout() {
318 let mut service = ServiceBuilder::new()
319 .layer(TimeoutLayer::with_status_code(
320 StatusCode::GATEWAY_TIMEOUT,
321 Duration::from_secs(1),
322 ))
323 .service_fn(fast_handler);
324
325 let request = Request::get("/").body(Body::empty()).unwrap();
326 let res = service.ready().await.unwrap().call(request).await.unwrap();
327
328 assert_eq!(res.status(), StatusCode::OK);
329 }
330
331 #[tokio::test]
332 async fn timeout_middleware_with_custom_status_code() {
333 let timeout_service = Timeout::with_status_code(
334 tower::service_fn(slow_handler),
335 StatusCode::REQUEST_TIMEOUT,
336 Duration::from_millis(10),
337 );
338
339 let mut service = ServiceBuilder::new().service(timeout_service);
340
341 let request = Request::get("/").body(Body::empty()).unwrap();
342 let res = service.ready().await.unwrap().call(request).await.unwrap();
343
344 assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
345 }
346
347 #[tokio::test]
348 async fn timeout_response_has_empty_body() {
349 let mut service = ServiceBuilder::new()
350 .layer(TimeoutLayer::with_status_code(
351 StatusCode::GATEWAY_TIMEOUT,
352 Duration::from_millis(10),
353 ))
354 .service_fn(slow_handler);
355
356 let request = Request::get("/").body(Body::empty()).unwrap();
357 let res = service.ready().await.unwrap().call(request).await.unwrap();
358
359 assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT);
360
361 use http_body_util::BodyExt;
363 let body = res.into_body();
364 let bytes = body.collect().await.unwrap().to_bytes();
365 assert!(bytes.is_empty());
366 }
367
368 #[tokio::test]
369 async fn deprecated_new_method_compatibility() {
370 #[allow(deprecated)]
371 let layer = TimeoutLayer::new(Duration::from_millis(10));
372
373 let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler);
374
375 let request = Request::get("/").body(Body::empty()).unwrap();
376 let res = service.ready().await.unwrap().call(request).await.unwrap();
377
378 assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
380 }
381
382 async fn slow_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
383 tokio::time::sleep(Duration::from_secs(10)).await;
384 Ok(Response::builder()
385 .status(StatusCode::OK)
386 .body(Body::empty())
387 .unwrap())
388 }
389
390 async fn fast_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
391 Ok(Response::builder()
392 .status(StatusCode::OK)
393 .body(Body::empty())
394 .unwrap())
395 }
396}