1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use axum::response::IntoResponse;
5
6#[cfg(feature = "openapi")]
8use aide::openapi::{MediaType, Operation, ReferenceOr, Response, SchemaObject, StatusCode};
9
10#[cfg(feature = "openapi")]
12pub trait ProblemDetailsVariantInfo {
13 fn get_variant_info(variant_name: &str) -> Option<(u16, String, Option<schemars::Schema>)>;
14}
15
16#[cfg(feature = "openapi")]
18pub fn problem_details_schema() -> schemars::Schema {
19 use schemars::JsonSchema;
20 crate::problem_details::ProblemDetails::json_schema(&mut schemars::SchemaGenerator::default())
21}
22
23#[cfg(feature = "openapi")]
25pub fn register_error_response_by_variant<T>(
26 _ctx: &mut aide::generate::GenContext,
27 operation: &mut Operation,
28 variant_path: &str,
29) where
30 T: ProblemDetailsVariantInfo,
31{
32 let variant_name = variant_path.split("::").last().unwrap_or(variant_path);
33
34 let Some((status_code, description, _schema_opt)) = T::get_variant_info(variant_name) else {
35 tracing::warn!("Variant '{}' not found in error type '{}' when registering OpenAPI responses",
36 variant_name, std::any::type_name::<T>());
37 return;
38 };
39
40 let problem_type = format!("about:blank/{}", variant_name.to_lowercase().replace("::", "-"));
42 let example = serde_json::json!({
43 "type": problem_type,
44 "title": format!("{} Error", variant_name),
45 "status": status_code,
46 "detail": format!("{} occurred", variant_name)
47 });
48
49 let response = Response {
50 description,
51 content: {
52 let mut content = indexmap::IndexMap::new();
53 let media_type = MediaType {
54 schema: Some(SchemaObject {
55 json_schema: problem_details_schema(),
56 example: Some(example),
57 external_docs: None,
58 }),
59 ..Default::default()
60 };
61
62 content.insert("application/problem+json".to_string(), media_type.clone());
63 content.insert("application/json".to_string(), media_type); content
65 },
66 ..Default::default()
67 };
68
69 if operation.responses.is_none() {
71 operation.responses = Some(Default::default());
72 }
73
74 let responses = operation.responses.as_mut().unwrap();
75 let status_code_key = StatusCode::Code(status_code);
76
77 if let Some(existing) = responses.responses.get_mut(&status_code_key) {
78 if let ReferenceOr::Item(existing_response) = existing {
80 if existing_response.description != response.description {
81 existing_response.description = format!("{}\n- {}", existing_response.description, response.description);
82 }
83 }
84 } else {
85 responses.responses.insert(status_code_key, ReferenceOr::Item(response));
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94pub struct ProblemDetails {
95 #[serde(rename = "type")]
97 pub problem_type: String,
98
99 pub title: String,
101
102 pub status: u16,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub detail: Option<String>,
108
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub instance: Option<String>,
112
113 #[serde(flatten)]
115 pub extensions: HashMap<String, serde_json::Value>,
116}
117
118impl ProblemDetails {
119 pub fn new(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
121 Self {
122 problem_type: problem_type.into(),
123 title: title.into(),
124 status,
125 detail: None,
126 instance: None,
127 extensions: HashMap::new(),
128 }
129 }
130
131 pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
133 self.detail = Some(detail.into());
134 self
135 }
136
137 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
139 self.instance = Some(instance.into());
140 self
141 }
142
143 pub fn with_extension(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
145 self.extensions.insert(key.into(), value);
146 self
147 }
148
149 pub fn validation_error(detail: impl Into<String>) -> Self {
151 Self::new(
152 "about:blank",
153 "Validation Error",
154 400,
155 )
156 .with_detail(detail)
157 }
158
159 pub fn authentication_error() -> Self {
161 Self::new(
162 "about:blank",
163 "Authentication Required",
164 401,
165 )
166 .with_detail("Authentication credentials are required to access this resource")
167 }
168
169 pub fn authorization_error() -> Self {
171 Self::new(
172 "about:blank",
173 "Insufficient Permissions",
174 403,
175 )
176 .with_detail("You don't have permission to access this resource")
177 }
178
179 pub fn not_found(resource: impl Into<String>) -> Self {
181 Self::new(
182 "about:blank",
183 "Resource Not Found",
184 404,
185 )
186 .with_detail(format!("The requested {} was not found", resource.into()))
187 }
188
189 pub fn internal_server_error() -> Self {
191 Self::new(
192 "about:blank",
193 "Internal Server Error",
194 500,
195 )
196 .with_detail("An unexpected error occurred while processing your request")
197 }
198
199 pub fn service_unavailable() -> Self {
201 Self::new(
202 "about:blank",
203 "Service Unavailable",
204 503,
205 )
206 .with_detail("The service is temporarily unavailable")
207 }
208
209 pub fn custom_problem(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
211 Self::new(
212 problem_type,
213 title,
214 status,
215 )
216 }
217}
218
219impl IntoResponse for ProblemDetails {
220 fn into_response(mut self) -> axum::response::Response {
221 let status = axum::http::StatusCode::from_u16(self.status)
222 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
223
224 if self.instance.is_none() {
226 if let Some(uri) = get_current_request_uri() {
227 self.instance = Some(uri);
228 }
229 }
230
231 (
233 status,
234 [("content-type", "application/problem+json")],
235 axum::Json(self),
236 ).into_response()
237 }
238}
239
240tokio::task_local! {
242 static CURRENT_REQUEST_URI: String;
243}
244
245fn get_current_request_uri() -> Option<String> {
247 CURRENT_REQUEST_URI.try_with(|uri| uri.clone()).ok()
248}
249
250pub fn set_current_request_uri(uri: String) {
252 CURRENT_REQUEST_URI.scope(uri, async {
253 });
255}
256
257pub async fn capture_request_uri_middleware(
259 req: axum::http::Request<axum::body::Body>,
260 next: axum::middleware::Next,
261) -> axum::response::Response {
262 let uri = req.uri().to_string();
263
264 CURRENT_REQUEST_URI.scope(uri, async move {
266 next.run(req).await
267 }).await
268}
269
270impl ProblemDetails {
272 pub fn status_code(&self) -> axum::http::StatusCode {
273 axum::http::StatusCode::from_u16(self.status)
274 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
275 }
276}
277
278pub trait ToProblemDetails {
280 fn to_problem_details(&self) -> ProblemDetails;
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_problem_details_creation() {
289 let problem = ProblemDetails::new("https://example.com/problems/test", "Test Problem", 400)
290 .with_detail("This is a test problem")
291 .with_instance("/test/123")
292 .with_extension("code", serde_json::Value::String("TEST_001".to_string()));
293
294 assert_eq!(problem.problem_type, "https://example.com/problems/test");
295 assert_eq!(problem.title, "Test Problem");
296 assert_eq!(problem.status, 400);
297 assert_eq!(problem.detail, Some("This is a test problem".to_string()));
298 assert_eq!(problem.instance, Some("/test/123".to_string()));
299 assert_eq!(problem.extensions.get("code"), Some(&serde_json::Value::String("TEST_001".to_string())));
300 }
301
302 #[test]
303 fn test_validation_error() {
304 let problem = ProblemDetails::validation_error("Name is required");
306 assert_eq!(problem.status, 400);
307 assert_eq!(problem.title, "Validation Error");
308 assert_eq!(problem.problem_type, "about:blank");
309 }
310
311 #[test]
312 fn test_into_response() {
313 let problem = ProblemDetails::not_found("user");
314 let response = problem.into_response();
315
316 assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
317 }
318
319 #[test]
320 fn test_status_code() {
321 let problem = ProblemDetails::validation_error("Test error");
322 assert_eq!(problem.status_code(), axum::http::StatusCode::BAD_REQUEST);
323 }
324
325 #[tokio::test]
326 async fn test_automatic_uri_capture() {
327 let test_uri = "/test/path".to_string();
329
330 CURRENT_REQUEST_URI.scope(test_uri.clone(), async {
331 let uri = get_current_request_uri();
332 assert_eq!(uri, Some(test_uri));
333 }).await;
334 }
335}