Skip to main content

rustauth_core/api/
error.rs

1use http::{header, Response, StatusCode};
2use serde::{Deserialize, Serialize};
3
4use crate::error::RustAuthError;
5use crate::error_codes::ErrorCode;
6use crate::rate_limit::RateLimitRejection;
7
8use super::endpoint::{ApiResponse, Body};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ApiErrorCode {
12    NotFound,
13    InvalidOrigin,
14    InvalidCallbackUrl,
15    InvalidRedirectUrl,
16    InvalidErrorCallbackUrl,
17    InvalidNewUserCallbackUrl,
18    MissingOrNullOrigin,
19    CrossSiteNavigationLoginBlocked,
20    TooManyRequests,
21}
22
23impl ApiErrorCode {
24    pub fn as_str(self) -> &'static str {
25        match self {
26            Self::NotFound => "NOT_FOUND",
27            Self::InvalidOrigin => "INVALID_ORIGIN",
28            Self::InvalidCallbackUrl => "INVALID_CALLBACK_URL",
29            Self::InvalidRedirectUrl => "INVALID_REDIRECT_URL",
30            Self::InvalidErrorCallbackUrl => "INVALID_ERROR_CALLBACK_URL",
31            Self::InvalidNewUserCallbackUrl => "INVALID_NEW_USER_CALLBACK_URL",
32            Self::MissingOrNullOrigin => "MISSING_OR_NULL_ORIGIN",
33            Self::CrossSiteNavigationLoginBlocked => "CROSS_SITE_NAVIGATION_LOGIN_BLOCKED",
34            Self::TooManyRequests => "TOO_MANY_REQUESTS",
35        }
36    }
37
38    pub fn message(self) -> &'static str {
39        match self {
40            Self::NotFound => "Not Found",
41            Self::InvalidOrigin => "Invalid origin",
42            Self::InvalidCallbackUrl => "Invalid callbackURL",
43            Self::InvalidRedirectUrl => "Invalid redirectURL",
44            Self::InvalidErrorCallbackUrl => "Invalid errorCallbackURL",
45            Self::InvalidNewUserCallbackUrl => "Invalid newUserCallbackURL",
46            Self::MissingOrNullOrigin => "Missing or null Origin",
47            Self::CrossSiteNavigationLoginBlocked => {
48                "Cross-site navigation login blocked. This request appears to be a CSRF attack."
49            }
50            Self::TooManyRequests => "Too many requests. Please try again later.",
51        }
52    }
53}
54
55impl ErrorCode for ApiErrorCode {
56    fn as_str(&self) -> &str {
57        (*self).as_str()
58    }
59
60    fn message(&self) -> &str {
61        (*self).message()
62    }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct ApiErrorResponse {
67    pub code: String,
68    pub message: String,
69    #[serde(default)]
70    #[serde(skip_serializing_if = "Option::is_none")]
71    #[serde(rename = "originalMessage")]
72    pub original_message: Option<String>,
73}
74
75impl ApiErrorResponse {
76    pub fn from_error_code(code: impl ErrorCode) -> Self {
77        Self {
78            code: code.as_str().to_owned(),
79            message: code.message().to_owned(),
80            original_message: None,
81        }
82    }
83}
84
85pub fn response(status: StatusCode, body: Body) -> Result<ApiResponse, RustAuthError> {
86    Response::builder()
87        .status(status)
88        .body(body)
89        .map_err(|error| RustAuthError::Serialization {
90            context: "building API response",
91            message: error.to_string(),
92        })
93}
94
95pub fn api_error(status: StatusCode, code: ApiErrorCode) -> Result<ApiResponse, RustAuthError> {
96    let body = serde_json::to_vec(&ApiErrorResponse::from_error_code(code)).map_err(|error| {
97        RustAuthError::Serialization {
98            context: "serializing API error response",
99            message: error.to_string(),
100        }
101    })?;
102
103    Response::builder()
104        .status(status)
105        .header(header::CONTENT_TYPE, "application/json")
106        .body(body)
107        .map_err(|error| RustAuthError::Serialization {
108            context: "building API error response",
109            message: error.to_string(),
110        })
111}
112
113pub(super) fn rate_limit_response(
114    rejection: RateLimitRejection,
115) -> Result<ApiResponse, RustAuthError> {
116    let mut response = api_error(StatusCode::TOO_MANY_REQUESTS, ApiErrorCode::TooManyRequests)?;
117    response.headers_mut().insert(
118        "X-Retry-After",
119        http::HeaderValue::from_str(&rejection.retry_after.to_string()).map_err(|error| {
120            RustAuthError::Serialization {
121                context: "building rate limit response headers",
122                message: error.to_string(),
123            }
124        })?,
125    );
126    Ok(response)
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::error_codes::ErrorCode;
133
134    fn assert_error_code(code: impl ErrorCode, expected_code: &str, expected_message: &str) {
135        assert_eq!(code.as_str(), expected_code);
136        assert_eq!(code.message(), expected_message);
137    }
138
139    #[test]
140    fn api_error_code_implements_error_code_trait() {
141        assert_error_code(
142            ApiErrorCode::InvalidOrigin,
143            "INVALID_ORIGIN",
144            "Invalid origin",
145        );
146    }
147
148    #[test]
149    fn api_error_response_from_error_code_matches_inherent_helpers() {
150        let code = ApiErrorCode::TooManyRequests;
151        let response = ApiErrorResponse::from_error_code(code);
152        assert_eq!(response.code, code.as_str());
153        assert_eq!(response.message, code.message());
154        assert_eq!(response.original_message, None);
155    }
156}