sentinel_proxy/errors/
mod.rs

1//! Error handling module for Sentinel proxy
2//!
3//! This module provides customizable error page generation for different
4//! service types (web, API, static) and formats (HTML, JSON, text, XML).
5
6use anyhow::Result;
7use bytes::Bytes;
8use http::{Response, StatusCode};
9use http_body_util::Full;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15use sentinel_config::{ErrorFormat, ErrorPage, ErrorPageConfig, ServiceType};
16
17/// Error response generator
18pub struct ErrorHandler {
19    /// Service type for this handler
20    service_type: ServiceType,
21    /// Error page configuration
22    config: Option<ErrorPageConfig>,
23    /// Cached error templates
24    templates: Arc<HashMap<u16, String>>,
25}
26
27/// Error response data
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ErrorResponse {
30    /// HTTP status code
31    pub status: u16,
32    /// Error title
33    pub title: String,
34    /// Error message
35    pub message: String,
36    /// Request ID for tracking
37    pub request_id: String,
38    /// Timestamp
39    pub timestamp: i64,
40    /// Additional details (optional)
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub details: Option<serde_json::Value>,
43    /// Stack trace (development only)
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub stack_trace: Option<Vec<String>>,
46}
47
48impl ErrorHandler {
49    /// Create a new error handler
50    pub fn new(service_type: ServiceType, config: Option<ErrorPageConfig>) -> Self {
51        let templates = if let Some(ref cfg) = config {
52            Self::load_templates(cfg)
53        } else {
54            Arc::new(HashMap::new())
55        };
56
57        Self {
58            service_type,
59            config,
60            templates,
61        }
62    }
63
64    /// Generate an error response
65    pub fn generate_response(
66        &self,
67        status: StatusCode,
68        message: Option<String>,
69        request_id: &str,
70        details: Option<serde_json::Value>,
71    ) -> Result<Response<Full<Bytes>>> {
72        let status_code = status.as_u16();
73        let error_data = ErrorResponse {
74            status: status_code,
75            title: Self::status_title(status),
76            message: message.unwrap_or_else(|| Self::default_message(status)),
77            request_id: request_id.to_string(),
78            timestamp: chrono::Utc::now().timestamp(),
79            details,
80            stack_trace: self.get_stack_trace(),
81        };
82
83        // Determine the format to use
84        let format = self.determine_format(status_code);
85
86        // Generate the response body
87        let (body, content_type) = match format {
88            ErrorFormat::Json => self.generate_json_response(&error_data)?,
89            ErrorFormat::Html => self.generate_html_response(&error_data, status_code)?,
90            ErrorFormat::Text => self.generate_text_response(&error_data)?,
91            ErrorFormat::Xml => self.generate_xml_response(&error_data)?,
92        };
93
94        // Build the response
95        let mut response = Response::builder()
96            .status(status)
97            .header("Content-Type", content_type)
98            .header("X-Request-Id", request_id);
99
100        // Add custom headers if configured
101        if let Some(page) = self.get_error_page(status_code) {
102            for (key, value) in &page.headers {
103                response = response.header(key, value);
104            }
105        }
106
107        Ok(response.body(Full::new(Bytes::from(body)))?)
108    }
109
110    /// Determine the error format based on service type and configuration
111    fn determine_format(&self, status_code: u16) -> ErrorFormat {
112        // Check if there's a specific configuration for this status code
113        if let Some(page) = self.get_error_page(status_code) {
114            return page.format;
115        }
116
117        // Check if there's a default format configured
118        if let Some(ref config) = self.config {
119            return config.default_format;
120        }
121
122        // Fall back to service type default
123        match self.service_type {
124            ServiceType::Api | ServiceType::Builtin => ErrorFormat::Json,
125            ServiceType::Web | ServiceType::Static => ErrorFormat::Html,
126        }
127    }
128
129    /// Get error page configuration for a specific status code
130    fn get_error_page(&self, status_code: u16) -> Option<&ErrorPage> {
131        self.config.as_ref().and_then(|c| c.pages.get(&status_code))
132    }
133
134    /// Generate JSON error response
135    fn generate_json_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
136        let json = serde_json::to_vec_pretty(error)?;
137        Ok((json, "application/json; charset=utf-8"))
138    }
139
140    /// Generate HTML error response
141    fn generate_html_response(
142        &self,
143        error: &ErrorResponse,
144        status_code: u16,
145    ) -> Result<(Vec<u8>, &'static str)> {
146        // Check for custom template
147        if let Some(template) = self.templates.get(&status_code) {
148            let html = self.render_template(template, error)?;
149            return Ok((html.into_bytes(), "text/html; charset=utf-8"));
150        }
151
152        // Generate default HTML
153        let html = self.generate_default_html(error);
154        Ok((html.into_bytes(), "text/html; charset=utf-8"))
155    }
156
157    /// Generate text error response
158    fn generate_text_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
159        let text = format!(
160            "{} {}\n\n{}\n\nRequest ID: {}\nTimestamp: {}",
161            error.status, error.title, error.message, error.request_id, error.timestamp
162        );
163        Ok((text.into_bytes(), "text/plain; charset=utf-8"))
164    }
165
166    /// Generate XML error response
167    fn generate_xml_response(&self, error: &ErrorResponse) -> Result<(Vec<u8>, &'static str)> {
168        let xml = format!(
169            r#"<?xml version="1.0" encoding="UTF-8"?>
170<error>
171    <status>{}</status>
172    <title>{}</title>
173    <message>{}</message>
174    <requestId>{}</requestId>
175    <timestamp>{}</timestamp>
176</error>"#,
177            error.status,
178            Self::escape_xml(&error.title),
179            Self::escape_xml(&error.message),
180            Self::escape_xml(&error.request_id),
181            error.timestamp
182        );
183        Ok((xml.into_bytes(), "application/xml; charset=utf-8"))
184    }
185
186    /// Generate default HTML error page
187    fn generate_default_html(&self, error: &ErrorResponse) -> String {
188        format!(
189            r#"<!DOCTYPE html>
190<html lang="en">
191<head>
192    <meta charset="UTF-8">
193    <meta name="viewport" content="width=device-width, initial-scale=1.0">
194    <title>{} {}</title>
195    <style>
196        body {{
197            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
198            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
199            color: #333;
200            display: flex;
201            align-items: center;
202            justify-content: center;
203            min-height: 100vh;
204            margin: 0;
205            padding: 20px;
206        }}
207        .error-container {{
208            background: white;
209            border-radius: 12px;
210            box-shadow: 0 20px 60px rgba(0,0,0,0.3);
211            padding: 40px;
212            max-width: 600px;
213            width: 100%;
214            text-align: center;
215        }}
216        h1 {{
217            color: #764ba2;
218            font-size: 72px;
219            margin: 0;
220            font-weight: bold;
221        }}
222        h2 {{
223            color: #666;
224            font-size: 24px;
225            margin: 10px 0;
226            font-weight: normal;
227        }}
228        p {{
229            color: #777;
230            font-size: 16px;
231            line-height: 1.6;
232            margin: 20px 0;
233        }}
234        .request-id {{
235            background: #f5f5f5;
236            border-radius: 4px;
237            padding: 8px 12px;
238            font-family: 'Courier New', monospace;
239            font-size: 12px;
240            color: #999;
241            margin-top: 30px;
242        }}
243        .back-link {{
244            display: inline-block;
245            margin-top: 20px;
246            color: #667eea;
247            text-decoration: none;
248            font-weight: 500;
249            transition: color 0.3s;
250        }}
251        .back-link:hover {{
252            color: #764ba2;
253        }}
254    </style>
255</head>
256<body>
257    <div class="error-container">
258        <h1>{}</h1>
259        <h2>{}</h2>
260        <p>{}</p>
261        <div class="request-id">Request ID: {}</div>
262        <a href="/" class="back-link">← Back to Home</a>
263    </div>
264</body>
265</html>"#,
266            error.status, error.title, error.status, error.title, error.message, error.request_id
267        )
268    }
269
270    /// Load custom templates from disk
271    fn load_templates(config: &ErrorPageConfig) -> Arc<HashMap<u16, String>> {
272        let mut templates = HashMap::new();
273
274        if let Some(ref template_dir) = config.template_dir {
275            for (status_code, page) in &config.pages {
276                if let Some(ref template_path) = page.template {
277                    let full_path = if template_path.is_absolute() {
278                        template_path.clone()
279                    } else {
280                        template_dir.join(template_path)
281                    };
282
283                    match std::fs::read_to_string(&full_path) {
284                        Ok(content) => {
285                            templates.insert(*status_code, content);
286                            debug!(
287                                "Loaded error template for status {}: {:?}",
288                                status_code, full_path
289                            );
290                        }
291                        Err(e) => {
292                            warn!("Failed to load error template {:?}: {}", full_path, e);
293                        }
294                    }
295                }
296            }
297        }
298
299        Arc::new(templates)
300    }
301
302    /// Render a template with error data
303    fn render_template(&self, template: &str, error: &ErrorResponse) -> Result<String> {
304        // Simple template rendering - replace placeholders
305        let rendered = template
306            .replace("{{status}}", &error.status.to_string())
307            .replace("{{title}}", &error.title)
308            .replace("{{message}}", &error.message)
309            .replace("{{request_id}}", &error.request_id)
310            .replace("{{timestamp}}", &error.timestamp.to_string());
311
312        Ok(rendered)
313    }
314
315    /// Get stack trace if enabled (development only)
316    fn get_stack_trace(&self) -> Option<Vec<String>> {
317        if let Some(ref config) = self.config {
318            if config.include_stack_trace {
319                // In production, we would capture the actual stack trace
320                // For now, return None
321                return None;
322            }
323        }
324        None
325    }
326
327    /// Get default status title
328    fn status_title(status: StatusCode) -> String {
329        status
330            .canonical_reason()
331            .unwrap_or("Unknown Error")
332            .to_string()
333    }
334
335    /// Get default error message for status code
336    fn default_message(status: StatusCode) -> String {
337        match status {
338            StatusCode::BAD_REQUEST => {
339                "The request could not be understood by the server.".to_string()
340            }
341            StatusCode::UNAUTHORIZED => {
342                "You are not authorized to access this resource.".to_string()
343            }
344            StatusCode::FORBIDDEN => "Access to this resource is forbidden.".to_string(),
345            StatusCode::NOT_FOUND => "The requested resource could not be found.".to_string(),
346            StatusCode::METHOD_NOT_ALLOWED => {
347                "The requested method is not allowed for this resource.".to_string()
348            }
349            StatusCode::REQUEST_TIMEOUT => "The request took too long to process.".to_string(),
350            StatusCode::PAYLOAD_TOO_LARGE => "The request payload is too large.".to_string(),
351            StatusCode::TOO_MANY_REQUESTS => {
352                "Too many requests. Please try again later.".to_string()
353            }
354            StatusCode::INTERNAL_SERVER_ERROR => {
355                "An internal server error occurred. Please try again later.".to_string()
356            }
357            StatusCode::BAD_GATEWAY => {
358                "The gateway received an invalid response from the upstream server.".to_string()
359            }
360            StatusCode::SERVICE_UNAVAILABLE => {
361                "The service is temporarily unavailable. Please try again later.".to_string()
362            }
363            StatusCode::GATEWAY_TIMEOUT => {
364                "The gateway timed out waiting for a response from the upstream server.".to_string()
365            }
366            _ => format!("An error occurred (HTTP {})", status.as_u16()),
367        }
368    }
369
370    /// Escape XML special characters
371    fn escape_xml(s: &str) -> String {
372        s.replace('&', "&amp;")
373            .replace('<', "&lt;")
374            .replace('>', "&gt;")
375            .replace('"', "&quot;")
376            .replace('\'', "&apos;")
377    }
378
379    /// Reload templates (for hot reload)
380    pub fn reload_templates(&mut self) {
381        if let Some(ref config) = self.config {
382            self.templates = Self::load_templates(config);
383            debug!("Reloaded error templates");
384        }
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_error_handler_json() {
394        let handler = ErrorHandler::new(ServiceType::Api, None);
395        let response = handler
396            .generate_response(
397                StatusCode::NOT_FOUND,
398                Some("Resource not found".to_string()),
399                "test-123",
400                None,
401            )
402            .unwrap();
403
404        assert_eq!(response.status(), StatusCode::NOT_FOUND);
405        let headers = response.headers();
406        assert_eq!(
407            headers.get("Content-Type").unwrap(),
408            "application/json; charset=utf-8"
409        );
410    }
411
412    #[test]
413    fn test_error_handler_html() {
414        let handler = ErrorHandler::new(ServiceType::Web, None);
415        let response = handler
416            .generate_response(StatusCode::INTERNAL_SERVER_ERROR, None, "test-456", None)
417            .unwrap();
418
419        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
420        let headers = response.headers();
421        assert_eq!(
422            headers.get("Content-Type").unwrap(),
423            "text/html; charset=utf-8"
424        );
425    }
426
427    #[test]
428    fn test_custom_error_format() {
429        let mut config = ErrorPageConfig {
430            pages: HashMap::new(),
431            default_format: ErrorFormat::Xml,
432            include_stack_trace: false,
433            template_dir: None,
434        };
435
436        config.pages.insert(
437            404,
438            ErrorPage {
439                format: ErrorFormat::Text,
440                template: None,
441                message: Some("Custom 404 message".to_string()),
442                headers: HashMap::new(),
443            },
444        );
445
446        let handler = ErrorHandler::new(ServiceType::Web, Some(config));
447        let response = handler
448            .generate_response(StatusCode::NOT_FOUND, None, "test-789", None)
449            .unwrap();
450
451        assert_eq!(
452            response.headers().get("Content-Type").unwrap(),
453            "text/plain; charset=utf-8"
454        );
455    }
456}