1use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::error::WiseGateError;
27use crate::types::{ConfigProvider, RateLimiter};
28use crate::{auth, headers, ip_filter, rate_limiter};
29
30pub async fn handle_request<C: ConfigProvider>(
60 req: Request<Incoming>,
61 forward_host: String,
62 forward_port: u16,
63 limiter: RateLimiter,
64 config: Arc<C>,
65 http_client: reqwest::Client,
66) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
67 let real_client_ip =
69 match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
70 Some(ip) => ip,
71 None => {
72 if config.allowed_proxy_ips().is_none() {
75 "unknown".to_string()
76 } else {
77 let err = WiseGateError::InvalidIp("missing or invalid proxy headers".into());
79 return Ok(create_error_response(err.status_code(), err.user_message()));
80 }
81 }
82 };
83
84 if real_client_ip != "unknown" && ip_filter::is_ip_blocked(&real_client_ip, config.as_ref()) {
86 let err = WiseGateError::IpBlocked(real_client_ip);
87 return Ok(create_error_response(err.status_code(), err.user_message()));
88 }
89
90 let request_path = req.uri().path();
92 if is_url_pattern_blocked(request_path, config.as_ref()) {
93 let err = WiseGateError::PatternBlocked(request_path.to_string());
94 return Ok(create_error_response(err.status_code(), err.user_message()));
95 }
96
97 let request_method = req.method().as_str();
99 if is_method_blocked(request_method, config.as_ref()) {
100 let err = WiseGateError::MethodBlocked(request_method.to_string());
101 return Ok(create_error_response(err.status_code(), err.user_message()));
102 }
103
104 if config.is_auth_enabled() {
107 let auth_header = req
108 .headers()
109 .get(headers::AUTHORIZATION)
110 .and_then(|v| v.to_str().ok());
111
112 let basic_auth_enabled = config.is_basic_auth_enabled();
113 let bearer_auth_enabled = config.is_bearer_auth_enabled();
114
115 let basic_auth_passed =
116 basic_auth_enabled && auth::check_basic_auth(auth_header, config.auth_credentials());
117 let bearer_auth_passed =
118 bearer_auth_enabled && auth::check_bearer_token(auth_header, config.bearer_token());
119
120 if !basic_auth_passed && !bearer_auth_passed {
122 return Ok(create_unauthorized_response(config.auth_realm()));
123 }
124 }
125
126 if real_client_ip != "unknown"
128 && !rate_limiter::check_rate_limit(&limiter, &real_client_ip, config.as_ref()).await
129 {
130 let err = WiseGateError::RateLimitExceeded(real_client_ip);
131 return Ok(create_error_response(err.status_code(), err.user_message()));
132 }
133
134 let mut req = req;
136 if real_client_ip != "unknown"
137 && let Ok(header_value) = real_client_ip.parse()
138 {
139 req.headers_mut().insert("x-real-ip", header_value);
140 }
141
142 forward_request(
144 req,
145 &forward_host,
146 forward_port,
147 config.as_ref(),
148 &http_client,
149 )
150 .await
151}
152
153async fn forward_request(
155 req: Request<Incoming>,
156 host: &str,
157 port: u16,
158 config: &impl ConfigProvider,
159 http_client: &reqwest::Client,
160) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
161 let proxy_config = config.proxy_config();
162 let (parts, body) = req.into_parts();
163 let body_bytes = match body.collect().await {
164 Ok(bytes) => {
165 let collected_bytes = bytes.to_bytes();
166
167 if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
169 {
170 let err = WiseGateError::BodyTooLarge {
171 size: collected_bytes.len(),
172 max: proxy_config.max_body_size,
173 };
174 return Ok(create_error_response(err.status_code(), err.user_message()));
175 }
176
177 collected_bytes
178 }
179 Err(e) => {
180 let err = WiseGateError::BodyReadError(e.to_string());
181 return Ok(create_error_response(err.status_code(), err.user_message()));
182 }
183 };
184
185 forward_with_reqwest(parts, body_bytes, host, port, http_client).await
186}
187
188async fn forward_with_reqwest(
190 parts: hyper::http::request::Parts,
191 body_bytes: bytes::Bytes,
192 host: &str,
193 port: u16,
194 client: &reqwest::Client,
195) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
196 let destination_uri = format!(
198 "http://{}:{}{}",
199 host,
200 port,
201 parts.uri.path_and_query().map_or("", |pq| pq.as_str())
202 );
203
204 let mut req_builder = match parts.method.as_str() {
206 "GET" => client.get(&destination_uri),
207 "POST" => client.post(&destination_uri),
208 "PUT" => client.put(&destination_uri),
209 "DELETE" => client.delete(&destination_uri),
210 "HEAD" => client.head(&destination_uri),
211 "PATCH" => client.patch(&destination_uri),
212 "OPTIONS" => client.request(reqwest::Method::OPTIONS, &destination_uri),
213 method => {
214 match reqwest::Method::from_bytes(method.as_bytes()) {
216 Ok(custom_method) => client.request(custom_method, &destination_uri),
217 Err(_) => {
218 let err = WiseGateError::MethodBlocked(format!("{} (unsupported)", method));
219 return Ok(create_error_response(err.status_code(), err.user_message()));
220 }
221 }
222 }
223 };
224
225 for (name, value) in parts.headers.iter() {
227 if name != "host"
228 && name != "content-length"
229 && let Ok(header_value) = value.to_str()
230 {
231 req_builder = req_builder.header(name.as_str(), header_value);
232 }
233 }
234
235 if !body_bytes.is_empty() {
237 req_builder = req_builder.body(body_bytes.to_vec());
238 }
239
240 match req_builder.send().await {
242 Ok(response) => {
243 let status = response.status();
244 let resp_headers = response.headers().clone();
245
246 match response.bytes().await {
247 Ok(body_bytes) => {
248 let mut hyper_response = match Response::builder()
249 .status(status.as_u16())
250 .body(Full::new(body_bytes))
251 {
252 Ok(resp) => resp,
253 Err(e) => {
254 let err = WiseGateError::ProxyError(format!(
255 "Failed to build response: {}",
256 e
257 ));
258 return Ok(create_error_response(
259 err.status_code(),
260 err.user_message(),
261 ));
262 }
263 };
264
265 for (name, value) in resp_headers.iter() {
267 let header_name = name.as_str().to_lowercase();
268 if !headers::is_hop_by_hop(&header_name)
270 && let (Ok(hyper_name), Ok(hyper_value)) = (
271 hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
272 hyper::header::HeaderValue::from_bytes(value.as_bytes()),
273 )
274 {
275 hyper_response.headers_mut().insert(hyper_name, hyper_value);
276 }
277 }
278
279 Ok(hyper_response)
280 }
281 Err(e) => {
282 let err = WiseGateError::BodyReadError(format!("response: {}", e));
283 Ok(create_error_response(err.status_code(), err.user_message()))
284 }
285 }
286 }
287 Err(err) => {
288 let wise_err = if err.is_timeout() {
290 WiseGateError::UpstreamTimeout(err.to_string())
291 } else if err.is_connect() {
292 WiseGateError::UpstreamConnectionFailed(err.to_string())
293 } else {
294 WiseGateError::ProxyError(err.to_string())
295 };
296 Ok(create_error_response(
297 wise_err.status_code(),
298 wise_err.user_message(),
299 ))
300 }
301 }
302}
303
304pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
329 Response::builder()
330 .status(status)
331 .header("content-type", "text/plain")
332 .body(Full::new(bytes::Bytes::from(message.to_string())))
333 .unwrap_or_else(|_| {
334 Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
336 })
337}
338
339pub fn create_unauthorized_response(realm: &str) -> Response<Full<bytes::Bytes>> {
352 Response::builder()
353 .status(StatusCode::UNAUTHORIZED)
354 .header(
355 headers::WWW_AUTHENTICATE,
356 format!("Basic realm=\"{}\"", realm),
357 )
358 .header("content-type", "text/plain")
359 .body(Full::new(bytes::Bytes::from("401 Unauthorized")))
360 .unwrap_or_else(|_| Response::new(Full::new(bytes::Bytes::from("401 Unauthorized"))))
361}
362
363fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
366 let blocked_patterns = config.blocked_patterns();
367 if blocked_patterns.is_empty() {
368 return false;
369 }
370
371 let decoded_path = url_decode(path);
373
374 blocked_patterns
376 .iter()
377 .any(|pattern| path.contains(pattern) || decoded_path.contains(pattern))
378}
379
380fn url_decode(input: &str) -> String {
384 let mut bytes = Vec::with_capacity(input.len());
385 let mut chars = input.chars().peekable();
386
387 while let Some(c) = chars.next() {
388 if c == '%' {
389 let hex: String = chars.by_ref().take(2).collect();
391 if hex.len() == 2
392 && let Ok(byte) = u8::from_str_radix(&hex, 16)
393 {
394 bytes.push(byte);
395 continue;
396 }
397 bytes.extend_from_slice(b"%");
399 bytes.extend_from_slice(hex.as_bytes());
400 } else {
401 let mut buf = [0u8; 4];
403 bytes.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
404 }
405 }
406
407 String::from_utf8_lossy(&bytes).into_owned()
409}
410
411fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
413 let blocked_methods = config.blocked_methods();
414 blocked_methods
415 .iter()
416 .any(|blocked_method| blocked_method == &method.to_uppercase())
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::test_utils::TestConfig;
423 use http_body_util::BodyExt;
424
425 #[test]
430 fn test_url_decode_no_encoding() {
431 assert_eq!(url_decode("/path/to/file"), "/path/to/file");
432 assert_eq!(url_decode("hello"), "hello");
433 assert_eq!(url_decode(""), "");
434 }
435
436 #[test]
437 fn test_url_decode_simple_encoding() {
438 assert_eq!(url_decode("%20"), " ");
439 assert_eq!(url_decode("hello%20world"), "hello world");
440 assert_eq!(url_decode("%2F"), "/");
441 }
442
443 #[test]
444 fn test_url_decode_dot_encoding() {
445 assert_eq!(url_decode("%2e"), ".");
447 assert_eq!(url_decode("%2E"), ".");
448 assert_eq!(url_decode(".%2ephp"), "..php");
449 }
450
451 #[test]
452 fn test_url_decode_php_bypass() {
453 assert_eq!(url_decode(".ph%70"), ".php");
455 assert_eq!(url_decode("%2ephp"), ".php");
456 assert_eq!(url_decode(".%70%68%70"), ".php");
457 }
458
459 #[test]
460 fn test_url_decode_env_bypass() {
461 assert_eq!(url_decode(".%65nv"), ".env");
463 assert_eq!(url_decode("%2eenv"), ".env");
464 assert_eq!(url_decode("%2e%65%6e%76"), ".env");
465 }
466
467 #[test]
468 fn test_url_decode_multiple_encodings() {
469 assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
470 }
471
472 #[test]
473 fn test_url_decode_invalid_hex() {
474 assert_eq!(url_decode("%GG"), "%GG");
476 assert_eq!(url_decode("%"), "%");
477 assert_eq!(url_decode("%2"), "%2");
478 assert_eq!(url_decode("%ZZ"), "%ZZ");
479 }
480
481 #[test]
482 fn test_url_decode_mixed_content() {
483 assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
484 assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
485 }
486
487 #[test]
488 fn test_url_decode_unicode() {
489 assert_eq!(url_decode("%C3%A9"), "é"); assert_eq!(url_decode("caf%C3%A9"), "café");
492 }
493
494 #[test]
499 fn test_url_pattern_blocked_simple() {
500 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
501
502 assert!(is_url_pattern_blocked("/file.php", &config));
503 assert!(is_url_pattern_blocked("/.env", &config));
504 assert!(is_url_pattern_blocked("/path/to/file.php", &config));
505 }
506
507 #[test]
508 fn test_url_pattern_not_blocked() {
509 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
510
511 assert!(!is_url_pattern_blocked("/file.html", &config));
512 assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
513 assert!(!is_url_pattern_blocked("/", &config));
514 }
515
516 #[test]
517 fn test_url_pattern_blocked_empty_patterns() {
518 let config = TestConfig::new();
519
520 assert!(!is_url_pattern_blocked("/file.php", &config));
521 assert!(!is_url_pattern_blocked("/.env", &config));
522 }
523
524 #[test]
525 fn test_url_pattern_blocked_bypass_attempt() {
526 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
527
528 assert!(is_url_pattern_blocked("/.ph%70", &config)); assert!(is_url_pattern_blocked("/%2eenv", &config)); assert!(is_url_pattern_blocked("/adm%69n", &config)); }
533
534 #[test]
535 fn test_url_pattern_blocked_double_encoding_attempt() {
536 let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
537
538 assert!(is_url_pattern_blocked("/.ph%70", &config));
540 }
541
542 #[test]
543 fn test_url_pattern_blocked_case_sensitive() {
544 let config = TestConfig::new().with_blocked_patterns(vec![".PHP"]);
545
546 assert!(is_url_pattern_blocked("/file.PHP", &config));
548 assert!(!is_url_pattern_blocked("/file.php", &config)); }
550
551 #[test]
552 fn test_url_pattern_blocked_partial_match() {
553 let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
554
555 assert!(is_url_pattern_blocked("/admin/panel", &config));
556 assert!(is_url_pattern_blocked("/path/admin", &config));
557 assert!(is_url_pattern_blocked("/administrator", &config)); }
559
560 #[test]
565 fn test_method_blocked() {
566 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
567
568 assert!(is_method_blocked("TRACE", &config));
569 assert!(is_method_blocked("CONNECT", &config));
570 }
571
572 #[test]
573 fn test_method_not_blocked() {
574 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
575
576 assert!(!is_method_blocked("GET", &config));
577 assert!(!is_method_blocked("POST", &config));
578 assert!(!is_method_blocked("PUT", &config));
579 assert!(!is_method_blocked("DELETE", &config));
580 }
581
582 #[test]
583 fn test_method_blocked_empty_list() {
584 let config = TestConfig::new();
585
586 assert!(!is_method_blocked("TRACE", &config));
587 assert!(!is_method_blocked("GET", &config));
588 }
589
590 #[test]
591 fn test_method_blocked_case_insensitive() {
592 let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
593
594 assert!(is_method_blocked("TRACE", &config));
595 assert!(is_method_blocked("trace", &config));
596 assert!(is_method_blocked("Trace", &config));
597 }
598
599 #[test]
604 fn test_create_error_response_status() {
605 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
606 assert_eq!(response.status(), StatusCode::NOT_FOUND);
607
608 let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
609 assert_eq!(response.status(), StatusCode::FORBIDDEN);
610
611 let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
612 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
613 }
614
615 #[test]
616 fn test_create_error_response_content_type() {
617 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
618 assert_eq!(
619 response.headers().get("content-type").unwrap(),
620 "text/plain"
621 );
622 }
623
624 #[tokio::test]
625 async fn test_create_error_response_body() {
626 let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
627 let body = response.into_body().collect().await.unwrap().to_bytes();
628 assert_eq!(body, "Resource not found");
629 }
630
631 #[tokio::test]
632 async fn test_create_error_response_empty_message() {
633 let response = create_error_response(StatusCode::NO_CONTENT, "");
634 let body = response.into_body().collect().await.unwrap().to_bytes();
635 assert_eq!(body, "");
636 }
637}