1use bytes::Bytes;
88use futures_util::future::{CatchUnwind, FutureExt};
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::BodyExt;
92use pin_project_lite::pin_project;
93use std::{
94    any::Any,
95    future::Future,
96    panic::AssertUnwindSafe,
97    pin::Pin,
98    task::{ready, Context, Poll},
99};
100use tower_layer::Layer;
101use tower_service::Service;
102
103use crate::{
104    body::{Full, UnsyncBoxBody},
105    BoxError,
106};
107
108#[derive(Debug, Clone, Copy, Default)]
113pub struct CatchPanicLayer<T> {
114    panic_handler: T,
115}
116
117impl CatchPanicLayer<DefaultResponseForPanic> {
118    pub fn new() -> Self {
120        CatchPanicLayer {
121            panic_handler: DefaultResponseForPanic,
122        }
123    }
124}
125
126impl<T> CatchPanicLayer<T> {
127    pub fn custom(panic_handler: T) -> Self
129    where
130        T: ResponseForPanic,
131    {
132        Self { panic_handler }
133    }
134}
135
136impl<T, S> Layer<S> for CatchPanicLayer<T>
137where
138    T: Clone,
139{
140    type Service = CatchPanic<S, T>;
141
142    fn layer(&self, inner: S) -> Self::Service {
143        CatchPanic {
144            inner,
145            panic_handler: self.panic_handler.clone(),
146        }
147    }
148}
149
150#[derive(Debug, Clone, Copy)]
154pub struct CatchPanic<S, T> {
155    inner: S,
156    panic_handler: T,
157}
158
159impl<S> CatchPanic<S, DefaultResponseForPanic> {
160    pub fn new(inner: S) -> Self {
162        Self {
163            inner,
164            panic_handler: DefaultResponseForPanic,
165        }
166    }
167}
168
169impl<S, T> CatchPanic<S, T> {
170    define_inner_service_accessors!();
171
172    pub fn custom(inner: S, panic_handler: T) -> Self
174    where
175        T: ResponseForPanic,
176    {
177        Self {
178            inner,
179            panic_handler,
180        }
181    }
182}
183
184impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
185where
186    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
187    ResBody: Body<Data = Bytes> + Send + 'static,
188    ResBody::Error: Into<BoxError>,
189    T: ResponseForPanic + Clone,
190    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
191    <T::ResponseBody as Body>::Error: Into<BoxError>,
192{
193    type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
194    type Error = S::Error;
195    type Future = ResponseFuture<S::Future, T>;
196
197    #[inline]
198    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        self.inner.poll_ready(cx)
200    }
201
202    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
203        match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
204            Ok(future) => ResponseFuture {
205                kind: Kind::Future {
206                    future: AssertUnwindSafe(future).catch_unwind(),
207                    panic_handler: Some(self.panic_handler.clone()),
208                },
209            },
210            Err(panic_err) => ResponseFuture {
211                kind: Kind::Panicked {
212                    panic_err: Some(panic_err),
213                    panic_handler: Some(self.panic_handler.clone()),
214                },
215            },
216        }
217    }
218}
219
220pin_project! {
221    pub struct ResponseFuture<F, T> {
223        #[pin]
224        kind: Kind<F, T>,
225    }
226}
227
228pin_project! {
229    #[project = KindProj]
230    enum Kind<F, T> {
231        Panicked {
232            panic_err: Option<Box<dyn Any + Send + 'static>>,
233            panic_handler: Option<T>,
234        },
235        Future {
236            #[pin]
237            future: CatchUnwind<AssertUnwindSafe<F>>,
238            panic_handler: Option<T>,
239        }
240    }
241}
242
243impl<F, ResBody, E, T> Future for ResponseFuture<F, T>
244where
245    F: Future<Output = Result<Response<ResBody>, E>>,
246    ResBody: Body<Data = Bytes> + Send + 'static,
247    ResBody::Error: Into<BoxError>,
248    T: ResponseForPanic,
249    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
250    <T::ResponseBody as Body>::Error: Into<BoxError>,
251{
252    type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>;
253
254    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
255        match self.project().kind.project() {
256            KindProj::Panicked {
257                panic_err,
258                panic_handler,
259            } => {
260                let panic_handler = panic_handler
261                    .take()
262                    .expect("future polled after completion");
263                let panic_err = panic_err.take().expect("future polled after completion");
264                Poll::Ready(Ok(response_for_panic(panic_handler, panic_err)))
265            }
266            KindProj::Future {
267                future,
268                panic_handler,
269            } => match ready!(future.poll(cx)) {
270                Ok(Ok(res)) => {
271                    Poll::Ready(Ok(res.map(|body| {
272                        UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())
273                    })))
274                }
275                Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)),
276                Err(panic_err) => Poll::Ready(Ok(response_for_panic(
277                    panic_handler
278                        .take()
279                        .expect("future polled after completion"),
280                    panic_err,
281                ))),
282            },
283        }
284    }
285}
286
287fn response_for_panic<T>(
288    mut panic_handler: T,
289    err: Box<dyn Any + Send + 'static>,
290) -> Response<UnsyncBoxBody<Bytes, BoxError>>
291where
292    T: ResponseForPanic,
293    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
294    <T::ResponseBody as Body>::Error: Into<BoxError>,
295{
296    panic_handler
297        .response_for_panic(err)
298        .map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync()))
299}
300
301pub trait ResponseForPanic: Clone {
303    type ResponseBody;
305
306    fn response_for_panic(
308        &mut self,
309        err: Box<dyn Any + Send + 'static>,
310    ) -> Response<Self::ResponseBody>;
311}
312
313impl<F, B> ResponseForPanic for F
314where
315    F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
316{
317    type ResponseBody = B;
318
319    fn response_for_panic(
320        &mut self,
321        err: Box<dyn Any + Send + 'static>,
322    ) -> Response<Self::ResponseBody> {
323        self(err)
324    }
325}
326
327#[derive(Debug, Default, Clone, Copy)]
332#[non_exhaustive]
333pub struct DefaultResponseForPanic;
334
335impl ResponseForPanic for DefaultResponseForPanic {
336    type ResponseBody = Full;
337
338    fn response_for_panic(
339        &mut self,
340        err: Box<dyn Any + Send + 'static>,
341    ) -> Response<Self::ResponseBody> {
342        if let Some(s) = err.downcast_ref::<String>() {
343            tracing::error!("Service panicked: {}", s);
344        } else if let Some(s) = err.downcast_ref::<&str>() {
345            tracing::error!("Service panicked: {}", s);
346        } else {
347            tracing::error!(
348                "Service panicked but `CatchPanic` was unable to downcast the panic info"
349            );
350        };
351
352        let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked")));
353        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
354
355        #[allow(clippy::declare_interior_mutable_const)]
356        const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
357        res.headers_mut()
358            .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
359
360        res
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    #![allow(unreachable_code)]
367
368    use super::*;
369    use crate::test_helpers::Body;
370    use http::Response;
371    use std::convert::Infallible;
372    use tower::{ServiceBuilder, ServiceExt};
373
374    #[tokio::test]
375    async fn panic_before_returning_future() {
376        let svc = ServiceBuilder::new()
377            .layer(CatchPanicLayer::new())
378            .service_fn(|_: Request<Body>| {
379                panic!("service panic");
380                async { Ok::<_, Infallible>(Response::new(Body::empty())) }
381            });
382
383        let req = Request::new(Body::empty());
384
385        let res = svc.oneshot(req).await.unwrap();
386
387        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
388        let body = crate::test_helpers::to_bytes(res).await.unwrap();
389        assert_eq!(&body[..], b"Service panicked");
390    }
391
392    #[tokio::test]
393    async fn panic_in_future() {
394        let svc = ServiceBuilder::new()
395            .layer(CatchPanicLayer::new())
396            .service_fn(|_: Request<Body>| async {
397                panic!("future panic");
398                Ok::<_, Infallible>(Response::new(Body::empty()))
399            });
400
401        let req = Request::new(Body::empty());
402
403        let res = svc.oneshot(req).await.unwrap();
404
405        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
406        let body = crate::test_helpers::to_bytes(res).await.unwrap();
407        assert_eq!(&body[..], b"Service panicked");
408    }
409}