Skip to main content

tower_http/
catch_panic.rs

1//! Convert panics into responses.
2//!
3//! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result`
4//! whenever possible.
5//!
6//! # Example
7//!
8//! ```rust
9//! use http::{Request, Response, header::HeaderName};
10//! use std::convert::Infallible;
11//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
12//! use tower_http::catch_panic::CatchPanicLayer;
13//! use http_body_util::Full;
14//! use bytes::Bytes;
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
19//!     panic!("something went wrong...")
20//! }
21//!
22//! let mut svc = ServiceBuilder::new()
23//!     // Catch panics and convert them into responses.
24//!     .layer(CatchPanicLayer::new())
25//!     .service_fn(handle);
26//!
27//! // Call the service.
28//! let request = Request::new(Full::default());
29//!
30//! let response = svc.ready().await?.call(request).await?;
31//!
32//! assert_eq!(response.status(), 500);
33//! #
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! Using a custom panic handler:
39//!
40//! ```rust
41//! use http::{Request, StatusCode, Response, header::{self, HeaderName}};
42//! use std::{any::Any, convert::Infallible};
43//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
44//! use tower_http::catch_panic::CatchPanicLayer;
45//! use bytes::Bytes;
46//! use http_body_util::Full;
47//!
48//! # #[tokio::main]
49//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
50//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
51//!     panic!("something went wrong...")
52//! }
53//!
54//! fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
55//!     let details = if let Some(s) = err.downcast_ref::<String>() {
56//!         s.clone()
57//!     } else if let Some(s) = err.downcast_ref::<&str>() {
58//!         s.to_string()
59//!     } else {
60//!         "Unknown panic message".to_string()
61//!     };
62//!
63//!     let body = serde_json::json!({
64//!         "error": {
65//!             "kind": "panic",
66//!             "details": details,
67//!         }
68//!     });
69//!     let body = serde_json::to_string(&body).unwrap();
70//!
71//!     Response::builder()
72//!         .status(StatusCode::INTERNAL_SERVER_ERROR)
73//!         .header(header::CONTENT_TYPE, "application/json")
74//!         .body(Full::from(body))
75//!         .unwrap()
76//! }
77//!
78//! let svc = ServiceBuilder::new()
79//!     // Use `handle_panic` to create the response.
80//!     .layer(CatchPanicLayer::custom(handle_panic))
81//!     .service_fn(handle);
82//! #
83//! # Ok(())
84//! # }
85//! ```
86
87use bytes::Bytes;
88use futures_util::future::{CatchUnwind, FutureExt};
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::BodyExt;
92use pin_project_lite::pin_project;
93use std::{
94    any::Any,
95    future::Future,
96    panic::AssertUnwindSafe,
97    pin::Pin,
98    task::{ready, Context, Poll},
99};
100use tower_layer::Layer;
101use tower_service::Service;
102
103use crate::{
104    body::{Full, UnsyncBoxBody},
105    BoxError,
106};
107
108/// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into
109/// `500 Internal Server` responses.
110///
111/// See the [module docs](self) for an example.
112#[derive(Debug, Clone, Copy, Default)]
113pub struct CatchPanicLayer<T> {
114    panic_handler: T,
115}
116
117impl CatchPanicLayer<DefaultResponseForPanic> {
118    /// Create a new `CatchPanicLayer` with the default panic handler.
119    pub fn new() -> Self {
120        CatchPanicLayer {
121            panic_handler: DefaultResponseForPanic,
122        }
123    }
124}
125
126impl<T> CatchPanicLayer<T> {
127    /// Create a new `CatchPanicLayer` with a custom panic handler.
128    pub fn custom(panic_handler: T) -> Self
129    where
130        T: ResponseForPanic,
131    {
132        Self { panic_handler }
133    }
134}
135
136impl<T, S> Layer<S> for CatchPanicLayer<T>
137where
138    T: Clone,
139{
140    type Service = CatchPanic<S, T>;
141
142    fn layer(&self, inner: S) -> Self::Service {
143        CatchPanic {
144            inner,
145            panic_handler: self.panic_handler.clone(),
146        }
147    }
148}
149
150/// Middleware that catches panics and converts them into `500 Internal Server` responses.
151///
152/// See the [module docs](self) for an example.
153#[derive(Debug, Clone, Copy)]
154pub struct CatchPanic<S, T> {
155    inner: S,
156    panic_handler: T,
157}
158
159impl<S> CatchPanic<S, DefaultResponseForPanic> {
160    /// Create a new `CatchPanic` with the default panic handler.
161    pub fn new(inner: S) -> Self {
162        Self {
163            inner,
164            panic_handler: DefaultResponseForPanic,
165        }
166    }
167}
168
169impl<S, T> CatchPanic<S, T> {
170    define_inner_service_accessors!();
171
172    /// Create a new `CatchPanic` with a custom panic handler.
173    pub fn custom(inner: S, panic_handler: T) -> Self
174    where
175        T: ResponseForPanic,
176    {
177        Self {
178            inner,
179            panic_handler,
180        }
181    }
182}
183
184impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
185where
186    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
187    ResBody: Body<Data = Bytes> + Send + 'static,
188    ResBody::Error: Into<BoxError>,
189    T: ResponseForPanic + Clone,
190    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
191    <T::ResponseBody as Body>::Error: Into<BoxError>,
192{
193    type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
194    type Error = S::Error;
195    type Future = ResponseFuture<S::Future, T>;
196
197    #[inline]
198    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        self.inner.poll_ready(cx)
200    }
201
202    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
203        match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
204            Ok(future) => ResponseFuture {
205                kind: Kind::Future {
206                    future: AssertUnwindSafe(future).catch_unwind(),
207                    panic_handler: Some(self.panic_handler.clone()),
208                },
209            },
210            Err(panic_err) => ResponseFuture {
211                kind: Kind::Panicked {
212                    panic_err: Some(panic_err),
213                    panic_handler: Some(self.panic_handler.clone()),
214                },
215            },
216        }
217    }
218}
219
220pin_project! {
221    /// Response future for [`CatchPanic`].
222    pub struct ResponseFuture<F, T> {
223        #[pin]
224        kind: Kind<F, T>,
225    }
226}
227
228pin_project! {
229    #[project = KindProj]
230    enum Kind<F, T> {
231        Panicked {
232            panic_err: Option<Box<dyn Any + Send + 'static>>,
233            panic_handler: Option<T>,
234        },
235        Future {
236            #[pin]
237            future: CatchUnwind<AssertUnwindSafe<F>>,
238            panic_handler: Option<T>,
239        }
240    }
241}
242
243impl<F, ResBody, E, T> Future for ResponseFuture<F, T>
244where
245    F: Future<Output = Result<Response<ResBody>, E>>,
246    ResBody: Body<Data = Bytes> + Send + 'static,
247    ResBody::Error: Into<BoxError>,
248    T: ResponseForPanic,
249    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
250    <T::ResponseBody as Body>::Error: Into<BoxError>,
251{
252    type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>;
253
254    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
255        match self.project().kind.project() {
256            KindProj::Panicked {
257                panic_err,
258                panic_handler,
259            } => {
260                let panic_handler = panic_handler
261                    .take()
262                    .expect("future polled after completion");
263                let panic_err = panic_err.take().expect("future polled after completion");
264                Poll::Ready(Ok(response_for_panic(panic_handler, panic_err)))
265            }
266            KindProj::Future {
267                future,
268                panic_handler,
269            } => match ready!(future.poll(cx)) {
270                Ok(Ok(res)) => Poll::Ready(Ok(res.map(|body| {
271                    UnsyncBoxBody::from_inner(body.map_err(Into::into).boxed_unsync())
272                }))),
273                Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)),
274                Err(panic_err) => Poll::Ready(Ok(response_for_panic(
275                    panic_handler
276                        .take()
277                        .expect("future polled after completion"),
278                    panic_err,
279                ))),
280            },
281        }
282    }
283}
284
285fn response_for_panic<T>(
286    mut panic_handler: T,
287    err: Box<dyn Any + Send + 'static>,
288) -> Response<UnsyncBoxBody<Bytes, BoxError>>
289where
290    T: ResponseForPanic,
291    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
292    <T::ResponseBody as Body>::Error: Into<BoxError>,
293{
294    panic_handler
295        .response_for_panic(err)
296        .map(|body| UnsyncBoxBody::from_inner(body.map_err(Into::into).boxed_unsync()))
297}
298
299/// Trait for creating responses from panics.
300pub trait ResponseForPanic: Clone {
301    /// The body type used for responses to panics.
302    type ResponseBody;
303
304    /// Create a response from the panic error.
305    fn response_for_panic(
306        &mut self,
307        err: Box<dyn Any + Send + 'static>,
308    ) -> Response<Self::ResponseBody>;
309}
310
311impl<F, B> ResponseForPanic for F
312where
313    F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
314{
315    type ResponseBody = B;
316
317    fn response_for_panic(
318        &mut self,
319        err: Box<dyn Any + Send + 'static>,
320    ) -> Response<Self::ResponseBody> {
321        self(err)
322    }
323}
324
325/// The default `ResponseForPanic` used by `CatchPanic`.
326///
327/// It will log the panic message and return a `500 Internal Server` error response with an empty
328/// body.
329#[derive(Debug, Default, Clone, Copy)]
330#[non_exhaustive]
331pub struct DefaultResponseForPanic;
332
333impl ResponseForPanic for DefaultResponseForPanic {
334    type ResponseBody = Full;
335
336    fn response_for_panic(
337        &mut self,
338        err: Box<dyn Any + Send + 'static>,
339    ) -> Response<Self::ResponseBody> {
340        if let Some(s) = err.downcast_ref::<String>() {
341            tracing::error!("Service panicked: {}", s);
342        } else if let Some(s) = err.downcast_ref::<&str>() {
343            tracing::error!("Service panicked: {}", s);
344        } else {
345            tracing::error!(
346                "Service panicked but `CatchPanic` was unable to downcast the panic info"
347            );
348        };
349
350        let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked")));
351        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
352
353        #[allow(clippy::declare_interior_mutable_const)]
354        const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
355        res.headers_mut()
356            .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
357
358        res
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    #![allow(unreachable_code)]
365
366    use super::*;
367    use crate::test_helpers::Body;
368    use http::Response;
369    use std::convert::Infallible;
370    use tower::{ServiceBuilder, ServiceExt};
371
372    #[tokio::test]
373    async fn panic_before_returning_future() {
374        let svc = ServiceBuilder::new()
375            .layer(CatchPanicLayer::new())
376            .service_fn(|_: Request<Body>| {
377                panic!("service panic");
378                async { Ok::<_, Infallible>(Response::new(Body::empty())) }
379            });
380
381        let req = Request::new(Body::empty());
382
383        let res = svc.oneshot(req).await.unwrap();
384
385        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
386        let body = crate::test_helpers::to_bytes(res).await.unwrap();
387        assert_eq!(&body[..], b"Service panicked");
388    }
389
390    #[tokio::test]
391    async fn panic_in_future() {
392        let svc = ServiceBuilder::new()
393            .layer(CatchPanicLayer::new())
394            .service_fn(|_: Request<Body>| async {
395                panic!("future panic");
396                Ok::<_, Infallible>(Response::new(Body::empty()))
397            });
398
399        let req = Request::new(Body::empty());
400
401        let res = svc.oneshot(req).await.unwrap();
402
403        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
404        let body = crate::test_helpers::to_bytes(res).await.unwrap();
405        assert_eq!(&body[..], b"Service panicked");
406    }
407}