Skip to main content

rust_web_server/proxy_config/
mod.rs

1//! Config-driven proxy application.
2//!
3//! When `rws.config.toml` contains `[[route]]` or `[[upstream]]` sections,
4//! `ConfigDrivenApp` is used as the top-level `Application` instead of the
5//! hardcoded `build_app()` in `main.rs`.
6//!
7//! # Quick start
8//!
9//! ```toml
10//! # rws.config.toml
11//! [[upstream]]
12//! name = "api"
13//! backends = ["localhost:3000"]
14//!
15//! [[route]]
16//! name = "api-proxy"
17//!
18//! [route.match]
19//! path = "/api/*"
20//!
21//! [route.action]
22//! type = "proxy"
23//!
24//! [route.action.proxy]
25//! upstream = "api"
26//! ```
27
28pub mod parser;
29pub mod health;
30pub mod builder;
31
32#[cfg(test)]
33mod tests;
34
35use std::sync::Arc;
36
37use crate::app::App;
38use crate::application::Application;
39use crate::core::New;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43use crate::server_config::ServerConfig;
44
45// ── Public config types ────────────────────────────────────────────────────────
46
47#[derive(Debug, Clone)]
48pub struct ProxyConfig {
49    pub upstreams: Vec<UpstreamConfig>,
50    pub routes: Vec<RouteConfig>,
51    pub tcp_proxies: Vec<TcpProxyConfig>,
52    pub udp_proxies: Vec<UdpProxyConfig>,
53    pub ws_proxies: Vec<WsProxyConfig>,
54    pub global_middleware: MiddlewareConfig,
55}
56
57#[derive(Debug, Clone)]
58pub struct UpstreamConfig {
59    pub name: String,
60    pub backends: Vec<String>,
61    pub strategy: String, // "round_robin" | "random" | "ip_hash"
62    pub health_check: Option<HealthCheckConfig>,
63    /// `true` when all backends use `https://` scheme — connections to the
64    /// upstream are made over TLS. Requires the `http-client` or `http2`
65    /// feature (which bring in `rustls` + `webpki-roots`).
66    pub tls: bool,
67}
68
69#[derive(Debug, Clone)]
70pub struct HealthCheckConfig {
71    pub path: String,
72    pub interval_secs: u64,
73    pub timeout_ms: u64,
74    pub healthy_threshold: u32,
75    pub unhealthy_threshold: u32,
76}
77
78#[derive(Debug, Clone)]
79pub struct RouteConfig {
80    pub name: String,
81    pub match_: MatchConfig,
82    pub action: ActionConfig,
83    pub middleware: MiddlewareConfig,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct MatchConfig {
88    pub host: Option<String>,
89    pub path: Option<String>,
90    pub method: Option<String>,
91    pub content_type: Option<String>,
92}
93
94#[derive(Debug, Clone)]
95pub enum ActionConfig {
96    Proxy {
97        upstream: String,
98        connect_timeout_ms: u64,
99        read_timeout_ms: u64,
100        strip_path_prefix: Option<String>,
101        add_path_prefix: Option<String>,
102    },
103    Grpc {
104        upstream: String,
105        connect_timeout_ms: u64,
106        read_timeout_ms: u64,
107    },
108    Static {
109        root: String,
110        index: Vec<String>,
111    },
112    Redirect {
113        location: String,
114        status: u16,
115    },
116    Respond {
117        status: u16,
118        body: String,
119        content_type: String,
120    },
121    Mcp,
122    Unknown(String),
123}
124
125#[derive(Debug, Clone, Default)]
126pub struct MiddlewareConfig {
127    pub rate_limit: Option<RateLimitConfig>,
128    pub cache: Option<CacheConfig>,
129    pub auth: Option<AuthConfig>,
130    pub rewrite_request: Vec<RewriteRuleConfig>,
131    pub rewrite_response: Vec<RewriteRuleConfig>,
132    pub ip_allow: Vec<String>,
133    pub ip_deny: Vec<String>,
134    /// `timeout_ms` — if the route (including all its other middleware)
135    /// doesn't produce a response within this many milliseconds, the client
136    /// gets `504 Gateway Timeout` instead of waiting further. See
137    /// `crate::timeout` for the underlying mechanism and its limitations.
138    pub timeout_ms: Option<u64>,
139}
140
141#[derive(Debug, Clone)]
142pub struct RateLimitConfig {
143    pub max_requests: u32,
144    pub window_secs: u64,
145}
146
147#[derive(Debug, Clone)]
148pub struct CacheConfig {
149    pub ttl_secs: u64,
150    pub vary_by: Vec<String>,
151}
152
153#[derive(Debug, Clone)]
154pub enum AuthConfig {
155    /// `auth = { type = "basic", htpasswd_file = ".htpasswd" }`. See
156    /// `crate::auth::BasicAuthLayer::from_htpasswd_file` for the supported
157    /// file format (plain text and `{SHA256}` only — not Apache's `{SHA}`,
158    /// `$apr1$`, or bcrypt).
159    Basic { htpasswd_file: String },
160    /// `auth = { type = "jwt", secret_env = "JWT_SECRET" }`. Requires the
161    /// `auth` feature; verifies HS256 JWTs via `crate::auth::JwtLayer`.
162    Jwt { secret_env: String },
163    Bearer { token_env: String },
164}
165
166#[derive(Debug, Clone, Default)]
167pub struct RewriteRuleConfig {
168    pub type_: String,
169    pub name: Option<String>,
170    pub value: Option<String>,
171    pub prefix: Option<String>,
172    pub from: Option<String>,
173    pub to: Option<String>,
174    pub code: Option<u16>,
175    pub reason: Option<String>,
176}
177
178#[derive(Debug, Clone)]
179pub struct TcpProxyConfig {
180    pub name: String,
181    pub listen: String,
182    pub backends: Vec<String>,
183    pub connect_timeout_ms: u64,
184}
185
186#[derive(Debug, Clone)]
187pub struct UdpProxyConfig {
188    pub name: String,
189    pub listen: String,
190    pub backends: Vec<String>,
191    pub reply_timeout_ms: u64,
192    pub buffer_size: usize,
193}
194
195#[derive(Debug, Clone)]
196pub struct WsProxyConfig {
197    pub name: String,
198    pub listen: String,
199    pub backends: Vec<String>,
200    pub connect_timeout_ms: u64,
201    pub read_timeout_ms: u64,
202}
203
204// ── ProxyConfig loading ────────────────────────────────────────────────────────
205
206impl ProxyConfig {
207    /// Returns `true` if `rws.config.toml` (or `RWS_CONFIG_FILE`) contains
208    /// `[[route]]` or `[[upstream]]` sections, meaning config-driven mode
209    /// should be used.
210    pub fn is_proxy_mode() -> bool {
211        let path = config_file_path();
212        match std::fs::read_to_string(&path) {
213            Ok(contents) => {
214                contents.contains("[[route]]") || contents.contains("[[upstream]]")
215            }
216            Err(_) => false,
217        }
218    }
219
220    /// Parse the config file and return a `ProxyConfig`.
221    pub fn load() -> Self {
222        let path = config_file_path();
223        let contents = std::fs::read_to_string(&path).unwrap_or_default();
224        Self::from_str(&contents)
225    }
226
227    /// Parse `toml` text directly into a `ProxyConfig`. Used in tests.
228    pub fn from_str(toml: &str) -> Self {
229        use parser::{get_array, get_str, get_u32, get_u64, section_exists};
230
231        let map = parser::parse(toml);
232
233        // ── upstreams ──────────────────────────────────────────────────────────
234        let mut upstreams = Vec::new();
235        let mut i = 0;
236        loop {
237            let sec = format!("upstream[{}]", i);
238            if !section_exists(&map, &sec) {
239                break;
240            }
241            let name = get_str(&map, &sec, "name");
242            let backends = get_array(&map, &sec, "backends");
243            let strategy = {
244                let s = get_str(&map, &sec, "strategy");
245                if s.is_empty() { "round_robin".to_string() } else { s }
246            };
247            let hc_sec = format!("{}.health_check", sec);
248            let health_check = if section_exists(&map, &hc_sec) {
249                Some(HealthCheckConfig {
250                    path: {
251                        let p = get_str(&map, &hc_sec, "path");
252                        if p.is_empty() { "/health".to_string() } else { p }
253                    },
254                    interval_secs: get_u64(&map, &hc_sec, "interval_secs", 30),
255                    timeout_ms: get_u64(&map, &hc_sec, "timeout_ms", 5000),
256                    healthy_threshold: get_u32(&map, &hc_sec, "healthy_threshold", 2),
257                    unhealthy_threshold: get_u32(&map, &hc_sec, "unhealthy_threshold", 3),
258                })
259            } else {
260                None
261            };
262            let tls = backends.iter().any(|b| b.starts_with("https://"));
263            upstreams.push(UpstreamConfig { name, backends, strategy, health_check, tls });
264            i += 1;
265        }
266
267        // ── routes ─────────────────────────────────────────────────────────────
268        let mut routes = Vec::new();
269        let mut i = 0;
270        loop {
271            let sec = format!("route[{}]", i);
272            if !section_exists(&map, &sec) {
273                break;
274            }
275            let name = get_str(&map, &sec, "name");
276
277            // match
278            let m_sec = format!("{}.match", sec);
279            let match_ = MatchConfig {
280                host: {
281                    let h = get_str(&map, &m_sec, "host");
282                    if h.is_empty() { None } else { Some(h) }
283                },
284                path: {
285                    let p = get_str(&map, &m_sec, "path");
286                    if p.is_empty() { None } else { Some(p) }
287                },
288                method: {
289                    let m = get_str(&map, &m_sec, "method");
290                    if m.is_empty() { None } else { Some(m.to_uppercase()) }
291                },
292                content_type: {
293                    let c = get_str(&map, &m_sec, "content_type");
294                    if c.is_empty() { None } else { Some(c) }
295                },
296            };
297
298            // action
299            let a_sec = format!("{}.action", sec);
300            let action_type = get_str(&map, &a_sec, "type");
301            let action = match action_type.as_str() {
302                "proxy" => {
303                    let p_sec = format!("{}.action.proxy", sec);
304                    ActionConfig::Proxy {
305                        upstream: get_str(&map, &p_sec, "upstream"),
306                        connect_timeout_ms: get_u64(&map, &p_sec, "connect_timeout_ms", 5000),
307                        read_timeout_ms: get_u64(&map, &p_sec, "read_timeout_ms", 30000),
308                        strip_path_prefix: {
309                            let v = get_str(&map, &p_sec, "strip_path_prefix");
310                            if v.is_empty() { None } else { Some(v) }
311                        },
312                        add_path_prefix: {
313                            let v = get_str(&map, &p_sec, "add_path_prefix");
314                            if v.is_empty() { None } else { Some(v) }
315                        },
316                    }
317                }
318                "grpc" => {
319                    let p_sec = format!("{}.action.grpc", sec);
320                    ActionConfig::Grpc {
321                        upstream: get_str(&map, &p_sec, "upstream"),
322                        connect_timeout_ms: get_u64(&map, &p_sec, "connect_timeout_ms", 5000),
323                        read_timeout_ms: get_u64(&map, &p_sec, "read_timeout_ms", 30000),
324                    }
325                }
326                "static" => {
327                    let s_sec = format!("{}.action.static", sec);
328                    ActionConfig::Static {
329                        root: get_str(&map, &s_sec, "root"),
330                        index: get_array(&map, &s_sec, "index"),
331                    }
332                }
333                "redirect" => {
334                    let r_sec = format!("{}.action.redirect", sec);
335                    ActionConfig::Redirect {
336                        location: get_str(&map, &r_sec, "location"),
337                        status: get_u64(&map, &r_sec, "status", 301) as u16,
338                    }
339                }
340                "respond" => {
341                    let r_sec = format!("{}.action.respond", sec);
342                    ActionConfig::Respond {
343                        status: get_u64(&map, &r_sec, "status", 200) as u16,
344                        body: get_str(&map, &r_sec, "body"),
345                        content_type: {
346                            let ct = get_str(&map, &r_sec, "content_type");
347                            if ct.is_empty() { "text/plain".to_string() } else { ct }
348                        },
349                    }
350                }
351                "mcp" => ActionConfig::Mcp,
352                other => ActionConfig::Unknown(other.to_string()),
353            };
354
355            // middleware
356            let mw_sec = format!("{}.middleware", sec);
357            let middleware = parse_middleware_config(&map, &mw_sec, i);
358
359            routes.push(RouteConfig { name, match_, action, middleware });
360            i += 1;
361        }
362
363        // ── tcp_proxy ──────────────────────────────────────────────────────────
364        let mut tcp_proxies = Vec::new();
365        let mut i = 0;
366        loop {
367            let sec = format!("tcp_proxy[{}]", i);
368            if !section_exists(&map, &sec) {
369                break;
370            }
371            tcp_proxies.push(TcpProxyConfig {
372                name: get_str(&map, &sec, "name"),
373                listen: get_str(&map, &sec, "listen"),
374                backends: get_array(&map, &sec, "backends"),
375                connect_timeout_ms: get_u64(&map, &sec, "connect_timeout_ms", 5000),
376            });
377            i += 1;
378        }
379
380        // ── udp_proxy ──────────────────────────────────────────────────────────
381        let mut udp_proxies = Vec::new();
382        let mut i = 0;
383        loop {
384            let sec = format!("udp_proxy[{}]", i);
385            if !section_exists(&map, &sec) {
386                break;
387            }
388            udp_proxies.push(UdpProxyConfig {
389                name: get_str(&map, &sec, "name"),
390                listen: get_str(&map, &sec, "listen"),
391                backends: get_array(&map, &sec, "backends"),
392                reply_timeout_ms: get_u64(&map, &sec, "reply_timeout_ms", 5000),
393                buffer_size: get_u64(&map, &sec, "buffer_size", 65536) as usize,
394            });
395            i += 1;
396        }
397
398        // ── ws_proxy ───────────────────────────────────────────────────────────
399        let mut ws_proxies = Vec::new();
400        let mut i = 0;
401        loop {
402            let sec = format!("ws_proxy[{}]", i);
403            if !section_exists(&map, &sec) {
404                break;
405            }
406            ws_proxies.push(WsProxyConfig {
407                name: get_str(&map, &sec, "name"),
408                listen: get_str(&map, &sec, "listen"),
409                backends: get_array(&map, &sec, "backends"),
410                connect_timeout_ms: get_u64(&map, &sec, "connect_timeout_ms", 5000),
411                read_timeout_ms: get_u64(&map, &sec, "read_timeout_ms", 30000),
412            });
413            i += 1;
414        }
415
416        // ── global middleware ──────────────────────────────────────────────────
417        let global_middleware = parse_middleware_config(&map, "middleware", usize::MAX);
418
419        ProxyConfig {
420            upstreams,
421            routes,
422            tcp_proxies,
423            udp_proxies,
424            ws_proxies,
425            global_middleware,
426        }
427    }
428}
429
430/// Parse a `MiddlewareConfig` from the section map at a given base path.
431/// `route_idx` is used only to build inner-array section paths for rewrite rules.
432fn parse_middleware_config(
433    map: &parser::SectionMap,
434    mw_sec: &str,
435    route_idx: usize,
436) -> MiddlewareConfig {
437    use parser::{get_array, get_str, get_u32, get_u64, section_exists};
438
439    let rl_sec = format!("{}.rate_limit", mw_sec);
440    let rate_limit = if section_exists(map, &rl_sec) {
441        Some(RateLimitConfig {
442            max_requests: get_u32(map, &rl_sec, "max_requests", 1000),
443            window_secs: get_u64(map, &rl_sec, "window_secs", 60),
444        })
445    } else {
446        None
447    };
448
449    let c_sec = format!("{}.cache", mw_sec);
450    let cache = if section_exists(map, &c_sec) {
451        Some(CacheConfig {
452            ttl_secs: get_u64(map, &c_sec, "ttl_secs", 60),
453            vary_by: get_array(map, &c_sec, "vary_by"),
454        })
455    } else {
456        None
457    };
458
459    let a_sec = format!("{}.auth", mw_sec);
460    let auth = if section_exists(map, &a_sec) {
461        let auth_type = get_str(map, &a_sec, "type");
462        match auth_type.as_str() {
463            "bearer" => Some(AuthConfig::Bearer {
464                token_env: get_str(map, &a_sec, "token_env"),
465            }),
466            "jwt" => Some(AuthConfig::Jwt {
467                secret_env: get_str(map, &a_sec, "secret_env"),
468            }),
469            "basic" => Some(AuthConfig::Basic {
470                htpasswd_file: get_str(map, &a_sec, "htpasswd_file"),
471            }),
472            _ => None,
473        }
474    } else {
475        None
476    };
477
478    // Rewrite rules — the section paths use route_idx for route-scoped rules
479    // or a flat path for global middleware. We look for:
480    //   route[N].middleware.rewrite.request[0], [1], …
481    //   route[N].middleware.rewrite.response[0], [1], …
482    // For global: middleware.rewrite.request[0], etc.
483    let rewrite_request = collect_rewrite_rules(map, mw_sec, "request");
484    let rewrite_response = collect_rewrite_rules(map, mw_sec, "response");
485
486    let ip_sec = format!("{}.ip_filter", mw_sec);
487    let ip_allow = if section_exists(map, &ip_sec) {
488        get_array(map, &ip_sec, "allow")
489    } else {
490        vec![]
491    };
492    let ip_deny = if section_exists(map, &ip_sec) {
493        get_array(map, &ip_sec, "deny")
494    } else {
495        vec![]
496    };
497
498    let _ = route_idx; // used implicitly via mw_sec paths
499
500    // Flat scalar directly under [route.middleware] (or the global
501    // [middleware] table), not a nested sub-table like rate_limit/cache —
502    // 0/absent both mean "no timeout configured".
503    let timeout_ms = match get_u64(map, mw_sec, "timeout_ms", 0) {
504        0 => None,
505        ms => Some(ms),
506    };
507
508    MiddlewareConfig { rate_limit, cache, auth, rewrite_request, rewrite_response, ip_allow, ip_deny, timeout_ms }
509}
510
511/// Collect `[[{mw_sec}.rewrite.{direction}]]` entries.
512fn collect_rewrite_rules(
513    map: &parser::SectionMap,
514    mw_sec: &str,
515    direction: &str,
516) -> Vec<RewriteRuleConfig> {
517    use parser::{get_str, get_u64};
518
519    let mut rules = Vec::new();
520    let mut j = 0;
521    loop {
522        let rsec = format!("{}.rewrite.{}[{}]", mw_sec, direction, j);
523        if !parser::section_exists(map, &rsec) {
524            break;
525        }
526        let code_val = get_u64(map, &rsec, "code", 0);
527        rules.push(RewriteRuleConfig {
528            type_: get_str(map, &rsec, "type"),
529            name: {
530                let v = get_str(map, &rsec, "name");
531                if v.is_empty() { None } else { Some(v) }
532            },
533            value: {
534                let v = get_str(map, &rsec, "value");
535                if v.is_empty() { None } else { Some(v) }
536            },
537            prefix: {
538                let v = get_str(map, &rsec, "prefix");
539                if v.is_empty() { None } else { Some(v) }
540            },
541            from: {
542                let v = get_str(map, &rsec, "from");
543                if v.is_empty() { None } else { Some(v) }
544            },
545            to: {
546                let v = get_str(map, &rsec, "to");
547                if v.is_empty() { None } else { Some(v) }
548            },
549            code: if code_val == 0 { None } else { Some(code_val as u16) },
550            reason: {
551                let v = get_str(map, &rsec, "reason");
552                if v.is_empty() { None } else { Some(v) }
553            },
554        });
555        j += 1;
556    }
557    rules
558}
559
560fn config_file_path() -> String {
561    std::env::var("RWS_CONFIG_FILE").unwrap_or_else(|_| "rws.config.toml".to_string())
562}
563
564// ── ConfigDrivenApp ────────────────────────────────────────────────────────────
565
566/// A compiled route: a matcher paired with a handler application.
567pub(crate) struct CompiledRoute {
568    pub(crate) matcher: RouteMatcher,
569    /// Shared, type-erased handler. `Arc` makes `Clone` cheap (pointer copy).
570    pub(crate) handler: Arc<dyn Application + Send + Sync>,
571}
572
573/// Matching criteria for a single route.
574#[derive(Clone, Default)]
575pub(crate) struct RouteMatcher {
576    /// Optional SNI hostname / `Host` header match.
577    pub(crate) host: Option<String>,
578    /// Path prefix to match (derived from `path = "/v1/*"`).
579    pub(crate) path_prefix: Option<String>,
580    /// Exact path to match (derived from `path = "/v1/ping"`).
581    pub(crate) path_exact: Option<String>,
582    /// Uppercase HTTP method, or `None` for any.
583    pub(crate) method: Option<String>,
584    /// `Content-Type` prefix (e.g. `"application/grpc"`).
585    pub(crate) content_type_prefix: Option<String>,
586}
587
588impl RouteMatcher {
589    pub(crate) fn from_match_config(cfg: &MatchConfig) -> Self {
590        let (path_prefix, path_exact) = match &cfg.path {
591            Some(p) if p.ends_with('*') => {
592                // "/v1/*" → prefix "/v1/"
593                let stripped = p.trim_end_matches('*').to_string();
594                (Some(stripped), None)
595            }
596            Some(p) => (None, Some(p.clone())),
597            None => (None, None),
598        };
599        let content_type_prefix = cfg.content_type.as_ref().map(|ct| {
600            if ct.ends_with('*') {
601                ct.trim_end_matches('*').to_string()
602            } else {
603                ct.clone()
604            }
605        });
606        RouteMatcher {
607            host: cfg.host.clone(),
608            path_prefix,
609            path_exact,
610            method: cfg.method.clone(),
611            content_type_prefix,
612        }
613    }
614
615    /// Returns `true` if `request` and `conn` match all configured criteria.
616    pub(crate) fn matches(&self, request: &Request, conn: &ConnectionInfo) -> bool {
617        // Host matching: SNI first, then Host header
618        if let Some(ref expected_host) = self.host {
619            let actual_host = conn
620                .sni_hostname
621                .as_deref()
622                .or_else(|| {
623                    request
624                        .headers
625                        .iter()
626                        .find(|h| h.name.eq_ignore_ascii_case("host"))
627                        .map(|h| h.value.as_str())
628                })
629                .unwrap_or("");
630            if actual_host != expected_host.as_str() {
631                return false;
632            }
633        }
634
635        // Method matching
636        if let Some(ref m) = self.method {
637            if request.method.to_uppercase() != m.as_str() {
638                return false;
639            }
640        }
641
642        // Path matching: strip query string for comparison
643        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
644        if let Some(ref prefix) = self.path_prefix {
645            if !path.starts_with(prefix.as_str()) {
646                return false;
647            }
648        } else if let Some(ref exact) = self.path_exact {
649            if path != exact.as_str() {
650                return false;
651            }
652        }
653
654        // Content-Type prefix matching
655        if let Some(ref ct_prefix) = self.content_type_prefix {
656            let actual_ct = request
657                .headers
658                .iter()
659                .find(|h| h.name.eq_ignore_ascii_case("content-type"))
660                .map(|h| h.value.as_str())
661                .unwrap_or("");
662            if !actual_ct.starts_with(ct_prefix.as_str()) {
663                return false;
664            }
665        }
666
667        true
668    }
669}
670
671/// An `Application` that routes requests based on a parsed `ProxyConfig`.
672///
673/// `Clone` is cheap: `routes` is an `Arc<Vec<...>>` (pointer copy), and
674/// `fallback` is `App`, itself cheap to clone (an `Option<Arc<ServerConfig>>`).
675#[derive(Clone)]
676pub struct ConfigDrivenApp {
677    routes: Arc<Vec<CompiledRoute>>,
678    /// Fallback for unmatched requests — handles /healthz, /readyz, /metrics,
679    /// static files, and the 404 controller. Reads `RWS_CONFIG_*` env vars
680    /// per request (`App::new()`'s default) unless pinned via
681    /// [`ConfigDrivenApp::with_config`].
682    fallback: App,
683}
684
685impl ConfigDrivenApp {
686    pub(crate) fn new(routes: Vec<CompiledRoute>) -> Self {
687        use crate::core::New;
688        ConfigDrivenApp {
689            routes: Arc::new(routes),
690            fallback: App::new(),
691        }
692    }
693
694    /// Pin the fallback [`App`] (used for any request none of the
695    /// config-driven routes match) to an explicit [`ServerConfig`], instead
696    /// of reading `RWS_CONFIG_*` environment variables per request.
697    ///
698    /// Mirrors [`App::with_config`] / [`crate::state::AppWithState::with_config`]
699    /// — same rationale: safe for parallel tests, and lets multiple
700    /// differently-configured proxy instances coexist in one process.
701    ///
702    /// ```rust,no_run
703    /// use rust_web_server::proxy_config::build_from_file;
704    /// use rust_web_server::server_config::ServerConfig;
705    ///
706    /// let (app, _handles) = build_from_file();
707    /// let app = app.with_config(ServerConfig::default());
708    /// ```
709    pub fn with_config(mut self, config: ServerConfig) -> Self {
710        self.fallback = App::with_config(config);
711        self
712    }
713}
714
715impl Application for ConfigDrivenApp {
716    fn execute(&self, request: &Request, conn: &ConnectionInfo) -> Result<Response, String> {
717        for route in self.routes.iter() {
718            if route.matcher.matches(request, conn) {
719                return route.handler.execute(request, conn);
720            }
721        }
722        self.fallback.execute(request, conn)
723    }
724}
725
726// ── NullApp ────────────────────────────────────────────────────────────────────
727
728/// A dead-end `Application` that always returns 404.
729/// Used as the `next` parameter when calling `Middleware::handle` directly.
730#[derive(Clone, Copy)]
731pub(crate) struct NullApp;
732
733impl Application for NullApp {
734    fn execute(&self, _request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
735        let mut r = Response::new();
736        r.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
737        r.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
738        Ok(r)
739    }
740}
741
742// ── DynamicProxy ──────────────────────────────────────────────────────────────
743
744use std::collections::HashMap;
745use std::collections::hash_map::DefaultHasher;
746use std::hash::{Hash, Hasher};
747use std::sync::atomic::{AtomicUsize, Ordering};
748use std::sync::RwLock;
749use std::time::{Duration, SystemTime, UNIX_EPOCH};
750
751/// Backend-selection strategy for `DynamicProxy`, configured via the
752/// `strategy` field on `[[upstream]]` in `rws.config.toml`. Unknown or empty
753/// values fall back to `RoundRobin`, matching the parser's own default.
754#[derive(Clone, Copy, Debug, PartialEq, Eq)]
755pub(crate) enum LoadBalanceStrategy {
756    RoundRobin,
757    Random,
758    IpHash,
759    LeastConnections,
760}
761
762impl LoadBalanceStrategy {
763    fn parse(s: &str) -> Self {
764        match s {
765            "random" => LoadBalanceStrategy::Random,
766            "ip_hash" => LoadBalanceStrategy::IpHash,
767            "least_connections" => LoadBalanceStrategy::LeastConnections,
768            _ => LoadBalanceStrategy::RoundRobin,
769        }
770    }
771}
772
773/// A proxy adapter that reads its backend list from a shared, health-checker-
774/// maintained live list at request time. Supports dynamic removal/restoration
775/// of backends without restarting.
776///
777/// This type is `Clone + Send + Sync` and implements `Application`.
778#[derive(Clone)]
779pub(crate) struct DynamicProxy {
780    live: Arc<RwLock<Vec<String>>>,
781    counter: Arc<AtomicUsize>,
782    connect_timeout: Duration,
783    read_timeout: Duration,
784    strip_prefix: Option<Arc<String>>,
785    add_prefix: Option<Arc<String>>,
786    tls: bool,
787    strategy: LoadBalanceStrategy,
788    connections: Arc<RwLock<HashMap<String, Arc<AtomicUsize>>>>,
789}
790
791impl DynamicProxy {
792    pub(crate) fn new(
793        live: Arc<RwLock<Vec<String>>>,
794        connect_timeout_ms: u64,
795        read_timeout_ms: u64,
796        strip_prefix: Option<String>,
797        add_prefix: Option<String>,
798        tls: bool,
799        strategy: String,
800    ) -> Self {
801        DynamicProxy {
802            live,
803            counter: Arc::new(AtomicUsize::new(0)),
804            connect_timeout: Duration::from_millis(connect_timeout_ms),
805            read_timeout: Duration::from_millis(read_timeout_ms),
806            strip_prefix: strip_prefix.map(Arc::new),
807            add_prefix: add_prefix.map(Arc::new),
808            tls,
809            strategy: LoadBalanceStrategy::parse(&strategy),
810            connections: Arc::new(RwLock::new(HashMap::new())),
811        }
812    }
813
814    fn next_backend(&self, client_ip: &str) -> Option<String> {
815        let live = self.live.read().unwrap();
816        if live.is_empty() {
817            return None;
818        }
819
820        let idx = match self.strategy {
821            LoadBalanceStrategy::RoundRobin => {
822                self.counter.fetch_add(1, Ordering::Relaxed) % live.len()
823            }
824            LoadBalanceStrategy::Random => {
825                let nanos = SystemTime::now()
826                    .duration_since(UNIX_EPOCH)
827                    .map(|d| d.subsec_nanos())
828                    .unwrap_or(0) as usize;
829                let salt = self.counter.fetch_add(1, Ordering::Relaxed);
830                nanos.wrapping_add(salt) % live.len()
831            }
832            LoadBalanceStrategy::IpHash => {
833                let mut hasher = DefaultHasher::new();
834                client_ip.hash(&mut hasher);
835                (hasher.finish() as usize) % live.len()
836            }
837            LoadBalanceStrategy::LeastConnections => {
838                let connections = self.connections.read().unwrap();
839                live.iter()
840                    .enumerate()
841                    .min_by_key(|(_, backend)| {
842                        connections
843                            .get(*backend)
844                            .map(|c| c.load(Ordering::Relaxed))
845                            .unwrap_or(0)
846                    })
847                    .map(|(i, _)| i)
848                    .unwrap_or(0)
849            }
850        };
851
852        Some(live[idx].clone())
853    }
854
855    /// Returns the shared in-flight connection counter for `backend`,
856    /// creating it on first use. Only consulted under `LeastConnections`.
857    fn connection_counter(&self, backend: &str) -> Arc<AtomicUsize> {
858        if let Some(counter) = self.connections.read().unwrap().get(backend) {
859            return Arc::clone(counter);
860        }
861        let mut connections = self.connections.write().unwrap();
862        Arc::clone(
863            connections
864                .entry(backend.to_string())
865                .or_insert_with(|| Arc::new(AtomicUsize::new(0))),
866        )
867    }
868}
869
870/// Decrements a backend's in-flight connection count when the request
871/// finishes (including early returns), keeping `least_connections` accurate.
872struct ConnectionGuard {
873    counter: Arc<AtomicUsize>,
874}
875
876impl Drop for ConnectionGuard {
877    fn drop(&mut self) {
878        self.counter.fetch_sub(1, Ordering::Relaxed);
879    }
880}
881
882impl Application for DynamicProxy {
883    fn execute(&self, request: &Request, conn: &ConnectionInfo) -> Result<Response, String> {
884        let backend = match self.next_backend(&conn.client.ip) {
885            Some(b) => b,
886            None => {
887                return Ok(bad_gateway());
888            }
889        };
890
891        let _connection_guard = if self.strategy == LoadBalanceStrategy::LeastConnections {
892            let counter = self.connection_counter(&backend);
893            counter.fetch_add(1, Ordering::Relaxed);
894            Some(ConnectionGuard { counter })
895        } else {
896            None
897        };
898
899        let (host, port, _) = match crate::proxy_config::health::parse_backend_url(&backend) {
900            Some(t) => t,
901            None => return Ok(bad_gateway()),
902        };
903
904        // Apply path rewriting if configured
905        let mut req_clone;
906        let effective_request = if self.strip_prefix.is_some() || self.add_prefix.is_some() {
907            req_clone = request.clone();
908            if let Some(ref sp) = self.strip_prefix {
909                if let Some(stripped) = req_clone.request_uri.strip_prefix(sp.as_str()) {
910                    req_clone.request_uri = if stripped.is_empty() || !stripped.starts_with('/') {
911                        format!("/{}", stripped)
912                    } else {
913                        stripped.to_string()
914                    };
915                }
916            }
917            if let Some(ref ap) = self.add_prefix {
918                req_clone.request_uri = format!("{}{}", ap, req_clone.request_uri);
919            }
920            &req_clone
921        } else {
922            request
923        };
924
925        let result = if self.tls {
926            #[cfg(any(feature = "http-client", feature = "http2"))]
927            {
928                crate::proxy::proxy_https1(
929                    effective_request,
930                    &conn.client.ip,
931                    &host,
932                    port,
933                    self.connect_timeout,
934                    self.read_timeout,
935                )
936            }
937            #[cfg(not(any(feature = "http-client", feature = "http2")))]
938            {
939                eprintln!("[proxy] HTTPS upstream requires http-client or http2 feature");
940                Err("TLS upstream not supported in this build".to_string())
941            }
942        } else {
943            crate::proxy::proxy_http1(
944                effective_request,
945                &conn.client.ip,
946                &host,
947                port,
948                self.connect_timeout,
949                self.read_timeout,
950            )
951        };
952
953        match result {
954            Ok(r) => Ok(r),
955            Err(_) => Ok(bad_gateway()),
956        }
957    }
958}
959
960fn bad_gateway() -> Response {
961    use crate::mime_type::MimeType;
962    use crate::range::Range;
963    let cr = Range::get_content_range(
964        b"502 Bad Gateway".to_vec(),
965        MimeType::TEXT_PLAIN.to_string(),
966    );
967    let mut r = Response::new();
968    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
969    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
970    r.content_range_list = vec![cr];
971    r
972}
973
974// ── RedirectAdapter ────────────────────────────────────────────────────────────
975
976/// Action adapter that issues HTTP redirects.
977///
978/// `$path` in `location_template` is replaced with the request URI at runtime.
979#[derive(Clone)]
980pub(crate) struct RedirectAdapter {
981    location_template: Arc<String>,
982    status: i16,
983    reason: Arc<String>,
984}
985
986impl RedirectAdapter {
987    pub(crate) fn new(location: String, status: u16) -> Self {
988        let (code, reason) = redirect_status(status);
989        RedirectAdapter {
990            location_template: Arc::new(location),
991            status: code,
992            reason: Arc::new(reason),
993        }
994    }
995}
996
997fn redirect_status(code: u16) -> (i16, String) {
998    let phrase = match code {
999        301 => STATUS_CODE_REASON_PHRASE.n301_moved_permanently.reason_phrase,
1000        302 => STATUS_CODE_REASON_PHRASE.n302_found.reason_phrase,
1001        307 => STATUS_CODE_REASON_PHRASE.n307_temporary_redirect.reason_phrase,
1002        308 => STATUS_CODE_REASON_PHRASE.n308_permanent_redirect.reason_phrase,
1003        _ => "Redirect",
1004    };
1005    (code as i16, phrase.to_string())
1006}
1007
1008impl Application for RedirectAdapter {
1009    fn execute(&self, request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1010        let location = self
1011            .location_template
1012            .replace("$path", &request.request_uri);
1013        use crate::header::Header;
1014        let mut r = Response::new();
1015        r.status_code = self.status;
1016        r.reason_phrase = self.reason.as_ref().clone();
1017        r.headers.push(Header { name: "Location".to_string(), value: location });
1018        Ok(r)
1019    }
1020}
1021
1022// ── RespondAdapter ─────────────────────────────────────────────────────────────
1023
1024/// Action adapter that returns a fixed response body.
1025#[derive(Clone)]
1026pub(crate) struct RespondAdapter {
1027    status: i16,
1028    reason: Arc<String>,
1029    body: Arc<Vec<u8>>,
1030    content_type: Arc<String>,
1031}
1032
1033impl RespondAdapter {
1034    pub(crate) fn new(status: u16, body: String, content_type: String) -> Self {
1035        use crate::response::STATUS_CODE_REASON_PHRASE;
1036        let reason = match status {
1037            200 => STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string(),
1038            201 => STATUS_CODE_REASON_PHRASE.n201_created.reason_phrase.to_string(),
1039            204 => STATUS_CODE_REASON_PHRASE.n204_no_content.reason_phrase.to_string(),
1040            400 => STATUS_CODE_REASON_PHRASE.n400_bad_request.reason_phrase.to_string(),
1041            401 => STATUS_CODE_REASON_PHRASE.n401_unauthorized.reason_phrase.to_string(),
1042            403 => STATUS_CODE_REASON_PHRASE.n403_forbidden.reason_phrase.to_string(),
1043            404 => STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string(),
1044            500 => STATUS_CODE_REASON_PHRASE.n500_internal_server_error.reason_phrase.to_string(),
1045            _ => "OK".to_string(),
1046        };
1047        RespondAdapter {
1048            status: status as i16,
1049            reason: Arc::new(reason),
1050            body: Arc::new(body.into_bytes()),
1051            content_type: Arc::new(content_type),
1052        }
1053    }
1054}
1055
1056impl Application for RespondAdapter {
1057    fn execute(&self, _request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1058        use crate::range::Range;
1059        let mut r = Response::new();
1060        r.status_code = self.status;
1061        r.reason_phrase = self.reason.as_ref().clone();
1062        if !self.body.is_empty() {
1063            r.content_range_list = vec![Range::get_content_range(
1064                self.body.as_ref().clone(),
1065                self.content_type.as_ref().clone(),
1066            )];
1067        }
1068        Ok(r)
1069    }
1070}
1071
1072// ── StaticAdapter ──────────────────────────────────────────────────────────────
1073
1074/// Action adapter that serves static files from a configured `root` directory.
1075///
1076/// Unlike `StaticResourceController` (which always resolves paths relative to
1077/// the process's current working directory), this adapter is parameterized
1078/// per-route by `ActionConfig::Static { root, index }` from `rws.config.toml`,
1079/// so a config-driven proxy can serve an arbitrary directory without Rust code.
1080#[derive(Clone)]
1081pub(crate) struct StaticAdapter {
1082    root: Arc<std::path::PathBuf>,
1083    index: Arc<Vec<String>>,
1084}
1085
1086impl StaticAdapter {
1087    pub(crate) fn new(root: String, index: Vec<String>) -> Self {
1088        let index = if index.is_empty() { vec!["index.html".to_string()] } else { index };
1089        StaticAdapter {
1090            root: Arc::new(std::path::PathBuf::from(root)),
1091            index: Arc::new(index),
1092        }
1093    }
1094
1095    /// Resolves `request_uri` against `root`. Returns `None` if the decoded
1096    /// path contains a `..` segment, which would otherwise let a request
1097    /// escape the configured root directory.
1098    fn resolve(&self, request_uri: &str) -> Option<std::path::PathBuf> {
1099        let raw_path = request_uri.split('?').next().unwrap_or(request_uri);
1100        let decoded = crate::url::URL::percent_decode(raw_path);
1101
1102        if decoded.split('/').any(|segment| segment == "..") {
1103            return None;
1104        }
1105
1106        let relative = decoded.trim_start_matches('/');
1107        Some(self.root.join(relative))
1108    }
1109}
1110
1111impl Application for StaticAdapter {
1112    fn execute(&self, request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1113        let mut response = Response::new();
1114
1115        let not_found = |mut response: Response| {
1116            response.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
1117            response.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
1118            response
1119        };
1120
1121        let candidate = match self.resolve(&request.request_uri) {
1122            Some(p) => p,
1123            None => {
1124                response.status_code = *STATUS_CODE_REASON_PHRASE.n403_forbidden.status_code;
1125                response.reason_phrase = STATUS_CODE_REASON_PHRASE.n403_forbidden.reason_phrase.to_string();
1126                return Ok(response);
1127            }
1128        };
1129
1130        let mut file_path = candidate;
1131        if file_path.is_dir() {
1132            let indexed = self
1133                .index
1134                .iter()
1135                .map(|name| file_path.join(name))
1136                .find(|p| p.is_file());
1137
1138            file_path = match indexed {
1139                Some(p) => p,
1140                None => return Ok(not_found(response)),
1141            };
1142        }
1143
1144        if !file_path.is_file() {
1145            return Ok(not_found(response));
1146        }
1147
1148        // Defense-in-depth against symlinks inside `root` that point outside it —
1149        // the `..`-segment check above only catches traversal in the request URI.
1150        if let (Ok(root_canon), Ok(file_canon)) =
1151            (self.root.canonicalize(), file_path.canonicalize())
1152        {
1153            if !file_canon.starts_with(&root_canon) {
1154                return Ok(not_found(response));
1155            }
1156        }
1157
1158        let path_str = match file_path.to_str() {
1159            Some(s) => s,
1160            None => return Ok(not_found(response)),
1161        };
1162
1163        match crate::range::Range::get_content_range_of_a_file(path_str) {
1164            Ok(content_range) => {
1165                response.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
1166                response.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
1167                response.content_range_list = vec![content_range];
1168                Ok(response)
1169            }
1170            Err(_) => Ok(not_found(response)),
1171        }
1172    }
1173}
1174
1175// ── PerRouteRateLimit middleware ───────────────────────────────────────────────
1176
1177/// A per-route rate limiter middleware backed by a shared `RateLimiter`.
1178pub(crate) struct PerRouteRateLimit(pub(crate) Arc<crate::rate_limit::RateLimiter>);
1179
1180impl crate::middleware::Middleware for PerRouteRateLimit {
1181    fn handle(
1182        &self,
1183        request: &Request,
1184        conn: &ConnectionInfo,
1185        next: &dyn Application,
1186    ) -> Result<Response, String> {
1187        use crate::error::{AppError, IntoResponse};
1188        if self.0.check(&conn.client.ip) {
1189            next.execute(request, conn)
1190        } else {
1191            Ok(AppError::TooManyRequests.into_response())
1192        }
1193    }
1194}
1195
1196// ── BearerAuthMiddleware ───────────────────────────────────────────────────────
1197
1198/// Bearer token authentication middleware.
1199pub(crate) struct BearerAuthMiddleware {
1200    pub(crate) token: Arc<String>,
1201}
1202
1203impl crate::middleware::Middleware for BearerAuthMiddleware {
1204    fn handle(
1205        &self,
1206        request: &Request,
1207        conn: &ConnectionInfo,
1208        next: &dyn Application,
1209    ) -> Result<Response, String> {
1210        use crate::error::{AppError, IntoResponse};
1211        let expected = format!("Bearer {}", self.token);
1212        let authorized = request
1213            .headers
1214            .iter()
1215            .any(|h| h.name.eq_ignore_ascii_case("authorization") && h.value == expected);
1216        if authorized {
1217            next.execute(request, conn)
1218        } else {
1219            Ok(AppError::Unauthorized.into_response())
1220        }
1221    }
1222}
1223
1224// ── arc_app helper ─────────────────────────────────────────────────────────────
1225
1226/// Box any `Application + Send + Sync + 'static` into an `Arc<dyn …>`.
1227pub(crate) fn arc_app<A: Application + Send + Sync + 'static>(
1228    a: A,
1229) -> Arc<dyn Application + Send + Sync> {
1230    Arc::new(a)
1231}
1232
1233// ── Public entry points ────────────────────────────────────────────────────────
1234
1235/// Build a `ConfigDrivenApp` from `rws.config.toml` and spawn L4/WS proxy
1236/// threads. Returns the app and a list of thread handles.
1237pub fn build_from_file() -> (ConfigDrivenApp, Vec<std::thread::JoinHandle<()>>) {
1238    builder::build_from_file()
1239}