Skip to main content

rustio_core/
error.rs

1use std::fmt;
2
3use crate::http::{Response, status_text};
4
5#[non_exhaustive]
6#[derive(Debug)]
7pub enum Error {
8    NotFound,
9    MethodNotAllowed,
10    BadRequest(String),
11    Unauthorized,
12    Forbidden,
13    Internal(String),
14}
15
16impl Error {
17    pub fn status(&self) -> u16 {
18        match self {
19            Error::NotFound => 404,
20            Error::MethodNotAllowed => 405,
21            Error::BadRequest(_) => 400,
22            Error::Unauthorized => 401,
23            Error::Forbidden => 403,
24            Error::Internal(_) => 500,
25        }
26    }
27
28    pub fn message(&self) -> &str {
29        match self {
30            Error::NotFound => "Not Found",
31            Error::MethodNotAllowed => "Method Not Allowed",
32            Error::BadRequest(msg) => msg,
33            Error::Unauthorized => "Unauthorized",
34            Error::Forbidden => "Forbidden",
35            Error::Internal(msg) => msg,
36        }
37    }
38
39    pub fn into_response(self) -> Response {
40        let status = self.status();
41        let message = match self {
42            Error::NotFound => String::from("Not Found"),
43            Error::MethodNotAllowed => String::from("Method Not Allowed"),
44            Error::BadRequest(msg) => msg,
45            Error::Unauthorized => String::from("Unauthorized"),
46            Error::Forbidden => String::from("Forbidden"),
47            Error::Internal(msg) => msg,
48        };
49        status_text(status, message)
50    }
51}
52
53impl fmt::Display for Error {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "{} {}", self.status(), self.message())
56    }
57}
58
59impl std::error::Error for Error {}
60
61impl From<sqlx::Error> for Error {
62    fn from(value: sqlx::Error) -> Self {
63        Error::Internal(value.to_string())
64    }
65}
66
67pub fn resolve(result: Result<Response, Error>) -> Response {
68    result.unwrap_or_else(Error::into_response)
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn status_codes_match_variant() {
77        assert_eq!(Error::NotFound.status(), 404);
78        assert_eq!(Error::MethodNotAllowed.status(), 405);
79        assert_eq!(Error::BadRequest(String::from("bad")).status(), 400);
80        assert_eq!(Error::Unauthorized.status(), 401);
81        assert_eq!(Error::Forbidden.status(), 403);
82        assert_eq!(Error::Internal(String::from("x")).status(), 500);
83    }
84
85    #[test]
86    fn parameterless_variants_use_status_phrase_as_message() {
87        assert_eq!(Error::NotFound.message(), "Not Found");
88        assert_eq!(Error::MethodNotAllowed.message(), "Method Not Allowed");
89        assert_eq!(Error::Unauthorized.message(), "Unauthorized");
90        assert_eq!(Error::Forbidden.message(), "Forbidden");
91    }
92
93    #[test]
94    fn parameterised_variants_carry_their_message() {
95        assert_eq!(Error::BadRequest(String::from("nope")).message(), "nope");
96        assert_eq!(Error::Internal(String::from("oops")).message(), "oops");
97    }
98
99    #[test]
100    fn into_response_uses_variant_status() {
101        assert_eq!(Error::NotFound.into_response().status().as_u16(), 404);
102        assert_eq!(Error::Forbidden.into_response().status().as_u16(), 403);
103        assert_eq!(
104            Error::BadRequest(String::from("x")).into_response().status().as_u16(),
105            400,
106        );
107        assert_eq!(
108            Error::Internal(String::from("x")).into_response().status().as_u16(),
109            500,
110        );
111    }
112
113    #[test]
114    fn display_shows_status_and_message() {
115        assert_eq!(format!("{}", Error::NotFound), "404 Not Found");
116        assert_eq!(format!("{}", Error::Forbidden), "403 Forbidden");
117        assert_eq!(
118            format!("{}", Error::Internal(String::from("oops"))),
119            "500 oops"
120        );
121    }
122
123    #[test]
124    fn resolve_passes_ok_through() {
125        let resp = status_text(204, "");
126        let resolved = resolve(Ok(resp));
127        assert_eq!(resolved.status().as_u16(), 204);
128    }
129
130    #[test]
131    fn resolve_converts_err_to_response() {
132        let resolved = resolve(Err(Error::Unauthorized));
133        assert_eq!(resolved.status().as_u16(), 401);
134    }
135}