Skip to main content

summer_web/
problem_details.rs

1use axum::response::IntoResponse;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6// OpenAPI related imports - only available when openapi feature is enabled
7#[cfg(feature = "openapi")]
8use aide::openapi::{MediaType, Operation, ReferenceOr, Response, SchemaObject, StatusCode};
9
10/// Trait for providing variant information for Problem Details OpenAPI documentation
11#[cfg(feature = "openapi")]
12pub trait ProblemDetailsVariantInfo {
13    fn get_variant_info(variant_name: &str) -> Option<(u16, String, Option<schemars::Schema>)>;
14}
15
16/// Generate Problem Details schema for OpenAPI documentation
17#[cfg(feature = "openapi")]
18pub fn problem_details_schema() -> schemars::Schema {
19    use schemars::JsonSchema;
20    crate::problem_details::ProblemDetails::json_schema(&mut schemars::SchemaGenerator::default())
21}
22
23/// Register error response by variant for OpenAPI documentation
24#[cfg(feature = "openapi")]
25pub fn register_error_response_by_variant<T>(
26    _ctx: &mut aide::generate::GenContext,
27    operation: &mut Operation,
28    variant_path: &str,
29) where
30    T: ProblemDetailsVariantInfo,
31{
32    let variant_name = variant_path.split("::").last().unwrap_or(variant_path);
33
34    let Some((status_code, description, _schema_opt)) = T::get_variant_info(variant_name) else {
35        tracing::warn!(
36            "Variant '{}' not found in error type '{}' when registering OpenAPI responses",
37            variant_name,
38            std::any::type_name::<T>()
39        );
40        return;
41    };
42
43    // Create Problem Details response
44    let problem_type = format!(
45        "about:blank/{}",
46        variant_name.to_lowercase().replace("::", "-")
47    );
48    let example = serde_json::json!({
49        "type": problem_type,
50        "title": format!("{} Error", variant_name),
51        "status": status_code,
52        "detail": format!("{} occurred", variant_name)
53    });
54
55    let response = Response {
56        description,
57        content: {
58            let mut content = indexmap::IndexMap::new();
59            let media_type = MediaType {
60                schema: Some(SchemaObject {
61                    json_schema: problem_details_schema(),
62                    example: Some(example),
63                    external_docs: None,
64                }),
65                ..Default::default()
66            };
67
68            content.insert("application/problem+json".to_string(), media_type.clone());
69            content.insert("application/json".to_string(), media_type); // backward compatibility
70            content
71        },
72        ..Default::default()
73    };
74
75    // Add response to operation
76    if operation.responses.is_none() {
77        operation.responses = Some(Default::default());
78    }
79
80    let responses = operation.responses.as_mut().unwrap();
81    let status_code_key = StatusCode::Code(status_code);
82
83    if let Some(existing) = responses.responses.get_mut(&status_code_key) {
84        // Merge descriptions if response already exists
85        if let ReferenceOr::Item(existing_response) = existing {
86            if existing_response.description != response.description {
87                existing_response.description = format!(
88                    "{}\n- {}",
89                    existing_response.description, response.description
90                );
91            }
92        }
93    } else {
94        responses
95            .responses
96            .insert(status_code_key, ReferenceOr::Item(response));
97    }
98}
99
100/// RFC 7807 Problem Details for HTTP APIs
101///
102/// This struct represents a standardized error response format as defined in RFC 7807.
103/// It provides a consistent way to communicate error information in HTTP APIs.
104#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
105pub struct ProblemDetails {
106    /// A URI reference that identifies the problem type
107    #[serde(rename = "type")]
108    pub problem_type: String,
109
110    /// A short, human-readable summary of the problem type
111    pub title: String,
112
113    /// The HTTP status code generated by the origin server
114    pub status: u16,
115
116    /// A human-readable explanation specific to this occurrence of the problem
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub detail: Option<String>,
119
120    /// A URI reference that identifies the specific occurrence of the problem
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub instance: Option<String>,
123
124    /// Additional problem-specific extension fields
125    #[serde(flatten)]
126    pub extensions: HashMap<String, serde_json::Value>,
127}
128
129impl ProblemDetails {
130    /// Create a new ProblemDetails with required fields
131    pub fn new(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
132        Self {
133            problem_type: problem_type.into(),
134            title: title.into(),
135            status,
136            detail: None,
137            instance: None,
138            extensions: HashMap::new(),
139        }
140    }
141
142    /// Set the detail field
143    pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
144        self.detail = Some(detail.into());
145        self
146    }
147
148    /// Set the instance field
149    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
150        self.instance = Some(instance.into());
151        self
152    }
153
154    /// Add an extension field
155    pub fn with_extension(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
156        self.extensions.insert(key.into(), value);
157        self
158    }
159
160    /// Create a validation error problem
161    pub fn validation_error(detail: impl Into<String>) -> Self {
162        Self::new("about:blank", "Validation Error", 400).with_detail(detail)
163    }
164
165    /// Create an authentication error problem
166    pub fn authentication_error() -> Self {
167        Self::new("about:blank", "Authentication Required", 401)
168            .with_detail("Authentication credentials are required to access this resource")
169    }
170
171    /// Create an authorization error problem
172    pub fn authorization_error() -> Self {
173        Self::new("about:blank", "Insufficient Permissions", 403)
174            .with_detail("You don't have permission to access this resource")
175    }
176
177    /// Create a not found error problem
178    pub fn not_found(resource: impl Into<String>) -> Self {
179        Self::new("about:blank", "Resource Not Found", 404)
180            .with_detail(format!("The requested {} was not found", resource.into()))
181    }
182
183    /// Create an internal server error problem
184    pub fn internal_server_error() -> Self {
185        Self::new("about:blank", "Internal Server Error", 500)
186            .with_detail("An unexpected error occurred while processing your request")
187    }
188
189    /// Create a service unavailable error problem
190    pub fn service_unavailable() -> Self {
191        Self::new("about:blank", "Service Unavailable", 503)
192            .with_detail("The service is temporarily unavailable")
193    }
194
195    /// Create a custom problem with explicit URI
196    pub fn custom_problem(
197        problem_type: impl Into<String>,
198        title: impl Into<String>,
199        status: u16,
200    ) -> Self {
201        Self::new(problem_type, title, status)
202    }
203}
204
205impl IntoResponse for ProblemDetails {
206    fn into_response(mut self) -> axum::response::Response {
207        let status = axum::http::StatusCode::from_u16(self.status)
208            .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
209
210        // Try to get the current request URI from task-local storage
211        if self.instance.is_none() {
212            if let Some(uri) = get_current_request_uri() {
213                self.instance = Some(uri);
214            }
215        }
216
217        // Set the correct Content-Type for Problem Details
218        (
219            status,
220            [("content-type", "application/problem+json")],
221            axum::Json(self),
222        )
223            .into_response()
224    }
225}
226
227// Task-local storage for current request URI
228tokio::task_local! {
229    static CURRENT_REQUEST_URI: String;
230}
231
232/// Get the current request URI from task-local storage
233fn get_current_request_uri() -> Option<String> {
234    CURRENT_REQUEST_URI.try_with(|uri| uri.clone()).ok()
235}
236
237/// Set the current request URI in task-local storage
238pub fn set_current_request_uri(uri: String) {
239    CURRENT_REQUEST_URI.scope(uri, async {
240        // This will be available for the duration of the request
241    });
242}
243
244/// Middleware to capture request URI for Problem Details
245pub async fn capture_request_uri_middleware(
246    req: axum::http::Request<axum::body::Body>,
247    next: axum::middleware::Next,
248) -> axum::response::Response {
249    let uri = req.uri().to_string();
250
251    // Run the rest of the request handling with the URI in task-local storage
252    CURRENT_REQUEST_URI
253        .scope(uri, async move { next.run(req).await })
254        .await
255}
256
257/// Get the HTTP status code from ProblemDetails
258impl ProblemDetails {
259    pub fn status_code(&self) -> axum::http::StatusCode {
260        axum::http::StatusCode::from_u16(self.status)
261            .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_problem_details_creation() {
271        let problem = ProblemDetails::new("https://example.com/problems/test", "Test Problem", 400)
272            .with_detail("This is a test problem")
273            .with_instance("/test/123")
274            .with_extension("code", serde_json::Value::String("TEST_001".to_string()));
275
276        assert_eq!(problem.problem_type, "https://example.com/problems/test");
277        assert_eq!(problem.title, "Test Problem");
278        assert_eq!(problem.status, 400);
279        assert_eq!(problem.detail, Some("This is a test problem".to_string()));
280        assert_eq!(problem.instance, Some("/test/123".to_string()));
281        assert_eq!(
282            problem.extensions.get("code"),
283            Some(&serde_json::Value::String("TEST_001".to_string()))
284        );
285    }
286
287    #[test]
288    fn test_validation_error() {
289        // Test with default about:blank
290        let problem = ProblemDetails::validation_error("Name is required");
291        assert_eq!(problem.status, 400);
292        assert_eq!(problem.title, "Validation Error");
293        assert_eq!(problem.problem_type, "about:blank");
294    }
295
296    #[test]
297    fn test_into_response() {
298        let problem = ProblemDetails::not_found("user");
299        let response = problem.into_response();
300
301        assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
302    }
303
304    #[test]
305    fn test_status_code() {
306        let problem = ProblemDetails::validation_error("Test error");
307        assert_eq!(problem.status_code(), axum::http::StatusCode::BAD_REQUEST);
308    }
309
310    #[tokio::test]
311    async fn test_automatic_uri_capture() {
312        // Test that URI is captured in task-local storage
313        let test_uri = "/test/path".to_string();
314
315        CURRENT_REQUEST_URI
316            .scope(test_uri.clone(), async {
317                let uri = get_current_request_uri();
318                assert_eq!(uri, Some(test_uri));
319            })
320            .await;
321    }
322}