shuttle_common/models/
error.rs

1use std::fmt::{Display, Formatter};
2
3use http::StatusCode;
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "display")]
7use crossterm::style::Stylize;
8
9#[cfg(feature = "axum")]
10impl axum::response::IntoResponse for ApiError {
11    fn into_response(self) -> axum::response::Response {
12        tracing::warn!("{}", self.message);
13
14        (self.status(), axum::Json(self)).into_response()
15    }
16}
17
18#[derive(Serialize, Deserialize, Debug)]
19#[typeshare::typeshare]
20pub struct ApiError {
21    pub message: String,
22    pub status_code: u16,
23}
24
25impl ApiError {
26    pub fn internal(message: &str) -> Self {
27        Self {
28            message: message.to_string(),
29            status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
30        }
31    }
32
33    /// Creates an internal error without exposing sensitive information to the user.
34    #[inline(always)]
35    #[allow(unused_variables)]
36    pub fn internal_safe<E>(message: &str, error: E) -> Self
37    where
38        E: std::error::Error + 'static,
39    {
40        tracing::error!(error = &error as &dyn std::error::Error, "{message}");
41
42        // Return the raw error during debug builds
43        #[cfg(debug_assertions)]
44        {
45            ApiError::internal(&error.to_string())
46        }
47        // Return the safe message during release builds
48        #[cfg(not(debug_assertions))]
49        {
50            ApiError::internal(message)
51        }
52    }
53
54    pub fn unavailable(error: impl std::error::Error) -> Self {
55        Self {
56            message: error.to_string(),
57            status_code: StatusCode::SERVICE_UNAVAILABLE.as_u16(),
58        }
59    }
60
61    pub fn bad_request(error: impl std::error::Error) -> Self {
62        Self {
63            message: error.to_string(),
64            status_code: StatusCode::BAD_REQUEST.as_u16(),
65        }
66    }
67
68    pub fn unauthorized() -> Self {
69        Self {
70            message: "Unauthorized".to_string(),
71            status_code: StatusCode::UNAUTHORIZED.as_u16(),
72        }
73    }
74
75    pub fn forbidden() -> Self {
76        Self {
77            message: "Forbidden".to_string(),
78            status_code: StatusCode::FORBIDDEN.as_u16(),
79        }
80    }
81
82    pub fn status(&self) -> StatusCode {
83        StatusCode::from_u16(self.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
84    }
85}
86
87pub trait ErrorContext<T> {
88    /// Make a new internal server error with the given message.
89    #[inline(always)]
90    fn context_internal_error(self, message: &str) -> Result<T, ApiError>
91    where
92        Self: Sized,
93    {
94        self.with_context_internal_error(move || message.to_string())
95    }
96
97    /// Make a new internal server error using the given function to create the message.
98    fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError>;
99
100    /// Make a new bad request error with the given message.
101    #[inline(always)]
102    fn context_bad_request(self, message: &str) -> Result<T, ApiError>
103    where
104        Self: Sized,
105    {
106        self.with_context_bad_request(move || message.to_string())
107    }
108
109    /// Make a new bad request error using the given function to create the message.
110    fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError>;
111
112    /// Make a new not found error with the given message.
113    #[inline(always)]
114    fn context_not_found(self, message: &str) -> Result<T, ApiError>
115    where
116        Self: Sized,
117    {
118        self.with_context_not_found(move || message.to_string())
119    }
120
121    /// Make a new not found error using the given function to create the message.
122    fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError>;
123}
124
125impl<T, E> ErrorContext<T> for Result<T, E>
126where
127    E: std::error::Error + 'static,
128{
129    #[inline(always)]
130    fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
131        match self {
132            Ok(value) => Ok(value),
133            Err(error) => Err(ApiError::internal_safe(message().as_ref(), error)),
134        }
135    }
136
137    #[inline(always)]
138    fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
139        match self {
140            Ok(value) => Ok(value),
141            Err(error) => Err({
142                let message = message();
143                tracing::warn!(
144                    error = &error as &dyn std::error::Error,
145                    "bad request: {message}"
146                );
147
148                ApiError {
149                    message,
150                    status_code: StatusCode::BAD_REQUEST.as_u16(),
151                }
152            }),
153        }
154    }
155
156    #[inline(always)]
157    fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
158        match self {
159            Ok(value) => Ok(value),
160            Err(error) => Err({
161                let message = message();
162                tracing::warn!(
163                    error = &error as &dyn std::error::Error,
164                    "not found: {message}"
165                );
166
167                ApiError {
168                    message,
169                    status_code: StatusCode::NOT_FOUND.as_u16(),
170                }
171            }),
172        }
173    }
174}
175
176impl<T> ErrorContext<T> for Option<T> {
177    #[inline]
178    fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
179        match self {
180            Some(value) => Ok(value),
181            None => Err(ApiError::internal(message().as_ref())),
182        }
183    }
184
185    #[inline]
186    fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
187        match self {
188            Some(value) => Ok(value),
189            None => Err({
190                ApiError {
191                    message: message(),
192                    status_code: StatusCode::BAD_REQUEST.as_u16(),
193                }
194            }),
195        }
196    }
197
198    #[inline]
199    fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
200        match self {
201            Some(value) => Ok(value),
202            None => Err({
203                ApiError {
204                    message: message(),
205                    status_code: StatusCode::NOT_FOUND.as_u16(),
206                }
207            }),
208        }
209    }
210}
211
212impl Display for ApiError {
213    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214        #[cfg(feature = "display")]
215        return write!(
216            f,
217            "{}\nMessage: {}",
218            self.status().to_string().bold(),
219            self.message.to_string().red()
220        );
221        #[cfg(not(feature = "display"))]
222        return write!(f, "{}\nMessage: {}", self.status(), self.message);
223    }
224}
225
226impl std::error::Error for ApiError {}