1use std::fmt;
11
12use crate::http::{status_text, Response};
13
14#[non_exhaustive]
15#[derive(Debug)]
16pub enum Error {
17 NotFound,
18 MethodNotAllowed,
19 BadRequest(String),
20 Unauthorized,
21 Forbidden,
22 PayloadTooLarge,
26 TooManyRequests,
29 Internal(String),
30}
31
32impl Error {
33 pub fn status(&self) -> u16 {
35 match self {
36 Error::NotFound => 404,
37 Error::MethodNotAllowed => 405,
38 Error::BadRequest(_) => 400,
39 Error::Unauthorized => 401,
40 Error::Forbidden => 403,
41 Error::PayloadTooLarge => 413,
42 Error::TooManyRequests => 429,
43 Error::Internal(_) => 500,
44 }
45 }
46
47 pub fn message(&self) -> &str {
53 match self {
54 Error::NotFound => "Not Found",
55 Error::MethodNotAllowed => "Method Not Allowed",
56 Error::BadRequest(msg) => msg,
57 Error::Unauthorized => "Unauthorized",
58 Error::Forbidden => "Forbidden",
59 Error::PayloadTooLarge => "Payload Too Large",
60 Error::TooManyRequests => "Too Many Requests",
61 Error::Internal(msg) => msg,
62 }
63 }
64
65 pub fn into_response(self) -> Response {
70 let status = self.status();
71 let body = match self {
72 Error::NotFound => String::from("Not Found"),
73 Error::MethodNotAllowed => String::from("Method Not Allowed"),
74 Error::BadRequest(msg) => msg,
75 Error::Unauthorized => String::from("Unauthorized"),
76 Error::Forbidden => String::from("Forbidden"),
77 Error::PayloadTooLarge => String::from("Payload Too Large"),
78 Error::TooManyRequests => String::from("Too Many Requests"),
79 Error::Internal(_) => String::from("Internal Server Error"),
80 };
81 status_text(status, body)
82 }
83}
84
85impl fmt::Display for Error {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 write!(f, "{} {}", self.status(), self.message())
88 }
89}
90
91impl std::error::Error for Error {}
92
93impl From<sqlx::Error> for Error {
94 fn from(value: sqlx::Error) -> Self {
95 Error::Internal(value.to_string())
96 }
97}
98
99pub fn resolve(result: Result<Response, Error>) -> Response {
105 result.unwrap_or_else(Error::into_response)
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use http_body_util::BodyExt;
112
113 #[test]
114 fn status_codes_match_variant() {
115 assert_eq!(Error::NotFound.status(), 404);
116 assert_eq!(Error::MethodNotAllowed.status(), 405);
117 assert_eq!(Error::BadRequest(String::from("bad")).status(), 400);
118 assert_eq!(Error::Unauthorized.status(), 401);
119 assert_eq!(Error::Forbidden.status(), 403);
120 assert_eq!(Error::Internal(String::from("x")).status(), 500);
121 }
122
123 #[test]
124 fn parameterless_variants_use_status_phrase_as_message() {
125 assert_eq!(Error::NotFound.message(), "Not Found");
126 assert_eq!(Error::MethodNotAllowed.message(), "Method Not Allowed");
127 assert_eq!(Error::Unauthorized.message(), "Unauthorized");
128 assert_eq!(Error::Forbidden.message(), "Forbidden");
129 }
130
131 #[test]
132 fn parameterised_variants_carry_their_message() {
133 assert_eq!(Error::BadRequest(String::from("nope")).message(), "nope");
134 assert_eq!(Error::Internal(String::from("oops")).message(), "oops");
135 }
136
137 #[test]
138 fn into_response_uses_variant_status() {
139 assert_eq!(Error::NotFound.into_response().status().as_u16(), 404);
140 assert_eq!(Error::Forbidden.into_response().status().as_u16(), 403);
141 assert_eq!(
142 Error::BadRequest(String::from("x"))
143 .into_response()
144 .status()
145 .as_u16(),
146 400,
147 );
148 assert_eq!(
149 Error::Internal(String::from("x"))
150 .into_response()
151 .status()
152 .as_u16(),
153 500,
154 );
155 }
156
157 #[test]
158 fn display_shows_status_and_message() {
159 assert_eq!(format!("{}", Error::NotFound), "404 Not Found");
160 assert_eq!(format!("{}", Error::Forbidden), "403 Forbidden");
161 assert_eq!(
162 format!("{}", Error::Internal(String::from("oops"))),
163 "500 oops"
164 );
165 }
166
167 #[test]
168 fn resolve_passes_ok_through() {
169 let resp = status_text(204, "");
170 let resolved = resolve(Ok(resp));
171 assert_eq!(resolved.status().as_u16(), 204);
172 }
173
174 #[test]
175 fn resolve_converts_err_to_response() {
176 let resolved = resolve(Err(Error::Unauthorized));
177 assert_eq!(resolved.status().as_u16(), 401);
178 }
179
180 #[tokio::test]
181 async fn internal_response_body_is_sanitized() {
182 let resp = Error::Internal(String::from("db password: hunter2")).into_response();
183 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
184 let body = std::str::from_utf8(&bytes).unwrap();
185 assert_eq!(body, "Internal Server Error");
186 assert!(!body.contains("hunter2"));
187 }
188
189 #[tokio::test]
190 async fn public_error_bodies_use_status_phrase_or_message() {
191 async fn body_of(err: Error) -> String {
192 let bytes = err
193 .into_response()
194 .into_body()
195 .collect()
196 .await
197 .unwrap()
198 .to_bytes();
199 String::from_utf8(bytes.to_vec()).unwrap()
200 }
201 assert_eq!(body_of(Error::NotFound).await, "Not Found");
202 assert_eq!(body_of(Error::Unauthorized).await, "Unauthorized");
203 assert_eq!(body_of(Error::Forbidden).await, "Forbidden");
204 assert_eq!(body_of(Error::BadRequest(String::from("bad"))).await, "bad");
205 }
206
207 #[test]
208 fn internal_display_and_message_retain_detail_for_logging() {
209 let err = Error::Internal(String::from("leaked secret"));
210 assert_eq!(err.message(), "leaked secret");
211 assert!(format!("{err}").contains("leaked secret"));
212 }
213}