Skip to main content

wisegate_core/
request_handler.rs

1//! HTTP request handling and proxying.
2//!
3//! This module contains the core request handling logic for the reverse proxy,
4//! including IP validation, rate limiting, URL pattern blocking, and request forwarding.
5//!
6//! # Architecture
7//!
8//! The request handling flow:
9//! 1. Extract and validate client IP from proxy headers
10//! 2. Check if IP is blocked
11//! 3. Check for blocked URL patterns
12//! 4. Check for blocked HTTP methods
13//! 5. Apply rate limiting
14//! 6. Forward request to upstream service
15//!
16//! # Connection Pooling
17//!
18//! The module accepts a shared [`reqwest::Client`] for HTTP connection pooling,
19//! which should be configured by the caller with appropriate timeouts.
20
21use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::error::WiseGateError;
27use crate::types::{ConfigProvider, RateLimiter};
28use crate::{auth, headers, ip_filter, rate_limiter};
29
30/// Handles an incoming HTTP request through the proxy pipeline.
31///
32/// This is the main entry point for request processing. It performs:
33/// - Client IP extraction and validation from proxy headers
34/// - IP blocking checks
35/// - URL pattern blocking (e.g., `.php`, `.env` files)
36/// - HTTP method blocking (e.g., `PUT`, `DELETE`)
37/// - Rate limiting per client IP
38/// - Request forwarding to the upstream service
39///
40/// # Arguments
41///
42/// * `req` - The incoming HTTP request
43/// * `forward_host` - The upstream host to forward requests to
44/// * `forward_port` - The upstream port to forward requests to
45/// * `limiter` - The shared rate limiter instance
46/// * `config` - Configuration provider for all settings
47/// * `http_client` - HTTP client for forwarding requests (with connection pooling)
48///
49/// # Returns
50///
51/// Always returns `Ok` with either:
52/// - A successful proxied response from upstream
53/// - An error response (403, 404, 405, 429, 502, etc.)
54///
55/// # Runtime
56///
57/// This is an async function backed by `reqwest`/`tokio`; it must be awaited
58/// from inside a Tokio runtime (`#[tokio::main]`, `Runtime::block_on`, etc.).
59///
60/// # Security
61///
62/// * **Strict mode** (proxy allowlist configured): both `X-Forwarded-For`
63///   and `Forwarded` headers must be present. The proxy IP is taken from the
64///   `Forwarded` header's `by=` field and matched against the allowlist; the
65///   client IP is taken from the last entry of `X-Forwarded-For`. Requests
66///   missing either header, or whose proxy is not in the allowlist, are
67///   rejected with `400 Bad Request`.
68/// * **Permissive mode** (no allowlist): if a request supplies an
69///   `X-Forwarded-For` or `Forwarded` header, the parsed IP is trusted as the
70///   real client IP and re-emitted to the upstream as `X-Real-IP`. An
71///   attacker can spoof this value by sending fake headers — only enable
72///   permissive mode when the proxy sits behind another layer that strips or
73///   normalises these headers.
74/// * Any client-supplied `X-Real-IP` header is stripped before processing.
75/// * The `Authorization` header is stripped before forwarding whenever
76///   wisegate performed authentication (see
77///   [`crate::AuthenticationProvider::forward_authorization_header`]).
78pub async fn handle_request<C: ConfigProvider>(
79    req: Request<Incoming>,
80    forward_host: Arc<str>,
81    forward_port: u16,
82    limiter: RateLimiter,
83    config: Arc<C>,
84    http_client: reqwest::Client,
85) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
86    // Extract and validate real client IP
87    let real_client_ip: Option<String> =
88        match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
89            Some(ip) => Some(ip),
90            None => {
91                if config.allowed_proxy_ips().is_none() {
92                    // Permissive mode: continue without IP-based security
93                    None
94                } else {
95                    // Strict mode: reject when proxy validation fails
96                    let err = WiseGateError::InvalidIp("missing or invalid proxy headers".into());
97                    return Ok(create_error_response(err.status_code(), err.user_message()));
98                }
99            }
100        };
101
102    // Check if IP is blocked
103    if let Some(ref ip) = real_client_ip
104        && ip_filter::is_ip_blocked(ip, config.as_ref())
105    {
106        let err = WiseGateError::IpBlocked(ip.clone());
107        return Ok(create_error_response(err.status_code(), err.user_message()));
108    }
109
110    // Check for blocked URL patterns
111    let request_path = req.uri().path();
112    if is_url_pattern_blocked(request_path, config.as_ref()) {
113        let err = WiseGateError::PatternBlocked(request_path.to_string());
114        return Ok(create_error_response(err.status_code(), err.user_message()));
115    }
116
117    // Check for blocked HTTP methods
118    let request_method = req.method().as_str();
119    if is_method_blocked(request_method, config.as_ref()) {
120        let err = WiseGateError::MethodBlocked(request_method.to_string());
121        return Ok(create_error_response(err.status_code(), err.user_message()));
122    }
123
124    // Check Authentication if enabled (Basic Auth and/or Bearer Token)
125    // Logic: if both are configured, either one passing is sufficient
126    if config.is_auth_enabled() {
127        let auth_header = req
128            .headers()
129            .get(headers::AUTHORIZATION)
130            .and_then(|v| v.to_str().ok());
131
132        let basic_auth_enabled = config.is_basic_auth_enabled();
133        let bearer_auth_enabled = config.is_bearer_auth_enabled();
134
135        let basic_auth_passed =
136            basic_auth_enabled && auth::check_basic_auth(auth_header, config.auth_credentials());
137        let bearer_auth_passed =
138            bearer_auth_enabled && auth::check_bearer_token(auth_header, config.bearer_token());
139
140        // Authentication fails if neither method passed
141        if !basic_auth_passed && !bearer_auth_passed {
142            return Ok(create_unauthorized_response(config.auth_realm()));
143        }
144    }
145
146    // Apply rate limiting (only when IP is known)
147    if let Some(ref ip) = real_client_ip
148        && !rate_limiter::check_rate_limit(&limiter, ip, config.as_ref()).await
149    {
150        let err = WiseGateError::RateLimitExceeded(ip.clone());
151        return Ok(create_error_response(err.status_code(), err.user_message()));
152    }
153
154    // Strip any client-supplied X-Real-IP first to prevent spoofing, then insert
155    // the validated one (if any) so upstream sees only wisegate's value.
156    let mut req = req;
157    req.headers_mut().remove(headers::X_REAL_IP);
158    if let Some(ref ip) = real_client_ip
159        && let Ok(header_value) = ip.parse()
160    {
161        req.headers_mut().insert(headers::X_REAL_IP, header_value);
162    }
163
164    // Forward the request
165    forward_request(
166        req,
167        &forward_host,
168        forward_port,
169        config.as_ref(),
170        &http_client,
171    )
172    .await
173}
174
175/// Forward request to upstream service
176async fn forward_request(
177    req: Request<Incoming>,
178    host: &str,
179    port: u16,
180    config: &impl ConfigProvider,
181    http_client: &reqwest::Client,
182) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
183    let proxy_config = config.proxy_config();
184    let (parts, body) = req.into_parts();
185
186    // Early rejection based on Content-Length header to prevent memory exhaustion
187    if proxy_config.max_body_size > 0
188        && let Some(content_length) = parts
189            .headers
190            .get(headers::CONTENT_LENGTH)
191            .and_then(|v| v.to_str().ok())
192            .and_then(|v| v.parse::<usize>().ok())
193        && content_length > proxy_config.max_body_size
194    {
195        let err = WiseGateError::BodyTooLarge {
196            size: content_length,
197            max: proxy_config.max_body_size,
198        };
199        return Ok(create_error_response(err.status_code(), err.user_message()));
200    }
201
202    let body_bytes = match body.collect().await {
203        Ok(bytes) => {
204            let collected_bytes = bytes.to_bytes();
205
206            // Check actual body size (Content-Length may be absent or inaccurate)
207            if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
208            {
209                let err = WiseGateError::BodyTooLarge {
210                    size: collected_bytes.len(),
211                    max: proxy_config.max_body_size,
212                };
213                return Ok(create_error_response(err.status_code(), err.user_message()));
214            }
215
216            collected_bytes
217        }
218        Err(e) => {
219            let err = WiseGateError::BodyReadError(e.to_string());
220            return Ok(create_error_response(err.status_code(), err.user_message()));
221        }
222    };
223
224    let strip_auth = config.is_auth_enabled() && !config.forward_authorization_header();
225    forward_with_reqwest(parts, body_bytes, host, port, http_client, strip_auth).await
226}
227
228/// Shared forwarding logic using reqwest with connection pooling
229async fn forward_with_reqwest(
230    parts: hyper::http::request::Parts,
231    body_bytes: bytes::Bytes,
232    host: &str,
233    port: u16,
234    client: &reqwest::Client,
235    strip_auth: bool,
236) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
237    // Construct destination URI
238    let destination_uri = format!(
239        "http://{}:{}{}",
240        host,
241        port,
242        parts.uri.path_and_query().map_or("", |pq| pq.as_str())
243    );
244
245    // Build the request with the original HTTP method
246    let method = match reqwest::Method::from_bytes(parts.method.as_str().as_bytes()) {
247        Ok(m) => m,
248        Err(_) => {
249            let err =
250                WiseGateError::MethodBlocked(format!("{} (unsupported)", parts.method.as_str()));
251            return Ok(create_error_response(err.status_code(), err.user_message()));
252        }
253    };
254    let mut req_builder = client.request(method, &destination_uri);
255
256    // Add headers (excluding host, content-length, and hop-by-hop headers per RFC 7230).
257    // When `strip_auth` is set, also drop Authorization so the upstream cannot
258    // re-validate or log credentials wisegate just consumed.
259    for (name, value) in parts.headers.iter() {
260        if name != headers::HOST
261            && name != headers::CONTENT_LENGTH
262            && !(strip_auth && name == headers::AUTHORIZATION)
263            && !headers::is_hop_by_hop(name.as_str())
264            && let Ok(header_value) = value.to_str()
265        {
266            req_builder = req_builder.header(name.as_str(), header_value);
267        }
268    }
269
270    // Add body if not empty
271    if !body_bytes.is_empty() {
272        req_builder = req_builder.body(body_bytes);
273    }
274
275    // Send request
276    match req_builder.send().await {
277        Ok(response) => {
278            let status = response.status();
279            let resp_headers = response.headers().clone();
280
281            match response.bytes().await {
282                Ok(body_bytes) => {
283                    let mut hyper_response = match Response::builder()
284                        .status(status.as_u16())
285                        .body(Full::new(body_bytes))
286                    {
287                        Ok(resp) => resp,
288                        Err(e) => {
289                            let err = WiseGateError::ProxyError(format!(
290                                "Failed to build response: {}",
291                                e
292                            ));
293                            return Ok(create_error_response(
294                                err.status_code(),
295                                err.user_message(),
296                            ));
297                        }
298                    };
299
300                    // Copy response headers (skip hop-by-hop headers)
301                    for (name, value) in resp_headers.iter() {
302                        // Skip hop-by-hop headers that shouldn't be forwarded
303                        if !headers::is_hop_by_hop(name.as_str())
304                            && let (Ok(hyper_name), Ok(hyper_value)) = (
305                                hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
306                                hyper::header::HeaderValue::from_bytes(value.as_bytes()),
307                            )
308                        {
309                            hyper_response.headers_mut().insert(hyper_name, hyper_value);
310                        }
311                    }
312
313                    Ok(hyper_response)
314                }
315                Err(e) => {
316                    let err = WiseGateError::BodyReadError(format!("response: {}", e));
317                    Ok(create_error_response(err.status_code(), err.user_message()))
318                }
319            }
320        }
321        Err(err) => {
322            // More specific error handling using WiseGateError
323            let wise_err = if err.is_timeout() {
324                WiseGateError::UpstreamTimeout(err.to_string())
325            } else if err.is_connect() {
326                WiseGateError::UpstreamConnectionFailed(err.to_string())
327            } else {
328                WiseGateError::ProxyError(err.to_string())
329            };
330            Ok(create_error_response(
331                wise_err.status_code(),
332                wise_err.user_message(),
333            ))
334        }
335    }
336}
337
338/// Creates a standardized error response.
339///
340/// Builds an HTTP response with the given status code and plain text message.
341/// Falls back to a minimal 500 response if building fails (should never happen
342/// with valid StatusCode).
343///
344/// # Arguments
345///
346/// * `status` - The HTTP status code for the response
347/// * `message` - The plain text error message body
348///
349/// # Returns
350///
351/// An HTTP response with `content-type: text/plain` header.
352///
353/// # Example
354///
355/// ```
356/// use wisegate_core::request_handler::create_error_response;
357/// use hyper::StatusCode;
358///
359/// let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
360/// assert_eq!(response.status(), StatusCode::NOT_FOUND);
361/// ```
362pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
363    Response::builder()
364        .status(status)
365        .header(headers::CONTENT_TYPE, "text/plain")
366        .body(Full::new(bytes::Bytes::from(message.to_string())))
367        .unwrap_or_else(|_| {
368            // Fallback response if builder fails (extremely unlikely)
369            Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
370        })
371}
372
373/// Creates a 401 Unauthorized response with WWW-Authenticate header.
374///
375/// Used when Basic Authentication is enabled and the request is not authenticated
376/// or has invalid credentials.
377///
378/// # Arguments
379///
380/// * `realm` - The authentication realm to display in the browser dialog
381///
382/// # Returns
383///
384/// An HTTP 401 response with `WWW-Authenticate: Basic realm="..."` header.
385pub fn create_unauthorized_response(realm: &str) -> Response<Full<bytes::Bytes>> {
386    // Sanitize realm: escape backslashes and quotes per RFC 7235 quoted-string
387    let sanitized_realm = realm.replace('\\', "\\\\").replace('"', "\\\"");
388    Response::builder()
389        .status(StatusCode::UNAUTHORIZED)
390        .header(
391            headers::WWW_AUTHENTICATE,
392            format!("Basic realm=\"{}\"", sanitized_realm),
393        )
394        .header(headers::CONTENT_TYPE, "text/plain")
395        .body(Full::new(bytes::Bytes::from("401 Unauthorized")))
396        .unwrap_or_else(|_| Response::new(Full::new(bytes::Bytes::from("401 Unauthorized"))))
397}
398
399/// Check if URL path contains any blocked patterns
400/// Decodes URL-encoded characters to prevent bypass via encoding (e.g., .ph%70 for .php)
401fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
402    let blocked_patterns = config.blocked_patterns();
403    if blocked_patterns.is_empty() {
404        return false;
405    }
406
407    // Decode URL-encoded path to prevent bypass attacks
408    let decoded_path = url_decode(path);
409    let has_encoding = decoded_path != path;
410
411    // Case-insensitive matching to prevent bypass via case variation
412    let path_lower = path.to_lowercase();
413    // Only allocate decoded lowercase if URL actually contained percent-encoding
414    let decoded_lower = if has_encoding {
415        Some(decoded_path.to_lowercase())
416    } else {
417        None
418    };
419
420    // Patterns are expected to be pre-normalized to lowercase
421    blocked_patterns.iter().any(|pattern| {
422        path_lower.contains(pattern.as_str())
423            || decoded_lower
424                .as_ref()
425                .is_some_and(|dl| dl.contains(pattern.as_str()))
426    })
427}
428
429/// Decode URL-encoded string (percent-encoding)
430/// Handles common bypass attempts like %2e for '.', %70 for 'p', etc.
431/// Properly handles multi-byte UTF-8 sequences.
432fn url_decode(input: &str) -> String {
433    let mut bytes = Vec::with_capacity(input.len());
434    let input_bytes = input.as_bytes();
435    let mut i = 0;
436
437    while i < input_bytes.len() {
438        if input_bytes[i] == b'%' && i + 2 < input_bytes.len() {
439            // Try to decode two hex digits without allocating
440            let hi = hex_digit(input_bytes[i + 1]);
441            let lo = hex_digit(input_bytes[i + 2]);
442            if let (Some(h), Some(l)) = (hi, lo) {
443                bytes.push(h << 4 | l);
444                i += 3;
445                continue;
446            }
447        }
448        bytes.push(input_bytes[i]);
449        i += 1;
450    }
451
452    // Convert bytes to string, replacing invalid UTF-8 with replacement character
453    String::from_utf8_lossy(&bytes).into_owned()
454}
455
456/// Converts an ASCII hex digit to its numeric value.
457fn hex_digit(b: u8) -> Option<u8> {
458    match b {
459        b'0'..=b'9' => Some(b - b'0'),
460        b'a'..=b'f' => Some(b - b'a' + 10),
461        b'A'..=b'F' => Some(b - b'A' + 10),
462        _ => None,
463    }
464}
465
466/// Check if HTTP method is blocked
467fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
468    let blocked_methods = config.blocked_methods();
469    blocked_methods
470        .iter()
471        .any(|blocked_method| blocked_method.eq_ignore_ascii_case(method))
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::test_utils::TestConfig;
478    use http_body_util::BodyExt;
479
480    // ===========================================
481    // url_decode tests
482    // ===========================================
483
484    #[test]
485    fn test_url_decode_no_encoding() {
486        assert_eq!(url_decode("/path/to/file"), "/path/to/file");
487        assert_eq!(url_decode("hello"), "hello");
488        assert_eq!(url_decode(""), "");
489    }
490
491    #[test]
492    fn test_url_decode_simple_encoding() {
493        assert_eq!(url_decode("%20"), " ");
494        assert_eq!(url_decode("hello%20world"), "hello world");
495        assert_eq!(url_decode("%2F"), "/");
496    }
497
498    #[test]
499    fn test_url_decode_dot_encoding() {
500        // Common bypass attempts
501        assert_eq!(url_decode("%2e"), ".");
502        assert_eq!(url_decode("%2E"), ".");
503        assert_eq!(url_decode(".%2ephp"), "..php");
504    }
505
506    #[test]
507    fn test_url_decode_php_bypass() {
508        // Attacker tries to bypass .php blocking
509        assert_eq!(url_decode(".ph%70"), ".php");
510        assert_eq!(url_decode("%2ephp"), ".php");
511        assert_eq!(url_decode(".%70%68%70"), ".php");
512    }
513
514    #[test]
515    fn test_url_decode_env_bypass() {
516        // Attacker tries to bypass .env blocking
517        assert_eq!(url_decode(".%65nv"), ".env");
518        assert_eq!(url_decode("%2eenv"), ".env");
519        assert_eq!(url_decode("%2e%65%6e%76"), ".env");
520    }
521
522    #[test]
523    fn test_url_decode_multiple_encodings() {
524        assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
525    }
526
527    #[test]
528    fn test_url_decode_invalid_hex() {
529        // Invalid hex should be preserved
530        assert_eq!(url_decode("%GG"), "%GG");
531        assert_eq!(url_decode("%"), "%");
532        assert_eq!(url_decode("%2"), "%2");
533        assert_eq!(url_decode("%ZZ"), "%ZZ");
534    }
535
536    #[test]
537    fn test_url_decode_mixed_content() {
538        assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
539        assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
540    }
541
542    #[test]
543    fn test_url_decode_unicode() {
544        // UTF-8 encoded characters
545        assert_eq!(url_decode("%C3%A9"), "é"); // é in UTF-8
546        assert_eq!(url_decode("caf%C3%A9"), "café");
547    }
548
549    // ===========================================
550    // is_url_pattern_blocked tests
551    // ===========================================
552
553    #[test]
554    fn test_url_pattern_blocked_simple() {
555        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
556
557        assert!(is_url_pattern_blocked("/file.php", &config));
558        assert!(is_url_pattern_blocked("/.env", &config));
559        assert!(is_url_pattern_blocked("/path/to/file.php", &config));
560    }
561
562    #[test]
563    fn test_url_pattern_not_blocked() {
564        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
565
566        assert!(!is_url_pattern_blocked("/file.html", &config));
567        assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
568        assert!(!is_url_pattern_blocked("/", &config));
569    }
570
571    #[test]
572    fn test_url_pattern_blocked_empty_patterns() {
573        let config = TestConfig::new();
574
575        assert!(!is_url_pattern_blocked("/file.php", &config));
576        assert!(!is_url_pattern_blocked("/.env", &config));
577    }
578
579    #[test]
580    fn test_url_pattern_blocked_bypass_attempt() {
581        let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
582
583        // URL-encoded bypass attempts should still be blocked
584        assert!(is_url_pattern_blocked("/.ph%70", &config)); // .php
585        assert!(is_url_pattern_blocked("/%2eenv", &config)); // .env
586        assert!(is_url_pattern_blocked("/adm%69n", &config)); // admin
587    }
588
589    #[test]
590    fn test_url_pattern_blocked_double_encoding_attempt() {
591        let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
592
593        // Single encoding should be caught
594        assert!(is_url_pattern_blocked("/.ph%70", &config));
595    }
596
597    #[test]
598    fn test_url_pattern_blocked_case_insensitive() {
599        let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
600
601        // Pattern matching is case-insensitive to prevent bypass
602        assert!(is_url_pattern_blocked("/file.PHP", &config));
603        assert!(is_url_pattern_blocked("/file.php", &config));
604        assert!(is_url_pattern_blocked("/file.Php", &config));
605    }
606
607    #[test]
608    fn test_url_pattern_blocked_partial_match() {
609        let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
610
611        assert!(is_url_pattern_blocked("/admin/panel", &config));
612        assert!(is_url_pattern_blocked("/path/admin", &config));
613        assert!(is_url_pattern_blocked("/administrator", &config)); // Contains "admin"
614    }
615
616    // ===========================================
617    // is_method_blocked tests
618    // ===========================================
619
620    #[test]
621    fn test_method_blocked() {
622        let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
623
624        assert!(is_method_blocked("TRACE", &config));
625        assert!(is_method_blocked("CONNECT", &config));
626    }
627
628    #[test]
629    fn test_method_not_blocked() {
630        let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
631
632        assert!(!is_method_blocked("GET", &config));
633        assert!(!is_method_blocked("POST", &config));
634        assert!(!is_method_blocked("PUT", &config));
635        assert!(!is_method_blocked("DELETE", &config));
636    }
637
638    #[test]
639    fn test_method_blocked_empty_list() {
640        let config = TestConfig::new();
641
642        assert!(!is_method_blocked("TRACE", &config));
643        assert!(!is_method_blocked("GET", &config));
644    }
645
646    #[test]
647    fn test_method_blocked_case_insensitive() {
648        let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
649
650        assert!(is_method_blocked("TRACE", &config));
651        assert!(is_method_blocked("trace", &config));
652        assert!(is_method_blocked("Trace", &config));
653    }
654
655    // ===========================================
656    // create_error_response tests
657    // ===========================================
658
659    #[test]
660    fn test_create_error_response_status() {
661        let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
662        assert_eq!(response.status(), StatusCode::NOT_FOUND);
663
664        let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
665        assert_eq!(response.status(), StatusCode::FORBIDDEN);
666
667        let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
668        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
669    }
670
671    #[test]
672    fn test_create_error_response_content_type() {
673        let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
674        assert_eq!(
675            response.headers().get("content-type").unwrap(),
676            "text/plain"
677        );
678    }
679
680    #[tokio::test]
681    async fn test_create_error_response_body() {
682        let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
683        let body = response.into_body().collect().await.unwrap().to_bytes();
684        assert_eq!(body, "Resource not found");
685    }
686
687    #[tokio::test]
688    async fn test_create_error_response_empty_message() {
689        let response = create_error_response(StatusCode::NO_CONTENT, "");
690        let body = response.into_body().collect().await.unwrap().to_bytes();
691        assert_eq!(body, "");
692    }
693
694    // ===========================================
695    // create_unauthorized_response tests
696    // ===========================================
697
698    #[test]
699    fn test_unauthorized_response_status() {
700        let response = create_unauthorized_response("WiseGate");
701        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
702    }
703
704    #[test]
705    fn test_unauthorized_response_www_authenticate_header() {
706        let response = create_unauthorized_response("WiseGate");
707        let header = response
708            .headers()
709            .get("www-authenticate")
710            .unwrap()
711            .to_str()
712            .unwrap();
713        assert_eq!(header, "Basic realm=\"WiseGate\"");
714    }
715
716    #[test]
717    fn test_unauthorized_response_realm_with_quotes() {
718        let response = create_unauthorized_response("My \"Realm\"");
719        let header = response
720            .headers()
721            .get("www-authenticate")
722            .unwrap()
723            .to_str()
724            .unwrap();
725        assert_eq!(header, "Basic realm=\"My \\\"Realm\\\"\"");
726    }
727
728    #[test]
729    fn test_unauthorized_response_realm_with_backslash() {
730        let response = create_unauthorized_response("My\\Realm");
731        let header = response
732            .headers()
733            .get("www-authenticate")
734            .unwrap()
735            .to_str()
736            .unwrap();
737        assert_eq!(header, "Basic realm=\"My\\\\Realm\"");
738    }
739
740    #[test]
741    fn test_unauthorized_response_content_type() {
742        let response = create_unauthorized_response("WiseGate");
743        assert_eq!(
744            response.headers().get("content-type").unwrap(),
745            "text/plain"
746        );
747    }
748
749    // ===========================================
750    // double-encoding test
751    // ===========================================
752
753    #[test]
754    fn test_url_decode_double_encoding_not_decoded_twice() {
755        // %252e decodes to %2e on first pass — should NOT become '.'
756        assert_eq!(url_decode("%252e"), "%2e");
757        assert_eq!(url_decode("%2565nv"), "%65nv");
758    }
759}