1#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::auth;
7use crate::circuit_breaker::SharedCircuitBreaker;
8use crate::config::ConfigManager;
9use crate::metrics::SharedMetrics;
10use crate::shutdown::ShutdownCoordinator;
11use anyhow::Result;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Incoming;
15use hyper::header::HeaderValue;
16use hyper::service::service_fn;
17use hyper::Request;
18use hyper::Response;
19use hyper_util::client::legacy::connect::HttpConnector;
20use hyper_util::client::legacy::Client;
21use hyper_util::rt::TokioExecutor;
22use hyper_util::rt::TokioIo;
23use socket2::{Domain, Protocol, Socket, Type};
24use std::net::SocketAddr;
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::io::AsyncWriteExt;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::TlsAcceptor;
30
31#[cfg(feature = "scripting")]
32use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
33
34type ClientType = Client<HttpConnector, Incoming>;
35type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
36
37#[cfg(feature = "scripting")]
38type OptionalLuaEngine = Option<LuaEngine>;
39#[cfg(not(feature = "scripting"))]
40type OptionalLuaEngine = ();
41
42fn record_app_metrics(
44 metrics: &SharedMetrics,
45 app_manager: &Option<Arc<AppManager>>,
46 target_url: &str,
47 bytes_in: u64,
48 bytes_out: u64,
49 status: u16,
50 duration: Duration,
51) {
52 if let Some(ref manager) = app_manager {
53 if let Ok(url) = url::Url::parse(target_url) {
54 if let Some(port) = url.port() {
55 if let Some(app_name) = futures::executor::block_on(manager.get_app_name(port)) {
56 metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
57 }
58 }
59 }
60 }
61}
62
63static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
65 std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
66
67fn verify_basic_auth(req: &Request<Incoming>, auth_entries: &[crate::auth::BasicAuth]) -> bool {
70 if auth_entries.is_empty() {
71 return true;
72 }
73
74 let auth_header = req.headers().get("authorization");
75 if auth_header.is_none() {
76 return false;
77 }
78
79 let header_value = auth_header.unwrap().to_str().unwrap_or("");
80 if !header_value.starts_with("Basic ") {
81 return false;
82 }
83
84 let encoded = &header_value[6..];
85 let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded)
86 .unwrap_or_default();
87 let creds = String::from_utf8_lossy(&decoded);
88
89 if let Some((username, password)) = creds.split_once(':') {
90 for entry in auth_entries {
91 if entry.username == username && auth::verify_password(password, &entry.hash) {
92 return true;
93 }
94 }
95 }
96
97 false
98}
99
100fn create_auth_required_response() -> Response<BoxBody> {
102 let body = http_body_util::Full::new(Bytes::from("Authentication required")).boxed();
103 Response::builder()
104 .status(401)
105 .header("WWW-Authenticate", "Basic realm=\"Restricted\"")
106 .body(body)
107 .unwrap()
108}
109
110fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
111 let domain = if addr.is_ipv4() {
112 Domain::IPV4
113 } else {
114 Domain::IPV6
115 };
116 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
117 socket.set_reuse_address(true)?;
118 socket.set_reuse_port(true)?;
119 socket.set_nonblocking(true)?;
120 socket.bind(&addr.into())?;
121 socket.listen(8192)?;
122 let std_listener: std::net::TcpListener = socket.into();
123 Ok(TcpListener::from_std(std_listener)?)
124}
125
126fn create_client() -> ClientType {
127 let exec = TokioExecutor::new();
128 let mut connector = HttpConnector::new();
129 connector.set_nodelay(true);
130 connector.set_keepalive(Some(Duration::from_secs(30)));
131 connector.set_connect_timeout(Some(Duration::from_secs(5)));
132 Client::builder(exec)
133 .pool_max_idle_per_host(256)
134 .pool_idle_timeout(Duration::from_secs(60))
135 .build(connector)
136}
137
138pub struct ProxyServer {
139 config: Arc<ConfigManager>,
140 shutdown: ShutdownCoordinator,
141 tls_acceptor: Option<TlsAcceptor>,
142 https_addr: Option<SocketAddr>,
143 metrics: SharedMetrics,
144 challenge_store: ChallengeStore,
145 lua_engine: OptionalLuaEngine,
146 circuit_breaker: SharedCircuitBreaker,
147 app_manager: Option<Arc<AppManager>>,
148}
149
150impl ProxyServer {
151 pub fn new(
152 config: Arc<ConfigManager>,
153 shutdown: ShutdownCoordinator,
154 metrics: SharedMetrics,
155 challenge_store: ChallengeStore,
156 lua_engine: OptionalLuaEngine,
157 circuit_breaker: SharedCircuitBreaker,
158 app_manager: Option<Arc<AppManager>>,
159 ) -> Result<Self> {
160 Ok(Self {
161 config,
162 shutdown,
163 tls_acceptor: None,
164 https_addr: None,
165 metrics,
166 challenge_store,
167 lua_engine,
168 circuit_breaker,
169 app_manager,
170 })
171 }
172
173 #[allow(clippy::too_many_arguments)]
174 pub fn with_https(
175 config: Arc<ConfigManager>,
176 shutdown: ShutdownCoordinator,
177 tls_acceptor: TlsAcceptor,
178 https_addr: SocketAddr,
179 metrics: SharedMetrics,
180 challenge_store: ChallengeStore,
181 lua_engine: OptionalLuaEngine,
182 circuit_breaker: SharedCircuitBreaker,
183 app_manager: Option<Arc<AppManager>>,
184 ) -> Result<Self> {
185 Ok(Self {
186 config,
187 shutdown,
188 tls_acceptor: Some(tls_acceptor),
189 https_addr: Some(https_addr),
190 metrics,
191 challenge_store,
192 lua_engine,
193 circuit_breaker,
194 app_manager,
195 })
196 }
197
198 pub async fn run(&self) -> Result<()> {
199 let cfg = self.config.get_config();
200 let http_addr: SocketAddr = cfg.server.bind.parse()?;
201 let https_addr = self.https_addr;
202
203 let has_https = https_addr.is_some();
204 let num_listeners = std::thread::available_parallelism()
205 .map(|n| n.get())
206 .unwrap_or(4);
207
208 let app_manager = self.app_manager.clone();
211 for i in 0..num_listeners {
212 let config_clone = self.config.clone();
213 let shutdown_clone = self.shutdown.clone();
214 let metrics_clone = self.metrics.clone();
215 let challenge_store_clone = self.challenge_store.clone();
216 let lua_clone = self.lua_engine.clone();
217 let cb_clone = self.circuit_breaker.clone();
218 let am_clone = app_manager.clone();
219
220 tokio::spawn(async move {
221 if let Err(e) = run_http_server(
222 http_addr,
223 config_clone,
224 shutdown_clone,
225 metrics_clone,
226 challenge_store_clone,
227 lua_clone,
228 cb_clone,
229 am_clone,
230 )
231 .await
232 {
233 tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
234 }
235 });
236 }
237
238 if let Some(https_addr) = https_addr {
239 for i in 0..num_listeners {
240 let config_clone = self.config.clone();
241 let shutdown_clone = self.shutdown.clone();
242 let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
243 let metrics_clone = self.metrics.clone();
244 let challenge_store_clone = self.challenge_store.clone();
245 let lua_clone = self.lua_engine.clone();
246 let cb_clone = self.circuit_breaker.clone();
247 let am_clone = app_manager.clone();
248
249 tokio::spawn(async move {
250 if let Err(e) = run_https_server(
251 https_addr,
252 config_clone,
253 shutdown_clone,
254 acceptor,
255 metrics_clone,
256 challenge_store_clone,
257 lua_clone,
258 cb_clone,
259 am_clone,
260 )
261 .await
262 {
263 tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
264 }
265 });
266 }
267 }
268
269 tracing::info!(
270 "HTTP/1.1 server listening on {} ({} accept loops)",
271 http_addr,
272 num_listeners
273 );
274 if has_https {
275 tracing::info!(
276 "HTTPS/2 server listening on {} ({} accept loops)",
277 https_addr.unwrap(),
278 num_listeners
279 );
280 }
281
282 loop {
283 if self.shutdown.is_shutting_down() {
284 tracing::info!("Shutting down servers...");
285 break;
286 }
287 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
288 }
289
290 Ok(())
291 }
292}
293
294#[allow(clippy::too_many_arguments)]
295async fn run_http_server(
296 addr: SocketAddr,
297 config: Arc<ConfigManager>,
298 shutdown: ShutdownCoordinator,
299 metrics: SharedMetrics,
300 challenge_store: ChallengeStore,
301 lua_engine: OptionalLuaEngine,
302 circuit_breaker: SharedCircuitBreaker,
303 app_manager: Option<Arc<AppManager>>,
304) -> Result<()> {
305 let listener = create_listener(addr)?;
306 let client = create_client();
307
308 loop {
309 if shutdown.is_shutting_down() {
310 break;
311 }
312
313 match listener.accept().await {
314 Ok((stream, _)) => {
315 let _ = stream.set_nodelay(true);
316 let client = client.clone();
317 let config = config.clone();
318 let metrics = metrics.clone();
319 let cs = challenge_store.clone();
320 let lua = lua_engine.clone();
321 let cb = circuit_breaker.clone();
322 let am = app_manager.clone();
323 tokio::spawn(async move {
324 if let Err(e) =
325 handle_http11_connection(stream, client, config, metrics, cs, lua, cb, am)
326 .await
327 {
328 tracing::debug!("HTTP/1.1 connection error: {}", e);
329 }
330 });
331 }
332 Err(e) => {
333 tracing::error!("HTTP/1.1 accept error: {}", e);
334 }
335 }
336 }
337
338 Ok(())
339}
340
341#[allow(clippy::too_many_arguments)]
342async fn run_https_server(
343 addr: SocketAddr,
344 config: Arc<ConfigManager>,
345 shutdown: ShutdownCoordinator,
346 acceptor: TlsAcceptor,
347 metrics: SharedMetrics,
348 challenge_store: ChallengeStore,
349 lua_engine: OptionalLuaEngine,
350 circuit_breaker: SharedCircuitBreaker,
351 app_manager: Option<Arc<AppManager>>,
352) -> Result<()> {
353 let listener = create_listener(addr)?;
354 let client = create_client();
355
356 loop {
357 if shutdown.is_shutting_down() {
358 break;
359 }
360
361 match listener.accept().await {
362 Ok((stream, _)) => {
363 let _ = stream.set_nodelay(true);
364 let client = client.clone();
365 let config = config.clone();
366 let acceptor = acceptor.clone();
367 let metrics = metrics.clone();
368 let cs = challenge_store.clone();
369 let lua = lua_engine.clone();
370 let cb = circuit_breaker.clone();
371 let am = app_manager.clone();
372 tokio::spawn(async move {
373 match acceptor.accept(stream).await {
374 Ok(tls_stream) => {
375 metrics.inc_tls_connections();
376 if let Err(e) = handle_https2_connection(
377 tls_stream, client, config, metrics, cs, lua, cb, am,
378 )
379 .await
380 {
381 tracing::debug!("HTTPS/2 connection error: {}", e);
382 }
383 }
384 Err(e) => {
385 tracing::error!("TLS accept error: {}", e);
386 }
387 }
388 });
389 }
390 Err(e) => {
391 tracing::error!("HTTPS/2 accept error: {}", e);
392 }
393 }
394 }
395
396 Ok(())
397}
398
399#[allow(clippy::too_many_arguments)]
400async fn handle_http11_connection(
401 stream: tokio::net::TcpStream,
402 client: ClientType,
403 config: Arc<ConfigManager>,
404 metrics: SharedMetrics,
405 challenge_store: ChallengeStore,
406 lua_engine: OptionalLuaEngine,
407 circuit_breaker: SharedCircuitBreaker,
408 app_manager: Option<Arc<AppManager>>,
409) -> Result<()> {
410 let io = TokioIo::new(stream);
411 let svc = service_fn(move |req| {
412 handle_request(
413 req,
414 client.clone(),
415 config.clone(),
416 metrics.clone(),
417 challenge_store.clone(),
418 lua_engine.clone(),
419 circuit_breaker.clone(),
420 app_manager.clone(),
421 )
422 });
423
424 let conn = hyper::server::conn::http1::Builder::new()
425 .keep_alive(true)
426 .pipeline_flush(true)
427 .serve_connection(io, svc)
428 .with_upgrades();
429
430 if let Err(e) = conn.await {
431 tracing::debug!("HTTP/1.1 connection error: {}", e);
432 }
433
434 Ok(())
435}
436
437#[allow(clippy::too_many_arguments)]
438async fn handle_https2_connection(
439 stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
440 client: ClientType,
441 config: Arc<ConfigManager>,
442 metrics: SharedMetrics,
443 challenge_store: ChallengeStore,
444 lua_engine: OptionalLuaEngine,
445 circuit_breaker: SharedCircuitBreaker,
446 app_manager: Option<Arc<AppManager>>,
447) -> Result<()> {
448 let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
449
450 let io = TokioIo::new(stream);
451
452 if is_h2 {
453 let exec = TokioExecutor::new();
454 let svc = service_fn(move |req| {
455 handle_request(
456 req,
457 client.clone(),
458 config.clone(),
459 metrics.clone(),
460 challenge_store.clone(),
461 lua_engine.clone(),
462 circuit_breaker.clone(),
463 app_manager.clone(),
464 )
465 });
466 let conn = hyper::server::conn::http2::Builder::new(exec)
467 .initial_stream_window_size(1024 * 1024)
468 .initial_connection_window_size(2 * 1024 * 1024)
469 .max_concurrent_streams(250)
470 .serve_connection(io, svc);
471 if let Err(e) = conn.await {
472 tracing::debug!("HTTPS/2 connection error: {}", e);
473 }
474 } else {
475 let svc = service_fn(move |req| {
476 handle_request(
477 req,
478 client.clone(),
479 config.clone(),
480 metrics.clone(),
481 challenge_store.clone(),
482 lua_engine.clone(),
483 circuit_breaker.clone(),
484 app_manager.clone(),
485 )
486 });
487 let conn = hyper::server::conn::http1::Builder::new()
488 .keep_alive(true)
489 .pipeline_flush(true)
490 .serve_connection(io, svc)
491 .with_upgrades();
492 if let Err(e) = conn.await {
493 tracing::debug!("HTTPS/1.1 connection error: {}", e);
494 }
495 }
496
497 Ok(())
498}
499
500#[cfg(feature = "scripting")]
502fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
503 req.headers()
504 .iter()
505 .map(|(k, v)| {
506 (
507 k.as_str().to_lowercase(),
508 v.to_str().unwrap_or("").to_string(),
509 )
510 })
511 .collect()
512}
513
514#[cfg(feature = "scripting")]
516fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
517 let host = req
518 .uri()
519 .host()
520 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
521 .unwrap_or("")
522 .to_string();
523
524 let content_length = req
525 .headers()
526 .get("content-length")
527 .and_then(|v| v.to_str().ok())
528 .and_then(|v| v.parse().ok())
529 .unwrap_or(0);
530
531 LuaRequest {
532 method: req.method().to_string(),
533 path: req.uri().path().to_string(),
534 headers: extract_headers(req),
535 host,
536 content_length,
537 }
538}
539
540#[cfg(feature = "scripting")]
542fn extract_response_headers(
543 headers: &hyper::HeaderMap,
544) -> std::collections::HashMap<String, String> {
545 headers
546 .iter()
547 .map(|(k, v)| {
548 (
549 k.as_str().to_lowercase(),
550 v.to_str().unwrap_or("").to_string(),
551 )
552 })
553 .collect()
554}
555
556#[allow(clippy::too_many_arguments)]
557async fn handle_request(
558 req: Request<Incoming>,
559 client: ClientType,
560 config_manager: Arc<ConfigManager>,
561 metrics: SharedMetrics,
562 challenge_store: ChallengeStore,
563 lua_engine: OptionalLuaEngine,
564 circuit_breaker: SharedCircuitBreaker,
565 app_manager: Option<Arc<AppManager>>,
566) -> Result<Response<BoxBody>, hyper::Error> {
567 let start_time = std::time::Instant::now();
568 metrics.inc_in_flight();
569 let config = config_manager.get_config();
570
571 if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
573 metrics.dec_in_flight();
574 return Ok(response);
575 }
576
577 if is_metrics_request(&req) {
578 let duration = start_time.elapsed();
579 metrics.dec_in_flight();
580 let metrics_output = metrics.format_metrics();
581 metrics.record_request(0, metrics_output.len() as u64, 200, duration);
582 let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
583 return Ok(Response::builder()
584 .status(200)
585 .header("Content-Type", "text/plain")
586 .body(body)
587 .unwrap());
588 }
589
590 #[cfg(feature = "scripting")]
592 if let Some(ref engine) = lua_engine {
593 if engine.has_on_request() {
594 let mut lua_req = build_lua_request(&req);
595 match engine.call_on_request(&mut lua_req) {
596 RequestHookResult::Deny { status, body } => {
597 metrics.dec_in_flight();
598 let duration = start_time.elapsed();
599 metrics.record_request(0, body.len() as u64, status, duration);
600 let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
601 return Ok(Response::builder().status(status).body(resp_body).unwrap());
602 }
603 RequestHookResult::Continue(updated_req) => {
604 let _ = updated_req;
610 }
611 }
612 }
613 }
614
615 let is_websocket = is_websocket_request(&req);
616
617 if is_websocket {
618 return handle_websocket_request(
619 req,
620 client,
621 &config,
622 &metrics,
623 start_time,
624 app_manager.clone(),
625 )
626 .await;
627 }
628
629 let result = handle_regular_request(
630 req,
631 client,
632 &config,
633 &lua_engine,
634 &circuit_breaker,
635 app_manager.clone(),
636 )
637 .await;
638 let duration = start_time.elapsed();
639
640 metrics.dec_in_flight();
641
642 match result {
643 #[allow(unused_variables)]
644 Ok((response, _target_url, route_scripts)) => {
645 let status = response.status().as_u16();
646
647 #[cfg(feature = "scripting")]
649 if let Some(ref engine) = lua_engine {
650 let lua_req = LuaRequest {
651 method: String::new(),
652 path: String::new(),
653 headers: std::collections::HashMap::new(),
654 host: String::new(),
655 content_length: 0,
656 };
657 let duration_ms = duration.as_secs_f64() * 1000.0;
658
659 if engine.has_on_request_end() {
661 engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
662 }
663
664 for script_name in &route_scripts {
666 engine.call_route_on_request_end(
667 script_name,
668 &lua_req,
669 status,
670 duration_ms,
671 &_target_url,
672 );
673 }
674 }
675
676 metrics.record_request(0, 0, status, duration);
677 record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration);
678 let (parts, body) = response.into_parts();
679 let boxed = body.map_err(|_| unreachable!()).boxed();
680 Ok(Response::from_parts(parts, boxed))
681 }
682 Err(e) => {
683 metrics.inc_errors();
684 Err(e)
685 }
686 }
687}
688
689fn is_websocket_request(req: &Request<Incoming>) -> bool {
690 if let Some(upgrade) = req.headers().get("upgrade") {
691 if upgrade == "websocket" {
692 return true;
693 }
694 }
695 false
696}
697
698fn is_metrics_request(req: &Request<Incoming>) -> bool {
699 req.uri().path() == "/metrics"
700}
701
702fn handle_acme_challenge(
703 req: &Request<Incoming>,
704 challenge_store: &ChallengeStore,
705) -> Option<Response<BoxBody>> {
706 let path = req.uri().path();
707 let prefix = "/.well-known/acme-challenge/";
708
709 if !path.starts_with(prefix) {
710 return None;
711 }
712
713 let token = &path[prefix.len()..];
714
715 if let Ok(store) = challenge_store.read() {
716 if let Some(key_auth) = store.get(token) {
717 let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
718 return Some(
719 Response::builder()
720 .status(200)
721 .header("Content-Type", "text/plain")
722 .body(body)
723 .unwrap(),
724 );
725 }
726 }
727
728 let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
729 Some(Response::builder().status(404).body(body).unwrap())
730}
731
732async fn handle_websocket_request(
733 req: Request<Incoming>,
734 _client: ClientType,
735 config: &crate::config::Config,
736 metrics: &SharedMetrics,
737 _start_time: std::time::Instant,
738 _app_manager: Option<Arc<AppManager>>,
739) -> Result<Response<BoxBody>, hyper::Error> {
740 let target_result = find_target(&req, &config.rules);
741
742 if target_result.is_none() {
743 metrics.inc_errors();
744 let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
745 return Ok(Response::builder().status(421).body(body).unwrap());
746 }
747
748 let (target_url, _, _, _) = target_result.unwrap();
749
750 let backend_addr = match url::Url::parse(&target_url) {
752 Ok(u) => format!(
753 "{}:{}",
754 u.host_str().unwrap_or("127.0.0.1"),
755 u.port().unwrap_or(80)
756 ),
757 Err(_) => {
758 metrics.inc_errors();
759 let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
760 return Ok(Response::builder().status(502).body(body).unwrap());
761 }
762 };
763
764 let path = req.uri().path().to_string();
765 let query = req
766 .uri()
767 .query()
768 .map(|q| format!("?{}", q))
769 .unwrap_or_default();
770
771 let ws_key = req
772 .headers()
773 .get("sec-websocket-key")
774 .and_then(|v| v.to_str().ok())
775 .unwrap_or("")
776 .to_string();
777 let ws_version = req
778 .headers()
779 .get("sec-websocket-version")
780 .and_then(|v| v.to_str().ok())
781 .unwrap_or("13")
782 .to_string();
783 let ws_protocol = req
784 .headers()
785 .get("sec-websocket-protocol")
786 .and_then(|v| v.to_str().ok())
787 .map(|s| s.to_string());
788 let host_header = req
789 .headers()
790 .get("host")
791 .and_then(|v| v.to_str().ok())
792 .unwrap_or(&backend_addr)
793 .to_string();
794
795 tracing::info!(
796 "WebSocket upgrade request to {}{}{}",
797 backend_addr,
798 path,
799 query
800 );
801
802 let backend = match TcpStream::connect(&backend_addr).await {
804 Ok(s) => s,
805 Err(e) => {
806 tracing::error!("Failed to connect to backend for WebSocket: {}", e);
807 metrics.inc_errors();
808 let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
809 return Ok(Response::builder().status(502).body(body).unwrap());
810 }
811 };
812
813 let mut handshake = format!(
815 "GET {}{} HTTP/1.1\r\n\
816 Host: {}\r\n\
817 Upgrade: websocket\r\n\
818 Connection: Upgrade\r\n\
819 Sec-WebSocket-Key: {}\r\n\
820 Sec-WebSocket-Version: {}\r\n",
821 path, query, host_header, ws_key, ws_version,
822 );
823 if let Some(proto) = &ws_protocol {
824 handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
825 }
826 handshake.push_str("\r\n");
827
828 let (mut backend_read, mut backend_write) = backend.into_split();
829 if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
830 tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
831 metrics.inc_errors();
832 let body =
833 http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
834 .boxed();
835 return Ok(Response::builder().status(502).body(body).unwrap());
836 }
837
838 let mut response_buf = vec![0u8; 4096];
840 let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
841 Ok(n) if n > 0 => n,
842 _ => {
843 tracing::error!("No response from backend for WebSocket upgrade");
844 metrics.inc_errors();
845 let body = http_body_util::Full::new(Bytes::from(
846 "Backend did not respond to WebSocket upgrade",
847 ))
848 .boxed();
849 return Ok(Response::builder().status(502).body(body).unwrap());
850 }
851 };
852
853 let response_str = String::from_utf8_lossy(&response_buf[..n]);
854 if !response_str.contains("101") {
855 tracing::error!(
856 "Backend rejected WebSocket upgrade: {}",
857 response_str.lines().next().unwrap_or("")
858 );
859 metrics.inc_errors();
860 let body =
861 http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
862 return Ok(Response::builder().status(502).body(body).unwrap());
863 }
864
865 let mut accept_key = String::new();
867 let mut resp_protocol = None;
868 for line in response_str.lines().skip(1) {
869 if line.trim().is_empty() {
870 break;
871 }
872 if let Some((name, value)) = line.split_once(':') {
873 let name_lower = name.trim().to_lowercase();
874 let value = value.trim().to_string();
875 if name_lower == "sec-websocket-accept" {
876 accept_key = value;
877 } else if name_lower == "sec-websocket-protocol" {
878 resp_protocol = Some(value);
879 }
880 }
881 }
882
883 let client_upgrade = hyper::upgrade::on(req);
885
886 let backend_stream = backend_read.reunite(backend_write).unwrap();
888
889 tokio::spawn(async move {
891 match client_upgrade.await {
892 Ok(upgraded) => {
893 let mut client_stream = TokioIo::new(upgraded);
894 let (mut br, mut bw) = tokio::io::split(backend_stream);
895 let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
896 let _ = tokio::join!(
897 tokio::io::copy(&mut br, &mut cw),
898 tokio::io::copy(&mut cr, &mut bw),
899 );
900 }
901 Err(e) => {
902 tracing::error!("WebSocket client upgrade failed: {}", e);
903 }
904 }
905 });
906
907 let mut resp = Response::builder()
909 .status(101)
910 .header("Upgrade", "websocket")
911 .header("Connection", "Upgrade")
912 .header("Sec-WebSocket-Accept", accept_key);
913 if let Some(proto) = resp_protocol {
914 resp = resp.header("Sec-WebSocket-Protocol", proto);
915 }
916 Ok(resp
917 .body(http_body_util::Full::new(Bytes::new()).boxed())
918 .unwrap())
919}
920
921async fn handle_regular_request(
923 req: Request<Incoming>,
924 client: ClientType,
925 config: &crate::config::Config,
926 lua_engine: &OptionalLuaEngine,
927 circuit_breaker: &SharedCircuitBreaker,
928 _app_manager: Option<Arc<AppManager>>,
929) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
930 let route = find_matching_rule(&req, &config.rules);
931
932 match route {
933 #[allow(unused_mut, unused_variables)]
934 Some(matched_route) => {
935 let path = req.uri().path().to_string();
936 let from_domain_rule = matched_route.from_domain_rule;
937 let matched_prefix = matched_route.matched_prefix();
938
939 if !matched_route.auth.is_empty() && !verify_basic_auth(&req, &matched_route.auth) {
940 tracing::debug!("Basic auth failed for {}", req.uri().path());
941 return Ok((create_auth_required_response(), String::new(), vec![]));
942 }
943 let route_scripts = matched_route.route_scripts.clone();
944
945 let target_selection = select_target(&matched_route, &path, circuit_breaker);
947 let (mut target_url, base_url) = match target_selection {
948 Some((url, base)) => (url, base),
949 None => {
950 let body =
952 http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
953 return Ok((
954 Response::builder()
955 .status(503)
956 .body(body)
957 .expect("Failed to build response"),
958 String::new(),
959 route_scripts,
960 ));
961 }
962 };
963 #[cfg(feature = "scripting")]
965 if let Some(ref engine) = lua_engine {
966 for script_name in &route_scripts {
967 let mut lua_req = build_lua_request(&req);
968 match engine.call_route_on_request(script_name, &mut lua_req) {
969 RequestHookResult::Deny { status, body } => {
970 let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
971 return Ok((
972 Response::builder().status(status).body(resp_body).unwrap(),
973 target_url,
974 route_scripts.clone(),
975 ));
976 }
977 RequestHookResult::Continue(_) => {}
978 }
979 }
980 }
981
982 #[cfg(feature = "scripting")]
984 if let Some(ref engine) = lua_engine {
985 if engine.has_on_route() {
986 let lua_req = build_lua_request(&req);
987 match engine.call_on_route(&lua_req, &target_url) {
988 RouteHookResult::Override(new_url) => {
989 target_url = new_url;
990 }
991 RouteHookResult::Default => {}
992 }
993 }
994 for script_name in &route_scripts {
996 let lua_req = build_lua_request(&req);
997 match engine.call_route_on_route(script_name, &lua_req, &target_url) {
998 RouteHookResult::Override(new_url) => {
999 target_url = new_url;
1000 }
1001 RouteHookResult::Default => {}
1002 }
1003 }
1004 }
1005
1006 let host_header = if from_domain_rule {
1008 req.uri()
1009 .host()
1010 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1011 .map(|s| s.to_string())
1012 } else {
1013 None
1014 };
1015
1016 let (mut parts, body) = req.into_parts();
1017
1018 let uri: hyper::Uri = target_url.parse().expect("valid URI");
1020 parts.uri = uri;
1021 parts.version = http::Version::HTTP_11;
1022 parts.extensions = http::Extensions::new();
1023
1024 let mut request = Request::from_parts(parts, body);
1025
1026 request
1027 .headers_mut()
1028 .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
1029
1030 if from_domain_rule {
1031 if let Some(host) = host_header {
1032 request
1033 .headers_mut()
1034 .insert("X-Forwarded-Host", host.parse().unwrap());
1035 }
1036 }
1037
1038 match client.request(request).await {
1039 Ok(response) => {
1040 let status_code = response.status().as_u16();
1042 if circuit_breaker.is_failure_status(status_code) {
1043 circuit_breaker.record_failure(&base_url);
1044 } else {
1045 circuit_breaker.record_success(&base_url);
1046 }
1047
1048 #[cfg(feature = "scripting")]
1050 if let Some(ref engine) = lua_engine {
1051 let has_global = engine.has_on_response();
1052 let has_route = !route_scripts.is_empty();
1053
1054 if has_global || has_route {
1055 use crate::scripting::ResponseMod;
1056
1057 let lua_req = LuaRequest {
1058 method: String::new(),
1059 path: String::new(),
1060 headers: std::collections::HashMap::new(),
1061 host: String::new(),
1062 content_length: 0,
1063 };
1064 let resp_headers = extract_response_headers(response.headers());
1065 let resp_status = response.status().as_u16();
1066
1067 let mut all_mods: Vec<ResponseMod> = Vec::new();
1069 if has_global {
1070 all_mods.push(engine.call_on_response(
1071 &lua_req,
1072 resp_status,
1073 &resp_headers,
1074 ));
1075 }
1076 for script_name in &route_scripts {
1077 all_mods.push(engine.call_route_on_response(
1078 script_name,
1079 &lua_req,
1080 resp_status,
1081 &resp_headers,
1082 ));
1083 }
1084
1085 let mut merged = ResponseMod::default();
1087 for mods in all_mods {
1088 merged.set_headers.extend(mods.set_headers);
1089 merged.remove_headers.extend(mods.remove_headers);
1090 if mods.replace_body.is_some() {
1091 merged.replace_body = mods.replace_body;
1092 }
1093 if mods.override_status.is_some() {
1094 merged.override_status = mods.override_status;
1095 }
1096 }
1097
1098 if !merged.set_headers.is_empty()
1100 || !merged.remove_headers.is_empty()
1101 || merged.replace_body.is_some()
1102 || merged.override_status.is_some()
1103 {
1104 let (mut parts, body) = response.into_parts();
1105
1106 if let Some(status) = merged.override_status {
1107 parts.status =
1108 hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1109 }
1110
1111 for name in &merged.remove_headers {
1112 if let Ok(header_name) =
1113 name.parse::<hyper::header::HeaderName>()
1114 {
1115 parts.headers.remove(header_name);
1116 }
1117 }
1118
1119 for (name, value) in &merged.set_headers {
1120 if let (Ok(header_name), Ok(header_value)) = (
1121 name.parse::<hyper::header::HeaderName>(),
1122 value.parse::<HeaderValue>(),
1123 ) {
1124 parts.headers.insert(header_name, header_value);
1125 }
1126 }
1127
1128 if let Some(new_body) = merged.replace_body {
1129 let new_bytes = Bytes::from(new_body);
1130 parts.headers.remove("content-length");
1131 parts.headers.insert(
1132 "content-length",
1133 new_bytes.len().to_string().parse().unwrap(),
1134 );
1135 let boxed = http_body_util::Full::new(new_bytes).boxed();
1136 return Ok((
1137 Response::from_parts(parts, boxed),
1138 target_url,
1139 route_scripts.clone(),
1140 ));
1141 }
1142
1143 let boxed = body.map_err(|_| unreachable!()).boxed();
1144 return Ok((
1145 Response::from_parts(parts, boxed),
1146 target_url,
1147 route_scripts.clone(),
1148 ));
1149 }
1150 }
1151 }
1152
1153 let is_html = response
1154 .headers()
1155 .get("content-type")
1156 .and_then(|v| v.to_str().ok())
1157 .map(|ct| ct.starts_with("text/html"))
1158 .unwrap_or(false);
1159
1160 if is_html {
1161 if let Some(prefix) = matched_prefix {
1162 let (parts, body) = response.into_parts();
1163 let body_bytes = body
1164 .collect()
1165 .await
1166 .map(|collected| collected.to_bytes())
1167 .unwrap_or_default();
1168
1169 let is_gzip = parts
1171 .headers
1172 .get("content-encoding")
1173 .and_then(|v| v.to_str().ok())
1174 .map(|v| v.contains("gzip"))
1175 .unwrap_or(false);
1176 let is_deflate = parts
1177 .headers
1178 .get("content-encoding")
1179 .and_then(|v| v.to_str().ok())
1180 .map(|v| v.contains("deflate"))
1181 .unwrap_or(false);
1182
1183 let raw_bytes = if is_gzip {
1184 use std::io::Read;
1185 let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1186 let mut decoded = Vec::new();
1187 decoder.read_to_end(&mut decoded).unwrap_or_default();
1188 Bytes::from(decoded)
1189 } else if is_deflate {
1190 use std::io::Read;
1191 let mut decoder =
1192 flate2::read::DeflateDecoder::new(&body_bytes[..]);
1193 let mut decoded = Vec::new();
1194 decoder.read_to_end(&mut decoded).unwrap_or_default();
1195 Bytes::from(decoded)
1196 } else {
1197 body_bytes
1198 };
1199
1200 let html = String::from_utf8_lossy(&raw_bytes);
1201 let rewritten = html
1202 .replace("href=\"/", &format!("href=\"{}/", prefix))
1203 .replace("src=\"/", &format!("src=\"{}/", prefix))
1204 .replace("action=\"/", &format!("action=\"{}/", prefix));
1205 let rewritten_bytes = Bytes::from(rewritten);
1206 let mut parts = parts;
1207 parts.headers.remove("content-encoding");
1208 parts.headers.remove("content-length");
1209 parts.headers.insert(
1210 "content-length",
1211 rewritten_bytes.len().to_string().parse().unwrap(),
1212 );
1213 let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1214 return Ok((
1215 Response::from_parts(parts, boxed),
1216 target_url,
1217 route_scripts.clone(),
1218 ));
1219 }
1220 }
1221
1222 let (parts, body) = response.into_parts();
1223 let boxed = body.map_err(|_| unreachable!()).boxed();
1224 Ok((
1225 Response::from_parts(parts, boxed),
1226 target_url,
1227 route_scripts,
1228 ))
1229 }
1230 Err(e) => {
1231 circuit_breaker.record_failure(&base_url);
1232 tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1233 let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1234 Ok((
1235 Response::builder()
1236 .status(502)
1237 .body(body)
1238 .expect("Failed to build response"),
1239 target_url,
1240 route_scripts,
1241 ))
1242 }
1243 }
1244 }
1245 None => {
1246 let _ = lua_engine;
1248 let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1249 Ok((
1250 Response::builder()
1251 .status(421)
1252 .body(body)
1253 .expect("Failed to build response"),
1254 String::new(),
1255 vec![],
1256 ))
1257 }
1258 }
1259}
1260
1261enum UrlResolution {
1263 AppendPath,
1265 StripPrefix(String),
1267 Identity,
1269}
1270
1271struct MatchedRoute<'a> {
1273 targets: &'a [crate::config::Target],
1274 from_domain_rule: bool,
1275 resolution: UrlResolution,
1276 route_scripts: Vec<String>,
1277 auth: Vec<crate::auth::BasicAuth>,
1278}
1279
1280impl<'a> MatchedRoute<'a> {
1281 fn matched_prefix(&self) -> Option<String> {
1282 match &self.resolution {
1283 UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1284 _ => None,
1285 }
1286 }
1287}
1288
1289fn resolve_target_url(
1291 target: &crate::config::Target,
1292 path: &str,
1293 resolution: &UrlResolution,
1294) -> String {
1295 let target_str = target.url.as_str();
1296 match resolution {
1297 UrlResolution::AppendPath => {
1298 if target_str.ends_with('/') {
1299 format!("{}{}", target_str, &path[1..])
1300 } else {
1301 format!("{}{}", target_str, path)
1302 }
1303 }
1304 UrlResolution::StripPrefix(prefix) => {
1305 let suffix = if path.len() >= prefix.len() {
1306 &path[prefix.len()..]
1307 } else {
1308 ""
1309 };
1310 format!("{}{}", target_str, suffix)
1311 }
1312 UrlResolution::Identity => target_str.to_owned(),
1313 }
1314}
1315
1316fn find_matching_rule<'a>(
1318 req: &Request<Incoming>,
1319 rules: &'a [crate::config::ProxyRule],
1320) -> Option<MatchedRoute<'a>> {
1321 let host = req
1322 .uri()
1323 .host()
1324 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1325 .map(|h| h.split(':').next().unwrap_or(h))?;
1326
1327 let path = req.uri().path();
1328
1329 for rule in rules {
1330 match &rule.matcher {
1331 crate::config::RuleMatcher::Domain(domain) => {
1332 if domain == host && !rule.targets.is_empty() {
1333 return Some(MatchedRoute {
1334 targets: &rule.targets,
1335 from_domain_rule: true,
1336 resolution: UrlResolution::AppendPath,
1337 route_scripts: rule.scripts.clone(),
1338 auth: rule.auth.clone(),
1339 });
1340 }
1341 }
1342 crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1343 if domain == host && !rule.targets.is_empty() {
1344 let matches = path.starts_with(path_prefix)
1345 || (path_prefix.ends_with('/')
1346 && path == path_prefix.trim_end_matches('/'));
1347 if matches {
1348 return Some(MatchedRoute {
1349 targets: &rule.targets,
1350 from_domain_rule: true,
1351 resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1352 route_scripts: rule.scripts.clone(),
1353 auth: rule.auth.clone(),
1354 });
1355 }
1356 }
1357 }
1358 _ => {}
1359 }
1360 }
1361
1362 for rule in rules {
1364 match &rule.matcher {
1365 crate::config::RuleMatcher::Exact(exact) => {
1366 if path == exact && !rule.targets.is_empty() {
1367 return Some(MatchedRoute {
1368 targets: &rule.targets,
1369 from_domain_rule: false,
1370 resolution: UrlResolution::Identity,
1371 route_scripts: rule.scripts.clone(),
1372 auth: rule.auth.clone(),
1373 });
1374 }
1375 }
1376 crate::config::RuleMatcher::Prefix(prefix) => {
1377 if !rule.targets.is_empty() {
1378 let matches = path.starts_with(prefix)
1380 || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1381 if matches {
1382 return Some(MatchedRoute {
1383 targets: &rule.targets,
1384 from_domain_rule: false,
1385 resolution: UrlResolution::StripPrefix(prefix.clone()),
1386 route_scripts: rule.scripts.clone(),
1387 auth: rule.auth.clone(),
1388 });
1389 }
1390 }
1391 }
1392 crate::config::RuleMatcher::Regex(ref rm) => {
1393 if rm.is_match(path) && !rule.targets.is_empty() {
1394 return Some(MatchedRoute {
1395 targets: &rule.targets,
1396 from_domain_rule: false,
1397 resolution: UrlResolution::Identity,
1398 route_scripts: rule.scripts.clone(),
1399 auth: rule.auth.clone(),
1400 });
1401 }
1402 }
1403 _ => {}
1404 }
1405 }
1406
1407 for rule in rules {
1409 if let crate::config::RuleMatcher::Default = &rule.matcher {
1410 if !rule.targets.is_empty() {
1411 return Some(MatchedRoute {
1412 targets: &rule.targets,
1413 from_domain_rule: false,
1414 resolution: UrlResolution::AppendPath,
1415 route_scripts: rule.scripts.clone(),
1416 auth: rule.auth.clone(),
1417 });
1418 }
1419 }
1420 }
1421
1422 None
1423}
1424
1425fn select_target(
1428 route: &MatchedRoute<'_>,
1429 path: &str,
1430 circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1431) -> Option<(String, String)> {
1432 for target in route.targets {
1433 let base_url = target.url.as_str().to_owned();
1434 if circuit_breaker.is_available(&base_url) {
1435 let resolved = resolve_target_url(target, path, &route.resolution);
1436 return Some((resolved, base_url));
1437 }
1438 }
1439 None
1440}
1441
1442fn find_target(
1444 req: &Request<Incoming>,
1445 rules: &[crate::config::ProxyRule],
1446) -> Option<(String, bool, Option<String>, Vec<String>)> {
1447 let route = find_matching_rule(req, rules)?;
1448 let path = req.uri().path();
1449 let target = route.targets.first()?;
1450 let resolved = resolve_target_url(target, path, &route.resolution);
1451 let matched_prefix = route.matched_prefix();
1452 Some((
1453 resolved,
1454 route.from_domain_rule,
1455 matched_prefix,
1456 route.route_scripts,
1457 ))
1458}