1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14#[cfg(feature = "logging")]
15use tracing::info_span;
16#[cfg(feature = "logging")]
17use tracing::Instrument;
18
19use crate::config::{extract_hostname, Config, SiteConfig};
20#[cfg(feature = "logging")]
21use crate::proxy::access_log::AccessLogGuard;
22use crate::proxy::access_log::{ensure_request_id, final_request_id};
23use crate::proxy::ActionResult;
24
25use crate::proxy::directives::{
26 handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
27 handle_strip_prefix, handle_uri_replace,
28};
29
30type ResponseBody =
33 http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
34
35fn is_hop_header(name: &header::HeaderName) -> bool {
41 matches!(
42 name,
43 &header::CONNECTION
44 | &header::UPGRADE
45 | &header::TE
46 | &header::TRAILER
47 | &header::PROXY_AUTHENTICATE
48 | &header::PROXY_AUTHORIZATION
49 )
50}
51
52pub fn process_directives(
59 directives: &[crate::config::Directive],
60 req: &mut Request<Incoming>,
61 current_path: &str,
62) -> Result<ActionResult, String> {
63 let mut modified_path = current_path.to_string();
64
65 for directive in directives {
66 match directive {
67 crate::config::Directive::Header { name, value } => {
68 if let Err(e) = handle_header(name, value.as_deref(), req) {
69 info!(" Failed to apply header {}: {}", name, e);
70 }
71 }
72
73 crate::config::Directive::UriReplace { find, replace } => {
74 handle_uri_replace(find, replace, &mut modified_path);
75 }
76
77 crate::config::Directive::StripPrefix { prefix } => {
78 handle_strip_prefix(prefix, &mut modified_path);
79 }
80
81 crate::config::Directive::HandlePath {
82 pattern,
83 directives: nested_directives,
84 } => {
85 if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
86 info!(" Matched handle_path: {}", pattern);
87 return process_directives(nested_directives, req, &remaining_path);
88 }
89 }
90
91 crate::config::Directive::Method {
92 methods,
93 directives: nested_directives,
94 } => {
95 if handle_method(methods, req) {
96 info!(" Matched method directive");
97 return process_directives(nested_directives, req, &modified_path);
98 }
99 }
100
101 crate::config::Directive::Redirect { status, url } => {
102 return Ok(handle_redirect(status, url));
103 }
104
105 crate::config::Directive::Respond { status, body } => {
106 return Ok(handle_respond(status, body));
107 }
108
109 crate::config::Directive::ReverseProxy {
110 to,
111 connect_timeout,
112 read_timeout,
113 } => {
114 return Ok(handle_reverse_proxy(
115 to,
116 &modified_path,
117 *connect_timeout,
118 *read_timeout,
119 ));
120 }
121 }
122 }
123
124 Err(format!(
125 "No action directive (respond or reverse_proxy) found in configuration for path: {}",
126 current_path
127 ))
128}
129
130pub async fn proxy(
141 mut req: Request<Incoming>,
142 client: Client<HttpsConnector<HttpConnector>, Incoming>,
143 config: Arc<Config>,
144 remote_addr: std::net::SocketAddr,
145 is_tls: bool,
146) -> Result<Response<ResponseBody>, Error> {
147 let initial_request_id = ensure_request_id(&mut req);
149
150 #[cfg(feature = "logging")]
152 let method = req.method().clone().to_string();
153 let path = req.uri().path().to_string();
154 let host = req
155 .headers()
156 .get(hyper::header::HOST)
157 .and_then(|h| h.to_str().ok())
158 .unwrap_or("localhost")
159 .to_string();
160
161 #[cfg(feature = "logging")]
162 let span = info_span!("request", req_id = %initial_request_id);
163
164 #[allow(unused_variables)]
165 let future = async move {
166 #[cfg(feature = "logging")]
167 let mut log_guard = AccessLogGuard::new(
168 initial_request_id.clone(),
169 remote_addr,
170 method,
171 path.clone(),
172 host.clone(),
173 );
174
175 let site_config = match find_site(&config, &host, is_tls) {
179 Some(config) => config,
180 None => {
181 error!("No configuration found for host: {}", host);
182 let (response, _body_len) = error_response_with_id(
183 StatusCode::NOT_FOUND,
184 &format!("No configuration found for host: {}", host),
185 &initial_request_id,
186 );
187 #[cfg(feature = "logging")]
188 {
189 log_guard.set_bytes_sent(_body_len);
190 log_guard.finish(404);
191 }
192 return Ok(response);
193 }
194 };
195
196 let action_result = match process_directives(&site_config.directives, &mut req, &path) {
198 Ok(result) => result,
199 Err(e) => {
200 error!("Directive processing error: {}", e);
201 let final_id = final_request_id(&req, &initial_request_id);
202 #[cfg(feature = "logging")]
203 {
204 log_guard.set_request_id(final_id.clone());
205 tracing::Span::current().record("req_id", final_id.as_str());
206 }
207 let (response, _body_len) =
208 error_response_with_id(StatusCode::INTERNAL_SERVER_ERROR, &e, &final_id);
209 #[cfg(feature = "logging")]
210 {
211 log_guard.set_bytes_sent(_body_len);
212 log_guard.finish(500);
213 }
214 return Ok(response);
215 }
216 };
217
218 let request_id = final_request_id(&req, &initial_request_id);
220 #[cfg(feature = "logging")]
221 {
222 log_guard.set_request_id(request_id.clone());
223 tracing::Span::current().record("req_id", request_id.as_str());
225 }
226
227 match action_result {
229 ActionResult::Redirect { status, url } => {
230 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
231
232 let boxed: ResponseBody = Full::new(Bytes::new())
233 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
234 .boxed();
235 let response = Response::builder()
236 .status(status_code)
237 .header("Location", &url)
238 .header("X-Request-ID", &request_id)
239 .body(boxed)?;
240 #[cfg(feature = "logging")]
241 {
242 log_guard.set_bytes_sent(0);
243 log_guard.finish(status_code.as_u16());
244 }
245 Ok(response)
246 }
247 ActionResult::Respond { status, body } => {
248 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
249 let _body_len = body.len();
250
251 let boxed: ResponseBody = Full::new(Bytes::from(body))
252 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
253 .boxed();
254 let response = Response::builder()
255 .status(status_code)
256 .header("X-Request-ID", &request_id)
257 .body(boxed)?;
258 #[cfg(feature = "logging")]
259 {
260 log_guard.set_bytes_sent(_body_len);
261 log_guard.finish(status_code.as_u16());
262 }
263 Ok(response)
264 }
265 ActionResult::ReverseProxy {
266 backend_url,
267 path_to_send,
268 connect_timeout: _,
269 read_timeout,
270 } => {
271 let backend_with_proto =
273 if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
274 backend_url
275 } else {
276 format!("http://{}", backend_url)
277 };
278
279 let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
281 parts.path_and_query = Some(path_to_send.parse()?);
282 let new_uri = Uri::from_parts(parts)?;
283
284 *req.uri_mut() = new_uri.clone();
285
286 let original_host_header = req.headers().get(hyper::header::HOST).cloned();
288
289 req.headers_mut().remove(hyper::header::HOST);
291 if let Some(authority) = new_uri.authority() {
292 if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>()
293 {
294 req.headers_mut().insert(hyper::header::HOST, host_value);
295 }
296 }
297
298 if let Some(host_value) = original_host_header.clone() {
300 req.headers_mut().insert("X-Forwarded-Host", host_value);
301 }
302
303 req.headers_mut().insert(
305 "X-Forwarded-Proto",
306 hyper::header::HeaderValue::from_static(if is_tls { "https" } else { "http" }),
307 );
308
309 if let Ok(ip_value) =
311 hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
312 {
313 req.headers_mut().insert("X-Forwarded-For", ip_value);
314 }
315
316 req.headers_mut().remove(header::CONNECTION);
318 req.headers_mut().remove("accept-encoding");
319
320 let backend_timeout = read_timeout.unwrap_or(30);
322 match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
323 Ok(Ok(response)) => {
324 let status = response.status();
325 let headers = response.headers().clone();
326
327 let mut builder = Response::builder().status(status);
329
330 for (name, value) in headers.iter() {
332 if !is_hop_header(name) && name != header::CONTENT_LENGTH {
333 builder = builder.header(name, value);
334 }
335 }
336
337 let (_, incoming_body) = response.into_parts();
338 let boxed: ResponseBody = incoming_body
339 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
340 .boxed();
341
342 let response = builder.header("X-Request-ID", &request_id).body(boxed)?;
343 #[cfg(feature = "logging")]
344 log_guard.finish(status.as_u16());
345 Ok(response)
346 }
347 Ok(Err(e)) => {
348 error!("Backend connection failed: {:?}", e);
349 if e.is_connect() {
350 error!(" Reason: Connection refused - backend unavailable");
351 } else {
352 error!(" Reason: Other connection error");
353 }
354
355 let (response, _body_len) = error_response_with_id(
356 StatusCode::BAD_GATEWAY,
357 "Backend service unavailable",
358 &request_id,
359 );
360 #[cfg(feature = "logging")]
361 {
362 log_guard.set_bytes_sent(_body_len);
363 log_guard.finish(502);
364 }
365 Ok(response)
366 }
367 Err(_) => {
368 error!(
369 "Backend request timed out after {} seconds",
370 backend_timeout
371 );
372
373 let (response, _body_len) = error_response_with_id(
374 StatusCode::GATEWAY_TIMEOUT,
375 "Backend request timed out",
376 &request_id,
377 );
378 #[cfg(feature = "logging")]
379 {
380 log_guard.set_bytes_sent(_body_len);
381 log_guard.finish(504);
382 }
383 Ok(response)
384 }
385 }
386 }
387 }
388 };
389
390 #[cfg(feature = "logging")]
391 let future = future.instrument(span);
392
393 future.await
394}
395
396fn error_response_with_id(
400 status: StatusCode,
401 message: &str,
402 request_id: &str,
403) -> (Response<ResponseBody>, usize) {
404 let body = format!(
405 r#"<!DOCTYPE html>
406 <html>
407 <head><title>{} {}</title></head>
408 <body>
409 <h1>{} {}</h1>
410 <p>{}</p>
411 <hr>
412 <p><em>Rust Proxy Server</em></p>
413 </body>
414 </html>"#,
415 status.as_u16(),
416 status.canonical_reason().unwrap_or("Error"),
417 status.as_u16(),
418 status.canonical_reason().unwrap_or("Error"),
419 message
420 );
421
422 let body_len = body.len();
423 let full = Full::new(Bytes::from(body));
424 let boxed: ResponseBody = full
425 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
426 .boxed();
427
428 let mut builder = Response::builder()
429 .status(status)
430 .header("Content-Type", "text/html; charset=utf-8");
431
432 if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
433 builder = builder.header("X-Request-ID", val);
434 }
435
436 let response = builder.body(boxed).unwrap_or_else(|_| {
437 Response::new(
438 Full::new(Bytes::from("Internal Server Error"))
439 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
440 .boxed(),
441 )
442 });
443
444 (response, body_len)
445}
446
447pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
450 if let Some(prefix) = pattern.strip_suffix("/*") {
451 if path.starts_with(prefix) {
452 let remaining = path.strip_prefix(prefix).unwrap_or(path);
453 Some(remaining.to_string())
454 } else {
455 None
456 }
457 } else if pattern == path {
458 Some("/".to_string())
459 } else {
460 None
461 }
462}
463
464pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> {
485 if let Some(site) = config.sites.get(host) {
487 return Some(site);
488 }
489
490 let has_port = if host.starts_with('[') {
494 if let Some(bracket_end) = host.find(']') {
496 host[bracket_end..].contains(':')
497 } else {
498 false
499 }
500 } else {
501 host.contains(':')
502 };
503
504 if !has_port {
505 let default_port = if is_tls { 443 } else { 80 };
507 let candidate = format!("{}:{}", host, default_port);
508 if let Some(site) = config.sites.get(&candidate) {
509 return Some(site);
510 }
511
512 if is_tls {
514 let mut matches = config.sites.values().filter(|s| {
515 s.tls.is_some() && extract_hostname(&s.address).eq_ignore_ascii_case(host)
516 });
517 if let Some(site) = matches.next() {
518 if matches.next().is_none() {
519 return Some(site);
520 }
521 }
522 }
523 } else {
524 let hostname = if host.starts_with('[') {
526 let end = host.find(']').unwrap_or(host.len());
528 host[1..end].to_string()
529 } else {
530 host.rsplit(':').next_back().unwrap_or(host).to_string()
531 };
532 if let Some(site) = config.sites.get(&hostname) {
533 return Some(site);
534 }
535 }
536
537 None
538}
539
540#[cfg(test)]
541mod find_site_tests {
542 use super::*;
543 use std::collections::HashMap;
544
545 fn make_config(sites: Vec<(&str, bool)>) -> Config {
546 let mut map = HashMap::new();
547 for (addr, has_tls) in sites {
548 map.insert(
549 addr.to_string(),
550 crate::config::SiteConfig {
551 address: addr.to_string(),
552 directives: vec![],
553 tls: if has_tls {
554 Some(crate::config::TlsConfig {
555 cert_path: "/fake/cert.pem".to_string(),
556 key_path: "/fake/key.pem".to_string(),
557 })
558 } else {
559 None
560 },
561 },
562 );
563 }
564 Config { sites: map }
565 }
566
567 #[test]
568 fn test_exact_match() {
569 let config = make_config(vec![("example.com:443", true)]);
570 assert!(find_site(&config, "example.com:443", true).is_some());
571 }
572
573 #[test]
574 fn test_tls_host_without_port_finds_443() {
575 let config = make_config(vec![("example.com:443", true)]);
576 assert!(
578 find_site(&config, "example.com", true).is_some(),
579 "Should find example.com:443 when Host has no port and is_tls=true"
580 );
581 }
582
583 #[test]
584 fn test_http_host_without_port_finds_80() {
585 let config = make_config(vec![("example.com:80", false)]);
586 assert!(
588 find_site(&config, "example.com", false).is_some(),
589 "Should find example.com:80 when Host has no port and is_tls=false"
590 );
591 }
592
593 #[test]
594 fn test_tls_host_without_port_no_match_on_80() {
595 let config = make_config(vec![("example.com:80", false)]);
596 assert!(
598 find_site(&config, "example.com", true).is_none(),
599 "TLS on port 443 should not find :80 site"
600 );
601 }
602
603 #[test]
604 fn test_host_with_port_strips_port_fallback() {
605 let config = make_config(vec![("example.com", false)]);
606 assert!(
608 find_site(&config, "example.com:8080", false).is_some(),
609 "Should strip port from Host and find config without port"
610 );
611 }
612
613 #[test]
614 fn test_tls_host_without_port_finds_non_standard_port() {
615 let config = make_config(vec![("alpha.local:8443", true)]);
616 assert!(
617 find_site(&config, "alpha.local", true).is_some(),
618 "Should find alpha.local:8443 when Host has no port on TLS"
619 );
620 }
621
622 #[test]
623 fn test_no_match() {
624 let config = make_config(vec![("other.com:443", true)]);
625 assert!(find_site(&config, "example.com", true).is_none());
626 }
627}