shuttle_common/models/
error.rs

1use std::fmt::Display;
2
3use http::{status::InvalidStatusCode, StatusCode};
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "axum")]
7impl axum::response::IntoResponse for ApiError {
8    fn into_response(self) -> axum::response::Response {
9        #[cfg(feature = "tracing-in-errors")]
10        tracing::warn!("{}", self.message);
11
12        (
13            self.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
14            axum::Json(self),
15        )
16            .into_response()
17    }
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
22#[typeshare::typeshare]
23pub struct ApiError {
24    message: String,
25    status_code: u16,
26}
27
28impl ApiError {
29    #[inline(always)]
30    pub fn new(message: impl Display, status_code: StatusCode) -> Self {
31        Self {
32            message: message.to_string(),
33            status_code: status_code.as_u16(),
34        }
35    }
36    #[inline(always)]
37    pub fn status(&self) -> Result<StatusCode, InvalidStatusCode> {
38        StatusCode::from_u16(self.status_code)
39    }
40    #[inline(always)]
41    pub fn message(&self) -> &str {
42        self.message.as_str()
43    }
44
45    /// Create a one-off internal error with a string message exposed to the user.
46    #[inline(always)]
47    pub fn internal(message: impl AsRef<str>) -> Self {
48        #[cfg(feature = "tracing-in-errors")]
49        {
50            /// Dummy wrapper to allow logging a string `as &dyn std::error::Error`
51            #[derive(Debug)]
52            struct InternalError(String);
53            impl std::error::Error for InternalError {}
54            impl std::fmt::Display for InternalError {
55                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56                    f.write_str(self.0.as_str())
57                }
58            }
59
60            tracing::error!(
61                error = &InternalError(message.as_ref().to_owned()) as &dyn std::error::Error,
62                "Internal API Error"
63            );
64        }
65
66        Self::_internal(message.as_ref())
67    }
68
69    /// Creates an internal error without exposing sensitive information to the user.
70    #[inline(always)]
71    #[allow(unused_variables)]
72    pub fn internal_safe<E: std::error::Error + 'static>(safe_msg: impl Display, error: E) -> Self {
73        #[cfg(feature = "tracing-in-errors")]
74        tracing::error!(error = &error as &dyn std::error::Error, "{}", safe_msg);
75
76        // Return the raw error during debug builds
77        #[cfg(debug_assertions)]
78        {
79            Self::_internal(error)
80        }
81        // Return the safe message during release builds
82        #[cfg(not(debug_assertions))]
83        {
84            Self::_internal(safe_msg)
85        }
86    }
87
88    // 5xx
89    #[inline(always)]
90    fn _internal(error: impl Display) -> Self {
91        Self::new(error.to_string(), StatusCode::INTERNAL_SERVER_ERROR)
92    }
93    #[inline(always)]
94    pub fn service_unavailable(error: impl Display) -> Self {
95        Self::new(error.to_string(), StatusCode::SERVICE_UNAVAILABLE)
96    }
97    // 4xx
98    #[inline(always)]
99    pub fn bad_request(error: impl Display) -> Self {
100        Self::new(error.to_string(), StatusCode::BAD_REQUEST)
101    }
102    #[inline(always)]
103    pub fn unauthorized(error: impl Display) -> Self {
104        Self::new(error.to_string(), StatusCode::UNAUTHORIZED)
105    }
106    #[inline(always)]
107    pub fn forbidden(error: impl Display) -> Self {
108        Self::new(error.to_string(), StatusCode::FORBIDDEN)
109    }
110    #[inline(always)]
111    pub fn not_found(error: impl Display) -> Self {
112        Self::new(error.to_string(), StatusCode::NOT_FOUND)
113    }
114}
115
116pub trait ErrorContext<T> {
117    /// Make a new internal server error with the given message.
118    fn context_internal_error(self, message: impl Display) -> Result<T, ApiError>;
119
120    /// Make a new internal server error using the given function to create the message.
121    #[inline(always)]
122    fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError>
123    where
124        Self: Sized,
125    {
126        self.context_internal_error(message())
127    }
128}
129
130impl<T, E> ErrorContext<T> for Result<T, E>
131where
132    E: std::error::Error + 'static,
133{
134    #[inline(always)]
135    fn context_internal_error(self, message: impl Display) -> Result<T, ApiError> {
136        self.map_err(|error| ApiError::internal_safe(message, error))
137    }
138}
139
140impl std::fmt::Display for ApiError {
141    #[cfg(feature = "display")]
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        use crossterm::style::Stylize;
144        write!(
145            f,
146            "{}\nMessage: {}",
147            self.status()
148                .map(|s| s.to_string())
149                .unwrap_or("Unknown".to_owned())
150                .bold(),
151            self.message.to_string().red()
152        )
153    }
154    #[cfg(not(feature = "display"))]
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "{}\nMessage: {}",
159            self.status()
160                .map(|s| s.to_string())
161                .unwrap_or("Unknown".to_owned()),
162            self.message,
163        )
164    }
165}
166
167impl std::error::Error for ApiError {}