Skip to main content

rustio_core/
error.rs

1//! Unified error type for the framework.
2//!
3//! Handlers and middleware return `Result<Response, Error>`. The router
4//! converts any unhandled `Err` into an HTTP response as a final safety net.
5//!
6//! `Error::Internal(msg)` keeps the full message for logging (via [`Display`]
7//! and [`Error::message`]) but sanitizes it to a generic
8//! `"Internal Server Error"` body when converted into an HTTP response.
9
10use std::fmt;
11
12use crate::http::{status_text, Response};
13
14#[non_exhaustive]
15#[derive(Debug)]
16pub enum Error {
17    NotFound,
18    MethodNotAllowed,
19    BadRequest(String),
20    Unauthorized,
21    Forbidden,
22    Internal(String),
23}
24
25impl Error {
26    /// HTTP status code associated with this variant.
27    pub fn status(&self) -> u16 {
28        match self {
29            Error::NotFound => 404,
30            Error::MethodNotAllowed => 405,
31            Error::BadRequest(_) => 400,
32            Error::Unauthorized => 401,
33            Error::Forbidden => 403,
34            Error::Internal(_) => 500,
35        }
36    }
37
38    /// Human-readable message carried by the variant.
39    ///
40    /// For `Internal`, this returns the full underlying detail. That detail
41    /// is safe for logs but is *not* sent to HTTP clients — see
42    /// [`Error::into_response`].
43    pub fn message(&self) -> &str {
44        match self {
45            Error::NotFound => "Not Found",
46            Error::MethodNotAllowed => "Method Not Allowed",
47            Error::BadRequest(msg) => msg,
48            Error::Unauthorized => "Unauthorized",
49            Error::Forbidden => "Forbidden",
50            Error::Internal(msg) => msg,
51        }
52    }
53
54    /// Convert this error into an HTTP response.
55    ///
56    /// The body exposed to clients is sanitized for `Internal` — it always
57    /// reads `"Internal Server Error"`, never the original detail.
58    pub fn into_response(self) -> Response {
59        let status = self.status();
60        let body = match self {
61            Error::NotFound => String::from("Not Found"),
62            Error::MethodNotAllowed => String::from("Method Not Allowed"),
63            Error::BadRequest(msg) => msg,
64            Error::Unauthorized => String::from("Unauthorized"),
65            Error::Forbidden => String::from("Forbidden"),
66            Error::Internal(_) => String::from("Internal Server Error"),
67        };
68        status_text(status, body)
69    }
70}
71
72impl fmt::Display for Error {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        write!(f, "{} {}", self.status(), self.message())
75    }
76}
77
78impl std::error::Error for Error {}
79
80impl From<sqlx::Error> for Error {
81    fn from(value: sqlx::Error) -> Self {
82        Error::Internal(value.to_string())
83    }
84}
85
86/// Convert a handler result into a definite `Response`.
87///
88/// Useful in middleware that needs to observe both success and error paths
89/// before returning — e.g. attaching an `X-Request-Id` header to every
90/// response regardless of outcome.
91pub fn resolve(result: Result<Response, Error>) -> Response {
92    result.unwrap_or_else(Error::into_response)
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use http_body_util::BodyExt;
99
100    #[test]
101    fn status_codes_match_variant() {
102        assert_eq!(Error::NotFound.status(), 404);
103        assert_eq!(Error::MethodNotAllowed.status(), 405);
104        assert_eq!(Error::BadRequest(String::from("bad")).status(), 400);
105        assert_eq!(Error::Unauthorized.status(), 401);
106        assert_eq!(Error::Forbidden.status(), 403);
107        assert_eq!(Error::Internal(String::from("x")).status(), 500);
108    }
109
110    #[test]
111    fn parameterless_variants_use_status_phrase_as_message() {
112        assert_eq!(Error::NotFound.message(), "Not Found");
113        assert_eq!(Error::MethodNotAllowed.message(), "Method Not Allowed");
114        assert_eq!(Error::Unauthorized.message(), "Unauthorized");
115        assert_eq!(Error::Forbidden.message(), "Forbidden");
116    }
117
118    #[test]
119    fn parameterised_variants_carry_their_message() {
120        assert_eq!(Error::BadRequest(String::from("nope")).message(), "nope");
121        assert_eq!(Error::Internal(String::from("oops")).message(), "oops");
122    }
123
124    #[test]
125    fn into_response_uses_variant_status() {
126        assert_eq!(Error::NotFound.into_response().status().as_u16(), 404);
127        assert_eq!(Error::Forbidden.into_response().status().as_u16(), 403);
128        assert_eq!(
129            Error::BadRequest(String::from("x"))
130                .into_response()
131                .status()
132                .as_u16(),
133            400,
134        );
135        assert_eq!(
136            Error::Internal(String::from("x"))
137                .into_response()
138                .status()
139                .as_u16(),
140            500,
141        );
142    }
143
144    #[test]
145    fn display_shows_status_and_message() {
146        assert_eq!(format!("{}", Error::NotFound), "404 Not Found");
147        assert_eq!(format!("{}", Error::Forbidden), "403 Forbidden");
148        assert_eq!(
149            format!("{}", Error::Internal(String::from("oops"))),
150            "500 oops"
151        );
152    }
153
154    #[test]
155    fn resolve_passes_ok_through() {
156        let resp = status_text(204, "");
157        let resolved = resolve(Ok(resp));
158        assert_eq!(resolved.status().as_u16(), 204);
159    }
160
161    #[test]
162    fn resolve_converts_err_to_response() {
163        let resolved = resolve(Err(Error::Unauthorized));
164        assert_eq!(resolved.status().as_u16(), 401);
165    }
166
167    #[tokio::test]
168    async fn internal_response_body_is_sanitized() {
169        let resp = Error::Internal(String::from("db password: hunter2")).into_response();
170        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
171        let body = std::str::from_utf8(&bytes).unwrap();
172        assert_eq!(body, "Internal Server Error");
173        assert!(!body.contains("hunter2"));
174    }
175
176    #[tokio::test]
177    async fn public_error_bodies_use_status_phrase_or_message() {
178        async fn body_of(err: Error) -> String {
179            let bytes = err
180                .into_response()
181                .into_body()
182                .collect()
183                .await
184                .unwrap()
185                .to_bytes();
186            String::from_utf8(bytes.to_vec()).unwrap()
187        }
188        assert_eq!(body_of(Error::NotFound).await, "Not Found");
189        assert_eq!(body_of(Error::Unauthorized).await, "Unauthorized");
190        assert_eq!(body_of(Error::Forbidden).await, "Forbidden");
191        assert_eq!(body_of(Error::BadRequest(String::from("bad"))).await, "bad");
192    }
193
194    #[test]
195    fn internal_display_and_message_retain_detail_for_logging() {
196        let err = Error::Internal(String::from("leaked secret"));
197        assert_eq!(err.message(), "leaked secret");
198        assert!(format!("{err}").contains("leaked secret"));
199    }
200}