salvo_core/
catcher.rs

1//! Catch and handle errors.
2//!
3//! If the status code of [`Response`] is an error, and the body of [`Response`] is empty, then salvo
4//! will try to use `Catcher` to catch the error and display a friendly error page.
5//!
6//! You can return a system default [`Catcher`] through [`Catcher::default()`], and then add it to
7//! [`Service`](crate::Service):
8//!
9//! # Example
10//!
11//! ```
12//! use salvo_core::prelude::*;
13//! use salvo_core::catcher::Catcher;
14//!
15//! #[handler]
16//! async fn handle404(&self, res: &mut Response, ctrl: &mut FlowCtrl) {
17//!     if let Some(StatusCode::NOT_FOUND) = res.status_code {
18//!         res.render("Custom 404 Error Page");
19//!         ctrl.skip_rest();
20//!     }
21//! }
22//!
23//! #[tokio::main]
24//! async fn main() {
25//!     Service::new(Router::new()).catcher(Catcher::default().hoop(handle404));
26//! }
27//! ```
28//!
29//! The default [`Catcher`] supports sending error pages in `XML`, `JSON`, `HTML`, `Text` formats.
30//!
31//! You can add a custom error handler to [`Catcher`] by adding `hoop` to the default `Catcher`.
32//! The error handler is still [`Handler`].
33//!
34//! You can add multiple custom error catching handlers to [`Catcher`] through [`Catcher::hoop`]. The custom error
35//! handler can call [`FlowCtrl::skip_rest()`] method to skip next error handlers and return early.
36
37use std::borrow::Cow;
38use std::fmt::{self, Debug, Formatter};
39use std::sync::{Arc, LazyLock};
40
41use async_trait::async_trait;
42use bytes::Bytes;
43use mime::Mime;
44use serde::Serialize;
45
46use crate::handler::{Handler, WhenHoop};
47use crate::http::mime::guess_accept_mime;
48use crate::http::{Request, ResBody, Response, StatusCode, StatusError, header};
49use crate::{Depot, FlowCtrl};
50
51static SUPPORTED_FORMATS: LazyLock<Vec<mime::Name>> =
52    LazyLock::new(|| vec![mime::JSON, mime::HTML, mime::XML, mime::PLAIN]);
53const SALVO_LINK: &str = r#"<a href="https://salvo.rs" target="_blank">salvo</a>"#;
54
55/// `Catcher` is used to catch errors.
56///
57/// View [module level documentation](index.html) for more details.
58pub struct Catcher {
59    goal: Arc<dyn Handler>,
60    hoops: Vec<Arc<dyn Handler>>,
61}
62impl Default for Catcher {
63    /// Create new `Catcher` with its goal handler is [`DefaultGoal`].
64    fn default() -> Self {
65        Self {
66            goal: Arc::new(DefaultGoal::new()),
67            hoops: vec![],
68        }
69    }
70}
71impl Debug for Catcher {
72    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
73        f.debug_struct("Catcher").finish()
74    }
75}
76impl Catcher {
77    /// Create new `Catcher`.
78    pub fn new<H: Handler>(goal: H) -> Self {
79        Self {
80            goal: Arc::new(goal),
81            hoops: vec![],
82        }
83    }
84
85    /// Get current catcher's middlewares reference.
86    #[inline]
87    #[must_use]
88    pub fn hoops(&self) -> &Vec<Arc<dyn Handler>> {
89        &self.hoops
90    }
91    /// Get current catcher's middlewares mutable reference.
92    #[inline]
93    pub fn hoops_mut(&mut self) -> &mut Vec<Arc<dyn Handler>> {
94        &mut self.hoops
95    }
96
97    /// Add a handler as middleware, it will run the handler when error caught.
98    #[inline]
99    #[must_use]
100    pub fn hoop<H: Handler>(mut self, hoop: H) -> Self {
101        self.hoops.push(Arc::new(hoop));
102        self
103    }
104
105    /// Add a handler as middleware, it will run the handler when error caught.
106    ///
107    /// This middleware is only effective when the filter returns true..
108    #[inline]
109    #[must_use]
110    pub fn hoop_when<H, F>(mut self, hoop: H, filter: F) -> Self
111    where
112        H: Handler,
113        F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static,
114    {
115        self.hoops.push(Arc::new(WhenHoop {
116            inner: hoop,
117            filter,
118        }));
119        self
120    }
121
122    /// Catch error and send error page.
123    pub async fn catch(&self, req: &mut Request, depot: &mut Depot, res: &mut Response) {
124        let mut ctrl = FlowCtrl::new(self.hoops.iter().chain([&self.goal]).cloned().collect());
125        ctrl.call_next(req, depot, res).await;
126    }
127}
128
129/// Default [`Handler`] used as goal for [`Catcher`].
130///
131/// If http status is error, and all custom handlers is not catch it and write body,
132/// `DefaultGoal` will used to catch them.
133///
134/// `DefaultGoal` supports sending error pages in `XML`, `JSON`, `HTML`, `Text` formats.
135#[derive(Default, Debug)]
136pub struct DefaultGoal {
137    footer: Option<Cow<'static, str>>,
138}
139impl DefaultGoal {
140    /// Create new `DefaultGoal`.
141    #[must_use]
142    pub fn new() -> Self {
143        Self { footer: None }
144    }
145    /// Create new `DefaultGoal` with custom footer.
146    #[inline]
147    #[must_use]
148    pub fn with_footer(footer: impl Into<Cow<'static, str>>) -> Self {
149        Self::new().footer(footer)
150    }
151
152    /// Set custom footer which is only used in html error page.
153    ///
154    /// If footer is `None`, then use default footer.
155    /// Default footer is `<a href="https://salvo.rs" target="_blank">salvo</a>`.
156    #[must_use]
157    pub fn footer(mut self, footer: impl Into<Cow<'static, str>>) -> Self {
158        self.footer = Some(footer.into());
159        self
160    }
161}
162#[async_trait]
163impl Handler for DefaultGoal {
164    async fn handle(
165        &self,
166        req: &mut Request,
167        _depot: &mut Depot,
168        res: &mut Response,
169        _ctrl: &mut FlowCtrl,
170    ) {
171        let status = res.status_code.unwrap_or(StatusCode::NOT_FOUND);
172        if (status.is_server_error() || status.is_client_error())
173            && (res.body.is_none() || res.body.is_error())
174        {
175            write_error_default(req, res, self.footer.as_deref());
176        }
177    }
178}
179
180fn status_error_html(
181    code: StatusCode,
182    name: &str,
183    brief: &str,
184    detail: Option<&str>,
185    cause: Option<&str>,
186    footer: Option<&str>,
187) -> String {
188    format!(
189        r#"<!DOCTYPE html>
190<html>
191<head>
192    <meta charset="utf-8">
193    <meta name="viewport" content="width=device-width">
194    <title>{0}: {1}</title>
195    <style>
196    :root {{
197        --bg-color: #fff;
198        --text-color: #222;
199    }}
200    body {{
201        background: var(--bg-color);
202        color: var(--text-color);
203        text-align: center;
204    }}
205    pre {{ text-align: left; padding: 0 1rem; }}
206    footer{{text-align:center;}}
207    @media (prefers-color-scheme: dark) {{
208        :root {{
209            --bg-color: #222;
210            --text-color: #ddd;
211        }}
212        a:link {{ color: red; }}
213        a:visited {{ color: #a8aeff; }}
214        a:hover {{color: #a8aeff;}}
215        a:active {{color: #a8aeff;}}
216    }}
217    </style>
218</head>
219<body>
220    <div><h1>{}: {}</h1><h3>{}</h3>{}{}<hr><footer>{}</footer></div>
221</body>
222</html>"#,
223        code.as_u16(),
224        name,
225        brief,
226        detail
227            .map(|detail| format!("<pre>{detail}</pre>"))
228            .unwrap_or_default(),
229        cause
230            .map(|cause| format!("<pre>{cause:#?}</pre>"))
231            .unwrap_or_default(),
232        footer.unwrap_or(SALVO_LINK)
233    )
234}
235
236#[inline]
237fn status_error_json(
238    code: StatusCode,
239    name: &str,
240    brief: &str,
241    detail: Option<&str>,
242    cause: Option<&str>,
243) -> String {
244    #[derive(Serialize)]
245    struct Data<'a> {
246        error: Error<'a>,
247    }
248    #[derive(Serialize)]
249    struct Error<'a> {
250        code: u16,
251        name: &'a str,
252        brief: &'a str,
253        #[serde(skip_serializing_if = "Option::is_none")]
254        detail: Option<&'a str>,
255        #[serde(skip_serializing_if = "Option::is_none")]
256        cause: Option<&'a str>,
257    }
258    let data = Data {
259        error: Error {
260            code: code.as_u16(),
261            name,
262            brief,
263            detail,
264            cause,
265        },
266    };
267    serde_json::to_string(&data).unwrap_or_default()
268}
269
270fn status_error_plain(
271    code: StatusCode,
272    name: &str,
273    brief: &str,
274    detail: Option<&str>,
275    cause: Option<&str>,
276) -> String {
277    format!(
278        "code: {}\n\nname: {}\n\nbrief: {}{}{}",
279        code.as_u16(),
280        name,
281        brief,
282        detail
283            .map(|detail| format!("\n\ndetail: {detail}"))
284            .unwrap_or_default(),
285        cause
286            .map(|cause| format!("\n\ncause: {cause:#?}"))
287            .unwrap_or_default(),
288    )
289}
290
291fn status_error_xml(
292    code: StatusCode,
293    name: &str,
294    brief: &str,
295    detail: Option<&str>,
296    cause: Option<&str>,
297) -> String {
298    #[derive(Serialize)]
299    struct Data<'a> {
300        code: u16,
301        name: &'a str,
302        brief: &'a str,
303        #[serde(skip_serializing_if = "Option::is_none")]
304        detail: Option<&'a str>,
305        #[serde(skip_serializing_if = "Option::is_none")]
306        cause: Option<&'a str>,
307    }
308
309    let data = Data {
310        code: code.as_u16(),
311        name,
312        brief,
313        detail,
314        cause,
315    };
316    serde_xml_rs::to_string(&data).unwrap_or_default()
317}
318
319/// Create bytes from `StatusError`.
320#[doc(hidden)]
321#[inline]
322pub fn status_error_bytes(
323    err: &StatusError,
324    prefer_format: &Mime,
325    footer: Option<&str>,
326) -> (Mime, Bytes) {
327    let format = if !SUPPORTED_FORMATS.contains(&prefer_format.subtype()) {
328        mime::TEXT_HTML
329    } else {
330        prefer_format.clone()
331    };
332    #[cfg(debug_assertions)]
333    let cause = err.cause.as_ref().map(|e| format!("{e:#?}"));
334    #[cfg(not(debug_assertions))]
335    let cause: Option<&str> = None;
336    #[cfg(debug_assertions)]
337    let detail = err.detail.as_deref();
338    #[cfg(not(debug_assertions))]
339    let detail: Option<&str> = None;
340    let content = match format.subtype().as_ref() {
341        "plain" => status_error_plain(err.code, &err.name, &err.brief, detail, cause.as_deref()),
342        "json" => status_error_json(err.code, &err.name, &err.brief, detail, cause.as_deref()),
343        "xml" => status_error_xml(err.code, &err.name, &err.brief, detail, cause.as_deref()),
344        _ => status_error_html(
345            err.code,
346            &err.name,
347            &err.brief,
348            detail,
349            cause.as_deref(),
350            footer,
351        ),
352    };
353    (format, Bytes::from(content))
354}
355
356#[doc(hidden)]
357pub fn write_error_default(req: &Request, res: &mut Response, footer: Option<&str>) {
358    let format = guess_accept_mime(req, None);
359    let (format, data) = if let ResBody::Error(body) = &res.body {
360        status_error_bytes(body, &format, footer)
361    } else {
362        let status = res.status_code.unwrap_or(StatusCode::NOT_FOUND);
363        status_error_bytes(
364            &StatusError::from_code(status).unwrap_or_else(StatusError::internal_server_error),
365            &format,
366            footer,
367        )
368    };
369    res.headers_mut().insert(
370        header::CONTENT_TYPE,
371        format.to_string().parse().expect("invalid `Content-Type`"),
372    );
373    let _ = res.write_body(data);
374}
375
376#[cfg(test)]
377mod tests {
378    use crate::prelude::*;
379    use crate::test::{ResponseExt, TestClient};
380
381    use super::*;
382
383    struct CustomError;
384    #[async_trait]
385    impl Writer for CustomError {
386        async fn write(self, _req: &mut Request, _depot: &mut Depot, res: &mut Response) {
387            res.status_code = Some(StatusCode::INTERNAL_SERVER_ERROR);
388            res.render("custom error");
389        }
390    }
391
392    #[handler]
393    async fn handle404(
394        &self,
395        _req: &Request,
396        _depot: &Depot,
397        res: &mut Response,
398        ctrl: &mut FlowCtrl,
399    ) {
400        if res.status_code.is_none() || Some(StatusCode::NOT_FOUND) == res.status_code {
401            res.render("Custom 404 Error Page");
402            ctrl.skip_rest();
403        }
404    }
405
406    #[tokio::test]
407    async fn test_handle_error() {
408        #[handler]
409        async fn handle_custom() -> Result<(), CustomError> {
410            Err(CustomError)
411        }
412        let router = Router::new().push(Router::with_path("custom").get(handle_custom));
413        let service = Service::new(router);
414
415        async fn access(service: &Service, name: &str) -> String {
416            TestClient::get(format!("http://127.0.0.1:8698/{name}"))
417                .send(service)
418                .await
419                .take_string()
420                .await
421                .unwrap()
422        }
423
424        assert_eq!(access(&service, "custom").await, "custom error");
425    }
426
427    #[tokio::test]
428    async fn test_custom_catcher() {
429        #[handler]
430        async fn hello() -> &'static str {
431            "Hello World"
432        }
433        let router = Router::new().get(hello);
434        let service = Service::new(router).catcher(Catcher::default().hoop(handle404));
435
436        async fn access(service: &Service, name: &str) -> String {
437            TestClient::get(format!("http://127.0.0.1:8698/{name}"))
438                .send(service)
439                .await
440                .take_string()
441                .await
442                .unwrap()
443        }
444
445        assert_eq!(access(&service, "notfound").await, "Custom 404 Error Page");
446    }
447}