1use crate::acme::CertManager;
7use crate::config::ProxyConfig;
8use crate::error::{ProxyError, Result};
9use crate::lb::LoadBalancer;
10use crate::network_policy::NetworkPolicyChecker;
11use crate::routes::{transform_path, ResolvedService, ServiceRegistry};
12use bytes::Bytes;
13use http::{header, Request, Response, Uri, Version};
14use http_body_util::{BodyExt, Full};
15use hyper::body::Incoming;
16use hyper::upgrade::OnUpgrade;
17use hyper_util::client::legacy::Client;
18use hyper_util::rt::{TokioExecutor, TokioIo};
19use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use tokio::net::TcpStream;
23use tower::Service;
24use tracing::{debug, error, info, warn};
25use zlayer_spec::ExposeType;
26
27const OVERLAY_NETWORK: (u8, u8) = (10, 200); fn is_overlay_ip(ip: IpAddr) -> bool {
33 match ip {
34 IpAddr::V4(v4) => {
35 let octets = v4.octets();
36 octets[0] == OVERLAY_NETWORK.0 && octets[1] == OVERLAY_NETWORK.1
37 }
38 IpAddr::V6(_) => false,
39 }
40}
41
42pub type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
44
45#[must_use]
47pub fn empty_body() -> BoxBody {
48 http_body_util::Empty::<Bytes>::new()
49 .map_err(|never| match never {})
50 .boxed()
51}
52
53pub fn full_body(bytes: impl Into<Bytes>) -> BoxBody {
55 Full::new(bytes.into())
56 .map_err(|never| match never {})
57 .boxed()
58}
59
60#[derive(Clone)]
62pub struct ReverseProxyService {
63 registry: Arc<ServiceRegistry>,
65 load_balancer: Arc<LoadBalancer>,
67 client: Client<hyper_util::client::legacy::connect::HttpConnector, BoxBody>,
69 config: Arc<ProxyConfig>,
71 remote_addr: Option<SocketAddr>,
73 is_tls: bool,
75 cert_manager: Option<Arc<CertManager>>,
77 network_policy_checker: Option<NetworkPolicyChecker>,
79 trusted_proxies: Arc<crate::trust::TrustedProxyList>,
84}
85
86impl ReverseProxyService {
87 pub fn new(
89 registry: Arc<ServiceRegistry>,
90 load_balancer: Arc<LoadBalancer>,
91 config: Arc<ProxyConfig>,
92 ) -> Self {
93 let client = Client::builder(TokioExecutor::new())
94 .pool_max_idle_per_host(config.pool.max_idle_per_backend)
95 .pool_idle_timeout(config.pool.idle_timeout)
96 .pool_timer(hyper_util::rt::TokioTimer::new())
97 .build_http();
98
99 Self {
100 registry,
101 load_balancer,
102 client,
103 config,
104 remote_addr: None,
105 is_tls: false,
106 cert_manager: None,
107 network_policy_checker: None,
108 trusted_proxies: Arc::new(crate::trust::TrustedProxyList::localhost_only()),
109 }
110 }
111
112 #[must_use]
114 pub fn with_remote_addr(mut self, addr: SocketAddr) -> Self {
115 self.remote_addr = Some(addr);
116 self
117 }
118
119 #[must_use]
121 pub fn with_tls(mut self, is_tls: bool) -> Self {
122 self.is_tls = is_tls;
123 self
124 }
125
126 #[must_use]
131 pub fn with_trusted_proxies(mut self, trusted: Arc<crate::trust::TrustedProxyList>) -> Self {
132 self.trusted_proxies = trusted;
133 self
134 }
135
136 #[must_use]
138 pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
139 self.cert_manager = Some(cm);
140 self
141 }
142
143 #[must_use]
145 pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
146 self.network_policy_checker = Some(checker);
147 self
148 }
149
150 #[must_use]
152 pub fn is_tls(&self) -> bool {
153 self.is_tls
154 }
155
156 #[allow(clippy::too_many_lines)]
168 pub async fn proxy_request(&self, mut req: Request<Incoming>) -> Result<Response<BoxBody>> {
169 let start = std::time::Instant::now();
170 let method = req.method().clone();
171 let uri = req.uri().clone();
172
173 let host = req
174 .headers()
175 .get(header::HOST)
176 .and_then(|h| h.to_str().ok())
177 .or_else(|| uri.host())
178 .map(std::string::ToString::to_string);
179
180 let path = uri.path().to_string();
181
182 if path.starts_with("/.well-known/acme-challenge/") {
184 if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
185 if !token.is_empty() {
186 if let Some(ref cm) = self.cert_manager {
187 if let Some(auth) = cm.get_challenge_response(token) {
188 return Ok(Response::builder()
189 .status(200)
190 .header("content-type", "text/plain")
191 .body(full_body(auth))
192 .unwrap());
193 }
194 }
195 }
196 }
197 }
198
199 if crate::tunnel::is_upgrade_request(&req) {
201 let resolved = self
203 .registry
204 .resolve(host.as_deref(), &path)
205 .await
206 .ok_or_else(|| ProxyError::RouteNotFound {
207 host: host.as_deref().unwrap_or("<none>").to_string(),
208 path: path.clone(),
209 })?;
210
211 if resolved.expose == ExposeType::Internal {
213 if let Some(addr) = self.remote_addr {
214 if !is_overlay_ip(addr.ip()) {
215 return Err(ProxyError::Forbidden(
216 "endpoint is internal-only".to_string(),
217 ));
218 }
219 }
220 }
221
222 if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
224 if !checker
225 .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
226 .await
227 {
228 return Err(ProxyError::Forbidden(format!(
229 "network policy denied access to service '{}'",
230 resolved.name
231 )));
232 }
233 }
234
235 let backend = self.load_balancer.select(&resolved.name).ok_or_else(|| {
236 ProxyError::NoHealthyBackends {
237 service: resolved.name.clone(),
238 }
239 })?;
240 let _guard = backend.track_connection();
241 let backend_addr = backend.addr;
242
243 info!(
244 method = %method,
245 host = ?host,
246 path = %path,
247 backend = %backend_addr,
248 service = %resolved.name,
249 "Forwarding upgrade request"
250 );
251
252 let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
254
255 let original_path = req.uri().path();
257 let transformed_path =
258 transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
259 let new_uri = format!(
260 "http://{}{}{}",
261 backend_addr,
262 transformed_path,
263 req.uri()
264 .query()
265 .map(|q| format!("?{q}"))
266 .unwrap_or_default()
267 );
268
269 let (orig_parts, _body) = req.into_parts();
271 let mut backend_parts = http::request::Builder::new()
272 .method(orig_parts.method.clone())
273 .uri(
274 new_uri
275 .parse::<Uri>()
276 .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?,
277 )
278 .body(())
279 .unwrap()
280 .into_parts()
281 .0;
282
283 for (name, value) in &orig_parts.headers {
285 backend_parts.headers.insert(name.clone(), value.clone());
286 }
287
288 crate::tunnel::copy_upgrade_headers(&orig_parts, &mut backend_parts);
290
291 self.add_forwarding_headers(&mut backend_parts);
293
294 let tcp_stream = TcpStream::connect(backend_addr).await.map_err(|e| {
296 error!(error = %e, backend = %backend_addr, "Backend upgrade connect failed");
297 ProxyError::BackendConnectionFailed {
298 backend: backend_addr,
299 reason: e.to_string(),
300 }
301 })?;
302 let io = TokioIo::new(tcp_stream);
303
304 let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
306 .preserve_header_case(true)
307 .handshake(io)
308 .await
309 .map_err(|e| {
310 error!(error = %e, backend = %backend_addr, "Backend upgrade handshake failed");
311 ProxyError::BackendRequestFailed(format!("Upgrade handshake failed: {e}"))
312 })?;
313
314 tokio::spawn(async move {
316 if let Err(e) = conn.with_upgrades().await {
317 error!(error = %e, "Backend upgrade connection driver error");
318 }
319 });
320
321 let backend_req =
323 Request::from_parts(backend_parts, http_body_util::Empty::<Bytes>::new());
324 let backend_response = sender.send_request(backend_req).await.map_err(|e| {
325 error!(error = %e, backend = %backend_addr, "Backend upgrade request failed");
326 ProxyError::BackendRequestFailed(e.to_string())
327 })?;
328
329 if backend_response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
330 let server_upgrade: OnUpgrade = hyper::upgrade::on(backend_response);
332
333 let mut resp_builder =
335 Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS);
336 if let Some(upgrade_val) = orig_parts.headers.get(header::UPGRADE) {
345 resp_builder = resp_builder.header(header::UPGRADE, upgrade_val.clone());
346 }
347 resp_builder = resp_builder.header(header::CONNECTION, "upgrade");
348
349 let client_response = resp_builder.body(empty_body()).map_err(|e| {
350 ProxyError::Internal(format!("Failed to build 101 response: {e}"))
351 })?;
352
353 tokio::spawn(async move {
355 if let Err(e) =
356 crate::tunnel::proxy_upgrade(client_upgrade, server_upgrade).await
357 {
358 debug!(error = %e, "Upgrade tunnel ended");
359 }
360 });
361
362 let (mut parts, body) = client_response.into_parts();
364 if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
365 parts.headers.insert("server-timing", hv);
366 }
367
368 return Ok(Response::from_parts(parts, body));
369 }
370
371 let (mut parts, body) = backend_response.into_parts();
373 let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
374
375 if self.is_tls && self.config.headers.hsts {
377 let value = if self.config.headers.hsts_subdomains {
378 format!(
379 "max-age={}; includeSubDomains",
380 self.config.headers.hsts_max_age
381 )
382 } else {
383 format!("max-age={}", self.config.headers.hsts_max_age)
384 };
385 if let Ok(hv) = value.parse() {
386 parts.headers.insert("strict-transport-security", hv);
387 }
388 }
389
390 if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
392 parts.headers.insert("server-timing", hv);
393 }
394
395 return Ok(Response::from_parts(parts, streaming_body));
396 }
397
398 debug!(method = %method, host = ?host, path = %path, "Routing request");
399
400 let resolved = self
402 .registry
403 .resolve(host.as_deref(), &path)
404 .await
405 .ok_or_else(|| ProxyError::RouteNotFound {
406 host: host.as_deref().unwrap_or("<none>").to_string(),
407 path: path.clone(),
408 })?;
409
410 if resolved.expose == ExposeType::Internal {
412 match self.remote_addr {
413 Some(addr) if !is_overlay_ip(addr.ip()) => {
414 warn!(
415 source = %addr.ip(),
416 service = %resolved.name,
417 "Rejected non-overlay source for internal endpoint"
418 );
419 return Err(ProxyError::Forbidden(
420 "endpoint is internal-only".to_string(),
421 ));
422 }
423 None => {
424 debug!(
425 service = %resolved.name,
426 "No remote_addr available; skipping overlay source check"
427 );
428 }
429 _ => {}
430 }
431 }
432
433 if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
435 if !checker
436 .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
437 .await
438 {
439 return Err(ProxyError::Forbidden(format!(
440 "network policy denied access to service '{}'",
441 resolved.name
442 )));
443 }
444 }
445
446 let backend = self.load_balancer.select(&resolved.name).ok_or_else(|| {
448 ProxyError::NoHealthyBackends {
449 service: resolved.name.clone(),
450 }
451 })?;
452 let _guard = backend.track_connection();
453 let backend_addr = backend.addr;
454
455 info!(
456 method = %method,
457 host = ?host,
458 path = %path,
459 backend = %backend_addr,
460 service = %resolved.name,
461 "Forwarding request"
462 );
463
464 let forwarded_req = self.build_forwarded_request(req, &backend_addr, &resolved)?;
466
467 let response = self.client.request(forwarded_req).await.map_err(|e| {
469 error!(error = %e, backend = %backend_addr, "Backend request failed");
470 ProxyError::BackendRequestFailed(e.to_string())
471 })?;
472
473 let (mut parts, body) = response.into_parts();
474 let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
475
476 if self.is_tls && self.config.headers.hsts {
478 let value = if self.config.headers.hsts_subdomains {
479 format!(
480 "max-age={}; includeSubDomains",
481 self.config.headers.hsts_max_age
482 )
483 } else {
484 format!("max-age={}", self.config.headers.hsts_max_age)
485 };
486 if let Ok(hv) = value.parse() {
487 parts.headers.insert("strict-transport-security", hv);
488 }
489 }
490
491 if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
493 parts.headers.insert("server-timing", hv);
494 }
495
496 Ok(Response::from_parts(parts, streaming_body))
497 }
498
499 fn build_forwarded_request(
500 &self,
501 req: Request<Incoming>,
502 backend: &SocketAddr,
503 resolved: &ResolvedService,
504 ) -> Result<Request<BoxBody>> {
505 let (mut parts, body) = req.into_parts();
506
507 let original_path = parts.uri.path();
509 let transformed_path =
510 transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
511
512 let new_uri = format!(
514 "http://{}{}{}",
515 backend,
516 transformed_path,
517 parts
518 .uri
519 .query()
520 .map(|q| format!("?{q}"))
521 .unwrap_or_default()
522 );
523
524 parts.uri = new_uri
525 .parse::<Uri>()
526 .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?;
527
528 self.add_forwarding_headers(&mut parts);
530
531 Self::remove_hop_by_hop_headers(&mut parts);
533
534 let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
535
536 let req = Request::from_parts(parts, streaming_body);
537 Ok(req)
538 }
539
540 fn add_forwarding_headers(&self, parts: &mut http::request::Parts) {
541 let config = &self.config.headers;
542
543 let peer_is_trusted = self
546 .remote_addr
547 .is_some_and(|addr| self.trusted_proxies.is_trusted(addr.ip()));
548
549 let effective_client_ip: Option<IpAddr> = if peer_is_trusted {
554 let cf_ip = parts
555 .headers
556 .get("cf-connecting-ip")
557 .and_then(|h| h.to_str().ok())
558 .and_then(|s| s.trim().parse::<IpAddr>().ok());
559
560 let xff_leftmost = parts
561 .headers
562 .get("x-forwarded-for")
563 .and_then(|h| h.to_str().ok())
564 .and_then(|s| s.split(',').next())
565 .and_then(|s| s.trim().parse::<IpAddr>().ok());
566
567 cf_ip
568 .or(xff_leftmost)
569 .or_else(|| self.remote_addr.map(|a| a.ip()))
570 } else {
571 self.remote_addr.map(|a| a.ip())
572 };
573
574 if config.x_forwarded_for {
576 if let Some(addr) = self.remote_addr {
577 let existing_xff = parts
578 .headers
579 .get("x-forwarded-for")
580 .and_then(|h| h.to_str().ok())
581 .map(std::string::ToString::to_string);
582
583 let new_value = if peer_is_trusted {
584 let real = effective_client_ip.unwrap_or_else(|| addr.ip()).to_string();
588 match existing_xff {
589 Some(chain) if !chain.trim().is_empty() => format!("{real}, {chain}"),
590 _ => real,
591 }
592 } else {
593 match existing_xff {
596 Some(chain) => format!("{}, {}", chain, addr.ip()),
597 None => addr.ip().to_string(),
598 }
599 };
600
601 if let Ok(value) = new_value.parse() {
602 parts.headers.insert("x-forwarded-for", value);
603 }
604 }
605 }
606
607 if config.x_forwarded_proto && parts.headers.get("x-forwarded-proto").is_none() {
609 let proto = if self.is_tls { "https" } else { "http" };
610 if let Ok(value) = proto.parse() {
611 parts.headers.insert("x-forwarded-proto", value);
612 }
613 }
614
615 if config.x_forwarded_host {
617 if let Some(host) = parts.headers.get(header::HOST).cloned() {
618 if parts.headers.get("x-forwarded-host").is_none() {
619 parts.headers.insert("x-forwarded-host", host);
620 }
621 }
622 }
623
624 if config.x_real_ip {
628 if let Some(ip) = effective_client_ip {
629 if parts.headers.get("x-real-ip").is_none() {
630 if let Ok(value) = ip.to_string().parse() {
631 parts.headers.insert("x-real-ip", value);
632 }
633 }
634 }
635 }
636
637 if config.via {
639 let proto_version = match parts.version {
640 Version::HTTP_09 => "0.9",
641 Version::HTTP_10 => "1.0",
642 Version::HTTP_2 => "2.0",
643 Version::HTTP_3 => "3.0",
644 _ => "1.1",
645 };
646
647 let via_value = format!("{} {}", proto_version, config.server_name);
648 let existing = parts
649 .headers
650 .get(header::VIA)
651 .and_then(|h| h.to_str().ok())
652 .map(|s| format!("{s}, {via_value}"))
653 .unwrap_or(via_value);
654
655 if let Ok(value) = existing.parse() {
656 parts.headers.insert(header::VIA, value);
657 }
658 }
659 }
660
661 fn remove_hop_by_hop_headers(parts: &mut http::request::Parts) {
662 const HOP_BY_HOP: &[&str] = &[
664 "connection",
665 "keep-alive",
666 "proxy-authenticate",
667 "proxy-authorization",
668 "te",
669 "trailer",
670 "transfer-encoding",
671 "upgrade",
672 ];
673
674 let connection_headers: Vec<String> = parts
676 .headers
677 .get(header::CONNECTION)
678 .and_then(|h| h.to_str().ok())
679 .map(|value| value.split(',').map(|s| s.trim().to_lowercase()).collect())
680 .unwrap_or_default();
681
682 for header_name in HOP_BY_HOP {
683 parts.headers.remove(*header_name);
684 }
685
686 for header_name in connection_headers {
688 parts.headers.remove(header_name.as_str());
689 }
690 }
691
692 pub fn error_response(error: &ProxyError) -> Response<BoxBody> {
708 let status = error.status_code();
709 let body = status.canonical_reason().map_or_else(
714 || status.as_str().to_string(),
715 |reason| format!("{} {reason}", status.as_u16()),
716 );
717
718 Response::builder()
719 .status(status)
720 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
721 .body(full_body(body))
722 .unwrap()
723 }
724}
725
726impl Service<Request<Incoming>> for ReverseProxyService {
727 type Response = Response<BoxBody>;
728 type Error = ProxyError;
729 type Future = std::pin::Pin<
730 Box<
731 dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
732 + Send,
733 >,
734 >;
735
736 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
737 Poll::Ready(Ok(()))
738 }
739
740 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
741 let this = self.clone();
742 Box::pin(async move { this.proxy_request(req).await })
743 }
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[test]
751 fn test_error_response() {
752 let error = ProxyError::RouteNotFound {
753 host: "example.com".to_string(),
754 path: "/api".to_string(),
755 };
756
757 let response = ReverseProxyService::error_response(&error);
758 assert_eq!(response.status(), http::StatusCode::NOT_FOUND);
759 }
760
761 #[test]
762 fn test_hop_by_hop_headers() {
763 let mut parts = http::request::Builder::new()
764 .method("GET")
765 .uri("/test")
766 .header("connection", "keep-alive, x-custom")
767 .header("keep-alive", "timeout=5")
768 .header("x-custom", "value")
769 .header("x-other", "value")
770 .body(())
771 .unwrap()
772 .into_parts()
773 .0;
774
775 ReverseProxyService::remove_hop_by_hop_headers(&mut parts);
776
777 assert!(parts.headers.get("connection").is_none());
778 assert!(parts.headers.get("keep-alive").is_none());
779 assert!(parts.headers.get("x-custom").is_none());
780 assert!(parts.headers.get("x-other").is_some());
782 }
783
784 #[test]
785 fn test_is_overlay_ip_accepts_overlay_range() {
786 assert!(is_overlay_ip("10.200.0.1".parse().unwrap()));
788 assert!(is_overlay_ip("10.200.255.254".parse().unwrap()));
789 assert!(is_overlay_ip("10.200.1.100".parse().unwrap()));
790 }
791
792 #[test]
793 fn test_is_overlay_ip_rejects_non_overlay() {
794 assert!(!is_overlay_ip("192.168.1.1".parse().unwrap()));
796 assert!(!is_overlay_ip("10.0.0.1".parse().unwrap()));
797 assert!(!is_overlay_ip("10.201.0.1".parse().unwrap()));
798 assert!(!is_overlay_ip("172.16.0.1".parse().unwrap()));
799 assert!(!is_overlay_ip("8.8.8.8".parse().unwrap()));
800 }
801
802 #[test]
803 fn test_is_overlay_ip_rejects_ipv6() {
804 assert!(!is_overlay_ip("::1".parse().unwrap()));
805 assert!(!is_overlay_ip("fe80::1".parse().unwrap()));
806 }
807
808 #[test]
809 fn test_forbidden_error_response() {
810 let error = ProxyError::Forbidden("endpoint 'ws' is internal-only".to_string());
811 let response = ReverseProxyService::error_response(&error);
812 assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
813 }
814
815 use crate::trust::TrustedProxyList;
818
819 fn build_svc(peer: SocketAddr, trusted: TrustedProxyList) -> ReverseProxyService {
820 let registry = Arc::new(ServiceRegistry::new());
821 let load_balancer = Arc::new(LoadBalancer::new());
822 let config = Arc::new(ProxyConfig::default());
823 ReverseProxyService::new(registry, load_balancer, config)
824 .with_remote_addr(peer)
825 .with_trusted_proxies(Arc::new(trusted))
826 }
827
828 fn parts_with_headers(headers: &[(&str, &str)]) -> http::request::Parts {
829 let mut builder = http::request::Builder::new().method("GET").uri("/");
830 for (k, v) in headers {
831 builder = builder.header(*k, *v);
832 }
833 builder.body(()).unwrap().into_parts().0
834 }
835
836 #[test]
837 fn trusted_peer_cf_connecting_ip_is_honored() {
838 let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
841 let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
842 let svc = build_svc(peer, trusted);
843
844 let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
845 svc.add_forwarding_headers(&mut parts);
846
847 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.7");
848 let xff = parts
849 .headers
850 .get("x-forwarded-for")
851 .unwrap()
852 .to_str()
853 .unwrap();
854 assert!(
855 xff.starts_with("198.51.100.7"),
856 "XFF should start with real client IP, got {xff}"
857 );
858 }
859
860 #[test]
861 fn trusted_peer_xff_leftmost_is_honored_when_no_cf_header() {
862 let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
865 let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
866 let svc = build_svc(peer, trusted);
867
868 let mut parts = parts_with_headers(&[("x-forwarded-for", "198.51.100.9, 10.0.0.1")]);
869 svc.add_forwarding_headers(&mut parts);
870
871 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.9");
872 let xff = parts
873 .headers
874 .get("x-forwarded-for")
875 .unwrap()
876 .to_str()
877 .unwrap();
878 assert!(
880 xff.starts_with("198.51.100.9"),
881 "XFF should start with leftmost real client, got {xff}"
882 );
883 assert!(
884 xff.contains("10.0.0.1"),
885 "original chain should survive: {xff}"
886 );
887 }
888
889 #[test]
890 fn untrusted_peer_cf_connecting_ip_is_ignored() {
891 let peer: SocketAddr = "8.8.8.8:443".parse().unwrap();
894 let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
895 let svc = build_svc(peer, trusted);
896
897 let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
898 svc.add_forwarding_headers(&mut parts);
899
900 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "8.8.8.8");
901 let xff = parts
902 .headers
903 .get("x-forwarded-for")
904 .unwrap()
905 .to_str()
906 .unwrap();
907 assert!(
909 xff.ends_with("8.8.8.8"),
910 "XFF for untrusted peer should end with peer IP, got {xff}"
911 );
912 }
913
914 #[test]
915 fn no_headers_uses_peer_ip() {
916 let peer: SocketAddr = "198.51.100.250:443".parse().unwrap();
919 let trusted = TrustedProxyList::localhost_only();
920 let svc = build_svc(peer, trusted);
921
922 let mut parts = parts_with_headers(&[]);
923 svc.add_forwarding_headers(&mut parts);
924
925 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.250");
926 assert_eq!(
927 parts.headers.get("x-forwarded-for").unwrap(),
928 "198.51.100.250"
929 );
930 }
931}