Skip to main content

trz_gateway_client/
client.rs

1//! The Terrazzo Gateway [Client].
2
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::sync::atomic::AtomicBool;
6use std::sync::atomic::Ordering::SeqCst;
7use std::time::Instant;
8
9use connect::ConnectError;
10use futures::FutureExt;
11use futures::future::Shared;
12use nameth::NamedEnumValues as _;
13use nameth::nameth;
14use tokio::sync::oneshot;
15use tracing::Instrument;
16use tracing::info;
17use tracing::info_span;
18use tracing::warn;
19use trz_gateway_common::declare_identifier;
20use trz_gateway_common::handle::ServerHandle;
21use trz_gateway_common::id::ClientId;
22use trz_gateway_common::id::ClientName;
23use trz_gateway_common::retry_strategy::RetryStrategy;
24use trz_gateway_common::security_configuration::certificate::CertificateConfig;
25use trz_gateway_common::security_configuration::certificate::tls_server::ToTlsServer as _;
26use trz_gateway_common::security_configuration::certificate::tls_server::ToTlsServerError;
27use trz_gateway_common::security_configuration::custom_server_certificate_verifier::ChainOnlyServerCertificateVerifier;
28use trz_gateway_common::security_configuration::trusted_store::TrustedStoreConfig;
29use trz_gateway_common::security_configuration::trusted_store::tls_client::ToTlsClient as _;
30use trz_gateway_common::security_configuration::trusted_store::tls_client::ToTlsClientError;
31use uuid::Uuid;
32
33use self::config::SniOverrideError;
34use self::config::url;
35use self::service::ClientService;
36use crate::tunnel_config::TunnelConfig;
37
38pub mod certificate;
39pub mod config;
40pub mod connect;
41mod connection;
42mod health;
43pub mod service;
44
45/// The [Client].
46///
47/// It creates a WebSocket tunnel with the Terrazzo Gateway, and then runs a
48/// gRPC server that listens to requests sent or forwarded by the Terrazzo
49/// Gateway over the WebSocket tunnel.
50pub struct Client {
51    /// The client name for troubleshooting purposes
52    pub client_name: ClientName,
53
54    /// The URL used to open the TCP connection.
55    uri: String,
56
57    /// The TLS server name to validate when it differs from [Self::uri].
58    sni_override: Option<String>,
59
60    /// The TLS client is used to create the secure WebSocket tunnel.
61    ///
62    /// (without client certificate auth)
63    tls_client: tokio_tungstenite::Connector,
64
65    /// The TLS server is used to accept gRPC connections through the tunnel.
66    ///
67    /// The client uses its certiticate to authenticate the server side of the connection.
68    tls_server: tokio_rustls::TlsAcceptor,
69
70    /// A callback to configure the [tonic gRPC server](tonic::transport::Server).
71    client_service: Arc<dyn ClientService>,
72
73    /// The strategy to retry creating the WebSocket tunnel when it fails
74    retry_strategy: RetryStrategy,
75
76    /// A global mutable variable that holds the [AuthCode].
77    ///
78    /// This is used to periodically renew the certificate.
79    current_auth_code: Arc<Mutex<AuthCode>>,
80}
81
82declare_identifier!(AuthCode);
83
84impl Client {
85    /// Creates a new [Client].
86    pub fn new<C: TunnelConfig>(config: C) -> Result<Arc<Self>, NewClientError<C>> {
87        let tls_client = config
88            .gateway_pki()
89            .to_tls_client(ChainOnlyServerCertificateVerifier)?;
90        let tls_server = config.client_certificate().to_tls_server()?;
91        let client_name = config.client_name();
92        let tunnel_path = format!("/remote/tunnel/{client_name}");
93        Ok(Arc::new(Client {
94            client_name,
95            uri: url(&config, &tunnel_path)?.to_string(),
96            sni_override: config.sni_override().map(ToOwned::to_owned),
97            tls_client: tokio_tungstenite::Connector::Rustls(tls_client.into()),
98            tls_server: tokio_rustls::TlsAcceptor::from(tls_server),
99            client_service: Arc::new(config.client_service()),
100            retry_strategy: config.retry_strategy(),
101            current_auth_code: config.current_auth_code(),
102        }))
103    }
104
105    /// Runs the client and returns a handle to stop the client.
106    pub async fn run(self: &Arc<Self>) -> Result<ServerHandle<()>, ConnectError> {
107        let this = self.clone();
108        let client_name = &this.client_name;
109        let span = info_span!("Run", %client_name);
110        async move {
111            let client_id = ClientId::from(Uuid::new_v4().to_string());
112            info!(%client_id, "Allocated new client id");
113            let (shutdown_rx, terminated_tx, handle) = ServerHandle::new("Client");
114            let (serving_tx, serving_rx) = oneshot::channel();
115            let task = run_impl(this, client_id, serving_tx, shutdown_rx, terminated_tx);
116            tokio::spawn(task.in_current_span());
117            let _ = serving_rx.await;
118            Ok(handle)
119        }
120        .instrument(span)
121        .await
122    }
123}
124
125async fn run_impl(
126    this: Arc<Client>,
127    client_id: ClientId,
128
129    // Set when the client is serving connections
130    serving_tx: oneshot::Sender<()>,
131
132    // Set when the client should start shutting down
133    shutdown_rx: impl Future<Output = ()> + Send + 'static,
134
135    // Set when the client has shut down
136    terminated_tx: oneshot::Sender<()>,
137) {
138    scopeguard::defer! { let _ = terminated_tx.send(()); };
139    let retry_strategy0 = this.retry_strategy.clone();
140    let mut retry_strategy = retry_strategy0.clone();
141    let shutdown_rx = shutdown_rx.shared();
142
143    let is_shutdown = is_shutdown(shutdown_rx.clone());
144
145    let mut serving_tx: Option<oneshot::Sender<()>> = Some(serving_tx);
146    loop {
147        let start = Instant::now();
148        let result = this
149            .connect(
150                client_id.clone(),
151                shutdown_rx.clone(),
152                retry_strategy.peek() / 2,
153                &mut serving_tx,
154            )
155            .await;
156        if is_shutdown.load(SeqCst) {
157            return;
158        }
159        let uptime = Instant::now() - start;
160        if uptime < retry_strategy0.max_delay() {
161            match result {
162                Ok(()) => {
163                    info! { "Connection closed, retrying in {}...", humantime::format_duration(retry_strategy.peek()) }
164                }
165                Err(error) => {
166                    warn! { %error, "Connection failed, retrying in {}...", humantime::format_duration(retry_strategy.peek()) }
167                }
168            }
169            if let futures::future::Either::Right(((), _retry_strategy_wait)) =
170                futures::future::select(Box::pin(retry_strategy.wait()), shutdown_rx.clone()).await
171            {
172                return;
173            }
174        } else {
175            retry_strategy = retry_strategy0.clone();
176        }
177    }
178}
179
180fn is_shutdown(shutdown_rx: Shared<impl Future<Output = ()> + Send + 'static>) -> Arc<AtomicBool> {
181    let is_shutdown = Arc::new(AtomicBool::new(false));
182    tokio::spawn({
183        let is_shutdown = is_shutdown.clone();
184        async move {
185            let _ = shutdown_rx.await;
186            is_shutdown.store(true, SeqCst);
187        }
188    });
189    return is_shutdown;
190}
191
192#[nameth]
193#[derive(thiserror::Error, Debug)]
194pub enum NewClientError<C: TunnelConfig> {
195    #[error("[{n}] {0}", n = self.name())]
196    SniOverride(#[from] SniOverrideError),
197
198    #[error("[{n}] {0}", n = self.name())]
199    ToTlsClient(#[from] ToTlsClientError<<C::GatewayPki as TrustedStoreConfig>::Error>),
200
201    #[error("[{n}] {0}", n = self.name())]
202    ToTlsServer(#[from] ToTlsServerError<<C::ClientCertificate as CertificateConfig>::Error>),
203}