tower_async_http/
catch_panic.rs1use bytes::Bytes;
88use futures_util::future::FutureExt;
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Full};
92use std::{any::Any, panic::AssertUnwindSafe};
93use tower_async_layer::Layer;
94use tower_async_service::Service;
95
96use crate::BoxError;
97
98#[derive(Debug, Clone, Copy, Default)]
103pub struct CatchPanicLayer<T> {
104 panic_handler: T,
105}
106
107impl CatchPanicLayer<DefaultResponseForPanic> {
108 pub fn new() -> Self {
110 CatchPanicLayer {
111 panic_handler: DefaultResponseForPanic,
112 }
113 }
114}
115
116impl<T> CatchPanicLayer<T> {
117 pub fn custom(panic_handler: T) -> Self
119 where
120 T: ResponseForPanic,
121 {
122 Self { panic_handler }
123 }
124}
125
126impl<T, S> Layer<S> for CatchPanicLayer<T>
127where
128 T: Clone,
129{
130 type Service = CatchPanic<S, T>;
131
132 fn layer(&self, inner: S) -> Self::Service {
133 CatchPanic {
134 inner,
135 panic_handler: self.panic_handler.clone(),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Copy)]
144pub struct CatchPanic<S, T> {
145 inner: S,
146 panic_handler: T,
147}
148
149impl<S> CatchPanic<S, DefaultResponseForPanic> {
150 pub fn new(inner: S) -> Self {
152 Self {
153 inner,
154 panic_handler: DefaultResponseForPanic,
155 }
156 }
157}
158
159impl<S, T> CatchPanic<S, T> {
160 define_inner_service_accessors!();
161
162 pub fn custom(inner: S, panic_handler: T) -> Self
164 where
165 T: ResponseForPanic,
166 {
167 Self {
168 inner,
169 panic_handler,
170 }
171 }
172}
173
174impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
175where
176 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
177 ResBody: Body<Data = Bytes> + Send + 'static,
178 ResBody::Error: Into<BoxError>,
179 T: ResponseForPanic + Clone,
180 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
181 <T::ResponseBody as Body>::Error: Into<BoxError>,
182{
183 type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
184 type Error = S::Error;
185
186 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
187 let future = match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
188 Ok(future) => future,
189 Err(panic_err) => {
190 return Ok(self
191 .panic_handler
192 .response_for_panic(panic_err)
193 .map(|body| body.map_err(Into::into).boxed_unsync()))
194 }
195 };
196 match AssertUnwindSafe(future).catch_unwind().await {
197 Ok(res) => match res {
198 Ok(res) => Ok(res.map(|body| body.map_err(Into::into).boxed_unsync())),
199 Err(err) => Err(err),
200 },
201 Err(panic_err) => Ok(self
202 .panic_handler
203 .response_for_panic(panic_err)
204 .map(|body| body.map_err(Into::into).boxed_unsync())),
205 }
206 }
207}
208
209pub trait ResponseForPanic: Clone {
211 type ResponseBody;
213
214 fn response_for_panic(
216 &self,
217 err: Box<dyn Any + Send + 'static>,
218 ) -> Response<Self::ResponseBody>;
219}
220
221impl<F, B> ResponseForPanic for F
222where
223 F: Fn(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
224{
225 type ResponseBody = B;
226
227 fn response_for_panic(
228 &self,
229 err: Box<dyn Any + Send + 'static>,
230 ) -> Response<Self::ResponseBody> {
231 self(err)
232 }
233}
234
235#[derive(Debug, Default, Clone, Copy)]
240#[non_exhaustive]
241pub struct DefaultResponseForPanic;
242
243impl ResponseForPanic for DefaultResponseForPanic {
244 type ResponseBody = Full<Bytes>;
245
246 fn response_for_panic(
247 &self,
248 err: Box<dyn Any + Send + 'static>,
249 ) -> Response<Self::ResponseBody> {
250 if let Some(s) = err.downcast_ref::<String>() {
251 tracing::error!("Service panicked: {}", s);
252 } else if let Some(s) = err.downcast_ref::<&str>() {
253 tracing::error!("Service panicked: {}", s);
254 } else {
255 tracing::error!(
256 "Service panicked but `CatchPanic` was unable to downcast the panic info"
257 );
258 };
259
260 let mut res = Response::new(Full::from("Service panicked"));
261 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
262
263 #[allow(clippy::declare_interior_mutable_const)]
264 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
265 res.headers_mut()
266 .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
267
268 res
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 #![allow(unreachable_code)]
275
276 use super::*;
277
278 use crate::test_helpers::{self, Body};
279
280 use hyper::Response;
281 use std::convert::Infallible;
282 use tower_async::{ServiceBuilder, ServiceExt};
283
284 #[tokio::test]
285 async fn panic_before_returning_future() {
286 let svc = ServiceBuilder::new()
287 .layer(CatchPanicLayer::new())
288 .service_fn(|_: Request<Body>| {
289 panic!("service panic");
290 async { Ok::<_, Infallible>(Response::new(Body::empty())) }
291 });
292
293 let req = Request::new(Body::empty());
294
295 let res = svc.oneshot(req).await.unwrap();
296
297 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
298 let body = test_helpers::to_bytes(res).await.unwrap();
299 assert_eq!(&body[..], b"Service panicked");
300 }
301
302 #[tokio::test]
303 async fn panic_in_future() {
304 let svc = ServiceBuilder::new()
305 .layer(CatchPanicLayer::new())
306 .service_fn(|_: Request<Body>| async {
307 panic!("future panic");
308 Ok::<_, Infallible>(Response::new(Body::empty()))
309 });
310
311 let req = Request::new(Body::empty());
312
313 let res = svc.oneshot(req).await.unwrap();
314
315 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
316 let body = test_helpers::to_bytes(res).await.unwrap();
317 assert_eq!(&body[..], b"Service panicked");
318 }
319}