xitca_web/middleware/
catch_unwind.rs

1//! panic catcher middleware.
2
3use xitca_http::util::middleware::catch_unwind::{self, CatchUnwindError};
4
5use crate::{
6    WebContext,
7    error::{Error, ThreadJoinError},
8    service::{Service, ready::ReadyService},
9};
10
11/// middleware for catching panic inside [`Service::call`] and return a 500 error response.
12///
13/// # Examples:
14/// ```rust
15/// # use xitca_web::{handler::handler_service, middleware::CatchUnwind, service::ServiceExt, App, WebContext};
16/// // handler function that cause panic.
17/// async fn handler(_: &WebContext<'_>) -> &'static str {
18///     panic!("");
19/// }
20///
21/// App::new()
22///     // request to "/" would always panic due to handler function.
23///     .at("/", handler_service(handler))
24///     // enclosed application with CatchUnwind middleware.
25///     // panic in handler function would be caught and converted to 500 internal server error response to client.
26///     .enclosed(CatchUnwind);
27///
28/// // CatchUnwind can also be used on individual route service for scoped panic catching:
29/// App::new()
30///     .at("/", handler_service(handler))
31///     // only catch panic on "/scope" path.
32///     .at("/scope", handler_service(handler).enclosed(CatchUnwind));
33/// ```
34#[derive(Clone)]
35pub struct CatchUnwind;
36
37impl<Arg> Service<Arg> for CatchUnwind
38where
39    catch_unwind::CatchUnwind: Service<Arg>,
40{
41    type Response = CatchUnwindService<<catch_unwind::CatchUnwind as Service<Arg>>::Response>;
42    type Error = <catch_unwind::CatchUnwind as Service<Arg>>::Error;
43
44    async fn call(&self, arg: Arg) -> Result<Self::Response, Self::Error> {
45        catch_unwind::CatchUnwind.call(arg).await.map(CatchUnwindService)
46    }
47}
48
49pub struct CatchUnwindService<S>(S);
50
51impl<'r, C, B, S> Service<WebContext<'r, C, B>> for CatchUnwindService<S>
52where
53    S: Service<WebContext<'r, C, B>>,
54    S::Error: Into<Error>,
55{
56    type Response = S::Response;
57    type Error = Error;
58
59    #[inline]
60    async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
61        self.0.call(ctx).await.map_err(Into::into)
62    }
63}
64
65impl<E> From<CatchUnwindError<E>> for Error
66where
67    E: Into<Error>,
68{
69    fn from(e: CatchUnwindError<E>) -> Self {
70        match e {
71            CatchUnwindError::First(e) => Error::from(ThreadJoinError::new(e)),
72            CatchUnwindError::Second(e) => e.into(),
73        }
74    }
75}
76
77impl<S> ReadyService for CatchUnwindService<S>
78where
79    S: ReadyService,
80{
81    type Ready = S::Ready;
82
83    #[inline]
84    async fn ready(&self) -> Self::Ready {
85        self.0.ready().await
86    }
87}
88
89#[cfg(test)]
90mod test {
91    use xitca_unsafe_collection::futures::NowOrPanic;
92
93    use crate::{
94        App,
95        handler::handler_service,
96        http::{Request, StatusCode},
97    };
98
99    use super::*;
100
101    #[test]
102    fn catch_panic() {
103        async fn handler() -> &'static str {
104            panic!("");
105        }
106
107        let res = App::new()
108            .with_state("996")
109            .at("/", handler_service(handler))
110            .enclosed(CatchUnwind)
111            .finish()
112            .call(())
113            .now_or_panic()
114            .unwrap()
115            .call(Request::default())
116            .now_or_panic()
117            .unwrap();
118
119        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
120    }
121}