1use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::error::WiseGateError;
27use crate::types::{ConfigProvider, RateLimiter};
28use crate::{auth, headers, ip_filter, rate_limiter};
29
30pub async fn handle_request<C: ConfigProvider>(
79 req: Request<Incoming>,
80 forward_host: Arc<str>,
81 forward_port: u16,
82 limiter: RateLimiter,
83 config: Arc<C>,
84 http_client: reqwest::Client,
85) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
86 let real_client_ip: Option<String> =
88 match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
89 Some(ip) => Some(ip),
90 None => {
91 if config.allowed_proxy_ips().is_none() {
92 None
94 } else {
95 let err = WiseGateError::InvalidIp("missing or invalid proxy headers".into());
97 return Ok(create_error_response(err.status_code(), err.user_message()));
98 }
99 }
100 };
101
102 if let Some(ref ip) = real_client_ip
104 && ip_filter::is_ip_blocked(ip, config.as_ref())
105 {
106 let err = WiseGateError::IpBlocked(ip.clone());
107 return Ok(create_error_response(err.status_code(), err.user_message()));
108 }
109
110 let request_path = req.uri().path();
112 if is_url_pattern_blocked(request_path, config.as_ref()) {
113 let err = WiseGateError::PatternBlocked(request_path.to_string());
114 return Ok(create_error_response(err.status_code(), err.user_message()));
115 }
116
117 let request_method = req.method().as_str();
119 if is_method_blocked(request_method, config.as_ref()) {
120 let err = WiseGateError::MethodBlocked(request_method.to_string());
121 return Ok(create_error_response(err.status_code(), err.user_message()));
122 }
123
124 if config.is_auth_enabled() {
127 let auth_header = req
128 .headers()
129 .get(headers::AUTHORIZATION)
130 .and_then(|v| v.to_str().ok());
131
132 let basic_auth_enabled = config.is_basic_auth_enabled();
133 let bearer_auth_enabled = config.is_bearer_auth_enabled();
134
135 let basic_auth_passed =
136 basic_auth_enabled && auth::check_basic_auth(auth_header, config.auth_credentials());
137 let bearer_auth_passed =
138 bearer_auth_enabled && auth::check_bearer_token(auth_header, config.bearer_token());
139
140 if !basic_auth_passed && !bearer_auth_passed {
142 return Ok(create_unauthorized_response(config.auth_realm()));
143 }
144 }
145
146 if let Some(ref ip) = real_client_ip
148 && !rate_limiter::check_rate_limit(&limiter, ip, config.as_ref()).await
149 {
150 let err = WiseGateError::RateLimitExceeded(ip.clone());
151 return Ok(create_error_response(err.status_code(), err.user_message()));
152 }
153
154 let mut req = req;
157 req.headers_mut().remove(headers::X_REAL_IP);
158 if let Some(ref ip) = real_client_ip
159 && let Ok(header_value) = ip.parse()
160 {
161 req.headers_mut().insert(headers::X_REAL_IP, header_value);
162 }
163
164 forward_request(
166 req,
167 &forward_host,
168 forward_port,
169 config.as_ref(),
170 &http_client,
171 )
172 .await
173}
174
175async fn forward_request(
177 req: Request<Incoming>,
178 host: &str,
179 port: u16,
180 config: &impl ConfigProvider,
181 http_client: &reqwest::Client,
182) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
183 let proxy_config = config.proxy_config();
184 let (parts, body) = req.into_parts();
185
186 if proxy_config.max_body_size > 0
188 && let Some(content_length) = parts
189 .headers
190 .get(headers::CONTENT_LENGTH)
191 .and_then(|v| v.to_str().ok())
192 .and_then(|v| v.parse::<usize>().ok())
193 && content_length > proxy_config.max_body_size
194 {
195 let err = WiseGateError::BodyTooLarge {
196 size: content_length,
197 max: proxy_config.max_body_size,
198 };
199 return Ok(create_error_response(err.status_code(), err.user_message()));
200 }
201
202 let body_bytes = match body.collect().await {
203 Ok(bytes) => {
204 let collected_bytes = bytes.to_bytes();
205
206 if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
208 {
209 let err = WiseGateError::BodyTooLarge {
210 size: collected_bytes.len(),
211 max: proxy_config.max_body_size,
212 };
213 return Ok(create_error_response(err.status_code(), err.user_message()));
214 }
215
216 collected_bytes
217 }
218 Err(e) => {
219 let err = WiseGateError::BodyReadError(e.to_string());
220 return Ok(create_error_response(err.status_code(), err.user_message()));
221 }
222 };
223
224 let strip_auth = config.is_auth_enabled() && !config.forward_authorization_header();
225 forward_with_reqwest(parts, body_bytes, host, port, http_client, strip_auth).await
226}
227
228async fn forward_with_reqwest(
230 parts: hyper::http::request::Parts,
231 body_bytes: bytes::Bytes,
232 host: &str,
233 port: u16,
234 client: &reqwest::Client,
235 strip_auth: bool,
236) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
237 let destination_uri = format!(
239 "http://{}:{}{}",
240 host,
241 port,
242 parts.uri.path_and_query().map_or("", |pq| pq.as_str())
243 );
244
245 let method = match reqwest::Method::from_bytes(parts.method.as_str().as_bytes()) {
247 Ok(m) => m,
248 Err(_) => {
249 let err =
250 WiseGateError::MethodBlocked(format!("{} (unsupported)", parts.method.as_str()));
251 return Ok(create_error_response(err.status_code(), err.user_message()));
252 }
253 };
254 let mut req_builder = client.request(method, &destination_uri);
255
256 for (name, value) in parts.headers.iter() {
260 if name != headers::HOST
261 && name != headers::CONTENT_LENGTH
262 && !(strip_auth && name == headers::AUTHORIZATION)
263 && !headers::is_hop_by_hop(name.as_str())
264 && let Ok(header_value) = value.to_str()
265 {
266 req_builder = req_builder.header(name.as_str(), header_value);
267 }
268 }
269
270 if !body_bytes.is_empty() {
272 req_builder = req_builder.body(body_bytes);
273 }
274
275 match req_builder.send().await {
277 Ok(response) => {
278 let status = response.status();
279 let resp_headers = response.headers().clone();
280
281 match response.bytes().await {
282 Ok(body_bytes) => {
283 let mut hyper_response = match Response::builder()
284 .status(status.as_u16())
285 .body(Full::new(body_bytes))
286 {
287 Ok(resp) => resp,
288 Err(e) => {
289 let err = WiseGateError::ProxyError(format!(
290 "Failed to build response: {}",
291 e
292 ));
293 return Ok(create_error_response(
294 err.status_code(),
295 err.user_message(),
296 ));
297 }
298 };
299
300 for (name, value) in resp_headers.iter() {
302 if !headers::is_hop_by_hop(name.as_str())
304 && let (Ok(hyper_name), Ok(hyper_value)) = (
305 hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
306 hyper::header::HeaderValue::from_bytes(value.as_bytes()),
307 )
308 {
309 hyper_response.headers_mut().insert(hyper_name, hyper_value);
310 }
311 }
312
313 Ok(hyper_response)
314 }
315 Err(e) => {
316 let err = WiseGateError::BodyReadError(format!("response: {}", e));
317 Ok(create_error_response(err.status_code(), err.user_message()))
318 }
319 }
320 }
321 Err(err) => {
322 let wise_err = if err.is_timeout() {
324 WiseGateError::UpstreamTimeout(err.to_string())
325 } else if err.is_connect() {
326 WiseGateError::UpstreamConnectionFailed(err.to_string())
327 } else {
328 WiseGateError::ProxyError(err.to_string())
329 };
330 Ok(create_error_response(
331 wise_err.status_code(),
332 wise_err.user_message(),
333 ))
334 }
335 }
336}
337
338pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
363 Response::builder()
364 .status(status)
365 .header(headers::CONTENT_TYPE, "text/plain")
366 .body(Full::new(bytes::Bytes::from(message.to_string())))
367 .unwrap_or_else(|_| {
368 Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
370 })
371}
372
373pub fn create_unauthorized_response(realm: &str) -> Response<Full<bytes::Bytes>> {
386 let sanitized_realm = realm.replace('\\', "\\\\").replace('"', "\\\"");
388 Response::builder()
389 .status(StatusCode::UNAUTHORIZED)
390 .header(
391 headers::WWW_AUTHENTICATE,
392 format!("Basic realm=\"{}\"", sanitized_realm),
393 )
394 .header(headers::CONTENT_TYPE, "text/plain")
395 .body(Full::new(bytes::Bytes::from("401 Unauthorized")))
396 .unwrap_or_else(|_| Response::new(Full::new(bytes::Bytes::from("401 Unauthorized"))))
397}
398
399fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
402 let blocked_patterns = config.blocked_patterns();
403 if blocked_patterns.is_empty() {
404 return false;
405 }
406
407 let decoded_path = url_decode(path);
409 let has_encoding = decoded_path != path;
410
411 let path_lower = path.to_lowercase();
413 let decoded_lower = if has_encoding {
415 Some(decoded_path.to_lowercase())
416 } else {
417 None
418 };
419
420 blocked_patterns.iter().any(|pattern| {
422 path_lower.contains(pattern.as_str())
423 || decoded_lower
424 .as_ref()
425 .is_some_and(|dl| dl.contains(pattern.as_str()))
426 })
427}
428
429fn url_decode(input: &str) -> String {
433 let mut bytes = Vec::with_capacity(input.len());
434 let input_bytes = input.as_bytes();
435 let mut i = 0;
436
437 while i < input_bytes.len() {
438 if input_bytes[i] == b'%' && i + 2 < input_bytes.len() {
439 let hi = hex_digit(input_bytes[i + 1]);
441 let lo = hex_digit(input_bytes[i + 2]);
442 if let (Some(h), Some(l)) = (hi, lo) {
443 bytes.push(h << 4 | l);
444 i += 3;
445 continue;
446 }
447 }
448 bytes.push(input_bytes[i]);
449 i += 1;
450 }
451
452 String::from_utf8_lossy(&bytes).into_owned()
454}
455
456fn hex_digit(b: u8) -> Option<u8> {
458 match b {
459 b'0'..=b'9' => Some(b - b'0'),
460 b'a'..=b'f' => Some(b - b'a' + 10),
461 b'A'..=b'F' => Some(b - b'A' + 10),
462 _ => None,
463 }
464}
465
466fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
468 let blocked_methods = config.blocked_methods();
469 blocked_methods
470 .iter()
471 .any(|blocked_method| blocked_method.eq_ignore_ascii_case(method))
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::test_utils::TestConfig;
478 use http_body_util::BodyExt;
479
480 #[test]
485 fn test_url_decode_no_encoding() {
486 assert_eq!(url_decode("/path/to/file"), "/path/to/file");
487 assert_eq!(url_decode("hello"), "hello");
488 assert_eq!(url_decode(""), "");
489 }
490
491 #[test]
492 fn test_url_decode_simple_encoding() {
493 assert_eq!(url_decode("%20"), " ");
494 assert_eq!(url_decode("hello%20world"), "hello world");
495 assert_eq!(url_decode("%2F"), "/");
496 }
497
498 #[test]
499 fn test_url_decode_dot_encoding() {
500 assert_eq!(url_decode("%2e"), ".");
502 assert_eq!(url_decode("%2E"), ".");
503 assert_eq!(url_decode(".%2ephp"), "..php");
504 }
505
506 #[test]
507 fn test_url_decode_php_bypass() {
508 assert_eq!(url_decode(".ph%70"), ".php");
510 assert_eq!(url_decode("%2ephp"), ".php");
511 assert_eq!(url_decode(".%70%68%70"), ".php");
512 }
513
514 #[test]
515 fn test_url_decode_env_bypass() {
516 assert_eq!(url_decode(".%65nv"), ".env");
518 assert_eq!(url_decode("%2eenv"), ".env");
519 assert_eq!(url_decode("%2e%65%6e%76"), ".env");
520 }
521
522 #[test]
523 fn test_url_decode_multiple_encodings() {
524 assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
525 }
526
527 #[test]
528 fn test_url_decode_invalid_hex() {
529 assert_eq!(url_decode("%GG"), "%GG");
531 assert_eq!(url_decode("%"), "%");
532 assert_eq!(url_decode("%2"), "%2");
533 assert_eq!(url_decode("%ZZ"), "%ZZ");
534 }
535
536 #[test]
537 fn test_url_decode_mixed_content() {
538 assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
539 assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
540 }
541
542 #[test]
543 fn test_url_decode_unicode() {
544 assert_eq!(url_decode("%C3%A9"), "é"); assert_eq!(url_decode("caf%C3%A9"), "café");
547 }
548
549 #[test]
554 fn test_url_pattern_blocked_simple() {
555 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
556
557 assert!(is_url_pattern_blocked("/file.php", &config));
558 assert!(is_url_pattern_blocked("/.env", &config));
559 assert!(is_url_pattern_blocked("/path/to/file.php", &config));
560 }
561
562 #[test]
563 fn test_url_pattern_not_blocked() {
564 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
565
566 assert!(!is_url_pattern_blocked("/file.html", &config));
567 assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
568 assert!(!is_url_pattern_blocked("/", &config));
569 }
570
571 #[test]
572 fn test_url_pattern_blocked_empty_patterns() {
573 let config = TestConfig::new();
574
575 assert!(!is_url_pattern_blocked("/file.php", &config));
576 assert!(!is_url_pattern_blocked("/.env", &config));
577 }
578
579 #[test]
580 fn test_url_pattern_blocked_bypass_attempt() {
581 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
582
583 assert!(is_url_pattern_blocked("/.ph%70", &config)); assert!(is_url_pattern_blocked("/%2eenv", &config)); assert!(is_url_pattern_blocked("/adm%69n", &config)); }
588
589 #[test]
590 fn test_url_pattern_blocked_double_encoding_attempt() {
591 let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
592
593 assert!(is_url_pattern_blocked("/.ph%70", &config));
595 }
596
597 #[test]
598 fn test_url_pattern_blocked_case_insensitive() {
599 let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
600
601 assert!(is_url_pattern_blocked("/file.PHP", &config));
603 assert!(is_url_pattern_blocked("/file.php", &config));
604 assert!(is_url_pattern_blocked("/file.Php", &config));
605 }
606
607 #[test]
608 fn test_url_pattern_blocked_partial_match() {
609 let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
610
611 assert!(is_url_pattern_blocked("/admin/panel", &config));
612 assert!(is_url_pattern_blocked("/path/admin", &config));
613 assert!(is_url_pattern_blocked("/administrator", &config)); }
615
616 #[test]
621 fn test_method_blocked() {
622 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
623
624 assert!(is_method_blocked("TRACE", &config));
625 assert!(is_method_blocked("CONNECT", &config));
626 }
627
628 #[test]
629 fn test_method_not_blocked() {
630 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
631
632 assert!(!is_method_blocked("GET", &config));
633 assert!(!is_method_blocked("POST", &config));
634 assert!(!is_method_blocked("PUT", &config));
635 assert!(!is_method_blocked("DELETE", &config));
636 }
637
638 #[test]
639 fn test_method_blocked_empty_list() {
640 let config = TestConfig::new();
641
642 assert!(!is_method_blocked("TRACE", &config));
643 assert!(!is_method_blocked("GET", &config));
644 }
645
646 #[test]
647 fn test_method_blocked_case_insensitive() {
648 let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
649
650 assert!(is_method_blocked("TRACE", &config));
651 assert!(is_method_blocked("trace", &config));
652 assert!(is_method_blocked("Trace", &config));
653 }
654
655 #[test]
660 fn test_create_error_response_status() {
661 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
662 assert_eq!(response.status(), StatusCode::NOT_FOUND);
663
664 let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
665 assert_eq!(response.status(), StatusCode::FORBIDDEN);
666
667 let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
668 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
669 }
670
671 #[test]
672 fn test_create_error_response_content_type() {
673 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
674 assert_eq!(
675 response.headers().get("content-type").unwrap(),
676 "text/plain"
677 );
678 }
679
680 #[tokio::test]
681 async fn test_create_error_response_body() {
682 let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
683 let body = response.into_body().collect().await.unwrap().to_bytes();
684 assert_eq!(body, "Resource not found");
685 }
686
687 #[tokio::test]
688 async fn test_create_error_response_empty_message() {
689 let response = create_error_response(StatusCode::NO_CONTENT, "");
690 let body = response.into_body().collect().await.unwrap().to_bytes();
691 assert_eq!(body, "");
692 }
693
694 #[test]
699 fn test_unauthorized_response_status() {
700 let response = create_unauthorized_response("WiseGate");
701 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
702 }
703
704 #[test]
705 fn test_unauthorized_response_www_authenticate_header() {
706 let response = create_unauthorized_response("WiseGate");
707 let header = response
708 .headers()
709 .get("www-authenticate")
710 .unwrap()
711 .to_str()
712 .unwrap();
713 assert_eq!(header, "Basic realm=\"WiseGate\"");
714 }
715
716 #[test]
717 fn test_unauthorized_response_realm_with_quotes() {
718 let response = create_unauthorized_response("My \"Realm\"");
719 let header = response
720 .headers()
721 .get("www-authenticate")
722 .unwrap()
723 .to_str()
724 .unwrap();
725 assert_eq!(header, "Basic realm=\"My \\\"Realm\\\"\"");
726 }
727
728 #[test]
729 fn test_unauthorized_response_realm_with_backslash() {
730 let response = create_unauthorized_response("My\\Realm");
731 let header = response
732 .headers()
733 .get("www-authenticate")
734 .unwrap()
735 .to_str()
736 .unwrap();
737 assert_eq!(header, "Basic realm=\"My\\\\Realm\"");
738 }
739
740 #[test]
741 fn test_unauthorized_response_content_type() {
742 let response = create_unauthorized_response("WiseGate");
743 assert_eq!(
744 response.headers().get("content-type").unwrap(),
745 "text/plain"
746 );
747 }
748
749 #[test]
754 fn test_url_decode_double_encoding_not_decoded_twice() {
755 assert_eq!(url_decode("%252e"), "%2e");
757 assert_eq!(url_decode("%2565nv"), "%65nv");
758 }
759}