1use bytes::Bytes;
88use futures_util::future::{CatchUnwind, FutureExt};
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::BodyExt;
92use pin_project_lite::pin_project;
93use std::{
94 any::Any,
95 future::Future,
96 panic::AssertUnwindSafe,
97 pin::Pin,
98 task::{ready, Context, Poll},
99};
100use tower_layer::Layer;
101use tower_service::Service;
102
103use crate::{
104 body::{Full, UnsyncBoxBody},
105 BoxError,
106};
107
108#[derive(Debug, Clone, Copy, Default)]
113pub struct CatchPanicLayer<T> {
114 panic_handler: T,
115}
116
117impl CatchPanicLayer<DefaultResponseForPanic> {
118 pub fn new() -> Self {
120 CatchPanicLayer {
121 panic_handler: DefaultResponseForPanic,
122 }
123 }
124}
125
126impl<T> CatchPanicLayer<T> {
127 pub fn custom(panic_handler: T) -> Self
129 where
130 T: ResponseForPanic,
131 {
132 Self { panic_handler }
133 }
134}
135
136impl<T, S> Layer<S> for CatchPanicLayer<T>
137where
138 T: Clone,
139{
140 type Service = CatchPanic<S, T>;
141
142 fn layer(&self, inner: S) -> Self::Service {
143 CatchPanic {
144 inner,
145 panic_handler: self.panic_handler.clone(),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Copy)]
154pub struct CatchPanic<S, T> {
155 inner: S,
156 panic_handler: T,
157}
158
159impl<S> CatchPanic<S, DefaultResponseForPanic> {
160 pub fn new(inner: S) -> Self {
162 Self {
163 inner,
164 panic_handler: DefaultResponseForPanic,
165 }
166 }
167}
168
169impl<S, T> CatchPanic<S, T> {
170 define_inner_service_accessors!();
171
172 pub fn custom(inner: S, panic_handler: T) -> Self
174 where
175 T: ResponseForPanic,
176 {
177 Self {
178 inner,
179 panic_handler,
180 }
181 }
182}
183
184impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
185where
186 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
187 ResBody: Body<Data = Bytes> + Send + 'static,
188 ResBody::Error: Into<BoxError>,
189 T: ResponseForPanic + Clone,
190 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
191 <T::ResponseBody as Body>::Error: Into<BoxError>,
192{
193 type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
194 type Error = S::Error;
195 type Future = ResponseFuture<S::Future, T>;
196
197 #[inline]
198 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199 self.inner.poll_ready(cx)
200 }
201
202 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
203 match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
204 Ok(future) => ResponseFuture {
205 kind: Kind::Future {
206 future: AssertUnwindSafe(future).catch_unwind(),
207 panic_handler: Some(self.panic_handler.clone()),
208 },
209 },
210 Err(panic_err) => ResponseFuture {
211 kind: Kind::Panicked {
212 panic_err: Some(panic_err),
213 panic_handler: Some(self.panic_handler.clone()),
214 },
215 },
216 }
217 }
218}
219
220pin_project! {
221 pub struct ResponseFuture<F, T> {
223 #[pin]
224 kind: Kind<F, T>,
225 }
226}
227
228pin_project! {
229 #[project = KindProj]
230 enum Kind<F, T> {
231 Panicked {
232 panic_err: Option<Box<dyn Any + Send + 'static>>,
233 panic_handler: Option<T>,
234 },
235 Future {
236 #[pin]
237 future: CatchUnwind<AssertUnwindSafe<F>>,
238 panic_handler: Option<T>,
239 }
240 }
241}
242
243impl<F, ResBody, E, T> Future for ResponseFuture<F, T>
244where
245 F: Future<Output = Result<Response<ResBody>, E>>,
246 ResBody: Body<Data = Bytes> + Send + 'static,
247 ResBody::Error: Into<BoxError>,
248 T: ResponseForPanic,
249 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
250 <T::ResponseBody as Body>::Error: Into<BoxError>,
251{
252 type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>;
253
254 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
255 match self.project().kind.project() {
256 KindProj::Panicked {
257 panic_err,
258 panic_handler,
259 } => {
260 let panic_handler = panic_handler
261 .take()
262 .expect("future polled after completion");
263 let panic_err = panic_err.take().expect("future polled after completion");
264 Poll::Ready(Ok(response_for_panic(panic_handler, panic_err)))
265 }
266 KindProj::Future {
267 future,
268 panic_handler,
269 } => match ready!(future.poll(cx)) {
270 Ok(Ok(res)) => Poll::Ready(Ok(res.map(|body| {
271 UnsyncBoxBody::from_inner(body.map_err(Into::into).boxed_unsync())
272 }))),
273 Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)),
274 Err(panic_err) => Poll::Ready(Ok(response_for_panic(
275 panic_handler
276 .take()
277 .expect("future polled after completion"),
278 panic_err,
279 ))),
280 },
281 }
282 }
283}
284
285fn response_for_panic<T>(
286 mut panic_handler: T,
287 err: Box<dyn Any + Send + 'static>,
288) -> Response<UnsyncBoxBody<Bytes, BoxError>>
289where
290 T: ResponseForPanic,
291 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
292 <T::ResponseBody as Body>::Error: Into<BoxError>,
293{
294 panic_handler
295 .response_for_panic(err)
296 .map(|body| UnsyncBoxBody::from_inner(body.map_err(Into::into).boxed_unsync()))
297}
298
299pub trait ResponseForPanic: Clone {
301 type ResponseBody;
303
304 fn response_for_panic(
306 &mut self,
307 err: Box<dyn Any + Send + 'static>,
308 ) -> Response<Self::ResponseBody>;
309}
310
311impl<F, B> ResponseForPanic for F
312where
313 F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
314{
315 type ResponseBody = B;
316
317 fn response_for_panic(
318 &mut self,
319 err: Box<dyn Any + Send + 'static>,
320 ) -> Response<Self::ResponseBody> {
321 self(err)
322 }
323}
324
325#[derive(Debug, Default, Clone, Copy)]
330#[non_exhaustive]
331pub struct DefaultResponseForPanic;
332
333impl ResponseForPanic for DefaultResponseForPanic {
334 type ResponseBody = Full;
335
336 fn response_for_panic(
337 &mut self,
338 err: Box<dyn Any + Send + 'static>,
339 ) -> Response<Self::ResponseBody> {
340 if let Some(s) = err.downcast_ref::<String>() {
341 tracing::error!("Service panicked: {}", s);
342 } else if let Some(s) = err.downcast_ref::<&str>() {
343 tracing::error!("Service panicked: {}", s);
344 } else {
345 tracing::error!(
346 "Service panicked but `CatchPanic` was unable to downcast the panic info"
347 );
348 };
349
350 let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked")));
351 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
352
353 #[allow(clippy::declare_interior_mutable_const)]
354 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
355 res.headers_mut()
356 .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
357
358 res
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 #![allow(unreachable_code)]
365
366 use super::*;
367 use crate::test_helpers::Body;
368 use http::Response;
369 use std::convert::Infallible;
370 use tower::{ServiceBuilder, ServiceExt};
371
372 #[tokio::test]
373 async fn panic_before_returning_future() {
374 let svc = ServiceBuilder::new()
375 .layer(CatchPanicLayer::new())
376 .service_fn(|_: Request<Body>| {
377 panic!("service panic");
378 async { Ok::<_, Infallible>(Response::new(Body::empty())) }
379 });
380
381 let req = Request::new(Body::empty());
382
383 let res = svc.oneshot(req).await.unwrap();
384
385 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
386 let body = crate::test_helpers::to_bytes(res).await.unwrap();
387 assert_eq!(&body[..], b"Service panicked");
388 }
389
390 #[tokio::test]
391 async fn panic_in_future() {
392 let svc = ServiceBuilder::new()
393 .layer(CatchPanicLayer::new())
394 .service_fn(|_: Request<Body>| async {
395 panic!("future panic");
396 Ok::<_, Infallible>(Response::new(Body::empty()))
397 });
398
399 let req = Request::new(Body::empty());
400
401 let res = svc.oneshot(req).await.unwrap();
402
403 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
404 let body = crate::test_helpers::to_bytes(res).await.unwrap();
405 assert_eq!(&body[..], b"Service panicked");
406 }
407}