via/error/
rescue.rs

1use std::borrow::Cow;
2use std::fmt::{self, Display, Formatter};
3
4use http::StatusCode;
5
6use crate::error::{Error, Errors};
7use crate::middleware::{BoxFuture, Middleware};
8use crate::response::{Response, ResponseBuilder};
9use crate::{Next, Pipe, Request};
10
11/// Recover from errors that occur in downstream middleware.
12///
13pub struct Rescue<F> {
14    recover: F,
15}
16
17/// Customize how an [`Error`] is converted to a response.
18///
19pub struct Sanitize<'a> {
20    json: bool,
21    error: &'a Error,
22    status: Option<StatusCode>,
23    message: Option<Cow<'a, str>>,
24}
25
26/// Recover from errors that occur in downstream middleware.
27///
28pub fn rescue<F>(recover: F) -> Rescue<F>
29where
30    F: Fn(Sanitize) -> Sanitize + Copy + Send + Sync + 'static,
31{
32    Rescue { recover }
33}
34
35impl<State, F> Middleware<State> for Rescue<F>
36where
37    State: Send + Sync + 'static,
38    F: Fn(Sanitize) -> Sanitize + Copy + Send + Sync + 'static,
39{
40    fn call(&self, request: Request<State>, next: Next<State>) -> BoxFuture {
41        let Self { recover } = *self;
42
43        Box::pin(async move {
44            next.call(request).await.or_else(|error| {
45                let response = Response::build();
46                let sanitize = Sanitize::new(&error);
47
48                recover(sanitize).pipe(response).or_else(|residual| {
49                    if cfg!(debug_assertions) {
50                        eprintln!("warn: a residual error occurred in rescue");
51                        eprintln!("{}", residual);
52                    }
53
54                    Ok(error.into())
55                })
56            })
57        })
58    }
59}
60
61impl<'a> Sanitize<'a> {
62    /// Generate a json response for the error.
63    ///
64    pub fn as_json(self) -> Self {
65        Self { json: true, ..self }
66    }
67
68    /// Sanitize the contained error based on the error source.
69    ///
70    pub fn map<F>(self, f: F) -> Self
71    where
72        F: FnOnce(Self, &(dyn std::error::Error + 'static)) -> Self,
73    {
74        if let Some(source) = self.error.source() {
75            f(self, source)
76        } else {
77            self
78        }
79    }
80
81    /// Use the canonical reason of the status code as the error message.
82    ///
83    pub fn with_canonical_reason(self) -> Self {
84        Self {
85            message: self.status_code().canonical_reason().map(Cow::Borrowed),
86            ..self
87        }
88    }
89
90    /// Provide a custom message to use for the response generated from this
91    /// error.
92    ///
93    pub fn with_message<T>(self, message: T) -> Self
94    where
95        Cow<'a, str>: From<T>,
96    {
97        Self {
98            message: Some(message.into()),
99            ..self
100        }
101    }
102
103    /// Overrides the HTTP status code of the error.
104    ///
105    pub fn with_status_code(self, status: StatusCode) -> Self {
106        Self {
107            status: Some(status),
108            ..self
109        }
110    }
111}
112
113impl<'a> Sanitize<'a> {
114    fn new(error: &'a Error) -> Self {
115        Self {
116            json: false,
117            error,
118            status: None,
119            message: None,
120        }
121    }
122
123    fn status_code(&self) -> StatusCode {
124        self.status.unwrap_or(self.error.status)
125    }
126}
127
128impl Display for Sanitize<'_> {
129    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
130        Display::fmt(self.error, f)
131    }
132}
133
134impl Pipe for Sanitize<'_> {
135    fn pipe(self, response: ResponseBuilder) -> Result<Response, Error> {
136        let status_code = self.status_code();
137        let response = response.status(status_code);
138
139        match self.message {
140            None if self.json => response.json(&self.error.repr_json(status_code)),
141            Some(message) if self.json => response.json(Errors::new(status_code).push(message)),
142
143            None => response.text(self.error.to_string()),
144            Some(message) => response.text(message),
145        }
146    }
147}