1use http_body_util::{BodyExt, Full};
22use hyper::{Request, Response, StatusCode, body::Incoming};
23use std::convert::Infallible;
24use std::sync::Arc;
25
26use crate::types::{ConfigProvider, RateLimiter};
27use crate::{headers, ip_filter, rate_limiter};
28
29pub async fn handle_request<C: ConfigProvider>(
59 req: Request<Incoming>,
60 forward_host: String,
61 forward_port: u16,
62 limiter: RateLimiter,
63 config: Arc<C>,
64 http_client: reqwest::Client,
65) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
66 let real_client_ip =
68 match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
69 Some(ip) => ip,
70 None => {
71 if config.allowed_proxy_ips().is_none() {
74 "unknown".to_string()
75 } else {
76 return Ok(create_error_response(
78 StatusCode::FORBIDDEN,
79 "Invalid request: missing or invalid proxy headers",
80 ));
81 }
82 }
83 };
84
85 if real_client_ip != "unknown" && ip_filter::is_ip_blocked(&real_client_ip, config.as_ref()) {
87 return Ok(create_error_response(
88 StatusCode::FORBIDDEN,
89 "IP address is blocked",
90 ));
91 }
92
93 let request_path = req.uri().path();
95 if is_url_pattern_blocked(request_path, config.as_ref()) {
96 return Ok(create_error_response(StatusCode::NOT_FOUND, "Not Found"));
97 }
98
99 let request_method = req.method().as_str();
101 if is_method_blocked(request_method, config.as_ref()) {
102 return Ok(create_error_response(
103 StatusCode::METHOD_NOT_ALLOWED,
104 "HTTP method not allowed",
105 ));
106 }
107
108 if real_client_ip != "unknown"
110 && !rate_limiter::check_rate_limit(&limiter, &real_client_ip, config.as_ref()).await
111 {
112 return Ok(create_error_response(
113 StatusCode::TOO_MANY_REQUESTS,
114 "Rate limit exceeded",
115 ));
116 }
117
118 let mut req = req;
120 if real_client_ip != "unknown"
121 && let Ok(header_value) = real_client_ip.parse()
122 {
123 req.headers_mut().insert("x-real-ip", header_value);
124 }
125
126 forward_request(
128 req,
129 &forward_host,
130 forward_port,
131 config.as_ref(),
132 &http_client,
133 )
134 .await
135}
136
137async fn forward_request(
139 req: Request<Incoming>,
140 host: &str,
141 port: u16,
142 config: &impl ConfigProvider,
143 http_client: &reqwest::Client,
144) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
145 let proxy_config = config.proxy_config();
146 let (parts, body) = req.into_parts();
147 let body_bytes = match body.collect().await {
148 Ok(bytes) => {
149 let collected_bytes = bytes.to_bytes();
150
151 if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
153 {
154 return Ok(create_error_response(
155 StatusCode::PAYLOAD_TOO_LARGE,
156 "Request body too large",
157 ));
158 }
159
160 collected_bytes
161 }
162 Err(_) => {
163 return Ok(create_error_response(
164 StatusCode::BAD_REQUEST,
165 "Failed to read request body",
166 ));
167 }
168 };
169
170 forward_with_reqwest(parts, body_bytes, host, port, http_client).await
171}
172
173async fn forward_with_reqwest(
175 parts: hyper::http::request::Parts,
176 body_bytes: bytes::Bytes,
177 host: &str,
178 port: u16,
179 client: &reqwest::Client,
180) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
181 let destination_uri = format!(
183 "http://{}:{}{}",
184 host,
185 port,
186 parts.uri.path_and_query().map_or("", |pq| pq.as_str())
187 );
188
189 let mut req_builder = match parts.method.as_str() {
191 "GET" => client.get(&destination_uri),
192 "POST" => client.post(&destination_uri),
193 "PUT" => client.put(&destination_uri),
194 "DELETE" => client.delete(&destination_uri),
195 "HEAD" => client.head(&destination_uri),
196 "PATCH" => client.patch(&destination_uri),
197 "OPTIONS" => client.request(reqwest::Method::OPTIONS, &destination_uri),
198 method => {
199 match reqwest::Method::from_bytes(method.as_bytes()) {
201 Ok(custom_method) => client.request(custom_method, &destination_uri),
202 Err(_) => {
203 return Ok(create_error_response(
204 StatusCode::METHOD_NOT_ALLOWED,
205 "HTTP method not supported",
206 ));
207 }
208 }
209 }
210 };
211
212 for (name, value) in parts.headers.iter() {
214 if name != "host"
215 && name != "content-length"
216 && let Ok(header_value) = value.to_str()
217 {
218 req_builder = req_builder.header(name.as_str(), header_value);
219 }
220 }
221
222 if !body_bytes.is_empty() {
224 req_builder = req_builder.body(body_bytes.to_vec());
225 }
226
227 match req_builder.send().await {
229 Ok(response) => {
230 let status = response.status();
231 let resp_headers = response.headers().clone();
232
233 match response.bytes().await {
234 Ok(body_bytes) => {
235 let mut hyper_response = match Response::builder()
236 .status(status.as_u16())
237 .body(Full::new(body_bytes))
238 {
239 Ok(resp) => resp,
240 Err(_) => {
241 return Ok(create_error_response(
242 StatusCode::INTERNAL_SERVER_ERROR,
243 "Failed to build response",
244 ));
245 }
246 };
247
248 for (name, value) in resp_headers.iter() {
250 let header_name = name.as_str().to_lowercase();
251 if !headers::is_hop_by_hop(&header_name)
253 && let (Ok(hyper_name), Ok(hyper_value)) = (
254 hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
255 hyper::header::HeaderValue::from_bytes(value.as_bytes()),
256 )
257 {
258 hyper_response.headers_mut().insert(hyper_name, hyper_value);
259 }
260 }
261
262 Ok(hyper_response)
263 }
264 Err(_) => Ok(create_error_response(
265 StatusCode::BAD_GATEWAY,
266 "Failed to read response body",
267 )),
268 }
269 }
270 Err(err) => {
271 if err.is_timeout() {
273 Ok(create_error_response(
274 StatusCode::GATEWAY_TIMEOUT,
275 "Upstream service timeout",
276 ))
277 } else if err.is_connect() {
278 Ok(create_error_response(
279 StatusCode::BAD_GATEWAY,
280 "Could not connect to upstream service",
281 ))
282 } else {
283 Ok(create_error_response(
284 StatusCode::BAD_GATEWAY,
285 "Upstream service error",
286 ))
287 }
288 }
289 }
290}
291
292pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
317 Response::builder()
318 .status(status)
319 .header("content-type", "text/plain")
320 .body(Full::new(bytes::Bytes::from(message.to_string())))
321 .unwrap_or_else(|_| {
322 Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
324 })
325}
326
327fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
330 let blocked_patterns = config.blocked_patterns();
331 if blocked_patterns.is_empty() {
332 return false;
333 }
334
335 let decoded_path = url_decode(path);
337
338 blocked_patterns
340 .iter()
341 .any(|pattern| path.contains(pattern) || decoded_path.contains(pattern))
342}
343
344fn url_decode(input: &str) -> String {
348 let mut bytes = Vec::with_capacity(input.len());
349 let mut chars = input.chars().peekable();
350
351 while let Some(c) = chars.next() {
352 if c == '%' {
353 let hex: String = chars.by_ref().take(2).collect();
355 if hex.len() == 2
356 && let Ok(byte) = u8::from_str_radix(&hex, 16)
357 {
358 bytes.push(byte);
359 continue;
360 }
361 bytes.extend_from_slice(b"%");
363 bytes.extend_from_slice(hex.as_bytes());
364 } else {
365 let mut buf = [0u8; 4];
367 bytes.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
368 }
369 }
370
371 String::from_utf8_lossy(&bytes).into_owned()
373}
374
375fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
377 let blocked_methods = config.blocked_methods();
378 blocked_methods
379 .iter()
380 .any(|blocked_method| blocked_method == &method.to_uppercase())
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::test_utils::TestConfig;
387 use http_body_util::BodyExt;
388
389 #[test]
394 fn test_url_decode_no_encoding() {
395 assert_eq!(url_decode("/path/to/file"), "/path/to/file");
396 assert_eq!(url_decode("hello"), "hello");
397 assert_eq!(url_decode(""), "");
398 }
399
400 #[test]
401 fn test_url_decode_simple_encoding() {
402 assert_eq!(url_decode("%20"), " ");
403 assert_eq!(url_decode("hello%20world"), "hello world");
404 assert_eq!(url_decode("%2F"), "/");
405 }
406
407 #[test]
408 fn test_url_decode_dot_encoding() {
409 assert_eq!(url_decode("%2e"), ".");
411 assert_eq!(url_decode("%2E"), ".");
412 assert_eq!(url_decode(".%2ephp"), "..php");
413 }
414
415 #[test]
416 fn test_url_decode_php_bypass() {
417 assert_eq!(url_decode(".ph%70"), ".php");
419 assert_eq!(url_decode("%2ephp"), ".php");
420 assert_eq!(url_decode(".%70%68%70"), ".php");
421 }
422
423 #[test]
424 fn test_url_decode_env_bypass() {
425 assert_eq!(url_decode(".%65nv"), ".env");
427 assert_eq!(url_decode("%2eenv"), ".env");
428 assert_eq!(url_decode("%2e%65%6e%76"), ".env");
429 }
430
431 #[test]
432 fn test_url_decode_multiple_encodings() {
433 assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
434 }
435
436 #[test]
437 fn test_url_decode_invalid_hex() {
438 assert_eq!(url_decode("%GG"), "%GG");
440 assert_eq!(url_decode("%"), "%");
441 assert_eq!(url_decode("%2"), "%2");
442 assert_eq!(url_decode("%ZZ"), "%ZZ");
443 }
444
445 #[test]
446 fn test_url_decode_mixed_content() {
447 assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
448 assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
449 }
450
451 #[test]
452 fn test_url_decode_unicode() {
453 assert_eq!(url_decode("%C3%A9"), "é"); assert_eq!(url_decode("caf%C3%A9"), "café");
456 }
457
458 #[test]
463 fn test_url_pattern_blocked_simple() {
464 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
465
466 assert!(is_url_pattern_blocked("/file.php", &config));
467 assert!(is_url_pattern_blocked("/.env", &config));
468 assert!(is_url_pattern_blocked("/path/to/file.php", &config));
469 }
470
471 #[test]
472 fn test_url_pattern_not_blocked() {
473 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
474
475 assert!(!is_url_pattern_blocked("/file.html", &config));
476 assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
477 assert!(!is_url_pattern_blocked("/", &config));
478 }
479
480 #[test]
481 fn test_url_pattern_blocked_empty_patterns() {
482 let config = TestConfig::new();
483
484 assert!(!is_url_pattern_blocked("/file.php", &config));
485 assert!(!is_url_pattern_blocked("/.env", &config));
486 }
487
488 #[test]
489 fn test_url_pattern_blocked_bypass_attempt() {
490 let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
491
492 assert!(is_url_pattern_blocked("/.ph%70", &config)); assert!(is_url_pattern_blocked("/%2eenv", &config)); assert!(is_url_pattern_blocked("/adm%69n", &config)); }
497
498 #[test]
499 fn test_url_pattern_blocked_double_encoding_attempt() {
500 let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
501
502 assert!(is_url_pattern_blocked("/.ph%70", &config));
504 }
505
506 #[test]
507 fn test_url_pattern_blocked_case_sensitive() {
508 let config = TestConfig::new().with_blocked_patterns(vec![".PHP"]);
509
510 assert!(is_url_pattern_blocked("/file.PHP", &config));
512 assert!(!is_url_pattern_blocked("/file.php", &config)); }
514
515 #[test]
516 fn test_url_pattern_blocked_partial_match() {
517 let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
518
519 assert!(is_url_pattern_blocked("/admin/panel", &config));
520 assert!(is_url_pattern_blocked("/path/admin", &config));
521 assert!(is_url_pattern_blocked("/administrator", &config)); }
523
524 #[test]
529 fn test_method_blocked() {
530 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
531
532 assert!(is_method_blocked("TRACE", &config));
533 assert!(is_method_blocked("CONNECT", &config));
534 }
535
536 #[test]
537 fn test_method_not_blocked() {
538 let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
539
540 assert!(!is_method_blocked("GET", &config));
541 assert!(!is_method_blocked("POST", &config));
542 assert!(!is_method_blocked("PUT", &config));
543 assert!(!is_method_blocked("DELETE", &config));
544 }
545
546 #[test]
547 fn test_method_blocked_empty_list() {
548 let config = TestConfig::new();
549
550 assert!(!is_method_blocked("TRACE", &config));
551 assert!(!is_method_blocked("GET", &config));
552 }
553
554 #[test]
555 fn test_method_blocked_case_insensitive() {
556 let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
557
558 assert!(is_method_blocked("TRACE", &config));
559 assert!(is_method_blocked("trace", &config));
560 assert!(is_method_blocked("Trace", &config));
561 }
562
563 #[test]
568 fn test_create_error_response_status() {
569 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
570 assert_eq!(response.status(), StatusCode::NOT_FOUND);
571
572 let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
573 assert_eq!(response.status(), StatusCode::FORBIDDEN);
574
575 let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
576 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
577 }
578
579 #[test]
580 fn test_create_error_response_content_type() {
581 let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
582 assert_eq!(
583 response.headers().get("content-type").unwrap(),
584 "text/plain"
585 );
586 }
587
588 #[tokio::test]
589 async fn test_create_error_response_body() {
590 let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
591 let body = response.into_body().collect().await.unwrap().to_bytes();
592 assert_eq!(body, "Resource not found");
593 }
594
595 #[tokio::test]
596 async fn test_create_error_response_empty_message() {
597 let response = create_error_response(StatusCode::NO_CONTENT, "");
598 let body = response.into_body().collect().await.unwrap().to_bytes();
599 assert_eq!(body, "");
600 }
601}