1use std::borrow::Cow;
38use std::fmt::{self, Debug, Formatter};
39use std::sync::{Arc, LazyLock};
40
41use async_trait::async_trait;
42use bytes::Bytes;
43use mime::Mime;
44use serde::Serialize;
45
46use crate::handler::{Handler, WhenHoop};
47use crate::http::mime::guess_accept_mime;
48use crate::http::{Request, ResBody, Response, StatusCode, StatusError, header};
49use crate::{Depot, FlowCtrl};
50
51static SUPPORTED_FORMATS: LazyLock<Vec<mime::Name>> =
52 LazyLock::new(|| vec![mime::JSON, mime::HTML, mime::XML, mime::PLAIN]);
53const SALVO_LINK: &str = r#"<a href="https://salvo.rs" target="_blank">salvo</a>"#;
54
55pub struct Catcher {
59 goal: Arc<dyn Handler>,
60 hoops: Vec<Arc<dyn Handler>>,
61}
62impl Default for Catcher {
63 fn default() -> Self {
65 Self {
66 goal: Arc::new(DefaultGoal::new()),
67 hoops: vec![],
68 }
69 }
70}
71impl Debug for Catcher {
72 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
73 f.debug_struct("Catcher").finish()
74 }
75}
76impl Catcher {
77 pub fn new<H: Handler>(goal: H) -> Self {
79 Self {
80 goal: Arc::new(goal),
81 hoops: vec![],
82 }
83 }
84
85 #[inline]
87 #[must_use]
88 pub fn hoops(&self) -> &Vec<Arc<dyn Handler>> {
89 &self.hoops
90 }
91 #[inline]
93 pub fn hoops_mut(&mut self) -> &mut Vec<Arc<dyn Handler>> {
94 &mut self.hoops
95 }
96
97 #[inline]
99 #[must_use]
100 pub fn hoop<H: Handler>(mut self, hoop: H) -> Self {
101 self.hoops.push(Arc::new(hoop));
102 self
103 }
104
105 #[inline]
109 #[must_use]
110 pub fn hoop_when<H, F>(mut self, hoop: H, filter: F) -> Self
111 where
112 H: Handler,
113 F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static,
114 {
115 self.hoops.push(Arc::new(WhenHoop {
116 inner: hoop,
117 filter,
118 }));
119 self
120 }
121
122 pub async fn catch(&self, req: &mut Request, depot: &mut Depot, res: &mut Response) {
124 let mut ctrl = FlowCtrl::new(self.hoops.iter().chain([&self.goal]).cloned().collect());
125 ctrl.call_next(req, depot, res).await;
126 }
127}
128
129#[derive(Default, Debug)]
136pub struct DefaultGoal {
137 footer: Option<Cow<'static, str>>,
138}
139impl DefaultGoal {
140 #[must_use]
142 pub fn new() -> Self {
143 Self { footer: None }
144 }
145 #[inline]
147 #[must_use]
148 pub fn with_footer(footer: impl Into<Cow<'static, str>>) -> Self {
149 Self::new().footer(footer)
150 }
151
152 #[must_use]
157 pub fn footer(mut self, footer: impl Into<Cow<'static, str>>) -> Self {
158 self.footer = Some(footer.into());
159 self
160 }
161}
162#[async_trait]
163impl Handler for DefaultGoal {
164 async fn handle(
165 &self,
166 req: &mut Request,
167 _depot: &mut Depot,
168 res: &mut Response,
169 _ctrl: &mut FlowCtrl,
170 ) {
171 let status = res.status_code.unwrap_or(StatusCode::NOT_FOUND);
172 if (status.is_server_error() || status.is_client_error())
173 && (res.body.is_none() || res.body.is_error())
174 {
175 write_error_default(req, res, self.footer.as_deref());
176 }
177 }
178}
179
180fn status_error_html(
181 code: StatusCode,
182 name: &str,
183 brief: &str,
184 detail: Option<&str>,
185 cause: Option<&str>,
186 footer: Option<&str>,
187) -> String {
188 format!(
189 r#"<!DOCTYPE html>
190<html>
191<head>
192 <meta charset="utf-8">
193 <meta name="viewport" content="width=device-width">
194 <title>{0}: {1}</title>
195 <style>
196 :root {{
197 --bg-color: #fff;
198 --text-color: #222;
199 }}
200 body {{
201 background: var(--bg-color);
202 color: var(--text-color);
203 text-align: center;
204 }}
205 pre {{ text-align: left; padding: 0 1rem; }}
206 footer{{text-align:center;}}
207 @media (prefers-color-scheme: dark) {{
208 :root {{
209 --bg-color: #222;
210 --text-color: #ddd;
211 }}
212 a:link {{ color: red; }}
213 a:visited {{ color: #a8aeff; }}
214 a:hover {{color: #a8aeff;}}
215 a:active {{color: #a8aeff;}}
216 }}
217 </style>
218</head>
219<body>
220 <div><h1>{}: {}</h1><h3>{}</h3>{}{}<hr><footer>{}</footer></div>
221</body>
222</html>"#,
223 code.as_u16(),
224 name,
225 brief,
226 detail
227 .map(|detail| format!("<pre>{detail}</pre>"))
228 .unwrap_or_default(),
229 cause
230 .map(|cause| format!("<pre>{cause:#?}</pre>"))
231 .unwrap_or_default(),
232 footer.unwrap_or(SALVO_LINK)
233 )
234}
235
236#[inline]
237fn status_error_json(
238 code: StatusCode,
239 name: &str,
240 brief: &str,
241 detail: Option<&str>,
242 cause: Option<&str>,
243) -> String {
244 #[derive(Serialize)]
245 struct Data<'a> {
246 error: Error<'a>,
247 }
248 #[derive(Serialize)]
249 struct Error<'a> {
250 code: u16,
251 name: &'a str,
252 brief: &'a str,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 detail: Option<&'a str>,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 cause: Option<&'a str>,
257 }
258 let data = Data {
259 error: Error {
260 code: code.as_u16(),
261 name,
262 brief,
263 detail,
264 cause,
265 },
266 };
267 serde_json::to_string(&data).unwrap_or_default()
268}
269
270fn status_error_plain(
271 code: StatusCode,
272 name: &str,
273 brief: &str,
274 detail: Option<&str>,
275 cause: Option<&str>,
276) -> String {
277 format!(
278 "code: {}\n\nname: {}\n\nbrief: {}{}{}",
279 code.as_u16(),
280 name,
281 brief,
282 detail
283 .map(|detail| format!("\n\ndetail: {detail}"))
284 .unwrap_or_default(),
285 cause
286 .map(|cause| format!("\n\ncause: {cause:#?}"))
287 .unwrap_or_default(),
288 )
289}
290
291fn status_error_xml(
292 code: StatusCode,
293 name: &str,
294 brief: &str,
295 detail: Option<&str>,
296 cause: Option<&str>,
297) -> String {
298 #[derive(Serialize)]
299 struct Data<'a> {
300 code: u16,
301 name: &'a str,
302 brief: &'a str,
303 #[serde(skip_serializing_if = "Option::is_none")]
304 detail: Option<&'a str>,
305 #[serde(skip_serializing_if = "Option::is_none")]
306 cause: Option<&'a str>,
307 }
308
309 let data = Data {
310 code: code.as_u16(),
311 name,
312 brief,
313 detail,
314 cause,
315 };
316 serde_xml_rs::to_string(&data).unwrap_or_default()
317}
318
319#[doc(hidden)]
321#[inline]
322pub fn status_error_bytes(
323 err: &StatusError,
324 prefer_format: &Mime,
325 footer: Option<&str>,
326) -> (Mime, Bytes) {
327 let format = if !SUPPORTED_FORMATS.contains(&prefer_format.subtype()) {
328 mime::TEXT_HTML
329 } else {
330 prefer_format.clone()
331 };
332 #[cfg(debug_assertions)]
333 let cause = err.cause.as_ref().map(|e| format!("{e:#?}"));
334 #[cfg(not(debug_assertions))]
335 let cause: Option<&str> = None;
336 #[cfg(debug_assertions)]
337 let detail = err.detail.as_deref();
338 #[cfg(not(debug_assertions))]
339 let detail: Option<&str> = None;
340 let content = match format.subtype().as_ref() {
341 "plain" => status_error_plain(err.code, &err.name, &err.brief, detail, cause.as_deref()),
342 "json" => status_error_json(err.code, &err.name, &err.brief, detail, cause.as_deref()),
343 "xml" => status_error_xml(err.code, &err.name, &err.brief, detail, cause.as_deref()),
344 _ => status_error_html(
345 err.code,
346 &err.name,
347 &err.brief,
348 detail,
349 cause.as_deref(),
350 footer,
351 ),
352 };
353 (format, Bytes::from(content))
354}
355
356#[doc(hidden)]
357pub fn write_error_default(req: &Request, res: &mut Response, footer: Option<&str>) {
358 let format = guess_accept_mime(req, None);
359 let (format, data) = if let ResBody::Error(body) = &res.body {
360 status_error_bytes(body, &format, footer)
361 } else {
362 let status = res.status_code.unwrap_or(StatusCode::NOT_FOUND);
363 status_error_bytes(
364 &StatusError::from_code(status).unwrap_or_else(StatusError::internal_server_error),
365 &format,
366 footer,
367 )
368 };
369 res.headers_mut().insert(
370 header::CONTENT_TYPE,
371 format.to_string().parse().expect("invalid `Content-Type`"),
372 );
373 let _ = res.write_body(data);
374}
375
376#[cfg(test)]
377mod tests {
378 use crate::prelude::*;
379 use crate::test::{ResponseExt, TestClient};
380
381 use super::*;
382
383 struct CustomError;
384 #[async_trait]
385 impl Writer for CustomError {
386 async fn write(self, _req: &mut Request, _depot: &mut Depot, res: &mut Response) {
387 res.status_code = Some(StatusCode::INTERNAL_SERVER_ERROR);
388 res.render("custom error");
389 }
390 }
391
392 #[handler]
393 async fn handle404(
394 &self,
395 _req: &Request,
396 _depot: &Depot,
397 res: &mut Response,
398 ctrl: &mut FlowCtrl,
399 ) {
400 if res.status_code.is_none() || Some(StatusCode::NOT_FOUND) == res.status_code {
401 res.render("Custom 404 Error Page");
402 ctrl.skip_rest();
403 }
404 }
405
406 #[tokio::test]
407 async fn test_handle_error() {
408 #[handler]
409 async fn handle_custom() -> Result<(), CustomError> {
410 Err(CustomError)
411 }
412 let router = Router::new().push(Router::with_path("custom").get(handle_custom));
413 let service = Service::new(router);
414
415 async fn access(service: &Service, name: &str) -> String {
416 TestClient::get(format!("http://127.0.0.1:8698/{name}"))
417 .send(service)
418 .await
419 .take_string()
420 .await
421 .unwrap()
422 }
423
424 assert_eq!(access(&service, "custom").await, "custom error");
425 }
426
427 #[tokio::test]
428 async fn test_custom_catcher() {
429 #[handler]
430 async fn hello() -> &'static str {
431 "Hello World"
432 }
433 let router = Router::new().get(hello);
434 let service = Service::new(router).catcher(Catcher::default().hoop(handle404));
435
436 async fn access(service: &Service, name: &str) -> String {
437 TestClient::get(format!("http://127.0.0.1:8698/{name}"))
438 .send(service)
439 .await
440 .take_string()
441 .await
442 .unwrap()
443 }
444
445 assert_eq!(access(&service, "notfound").await, "Custom 404 Error Page");
446 }
447}