1#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::auth;
7use crate::circuit_breaker::SharedCircuitBreaker;
8use crate::config::ConfigManager;
9use crate::metrics::SharedMetrics;
10use crate::shutdown::ShutdownCoordinator;
11use anyhow::Result;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Incoming;
15use hyper::header::HeaderValue;
16use hyper::service::service_fn;
17use hyper::Request;
18use hyper::Response;
19use hyper_util::client::legacy::connect::HttpConnector;
20use hyper_util::client::legacy::Client;
21use hyper_util::rt::TokioExecutor;
22use hyper_util::rt::TokioIo;
23use socket2::{Domain, Protocol, Socket, Type};
24use std::net::SocketAddr;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::sync::Arc;
27use std::time::Duration;
28use tokio::io::AsyncWriteExt;
29use tokio::net::{TcpListener, TcpStream};
30use tokio_rustls::TlsAcceptor;
31
32#[cfg(feature = "scripting")]
33use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
34
35type ClientType = Client<HttpConnector, Incoming>;
36type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
37
38#[cfg(feature = "scripting")]
39type OptionalLuaEngine = Option<LuaEngine>;
40#[cfg(not(feature = "scripting"))]
41type OptionalLuaEngine = ();
42
43pub struct LoadBalancerState {
44 counters: Vec<AtomicUsize>,
45}
46
47impl LoadBalancerState {
48 pub fn new(num_rules: usize) -> Self {
49 Self {
50 counters: (0..num_rules).map(|_| AtomicUsize::new(0)).collect(),
51 }
52 }
53
54 pub fn select_index(&self, rule_idx: usize, num_targets: usize) -> usize {
55 if num_targets == 0 {
56 return 0;
57 }
58 self.counters[rule_idx].fetch_add(1, Ordering::Relaxed) % num_targets
59 }
60}
61
62fn record_app_metrics(
64 metrics: &SharedMetrics,
65 app_manager: &Option<Arc<AppManager>>,
66 target_url: &str,
67 bytes_in: u64,
68 bytes_out: u64,
69 status: u16,
70 duration: Duration,
71) {
72 if let Some(ref manager) = app_manager {
73 if let Ok(url) = url::Url::parse(target_url) {
74 if let Some(port) = url.port() {
75 if let Some(app_name) = futures::executor::block_on(manager.get_app_name(port)) {
76 metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
77 }
78 }
79 }
80 }
81}
82
83static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
85 std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
86
87fn verify_basic_auth(req: &Request<Incoming>, auth_entries: &[crate::auth::BasicAuth]) -> bool {
90 if auth_entries.is_empty() {
91 return true;
92 }
93
94 let auth_header = req.headers().get("authorization");
95 if auth_header.is_none() {
96 return false;
97 }
98
99 let header_value = auth_header.unwrap().to_str().unwrap_or("");
100 if !header_value.starts_with("Basic ") {
101 return false;
102 }
103
104 let encoded = &header_value[6..];
105 let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded)
106 .unwrap_or_default();
107 let creds = String::from_utf8_lossy(&decoded);
108
109 if let Some((username, password)) = creds.split_once(':') {
110 for entry in auth_entries {
111 if entry.username == username && auth::verify_password(password, &entry.hash) {
112 return true;
113 }
114 }
115 }
116
117 false
118}
119
120fn create_auth_required_response() -> Response<BoxBody> {
122 let body = http_body_util::Full::new(Bytes::from("Authentication required")).boxed();
123 Response::builder()
124 .status(401)
125 .header("WWW-Authenticate", "Basic realm=\"Restricted\"")
126 .body(body)
127 .unwrap()
128}
129
130fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
131 let domain = if addr.is_ipv4() {
132 Domain::IPV4
133 } else {
134 Domain::IPV6
135 };
136 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
137 socket.set_reuse_address(true)?;
138 socket.set_reuse_port(true)?;
139 socket.set_nonblocking(true)?;
140 socket.bind(&addr.into())?;
141 socket.listen(8192)?;
142 let std_listener: std::net::TcpListener = socket.into();
143 Ok(TcpListener::from_std(std_listener)?)
144}
145
146fn create_client() -> ClientType {
147 let exec = TokioExecutor::new();
148 let mut connector = HttpConnector::new();
149 connector.set_nodelay(true);
150 connector.set_keepalive(Some(Duration::from_secs(30)));
151 connector.set_connect_timeout(Some(Duration::from_secs(5)));
152 Client::builder(exec)
153 .pool_max_idle_per_host(256)
154 .pool_idle_timeout(Duration::from_secs(60))
155 .build(connector)
156}
157
158pub struct ProxyServer {
159 config: Arc<ConfigManager>,
160 shutdown: ShutdownCoordinator,
161 tls_acceptor: Option<TlsAcceptor>,
162 https_addr: Option<SocketAddr>,
163 metrics: SharedMetrics,
164 challenge_store: ChallengeStore,
165 lua_engine: OptionalLuaEngine,
166 circuit_breaker: SharedCircuitBreaker,
167 app_manager: Option<Arc<AppManager>>,
168 load_balancer: Arc<LoadBalancerState>,
169}
170
171impl ProxyServer {
172 pub fn new(
173 config: Arc<ConfigManager>,
174 shutdown: ShutdownCoordinator,
175 metrics: SharedMetrics,
176 challenge_store: ChallengeStore,
177 lua_engine: OptionalLuaEngine,
178 circuit_breaker: SharedCircuitBreaker,
179 app_manager: Option<Arc<AppManager>>,
180 ) -> Result<Self> {
181 let num_rules = config.get_config().rules.len();
182 Ok(Self {
183 config,
184 shutdown,
185 tls_acceptor: None,
186 https_addr: None,
187 metrics,
188 challenge_store,
189 lua_engine,
190 circuit_breaker,
191 app_manager,
192 load_balancer: Arc::new(LoadBalancerState::new(num_rules)),
193 })
194 }
195
196 #[allow(clippy::too_many_arguments)]
197 pub fn with_https(
198 config: Arc<ConfigManager>,
199 shutdown: ShutdownCoordinator,
200 tls_acceptor: TlsAcceptor,
201 https_addr: SocketAddr,
202 metrics: SharedMetrics,
203 challenge_store: ChallengeStore,
204 lua_engine: OptionalLuaEngine,
205 circuit_breaker: SharedCircuitBreaker,
206 app_manager: Option<Arc<AppManager>>,
207 ) -> Result<Self> {
208 let num_rules = config.get_config().rules.len();
209 Ok(Self {
210 config,
211 shutdown,
212 tls_acceptor: Some(tls_acceptor),
213 https_addr: Some(https_addr),
214 metrics,
215 challenge_store,
216 lua_engine,
217 circuit_breaker,
218 app_manager,
219 load_balancer: Arc::new(LoadBalancerState::new(num_rules)),
220 })
221 }
222
223 pub async fn run(&self) -> Result<()> {
224 let cfg = self.config.get_config();
225 let http_addr: SocketAddr = cfg.server.bind.parse()?;
226 let https_addr = self.https_addr;
227
228 let has_https = https_addr.is_some();
229 let num_listeners = std::thread::available_parallelism()
230 .map(|n| n.get())
231 .unwrap_or(4);
232
233 let app_manager = self.app_manager.clone();
236 for i in 0..num_listeners {
237 let config_clone = self.config.clone();
238 let shutdown_clone = self.shutdown.clone();
239 let metrics_clone = self.metrics.clone();
240 let challenge_store_clone = self.challenge_store.clone();
241 let lua_clone = self.lua_engine.clone();
242 let cb_clone = self.circuit_breaker.clone();
243 let am_clone = app_manager.clone();
244 let lb_clone = self.load_balancer.clone();
245
246 tokio::spawn(async move {
247 if let Err(e) = run_http_server(
248 http_addr,
249 config_clone,
250 shutdown_clone,
251 metrics_clone,
252 challenge_store_clone,
253 lua_clone,
254 cb_clone,
255 am_clone,
256 lb_clone,
257 )
258 .await
259 {
260 tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
261 }
262 });
263 }
264
265 if let Some(https_addr) = https_addr {
266 for i in 0..num_listeners {
267 let config_clone = self.config.clone();
268 let shutdown_clone = self.shutdown.clone();
269 let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
270 let metrics_clone = self.metrics.clone();
271 let challenge_store_clone = self.challenge_store.clone();
272 let lua_clone = self.lua_engine.clone();
273 let cb_clone = self.circuit_breaker.clone();
274 let am_clone = app_manager.clone();
275 let lb_clone = self.load_balancer.clone();
276
277 tokio::spawn(async move {
278 if let Err(e) = run_https_server(
279 https_addr,
280 config_clone,
281 shutdown_clone,
282 acceptor,
283 metrics_clone,
284 challenge_store_clone,
285 lua_clone,
286 cb_clone,
287 am_clone,
288 lb_clone,
289 )
290 .await
291 {
292 tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
293 }
294 });
295 }
296 }
297
298 tracing::info!(
299 "HTTP/1.1 server listening on {} ({} accept loops)",
300 http_addr,
301 num_listeners
302 );
303 if has_https {
304 tracing::info!(
305 "HTTPS/2 server listening on {} ({} accept loops)",
306 https_addr.unwrap(),
307 num_listeners
308 );
309 }
310
311 loop {
312 if self.shutdown.is_shutting_down() {
313 tracing::info!("Shutting down servers...");
314 break;
315 }
316 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
317 }
318
319 Ok(())
320 }
321}
322
323#[allow(clippy::too_many_arguments)]
324async fn run_http_server(
325 addr: SocketAddr,
326 config: Arc<ConfigManager>,
327 shutdown: ShutdownCoordinator,
328 metrics: SharedMetrics,
329 challenge_store: ChallengeStore,
330 lua_engine: OptionalLuaEngine,
331 circuit_breaker: SharedCircuitBreaker,
332 app_manager: Option<Arc<AppManager>>,
333 load_balancer: Arc<LoadBalancerState>,
334) -> Result<()> {
335 let listener = create_listener(addr)?;
336 let client = create_client();
337
338 loop {
339 if shutdown.is_shutting_down() {
340 break;
341 }
342
343 match listener.accept().await {
344 Ok((stream, _)) => {
345 let _ = stream.set_nodelay(true);
346 let client = client.clone();
347 let config = config.clone();
348 let metrics = metrics.clone();
349 let cs = challenge_store.clone();
350 let lua = lua_engine.clone();
351 let cb = circuit_breaker.clone();
352 let am = app_manager.clone();
353 let lb = load_balancer.clone();
354 tokio::spawn(async move {
355 if let Err(e) = handle_http11_connection(
356 stream, client, config, metrics, cs, lua, cb, am, lb,
357 )
358 .await
359 {
360 tracing::debug!("HTTP/1.1 connection error: {}", e);
361 }
362 });
363 }
364 Err(e) => {
365 tracing::error!("HTTP/1.1 accept error: {}", e);
366 }
367 }
368 }
369
370 Ok(())
371}
372
373#[allow(clippy::too_many_arguments)]
374async fn run_https_server(
375 addr: SocketAddr,
376 config: Arc<ConfigManager>,
377 shutdown: ShutdownCoordinator,
378 acceptor: TlsAcceptor,
379 metrics: SharedMetrics,
380 challenge_store: ChallengeStore,
381 lua_engine: OptionalLuaEngine,
382 circuit_breaker: SharedCircuitBreaker,
383 app_manager: Option<Arc<AppManager>>,
384 load_balancer: Arc<LoadBalancerState>,
385) -> Result<()> {
386 let listener = create_listener(addr)?;
387 let client = create_client();
388
389 loop {
390 if shutdown.is_shutting_down() {
391 break;
392 }
393
394 match listener.accept().await {
395 Ok((stream, _)) => {
396 let _ = stream.set_nodelay(true);
397 let client = client.clone();
398 let config = config.clone();
399 let acceptor = acceptor.clone();
400 let metrics = metrics.clone();
401 let cs = challenge_store.clone();
402 let lua = lua_engine.clone();
403 let cb = circuit_breaker.clone();
404 let am = app_manager.clone();
405 let lb = load_balancer.clone();
406 tokio::spawn(async move {
407 match acceptor.accept(stream).await {
408 Ok(tls_stream) => {
409 metrics.inc_tls_connections();
410 if let Err(e) = handle_https2_connection(
411 tls_stream, client, config, metrics, cs, lua, cb, am, lb,
412 )
413 .await
414 {
415 tracing::debug!("HTTPS/2 connection error: {}", e);
416 }
417 }
418 Err(e) => {
419 tracing::error!("TLS accept error: {}", e);
420 }
421 }
422 });
423 }
424 Err(e) => {
425 tracing::error!("HTTPS/2 accept error: {}", e);
426 }
427 }
428 }
429
430 Ok(())
431}
432
433#[allow(clippy::too_many_arguments)]
434async fn handle_http11_connection(
435 stream: tokio::net::TcpStream,
436 client: ClientType,
437 config: Arc<ConfigManager>,
438 metrics: SharedMetrics,
439 challenge_store: ChallengeStore,
440 lua_engine: OptionalLuaEngine,
441 circuit_breaker: SharedCircuitBreaker,
442 app_manager: Option<Arc<AppManager>>,
443 load_balancer: Arc<LoadBalancerState>,
444) -> Result<()> {
445 let io = TokioIo::new(stream);
446 let svc = service_fn(move |req| {
447 handle_request(
448 req,
449 client.clone(),
450 config.clone(),
451 metrics.clone(),
452 challenge_store.clone(),
453 lua_engine.clone(),
454 circuit_breaker.clone(),
455 app_manager.clone(),
456 load_balancer.clone(),
457 )
458 });
459
460 let conn = hyper::server::conn::http1::Builder::new()
461 .keep_alive(true)
462 .pipeline_flush(true)
463 .serve_connection(io, svc)
464 .with_upgrades();
465
466 if let Err(e) = conn.await {
467 tracing::debug!("HTTP/1.1 connection error: {}", e);
468 }
469
470 Ok(())
471}
472
473#[allow(clippy::too_many_arguments)]
474async fn handle_https2_connection(
475 stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
476 client: ClientType,
477 config: Arc<ConfigManager>,
478 metrics: SharedMetrics,
479 challenge_store: ChallengeStore,
480 lua_engine: OptionalLuaEngine,
481 circuit_breaker: SharedCircuitBreaker,
482 app_manager: Option<Arc<AppManager>>,
483 load_balancer: Arc<LoadBalancerState>,
484) -> Result<()> {
485 let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
486
487 let io = TokioIo::new(stream);
488
489 if is_h2 {
490 let exec = TokioExecutor::new();
491 let svc = service_fn(move |req| {
492 handle_request(
493 req,
494 client.clone(),
495 config.clone(),
496 metrics.clone(),
497 challenge_store.clone(),
498 lua_engine.clone(),
499 circuit_breaker.clone(),
500 app_manager.clone(),
501 load_balancer.clone(),
502 )
503 });
504 let conn = hyper::server::conn::http2::Builder::new(exec)
505 .initial_stream_window_size(1024 * 1024)
506 .initial_connection_window_size(2 * 1024 * 1024)
507 .max_concurrent_streams(250)
508 .serve_connection(io, svc);
509 if let Err(e) = conn.await {
510 tracing::debug!("HTTPS/2 connection error: {}", e);
511 }
512 } else {
513 let svc = service_fn(move |req| {
514 handle_request(
515 req,
516 client.clone(),
517 config.clone(),
518 metrics.clone(),
519 challenge_store.clone(),
520 lua_engine.clone(),
521 circuit_breaker.clone(),
522 app_manager.clone(),
523 load_balancer.clone(),
524 )
525 });
526 let conn = hyper::server::conn::http1::Builder::new()
527 .keep_alive(true)
528 .pipeline_flush(true)
529 .serve_connection(io, svc)
530 .with_upgrades();
531 if let Err(e) = conn.await {
532 tracing::debug!("HTTPS/1.1 connection error: {}", e);
533 }
534 }
535
536 Ok(())
537}
538
539#[cfg(feature = "scripting")]
541fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
542 req.headers()
543 .iter()
544 .map(|(k, v)| {
545 (
546 k.as_str().to_lowercase(),
547 v.to_str().unwrap_or("").to_string(),
548 )
549 })
550 .collect()
551}
552
553#[cfg(feature = "scripting")]
555fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
556 let host = req
557 .uri()
558 .host()
559 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
560 .unwrap_or("")
561 .to_string();
562
563 let content_length = req
564 .headers()
565 .get("content-length")
566 .and_then(|v| v.to_str().ok())
567 .and_then(|v| v.parse().ok())
568 .unwrap_or(0);
569
570 LuaRequest {
571 method: req.method().to_string(),
572 path: req.uri().path().to_string(),
573 headers: extract_headers(req),
574 host,
575 content_length,
576 }
577}
578
579#[cfg(feature = "scripting")]
581fn extract_response_headers(
582 headers: &hyper::HeaderMap,
583) -> std::collections::HashMap<String, String> {
584 headers
585 .iter()
586 .map(|(k, v)| {
587 (
588 k.as_str().to_lowercase(),
589 v.to_str().unwrap_or("").to_string(),
590 )
591 })
592 .collect()
593}
594
595#[allow(clippy::too_many_arguments)]
596async fn handle_request(
597 req: Request<Incoming>,
598 client: ClientType,
599 config_manager: Arc<ConfigManager>,
600 metrics: SharedMetrics,
601 challenge_store: ChallengeStore,
602 lua_engine: OptionalLuaEngine,
603 circuit_breaker: SharedCircuitBreaker,
604 app_manager: Option<Arc<AppManager>>,
605 load_balancer: Arc<LoadBalancerState>,
606) -> Result<Response<BoxBody>, hyper::Error> {
607 let start_time = std::time::Instant::now();
608 metrics.inc_in_flight();
609 let config = config_manager.get_config();
610
611 if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
613 metrics.dec_in_flight();
614 return Ok(response);
615 }
616
617 if is_metrics_request(&req) {
618 let duration = start_time.elapsed();
619 metrics.dec_in_flight();
620 let metrics_output = metrics.format_metrics();
621 metrics.record_request(0, metrics_output.len() as u64, 200, duration);
622 let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
623 return Ok(Response::builder()
624 .status(200)
625 .header("Content-Type", "text/plain")
626 .body(body)
627 .unwrap());
628 }
629
630 #[cfg(feature = "scripting")]
632 if let Some(ref engine) = lua_engine {
633 if engine.has_on_request() {
634 let mut lua_req = build_lua_request(&req);
635 match engine.call_on_request(&mut lua_req) {
636 RequestHookResult::Deny { status, body } => {
637 metrics.dec_in_flight();
638 let duration = start_time.elapsed();
639 metrics.record_request(0, body.len() as u64, status, duration);
640 let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
641 return Ok(Response::builder().status(status).body(resp_body).unwrap());
642 }
643 RequestHookResult::Continue(updated_req) => {
644 let _ = updated_req;
650 }
651 }
652 }
653 }
654
655 let is_websocket = is_websocket_request(&req);
656
657 if is_websocket {
658 return handle_websocket_request(
659 req,
660 client,
661 &config,
662 &metrics,
663 start_time,
664 app_manager.clone(),
665 )
666 .await;
667 }
668
669 let result = handle_regular_request(
670 req,
671 client,
672 &config,
673 &lua_engine,
674 &circuit_breaker,
675 app_manager.clone(),
676 load_balancer.clone(),
677 )
678 .await;
679 let duration = start_time.elapsed();
680
681 metrics.dec_in_flight();
682
683 match result {
684 #[allow(unused_variables)]
685 Ok((response, _target_url, route_scripts)) => {
686 let status = response.status().as_u16();
687
688 #[cfg(feature = "scripting")]
690 if let Some(ref engine) = lua_engine {
691 let lua_req = LuaRequest {
692 method: String::new(),
693 path: String::new(),
694 headers: std::collections::HashMap::new(),
695 host: String::new(),
696 content_length: 0,
697 };
698 let duration_ms = duration.as_secs_f64() * 1000.0;
699
700 if engine.has_on_request_end() {
702 engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
703 }
704
705 for script_name in &route_scripts {
707 engine.call_route_on_request_end(
708 script_name,
709 &lua_req,
710 status,
711 duration_ms,
712 &_target_url,
713 );
714 }
715 }
716
717 metrics.record_request(0, 0, status, duration);
718 record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration);
719 let (parts, body) = response.into_parts();
720 let boxed = body.map_err(|_| unreachable!()).boxed();
721 Ok(Response::from_parts(parts, boxed))
722 }
723 Err(e) => {
724 metrics.inc_errors();
725 Err(e)
726 }
727 }
728}
729
730fn is_websocket_request(req: &Request<Incoming>) -> bool {
731 if let Some(upgrade) = req.headers().get("upgrade") {
732 if upgrade == "websocket" {
733 return true;
734 }
735 }
736 false
737}
738
739fn is_metrics_request(req: &Request<Incoming>) -> bool {
740 req.uri().path() == "/metrics"
741}
742
743fn handle_acme_challenge(
744 req: &Request<Incoming>,
745 challenge_store: &ChallengeStore,
746) -> Option<Response<BoxBody>> {
747 let path = req.uri().path();
748 let prefix = "/.well-known/acme-challenge/";
749
750 if !path.starts_with(prefix) {
751 return None;
752 }
753
754 let token = &path[prefix.len()..];
755
756 if let Ok(store) = challenge_store.read() {
757 if let Some(key_auth) = store.get(token) {
758 let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
759 return Some(
760 Response::builder()
761 .status(200)
762 .header("Content-Type", "text/plain")
763 .body(body)
764 .unwrap(),
765 );
766 }
767 }
768
769 let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
770 Some(Response::builder().status(404).body(body).unwrap())
771}
772
773async fn handle_websocket_request(
774 req: Request<Incoming>,
775 _client: ClientType,
776 config: &crate::config::Config,
777 metrics: &SharedMetrics,
778 _start_time: std::time::Instant,
779 _app_manager: Option<Arc<AppManager>>,
780) -> Result<Response<BoxBody>, hyper::Error> {
781 let target_result = find_target(&req, &config.rules);
782
783 if target_result.is_none() {
784 metrics.inc_errors();
785 let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
786 return Ok(Response::builder().status(421).body(body).unwrap());
787 }
788
789 let (target_url, _, _, _) = target_result.unwrap();
790
791 let backend_addr = match url::Url::parse(&target_url) {
793 Ok(u) => format!(
794 "{}:{}",
795 u.host_str().unwrap_or("127.0.0.1"),
796 u.port().unwrap_or(80)
797 ),
798 Err(_) => {
799 metrics.inc_errors();
800 let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
801 return Ok(Response::builder().status(502).body(body).unwrap());
802 }
803 };
804
805 let path = req.uri().path().to_string();
806 let query = req
807 .uri()
808 .query()
809 .map(|q| format!("?{}", q))
810 .unwrap_or_default();
811
812 let ws_key = req
813 .headers()
814 .get("sec-websocket-key")
815 .and_then(|v| v.to_str().ok())
816 .unwrap_or("")
817 .to_string();
818 let ws_version = req
819 .headers()
820 .get("sec-websocket-version")
821 .and_then(|v| v.to_str().ok())
822 .unwrap_or("13")
823 .to_string();
824 let ws_protocol = req
825 .headers()
826 .get("sec-websocket-protocol")
827 .and_then(|v| v.to_str().ok())
828 .map(|s| s.to_string());
829 let host_header = req
830 .headers()
831 .get("host")
832 .and_then(|v| v.to_str().ok())
833 .unwrap_or(&backend_addr)
834 .to_string();
835
836 tracing::info!(
837 "WebSocket upgrade request to {}{}{}",
838 backend_addr,
839 path,
840 query
841 );
842
843 let backend = match TcpStream::connect(&backend_addr).await {
845 Ok(s) => s,
846 Err(e) => {
847 tracing::error!("Failed to connect to backend for WebSocket: {}", e);
848 metrics.inc_errors();
849 let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
850 return Ok(Response::builder().status(502).body(body).unwrap());
851 }
852 };
853
854 let mut handshake = format!(
856 "GET {}{} HTTP/1.1\r\n\
857 Host: {}\r\n\
858 Upgrade: websocket\r\n\
859 Connection: Upgrade\r\n\
860 Sec-WebSocket-Key: {}\r\n\
861 Sec-WebSocket-Version: {}\r\n",
862 path, query, host_header, ws_key, ws_version,
863 );
864 if let Some(proto) = &ws_protocol {
865 handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
866 }
867 handshake.push_str("\r\n");
868
869 let (mut backend_read, mut backend_write) = backend.into_split();
870 if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
871 tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
872 metrics.inc_errors();
873 let body =
874 http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
875 .boxed();
876 return Ok(Response::builder().status(502).body(body).unwrap());
877 }
878
879 let mut response_buf = vec![0u8; 4096];
881 let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
882 Ok(n) if n > 0 => n,
883 _ => {
884 tracing::error!("No response from backend for WebSocket upgrade");
885 metrics.inc_errors();
886 let body = http_body_util::Full::new(Bytes::from(
887 "Backend did not respond to WebSocket upgrade",
888 ))
889 .boxed();
890 return Ok(Response::builder().status(502).body(body).unwrap());
891 }
892 };
893
894 let response_str = String::from_utf8_lossy(&response_buf[..n]);
895 if !response_str.contains("101") {
896 tracing::error!(
897 "Backend rejected WebSocket upgrade: {}",
898 response_str.lines().next().unwrap_or("")
899 );
900 metrics.inc_errors();
901 let body =
902 http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
903 return Ok(Response::builder().status(502).body(body).unwrap());
904 }
905
906 let mut accept_key = String::new();
908 let mut resp_protocol = None;
909 for line in response_str.lines().skip(1) {
910 if line.trim().is_empty() {
911 break;
912 }
913 if let Some((name, value)) = line.split_once(':') {
914 let name_lower = name.trim().to_lowercase();
915 let value = value.trim().to_string();
916 if name_lower == "sec-websocket-accept" {
917 accept_key = value;
918 } else if name_lower == "sec-websocket-protocol" {
919 resp_protocol = Some(value);
920 }
921 }
922 }
923
924 let client_upgrade = hyper::upgrade::on(req);
926
927 let backend_stream = backend_read.reunite(backend_write).unwrap();
929
930 tokio::spawn(async move {
932 match client_upgrade.await {
933 Ok(upgraded) => {
934 let mut client_stream = TokioIo::new(upgraded);
935 let (mut br, mut bw) = tokio::io::split(backend_stream);
936 let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
937 let _ = tokio::join!(
938 tokio::io::copy(&mut br, &mut cw),
939 tokio::io::copy(&mut cr, &mut bw),
940 );
941 }
942 Err(e) => {
943 tracing::error!("WebSocket client upgrade failed: {}", e);
944 }
945 }
946 });
947
948 let mut resp = Response::builder()
950 .status(101)
951 .header("Upgrade", "websocket")
952 .header("Connection", "Upgrade")
953 .header("Sec-WebSocket-Accept", accept_key);
954 if let Some(proto) = resp_protocol {
955 resp = resp.header("Sec-WebSocket-Protocol", proto);
956 }
957 Ok(resp
958 .body(http_body_util::Full::new(Bytes::new()).boxed())
959 .unwrap())
960}
961
962async fn handle_regular_request(
964 req: Request<Incoming>,
965 client: ClientType,
966 config: &crate::config::Config,
967 lua_engine: &OptionalLuaEngine,
968 circuit_breaker: &SharedCircuitBreaker,
969 _app_manager: Option<Arc<AppManager>>,
970 load_balancer: Arc<LoadBalancerState>,
971) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
972 let route = find_matching_rule(&req, &config.rules);
973
974 match route {
975 #[allow(unused_mut, unused_variables)]
976 Some(matched_route) => {
977 let path = req.uri().path().to_string();
978 let from_domain_rule = matched_route.from_domain_rule;
979 let matched_prefix = matched_route.matched_prefix();
980
981 if !matched_route.auth.is_empty() && !verify_basic_auth(&req, &matched_route.auth) {
982 tracing::debug!("Basic auth failed for {}", req.uri().path());
983 return Ok((create_auth_required_response(), String::new(), vec![]));
984 }
985 let route_scripts = matched_route.route_scripts.clone();
986
987 let target_selection =
989 select_target(&matched_route, &path, circuit_breaker, &load_balancer);
990 let (mut target_url, base_url) = match target_selection {
991 Some((url, base)) => (url, base),
992 None => {
993 let body =
995 http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
996 return Ok((
997 Response::builder()
998 .status(503)
999 .body(body)
1000 .expect("Failed to build response"),
1001 String::new(),
1002 route_scripts,
1003 ));
1004 }
1005 };
1006 #[cfg(feature = "scripting")]
1008 if let Some(ref engine) = lua_engine {
1009 for script_name in &route_scripts {
1010 let mut lua_req = build_lua_request(&req);
1011 match engine.call_route_on_request(script_name, &mut lua_req) {
1012 RequestHookResult::Deny { status, body } => {
1013 let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
1014 return Ok((
1015 Response::builder().status(status).body(resp_body).unwrap(),
1016 target_url,
1017 route_scripts.clone(),
1018 ));
1019 }
1020 RequestHookResult::Continue(_) => {}
1021 }
1022 }
1023 }
1024
1025 #[cfg(feature = "scripting")]
1027 if let Some(ref engine) = lua_engine {
1028 if engine.has_on_route() {
1029 let lua_req = build_lua_request(&req);
1030 match engine.call_on_route(&lua_req, &target_url) {
1031 RouteHookResult::Override(new_url) => {
1032 target_url = new_url;
1033 }
1034 RouteHookResult::Default => {}
1035 }
1036 }
1037 for script_name in &route_scripts {
1039 let lua_req = build_lua_request(&req);
1040 match engine.call_route_on_route(script_name, &lua_req, &target_url) {
1041 RouteHookResult::Override(new_url) => {
1042 target_url = new_url;
1043 }
1044 RouteHookResult::Default => {}
1045 }
1046 }
1047 }
1048
1049 let host_header = if from_domain_rule {
1051 req.uri()
1052 .host()
1053 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1054 .map(|s| s.to_string())
1055 } else {
1056 None
1057 };
1058
1059 let (mut parts, body) = req.into_parts();
1060
1061 let uri: hyper::Uri = target_url.parse().expect("valid URI");
1063 parts.uri = uri;
1064 parts.version = http::Version::HTTP_11;
1065 parts.extensions = http::Extensions::new();
1066
1067 let mut request = Request::from_parts(parts, body);
1068
1069 request
1070 .headers_mut()
1071 .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
1072
1073 if from_domain_rule {
1074 if let Some(host) = host_header {
1075 request
1076 .headers_mut()
1077 .insert("X-Forwarded-Host", host.parse().unwrap());
1078 }
1079 }
1080
1081 match client.request(request).await {
1082 Ok(response) => {
1083 let status_code = response.status().as_u16();
1085 if circuit_breaker.is_failure_status(status_code) {
1086 circuit_breaker.record_failure(&base_url);
1087 } else {
1088 circuit_breaker.record_success(&base_url);
1089 }
1090
1091 #[cfg(feature = "scripting")]
1093 if let Some(ref engine) = lua_engine {
1094 let has_global = engine.has_on_response();
1095 let has_route = !route_scripts.is_empty();
1096
1097 if has_global || has_route {
1098 use crate::scripting::ResponseMod;
1099
1100 let lua_req = LuaRequest {
1101 method: String::new(),
1102 path: String::new(),
1103 headers: std::collections::HashMap::new(),
1104 host: String::new(),
1105 content_length: 0,
1106 };
1107 let resp_headers = extract_response_headers(response.headers());
1108 let resp_status = response.status().as_u16();
1109
1110 let mut all_mods: Vec<ResponseMod> = Vec::new();
1112 if has_global {
1113 all_mods.push(engine.call_on_response(
1114 &lua_req,
1115 resp_status,
1116 &resp_headers,
1117 ));
1118 }
1119 for script_name in &route_scripts {
1120 all_mods.push(engine.call_route_on_response(
1121 script_name,
1122 &lua_req,
1123 resp_status,
1124 &resp_headers,
1125 ));
1126 }
1127
1128 let mut merged = ResponseMod::default();
1130 for mods in all_mods {
1131 merged.set_headers.extend(mods.set_headers);
1132 merged.remove_headers.extend(mods.remove_headers);
1133 if mods.replace_body.is_some() {
1134 merged.replace_body = mods.replace_body;
1135 }
1136 if mods.override_status.is_some() {
1137 merged.override_status = mods.override_status;
1138 }
1139 }
1140
1141 if !merged.set_headers.is_empty()
1143 || !merged.remove_headers.is_empty()
1144 || merged.replace_body.is_some()
1145 || merged.override_status.is_some()
1146 {
1147 let (mut parts, body) = response.into_parts();
1148
1149 if let Some(status) = merged.override_status {
1150 parts.status =
1151 hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1152 }
1153
1154 for name in &merged.remove_headers {
1155 if let Ok(header_name) =
1156 name.parse::<hyper::header::HeaderName>()
1157 {
1158 parts.headers.remove(header_name);
1159 }
1160 }
1161
1162 for (name, value) in &merged.set_headers {
1163 if let (Ok(header_name), Ok(header_value)) = (
1164 name.parse::<hyper::header::HeaderName>(),
1165 value.parse::<HeaderValue>(),
1166 ) {
1167 parts.headers.insert(header_name, header_value);
1168 }
1169 }
1170
1171 if let Some(new_body) = merged.replace_body {
1172 let new_bytes = Bytes::from(new_body);
1173 parts.headers.remove("content-length");
1174 parts.headers.insert(
1175 "content-length",
1176 new_bytes.len().to_string().parse().unwrap(),
1177 );
1178 let boxed = http_body_util::Full::new(new_bytes).boxed();
1179 return Ok((
1180 Response::from_parts(parts, boxed),
1181 target_url,
1182 route_scripts.clone(),
1183 ));
1184 }
1185
1186 let boxed = body.map_err(|_| unreachable!()).boxed();
1187 return Ok((
1188 Response::from_parts(parts, boxed),
1189 target_url,
1190 route_scripts.clone(),
1191 ));
1192 }
1193 }
1194 }
1195
1196 let is_html = response
1197 .headers()
1198 .get("content-type")
1199 .and_then(|v| v.to_str().ok())
1200 .map(|ct| ct.starts_with("text/html"))
1201 .unwrap_or(false);
1202
1203 if is_html {
1204 if let Some(prefix) = matched_prefix {
1205 let (parts, body) = response.into_parts();
1206 let body_bytes = body
1207 .collect()
1208 .await
1209 .map(|collected| collected.to_bytes())
1210 .unwrap_or_default();
1211
1212 let is_gzip = parts
1214 .headers
1215 .get("content-encoding")
1216 .and_then(|v| v.to_str().ok())
1217 .map(|v| v.contains("gzip"))
1218 .unwrap_or(false);
1219 let is_deflate = parts
1220 .headers
1221 .get("content-encoding")
1222 .and_then(|v| v.to_str().ok())
1223 .map(|v| v.contains("deflate"))
1224 .unwrap_or(false);
1225
1226 let raw_bytes = if is_gzip {
1227 use std::io::Read;
1228 let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1229 let mut decoded = Vec::new();
1230 decoder.read_to_end(&mut decoded).unwrap_or_default();
1231 Bytes::from(decoded)
1232 } else if is_deflate {
1233 use std::io::Read;
1234 let mut decoder =
1235 flate2::read::DeflateDecoder::new(&body_bytes[..]);
1236 let mut decoded = Vec::new();
1237 decoder.read_to_end(&mut decoded).unwrap_or_default();
1238 Bytes::from(decoded)
1239 } else {
1240 body_bytes
1241 };
1242
1243 let html = String::from_utf8_lossy(&raw_bytes);
1244 let rewritten = html
1245 .replace("href=\"/", &format!("href=\"{}/", prefix))
1246 .replace("src=\"/", &format!("src=\"{}/", prefix))
1247 .replace("action=\"/", &format!("action=\"{}/", prefix));
1248 let rewritten_bytes = Bytes::from(rewritten);
1249 let mut parts = parts;
1250 parts.headers.remove("content-encoding");
1251 parts.headers.remove("content-length");
1252 parts.headers.insert(
1253 "content-length",
1254 rewritten_bytes.len().to_string().parse().unwrap(),
1255 );
1256 let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1257 return Ok((
1258 Response::from_parts(parts, boxed),
1259 target_url,
1260 route_scripts.clone(),
1261 ));
1262 }
1263 }
1264
1265 let (parts, body) = response.into_parts();
1266 let boxed = body.map_err(|_| unreachable!()).boxed();
1267 Ok((
1268 Response::from_parts(parts, boxed),
1269 target_url,
1270 route_scripts,
1271 ))
1272 }
1273 Err(e) => {
1274 circuit_breaker.record_failure(&base_url);
1275 tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1276 let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1277 Ok((
1278 Response::builder()
1279 .status(502)
1280 .body(body)
1281 .expect("Failed to build response"),
1282 target_url,
1283 route_scripts,
1284 ))
1285 }
1286 }
1287 }
1288 None => {
1289 let _ = lua_engine;
1291 let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1292 Ok((
1293 Response::builder()
1294 .status(421)
1295 .body(body)
1296 .expect("Failed to build response"),
1297 String::new(),
1298 vec![],
1299 ))
1300 }
1301 }
1302}
1303
1304enum UrlResolution {
1306 AppendPath,
1308 StripPrefix(String),
1310 Identity,
1312}
1313
1314struct MatchedRoute<'a> {
1316 targets: &'a [crate::config::Target],
1317 from_domain_rule: bool,
1318 resolution: UrlResolution,
1319 route_scripts: Vec<String>,
1320 auth: Vec<crate::auth::BasicAuth>,
1321 load_balancing: &'a crate::config::LoadBalancingStrategy,
1322}
1323
1324impl<'a> MatchedRoute<'a> {
1325 fn matched_prefix(&self) -> Option<String> {
1326 match &self.resolution {
1327 UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1328 _ => None,
1329 }
1330 }
1331}
1332
1333fn resolve_target_url(
1335 target: &crate::config::Target,
1336 path: &str,
1337 resolution: &UrlResolution,
1338) -> String {
1339 let target_str = target.url.as_str();
1340 match resolution {
1341 UrlResolution::AppendPath => {
1342 if target_str.ends_with('/') {
1343 format!("{}{}", target_str, &path[1..])
1344 } else {
1345 format!("{}{}", target_str, path)
1346 }
1347 }
1348 UrlResolution::StripPrefix(prefix) => {
1349 let suffix = if path.len() >= prefix.len() {
1350 &path[prefix.len()..]
1351 } else {
1352 ""
1353 };
1354 format!("{}{}", target_str, suffix)
1355 }
1356 UrlResolution::Identity => target_str.to_owned(),
1357 }
1358}
1359
1360fn find_matching_rule<'a>(
1362 req: &Request<Incoming>,
1363 rules: &'a [crate::config::ProxyRule],
1364) -> Option<MatchedRoute<'a>> {
1365 let host = req
1366 .uri()
1367 .host()
1368 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1369 .map(|h| h.split(':').next().unwrap_or(h))?;
1370
1371 let path = req.uri().path();
1372
1373 for rule in rules {
1374 match &rule.matcher {
1375 crate::config::RuleMatcher::Domain(domain) => {
1376 if domain == host && !rule.targets.is_empty() {
1377 return Some(MatchedRoute {
1378 targets: &rule.targets,
1379 from_domain_rule: true,
1380 resolution: UrlResolution::AppendPath,
1381 route_scripts: rule.scripts.clone(),
1382 auth: rule.auth.clone(),
1383 load_balancing: &rule.load_balancing,
1384 });
1385 }
1386 }
1387 crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1388 if domain == host && !rule.targets.is_empty() {
1389 let matches = path.starts_with(path_prefix)
1390 || (path_prefix.ends_with('/')
1391 && path == path_prefix.trim_end_matches('/'));
1392 if matches {
1393 return Some(MatchedRoute {
1394 targets: &rule.targets,
1395 from_domain_rule: true,
1396 resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1397 route_scripts: rule.scripts.clone(),
1398 auth: rule.auth.clone(),
1399 load_balancing: &rule.load_balancing,
1400 });
1401 }
1402 }
1403 }
1404 _ => {}
1405 }
1406 }
1407
1408 for rule in rules {
1410 match &rule.matcher {
1411 crate::config::RuleMatcher::Exact(exact) => {
1412 if path == exact && !rule.targets.is_empty() {
1413 return Some(MatchedRoute {
1414 targets: &rule.targets,
1415 from_domain_rule: false,
1416 resolution: UrlResolution::Identity,
1417 route_scripts: rule.scripts.clone(),
1418 auth: rule.auth.clone(),
1419 load_balancing: &rule.load_balancing,
1420 });
1421 }
1422 }
1423 crate::config::RuleMatcher::Prefix(prefix) => {
1424 if !rule.targets.is_empty() {
1425 let matches = path.starts_with(prefix)
1427 || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1428 if matches {
1429 return Some(MatchedRoute {
1430 targets: &rule.targets,
1431 from_domain_rule: false,
1432 resolution: UrlResolution::StripPrefix(prefix.clone()),
1433 route_scripts: rule.scripts.clone(),
1434 auth: rule.auth.clone(),
1435 load_balancing: &rule.load_balancing,
1436 });
1437 }
1438 }
1439 }
1440 crate::config::RuleMatcher::Regex(ref rm) => {
1441 if rm.is_match(path) && !rule.targets.is_empty() {
1442 return Some(MatchedRoute {
1443 targets: &rule.targets,
1444 from_domain_rule: false,
1445 resolution: UrlResolution::Identity,
1446 route_scripts: rule.scripts.clone(),
1447 auth: rule.auth.clone(),
1448 load_balancing: &rule.load_balancing,
1449 });
1450 }
1451 }
1452 _ => {}
1453 }
1454 }
1455
1456 for rule in rules {
1458 if let crate::config::RuleMatcher::Default = &rule.matcher {
1459 if !rule.targets.is_empty() {
1460 return Some(MatchedRoute {
1461 targets: &rule.targets,
1462 from_domain_rule: false,
1463 resolution: UrlResolution::Identity,
1464 route_scripts: rule.scripts.clone(),
1465 auth: rule.auth.clone(),
1466 load_balancing: &rule.load_balancing,
1467 });
1468 }
1469 }
1470 }
1471
1472 None
1473}
1474
1475fn select_target(
1478 route: &MatchedRoute<'_>,
1479 path: &str,
1480 circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1481 load_balancer: &LoadBalancerState,
1482) -> Option<(String, String)> {
1483 let targets = route.targets;
1484 if targets.is_empty() {
1485 return None;
1486 }
1487
1488 match route.load_balancing {
1489 crate::config::LoadBalancingStrategy::Failover => {
1490 for target in targets {
1492 let base_url = target.url.as_str().to_owned();
1493 if circuit_breaker.is_available(&base_url) {
1494 let resolved = resolve_target_url(target, path, &route.resolution);
1495 return Some((resolved, base_url));
1496 }
1497 }
1498 None
1499 }
1500 crate::config::LoadBalancingStrategy::RoundRobin => {
1501 let num_targets = targets.len();
1503 let start_idx = load_balancer.counters[0].load(Ordering::Relaxed) % num_targets;
1504
1505 for i in 0..num_targets {
1506 let idx = (start_idx + i) % num_targets;
1507 let target = &targets[idx];
1508 let base_url = target.url.as_str().to_owned();
1509 if circuit_breaker.is_available(&base_url) {
1510 load_balancer.counters[0].fetch_add(1, Ordering::Relaxed);
1511 let resolved = resolve_target_url(target, path, &route.resolution);
1512 return Some((resolved, base_url));
1513 }
1514 }
1515 None
1516 }
1517 crate::config::LoadBalancingStrategy::Weighted => {
1518 let total_weight: u32 = targets.iter().map(|t| t.weight as u32).sum();
1520 if total_weight == 0 {
1521 return select_target(route, path, circuit_breaker, &LoadBalancerState::new(1));
1522 }
1523
1524 let start_idx =
1525 (load_balancer.counters[0].load(Ordering::Relaxed) % total_weight as usize) as u32;
1526 let mut cumulative = 0u32;
1527
1528 for target in targets.iter() {
1529 cumulative += target.weight as u32;
1530 let base_url = target.url.as_str().to_owned();
1531 if cumulative > start_idx && circuit_breaker.is_available(&base_url) {
1532 load_balancer.counters[0].fetch_add(1, Ordering::Relaxed);
1533 let resolved = resolve_target_url(target, path, &route.resolution);
1534 return Some((resolved, base_url));
1535 }
1536 }
1537
1538 for target in targets {
1540 let base_url = target.url.as_str().to_owned();
1541 if circuit_breaker.is_available(&base_url) {
1542 let resolved = resolve_target_url(target, path, &route.resolution);
1543 return Some((resolved, base_url));
1544 }
1545 }
1546 None
1547 }
1548 }
1549}
1550
1551fn find_target(
1553 req: &Request<Incoming>,
1554 rules: &[crate::config::ProxyRule],
1555) -> Option<(String, bool, Option<String>, Vec<String>)> {
1556 let route = find_matching_rule(req, rules)?;
1557 let path = req.uri().path();
1558 let target = route.targets.first()?;
1559 let resolved = resolve_target_url(target, path, &route.resolution);
1560 let matched_prefix = route.matched_prefix();
1561 Some((
1562 resolved,
1563 route.from_domain_rule,
1564 matched_prefix,
1565 route.route_scripts,
1566 ))
1567}
1568
1569#[cfg(test)]
1570mod tests {
1571 use super::*;
1572
1573 #[test]
1574 fn test_load_balancer_state_select_index() {
1575 let lb = LoadBalancerState::new(1);
1576
1577 assert_eq!(lb.select_index(0, 3), 0);
1579 assert_eq!(lb.select_index(0, 3), 1);
1581 assert_eq!(lb.select_index(0, 3), 2);
1583 assert_eq!(lb.select_index(0, 3), 0);
1585 }
1586
1587 #[test]
1588 fn test_load_balancer_state_zero_targets() {
1589 let lb = LoadBalancerState::new(1);
1590 assert_eq!(lb.select_index(0, 0), 0);
1591 }
1592}