1use anyhow::Result;
7use bytes::Bytes;
8use http::{Response, StatusCode};
9use http_body_util::Full;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15use sentinel_config::{ErrorFormat, ErrorPage, ErrorPageConfig, ServiceType};
16
17pub struct ErrorHandler {
19 service_type: ServiceType,
21 config: Option<ErrorPageConfig>,
23 templates: Arc<HashMap<u16, String>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ErrorResponse {
30 pub status: u16,
32 pub title: String,
34 pub message: String,
36 pub request_id: String,
38 pub timestamp: i64,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub details: Option<serde_json::Value>,
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub stack_trace: Option<Vec<String>>,
46}
47
48impl ErrorHandler {
49 pub fn new(service_type: ServiceType, config: Option<ErrorPageConfig>) -> Self {
51 let templates = if let Some(ref cfg) = config {
52 Self::load_templates(cfg)
53 } else {
54 Arc::new(HashMap::new())
55 };
56
57 Self {
58 service_type,
59 config,
60 templates,
61 }
62 }
63
64 pub fn generate_response(
66 &self,
67 status: StatusCode,
68 message: Option<String>,
69 request_id: &str,
70 details: Option<serde_json::Value>,
71 ) -> Result<Response<Full<Bytes>>> {
72 let status_code = status.as_u16();
73 let error_data = ErrorResponse {
74 status: status_code,
75 title: Self::status_title(status),
76 message: message.unwrap_or_else(|| Self::default_message(status)),
77 request_id: request_id.to_string(),
78 timestamp: chrono::Utc::now().timestamp(),
79 details,
80 stack_trace: self.get_stack_trace(),
81 };
82
83 let format = self.determine_format(status_code);
85
86 let (body, content_type) = match format {
88 ErrorFormat::Json => self.generate_json_response(&error_data)?,
89 ErrorFormat::Html => self.generate_html_response(&error_data, status_code)?,
90 ErrorFormat::Text => self.generate_text_response(&error_data)?,
91 ErrorFormat::Xml => self.generate_xml_response(&error_data)?,
92 };
93
94 let mut response = Response::builder()
96 .status(status)
97 .header("Content-Type", content_type)
98 .header("X-Request-Id", request_id);
99
100 if let Some(page) = self.get_error_page(status_code) {
102 for (key, value) in &page.headers {
103 response = response.header(key, value);
104 }
105 }
106
107 Ok(response.body(Full::new(Bytes::from(body)))?)
108 }
109
110 fn determine_format(&self, status_code: u16) -> ErrorFormat {
112 if let Some(page) = self.get_error_page(status_code) {
114 return page.format;
115 }
116
117 match self.service_type {
119 ServiceType::Api | ServiceType::Builtin => ErrorFormat::Json,
120 ServiceType::Web | ServiceType::Static => self
121 .config
122 .as_ref()
123 .map(|c| c.default_format)
124 .unwrap_or(ErrorFormat::Html),
125 }
126 }
127
128 fn get_error_page(&self, status_code: u16) -> Option<&ErrorPage> {
130 self.config.as_ref().and_then(|c| c.pages.get(&status_code))
131 }
132
133 fn generate_json_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
135 let json = serde_json::to_vec_pretty(error)?;
136 Ok((json, "application/json; charset=utf-8"))
137 }
138
139 fn generate_html_response(
141 &self,
142 error: &ErrorResponse,
143 status_code: u16,
144 ) -> Result<(Vec<u8>, &'static str)> {
145 if let Some(template) = self.templates.get(&status_code) {
147 let html = self.render_template(template, error)?;
148 return Ok((html.into_bytes(), "text/html; charset=utf-8"));
149 }
150
151 let html = self.generate_default_html(error);
153 Ok((html.into_bytes(), "text/html; charset=utf-8"))
154 }
155
156 fn generate_text_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
158 let text = format!(
159 "{} {}\n\n{}\n\nRequest ID: {}\nTimestamp: {}",
160 error.status, error.title, error.message, error.request_id, error.timestamp
161 );
162 Ok((text.into_bytes(), "text/plain; charset=utf-8"))
163 }
164
165 fn generate_xml_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
167 let xml = format!(
168 r#"<?xml version="1.0" encoding="UTF-8"?>
169<error>
170 <status>{}</status>
171 <title>{}</title>
172 <message>{}</message>
173 <requestId>{}</requestId>
174 <timestamp>{}</timestamp>
175</error>"#,
176 error.status,
177 Self::escape_xml(&error.title),
178 Self::escape_xml(&error.message),
179 Self::escape_xml(&error.request_id),
180 error.timestamp
181 );
182 Ok((xml.into_bytes(), "application/xml; charset=utf-8"))
183 }
184
185 fn generate_default_html(&self, error: &ErrorResponse) -> String {
187 format!(
188 r#"<!DOCTYPE html>
189<html lang="en">
190<head>
191 <meta charset="UTF-8">
192 <meta name="viewport" content="width=device-width, initial-scale=1.0">
193 <title>{} {}</title>
194 <style>
195 body {{
196 font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
197 background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
198 color: #333;
199 display: flex;
200 align-items: center;
201 justify-content: center;
202 min-height: 100vh;
203 margin: 0;
204 padding: 20px;
205 }}
206 .error-container {{
207 background: white;
208 border-radius: 12px;
209 box-shadow: 0 20px 60px rgba(0,0,0,0.3);
210 padding: 40px;
211 max-width: 600px;
212 width: 100%;
213 text-align: center;
214 }}
215 h1 {{
216 color: #764ba2;
217 font-size: 72px;
218 margin: 0;
219 font-weight: bold;
220 }}
221 h2 {{
222 color: #666;
223 font-size: 24px;
224 margin: 10px 0;
225 font-weight: normal;
226 }}
227 p {{
228 color: #777;
229 font-size: 16px;
230 line-height: 1.6;
231 margin: 20px 0;
232 }}
233 .request-id {{
234 background: #f5f5f5;
235 border-radius: 4px;
236 padding: 8px 12px;
237 font-family: 'Courier New', monospace;
238 font-size: 12px;
239 color: #999;
240 margin-top: 30px;
241 }}
242 .back-link {{
243 display: inline-block;
244 margin-top: 20px;
245 color: #667eea;
246 text-decoration: none;
247 font-weight: 500;
248 transition: color 0.3s;
249 }}
250 .back-link:hover {{
251 color: #764ba2;
252 }}
253 </style>
254</head>
255<body>
256 <div class="error-container">
257 <h1>{}</h1>
258 <h2>{}</h2>
259 <p>{}</p>
260 <div class="request-id">Request ID: {}</div>
261 <a href="/" class="back-link">← Back to Home</a>
262 </div>
263</body>
264</html>"#,
265 error.status, error.title, error.status, error.title, error.message, error.request_id
266 )
267 }
268
269 fn load_templates(config: &ErrorPageConfig) -> Arc<HashMap<u16, String>> {
271 let mut templates = HashMap::new();
272
273 if let Some(ref template_dir) = config.template_dir {
274 for (status_code, page) in &config.pages {
275 if let Some(ref template_path) = page.template {
276 let full_path = if template_path.is_absolute() {
277 template_path.clone()
278 } else {
279 template_dir.join(template_path)
280 };
281
282 match std::fs::read_to_string(&full_path) {
283 Ok(content) => {
284 templates.insert(*status_code, content);
285 debug!(
286 "Loaded error template for status {}: {:?}",
287 status_code, full_path
288 );
289 }
290 Err(e) => {
291 warn!("Failed to load error template {:?}: {}", full_path, e);
292 }
293 }
294 }
295 }
296 }
297
298 Arc::new(templates)
299 }
300
301 fn render_template(&self, template: &str, error: &ErrorResponse) -> Result<String> {
303 let rendered = template
305 .replace("{{status}}", &error.status.to_string())
306 .replace("{{title}}", &error.title)
307 .replace("{{message}}", &error.message)
308 .replace("{{request_id}}", &error.request_id)
309 .replace("{{timestamp}}", &error.timestamp.to_string());
310
311 Ok(rendered)
312 }
313
314 fn get_stack_trace(&self) -> Option<Vec<String>> {
316 if let Some(ref config) = self.config {
317 if config.include_stack_trace {
318 return None;
321 }
322 }
323 None
324 }
325
326 fn status_title(status: StatusCode) -> String {
328 status
329 .canonical_reason()
330 .unwrap_or("Unknown Error")
331 .to_string()
332 }
333
334 fn default_message(status: StatusCode) -> String {
336 match status {
337 StatusCode::BAD_REQUEST => {
338 "The request could not be understood by the server.".to_string()
339 }
340 StatusCode::UNAUTHORIZED => {
341 "You are not authorized to access this resource.".to_string()
342 }
343 StatusCode::FORBIDDEN => "Access to this resource is forbidden.".to_string(),
344 StatusCode::NOT_FOUND => "The requested resource could not be found.".to_string(),
345 StatusCode::METHOD_NOT_ALLOWED => {
346 "The requested method is not allowed for this resource.".to_string()
347 }
348 StatusCode::REQUEST_TIMEOUT => "The request took too long to process.".to_string(),
349 StatusCode::PAYLOAD_TOO_LARGE => "The request payload is too large.".to_string(),
350 StatusCode::TOO_MANY_REQUESTS => {
351 "Too many requests. Please try again later.".to_string()
352 }
353 StatusCode::INTERNAL_SERVER_ERROR => {
354 "An internal server error occurred. Please try again later.".to_string()
355 }
356 StatusCode::BAD_GATEWAY => {
357 "The gateway received an invalid response from the upstream server.".to_string()
358 }
359 StatusCode::SERVICE_UNAVAILABLE => {
360 "The service is temporarily unavailable. Please try again later.".to_string()
361 }
362 StatusCode::GATEWAY_TIMEOUT => {
363 "The gateway timed out waiting for a response from the upstream server.".to_string()
364 }
365 _ => format!("An error occurred (HTTP {})", status.as_u16()),
366 }
367 }
368
369 fn escape_xml(s: &str) -> String {
371 s.replace('&', "&")
372 .replace('<', "<")
373 .replace('>', ">")
374 .replace('"', """)
375 .replace('\'', "'")
376 }
377
378 pub fn reload_templates(&mut self) {
380 if let Some(ref config) = self.config {
381 self.templates = Self::load_templates(config);
382 debug!("Reloaded error templates");
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_error_handler_json() {
393 let handler = ErrorHandler::new(ServiceType::Api, None);
394 let response = handler
395 .generate_response(
396 StatusCode::NOT_FOUND,
397 Some("Resource not found".to_string()),
398 "test-123",
399 None,
400 )
401 .unwrap();
402
403 assert_eq!(response.status(), StatusCode::NOT_FOUND);
404 let headers = response.headers();
405 assert_eq!(
406 headers.get("Content-Type").unwrap(),
407 "application/json; charset=utf-8"
408 );
409 }
410
411 #[test]
412 fn test_error_handler_html() {
413 let handler = ErrorHandler::new(ServiceType::Web, None);
414 let response = handler
415 .generate_response(StatusCode::INTERNAL_SERVER_ERROR, None, "test-456", None)
416 .unwrap();
417
418 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
419 let headers = response.headers();
420 assert_eq!(
421 headers.get("Content-Type").unwrap(),
422 "text/html; charset=utf-8"
423 );
424 }
425
426 #[test]
427 fn test_custom_error_format() {
428 let mut config = ErrorPageConfig {
429 pages: HashMap::new(),
430 default_format: ErrorFormat::Xml,
431 include_stack_trace: false,
432 template_dir: None,
433 };
434
435 config.pages.insert(
436 404,
437 ErrorPage {
438 format: ErrorFormat::Text,
439 template: None,
440 message: Some("Custom 404 message".to_string()),
441 headers: HashMap::new(),
442 },
443 );
444
445 let handler = ErrorHandler::new(ServiceType::Web, Some(config));
446 let response = handler
447 .generate_response(StatusCode::NOT_FOUND, None, "test-789", None)
448 .unwrap();
449
450 assert_eq!(
451 response.headers().get("Content-Type").unwrap(),
452 "text/plain; charset=utf-8"
453 );
454 }
455}