1use anyhow::Result;
7use bytes::Bytes;
8use http::{Response, StatusCode};
9use http_body_util::Full;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15use sentinel_config::{ErrorFormat, ErrorPage, ErrorPageConfig, ServiceType};
16
17pub struct ErrorHandler {
19 service_type: ServiceType,
21 config: Option<ErrorPageConfig>,
23 templates: Arc<HashMap<u16, String>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ErrorResponse {
30 pub status: u16,
32 pub title: String,
34 pub message: String,
36 pub request_id: String,
38 pub timestamp: i64,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub details: Option<serde_json::Value>,
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub stack_trace: Option<Vec<String>>,
46}
47
48impl ErrorHandler {
49 pub fn new(service_type: ServiceType, config: Option<ErrorPageConfig>) -> Self {
51 let templates = if let Some(ref cfg) = config {
52 Self::load_templates(cfg)
53 } else {
54 Arc::new(HashMap::new())
55 };
56
57 Self {
58 service_type,
59 config,
60 templates,
61 }
62 }
63
64 pub fn generate_response(
66 &self,
67 status: StatusCode,
68 message: Option<String>,
69 request_id: &str,
70 details: Option<serde_json::Value>,
71 ) -> Result<Response<Full<Bytes>>> {
72 let status_code = status.as_u16();
73 let error_data = ErrorResponse {
74 status: status_code,
75 title: Self::status_title(status),
76 message: message.unwrap_or_else(|| Self::default_message(status)),
77 request_id: request_id.to_string(),
78 timestamp: chrono::Utc::now().timestamp(),
79 details,
80 stack_trace: self.get_stack_trace(),
81 };
82
83 let format = self.determine_format(status_code);
85
86 let (body, content_type) = match format {
88 ErrorFormat::Json => self.generate_json_response(&error_data)?,
89 ErrorFormat::Html => self.generate_html_response(&error_data, status_code)?,
90 ErrorFormat::Text => self.generate_text_response(&error_data)?,
91 ErrorFormat::Xml => self.generate_xml_response(&error_data)?,
92 };
93
94 let mut response = Response::builder()
96 .status(status)
97 .header("Content-Type", content_type)
98 .header("X-Request-Id", request_id);
99
100 if let Some(page) = self.get_error_page(status_code) {
102 for (key, value) in &page.headers {
103 response = response.header(key, value);
104 }
105 }
106
107 Ok(response.body(Full::new(Bytes::from(body)))?)
108 }
109
110 fn determine_format(&self, status_code: u16) -> ErrorFormat {
112 if let Some(page) = self.get_error_page(status_code) {
114 return page.format;
115 }
116
117 if let Some(ref config) = self.config {
119 return config.default_format;
120 }
121
122 match self.service_type {
124 ServiceType::Api | ServiceType::Builtin => ErrorFormat::Json,
125 ServiceType::Web | ServiceType::Static => ErrorFormat::Html,
126 }
127 }
128
129 fn get_error_page(&self, status_code: u16) -> Option<&ErrorPage> {
131 self.config.as_ref().and_then(|c| c.pages.get(&status_code))
132 }
133
134 fn generate_json_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
136 let json = serde_json::to_vec_pretty(error)?;
137 Ok((json, "application/json; charset=utf-8"))
138 }
139
140 fn generate_html_response(
142 &self,
143 error: &ErrorResponse,
144 status_code: u16,
145 ) -> Result<(Vec<u8>, &'static str)> {
146 if let Some(template) = self.templates.get(&status_code) {
148 let html = self.render_template(template, error)?;
149 return Ok((html.into_bytes(), "text/html; charset=utf-8"));
150 }
151
152 let html = self.generate_default_html(error);
154 Ok((html.into_bytes(), "text/html; charset=utf-8"))
155 }
156
157 fn generate_text_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
159 let text = format!(
160 "{} {}\n\n{}\n\nRequest ID: {}\nTimestamp: {}",
161 error.status, error.title, error.message, error.request_id, error.timestamp
162 );
163 Ok((text.into_bytes(), "text/plain; charset=utf-8"))
164 }
165
166 fn generate_xml_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
168 let xml = format!(
169 r#"<?xml version="1.0" encoding="UTF-8"?>
170<error>
171 <status>{}</status>
172 <title>{}</title>
173 <message>{}</message>
174 <requestId>{}</requestId>
175 <timestamp>{}</timestamp>
176</error>"#,
177 error.status,
178 Self::escape_xml(&error.title),
179 Self::escape_xml(&error.message),
180 Self::escape_xml(&error.request_id),
181 error.timestamp
182 );
183 Ok((xml.into_bytes(), "application/xml; charset=utf-8"))
184 }
185
186 fn generate_default_html(&self, error: &ErrorResponse) -> String {
188 format!(
189 r#"<!DOCTYPE html>
190<html lang="en">
191<head>
192 <meta charset="UTF-8">
193 <meta name="viewport" content="width=device-width, initial-scale=1.0">
194 <title>{} {}</title>
195 <style>
196 body {{
197 font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
198 background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
199 color: #333;
200 display: flex;
201 align-items: center;
202 justify-content: center;
203 min-height: 100vh;
204 margin: 0;
205 padding: 20px;
206 }}
207 .error-container {{
208 background: white;
209 border-radius: 12px;
210 box-shadow: 0 20px 60px rgba(0,0,0,0.3);
211 padding: 40px;
212 max-width: 600px;
213 width: 100%;
214 text-align: center;
215 }}
216 h1 {{
217 color: #764ba2;
218 font-size: 72px;
219 margin: 0;
220 font-weight: bold;
221 }}
222 h2 {{
223 color: #666;
224 font-size: 24px;
225 margin: 10px 0;
226 font-weight: normal;
227 }}
228 p {{
229 color: #777;
230 font-size: 16px;
231 line-height: 1.6;
232 margin: 20px 0;
233 }}
234 .request-id {{
235 background: #f5f5f5;
236 border-radius: 4px;
237 padding: 8px 12px;
238 font-family: 'Courier New', monospace;
239 font-size: 12px;
240 color: #999;
241 margin-top: 30px;
242 }}
243 .back-link {{
244 display: inline-block;
245 margin-top: 20px;
246 color: #667eea;
247 text-decoration: none;
248 font-weight: 500;
249 transition: color 0.3s;
250 }}
251 .back-link:hover {{
252 color: #764ba2;
253 }}
254 </style>
255</head>
256<body>
257 <div class="error-container">
258 <h1>{}</h1>
259 <h2>{}</h2>
260 <p>{}</p>
261 <div class="request-id">Request ID: {}</div>
262 <a href="/" class="back-link">← Back to Home</a>
263 </div>
264</body>
265</html>"#,
266 error.status, error.title, error.status, error.title, error.message, error.request_id
267 )
268 }
269
270 fn load_templates(config: &ErrorPageConfig) -> Arc<HashMap<u16, String>> {
272 let mut templates = HashMap::new();
273
274 if let Some(ref template_dir) = config.template_dir {
275 for (status_code, page) in &config.pages {
276 if let Some(ref template_path) = page.template {
277 let full_path = if template_path.is_absolute() {
278 template_path.clone()
279 } else {
280 template_dir.join(template_path)
281 };
282
283 match std::fs::read_to_string(&full_path) {
284 Ok(content) => {
285 templates.insert(*status_code, content);
286 debug!(
287 "Loaded error template for status {}: {:?}",
288 status_code, full_path
289 );
290 }
291 Err(e) => {
292 warn!("Failed to load error template {:?}: {}", full_path, e);
293 }
294 }
295 }
296 }
297 }
298
299 Arc::new(templates)
300 }
301
302 fn render_template(&self, template: &str, error: &ErrorResponse) -> Result<String> {
304 let rendered = template
306 .replace("{{status}}", &error.status.to_string())
307 .replace("{{title}}", &error.title)
308 .replace("{{message}}", &error.message)
309 .replace("{{request_id}}", &error.request_id)
310 .replace("{{timestamp}}", &error.timestamp.to_string());
311
312 Ok(rendered)
313 }
314
315 fn get_stack_trace(&self) -> Option<Vec<String>> {
317 if let Some(ref config) = self.config {
318 if config.include_stack_trace {
319 return None;
322 }
323 }
324 None
325 }
326
327 fn status_title(status: StatusCode) -> String {
329 status
330 .canonical_reason()
331 .unwrap_or("Unknown Error")
332 .to_string()
333 }
334
335 fn default_message(status: StatusCode) -> String {
337 match status {
338 StatusCode::BAD_REQUEST => {
339 "The request could not be understood by the server.".to_string()
340 }
341 StatusCode::UNAUTHORIZED => {
342 "You are not authorized to access this resource.".to_string()
343 }
344 StatusCode::FORBIDDEN => "Access to this resource is forbidden.".to_string(),
345 StatusCode::NOT_FOUND => "The requested resource could not be found.".to_string(),
346 StatusCode::METHOD_NOT_ALLOWED => {
347 "The requested method is not allowed for this resource.".to_string()
348 }
349 StatusCode::REQUEST_TIMEOUT => "The request took too long to process.".to_string(),
350 StatusCode::PAYLOAD_TOO_LARGE => "The request payload is too large.".to_string(),
351 StatusCode::TOO_MANY_REQUESTS => {
352 "Too many requests. Please try again later.".to_string()
353 }
354 StatusCode::INTERNAL_SERVER_ERROR => {
355 "An internal server error occurred. Please try again later.".to_string()
356 }
357 StatusCode::BAD_GATEWAY => {
358 "The gateway received an invalid response from the upstream server.".to_string()
359 }
360 StatusCode::SERVICE_UNAVAILABLE => {
361 "The service is temporarily unavailable. Please try again later.".to_string()
362 }
363 StatusCode::GATEWAY_TIMEOUT => {
364 "The gateway timed out waiting for a response from the upstream server.".to_string()
365 }
366 _ => format!("An error occurred (HTTP {})", status.as_u16()),
367 }
368 }
369
370 fn escape_xml(s: &str) -> String {
372 s.replace('&', "&")
373 .replace('<', "<")
374 .replace('>', ">")
375 .replace('"', """)
376 .replace('\'', "'")
377 }
378
379 pub fn reload_templates(&mut self) {
381 if let Some(ref config) = self.config {
382 self.templates = Self::load_templates(config);
383 debug!("Reloaded error templates");
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_error_handler_json() {
394 let handler = ErrorHandler::new(ServiceType::Api, None);
395 let response = handler
396 .generate_response(
397 StatusCode::NOT_FOUND,
398 Some("Resource not found".to_string()),
399 "test-123",
400 None,
401 )
402 .unwrap();
403
404 assert_eq!(response.status(), StatusCode::NOT_FOUND);
405 let headers = response.headers();
406 assert_eq!(
407 headers.get("Content-Type").unwrap(),
408 "application/json; charset=utf-8"
409 );
410 }
411
412 #[test]
413 fn test_error_handler_html() {
414 let handler = ErrorHandler::new(ServiceType::Web, None);
415 let response = handler
416 .generate_response(StatusCode::INTERNAL_SERVER_ERROR, None, "test-456", None)
417 .unwrap();
418
419 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
420 let headers = response.headers();
421 assert_eq!(
422 headers.get("Content-Type").unwrap(),
423 "text/html; charset=utf-8"
424 );
425 }
426
427 #[test]
428 fn test_custom_error_format() {
429 let mut config = ErrorPageConfig {
430 pages: HashMap::new(),
431 default_format: ErrorFormat::Xml,
432 include_stack_trace: false,
433 template_dir: None,
434 };
435
436 config.pages.insert(
437 404,
438 ErrorPage {
439 format: ErrorFormat::Text,
440 template: None,
441 message: Some("Custom 404 message".to_string()),
442 headers: HashMap::new(),
443 },
444 );
445
446 let handler = ErrorHandler::new(ServiceType::Web, Some(config));
447 let response = handler
448 .generate_response(StatusCode::NOT_FOUND, None, "test-789", None)
449 .unwrap();
450
451 assert_eq!(
452 response.headers().get("Content-Type").unwrap(),
453 "text/plain; charset=utf-8"
454 );
455 }
456}