tower_async_http/
catch_panic.rs

1//! Convert panics into responses.
2//!
3//! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result`
4//! whenever possible.
5//!
6//! # Example
7//!
8//! ```rust
9//! use http::{Request, Response, header::HeaderName};
10//! use std::convert::Infallible;
11//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
12//! use tower_async_http::catch_panic::CatchPanicLayer;
13//! use http_body_util::Full;
14//! use bytes::Bytes;
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), BoxError> {
18//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
19//!     panic!("something went wrong...")
20//! }
21//!
22//! let mut svc = ServiceBuilder::new()
23//!     // Catch panics and convert them into responses.
24//!     .layer(CatchPanicLayer::new())
25//!     .service_fn(handle);
26//!
27//! // Call the service.
28//! let request = Request::new(Full::<Bytes>::default());
29//!
30//! let response = svc.call(request).await?;
31//!
32//! assert_eq!(response.status(), 500);
33//! #
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! Using a custom panic handler:
39//!
40//! ```rust
41//! use http::{Request, StatusCode, Response, header::{self, HeaderName}};
42//! use std::{any::Any, convert::Infallible};
43//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
44//! use tower_async_http::catch_panic::CatchPanicLayer;
45//! use http_body_util::Full;
46//! use bytes::Bytes;
47//!
48//! # #[tokio::main]
49//! # async fn main() -> Result<(), BoxError> {
50//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
51//!     panic!("something went wrong...")
52//! }
53//!
54//! fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
55//!     let details = if let Some(s) = err.downcast_ref::<String>() {
56//!         s.clone()
57//!     } else if let Some(s) = err.downcast_ref::<&str>() {
58//!         s.to_string()
59//!     } else {
60//!         "Unknown panic message".to_string()
61//!     };
62//!
63//!     let body = serde_json::json!({
64//!         "error": {
65//!             "kind": "panic",
66//!             "details": details,
67//!         }
68//!     });
69//!     let body = serde_json::to_string(&body).unwrap();
70//!
71//!     Response::builder()
72//!         .status(StatusCode::INTERNAL_SERVER_ERROR)
73//!         .header(header::CONTENT_TYPE, "application/json")
74//!         .body(Full::from(body))
75//!         .unwrap()
76//! }
77//!
78//! let svc = ServiceBuilder::new()
79//!     // Use `handle_panic` to create the response.
80//!     .layer(CatchPanicLayer::custom(handle_panic))
81//!     .service_fn(handle);
82//! #
83//! # Ok(())
84//! # }
85//! ```
86
87use bytes::Bytes;
88use futures_util::future::FutureExt;
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Full};
92use std::{any::Any, panic::AssertUnwindSafe};
93use tower_async_layer::Layer;
94use tower_async_service::Service;
95
96use crate::BoxError;
97
98/// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into
99/// `500 Internal Server` responses.
100///
101/// See the [module docs](self) for an example.
102#[derive(Debug, Clone, Copy, Default)]
103pub struct CatchPanicLayer<T> {
104    panic_handler: T,
105}
106
107impl CatchPanicLayer<DefaultResponseForPanic> {
108    /// Create a new `CatchPanicLayer` with the default panic handler.
109    pub fn new() -> Self {
110        CatchPanicLayer {
111            panic_handler: DefaultResponseForPanic,
112        }
113    }
114}
115
116impl<T> CatchPanicLayer<T> {
117    /// Create a new `CatchPanicLayer` with a custom panic handler.
118    pub fn custom(panic_handler: T) -> Self
119    where
120        T: ResponseForPanic,
121    {
122        Self { panic_handler }
123    }
124}
125
126impl<T, S> Layer<S> for CatchPanicLayer<T>
127where
128    T: Clone,
129{
130    type Service = CatchPanic<S, T>;
131
132    fn layer(&self, inner: S) -> Self::Service {
133        CatchPanic {
134            inner,
135            panic_handler: self.panic_handler.clone(),
136        }
137    }
138}
139
140/// Middleware that catches panics and converts them into `500 Internal Server` responses.
141///
142/// See the [module docs](self) for an example.
143#[derive(Debug, Clone, Copy)]
144pub struct CatchPanic<S, T> {
145    inner: S,
146    panic_handler: T,
147}
148
149impl<S> CatchPanic<S, DefaultResponseForPanic> {
150    /// Create a new `CatchPanic` with the default panic handler.
151    pub fn new(inner: S) -> Self {
152        Self {
153            inner,
154            panic_handler: DefaultResponseForPanic,
155        }
156    }
157}
158
159impl<S, T> CatchPanic<S, T> {
160    define_inner_service_accessors!();
161
162    /// Create a new `CatchPanic` with a custom panic handler.
163    pub fn custom(inner: S, panic_handler: T) -> Self
164    where
165        T: ResponseForPanic,
166    {
167        Self {
168            inner,
169            panic_handler,
170        }
171    }
172}
173
174impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
175where
176    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
177    ResBody: Body<Data = Bytes> + Send + 'static,
178    ResBody::Error: Into<BoxError>,
179    T: ResponseForPanic + Clone,
180    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
181    <T::ResponseBody as Body>::Error: Into<BoxError>,
182{
183    type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
184    type Error = S::Error;
185
186    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
187        let future = match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
188            Ok(future) => future,
189            Err(panic_err) => {
190                return Ok(self
191                    .panic_handler
192                    .response_for_panic(panic_err)
193                    .map(|body| body.map_err(Into::into).boxed_unsync()))
194            }
195        };
196        match AssertUnwindSafe(future).catch_unwind().await {
197            Ok(res) => match res {
198                Ok(res) => Ok(res.map(|body| body.map_err(Into::into).boxed_unsync())),
199                Err(err) => Err(err),
200            },
201            Err(panic_err) => Ok(self
202                .panic_handler
203                .response_for_panic(panic_err)
204                .map(|body| body.map_err(Into::into).boxed_unsync())),
205        }
206    }
207}
208
209/// Trait for creating responses from panics.
210pub trait ResponseForPanic: Clone {
211    /// The body type used for responses to panics.
212    type ResponseBody;
213
214    /// Create a response from the panic error.
215    fn response_for_panic(
216        &self,
217        err: Box<dyn Any + Send + 'static>,
218    ) -> Response<Self::ResponseBody>;
219}
220
221impl<F, B> ResponseForPanic for F
222where
223    F: Fn(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
224{
225    type ResponseBody = B;
226
227    fn response_for_panic(
228        &self,
229        err: Box<dyn Any + Send + 'static>,
230    ) -> Response<Self::ResponseBody> {
231        self(err)
232    }
233}
234
235/// The default `ResponseForPanic` used by `CatchPanic`.
236///
237/// It will log the panic message and return a `500 Internal Server` error response with an empty
238/// body.
239#[derive(Debug, Default, Clone, Copy)]
240#[non_exhaustive]
241pub struct DefaultResponseForPanic;
242
243impl ResponseForPanic for DefaultResponseForPanic {
244    type ResponseBody = Full<Bytes>;
245
246    fn response_for_panic(
247        &self,
248        err: Box<dyn Any + Send + 'static>,
249    ) -> Response<Self::ResponseBody> {
250        if let Some(s) = err.downcast_ref::<String>() {
251            tracing::error!("Service panicked: {}", s);
252        } else if let Some(s) = err.downcast_ref::<&str>() {
253            tracing::error!("Service panicked: {}", s);
254        } else {
255            tracing::error!(
256                "Service panicked but `CatchPanic` was unable to downcast the panic info"
257            );
258        };
259
260        let mut res = Response::new(Full::from("Service panicked"));
261        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
262
263        #[allow(clippy::declare_interior_mutable_const)]
264        const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
265        res.headers_mut()
266            .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
267
268        res
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    #![allow(unreachable_code)]
275
276    use super::*;
277
278    use crate::test_helpers::{self, Body};
279
280    use hyper::Response;
281    use std::convert::Infallible;
282    use tower_async::{ServiceBuilder, ServiceExt};
283
284    #[tokio::test]
285    async fn panic_before_returning_future() {
286        let svc = ServiceBuilder::new()
287            .layer(CatchPanicLayer::new())
288            .service_fn(|_: Request<Body>| {
289                panic!("service panic");
290                async { Ok::<_, Infallible>(Response::new(Body::empty())) }
291            });
292
293        let req = Request::new(Body::empty());
294
295        let res = svc.oneshot(req).await.unwrap();
296
297        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
298        let body = test_helpers::to_bytes(res).await.unwrap();
299        assert_eq!(&body[..], b"Service panicked");
300    }
301
302    #[tokio::test]
303    async fn panic_in_future() {
304        let svc = ServiceBuilder::new()
305            .layer(CatchPanicLayer::new())
306            .service_fn(|_: Request<Body>| async {
307                panic!("future panic");
308                Ok::<_, Infallible>(Response::new(Body::empty()))
309            });
310
311        let req = Request::new(Body::empty());
312
313        let res = svc.oneshot(req).await.unwrap();
314
315        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
316        let body = test_helpers::to_bytes(res).await.unwrap();
317        assert_eq!(&body[..], b"Service panicked");
318    }
319}