volo_http/server/layer/
timeout.rs

1use std::time::Duration;
2
3use motore::{Service, layer::Layer};
4
5use crate::{context::ServerContext, request::Request, response::Response, server::IntoResponse};
6
7/// [`Layer`] for setting timeout to the request
8///
9/// See [`TimeoutLayer::new`] for more details.
10#[derive(Clone)]
11pub struct TimeoutLayer<H> {
12    duration: Duration,
13    handler: H,
14}
15
16impl<H> TimeoutLayer<H> {
17    /// Create a new [`TimeoutLayer`] with given [`Duration`] and handler.
18    ///
19    /// The handler should be a sync function with [`&ServerContext`](ServerContext) as parameter,
20    /// and return anything that implement [`IntoResponse`].
21    ///
22    /// # Examples
23    ///
24    /// ```
25    /// use std::time::Duration;
26    ///
27    /// use http::status::StatusCode;
28    /// use volo_http::{
29    ///     context::ServerContext,
30    ///     server::{
31    ///         layer::TimeoutLayer,
32    ///         route::{Router, get},
33    ///     },
34    /// };
35    ///
36    /// async fn index() -> &'static str {
37    ///     "Hello, World"
38    /// }
39    ///
40    /// fn timeout_handler(_: &ServerContext) -> StatusCode {
41    ///     StatusCode::REQUEST_TIMEOUT
42    /// }
43    ///
44    /// let router: Router = Router::new()
45    ///     .route("/", get(index))
46    ///     .layer(TimeoutLayer::new(Duration::from_secs(1), timeout_handler));
47    /// ```
48    pub fn new(duration: Duration, handler: H) -> Self {
49        Self { duration, handler }
50    }
51}
52
53impl<S, H> Layer<S> for TimeoutLayer<H>
54where
55    S: Send + Sync + 'static,
56{
57    type Service = Timeout<S, H>;
58
59    fn layer(self, inner: S) -> Self::Service {
60        Timeout {
61            service: inner,
62            duration: self.duration,
63            handler: self.handler,
64        }
65    }
66}
67
68trait TimeoutHandler<'r> {
69    fn call(self, cx: &'r ServerContext) -> Response;
70}
71
72impl<'r, F, R> TimeoutHandler<'r> for F
73where
74    F: FnOnce(&'r ServerContext) -> R + 'r,
75    R: IntoResponse + 'r,
76{
77    fn call(self, cx: &'r ServerContext) -> Response {
78        self(cx).into_response()
79    }
80}
81
82/// [`TimeoutLayer`] generated [`Service`]
83///
84/// See [`TimeoutLayer`] for more details.
85#[derive(Clone)]
86pub struct Timeout<S, H> {
87    service: S,
88    duration: Duration,
89    handler: H,
90}
91
92impl<S, B, H> Service<ServerContext, Request<B>> for Timeout<S, H>
93where
94    S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
95    S::Response: IntoResponse,
96    S::Error: IntoResponse,
97    B: Send,
98    H: for<'r> TimeoutHandler<'r> + Clone + Sync,
99{
100    type Response = Response;
101    type Error = S::Error;
102
103    async fn call(
104        &self,
105        cx: &mut ServerContext,
106        req: Request<B>,
107    ) -> Result<Self::Response, Self::Error> {
108        let fut_service = self.service.call(cx, req);
109        let fut_timeout = tokio::time::sleep(self.duration);
110
111        tokio::select! {
112            resp = fut_service => resp.map(IntoResponse::into_response),
113            _ = fut_timeout => {
114                Ok(self.handler.clone().call(cx))
115            },
116        }
117    }
118}
119
120#[cfg(test)]
121mod timeout_tests {
122    use http::{Method, StatusCode};
123    use motore::{Service, layer::Layer};
124
125    use crate::{
126        body::BodyConversion,
127        context::ServerContext,
128        server::{
129            route::{Route, get},
130            test_helpers::empty_cx,
131        },
132        utils::test_helpers::simple_req,
133    };
134
135    #[tokio::test]
136    async fn test_timeout_layer() {
137        use std::time::Duration;
138
139        use crate::server::layer::TimeoutLayer;
140
141        async fn index_handler() -> &'static str {
142            "Hello, World"
143        }
144
145        async fn index_timeout_handler() -> &'static str {
146            tokio::time::sleep(Duration::from_secs_f64(1.5)).await;
147            "Hello, World"
148        }
149
150        fn timeout_handler(_: &ServerContext) -> StatusCode {
151            StatusCode::REQUEST_TIMEOUT
152        }
153
154        let timeout_layer = TimeoutLayer::new(Duration::from_secs(1), timeout_handler);
155
156        let mut cx = empty_cx();
157
158        // Test case 1: timeout
159        let route: Route<&str> = Route::new(get(index_timeout_handler));
160        let service = timeout_layer.clone().layer(route);
161        let req = simple_req(Method::GET, "/", "");
162        let resp = service.call(&mut cx, req).await.unwrap();
163        assert_eq!(resp.status(), StatusCode::REQUEST_TIMEOUT);
164
165        // Test case 2: not timeout
166        let route: Route<&str> = Route::new(get(index_handler));
167        let service = timeout_layer.clone().layer(route);
168        let req = simple_req(Method::GET, "/", "");
169        let resp = service.call(&mut cx, req).await.unwrap();
170        assert_eq!(
171            resp.into_body().into_string().await.unwrap(),
172            "Hello, World"
173        );
174    }
175}