1use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::error::WiseGateError;
27use crate::types::{ConfigProvider, RateLimiter};
28use crate::{auth, headers, ip_filter, rate_limiter};
29
30pub async fn handle_request<C: ConfigProvider>(
60 req: Request<Incoming>,
61 forward_host: Arc<str>,
62 forward_port: u16,
63 limiter: RateLimiter,
64 config: Arc<C>,
65 http_client: reqwest::Client,
66) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
67 let real_client_ip: Option<String> =
69 match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
70 Some(ip) => Some(ip),
71 None => {
72 if config.allowed_proxy_ips().is_none() {
73 None
75 } else {
76 let err = WiseGateError::InvalidIp("missing or invalid proxy headers".into());
78 return Ok(create_error_response(err.status_code(), err.user_message()));
79 }
80 }
81 };
82
83 if let Some(ref ip) = real_client_ip
85 && ip_filter::is_ip_blocked(ip, config.as_ref())
86 {
87 let err = WiseGateError::IpBlocked(ip.clone());
88 return Ok(create_error_response(err.status_code(), err.user_message()));
89 }
90
91 let request_path = req.uri().path();
93 if is_url_pattern_blocked(request_path, config.as_ref()) {
94 let err = WiseGateError::PatternBlocked(request_path.to_string());
95 return Ok(create_error_response(err.status_code(), err.user_message()));
96 }
97
98 let request_method = req.method().as_str();
100 if is_method_blocked(request_method, config.as_ref()) {
101 let err = WiseGateError::MethodBlocked(request_method.to_string());
102 return Ok(create_error_response(err.status_code(), err.user_message()));
103 }
104
105 if config.is_auth_enabled() {
108 let auth_header = req
109 .headers()
110 .get(headers::AUTHORIZATION)
111 .and_then(|v| v.to_str().ok());
112
113 let basic_auth_enabled = config.is_basic_auth_enabled();
114 let bearer_auth_enabled = config.is_bearer_auth_enabled();
115
116 let basic_auth_passed =
117 basic_auth_enabled && auth::check_basic_auth(auth_header, config.auth_credentials());
118 let bearer_auth_passed =
119 bearer_auth_enabled && auth::check_bearer_token(auth_header, config.bearer_token());
120
121 if !basic_auth_passed && !bearer_auth_passed {
123 return Ok(create_unauthorized_response(config.auth_realm()));
124 }
125 }
126
127 if let Some(ref ip) = real_client_ip
129 && !rate_limiter::check_rate_limit(&limiter, ip, config.as_ref()).await
130 {
131 let err = WiseGateError::RateLimitExceeded(ip.clone());
132 return Ok(create_error_response(err.status_code(), err.user_message()));
133 }
134
135 let mut req = req;
137 if let Some(ref ip) = real_client_ip
138 && let Ok(header_value) = ip.parse()
139 {
140 req.headers_mut().insert(headers::X_REAL_IP, header_value);
141 }
142
143 forward_request(
145 req,
146 &forward_host,
147 forward_port,
148 config.as_ref(),
149 &http_client,
150 )
151 .await
152}
153
154async fn forward_request(
156 req: Request<Incoming>,
157 host: &str,
158 port: u16,
159 config: &impl ConfigProvider,
160 http_client: &reqwest::Client,
161) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
162 let proxy_config = config.proxy_config();
163 let (parts, body) = req.into_parts();
164
165 if proxy_config.max_body_size > 0
167 && let Some(content_length) = parts
168 .headers
169 .get(headers::CONTENT_LENGTH)
170 .and_then(|v| v.to_str().ok())
171 .and_then(|v| v.parse::<usize>().ok())
172 && content_length > proxy_config.max_body_size
173 {
174 let err = WiseGateError::BodyTooLarge {
175 size: content_length,
176 max: proxy_config.max_body_size,
177 };
178 return Ok(create_error_response(err.status_code(), err.user_message()));
179 }
180
181 let body_bytes = match body.collect().await {
182 Ok(bytes) => {
183 let collected_bytes = bytes.to_bytes();
184
185 if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
187 {
188 let err = WiseGateError::BodyTooLarge {
189 size: collected_bytes.len(),
190 max: proxy_config.max_body_size,
191 };
192 return Ok(create_error_response(err.status_code(), err.user_message()));
193 }
194
195 collected_bytes
196 }
197 Err(e) => {
198 let err = WiseGateError::BodyReadError(e.to_string());
199 return Ok(create_error_response(err.status_code(), err.user_message()));
200 }
201 };
202
203 forward_with_reqwest(parts, body_bytes, host, port, http_client).await
204}
205
206async fn forward_with_reqwest(
208 parts: hyper::http::request::Parts,
209 body_bytes: bytes::Bytes,
210 host: &str,
211 port: u16,
212 client: &reqwest::Client,
213) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
214 let destination_uri = format!(
216 "http://{}:{}{}",
217 host,
218 port,
219 parts.uri.path_and_query().map_or("", |pq| pq.as_str())
220 );
221
222 let method = match reqwest::Method::from_bytes(parts.method.as_str().as_bytes()) {
224 Ok(m) => m,
225 Err(_) => {
226 let err =
227 WiseGateError::MethodBlocked(format!("{} (unsupported)", parts.method.as_str()));
228 return Ok(create_error_response(err.status_code(), err.user_message()));
229 }
230 };
231 let mut req_builder = client.request(method, &destination_uri);
232
233 for (name, value) in parts.headers.iter() {
235 if name != headers::HOST
236 && name != headers::CONTENT_LENGTH
237 && !headers::is_hop_by_hop(name.as_str())
238 && let Ok(header_value) = value.to_str()
239 {
240 req_builder = req_builder.header(name.as_str(), header_value);
241 }
242 }
243
244 if !body_bytes.is_empty() {
246 req_builder = req_builder.body(body_bytes);
247 }
248
249 match req_builder.send().await {
251 Ok(response) => {
252 let status = response.status();
253 let resp_headers = response.headers().clone();
254
255 match response.bytes().await {
256 Ok(body_bytes) => {
257 let mut hyper_response = match Response::builder()
258 .status(status.as_u16())
259 .body(Full::new(body_bytes))
260 {
261 Ok(resp) => resp,
262 Err(e) => {
263 let err = WiseGateError::ProxyError(format!(
264 "Failed to build response: {}",
265 e
266 ));
267 return Ok(create_error_response(
268 err.status_code(),
269 err.user_message(),
270 ));
271 }
272 };
273
274 for (name, value) in resp_headers.iter() {
276 if !headers::is_hop_by_hop(name.as_str())
278 && let (Ok(hyper_name), Ok(hyper_value)) = (
279 hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
280 hyper::header::HeaderValue::from_bytes(value.as_bytes()),
281 )
282 {
283 hyper_response.headers_mut().insert(hyper_name, hyper_value);
284 }
285 }
286
287 Ok(hyper_response)
288 }
289 Err(e) => {
290 let err = WiseGateError::BodyReadError(format!("response: {}", e));
291 Ok(create_error_response(err.status_code(), err.user_message()))
292 }
293 }
294 }
295 Err(err) => {
296 let wise_err = if err.is_timeout() {
298 WiseGateError::UpstreamTimeout(err.to_string())
299 } else if err.is_connect() {
300 WiseGateError::UpstreamConnectionFailed(err.to_string())
301 } else {
302 WiseGateError::ProxyError(err.to_string())
303 };
304 Ok(create_error_response(
305 wise_err.status_code(),
306 wise_err.user_message(),
307 ))
308 }
309 }
310}
311
312pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
337 Response::builder()
338 .status(status)
339 .header(headers::CONTENT_TYPE, "text/plain")
340 .body(Full::new(bytes::Bytes::from(message.to_string())))
341 .unwrap_or_else(|_| {
342 Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
344 })
345}
346
347pub fn create_unauthorized_response(realm: &str) -> Response<Full<bytes::Bytes>> {
360 let sanitized_realm = realm.replace('\\', "\\\\").replace('"', "\\\"");
362 Response::builder()
363 .status(StatusCode::UNAUTHORIZED)
364 .header(
365 headers::WWW_AUTHENTICATE,
366 format!("Basic realm=\"{}\"", sanitized_realm),
367 )
368 .header(headers::CONTENT_TYPE, "text/plain")
369 .body(Full::new(bytes::Bytes::from("401 Unauthorized")))
370 .unwrap_or_else(|_| Response::new(Full::new(bytes::Bytes::from("401 Unauthorized"))))
371}
372
373fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
376 let blocked_patterns = config.blocked_patterns();
377 if blocked_patterns.is_empty() {
378 return false;
379 }
380
381 let decoded_path = url_decode(path);
383 let has_encoding = decoded_path != path;
384
385 let path_lower = path.to_lowercase();
387 let decoded_lower = if has_encoding {
389 Some(decoded_path.to_lowercase())
390 } else {
391 None
392 };
393
394 blocked_patterns.iter().any(|pattern| {
396 path_lower.contains(pattern.as_str())
397 || decoded_lower
398 .as_ref()
399 .is_some_and(|dl| dl.contains(pattern.as_str()))
400 })
401}
402
403fn url_decode(input: &str) -> String {
407 let mut bytes = Vec::with_capacity(input.len());
408 let input_bytes = input.as_bytes();
409 let mut i = 0;
410
411 while i < input_bytes.len() {
412 if input_bytes[i] == b'%' && i + 2 < input_bytes.len() {
413 let hi = hex_digit(input_bytes[i + 1]);
415 let lo = hex_digit(input_bytes[i + 2]);
416 if let (Some(h), Some(l)) = (hi, lo) {
417 bytes.push(h << 4 | l);
418 i += 3;
419 continue;
420 }
421 }
422 bytes.push(input_bytes[i]);
423 i += 1;
424 }
425
426 String::from_utf8_lossy(&bytes).into_owned()
428}
429
430fn hex_digit(b: u8) -> Option<u8> {
432 match b {
433 b'0'..=b'9' => Some(b - b'0'),
434 b'a'..=b'f' => Some(b - b'a' + 10),
435 b'A'..=b'F' => Some(b - b'A' + 10),
436 _ => None,
437 }
438}
439
440fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
442 let blocked_methods = config.blocked_methods();
443 blocked_methods
444 .iter()
445 .any(|blocked_method| blocked_method.eq_ignore_ascii_case(method))
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use crate::test_utils::TestConfig;
452 use http_body_util::BodyExt;
453
454 #[test]
459 fn test_url_decode_no_encoding() {
460 assert_eq!(url_decode("/path/to/file"), "/path/to/file");
461 assert_eq!(url_decode("hello"), "hello");
462 assert_eq!(url_decode(""), "");
463 }
464
465 #[test]
466 fn test_url_decode_simple_encoding() {
467 assert_eq!(url_decode("%20"), " ");
468 assert_eq!(url_decode("hello%20world"), "hello world");
469 assert_eq!(url_decode("%2F"), "/");
470 }
471
472 #[test]
473 fn test_url_decode_dot_encoding() {
474 assert_eq!(url_decode("%2e"), ".");
476 assert_eq!(url_decode("%2E"), ".");
477 assert_eq!(url_decode(".%2ephp"), "..php");
478 }
479
480 #[test]
481 fn test_url_decode_php_bypass() {
482 assert_eq!(url_decode(".ph%70"), ".php");
484 assert_eq!(url_decode("%2ephp"), ".php");
485 assert_eq!(url_decode(".%70%68%70"), ".php");
486 }
487
488 #[test]
489 fn test_url_decode_env_bypass() {
490 assert_eq!(url_decode(".%65nv"), ".env");
492 assert_eq!(url_decode("%2eenv"), ".env");
493 assert_eq!(url_decode("%2e%65%6e%76"), ".env");
494 }
495
496 #[test]
497 fn test_url_decode_multiple_encodings() {
498 assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
499 }
500
501 #[test]
502 fn test_url_decode_invalid_hex() {
503 assert_eq!(url_decode("%GG"), "%GG");
505 assert_eq!(url_decode("%"), "%");
506 assert_eq!(url_decode("%2"), "%2");
507 assert_eq!(url_decode("%ZZ"), "%ZZ");
508 }
509
510 #[test]
511 fn test_url_decode_mixed_content() {
512 assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
513 assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
514 }
515
516 #[test]
517 fn test_url_decode_unicode() {
518 assert_eq!(url_decode("%C3%A9"), "é"); assert_eq!(url_decode("caf%C3%A9"), "café");
521 }
522
523 #[test]
528 fn test_url_pattern_blocked_simple() {
529 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
530
531 assert!(is_url_pattern_blocked("/file.php", &config));
532 assert!(is_url_pattern_blocked("/.env", &config));
533 assert!(is_url_pattern_blocked("/path/to/file.php", &config));
534 }
535
536 #[test]
537 fn test_url_pattern_not_blocked() {
538 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
539
540 assert!(!is_url_pattern_blocked("/file.html", &config));
541 assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
542 assert!(!is_url_pattern_blocked("/", &config));
543 }
544
545 #[test]
546 fn test_url_pattern_blocked_empty_patterns() {
547 let config = TestConfig::new();
548
549 assert!(!is_url_pattern_blocked("/file.php", &config));
550 assert!(!is_url_pattern_blocked("/.env", &config));
551 }
552
553 #[test]
554 fn test_url_pattern_blocked_bypass_attempt() {
555 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
556
557 assert!(is_url_pattern_blocked("/.ph%70", &config)); assert!(is_url_pattern_blocked("/%2eenv", &config)); assert!(is_url_pattern_blocked("/adm%69n", &config)); }
562
563 #[test]
564 fn test_url_pattern_blocked_double_encoding_attempt() {
565 let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
566
567 assert!(is_url_pattern_blocked("/.ph%70", &config));
569 }
570
571 #[test]
572 fn test_url_pattern_blocked_case_insensitive() {
573 let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
574
575 assert!(is_url_pattern_blocked("/file.PHP", &config));
577 assert!(is_url_pattern_blocked("/file.php", &config));
578 assert!(is_url_pattern_blocked("/file.Php", &config));
579 }
580
581 #[test]
582 fn test_url_pattern_blocked_partial_match() {
583 let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
584
585 assert!(is_url_pattern_blocked("/admin/panel", &config));
586 assert!(is_url_pattern_blocked("/path/admin", &config));
587 assert!(is_url_pattern_blocked("/administrator", &config)); }
589
590 #[test]
595 fn test_method_blocked() {
596 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
597
598 assert!(is_method_blocked("TRACE", &config));
599 assert!(is_method_blocked("CONNECT", &config));
600 }
601
602 #[test]
603 fn test_method_not_blocked() {
604 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
605
606 assert!(!is_method_blocked("GET", &config));
607 assert!(!is_method_blocked("POST", &config));
608 assert!(!is_method_blocked("PUT", &config));
609 assert!(!is_method_blocked("DELETE", &config));
610 }
611
612 #[test]
613 fn test_method_blocked_empty_list() {
614 let config = TestConfig::new();
615
616 assert!(!is_method_blocked("TRACE", &config));
617 assert!(!is_method_blocked("GET", &config));
618 }
619
620 #[test]
621 fn test_method_blocked_case_insensitive() {
622 let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
623
624 assert!(is_method_blocked("TRACE", &config));
625 assert!(is_method_blocked("trace", &config));
626 assert!(is_method_blocked("Trace", &config));
627 }
628
629 #[test]
634 fn test_create_error_response_status() {
635 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
636 assert_eq!(response.status(), StatusCode::NOT_FOUND);
637
638 let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
639 assert_eq!(response.status(), StatusCode::FORBIDDEN);
640
641 let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
642 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
643 }
644
645 #[test]
646 fn test_create_error_response_content_type() {
647 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
648 assert_eq!(
649 response.headers().get("content-type").unwrap(),
650 "text/plain"
651 );
652 }
653
654 #[tokio::test]
655 async fn test_create_error_response_body() {
656 let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
657 let body = response.into_body().collect().await.unwrap().to_bytes();
658 assert_eq!(body, "Resource not found");
659 }
660
661 #[tokio::test]
662 async fn test_create_error_response_empty_message() {
663 let response = create_error_response(StatusCode::NO_CONTENT, "");
664 let body = response.into_body().collect().await.unwrap().to_bytes();
665 assert_eq!(body, "");
666 }
667
668 #[test]
673 fn test_unauthorized_response_status() {
674 let response = create_unauthorized_response("WiseGate");
675 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
676 }
677
678 #[test]
679 fn test_unauthorized_response_www_authenticate_header() {
680 let response = create_unauthorized_response("WiseGate");
681 let header = response
682 .headers()
683 .get("www-authenticate")
684 .unwrap()
685 .to_str()
686 .unwrap();
687 assert_eq!(header, "Basic realm=\"WiseGate\"");
688 }
689
690 #[test]
691 fn test_unauthorized_response_realm_with_quotes() {
692 let response = create_unauthorized_response("My \"Realm\"");
693 let header = response
694 .headers()
695 .get("www-authenticate")
696 .unwrap()
697 .to_str()
698 .unwrap();
699 assert_eq!(header, "Basic realm=\"My \\\"Realm\\\"\"");
700 }
701
702 #[test]
703 fn test_unauthorized_response_realm_with_backslash() {
704 let response = create_unauthorized_response("My\\Realm");
705 let header = response
706 .headers()
707 .get("www-authenticate")
708 .unwrap()
709 .to_str()
710 .unwrap();
711 assert_eq!(header, "Basic realm=\"My\\\\Realm\"");
712 }
713
714 #[test]
715 fn test_unauthorized_response_content_type() {
716 let response = create_unauthorized_response("WiseGate");
717 assert_eq!(
718 response.headers().get("content-type").unwrap(),
719 "text/plain"
720 );
721 }
722
723 #[test]
728 fn test_url_decode_double_encoding_not_decoded_twice() {
729 assert_eq!(url_decode("%252e"), "%2e");
731 assert_eq!(url_decode("%2565nv"), "%65nv");
732 }
733}