Skip to main content

zlayer_proxy/
service.rs

1//! Reverse proxy service implementation
2//!
3//! This module provides the core proxy service that handles request forwarding.
4//! It uses the `ServiceRegistry` for route resolution and backend selection.
5
6use crate::acme::CertManager;
7use crate::config::ProxyConfig;
8use crate::error::{ProxyError, Result};
9use crate::lb::LoadBalancer;
10use crate::network_policy::NetworkPolicyChecker;
11use crate::routes::{transform_path, ResolvedService, ServiceRegistry};
12use bytes::Bytes;
13use http::{header, Request, Response, Uri, Version};
14use http_body_util::{BodyExt, Full};
15use hyper::body::Incoming;
16use hyper::upgrade::OnUpgrade;
17use hyper_util::client::legacy::Client;
18use hyper_util::rt::{TokioExecutor, TokioIo};
19use std::collections::VecDeque;
20use std::net::{IpAddr, SocketAddr};
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::{Duration, Instant};
24use tokio::net::TcpStream;
25use tokio::sync::Mutex;
26use tower::Service;
27use tracing::{debug, error, info, warn};
28use zlayer_spec::ExposeType;
29
30/// Default ceiling for how long [`ReverseProxyService::proxy_request`] will hold
31/// a request while it waits for an [`Activator`] to bring a scaled-to-zero
32/// service back up before falling back to the existing `503`.
33const ACTIVATE_DEADLINE: Duration = Duration::from_secs(30);
34
35/// Polling step used while waiting for a backend to become available after an
36/// [`Activator::activate`] call. Small enough to feel responsive, large enough
37/// not to spin the load balancer.
38const ACTIVATE_POLL_STEP: Duration = Duration::from_millis(200);
39
40/// Width of the sliding window over which [`RpsRegistry`] computes
41/// requests-per-second.
42const RPS_WINDOW: Duration = Duration::from_secs(10);
43
44/// On-demand activation hook for scale-to-zero services.
45///
46/// When the proxy resolves a route whose backend group currently has **no
47/// healthy backends** (the scale-to-zero idle state), it calls
48/// [`Activator::activate`] with the resolved load-balancer group name. The
49/// implementation is expected to trigger a scale-up (e.g. scale the service to
50/// its activation floor) and return once it has *initiated* the scale; the
51/// proxy then re-polls backend selection on a bounded backoff loop, forwarding
52/// the held request the moment a healthy backend appears.
53///
54/// Returning `Err` is non-fatal: the proxy logs it and falls through to the
55/// existing no-healthy-backends `503` path, so a flaky activator never blocks
56/// the request indefinitely beyond the proxy's own deadline.
57#[async_trait::async_trait]
58pub trait Activator: Send + Sync {
59    /// Trigger activation (scale-up) for the load-balancer group `service`.
60    ///
61    /// `service` is the resolved LB group name (the same string passed to
62    /// [`LoadBalancer::select`](crate::lb::LoadBalancer::select)), NOT
63    /// necessarily the bare service name. Implementations that need the bare
64    /// service name should derive it from this key.
65    ///
66    /// # Errors
67    ///
68    /// Returns a human-readable error string if activation could not be
69    /// initiated. The proxy treats this as non-fatal and falls back to `503`.
70    async fn activate(&self, service: &str) -> std::result::Result<(), String>;
71}
72
73/// Per-service sliding-window request-rate counter.
74///
75/// Records one timestamp per successfully-routed request and reports the
76/// requests-per-second rate over a fixed [`RPS_WINDOW`]. Cheap to clone (it is
77/// an `Arc`-friendly wrapper around interior-mutable state) and safe to share
78/// across the proxy's per-connection service clones.
79///
80/// This is the real per-service RPS signal the autoscaler consumes: the proxy
81/// records every routed request via [`RpsRegistry::record`], and the scheduler
82/// reads [`RpsRegistry::rps`] for the same service key to drive request-rate
83/// scaling.
84#[derive(Debug, Default)]
85pub struct RpsRegistry {
86    /// Per-service ring of recent request timestamps, pruned to [`RPS_WINDOW`].
87    services: Mutex<std::collections::HashMap<String, VecDeque<Instant>>>,
88}
89
90impl RpsRegistry {
91    /// Create an empty registry.
92    #[must_use]
93    pub fn new() -> Self {
94        Self::default()
95    }
96
97    /// Record a single request against `service` at the current instant,
98    /// pruning timestamps older than the window.
99    pub async fn record(&self, service: &str) {
100        let now = Instant::now();
101        let cutoff = now.checked_sub(RPS_WINDOW).unwrap_or(now);
102        let mut map = self.services.lock().await;
103        let ring = map.entry(service.to_string()).or_default();
104        ring.push_back(now);
105        while ring.front().is_some_and(|t| *t < cutoff) {
106            ring.pop_front();
107        }
108    }
109
110    /// Current requests-per-second for `service`, averaged over the sliding
111    /// window. Returns `0.0` for an unknown or idle service.
112    pub async fn rps(&self, service: &str) -> f64 {
113        let now = Instant::now();
114        let cutoff = now.checked_sub(RPS_WINDOW).unwrap_or(now);
115        let mut map = self.services.lock().await;
116        let Some(ring) = map.get_mut(service) else {
117            return 0.0;
118        };
119        while ring.front().is_some_and(|t| *t < cutoff) {
120            ring.pop_front();
121        }
122        let count = ring.len();
123        #[allow(clippy::cast_precision_loss)]
124        {
125            count as f64 / RPS_WINDOW.as_secs_f64()
126        }
127    }
128
129    /// Snapshot of the current per-service RPS for every service seen within
130    /// the window. Services whose window has fully drained report `0.0`.
131    pub async fn snapshot(&self) -> std::collections::HashMap<String, f64> {
132        let now = Instant::now();
133        let cutoff = now.checked_sub(RPS_WINDOW).unwrap_or(now);
134        let window_secs = RPS_WINDOW.as_secs_f64();
135        let mut map = self.services.lock().await;
136        let mut out = std::collections::HashMap::with_capacity(map.len());
137        for (name, ring) in map.iter_mut() {
138            while ring.front().is_some_and(|t| *t < cutoff) {
139                ring.pop_front();
140            }
141            #[allow(clippy::cast_precision_loss)]
142            let rps = ring.len() as f64 / window_secs;
143            out.insert(name.clone(), rps);
144        }
145        out
146    }
147}
148
149/// The overlay network CIDR used for internal service communication.
150/// Source IPs outside this range are rejected for internal-only routes.
151const OVERLAY_NETWORK: (u8, u8) = (10, 200); // 10.200.0.0/16
152
153/// Check whether an IP address belongs to the overlay network (10.200.0.0/16).
154fn is_overlay_ip(ip: IpAddr) -> bool {
155    match ip {
156        IpAddr::V4(v4) => {
157            let octets = v4.octets();
158            octets[0] == OVERLAY_NETWORK.0 && octets[1] == OVERLAY_NETWORK.1
159        }
160        IpAddr::V6(_) => false,
161    }
162}
163
164/// Body type for outgoing responses
165pub type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
166
167/// Empty body utility
168#[must_use]
169pub fn empty_body() -> BoxBody {
170    http_body_util::Empty::<Bytes>::new()
171        .map_err(|never| match never {})
172        .boxed()
173}
174
175/// Full body utility
176pub fn full_body(bytes: impl Into<Bytes>) -> BoxBody {
177    Full::new(bytes.into())
178        .map_err(|never| match never {})
179        .boxed()
180}
181
182/// The reverse proxy service
183#[derive(Clone)]
184pub struct ReverseProxyService {
185    /// Service registry for route resolution
186    registry: Arc<ServiceRegistry>,
187    /// Load balancer for backend selection
188    load_balancer: Arc<LoadBalancer>,
189    /// HTTP client for backend requests
190    client: Client<hyper_util::client::legacy::connect::HttpConnector, BoxBody>,
191    /// Proxy configuration
192    config: Arc<ProxyConfig>,
193    /// Client remote address (set per-request)
194    remote_addr: Option<SocketAddr>,
195    /// Whether the connection is over TLS
196    is_tls: bool,
197    /// Certificate manager for ACME challenge responses
198    cert_manager: Option<Arc<CertManager>>,
199    /// Optional network policy checker for access control enforcement
200    network_policy_checker: Option<NetworkPolicyChecker>,
201    /// Trusted upstream proxies. Requests whose TCP peer IP is in this list
202    /// may set `CF-Connecting-IP` / `X-Forwarded-For` and be believed. When no
203    /// explicit list is provided, defaults to `TrustedProxyList::localhost_only()`
204    /// — a safe default for nodes that accidentally receive direct requests.
205    trusted_proxies: Arc<crate::trust::TrustedProxyList>,
206    /// Optional on-demand activator for scale-to-zero services. When set and a
207    /// resolved route has no healthy backend, the request is held while the
208    /// activator scales the service up (see [`Activator`]).
209    activator: Option<Arc<dyn Activator>>,
210    /// Optional per-service request-rate counter. When set, every
211    /// successfully-routed request is recorded so the autoscaler can read a
212    /// real RPS signal (see [`RpsRegistry`]).
213    rps_registry: Option<Arc<RpsRegistry>>,
214}
215
216impl ReverseProxyService {
217    /// Create a new reverse proxy service
218    pub fn new(
219        registry: Arc<ServiceRegistry>,
220        load_balancer: Arc<LoadBalancer>,
221        config: Arc<ProxyConfig>,
222    ) -> Self {
223        let client = Client::builder(TokioExecutor::new())
224            .pool_max_idle_per_host(config.pool.max_idle_per_backend)
225            .pool_idle_timeout(config.pool.idle_timeout)
226            .pool_timer(hyper_util::rt::TokioTimer::new())
227            .build_http();
228
229        Self {
230            registry,
231            load_balancer,
232            client,
233            config,
234            remote_addr: None,
235            is_tls: false,
236            cert_manager: None,
237            network_policy_checker: None,
238            trusted_proxies: Arc::new(crate::trust::TrustedProxyList::localhost_only()),
239            activator: None,
240            rps_registry: None,
241        }
242    }
243
244    /// Set the remote client address for this request
245    #[must_use]
246    pub fn with_remote_addr(mut self, addr: SocketAddr) -> Self {
247        self.remote_addr = Some(addr);
248        self
249    }
250
251    /// Mark this connection as being over TLS
252    #[must_use]
253    pub fn with_tls(mut self, is_tls: bool) -> Self {
254        self.is_tls = is_tls;
255        self
256    }
257
258    /// Override the trusted-proxy list (default: `localhost_only`).
259    ///
260    /// Peers in this list are believed when they set `CF-Connecting-IP` or
261    /// `X-Forwarded-For` headers identifying the real client IP.
262    #[must_use]
263    pub fn with_trusted_proxies(mut self, trusted: Arc<crate::trust::TrustedProxyList>) -> Self {
264        self.trusted_proxies = trusted;
265        self
266    }
267
268    /// Set the certificate manager for ACME challenge interception
269    #[must_use]
270    pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
271        self.cert_manager = Some(cm);
272        self
273    }
274
275    /// Set the network policy checker for access control enforcement
276    #[must_use]
277    pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
278        self.network_policy_checker = Some(checker);
279        self
280    }
281
282    /// Set the on-demand activator for scale-to-zero services.
283    ///
284    /// With an activator installed, a request to a resolved route whose backend
285    /// group has no healthy backend triggers [`Activator::activate`] and is held
286    /// (up to [`ACTIVATE_DEADLINE`]) until a backend appears, instead of
287    /// immediately returning `503`.
288    #[must_use]
289    pub fn with_activator(mut self, activator: Arc<dyn Activator>) -> Self {
290        self.activator = Some(activator);
291        self
292    }
293
294    /// Set the per-service request-rate registry.
295    ///
296    /// With a registry installed, every successfully-routed request is recorded
297    /// so the autoscaler can read a real per-service RPS signal.
298    #[must_use]
299    pub fn with_rps_registry(mut self, rps_registry: Arc<RpsRegistry>) -> Self {
300        self.rps_registry = Some(rps_registry);
301        self
302    }
303
304    /// Check if this connection is over TLS
305    #[must_use]
306    pub fn is_tls(&self) -> bool {
307        self.is_tls
308    }
309
310    /// Handle an incoming HTTP request
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if route resolution fails, no healthy backends are
315    /// available, or the backend request fails.
316    ///
317    /// # Panics
318    ///
319    /// Panics if building a well-formed HTTP response for an ACME challenge
320    /// or upgrade reply fails (indicates a bug, not a runtime condition).
321    #[allow(clippy::too_many_lines)]
322    pub async fn proxy_request(&self, mut req: Request<Incoming>) -> Result<Response<BoxBody>> {
323        let start = std::time::Instant::now();
324        let method = req.method().clone();
325        let uri = req.uri().clone();
326
327        let host = req
328            .headers()
329            .get(header::HOST)
330            .and_then(|h| h.to_str().ok())
331            .or_else(|| uri.host())
332            .map(std::string::ToString::to_string);
333
334        let path = uri.path().to_string();
335
336        // ACME HTTP-01 challenge interception. This is TERMINAL: any request
337        // whose path is under /.well-known/acme-challenge/ is fully handled
338        // here and never falls through to vhost routing (which would return a
339        // confusing 403 Forbidden for an HTTPS-only host). A stored token
340        // returns 200 with the key authorization; an unknown/expired/empty
341        // token (or absent cert manager) returns a clean 404.
342        if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
343            if !token.is_empty() {
344                if let Some(ref cm) = self.cert_manager {
345                    if let Some(auth) = cm.get_challenge_response(token) {
346                        return Ok(Response::builder()
347                            .status(200)
348                            .header("content-type", "text/plain")
349                            .body(full_body(auth))
350                            .unwrap());
351                    }
352                }
353            }
354            tracing::warn!(
355                token = %token,
356                cert_manager = self.cert_manager.is_some(),
357                host = host.as_deref().unwrap_or("<none>"),
358                "ACME HTTP-01 challenge token not found; returning 404"
359            );
360            return Ok(Response::builder()
361                .status(404)
362                .header("content-type", "text/plain")
363                .body(full_body("ACME challenge token not found"))
364                .unwrap());
365        }
366
367        // Check for WebSocket/HTTP upgrade
368        if crate::tunnel::is_upgrade_request(&req) {
369            // Resolve to get backend for upgrade
370            let resolved = self
371                .registry
372                .resolve(host.as_deref(), &path)
373                .await
374                .ok_or_else(|| ProxyError::RouteNotFound {
375                    host: host.as_deref().unwrap_or("<none>").to_string(),
376                    path: path.clone(),
377                })?;
378
379            // Enforce internal endpoints
380            if resolved.expose == ExposeType::Internal {
381                if let Some(addr) = self.remote_addr {
382                    if !is_overlay_ip(addr.ip()) {
383                        return Err(ProxyError::Forbidden(
384                            "endpoint is internal-only".to_string(),
385                        ));
386                    }
387                }
388            }
389
390            // Enforce network policy access rules
391            if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
392                if !checker
393                    .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
394                    .await
395                {
396                    return Err(ProxyError::Forbidden(format!(
397                        "network policy denied access to service '{}'",
398                        resolved.name
399                    )));
400                }
401            }
402
403            let backend = self
404                .select_or_activate(&resolved.name)
405                .await
406                .ok_or_else(|| ProxyError::NoHealthyBackends {
407                    service: resolved.name.clone(),
408                })?;
409            let _guard = backend.track_connection();
410            let backend_addr = backend.addr;
411
412            // Record the routed request for per-service RPS metrics.
413            if let Some(rps) = &self.rps_registry {
414                rps.record(&resolved.name).await;
415            }
416
417            info!(
418                method = %method,
419                host = ?host,
420                path = %path,
421                backend = %backend_addr,
422                service = %resolved.name,
423                "Forwarding upgrade request"
424            );
425
426            // Extract the client's OnUpgrade future BEFORE consuming the request
427            let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
428
429            // Build the backend URI
430            let original_path = req.uri().path();
431            let transformed_path =
432                transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
433            let new_uri = format!(
434                "http://{}{}{}",
435                backend_addr,
436                transformed_path,
437                req.uri()
438                    .query()
439                    .map(|q| format!("?{q}"))
440                    .unwrap_or_default()
441            );
442
443            // Build backend request, preserving upgrade headers
444            let (orig_parts, _body) = req.into_parts();
445            let mut backend_parts = http::request::Builder::new()
446                .method(orig_parts.method.clone())
447                .uri(
448                    new_uri
449                        .parse::<Uri>()
450                        .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?,
451                )
452                .body(())
453                .unwrap()
454                .into_parts()
455                .0;
456
457            // Copy all original headers first (preserving Host, etc.)
458            for (name, value) in &orig_parts.headers {
459                backend_parts.headers.insert(name.clone(), value.clone());
460            }
461
462            // Copy upgrade-specific headers (Connection, Upgrade, Sec-WebSocket-*)
463            crate::tunnel::copy_upgrade_headers(&orig_parts, &mut backend_parts);
464
465            // Add forwarding headers
466            self.add_forwarding_headers(&mut backend_parts);
467
468            // Connect directly to backend (bypass connection pool for long-lived upgrades)
469            let tcp_stream = TcpStream::connect(backend_addr).await.map_err(|e| {
470                error!(error = %e, backend = %backend_addr, "Backend upgrade connect failed");
471                ProxyError::BackendConnectionFailed {
472                    backend: backend_addr,
473                    reason: e.to_string(),
474                }
475            })?;
476            let io = TokioIo::new(tcp_stream);
477
478            // Perform HTTP/1.1 handshake preserving header case
479            let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
480                .preserve_header_case(true)
481                .handshake(io)
482                .await
483                .map_err(|e| {
484                    error!(error = %e, backend = %backend_addr, "Backend upgrade handshake failed");
485                    ProxyError::BackendRequestFailed(format!("Upgrade handshake failed: {e}"))
486                })?;
487
488            // Spawn the connection driver
489            tokio::spawn(async move {
490                if let Err(e) = conn.with_upgrades().await {
491                    error!(error = %e, "Backend upgrade connection driver error");
492                }
493            });
494
495            // Send the request to the backend
496            let backend_req =
497                Request::from_parts(backend_parts, http_body_util::Empty::<Bytes>::new());
498            let backend_response = sender.send_request(backend_req).await.map_err(|e| {
499                error!(error = %e, backend = %backend_addr, "Backend upgrade request failed");
500                ProxyError::BackendRequestFailed(e.to_string())
501            })?;
502
503            if backend_response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
504                // Get the server's OnUpgrade future
505                let server_upgrade: OnUpgrade = hyper::upgrade::on(backend_response);
506
507                // Build 101 response to send back to the client
508                let mut resp_builder =
509                    Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS);
510                // Note: we need to construct the response manually since we consumed
511                // the backend response to get OnUpgrade. Copy relevant headers.
512                // The hyper::upgrade::on() for the response does NOT consume it —
513                // it was consumed. We need to return a 101 with appropriate headers.
514                // Actually, hyper::upgrade::on() takes the response by value, so we
515                // must build our own 101 response for the client.
516
517                // For the client response, set Connection: upgrade and Upgrade headers
518                if let Some(upgrade_val) = orig_parts.headers.get(header::UPGRADE) {
519                    resp_builder = resp_builder.header(header::UPGRADE, upgrade_val.clone());
520                }
521                resp_builder = resp_builder.header(header::CONNECTION, "upgrade");
522
523                let client_response = resp_builder.body(empty_body()).map_err(|e| {
524                    ProxyError::Internal(format!("Failed to build 101 response: {e}"))
525                })?;
526
527                // Spawn background task to bridge the upgraded connections
528                tokio::spawn(async move {
529                    if let Err(e) =
530                        crate::tunnel::proxy_upgrade(client_upgrade, server_upgrade).await
531                    {
532                        debug!(error = %e, "Upgrade tunnel ended");
533                    }
534                });
535
536                // Add timing header to the 101 response
537                let (mut parts, body) = client_response.into_parts();
538                if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
539                    parts.headers.insert("server-timing", hv);
540                }
541
542                return Ok(Response::from_parts(parts, body));
543            }
544
545            // Backend didn't upgrade — stream the response as-is
546            let (mut parts, body) = backend_response.into_parts();
547            let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
548
549            // Add HSTS header for TLS connections
550            if self.is_tls && self.config.headers.hsts {
551                let value = if self.config.headers.hsts_subdomains {
552                    format!(
553                        "max-age={}; includeSubDomains",
554                        self.config.headers.hsts_max_age
555                    )
556                } else {
557                    format!("max-age={}", self.config.headers.hsts_max_age)
558                };
559                if let Ok(hv) = value.parse() {
560                    parts.headers.insert("strict-transport-security", hv);
561                }
562            }
563
564            // Add Server-Timing header
565            if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
566                parts.headers.insert("server-timing", hv);
567            }
568
569            return Ok(Response::from_parts(parts, streaming_body));
570        }
571
572        debug!(method = %method, host = ?host, path = %path, "Routing request");
573
574        // Resolve route
575        let resolved = self
576            .registry
577            .resolve(host.as_deref(), &path)
578            .await
579            .ok_or_else(|| ProxyError::RouteNotFound {
580                host: host.as_deref().unwrap_or("<none>").to_string(),
581                path: path.clone(),
582            })?;
583
584        // Enforce internal endpoints
585        if resolved.expose == ExposeType::Internal {
586            match self.remote_addr {
587                Some(addr) if !is_overlay_ip(addr.ip()) => {
588                    warn!(
589                        source = %addr.ip(),
590                        service = %resolved.name,
591                        "Rejected non-overlay source for internal endpoint"
592                    );
593                    return Err(ProxyError::Forbidden(
594                        "endpoint is internal-only".to_string(),
595                    ));
596                }
597                None => {
598                    debug!(
599                        service = %resolved.name,
600                        "No remote_addr available; skipping overlay source check"
601                    );
602                }
603                _ => {}
604            }
605        }
606
607        // Enforce network policy access rules
608        if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
609            if !checker
610                .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
611                .await
612            {
613                return Err(ProxyError::Forbidden(format!(
614                    "network policy denied access to service '{}'",
615                    resolved.name
616                )));
617            }
618        }
619
620        // Select backend via load balancer, activating a scaled-to-zero
621        // service on demand if needed (and an activator is installed).
622        let backend = self
623            .select_or_activate(&resolved.name)
624            .await
625            .ok_or_else(|| ProxyError::NoHealthyBackends {
626                service: resolved.name.clone(),
627            })?;
628        let _guard = backend.track_connection();
629        let backend_addr = backend.addr;
630
631        // Record the routed request for per-service RPS metrics.
632        if let Some(rps) = &self.rps_registry {
633            rps.record(&resolved.name).await;
634        }
635
636        info!(
637            method = %method,
638            host = ?host,
639            path = %path,
640            backend = %backend_addr,
641            service = %resolved.name,
642            "Forwarding request"
643        );
644
645        // Build forwarded request
646        let forwarded_req = self.build_forwarded_request(req, &backend_addr, &resolved)?;
647
648        // Forward to backend
649        let response = self.client.request(forwarded_req).await.map_err(|e| {
650            error!(error = %e, backend = %backend_addr, "Backend request failed");
651            ProxyError::BackendRequestFailed(e.to_string())
652        })?;
653
654        let (mut parts, body) = response.into_parts();
655        let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
656
657        // Add HSTS header for TLS connections
658        if self.is_tls && self.config.headers.hsts {
659            let value = if self.config.headers.hsts_subdomains {
660                format!(
661                    "max-age={}; includeSubDomains",
662                    self.config.headers.hsts_max_age
663                )
664            } else {
665                format!("max-age={}", self.config.headers.hsts_max_age)
666            };
667            if let Ok(hv) = value.parse() {
668                parts.headers.insert("strict-transport-security", hv);
669            }
670        }
671
672        // Add Server-Timing header
673        if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
674            parts.headers.insert("server-timing", hv);
675        }
676
677        Ok(Response::from_parts(parts, streaming_body))
678    }
679
680    /// Select a healthy backend for `service`, activating a scaled-to-zero
681    /// service on demand if no backend is available and an [`Activator`] is
682    /// installed.
683    ///
684    /// Behavior:
685    /// - If [`LoadBalancer::select`](crate::lb::LoadBalancer::select) returns a
686    ///   backend, it is returned immediately (the common path; zero added cost).
687    /// - Otherwise, with no activator installed, `None` is returned at once so
688    ///   the caller's existing `503` path is preserved unchanged.
689    /// - With an activator installed, [`Activator::activate`] is called once and
690    ///   then backend selection is re-polled every [`ACTIVATE_POLL_STEP`] until a
691    ///   backend appears or [`ACTIVATE_DEADLINE`] elapses, at which point `None`
692    ///   is returned and the caller falls back to `503`.
693    ///
694    /// The caller still holds the (un-consumed) request body across this await,
695    /// so a successful activation forwards the original request normally.
696    async fn select_or_activate(&self, service: &str) -> Option<Arc<crate::lb::Backend>> {
697        if let Some(backend) = self.load_balancer.select(service) {
698            return Some(backend);
699        }
700
701        let Some(activator) = &self.activator else {
702            return None;
703        };
704
705        info!(
706            service = %service,
707            "No healthy backend; invoking activator (scale-to-zero wake-up)"
708        );
709        if let Err(e) = activator.activate(service).await {
710            // Non-fatal: log and fall through to the bounded re-poll. The
711            // service may still be coming up from a concurrent activation.
712            warn!(service = %service, error = %e, "Activator returned an error; will still poll for a backend");
713        }
714
715        let deadline = Instant::now() + ACTIVATE_DEADLINE;
716        loop {
717            if let Some(backend) = self.load_balancer.select(service) {
718                info!(service = %service, "Backend became available after activation");
719                return Some(backend);
720            }
721            if Instant::now() >= deadline {
722                warn!(
723                    service = %service,
724                    "Activation deadline elapsed without a healthy backend; falling back to 503"
725                );
726                return None;
727            }
728            tokio::time::sleep(ACTIVATE_POLL_STEP).await;
729        }
730    }
731
732    fn build_forwarded_request(
733        &self,
734        req: Request<Incoming>,
735        backend: &SocketAddr,
736        resolved: &ResolvedService,
737    ) -> Result<Request<BoxBody>> {
738        let (mut parts, body) = req.into_parts();
739
740        // Transform the path if needed
741        let original_path = parts.uri.path();
742        let transformed_path =
743            transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
744
745        // Build new URI for backend
746        let new_uri = format!(
747            "http://{}{}{}",
748            backend,
749            transformed_path,
750            parts
751                .uri
752                .query()
753                .map(|q| format!("?{q}"))
754                .unwrap_or_default()
755        );
756
757        parts.uri = new_uri
758            .parse::<Uri>()
759            .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?;
760
761        // Add forwarding headers
762        self.add_forwarding_headers(&mut parts);
763
764        // Remove hop-by-hop headers
765        Self::remove_hop_by_hop_headers(&mut parts);
766
767        let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
768
769        let req = Request::from_parts(parts, streaming_body);
770        Ok(req)
771    }
772
773    fn add_forwarding_headers(&self, parts: &mut http::request::Parts) {
774        let config = &self.config.headers;
775
776        // Determine whether the immediate TCP peer is a trusted upstream proxy
777        // that may dictate the real client IP via CF-Connecting-IP or XFF.
778        let peer_is_trusted = self
779            .remote_addr
780            .is_some_and(|addr| self.trusted_proxies.is_trusted(addr.ip()));
781
782        // Compute the effective client IP:
783        //   - Trusted peer + CF-Connecting-IP (parseable) -> use CF header
784        //   - Trusted peer + leftmost X-Forwarded-For (parseable) -> use XFF
785        //   - Otherwise -> fall back to the TCP peer IP
786        let effective_client_ip: Option<IpAddr> = if peer_is_trusted {
787            let cf_ip = parts
788                .headers
789                .get("cf-connecting-ip")
790                .and_then(|h| h.to_str().ok())
791                .and_then(|s| s.trim().parse::<IpAddr>().ok());
792
793            let xff_leftmost = parts
794                .headers
795                .get("x-forwarded-for")
796                .and_then(|h| h.to_str().ok())
797                .and_then(|s| s.split(',').next())
798                .and_then(|s| s.trim().parse::<IpAddr>().ok());
799
800            cf_ip
801                .or(xff_leftmost)
802                .or_else(|| self.remote_addr.map(|a| a.ip()))
803        } else {
804            self.remote_addr.map(|a| a.ip())
805        };
806
807        // X-Forwarded-For
808        if config.x_forwarded_for {
809            if let Some(addr) = self.remote_addr {
810                let existing_xff = parts
811                    .headers
812                    .get("x-forwarded-for")
813                    .and_then(|h| h.to_str().ok())
814                    .map(std::string::ToString::to_string);
815
816                let new_value = if peer_is_trusted {
817                    // Trusted proxy: prepend the real client IP (from CF /
818                    // leftmost XFF / peer) to any existing chain so downstream
819                    // sees [real_client, ...upstream_chain].
820                    let real = effective_client_ip.unwrap_or_else(|| addr.ip()).to_string();
821                    match existing_xff {
822                        Some(chain) if !chain.trim().is_empty() => format!("{real}, {chain}"),
823                        _ => real,
824                    }
825                } else {
826                    // Untrusted peer: preserve existing behavior — append the
827                    // peer IP to any existing chain.
828                    match existing_xff {
829                        Some(chain) => format!("{}, {}", chain, addr.ip()),
830                        None => addr.ip().to_string(),
831                    }
832                };
833
834                if let Ok(value) = new_value.parse() {
835                    parts.headers.insert("x-forwarded-for", value);
836                }
837            }
838        }
839
840        // X-Forwarded-Proto
841        if config.x_forwarded_proto && parts.headers.get("x-forwarded-proto").is_none() {
842            let proto = if self.is_tls { "https" } else { "http" };
843            if let Ok(value) = proto.parse() {
844                parts.headers.insert("x-forwarded-proto", value);
845            }
846        }
847
848        // X-Forwarded-Host
849        if config.x_forwarded_host {
850            if let Some(host) = parts.headers.get(header::HOST).cloned() {
851                if parts.headers.get("x-forwarded-host").is_none() {
852                    parts.headers.insert("x-forwarded-host", host);
853                }
854            }
855        }
856
857        // X-Real-IP — set to the effective client IP only if the header is
858        // currently absent (conservative: do not overwrite a value set by an
859        // upstream component).
860        if config.x_real_ip {
861            if let Some(ip) = effective_client_ip {
862                if parts.headers.get("x-real-ip").is_none() {
863                    if let Ok(value) = ip.to_string().parse() {
864                        parts.headers.insert("x-real-ip", value);
865                    }
866                }
867            }
868        }
869
870        // Via header
871        if config.via {
872            let proto_version = match parts.version {
873                Version::HTTP_09 => "0.9",
874                Version::HTTP_10 => "1.0",
875                Version::HTTP_2 => "2.0",
876                Version::HTTP_3 => "3.0",
877                _ => "1.1",
878            };
879
880            let via_value = format!("{} {}", proto_version, config.server_name);
881            let existing = parts
882                .headers
883                .get(header::VIA)
884                .and_then(|h| h.to_str().ok())
885                .map(|s| format!("{s}, {via_value}"))
886                .unwrap_or(via_value);
887
888            if let Ok(value) = existing.parse() {
889                parts.headers.insert(header::VIA, value);
890            }
891        }
892    }
893
894    fn remove_hop_by_hop_headers(parts: &mut http::request::Parts) {
895        // Standard hop-by-hop headers that should not be forwarded
896        const HOP_BY_HOP: &[&str] = &[
897            "connection",
898            "keep-alive",
899            "proxy-authenticate",
900            "proxy-authorization",
901            "te",
902            "trailer",
903            "transfer-encoding",
904            "upgrade",
905        ];
906
907        // First, collect headers listed in the Connection header before we remove it
908        let connection_headers: Vec<String> = parts
909            .headers
910            .get(header::CONNECTION)
911            .and_then(|h| h.to_str().ok())
912            .map(|value| value.split(',').map(|s| s.trim().to_lowercase()).collect())
913            .unwrap_or_default();
914
915        for header_name in HOP_BY_HOP {
916            parts.headers.remove(*header_name);
917        }
918
919        // Also remove headers that were listed in the Connection header
920        for header_name in connection_headers {
921            parts.headers.remove(header_name.as_str());
922        }
923    }
924
925    /// Build a client-facing error response with a **generic** body.
926    ///
927    /// This is the default-deny safety boundary for the ingress proxy. The
928    /// proxy binds `0.0.0.0:80`/`:443`, so it MUST NOT leak internal details
929    /// (the requested Host/path, a backend address, or the internal
930    /// load-balancer group name) to an unauthenticated caller. The full,
931    /// detailed [`ProxyError`] is logged by the caller (`error!(error = %e)`
932    /// in `server.rs`); the body returned here carries only a minimal,
933    /// status-appropriate phrase so an unmatched / no-target request gets a
934    /// clean deny rather than an internal echo.
935    ///
936    /// # Panics
937    ///
938    /// Panics if building a valid HTTP response with a plain-text body fails,
939    /// which should never occur with well-formed status codes.
940    pub fn error_response(error: &ProxyError) -> Response<BoxBody> {
941        let status = error.status_code();
942        // Generic, non-leaking body keyed purely off the status code. We
943        // deliberately do NOT interpolate `error` (which can contain the Host,
944        // path, backend address, or LB group name) into the client-visible
945        // body.
946        let body = status.canonical_reason().map_or_else(
947            || status.as_str().to_string(),
948            |reason| format!("{} {reason}", status.as_u16()),
949        );
950
951        Response::builder()
952            .status(status)
953            .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
954            .body(full_body(body))
955            .unwrap()
956    }
957}
958
959impl Service<Request<Incoming>> for ReverseProxyService {
960    type Response = Response<BoxBody>;
961    type Error = ProxyError;
962    type Future = std::pin::Pin<
963        Box<
964            dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
965                + Send,
966        >,
967    >;
968
969    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
970        Poll::Ready(Ok(()))
971    }
972
973    fn call(&mut self, req: Request<Incoming>) -> Self::Future {
974        let this = self.clone();
975        Box::pin(async move { this.proxy_request(req).await })
976    }
977}
978
979#[cfg(test)]
980mod tests {
981    use super::*;
982
983    #[test]
984    fn test_error_response() {
985        let error = ProxyError::RouteNotFound {
986            host: "example.com".to_string(),
987            path: "/api".to_string(),
988        };
989
990        let response = ReverseProxyService::error_response(&error);
991        assert_eq!(response.status(), http::StatusCode::NOT_FOUND);
992    }
993
994    #[test]
995    fn test_hop_by_hop_headers() {
996        let mut parts = http::request::Builder::new()
997            .method("GET")
998            .uri("/test")
999            .header("connection", "keep-alive, x-custom")
1000            .header("keep-alive", "timeout=5")
1001            .header("x-custom", "value")
1002            .header("x-other", "value")
1003            .body(())
1004            .unwrap()
1005            .into_parts()
1006            .0;
1007
1008        ReverseProxyService::remove_hop_by_hop_headers(&mut parts);
1009
1010        assert!(parts.headers.get("connection").is_none());
1011        assert!(parts.headers.get("keep-alive").is_none());
1012        assert!(parts.headers.get("x-custom").is_none());
1013        // x-other should remain
1014        assert!(parts.headers.get("x-other").is_some());
1015    }
1016
1017    #[test]
1018    fn test_is_overlay_ip_accepts_overlay_range() {
1019        // 10.200.x.x should be recognized as overlay
1020        assert!(is_overlay_ip("10.200.0.1".parse().unwrap()));
1021        assert!(is_overlay_ip("10.200.255.254".parse().unwrap()));
1022        assert!(is_overlay_ip("10.200.1.100".parse().unwrap()));
1023    }
1024
1025    #[test]
1026    fn test_is_overlay_ip_rejects_non_overlay() {
1027        // Non-overlay addresses
1028        assert!(!is_overlay_ip("192.168.1.1".parse().unwrap()));
1029        assert!(!is_overlay_ip("10.0.0.1".parse().unwrap()));
1030        assert!(!is_overlay_ip("10.201.0.1".parse().unwrap()));
1031        assert!(!is_overlay_ip("172.16.0.1".parse().unwrap()));
1032        assert!(!is_overlay_ip("8.8.8.8".parse().unwrap()));
1033    }
1034
1035    #[test]
1036    fn test_is_overlay_ip_rejects_ipv6() {
1037        assert!(!is_overlay_ip("::1".parse().unwrap()));
1038        assert!(!is_overlay_ip("fe80::1".parse().unwrap()));
1039    }
1040
1041    #[test]
1042    fn test_forbidden_error_response() {
1043        let error = ProxyError::Forbidden("endpoint 'ws' is internal-only".to_string());
1044        let response = ReverseProxyService::error_response(&error);
1045        assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1046    }
1047
1048    // --- Tests for CF-Connecting-IP / X-Forwarded-For trust handling ------
1049
1050    use crate::trust::TrustedProxyList;
1051
1052    fn build_svc(peer: SocketAddr, trusted: TrustedProxyList) -> ReverseProxyService {
1053        let registry = Arc::new(ServiceRegistry::new());
1054        let load_balancer = Arc::new(LoadBalancer::new());
1055        let config = Arc::new(ProxyConfig::default());
1056        ReverseProxyService::new(registry, load_balancer, config)
1057            .with_remote_addr(peer)
1058            .with_trusted_proxies(Arc::new(trusted))
1059    }
1060
1061    fn parts_with_headers(headers: &[(&str, &str)]) -> http::request::Parts {
1062        let mut builder = http::request::Builder::new().method("GET").uri("/");
1063        for (k, v) in headers {
1064            builder = builder.header(*k, *v);
1065        }
1066        builder.body(()).unwrap().into_parts().0
1067    }
1068
1069    #[test]
1070    fn trusted_peer_cf_connecting_ip_is_honored() {
1071        // Peer 203.0.113.50 is inside the trusted /24. Its CF-Connecting-IP
1072        // should become X-Real-IP and be prepended to X-Forwarded-For.
1073        let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
1074        let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
1075        let svc = build_svc(peer, trusted);
1076
1077        let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
1078        svc.add_forwarding_headers(&mut parts);
1079
1080        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.7");
1081        let xff = parts
1082            .headers
1083            .get("x-forwarded-for")
1084            .unwrap()
1085            .to_str()
1086            .unwrap();
1087        assert!(
1088            xff.starts_with("198.51.100.7"),
1089            "XFF should start with real client IP, got {xff}"
1090        );
1091    }
1092
1093    #[test]
1094    fn trusted_peer_xff_leftmost_is_honored_when_no_cf_header() {
1095        // Peer is trusted; no CF header but XFF chain is present. The leftmost
1096        // XFF entry is treated as the real client IP.
1097        let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
1098        let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
1099        let svc = build_svc(peer, trusted);
1100
1101        let mut parts = parts_with_headers(&[("x-forwarded-for", "198.51.100.9, 10.0.0.1")]);
1102        svc.add_forwarding_headers(&mut parts);
1103
1104        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.9");
1105        let xff = parts
1106            .headers
1107            .get("x-forwarded-for")
1108            .unwrap()
1109            .to_str()
1110            .unwrap();
1111        // Real client prepended, original chain preserved after.
1112        assert!(
1113            xff.starts_with("198.51.100.9"),
1114            "XFF should start with leftmost real client, got {xff}"
1115        );
1116        assert!(
1117            xff.contains("10.0.0.1"),
1118            "original chain should survive: {xff}"
1119        );
1120    }
1121
1122    #[test]
1123    fn untrusted_peer_cf_connecting_ip_is_ignored() {
1124        // Peer 8.8.8.8 is NOT in the trusted list. The CF header must be
1125        // ignored and X-Real-IP must reflect the TCP peer.
1126        let peer: SocketAddr = "8.8.8.8:443".parse().unwrap();
1127        let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
1128        let svc = build_svc(peer, trusted);
1129
1130        let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
1131        svc.add_forwarding_headers(&mut parts);
1132
1133        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "8.8.8.8");
1134        let xff = parts
1135            .headers
1136            .get("x-forwarded-for")
1137            .unwrap()
1138            .to_str()
1139            .unwrap();
1140        // Untrusted peer: XFF should end with the peer IP (append behavior).
1141        assert!(
1142            xff.ends_with("8.8.8.8"),
1143            "XFF for untrusted peer should end with peer IP, got {xff}"
1144        );
1145    }
1146
1147    #[test]
1148    fn no_headers_uses_peer_ip() {
1149        // No CF, no XFF. Any peer (trusted or not) should yield X-Real-IP ==
1150        // peer IP.
1151        let peer: SocketAddr = "198.51.100.250:443".parse().unwrap();
1152        let trusted = TrustedProxyList::localhost_only();
1153        let svc = build_svc(peer, trusted);
1154
1155        let mut parts = parts_with_headers(&[]);
1156        svc.add_forwarding_headers(&mut parts);
1157
1158        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.250");
1159        assert_eq!(
1160            parts.headers.get("x-forwarded-for").unwrap(),
1161            "198.51.100.250"
1162        );
1163    }
1164
1165    // --- RpsRegistry --------------------------------------------------------
1166
1167    #[tokio::test]
1168    async fn rps_registry_counts_recorded_requests() {
1169        let reg = RpsRegistry::new();
1170        // Unknown service is 0.
1171        assert!((reg.rps("svc").await - 0.0).abs() < f64::EPSILON);
1172
1173        // Record N requests; rps == N / window_secs.
1174        let n = 30;
1175        for _ in 0..n {
1176            reg.record("svc").await;
1177        }
1178        let expected = f64::from(n) / RPS_WINDOW.as_secs_f64();
1179        let got = reg.rps("svc").await;
1180        assert!(
1181            (got - expected).abs() < 1e-9,
1182            "expected {expected}, got {got}"
1183        );
1184    }
1185
1186    #[tokio::test]
1187    async fn rps_registry_isolates_services() {
1188        let reg = RpsRegistry::new();
1189        reg.record("a").await;
1190        reg.record("a").await;
1191        reg.record("b").await;
1192
1193        let snap = reg.snapshot().await;
1194        let a = snap.get("a").copied().unwrap_or_default();
1195        let b = snap.get("b").copied().unwrap_or_default();
1196        assert!(
1197            a > b,
1198            "service a ({a}) should have a higher rate than b ({b})"
1199        );
1200        // b recorded exactly once.
1201        assert!((b - 1.0 / RPS_WINDOW.as_secs_f64()).abs() < 1e-9);
1202    }
1203
1204    #[tokio::test]
1205    async fn rps_registry_prunes_old_timestamps() {
1206        let reg = RpsRegistry::new();
1207        // Inject a timestamp well outside the window directly so the test does
1208        // not have to sleep for the full window.
1209        {
1210            let mut map = reg.services.lock().await;
1211            let ring = map.entry("svc".to_string()).or_default();
1212            let stale = Instant::now()
1213                .checked_sub(RPS_WINDOW + Duration::from_secs(5))
1214                .expect("instant underflow in test");
1215            ring.push_back(stale);
1216        }
1217        // The stale entry must be pruned on read.
1218        assert!((reg.rps("svc").await - 0.0).abs() < f64::EPSILON);
1219    }
1220
1221    // --- Activator ----------------------------------------------------------
1222
1223    use crate::lb::LbStrategy;
1224
1225    /// Activator that flips a flag and registers a healthy backend on the
1226    /// shared load balancer, simulating a scale-from-zero wake-up.
1227    struct TestActivator {
1228        lb: Arc<LoadBalancer>,
1229        called: std::sync::atomic::AtomicBool,
1230    }
1231
1232    #[async_trait::async_trait]
1233    impl Activator for TestActivator {
1234        async fn activate(&self, service: &str) -> std::result::Result<(), String> {
1235            self.called.store(true, std::sync::atomic::Ordering::SeqCst);
1236            // "Scale up": give the group a backend so the re-poll succeeds.
1237            self.lb.register(
1238                service,
1239                vec!["127.0.0.1:9".parse().unwrap()],
1240                LbStrategy::RoundRobin,
1241            );
1242            Ok(())
1243        }
1244    }
1245
1246    fn build_svc_with_lb(lb: Arc<LoadBalancer>) -> ReverseProxyService {
1247        let registry = Arc::new(ServiceRegistry::new());
1248        let config = Arc::new(ProxyConfig::default());
1249        ReverseProxyService::new(registry, lb, config)
1250    }
1251
1252    #[tokio::test]
1253    async fn select_or_activate_returns_none_without_activator() {
1254        let lb = Arc::new(LoadBalancer::new());
1255        let svc = build_svc_with_lb(Arc::clone(&lb));
1256        // No backends, no activator: immediate None (preserves 503 path).
1257        assert!(svc.select_or_activate("svc").await.is_none());
1258    }
1259
1260    #[tokio::test]
1261    async fn select_or_activate_wakes_scaled_to_zero_service() {
1262        let lb = Arc::new(LoadBalancer::new());
1263        // Group exists but has no healthy backend (scale-to-zero idle).
1264        lb.register("svc", vec![], LbStrategy::RoundRobin);
1265
1266        let activator = Arc::new(TestActivator {
1267            lb: Arc::clone(&lb),
1268            called: std::sync::atomic::AtomicBool::new(false),
1269        });
1270        let svc = build_svc_with_lb(Arc::clone(&lb)).with_activator(activator.clone());
1271
1272        let backend = svc.select_or_activate("svc").await;
1273        assert!(
1274            backend.is_some(),
1275            "activator should have produced a backend"
1276        );
1277        assert!(
1278            activator.called.load(std::sync::atomic::Ordering::SeqCst),
1279            "activator must have been invoked"
1280        );
1281    }
1282
1283    #[tokio::test]
1284    async fn select_or_activate_returns_existing_backend_without_calling_activator() {
1285        let lb = Arc::new(LoadBalancer::new());
1286        lb.register(
1287            "svc",
1288            vec!["127.0.0.1:9".parse().unwrap()],
1289            LbStrategy::RoundRobin,
1290        );
1291        let activator = Arc::new(TestActivator {
1292            lb: Arc::clone(&lb),
1293            called: std::sync::atomic::AtomicBool::new(false),
1294        });
1295        let svc = build_svc_with_lb(Arc::clone(&lb)).with_activator(activator.clone());
1296
1297        assert!(svc.select_or_activate("svc").await.is_some());
1298        assert!(
1299            !activator.called.load(std::sync::atomic::Ordering::SeqCst),
1300            "activator must NOT be called when a backend already exists"
1301        );
1302    }
1303}