Skip to main content

soli_proxy/server/
mod.rs

1// When scripting feature is disabled, OptionalLuaEngine = () and cloning it triggers warnings
2#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::auth;
7use crate::circuit_breaker::SharedCircuitBreaker;
8use crate::config::ConfigManager;
9use crate::metrics::SharedMetrics;
10use crate::shutdown::ShutdownCoordinator;
11use anyhow::Result;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Incoming;
15use hyper::header::HeaderValue;
16use hyper::service::service_fn;
17use hyper::Request;
18use hyper::Response;
19use hyper_util::client::legacy::connect::HttpConnector;
20use hyper_util::client::legacy::Client;
21use hyper_util::rt::TokioExecutor;
22use hyper_util::rt::TokioIo;
23use socket2::{Domain, Protocol, Socket, Type};
24use std::net::SocketAddr;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::sync::Arc;
27use std::time::Duration;
28use tokio::io::AsyncWriteExt;
29use tokio::net::{TcpListener, TcpStream};
30use tokio_rustls::TlsAcceptor;
31
32#[cfg(feature = "scripting")]
33use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
34
35type ClientType = Client<HttpConnector, Incoming>;
36type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
37
38#[cfg(feature = "scripting")]
39type OptionalLuaEngine = Option<LuaEngine>;
40#[cfg(not(feature = "scripting"))]
41type OptionalLuaEngine = ();
42
43pub struct LoadBalancerState {
44    counters: Vec<AtomicUsize>,
45}
46
47impl LoadBalancerState {
48    pub fn new(num_rules: usize) -> Self {
49        Self {
50            counters: (0..num_rules).map(|_| AtomicUsize::new(0)).collect(),
51        }
52    }
53
54    pub fn select_index(&self, rule_idx: usize, num_targets: usize) -> usize {
55        if num_targets == 0 {
56            return 0;
57        }
58        self.counters[rule_idx].fetch_add(1, Ordering::Relaxed) % num_targets
59    }
60}
61
62/// Helper to record app-specific metrics
63fn record_app_metrics(
64    metrics: &SharedMetrics,
65    app_manager: &Option<Arc<AppManager>>,
66    target_url: &str,
67    bytes_in: u64,
68    bytes_out: u64,
69    status: u16,
70    duration: Duration,
71) {
72    if let Some(ref manager) = app_manager {
73        if let Ok(url) = url::Url::parse(target_url) {
74            if let Some(port) = url.port() {
75                if let Some(app_name) = futures::executor::block_on(manager.get_app_name(port)) {
76                    metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
77                }
78            }
79        }
80    }
81}
82
83/// Pre-parsed header value for X-Forwarded-For to avoid parsing on every request
84static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
85    std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
86
87/// Verify Basic Auth credentials against stored hashes
88/// Returns true if credentials are valid, false otherwise
89fn verify_basic_auth(req: &Request<Incoming>, auth_entries: &[crate::auth::BasicAuth]) -> bool {
90    if auth_entries.is_empty() {
91        return true;
92    }
93
94    let auth_header = req.headers().get("authorization");
95    if auth_header.is_none() {
96        return false;
97    }
98
99    let header_value = auth_header.unwrap().to_str().unwrap_or("");
100    if !header_value.starts_with("Basic ") {
101        return false;
102    }
103
104    let encoded = &header_value[6..];
105    let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded)
106        .unwrap_or_default();
107    let creds = String::from_utf8_lossy(&decoded);
108
109    if let Some((username, password)) = creds.split_once(':') {
110        for entry in auth_entries {
111            if entry.username == username && auth::verify_password(password, &entry.hash) {
112                return true;
113            }
114        }
115    }
116
117    false
118}
119
120/// Create 401 Unauthorized response with WWW-Authenticate header
121fn create_auth_required_response() -> Response<BoxBody> {
122    let body = http_body_util::Full::new(Bytes::from("Authentication required")).boxed();
123    Response::builder()
124        .status(401)
125        .header("WWW-Authenticate", "Basic realm=\"Restricted\"")
126        .body(body)
127        .unwrap()
128}
129
130fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
131    let domain = if addr.is_ipv4() {
132        Domain::IPV4
133    } else {
134        Domain::IPV6
135    };
136    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
137    socket.set_reuse_address(true)?;
138    socket.set_reuse_port(true)?;
139    socket.set_nonblocking(true)?;
140    socket.bind(&addr.into())?;
141    socket.listen(8192)?;
142    let std_listener: std::net::TcpListener = socket.into();
143    Ok(TcpListener::from_std(std_listener)?)
144}
145
146fn create_client() -> ClientType {
147    let exec = TokioExecutor::new();
148    let mut connector = HttpConnector::new();
149    connector.set_nodelay(true);
150    connector.set_keepalive(Some(Duration::from_secs(30)));
151    connector.set_connect_timeout(Some(Duration::from_secs(5)));
152    Client::builder(exec)
153        .pool_max_idle_per_host(256)
154        .pool_idle_timeout(Duration::from_secs(60))
155        .build(connector)
156}
157
158pub struct ProxyServer {
159    config: Arc<ConfigManager>,
160    shutdown: ShutdownCoordinator,
161    tls_acceptor: Option<TlsAcceptor>,
162    https_addr: Option<SocketAddr>,
163    metrics: SharedMetrics,
164    challenge_store: ChallengeStore,
165    lua_engine: OptionalLuaEngine,
166    circuit_breaker: SharedCircuitBreaker,
167    app_manager: Option<Arc<AppManager>>,
168    load_balancer: Arc<LoadBalancerState>,
169}
170
171impl ProxyServer {
172    pub fn new(
173        config: Arc<ConfigManager>,
174        shutdown: ShutdownCoordinator,
175        metrics: SharedMetrics,
176        challenge_store: ChallengeStore,
177        lua_engine: OptionalLuaEngine,
178        circuit_breaker: SharedCircuitBreaker,
179        app_manager: Option<Arc<AppManager>>,
180    ) -> Result<Self> {
181        let num_rules = config.get_config().rules.len();
182        Ok(Self {
183            config,
184            shutdown,
185            tls_acceptor: None,
186            https_addr: None,
187            metrics,
188            challenge_store,
189            lua_engine,
190            circuit_breaker,
191            app_manager,
192            load_balancer: Arc::new(LoadBalancerState::new(num_rules)),
193        })
194    }
195
196    #[allow(clippy::too_many_arguments)]
197    pub fn with_https(
198        config: Arc<ConfigManager>,
199        shutdown: ShutdownCoordinator,
200        tls_acceptor: TlsAcceptor,
201        https_addr: SocketAddr,
202        metrics: SharedMetrics,
203        challenge_store: ChallengeStore,
204        lua_engine: OptionalLuaEngine,
205        circuit_breaker: SharedCircuitBreaker,
206        app_manager: Option<Arc<AppManager>>,
207    ) -> Result<Self> {
208        let num_rules = config.get_config().rules.len();
209        Ok(Self {
210            config,
211            shutdown,
212            tls_acceptor: Some(tls_acceptor),
213            https_addr: Some(https_addr),
214            metrics,
215            challenge_store,
216            lua_engine,
217            circuit_breaker,
218            app_manager,
219            load_balancer: Arc::new(LoadBalancerState::new(num_rules)),
220        })
221    }
222
223    pub async fn run(&self) -> Result<()> {
224        let cfg = self.config.get_config();
225        let http_addr: SocketAddr = cfg.server.bind.parse()?;
226        let https_addr = self.https_addr;
227
228        let has_https = https_addr.is_some();
229        let num_listeners = std::thread::available_parallelism()
230            .map(|n| n.get())
231            .unwrap_or(4);
232
233        // Spawn N HTTP accept loops with SO_REUSEPORT
234        // Each listener gets its own client with its own connection pool to avoid contention
235        let app_manager = self.app_manager.clone();
236        for i in 0..num_listeners {
237            let config_clone = self.config.clone();
238            let shutdown_clone = self.shutdown.clone();
239            let metrics_clone = self.metrics.clone();
240            let challenge_store_clone = self.challenge_store.clone();
241            let lua_clone = self.lua_engine.clone();
242            let cb_clone = self.circuit_breaker.clone();
243            let am_clone = app_manager.clone();
244            let lb_clone = self.load_balancer.clone();
245
246            tokio::spawn(async move {
247                if let Err(e) = run_http_server(
248                    http_addr,
249                    config_clone,
250                    shutdown_clone,
251                    metrics_clone,
252                    challenge_store_clone,
253                    lua_clone,
254                    cb_clone,
255                    am_clone,
256                    lb_clone,
257                )
258                .await
259                {
260                    tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
261                }
262            });
263        }
264
265        if let Some(https_addr) = https_addr {
266            for i in 0..num_listeners {
267                let config_clone = self.config.clone();
268                let shutdown_clone = self.shutdown.clone();
269                let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
270                let metrics_clone = self.metrics.clone();
271                let challenge_store_clone = self.challenge_store.clone();
272                let lua_clone = self.lua_engine.clone();
273                let cb_clone = self.circuit_breaker.clone();
274                let am_clone = app_manager.clone();
275                let lb_clone = self.load_balancer.clone();
276
277                tokio::spawn(async move {
278                    if let Err(e) = run_https_server(
279                        https_addr,
280                        config_clone,
281                        shutdown_clone,
282                        acceptor,
283                        metrics_clone,
284                        challenge_store_clone,
285                        lua_clone,
286                        cb_clone,
287                        am_clone,
288                        lb_clone,
289                    )
290                    .await
291                    {
292                        tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
293                    }
294                });
295            }
296        }
297
298        tracing::info!(
299            "HTTP/1.1 server listening on {} ({} accept loops)",
300            http_addr,
301            num_listeners
302        );
303        if has_https {
304            tracing::info!(
305                "HTTPS/2 server listening on {} ({} accept loops)",
306                https_addr.unwrap(),
307                num_listeners
308            );
309        }
310
311        loop {
312            if self.shutdown.is_shutting_down() {
313                tracing::info!("Shutting down servers...");
314                break;
315            }
316            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
317        }
318
319        Ok(())
320    }
321}
322
323#[allow(clippy::too_many_arguments)]
324async fn run_http_server(
325    addr: SocketAddr,
326    config: Arc<ConfigManager>,
327    shutdown: ShutdownCoordinator,
328    metrics: SharedMetrics,
329    challenge_store: ChallengeStore,
330    lua_engine: OptionalLuaEngine,
331    circuit_breaker: SharedCircuitBreaker,
332    app_manager: Option<Arc<AppManager>>,
333    load_balancer: Arc<LoadBalancerState>,
334) -> Result<()> {
335    let listener = create_listener(addr)?;
336    let client = create_client();
337
338    loop {
339        if shutdown.is_shutting_down() {
340            break;
341        }
342
343        match listener.accept().await {
344            Ok((stream, _)) => {
345                let _ = stream.set_nodelay(true);
346                let client = client.clone();
347                let config = config.clone();
348                let metrics = metrics.clone();
349                let cs = challenge_store.clone();
350                let lua = lua_engine.clone();
351                let cb = circuit_breaker.clone();
352                let am = app_manager.clone();
353                let lb = load_balancer.clone();
354                tokio::spawn(async move {
355                    if let Err(e) = handle_http11_connection(
356                        stream, client, config, metrics, cs, lua, cb, am, lb,
357                    )
358                    .await
359                    {
360                        tracing::debug!("HTTP/1.1 connection error: {}", e);
361                    }
362                });
363            }
364            Err(e) => {
365                tracing::error!("HTTP/1.1 accept error: {}", e);
366            }
367        }
368    }
369
370    Ok(())
371}
372
373#[allow(clippy::too_many_arguments)]
374async fn run_https_server(
375    addr: SocketAddr,
376    config: Arc<ConfigManager>,
377    shutdown: ShutdownCoordinator,
378    acceptor: TlsAcceptor,
379    metrics: SharedMetrics,
380    challenge_store: ChallengeStore,
381    lua_engine: OptionalLuaEngine,
382    circuit_breaker: SharedCircuitBreaker,
383    app_manager: Option<Arc<AppManager>>,
384    load_balancer: Arc<LoadBalancerState>,
385) -> Result<()> {
386    let listener = create_listener(addr)?;
387    let client = create_client();
388
389    loop {
390        if shutdown.is_shutting_down() {
391            break;
392        }
393
394        match listener.accept().await {
395            Ok((stream, _)) => {
396                let _ = stream.set_nodelay(true);
397                let client = client.clone();
398                let config = config.clone();
399                let acceptor = acceptor.clone();
400                let metrics = metrics.clone();
401                let cs = challenge_store.clone();
402                let lua = lua_engine.clone();
403                let cb = circuit_breaker.clone();
404                let am = app_manager.clone();
405                let lb = load_balancer.clone();
406                tokio::spawn(async move {
407                    match acceptor.accept(stream).await {
408                        Ok(tls_stream) => {
409                            metrics.inc_tls_connections();
410                            if let Err(e) = handle_https2_connection(
411                                tls_stream, client, config, metrics, cs, lua, cb, am, lb,
412                            )
413                            .await
414                            {
415                                tracing::debug!("HTTPS/2 connection error: {}", e);
416                            }
417                        }
418                        Err(e) => {
419                            tracing::error!("TLS accept error: {}", e);
420                        }
421                    }
422                });
423            }
424            Err(e) => {
425                tracing::error!("HTTPS/2 accept error: {}", e);
426            }
427        }
428    }
429
430    Ok(())
431}
432
433#[allow(clippy::too_many_arguments)]
434async fn handle_http11_connection(
435    stream: tokio::net::TcpStream,
436    client: ClientType,
437    config: Arc<ConfigManager>,
438    metrics: SharedMetrics,
439    challenge_store: ChallengeStore,
440    lua_engine: OptionalLuaEngine,
441    circuit_breaker: SharedCircuitBreaker,
442    app_manager: Option<Arc<AppManager>>,
443    load_balancer: Arc<LoadBalancerState>,
444) -> Result<()> {
445    let io = TokioIo::new(stream);
446    let svc = service_fn(move |req| {
447        handle_request(
448            req,
449            client.clone(),
450            config.clone(),
451            metrics.clone(),
452            challenge_store.clone(),
453            lua_engine.clone(),
454            circuit_breaker.clone(),
455            app_manager.clone(),
456            load_balancer.clone(),
457        )
458    });
459
460    let conn = hyper::server::conn::http1::Builder::new()
461        .keep_alive(true)
462        .pipeline_flush(true)
463        .serve_connection(io, svc)
464        .with_upgrades();
465
466    if let Err(e) = conn.await {
467        tracing::debug!("HTTP/1.1 connection error: {}", e);
468    }
469
470    Ok(())
471}
472
473#[allow(clippy::too_many_arguments)]
474async fn handle_https2_connection(
475    stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
476    client: ClientType,
477    config: Arc<ConfigManager>,
478    metrics: SharedMetrics,
479    challenge_store: ChallengeStore,
480    lua_engine: OptionalLuaEngine,
481    circuit_breaker: SharedCircuitBreaker,
482    app_manager: Option<Arc<AppManager>>,
483    load_balancer: Arc<LoadBalancerState>,
484) -> Result<()> {
485    let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
486
487    let io = TokioIo::new(stream);
488
489    if is_h2 {
490        let exec = TokioExecutor::new();
491        let svc = service_fn(move |req| {
492            handle_request(
493                req,
494                client.clone(),
495                config.clone(),
496                metrics.clone(),
497                challenge_store.clone(),
498                lua_engine.clone(),
499                circuit_breaker.clone(),
500                app_manager.clone(),
501                load_balancer.clone(),
502            )
503        });
504        let conn = hyper::server::conn::http2::Builder::new(exec)
505            .initial_stream_window_size(1024 * 1024)
506            .initial_connection_window_size(2 * 1024 * 1024)
507            .max_concurrent_streams(250)
508            .serve_connection(io, svc);
509        if let Err(e) = conn.await {
510            tracing::debug!("HTTPS/2 connection error: {}", e);
511        }
512    } else {
513        let svc = service_fn(move |req| {
514            handle_request(
515                req,
516                client.clone(),
517                config.clone(),
518                metrics.clone(),
519                challenge_store.clone(),
520                lua_engine.clone(),
521                circuit_breaker.clone(),
522                app_manager.clone(),
523                load_balancer.clone(),
524            )
525        });
526        let conn = hyper::server::conn::http1::Builder::new()
527            .keep_alive(true)
528            .pipeline_flush(true)
529            .serve_connection(io, svc)
530            .with_upgrades();
531        if let Err(e) = conn.await {
532            tracing::debug!("HTTPS/1.1 connection error: {}", e);
533        }
534    }
535
536    Ok(())
537}
538
539/// Extract headers from a hyper request into a HashMap for Lua consumption.
540#[cfg(feature = "scripting")]
541fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
542    req.headers()
543        .iter()
544        .map(|(k, v)| {
545            (
546                k.as_str().to_lowercase(),
547                v.to_str().unwrap_or("").to_string(),
548            )
549        })
550        .collect()
551}
552
553/// Build a LuaRequest from a hyper Request.
554#[cfg(feature = "scripting")]
555fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
556    let host = req
557        .uri()
558        .host()
559        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
560        .unwrap_or("")
561        .to_string();
562
563    let content_length = req
564        .headers()
565        .get("content-length")
566        .and_then(|v| v.to_str().ok())
567        .and_then(|v| v.parse().ok())
568        .unwrap_or(0);
569
570    LuaRequest {
571        method: req.method().to_string(),
572        path: req.uri().path().to_string(),
573        headers: extract_headers(req),
574        host,
575        content_length,
576    }
577}
578
579/// Extract response headers into a HashMap for Lua consumption.
580#[cfg(feature = "scripting")]
581fn extract_response_headers(
582    headers: &hyper::HeaderMap,
583) -> std::collections::HashMap<String, String> {
584    headers
585        .iter()
586        .map(|(k, v)| {
587            (
588                k.as_str().to_lowercase(),
589                v.to_str().unwrap_or("").to_string(),
590            )
591        })
592        .collect()
593}
594
595#[allow(clippy::too_many_arguments)]
596async fn handle_request(
597    req: Request<Incoming>,
598    client: ClientType,
599    config_manager: Arc<ConfigManager>,
600    metrics: SharedMetrics,
601    challenge_store: ChallengeStore,
602    lua_engine: OptionalLuaEngine,
603    circuit_breaker: SharedCircuitBreaker,
604    app_manager: Option<Arc<AppManager>>,
605    load_balancer: Arc<LoadBalancerState>,
606) -> Result<Response<BoxBody>, hyper::Error> {
607    let start_time = std::time::Instant::now();
608    metrics.inc_in_flight();
609    let config = config_manager.get_config();
610
611    // ACME challenge check — must come before all other routing
612    if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
613        metrics.dec_in_flight();
614        return Ok(response);
615    }
616
617    if is_metrics_request(&req) {
618        let duration = start_time.elapsed();
619        metrics.dec_in_flight();
620        let metrics_output = metrics.format_metrics();
621        metrics.record_request(0, metrics_output.len() as u64, 200, duration);
622        let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
623        return Ok(Response::builder()
624            .status(200)
625            .header("Content-Type", "text/plain")
626            .body(body)
627            .unwrap());
628    }
629
630    // --- Lua on_request hook ---
631    #[cfg(feature = "scripting")]
632    if let Some(ref engine) = lua_engine {
633        if engine.has_on_request() {
634            let mut lua_req = build_lua_request(&req);
635            match engine.call_on_request(&mut lua_req) {
636                RequestHookResult::Deny { status, body } => {
637                    metrics.dec_in_flight();
638                    let duration = start_time.elapsed();
639                    metrics.record_request(0, body.len() as u64, status, duration);
640                    let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
641                    return Ok(Response::builder().status(status).body(resp_body).unwrap());
642                }
643                RequestHookResult::Continue(updated_req) => {
644                    // Apply any header modifications back to the hyper request
645                    // We can't easily mutate the incoming request headers here since
646                    // we'd need to own it, so we store the lua_req for later use.
647                    // Headers set via set_header in on_request will be applied after
648                    // the request is decomposed into parts.
649                    let _ = updated_req;
650                }
651            }
652        }
653    }
654
655    let is_websocket = is_websocket_request(&req);
656
657    if is_websocket {
658        return handle_websocket_request(
659            req,
660            client,
661            &config,
662            &metrics,
663            start_time,
664            app_manager.clone(),
665        )
666        .await;
667    }
668
669    let result = handle_regular_request(
670        req,
671        client,
672        &config,
673        &lua_engine,
674        &circuit_breaker,
675        app_manager.clone(),
676        load_balancer.clone(),
677    )
678    .await;
679    let duration = start_time.elapsed();
680
681    metrics.dec_in_flight();
682
683    match result {
684        #[allow(unused_variables)]
685        Ok((response, _target_url, route_scripts)) => {
686            let status = response.status().as_u16();
687
688            // --- Lua on_request_end hooks (global + route) ---
689            #[cfg(feature = "scripting")]
690            if let Some(ref engine) = lua_engine {
691                let lua_req = LuaRequest {
692                    method: String::new(),
693                    path: String::new(),
694                    headers: std::collections::HashMap::new(),
695                    host: String::new(),
696                    content_length: 0,
697                };
698                let duration_ms = duration.as_secs_f64() * 1000.0;
699
700                // Global on_request_end
701                if engine.has_on_request_end() {
702                    engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
703                }
704
705                // Route-specific on_request_end
706                for script_name in &route_scripts {
707                    engine.call_route_on_request_end(
708                        script_name,
709                        &lua_req,
710                        status,
711                        duration_ms,
712                        &_target_url,
713                    );
714                }
715            }
716
717            metrics.record_request(0, 0, status, duration);
718            record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration);
719            let (parts, body) = response.into_parts();
720            let boxed = body.map_err(|_| unreachable!()).boxed();
721            Ok(Response::from_parts(parts, boxed))
722        }
723        Err(e) => {
724            metrics.inc_errors();
725            Err(e)
726        }
727    }
728}
729
730fn is_websocket_request(req: &Request<Incoming>) -> bool {
731    if let Some(upgrade) = req.headers().get("upgrade") {
732        if upgrade == "websocket" {
733            return true;
734        }
735    }
736    false
737}
738
739fn is_metrics_request(req: &Request<Incoming>) -> bool {
740    req.uri().path() == "/metrics"
741}
742
743fn handle_acme_challenge(
744    req: &Request<Incoming>,
745    challenge_store: &ChallengeStore,
746) -> Option<Response<BoxBody>> {
747    let path = req.uri().path();
748    let prefix = "/.well-known/acme-challenge/";
749
750    if !path.starts_with(prefix) {
751        return None;
752    }
753
754    let token = &path[prefix.len()..];
755
756    if let Ok(store) = challenge_store.read() {
757        if let Some(key_auth) = store.get(token) {
758            let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
759            return Some(
760                Response::builder()
761                    .status(200)
762                    .header("Content-Type", "text/plain")
763                    .body(body)
764                    .unwrap(),
765            );
766        }
767    }
768
769    let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
770    Some(Response::builder().status(404).body(body).unwrap())
771}
772
773async fn handle_websocket_request(
774    req: Request<Incoming>,
775    _client: ClientType,
776    config: &crate::config::Config,
777    metrics: &SharedMetrics,
778    _start_time: std::time::Instant,
779    _app_manager: Option<Arc<AppManager>>,
780) -> Result<Response<BoxBody>, hyper::Error> {
781    let target_result = find_target(&req, &config.rules);
782
783    if target_result.is_none() {
784        metrics.inc_errors();
785        let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
786        return Ok(Response::builder().status(421).body(body).unwrap());
787    }
788
789    let (target_url, _, _, _) = target_result.unwrap();
790
791    // Extract host:port from target URL (e.g. "http://127.0.0.1:3000/path" -> "127.0.0.1:3000")
792    let backend_addr = match url::Url::parse(&target_url) {
793        Ok(u) => format!(
794            "{}:{}",
795            u.host_str().unwrap_or("127.0.0.1"),
796            u.port().unwrap_or(80)
797        ),
798        Err(_) => {
799            metrics.inc_errors();
800            let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
801            return Ok(Response::builder().status(502).body(body).unwrap());
802        }
803    };
804
805    let path = req.uri().path().to_string();
806    let query = req
807        .uri()
808        .query()
809        .map(|q| format!("?{}", q))
810        .unwrap_or_default();
811
812    let ws_key = req
813        .headers()
814        .get("sec-websocket-key")
815        .and_then(|v| v.to_str().ok())
816        .unwrap_or("")
817        .to_string();
818    let ws_version = req
819        .headers()
820        .get("sec-websocket-version")
821        .and_then(|v| v.to_str().ok())
822        .unwrap_or("13")
823        .to_string();
824    let ws_protocol = req
825        .headers()
826        .get("sec-websocket-protocol")
827        .and_then(|v| v.to_str().ok())
828        .map(|s| s.to_string());
829    let host_header = req
830        .headers()
831        .get("host")
832        .and_then(|v| v.to_str().ok())
833        .unwrap_or(&backend_addr)
834        .to_string();
835
836    tracing::info!(
837        "WebSocket upgrade request to {}{}{}",
838        backend_addr,
839        path,
840        query
841    );
842
843    // Connect to the backend
844    let backend = match TcpStream::connect(&backend_addr).await {
845        Ok(s) => s,
846        Err(e) => {
847            tracing::error!("Failed to connect to backend for WebSocket: {}", e);
848            metrics.inc_errors();
849            let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
850            return Ok(Response::builder().status(502).body(body).unwrap());
851        }
852    };
853
854    // Send the upgrade request to the backend
855    let mut handshake = format!(
856        "GET {}{} HTTP/1.1\r\n\
857         Host: {}\r\n\
858         Upgrade: websocket\r\n\
859         Connection: Upgrade\r\n\
860         Sec-WebSocket-Key: {}\r\n\
861         Sec-WebSocket-Version: {}\r\n",
862        path, query, host_header, ws_key, ws_version,
863    );
864    if let Some(proto) = &ws_protocol {
865        handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
866    }
867    handshake.push_str("\r\n");
868
869    let (mut backend_read, mut backend_write) = backend.into_split();
870    if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
871        tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
872        metrics.inc_errors();
873        let body =
874            http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
875                .boxed();
876        return Ok(Response::builder().status(502).body(body).unwrap());
877    }
878
879    // Read the backend's 101 response
880    let mut response_buf = vec![0u8; 4096];
881    let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
882        Ok(n) if n > 0 => n,
883        _ => {
884            tracing::error!("No response from backend for WebSocket upgrade");
885            metrics.inc_errors();
886            let body = http_body_util::Full::new(Bytes::from(
887                "Backend did not respond to WebSocket upgrade",
888            ))
889            .boxed();
890            return Ok(Response::builder().status(502).body(body).unwrap());
891        }
892    };
893
894    let response_str = String::from_utf8_lossy(&response_buf[..n]);
895    if !response_str.contains("101") {
896        tracing::error!(
897            "Backend rejected WebSocket upgrade: {}",
898            response_str.lines().next().unwrap_or("")
899        );
900        metrics.inc_errors();
901        let body =
902            http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
903        return Ok(Response::builder().status(502).body(body).unwrap());
904    }
905
906    // Extract headers from backend 101 response
907    let mut accept_key = String::new();
908    let mut resp_protocol = None;
909    for line in response_str.lines().skip(1) {
910        if line.trim().is_empty() {
911            break;
912        }
913        if let Some((name, value)) = line.split_once(':') {
914            let name_lower = name.trim().to_lowercase();
915            let value = value.trim().to_string();
916            if name_lower == "sec-websocket-accept" {
917                accept_key = value;
918            } else if name_lower == "sec-websocket-protocol" {
919                resp_protocol = Some(value);
920            }
921        }
922    }
923
924    // Use hyper::upgrade::on to get the client-side stream after we return 101
925    let client_upgrade = hyper::upgrade::on(req);
926
927    // Reunite the backend halves
928    let backend_stream = backend_read.reunite(backend_write).unwrap();
929
930    // Spawn the bidirectional copy task
931    tokio::spawn(async move {
932        match client_upgrade.await {
933            Ok(upgraded) => {
934                let mut client_stream = TokioIo::new(upgraded);
935                let (mut br, mut bw) = tokio::io::split(backend_stream);
936                let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
937                let _ = tokio::join!(
938                    tokio::io::copy(&mut br, &mut cw),
939                    tokio::io::copy(&mut cr, &mut bw),
940                );
941            }
942            Err(e) => {
943                tracing::error!("WebSocket client upgrade failed: {}", e);
944            }
945        }
946    });
947
948    // Return 101 Switching Protocols to the client
949    let mut resp = Response::builder()
950        .status(101)
951        .header("Upgrade", "websocket")
952        .header("Connection", "Upgrade")
953        .header("Sec-WebSocket-Accept", accept_key);
954    if let Some(proto) = resp_protocol {
955        resp = resp.header("Sec-WebSocket-Protocol", proto);
956    }
957    Ok(resp
958        .body(http_body_util::Full::new(Bytes::new()).boxed())
959        .unwrap())
960}
961
962/// Returns (Response, target_url_for_logging, route_scripts)
963async fn handle_regular_request(
964    req: Request<Incoming>,
965    client: ClientType,
966    config: &crate::config::Config,
967    lua_engine: &OptionalLuaEngine,
968    circuit_breaker: &SharedCircuitBreaker,
969    _app_manager: Option<Arc<AppManager>>,
970    load_balancer: Arc<LoadBalancerState>,
971) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
972    let route = find_matching_rule(&req, &config.rules);
973
974    match route {
975        #[allow(unused_mut, unused_variables)]
976        Some(matched_route) => {
977            let path = req.uri().path().to_string();
978            let from_domain_rule = matched_route.from_domain_rule;
979            let matched_prefix = matched_route.matched_prefix();
980
981            if !matched_route.auth.is_empty() && !verify_basic_auth(&req, &matched_route.auth) {
982                tracing::debug!("Basic auth failed for {}", req.uri().path());
983                return Ok((create_auth_required_response(), String::new(), vec![]));
984            }
985            let route_scripts = matched_route.route_scripts.clone();
986
987            // Select an available target via circuit breaker
988            let target_selection =
989                select_target(&matched_route, &path, circuit_breaker, &load_balancer);
990            let (mut target_url, base_url) = match target_selection {
991                Some((url, base)) => (url, base),
992                None => {
993                    // All targets are circuit-broken
994                    let body =
995                        http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
996                    return Ok((
997                        Response::builder()
998                            .status(503)
999                            .body(body)
1000                            .expect("Failed to build response"),
1001                        String::new(),
1002                        route_scripts,
1003                    ));
1004                }
1005            };
1006            // --- Lua route-specific on_request hooks ---
1007            #[cfg(feature = "scripting")]
1008            if let Some(ref engine) = lua_engine {
1009                for script_name in &route_scripts {
1010                    let mut lua_req = build_lua_request(&req);
1011                    match engine.call_route_on_request(script_name, &mut lua_req) {
1012                        RequestHookResult::Deny { status, body } => {
1013                            let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
1014                            return Ok((
1015                                Response::builder().status(status).body(resp_body).unwrap(),
1016                                target_url,
1017                                route_scripts.clone(),
1018                            ));
1019                        }
1020                        RequestHookResult::Continue(_) => {}
1021                    }
1022                }
1023            }
1024
1025            // --- Lua on_route hook (global) ---
1026            #[cfg(feature = "scripting")]
1027            if let Some(ref engine) = lua_engine {
1028                if engine.has_on_route() {
1029                    let lua_req = build_lua_request(&req);
1030                    match engine.call_on_route(&lua_req, &target_url) {
1031                        RouteHookResult::Override(new_url) => {
1032                            target_url = new_url;
1033                        }
1034                        RouteHookResult::Default => {}
1035                    }
1036                }
1037                // Route-specific on_route hooks
1038                for script_name in &route_scripts {
1039                    let lua_req = build_lua_request(&req);
1040                    match engine.call_route_on_route(script_name, &lua_req, &target_url) {
1041                        RouteHookResult::Override(new_url) => {
1042                            target_url = new_url;
1043                        }
1044                        RouteHookResult::Default => {}
1045                    }
1046                }
1047            }
1048
1049            // Only extract host_header when needed (domain rules only)
1050            let host_header = if from_domain_rule {
1051                req.uri()
1052                    .host()
1053                    .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1054                    .map(|s| s.to_string())
1055            } else {
1056                None
1057            };
1058
1059            let (mut parts, body) = req.into_parts();
1060
1061            // Move headers directly instead of cloning one by one
1062            let uri: hyper::Uri = target_url.parse().expect("valid URI");
1063            parts.uri = uri;
1064            parts.version = http::Version::HTTP_11;
1065            parts.extensions = http::Extensions::new();
1066
1067            let mut request = Request::from_parts(parts, body);
1068
1069            request
1070                .headers_mut()
1071                .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
1072
1073            if from_domain_rule {
1074                if let Some(host) = host_header {
1075                    request
1076                        .headers_mut()
1077                        .insert("X-Forwarded-Host", host.parse().unwrap());
1078                }
1079            }
1080
1081            match client.request(request).await {
1082                Ok(response) => {
1083                    // --- Circuit breaker: record success or failure ---
1084                    let status_code = response.status().as_u16();
1085                    if circuit_breaker.is_failure_status(status_code) {
1086                        circuit_breaker.record_failure(&base_url);
1087                    } else {
1088                        circuit_breaker.record_success(&base_url);
1089                    }
1090
1091                    // --- Lua on_response hooks (global + route) ---
1092                    #[cfg(feature = "scripting")]
1093                    if let Some(ref engine) = lua_engine {
1094                        let has_global = engine.has_on_response();
1095                        let has_route = !route_scripts.is_empty();
1096
1097                        if has_global || has_route {
1098                            use crate::scripting::ResponseMod;
1099
1100                            let lua_req = LuaRequest {
1101                                method: String::new(),
1102                                path: String::new(),
1103                                headers: std::collections::HashMap::new(),
1104                                host: String::new(),
1105                                content_length: 0,
1106                            };
1107                            let resp_headers = extract_response_headers(response.headers());
1108                            let resp_status = response.status().as_u16();
1109
1110                            // Collect all mods: global first, then route scripts
1111                            let mut all_mods: Vec<ResponseMod> = Vec::new();
1112                            if has_global {
1113                                all_mods.push(engine.call_on_response(
1114                                    &lua_req,
1115                                    resp_status,
1116                                    &resp_headers,
1117                                ));
1118                            }
1119                            for script_name in &route_scripts {
1120                                all_mods.push(engine.call_route_on_response(
1121                                    script_name,
1122                                    &lua_req,
1123                                    resp_status,
1124                                    &resp_headers,
1125                                ));
1126                            }
1127
1128                            // Merge all mods
1129                            let mut merged = ResponseMod::default();
1130                            for mods in all_mods {
1131                                merged.set_headers.extend(mods.set_headers);
1132                                merged.remove_headers.extend(mods.remove_headers);
1133                                if mods.replace_body.is_some() {
1134                                    merged.replace_body = mods.replace_body;
1135                                }
1136                                if mods.override_status.is_some() {
1137                                    merged.override_status = mods.override_status;
1138                                }
1139                            }
1140
1141                            // Apply modifications if any
1142                            if !merged.set_headers.is_empty()
1143                                || !merged.remove_headers.is_empty()
1144                                || merged.replace_body.is_some()
1145                                || merged.override_status.is_some()
1146                            {
1147                                let (mut parts, body) = response.into_parts();
1148
1149                                if let Some(status) = merged.override_status {
1150                                    parts.status =
1151                                        hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1152                                }
1153
1154                                for name in &merged.remove_headers {
1155                                    if let Ok(header_name) =
1156                                        name.parse::<hyper::header::HeaderName>()
1157                                    {
1158                                        parts.headers.remove(header_name);
1159                                    }
1160                                }
1161
1162                                for (name, value) in &merged.set_headers {
1163                                    if let (Ok(header_name), Ok(header_value)) = (
1164                                        name.parse::<hyper::header::HeaderName>(),
1165                                        value.parse::<HeaderValue>(),
1166                                    ) {
1167                                        parts.headers.insert(header_name, header_value);
1168                                    }
1169                                }
1170
1171                                if let Some(new_body) = merged.replace_body {
1172                                    let new_bytes = Bytes::from(new_body);
1173                                    parts.headers.remove("content-length");
1174                                    parts.headers.insert(
1175                                        "content-length",
1176                                        new_bytes.len().to_string().parse().unwrap(),
1177                                    );
1178                                    let boxed = http_body_util::Full::new(new_bytes).boxed();
1179                                    return Ok((
1180                                        Response::from_parts(parts, boxed),
1181                                        target_url,
1182                                        route_scripts.clone(),
1183                                    ));
1184                                }
1185
1186                                let boxed = body.map_err(|_| unreachable!()).boxed();
1187                                return Ok((
1188                                    Response::from_parts(parts, boxed),
1189                                    target_url,
1190                                    route_scripts.clone(),
1191                                ));
1192                            }
1193                        }
1194                    }
1195
1196                    let is_html = response
1197                        .headers()
1198                        .get("content-type")
1199                        .and_then(|v| v.to_str().ok())
1200                        .map(|ct| ct.starts_with("text/html"))
1201                        .unwrap_or(false);
1202
1203                    if is_html {
1204                        if let Some(prefix) = matched_prefix {
1205                            let (parts, body) = response.into_parts();
1206                            let body_bytes = body
1207                                .collect()
1208                                .await
1209                                .map(|collected| collected.to_bytes())
1210                                .unwrap_or_default();
1211
1212                            // Decompress gzip/deflate body before rewriting
1213                            let is_gzip = parts
1214                                .headers
1215                                .get("content-encoding")
1216                                .and_then(|v| v.to_str().ok())
1217                                .map(|v| v.contains("gzip"))
1218                                .unwrap_or(false);
1219                            let is_deflate = parts
1220                                .headers
1221                                .get("content-encoding")
1222                                .and_then(|v| v.to_str().ok())
1223                                .map(|v| v.contains("deflate"))
1224                                .unwrap_or(false);
1225
1226                            let raw_bytes = if is_gzip {
1227                                use std::io::Read;
1228                                let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1229                                let mut decoded = Vec::new();
1230                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1231                                Bytes::from(decoded)
1232                            } else if is_deflate {
1233                                use std::io::Read;
1234                                let mut decoder =
1235                                    flate2::read::DeflateDecoder::new(&body_bytes[..]);
1236                                let mut decoded = Vec::new();
1237                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1238                                Bytes::from(decoded)
1239                            } else {
1240                                body_bytes
1241                            };
1242
1243                            let html = String::from_utf8_lossy(&raw_bytes);
1244                            let rewritten = html
1245                                .replace("href=\"/", &format!("href=\"{}/", prefix))
1246                                .replace("src=\"/", &format!("src=\"{}/", prefix))
1247                                .replace("action=\"/", &format!("action=\"{}/", prefix));
1248                            let rewritten_bytes = Bytes::from(rewritten);
1249                            let mut parts = parts;
1250                            parts.headers.remove("content-encoding");
1251                            parts.headers.remove("content-length");
1252                            parts.headers.insert(
1253                                "content-length",
1254                                rewritten_bytes.len().to_string().parse().unwrap(),
1255                            );
1256                            let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1257                            return Ok((
1258                                Response::from_parts(parts, boxed),
1259                                target_url,
1260                                route_scripts.clone(),
1261                            ));
1262                        }
1263                    }
1264
1265                    let (parts, body) = response.into_parts();
1266                    let boxed = body.map_err(|_| unreachable!()).boxed();
1267                    Ok((
1268                        Response::from_parts(parts, boxed),
1269                        target_url,
1270                        route_scripts,
1271                    ))
1272                }
1273                Err(e) => {
1274                    circuit_breaker.record_failure(&base_url);
1275                    tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1276                    let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1277                    Ok((
1278                        Response::builder()
1279                            .status(502)
1280                            .body(body)
1281                            .expect("Failed to build response"),
1282                        target_url,
1283                        route_scripts,
1284                    ))
1285                }
1286            }
1287        }
1288        None => {
1289            // Suppress unused variable warning when scripting feature is disabled
1290            let _ = lua_engine;
1291            let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1292            Ok((
1293                Response::builder()
1294                    .status(421)
1295                    .body(body)
1296                    .expect("Failed to build response"),
1297                String::new(),
1298                vec![],
1299            ))
1300        }
1301    }
1302}
1303
1304/// How the target URL is resolved from the matched route
1305enum UrlResolution {
1306    /// Domain, Default: append full request path
1307    AppendPath,
1308    /// DomainPath, Prefix: strip prefix, append suffix
1309    StripPrefix(String),
1310    /// Exact, Regex: use target URL as-is
1311    Identity,
1312}
1313
1314/// A matched routing rule with all the info needed to resolve a target URL
1315struct MatchedRoute<'a> {
1316    targets: &'a [crate::config::Target],
1317    from_domain_rule: bool,
1318    resolution: UrlResolution,
1319    route_scripts: Vec<String>,
1320    auth: Vec<crate::auth::BasicAuth>,
1321    load_balancing: &'a crate::config::LoadBalancingStrategy,
1322}
1323
1324impl<'a> MatchedRoute<'a> {
1325    fn matched_prefix(&self) -> Option<String> {
1326        match &self.resolution {
1327            UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1328            _ => None,
1329        }
1330    }
1331}
1332
1333/// Resolve a target URL based on the resolution strategy
1334fn resolve_target_url(
1335    target: &crate::config::Target,
1336    path: &str,
1337    resolution: &UrlResolution,
1338) -> String {
1339    let target_str = target.url.as_str();
1340    match resolution {
1341        UrlResolution::AppendPath => {
1342            if target_str.ends_with('/') {
1343                format!("{}{}", target_str, &path[1..])
1344            } else {
1345                format!("{}{}", target_str, path)
1346            }
1347        }
1348        UrlResolution::StripPrefix(prefix) => {
1349            let suffix = if path.len() >= prefix.len() {
1350                &path[prefix.len()..]
1351            } else {
1352                ""
1353            };
1354            format!("{}{}", target_str, suffix)
1355        }
1356        UrlResolution::Identity => target_str.to_owned(),
1357    }
1358}
1359
1360/// Pure routing: find which rule matches the request
1361fn find_matching_rule<'a>(
1362    req: &Request<Incoming>,
1363    rules: &'a [crate::config::ProxyRule],
1364) -> Option<MatchedRoute<'a>> {
1365    let host = req
1366        .uri()
1367        .host()
1368        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1369        .map(|h| h.split(':').next().unwrap_or(h))?;
1370
1371    let path = req.uri().path();
1372
1373    for rule in rules {
1374        match &rule.matcher {
1375            crate::config::RuleMatcher::Domain(domain) => {
1376                if domain == host && !rule.targets.is_empty() {
1377                    return Some(MatchedRoute {
1378                        targets: &rule.targets,
1379                        from_domain_rule: true,
1380                        resolution: UrlResolution::AppendPath,
1381                        route_scripts: rule.scripts.clone(),
1382                        auth: rule.auth.clone(),
1383                        load_balancing: &rule.load_balancing,
1384                    });
1385                }
1386            }
1387            crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1388                if domain == host && !rule.targets.is_empty() {
1389                    let matches = path.starts_with(path_prefix)
1390                        || (path_prefix.ends_with('/')
1391                            && path == path_prefix.trim_end_matches('/'));
1392                    if matches {
1393                        return Some(MatchedRoute {
1394                            targets: &rule.targets,
1395                            from_domain_rule: true,
1396                            resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1397                            route_scripts: rule.scripts.clone(),
1398                            auth: rule.auth.clone(),
1399                            load_balancing: &rule.load_balancing,
1400                        });
1401                    }
1402                }
1403            }
1404            _ => {}
1405        }
1406    }
1407
1408    // Check specific rules (Exact, Prefix, Regex) before Default
1409    for rule in rules {
1410        match &rule.matcher {
1411            crate::config::RuleMatcher::Exact(exact) => {
1412                if path == exact && !rule.targets.is_empty() {
1413                    return Some(MatchedRoute {
1414                        targets: &rule.targets,
1415                        from_domain_rule: false,
1416                        resolution: UrlResolution::Identity,
1417                        route_scripts: rule.scripts.clone(),
1418                        auth: rule.auth.clone(),
1419                        load_balancing: &rule.load_balancing,
1420                    });
1421                }
1422            }
1423            crate::config::RuleMatcher::Prefix(prefix) => {
1424                if !rule.targets.is_empty() {
1425                    // Match /db against prefix /db/ (path without trailing slash)
1426                    let matches = path.starts_with(prefix)
1427                        || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1428                    if matches {
1429                        return Some(MatchedRoute {
1430                            targets: &rule.targets,
1431                            from_domain_rule: false,
1432                            resolution: UrlResolution::StripPrefix(prefix.clone()),
1433                            route_scripts: rule.scripts.clone(),
1434                            auth: rule.auth.clone(),
1435                            load_balancing: &rule.load_balancing,
1436                        });
1437                    }
1438                }
1439            }
1440            crate::config::RuleMatcher::Regex(ref rm) => {
1441                if rm.is_match(path) && !rule.targets.is_empty() {
1442                    return Some(MatchedRoute {
1443                        targets: &rule.targets,
1444                        from_domain_rule: false,
1445                        resolution: UrlResolution::Identity,
1446                        route_scripts: rule.scripts.clone(),
1447                        auth: rule.auth.clone(),
1448                        load_balancing: &rule.load_balancing,
1449                    });
1450                }
1451            }
1452            _ => {}
1453        }
1454    }
1455
1456    // Fall back to Default rule
1457    for rule in rules {
1458        if let crate::config::RuleMatcher::Default = &rule.matcher {
1459            if !rule.targets.is_empty() {
1460                return Some(MatchedRoute {
1461                    targets: &rule.targets,
1462                    from_domain_rule: false,
1463                    resolution: UrlResolution::Identity,
1464                    route_scripts: rule.scripts.clone(),
1465                    auth: rule.auth.clone(),
1466                    load_balancing: &rule.load_balancing,
1467                });
1468            }
1469        }
1470    }
1471
1472    None
1473}
1474
1475/// Select a target based on the load balancing strategy.
1476/// Returns (resolved_url, base_url) for logging and circuit breaker tracking.
1477fn select_target(
1478    route: &MatchedRoute<'_>,
1479    path: &str,
1480    circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1481    load_balancer: &LoadBalancerState,
1482) -> Option<(String, String)> {
1483    let targets = route.targets;
1484    if targets.is_empty() {
1485        return None;
1486    }
1487
1488    match route.load_balancing {
1489        crate::config::LoadBalancingStrategy::Failover => {
1490            // Failover: use first available target (circuit breaker aware)
1491            for target in targets {
1492                let base_url = target.url.as_str().to_owned();
1493                if circuit_breaker.is_available(&base_url) {
1494                    let resolved = resolve_target_url(target, path, &route.resolution);
1495                    return Some((resolved, base_url));
1496                }
1497            }
1498            None
1499        }
1500        crate::config::LoadBalancingStrategy::RoundRobin => {
1501            // Round-robin: cycle through all targets, skip unhealthy ones
1502            let num_targets = targets.len();
1503            let start_idx = load_balancer.counters[0].load(Ordering::Relaxed) % num_targets;
1504
1505            for i in 0..num_targets {
1506                let idx = (start_idx + i) % num_targets;
1507                let target = &targets[idx];
1508                let base_url = target.url.as_str().to_owned();
1509                if circuit_breaker.is_available(&base_url) {
1510                    load_balancer.counters[0].fetch_add(1, Ordering::Relaxed);
1511                    let resolved = resolve_target_url(target, path, &route.resolution);
1512                    return Some((resolved, base_url));
1513                }
1514            }
1515            None
1516        }
1517        crate::config::LoadBalancingStrategy::Weighted => {
1518            // Weighted: use weights to determine distribution, skip unhealthy
1519            let total_weight: u32 = targets.iter().map(|t| t.weight as u32).sum();
1520            if total_weight == 0 {
1521                return select_target(route, path, circuit_breaker, &LoadBalancerState::new(1));
1522            }
1523
1524            let start_idx =
1525                (load_balancer.counters[0].load(Ordering::Relaxed) % total_weight as usize) as u32;
1526            let mut cumulative = 0u32;
1527
1528            for target in targets.iter() {
1529                cumulative += target.weight as u32;
1530                let base_url = target.url.as_str().to_owned();
1531                if cumulative > start_idx && circuit_breaker.is_available(&base_url) {
1532                    load_balancer.counters[0].fetch_add(1, Ordering::Relaxed);
1533                    let resolved = resolve_target_url(target, path, &route.resolution);
1534                    return Some((resolved, base_url));
1535                }
1536            }
1537
1538            // Fallback: try any available target
1539            for target in targets {
1540                let base_url = target.url.as_str().to_owned();
1541                if circuit_breaker.is_available(&base_url) {
1542                    let resolved = resolve_target_url(target, path, &route.resolution);
1543                    return Some((resolved, base_url));
1544                }
1545            }
1546            None
1547        }
1548    }
1549}
1550
1551/// Backward-compatible wrapper: returns (target_url, from_domain_rule, matched_prefix, route_scripts)
1552fn find_target(
1553    req: &Request<Incoming>,
1554    rules: &[crate::config::ProxyRule],
1555) -> Option<(String, bool, Option<String>, Vec<String>)> {
1556    let route = find_matching_rule(req, rules)?;
1557    let path = req.uri().path();
1558    let target = route.targets.first()?;
1559    let resolved = resolve_target_url(target, path, &route.resolution);
1560    let matched_prefix = route.matched_prefix();
1561    Some((
1562        resolved,
1563        route.from_domain_rule,
1564        matched_prefix,
1565        route.route_scripts,
1566    ))
1567}
1568
1569#[cfg(test)]
1570mod tests {
1571    use super::*;
1572
1573    #[test]
1574    fn test_load_balancer_state_select_index() {
1575        let lb = LoadBalancerState::new(1);
1576
1577        // First call should return 0
1578        assert_eq!(lb.select_index(0, 3), 0);
1579        // Second call should return 1
1580        assert_eq!(lb.select_index(0, 3), 1);
1581        // Third call should return 2
1582        assert_eq!(lb.select_index(0, 3), 2);
1583        // Fourth call wraps around to 0
1584        assert_eq!(lb.select_index(0, 3), 0);
1585    }
1586
1587    #[test]
1588    fn test_load_balancer_state_zero_targets() {
1589        let lb = LoadBalancerState::new(1);
1590        assert_eq!(lb.select_index(0, 0), 0);
1591    }
1592}