volo_http/server/layer/
timeout.rs1use std::time::Duration;
2
3use motore::{Service, layer::Layer};
4
5use crate::{context::ServerContext, request::Request, response::Response, server::IntoResponse};
6
7#[derive(Clone)]
11pub struct TimeoutLayer<H> {
12 duration: Duration,
13 handler: H,
14}
15
16impl<H> TimeoutLayer<H> {
17 pub fn new(duration: Duration, handler: H) -> Self {
49 Self { duration, handler }
50 }
51}
52
53impl<S, H> Layer<S> for TimeoutLayer<H>
54where
55 S: Send + Sync + 'static,
56{
57 type Service = Timeout<S, H>;
58
59 fn layer(self, inner: S) -> Self::Service {
60 Timeout {
61 service: inner,
62 duration: self.duration,
63 handler: self.handler,
64 }
65 }
66}
67
68trait TimeoutHandler<'r> {
69 fn call(self, cx: &'r ServerContext) -> Response;
70}
71
72impl<'r, F, R> TimeoutHandler<'r> for F
73where
74 F: FnOnce(&'r ServerContext) -> R + 'r,
75 R: IntoResponse + 'r,
76{
77 fn call(self, cx: &'r ServerContext) -> Response {
78 self(cx).into_response()
79 }
80}
81
82#[derive(Clone)]
86pub struct Timeout<S, H> {
87 service: S,
88 duration: Duration,
89 handler: H,
90}
91
92impl<S, B, H> Service<ServerContext, Request<B>> for Timeout<S, H>
93where
94 S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
95 S::Response: IntoResponse,
96 S::Error: IntoResponse,
97 B: Send,
98 H: for<'r> TimeoutHandler<'r> + Clone + Sync,
99{
100 type Response = Response;
101 type Error = S::Error;
102
103 async fn call(
104 &self,
105 cx: &mut ServerContext,
106 req: Request<B>,
107 ) -> Result<Self::Response, Self::Error> {
108 let fut_service = self.service.call(cx, req);
109 let fut_timeout = tokio::time::sleep(self.duration);
110
111 tokio::select! {
112 resp = fut_service => resp.map(IntoResponse::into_response),
113 _ = fut_timeout => {
114 Ok(self.handler.clone().call(cx))
115 },
116 }
117 }
118}
119
120#[cfg(test)]
121mod timeout_tests {
122 use http::{Method, StatusCode};
123 use motore::{Service, layer::Layer};
124
125 use crate::{
126 body::BodyConversion,
127 context::ServerContext,
128 server::{
129 route::{Route, get},
130 test_helpers::empty_cx,
131 },
132 utils::test_helpers::simple_req,
133 };
134
135 #[tokio::test]
136 async fn test_timeout_layer() {
137 use std::time::Duration;
138
139 use crate::server::layer::TimeoutLayer;
140
141 async fn index_handler() -> &'static str {
142 "Hello, World"
143 }
144
145 async fn index_timeout_handler() -> &'static str {
146 tokio::time::sleep(Duration::from_secs_f64(1.5)).await;
147 "Hello, World"
148 }
149
150 fn timeout_handler(_: &ServerContext) -> StatusCode {
151 StatusCode::REQUEST_TIMEOUT
152 }
153
154 let timeout_layer = TimeoutLayer::new(Duration::from_secs(1), timeout_handler);
155
156 let mut cx = empty_cx();
157
158 let route: Route<&str> = Route::new(get(index_timeout_handler));
160 let service = timeout_layer.clone().layer(route);
161 let req = simple_req(Method::GET, "/", "");
162 let resp = service.call(&mut cx, req).await.unwrap();
163 assert_eq!(resp.status(), StatusCode::REQUEST_TIMEOUT);
164
165 let route: Route<&str> = Route::new(get(index_handler));
167 let service = timeout_layer.clone().layer(route);
168 let req = simple_req(Method::GET, "/", "");
169 let resp = service.call(&mut cx, req).await.unwrap();
170 assert_eq!(
171 resp.into_body().into_string().await.unwrap(),
172 "Hello, World"
173 );
174 }
175}