sentinel_proxy/errors/
mod.rs

1//! Error handling module for Sentinel proxy
2//!
3//! This module provides customizable error page generation for different
4//! service types (web, API, static) and formats (HTML, JSON, text, XML).
5
6use anyhow::Result;
7use bytes::Bytes;
8use http::{Response, StatusCode};
9use http_body_util::Full;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15use sentinel_config::{ErrorFormat, ErrorPage, ErrorPageConfig, ServiceType};
16
17/// Error response generator
18pub struct ErrorHandler {
19    /// Service type for this handler
20    service_type: ServiceType,
21    /// Error page configuration
22    config: Option<ErrorPageConfig>,
23    /// Cached error templates
24    templates: Arc<HashMap<u16, String>>,
25}
26
27/// Error response data
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ErrorResponse {
30    /// HTTP status code
31    pub status: u16,
32    /// Error title
33    pub title: String,
34    /// Error message
35    pub message: String,
36    /// Request ID for tracking
37    pub request_id: String,
38    /// Timestamp
39    pub timestamp: i64,
40    /// Additional details (optional)
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub details: Option<serde_json::Value>,
43    /// Stack trace (development only)
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub stack_trace: Option<Vec<String>>,
46}
47
48impl ErrorHandler {
49    /// Create a new error handler
50    pub fn new(service_type: ServiceType, config: Option<ErrorPageConfig>) -> Self {
51        let templates = if let Some(ref cfg) = config {
52            Self::load_templates(cfg)
53        } else {
54            Arc::new(HashMap::new())
55        };
56
57        Self {
58            service_type,
59            config,
60            templates,
61        }
62    }
63
64    /// Generate an error response
65    pub fn generate_response(
66        &self,
67        status: StatusCode,
68        message: Option<String>,
69        request_id: &str,
70        details: Option<serde_json::Value>,
71    ) -> Result<Response<Full<Bytes>>> {
72        let status_code = status.as_u16();
73        let error_data = ErrorResponse {
74            status: status_code,
75            title: Self::status_title(status),
76            message: message.unwrap_or_else(|| Self::default_message(status)),
77            request_id: request_id.to_string(),
78            timestamp: chrono::Utc::now().timestamp(),
79            details,
80            stack_trace: self.get_stack_trace(),
81        };
82
83        // Determine the format to use
84        let format = self.determine_format(status_code);
85
86        // Generate the response body
87        let (body, content_type) = match format {
88            ErrorFormat::Json => self.generate_json_response(&error_data)?,
89            ErrorFormat::Html => self.generate_html_response(&error_data, status_code)?,
90            ErrorFormat::Text => self.generate_text_response(&error_data)?,
91            ErrorFormat::Xml => self.generate_xml_response(&error_data)?,
92        };
93
94        // Build the response
95        let mut response = Response::builder()
96            .status(status)
97            .header("Content-Type", content_type)
98            .header("X-Request-Id", request_id);
99
100        // Add custom headers if configured
101        if let Some(page) = self.get_error_page(status_code) {
102            for (key, value) in &page.headers {
103                response = response.header(key, value);
104            }
105        }
106
107        Ok(response.body(Full::new(Bytes::from(body)))?)
108    }
109
110    /// Determine the error format based on service type and configuration
111    fn determine_format(&self, status_code: u16) -> ErrorFormat {
112        // Check if there's a specific configuration for this status code
113        if let Some(page) = self.get_error_page(status_code) {
114            return page.format;
115        }
116
117        // Use default format based on service type
118        match self.service_type {
119            ServiceType::Api | ServiceType::Builtin => ErrorFormat::Json,
120            ServiceType::Web | ServiceType::Static => self
121                .config
122                .as_ref()
123                .map(|c| c.default_format)
124                .unwrap_or(ErrorFormat::Html),
125        }
126    }
127
128    /// Get error page configuration for a specific status code
129    fn get_error_page(&self, status_code: u16) -> Option<&ErrorPage> {
130        self.config.as_ref().and_then(|c| c.pages.get(&status_code))
131    }
132
133    /// Generate JSON error response
134    fn generate_json_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
135        let json = serde_json::to_vec_pretty(error)?;
136        Ok((json, "application/json; charset=utf-8"))
137    }
138
139    /// Generate HTML error response
140    fn generate_html_response(
141        &self,
142        error: &ErrorResponse,
143        status_code: u16,
144    ) -> Result<(Vec<u8>, &'static str)> {
145        // Check for custom template
146        if let Some(template) = self.templates.get(&status_code) {
147            let html = self.render_template(template, error)?;
148            return Ok((html.into_bytes(), "text/html; charset=utf-8"));
149        }
150
151        // Generate default HTML
152        let html = self.generate_default_html(error);
153        Ok((html.into_bytes(), "text/html; charset=utf-8"))
154    }
155
156    /// Generate text error response
157    fn generate_text_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
158        let text = format!(
159            "{} {}\n\n{}\n\nRequest ID: {}\nTimestamp: {}",
160            error.status, error.title, error.message, error.request_id, error.timestamp
161        );
162        Ok((text.into_bytes(), "text/plain; charset=utf-8"))
163    }
164
165    /// Generate XML error response
166    fn generate_xml_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
167        let xml = format!(
168            r#"<?xml version="1.0" encoding="UTF-8"?>
169<error>
170    <status>{}</status>
171    <title>{}</title>
172    <message>{}</message>
173    <requestId>{}</requestId>
174    <timestamp>{}</timestamp>
175</error>"#,
176            error.status,
177            Self::escape_xml(&error.title),
178            Self::escape_xml(&error.message),
179            Self::escape_xml(&error.request_id),
180            error.timestamp
181        );
182        Ok((xml.into_bytes(), "application/xml; charset=utf-8"))
183    }
184
185    /// Generate default HTML error page
186    fn generate_default_html(&self, error: &ErrorResponse) -> String {
187        format!(
188            r#"<!DOCTYPE html>
189<html lang="en">
190<head>
191    <meta charset="UTF-8">
192    <meta name="viewport" content="width=device-width, initial-scale=1.0">
193    <title>{} {}</title>
194    <style>
195        body {{
196            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
197            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
198            color: #333;
199            display: flex;
200            align-items: center;
201            justify-content: center;
202            min-height: 100vh;
203            margin: 0;
204            padding: 20px;
205        }}
206        .error-container {{
207            background: white;
208            border-radius: 12px;
209            box-shadow: 0 20px 60px rgba(0,0,0,0.3);
210            padding: 40px;
211            max-width: 600px;
212            width: 100%;
213            text-align: center;
214        }}
215        h1 {{
216            color: #764ba2;
217            font-size: 72px;
218            margin: 0;
219            font-weight: bold;
220        }}
221        h2 {{
222            color: #666;
223            font-size: 24px;
224            margin: 10px 0;
225            font-weight: normal;
226        }}
227        p {{
228            color: #777;
229            font-size: 16px;
230            line-height: 1.6;
231            margin: 20px 0;
232        }}
233        .request-id {{
234            background: #f5f5f5;
235            border-radius: 4px;
236            padding: 8px 12px;
237            font-family: 'Courier New', monospace;
238            font-size: 12px;
239            color: #999;
240            margin-top: 30px;
241        }}
242        .back-link {{
243            display: inline-block;
244            margin-top: 20px;
245            color: #667eea;
246            text-decoration: none;
247            font-weight: 500;
248            transition: color 0.3s;
249        }}
250        .back-link:hover {{
251            color: #764ba2;
252        }}
253    </style>
254</head>
255<body>
256    <div class="error-container">
257        <h1>{}</h1>
258        <h2>{}</h2>
259        <p>{}</p>
260        <div class="request-id">Request ID: {}</div>
261        <a href="/" class="back-link">← Back to Home</a>
262    </div>
263</body>
264</html>"#,
265            error.status, error.title, error.status, error.title, error.message, error.request_id
266        )
267    }
268
269    /// Load custom templates from disk
270    fn load_templates(config: &ErrorPageConfig) -> Arc<HashMap<u16, String>> {
271        let mut templates = HashMap::new();
272
273        if let Some(ref template_dir) = config.template_dir {
274            for (status_code, page) in &config.pages {
275                if let Some(ref template_path) = page.template {
276                    let full_path = if template_path.is_absolute() {
277                        template_path.clone()
278                    } else {
279                        template_dir.join(template_path)
280                    };
281
282                    match std::fs::read_to_string(&full_path) {
283                        Ok(content) => {
284                            templates.insert(*status_code, content);
285                            debug!(
286                                "Loaded error template for status {}: {:?}",
287                                status_code, full_path
288                            );
289                        }
290                        Err(e) => {
291                            warn!("Failed to load error template {:?}: {}", full_path, e);
292                        }
293                    }
294                }
295            }
296        }
297
298        Arc::new(templates)
299    }
300
301    /// Render a template with error data
302    fn render_template(&self, template: &str, error: &ErrorResponse) -> Result<String> {
303        // Simple template rendering - replace placeholders
304        let rendered = template
305            .replace("{{status}}", &error.status.to_string())
306            .replace("{{title}}", &error.title)
307            .replace("{{message}}", &error.message)
308            .replace("{{request_id}}", &error.request_id)
309            .replace("{{timestamp}}", &error.timestamp.to_string());
310
311        Ok(rendered)
312    }
313
314    /// Get stack trace if enabled (development only)
315    fn get_stack_trace(&self) -> Option<Vec<String>> {
316        if let Some(ref config) = self.config {
317            if config.include_stack_trace {
318                // In production, we would capture the actual stack trace
319                // For now, return None
320                return None;
321            }
322        }
323        None
324    }
325
326    /// Get default status title
327    fn status_title(status: StatusCode) -> String {
328        status
329            .canonical_reason()
330            .unwrap_or("Unknown Error")
331            .to_string()
332    }
333
334    /// Get default error message for status code
335    fn default_message(status: StatusCode) -> String {
336        match status {
337            StatusCode::BAD_REQUEST => {
338                "The request could not be understood by the server.".to_string()
339            }
340            StatusCode::UNAUTHORIZED => {
341                "You are not authorized to access this resource.".to_string()
342            }
343            StatusCode::FORBIDDEN => "Access to this resource is forbidden.".to_string(),
344            StatusCode::NOT_FOUND => "The requested resource could not be found.".to_string(),
345            StatusCode::METHOD_NOT_ALLOWED => {
346                "The requested method is not allowed for this resource.".to_string()
347            }
348            StatusCode::REQUEST_TIMEOUT => "The request took too long to process.".to_string(),
349            StatusCode::PAYLOAD_TOO_LARGE => "The request payload is too large.".to_string(),
350            StatusCode::TOO_MANY_REQUESTS => {
351                "Too many requests. Please try again later.".to_string()
352            }
353            StatusCode::INTERNAL_SERVER_ERROR => {
354                "An internal server error occurred. Please try again later.".to_string()
355            }
356            StatusCode::BAD_GATEWAY => {
357                "The gateway received an invalid response from the upstream server.".to_string()
358            }
359            StatusCode::SERVICE_UNAVAILABLE => {
360                "The service is temporarily unavailable. Please try again later.".to_string()
361            }
362            StatusCode::GATEWAY_TIMEOUT => {
363                "The gateway timed out waiting for a response from the upstream server.".to_string()
364            }
365            _ => format!("An error occurred (HTTP {})", status.as_u16()),
366        }
367    }
368
369    /// Escape XML special characters
370    fn escape_xml(s: &str) -> String {
371        s.replace('&', "&amp;")
372            .replace('<', "&lt;")
373            .replace('>', "&gt;")
374            .replace('"', "&quot;")
375            .replace('\'', "&apos;")
376    }
377
378    /// Reload templates (for hot reload)
379    pub fn reload_templates(&mut self) {
380        if let Some(ref config) = self.config {
381            self.templates = Self::load_templates(config);
382            debug!("Reloaded error templates");
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_error_handler_json() {
393        let handler = ErrorHandler::new(ServiceType::Api, None);
394        let response = handler
395            .generate_response(
396                StatusCode::NOT_FOUND,
397                Some("Resource not found".to_string()),
398                "test-123",
399                None,
400            )
401            .unwrap();
402
403        assert_eq!(response.status(), StatusCode::NOT_FOUND);
404        let headers = response.headers();
405        assert_eq!(
406            headers.get("Content-Type").unwrap(),
407            "application/json; charset=utf-8"
408        );
409    }
410
411    #[test]
412    fn test_error_handler_html() {
413        let handler = ErrorHandler::new(ServiceType::Web, None);
414        let response = handler
415            .generate_response(StatusCode::INTERNAL_SERVER_ERROR, None, "test-456", None)
416            .unwrap();
417
418        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
419        let headers = response.headers();
420        assert_eq!(
421            headers.get("Content-Type").unwrap(),
422            "text/html; charset=utf-8"
423        );
424    }
425
426    #[test]
427    fn test_custom_error_format() {
428        let mut config = ErrorPageConfig {
429            pages: HashMap::new(),
430            default_format: ErrorFormat::Xml,
431            include_stack_trace: false,
432            template_dir: None,
433        };
434
435        config.pages.insert(
436            404,
437            ErrorPage {
438                format: ErrorFormat::Text,
439                template: None,
440                message: Some("Custom 404 message".to_string()),
441                headers: HashMap::new(),
442            },
443        );
444
445        let handler = ErrorHandler::new(ServiceType::Web, Some(config));
446        let response = handler
447            .generate_response(StatusCode::NOT_FOUND, None, "test-789", None)
448            .unwrap();
449
450        assert_eq!(
451            response.headers().get("Content-Type").unwrap(),
452            "text/plain; charset=utf-8"
453        );
454    }
455}