Skip to main content

rust_web_server/proxy_config/
mod.rs

1//! Config-driven proxy application.
2//!
3//! When `rws.config.toml` contains `[[route]]` or `[[upstream]]` sections,
4//! `ConfigDrivenApp` is used as the top-level `Application` instead of the
5//! hardcoded `build_app()` in `main.rs`.
6//!
7//! # Quick start
8//!
9//! ```toml
10//! # rws.config.toml
11//! [[upstream]]
12//! name = "api"
13//! backends = ["localhost:3000"]
14//!
15//! [[route]]
16//! name = "api-proxy"
17//!
18//! [route.match]
19//! path = "/api/*"
20//!
21//! [route.action]
22//! type = "proxy"
23//!
24//! [route.action.proxy]
25//! upstream = "api"
26//! ```
27
28pub mod parser;
29pub mod health;
30pub mod builder;
31
32#[cfg(test)]
33mod tests;
34
35use std::sync::Arc;
36
37use crate::app::App;
38use crate::application::Application;
39use crate::core::New;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43use crate::server_config::ServerConfig;
44
45// ── Public config types ────────────────────────────────────────────────────────
46
47#[derive(Debug, Clone)]
48pub struct ProxyConfig {
49    pub upstreams: Vec<UpstreamConfig>,
50    pub routes: Vec<RouteConfig>,
51    pub tcp_proxies: Vec<TcpProxyConfig>,
52    pub udp_proxies: Vec<UdpProxyConfig>,
53    pub ws_proxies: Vec<WsProxyConfig>,
54    pub global_middleware: MiddlewareConfig,
55}
56
57#[derive(Debug, Clone)]
58pub struct UpstreamConfig {
59    pub name: String,
60    pub backends: Vec<String>,
61    pub strategy: String, // "round_robin" | "random" | "ip_hash"
62    pub health_check: Option<HealthCheckConfig>,
63    /// `true` when all backends use `https://` scheme — connections to the
64    /// upstream are made over TLS. Requires the `http-client` or `http2`
65    /// feature (which bring in `rustls` + `webpki-roots`).
66    pub tls: bool,
67}
68
69#[derive(Debug, Clone)]
70pub struct HealthCheckConfig {
71    pub path: String,
72    pub interval_secs: u64,
73    pub timeout_ms: u64,
74    pub healthy_threshold: u32,
75    pub unhealthy_threshold: u32,
76}
77
78#[derive(Debug, Clone)]
79pub struct RouteConfig {
80    pub name: String,
81    pub match_: MatchConfig,
82    pub action: ActionConfig,
83    pub middleware: MiddlewareConfig,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct MatchConfig {
88    pub host: Option<String>,
89    pub path: Option<String>,
90    pub method: Option<String>,
91    pub content_type: Option<String>,
92}
93
94#[derive(Debug, Clone)]
95pub enum ActionConfig {
96    Proxy {
97        upstream: String,
98        connect_timeout_ms: u64,
99        read_timeout_ms: u64,
100        strip_path_prefix: Option<String>,
101        add_path_prefix: Option<String>,
102    },
103    Grpc {
104        upstream: String,
105        connect_timeout_ms: u64,
106        read_timeout_ms: u64,
107    },
108    Static {
109        root: String,
110        index: Vec<String>,
111    },
112    Redirect {
113        location: String,
114        status: u16,
115    },
116    Respond {
117        status: u16,
118        body: String,
119        content_type: String,
120    },
121    Mcp,
122    Unknown(String),
123}
124
125#[derive(Debug, Clone, Default)]
126pub struct MiddlewareConfig {
127    pub rate_limit: Option<RateLimitConfig>,
128    pub cache: Option<CacheConfig>,
129    pub auth: Option<AuthConfig>,
130    pub rewrite_request: Vec<RewriteRuleConfig>,
131    pub rewrite_response: Vec<RewriteRuleConfig>,
132    pub ip_allow: Vec<String>,
133    pub ip_deny: Vec<String>,
134}
135
136#[derive(Debug, Clone)]
137pub struct RateLimitConfig {
138    pub max_requests: u32,
139    pub window_secs: u64,
140}
141
142#[derive(Debug, Clone)]
143pub struct CacheConfig {
144    pub ttl_secs: u64,
145    pub vary_by: Vec<String>,
146}
147
148#[derive(Debug, Clone)]
149pub enum AuthConfig {
150    Basic { users_file: String },
151    Jwt { secret_env: String },
152    Bearer { token_env: String },
153}
154
155#[derive(Debug, Clone, Default)]
156pub struct RewriteRuleConfig {
157    pub type_: String,
158    pub name: Option<String>,
159    pub value: Option<String>,
160    pub prefix: Option<String>,
161    pub from: Option<String>,
162    pub to: Option<String>,
163    pub code: Option<u16>,
164    pub reason: Option<String>,
165}
166
167#[derive(Debug, Clone)]
168pub struct TcpProxyConfig {
169    pub name: String,
170    pub listen: String,
171    pub backends: Vec<String>,
172    pub connect_timeout_ms: u64,
173}
174
175#[derive(Debug, Clone)]
176pub struct UdpProxyConfig {
177    pub name: String,
178    pub listen: String,
179    pub backends: Vec<String>,
180    pub reply_timeout_ms: u64,
181    pub buffer_size: usize,
182}
183
184#[derive(Debug, Clone)]
185pub struct WsProxyConfig {
186    pub name: String,
187    pub listen: String,
188    pub backends: Vec<String>,
189    pub connect_timeout_ms: u64,
190    pub read_timeout_ms: u64,
191}
192
193// ── ProxyConfig loading ────────────────────────────────────────────────────────
194
195impl ProxyConfig {
196    /// Returns `true` if `rws.config.toml` (or `RWS_CONFIG_FILE`) contains
197    /// `[[route]]` or `[[upstream]]` sections, meaning config-driven mode
198    /// should be used.
199    pub fn is_proxy_mode() -> bool {
200        let path = config_file_path();
201        match std::fs::read_to_string(&path) {
202            Ok(contents) => {
203                contents.contains("[[route]]") || contents.contains("[[upstream]]")
204            }
205            Err(_) => false,
206        }
207    }
208
209    /// Parse the config file and return a `ProxyConfig`.
210    pub fn load() -> Self {
211        let path = config_file_path();
212        let contents = std::fs::read_to_string(&path).unwrap_or_default();
213        Self::from_str(&contents)
214    }
215
216    /// Parse `toml` text directly into a `ProxyConfig`. Used in tests.
217    pub fn from_str(toml: &str) -> Self {
218        use parser::{get_array, get_str, get_u32, get_u64, section_exists};
219
220        let map = parser::parse(toml);
221
222        // ── upstreams ──────────────────────────────────────────────────────────
223        let mut upstreams = Vec::new();
224        let mut i = 0;
225        loop {
226            let sec = format!("upstream[{}]", i);
227            if !section_exists(&map, &sec) {
228                break;
229            }
230            let name = get_str(&map, &sec, "name");
231            let backends = get_array(&map, &sec, "backends");
232            let strategy = {
233                let s = get_str(&map, &sec, "strategy");
234                if s.is_empty() { "round_robin".to_string() } else { s }
235            };
236            let hc_sec = format!("{}.health_check", sec);
237            let health_check = if section_exists(&map, &hc_sec) {
238                Some(HealthCheckConfig {
239                    path: {
240                        let p = get_str(&map, &hc_sec, "path");
241                        if p.is_empty() { "/health".to_string() } else { p }
242                    },
243                    interval_secs: get_u64(&map, &hc_sec, "interval_secs", 30),
244                    timeout_ms: get_u64(&map, &hc_sec, "timeout_ms", 5000),
245                    healthy_threshold: get_u32(&map, &hc_sec, "healthy_threshold", 2),
246                    unhealthy_threshold: get_u32(&map, &hc_sec, "unhealthy_threshold", 3),
247                })
248            } else {
249                None
250            };
251            let tls = backends.iter().any(|b| b.starts_with("https://"));
252            upstreams.push(UpstreamConfig { name, backends, strategy, health_check, tls });
253            i += 1;
254        }
255
256        // ── routes ─────────────────────────────────────────────────────────────
257        let mut routes = Vec::new();
258        let mut i = 0;
259        loop {
260            let sec = format!("route[{}]", i);
261            if !section_exists(&map, &sec) {
262                break;
263            }
264            let name = get_str(&map, &sec, "name");
265
266            // match
267            let m_sec = format!("{}.match", sec);
268            let match_ = MatchConfig {
269                host: {
270                    let h = get_str(&map, &m_sec, "host");
271                    if h.is_empty() { None } else { Some(h) }
272                },
273                path: {
274                    let p = get_str(&map, &m_sec, "path");
275                    if p.is_empty() { None } else { Some(p) }
276                },
277                method: {
278                    let m = get_str(&map, &m_sec, "method");
279                    if m.is_empty() { None } else { Some(m.to_uppercase()) }
280                },
281                content_type: {
282                    let c = get_str(&map, &m_sec, "content_type");
283                    if c.is_empty() { None } else { Some(c) }
284                },
285            };
286
287            // action
288            let a_sec = format!("{}.action", sec);
289            let action_type = get_str(&map, &a_sec, "type");
290            let action = match action_type.as_str() {
291                "proxy" => {
292                    let p_sec = format!("{}.action.proxy", sec);
293                    ActionConfig::Proxy {
294                        upstream: get_str(&map, &p_sec, "upstream"),
295                        connect_timeout_ms: get_u64(&map, &p_sec, "connect_timeout_ms", 5000),
296                        read_timeout_ms: get_u64(&map, &p_sec, "read_timeout_ms", 30000),
297                        strip_path_prefix: {
298                            let v = get_str(&map, &p_sec, "strip_path_prefix");
299                            if v.is_empty() { None } else { Some(v) }
300                        },
301                        add_path_prefix: {
302                            let v = get_str(&map, &p_sec, "add_path_prefix");
303                            if v.is_empty() { None } else { Some(v) }
304                        },
305                    }
306                }
307                "grpc" => {
308                    let p_sec = format!("{}.action.grpc", sec);
309                    ActionConfig::Grpc {
310                        upstream: get_str(&map, &p_sec, "upstream"),
311                        connect_timeout_ms: get_u64(&map, &p_sec, "connect_timeout_ms", 5000),
312                        read_timeout_ms: get_u64(&map, &p_sec, "read_timeout_ms", 30000),
313                    }
314                }
315                "static" => {
316                    let s_sec = format!("{}.action.static", sec);
317                    ActionConfig::Static {
318                        root: get_str(&map, &s_sec, "root"),
319                        index: get_array(&map, &s_sec, "index"),
320                    }
321                }
322                "redirect" => {
323                    let r_sec = format!("{}.action.redirect", sec);
324                    ActionConfig::Redirect {
325                        location: get_str(&map, &r_sec, "location"),
326                        status: get_u64(&map, &r_sec, "status", 301) as u16,
327                    }
328                }
329                "respond" => {
330                    let r_sec = format!("{}.action.respond", sec);
331                    ActionConfig::Respond {
332                        status: get_u64(&map, &r_sec, "status", 200) as u16,
333                        body: get_str(&map, &r_sec, "body"),
334                        content_type: {
335                            let ct = get_str(&map, &r_sec, "content_type");
336                            if ct.is_empty() { "text/plain".to_string() } else { ct }
337                        },
338                    }
339                }
340                "mcp" => ActionConfig::Mcp,
341                other => ActionConfig::Unknown(other.to_string()),
342            };
343
344            // middleware
345            let mw_sec = format!("{}.middleware", sec);
346            let middleware = parse_middleware_config(&map, &mw_sec, i);
347
348            routes.push(RouteConfig { name, match_, action, middleware });
349            i += 1;
350        }
351
352        // ── tcp_proxy ──────────────────────────────────────────────────────────
353        let mut tcp_proxies = Vec::new();
354        let mut i = 0;
355        loop {
356            let sec = format!("tcp_proxy[{}]", i);
357            if !section_exists(&map, &sec) {
358                break;
359            }
360            tcp_proxies.push(TcpProxyConfig {
361                name: get_str(&map, &sec, "name"),
362                listen: get_str(&map, &sec, "listen"),
363                backends: get_array(&map, &sec, "backends"),
364                connect_timeout_ms: get_u64(&map, &sec, "connect_timeout_ms", 5000),
365            });
366            i += 1;
367        }
368
369        // ── udp_proxy ──────────────────────────────────────────────────────────
370        let mut udp_proxies = Vec::new();
371        let mut i = 0;
372        loop {
373            let sec = format!("udp_proxy[{}]", i);
374            if !section_exists(&map, &sec) {
375                break;
376            }
377            udp_proxies.push(UdpProxyConfig {
378                name: get_str(&map, &sec, "name"),
379                listen: get_str(&map, &sec, "listen"),
380                backends: get_array(&map, &sec, "backends"),
381                reply_timeout_ms: get_u64(&map, &sec, "reply_timeout_ms", 5000),
382                buffer_size: get_u64(&map, &sec, "buffer_size", 65536) as usize,
383            });
384            i += 1;
385        }
386
387        // ── ws_proxy ───────────────────────────────────────────────────────────
388        let mut ws_proxies = Vec::new();
389        let mut i = 0;
390        loop {
391            let sec = format!("ws_proxy[{}]", i);
392            if !section_exists(&map, &sec) {
393                break;
394            }
395            ws_proxies.push(WsProxyConfig {
396                name: get_str(&map, &sec, "name"),
397                listen: get_str(&map, &sec, "listen"),
398                backends: get_array(&map, &sec, "backends"),
399                connect_timeout_ms: get_u64(&map, &sec, "connect_timeout_ms", 5000),
400                read_timeout_ms: get_u64(&map, &sec, "read_timeout_ms", 30000),
401            });
402            i += 1;
403        }
404
405        // ── global middleware ──────────────────────────────────────────────────
406        let global_middleware = parse_middleware_config(&map, "middleware", usize::MAX);
407
408        ProxyConfig {
409            upstreams,
410            routes,
411            tcp_proxies,
412            udp_proxies,
413            ws_proxies,
414            global_middleware,
415        }
416    }
417}
418
419/// Parse a `MiddlewareConfig` from the section map at a given base path.
420/// `route_idx` is used only to build inner-array section paths for rewrite rules.
421fn parse_middleware_config(
422    map: &parser::SectionMap,
423    mw_sec: &str,
424    route_idx: usize,
425) -> MiddlewareConfig {
426    use parser::{get_array, get_str, get_u32, get_u64, section_exists};
427
428    let rl_sec = format!("{}.rate_limit", mw_sec);
429    let rate_limit = if section_exists(map, &rl_sec) {
430        Some(RateLimitConfig {
431            max_requests: get_u32(map, &rl_sec, "max_requests", 1000),
432            window_secs: get_u64(map, &rl_sec, "window_secs", 60),
433        })
434    } else {
435        None
436    };
437
438    let c_sec = format!("{}.cache", mw_sec);
439    let cache = if section_exists(map, &c_sec) {
440        Some(CacheConfig {
441            ttl_secs: get_u64(map, &c_sec, "ttl_secs", 60),
442            vary_by: get_array(map, &c_sec, "vary_by"),
443        })
444    } else {
445        None
446    };
447
448    let a_sec = format!("{}.auth", mw_sec);
449    let auth = if section_exists(map, &a_sec) {
450        let auth_type = get_str(map, &a_sec, "type");
451        match auth_type.as_str() {
452            "bearer" => Some(AuthConfig::Bearer {
453                token_env: get_str(map, &a_sec, "token_env"),
454            }),
455            "jwt" => Some(AuthConfig::Jwt {
456                secret_env: get_str(map, &a_sec, "secret_env"),
457            }),
458            "basic" => Some(AuthConfig::Basic {
459                users_file: get_str(map, &a_sec, "users_file"),
460            }),
461            _ => None,
462        }
463    } else {
464        None
465    };
466
467    // Rewrite rules — the section paths use route_idx for route-scoped rules
468    // or a flat path for global middleware. We look for:
469    //   route[N].middleware.rewrite.request[0], [1], …
470    //   route[N].middleware.rewrite.response[0], [1], …
471    // For global: middleware.rewrite.request[0], etc.
472    let rewrite_request = collect_rewrite_rules(map, mw_sec, "request");
473    let rewrite_response = collect_rewrite_rules(map, mw_sec, "response");
474
475    let ip_sec = format!("{}.ip_filter", mw_sec);
476    let ip_allow = if section_exists(map, &ip_sec) {
477        get_array(map, &ip_sec, "allow")
478    } else {
479        vec![]
480    };
481    let ip_deny = if section_exists(map, &ip_sec) {
482        get_array(map, &ip_sec, "deny")
483    } else {
484        vec![]
485    };
486
487    let _ = route_idx; // used implicitly via mw_sec paths
488
489    MiddlewareConfig { rate_limit, cache, auth, rewrite_request, rewrite_response, ip_allow, ip_deny }
490}
491
492/// Collect `[[{mw_sec}.rewrite.{direction}]]` entries.
493fn collect_rewrite_rules(
494    map: &parser::SectionMap,
495    mw_sec: &str,
496    direction: &str,
497) -> Vec<RewriteRuleConfig> {
498    use parser::{get_str, get_u64};
499
500    let mut rules = Vec::new();
501    let mut j = 0;
502    loop {
503        let rsec = format!("{}.rewrite.{}[{}]", mw_sec, direction, j);
504        if !parser::section_exists(map, &rsec) {
505            break;
506        }
507        let code_val = get_u64(map, &rsec, "code", 0);
508        rules.push(RewriteRuleConfig {
509            type_: get_str(map, &rsec, "type"),
510            name: {
511                let v = get_str(map, &rsec, "name");
512                if v.is_empty() { None } else { Some(v) }
513            },
514            value: {
515                let v = get_str(map, &rsec, "value");
516                if v.is_empty() { None } else { Some(v) }
517            },
518            prefix: {
519                let v = get_str(map, &rsec, "prefix");
520                if v.is_empty() { None } else { Some(v) }
521            },
522            from: {
523                let v = get_str(map, &rsec, "from");
524                if v.is_empty() { None } else { Some(v) }
525            },
526            to: {
527                let v = get_str(map, &rsec, "to");
528                if v.is_empty() { None } else { Some(v) }
529            },
530            code: if code_val == 0 { None } else { Some(code_val as u16) },
531            reason: {
532                let v = get_str(map, &rsec, "reason");
533                if v.is_empty() { None } else { Some(v) }
534            },
535        });
536        j += 1;
537    }
538    rules
539}
540
541fn config_file_path() -> String {
542    std::env::var("RWS_CONFIG_FILE").unwrap_or_else(|_| "rws.config.toml".to_string())
543}
544
545// ── ConfigDrivenApp ────────────────────────────────────────────────────────────
546
547/// A compiled route: a matcher paired with a handler application.
548pub(crate) struct CompiledRoute {
549    pub(crate) matcher: RouteMatcher,
550    /// Shared, type-erased handler. `Arc` makes `Clone` cheap (pointer copy).
551    pub(crate) handler: Arc<dyn Application + Send + Sync>,
552}
553
554/// Matching criteria for a single route.
555#[derive(Clone, Default)]
556pub(crate) struct RouteMatcher {
557    /// Optional SNI hostname / `Host` header match.
558    pub(crate) host: Option<String>,
559    /// Path prefix to match (derived from `path = "/v1/*"`).
560    pub(crate) path_prefix: Option<String>,
561    /// Exact path to match (derived from `path = "/v1/ping"`).
562    pub(crate) path_exact: Option<String>,
563    /// Uppercase HTTP method, or `None` for any.
564    pub(crate) method: Option<String>,
565    /// `Content-Type` prefix (e.g. `"application/grpc"`).
566    pub(crate) content_type_prefix: Option<String>,
567}
568
569impl RouteMatcher {
570    pub(crate) fn from_match_config(cfg: &MatchConfig) -> Self {
571        let (path_prefix, path_exact) = match &cfg.path {
572            Some(p) if p.ends_with('*') => {
573                // "/v1/*" → prefix "/v1/"
574                let stripped = p.trim_end_matches('*').to_string();
575                (Some(stripped), None)
576            }
577            Some(p) => (None, Some(p.clone())),
578            None => (None, None),
579        };
580        let content_type_prefix = cfg.content_type.as_ref().map(|ct| {
581            if ct.ends_with('*') {
582                ct.trim_end_matches('*').to_string()
583            } else {
584                ct.clone()
585            }
586        });
587        RouteMatcher {
588            host: cfg.host.clone(),
589            path_prefix,
590            path_exact,
591            method: cfg.method.clone(),
592            content_type_prefix,
593        }
594    }
595
596    /// Returns `true` if `request` and `conn` match all configured criteria.
597    pub(crate) fn matches(&self, request: &Request, conn: &ConnectionInfo) -> bool {
598        // Host matching: SNI first, then Host header
599        if let Some(ref expected_host) = self.host {
600            let actual_host = conn
601                .sni_hostname
602                .as_deref()
603                .or_else(|| {
604                    request
605                        .headers
606                        .iter()
607                        .find(|h| h.name.eq_ignore_ascii_case("host"))
608                        .map(|h| h.value.as_str())
609                })
610                .unwrap_or("");
611            if actual_host != expected_host.as_str() {
612                return false;
613            }
614        }
615
616        // Method matching
617        if let Some(ref m) = self.method {
618            if request.method.to_uppercase() != m.as_str() {
619                return false;
620            }
621        }
622
623        // Path matching: strip query string for comparison
624        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
625        if let Some(ref prefix) = self.path_prefix {
626            if !path.starts_with(prefix.as_str()) {
627                return false;
628            }
629        } else if let Some(ref exact) = self.path_exact {
630            if path != exact.as_str() {
631                return false;
632            }
633        }
634
635        // Content-Type prefix matching
636        if let Some(ref ct_prefix) = self.content_type_prefix {
637            let actual_ct = request
638                .headers
639                .iter()
640                .find(|h| h.name.eq_ignore_ascii_case("content-type"))
641                .map(|h| h.value.as_str())
642                .unwrap_or("");
643            if !actual_ct.starts_with(ct_prefix.as_str()) {
644                return false;
645            }
646        }
647
648        true
649    }
650}
651
652/// An `Application` that routes requests based on a parsed `ProxyConfig`.
653///
654/// `Clone` is cheap: `routes` is an `Arc<Vec<...>>` (pointer copy), and
655/// `fallback` is `App`, itself cheap to clone (an `Option<Arc<ServerConfig>>`).
656#[derive(Clone)]
657pub struct ConfigDrivenApp {
658    routes: Arc<Vec<CompiledRoute>>,
659    /// Fallback for unmatched requests — handles /healthz, /readyz, /metrics,
660    /// static files, and the 404 controller. Reads `RWS_CONFIG_*` env vars
661    /// per request (`App::new()`'s default) unless pinned via
662    /// [`ConfigDrivenApp::with_config`].
663    fallback: App,
664}
665
666impl ConfigDrivenApp {
667    pub(crate) fn new(routes: Vec<CompiledRoute>) -> Self {
668        use crate::core::New;
669        ConfigDrivenApp {
670            routes: Arc::new(routes),
671            fallback: App::new(),
672        }
673    }
674
675    /// Pin the fallback [`App`] (used for any request none of the
676    /// config-driven routes match) to an explicit [`ServerConfig`], instead
677    /// of reading `RWS_CONFIG_*` environment variables per request.
678    ///
679    /// Mirrors [`App::with_config`] / [`crate::state::AppWithState::with_config`]
680    /// — same rationale: safe for parallel tests, and lets multiple
681    /// differently-configured proxy instances coexist in one process.
682    ///
683    /// ```rust,no_run
684    /// use rust_web_server::proxy_config::build_from_file;
685    /// use rust_web_server::server_config::ServerConfig;
686    ///
687    /// let (app, _handles) = build_from_file();
688    /// let app = app.with_config(ServerConfig::default());
689    /// ```
690    pub fn with_config(mut self, config: ServerConfig) -> Self {
691        self.fallback = App::with_config(config);
692        self
693    }
694}
695
696impl Application for ConfigDrivenApp {
697    fn execute(&self, request: &Request, conn: &ConnectionInfo) -> Result<Response, String> {
698        for route in self.routes.iter() {
699            if route.matcher.matches(request, conn) {
700                return route.handler.execute(request, conn);
701            }
702        }
703        self.fallback.execute(request, conn)
704    }
705}
706
707// ── NullApp ────────────────────────────────────────────────────────────────────
708
709/// A dead-end `Application` that always returns 404.
710/// Used as the `next` parameter when calling `Middleware::handle` directly.
711#[derive(Clone, Copy)]
712pub(crate) struct NullApp;
713
714impl Application for NullApp {
715    fn execute(&self, _request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
716        let mut r = Response::new();
717        r.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
718        r.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
719        Ok(r)
720    }
721}
722
723// ── DynamicProxy ──────────────────────────────────────────────────────────────
724
725use std::collections::HashMap;
726use std::collections::hash_map::DefaultHasher;
727use std::hash::{Hash, Hasher};
728use std::sync::atomic::{AtomicUsize, Ordering};
729use std::sync::RwLock;
730use std::time::{Duration, SystemTime, UNIX_EPOCH};
731
732/// Backend-selection strategy for `DynamicProxy`, configured via the
733/// `strategy` field on `[[upstream]]` in `rws.config.toml`. Unknown or empty
734/// values fall back to `RoundRobin`, matching the parser's own default.
735#[derive(Clone, Copy, Debug, PartialEq, Eq)]
736pub(crate) enum LoadBalanceStrategy {
737    RoundRobin,
738    Random,
739    IpHash,
740    LeastConnections,
741}
742
743impl LoadBalanceStrategy {
744    fn parse(s: &str) -> Self {
745        match s {
746            "random" => LoadBalanceStrategy::Random,
747            "ip_hash" => LoadBalanceStrategy::IpHash,
748            "least_connections" => LoadBalanceStrategy::LeastConnections,
749            _ => LoadBalanceStrategy::RoundRobin,
750        }
751    }
752}
753
754/// A proxy adapter that reads its backend list from a shared, health-checker-
755/// maintained live list at request time. Supports dynamic removal/restoration
756/// of backends without restarting.
757///
758/// This type is `Clone + Send + Sync` and implements `Application`.
759#[derive(Clone)]
760pub(crate) struct DynamicProxy {
761    live: Arc<RwLock<Vec<String>>>,
762    counter: Arc<AtomicUsize>,
763    connect_timeout: Duration,
764    read_timeout: Duration,
765    strip_prefix: Option<Arc<String>>,
766    add_prefix: Option<Arc<String>>,
767    tls: bool,
768    strategy: LoadBalanceStrategy,
769    connections: Arc<RwLock<HashMap<String, Arc<AtomicUsize>>>>,
770}
771
772impl DynamicProxy {
773    pub(crate) fn new(
774        live: Arc<RwLock<Vec<String>>>,
775        connect_timeout_ms: u64,
776        read_timeout_ms: u64,
777        strip_prefix: Option<String>,
778        add_prefix: Option<String>,
779        tls: bool,
780        strategy: String,
781    ) -> Self {
782        DynamicProxy {
783            live,
784            counter: Arc::new(AtomicUsize::new(0)),
785            connect_timeout: Duration::from_millis(connect_timeout_ms),
786            read_timeout: Duration::from_millis(read_timeout_ms),
787            strip_prefix: strip_prefix.map(Arc::new),
788            add_prefix: add_prefix.map(Arc::new),
789            tls,
790            strategy: LoadBalanceStrategy::parse(&strategy),
791            connections: Arc::new(RwLock::new(HashMap::new())),
792        }
793    }
794
795    fn next_backend(&self, client_ip: &str) -> Option<String> {
796        let live = self.live.read().unwrap();
797        if live.is_empty() {
798            return None;
799        }
800
801        let idx = match self.strategy {
802            LoadBalanceStrategy::RoundRobin => {
803                self.counter.fetch_add(1, Ordering::Relaxed) % live.len()
804            }
805            LoadBalanceStrategy::Random => {
806                let nanos = SystemTime::now()
807                    .duration_since(UNIX_EPOCH)
808                    .map(|d| d.subsec_nanos())
809                    .unwrap_or(0) as usize;
810                let salt = self.counter.fetch_add(1, Ordering::Relaxed);
811                nanos.wrapping_add(salt) % live.len()
812            }
813            LoadBalanceStrategy::IpHash => {
814                let mut hasher = DefaultHasher::new();
815                client_ip.hash(&mut hasher);
816                (hasher.finish() as usize) % live.len()
817            }
818            LoadBalanceStrategy::LeastConnections => {
819                let connections = self.connections.read().unwrap();
820                live.iter()
821                    .enumerate()
822                    .min_by_key(|(_, backend)| {
823                        connections
824                            .get(*backend)
825                            .map(|c| c.load(Ordering::Relaxed))
826                            .unwrap_or(0)
827                    })
828                    .map(|(i, _)| i)
829                    .unwrap_or(0)
830            }
831        };
832
833        Some(live[idx].clone())
834    }
835
836    /// Returns the shared in-flight connection counter for `backend`,
837    /// creating it on first use. Only consulted under `LeastConnections`.
838    fn connection_counter(&self, backend: &str) -> Arc<AtomicUsize> {
839        if let Some(counter) = self.connections.read().unwrap().get(backend) {
840            return Arc::clone(counter);
841        }
842        let mut connections = self.connections.write().unwrap();
843        Arc::clone(
844            connections
845                .entry(backend.to_string())
846                .or_insert_with(|| Arc::new(AtomicUsize::new(0))),
847        )
848    }
849}
850
851/// Decrements a backend's in-flight connection count when the request
852/// finishes (including early returns), keeping `least_connections` accurate.
853struct ConnectionGuard {
854    counter: Arc<AtomicUsize>,
855}
856
857impl Drop for ConnectionGuard {
858    fn drop(&mut self) {
859        self.counter.fetch_sub(1, Ordering::Relaxed);
860    }
861}
862
863impl Application for DynamicProxy {
864    fn execute(&self, request: &Request, conn: &ConnectionInfo) -> Result<Response, String> {
865        let backend = match self.next_backend(&conn.client.ip) {
866            Some(b) => b,
867            None => {
868                return Ok(bad_gateway());
869            }
870        };
871
872        let _connection_guard = if self.strategy == LoadBalanceStrategy::LeastConnections {
873            let counter = self.connection_counter(&backend);
874            counter.fetch_add(1, Ordering::Relaxed);
875            Some(ConnectionGuard { counter })
876        } else {
877            None
878        };
879
880        let (host, port, _) = match crate::proxy_config::health::parse_backend_url(&backend) {
881            Some(t) => t,
882            None => return Ok(bad_gateway()),
883        };
884
885        // Apply path rewriting if configured
886        let mut req_clone;
887        let effective_request = if self.strip_prefix.is_some() || self.add_prefix.is_some() {
888            req_clone = request.clone();
889            if let Some(ref sp) = self.strip_prefix {
890                if let Some(stripped) = req_clone.request_uri.strip_prefix(sp.as_str()) {
891                    req_clone.request_uri = if stripped.is_empty() || !stripped.starts_with('/') {
892                        format!("/{}", stripped)
893                    } else {
894                        stripped.to_string()
895                    };
896                }
897            }
898            if let Some(ref ap) = self.add_prefix {
899                req_clone.request_uri = format!("{}{}", ap, req_clone.request_uri);
900            }
901            &req_clone
902        } else {
903            request
904        };
905
906        let result = if self.tls {
907            #[cfg(any(feature = "http-client", feature = "http2"))]
908            {
909                crate::proxy::proxy_https1(
910                    effective_request,
911                    &conn.client.ip,
912                    &host,
913                    port,
914                    self.connect_timeout,
915                    self.read_timeout,
916                )
917            }
918            #[cfg(not(any(feature = "http-client", feature = "http2")))]
919            {
920                eprintln!("[proxy] HTTPS upstream requires http-client or http2 feature");
921                Err("TLS upstream not supported in this build".to_string())
922            }
923        } else {
924            crate::proxy::proxy_http1(
925                effective_request,
926                &conn.client.ip,
927                &host,
928                port,
929                self.connect_timeout,
930                self.read_timeout,
931            )
932        };
933
934        match result {
935            Ok(r) => Ok(r),
936            Err(_) => Ok(bad_gateway()),
937        }
938    }
939}
940
941fn bad_gateway() -> Response {
942    use crate::mime_type::MimeType;
943    use crate::range::Range;
944    let cr = Range::get_content_range(
945        b"502 Bad Gateway".to_vec(),
946        MimeType::TEXT_PLAIN.to_string(),
947    );
948    let mut r = Response::new();
949    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
950    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
951    r.content_range_list = vec![cr];
952    r
953}
954
955// ── RedirectAdapter ────────────────────────────────────────────────────────────
956
957/// Action adapter that issues HTTP redirects.
958///
959/// `$path` in `location_template` is replaced with the request URI at runtime.
960#[derive(Clone)]
961pub(crate) struct RedirectAdapter {
962    location_template: Arc<String>,
963    status: i16,
964    reason: Arc<String>,
965}
966
967impl RedirectAdapter {
968    pub(crate) fn new(location: String, status: u16) -> Self {
969        let (code, reason) = redirect_status(status);
970        RedirectAdapter {
971            location_template: Arc::new(location),
972            status: code,
973            reason: Arc::new(reason),
974        }
975    }
976}
977
978fn redirect_status(code: u16) -> (i16, String) {
979    let phrase = match code {
980        301 => STATUS_CODE_REASON_PHRASE.n301_moved_permanently.reason_phrase,
981        302 => STATUS_CODE_REASON_PHRASE.n302_found.reason_phrase,
982        307 => STATUS_CODE_REASON_PHRASE.n307_temporary_redirect.reason_phrase,
983        308 => STATUS_CODE_REASON_PHRASE.n308_permanent_redirect.reason_phrase,
984        _ => "Redirect",
985    };
986    (code as i16, phrase.to_string())
987}
988
989impl Application for RedirectAdapter {
990    fn execute(&self, request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
991        let location = self
992            .location_template
993            .replace("$path", &request.request_uri);
994        use crate::header::Header;
995        let mut r = Response::new();
996        r.status_code = self.status;
997        r.reason_phrase = self.reason.as_ref().clone();
998        r.headers.push(Header { name: "Location".to_string(), value: location });
999        Ok(r)
1000    }
1001}
1002
1003// ── RespondAdapter ─────────────────────────────────────────────────────────────
1004
1005/// Action adapter that returns a fixed response body.
1006#[derive(Clone)]
1007pub(crate) struct RespondAdapter {
1008    status: i16,
1009    reason: Arc<String>,
1010    body: Arc<Vec<u8>>,
1011    content_type: Arc<String>,
1012}
1013
1014impl RespondAdapter {
1015    pub(crate) fn new(status: u16, body: String, content_type: String) -> Self {
1016        use crate::response::STATUS_CODE_REASON_PHRASE;
1017        let reason = match status {
1018            200 => STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string(),
1019            201 => STATUS_CODE_REASON_PHRASE.n201_created.reason_phrase.to_string(),
1020            204 => STATUS_CODE_REASON_PHRASE.n204_no_content.reason_phrase.to_string(),
1021            400 => STATUS_CODE_REASON_PHRASE.n400_bad_request.reason_phrase.to_string(),
1022            401 => STATUS_CODE_REASON_PHRASE.n401_unauthorized.reason_phrase.to_string(),
1023            403 => STATUS_CODE_REASON_PHRASE.n403_forbidden.reason_phrase.to_string(),
1024            404 => STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string(),
1025            500 => STATUS_CODE_REASON_PHRASE.n500_internal_server_error.reason_phrase.to_string(),
1026            _ => "OK".to_string(),
1027        };
1028        RespondAdapter {
1029            status: status as i16,
1030            reason: Arc::new(reason),
1031            body: Arc::new(body.into_bytes()),
1032            content_type: Arc::new(content_type),
1033        }
1034    }
1035}
1036
1037impl Application for RespondAdapter {
1038    fn execute(&self, _request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1039        use crate::range::Range;
1040        let mut r = Response::new();
1041        r.status_code = self.status;
1042        r.reason_phrase = self.reason.as_ref().clone();
1043        if !self.body.is_empty() {
1044            r.content_range_list = vec![Range::get_content_range(
1045                self.body.as_ref().clone(),
1046                self.content_type.as_ref().clone(),
1047            )];
1048        }
1049        Ok(r)
1050    }
1051}
1052
1053// ── StaticAdapter ──────────────────────────────────────────────────────────────
1054
1055/// Action adapter that serves static files from a configured `root` directory.
1056///
1057/// Unlike `StaticResourceController` (which always resolves paths relative to
1058/// the process's current working directory), this adapter is parameterized
1059/// per-route by `ActionConfig::Static { root, index }` from `rws.config.toml`,
1060/// so a config-driven proxy can serve an arbitrary directory without Rust code.
1061#[derive(Clone)]
1062pub(crate) struct StaticAdapter {
1063    root: Arc<std::path::PathBuf>,
1064    index: Arc<Vec<String>>,
1065}
1066
1067impl StaticAdapter {
1068    pub(crate) fn new(root: String, index: Vec<String>) -> Self {
1069        let index = if index.is_empty() { vec!["index.html".to_string()] } else { index };
1070        StaticAdapter {
1071            root: Arc::new(std::path::PathBuf::from(root)),
1072            index: Arc::new(index),
1073        }
1074    }
1075
1076    /// Resolves `request_uri` against `root`. Returns `None` if the decoded
1077    /// path contains a `..` segment, which would otherwise let a request
1078    /// escape the configured root directory.
1079    fn resolve(&self, request_uri: &str) -> Option<std::path::PathBuf> {
1080        let raw_path = request_uri.split('?').next().unwrap_or(request_uri);
1081        let decoded = crate::url::URL::percent_decode(raw_path);
1082
1083        if decoded.split('/').any(|segment| segment == "..") {
1084            return None;
1085        }
1086
1087        let relative = decoded.trim_start_matches('/');
1088        Some(self.root.join(relative))
1089    }
1090}
1091
1092impl Application for StaticAdapter {
1093    fn execute(&self, request: &Request, _conn: &ConnectionInfo) -> Result<Response, String> {
1094        let mut response = Response::new();
1095
1096        let not_found = |mut response: Response| {
1097            response.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
1098            response.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
1099            response
1100        };
1101
1102        let candidate = match self.resolve(&request.request_uri) {
1103            Some(p) => p,
1104            None => {
1105                response.status_code = *STATUS_CODE_REASON_PHRASE.n403_forbidden.status_code;
1106                response.reason_phrase = STATUS_CODE_REASON_PHRASE.n403_forbidden.reason_phrase.to_string();
1107                return Ok(response);
1108            }
1109        };
1110
1111        let mut file_path = candidate;
1112        if file_path.is_dir() {
1113            let indexed = self
1114                .index
1115                .iter()
1116                .map(|name| file_path.join(name))
1117                .find(|p| p.is_file());
1118
1119            file_path = match indexed {
1120                Some(p) => p,
1121                None => return Ok(not_found(response)),
1122            };
1123        }
1124
1125        if !file_path.is_file() {
1126            return Ok(not_found(response));
1127        }
1128
1129        // Defense-in-depth against symlinks inside `root` that point outside it —
1130        // the `..`-segment check above only catches traversal in the request URI.
1131        if let (Ok(root_canon), Ok(file_canon)) =
1132            (self.root.canonicalize(), file_path.canonicalize())
1133        {
1134            if !file_canon.starts_with(&root_canon) {
1135                return Ok(not_found(response));
1136            }
1137        }
1138
1139        let path_str = match file_path.to_str() {
1140            Some(s) => s,
1141            None => return Ok(not_found(response)),
1142        };
1143
1144        match crate::range::Range::get_content_range_of_a_file(path_str) {
1145            Ok(content_range) => {
1146                response.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
1147                response.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
1148                response.content_range_list = vec![content_range];
1149                Ok(response)
1150            }
1151            Err(_) => Ok(not_found(response)),
1152        }
1153    }
1154}
1155
1156// ── PerRouteRateLimit middleware ───────────────────────────────────────────────
1157
1158/// A per-route rate limiter middleware backed by a shared `RateLimiter`.
1159pub(crate) struct PerRouteRateLimit(pub(crate) Arc<crate::rate_limit::RateLimiter>);
1160
1161impl crate::middleware::Middleware for PerRouteRateLimit {
1162    fn handle(
1163        &self,
1164        request: &Request,
1165        conn: &ConnectionInfo,
1166        next: &dyn Application,
1167    ) -> Result<Response, String> {
1168        use crate::error::{AppError, IntoResponse};
1169        if self.0.check(&conn.client.ip) {
1170            next.execute(request, conn)
1171        } else {
1172            Ok(AppError::TooManyRequests.into_response())
1173        }
1174    }
1175}
1176
1177// ── BearerAuthMiddleware ───────────────────────────────────────────────────────
1178
1179/// Bearer token authentication middleware.
1180pub(crate) struct BearerAuthMiddleware {
1181    pub(crate) token: Arc<String>,
1182}
1183
1184impl crate::middleware::Middleware for BearerAuthMiddleware {
1185    fn handle(
1186        &self,
1187        request: &Request,
1188        conn: &ConnectionInfo,
1189        next: &dyn Application,
1190    ) -> Result<Response, String> {
1191        use crate::error::{AppError, IntoResponse};
1192        let expected = format!("Bearer {}", self.token);
1193        let authorized = request
1194            .headers
1195            .iter()
1196            .any(|h| h.name.eq_ignore_ascii_case("authorization") && h.value == expected);
1197        if authorized {
1198            next.execute(request, conn)
1199        } else {
1200            Ok(AppError::Unauthorized.into_response())
1201        }
1202    }
1203}
1204
1205// ── arc_app helper ─────────────────────────────────────────────────────────────
1206
1207/// Box any `Application + Send + Sync + 'static` into an `Arc<dyn …>`.
1208pub(crate) fn arc_app<A: Application + Send + Sync + 'static>(
1209    a: A,
1210) -> Arc<dyn Application + Send + Sync> {
1211    Arc::new(a)
1212}
1213
1214// ── Public entry points ────────────────────────────────────────────────────────
1215
1216/// Build a `ConfigDrivenApp` from `rws.config.toml` and spawn L4/WS proxy
1217/// threads. Returns the app and a list of thread handles.
1218pub fn build_from_file() -> (ConfigDrivenApp, Vec<std::thread::JoinHandle<()>>) {
1219    builder::build_from_file()
1220}