rustauth_core/api/
error.rs1use http::{header, Response, StatusCode};
2use serde::{Deserialize, Serialize};
3
4use crate::error::RustAuthError;
5use crate::error_codes::ErrorCode;
6use crate::rate_limit::RateLimitRejection;
7
8use super::endpoint::{ApiResponse, Body};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ApiErrorCode {
12 NotFound,
13 InvalidOrigin,
14 InvalidCallbackUrl,
15 InvalidRedirectUrl,
16 InvalidErrorCallbackUrl,
17 InvalidNewUserCallbackUrl,
18 MissingOrNullOrigin,
19 CrossSiteNavigationLoginBlocked,
20 TooManyRequests,
21}
22
23impl ApiErrorCode {
24 pub fn as_str(self) -> &'static str {
25 match self {
26 Self::NotFound => "NOT_FOUND",
27 Self::InvalidOrigin => "INVALID_ORIGIN",
28 Self::InvalidCallbackUrl => "INVALID_CALLBACK_URL",
29 Self::InvalidRedirectUrl => "INVALID_REDIRECT_URL",
30 Self::InvalidErrorCallbackUrl => "INVALID_ERROR_CALLBACK_URL",
31 Self::InvalidNewUserCallbackUrl => "INVALID_NEW_USER_CALLBACK_URL",
32 Self::MissingOrNullOrigin => "MISSING_OR_NULL_ORIGIN",
33 Self::CrossSiteNavigationLoginBlocked => "CROSS_SITE_NAVIGATION_LOGIN_BLOCKED",
34 Self::TooManyRequests => "TOO_MANY_REQUESTS",
35 }
36 }
37
38 pub fn message(self) -> &'static str {
39 match self {
40 Self::NotFound => "Not Found",
41 Self::InvalidOrigin => "Invalid origin",
42 Self::InvalidCallbackUrl => "Invalid callbackURL",
43 Self::InvalidRedirectUrl => "Invalid redirectURL",
44 Self::InvalidErrorCallbackUrl => "Invalid errorCallbackURL",
45 Self::InvalidNewUserCallbackUrl => "Invalid newUserCallbackURL",
46 Self::MissingOrNullOrigin => "Missing or null Origin",
47 Self::CrossSiteNavigationLoginBlocked => {
48 "Cross-site navigation login blocked. This request appears to be a CSRF attack."
49 }
50 Self::TooManyRequests => "Too many requests. Please try again later.",
51 }
52 }
53}
54
55impl ErrorCode for ApiErrorCode {
56 fn as_str(&self) -> &str {
57 (*self).as_str()
58 }
59
60 fn message(&self) -> &str {
61 (*self).message()
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct ApiErrorResponse {
67 pub code: String,
68 pub message: String,
69 #[serde(default)]
70 #[serde(skip_serializing_if = "Option::is_none")]
71 #[serde(rename = "originalMessage")]
72 pub original_message: Option<String>,
73}
74
75impl ApiErrorResponse {
76 pub fn from_error_code(code: impl ErrorCode) -> Self {
77 Self {
78 code: code.as_str().to_owned(),
79 message: code.message().to_owned(),
80 original_message: None,
81 }
82 }
83}
84
85pub fn response(status: StatusCode, body: Body) -> Result<ApiResponse, RustAuthError> {
86 Response::builder()
87 .status(status)
88 .body(body)
89 .map_err(|error| RustAuthError::Serialization {
90 context: "building API response",
91 message: error.to_string(),
92 })
93}
94
95pub fn api_error(status: StatusCode, code: ApiErrorCode) -> Result<ApiResponse, RustAuthError> {
96 let body = serde_json::to_vec(&ApiErrorResponse::from_error_code(code)).map_err(|error| {
97 RustAuthError::Serialization {
98 context: "serializing API error response",
99 message: error.to_string(),
100 }
101 })?;
102
103 Response::builder()
104 .status(status)
105 .header(header::CONTENT_TYPE, "application/json")
106 .body(body)
107 .map_err(|error| RustAuthError::Serialization {
108 context: "building API error response",
109 message: error.to_string(),
110 })
111}
112
113pub(super) fn rate_limit_response(
114 rejection: RateLimitRejection,
115) -> Result<ApiResponse, RustAuthError> {
116 let mut response = api_error(StatusCode::TOO_MANY_REQUESTS, ApiErrorCode::TooManyRequests)?;
117 response.headers_mut().insert(
118 "X-Retry-After",
119 http::HeaderValue::from_str(&rejection.retry_after.to_string()).map_err(|error| {
120 RustAuthError::Serialization {
121 context: "building rate limit response headers",
122 message: error.to_string(),
123 }
124 })?,
125 );
126 Ok(response)
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::error_codes::ErrorCode;
133
134 fn assert_error_code(code: impl ErrorCode, expected_code: &str, expected_message: &str) {
135 assert_eq!(code.as_str(), expected_code);
136 assert_eq!(code.message(), expected_message);
137 }
138
139 #[test]
140 fn api_error_code_implements_error_code_trait() {
141 assert_error_code(
142 ApiErrorCode::InvalidOrigin,
143 "INVALID_ORIGIN",
144 "Invalid origin",
145 );
146 }
147
148 #[test]
149 fn api_error_response_from_error_code_matches_inherent_helpers() {
150 let code = ApiErrorCode::TooManyRequests;
151 let response = ApiErrorResponse::from_error_code(code);
152 assert_eq!(response.code, code.as_str());
153 assert_eq!(response.message, code.message());
154 assert_eq!(response.original_message, None);
155 }
156}