spring_web/
problem_details.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use axum::response::IntoResponse;
5
6// OpenAPI related imports - only available when openapi feature is enabled
7#[cfg(feature = "openapi")]
8use aide::openapi::{MediaType, Operation, ReferenceOr, Response, SchemaObject, StatusCode};
9
10/// Trait for providing variant information for Problem Details OpenAPI documentation
11#[cfg(feature = "openapi")]
12pub trait ProblemDetailsVariantInfo {
13    fn get_variant_info(variant_name: &str) -> Option<(u16, String, Option<schemars::Schema>)>;
14}
15
16/// Generate Problem Details schema for OpenAPI documentation
17#[cfg(feature = "openapi")]
18pub fn problem_details_schema() -> schemars::Schema {
19    use schemars::JsonSchema;
20    crate::problem_details::ProblemDetails::json_schema(&mut schemars::SchemaGenerator::default())
21}
22
23/// Register error response by variant for OpenAPI documentation
24#[cfg(feature = "openapi")]
25pub fn register_error_response_by_variant<T>(
26    _ctx: &mut aide::generate::GenContext,
27    operation: &mut Operation,
28    variant_path: &str,
29) where
30    T: ProblemDetailsVariantInfo,
31{
32    let variant_name = variant_path.split("::").last().unwrap_or(variant_path);
33    
34    let Some((status_code, description, _schema_opt)) = T::get_variant_info(variant_name) else {
35        tracing::warn!("Variant '{}' not found in error type '{}' when registering OpenAPI responses", 
36                      variant_name, std::any::type_name::<T>());
37        return;
38    };
39    
40    // Create Problem Details response
41    let problem_type = format!("about:blank/{}", variant_name.to_lowercase().replace("::", "-"));
42    let example = serde_json::json!({
43        "type": problem_type,
44        "title": format!("{} Error", variant_name),
45        "status": status_code,
46        "detail": format!("{} occurred", variant_name)
47    });
48    
49    let response = Response {
50        description,
51        content: {
52            let mut content = indexmap::IndexMap::new();
53            let media_type = MediaType {
54                schema: Some(SchemaObject {
55                    json_schema: problem_details_schema(),
56                    example: Some(example),
57                    external_docs: None,
58                }),
59                ..Default::default()
60            };
61            
62            content.insert("application/problem+json".to_string(), media_type.clone());
63            content.insert("application/json".to_string(), media_type); // backward compatibility
64            content
65        },
66        ..Default::default()
67    };
68    
69    // Add response to operation
70    if operation.responses.is_none() {
71        operation.responses = Some(Default::default());
72    }
73    
74    let responses = operation.responses.as_mut().unwrap();
75    let status_code_key = StatusCode::Code(status_code);
76    
77    if let Some(existing) = responses.responses.get_mut(&status_code_key) {
78        // Merge descriptions if response already exists
79        if let ReferenceOr::Item(existing_response) = existing {
80            if existing_response.description != response.description {
81                existing_response.description = format!("{}\n- {}", existing_response.description, response.description);
82            }
83        }
84    } else {
85        responses.responses.insert(status_code_key, ReferenceOr::Item(response));
86    }
87}
88
89/// RFC 7807 Problem Details for HTTP APIs
90/// 
91/// This struct represents a standardized error response format as defined in RFC 7807.
92/// It provides a consistent way to communicate error information in HTTP APIs.
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94pub struct ProblemDetails {
95    /// A URI reference that identifies the problem type
96    #[serde(rename = "type")]
97    pub problem_type: String,
98    
99    /// A short, human-readable summary of the problem type
100    pub title: String,
101    
102    /// The HTTP status code generated by the origin server
103    pub status: u16,
104    
105    /// A human-readable explanation specific to this occurrence of the problem
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub detail: Option<String>,
108    
109    /// A URI reference that identifies the specific occurrence of the problem
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub instance: Option<String>,
112    
113    /// Additional problem-specific extension fields
114    #[serde(flatten)]
115    pub extensions: HashMap<String, serde_json::Value>,
116}
117
118impl ProblemDetails {
119    /// Create a new ProblemDetails with required fields
120    pub fn new(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
121        Self {
122            problem_type: problem_type.into(),
123            title: title.into(),
124            status,
125            detail: None,
126            instance: None,
127            extensions: HashMap::new(),
128        }
129    }
130    
131    /// Set the detail field
132    pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
133        self.detail = Some(detail.into());
134        self
135    }
136    
137    /// Set the instance field
138    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
139        self.instance = Some(instance.into());
140        self
141    }
142    
143    /// Add an extension field
144    pub fn with_extension(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
145        self.extensions.insert(key.into(), value);
146        self
147    }
148    
149    /// Create a validation error problem
150    pub fn validation_error(detail: impl Into<String>) -> Self {
151        Self::new(
152            "about:blank",
153            "Validation Error",
154            400,
155        )
156        .with_detail(detail)
157    }
158    
159    /// Create an authentication error problem
160    pub fn authentication_error() -> Self {
161        Self::new(
162            "about:blank",
163            "Authentication Required",
164            401,
165        )
166        .with_detail("Authentication credentials are required to access this resource")
167    }
168    
169    /// Create an authorization error problem
170    pub fn authorization_error() -> Self {
171        Self::new(
172            "about:blank",
173            "Insufficient Permissions",
174            403,
175        )
176        .with_detail("You don't have permission to access this resource")
177    }
178    
179    /// Create a not found error problem
180    pub fn not_found(resource: impl Into<String>) -> Self {
181        Self::new(
182            "about:blank",
183            "Resource Not Found",
184            404,
185        )
186        .with_detail(format!("The requested {} was not found", resource.into()))
187    }
188    
189    /// Create an internal server error problem
190    pub fn internal_server_error() -> Self {
191        Self::new(
192            "about:blank",
193            "Internal Server Error",
194            500,
195        )
196        .with_detail("An unexpected error occurred while processing your request")
197    }
198    
199    /// Create a service unavailable error problem
200    pub fn service_unavailable() -> Self {
201        Self::new(
202            "about:blank",
203            "Service Unavailable",
204            503,
205        )
206        .with_detail("The service is temporarily unavailable")
207    }
208    
209    /// Create a custom problem with explicit URI
210    pub fn custom_problem(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
211        Self::new(
212            problem_type,
213            title,
214            status,
215        )
216    }
217}
218
219impl IntoResponse for ProblemDetails {
220    fn into_response(mut self) -> axum::response::Response {
221        let status = axum::http::StatusCode::from_u16(self.status)
222            .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
223        
224        // Try to get the current request URI from task-local storage
225        if self.instance.is_none() {
226            if let Some(uri) = get_current_request_uri() {
227                self.instance = Some(uri);
228            }
229        }
230        
231        // Set the correct Content-Type for Problem Details
232        (
233            status,
234            [("content-type", "application/problem+json")],
235            axum::Json(self),
236        ).into_response()
237    }
238}
239
240// Task-local storage for current request URI
241tokio::task_local! {
242    static CURRENT_REQUEST_URI: String;
243}
244
245/// Get the current request URI from task-local storage
246fn get_current_request_uri() -> Option<String> {
247    CURRENT_REQUEST_URI.try_with(|uri| uri.clone()).ok()
248}
249
250/// Set the current request URI in task-local storage
251pub fn set_current_request_uri(uri: String) {
252    CURRENT_REQUEST_URI.scope(uri, async {
253        // This will be available for the duration of the request
254    });
255}
256
257/// Middleware to capture request URI for Problem Details
258pub async fn capture_request_uri_middleware(
259    req: axum::http::Request<axum::body::Body>,
260    next: axum::middleware::Next,
261) -> axum::response::Response {
262    let uri = req.uri().to_string();
263    
264    // Run the rest of the request handling with the URI in task-local storage
265    CURRENT_REQUEST_URI.scope(uri, async move {
266        next.run(req).await
267    }).await
268}
269
270/// Get the HTTP status code from ProblemDetails
271impl ProblemDetails {
272    pub fn status_code(&self) -> axum::http::StatusCode {
273        axum::http::StatusCode::from_u16(self.status)
274            .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
275    }
276}
277
278/// Trait for converting errors to Problem Details
279pub trait ToProblemDetails {
280    fn to_problem_details(&self) -> ProblemDetails;
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    
287    #[test]
288    fn test_problem_details_creation() {
289        let problem = ProblemDetails::new("https://example.com/problems/test", "Test Problem", 400)
290            .with_detail("This is a test problem")
291            .with_instance("/test/123")
292            .with_extension("code", serde_json::Value::String("TEST_001".to_string()));
293        
294        assert_eq!(problem.problem_type, "https://example.com/problems/test");
295        assert_eq!(problem.title, "Test Problem");
296        assert_eq!(problem.status, 400);
297        assert_eq!(problem.detail, Some("This is a test problem".to_string()));
298        assert_eq!(problem.instance, Some("/test/123".to_string()));
299        assert_eq!(problem.extensions.get("code"), Some(&serde_json::Value::String("TEST_001".to_string())));
300    }
301    
302    #[test]
303    fn test_validation_error() {
304        // Test with default about:blank
305        let problem = ProblemDetails::validation_error("Name is required");
306        assert_eq!(problem.status, 400);
307        assert_eq!(problem.title, "Validation Error");
308        assert_eq!(problem.problem_type, "about:blank");
309    }
310    
311    #[test]
312    fn test_into_response() {
313        let problem = ProblemDetails::not_found("user");
314        let response = problem.into_response();
315        
316        assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
317    }
318    
319    #[test]
320    fn test_status_code() {
321        let problem = ProblemDetails::validation_error("Test error");
322        assert_eq!(problem.status_code(), axum::http::StatusCode::BAD_REQUEST);
323    }
324
325    #[tokio::test]
326    async fn test_automatic_uri_capture() {
327        // Test that URI is captured in task-local storage
328        let test_uri = "/test/path".to_string();
329        
330        CURRENT_REQUEST_URI.scope(test_uri.clone(), async {
331            let uri = get_current_request_uri();
332            assert_eq!(uri, Some(test_uri));
333        }).await;
334    }
335}