1use std::fmt;
2
3use crate::http::{Response, status_text};
4
5#[non_exhaustive]
6#[derive(Debug)]
7pub enum Error {
8 NotFound,
9 MethodNotAllowed,
10 BadRequest(String),
11 Unauthorized,
12 Forbidden,
13 Internal(String),
14}
15
16impl Error {
17 pub fn status(&self) -> u16 {
18 match self {
19 Error::NotFound => 404,
20 Error::MethodNotAllowed => 405,
21 Error::BadRequest(_) => 400,
22 Error::Unauthorized => 401,
23 Error::Forbidden => 403,
24 Error::Internal(_) => 500,
25 }
26 }
27
28 pub fn message(&self) -> &str {
29 match self {
30 Error::NotFound => "Not Found",
31 Error::MethodNotAllowed => "Method Not Allowed",
32 Error::BadRequest(msg) => msg,
33 Error::Unauthorized => "Unauthorized",
34 Error::Forbidden => "Forbidden",
35 Error::Internal(msg) => msg,
36 }
37 }
38
39 pub fn into_response(self) -> Response {
40 let status = self.status();
41 let message = match self {
42 Error::NotFound => String::from("Not Found"),
43 Error::MethodNotAllowed => String::from("Method Not Allowed"),
44 Error::BadRequest(msg) => msg,
45 Error::Unauthorized => String::from("Unauthorized"),
46 Error::Forbidden => String::from("Forbidden"),
47 Error::Internal(msg) => msg,
48 };
49 status_text(status, message)
50 }
51}
52
53impl fmt::Display for Error {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 write!(f, "{} {}", self.status(), self.message())
56 }
57}
58
59impl std::error::Error for Error {}
60
61impl From<sqlx::Error> for Error {
62 fn from(value: sqlx::Error) -> Self {
63 Error::Internal(value.to_string())
64 }
65}
66
67pub fn resolve(result: Result<Response, Error>) -> Response {
68 result.unwrap_or_else(Error::into_response)
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn status_codes_match_variant() {
77 assert_eq!(Error::NotFound.status(), 404);
78 assert_eq!(Error::MethodNotAllowed.status(), 405);
79 assert_eq!(Error::BadRequest(String::from("bad")).status(), 400);
80 assert_eq!(Error::Unauthorized.status(), 401);
81 assert_eq!(Error::Forbidden.status(), 403);
82 assert_eq!(Error::Internal(String::from("x")).status(), 500);
83 }
84
85 #[test]
86 fn parameterless_variants_use_status_phrase_as_message() {
87 assert_eq!(Error::NotFound.message(), "Not Found");
88 assert_eq!(Error::MethodNotAllowed.message(), "Method Not Allowed");
89 assert_eq!(Error::Unauthorized.message(), "Unauthorized");
90 assert_eq!(Error::Forbidden.message(), "Forbidden");
91 }
92
93 #[test]
94 fn parameterised_variants_carry_their_message() {
95 assert_eq!(Error::BadRequest(String::from("nope")).message(), "nope");
96 assert_eq!(Error::Internal(String::from("oops")).message(), "oops");
97 }
98
99 #[test]
100 fn into_response_uses_variant_status() {
101 assert_eq!(Error::NotFound.into_response().status().as_u16(), 404);
102 assert_eq!(Error::Forbidden.into_response().status().as_u16(), 403);
103 assert_eq!(
104 Error::BadRequest(String::from("x")).into_response().status().as_u16(),
105 400,
106 );
107 assert_eq!(
108 Error::Internal(String::from("x")).into_response().status().as_u16(),
109 500,
110 );
111 }
112
113 #[test]
114 fn display_shows_status_and_message() {
115 assert_eq!(format!("{}", Error::NotFound), "404 Not Found");
116 assert_eq!(format!("{}", Error::Forbidden), "403 Forbidden");
117 assert_eq!(
118 format!("{}", Error::Internal(String::from("oops"))),
119 "500 oops"
120 );
121 }
122
123 #[test]
124 fn resolve_passes_ok_through() {
125 let resp = status_text(204, "");
126 let resolved = resolve(Ok(resp));
127 assert_eq!(resolved.status().as_u16(), 204);
128 }
129
130 #[test]
131 fn resolve_converts_err_to_response() {
132 let resolved = resolve(Err(Error::Unauthorized));
133 assert_eq!(resolved.status().as_u16(), 401);
134 }
135}