wisegate_core/
request_handler.rs

1//! HTTP request handling and proxying.
2//!
3//! This module contains the core request handling logic for the reverse proxy,
4//! including IP validation, rate limiting, URL pattern blocking, and request forwarding.
5//!
6//! # Architecture
7//!
8//! The request handling flow:
9//! 1. Extract and validate client IP from proxy headers
10//! 2. Check if IP is blocked
11//! 3. Check for blocked URL patterns
12//! 4. Check for blocked HTTP methods
13//! 5. Apply rate limiting
14//! 6. Forward request to upstream service
15//!
16//! # Connection Pooling
17//!
18//! The module accepts a shared [`reqwest::Client`] for HTTP connection pooling,
19//! which should be configured by the caller with appropriate timeouts.
20
21use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::error::WiseGateError;
27use crate::types::{ConfigProvider, RateLimiter};
28use crate::{auth, headers, ip_filter, rate_limiter};
29
30/// Handles an incoming HTTP request through the proxy pipeline.
31///
32/// This is the main entry point for request processing. It performs:
33/// - Client IP extraction and validation from proxy headers
34/// - IP blocking checks
35/// - URL pattern blocking (e.g., `.php`, `.env` files)
36/// - HTTP method blocking (e.g., `PUT`, `DELETE`)
37/// - Rate limiting per client IP
38/// - Request forwarding to the upstream service
39///
40/// # Arguments
41///
42/// * `req` - The incoming HTTP request
43/// * `forward_host` - The upstream host to forward requests to
44/// * `forward_port` - The upstream port to forward requests to
45/// * `limiter` - The shared rate limiter instance
46/// * `config` - Configuration provider for all settings
47/// * `http_client` - HTTP client for forwarding requests (with connection pooling)
48///
49/// # Returns
50///
51/// Always returns `Ok` with either:
52/// - A successful proxied response from upstream
53/// - An error response (403, 404, 405, 429, 502, etc.)
54///
55/// # Security
56///
57/// In strict mode (when proxy allowlist is configured), requests
58/// without valid proxy headers are rejected with 403 Forbidden.
59pub async fn handle_request<C: ConfigProvider>(
60    req: Request<Incoming>,
61    forward_host: String,
62    forward_port: u16,
63    limiter: RateLimiter,
64    config: Arc<C>,
65    http_client: reqwest::Client,
66) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
67    // Extract and validate real client IP
68    let real_client_ip =
69        match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
70            Some(ip) => ip,
71            None => {
72                // In permissive mode (no allowlist configured), we couldn't extract IP from headers
73                // Use placeholder IP and continue with non-IP-based security features only
74                if config.allowed_proxy_ips().is_none() {
75                    "unknown".to_string()
76                } else {
77                    // If allowlist is configured but validation failed, reject the request
78                    let err = WiseGateError::InvalidIp("missing or invalid proxy headers".into());
79                    return Ok(create_error_response(err.status_code(), err.user_message()));
80                }
81            }
82        };
83
84    // Check if IP is blocked (skip if IP is unknown)
85    if real_client_ip != "unknown" && ip_filter::is_ip_blocked(&real_client_ip, config.as_ref()) {
86        let err = WiseGateError::IpBlocked(real_client_ip);
87        return Ok(create_error_response(err.status_code(), err.user_message()));
88    }
89
90    // Check for blocked URL patterns
91    let request_path = req.uri().path();
92    if is_url_pattern_blocked(request_path, config.as_ref()) {
93        let err = WiseGateError::PatternBlocked(request_path.to_string());
94        return Ok(create_error_response(err.status_code(), err.user_message()));
95    }
96
97    // Check for blocked HTTP methods
98    let request_method = req.method().as_str();
99    if is_method_blocked(request_method, config.as_ref()) {
100        let err = WiseGateError::MethodBlocked(request_method.to_string());
101        return Ok(create_error_response(err.status_code(), err.user_message()));
102    }
103
104    // Check Authentication if enabled (Basic Auth and/or Bearer Token)
105    // Logic: if both are configured, either one passing is sufficient
106    if config.is_auth_enabled() {
107        let auth_header = req
108            .headers()
109            .get(headers::AUTHORIZATION)
110            .and_then(|v| v.to_str().ok());
111
112        let basic_auth_enabled = config.is_basic_auth_enabled();
113        let bearer_auth_enabled = config.is_bearer_auth_enabled();
114
115        let basic_auth_passed =
116            basic_auth_enabled && auth::check_basic_auth(auth_header, config.auth_credentials());
117        let bearer_auth_passed =
118            bearer_auth_enabled && auth::check_bearer_token(auth_header, config.bearer_token());
119
120        // Authentication fails if neither method passed
121        if !basic_auth_passed && !bearer_auth_passed {
122            return Ok(create_unauthorized_response(config.auth_realm()));
123        }
124    }
125
126    // Apply rate limiting (skip if IP is unknown)
127    if real_client_ip != "unknown"
128        && !rate_limiter::check_rate_limit(&limiter, &real_client_ip, config.as_ref()).await
129    {
130        let err = WiseGateError::RateLimitExceeded(real_client_ip);
131        return Ok(create_error_response(err.status_code(), err.user_message()));
132    }
133
134    // Add X-Real-IP header for upstream service (only if we have a real IP)
135    let mut req = req;
136    if real_client_ip != "unknown"
137        && let Ok(header_value) = real_client_ip.parse()
138    {
139        req.headers_mut().insert("x-real-ip", header_value);
140    }
141
142    // Forward the request
143    forward_request(
144        req,
145        &forward_host,
146        forward_port,
147        config.as_ref(),
148        &http_client,
149    )
150    .await
151}
152
153/// Forward request to upstream service
154async fn forward_request(
155    req: Request<Incoming>,
156    host: &str,
157    port: u16,
158    config: &impl ConfigProvider,
159    http_client: &reqwest::Client,
160) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
161    let proxy_config = config.proxy_config();
162    let (parts, body) = req.into_parts();
163    let body_bytes = match body.collect().await {
164        Ok(bytes) => {
165            let collected_bytes = bytes.to_bytes();
166
167            // Check body size limit
168            if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
169            {
170                let err = WiseGateError::BodyTooLarge {
171                    size: collected_bytes.len(),
172                    max: proxy_config.max_body_size,
173                };
174                return Ok(create_error_response(err.status_code(), err.user_message()));
175            }
176
177            collected_bytes
178        }
179        Err(e) => {
180            let err = WiseGateError::BodyReadError(e.to_string());
181            return Ok(create_error_response(err.status_code(), err.user_message()));
182        }
183    };
184
185    forward_with_reqwest(parts, body_bytes, host, port, http_client).await
186}
187
188/// Shared forwarding logic using reqwest with connection pooling
189async fn forward_with_reqwest(
190    parts: hyper::http::request::Parts,
191    body_bytes: bytes::Bytes,
192    host: &str,
193    port: u16,
194    client: &reqwest::Client,
195) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
196    // Construct destination URI
197    let destination_uri = format!(
198        "http://{}:{}{}",
199        host,
200        port,
201        parts.uri.path_and_query().map_or("", |pq| pq.as_str())
202    );
203
204    // Build the request with method support for all HTTP verbs
205    let mut req_builder = match parts.method.as_str() {
206        "GET" => client.get(&destination_uri),
207        "POST" => client.post(&destination_uri),
208        "PUT" => client.put(&destination_uri),
209        "DELETE" => client.delete(&destination_uri),
210        "HEAD" => client.head(&destination_uri),
211        "PATCH" => client.patch(&destination_uri),
212        "OPTIONS" => client.request(reqwest::Method::OPTIONS, &destination_uri),
213        method => {
214            // Try to parse custom methods
215            match reqwest::Method::from_bytes(method.as_bytes()) {
216                Ok(custom_method) => client.request(custom_method, &destination_uri),
217                Err(_) => {
218                    let err = WiseGateError::MethodBlocked(format!("{} (unsupported)", method));
219                    return Ok(create_error_response(err.status_code(), err.user_message()));
220                }
221            }
222        }
223    };
224
225    // Add headers (excluding host and content-length)
226    for (name, value) in parts.headers.iter() {
227        if name != "host"
228            && name != "content-length"
229            && let Ok(header_value) = value.to_str()
230        {
231            req_builder = req_builder.header(name.as_str(), header_value);
232        }
233    }
234
235    // Add body if not empty
236    if !body_bytes.is_empty() {
237        req_builder = req_builder.body(body_bytes.to_vec());
238    }
239
240    // Send request
241    match req_builder.send().await {
242        Ok(response) => {
243            let status = response.status();
244            let resp_headers = response.headers().clone();
245
246            match response.bytes().await {
247                Ok(body_bytes) => {
248                    let mut hyper_response = match Response::builder()
249                        .status(status.as_u16())
250                        .body(Full::new(body_bytes))
251                    {
252                        Ok(resp) => resp,
253                        Err(e) => {
254                            let err = WiseGateError::ProxyError(format!(
255                                "Failed to build response: {}",
256                                e
257                            ));
258                            return Ok(create_error_response(
259                                err.status_code(),
260                                err.user_message(),
261                            ));
262                        }
263                    };
264
265                    // Copy response headers (skip hop-by-hop headers)
266                    for (name, value) in resp_headers.iter() {
267                        let header_name = name.as_str().to_lowercase();
268                        // Skip hop-by-hop headers that shouldn't be forwarded
269                        if !headers::is_hop_by_hop(&header_name)
270                            && let (Ok(hyper_name), Ok(hyper_value)) = (
271                                hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
272                                hyper::header::HeaderValue::from_bytes(value.as_bytes()),
273                            )
274                        {
275                            hyper_response.headers_mut().insert(hyper_name, hyper_value);
276                        }
277                    }
278
279                    Ok(hyper_response)
280                }
281                Err(e) => {
282                    let err = WiseGateError::BodyReadError(format!("response: {}", e));
283                    Ok(create_error_response(err.status_code(), err.user_message()))
284                }
285            }
286        }
287        Err(err) => {
288            // More specific error handling using WiseGateError
289            let wise_err = if err.is_timeout() {
290                WiseGateError::UpstreamTimeout(err.to_string())
291            } else if err.is_connect() {
292                WiseGateError::UpstreamConnectionFailed(err.to_string())
293            } else {
294                WiseGateError::ProxyError(err.to_string())
295            };
296            Ok(create_error_response(
297                wise_err.status_code(),
298                wise_err.user_message(),
299            ))
300        }
301    }
302}
303
304/// Creates a standardized error response.
305///
306/// Builds an HTTP response with the given status code and plain text message.
307/// Falls back to a minimal 500 response if building fails (should never happen
308/// with valid StatusCode).
309///
310/// # Arguments
311///
312/// * `status` - The HTTP status code for the response
313/// * `message` - The plain text error message body
314///
315/// # Returns
316///
317/// An HTTP response with `content-type: text/plain` header.
318///
319/// # Example
320///
321/// ```
322/// use wisegate_core::request_handler::create_error_response;
323/// use hyper::StatusCode;
324///
325/// let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
326/// assert_eq!(response.status(), StatusCode::NOT_FOUND);
327/// ```
328pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
329    Response::builder()
330        .status(status)
331        .header("content-type", "text/plain")
332        .body(Full::new(bytes::Bytes::from(message.to_string())))
333        .unwrap_or_else(|_| {
334            // Fallback response if builder fails (extremely unlikely)
335            Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
336        })
337}
338
339/// Creates a 401 Unauthorized response with WWW-Authenticate header.
340///
341/// Used when Basic Authentication is enabled and the request is not authenticated
342/// or has invalid credentials.
343///
344/// # Arguments
345///
346/// * `realm` - The authentication realm to display in the browser dialog
347///
348/// # Returns
349///
350/// An HTTP 401 response with `WWW-Authenticate: Basic realm="..."` header.
351pub fn create_unauthorized_response(realm: &str) -> Response<Full<bytes::Bytes>> {
352    Response::builder()
353        .status(StatusCode::UNAUTHORIZED)
354        .header(
355            headers::WWW_AUTHENTICATE,
356            format!("Basic realm=\"{}\"", realm),
357        )
358        .header("content-type", "text/plain")
359        .body(Full::new(bytes::Bytes::from("401 Unauthorized")))
360        .unwrap_or_else(|_| Response::new(Full::new(bytes::Bytes::from("401 Unauthorized"))))
361}
362
363/// Check if URL path contains any blocked patterns
364/// Decodes URL-encoded characters to prevent bypass via encoding (e.g., .ph%70 for .php)
365fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
366    let blocked_patterns = config.blocked_patterns();
367    if blocked_patterns.is_empty() {
368        return false;
369    }
370
371    // Decode URL-encoded path to prevent bypass attacks
372    let decoded_path = url_decode(path);
373
374    // Check against both original and decoded path
375    blocked_patterns
376        .iter()
377        .any(|pattern| path.contains(pattern) || decoded_path.contains(pattern))
378}
379
380/// Decode URL-encoded string (percent-encoding)
381/// Handles common bypass attempts like %2e for '.', %70 for 'p', etc.
382/// Properly handles multi-byte UTF-8 sequences.
383fn url_decode(input: &str) -> String {
384    let mut bytes = Vec::with_capacity(input.len());
385    let mut chars = input.chars().peekable();
386
387    while let Some(c) = chars.next() {
388        if c == '%' {
389            // Try to read two hex digits
390            let hex: String = chars.by_ref().take(2).collect();
391            if hex.len() == 2
392                && let Ok(byte) = u8::from_str_radix(&hex, 16)
393            {
394                bytes.push(byte);
395                continue;
396            }
397            // If decoding failed, keep original characters
398            bytes.extend_from_slice(b"%");
399            bytes.extend_from_slice(hex.as_bytes());
400        } else {
401            // Regular character - encode as UTF-8 bytes
402            let mut buf = [0u8; 4];
403            bytes.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
404        }
405    }
406
407    // Convert bytes to string, replacing invalid UTF-8 with replacement character
408    String::from_utf8_lossy(&bytes).into_owned()
409}
410
411/// Check if HTTP method is blocked
412fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
413    let blocked_methods = config.blocked_methods();
414    blocked_methods
415        .iter()
416        .any(|blocked_method| blocked_method == &method.to_uppercase())
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::test_utils::TestConfig;
423    use http_body_util::BodyExt;
424
425    // ===========================================
426    // url_decode tests
427    // ===========================================
428
429    #[test]
430    fn test_url_decode_no_encoding() {
431        assert_eq!(url_decode("/path/to/file"), "/path/to/file");
432        assert_eq!(url_decode("hello"), "hello");
433        assert_eq!(url_decode(""), "");
434    }
435
436    #[test]
437    fn test_url_decode_simple_encoding() {
438        assert_eq!(url_decode("%20"), " ");
439        assert_eq!(url_decode("hello%20world"), "hello world");
440        assert_eq!(url_decode("%2F"), "/");
441    }
442
443    #[test]
444    fn test_url_decode_dot_encoding() {
445        // Common bypass attempts
446        assert_eq!(url_decode("%2e"), ".");
447        assert_eq!(url_decode("%2E"), ".");
448        assert_eq!(url_decode(".%2ephp"), "..php");
449    }
450
451    #[test]
452    fn test_url_decode_php_bypass() {
453        // Attacker tries to bypass .php blocking
454        assert_eq!(url_decode(".ph%70"), ".php");
455        assert_eq!(url_decode("%2ephp"), ".php");
456        assert_eq!(url_decode(".%70%68%70"), ".php");
457    }
458
459    #[test]
460    fn test_url_decode_env_bypass() {
461        // Attacker tries to bypass .env blocking
462        assert_eq!(url_decode(".%65nv"), ".env");
463        assert_eq!(url_decode("%2eenv"), ".env");
464        assert_eq!(url_decode("%2e%65%6e%76"), ".env");
465    }
466
467    #[test]
468    fn test_url_decode_multiple_encodings() {
469        assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
470    }
471
472    #[test]
473    fn test_url_decode_invalid_hex() {
474        // Invalid hex should be preserved
475        assert_eq!(url_decode("%GG"), "%GG");
476        assert_eq!(url_decode("%"), "%");
477        assert_eq!(url_decode("%2"), "%2");
478        assert_eq!(url_decode("%ZZ"), "%ZZ");
479    }
480
481    #[test]
482    fn test_url_decode_mixed_content() {
483        assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
484        assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
485    }
486
487    #[test]
488    fn test_url_decode_unicode() {
489        // UTF-8 encoded characters
490        assert_eq!(url_decode("%C3%A9"), "é"); // é in UTF-8
491        assert_eq!(url_decode("caf%C3%A9"), "café");
492    }
493
494    // ===========================================
495    // is_url_pattern_blocked tests
496    // ===========================================
497
498    #[test]
499    fn test_url_pattern_blocked_simple() {
500        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
501
502        assert!(is_url_pattern_blocked("/file.php", &config));
503        assert!(is_url_pattern_blocked("/.env", &config));
504        assert!(is_url_pattern_blocked("/path/to/file.php", &config));
505    }
506
507    #[test]
508    fn test_url_pattern_not_blocked() {
509        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
510
511        assert!(!is_url_pattern_blocked("/file.html", &config));
512        assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
513        assert!(!is_url_pattern_blocked("/", &config));
514    }
515
516    #[test]
517    fn test_url_pattern_blocked_empty_patterns() {
518        let config = TestConfig::new();
519
520        assert!(!is_url_pattern_blocked("/file.php", &config));
521        assert!(!is_url_pattern_blocked("/.env", &config));
522    }
523
524    #[test]
525    fn test_url_pattern_blocked_bypass_attempt() {
526        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
527
528        // URL-encoded bypass attempts should still be blocked
529        assert!(is_url_pattern_blocked("/.ph%70", &config)); // .php
530        assert!(is_url_pattern_blocked("/%2eenv", &config)); // .env
531        assert!(is_url_pattern_blocked("/adm%69n", &config)); // admin
532    }
533
534    #[test]
535    fn test_url_pattern_blocked_double_encoding_attempt() {
536        let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
537
538        // Single encoding should be caught
539        assert!(is_url_pattern_blocked("/.ph%70", &config));
540    }
541
542    #[test]
543    fn test_url_pattern_blocked_case_sensitive() {
544        let config = TestConfig::new().with_blocked_patterns(vec![".PHP"]);
545
546        // Pattern matching is case-sensitive
547        assert!(is_url_pattern_blocked("/file.PHP", &config));
548        assert!(!is_url_pattern_blocked("/file.php", &config)); // Different case
549    }
550
551    #[test]
552    fn test_url_pattern_blocked_partial_match() {
553        let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
554
555        assert!(is_url_pattern_blocked("/admin/panel", &config));
556        assert!(is_url_pattern_blocked("/path/admin", &config));
557        assert!(is_url_pattern_blocked("/administrator", &config)); // Contains "admin"
558    }
559
560    // ===========================================
561    // is_method_blocked tests
562    // ===========================================
563
564    #[test]
565    fn test_method_blocked() {
566        let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
567
568        assert!(is_method_blocked("TRACE", &config));
569        assert!(is_method_blocked("CONNECT", &config));
570    }
571
572    #[test]
573    fn test_method_not_blocked() {
574        let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
575
576        assert!(!is_method_blocked("GET", &config));
577        assert!(!is_method_blocked("POST", &config));
578        assert!(!is_method_blocked("PUT", &config));
579        assert!(!is_method_blocked("DELETE", &config));
580    }
581
582    #[test]
583    fn test_method_blocked_empty_list() {
584        let config = TestConfig::new();
585
586        assert!(!is_method_blocked("TRACE", &config));
587        assert!(!is_method_blocked("GET", &config));
588    }
589
590    #[test]
591    fn test_method_blocked_case_insensitive() {
592        let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
593
594        assert!(is_method_blocked("TRACE", &config));
595        assert!(is_method_blocked("trace", &config));
596        assert!(is_method_blocked("Trace", &config));
597    }
598
599    // ===========================================
600    // create_error_response tests
601    // ===========================================
602
603    #[test]
604    fn test_create_error_response_status() {
605        let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
606        assert_eq!(response.status(), StatusCode::NOT_FOUND);
607
608        let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
609        assert_eq!(response.status(), StatusCode::FORBIDDEN);
610
611        let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
612        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
613    }
614
615    #[test]
616    fn test_create_error_response_content_type() {
617        let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
618        assert_eq!(
619            response.headers().get("content-type").unwrap(),
620            "text/plain"
621        );
622    }
623
624    #[tokio::test]
625    async fn test_create_error_response_body() {
626        let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
627        let body = response.into_body().collect().await.unwrap().to_bytes();
628        assert_eq!(body, "Resource not found");
629    }
630
631    #[tokio::test]
632    async fn test_create_error_response_empty_message() {
633        let response = create_error_response(StatusCode::NO_CONTENT, "");
634        let body = response.into_body().collect().await.unwrap().to_bytes();
635        assert_eq!(body, "");
636    }
637}