salvo_extra/
catch_panic.rs

1//! Middleware for catch panic in handlers.
2//!
3//! This middleware catches panics and write `500 Internal Server Error` into response.
4//! This middleware should be used as the first middleware.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use salvo_core::prelude::*;
10//! use salvo_extra::catch_panic::CatchPanic;
11//!
12//! #[handler]
13//! async fn hello() {
14//!     panic!("panic error!");
15//! }
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new().hoop(CatchPanic::new()).get(hello);
20//!     let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
21//!     Server::new(acceptor).serve(router).await;
22//! }
23//! ```
24
25use std::panic::AssertUnwindSafe;
26
27use futures_util::FutureExt;
28
29use salvo_core::http::{Request, Response, StatusError};
30use salvo_core::{async_trait, Depot, FlowCtrl, Error, Handler};
31
32
33/// Middleware that catches panics in handlers and converts them to HTTP 500 responses.
34/// 
35/// This middleware should be registered as the first middleware in your router chain
36/// to ensure it catches panics from all subsequent handlers and middlewares.
37/// 
38/// View [module level documentation](index.html) for more details.
39#[derive(Default, Debug)]
40pub struct CatchPanic {}
41impl CatchPanic {
42    /// Create new `CatchPanic` middleware.
43    #[inline]
44    pub fn new() -> Self {
45        CatchPanic {}
46    }
47}
48
49#[async_trait]
50impl Handler for CatchPanic {
51    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
52        if let Err(e) = AssertUnwindSafe(ctrl.call_next(req, depot, res)).catch_unwind().await {
53            tracing::error!(error = ?e, "panic occurred");
54            res.render(
55                StatusError::internal_server_error()
56                    .brief("Panic occurred on server.")
57                    .cause(Error::other(format!("{e:#?}"))),
58            );
59        }
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use salvo_core::prelude::*;
66    use salvo_core::test::{ResponseExt, TestClient};
67    use tracing_test::traced_test;
68
69    use super::*;
70
71    #[tokio::test]
72    #[traced_test]
73    async fn test_catch_panic() {
74        #[handler]
75        async fn hello() -> &'static str {
76            panic!("panic error!");
77        }
78
79        let router = Router::new()
80            .hoop(CatchPanic::new())
81            .push(Router::with_path("hello").get(hello));
82
83        TestClient::get("http://127.0.0.1:5801/hello")
84            .send(router)
85            .await
86            .take_string()
87            .await
88            .unwrap();
89        assert!(logs_contain("panic occurred"));
90    }
91}