Skip to main content

zlayer_proxy/
server.rs

1//! HTTP server implementation
2//!
3//! This module provides the HTTP/HTTPS server for the proxy.
4//! Uses `ServiceRegistry` for route resolution instead of the legacy `Router`.
5
6use crate::acme::CertManager;
7use crate::config::ProxyConfig;
8use crate::error::{ProxyError, Result};
9use crate::lb::LoadBalancer;
10use crate::network_policy::NetworkPolicyChecker;
11use crate::routes::ServiceRegistry;
12use crate::service::ReverseProxyService;
13use crate::sni_resolver::SniCertResolver;
14use hyper::body::Incoming;
15use hyper::server::conn::http1;
16use hyper::service::service_fn;
17use hyper::Request;
18use hyper_util::rt::TokioIo;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tokio::net::TcpListener;
22use tokio::sync::watch;
23use tokio_rustls::TlsAcceptor;
24use tracing::{debug, error, info, warn};
25
26/// The proxy server
27pub struct ProxyServer {
28    /// Server configuration
29    config: Arc<ProxyConfig>,
30    /// Service registry for route resolution
31    registry: Arc<ServiceRegistry>,
32    /// Load balancer for backend selection
33    load_balancer: Arc<LoadBalancer>,
34    /// Shutdown signal sender
35    shutdown_tx: watch::Sender<bool>,
36    /// Shutdown signal receiver
37    shutdown_rx: watch::Receiver<bool>,
38    /// TLS acceptor for HTTPS connections
39    tls_acceptor: Option<TlsAcceptor>,
40    /// Certificate manager for ACME challenge responses
41    cert_manager: Option<Arc<CertManager>>,
42    /// Optional network policy checker for access control enforcement
43    network_policy_checker: Option<NetworkPolicyChecker>,
44}
45
46impl ProxyServer {
47    /// Create a new proxy server
48    pub fn new(
49        config: ProxyConfig,
50        registry: Arc<ServiceRegistry>,
51        load_balancer: Arc<LoadBalancer>,
52    ) -> Self {
53        let (shutdown_tx, shutdown_rx) = watch::channel(false);
54
55        Self {
56            config: Arc::new(config),
57            registry,
58            load_balancer,
59            shutdown_tx,
60            shutdown_rx,
61            tls_acceptor: None,
62            cert_manager: None,
63            network_policy_checker: None,
64        }
65    }
66
67    /// Create a proxy server with an existing registry (alias for `new`)
68    pub fn with_registry(
69        config: ProxyConfig,
70        registry: Arc<ServiceRegistry>,
71        load_balancer: Arc<LoadBalancer>,
72    ) -> Self {
73        Self::new(config, registry, load_balancer)
74    }
75
76    /// Create a proxy server with TLS via SNI resolver
77    pub fn with_tls_resolver(
78        config: ProxyConfig,
79        registry: Arc<ServiceRegistry>,
80        load_balancer: Arc<LoadBalancer>,
81        resolver: Arc<SniCertResolver>,
82    ) -> Self {
83        let tls_config = rustls::ServerConfig::builder()
84            .with_no_client_auth()
85            .with_cert_resolver(resolver);
86        let acceptor = TlsAcceptor::from(Arc::new(tls_config));
87        let (shutdown_tx, shutdown_rx) = watch::channel(false);
88
89        Self {
90            config: Arc::new(config),
91            registry,
92            load_balancer,
93            shutdown_tx,
94            shutdown_rx,
95            tls_acceptor: Some(acceptor),
96            cert_manager: None,
97            network_policy_checker: None,
98        }
99    }
100
101    /// Set the certificate manager for ACME challenge interception
102    #[must_use]
103    pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
104        self.cert_manager = Some(cm);
105        self
106    }
107
108    /// Set the network policy checker for access control enforcement
109    #[must_use]
110    pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
111        self.network_policy_checker = Some(checker);
112        self
113    }
114
115    /// Check if TLS is enabled
116    #[must_use]
117    pub fn has_tls(&self) -> bool {
118        self.tls_acceptor.is_some()
119    }
120
121    /// Get the TLS acceptor if configured
122    #[must_use]
123    pub fn tls_acceptor(&self) -> Option<&TlsAcceptor> {
124        self.tls_acceptor.as_ref()
125    }
126
127    /// Get the service registry
128    #[must_use]
129    pub fn registry(&self) -> Arc<ServiceRegistry> {
130        self.registry.clone()
131    }
132
133    /// Get the configuration
134    #[must_use]
135    pub fn config(&self) -> Arc<ProxyConfig> {
136        self.config.clone()
137    }
138
139    /// Signal the server to shut down
140    pub fn shutdown(&self) {
141        let _ = self.shutdown_tx.send(true);
142    }
143
144    /// Run the HTTP server
145    ///
146    /// # Errors
147    ///
148    /// Returns an error if binding to the configured HTTP address fails
149    /// or if the accept loop encounters a fatal error.
150    pub async fn run(&self) -> Result<()> {
151        let addr = self.config.server.http_addr;
152        let listener = TcpListener::bind(addr)
153            .await
154            .map_err(|e| ProxyError::BindFailed {
155                addr,
156                reason: e.to_string(),
157            })?;
158
159        info!(addr = %addr, "HTTP proxy server listening");
160
161        self.accept_loop(listener).await
162    }
163
164    /// Run the server on a specific address
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if binding to the given address fails
169    /// or if the accept loop encounters a fatal error.
170    pub async fn run_on(&self, addr: SocketAddr) -> Result<()> {
171        let listener = TcpListener::bind(addr)
172            .await
173            .map_err(|e| ProxyError::BindFailed {
174                addr,
175                reason: e.to_string(),
176            })?;
177
178        info!(addr = %addr, "HTTP proxy server listening");
179
180        self.accept_loop(listener).await
181    }
182
183    async fn accept_loop(&self, listener: TcpListener) -> Result<()> {
184        let mut shutdown_rx = self.shutdown_rx.clone();
185
186        loop {
187            tokio::select! {
188                // Check for shutdown signal
189                _ = shutdown_rx.changed() => {
190                    if *shutdown_rx.borrow() {
191                        info!("Shutting down proxy server");
192                        break;
193                    }
194                }
195
196                // Accept new connections
197                result = listener.accept() => {
198                    match result {
199                        Ok((stream, remote_addr)) => {
200                            let registry = self.registry.clone();
201                            let load_balancer = self.load_balancer.clone();
202                            let config = self.config.clone();
203                            let cert_manager = self.cert_manager.clone();
204                            let npc = self.network_policy_checker.clone();
205
206                            tokio::spawn(async move {
207                                if let Err(e) = Self::handle_connection(
208                                    stream,
209                                    remote_addr,
210                                    registry,
211                                    load_balancer,
212                                    config,
213                                    cert_manager,
214                                    npc,
215                                ).await {
216                                    debug!(
217                                        error = %e,
218                                        remote_addr = %remote_addr,
219                                        "Connection error"
220                                    );
221                                }
222                            });
223                        }
224                        Err(e) => {
225                            warn!(error = %e, "Failed to accept connection");
226                        }
227                    }
228                }
229            }
230        }
231
232        Ok(())
233    }
234
235    #[allow(clippy::too_many_arguments)]
236    async fn handle_connection(
237        stream: tokio::net::TcpStream,
238        remote_addr: SocketAddr,
239        registry: Arc<ServiceRegistry>,
240        load_balancer: Arc<LoadBalancer>,
241        config: Arc<ProxyConfig>,
242        cert_manager: Option<Arc<CertManager>>,
243        network_policy_checker: Option<NetworkPolicyChecker>,
244    ) -> Result<()> {
245        let io = TokioIo::new(stream);
246
247        let mut service =
248            ReverseProxyService::new(registry, load_balancer, config).with_remote_addr(remote_addr);
249        if let Some(cm) = cert_manager {
250            service = service.with_cert_manager(cm);
251        }
252        if let Some(checker) = network_policy_checker {
253            service = service.with_network_policy_checker(checker);
254        }
255
256        let service = service_fn(move |req: Request<Incoming>| {
257            let svc = service.clone();
258            async move {
259                match svc.proxy_request(req).await {
260                    Ok(response) => Ok::<_, hyper::Error>(response),
261                    Err(e) => {
262                        error!(error = %e, "Proxy error");
263                        Ok(ReverseProxyService::error_response(&e))
264                    }
265                }
266            }
267        });
268
269        http1::Builder::new()
270            .preserve_header_case(true)
271            .title_case_headers(false)
272            .serve_connection(io, service)
273            .with_upgrades()
274            .await
275            .map_err(ProxyError::Hyper)?;
276
277        Ok(())
278    }
279
280    /// Run the HTTPS server
281    ///
282    /// This requires TLS to be configured when creating the `ProxyServer`.
283    ///
284    /// # Errors
285    ///
286    /// Returns an error if TLS is not configured, if binding to the
287    /// configured HTTPS address fails, or if the accept loop encounters a fatal error.
288    pub async fn run_https(&self) -> Result<()> {
289        let acceptor = self
290            .tls_acceptor
291            .as_ref()
292            .ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
293
294        let addr = self.config.server.https_addr;
295        let listener = TcpListener::bind(addr)
296            .await
297            .map_err(|e| ProxyError::BindFailed {
298                addr,
299                reason: e.to_string(),
300            })?;
301
302        info!(addr = %addr, "HTTPS proxy server listening");
303
304        self.accept_loop_tls(listener, acceptor.clone()).await
305    }
306
307    /// Run the HTTPS server on a specific address
308    ///
309    /// # Errors
310    ///
311    /// Returns an error if TLS is not configured, if binding to the
312    /// given address fails, or if the accept loop encounters a fatal error.
313    pub async fn run_https_on(&self, addr: SocketAddr) -> Result<()> {
314        let acceptor = self
315            .tls_acceptor
316            .as_ref()
317            .ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
318
319        let listener = TcpListener::bind(addr)
320            .await
321            .map_err(|e| ProxyError::BindFailed {
322                addr,
323                reason: e.to_string(),
324            })?;
325
326        info!(addr = %addr, "HTTPS proxy server listening");
327
328        self.accept_loop_tls(listener, acceptor.clone()).await
329    }
330
331    /// Run both HTTP and HTTPS servers concurrently
332    ///
333    /// This requires TLS to be configured when creating the `ProxyServer`.
334    ///
335    /// # Errors
336    ///
337    /// Returns an error if TLS is not configured, if binding to either
338    /// the HTTP or HTTPS address fails, or if either accept loop encounters
339    /// a fatal error.
340    #[allow(clippy::similar_names)]
341    pub async fn run_both(&self) -> Result<()> {
342        let http_addr = self.config.server.http_addr;
343        let https_addr = self.config.server.https_addr;
344
345        let acceptor = self
346            .tls_acceptor
347            .as_ref()
348            .ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
349
350        let http_listener =
351            TcpListener::bind(http_addr)
352                .await
353                .map_err(|e| ProxyError::BindFailed {
354                    addr: http_addr,
355                    reason: e.to_string(),
356                })?;
357
358        let https_listener =
359            TcpListener::bind(https_addr)
360                .await
361                .map_err(|e| ProxyError::BindFailed {
362                    addr: https_addr,
363                    reason: e.to_string(),
364                })?;
365
366        info!(http = %http_addr, https = %https_addr, "Proxy server listening");
367
368        // Run both accept loops concurrently
369        let http_future = self.accept_loop(http_listener);
370        let https_future = self.accept_loop_tls(https_listener, acceptor.clone());
371
372        tokio::select! {
373            result = http_future => result,
374            result = https_future => result,
375        }
376    }
377
378    async fn accept_loop_tls(&self, listener: TcpListener, acceptor: TlsAcceptor) -> Result<()> {
379        let mut shutdown_rx = self.shutdown_rx.clone();
380
381        loop {
382            tokio::select! {
383                // Check for shutdown signal
384                _ = shutdown_rx.changed() => {
385                    if *shutdown_rx.borrow() {
386                        info!("Shutting down HTTPS proxy server");
387                        break;
388                    }
389                }
390
391                // Accept new connections
392                result = listener.accept() => {
393                    match result {
394                        Ok((stream, remote_addr)) => {
395                            let registry = self.registry.clone();
396                            let load_balancer = self.load_balancer.clone();
397                            let config = self.config.clone();
398                            let acceptor = acceptor.clone();
399                            let cert_manager = self.cert_manager.clone();
400                            let npc = self.network_policy_checker.clone();
401
402                            tokio::spawn(async move {
403                                if let Err(e) = Self::handle_tls_connection(
404                                    stream,
405                                    remote_addr,
406                                    registry,
407                                    load_balancer,
408                                    config,
409                                    acceptor,
410                                    cert_manager,
411                                    npc,
412                                ).await {
413                                    debug!(
414                                        error = %e,
415                                        remote_addr = %remote_addr,
416                                        "TLS connection error"
417                                    );
418                                }
419                            });
420                        }
421                        Err(e) => {
422                            warn!(error = %e, "Failed to accept TLS connection");
423                        }
424                    }
425                }
426            }
427        }
428
429        Ok(())
430    }
431
432    #[allow(clippy::too_many_arguments)]
433    async fn handle_tls_connection(
434        stream: tokio::net::TcpStream,
435        remote_addr: SocketAddr,
436        registry: Arc<ServiceRegistry>,
437        load_balancer: Arc<LoadBalancer>,
438        config: Arc<ProxyConfig>,
439        acceptor: TlsAcceptor,
440        cert_manager: Option<Arc<CertManager>>,
441        network_policy_checker: Option<NetworkPolicyChecker>,
442    ) -> Result<()> {
443        // Perform TLS handshake
444        let tls_stream = acceptor
445            .accept(stream)
446            .await
447            .map_err(|e| ProxyError::Tls(format!("TLS handshake failed: {e}")))?;
448
449        let io = TokioIo::new(tls_stream);
450
451        let mut service = ReverseProxyService::new(registry, load_balancer, config)
452            .with_remote_addr(remote_addr)
453            .with_tls(true);
454        if let Some(cm) = cert_manager {
455            service = service.with_cert_manager(cm);
456        }
457        if let Some(checker) = network_policy_checker {
458            service = service.with_network_policy_checker(checker);
459        }
460
461        let service = service_fn(move |req: Request<Incoming>| {
462            let svc = service.clone();
463            async move {
464                match svc.proxy_request(req).await {
465                    Ok(response) => Ok::<_, hyper::Error>(response),
466                    Err(e) => {
467                        error!(error = %e, "Proxy error");
468                        Ok(ReverseProxyService::error_response(&e))
469                    }
470                }
471            }
472        });
473
474        http1::Builder::new()
475            .preserve_header_case(true)
476            .title_case_headers(false)
477            .serve_connection(io, service)
478            .with_upgrades()
479            .await
480            .map_err(ProxyError::Hyper)?;
481
482        Ok(())
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::lb::LoadBalancer;
490    use crate::routes::{ResolvedService, RouteEntry};
491    use zlayer_spec::{ExposeType, Protocol};
492
493    /// Helper to build a minimal `RouteEntry` for tests.
494    fn make_entry(
495        service: &str,
496        host: Option<&str>,
497        path: &str,
498        backends: Vec<SocketAddr>,
499    ) -> RouteEntry {
500        RouteEntry {
501            service_name: service.to_string(),
502            endpoint_name: "http".to_string(),
503            host: host.map(std::string::ToString::to_string),
504            path_prefix: path.to_string(),
505            resolved: ResolvedService {
506                name: service.to_string(),
507                backends,
508                use_tls: false,
509                sni_hostname: String::new(),
510                expose: ExposeType::Public,
511                protocol: Protocol::Http,
512                strip_prefix: false,
513                path_prefix: path.to_string(),
514                target_port: 8080,
515            },
516        }
517    }
518
519    #[tokio::test]
520    async fn test_server_shutdown() {
521        let registry = Arc::new(ServiceRegistry::new());
522        let lb = Arc::new(LoadBalancer::new());
523        let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
524
525        // Create a separate handle for shutdown
526        let shutdown_tx = server.shutdown_tx.clone();
527
528        // Signal shutdown immediately
529        let _ = shutdown_tx.send(true);
530
531        // Server should exit gracefully
532        // (In a real test, we'd spawn the server and verify it stops)
533    }
534
535    #[tokio::test]
536    async fn test_registry_integration() {
537        let registry = Arc::new(ServiceRegistry::new());
538
539        // Add a route
540        registry
541            .register(make_entry(
542                "test-service",
543                None,
544                "/api",
545                vec!["127.0.0.1:8081".parse().unwrap()],
546            ))
547            .await;
548
549        let lb = Arc::new(LoadBalancer::new());
550        let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
551
552        // Verify registry is accessible
553        let reg = server.registry();
554        assert_eq!(reg.route_count().await, 1);
555    }
556}