wisegate_core/
request_handler.rs

1//! HTTP request handling and proxying.
2//!
3//! This module contains the core request handling logic for the reverse proxy,
4//! including IP validation, rate limiting, URL pattern blocking, and request forwarding.
5//!
6//! # Architecture
7//!
8//! The request handling flow:
9//! 1. Extract and validate client IP from proxy headers
10//! 2. Check if IP is blocked
11//! 3. Check for blocked URL patterns
12//! 4. Check for blocked HTTP methods
13//! 5. Apply rate limiting
14//! 6. Forward request to upstream service
15//!
16//! # Connection Pooling
17//!
18//! The module accepts a shared [`reqwest::Client`] for HTTP connection pooling,
19//! which should be configured by the caller with appropriate timeouts.
20
21use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::types::{ConfigProvider, RateLimiter};
27use crate::{headers, ip_filter, rate_limiter};
28
29/// Handles an incoming HTTP request through the proxy pipeline.
30///
31/// This is the main entry point for request processing. It performs:
32/// - Client IP extraction and validation from proxy headers
33/// - IP blocking checks
34/// - URL pattern blocking (e.g., `.php`, `.env` files)
35/// - HTTP method blocking (e.g., `PUT`, `DELETE`)
36/// - Rate limiting per client IP
37/// - Request forwarding to the upstream service
38///
39/// # Arguments
40///
41/// * `req` - The incoming HTTP request
42/// * `forward_host` - The upstream host to forward requests to
43/// * `forward_port` - The upstream port to forward requests to
44/// * `limiter` - The shared rate limiter instance
45/// * `config` - Configuration provider for all settings
46/// * `http_client` - HTTP client for forwarding requests (with connection pooling)
47///
48/// # Returns
49///
50/// Always returns `Ok` with either:
51/// - A successful proxied response from upstream
52/// - An error response (403, 404, 405, 429, 502, etc.)
53///
54/// # Security
55///
56/// In strict mode (when proxy allowlist is configured), requests
57/// without valid proxy headers are rejected with 403 Forbidden.
58pub async fn handle_request<C: ConfigProvider>(
59    req: Request<Incoming>,
60    forward_host: String,
61    forward_port: u16,
62    limiter: RateLimiter,
63    config: Arc<C>,
64    http_client: reqwest::Client,
65) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
66    // Extract and validate real client IP
67    let real_client_ip =
68        match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
69            Some(ip) => ip,
70            None => {
71                // In permissive mode (no allowlist configured), we couldn't extract IP from headers
72                // Use placeholder IP and continue with non-IP-based security features only
73                if config.allowed_proxy_ips().is_none() {
74                    "unknown".to_string()
75                } else {
76                    // If allowlist is configured but validation failed, reject the request
77                    return Ok(create_error_response(
78                        StatusCode::FORBIDDEN,
79                        "Invalid request: missing or invalid proxy headers",
80                    ));
81                }
82            }
83        };
84
85    // Check if IP is blocked (skip if IP is unknown)
86    if real_client_ip != "unknown" && ip_filter::is_ip_blocked(&real_client_ip, config.as_ref()) {
87        return Ok(create_error_response(
88            StatusCode::FORBIDDEN,
89            "IP address is blocked",
90        ));
91    }
92
93    // Check for blocked URL patterns
94    let request_path = req.uri().path();
95    if is_url_pattern_blocked(request_path, config.as_ref()) {
96        return Ok(create_error_response(StatusCode::NOT_FOUND, "Not Found"));
97    }
98
99    // Check for blocked HTTP methods
100    let request_method = req.method().as_str();
101    if is_method_blocked(request_method, config.as_ref()) {
102        return Ok(create_error_response(
103            StatusCode::METHOD_NOT_ALLOWED,
104            "HTTP method not allowed",
105        ));
106    }
107
108    // Apply rate limiting (skip if IP is unknown)
109    if real_client_ip != "unknown"
110        && !rate_limiter::check_rate_limit(&limiter, &real_client_ip, config.as_ref()).await
111    {
112        return Ok(create_error_response(
113            StatusCode::TOO_MANY_REQUESTS,
114            "Rate limit exceeded",
115        ));
116    }
117
118    // Add X-Real-IP header for upstream service (only if we have a real IP)
119    let mut req = req;
120    if real_client_ip != "unknown"
121        && let Ok(header_value) = real_client_ip.parse()
122    {
123        req.headers_mut().insert("x-real-ip", header_value);
124    }
125
126    // Forward the request
127    forward_request(
128        req,
129        &forward_host,
130        forward_port,
131        config.as_ref(),
132        &http_client,
133    )
134    .await
135}
136
137/// Forward request to upstream service
138async fn forward_request(
139    req: Request<Incoming>,
140    host: &str,
141    port: u16,
142    config: &impl ConfigProvider,
143    http_client: &reqwest::Client,
144) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
145    let proxy_config = config.proxy_config();
146    let (parts, body) = req.into_parts();
147    let body_bytes = match body.collect().await {
148        Ok(bytes) => {
149            let collected_bytes = bytes.to_bytes();
150
151            // Check body size limit
152            if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
153            {
154                return Ok(create_error_response(
155                    StatusCode::PAYLOAD_TOO_LARGE,
156                    "Request body too large",
157                ));
158            }
159
160            collected_bytes
161        }
162        Err(_) => {
163            return Ok(create_error_response(
164                StatusCode::BAD_REQUEST,
165                "Failed to read request body",
166            ));
167        }
168    };
169
170    forward_with_reqwest(parts, body_bytes, host, port, http_client).await
171}
172
173/// Shared forwarding logic using reqwest with connection pooling
174async fn forward_with_reqwest(
175    parts: hyper::http::request::Parts,
176    body_bytes: bytes::Bytes,
177    host: &str,
178    port: u16,
179    client: &reqwest::Client,
180) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
181    // Construct destination URI
182    let destination_uri = format!(
183        "http://{}:{}{}",
184        host,
185        port,
186        parts.uri.path_and_query().map_or("", |pq| pq.as_str())
187    );
188
189    // Build the request with method support for all HTTP verbs
190    let mut req_builder = match parts.method.as_str() {
191        "GET" => client.get(&destination_uri),
192        "POST" => client.post(&destination_uri),
193        "PUT" => client.put(&destination_uri),
194        "DELETE" => client.delete(&destination_uri),
195        "HEAD" => client.head(&destination_uri),
196        "PATCH" => client.patch(&destination_uri),
197        "OPTIONS" => client.request(reqwest::Method::OPTIONS, &destination_uri),
198        method => {
199            // Try to parse custom methods
200            match reqwest::Method::from_bytes(method.as_bytes()) {
201                Ok(custom_method) => client.request(custom_method, &destination_uri),
202                Err(_) => {
203                    return Ok(create_error_response(
204                        StatusCode::METHOD_NOT_ALLOWED,
205                        "HTTP method not supported",
206                    ));
207                }
208            }
209        }
210    };
211
212    // Add headers (excluding host and content-length)
213    for (name, value) in parts.headers.iter() {
214        if name != "host"
215            && name != "content-length"
216            && let Ok(header_value) = value.to_str()
217        {
218            req_builder = req_builder.header(name.as_str(), header_value);
219        }
220    }
221
222    // Add body if not empty
223    if !body_bytes.is_empty() {
224        req_builder = req_builder.body(body_bytes.to_vec());
225    }
226
227    // Send request
228    match req_builder.send().await {
229        Ok(response) => {
230            let status = response.status();
231            let resp_headers = response.headers().clone();
232
233            match response.bytes().await {
234                Ok(body_bytes) => {
235                    let mut hyper_response = match Response::builder()
236                        .status(status.as_u16())
237                        .body(Full::new(body_bytes))
238                    {
239                        Ok(resp) => resp,
240                        Err(_) => {
241                            return Ok(create_error_response(
242                                StatusCode::INTERNAL_SERVER_ERROR,
243                                "Failed to build response",
244                            ));
245                        }
246                    };
247
248                    // Copy response headers (skip hop-by-hop headers)
249                    for (name, value) in resp_headers.iter() {
250                        let header_name = name.as_str().to_lowercase();
251                        // Skip hop-by-hop headers that shouldn't be forwarded
252                        if !headers::is_hop_by_hop(&header_name)
253                            && let (Ok(hyper_name), Ok(hyper_value)) = (
254                                hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
255                                hyper::header::HeaderValue::from_bytes(value.as_bytes()),
256                            )
257                        {
258                            hyper_response.headers_mut().insert(hyper_name, hyper_value);
259                        }
260                    }
261
262                    Ok(hyper_response)
263                }
264                Err(_) => Ok(create_error_response(
265                    StatusCode::BAD_GATEWAY,
266                    "Failed to read response body",
267                )),
268            }
269        }
270        Err(err) => {
271            // More specific error handling
272            if err.is_timeout() {
273                Ok(create_error_response(
274                    StatusCode::GATEWAY_TIMEOUT,
275                    "Upstream service timeout",
276                ))
277            } else if err.is_connect() {
278                Ok(create_error_response(
279                    StatusCode::BAD_GATEWAY,
280                    "Could not connect to upstream service",
281                ))
282            } else {
283                Ok(create_error_response(
284                    StatusCode::BAD_GATEWAY,
285                    "Upstream service error",
286                ))
287            }
288        }
289    }
290}
291
292/// Creates a standardized error response.
293///
294/// Builds an HTTP response with the given status code and plain text message.
295/// Falls back to a minimal 500 response if building fails (should never happen
296/// with valid StatusCode).
297///
298/// # Arguments
299///
300/// * `status` - The HTTP status code for the response
301/// * `message` - The plain text error message body
302///
303/// # Returns
304///
305/// An HTTP response with `content-type: text/plain` header.
306///
307/// # Example
308///
309/// ```
310/// use wisegate_core::request_handler::create_error_response;
311/// use hyper::StatusCode;
312///
313/// let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
314/// assert_eq!(response.status(), StatusCode::NOT_FOUND);
315/// ```
316pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
317    Response::builder()
318        .status(status)
319        .header("content-type", "text/plain")
320        .body(Full::new(bytes::Bytes::from(message.to_string())))
321        .unwrap_or_else(|_| {
322            // Fallback response if builder fails (extremely unlikely)
323            Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
324        })
325}
326
327/// Check if URL path contains any blocked patterns
328/// Decodes URL-encoded characters to prevent bypass via encoding (e.g., .ph%70 for .php)
329fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
330    let blocked_patterns = config.blocked_patterns();
331    if blocked_patterns.is_empty() {
332        return false;
333    }
334
335    // Decode URL-encoded path to prevent bypass attacks
336    let decoded_path = url_decode(path);
337
338    // Check against both original and decoded path
339    blocked_patterns
340        .iter()
341        .any(|pattern| path.contains(pattern) || decoded_path.contains(pattern))
342}
343
344/// Decode URL-encoded string (percent-encoding)
345/// Handles common bypass attempts like %2e for '.', %70 for 'p', etc.
346/// Properly handles multi-byte UTF-8 sequences.
347fn url_decode(input: &str) -> String {
348    let mut bytes = Vec::with_capacity(input.len());
349    let mut chars = input.chars().peekable();
350
351    while let Some(c) = chars.next() {
352        if c == '%' {
353            // Try to read two hex digits
354            let hex: String = chars.by_ref().take(2).collect();
355            if hex.len() == 2
356                && let Ok(byte) = u8::from_str_radix(&hex, 16)
357            {
358                bytes.push(byte);
359                continue;
360            }
361            // If decoding failed, keep original characters
362            bytes.extend_from_slice(b"%");
363            bytes.extend_from_slice(hex.as_bytes());
364        } else {
365            // Regular character - encode as UTF-8 bytes
366            let mut buf = [0u8; 4];
367            bytes.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
368        }
369    }
370
371    // Convert bytes to string, replacing invalid UTF-8 with replacement character
372    String::from_utf8_lossy(&bytes).into_owned()
373}
374
375/// Check if HTTP method is blocked
376fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
377    let blocked_methods = config.blocked_methods();
378    blocked_methods
379        .iter()
380        .any(|blocked_method| blocked_method == &method.to_uppercase())
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::test_utils::TestConfig;
387    use http_body_util::BodyExt;
388
389    // ===========================================
390    // url_decode tests
391    // ===========================================
392
393    #[test]
394    fn test_url_decode_no_encoding() {
395        assert_eq!(url_decode("/path/to/file"), "/path/to/file");
396        assert_eq!(url_decode("hello"), "hello");
397        assert_eq!(url_decode(""), "");
398    }
399
400    #[test]
401    fn test_url_decode_simple_encoding() {
402        assert_eq!(url_decode("%20"), " ");
403        assert_eq!(url_decode("hello%20world"), "hello world");
404        assert_eq!(url_decode("%2F"), "/");
405    }
406
407    #[test]
408    fn test_url_decode_dot_encoding() {
409        // Common bypass attempts
410        assert_eq!(url_decode("%2e"), ".");
411        assert_eq!(url_decode("%2E"), ".");
412        assert_eq!(url_decode(".%2ephp"), "..php");
413    }
414
415    #[test]
416    fn test_url_decode_php_bypass() {
417        // Attacker tries to bypass .php blocking
418        assert_eq!(url_decode(".ph%70"), ".php");
419        assert_eq!(url_decode("%2ephp"), ".php");
420        assert_eq!(url_decode(".%70%68%70"), ".php");
421    }
422
423    #[test]
424    fn test_url_decode_env_bypass() {
425        // Attacker tries to bypass .env blocking
426        assert_eq!(url_decode(".%65nv"), ".env");
427        assert_eq!(url_decode("%2eenv"), ".env");
428        assert_eq!(url_decode("%2e%65%6e%76"), ".env");
429    }
430
431    #[test]
432    fn test_url_decode_multiple_encodings() {
433        assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
434    }
435
436    #[test]
437    fn test_url_decode_invalid_hex() {
438        // Invalid hex should be preserved
439        assert_eq!(url_decode("%GG"), "%GG");
440        assert_eq!(url_decode("%"), "%");
441        assert_eq!(url_decode("%2"), "%2");
442        assert_eq!(url_decode("%ZZ"), "%ZZ");
443    }
444
445    #[test]
446    fn test_url_decode_mixed_content() {
447        assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
448        assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
449    }
450
451    #[test]
452    fn test_url_decode_unicode() {
453        // UTF-8 encoded characters
454        assert_eq!(url_decode("%C3%A9"), "é"); // é in UTF-8
455        assert_eq!(url_decode("caf%C3%A9"), "café");
456    }
457
458    // ===========================================
459    // is_url_pattern_blocked tests
460    // ===========================================
461
462    #[test]
463    fn test_url_pattern_blocked_simple() {
464        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
465
466        assert!(is_url_pattern_blocked("/file.php", &config));
467        assert!(is_url_pattern_blocked("/.env", &config));
468        assert!(is_url_pattern_blocked("/path/to/file.php", &config));
469    }
470
471    #[test]
472    fn test_url_pattern_not_blocked() {
473        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
474
475        assert!(!is_url_pattern_blocked("/file.html", &config));
476        assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
477        assert!(!is_url_pattern_blocked("/", &config));
478    }
479
480    #[test]
481    fn test_url_pattern_blocked_empty_patterns() {
482        let config = TestConfig::new();
483
484        assert!(!is_url_pattern_blocked("/file.php", &config));
485        assert!(!is_url_pattern_blocked("/.env", &config));
486    }
487
488    #[test]
489    fn test_url_pattern_blocked_bypass_attempt() {
490        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
491
492        // URL-encoded bypass attempts should still be blocked
493        assert!(is_url_pattern_blocked("/.ph%70", &config)); // .php
494        assert!(is_url_pattern_blocked("/%2eenv", &config)); // .env
495        assert!(is_url_pattern_blocked("/adm%69n", &config)); // admin
496    }
497
498    #[test]
499    fn test_url_pattern_blocked_double_encoding_attempt() {
500        let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
501
502        // Single encoding should be caught
503        assert!(is_url_pattern_blocked("/.ph%70", &config));
504    }
505
506    #[test]
507    fn test_url_pattern_blocked_case_sensitive() {
508        let config = TestConfig::new().with_blocked_patterns(vec![".PHP"]);
509
510        // Pattern matching is case-sensitive
511        assert!(is_url_pattern_blocked("/file.PHP", &config));
512        assert!(!is_url_pattern_blocked("/file.php", &config)); // Different case
513    }
514
515    #[test]
516    fn test_url_pattern_blocked_partial_match() {
517        let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
518
519        assert!(is_url_pattern_blocked("/admin/panel", &config));
520        assert!(is_url_pattern_blocked("/path/admin", &config));
521        assert!(is_url_pattern_blocked("/administrator", &config)); // Contains "admin"
522    }
523
524    // ===========================================
525    // is_method_blocked tests
526    // ===========================================
527
528    #[test]
529    fn test_method_blocked() {
530        let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
531
532        assert!(is_method_blocked("TRACE", &config));
533        assert!(is_method_blocked("CONNECT", &config));
534    }
535
536    #[test]
537    fn test_method_not_blocked() {
538        let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
539
540        assert!(!is_method_blocked("GET", &config));
541        assert!(!is_method_blocked("POST", &config));
542        assert!(!is_method_blocked("PUT", &config));
543        assert!(!is_method_blocked("DELETE", &config));
544    }
545
546    #[test]
547    fn test_method_blocked_empty_list() {
548        let config = TestConfig::new();
549
550        assert!(!is_method_blocked("TRACE", &config));
551        assert!(!is_method_blocked("GET", &config));
552    }
553
554    #[test]
555    fn test_method_blocked_case_insensitive() {
556        let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
557
558        assert!(is_method_blocked("TRACE", &config));
559        assert!(is_method_blocked("trace", &config));
560        assert!(is_method_blocked("Trace", &config));
561    }
562
563    // ===========================================
564    // create_error_response tests
565    // ===========================================
566
567    #[test]
568    fn test_create_error_response_status() {
569        let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
570        assert_eq!(response.status(), StatusCode::NOT_FOUND);
571
572        let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
573        assert_eq!(response.status(), StatusCode::FORBIDDEN);
574
575        let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
576        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
577    }
578
579    #[test]
580    fn test_create_error_response_content_type() {
581        let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
582        assert_eq!(
583            response.headers().get("content-type").unwrap(),
584            "text/plain"
585        );
586    }
587
588    #[tokio::test]
589    async fn test_create_error_response_body() {
590        let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
591        let body = response.into_body().collect().await.unwrap().to_bytes();
592        assert_eq!(body, "Resource not found");
593    }
594
595    #[tokio::test]
596    async fn test_create_error_response_empty_message() {
597        let response = create_error_response(StatusCode::NO_CONTENT, "");
598        let body = response.into_body().collect().await.unwrap().to_bytes();
599        assert_eq!(body, "");
600    }
601}