Skip to main content

rust_web_server/proxy_config/
mod.rs

1//! Config-driven proxy application.
2//!
3//! When `rws.config.toml` contains `[[route]]` or `[[upstream]]` sections,
4//! `ConfigDrivenApp` is used as the top-level `Application` instead of the
5//! hardcoded `build_app()` in `main.rs`.
6//!
7//! # Quick start
8//!
9//! ```toml
10//! # rws.config.toml
11//! [[upstream]]
12//! name = "api"
13//! backends = ["localhost:3000"]
14//!
15//! [[route]]
16//! name = "api-proxy"
17//!
18//! [route.match]
19//! path = "/api/*"
20//!
21//! [route.action]
22//! type = "proxy"
23//!
24//! [route.action.proxy]
25//! upstream = "api"
26//! ```
27
28pub mod parser;
29pub mod health;
30pub mod builder;
31
32#[cfg(test)]
33mod tests;
34
35use std::sync::Arc;
36
37use crate::app::App;
38use crate::application::Application;
39use crate::core::New;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43use crate::server_config::ServerConfig;
44
45// ── Public config types ────────────────────────────────────────────────────────
46
47#[derive(Debug, Clone)]
48pub struct ProxyConfig {
49    pub upstreams: Vec<UpstreamConfig>,
50    pub routes: Vec<RouteConfig>,
51    pub tcp_proxies: Vec<TcpProxyConfig>,
52    pub udp_proxies: Vec<UdpProxyConfig>,
53    pub ws_proxies: Vec<WsProxyConfig>,
54    pub global_middleware: MiddlewareConfig,
55}
56
57#[derive(Debug, Clone)]
58pub struct UpstreamConfig {
59    pub name: String,
60    pub backends: Vec<String>,
61    pub strategy: String, // "round_robin" | "random" | "ip_hash"
62    pub health_check: Option<HealthCheckConfig>,
63    /// `true` when all backends use `https://` scheme — connections to the
64    /// upstream are made over TLS. Requires the `http-client` or `http2`
65    /// feature (which bring in `rustls` + `webpki-roots`).
66    pub tls: bool,
67}
68
69#[derive(Debug, Clone)]
70pub struct HealthCheckConfig {
71    pub path: String,
72    pub interval_secs: u64,
73    pub timeout_ms: u64,
74    pub healthy_threshold: u32,
75    pub unhealthy_threshold: u32,
76}
77
78#[derive(Debug, Clone)]
79pub struct RouteConfig {
80    pub name: String,
81    pub match_: MatchConfig,
82    pub action: ActionConfig,
83    pub middleware: MiddlewareConfig,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct MatchConfig {
88    pub host: Option<String>,
89    pub path: Option<String>,
90    pub method: Option<String>,
91    pub content_type: Option<String>,
92}
93
94#[derive(Debug, Clone)]
95pub enum ActionConfig {
96    Proxy {
97        upstream: String,
98        connect_timeout_ms: u64,
99        read_timeout_ms: u64,
100        strip_path_prefix: Option<String>,
101        add_path_prefix: Option<String>,
102    },
103    Grpc {
104        upstream: String,
105        connect_timeout_ms: u64,
106        read_timeout_ms: u64,
107    },
108    Static {
109        root: String,
110        index: Vec<String>,
111    },
112    Redirect {
113        location: String,
114        status: u16,
115    },
116    Respond {
117        status: u16,
118        body: String,
119        content_type: String,
120    },
121    Mcp,
122    Unknown(String),
123}
124
125#[derive(Debug, Clone, Default)]
126pub struct MiddlewareConfig {
127    pub rate_limit: Option<RateLimitConfig>,
128    pub cache: Option<CacheConfig>,
129    pub auth: Option<AuthConfig>,
130    pub rewrite_request: Vec<RewriteRuleConfig>,
131    pub rewrite_response: Vec<RewriteRuleConfig>,
132    pub ip_allow: Vec<String>,
133    pub ip_deny: Vec<String>,
134    /// `timeout_ms` — if the route (including all its other middleware)
135    /// doesn't produce a response within this many milliseconds, the client
136    /// gets `504 Gateway Timeout` instead of waiting further. See
137    /// `crate::timeout` for the underlying mechanism and its limitations.
138    pub timeout_ms: Option<u64>,
139}
140
141#[derive(Debug, Clone)]
142pub struct RateLimitConfig {
143    pub max_requests: u32,
144    pub window_secs: u64,
145}
146
147#[derive(Debug, Clone)]
148pub struct CacheConfig {
149    pub ttl_secs: u64,
150    pub vary_by: Vec<String>,
151}
152
153#[derive(Debug, Clone)]
154pub enum AuthConfig {
155    Basic { users_file: String },
156    Jwt { secret_env: String },
157    Bearer { token_env: String },
158}
159
160#[derive(Debug, Clone, Default)]
161pub struct RewriteRuleConfig {
162    pub type_: String,
163    pub name: Option<String>,
164    pub value: Option<String>,
165    pub prefix: Option<String>,
166    pub from: Option<String>,
167    pub to: Option<String>,
168    pub code: Option<u16>,
169    pub reason: Option<String>,
170}
171
172#[derive(Debug, Clone)]
173pub struct TcpProxyConfig {
174    pub name: String,
175    pub listen: String,
176    pub backends: Vec<String>,
177    pub connect_timeout_ms: u64,
178}
179
180#[derive(Debug, Clone)]
181pub struct UdpProxyConfig {
182    pub name: String,
183    pub listen: String,
184    pub backends: Vec<String>,
185    pub reply_timeout_ms: u64,
186    pub buffer_size: usize,
187}
188
189#[derive(Debug, Clone)]
190pub struct WsProxyConfig {
191    pub name: String,
192    pub listen: String,
193    pub backends: Vec<String>,
194    pub connect_timeout_ms: u64,
195    pub read_timeout_ms: u64,
196}
197
198// ── ProxyConfig loading ────────────────────────────────────────────────────────
199
200impl ProxyConfig {
201    /// Returns `true` if `rws.config.toml` (or `RWS_CONFIG_FILE`) contains
202    /// `[[route]]` or `[[upstream]]` sections, meaning config-driven mode
203    /// should be used.
204    pub fn is_proxy_mode() -> bool {
205        let path = config_file_path();
206        match std::fs::read_to_string(&path) {
207            Ok(contents) => {
208                contents.contains("[[route]]") || contents.contains("[[upstream]]")
209            }
210            Err(_) => false,
211        }
212    }
213
214    /// Parse the config file and return a `ProxyConfig`.
215    pub fn load() -> Self {
216        let path = config_file_path();
217        let contents = std::fs::read_to_string(&path).unwrap_or_default();
218        Self::from_str(&contents)
219    }
220
221    /// Parse `toml` text directly into a `ProxyConfig`. Used in tests.
222    pub fn from_str(toml: &str) -> Self {
223        use parser::{get_array, get_str, get_u32, get_u64, section_exists};
224
225        let map = parser::parse(toml);
226
227        // ── upstreams ──────────────────────────────────────────────────────────
228        let mut upstreams = Vec::new();
229        let mut i = 0;
230        loop {
231            let sec = format!("upstream[{}]", i);
232            if !section_exists(&map, &sec) {
233                break;
234            }
235            let name = get_str(&map, &sec, "name");
236            let backends = get_array(&map, &sec, "backends");
237            let strategy = {
238                let s = get_str(&map, &sec, "strategy");
239                if s.is_empty() { "round_robin".to_string() } else { s }
240            };
241            let hc_sec = format!("{}.health_check", sec);
242            let health_check = if section_exists(&map, &hc_sec) {
243                Some(HealthCheckConfig {
244                    path: {
245                        let p = get_str(&map, &hc_sec, "path");
246                        if p.is_empty() { "/health".to_string() } else { p }
247                    },
248                    interval_secs: get_u64(&map, &hc_sec, "interval_secs", 30),
249                    timeout_ms: get_u64(&map, &hc_sec, "timeout_ms", 5000),
250                    healthy_threshold: get_u32(&map, &hc_sec, "healthy_threshold", 2),
251                    unhealthy_threshold: get_u32(&map, &hc_sec, "unhealthy_threshold", 3),
252                })
253            } else {
254                None
255            };
256            let tls = backends.iter().any(|b| b.starts_with("https://"));
257            upstreams.push(UpstreamConfig { name, backends, strategy, health_check, tls });
258            i += 1;
259        }
260
261        // ── routes ─────────────────────────────────────────────────────────────
262        let mut routes = Vec::new();
263        let mut i = 0;
264        loop {
265            let sec = format!("route[{}]", i);
266            if !section_exists(&map, &sec) {
267                break;
268            }
269            let name = get_str(&map, &sec, "name");
270
271            // match
272            let m_sec = format!("{}.match", sec);
273            let match_ = MatchConfig {
274                host: {
275                    let h = get_str(&map, &m_sec, "host");
276                    if h.is_empty() { None } else { Some(h) }
277                },
278                path: {
279                    let p = get_str(&map, &m_sec, "path");
280                    if p.is_empty() { None } else { Some(p) }
281                },
282                method: {
283                    let m = get_str(&map, &m_sec, "method");
284                    if m.is_empty() { None } else { Some(m.to_uppercase()) }
285                },
286                content_type: {
287                    let c = get_str(&map, &m_sec, "content_type");
288                    if c.is_empty() { None } else { Some(c) }
289                },
290            };
291
292            // action
293            let a_sec = format!("{}.action", sec);
294            let action_type = get_str(&map, &a_sec, "type");
295            let action = match action_type.as_str() {
296                "proxy" => {
297                    let p_sec = format!("{}.action.proxy", sec);
298                    ActionConfig::Proxy {
299                        upstream: get_str(&map, &p_sec, "upstream"),
300                        connect_timeout_ms: get_u64(&map, &p_sec, "connect_timeout_ms", 5000),
301                        read_timeout_ms: get_u64(&map, &p_sec, "read_timeout_ms", 30000),
302                        strip_path_prefix: {
303                            let v = get_str(&map, &p_sec, "strip_path_prefix");
304                            if v.is_empty() { None } else { Some(v) }
305                        },
306                        add_path_prefix: {
307                            let v = get_str(&map, &p_sec, "add_path_prefix");
308                            if v.is_empty() { None } else { Some(v) }
309                        },
310                    }
311                }
312                "grpc" => {
313                    let p_sec = format!("{}.action.grpc", sec);
314                    ActionConfig::Grpc {
315                        upstream: get_str(&map, &p_sec, "upstream"),
316                        connect_timeout_ms: get_u64(&map, &p_sec, "connect_timeout_ms", 5000),
317                        read_timeout_ms: get_u64(&map, &p_sec, "read_timeout_ms", 30000),
318                    }
319                }
320                "static" => {
321                    let s_sec = format!("{}.action.static", sec);
322                    ActionConfig::Static {
323                        root: get_str(&map, &s_sec, "root"),
324                        index: get_array(&map, &s_sec, "index"),
325                    }
326                }
327                "redirect" => {
328                    let r_sec = format!("{}.action.redirect", sec);
329                    ActionConfig::Redirect {
330                        location: get_str(&map, &r_sec, "location"),
331                        status: get_u64(&map, &r_sec, "status", 301) as u16,
332                    }
333                }
334                "respond" => {
335                    let r_sec = format!("{}.action.respond", sec);
336                    ActionConfig::Respond {
337                        status: get_u64(&map, &r_sec, "status", 200) as u16,
338                        body: get_str(&map, &r_sec, "body"),
339                        content_type: {
340                            let ct = get_str(&map, &r_sec, "content_type");
341                            if ct.is_empty() { "text/plain".to_string() } else { ct }
342                        },
343                    }
344                }
345                "mcp" => ActionConfig::Mcp,
346                other => ActionConfig::Unknown(other.to_string()),
347            };
348
349            // middleware
350            let mw_sec = format!("{}.middleware", sec);
351            let middleware = parse_middleware_config(&map, &mw_sec, i);
352
353            routes.push(RouteConfig { name, match_, action, middleware });
354            i += 1;
355        }
356
357        // ── tcp_proxy ──────────────────────────────────────────────────────────
358        let mut tcp_proxies = Vec::new();
359        let mut i = 0;
360        loop {
361            let sec = format!("tcp_proxy[{}]", i);
362            if !section_exists(&map, &sec) {
363                break;
364            }
365            tcp_proxies.push(TcpProxyConfig {
366                name: get_str(&map, &sec, "name"),
367                listen: get_str(&map, &sec, "listen"),
368                backends: get_array(&map, &sec, "backends"),
369                connect_timeout_ms: get_u64(&map, &sec, "connect_timeout_ms", 5000),
370            });
371            i += 1;
372        }
373
374        // ── udp_proxy ──────────────────────────────────────────────────────────
375        let mut udp_proxies = Vec::new();
376        let mut i = 0;
377        loop {
378            let sec = format!("udp_proxy[{}]", i);
379            if !section_exists(&map, &sec) {
380                break;
381            }
382            udp_proxies.push(UdpProxyConfig {
383                name: get_str(&map, &sec, "name"),
384                listen: get_str(&map, &sec, "listen"),
385                backends: get_array(&map, &sec, "backends"),
386                reply_timeout_ms: get_u64(&map, &sec, "reply_timeout_ms", 5000),
387                buffer_size: get_u64(&map, &sec, "buffer_size", 65536) as usize,
388            });
389            i += 1;
390        }
391
392        // ── ws_proxy ───────────────────────────────────────────────────────────
393        let mut ws_proxies = Vec::new();
394        let mut i = 0;
395        loop {
396            let sec = format!("ws_proxy[{}]", i);
397            if !section_exists(&map, &sec) {
398                break;
399            }
400            ws_proxies.push(WsProxyConfig {
401                name: get_str(&map, &sec, "name"),
402                listen: get_str(&map, &sec, "listen"),
403                backends: get_array(&map, &sec, "backends"),
404                connect_timeout_ms: get_u64(&map, &sec, "connect_timeout_ms", 5000),
405                read_timeout_ms: get_u64(&map, &sec, "read_timeout_ms", 30000),
406            });
407            i += 1;
408        }
409
410        // ── global middleware ──────────────────────────────────────────────────
411        let global_middleware = parse_middleware_config(&map, "middleware", usize::MAX);
412
413        ProxyConfig {
414            upstreams,
415            routes,
416            tcp_proxies,
417            udp_proxies,
418            ws_proxies,
419            global_middleware,
420        }
421    }
422}
423
424/// Parse a `MiddlewareConfig` from the section map at a given base path.
425/// `route_idx` is used only to build inner-array section paths for rewrite rules.
426fn parse_middleware_config(
427    map: &parser::SectionMap,
428    mw_sec: &str,
429    route_idx: usize,
430) -> MiddlewareConfig {
431    use parser::{get_array, get_str, get_u32, get_u64, section_exists};
432
433    let rl_sec = format!("{}.rate_limit", mw_sec);
434    let rate_limit = if section_exists(map, &rl_sec) {
435        Some(RateLimitConfig {
436            max_requests: get_u32(map, &rl_sec, "max_requests", 1000),
437            window_secs: get_u64(map, &rl_sec, "window_secs", 60),
438        })
439    } else {
440        None
441    };
442
443    let c_sec = format!("{}.cache", mw_sec);
444    let cache = if section_exists(map, &c_sec) {
445        Some(CacheConfig {
446            ttl_secs: get_u64(map, &c_sec, "ttl_secs", 60),
447            vary_by: get_array(map, &c_sec, "vary_by"),
448        })
449    } else {
450        None
451    };
452
453    let a_sec = format!("{}.auth", mw_sec);
454    let auth = if section_exists(map, &a_sec) {
455        let auth_type = get_str(map, &a_sec, "type");
456        match auth_type.as_str() {
457            "bearer" => Some(AuthConfig::Bearer {
458                token_env: get_str(map, &a_sec, "token_env"),
459            }),
460            "jwt" => Some(AuthConfig::Jwt {
461                secret_env: get_str(map, &a_sec, "secret_env"),
462            }),
463            "basic" => Some(AuthConfig::Basic {
464                users_file: get_str(map, &a_sec, "users_file"),
465            }),
466            _ => None,
467        }
468    } else {
469        None
470    };
471
472    // Rewrite rules — the section paths use route_idx for route-scoped rules
473    // or a flat path for global middleware. We look for:
474    //   route[N].middleware.rewrite.request[0], [1], …
475    //   route[N].middleware.rewrite.response[0], [1], …
476    // For global: middleware.rewrite.request[0], etc.
477    let rewrite_request = collect_rewrite_rules(map, mw_sec, "request");
478    let rewrite_response = collect_rewrite_rules(map, mw_sec, "response");
479
480    let ip_sec = format!("{}.ip_filter", mw_sec);
481    let ip_allow = if section_exists(map, &ip_sec) {
482        get_array(map, &ip_sec, "allow")
483    } else {
484        vec![]
485    };
486    let ip_deny = if section_exists(map, &ip_sec) {
487        get_array(map, &ip_sec, "deny")
488    } else {
489        vec![]
490    };
491
492    let _ = route_idx; // used implicitly via mw_sec paths
493
494    // Flat scalar directly under [route.middleware] (or the global
495    // [middleware] table), not a nested sub-table like rate_limit/cache —
496    // 0/absent both mean "no timeout configured".
497    let timeout_ms = match get_u64(map, mw_sec, "timeout_ms", 0) {
498        0 => None,
499        ms => Some(ms),
500    };
501
502    MiddlewareConfig { rate_limit, cache, auth, rewrite_request, rewrite_response, ip_allow, ip_deny, timeout_ms }
503}
504
505/// Collect `[[{mw_sec}.rewrite.{direction}]]` entries.
506fn collect_rewrite_rules(
507    map: &parser::SectionMap,
508    mw_sec: &str,
509    direction: &str,
510) -> Vec<RewriteRuleConfig> {
511    use parser::{get_str, get_u64};
512
513    let mut rules = Vec::new();
514    let mut j = 0;
515    loop {
516        let rsec = format!("{}.rewrite.{}[{}]", mw_sec, direction, j);
517        if !parser::section_exists(map, &rsec) {
518            break;
519        }
520        let code_val = get_u64(map, &rsec, "code", 0);
521        rules.push(RewriteRuleConfig {
522            type_: get_str(map, &rsec, "type"),
523            name: {
524                let v = get_str(map, &rsec, "name");
525                if v.is_empty() { None } else { Some(v) }
526            },
527            value: {
528                let v = get_str(map, &rsec, "value");
529                if v.is_empty() { None } else { Some(v) }
530            },
531            prefix: {
532                let v = get_str(map, &rsec, "prefix");
533                if v.is_empty() { None } else { Some(v) }
534            },
535            from: {
536                let v = get_str(map, &rsec, "from");
537                if v.is_empty() { None } else { Some(v) }
538            },
539            to: {
540                let v = get_str(map, &rsec, "to");
541                if v.is_empty() { None } else { Some(v) }
542            },
543            code: if code_val == 0 { None } else { Some(code_val as u16) },
544            reason: {
545                let v = get_str(map, &rsec, "reason");
546                if v.is_empty() { None } else { Some(v) }
547            },
548        });
549        j += 1;
550    }
551    rules
552}
553
554fn config_file_path() -> String {
555    std::env::var("RWS_CONFIG_FILE").unwrap_or_else(|_| "rws.config.toml".to_string())
556}
557
558// ── ConfigDrivenApp ────────────────────────────────────────────────────────────
559
560/// A compiled route: a matcher paired with a handler application.
561pub(crate) struct CompiledRoute {
562    pub(crate) matcher: RouteMatcher,
563    /// Shared, type-erased handler. `Arc` makes `Clone` cheap (pointer copy).
564    pub(crate) handler: Arc<dyn Application + Send + Sync>,
565}
566
567/// Matching criteria for a single route.
568#[derive(Clone, Default)]
569pub(crate) struct RouteMatcher {
570    /// Optional SNI hostname / `Host` header match.
571    pub(crate) host: Option<String>,
572    /// Path prefix to match (derived from `path = "/v1/*"`).
573    pub(crate) path_prefix: Option<String>,
574    /// Exact path to match (derived from `path = "/v1/ping"`).
575    pub(crate) path_exact: Option<String>,
576    /// Uppercase HTTP method, or `None` for any.
577    pub(crate) method: Option<String>,
578    /// `Content-Type` prefix (e.g. `"application/grpc"`).
579    pub(crate) content_type_prefix: Option<String>,
580}
581
582impl RouteMatcher {
583    pub(crate) fn from_match_config(cfg: &MatchConfig) -> Self {
584        let (path_prefix, path_exact) = match &cfg.path {
585            Some(p) if p.ends_with('*') => {
586                // "/v1/*" → prefix "/v1/"
587                let stripped = p.trim_end_matches('*').to_string();
588                (Some(stripped), None)
589            }
590            Some(p) => (None, Some(p.clone())),
591            None => (None, None),
592        };
593        let content_type_prefix = cfg.content_type.as_ref().map(|ct| {
594            if ct.ends_with('*') {
595                ct.trim_end_matches('*').to_string()
596            } else {
597                ct.clone()
598            }
599        });
600        RouteMatcher {
601            host: cfg.host.clone(),
602            path_prefix,
603            path_exact,
604            method: cfg.method.clone(),
605            content_type_prefix,
606        }
607    }
608
609    /// Returns `true` if `request` and `conn` match all configured criteria.
610    pub(crate) fn matches(&self, request: &Request, conn: &ConnectionInfo) -> bool {
611        // Host matching: SNI first, then Host header
612        if let Some(ref expected_host) = self.host {
613            let actual_host = conn
614                .sni_hostname
615                .as_deref()
616                .or_else(|| {
617                    request
618                        .headers
619                        .iter()
620                        .find(|h| h.name.eq_ignore_ascii_case("host"))
621                        .map(|h| h.value.as_str())
622                })
623                .unwrap_or("");
624            if actual_host != expected_host.as_str() {
625                return false;
626            }
627        }
628
629        // Method matching
630        if let Some(ref m) = self.method {
631            if request.method.to_uppercase() != m.as_str() {
632                return false;
633            }
634        }
635
636        // Path matching: strip query string for comparison
637        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
638        if let Some(ref prefix) = self.path_prefix {
639            if !path.starts_with(prefix.as_str()) {
640                return false;
641            }
642        } else if let Some(ref exact) = self.path_exact {
643            if path != exact.as_str() {
644                return false;
645            }
646        }
647
648        // Content-Type prefix matching
649        if let Some(ref ct_prefix) = self.content_type_prefix {
650            let actual_ct = request
651                .headers
652                .iter()
653                .find(|h| h.name.eq_ignore_ascii_case("content-type"))
654                .map(|h| h.value.as_str())
655                .unwrap_or("");
656            if !actual_ct.starts_with(ct_prefix.as_str()) {
657                return false;
658            }
659        }
660
661        true
662    }
663}
664
665/// An `Application` that routes requests based on a parsed `ProxyConfig`.
666///
667/// `Clone` is cheap: `routes` is an `Arc<Vec<...>>` (pointer copy), and
668/// `fallback` is `App`, itself cheap to clone (an `Option<Arc<ServerConfig>>`).
669#[derive(Clone)]
670pub struct ConfigDrivenApp {
671    routes: Arc<Vec<CompiledRoute>>,
672    /// Fallback for unmatched requests — handles /healthz, /readyz, /metrics,
673    /// static files, and the 404 controller. Reads `RWS_CONFIG_*` env vars
674    /// per request (`App::new()`'s default) unless pinned via
675    /// [`ConfigDrivenApp::with_config`].
676    fallback: App,
677}
678
679impl ConfigDrivenApp {
680    pub(crate) fn new(routes: Vec<CompiledRoute>) -> Self {
681        use crate::core::New;
682        ConfigDrivenApp {
683            routes: Arc::new(routes),
684            fallback: App::new(),
685        }
686    }
687
688    /// Pin the fallback [`App`] (used for any request none of the
689    /// config-driven routes match) to an explicit [`ServerConfig`], instead
690    /// of reading `RWS_CONFIG_*` environment variables per request.
691    ///
692    /// Mirrors [`App::with_config`] / [`crate::state::AppWithState::with_config`]
693    /// — same rationale: safe for parallel tests, and lets multiple
694    /// differently-configured proxy instances coexist in one process.
695    ///
696    /// ```rust,no_run
697    /// use rust_web_server::proxy_config::build_from_file;
698    /// use rust_web_server::server_config::ServerConfig;
699    ///
700    /// let (app, _handles) = build_from_file();
701    /// let app = app.with_config(ServerConfig::default());
702    /// ```
703    pub fn with_config(mut self, config: ServerConfig) -> Self {
704        self.fallback = App::with_config(config);
705        self
706    }
707}
708
709impl Application for ConfigDrivenApp {
710    fn execute(&self, request: &Request, conn: &ConnectionInfo) -> Result<Response, String> {
711        for route in self.routes.iter() {
712            if route.matcher.matches(request, conn) {
713                return route.handler.execute(request, conn);
714            }
715        }
716        self.fallback.execute(request, conn)
717    }
718}
719
720// ── NullApp ────────────────────────────────────────────────────────────────────
721
722/// A dead-end `Application` that always returns 404.
723/// Used as the `next` parameter when calling `Middleware::handle` directly.
724#[derive(Clone, Copy)]
725pub(crate) struct NullApp;
726
727impl Application for NullApp {
728    fn execute(&self, _request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
729        let mut r = Response::new();
730        r.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
731        r.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
732        Ok(r)
733    }
734}
735
736// ── DynamicProxy ──────────────────────────────────────────────────────────────
737
738use std::collections::HashMap;
739use std::collections::hash_map::DefaultHasher;
740use std::hash::{Hash, Hasher};
741use std::sync::atomic::{AtomicUsize, Ordering};
742use std::sync::RwLock;
743use std::time::{Duration, SystemTime, UNIX_EPOCH};
744
745/// Backend-selection strategy for `DynamicProxy`, configured via the
746/// `strategy` field on `[[upstream]]` in `rws.config.toml`. Unknown or empty
747/// values fall back to `RoundRobin`, matching the parser's own default.
748#[derive(Clone, Copy, Debug, PartialEq, Eq)]
749pub(crate) enum LoadBalanceStrategy {
750    RoundRobin,
751    Random,
752    IpHash,
753    LeastConnections,
754}
755
756impl LoadBalanceStrategy {
757    fn parse(s: &str) -> Self {
758        match s {
759            "random" => LoadBalanceStrategy::Random,
760            "ip_hash" => LoadBalanceStrategy::IpHash,
761            "least_connections" => LoadBalanceStrategy::LeastConnections,
762            _ => LoadBalanceStrategy::RoundRobin,
763        }
764    }
765}
766
767/// A proxy adapter that reads its backend list from a shared, health-checker-
768/// maintained live list at request time. Supports dynamic removal/restoration
769/// of backends without restarting.
770///
771/// This type is `Clone + Send + Sync` and implements `Application`.
772#[derive(Clone)]
773pub(crate) struct DynamicProxy {
774    live: Arc<RwLock<Vec<String>>>,
775    counter: Arc<AtomicUsize>,
776    connect_timeout: Duration,
777    read_timeout: Duration,
778    strip_prefix: Option<Arc<String>>,
779    add_prefix: Option<Arc<String>>,
780    tls: bool,
781    strategy: LoadBalanceStrategy,
782    connections: Arc<RwLock<HashMap<String, Arc<AtomicUsize>>>>,
783}
784
785impl DynamicProxy {
786    pub(crate) fn new(
787        live: Arc<RwLock<Vec<String>>>,
788        connect_timeout_ms: u64,
789        read_timeout_ms: u64,
790        strip_prefix: Option<String>,
791        add_prefix: Option<String>,
792        tls: bool,
793        strategy: String,
794    ) -> Self {
795        DynamicProxy {
796            live,
797            counter: Arc::new(AtomicUsize::new(0)),
798            connect_timeout: Duration::from_millis(connect_timeout_ms),
799            read_timeout: Duration::from_millis(read_timeout_ms),
800            strip_prefix: strip_prefix.map(Arc::new),
801            add_prefix: add_prefix.map(Arc::new),
802            tls,
803            strategy: LoadBalanceStrategy::parse(&strategy),
804            connections: Arc::new(RwLock::new(HashMap::new())),
805        }
806    }
807
808    fn next_backend(&self, client_ip: &str) -> Option<String> {
809        let live = self.live.read().unwrap();
810        if live.is_empty() {
811            return None;
812        }
813
814        let idx = match self.strategy {
815            LoadBalanceStrategy::RoundRobin => {
816                self.counter.fetch_add(1, Ordering::Relaxed) % live.len()
817            }
818            LoadBalanceStrategy::Random => {
819                let nanos = SystemTime::now()
820                    .duration_since(UNIX_EPOCH)
821                    .map(|d| d.subsec_nanos())
822                    .unwrap_or(0) as usize;
823                let salt = self.counter.fetch_add(1, Ordering::Relaxed);
824                nanos.wrapping_add(salt) % live.len()
825            }
826            LoadBalanceStrategy::IpHash => {
827                let mut hasher = DefaultHasher::new();
828                client_ip.hash(&mut hasher);
829                (hasher.finish() as usize) % live.len()
830            }
831            LoadBalanceStrategy::LeastConnections => {
832                let connections = self.connections.read().unwrap();
833                live.iter()
834                    .enumerate()
835                    .min_by_key(|(_, backend)| {
836                        connections
837                            .get(*backend)
838                            .map(|c| c.load(Ordering::Relaxed))
839                            .unwrap_or(0)
840                    })
841                    .map(|(i, _)| i)
842                    .unwrap_or(0)
843            }
844        };
845
846        Some(live[idx].clone())
847    }
848
849    /// Returns the shared in-flight connection counter for `backend`,
850    /// creating it on first use. Only consulted under `LeastConnections`.
851    fn connection_counter(&self, backend: &str) -> Arc<AtomicUsize> {
852        if let Some(counter) = self.connections.read().unwrap().get(backend) {
853            return Arc::clone(counter);
854        }
855        let mut connections = self.connections.write().unwrap();
856        Arc::clone(
857            connections
858                .entry(backend.to_string())
859                .or_insert_with(|| Arc::new(AtomicUsize::new(0))),
860        )
861    }
862}
863
864/// Decrements a backend's in-flight connection count when the request
865/// finishes (including early returns), keeping `least_connections` accurate.
866struct ConnectionGuard {
867    counter: Arc<AtomicUsize>,
868}
869
870impl Drop for ConnectionGuard {
871    fn drop(&mut self) {
872        self.counter.fetch_sub(1, Ordering::Relaxed);
873    }
874}
875
876impl Application for DynamicProxy {
877    fn execute(&self, request: &Request, conn: &ConnectionInfo) -> Result<Response, String> {
878        let backend = match self.next_backend(&conn.client.ip) {
879            Some(b) => b,
880            None => {
881                return Ok(bad_gateway());
882            }
883        };
884
885        let _connection_guard = if self.strategy == LoadBalanceStrategy::LeastConnections {
886            let counter = self.connection_counter(&backend);
887            counter.fetch_add(1, Ordering::Relaxed);
888            Some(ConnectionGuard { counter })
889        } else {
890            None
891        };
892
893        let (host, port, _) = match crate::proxy_config::health::parse_backend_url(&backend) {
894            Some(t) => t,
895            None => return Ok(bad_gateway()),
896        };
897
898        // Apply path rewriting if configured
899        let mut req_clone;
900        let effective_request = if self.strip_prefix.is_some() || self.add_prefix.is_some() {
901            req_clone = request.clone();
902            if let Some(ref sp) = self.strip_prefix {
903                if let Some(stripped) = req_clone.request_uri.strip_prefix(sp.as_str()) {
904                    req_clone.request_uri = if stripped.is_empty() || !stripped.starts_with('/') {
905                        format!("/{}", stripped)
906                    } else {
907                        stripped.to_string()
908                    };
909                }
910            }
911            if let Some(ref ap) = self.add_prefix {
912                req_clone.request_uri = format!("{}{}", ap, req_clone.request_uri);
913            }
914            &req_clone
915        } else {
916            request
917        };
918
919        let result = if self.tls {
920            #[cfg(any(feature = "http-client", feature = "http2"))]
921            {
922                crate::proxy::proxy_https1(
923                    effective_request,
924                    &conn.client.ip,
925                    &host,
926                    port,
927                    self.connect_timeout,
928                    self.read_timeout,
929                )
930            }
931            #[cfg(not(any(feature = "http-client", feature = "http2")))]
932            {
933                eprintln!("[proxy] HTTPS upstream requires http-client or http2 feature");
934                Err("TLS upstream not supported in this build".to_string())
935            }
936        } else {
937            crate::proxy::proxy_http1(
938                effective_request,
939                &conn.client.ip,
940                &host,
941                port,
942                self.connect_timeout,
943                self.read_timeout,
944            )
945        };
946
947        match result {
948            Ok(r) => Ok(r),
949            Err(_) => Ok(bad_gateway()),
950        }
951    }
952}
953
954fn bad_gateway() -> Response {
955    use crate::mime_type::MimeType;
956    use crate::range::Range;
957    let cr = Range::get_content_range(
958        b"502 Bad Gateway".to_vec(),
959        MimeType::TEXT_PLAIN.to_string(),
960    );
961    let mut r = Response::new();
962    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
963    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
964    r.content_range_list = vec![cr];
965    r
966}
967
968// ── RedirectAdapter ────────────────────────────────────────────────────────────
969
970/// Action adapter that issues HTTP redirects.
971///
972/// `$path` in `location_template` is replaced with the request URI at runtime.
973#[derive(Clone)]
974pub(crate) struct RedirectAdapter {
975    location_template: Arc<String>,
976    status: i16,
977    reason: Arc<String>,
978}
979
980impl RedirectAdapter {
981    pub(crate) fn new(location: String, status: u16) -> Self {
982        let (code, reason) = redirect_status(status);
983        RedirectAdapter {
984            location_template: Arc::new(location),
985            status: code,
986            reason: Arc::new(reason),
987        }
988    }
989}
990
991fn redirect_status(code: u16) -> (i16, String) {
992    let phrase = match code {
993        301 => STATUS_CODE_REASON_PHRASE.n301_moved_permanently.reason_phrase,
994        302 => STATUS_CODE_REASON_PHRASE.n302_found.reason_phrase,
995        307 => STATUS_CODE_REASON_PHRASE.n307_temporary_redirect.reason_phrase,
996        308 => STATUS_CODE_REASON_PHRASE.n308_permanent_redirect.reason_phrase,
997        _ => "Redirect",
998    };
999    (code as i16, phrase.to_string())
1000}
1001
1002impl Application for RedirectAdapter {
1003    fn execute(&self, request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1004        let location = self
1005            .location_template
1006            .replace("$path", &request.request_uri);
1007        use crate::header::Header;
1008        let mut r = Response::new();
1009        r.status_code = self.status;
1010        r.reason_phrase = self.reason.as_ref().clone();
1011        r.headers.push(Header { name: "Location".to_string(), value: location });
1012        Ok(r)
1013    }
1014}
1015
1016// ── RespondAdapter ─────────────────────────────────────────────────────────────
1017
1018/// Action adapter that returns a fixed response body.
1019#[derive(Clone)]
1020pub(crate) struct RespondAdapter {
1021    status: i16,
1022    reason: Arc<String>,
1023    body: Arc<Vec<u8>>,
1024    content_type: Arc<String>,
1025}
1026
1027impl RespondAdapter {
1028    pub(crate) fn new(status: u16, body: String, content_type: String) -> Self {
1029        use crate::response::STATUS_CODE_REASON_PHRASE;
1030        let reason = match status {
1031            200 => STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string(),
1032            201 => STATUS_CODE_REASON_PHRASE.n201_created.reason_phrase.to_string(),
1033            204 => STATUS_CODE_REASON_PHRASE.n204_no_content.reason_phrase.to_string(),
1034            400 => STATUS_CODE_REASON_PHRASE.n400_bad_request.reason_phrase.to_string(),
1035            401 => STATUS_CODE_REASON_PHRASE.n401_unauthorized.reason_phrase.to_string(),
1036            403 => STATUS_CODE_REASON_PHRASE.n403_forbidden.reason_phrase.to_string(),
1037            404 => STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string(),
1038            500 => STATUS_CODE_REASON_PHRASE.n500_internal_server_error.reason_phrase.to_string(),
1039            _ => "OK".to_string(),
1040        };
1041        RespondAdapter {
1042            status: status as i16,
1043            reason: Arc::new(reason),
1044            body: Arc::new(body.into_bytes()),
1045            content_type: Arc::new(content_type),
1046        }
1047    }
1048}
1049
1050impl Application for RespondAdapter {
1051    fn execute(&self, _request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1052        use crate::range::Range;
1053        let mut r = Response::new();
1054        r.status_code = self.status;
1055        r.reason_phrase = self.reason.as_ref().clone();
1056        if !self.body.is_empty() {
1057            r.content_range_list = vec![Range::get_content_range(
1058                self.body.as_ref().clone(),
1059                self.content_type.as_ref().clone(),
1060            )];
1061        }
1062        Ok(r)
1063    }
1064}
1065
1066// ── StaticAdapter ──────────────────────────────────────────────────────────────
1067
1068/// Action adapter that serves static files from a configured `root` directory.
1069///
1070/// Unlike `StaticResourceController` (which always resolves paths relative to
1071/// the process's current working directory), this adapter is parameterized
1072/// per-route by `ActionConfig::Static { root, index }` from `rws.config.toml`,
1073/// so a config-driven proxy can serve an arbitrary directory without Rust code.
1074#[derive(Clone)]
1075pub(crate) struct StaticAdapter {
1076    root: Arc<std::path::PathBuf>,
1077    index: Arc<Vec<String>>,
1078}
1079
1080impl StaticAdapter {
1081    pub(crate) fn new(root: String, index: Vec<String>) -> Self {
1082        let index = if index.is_empty() { vec!["index.html".to_string()] } else { index };
1083        StaticAdapter {
1084            root: Arc::new(std::path::PathBuf::from(root)),
1085            index: Arc::new(index),
1086        }
1087    }
1088
1089    /// Resolves `request_uri` against `root`. Returns `None` if the decoded
1090    /// path contains a `..` segment, which would otherwise let a request
1091    /// escape the configured root directory.
1092    fn resolve(&self, request_uri: &str) -> Option<std::path::PathBuf> {
1093        let raw_path = request_uri.split('?').next().unwrap_or(request_uri);
1094        let decoded = crate::url::URL::percent_decode(raw_path);
1095
1096        if decoded.split('/').any(|segment| segment == "..") {
1097            return None;
1098        }
1099
1100        let relative = decoded.trim_start_matches('/');
1101        Some(self.root.join(relative))
1102    }
1103}
1104
1105impl Application for StaticAdapter {
1106    fn execute(&self, request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1107        let mut response = Response::new();
1108
1109        let not_found = |mut response: Response| {
1110            response.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
1111            response.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
1112            response
1113        };
1114
1115        let candidate = match self.resolve(&request.request_uri) {
1116            Some(p) => p,
1117            None => {
1118                response.status_code = *STATUS_CODE_REASON_PHRASE.n403_forbidden.status_code;
1119                response.reason_phrase = STATUS_CODE_REASON_PHRASE.n403_forbidden.reason_phrase.to_string();
1120                return Ok(response);
1121            }
1122        };
1123
1124        let mut file_path = candidate;
1125        if file_path.is_dir() {
1126            let indexed = self
1127                .index
1128                .iter()
1129                .map(|name| file_path.join(name))
1130                .find(|p| p.is_file());
1131
1132            file_path = match indexed {
1133                Some(p) => p,
1134                None => return Ok(not_found(response)),
1135            };
1136        }
1137
1138        if !file_path.is_file() {
1139            return Ok(not_found(response));
1140        }
1141
1142        // Defense-in-depth against symlinks inside `root` that point outside it —
1143        // the `..`-segment check above only catches traversal in the request URI.
1144        if let (Ok(root_canon), Ok(file_canon)) =
1145            (self.root.canonicalize(), file_path.canonicalize())
1146        {
1147            if !file_canon.starts_with(&root_canon) {
1148                return Ok(not_found(response));
1149            }
1150        }
1151
1152        let path_str = match file_path.to_str() {
1153            Some(s) => s,
1154            None => return Ok(not_found(response)),
1155        };
1156
1157        match crate::range::Range::get_content_range_of_a_file(path_str) {
1158            Ok(content_range) => {
1159                response.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
1160                response.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
1161                response.content_range_list = vec![content_range];
1162                Ok(response)
1163            }
1164            Err(_) => Ok(not_found(response)),
1165        }
1166    }
1167}
1168
1169// ── PerRouteRateLimit middleware ───────────────────────────────────────────────
1170
1171/// A per-route rate limiter middleware backed by a shared `RateLimiter`.
1172pub(crate) struct PerRouteRateLimit(pub(crate) Arc<crate::rate_limit::RateLimiter>);
1173
1174impl crate::middleware::Middleware for PerRouteRateLimit {
1175    fn handle(
1176        &self,
1177        request: &Request,
1178        conn: &ConnectionInfo,
1179        next: &dyn Application,
1180    ) -> Result<Response, String> {
1181        use crate::error::{AppError, IntoResponse};
1182        if self.0.check(&conn.client.ip) {
1183            next.execute(request, conn)
1184        } else {
1185            Ok(AppError::TooManyRequests.into_response())
1186        }
1187    }
1188}
1189
1190// ── BearerAuthMiddleware ───────────────────────────────────────────────────────
1191
1192/// Bearer token authentication middleware.
1193pub(crate) struct BearerAuthMiddleware {
1194    pub(crate) token: Arc<String>,
1195}
1196
1197impl crate::middleware::Middleware for BearerAuthMiddleware {
1198    fn handle(
1199        &self,
1200        request: &Request,
1201        conn: &ConnectionInfo,
1202        next: &dyn Application,
1203    ) -> Result<Response, String> {
1204        use crate::error::{AppError, IntoResponse};
1205        let expected = format!("Bearer {}", self.token);
1206        let authorized = request
1207            .headers
1208            .iter()
1209            .any(|h| h.name.eq_ignore_ascii_case("authorization") && h.value == expected);
1210        if authorized {
1211            next.execute(request, conn)
1212        } else {
1213            Ok(AppError::Unauthorized.into_response())
1214        }
1215    }
1216}
1217
1218// ── arc_app helper ─────────────────────────────────────────────────────────────
1219
1220/// Box any `Application + Send + Sync + 'static` into an `Arc<dyn …>`.
1221pub(crate) fn arc_app<A: Application + Send + Sync + 'static>(
1222    a: A,
1223) -> Arc<dyn Application + Send + Sync> {
1224    Arc::new(a)
1225}
1226
1227// ── Public entry points ────────────────────────────────────────────────────────
1228
1229/// Build a `ConfigDrivenApp` from `rws.config.toml` and spawn L4/WS proxy
1230/// threads. Returns the app and a list of thread handles.
1231pub fn build_from_file() -> (ConfigDrivenApp, Vec<std::thread::JoinHandle<()>>) {
1232    builder::build_from_file()
1233}