trz_gateway_server/server/
mod.rs

1use std::future::Future;
2use std::net::SocketAddr;
3use std::net::ToSocketAddrs;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use axum_server::Handle;
8use axum_server::tls_rustls::RustlsConfig;
9use futures::FutureExt;
10use futures::future::Shared;
11use nameth::NamedEnumValues as _;
12use nameth::nameth;
13use tokio::sync::oneshot;
14use tokio_rustls::TlsConnector;
15use tracing::Instrument as _;
16use tracing::Span;
17use tracing::debug;
18use tracing::warn;
19use trz_gateway_common::certificate_info::X509CertificateInfo;
20use trz_gateway_common::handle::ServerHandle;
21use trz_gateway_common::security_configuration::certificate::CertificateConfig;
22use trz_gateway_common::security_configuration::certificate::tls_server::ToTlsServer;
23use trz_gateway_common::security_configuration::certificate::tls_server::ToTlsServerError;
24use trz_gateway_common::security_configuration::custom_server_certificate_verifier::SignedExtensionCertificateVerifier;
25use trz_gateway_common::security_configuration::trusted_store::TrustedStoreConfig;
26use trz_gateway_common::security_configuration::trusted_store::cache::CachedTrustedStoreConfig;
27use trz_gateway_common::security_configuration::trusted_store::tls_client::ToTlsClient;
28use trz_gateway_common::security_configuration::trusted_store::tls_client::ToTlsClientError;
29use trz_gateway_common::tracing::EnableTracingError;
30
31use self::gateway_config::AppConfig;
32use self::gateway_config::GatewayConfig;
33use self::issuer_config::IssuerConfig;
34use self::issuer_config::IssuerConfigError;
35use crate::connection::Connections;
36
37mod app;
38mod certificate;
39pub mod gateway_config;
40mod issuer_config;
41pub mod root_ca_configuration;
42mod tunnel;
43
44#[cfg(test)]
45mod tests;
46
47pub struct Server {
48    shutdown: Shared<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
49    root_ca: Arc<X509CertificateInfo>,
50    tls_server: RustlsConfig,
51    tls_client: TlsConnector,
52    connections: Arc<Connections>,
53    issuer_config: IssuerConfig,
54    app_config: Box<dyn AppConfig>,
55}
56
57impl Server {
58    pub async fn run<C: GatewayConfig>(
59        config: C,
60    ) -> Result<(Arc<Self>, ServerHandle<()>), GatewayError<C>> {
61        if config.enable_tracing() {
62            trz_gateway_common::tracing::enable_tracing()?;
63        }
64
65        let (shutdown_rx, terminated_tx, handle) = ServerHandle::new();
66        let shutdown_rx: Pin<Box<dyn Future<Output = ()> + Send + Sync>> = Box::pin(shutdown_rx);
67
68        let root_ca = config
69            .root_ca()
70            .certificate()
71            .map_err(|error| GatewayError::RootCa(error.into()))?;
72        debug!("Got Root CA: {}", root_ca.display());
73
74        let client_certificate_issuer = config.client_certificate_issuer();
75        let issuer_config = IssuerConfig::new(&client_certificate_issuer)?;
76
77        let tls_server = config.tls().to_tls_server().await?;
78        debug!("Got TLS server config");
79
80        let tls_client = config
81            .root_ca()
82            .to_tls_client(SignedExtensionCertificateVerifier {
83                store: CachedTrustedStoreConfig::new(client_certificate_issuer)
84                    .map_err(GatewayError::CachedTrustedStoreConfig)?,
85                signer_name: issuer_config.signer_name.clone(),
86            })
87            .await?;
88        debug!("Got TLS client config");
89
90        let server = Arc::new(Self {
91            shutdown: shutdown_rx.shared(),
92            root_ca,
93            tls_server: RustlsConfig::from_config(Arc::from(tls_server)),
94            tls_client: TlsConnector::from(Arc::new(tls_client)),
95            connections: Arc::new(Connections::default()),
96            issuer_config,
97            app_config: Box::new(config.app_config()),
98        });
99
100        let (host, port) = (config.host(), config.port());
101        let socket_addrs = (host, port).to_socket_addrs();
102        let socket_addrs = socket_addrs.map_err(|error| GatewayError::ToSocketAddrs {
103            host: host.to_owned(),
104            port,
105            error,
106        })?;
107
108        let mut terminated = vec![];
109
110        for socket_addr in socket_addrs {
111            debug!("Setup server on {socket_addr}");
112            let task = server.clone().run_endpoint(socket_addr, Span::current());
113            let (terminated_tx, terminated_rx) = oneshot::channel();
114            terminated.push(terminated_rx);
115            tokio::spawn(async move {
116                match task.await {
117                    Ok(()) => (),
118                    Err(error) => warn!("Failed {error}"),
119                }
120                let _: Result<(), ()> = terminated_tx.send(());
121            });
122        }
123
124        {
125            use futures::future::join_all;
126            let all_terminated = join_all(terminated);
127            tokio::spawn(
128                async move {
129                    let _: Vec<Result<(), oneshot::error::RecvError>> = all_terminated.await;
130                    let _: Result<(), ()> = terminated_tx.send(());
131                }
132                .in_current_span(),
133            );
134        }
135        Ok((server, handle))
136    }
137
138    async fn run_endpoint(
139        self: Arc<Self>,
140        socket_addr: SocketAddr,
141        span: Span,
142    ) -> Result<(), RunGatewayError> {
143        let app = self.make_app(span);
144
145        let handle = Handle::new();
146        let axum_server =
147            axum_server::bind_rustls(socket_addr, self.tls_server.clone()).handle(handle.clone());
148
149        let shutdown = self.shutdown.clone();
150        tokio::spawn(
151            async move {
152                let () = shutdown.await;
153                handle.shutdown();
154            }
155            .in_current_span(),
156        );
157
158        debug!("Serving...");
159        let () = axum_server
160            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
161            .await
162            .map_err(RunGatewayError::Serve)?;
163        debug!("Serving: done");
164        Ok(())
165    }
166
167    pub fn connections(&self) -> &Connections {
168        &self.connections
169    }
170}
171
172#[nameth]
173#[derive(thiserror::Error, Debug)]
174pub enum GatewayError<C: GatewayConfig> {
175    #[error("[{n}] {0}", n = self.name())]
176    EnableTracing(#[from] EnableTracingError),
177
178    #[error("[{n}] Failed to get Root CA: {0}", n = self.name())]
179    RootCa(Box<dyn std::error::Error>),
180
181    #[error("[{n}] Failed to get the client certificate issuer configuration: {0}", n = self.name())]
182    IssuerConfig(#[from] IssuerConfigError<C::ClientCertificateIssuerConfig>),
183
184    #[error("[{n}] Failed to get socket address for {host}:{port}: {error}", n = self.name())]
185    ToSocketAddrs {
186        host: String,
187        port: u16,
188        error: std::io::Error,
189    },
190
191    #[error("[{n}] {0}", n = self.name())]
192    ToTlsServerConfig(#[from] ToTlsServerError<<C::TlsConfig as CertificateConfig>::Error>),
193
194    #[error("[{n}] {0}", n = self.name())]
195    ToTlsClientConfig(#[from] ToTlsClientError<<C::RootCaConfig as TrustedStoreConfig>::Error>),
196
197    #[error("[{n}] {0}", n = self.name())]
198    CachedTrustedStoreConfig(<C::ClientCertificateIssuerConfig as TrustedStoreConfig>::Error),
199
200    #[error("[{n}] {0}", n = self.name())]
201    RunGatewayError(#[from] RunGatewayError),
202}
203
204#[nameth]
205#[derive(thiserror::Error, Debug)]
206pub enum RunGatewayError {
207    #[error("[{n}] {0}", n = self.name())]
208    Serve(std::io::Error),
209}