zino_http/response/
rejection.rs

1use self::RejectionKind::*;
2use super::Response;
3use crate::request::{Context, RequestContext};
4use std::sync::Arc;
5use zino_core::{SharedString, error::Error, trace::TraceContext, validation::Validation, warn};
6
7/// A rejection response type.
8#[derive(Debug)]
9pub struct Rejection {
10    /// Rejection kind.
11    kind: RejectionKind,
12    /// Optional context.
13    context: Option<Arc<Context>>,
14    /// Optional trace context.
15    trace_context: Option<TraceContext>,
16}
17
18/// Rejection kind.
19#[derive(Debug)]
20#[non_exhaustive]
21enum RejectionKind {
22    /// 400 Bad Request
23    BadRequest(Validation),
24    /// 401 Unauthorized
25    Unauthorized(Error),
26    /// 403 Forbidden
27    Forbidden(Error),
28    /// 404 NotFound
29    NotFound(Error),
30    /// 405 Method Not Allowed
31    MethodNotAllowed(Error),
32    /// 409 Conflict
33    Conflict(Error),
34    /// 500 Internal Server Error
35    InternalServerError(Error),
36    /// 503 Service Unavailable
37    ServiceUnavailable(Error),
38}
39
40impl Rejection {
41    /// Creates a `400 Bad Request` rejection.
42    #[inline]
43    pub fn bad_request(validation: Validation) -> Self {
44        Self {
45            kind: BadRequest(validation),
46            context: None,
47            trace_context: None,
48        }
49    }
50
51    /// Creates a `401 Unauthorized` rejection.
52    #[inline]
53    pub fn unauthorized(err: impl Into<Error>) -> Self {
54        Self {
55            kind: Unauthorized(err.into()),
56            context: None,
57            trace_context: None,
58        }
59    }
60
61    /// Creates a `403 Forbidden` rejection.
62    #[inline]
63    pub fn forbidden(err: impl Into<Error>) -> Self {
64        Self {
65            kind: Forbidden(err.into()),
66            context: None,
67            trace_context: None,
68        }
69    }
70
71    /// Creates a `404 Not Found` rejection.
72    #[inline]
73    pub fn not_found(err: impl Into<Error>) -> Self {
74        Self {
75            kind: NotFound(err.into()),
76            context: None,
77            trace_context: None,
78        }
79    }
80
81    /// Creates a `405 Method Not Allowed` rejection.
82    #[inline]
83    pub fn method_not_allowed(err: impl Into<Error>) -> Self {
84        Self {
85            kind: MethodNotAllowed(err.into()),
86            context: None,
87            trace_context: None,
88        }
89    }
90
91    /// Creates a `409 Conflict` rejection.
92    #[inline]
93    pub fn conflict(err: impl Into<Error>) -> Self {
94        Self {
95            kind: Conflict(err.into()),
96            context: None,
97            trace_context: None,
98        }
99    }
100
101    /// Creates a `500 Internal Server Error` rejection.
102    #[inline]
103    pub fn internal_server_error(err: impl Into<Error>) -> Self {
104        Self {
105            kind: InternalServerError(err.into()),
106            context: None,
107            trace_context: None,
108        }
109    }
110
111    /// Creates a `503 Service Unavailable` rejection.
112    #[inline]
113    pub fn service_unavailable(err: impl Into<Error>) -> Self {
114        Self {
115            kind: ServiceUnavailable(err.into()),
116            context: None,
117            trace_context: None,
118        }
119    }
120
121    /// Creates a new instance with the validation entry.
122    #[inline]
123    pub fn from_validation_entry(key: impl Into<SharedString>, err: impl Into<Error>) -> Self {
124        let validation = Validation::from_entry(key, err);
125        Self::bad_request(validation)
126    }
127
128    /// Creates a new instance from an error classified by the error message.
129    pub fn from_error(err: impl Into<Error>) -> Self {
130        fn inner(err: Error) -> Rejection {
131            let message = err.message();
132            if message.starts_with("401 Unauthorized") {
133                Rejection::unauthorized(err)
134            } else if message.starts_with("403 Forbidden") {
135                Rejection::forbidden(err)
136            } else if message.starts_with("404 Not Found") {
137                Rejection::not_found(err)
138            } else if message.starts_with("405 Method Not Allowed") {
139                Rejection::method_not_allowed(err)
140            } else if message.starts_with("409 Conflict") {
141                Rejection::conflict(err)
142            } else if message.starts_with("503 Service Unavailable") {
143                Rejection::service_unavailable(err)
144            } else {
145                Rejection::internal_server_error(err)
146            }
147        }
148        inner(err.into())
149    }
150
151    /// Creates a new instance with the error message.
152    #[inline]
153    pub fn with_message(message: impl Into<SharedString>) -> Self {
154        Self::from_error(Error::new(message))
155    }
156
157    /// Provides the request context for the rejection.
158    #[inline]
159    pub fn context<T: RequestContext + ?Sized>(mut self, ctx: &T) -> Self {
160        self.context = ctx.get_context();
161        self.trace_context = Some(ctx.new_trace_context());
162        self
163    }
164
165    /// Returns the status code as `u16`.
166    #[inline]
167    pub fn status_code(&self) -> u16 {
168        match &self.kind {
169            BadRequest(_) => 400,
170            Unauthorized(_) => 401,
171            Forbidden(_) => 403,
172            NotFound(_) => 404,
173            MethodNotAllowed(_) => 405,
174            Conflict(_) => 409,
175            InternalServerError(_) => 500,
176            ServiceUnavailable(_) => 503,
177        }
178    }
179}
180
181macro_rules! impl_from_rejection {
182    ($Ty:ty) => {
183        impl From<Rejection> for Response<$Ty> {
184            fn from(rejection: Rejection) -> Self {
185                let mut res = match rejection.kind {
186                    BadRequest(validation) => {
187                        let mut res = Response::new(<$Ty>::BAD_REQUEST);
188                        res.set_validation_data(validation);
189                        res
190                    }
191                    Unauthorized(err) => {
192                        let mut res = Response::new(<$Ty>::UNAUTHORIZED);
193                        res.set_error_message(err);
194                        res
195                    }
196                    Forbidden(err) => {
197                        let mut res = Response::new(<$Ty>::FORBIDDEN);
198                        res.set_error_message(err);
199                        res
200                    }
201                    NotFound(err) => {
202                        let mut res = Response::new(<$Ty>::NOT_FOUND);
203                        res.set_error_message(err);
204                        res
205                    }
206                    MethodNotAllowed(err) => {
207                        let mut res = Response::new(<$Ty>::METHOD_NOT_ALLOWED);
208                        res.set_error_message(err);
209                        res
210                    }
211                    Conflict(err) => {
212                        let mut res = Response::new(<$Ty>::CONFLICT);
213                        res.set_error_message(err);
214                        res
215                    }
216                    InternalServerError(err) => {
217                        let mut res = Response::new(<$Ty>::INTERNAL_SERVER_ERROR);
218                        res.set_error_message(err);
219                        res
220                    }
221                    ServiceUnavailable(err) => {
222                        let mut res = Response::new(<$Ty>::SERVICE_UNAVAILABLE);
223                        res.set_error_message(err);
224                        res
225                    }
226                };
227                if let Some(ctx) = rejection.context {
228                    res.set_instance(ctx.instance().to_owned());
229                    res.set_start_time(ctx.start_time());
230                    res.set_request_id(ctx.request_id());
231                }
232                res.set_trace_context(rejection.trace_context);
233                res
234            }
235        }
236    };
237}
238
239impl_from_rejection!(http::StatusCode);
240
241#[cfg(feature = "http02")]
242impl_from_rejection!(http02::StatusCode);
243
244/// Trait for extracting rejections.
245pub trait ExtractRejection<T> {
246    /// Extracts a rejection with the request context.
247    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection>;
248}
249
250impl<T> ExtractRejection<T> for Option<T> {
251    #[inline]
252    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
253        self.ok_or_else(|| Rejection::not_found(warn!("resource does not exist")).context(ctx))
254    }
255}
256
257impl<T, E: Into<Error>> ExtractRejection<T> for Result<T, E> {
258    #[inline]
259    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
260        self.map_err(|err| Rejection::from_error(err).context(ctx))
261    }
262}
263
264impl<T, E: Into<Error>> ExtractRejection<T> for Result<Option<T>, E> {
265    #[inline]
266    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
267        self.map_err(|err| Rejection::from_error(err).context(ctx))?
268            .ok_or_else(|| Rejection::not_found(warn!("resource does not exist")).context(ctx))
269    }
270}
271
272/// Returns early with a [`Rejection`].
273#[macro_export]
274macro_rules! reject {
275    ($ctx:ident, $validation:expr $(,)?) => {{
276        return Err(Rejection::bad_request($validation).context(&$ctx).into());
277    }};
278    ($ctx:ident, $key:literal, $message:literal $(,)?) => {{
279        let err = Error::new($message);
280        warn!("invalid value for `{}`: {}", $key, $message);
281        return Err(Rejection::from_validation_entry($key, err).context(&$ctx).into());
282    }};
283    ($ctx:ident, $key:literal, $err:expr $(,)?) => {{
284        return Err(Rejection::from_validation_entry($key, $err).context(&$ctx).into());
285    }};
286    ($ctx:ident, $kind:ident, $message:literal $(,)?) => {{
287        let err = warn!($message);
288        return Err(Rejection::$kind(err).context(&$ctx).into());
289    }};
290    ($ctx:ident, $kind:ident, $err:expr $(,)?) => {{
291        return Err(Rejection::$kind($err).context(&$ctx).into());
292    }};
293    ($ctx:ident, $kind:ident, $fmt:expr, $($arg:tt)+) => {{
294        let err = warn!($fmt, $($arg)+);
295        return Err(Rejection::$kind(err).context(&$ctx).into());
296    }};
297}