1use crate::timeout::body::TimeoutBody;
2use crate::timeout::deadline_body::DeadlineBody;
3use http::{Request, Response, StatusCode};
4use pin_project_lite::pin_project;
5use std::{
6 future::Future,
7 pin::Pin,
8 task::{ready, Context, Poll},
9 time::Duration,
10};
11use tokio::time::Sleep;
12use tower_layer::Layer;
13use tower_service::Service;
14
15#[derive(Debug, Clone, Copy)]
19pub struct TimeoutLayer {
20 timeout: Duration,
21 status_code: StatusCode,
22}
23
24impl TimeoutLayer {
25 #[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")]
30 pub fn new(timeout: Duration) -> Self {
31 Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
32 }
33
34 pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self {
36 Self {
37 timeout,
38 status_code,
39 }
40 }
41}
42
43impl<S> Layer<S> for TimeoutLayer {
44 type Service = Timeout<S>;
45
46 fn layer(&self, inner: S) -> Self::Service {
47 Timeout::with_status_code(inner, self.status_code, self.timeout)
48 }
49}
50
51#[derive(Debug, Clone, Copy)]
55pub struct Timeout<S> {
56 inner: S,
57 timeout: Duration,
58 status_code: StatusCode,
59}
60
61impl<S> Timeout<S> {
62 #[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")]
67 pub fn new(inner: S, timeout: Duration) -> Self {
68 Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout)
69 }
70
71 pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self {
73 Self {
74 inner,
75 timeout,
76 status_code,
77 }
78 }
79
80 define_inner_service_accessors!();
81
82 #[deprecated(
86 since = "0.6.7",
87 note = "Use `Timeout::layer_with_status_code` instead"
88 )]
89 pub fn layer(timeout: Duration) -> TimeoutLayer {
90 TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
91 }
92
93 pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer {
95 TimeoutLayer::with_status_code(status_code, timeout)
96 }
97}
98
99impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Timeout<S>
100where
101 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
102 ResBody: Default,
103{
104 type Response = S::Response;
105 type Error = S::Error;
106 type Future = ResponseFuture<S::Future>;
107
108 #[inline]
109 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 self.inner.poll_ready(cx)
111 }
112
113 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
114 let sleep = tokio::time::sleep(self.timeout);
115 ResponseFuture {
116 inner: self.inner.call(req),
117 sleep,
118 status_code: self.status_code,
119 }
120 }
121}
122
123pin_project! {
124 pub struct ResponseFuture<F> {
126 #[pin]
127 inner: F,
128 #[pin]
129 sleep: Sleep,
130 status_code: StatusCode,
131 }
132}
133
134impl<F, B, E> Future for ResponseFuture<F>
135where
136 F: Future<Output = Result<Response<B>, E>>,
137 B: Default,
138{
139 type Output = Result<Response<B>, E>;
140
141 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142 let this = self.project();
143
144 if this.sleep.poll(cx).is_ready() {
145 let mut res = Response::new(B::default());
146 *res.status_mut() = *this.status_code;
147 return Poll::Ready(Ok(res));
148 }
149
150 this.inner.poll(cx)
151 }
152}
153
154#[derive(Clone, Debug)]
156pub struct RequestBodyTimeoutLayer {
157 timeout: Duration,
158}
159
160impl RequestBodyTimeoutLayer {
161 pub fn new(timeout: Duration) -> Self {
163 Self { timeout }
164 }
165}
166
167impl<S> Layer<S> for RequestBodyTimeoutLayer {
168 type Service = RequestBodyTimeout<S>;
169
170 fn layer(&self, inner: S) -> Self::Service {
171 RequestBodyTimeout::new(inner, self.timeout)
172 }
173}
174
175#[derive(Clone, Debug)]
177pub struct RequestBodyTimeout<S> {
178 inner: S,
179 timeout: Duration,
180}
181
182impl<S> RequestBodyTimeout<S> {
183 pub fn new(service: S, timeout: Duration) -> Self {
185 Self {
186 inner: service,
187 timeout,
188 }
189 }
190
191 pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer {
195 RequestBodyTimeoutLayer::new(timeout)
196 }
197
198 define_inner_service_accessors!();
199}
200
201impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyTimeout<S>
202where
203 S: Service<Request<TimeoutBody<ReqBody>>>,
204{
205 type Response = S::Response;
206 type Error = S::Error;
207 type Future = S::Future;
208
209 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
210 self.inner.poll_ready(cx)
211 }
212
213 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
214 let req = req.map(|body| TimeoutBody::new(self.timeout, body));
215 self.inner.call(req)
216 }
217}
218
219#[derive(Clone)]
221pub struct ResponseBodyTimeoutLayer {
222 timeout: Duration,
223}
224
225impl ResponseBodyTimeoutLayer {
226 pub fn new(timeout: Duration) -> Self {
228 Self { timeout }
229 }
230}
231
232impl<S> Layer<S> for ResponseBodyTimeoutLayer {
233 type Service = ResponseBodyTimeout<S>;
234
235 fn layer(&self, inner: S) -> Self::Service {
236 ResponseBodyTimeout::new(inner, self.timeout)
237 }
238}
239
240#[derive(Clone)]
242pub struct ResponseBodyTimeout<S> {
243 inner: S,
244 timeout: Duration,
245}
246
247impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyTimeout<S>
248where
249 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
250{
251 type Response = Response<TimeoutBody<ResBody>>;
252 type Error = S::Error;
253 type Future = ResponseBodyTimeoutFuture<S::Future>;
254
255 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
256 self.inner.poll_ready(cx)
257 }
258
259 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
260 ResponseBodyTimeoutFuture {
261 inner: self.inner.call(req),
262 timeout: self.timeout,
263 }
264 }
265}
266
267impl<S> ResponseBodyTimeout<S> {
268 pub fn new(service: S, timeout: Duration) -> Self {
270 Self {
271 inner: service,
272 timeout,
273 }
274 }
275
276 pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer {
280 ResponseBodyTimeoutLayer::new(timeout)
281 }
282
283 define_inner_service_accessors!();
284}
285
286pin_project! {
287 pub struct ResponseBodyTimeoutFuture<Fut> {
289 #[pin]
290 inner: Fut,
291 timeout: Duration,
292 }
293}
294
295impl<Fut, ResBody, E> Future for ResponseBodyTimeoutFuture<Fut>
296where
297 Fut: Future<Output = Result<Response<ResBody>, E>>,
298{
299 type Output = Result<Response<TimeoutBody<ResBody>>, E>;
300
301 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
302 let timeout = self.timeout;
303 let this = self.project();
304 let res = ready!(this.inner.poll(cx))?;
305 Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body))))
306 }
307}
308
309#[derive(Clone, Debug)]
314pub struct RequestBodyDeadlineLayer {
315 timeout: Duration,
316}
317
318impl RequestBodyDeadlineLayer {
319 pub fn new(timeout: Duration) -> Self {
321 Self { timeout }
322 }
323}
324
325impl<S> Layer<S> for RequestBodyDeadlineLayer {
326 type Service = RequestBodyDeadline<S>;
327
328 fn layer(&self, inner: S) -> Self::Service {
329 RequestBodyDeadline::new(inner, self.timeout)
330 }
331}
332
333#[derive(Clone, Debug)]
335pub struct RequestBodyDeadline<S> {
336 inner: S,
337 timeout: Duration,
338}
339
340impl<S> RequestBodyDeadline<S> {
341 pub fn new(service: S, timeout: Duration) -> Self {
343 Self {
344 inner: service,
345 timeout,
346 }
347 }
348
349 pub fn layer(timeout: Duration) -> RequestBodyDeadlineLayer {
353 RequestBodyDeadlineLayer::new(timeout)
354 }
355
356 define_inner_service_accessors!();
357}
358
359impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyDeadline<S>
360where
361 S: Service<Request<DeadlineBody<ReqBody>>>,
362{
363 type Response = S::Response;
364 type Error = S::Error;
365 type Future = S::Future;
366
367 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
368 self.inner.poll_ready(cx)
369 }
370
371 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
372 let req = req.map(|body| DeadlineBody::new(self.timeout, body));
373 self.inner.call(req)
374 }
375}
376
377#[derive(Clone)]
382pub struct ResponseBodyDeadlineLayer {
383 timeout: Duration,
384}
385
386impl ResponseBodyDeadlineLayer {
387 pub fn new(timeout: Duration) -> Self {
389 Self { timeout }
390 }
391}
392
393impl<S> Layer<S> for ResponseBodyDeadlineLayer {
394 type Service = ResponseBodyDeadline<S>;
395
396 fn layer(&self, inner: S) -> Self::Service {
397 ResponseBodyDeadline::new(inner, self.timeout)
398 }
399}
400
401#[derive(Clone)]
403pub struct ResponseBodyDeadline<S> {
404 inner: S,
405 timeout: Duration,
406}
407
408impl<S> ResponseBodyDeadline<S> {
409 pub fn new(service: S, timeout: Duration) -> Self {
411 Self {
412 inner: service,
413 timeout,
414 }
415 }
416
417 pub fn layer(timeout: Duration) -> ResponseBodyDeadlineLayer {
421 ResponseBodyDeadlineLayer::new(timeout)
422 }
423
424 define_inner_service_accessors!();
425}
426
427impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyDeadline<S>
428where
429 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
430{
431 type Response = Response<DeadlineBody<ResBody>>;
432 type Error = S::Error;
433 type Future = ResponseBodyDeadlineFuture<S::Future>;
434
435 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436 self.inner.poll_ready(cx)
437 }
438
439 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
440 ResponseBodyDeadlineFuture {
441 inner: self.inner.call(req),
442 timeout: self.timeout,
443 }
444 }
445}
446
447pin_project! {
448 pub struct ResponseBodyDeadlineFuture<Fut> {
450 #[pin]
451 inner: Fut,
452 timeout: Duration,
453 }
454}
455
456impl<Fut, ResBody, E> Future for ResponseBodyDeadlineFuture<Fut>
457where
458 Fut: Future<Output = Result<Response<ResBody>, E>>,
459{
460 type Output = Result<Response<DeadlineBody<ResBody>>, E>;
461
462 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
463 let timeout = self.timeout;
464 let this = self.project();
465 let res = ready!(this.inner.poll(cx))?;
466 Poll::Ready(Ok(res.map(|body| DeadlineBody::new(timeout, body))))
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::test_helpers::Body;
474 use http::{Request, Response, StatusCode};
475 use std::time::Duration;
476 use tower::{BoxError, ServiceBuilder, ServiceExt};
477
478 #[tokio::test]
479 async fn request_completes_within_timeout() {
480 let mut service = ServiceBuilder::new()
481 .layer(TimeoutLayer::with_status_code(
482 StatusCode::GATEWAY_TIMEOUT,
483 Duration::from_secs(1),
484 ))
485 .service_fn(fast_handler);
486
487 let request = Request::get("/").body(Body::empty()).unwrap();
488 let res = service.ready().await.unwrap().call(request).await.unwrap();
489
490 assert_eq!(res.status(), StatusCode::OK);
491 }
492
493 #[tokio::test]
494 async fn timeout_middleware_with_custom_status_code() {
495 let timeout_service = Timeout::with_status_code(
496 tower::service_fn(slow_handler),
497 StatusCode::REQUEST_TIMEOUT,
498 Duration::from_millis(10),
499 );
500
501 let mut service = ServiceBuilder::new().service(timeout_service);
502
503 let request = Request::get("/").body(Body::empty()).unwrap();
504 let res = service.ready().await.unwrap().call(request).await.unwrap();
505
506 assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
507 }
508
509 #[tokio::test]
510 async fn timeout_response_has_empty_body() {
511 let mut service = ServiceBuilder::new()
512 .layer(TimeoutLayer::with_status_code(
513 StatusCode::GATEWAY_TIMEOUT,
514 Duration::from_millis(10),
515 ))
516 .service_fn(slow_handler);
517
518 let request = Request::get("/").body(Body::empty()).unwrap();
519 let res = service.ready().await.unwrap().call(request).await.unwrap();
520
521 assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT);
522
523 use http_body_util::BodyExt;
525 let body = res.into_body();
526 let bytes = body.collect().await.unwrap().to_bytes();
527 assert!(bytes.is_empty());
528 }
529
530 #[tokio::test]
531 async fn deprecated_new_method_compatibility() {
532 #[allow(deprecated)]
533 let layer = TimeoutLayer::new(Duration::from_millis(10));
534
535 let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler);
536
537 let request = Request::get("/").body(Body::empty()).unwrap();
538 let res = service.ready().await.unwrap().call(request).await.unwrap();
539
540 assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
542 }
543
544 async fn slow_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
545 tokio::time::sleep(Duration::from_secs(10)).await;
546 Ok(Response::builder()
547 .status(StatusCode::OK)
548 .body(Body::empty())
549 .unwrap())
550 }
551
552 async fn fast_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
553 Ok(Response::builder()
554 .status(StatusCode::OK)
555 .body(Body::empty())
556 .unwrap())
557 }
558}