1use std::fmt;
11
12use crate::http::{status_text, Response};
13
14#[non_exhaustive]
15#[derive(Debug)]
16pub enum Error {
17 NotFound,
18 MethodNotAllowed,
19 BadRequest(String),
20 Unauthorized,
21 Forbidden,
22 Internal(String),
23}
24
25impl Error {
26 pub fn status(&self) -> u16 {
28 match self {
29 Error::NotFound => 404,
30 Error::MethodNotAllowed => 405,
31 Error::BadRequest(_) => 400,
32 Error::Unauthorized => 401,
33 Error::Forbidden => 403,
34 Error::Internal(_) => 500,
35 }
36 }
37
38 pub fn message(&self) -> &str {
44 match self {
45 Error::NotFound => "Not Found",
46 Error::MethodNotAllowed => "Method Not Allowed",
47 Error::BadRequest(msg) => msg,
48 Error::Unauthorized => "Unauthorized",
49 Error::Forbidden => "Forbidden",
50 Error::Internal(msg) => msg,
51 }
52 }
53
54 pub fn into_response(self) -> Response {
59 let status = self.status();
60 let body = match self {
61 Error::NotFound => String::from("Not Found"),
62 Error::MethodNotAllowed => String::from("Method Not Allowed"),
63 Error::BadRequest(msg) => msg,
64 Error::Unauthorized => String::from("Unauthorized"),
65 Error::Forbidden => String::from("Forbidden"),
66 Error::Internal(_) => String::from("Internal Server Error"),
67 };
68 status_text(status, body)
69 }
70}
71
72impl fmt::Display for Error {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "{} {}", self.status(), self.message())
75 }
76}
77
78impl std::error::Error for Error {}
79
80impl From<sqlx::Error> for Error {
81 fn from(value: sqlx::Error) -> Self {
82 Error::Internal(value.to_string())
83 }
84}
85
86pub fn resolve(result: Result<Response, Error>) -> Response {
92 result.unwrap_or_else(Error::into_response)
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use http_body_util::BodyExt;
99
100 #[test]
101 fn status_codes_match_variant() {
102 assert_eq!(Error::NotFound.status(), 404);
103 assert_eq!(Error::MethodNotAllowed.status(), 405);
104 assert_eq!(Error::BadRequest(String::from("bad")).status(), 400);
105 assert_eq!(Error::Unauthorized.status(), 401);
106 assert_eq!(Error::Forbidden.status(), 403);
107 assert_eq!(Error::Internal(String::from("x")).status(), 500);
108 }
109
110 #[test]
111 fn parameterless_variants_use_status_phrase_as_message() {
112 assert_eq!(Error::NotFound.message(), "Not Found");
113 assert_eq!(Error::MethodNotAllowed.message(), "Method Not Allowed");
114 assert_eq!(Error::Unauthorized.message(), "Unauthorized");
115 assert_eq!(Error::Forbidden.message(), "Forbidden");
116 }
117
118 #[test]
119 fn parameterised_variants_carry_their_message() {
120 assert_eq!(Error::BadRequest(String::from("nope")).message(), "nope");
121 assert_eq!(Error::Internal(String::from("oops")).message(), "oops");
122 }
123
124 #[test]
125 fn into_response_uses_variant_status() {
126 assert_eq!(Error::NotFound.into_response().status().as_u16(), 404);
127 assert_eq!(Error::Forbidden.into_response().status().as_u16(), 403);
128 assert_eq!(
129 Error::BadRequest(String::from("x"))
130 .into_response()
131 .status()
132 .as_u16(),
133 400,
134 );
135 assert_eq!(
136 Error::Internal(String::from("x"))
137 .into_response()
138 .status()
139 .as_u16(),
140 500,
141 );
142 }
143
144 #[test]
145 fn display_shows_status_and_message() {
146 assert_eq!(format!("{}", Error::NotFound), "404 Not Found");
147 assert_eq!(format!("{}", Error::Forbidden), "403 Forbidden");
148 assert_eq!(
149 format!("{}", Error::Internal(String::from("oops"))),
150 "500 oops"
151 );
152 }
153
154 #[test]
155 fn resolve_passes_ok_through() {
156 let resp = status_text(204, "");
157 let resolved = resolve(Ok(resp));
158 assert_eq!(resolved.status().as_u16(), 204);
159 }
160
161 #[test]
162 fn resolve_converts_err_to_response() {
163 let resolved = resolve(Err(Error::Unauthorized));
164 assert_eq!(resolved.status().as_u16(), 401);
165 }
166
167 #[tokio::test]
168 async fn internal_response_body_is_sanitized() {
169 let resp = Error::Internal(String::from("db password: hunter2")).into_response();
170 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
171 let body = std::str::from_utf8(&bytes).unwrap();
172 assert_eq!(body, "Internal Server Error");
173 assert!(!body.contains("hunter2"));
174 }
175
176 #[tokio::test]
177 async fn public_error_bodies_use_status_phrase_or_message() {
178 async fn body_of(err: Error) -> String {
179 let bytes = err
180 .into_response()
181 .into_body()
182 .collect()
183 .await
184 .unwrap()
185 .to_bytes();
186 String::from_utf8(bytes.to_vec()).unwrap()
187 }
188 assert_eq!(body_of(Error::NotFound).await, "Not Found");
189 assert_eq!(body_of(Error::Unauthorized).await, "Unauthorized");
190 assert_eq!(body_of(Error::Forbidden).await, "Forbidden");
191 assert_eq!(body_of(Error::BadRequest(String::from("bad"))).await, "bad");
192 }
193
194 #[test]
195 fn internal_display_and_message_retain_detail_for_logging() {
196 let err = Error::Internal(String::from("leaked secret"));
197 assert_eq!(err.message(), "leaked secret");
198 assert!(format!("{err}").contains("leaked secret"));
199 }
200}